TFT-Strategist/tests/test_recommendation.py

223 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试阵容推荐模块
"""
import os
import sys
import unittest
from unittest import mock
# 添加项目根目录到Python路径
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.recommendation.recommendation_engine import RecommendationEngine, TeamComposition
from src.data_provider import DataQueryAPI
from src.scoring import TeamScorer
class TestTeamComposition(unittest.TestCase):
"""测试TeamComposition类"""
def test_add_chess(self):
"""测试添加棋子"""
team = TeamComposition()
chess = {"displayName": "布兰德", "price": "4", "jobIds": "10172", "raceIds": "10154"}
# 添加棋子
team.add_chess(chess)
# 验证结果
self.assertEqual(len(team.chess_list), 1)
self.assertEqual(team.chess_list[0], chess)
self.assertEqual(team.size, 1)
self.assertEqual(team.total_cost, 4)
# 重复添加同一个棋子,应该不会改变
team.add_chess(chess)
self.assertEqual(len(team.chess_list), 1)
def test_calculate_synergies(self):
"""测试计算羁绊"""
team = TeamComposition()
# 添加棋子
team.add_chess({
"displayName": "布兰德",
"price": "4",
"jobIds": "10172",
"raceIds": "10154"
})
team.add_chess({
"displayName": "阿利斯塔",
"price": "3",
"jobIds": "10157",
"raceIds": ""
})
team.add_chess({
"displayName": "其他棋子",
"price": "2",
"jobIds": "10172",
"raceIds": "10154"
})
# 模拟DataQueryAPI
mock_api = mock.Mock()
# 设置get_synergy_by_id的返回值
def get_synergy_by_id(synergy_id):
synergies = {
"10172": {"jobId": "10172", "name": "高级工程师"},
"10157": {"jobId": "10157", "name": "重装战士"},
"10154": {"raceId": "10154", "name": "街头恶魔"}
}
return synergies.get(synergy_id)
mock_api.get_synergy_by_id.side_effect = get_synergy_by_id
# 设置get_synergy_levels的返回值
def get_synergy_levels(synergy_id):
levels = {
"10172": {"2": "效果1", "4": "效果2"},
"10157": {"2": "效果1", "4": "效果2"},
"10154": {"3": "效果1", "5": "效果2"}
}
return levels.get(synergy_id)
mock_api.get_synergy_levels.side_effect = get_synergy_levels
# 计算羁绊
team.calculate_synergies(mock_api)
# 验证结果
self.assertEqual(team.synergy_counts.get("10172"), 2) # 高级工程师有2个
self.assertEqual(team.synergy_counts.get("10157"), 1) # 重装战士有1个
self.assertEqual(team.synergy_counts.get("10154"), 2) # 街头恶魔有2个
# 验证激活的羁绊等级
job_levels = team.synergy_levels['job']
race_levels = team.synergy_levels['race']
# 应该有高级工程师(2)和重装战士(0)的羁绊
self.assertEqual(len(job_levels), 1) # 只有高级工程师达到了激活等级
self.assertEqual(job_levels[0]['name'], "高级工程师")
self.assertEqual(job_levels[0]['level'], 2)
# 重装战士未激活街头恶魔未激活等级3
self.assertEqual(len(race_levels), 0)
def test_to_dict(self):
"""测试转换为字典"""
team = TeamComposition()
team.chess_list = [{"displayName": "布兰德", "price": "4"}]
team.synergy_counts = {"10172": 1}
team.synergy_levels = {'job': [{'name': '高级工程师', 'level': 2}], 'race': []}
team.total_cost = 4
team.score = 10.5
# 转换为字典
result = team.to_dict()
# 验证结果
self.assertEqual(result['chess_list'], team.chess_list)
self.assertEqual(result['synergy_counts'], team.synergy_counts)
self.assertEqual(result['synergy_levels'], team.synergy_levels)
self.assertEqual(result['total_cost'], team.total_cost)
self.assertEqual(result['size'], team.size)
self.assertEqual(result['score'], team.score)
class TestRecommendationEngine(unittest.TestCase):
"""测试RecommendationEngine类"""
def setUp(self):
"""测试前准备"""
# 模拟DataQueryAPI
self.mock_api = mock.Mock()
# 模拟TeamScorer
self.mock_scorer = mock.Mock()
# 创建RecommendationEngine实例
self.engine = RecommendationEngine(api=self.mock_api, scorer=self.mock_scorer)
def test_recommend_team_basic(self):
"""测试基本的阵容推荐功能"""
# 创建模拟的TeamComposition
team1 = TeamComposition()
team1.chess_list = [{"displayName": "布兰德", "price": "4"}]
team1.score = 10.5
team2 = TeamComposition()
team2.chess_list = [{"displayName": "阿利斯塔", "price": "3"}]
team2.score = 8.2
# 模拟_generate_candidate_teams方法
mock_generate = mock.Mock(return_value=[team1, team2])
self.engine._generate_candidate_teams = mock_generate
# 模拟calculate_synergies方法
def mock_calculate_synergies(api):
pass
team1.calculate_synergies = mock_calculate_synergies
team2.calculate_synergies = mock_calculate_synergies
# 模拟score_team方法
self.mock_scorer.score_team.side_effect = lambda team: team.score
# 调用recommend_team
result = self.engine.recommend_team(population=8, max_results=1)
# 验证结果
self.assertEqual(len(result), 1)
self.assertEqual(result[0].score, 10.5) # 应该返回得分最高的阵容
# 验证调用
mock_generate.assert_called_once()
def test_recommend_team_with_required_chess(self):
"""测试指定必选棋子的阵容推荐"""
# 设置get_chess_by_name的返回值
self.mock_api.get_chess_by_name.return_value = {
"displayName": "布兰德",
"price": "4"
}
# 模拟_generate_candidate_teams方法
self.engine._generate_candidate_teams = mock.Mock(return_value=[])
# 调用recommend_team
required_chess = [{"name": "布兰德"}]
self.engine.recommend_team(required_chess=required_chess)
# 验证get_chess_by_name被调用
self.mock_api.get_chess_by_name.assert_called_once_with("布兰德")
def test_recommend_team_with_required_synergies(self):
"""测试指定必选羁绊的阵容推荐"""
# 设置get_synergy_by_name的返回值
self.mock_api.get_synergy_by_name.return_value = {
"jobId": "10172",
"name": "高级工程师"
}
# 设置get_synergy_levels的返回值
self.mock_api.get_synergy_levels.return_value = {
"2": "效果1",
"4": "效果2"
}
# 模拟_generate_candidate_teams方法
self.engine._generate_candidate_teams = mock.Mock(return_value=[])
# 调用recommend_team
required_synergies = [{"name": "高级工程师", "level": 2}]
self.engine.recommend_team(required_synergies=required_synergies)
# 验证get_synergy_by_name被调用
self.mock_api.get_synergy_by_name.assert_called_once_with("高级工程师")
if __name__ == '__main__':
unittest.main()