-
Notifications
You must be signed in to change notification settings - Fork 77
.1584872158410848:3083b4cb15d65d36f079b043dc22ac14_69e3433c6c6178a73288d93d.69e34b786c6178a73288d98f.69e34b786059ef4a20f968e9:Trae CN.T(2026/4/18 17:14:32) #144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -187,7 +187,134 @@ def extract_answer(text): | |
| return None | ||
|
|
||
|
|
||
| def compute_score(predict_str: str, ground_truth: str, extra_info=None) -> float: | ||
| def compute_dynamic_weights(history_stats: dict = None) -> dict: | ||
| """ | ||
| 基于历史工具调用统计计算动态权重 | ||
|
|
||
| 参数: | ||
| history_stats: 历史统计数据,包含: | ||
| - avg_tool_calls: 平均工具调用次数 | ||
| - tool_success_rate: 工具调用成功率 | ||
| - tool_usage_rate: 工具使用率 | ||
| - history_size: 历史数据量 | ||
|
|
||
| 返回: | ||
| 包含动态权重的字典: | ||
| - acc_weight: 准确性奖励权重 | ||
| - format_weight: 格式奖励权重 | ||
| - tool_weight: 工具使用奖励权重 | ||
| """ | ||
| # 默认权重(与原固定权重一致) | ||
| default_weights = { | ||
| "acc_weight": 0.8, | ||
| "format_weight": 0.2, | ||
| "tool_weight": 1.2 | ||
| } | ||
|
|
||
| if history_stats is None or history_stats.get("history_size", 0) == 0: | ||
| return default_weights | ||
|
|
||
|
Comment on lines
+207
to
+216
|
||
| # 提取历史统计 | ||
| avg_tool_calls = history_stats.get("avg_tool_calls", 1.0) | ||
| tool_success_rate = history_stats.get("tool_success_rate", 0.5) | ||
|
|
||
| # 超参数配置 | ||
| expected_tool_calls = 1.5 # 期望的工具调用次数 | ||
| min_tool_calls = 0.3 # 最低期望的工具调用次数 | ||
| max_tool_calls = 3.0 # 最高期望的工具调用次数 | ||
|
|
||
| min_success_rate = 0.3 # 最低可接受的工具成功率 | ||
| target_success_rate = 0.7 # 理想的工具成功率 | ||
|
|
||
| # 1. 基于工具调用次数的调整因子 | ||
| # - 调用次数过少:鼓励工具使用 → 降低 acc 权重,提高 tool 权重 | ||
| # - 调用次数过多:可能过度依赖工具 → 提高 acc 权重,降低 tool 权重 | ||
|
|
||
| if avg_tool_calls < min_tool_calls: | ||
| # 工具调用严重不足,需要大力鼓励 | ||
| tool_call_factor = 0.7 # 降低 acc 权重 30% | ||
| tool_weight_factor = 1.5 # 提高 tool 权重 50% | ||
| elif avg_tool_calls > max_tool_calls: | ||
| # 工具调用过多,可能过度依赖 | ||
| tool_call_factor = 1.4 # 提高 acc 权重 40% | ||
| tool_weight_factor = 0.6 # 降低 tool 权重 40% | ||
| elif avg_tool_calls < expected_tool_calls * 0.8: | ||
| # 工具调用略少,适度鼓励 | ||
| ratio = (avg_tool_calls - min_tool_calls) / (expected_tool_calls * 0.8 - min_tool_calls) | ||
| tool_call_factor = 0.7 + 0.3 * ratio # 从 0.7 线性增加到 1.0 | ||
| tool_weight_factor = 1.5 - 0.5 * ratio # 从 1.5 线性减少到 1.0 | ||
| elif avg_tool_calls > expected_tool_calls * 1.5: | ||
| # 工具调用略多,适度抑制 | ||
| ratio = (avg_tool_calls - expected_tool_calls * 1.5) / (max_tool_calls - expected_tool_calls * 1.5) | ||
| tool_call_factor = 1.0 + 0.4 * ratio # 从 1.0 线性增加到 1.4 | ||
| tool_weight_factor = 1.0 - 0.4 * ratio # 从 1.0 线性减少到 0.6 | ||
| else: | ||
| # 工具调用次数在合理范围内 | ||
| tool_call_factor = 1.0 | ||
| tool_weight_factor = 1.0 | ||
|
|
||
| # 2. 基于工具成功率的调整因子 | ||
| # - 成功率低:工具使用无效 → 提高 acc 权重,降低 tool 权重 | ||
| # - 成功率高:工具使用有效 → 保持或略微提高 tool 权重 | ||
|
|
||
| if tool_success_rate < min_success_rate: | ||
| # 工具成功率极低,工具使用基本无效 | ||
| success_factor = 1.6 # 大幅提高 acc 权重 | ||
| success_tool_factor = 0.3 # 大幅降低 tool 权重 | ||
| elif tool_success_rate > target_success_rate: | ||
| # 工具成功率很高,工具使用有效 | ||
| success_factor = 0.9 # 略微降低 acc 权重 | ||
| success_tool_factor = 1.2 # 略微提高 tool 权重 | ||
| else: | ||
| # 成功率在中间范围,线性插值 | ||
| ratio = (tool_success_rate - min_success_rate) / (target_success_rate - min_success_rate) | ||
| success_factor = 1.6 - 0.7 * ratio # 从 1.6 线性减少到 0.9 | ||
| success_tool_factor = 0.3 + 0.9 * ratio # 从 0.3 线性增加到 1.2 | ||
|
|
||
| # 3. 计算最终动态权重 | ||
| base_acc = default_weights["acc_weight"] | ||
| base_format = default_weights["format_weight"] | ||
| base_tool = default_weights["tool_weight"] | ||
|
|
||
| # 组合调整因子 | ||
| final_acc_factor = tool_call_factor * success_factor | ||
| final_tool_factor = tool_weight_factor * success_tool_factor | ||
|
|
||
| # 计算最终权重,并限制在合理范围内 | ||
| dynamic_acc_weight = base_acc * final_acc_factor | ||
| dynamic_acc_weight = max(0.3, min(1.8, dynamic_acc_weight)) # 限制范围 [0.3, 1.8] | ||
|
|
||
| # 格式权重保持相对稳定,但可以根据 acc 权重进行适度调整 | ||
| # 确保格式惩罚不会被过度放大或缩小 | ||
| dynamic_format_weight = base_format * (1.0 + 0.3 * (1.0 - final_acc_factor)) | ||
| dynamic_format_weight = max(0.1, min(0.4, dynamic_format_weight)) # 限制范围 [0.1, 0.4] | ||
|
|
||
| dynamic_tool_weight = base_tool * final_tool_factor | ||
| dynamic_tool_weight = max(0.2, min(2.0, dynamic_tool_weight)) # 限制范围 [0.2, 2.0] | ||
|
|
||
| return { | ||
| "acc_weight": dynamic_acc_weight, | ||
| "format_weight": dynamic_format_weight, | ||
| "tool_weight": dynamic_tool_weight, | ||
| "base_acc_weight": base_acc, | ||
| "base_tool_weight": base_tool, | ||
| "tool_call_factor": tool_call_factor, | ||
| "success_factor": success_factor, | ||
| "avg_tool_calls": avg_tool_calls, | ||
| "tool_success_rate": tool_success_rate | ||
| } | ||
|
|
||
|
|
||
| def compute_score(predict_str: str, ground_truth: str, extra_info=None, history_stats=None) -> float: | ||
| """ | ||
| 计算视觉语言任务的奖励,支持基于历史统计的动态权重 | ||
|
|
||
| 参数: | ||
| predict_str: 模型生成的响应字符串 | ||
| ground_truth: 标准答案 | ||
|
Comment on lines
+308
to
+314
|
||
| extra_info: 额外信息(如问题文本) | ||
| history_stats: 历史统计数据,用于计算动态权重 | ||
| """ | ||
| is_format_error = False | ||
| # predict_str = "<think>" + predict_str | ||
| count_think_1 = predict_str.count("<think>") | ||
|
|
@@ -208,14 +335,6 @@ def compute_score(predict_str: str, ground_truth: str, extra_info=None) -> float | |
|
|
||
| answer_text = predict_str.split("<answer>")[-1].split("</answer>")[0].strip() | ||
|
|
||
| # pattern = re.compile(r'<\|im_start\|>assistant(.*?)$', re.DOTALL) # 匹配最后一个 target 后的所有内容 | ||
| # match = pattern.search(predict_str) | ||
| # if match: | ||
| # answer_text = match.group(1).strip() | ||
| # print(f'DEBUG{answer_text=}') | ||
| # else: | ||
| # answer_text = "" | ||
|
|
||
| question_text = extra_info['question'] | ||
| full_prompt = get_prompt(answer_text, ground_truth, question_text) | ||
|
|
||
|
|
@@ -257,26 +376,54 @@ def compute_score(predict_str: str, ground_truth: str, extra_info=None) -> float | |
| acc_reward = 0.0 | ||
| is_format_error = True | ||
|
|
||
| # 计算各奖励分量 | ||
| tool_reward_base = 1.0 if count_vision_1 > 0 else 0.0 | ||
| tool_reward = 1.0 if count_vision_1 > 0 and acc_reward > 0.5 else 0.0 | ||
| format_reward = -1.0 if is_format_error else 0.0 | ||
| # reward 1 | ||
| # return 0.8 * acc_reward + 0.2 * format_reward + 0.4 * tool_reward_base | ||
| # reward 2 | ||
| return 0.8 * acc_reward + 0.2 * format_reward + 1.2 * tool_reward | ||
|
|
||
| # reward 2 | ||
| # return 1.0 * acc_reward + 0.2 * format_reward + 1.0 * tool_reward + 0.2 * tool_reward_base | ||
| # reward 3 | ||
| # tool_reward_alpha = 1.2 if count_vision_1 > 0 else 0.0 | ||
| # return 1.0 * acc_reward * tool_reward_alpha + 0.2 * format_reward | ||
| # reward 4 | ||
| # extra_reward = tool_reward_base * (count_vision_1 - 1) * (1 - acc_reward) | ||
| # return 0.8 * acc_reward + 0.2 * format_reward + 0.4 * tool_reward_base + 0.2 * extra_reward | ||
|
|
||
|
|
||
|
|
||
| def compute_common_reasoning(predict_str: str, ground_truth: str, extra_info=None) -> float: | ||
|
|
||
| # 获取动态权重 | ||
| weights = compute_dynamic_weights(history_stats) | ||
|
|
||
| # 使用动态权重计算最终奖励 | ||
| final_reward = ( | ||
| weights["acc_weight"] * acc_reward + | ||
| weights["format_weight"] * format_reward + | ||
| weights["tool_weight"] * tool_reward | ||
| ) | ||
|
|
||
| # 打印调试信息(显示权重调整情况) | ||
| if history_stats and history_stats.get("history_size", 0) > 0: | ||
| print(f" [Dynamic Weights] acc={weights['acc_weight']:.2f} (base={weights['base_acc_weight']:.2f}), " | ||
| f"tool={weights['tool_weight']:.2f} (base={weights['base_tool_weight']:.2f}), " | ||
| f"history_calls={weights['avg_tool_calls']:.2f}, " | ||
| f"history_success={weights['tool_success_rate']:.2f}") | ||
|
|
||
| # 返回字典形式,包含详细信息以便追踪 | ||
| return { | ||
| "score": final_reward, | ||
| "acc": acc_reward, | ||
| "format": format_reward, | ||
| "tool": tool_reward, | ||
| "tool_calls": count_vision_1, | ||
| "acc_weight": weights["acc_weight"], | ||
| "format_weight": weights["format_weight"], | ||
| "tool_weight": weights["tool_weight"], | ||
| "base_acc_weight": weights["base_acc_weight"], | ||
| "base_tool_weight": weights["base_tool_weight"] | ||
| } | ||
|
|
||
|
|
||
|
|
||
| def compute_common_reasoning(predict_str: str, ground_truth: str, extra_info=None, history_stats=None) -> float: | ||
| """ | ||
| 计算通用推理任务的奖励,支持基于历史统计的动态权重 | ||
|
|
||
| 参数: | ||
| predict_str: 模型生成的响应字符串 | ||
| ground_truth: 标准答案 | ||
| extra_info: 额外信息(如问题文本) | ||
| history_stats: 历史统计数据,用于计算动态权重 | ||
| """ | ||
| is_format_error = False | ||
| # predict_str = "<think>" + predict_str | ||
| count_think_1 = predict_str.count("<think>") | ||
|
|
@@ -335,11 +482,35 @@ def compute_common_reasoning(predict_str: str, ground_truth: str, extra_info=Non | |
| print(f' [ERROR] judgement format invalid: {judgement}') | ||
| continue | ||
|
|
||
| # 计算各奖励分量 | ||
| tool_reward_base = 1.0 if count_vision_1 > 0 else 0.0 | ||
| tool_reward = 1.0 if count_vision_1 > 0 and acc_reward > 0.5 else 0.0 | ||
| format_reward = -1.0 if is_format_error else 0.0 | ||
|
|
||
| # 获取动态权重 | ||
| weights = compute_dynamic_weights(history_stats) | ||
|
|
||
| # 使用动态权重计算最终奖励 | ||
| final_reward = ( | ||
| weights["acc_weight"] * acc_reward + | ||
| weights["format_weight"] * format_reward + | ||
| weights["tool_weight"] * tool_reward | ||
| ) | ||
|
|
||
| print(f' [DEBUG] query={extra_info["question"]}, {ground_truth=}, {answer_text=}, {acc_reward=}, {format_reward=}') | ||
| return 0.8 * acc_reward + 0.2 * format_reward + 1.2 * tool_reward | ||
| print(f' [Dynamic Weights] acc={weights["acc_weight"]:.2f}, tool={weights["tool_weight"]:.2f}') | ||
|
|
||
| # 返回字典形式,包含详细信息以便追踪 | ||
| return { | ||
| "score": final_reward, | ||
| "acc": acc_reward, | ||
| "format": format_reward, | ||
| "tool": tool_reward, | ||
| "tool_calls": count_vision_1, | ||
| "acc_weight": weights["acc_weight"], | ||
| "format_weight": weights["format_weight"], | ||
| "tool_weight": weights["tool_weight"] | ||
| } | ||
|
|
||
|
|
||
| def rule_math_verify(ground_truth, model_answer): | ||
|
|
@@ -385,7 +556,16 @@ def generative_verify(query, ground_truth, model_answer): | |
| print(f' [ERROR math] verify bug output: ') | ||
|
|
||
|
|
||
| def compute_score_math(predict_str: str, ground_truth: str, extra_info=None) -> float: | ||
| def compute_score_math(predict_str: str, ground_truth: str, extra_info=None, history_stats=None) -> float: | ||
| """ | ||
| 计算数学视觉推理任务的奖励,支持基于历史统计的动态权重 | ||
|
|
||
| 参数: | ||
| predict_str: 模型生成的响应字符串 | ||
| ground_truth: 标准答案 | ||
| extra_info: 额外信息(如问题文本) | ||
| history_stats: 历史统计数据,用于计算动态权重 | ||
| """ | ||
| is_format_error = False | ||
| # predict_str = "<think>" + predict_str | ||
| count_think_1 = predict_str.count("<think>") | ||
|
|
@@ -411,8 +591,40 @@ def compute_score_math(predict_str: str, ground_truth: str, extra_info=None) -> | |
| acc_reward = 1.0 if generative_verify(extra_info['question'], ground_truth, model_answer) else 0.0 | ||
|
|
||
| format_reward = -1.0 if is_format_error else 0.0 | ||
|
|
||
| # 数学任务通常不涉及工具调用,但为了接口一致性,我们仍然支持动态权重 | ||
| # 这里使用不同的基础权重(数学任务更强调准确性) | ||
| if history_stats is None or history_stats.get("history_size", 0) == 0: | ||
| # 没有历史数据,使用原始固定权重 | ||
| final_reward = 1.2 * acc_reward + 0.4 * format_reward | ||
| acc_weight = 1.2 | ||
| format_weight = 0.4 | ||
| else: | ||
| # 有历史数据,使用动态权重 | ||
| # 数学任务的基础权重与视觉任务不同 | ||
| base_acc = 1.2 | ||
| base_format = 0.4 | ||
|
|
||
| # 基于历史统计进行调整 | ||
| # 对于数学任务,我们主要关注格式正确性和答案准确性 | ||
| # 如果历史格式错误率高,增加格式权重 | ||
| # 这里我们简化处理,直接使用基础权重,但可以根据需要扩展 | ||
|
|
||
| final_reward = base_acc * acc_reward + base_format * format_reward | ||
| acc_weight = base_acc | ||
| format_weight = base_format | ||
|
|
||
| print(f' [DEBUG] query={extra_info["question"]}, {ground_truth=}, {model_answer=}, {acc_reward=}, {format_reward=}') | ||
| return 1.2 * acc_reward + 0.4 * format_reward | ||
| print(f' [Weights] acc={acc_weight:.2f}, format={format_weight:.2f}') | ||
|
|
||
| # 返回字典形式,包含详细信息以便追踪 | ||
| return { | ||
| "score": final_reward, | ||
| "acc": acc_reward, | ||
| "format": format_reward, | ||
| "acc_weight": acc_weight, | ||
| "format_weight": format_weight | ||
| } | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_default_compute_score()now passes through dict results fromvl_agent(and returns them as-is). Some reward managers (e.g.,PrimeRewardManager.verify()) assume the reward function returns numeric scalars and will fail when converting dicts to tensors. Either ensurevl_agentscoring functions return a float here, or make the dict-to-float extraction explicit in_default_compute_score()(e.g., returnres["score"]while optionally exposing the extra fields elsewhere).