129 lines
4.7 KiB
Python
129 lines
4.7 KiB
Python
from datetime import datetime, timedelta
|
||
from decimal import Decimal
|
||
import configparser
|
||
from fastapi import HTTPException
|
||
from backend.app.db import get_connection
|
||
|
||
def get_pricing_strategy(strategy_id, cursor):
|
||
"""获取指定 ID 的价格策略"""
|
||
cursor.execute("""
|
||
SELECT segment1_threshold, segment1_price,
|
||
segment2_threshold, segment2_price,
|
||
segment3_price, segment3_threshold
|
||
FROM pricing_strategies
|
||
WHERE strategy_id = %s
|
||
""", (strategy_id,))
|
||
return cursor.fetchone()
|
||
|
||
def get_table_pricing_strategy(strategy_id, cursor):
|
||
"""获取指定 ID 的桌费策略"""
|
||
cursor.execute("""
|
||
SELECT segment1_threshold, segment1_price,
|
||
segment2_threshold, segment2_price,
|
||
segment3_price, segment3_threshold
|
||
FROM table_pricing_strategies
|
||
WHERE strategy_id = %s
|
||
""", (strategy_id,))
|
||
return cursor.fetchone()
|
||
|
||
def calculate_segmented_price(duration, strategy):
|
||
"""
|
||
计算按照分段收费策略的价格:
|
||
- duration: 总时长(分钟)
|
||
- strategy: 包含分段时间和价格的策略字典
|
||
"""
|
||
price = Decimal(0)
|
||
s1, p1, s2, p2, p3, s3 = strategy.values()
|
||
|
||
if duration <= s1:
|
||
price = duration * p1
|
||
elif duration <= s2:
|
||
price = (s1 * p1) + ((duration - s1) * p2)
|
||
else:
|
||
# 超过 s2 但小于 s3(如果 s3 设定)
|
||
if s3 is not None and duration > s3:
|
||
duration = s3 # 限制最大计费时间
|
||
price = (s1 * p1) + ((s2 - s1) * p2) + ((duration - s2) * p3)
|
||
|
||
return price
|
||
|
||
def calculate_overtime_fee(start_datetime, end_datetime):
|
||
"""计算超时费用(深夜时段收费)"""
|
||
config = configparser.ConfigParser()
|
||
config.read('backend/config.conf')
|
||
stay_up_late = config['stay_up_late']
|
||
start_time = datetime.strptime(stay_up_late['start_time'], '%H:%M:%S').time()
|
||
end_time = datetime.strptime(stay_up_late['end_time'], '%H:%M:%S').time()
|
||
price_per_minute = Decimal(stay_up_late['price_minutes'])
|
||
|
||
overtime_fee = Decimal(0)
|
||
current_time = start_datetime
|
||
|
||
while current_time < end_datetime:
|
||
if start_time <= current_time.time() or current_time.time() < end_time:
|
||
overtime_fee += price_per_minute
|
||
current_time += timedelta(minutes=1)
|
||
|
||
return overtime_fee
|
||
|
||
def calculate_order_price(order_id):
|
||
"""计算订单价格并写入数据库"""
|
||
connection = get_connection()
|
||
try:
|
||
with connection.cursor(dictionary=True) as cursor:
|
||
# 获取订单信息
|
||
cursor.execute("""
|
||
SELECT o.start_datetime, o.end_datetime, o.num_players,
|
||
o.pricing_strategy_id, t.table_pricing_strategy_id,
|
||
o.game_table_id
|
||
FROM orders o
|
||
JOIN game_tables t ON o.game_table_id = t.table_id
|
||
WHERE o.order_id = %s
|
||
FOR UPDATE
|
||
""", (order_id,))
|
||
order = cursor.fetchone()
|
||
|
||
if not order or not order['end_datetime']:
|
||
raise HTTPException(status_code=404, detail="订单不存在或未完成")
|
||
|
||
# 计算游戏时长(分钟)
|
||
start_time = order['start_datetime'].replace(tzinfo=None)
|
||
end_time = order['end_datetime'].replace(tzinfo=None)
|
||
duration = (end_time - start_time).total_seconds() / 60
|
||
duration_dec = Decimal(str(duration))
|
||
|
||
# 获取价格策略
|
||
pricing_strategy = get_pricing_strategy(order['pricing_strategy_id'], cursor)
|
||
table_strategy = get_table_pricing_strategy(order['table_pricing_strategy_id'], cursor)
|
||
|
||
if not pricing_strategy or not table_strategy:
|
||
raise HTTPException(status_code=400, detail="价格策略未找到")
|
||
|
||
# 计算价格
|
||
unit_price = calculate_segmented_price(duration_dec, pricing_strategy)
|
||
table_price = calculate_segmented_price(duration_dec, table_strategy)
|
||
|
||
# 计算总价格
|
||
base_price = (unit_price * order['num_players']) + table_price
|
||
overtime_fee = calculate_overtime_fee(start_time, end_time)
|
||
total_price = base_price + overtime_fee
|
||
|
||
# 更新订单价格
|
||
cursor.execute("""
|
||
UPDATE orders
|
||
SET payable_price = %s,
|
||
game_process_time = %s,
|
||
overtime_fee = %s
|
||
WHERE order_id = %s
|
||
""", (total_price, duration_dec, overtime_fee, order_id))
|
||
|
||
connection.commit()
|
||
return True
|
||
|
||
except Exception as e:
|
||
connection.rollback()
|
||
print(f"Error calculating order price: {e}")
|
||
return False
|
||
finally:
|
||
connection.close()
|