461 lines
20 KiB
Python
461 lines
20 KiB
Python
"""
|
||
阵容推荐引擎 - 负责根据用户需求生成最优阵容
|
||
|
||
此模块实现了阵容推荐的核心逻辑,包括:
|
||
1. 基于羁绊、棋子和人口约束生成可行阵容
|
||
2. 根据评分标准对阵容进行排序
|
||
3. 返回最优阵容推荐结果
|
||
"""
|
||
from typing import Dict, List, Optional, Any, Set, Tuple, Union
|
||
import logging
|
||
import itertools
|
||
from dataclasses import dataclass, field
|
||
from src.data_provider import DataQueryAPI
|
||
from src.scoring.scoring_system import TeamScorer
|
||
from src.config import get_global_weights_config
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||
)
|
||
logger = logging.getLogger("TFT-Strategist-RecommendationEngine")
|
||
|
||
@dataclass
|
||
class TeamComposition:
|
||
"""表示一个云顶之弈阵容"""
|
||
chess_list: List[Dict[str, Any]] = field(default_factory=list)
|
||
synergy_counts: Dict[str, int] = field(default_factory=dict)
|
||
synergy_levels: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict)
|
||
total_cost: int = 0
|
||
score: float = 0.0
|
||
|
||
@property
|
||
def size(self) -> int:
|
||
"""阵容人口数"""
|
||
return len(self.chess_list)
|
||
|
||
def add_chess(self, chess: Dict[str, Any]) -> None:
|
||
"""添加棋子到阵容"""
|
||
if chess not in self.chess_list:
|
||
self.chess_list.append(chess)
|
||
self.total_cost += int(chess.get('price', 0))
|
||
# 更新羁绊计数
|
||
|
||
def calculate_synergies(self, api: DataQueryAPI) -> None:
|
||
"""计算阵容的所有羁绊及其激活等级"""
|
||
self.synergy_counts = {}
|
||
self.synergy_levels = {'job': [], 'race': []}
|
||
|
||
# 统计所有羁绊的数量
|
||
for chess in self.chess_list:
|
||
# 处理职业
|
||
for job_id in chess.get('jobIds', '').split(','):
|
||
if job_id:
|
||
self.synergy_counts[job_id] = self.synergy_counts.get(job_id, 0) + 1
|
||
|
||
# 处理特质
|
||
for race_id in chess.get('raceIds', '').split(','):
|
||
if race_id:
|
||
self.synergy_counts[race_id] = self.synergy_counts.get(race_id, 0) + 1
|
||
|
||
# 确定各羁绊激活的等级
|
||
for synergy_id, count in self.synergy_counts.items():
|
||
# 确保synergy_id是字符串
|
||
synergy_id_str = str(synergy_id)
|
||
synergy = api.get_synergy_by_id(synergy_id_str)
|
||
if not synergy:
|
||
logger.warning(f"未找到羁绊: {{'id': {synergy_id}}}")
|
||
continue
|
||
|
||
# 确定激活的等级
|
||
levels = api.get_synergy_levels(synergy_id_str)
|
||
active_levels = []
|
||
|
||
for level_str, effect in levels.items():
|
||
level = int(level_str)
|
||
if count >= level:
|
||
active_level = {
|
||
'id': synergy_id_str,
|
||
'name': synergy.get('name', ''),
|
||
'level': level,
|
||
'count': count,
|
||
'effect': effect
|
||
}
|
||
active_levels.append(active_level)
|
||
|
||
if active_levels:
|
||
# 按等级排序
|
||
active_levels.sort(key=lambda x: x['level'])
|
||
# 将羁绊归类为职业或特质
|
||
if 'jobId' in synergy:
|
||
self.synergy_levels['job'].extend(active_levels)
|
||
else:
|
||
self.synergy_levels['race'].extend(active_levels)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""将阵容转换为字典表示"""
|
||
return {
|
||
'chess_list': self.chess_list,
|
||
'synergy_counts': self.synergy_counts,
|
||
'synergy_levels': self.synergy_levels,
|
||
'total_cost': self.total_cost,
|
||
'size': self.size,
|
||
'score': self.score
|
||
}
|
||
|
||
|
||
class RecommendationEngine:
|
||
"""
|
||
阵容推荐引擎,负责根据用户需求生成最优阵容
|
||
"""
|
||
|
||
def __init__(self, api: DataQueryAPI, scorer: TeamScorer, config_path: Optional[str] = None, config_obj: Optional[Dict[str, Any]] = None):
|
||
"""
|
||
初始化阵容推荐引擎
|
||
|
||
Args:
|
||
api: 数据查询API实例,如果为None则创建一个新的实例
|
||
scorer: 阵容评分系统实例,如果为None则创建一个新的实例
|
||
config_path: 配置文件路径,用于全局权重配置
|
||
config_obj: 配置对象,优先于config_path使用
|
||
"""
|
||
self.api = api if api else DataQueryAPI()
|
||
|
||
# 加载全局权重配置
|
||
if config_obj:
|
||
self.weights_config = config_obj
|
||
else:
|
||
self.weights_config = get_global_weights_config(config_path).config_data
|
||
|
||
# 创建阵容评分系统
|
||
self.scorer = scorer if scorer else TeamScorer(config_path=config_path)
|
||
|
||
def recommend_team(
|
||
self,
|
||
population: int = 9,
|
||
required_synergies: Optional[List[Dict[str, Any]]] = None,
|
||
required_chess: Optional[List[Dict[str, Any]]] = None,
|
||
max_results: int = 5
|
||
) -> List[TeamComposition]:
|
||
"""
|
||
根据用户需求推荐阵容
|
||
|
||
Args:
|
||
population: 阵容人口数,默认为9
|
||
required_synergies: 必须包含的羁绊列表,每个羁绊为一个字典,包含id和最低激活等级
|
||
required_chess: 必须包含的棋子列表
|
||
max_results: 最多返回的推荐阵容数量
|
||
|
||
Returns:
|
||
List[TeamComposition]: 推荐的阵容列表,按评分从高到低排序
|
||
"""
|
||
if population < 1 or population > 10:
|
||
logger.warning(f"无效的人口数 {population},已调整为默认值9")
|
||
population = 9
|
||
|
||
required_synergies = required_synergies or []
|
||
required_chess = required_chess or []
|
||
|
||
# 1. 处理必选棋子
|
||
base_team = TeamComposition()
|
||
for chess_info in required_chess:
|
||
if isinstance(chess_info, dict) and 'name' in chess_info:
|
||
chess = self.api.get_chess_by_name(chess_info['name'])
|
||
elif isinstance(chess_info, dict) and 'id' in chess_info:
|
||
chess = self.api.get_chess_by_id(str(chess_info['id']))
|
||
else:
|
||
chess = self.api.get_chess_by_name(str(chess_info))
|
||
|
||
if chess:
|
||
base_team.add_chess(chess)
|
||
else:
|
||
logger.warning(f"未找到棋子: {chess_info}")
|
||
|
||
# 如果必选棋子已经超过了人口限制,则直接返回
|
||
if base_team.size > population:
|
||
logger.warning(f"必选棋子数量({base_team.size})超过了人口限制({population})")
|
||
base_team.calculate_synergies(self.api)
|
||
base_team.score = self.scorer.score_team(base_team)
|
||
return [base_team]
|
||
|
||
# 2. 获取羁绊所需的棋子集合
|
||
synergy_chess_sets = []
|
||
for synergy_info in required_synergies:
|
||
if isinstance(synergy_info, dict) and 'name' in synergy_info:
|
||
synergy = self.api.get_synergy_by_name(synergy_info['name'])
|
||
min_level = synergy_info.get('level', 1)
|
||
elif isinstance(synergy_info, dict) and 'id' in synergy_info:
|
||
synergy = self.api.get_synergy_by_id(str(synergy_info['id']))
|
||
min_level = synergy_info.get('level', 1)
|
||
else:
|
||
synergy = self.api.get_synergy_by_name(str(synergy_info))
|
||
min_level = 1
|
||
|
||
if not synergy:
|
||
logger.warning(f"未找到羁绊: {synergy_info}")
|
||
continue
|
||
|
||
synergy_id = synergy.get('jobId') or synergy.get('raceId')
|
||
chess_list = self.api.get_chess_by_synergy(synergy_id)
|
||
|
||
# 获取该羁绊的激活等级
|
||
levels = self.api.get_synergy_levels(synergy_id)
|
||
target_level = 0
|
||
for level_str in levels.keys():
|
||
level = int(level_str)
|
||
if level >= min_level:
|
||
target_level = level
|
||
break
|
||
|
||
if target_level == 0:
|
||
logger.warning(f"羁绊 {synergy.get('name', '')} 找不到满足最低等级 {min_level} 的激活条件")
|
||
continue
|
||
|
||
# 考虑羁绊在全局配置中的权重
|
||
weight = self.weights_config.get('synergy_weights', {}).get(synergy.get('name', ''), 1.0)
|
||
|
||
synergy_chess_sets.append({
|
||
'synergy': synergy,
|
||
'chess_list': chess_list,
|
||
'target_level': target_level,
|
||
'weight': weight # 添加权重信息
|
||
})
|
||
|
||
# 3. 根据已选棋子和羁绊要求生成候选阵容
|
||
candidate_teams = self._generate_candidate_teams(base_team, synergy_chess_sets, population)
|
||
|
||
# 如果没有生成候选阵容,则尝试填充最佳棋子
|
||
if not candidate_teams:
|
||
remaining_slots = population - base_team.size
|
||
logger.info(f"未找到满足所有羁绊要求的阵容,尝试填充最佳棋子 (剩余槽位: {remaining_slots})")
|
||
candidate_teams = self._fill_team_with_best_chess(base_team, remaining_slots)
|
||
|
||
# 4. 计算每个阵容的评分并排序
|
||
for team in candidate_teams:
|
||
team.calculate_synergies(self.api)
|
||
team.score = self.scorer.score_team(team)
|
||
|
||
# 按分数从高到低排序
|
||
candidate_teams.sort(key=lambda t: t.score, reverse=True)
|
||
|
||
# 返回分数最高的几个阵容
|
||
return candidate_teams[:max_results]
|
||
|
||
def _generate_candidate_teams(
|
||
self,
|
||
base_team: TeamComposition,
|
||
synergy_chess_sets: List[Dict[str, Any]],
|
||
population: int
|
||
) -> List[TeamComposition]:
|
||
"""
|
||
根据羁绊要求生成候选阵容
|
||
|
||
Args:
|
||
base_team: 基础阵容,包含必选棋子
|
||
synergy_chess_sets: 羁绊需要的棋子集合列表
|
||
population: 人口限制
|
||
|
||
Returns:
|
||
List[TeamComposition]: 候选阵容列表
|
||
"""
|
||
# 如果没有羁绊要求,则直接填充最佳棋子
|
||
if not synergy_chess_sets:
|
||
remaining_slots = population - base_team.size
|
||
return self._fill_team_with_best_chess(base_team, remaining_slots)
|
||
|
||
# 排序羁绊集合,根据权重值从高到低排序
|
||
synergy_chess_sets.sort(key=lambda s: s.get('weight', 1.0), reverse=True)
|
||
|
||
# 开始构建组合
|
||
candidate_teams = []
|
||
|
||
# 记录已经在基础阵容中的棋子
|
||
base_chess_names = {chess.get('displayName') for chess in base_team.chess_list}
|
||
|
||
# 1. 从每个羁绊中选择需要的棋子数量
|
||
# 创建每个羁绊至少需要的棋子数组合
|
||
min_chess_options = []
|
||
for synergy_set in synergy_chess_sets:
|
||
chess_list = synergy_set['chess_list']
|
||
target_level = synergy_set['target_level']
|
||
|
||
# 计算需要添加的棋子数量
|
||
already_in_base = sum(1 for chess in base_team.chess_list
|
||
if any(job_id in chess.get('jobIds', '').split(',')
|
||
for job_id in [synergy_set['synergy'].get('jobId')])
|
||
or any(race_id in chess.get('raceIds', '').split(',')
|
||
for race_id in [synergy_set['synergy'].get('raceId')]))
|
||
|
||
needed_count = max(0, target_level - already_in_base)
|
||
if needed_count <= 0:
|
||
# 如果基础阵容已经满足该羁绊要求,则跳过
|
||
continue
|
||
|
||
# 从该羁绊的棋子中选择还未在基础阵容中的
|
||
available_chess = [chess for chess in chess_list
|
||
if chess.get('displayName') not in base_chess_names]
|
||
|
||
# 如果可用棋子不足以满足羁绊要求,则返回空列表
|
||
if len(available_chess) < needed_count:
|
||
logger.warning(f"羁绊 {synergy_set['synergy'].get('name', '')} 可用棋子不足,无法满足等级 {target_level} 的要求")
|
||
return []
|
||
|
||
# 所有可能的棋子组合
|
||
chess_combinations = list(itertools.combinations(available_chess, needed_count))
|
||
|
||
# 将权重因素纳入棋子选择决策
|
||
weighted_combinations = []
|
||
for combo in chess_combinations:
|
||
# 计算组合的权重分数
|
||
combo_weight = sum(self.weights_config.get('chess_weights', {}).get(chess.get('displayName', ''), 1.0) for chess in combo)
|
||
weighted_combinations.append((combo, combo_weight))
|
||
|
||
# 按权重从高到低排序
|
||
weighted_combinations.sort(key=lambda x: x[1], reverse=True)
|
||
|
||
# 选择权重较高的前几个组合
|
||
top_combinations = [combo for combo, _ in weighted_combinations[:min(5, len(weighted_combinations))]]
|
||
min_chess_options.append(top_combinations)
|
||
|
||
# 如果某个羁绊无法满足,则返回空列表
|
||
if not all(min_chess_options):
|
||
return []
|
||
|
||
# 2. 尝试组合不同羁绊的棋子,生成可行阵容
|
||
# 为每个羁绊选择一种棋子组合方式
|
||
for chess_selection in itertools.product(*min_chess_options):
|
||
# 创建新的候选阵容
|
||
candidate = TeamComposition()
|
||
# 添加基础阵容的棋子
|
||
for chess in base_team.chess_list:
|
||
candidate.add_chess(chess)
|
||
|
||
# 添加各羁绊选中的棋子
|
||
all_selected_chess = set()
|
||
for combo in chess_selection:
|
||
for chess in combo:
|
||
if chess.get('displayName') not in all_selected_chess:
|
||
candidate.add_chess(chess)
|
||
all_selected_chess.add(chess.get('displayName'))
|
||
|
||
# 检查是否超出人口限制
|
||
if candidate.size <= population:
|
||
# 如果有剩余人口,填充最佳棋子
|
||
if candidate.size < population:
|
||
remaining_slots = population - candidate.size
|
||
filled_candidates = self._fill_team_with_best_chess(candidate, remaining_slots)
|
||
candidate_teams.extend(filled_candidates)
|
||
else:
|
||
candidate_teams.append(candidate)
|
||
|
||
return candidate_teams
|
||
|
||
def _fill_team_with_best_chess(
|
||
self,
|
||
base_team: TeamComposition,
|
||
remaining_slots: int
|
||
) -> List[TeamComposition]:
|
||
"""
|
||
用最佳棋子填充剩余槽位
|
||
|
||
Args:
|
||
base_team: 基础阵容
|
||
remaining_slots: 剩余槽位数量
|
||
|
||
Returns:
|
||
List[TeamComposition]: 填充后的阵容列表
|
||
"""
|
||
if remaining_slots <= 0:
|
||
return [base_team]
|
||
|
||
# 计算每个棋子的羁绊价值
|
||
all_chess = self.api.get_all_chess()
|
||
base_chess_names = {chess.get('displayName') for chess in base_team.chess_list}
|
||
|
||
# 排除已在基础阵容中的棋子
|
||
available_chess = [chess for chess in all_chess
|
||
if chess.get('displayName') not in base_chess_names]
|
||
|
||
# 计算基础阵容中已有的羁绊
|
||
base_synergies = {}
|
||
for chess in base_team.chess_list:
|
||
# 职业
|
||
for job_id in chess.get('jobIds', '').split(','):
|
||
if job_id:
|
||
base_synergies[job_id] = base_synergies.get(job_id, 0) + 1
|
||
|
||
# 特质
|
||
for race_id in chess.get('raceIds', '').split(','):
|
||
if race_id:
|
||
base_synergies[race_id] = base_synergies.get(race_id, 0) + 1
|
||
|
||
# 评估每个可用棋子的价值
|
||
chess_values = []
|
||
for chess in available_chess:
|
||
value = 0
|
||
|
||
# 根据费用评估基础价值
|
||
cost = int(chess.get('price', 1))
|
||
cost_multiplier = float(self.weights_config.get('cost_weights', {}).get(str(cost), 1.0))
|
||
value += cost * cost_multiplier
|
||
|
||
# 根据羁绊评估额外价值
|
||
for job_id in chess.get('jobIds', '').split(','):
|
||
if job_id and job_id in base_synergies:
|
||
# 如果能与基础阵容形成羁绊,增加价值
|
||
synergy = self.api.get_synergy_by_id(job_id)
|
||
if synergy:
|
||
# 检查是否能激活新的羁绊等级
|
||
levels = self.api.get_synergy_levels(job_id)
|
||
current_count = base_synergies[job_id]
|
||
for level_str in sorted(levels.keys()):
|
||
level = int(level_str)
|
||
if current_count + 1 >= level > current_count:
|
||
# 如果添加这个棋子后能激活新的等级
|
||
weight = self.weights_config.get('synergy_weights', {}).get(synergy.get('name', ''), 1.0)
|
||
level_multiplier = float(self.weights_config.get('synergy_level_weights', {}).get(level_str, 1.0))
|
||
value += weight * level_multiplier * 10
|
||
break
|
||
# 即使不能激活新等级,也增加一些价值
|
||
value += 1
|
||
|
||
for race_id in chess.get('raceIds', '').split(','):
|
||
if race_id and race_id in base_synergies:
|
||
# 如果能与基础阵容形成羁绊,增加价值
|
||
synergy = self.api.get_synergy_by_id(race_id)
|
||
if synergy:
|
||
# 检查是否能激活新的羁绊等级
|
||
levels = self.api.get_synergy_levels(race_id)
|
||
current_count = base_synergies[race_id]
|
||
for level_str in sorted(levels.keys()):
|
||
level = int(level_str)
|
||
if current_count + 1 >= level > current_count:
|
||
# 如果添加这个棋子后能激活新的等级
|
||
weight = self.weights_config.get('synergy_weights', {}).get(synergy.get('name', ''), 1.0)
|
||
level_multiplier = float(self.weights_config.get('synergy_level_weights', {}).get(level_str, 1.0))
|
||
value += weight * level_multiplier * 10
|
||
break
|
||
# 即使不能激活新等级,也增加一些价值
|
||
value += 1
|
||
|
||
# 添加棋子自定义权重
|
||
chess_weight = self.weights_config.get('chess_weights', {}).get(chess.get('displayName', ''), 1.0)
|
||
value *= chess_weight
|
||
|
||
chess_values.append((chess, value))
|
||
|
||
# 按价值从高到低排序
|
||
chess_values.sort(key=lambda x: x[1], reverse=True)
|
||
|
||
# 创建候选阵容,添加价值最高的棋子
|
||
candidate = TeamComposition()
|
||
# 添加基础阵容的棋子
|
||
for chess in base_team.chess_list:
|
||
candidate.add_chess(chess)
|
||
|
||
# 添加价值最高的剩余槽位数量的棋子
|
||
for i in range(min(remaining_slots, len(chess_values))):
|
||
candidate.add_chess(chess_values[i][0])
|
||
|
||
return [candidate] |