")
+
+ # Replace |
+ output_str = output_str.replace(" | ", " | ")
+
+ return output_str
+
+def convert_markdown_to_html(markdown_content):
+ # Define a regex pattern to find Markdown tables with newlines
+ markdown_content = markdown_content.replace('\r', '')+'\n'
+ pattern = re.compile(r'\|\s*.*?\s*\|\n', re.DOTALL)
+
+ # Find all matches in the Markdown content
+ matches = pattern.findall(markdown_content)
+
+ for match in matches:
+ html_table = markdown_to_html(match)
+ markdown_content = markdown_content.replace(match, html_table, 1) # Only replace the first occurrence
+
+ res_html = convert_table(replace_table_with_placeholder(markdown_content))
+
+ return res_html
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/__init__.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/deep_research.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/deep_research.py
new file mode 100644
index 00000000..76e0ca69
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/deep_research.py
@@ -0,0 +1,183 @@
+from typing import Any, Dict, List
+
+import pandas as pd
+from datasets import load_dataset
+from json_repair import repair_json
+
+from vlmeval.smp import dump, get_intermediate_file_path, load
+from vlmeval.utils.mp_util import track_progress_rich
+from ..text_base import TextBaseDataset
+from ..utils.judge_util import build_judge
+
+
+def extract_final_answer(answer_with_thinking: str, start_tag='', end_tag=''):
+ answer_with_thinking = str(answer_with_thinking)
+ start_index = answer_with_thinking.rfind(start_tag)
+ if start_index != -1:
+ end_index = answer_with_thinking.find(end_tag, start_index)
+ if end_index != -1:
+ return answer_with_thinking[start_index + len(start_tag):end_index].strip()
+ return None
+
+
+def eval_model_output(ques_dict, judge):
+ newline = '\n'
+ prompt = f"""
+You are an expert in systematically validating and evaluating \
+LLM-generated solutions. Your task is to rigorously analyze the \
+correctness of a provided solution by comparing it step-by-step \
+against the reference solution, and output **only** a structured \
+verification list—with no additional text.
+
+## Instructions
+1. Break down the given LLM solution into individual steps and \
+evaluate each one against the corresponding reference solution steps.
+2. For each step, include the following three components:
+ - **solution_step**: The specific part of the LLM solution \
+being evaluated.
+ - **reason**: A clear, critical explanation of whether the \
+step contains errors, omissions, or deviations from the reference \
+approach. Be stringent in your assessment.
+ - **judge**: Your verdict: either `"correct"` or `"incorrect"`.
+3. If the final LLM answer is incorrect, you must identify \
+at least one step in your analysis as incorrect.
+4. Justify your judgments rigorously, pointing out even minor \
+inaccuracies or logical flaws.
+5. Do not attempt to answer the original question—your role \
+is strictly to evaluate.
+6. Output **only** a list of dictionaries in the exact format \
+provided below. Do not include any other text or comments.
+
+## Question
+{ques_dict['question']}
+
+## Reference Solution Steps
+{newline.join(ques_dict['steps'])}
+
+## Reference Answer
+{ques_dict['answer']}
+
+## LLM Solution Steps
+{ques_dict['prediction']}
+
+## LLM Answer
+{extract_final_answer(ques_dict['prediction'])}
+
+## Output Example
+[
+ {{"solution_step": "step content", \
+"reason": "reason of the judgement", \
+"judge": "correct or incorrect"}},
+ {{"solution_step": "step content", \
+"reason": "reason of the judgement", \
+"judge": "correct or incorrect"}},
+]
+"""
+
+ try:
+ messages = [
+ {"role": "system", "value": "You are a helpful assistant.", "type": "text"},
+ {"role": "user", "value": prompt, "type": "text"},
+ ]
+ llm_judge = judge.generate(messages)
+ start_index = llm_judge.find('[')
+ end_index = llm_judge.rfind(']') + 1
+ llm_judge = eval(repair_json(llm_judge[start_index:end_index]))
+ correct_step_count = 0
+ for step in llm_judge:
+ if step["judge"] == "correct":
+ correct_step_count += 1
+ step_level_acc = correct_step_count / len(llm_judge)
+ except Exception as e:
+ print(e)
+ llm_judge = None
+ step_level_acc = 0
+
+ ques_dict['exact_match'] = 1 if (
+ ques_dict['answer'] == ques_dict['prediction']
+ or ques_dict['answer'] == extract_final_answer(
+ ques_dict['prediction']
+ )
+ ) else 0
+ ques_dict['llm_judge'] = llm_judge
+ ques_dict['step_level_acc'] = step_level_acc
+ return ques_dict
+
+
+class SGI_Bench_Deep_Research(TextBaseDataset):
+ TYPE = 'QA'
+
+ @classmethod
+ def supported_datasets(cls):
+ return ["SGI-DeepResearch"]
+
+ def load_data(self, dataset):
+ hf = load_dataset("InternScience/SGI-DeepResearch", split="test")
+
+ rows: List[Dict[str, Any]] = []
+ idx = 0
+ for prob in hf:
+ rows.append(
+ {
+ "index": idx,
+ "id": prob["idx"],
+ "question": prob["question"],
+ "steps": prob["steps"],
+ "answer": prob["answer"],
+ "discipline": prob["discipline"],
+ "direction": prob["direction"],
+ "type": prob["type"]
+ }
+ )
+ idx += 1
+ return pd.DataFrame(rows)
+
+ def build_prompt(self, line):
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+ question = line['question'] + """
+You can reason step by step before giving the final answer. \
+The final answer should be enclosed by and .
+
+Example:
+Step 1. ...
+Step 2. ...
+...
+1.00
+"""
+
+ msgs = [{'type': 'text', 'value': question}]
+ return msgs
+
+ def evaluate(self, eval_file, **judge_kwargs):
+ data = load(eval_file)
+ data = pd.DataFrame(data)
+
+ if judge_kwargs.get('model') is None:
+ judge_kwargs['model'] = 'o4-mini'
+ if judge_kwargs.get('max_tokens') is None:
+ judge_kwargs['max_tokens'] = None
+
+ inp_list = []
+ judge = build_judge(**judge_kwargs)
+ for item in data.to_dict(orient="records"):
+ inp_list.append({"ques_dict": item, "judge": judge})
+ out_list = track_progress_rich(
+ func=eval_model_output,
+ tasks=inp_list,
+ nproc=judge_kwargs.get('nproc', 48)
+ )
+
+ exact_match = sum([item['exact_match'] for item in out_list]) / len(out_list)
+ step_level_acc = sum([item['step_level_acc'] for item in out_list]) / len(out_list)
+
+ result = {
+ 'Exact Match': exact_match,
+ 'Step Level Acc': step_level_acc
+ }
+
+ score_file = get_intermediate_file_path(eval_file, '_score', 'json')
+ result_file = get_intermediate_file_path(eval_file, '_result', 'json')
+ dump(out_list, score_file)
+ dump(result, result_file)
+ return result
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/dry_experiment.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/dry_experiment.py
new file mode 100644
index 00000000..44e00f6f
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/dry_experiment.py
@@ -0,0 +1,599 @@
+import ast
+import os
+import platform
+import shutil
+import subprocess
+import time
+from pathlib import Path
+from typing import Any, Dict, List
+
+import pandas as pd
+import requests
+from datasets import load_dataset
+from json_repair import repair_json
+from requests.adapters import HTTPAdapter
+from urllib3.util.retry import Retry
+
+from vlmeval.smp import LMUDataRoot, dump, get_intermediate_file_path, load
+from vlmeval.utils.mp_util import track_progress_rich
+from ..text_base import TextBaseDataset
+from ..utils.judge_util import build_judge
+
+save_dir = "./outputs/sgi_code_logs"
+tmp_data_dir = "./outputs/sgi_tmp_data"
+
+env = os.environ.copy()
+env["PYTHONIOENCODING"] = "utf-8"
+
+
+def run_script_in_folder(folder_path):
+ """
+ Run data.py (if exists) and main.py in the given folder,
+ print immediate status, and return execution results.
+ """
+ script_name = 'data_en.py'
+ script_path_full = folder_path / script_name
+ try:
+ result = subprocess.run(
+ ["conda", "run", "-n", "dryexp", "python", script_name],
+ capture_output=True,
+ text=True,
+ timeout=10 * 60, # 10-minute timeout
+ encoding="utf-8",
+ cwd=str(folder_path),
+ env=env,
+ shell=platform.system() == "Windows"
+ )
+ if result.returncode == 0:
+ # print(f"✅")
+ result = (str(script_path_full), True, "")
+ else:
+ print("❌")
+ error_message = (
+ result.stderr.strip()
+ if result.stderr else "Unknown error"
+ )
+ result = (str(script_path_full), False, error_message)
+ except subprocess.TimeoutExpired:
+ print("❌")
+ print(" Error: Execution timed out after 10 minutes")
+ result = (
+ str(script_path_full), False,
+ "Execution timed out after 10 minutes"
+ )
+ except Exception as e:
+ print("❌")
+ print(f" Error: {e}")
+ result = (str(script_path_full), False, str(e))
+ return result
+
+
+def run_script(ques_dict):
+ ques_dict['unit_test'] = []
+ for unit_test_idx in range(5):
+ folder_path = os.path.join(
+ save_dir, ques_dict['idx'],
+ f"unit_test_{unit_test_idx}"
+ )
+ unit_test_dict = {}
+
+ try:
+ # Run the script and capture output
+ start_time = time.time()
+ result = subprocess.run(
+ ["conda", "run", "-n", "dryexp", "python", 'main_model.py'],
+ capture_output=True,
+ text=True,
+ timeout=300, # 5 minutes timeout
+ encoding="utf-8",
+ cwd=str(folder_path),
+ env=env,
+ shell=platform.system() == "Windows"
+ )
+ end_time = time.time()
+ elapsed = end_time - start_time
+ model_code_output = f"{result.stderr}\n{result.stdout}".strip()
+
+ if result.returncode == 0:
+ # print(f"✅")
+ unit_test_dict["model_error"] = "[No Error]"
+ unit_test_dict["model_runtime"] = elapsed
+ unit_test_dict["model_returncode"] = result.returncode
+ unit_test_dict["model_code_output"] = model_code_output
+ else:
+ # print(f"❌")
+ # print(f" Error: {error_message}")
+ unit_test_dict["model_error"] = (
+ "[WRONG]" + result.stderr.strip()
+ if result.stderr else "Unknown error"
+ )
+ unit_test_dict["model_runtime"] = elapsed
+ unit_test_dict["model_returncode"] = result.returncode
+ unit_test_dict["model_code_output"] = model_code_output
+ except subprocess.TimeoutExpired:
+ # print(f"❌")
+ # print(f" Error: Execution timed out after 5 minutes")
+ unit_test_dict["model_error"] = (
+ "[WRONG]Execution timed out after 5 minutes"
+ )
+ unit_test_dict["model_runtime"] = 300.0
+ unit_test_dict["model_returncode"] = -1 # Terminated
+ unit_test_dict["model_code_output"] = unit_test_dict["model_error"]
+ except Exception as e:
+ # print(f"❌")
+ # print(f" Error: {e}")
+ unit_test_dict["model_error"] = "[WRONG]" + str(e)
+ unit_test_dict["model_runtime"] = -1
+ unit_test_dict["model_returncode"] = 1 # Error
+ unit_test_dict["model_code_output"] = unit_test_dict["model_error"]
+ ques_dict['unit_test'].append(unit_test_dict)
+ return ques_dict
+
+
+def eval_model_output(ques_dict, judge):
+ for unit_test_idx in range(5):
+ unit_test_dict = ques_dict['unit_test'][unit_test_idx]
+ correct_output = ques_dict[f"unit_test_{unit_test_idx}_output"]
+ unit_test_dict['exact_match'] = (
+ 1 if (unit_test_dict['model_code_output']
+ == correct_output) else 0
+ )
+
+ if unit_test_dict["exact_match"]:
+ unit_test_dict["llm_judge"] = {
+ "judgment": "correct",
+ "reason": "Exact match."
+ }
+ unit_test_dict['pass'] = 1
+ ques_dict['unit_test'][unit_test_idx] = unit_test_dict
+ continue
+
+ if (unit_test_dict["model_error"].startswith("[WRONG]")
+ or unit_test_dict["model_returncode"] != 0):
+ unit_test_dict["llm_judge"] = {
+ "judgment": "incorrect",
+ "reason": "There are problems running "
+ "the completed code."
+ }
+ unit_test_dict['pass'] = 0
+ ques_dict['unit_test'][unit_test_idx] = unit_test_dict
+ continue
+
+ prompt = f"""
+You are an expert in evaluating model output accuracy. Your task \
+is to precisely determine whether the model output matches the \
+reference output and provide a brief explanation.
+
+## Instructions
+1. Check all numerical values and ensure strict accuracy—every \
+digit must match exactly. Any inconsistency should be considered \
+incorrect.
+2. For training-related loss values or metrics, if the difference \
+between model output and reference output loss or metric values \
+is greater than 2%, consider it incorrect.
+3. The output should be a dictionary without any other text in \
+the following format:
+example = {{
+ "judgment": "Placeholder, use 'correct' if outputs match, \
+'incorrect' otherwise",
+ "reason": "Brief explanation placeholder"
+}}
+
+## Reference Output
+{correct_output}
+
+## Model Output
+{unit_test_dict["model_code_output"]}
+"""
+
+ try:
+ messages = [
+ {"role": "system",
+ "value": "You are a helpful assistant.",
+ "type": "text"},
+ {"role": "user", "value": prompt, "type": "text"},
+ ]
+ llm_judge = judge.generate(messages)
+ start_index = llm_judge.find('{')
+ end_index = llm_judge.rfind('}') + 1
+ llm_judge = eval(repair_json(llm_judge[start_index:end_index]))
+ except Exception as e:
+ print(e)
+ llm_judge = None
+
+ unit_test_dict['llm_judge'] = llm_judge
+ if llm_judge and isinstance(llm_judge, dict):
+ unit_test_dict['pass'] = (
+ 1 if llm_judge.get('judgment') == 'correct'
+ else 0
+ )
+ else:
+ unit_test_dict['pass'] = 0
+ ques_dict['unit_test'][unit_test_idx] = unit_test_dict
+
+ ques_dict['pass_nums'] = sum(
+ [ut['pass'] for ut in ques_dict['unit_test']]
+ )
+ ques_dict['model_average_runtime'] = [
+ ut['model_runtime']
+ for ut in ques_dict['unit_test']
+ if ut['model_runtime'] > 0
+ ]
+ avg_rt = ques_dict['model_average_runtime']
+ ques_dict['model_average_runtime'] = (
+ sum(avg_rt) / len(avg_rt) if len(avg_rt) > 0
+ else -1
+ )
+ ques_dict['se'] = sum(
+ [1 if ut['model_returncode'] == 0 else 0
+ for ut in ques_dict['unit_test']]
+ ) / 5
+ return ques_dict
+
+
+def download_file(url: str, dir_path: str):
+ os.makedirs(dir_path, exist_ok=True)
+ filename = url.split("/")[-1]
+ save_path = os.path.join(dir_path, filename)
+
+ if os.path.exists(save_path):
+ return save_path
+ session = requests.Session()
+ retries = Retry(
+ total=3, backoff_factor=1,
+ status_forcelist=[500, 502, 503, 504]
+ )
+ session.mount('http://', HTTPAdapter(max_retries=retries))
+ session.mount('https://', HTTPAdapter(max_retries=retries))
+
+ try:
+ with session.get(url, stream=True, timeout=30) as response:
+ response.raise_for_status()
+ with open(save_path, "wb") as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ if chunk:
+ f.write(chunk)
+
+ return save_path
+
+ except Exception as e:
+ if os.path.exists(save_path):
+ os.remove(save_path)
+ print(f"Error downloading {url}: {e}")
+ raise e
+
+
+def extract_final_answer(
+ answer_with_thinking: str,
+ start_tag='',
+ end_tag=''
+):
+ answer_with_thinking = str(answer_with_thinking)
+ start_index = answer_with_thinking.rfind(start_tag)
+ if start_index != -1:
+ end_index = answer_with_thinking.find(end_tag, start_index)
+ if end_index != -1:
+ return answer_with_thinking[
+ start_index + len(start_tag):end_index
+ ].strip()
+ return None
+
+
+def check_syntax(code_string):
+ try:
+ # Try to compile the code string
+ compile(code_string, '', 'exec')
+ return True
+ except SyntaxError:
+ return False
+
+
+def get_function_lines(file_content):
+ node = ast.parse(file_content)
+
+ function_lines = {}
+
+ for item in node.body:
+ if isinstance(item, ast.FunctionDef):
+ func_name = item.name
+ start_line = item.lineno
+ end_line = item.end_lineno
+ function_lines[func_name] = (start_line, end_line)
+
+ return function_lines
+
+
+def replace_code(
+ content_1, start_line_1, end_line_1,
+ content_2, start_line_2, end_line_2
+):
+ lines_1 = content_1.splitlines(keepends=True)
+ lines_2 = content_2.splitlines(keepends=True)
+
+ lines_1[start_line_1 - 1:end_line_1] = lines_2[start_line_2 - 1:end_line_2]
+
+ return ''.join(lines_1)
+
+
+def replace_function(main_code, new_code, function_name):
+ assert check_syntax(main_code), "wrong main_code"
+ assert check_syntax(new_code), "wrong new_code"
+ functions_dict_1 = get_function_lines(main_code)
+ functions_dict_2 = get_function_lines(new_code)
+
+ start_line_1, end_line_1 = functions_dict_1[function_name]
+ start_line_2, end_line_2 = functions_dict_2[function_name]
+
+ main_code_after_replacing = replace_code(
+ main_code, start_line_1, end_line_1,
+ new_code, start_line_2, end_line_2
+ )
+ assert check_syntax(main_code_after_replacing), \
+ "wrong main_code after replacing"
+ return main_code_after_replacing
+
+
+class SGI_Bench_Dry_Experiment(TextBaseDataset):
+ TYPE = 'QA'
+
+ @classmethod
+ def supported_datasets(cls):
+ return ["SGI-DryExperiment"]
+
+ def load_data(self, dataset):
+ hf = load_dataset("InternScience/SGI-DryExperiment", split="test")
+
+ rows: List[Dict[str, Any]] = []
+ idx = 0
+ for prob in hf:
+ rows.append(
+ {
+ "index": idx,
+ "idx": prob["idx"],
+ "question": prob["question"],
+ "data_code": prob["data_code"],
+ "main_code": prob["main_code"],
+ "incomplete_main_code": prob["incomplete_main_code"],
+ "incomplete_functions": prob["incomplete_functions"],
+ "unit_test_0_data": prob["unit_test_0_data"],
+ "unit_test_0_output": prob["unit_test_0_output"],
+ "unit_test_1_data": prob["unit_test_1_data"],
+ "unit_test_1_output": prob["unit_test_1_output"],
+ "unit_test_2_data": prob["unit_test_2_data"],
+ "unit_test_2_output": prob["unit_test_2_output"],
+ "unit_test_3_data": prob["unit_test_3_data"],
+ "unit_test_3_output": prob["unit_test_3_output"],
+ "unit_test_4_data": prob["unit_test_4_data"],
+ "unit_test_4_output": prob["unit_test_4_output"],
+ "function_type": prob["function_type"],
+ "runtime": prob["runtime"],
+ "discipline": prob["discipline"],
+ "direction": prob["direction"],
+ }
+ )
+ idx += 1
+ return pd.DataFrame(rows)
+
+ def build_prompt(self, line):
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+ question = line['question'] + """
+Output the completed function enclosed within and tags.
+
+Example 1:
+
+def hello():
+ print("Hello")
+
+
+Example 2:
+
+def add(a, b):
+ return a+b
+
+def minus(a, b):
+ return a-b
+
+
+"""
+
+ msgs = [{'type': 'text', 'value': question}]
+ return msgs
+
+ def evaluate(self, eval_file, **judge_kwargs):
+ save_dir_last = 'sgi_code_logs'
+ global save_dir
+ work_dir = str(Path(eval_file).parents[0])
+ save_dir = os.path.join(work_dir, save_dir_last)
+ tmp_data_dir_last = 'sgi_tmp_data'
+ global tmp_data_dir
+ tmp_data_dir = os.path.join(LMUDataRoot(), tmp_data_dir_last)
+ data = load(eval_file)
+ data = pd.DataFrame(data)
+
+ # 输入数据准备
+ data_flag = os.path.join(
+ save_dir, 'data_construction.json'
+ )
+ if not os.path.exists(data_flag):
+ os.makedirs(os.path.join(save_dir), exist_ok=True)
+ os.makedirs(os.path.join(tmp_data_dir), exist_ok=True)
+ os.makedirs(os.path.join(tmp_data_dir, "0206"), exist_ok=True)
+ os.makedirs(os.path.join(tmp_data_dir, "0200"), exist_ok=True)
+ os.makedirs(os.path.join(tmp_data_dir, "0236"), exist_ok=True)
+
+ _base = "https://raw.githubusercontent.com/InternScience/SGI-Bench/main/evaluation/task_3_dry_experiment/data" # noqa: E501
+ download_file(
+ f"{_base}/SGI_DryExperiment_0206/t10k-images-idx3-ubyte.gz",
+ tmp_data_dir + "/0206")
+ download_file(
+ f"{_base}/SGI_DryExperiment_0206/t10k-labels-idx1-ubyte.gz",
+ tmp_data_dir + "/0206")
+ download_file(
+ f"{_base}/SGI_DryExperiment_0206/train-images-idx3-ubyte.gz",
+ tmp_data_dir + "/0206")
+ download_file(
+ f"{_base}/SGI_DryExperiment_0206/train-labels-idx1-ubyte.gz",
+ tmp_data_dir + "/0206")
+
+ download_file(
+ f"{_base}/SGI_DryExperiment_0200/adult.data",
+ tmp_data_dir + "/0200")
+ download_file(
+ f"{_base}/SGI_DryExperiment_0200/adult.test",
+ tmp_data_dir + "/0200")
+
+ download_file(
+ f"{_base}/SGI_DryExperiment_0236/3d-user-study-data.zip",
+ tmp_data_dir + "/0236")
+
+ code_dir_list = []
+ for index, item in data.iterrows():
+ for unit_test_idx in range(5):
+ code_dir = os.path.join(
+ save_dir, item['idx'],
+ f"unit_test_{unit_test_idx}"
+ )
+ code_dir_list.append(
+ {'folder_path': Path(code_dir)}
+ )
+ os.makedirs(code_dir, exist_ok=True)
+ data_dir = os.path.join(
+ save_dir, item['idx'],
+ f"unit_test_{unit_test_idx}", 'data'
+ )
+ os.makedirs(data_dir, exist_ok=True)
+
+ data_py = os.path.join(
+ code_dir, "data_en.py"
+ )
+ with open(data_py, "w", encoding="utf-8") as f:
+ f.write(
+ item[f"unit_test_{unit_test_idx}_data"]
+ )
+ main_py = os.path.join(
+ code_dir, "main_en.py"
+ )
+ with open(main_py, "w", encoding="utf-8") as f:
+ f.write(item["main_code"])
+
+ for i in range(5):
+ dst = os.path.join(
+ save_dir,
+ f"SGI_DryExperiment_0206/unit_test_{i}/data/mnist_raw"
+ )
+ shutil.copytree(
+ tmp_data_dir + "/0206", dst,
+ dirs_exist_ok=True
+ )
+
+ for i in range(5):
+ dst = os.path.join(
+ save_dir,
+ f"SGI_DryExperiment_0200/unit_test_{i}/data"
+ )
+ shutil.copytree(
+ tmp_data_dir + "/0200", dst,
+ dirs_exist_ok=True
+ )
+
+ for i in range(5):
+ dst = os.path.join(
+ save_dir,
+ f"SGI_DryExperiment_0236/unit_test_{i}"
+ f"/data/em_3d_user_study"
+ )
+ shutil.copytree(
+ tmp_data_dir + "/0236", dst,
+ dirs_exist_ok=True
+ )
+
+ all_results = track_progress_rich(
+ tasks=code_dir_list,
+ func=run_script_in_folder,
+ nproc=judge_kwargs.get("nproc", 4)
+ )
+ dump(all_results, os.path.join(save_dir, 'data_construction.json'))
+ # 输入数据准备
+
+ # 代码保存
+ for index, item in data.iterrows():
+ main_code = item['incomplete_main_code']
+ incomplete_functions = item['incomplete_functions']
+ answer = extract_final_answer(item['prediction'])
+ for incomplete_function in eval(incomplete_functions):
+ try:
+ main_code = replace_function(
+ main_code, answer,
+ incomplete_function
+ )
+ except Exception:
+ pass
+ for unit_test_idx in range(5):
+ save_path = os.path.join(
+ save_dir, item['idx'],
+ f"unit_test_{unit_test_idx}",
+ "main_model.py"
+ )
+ with open(save_path, 'w', encoding='utf-8') as f:
+ f.write(main_code)
+ # 代码保存
+
+ # 代码运行
+ inp_list = [
+ {"ques_dict": item}
+ for item in data.to_dict(orient="records")
+ ]
+ out_list = track_progress_rich(
+ tasks=inp_list, func=run_script, nproc=100
+ )
+ # 代码运行
+ if judge_kwargs.get('model') is None:
+ judge_kwargs['model'] = 'o4-mini'
+ if judge_kwargs.get('max_tokens') is None:
+ judge_kwargs['max_tokens'] = None
+ # 代码评测
+ judge = build_judge(**judge_kwargs)
+ in_list = [
+ {"ques_dict": item, "judge": judge}
+ for item in out_list
+ ]
+ out_list = track_progress_rich(
+ tasks=in_list,
+ func=eval_model_output, nproc=100
+ )
+ # 代码评测
+
+ PassAll_5 = sum(
+ [1 if (item['pass_nums'] == 5) else 0
+ for item in out_list]
+ ) / len(out_list)
+ PassAll_3 = sum(
+ [1 if (item['pass_nums'] >= 3) else 0
+ for item in out_list]
+ ) / len(out_list)
+ PassAll_1 = sum(
+ [1 if (item['pass_nums'] >= 1) else 0
+ for item in out_list]
+ ) / len(out_list)
+ runtimes = [
+ item['model_average_runtime']
+ for item in out_list
+ if item['model_average_runtime'] > 0
+ ]
+ AET = sum(runtimes) / len(runtimes)
+ SER = sum([item['se'] for item in out_list]) / len(out_list)
+
+ result = {
+ 'PassAll@5': PassAll_5,
+ 'PassAll@3': PassAll_3,
+ 'PassAll@1': PassAll_1,
+ 'AET': AET,
+ 'SER': SER
+ }
+
+ score_file = get_intermediate_file_path(eval_file, '_score', 'json')
+ result_file = get_intermediate_file_path(eval_file, '_result', 'json')
+ dump(out_list, score_file)
+ dump(result, result_file)
+ return result
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/dry_experiment_requirements.txt b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/dry_experiment_requirements.txt
new file mode 100644
index 00000000..823dcaba
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/dry_experiment_requirements.txt
@@ -0,0 +1,67 @@
+absl-py==2.3.1
+aiohappyeyeballs==2.6.1
+aiohttp==3.12.15
+aiosignal==1.4.0
+async-timeout==5.0.1
+attrs==25.3.0
+certifi==2025.8.3
+cftime==1.6.4.post1
+charset-normalizer==3.4.3
+colorama==0.4.6
+contourpy==1.3.2
+cycler==0.12.1
+datasets==2.15.0
+dill==0.3.7
+filelock==3.19.1
+fire==0.7.1
+fonttools==4.59.2
+frozenlist==1.7.0
+fsspec==2023.10.0
+h5py==3.10.0
+huggingface-hub==0.34.4
+idna==3.10
+imageio==2.37.0
+joblib==1.5.2
+kiwisolver==1.4.9
+lazy_loader==0.4
+matplotlib==3.7.2
+ml_dtypes==0.5.3
+multidict==6.6.4
+multiprocess==0.70.15
+netCDF4==1.6.4
+networkx==3.4.2
+numpy==1.24.3
+opt_einsum==3.4.0
+packaging==25.0
+pandas==2.0.3
+pathlib==1.0.1
+patsy==1.0.1
+Pillow==10.1.0
+propcache==0.3.2
+pyarrow==21.0.0
+pyarrow-hotfix==0.7
+pyparsing==3.0.9
+python-dateutil==2.9.0.post0
+pytz==2025.2
+PyWavelets==1.4.1
+PyYAML==6.0.2
+rdkit==2023.9.5
+requests==2.31.0
+scikit-image==0.22.0
+scikit-learn==1.3.2
+scipy==1.11.4
+seaborn==0.12.2
+six==1.17.0
+statsmodels==0.14.0
+termcolor==3.1.0
+threadpoolctl==3.6.0
+tifffile==2025.5.10
+toolz==1.0.0
+tqdm==4.66.2
+typing_extensions==4.15.0
+tzdata==2025.2
+urllib3==2.5.0
+xarray==2023.6.0
+xgboost==1.7.6
+xxhash==3.5.0
+yarl==1.20.1
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/experimental_reasoning.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/experimental_reasoning.py
new file mode 100644
index 00000000..14f74fbc
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/experimental_reasoning.py
@@ -0,0 +1,306 @@
+import ast
+import base64
+import io
+import os
+import os.path as osp
+import re
+from typing import Any, Dict, List
+
+import pandas as pd
+from datasets import load_dataset
+
+from vlmeval.smp import (decode_base64_to_image_file, dump, encode_image_to_base64,
+ get_intermediate_file_path, load, read_ok, toliststr)
+from vlmeval.utils.mp_util import track_progress_rich
+from ..image_base import ImageBaseDataset
+from ..utils.judge_util import build_judge
+
+
+def extract_answer_from_response(response):
+ match = re.search(r"\\boxed\{([A-Za-z])\}", response)
+ if match:
+ return match.group(1)
+ else:
+ return None
+
+
+def mm_reasoning_is_correct(pred, gold):
+ try:
+ ans = extract_answer_from_response(pred).strip()
+ except Exception:
+ return False
+ return ans.lower() == gold.lower()
+
+
+def b64_encode_image(img) -> str:
+ buffered = io.BytesIO()
+ img.save(buffered, format="PNG")
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
+
+
+def judge_aux(judge, row):
+ reference_steps = row["steps"]
+ reference_steps = "\n".join([f"{i + 1}. {step}" for i, step in enumerate(reference_steps)])
+
+ judge_prompt = (
+ f"You are a strict evaluator assessing the "
+ f"**validity of the model prediction's reasoning "
+ f"process**. You must score this reasoning validity "
+ f"on a scale from 0 to 10, where 0 means the "
+ f"reasoning is completely invalid and 10 means the "
+ f"reasoning is fully rigorous.\n"
+ f"# Input\n"
+ f"Question:\n"
+ f"```\n"
+ f"{row['question']}\n"
+ f"```\n"
+ f"Reference Reasoning:\n"
+ f"```\n"
+ f"{reference_steps}\n"
+ f"```\n"
+ f"Model Prediction:\n"
+ f"```\n"
+ f"{row['prediction']}\n"
+ f"```\n"
+ f"# Evaluation Rules\n"
+ f"1. First, identify the **complete reasoning "
+ f"process** from the model prediction (ignore only "
+ f"the final answer if it is not accompanied by "
+ f"reasoning).\n"
+ f"2. Evaluate reasoning validity against two core "
+ f"criteria:\n"
+ f" - **Logical Coherence**: Check if the reasoning "
+ f"steps are sequential, self-consistent, and free of "
+ f"contradictions (e.g., no conflicting premises or "
+ f"illogical deductions).\n"
+ f" - **Alignment with Reference Reasoning**: Check "
+ f"if the reasoning direction, key premises, and "
+ f"deduction logic match the reference reasoning "
+ f"(partial alignment counts for partial credit).\n"
+ f"3. Deduct points for:\n"
+ f" - Irrelevant content (reasoning that does not "
+ f"address the question or key conditions).\n"
+ f" - Missing key reasoning steps (even if the "
+ f"final answer is correct).\n"
+ f" - Flawed logic (e.g., circular reasoning, "
+ f"false premises leading to conclusions).\n"
+ f"4. Do not prioritize the correctness of the "
+ f"**final answer**\u2014a correct answer with invalid "
+ f"reasoning still scores low, while an incorrect "
+ f"answer with partially valid reasoning may score "
+ f"higher.\n"
+ f"# Scoring Guide\n"
+ f"- **10**: Reasoning is fully rigorous, logically "
+ f"coherent (no contradictions), and perfectly "
+ f"aligned with the reference reasoning (all key "
+ f"steps and logic match).\n"
+ f"- **7-9**: Reasoning is mostly coherent, with "
+ f"minor logical gaps or partial misalignment with "
+ f"the reference reasoning (no major "
+ f"contradictions).\n"
+ f"- **4-6**: Reasoning has obvious logical flaws "
+ f"(e.g., one missing key step, minor "
+ f"contradictions) or limited alignment with the "
+ f"reference reasoning (only some core logic "
+ f"matches).\n"
+ f"- **1-3**: Reasoning is barely valid, with severe "
+ f"logical flaws (e.g., multiple contradictions) or "
+ f"almost no alignment with the reference reasoning "
+ f"(only tangentially related to the question).\n"
+ f"- **0**: Reasoning is completely invalid, "
+ f"contradictory (self-conflicting logic), or "
+ f"irrelevant (no connection to the question or key "
+ f"conditions).\n"
+ f"# Strict Output format example\n"
+ f"6"
+ )
+ try:
+ msgs = []
+ msgs.append({'role': 'system', 'value': 'You are a helpful assistant.'})
+ msgs.append({'role': 'user', 'type': 'text', 'value': judge_prompt})
+
+ images = ast.literal_eval(row['step_images'])
+ for image in images:
+ msgs.append({'role': 'user', 'value': image, 'type': 'image'})
+ llm_judge = judge.generate(msgs).strip()
+ pattern = r"(\d+)"
+ match = re.search(pattern, llm_judge)
+ rv_score = float(match.group(1)) if match else 0.0
+ except Exception:
+ rv_score = 0.0
+
+ mcc_score = mm_reasoning_is_correct(row['prediction'], chr(ord('A') + int(row['answer'])))
+ return dict(mcc_score=mcc_score, rv_score=rv_score)
+
+
+class SGI_Bench_Experimental_Reasoning(ImageBaseDataset):
+ TYPE = 'MCQ '
+
+ @classmethod
+ def supported_datasets(cls):
+ return ["SGI-Experimental-Reasoning"]
+
+ def dump_images(self, line):
+ step_dir = osp.join(self.img_root, 'step_images')
+ os.makedirs(self.img_root, exist_ok=True)
+ os.makedirs(step_dir, exist_ok=True)
+
+ results = {}
+
+ def _process_field(key_name, path_key_name, save_root):
+ tgt_paths = []
+ if key_name in line:
+ content = line[key_name]
+ if path_key_name in line and isinstance(line[path_key_name], list):
+ fnames = line[path_key_name]
+ else:
+ count = len(content) if isinstance(content, list) else 1
+ fnames = [f"{line['index']}_{i}.png" for i in range(count)]
+ imgs = content if isinstance(content, list) else [content]
+ for img, fname in zip(imgs, fnames):
+ full_path = osp.join(save_root, fname)
+ if not read_ok(full_path):
+ decode_base64_to_image_file(img, full_path)
+ tgt_paths.append(full_path)
+
+ elif path_key_name in line:
+ paths = toliststr(line[path_key_name])
+ read_ok_flag = [read_ok(x) for x in paths]
+
+ if not all(read_ok_flag):
+ paths_abs = [osp.join(save_root, x) for x in paths]
+ read_ok_flag = [read_ok(x) for x in paths_abs]
+ assert read_ok_flag, f"Field `{key_name}` missing and files not found: {paths}"
+ tgt_paths = paths_abs
+ else:
+ tgt_paths = paths
+
+ return tgt_paths
+
+ if 'image' in line or 'image_path' in line:
+ results['image'] = _process_field('image', 'image_path', self.img_root)
+ if 'step_images' in line or 'step_image_path' in line:
+ results['step_images'] = _process_field('step_images', 'step_image_path', step_dir)
+
+ return results
+
+ def load_data(self, dataset):
+ hf = load_dataset("InternScience/SGI-Reasoning", split="test")
+
+ rows: List[Dict[str, Any]] = []
+ idx = 0
+
+ for prob in hf:
+ current_row = {
+ "index": idx, #
+ "id": prob["idx"],
+ "question": prob["question"],
+ "image": [encode_image_to_base64(img) for img in prob["images"]],
+ "options": prob["options"],
+ "steps": prob["steps"],
+ "step_images": [encode_image_to_base64(img) for img in prob["step_images"]],
+ "answer": prob["answer"],
+ "image_type": prob["image_type"],
+ "discipline": prob["discipline"],
+ "direction": prob["direction"],
+ "type": prob["type"]
+ }
+ saved_paths = self.dump_images(current_row)
+ if 'image' in saved_paths:
+ current_row['image'] = saved_paths['image']
+
+ if 'step_images' in saved_paths:
+ current_row['step_images'] = saved_paths['step_images']
+ rows.append(current_row)
+ idx += 1
+
+ return pd.DataFrame(rows)
+
+ def build_prompt(self, line):
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+ question = (
+ "Please solve the following multiple-choice "
+ "question step-by-step. Each question is "
+ "provided with several options labeled A, B, "
+ "C, D, E, etc. Carefully analyze the question "
+ "and each option, reason step-by-step, then "
+ "select the single most correct option.\n\n"
+ "Your final output **must** include both "
+ "**the reasoning** and **the final answer**. "
+ "The final answer must meet two core "
+ "requirements:\n"
+ "1. It consists solely of the corresponding "
+ "letter of the correct option (e.g., A, B, C, "
+ "D, E, etc.);\n"
+ "2. This letter is enclosed in the \\boxed{} "
+ "format. Example: \\boxed{A}"
+ "\n\nQuestion:\n" + line['question']
+ + "\n\nOptions:\n"
+ )
+ for i, option in enumerate(line['options']):
+ option_label = chr(ord('A') + i)
+ question += f"{option_label}. {option}\n"
+
+ msgs = []
+ if isinstance(line['image'], list):
+ for p in line['image']:
+ msgs.append({'type': 'image', 'value': p})
+ elif isinstance(line['image'], str):
+ msgs.append({'type': 'image', 'value': line['image']})
+ msgs.append({'type': 'text', 'value': question})
+ return msgs
+
+ def evaluate(self, eval_file, **judge_kwargs):
+ data = load(eval_file)
+ data = pd.DataFrame(data)
+
+ data['mcc'] = 0
+ data['rv'] = 0
+
+ all_mcc, all_rv = [], []
+ if judge_kwargs.get('model') is None:
+ judge_kwargs['model'] = 'o4-mini'
+ if judge_kwargs.get('max_tokens') is None:
+ judge_kwargs['max_tokens'] = None
+ judge = build_judge(**judge_kwargs)
+
+ tups = []
+ indices = []
+ tmp_file = get_intermediate_file_path(eval_file, '_judge_tmp', 'pkl')
+ if osp.exists(tmp_file):
+ ans = load(tmp_file)
+ else:
+ ans = {}
+
+ for index, row in data.iterrows():
+ if index in ans:
+ continue
+ tups.append(dict(judge=judge, row=row))
+ indices.append(index)
+
+ if len(indices) > 0:
+ track_progress_rich(
+ judge_aux,
+ tasks=tups,
+ nproc=judge_kwargs.get('nproc', 32),
+ save=tmp_file,
+ keys=indices
+ )
+ ans = load(tmp_file)
+
+ for index, res in ans.items():
+ rv_score = res['rv_score']
+ mcc_score = res['mcc_score']
+ all_mcc.append(mcc_score)
+ data.at[index, 'mcc'] = 1 if mcc_score else 0
+ all_rv.append(rv_score)
+ data['rv'] = data['rv'].astype(float)
+ data.at[index, 'rv'] = rv_score / 10.0
+
+ score_file = get_intermediate_file_path(eval_file, '_score', 'csv')
+ result = {"MCC": sum(all_mcc) / len(all_mcc), "RV": sum(all_rv) / (10.0 * len(all_rv))}
+ result_file = get_intermediate_file_path(eval_file, '_result', 'json')
+ dump(data, score_file)
+ dump(result, result_file)
+ return result
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/idea_generation.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/idea_generation.py
new file mode 100644
index 00000000..6ea6b54a
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/idea_generation.py
@@ -0,0 +1,896 @@
+import ast
+import json
+import os.path as osp
+import re
+import time
+from datetime import datetime
+from typing import Any, Dict, List
+
+import networkx as nx
+import numpy as np
+import pandas as pd
+from datasets import load_dataset
+
+from vlmeval.smp import dump, get_intermediate_file_path, load
+from vlmeval.smp.log import get_logger
+from vlmeval.utils.mp_util import track_progress_rich
+from ..text_base import TextBaseDataset
+from ..utils.judge_util import build_judge
+from .utils import (flip_evaluation_result, format_idea_data, get_context_from_data,
+ get_evaluation_prompt_modified, parse_evaluation_result)
+
+embedding_model = None
+logger = get_logger(__name__)
+
+
+def parse_generated_idea(text: str) -> Dict[str, Any]:
+ """Parse the generated research proposal text into a structured dictionary"""
+ json_block_pattern = r"```(?:json)?\s*([\s\S]*?)```"
+ json_block_match = re.search(json_block_pattern, text)
+ if json_block_match:
+ json_str = json_block_match.group(1).strip()
+ try:
+ parsed_data = json.loads(json_str)
+ return parsed_data
+ except json.JSONDecodeError:
+ pass
+ try:
+ parsed_data = json.loads(text)
+ return parsed_data
+ except json.JSONDecodeError:
+ pass
+ result = {}
+ idea_patterns = [
+ r"[\"']?Idea[\"']?\s*:\s*[\"'](.*?)[\"']",
+ r"1\.\s*Idea[:\s-]+(.*?)(?=\n\s*(?:2\.|Implementation))",
+ ]
+ for pattern in idea_patterns:
+ match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
+ if match:
+ result["Idea"] = match.group(1).strip()
+ break
+ steps_patterns = [
+ r"[\"']?ImplementationSteps[\"']?\s*:\s*\{(.*?)\}",
+ r"2\.\s*Implementation Steps[:\s-]+(.*?)(?=\n\s*(?:3\.|Implementation Order))",
+ ]
+ for pattern in steps_patterns:
+ match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
+ if match:
+ steps_text = match.group(1).strip()
+ steps_dict = {}
+ step_matches = re.findall(r"[\"'](\d+)[\"']\s*:\s*[\"'](.*?)[\"']", steps_text)
+ for step_num, step_desc in step_matches:
+ steps_dict[step_num] = step_desc.strip()
+ if steps_dict:
+ result["ImplementationSteps"] = steps_dict
+ break
+ order_patterns = [
+ r"[\"']?ImplementationOrder[\"']?\s*:\s*\[(.*?)\]",
+ r"3\.\s*Implementation Order[:\s-]+(.*?)(?=\n\s*(?:4\.|Dataset))",
+ ]
+ for pattern in order_patterns:
+ match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
+ if match:
+ order_text = match.group(1).strip()
+ order_list = re.findall(r'["\']([^"\']+)["\']', order_text)
+ if order_list:
+ result["ImplementationOrder"] = order_list
+ break
+ dataset_patterns = [
+ r"[\"']?Dataset[\"']?\s*:\s*[\"'](.*?)[\"'](?=\s*,\s*[\"'])",
+ r"4\.\s*Dataset[:\s-]+(.*?)(?=\n\s*(?:5\.|Evaluation))",
+ ]
+ for pattern in dataset_patterns:
+ match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
+ if match:
+ result["Dataset"] = match.group(1).strip()
+ break
+ metrics_patterns = [
+ r"[\"']?EvaluationMetrics[\"']?\s*:\s*\{(.*?)\}(?=\s*,\s*[\"'])",
+ r"5\.\s*Evaluation Metrics[:\s-]+(.*?)(?=\n\s*(?:6\.|Expected))",
+ ]
+ for pattern in metrics_patterns:
+ match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
+ if match:
+ metrics_text = match.group(1).strip()
+ metrics_dict = {}
+ metric_matches = re.findall(r"[\"']([^\"']+)[\"']\s*:\s*[\"'](.*?)[\"']", metrics_text)
+ for metric_name, metric_desc in metric_matches:
+ metrics_dict[metric_name.strip()] = metric_desc.strip()
+ if metrics_dict:
+ result["EvaluationMetrics"] = metrics_dict
+ break
+ outcome_patterns = [
+ r"[\"']?ExpectedOutcome[\"']?\s*:\s*[\"'](.*?)[\"']",
+ r"6\.\s*Expected Outcome[:\s-]+(.*?)$",
+ ]
+ for pattern in outcome_patterns:
+ match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
+ if match:
+ result["ExpectedOutcome"] = match.group(1).strip()
+ break
+ if not result:
+ result["full_text"] = text
+ return result
+
+
+# 所有原始工具函数
+def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
+ a_norm = np.linalg.norm(a, axis=1, keepdims=True)
+ b_norm = np.linalg.norm(b, axis=1, keepdims=True)
+ a_norm = np.where(a_norm == 0, 1, a_norm)
+ b_norm = np.where(b_norm == 0, 1, b_norm)
+ a_normalized = a / a_norm
+ b_normalized = b / b_norm
+ return np.dot(a_normalized, b_normalized.T)
+
+
+def edge_jaccard(G1, G2):
+ edges1 = set(G1.edges())
+ edges2 = set(G2.edges())
+ if not edges1 and not edges2:
+ return 1.0
+ return len(edges1 & edges2) / len(edges1 | edges2)
+
+
+def node_text_similarity(G1, G2):
+ texts1 = [G1.nodes[n]['text'] for n in G1.nodes()]
+ texts2 = [G2.nodes[n]['text'] for n in G2.nodes()]
+ if not texts1 or not texts2:
+ logger.warning("node_text_similarity: One of the graphs has no node texts.")
+ return 0.0
+ try:
+ combined_text1 = ' '.join(texts1)
+ combined_text2 = ' '.join(texts2)
+ if len(combined_text1.strip()) < 3 or len(combined_text2.strip()) < 3:
+ logger.warning("node_text_similarity: One of the texts is too short to compare.")
+ return 0.0
+ words1 = set(combined_text1.lower().split())
+ words2 = set(combined_text2.lower().split())
+ if not words1 or not words2:
+ return 0.0
+ intersection = words1.intersection(words2)
+ union = words1.union(words2)
+ jaccard_sim = len(intersection) / len(union) if union else 0.0
+ return jaccard_sim
+ except Exception:
+ return 0.0
+
+
+def graph_similarity(dict1, dict2, alpha=0.5):
+ if not all(k in dict1 for k in ["ImplementationSteps", "ImplementationOrder"]) or \
+ not all(k in dict2 for k in ["ImplementationSteps", "ImplementationOrder"]):
+ logger.warning("graph_similarity: One of the graphs is missing necessary keys.")
+ return 0.0
+ if not dict1["ImplementationSteps"] or not dict1["ImplementationOrder"] or \
+ not dict2["ImplementationSteps"] or not dict2["ImplementationOrder"]:
+ logger.warning("graph_similarity: One of the graphs is missing necessary keys.")
+ return 0.0
+ try:
+ G1 = nx.DiGraph()
+ G2 = nx.DiGraph()
+ for k, v in dict1["ImplementationSteps"].items():
+ G1.add_node(str(k), text=v)
+ for k, v in dict2["ImplementationSteps"].items():
+ G2.add_node(str(k), text=v)
+ if len(G1.nodes()) == 0 or len(G2.nodes()) == 0:
+ return 0.0
+
+ def process_order_items(order_list, graph, step_keys):
+ edges_added = False
+ if all(o.isdigit() for o in order_list):
+ nodes = sorted([o for o in order_list if o in step_keys])
+ for i in range(len(nodes) - 1):
+ graph.add_edge(nodes[i], nodes[i + 1])
+ edges_added = True
+ else:
+ for o in order_list:
+ if "-" in o:
+ try:
+ src, dst = o.split("-")
+ if src in step_keys and dst in step_keys:
+ graph.add_edge(src, dst)
+ edges_added = True
+ except Exception:
+ pass
+ return edges_added
+
+ step_keys1 = [str(k) for k in dict1["ImplementationSteps"].keys()]
+ step_keys2 = [str(k) for k in dict2["ImplementationSteps"].keys()]
+ edges_added_G1 = process_order_items(dict1["ImplementationOrder"], G1, step_keys1)
+ edges_added_G2 = process_order_items(dict2["ImplementationOrder"], G2, step_keys2)
+ if not edges_added_G1:
+ nodes1 = sorted([n for n in G1.nodes()])
+ for i in range(len(nodes1) - 1):
+ G1.add_edge(nodes1[i], nodes1[i + 1])
+ edges_added_G1 = True
+ if not edges_added_G2:
+ nodes2 = sorted([n for n in G2.nodes()])
+ for i in range(len(nodes2) - 1):
+ G2.add_edge(nodes2[i], nodes2[i + 1])
+ edges_added_G2 = True
+ if not edges_added_G1 or not edges_added_G2:
+ logger.warning(
+ "graph_similarity: One of the graphs has no edges, only node text similarity will be computed.")
+ return node_text_similarity(G1, G2)
+ edge_sim = edge_jaccard(G1, G2)
+ text_sim = node_text_similarity(G1, G2)
+ return alpha * edge_sim + (1 - alpha) * text_sim
+ except Exception:
+ return 0.0
+
+
+def calculate_semantic_repetition(text: str) -> float:
+ sentences = [s.strip() for s in re.split(r'[.!?。!?]', text) if len(s.strip()) > 10]
+ if len(sentences) < 2:
+ return 0.0
+ try:
+ if embedding_model is None:
+ logger.warning("embedding_model is not available, cannot compute semantic repetition")
+ return 0.0
+ sentence_embeddings = embedding_model.encode(sentences)
+ similarity_matrix = cosine_similarity(sentence_embeddings, sentence_embeddings)
+ upper_triangle = []
+ for i in range(len(sentences)):
+ for j in range(i + 1, len(sentences)):
+ upper_triangle.append(similarity_matrix[i][j])
+ if not upper_triangle:
+ return 0.0
+ avg_similarity = np.mean(upper_triangle)
+ penalty = max(0, (avg_similarity - 0.2) * 10)
+ return min(penalty, 10.0)
+ except Exception as e:
+ logger.error(f"calculate_semantic_repetition error: {e}")
+ return 0.0
+
+
+def get_vote_from_model(model, original_idea_data, generated_idea_data, context=None, swap_positions=False):
+ original_idea_text = format_idea_data(original_idea_data)
+ generated_idea_text = format_idea_data(generated_idea_data)
+
+ # determine positions for evaluation
+ if swap_positions:
+ # swap positions: generated idea as A, original idea as B
+ prompt = get_evaluation_prompt_modified(generated_idea_text, original_idea_text, context)
+ positions_swapped = True
+ else:
+ # default positions: original idea as A, generated idea as B
+ prompt = get_evaluation_prompt_modified(original_idea_text, generated_idea_text, context)
+ positions_swapped = False
+
+ MAX_RETRIES = 5
+ retry_count = 0
+ while retry_count < MAX_RETRIES:
+ try:
+ response = model.generate(message=dict(type='text', value=prompt), temperature=0.1)
+ if response is None:
+ retry_count += 1
+ logger.warning(f"model {model.model} API call failed, retry {retry_count}")
+ time.sleep(1)
+ continue
+ evaluation_result = parse_evaluation_result(response)
+ if evaluation_result is None:
+ retry_count += 1
+ logger.warning(f"model {model.model} evaluation result parse error, retry {retry_count}")
+ time.sleep(1)
+ continue
+ if positions_swapped:
+ evaluation_result = flip_evaluation_result(evaluation_result)
+ return evaluation_result
+ except Exception as e:
+ retry_count += 1
+ logger.error(f"model {model.model} evaluation error: {e}, try {retry_count}")
+ time.sleep(1)
+
+ logger.warning(f"model {model.model} evaluation failed after {MAX_RETRIES} retries")
+ return None
+
+
+def compare_ideas_with_voting(original_idea_data, generated_idea_data, context=None, judge_models=None):
+ dimensions = ["effectiveness", "novelty", "detailedness", "feasibility", "overall"]
+ vote_counts = {
+ dim: {"original": 0, "generated": 0} for dim in dimensions
+ }
+ all_evaluations = []
+
+ for model in judge_models:
+ for swap in [False, True]: # each model votes twice, once with normal positions, once with swapped positions
+ evaluation = get_vote_from_model(
+ model=model,
+ original_idea_data=original_idea_data,
+ generated_idea_data=generated_idea_data,
+ context=context,
+ swap_positions=swap
+ )
+ if evaluation:
+ vote_detail = {
+ "model": model,
+ "positions_swapped": swap,
+ "results": {}
+ }
+ for dim in dimensions:
+ dim_result = evaluation.get(dim, {})
+ judgment = dim_result.get("judgment", "")
+ reason = dim_result.get("reason", "No reason provided")
+ if judgment == "win_A":
+ vote_counts[dim]["original"] += 1
+ result = "original_wins"
+ elif judgment == "win_B":
+ vote_counts[dim]["generated"] += 1
+ result = "generated_wins"
+ else:
+ logger.warning(f"error: {judgment}")
+ continue
+ vote_detail["results"][dim] = {
+ "result": result,
+ "reason": reason
+ }
+ all_evaluations.append(vote_detail)
+ else:
+ logger.error(f"model {model} evaluation failed, could not get votes")
+
+ final_results = {}
+ for dim in dimensions:
+ original_votes = vote_counts[dim]["original"]
+ generated_votes = vote_counts[dim]["generated"]
+ lose_gate = 2
+ if dim == "novelty":
+ win_gate = 4
+ else:
+ win_gate = 3
+
+ if generated_votes > win_gate:
+ result = "win"
+ reason = f"Generated idea received {generated_votes} votes, Original idea received {original_votes} votes."
+ elif generated_votes <= lose_gate:
+ result = "lose"
+ reason = f"Original idea received {original_votes} votes, Generated idea received {generated_votes} votes."
+ else:
+ result = "tie"
+ reason = f"Generated idea received {generated_votes} votes, Original idea received {original_votes} votes."
+
+ final_results[dim] = {
+ "res": result,
+ "reason": reason,
+ "vote_detail": {
+ "original_votes": original_votes,
+ "generated_votes": generated_votes
+ }
+ }
+
+ return {
+ "final_results": final_results,
+ "all_evaluations": all_evaluations
+ }
+
+
+# ImprovedIdeaEvaluator class
+class ImprovedIdeaEvaluator:
+ def __init__(self, idea_dict: dict):
+ self.idea_dict = idea_dict
+ self.original_data = {k: v for k, v in idea_dict.items() if k not in ["generated_idea_text", "generated_data"]}
+ self.original_data["Idea"] = self.original_data.get("core_idea", "")
+ self.original_data["RelatedWork"] = ast.literal_eval(self.original_data.get("related_work", "{}"))
+ self.original_data["ExistingSolutions"] = ast.literal_eval(self.original_data.get("existing_solutions", "{}"))
+ self.original_data["ImplementationSteps"] = ast.literal_eval(
+ self.original_data.get("implementation_steps", "{}"))
+ self.original_data["ImplementationOrder"] = ast.literal_eval(
+ self.original_data.get("implementation_order", "[]"))
+ self.original_data["EvaluationMetrics"] = ast.literal_eval(self.original_data.get("evaluation_metrics", "{}"))
+ self.original_data["Dataset"] = self.original_data.get("data", "")
+ self.original_data["ExpectedOutcome"] = self.original_data.get("expected_outcome", "")
+ self.generated_data = idea_dict["generated_data"]
+ self.idea = self.generated_data.get("Idea", "")
+ self.generated_data["Idea"] = self.idea
+ self.implementation_steps = self.generated_data.get("ImplementationSteps", {})
+ self.implementation_order = self.generated_data.get("ImplementationOrder", {})
+ self.dataset = self.generated_data.get("Dataset", "")
+ self.generated_data["Dataset"] = self.dataset
+ self.evaluation_metrics = self.generated_data.get("EvaluationMetrics", "")
+ self.expected_outcome = self.generated_data.get("ExpectedOutcome", "")
+ self.raw_scores = {
+ "novelty_similarity": 0.0,
+ "cutting_edge": 0.0,
+ "effectiveness_objective": 0.0,
+ "feasibility_objective": 0.0,
+ "completeness": 0.0,
+ "length_penalty": 0.0,
+ "repetition_penalty": 0.0
+ }
+ self.scores = {
+ "novelty_objective": 0.0,
+ "feasibility_objective": 0.0,
+ "detailedness_objective": 0.0,
+ "effectiveness_objective": 0.0,
+ "novelty": "",
+ "effectiveness": "",
+ "detailedness": "",
+ "feasibility": "",
+ }
+ self.details = {}
+
+ def evaluate_novelty_objective(self) -> None:
+ try:
+ text_to_compare = self.idea
+ related_work = self.original_data.get("RelatedWork", {})
+ existing_methods = self.original_data.get("ExistingSolutions", {})
+ all_existing_text = []
+ all_existing_text.extend(related_work.values())
+ all_existing_text.extend(existing_methods.values())
+ if all_existing_text and embedding_model is not None:
+ idea_embedding = embedding_model.encode([text_to_compare])
+ similarities = []
+ for existing_text in all_existing_text:
+ existing_embedding = embedding_model.encode([existing_text])
+ similarity = cosine_similarity(
+ idea_embedding.reshape(1, -1),
+ existing_embedding.reshape(1, -1)
+ )[0][0]
+ similarities.append(similarity)
+ avg_similarity = np.mean(similarities)
+ novelty_similarity_score = (1 - avg_similarity) * 10
+ novelty_similarity_score = max(0, min(10, novelty_similarity_score))
+ else:
+ novelty_similarity_score = 0.0
+ self.raw_scores["novelty_similarity"] = novelty_similarity_score
+ ref_related_work = self.original_data.get("related_work_test", "")
+ idea_embedding = embedding_model.encode([self.idea])
+ similarities = []
+ ref_related_work = ast.literal_eval(ref_related_work)
+ for key, value in ref_related_work.items():
+ snippet_data = f"{key}: {value}"
+ snippet_embedding = embedding_model.encode([snippet_data])
+ similarity = cosine_similarity(
+ idea_embedding.reshape(1, -1),
+ snippet_embedding.reshape(1, -1)
+ )[0][0]
+ similarities.append(similarity)
+ avg_similarity = np.mean(similarities)
+ cutting_edge_score = (1 - avg_similarity) * 10
+ cutting_edge_score = max(0, min(10, cutting_edge_score))
+ self.raw_scores["cutting_edge"] = cutting_edge_score
+ except Exception as e:
+ logger.error(f"Error in novelty evaluation: {e}")
+ self.raw_scores["novelty_similarity"] = 0.0
+ self.raw_scores["cutting_edge"] = 0.0
+ self.details["novelty_similarity"] = f"error: {str(e)}"
+ self.details["cutting_edge"] = f"error: {str(e)}"
+
+ def evaluate_effectiveness_objective(self) -> None:
+ try:
+ original_terms = self.original_data.get("keywords", [])
+ if embedding_model is None:
+ self.scores["effectiveness_objective"] = 0.0
+ self.details["effectiveness_objective"] = "embedding_model is not available"
+ return
+ terms_text = ", ".join([str(term) for term in original_terms])
+ idea_text = self.idea
+ try:
+ embeddings = embedding_model.encode([terms_text, idea_text], normalize_embeddings=True)
+ similarity = np.dot(embeddings[0], embeddings[1])
+ prof_score = similarity * 10
+ self.scores["effectiveness_objective"] = max(0, min(10, prof_score))
+ except Exception as e:
+ logger.error(f"Error computing embedding similarity: {e}")
+ matched_terms = []
+ generated_text_lower = idea_text.lower() if isinstance(idea_text, str) else ""
+ for term in original_terms:
+ term_str = str(term).lower()
+ if term_str in generated_text_lower:
+ matched_terms.append(term)
+ hit_rate = len(matched_terms) / len(original_terms) if original_terms else 0
+ self.scores["effectiveness_objective"] = hit_rate * 10
+ similarity = hit_rate
+ except Exception as e:
+ logger.error(f"Error in effectiveness_objective evaluation: {e}")
+ self.scores["effectiveness_objective"] = 0.0
+
+ def evaluate_completeness(self) -> None:
+ required_sections = [
+ "Idea",
+ "ImplementationSteps",
+ "ImplementationOrder",
+ "EvaluationMetrics",
+ "Dataset",
+ "ExpectedOutcome"
+ ]
+ section_found = {
+ "Idea": self.idea is not None,
+ "ImplementationSteps": self.implementation_steps is not None,
+ "ImplementationOrder": self.implementation_order is not None,
+ "EvaluationMetrics": self.evaluation_metrics is not None,
+ "Data": self.dataset is not None,
+ "ExpectedOutcome": self.expected_outcome is not None
+ }
+ total_sections = len(required_sections)
+ completed_sections = sum(section_found.values())
+ self.raw_scores["completeness"] = (completed_sections / total_sections) * 10
+ self.details["completeness"] = {
+ "total_sections": total_sections,
+ "completed_sections": completed_sections,
+ "completion_rate": completed_sections / total_sections,
+ }
+ missing_sections = [section for section, found in section_found.items() if not found]
+ if missing_sections:
+ logger.warning(f"Missing required sections: {', '.join(missing_sections)}")
+
+ def evaluate_feasibility_objective(self) -> None:
+ try:
+ generated_implementation = {
+ "ImplementationSteps": self.implementation_steps,
+ "ImplementationOrder": self.implementation_order
+ }
+ original_implementation = {
+ "ImplementationSteps": self.original_data["ImplementationSteps"],
+ "ImplementationOrder": self.original_data["ImplementationOrder"]
+ }
+ similarity = graph_similarity(
+ generated_implementation,
+ original_implementation,
+ alpha=0.6
+ )
+ self.scores["feasibility_objective"] = similarity * 10
+ self.details["feasibility_objective"] = {
+ "score": similarity,
+ }
+ except Exception as e:
+ logger.error(f"Error evaluating feasibility objective: {e}")
+ self.scores["feasibility_objective"] = 0.0
+ self.details["feasibility_objective"] = {"error": str(e)}
+
+ def evaluate_penalties(self) -> None:
+ if self.idea:
+ char_count = len(self.idea)
+ penalty = 0.0
+ if char_count > 700:
+ excess_chars = char_count - 700
+ penalty += excess_chars / 100.0
+ elif char_count < 300:
+ deficit_chars = 300 - char_count
+ penalty += deficit_chars / 100.0
+ self.raw_scores["length_penalty"] = min(penalty, 10.0)
+ else:
+ self.raw_scores["length_penalty"] = 0.0
+ if isinstance(self.idea, str):
+ self.raw_scores["repetition_penalty"] = calculate_semantic_repetition(self.idea)
+ else:
+ self.raw_scores["repetition_penalty"] = 0.0
+ self.details["penalties"] = {
+ "text_length": len(self.idea),
+ "length_penalty": self.raw_scores["length_penalty"],
+ "repetition_penalty": self.raw_scores["repetition_penalty"]
+ }
+
+ def LLM_multi_rounds(self, llm_judges):
+ try:
+ context = get_context_from_data(self.original_data)
+ evaluation_results = compare_ideas_with_voting(
+ original_idea_data=self.original_data,
+ generated_idea_data=self.generated_data,
+ context=context,
+ judge_models=llm_judges
+ )
+ summary = {
+ "evaluation_details": evaluation_results,
+ "timestamp": datetime.now().isoformat()
+ }
+ self.scores["novelty_subjective"] = evaluation_results["final_results"]["novelty"]["res"]
+ self.scores["effectiveness_subjective"] = evaluation_results["final_results"]["effectiveness"]["res"]
+ self.scores["detailedness_subjective"] = evaluation_results["final_results"]["detailedness"]["res"]
+ self.scores["feasibility_subjective"] = evaluation_results["final_results"]["feasibility"]["res"]
+ return {
+ "success": True,
+ "result": summary
+ }
+ except Exception as e:
+ logger.error(f"Error in LLM_multi_rounds: {e}")
+ return {
+ "success": False,
+ "error": str(e)
+ }
+
+ def merge_scores(self) -> None:
+ self.scores["novelty_objective"] = (
+ 0.5 * self.raw_scores["novelty_similarity"]
+ + 0.5 * self.raw_scores["cutting_edge"]
+ )
+ self.scores["detailedness_objective"] = (
+ 0.2 * self.raw_scores["completeness"]
+ + 0.4 * (10 - self.raw_scores["repetition_penalty"])
+ + 0.4 * (10 - self.raw_scores["length_penalty"])
+ )
+
+ def calculate_final_score(self, llm_judges) -> Dict[str, Any]:
+ self.LLM_multi_rounds(llm_judges)
+ self.evaluate_novelty_objective()
+ self.evaluate_effectiveness_objective()
+ self.evaluate_completeness()
+ self.evaluate_feasibility_objective()
+ self.evaluate_penalties()
+ self.merge_scores()
+
+ self.idea_dict.update({
+ "effectiveness_objective": float(self.scores["effectiveness_objective"]) * 10,
+ "novelty_objective": float(self.scores["novelty_objective"]) * 10,
+ "detailedness_objective": float(self.scores["detailedness_objective"]) * 10,
+ "feasibility_objective": float(self.scores["feasibility_objective"]) * 10,
+ "effectiveness_subjective": 100 if self.scores["effectiveness_subjective"] == 'win' else 0,
+ "novelty_subjective": 100 if self.scores["novelty_subjective"] == 'win' else 0,
+ "detailedness_subjective": 100 if self.scores["detailedness_subjective"] == 'win' else 0,
+ "feasibility_subjective": 100 if self.scores["feasibility_subjective"] == 'win' else 0,
+ })
+
+ self.idea_dict.update({
+ "effectiveness": (self.idea_dict["effectiveness_objective"] + self.idea_dict[
+ "effectiveness_subjective"]) / 2,
+ "novelty": (self.idea_dict["novelty_objective"] + self.idea_dict["novelty_subjective"]) / 2,
+ "detailedness": (self.idea_dict["detailedness_objective"] + self.idea_dict["detailedness_subjective"]) / 2,
+ "feasibility": (self.idea_dict["feasibility_objective"] + self.idea_dict["feasibility_subjective"]) / 2,
+ })
+
+ self.idea_dict["final_score"] = (
+ self.idea_dict["effectiveness"]
+ + self.idea_dict["novelty"]
+ + self.idea_dict["detailedness"]
+ + self.idea_dict["feasibility"]
+ ) / 4
+
+ return self.idea_dict
+
+
+def evaluate_single_idea(ques_dict, llm_judges):
+ try:
+ evaluator = ImprovedIdeaEvaluator(ques_dict)
+ evaluation_result = evaluator.calculate_final_score(llm_judges=llm_judges)
+ output = evaluation_result
+ return output
+ except Exception as e:
+ logger.error(f"evaluation error: {e}")
+ output = {
+ "error": str(e),
+ "final_score": 0.0
+ }
+ return output
+
+
+class SGI_Bench_Idea_Generation(TextBaseDataset):
+ TYPE = 'QA'
+ example = {
+ "Idea": (
+ "We propose an adaptive optimization framework based on a dynamic feature interaction "
+ "network. This framework captures feature correlations through a hierarchical attention "
+ "mechanism and combines it with a data distribution-aware dynamic weight adjustment "
+ "strategy to improve the model's adaptability to heterogeneous data while ensuring "
+ "computational efficiency."
+ ),
+ "ImplementationSteps": {
+ "1": (
+ "Data preprocessing: missing value filling, outlier handling, feature "
+ "normalization and type conversion, and building a basic feature set"
+ ),
+ "2": (
+ "Feature engineering: generating statistically derived features, time series "
+ "features, and cross-features, and building a feature candidate pool"
+ ),
+ "3": (
+ "Model architecture design: building a basic network module, integrating a "
+ "hierarchical attention mechanism with a dynamic interaction layer"
+ ),
+ "4": (
+ "Dynamic weight mechanism implementation: designing a data distribution-aware "
+ "weight adjustment function and embedding it into the network's intermediate layers"
+ ),
+ "5": (
+ "Model training and tuning: adopting a phased training strategy, using grid "
+ "search and early stopping to optimize hyperparameters"
+ ),
+ "6": (
+ "Performance Verification: Conduct comparative experiments on multiple datasets "
+ "to analyze model performance differences in different scenarios."
+ )
+ },
+ "ImplementationOrder": ["1-2", "2-3", "3-4", "4-5", "1-5", "5-6"],
+ "Dataset": (
+ "Contains three types of public datasets and one actual business data: "
+ "1) Public structured dataset (approximately 500,000 samples, 30+ features); "
+ "2) Text-numeric mixed dataset (approximately 200,000 samples, including text "
+ "embedding features); 3) Time series sparse dataset (approximately 100,000 samples, "
+ "spanning 1 year); 4) Real transaction data from an e-commerce platform "
+ "(approximately 1 million samples, including user behavior and product attribute "
+ "features)"
+ ),
+ "EvaluationMetrics": {
+ "Prediction Accuracy": (
+ "AUC and F1-score are used for classification tasks; MAE and RMSE are used for "
+ "regression tasks to evaluate the basic predictive ability of the model."
+ ),
+ "Robustness": (
+ "Performance decay rate is calculated through data perturbation testing (adding "
+ "noise and simulating feature loss) to measure model stability."
+ ),
+ "Efficiency": (
+ "Record model training time, inference latency, and memory usage to evaluate "
+ "computing resource consumption."
+ ),
+ "Interpretability": (
+ "Use SHAP values and feature importance ranking to quantify the feature "
+ "contribution to model decisions."
+ ),
+ "Generalization": (
+ "Performance retention across datasets to evaluate the model's adaptability "
+ "to unseen data."
+ )
+ },
+ "ExpectedOutcome": (
+ "The proposed framework outperforms existing mainstream methods in comprehensive "
+ "performance (accuracy, robustness, and efficiency) across multiple datasets, "
+ "particularly in scenarios with uneven data distribution and cross-scenario migration. "
+ "It also enhances model interpretability through a dynamic feature interaction "
+ "mechanism, providing effective support for practical business decision-making."
+ )
+ }
+
+ @classmethod
+ def supported_datasets(cls):
+ return ["SGI-IdeaGeneration"]
+
+ def load_data(self, dataset):
+ hf = load_dataset("InternScience/SGI-IdeaGeneration", split="test")
+ rows: List[Dict[str, Any]] = []
+ idx = 0
+ for prob in hf:
+ rows.append({
+ "index": idx,
+ "id": prob.get("idx", idx),
+ "question": prob["question"],
+ "discipline": prob["discipline"],
+ "core_idea": prob["core_idea"],
+ "related_work": prob["related_work"],
+ "related_work_test": prob.get("related_work_test", "{}"),
+ "existing_solutions": prob["existing_solutions"],
+ "implementation_steps": prob["implementation_steps"],
+ "implementation_order": prob["implementation_order"],
+ "data": prob["data"],
+ "evaluation_metrics": prob["evaluation_metrics"],
+ "expected_outcome": prob["expected_outcome"],
+ "keywords": prob.get("keywords", [])
+ })
+ idx += 1
+ return pd.DataFrame(rows)
+
+ def build_prompt(self, line):
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+
+ prompt = line['question'] + f"""\n\n### Example:
+```json
+{json.dumps(self.example, indent=4)}
+```"""
+ msgs = [{'type': 'text', 'value': prompt}]
+ return msgs
+
+ def evaluate(self, eval_file, **judge_kwargs):
+ data = load(eval_file)
+ data = pd.DataFrame(data)
+ global embedding_model
+ # 尝试加载嵌入模型进行评估
+ if embedding_model is None:
+ try:
+ from sentence_transformers import SentenceTransformer
+ logger.info("Loading SentenceTransformer embedding model...")
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
+ logger.info("SentenceTransformer embedding model loaded successfully.")
+ except Exception as e:
+ logger.error(f"Failed to load SentenceTransformer model: {e}")
+ embedding_model = None
+
+ data['generated_data'] = None
+ data['generated_data'] = data['generated_data'].astype(object)
+
+ data['generated_idea_text'] = None
+ data['generated_idea_text'] = data['generated_idea_text'].astype(object)
+
+ # 处理每个生成的想法
+ # Default judge models
+ JUDGE_MODELS = ["gpt-5.1", "gemini-3-pro-preview", "claude-sonnet-4-5-20250929"]
+ # JUDGE_MODELS = ["gpt-5.1-2025-11-13", "gemini-3-pro-preview", "claude-sonnet-4-5-20250929"]
+ llm_judges = [build_judge(**{**judge_kwargs, 'model': i}) for i in JUDGE_MODELS]
+ tups = []
+ indices = []
+ for idx, row in data.iterrows():
+ prediction = row['prediction']
+
+ # 解析生成的想法
+ if isinstance(prediction, str):
+ parsed_data = parse_generated_idea(prediction)
+ data.at[idx, 'generated_data'] = parsed_data
+ data.at[idx, 'generated_idea_text'] = prediction
+
+ # 构建评估所需的原始数据字典
+ ques_dict = row.to_dict()
+ ques_dict['generated_data'] = parsed_data
+ ques_dict['generated_idea_text'] = prediction
+
+ # 评估单个想法
+ tups.append((ques_dict, llm_judges))
+ indices.append(idx)
+
+ tmp_file = get_intermediate_file_path(eval_file, '_judge_tmp', 'pkl')
+ ans = {}
+ if osp.exists(tmp_file):
+ ans = load(tmp_file)
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
+ indices = [i for i in indices if i not in ans]
+
+ if len(indices):
+ track_progress_rich(
+ func=evaluate_single_idea,
+ tasks=tups,
+ nproc=judge_kwargs.get('nproc', 32),
+ save=tmp_file,
+ keys=indices,
+ )
+ ans = load(tmp_file)
+
+ for idx, evaluation_result in ans.items():
+ # 将评估结果添加到数据中
+ for key, value in evaluation_result.items():
+ if key not in ['generated_data', 'generated_idea_text']:
+ data.loc[idx, key] = value
+
+ # 计算平均分数
+ successful_evaluations = data[~data['final_score'].isna()]
+ if len(successful_evaluations) > 0:
+ avg_effectiveness_objective = successful_evaluations['effectiveness_objective'].mean()
+ avg_novelty_objective = successful_evaluations['novelty_objective'].mean()
+ avg_detailedness_objective = successful_evaluations['detailedness_objective'].mean()
+ avg_feasibility_objective = successful_evaluations['feasibility_objective'].mean()
+
+ avg_effectiveness_subjective = successful_evaluations['effectiveness_subjective'].mean()
+ avg_novelty_subjective = successful_evaluations['novelty_subjective'].mean()
+ avg_detailedness_subjective = successful_evaluations['detailedness_subjective'].mean()
+ avg_feasibility_subjective = successful_evaluations['feasibility_subjective'].mean()
+
+ effectiveness_score = successful_evaluations['effectiveness'].mean()
+ novelty_score = successful_evaluations['novelty'].mean()
+ detailedness_score = successful_evaluations['detailedness'].mean()
+ feasibility_score = successful_evaluations['feasibility'].mean()
+
+ avg_final_score = successful_evaluations['final_score'].mean()
+
+ result = {
+ "final_score": float(avg_final_score),
+ "effectiveness": float(effectiveness_score),
+ "novelty": float(novelty_score),
+ "detailedness": float(detailedness_score),
+ "feasibility": float(feasibility_score),
+ "details": {
+ "effectiveness_objective": float(avg_effectiveness_objective),
+ "effectiveness_subjective": float(avg_effectiveness_subjective),
+ "novelty_objective": float(avg_novelty_objective),
+ "novelty_subjective": float(avg_novelty_subjective),
+ "detailedness_objective": float(avg_detailedness_objective),
+ "detailedness_subjective": float(avg_detailedness_subjective),
+ "feasibility_objective": float(avg_feasibility_objective),
+ "feasibility_subjective": float(avg_feasibility_subjective),
+ "successful_evaluations": len(successful_evaluations),
+ "total_evaluations": len(data)
+ }
+ }
+ else:
+ result = {
+ "final_score": 0.0,
+ "effectiveness": 0.0,
+ "novelty": 0.0,
+ "detailedness": 0.0,
+ "feasibility": 0.0,
+ "error": "No successful evaluations"
+ }
+
+ # 保存结果
+ score_file = get_intermediate_file_path(eval_file, '_score', 'csv')
+ result_file = get_intermediate_file_path(eval_file, '_result', 'json')
+ dump(data, score_file)
+ dump(result, result_file)
+
+ return result
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/readme.md b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/readme.md
new file mode 100644
index 00000000..a81f1dd7
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/readme.md
@@ -0,0 +1,12 @@
+## 整体说明
+SGI-Bench-1.0 包含5个子数据集(deep research , dry experiment , wet experiment , experimental reasoning, idea generation)
+
+## 注意事项
+1. dry experiment , experimental reasoning和deep research以及idea generation使用了模型进行评估,需要设置`OPENAI_API_KEY`,以及`OPENAI_API_BASE`环境变量
+2. dry experiment评测过程中需要下载文件,默认路径是`./outputs`,可以通过`--judge-args`命令行参数传入`work_dir`参数进行控制。 评测之前还需要运行以下命令
+3. idea generation 的评测需要额外安装`sentence_transformers`包
+```
+conda create -n dryexp python=3.10.18
+conda activate dryexp
+pip install -r vlmeval/dataset/SGI_Bench_1_0/dry_experiment_requirements.txt
+```
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/utils.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/utils.py
new file mode 100644
index 00000000..307ca3a8
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/utils.py
@@ -0,0 +1,232 @@
+import re
+
+# ############################ Idea Generation ##############################
+
+
+def format_idea_data(idea_data):
+ fields = [
+ "Idea",
+ "ImplementationSteps",
+ "ImplementationOrder",
+ "Dataset",
+ "EvaluationMetrics",
+ "ExpectedOutcome"
+ ]
+
+ formatted_text = ""
+ for field in fields:
+ if field in idea_data and idea_data[field]:
+ formatted_text += f"{field}: {idea_data[field]}\n\n"
+
+ return formatted_text.strip()
+
+
+def get_context_from_data(data):
+ context_fields = [
+ "related_work",
+ "challenge",
+ "limitation",
+ "motivation",
+ "task_objective",
+ "existing_solutions"
+ ]
+ context = ""
+ for field in context_fields:
+ if field in data and data[field]:
+ context += f"{field}: {data[field]}\n\n"
+
+ return context.strip()
+
+
+def flip_evaluation_result(result):
+ flipped = {}
+ mapping = {
+ "win_A": "win_B",
+ "win_B": "win_A"
+ }
+
+ for key, value in result.items():
+ if isinstance(value, dict) and "judgment" in value:
+ flipped[key] = {
+ "judgment": mapping.get(value["judgment"], value["judgment"]),
+ "reason": value.get("reason", "")
+ }
+ else:
+ flipped[key] = mapping.get(value, value)
+
+ return flipped
+
+
+def get_evaluation_prompt_modified(hypothesis_A, hypothesis_B, context=None):
+ context_text = f"Context:\n{context}\n\n" if context else ""
+
+ prompt = f"""
+You are assisting researchers tasked with comparing TWO research hypotheses (Hypothesis A and Hypothesis B).
+Your job is to evaluate both hypotheses across five separate dimensions defined below, and to choose a winner
+(either Hypothesis A or Hypothesis B) for each dimension. Ties are NOT allowed — you MUST pick one winner per
+dimension. Base your judgments on scientific principles and the provided context only.
+
+##Background context:
+{context_text}
+
+##Hypothesis A:
+{hypothesis_A}
+
+##Hypothesis B:
+{hypothesis_B}
+
+##Definition of each dimension:
+###1) Effectiveness
+Which hypothesis is more likely to produce a successful experimental or empirical outcome in service of the stated
+research objective? Evaluate the likelihood that, if implemented using standard practices in the relevant discipline,
+the hypothesis will achieve the intended measurable result. Focus on mechanistic plausibility, causal logic, and
+whether the hypothesis addresses the core problem directly.
+
+###2)Novelty
+Novelty: Which hypothesis presents more innovative or original approaches? Compare the similarity between the idea
+and the related work and existing solutions in the background to assess its novelty. A lower similarity to the core
+idea indicates greater novelty.
+
+###3) Detailedness (Level of Specification)
+Which hypothesis provides clearer, more actionable, and more complete specification of mechanisms, assumptions,
+experimental steps, required variables, and dependencies? Detailedness rewards clarity that would enable a competent
+researcher to design an experiment or implementation with minimal ambiguity.
+
+###4) Feasibility
+Which hypothesis presents a more realistic and implementable solution given current technological constraints?
+
+###5) Overall
+Considering the overall aspects together but emphasizing conceptual coherence and scientific grounding, which
+hypothesis is superior overall? This is a synthesis judgment: prefer the hypothesis that is logically consistent,
+grounded in accepted principles, avoids critical unstated assumptions or contradictions, and is most defensible as
+a scientific proposition.
+
+Unified constraints:
+- Use only the provided context and widely accepted scientific principles in the relevant discipline. Do NOT invent
+facts external to the context unless they are broadly standard domain knowledge.
+- When a dimension explicitly says to ignore other factors (e.g., Novelty should ignore feasibility), strictly follow
+that guidance for that dimension. When evaluating a certain dimension, it should focus on this dimension itself and
+ignore the influence of other dimensions.
+- Be concise but specific: for each dimension provide a short judgment line (exact format below) plus 1–3 sentences
+of succinct reasoning grounded in the definitions above.
+- Format must match exactly (case-insensitive for "Win A/Win B") and include a reason after "because".
+
+
+##Output format (MUST FOLLOW EXACTLY)
+
+Format your response exactly as follows:
+Effectiveness: [Win A/Win B] because ...
+Novelty: [Win A/Win B] because ...
+Detailedness: [Win A/Win B] because ...
+Feasibility: [Win A/Win B] because ...
+Overall: [Win A/Win B] because ...
+"""
+ return prompt
+
+
+def parse_evaluation_result(result):
+ dimensions = ["effectiveness", "novelty", "detailedness", "feasibility", "overall"]
+ parsed_results = {}
+ all_valid = True
+
+ for dim in dimensions:
+ judgment = extract_win_lose(result, dim.capitalize())
+ reason = extract_reason(result, dim.capitalize())
+
+ if judgment is None:
+ all_valid = False
+ break
+
+ parsed_results[dim] = {
+ "judgment": judgment,
+ "reason": reason
+ }
+
+ if not all_valid:
+ return None
+
+ return parsed_results
+
+
+def extract_win_lose(result_text, dimension):
+ pattern = rf"{dimension}\s*:\s*\[\s*(Win\s*A|Win\s*B)\s*\]"
+ match = re.search(pattern, result_text, re.IGNORECASE)
+ if match:
+ judgment = match.group(1).strip().upper()
+ if "WIN A" in judgment:
+ return "win_A"
+ else:
+ return "win_B"
+
+ backup_pattern = rf"{dimension}\s*:\s*(Win\s*A|Win\s*B)\s+"
+ match = re.search(backup_pattern, result_text, re.IGNORECASE)
+ if match:
+ judgment = match.group(1).strip().upper()
+ if "WIN A" in judgment:
+ return "win_A"
+ else:
+ return "win_B"
+
+ line_pattern = rf"{dimension}[^\n]*?(Win\s*A|Win\s*B)"
+ match = re.search(line_pattern, result_text, re.IGNORECASE)
+ if match:
+ judgment = match.group(1).strip().upper()
+ if "WIN A" in judgment:
+ return "win_A"
+ else:
+ return "win_B"
+
+ return None
+
+
+def extract_reason(result_text, dimension):
+ pattern = rf"{dimension}\s*:\s*\[[^\]]+\]\s*because\s*(.*?)(?=\n\w|$)"
+ match = re.search(pattern, result_text, re.IGNORECASE | re.DOTALL)
+ if match:
+ reason = match.group(1).strip()
+ return reason
+
+ backup_pattern = rf"{dimension}\s*:[^\n]*?(because|due to|as|since)([^\n]+)"
+ match = re.search(backup_pattern, result_text, re.IGNORECASE)
+ if match:
+ reason = match.group(2).strip()
+ return reason
+
+ fallback_pattern = rf"{dimension}\s*:[^\n]*(.*?)(?=\n\w+:|$)"
+ match = re.search(fallback_pattern, result_text, re.IGNORECASE | re.DOTALL)
+ if match:
+ text = match.group(1).strip()
+ reason = re.sub(r"\[(Win\s*A|Win\s*B)\]", "", text).strip()
+ return reason
+
+ return "No specific reason provided"
+
+
+# ############################ Idea Generation ##############################
+
+
+def mean(lst: list):
+ assert len(lst) > 0, "list length must > 0"
+ return sum(lst) / len(lst)
+
+
+def show_results(results: list[dict], metric_name: str, category_name: str = None, precision: int = 2, scale=1):
+ category_dict = {}
+ for item in results:
+ if category_name is None:
+ item_category = 'default'
+ else:
+ item_category = item[category_name]
+ if item_category not in category_dict:
+ category_dict[item_category] = []
+
+ item_metric = float(item[metric_name])
+ category_dict[item_category].append(item_metric)
+
+ for k, v in category_dict.items():
+ category_dict[k] = round(mean(v) * scale, precision)
+
+ if category_name is None:
+ return category_dict['default']
+ else:
+ return category_dict
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/wet_experiment.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/wet_experiment.py
new file mode 100644
index 00000000..0c6cffed
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/SGI_Bench_1_0/wet_experiment.py
@@ -0,0 +1,360 @@
+import re
+from itertools import combinations
+from typing import Any, Dict, List
+
+import pandas as pd
+from datasets import load_dataset
+
+from vlmeval.smp import dump, get_intermediate_file_path, load
+from ..text_base import TextBaseDataset
+
+
+def parse_experiment_steps(text):
+ # Regular expression to match experiment steps until
+ # encountering a single line containing only a right
+ # parenthesis ")"
+ # Match format: variable_name = (parameter_list)
+ # Capture groups:
+ # 1: output variable name (e.g., "multimer_cells")
+ # 2: action name (e.g., "Incubate cells with MHC
+ # multimers")
+ # 3: parameter list (e.g.,
+ # "cells=washed_cells,\nmultimer_pool=...")
+ # Condition: parameter list continues until a single
+ # line of ")" (whitespace allowed around it)
+ step_pattern = (
+ r'(\w+)\s*=\s*<([^>]+)>\(\s*([\s\S]*?)'
+ r'(?=\n\s*\)\s*$)'
+ )
+ # Regular expression to match each parameter line
+ # Match format: key=value or key=value,
+ # Capture groups:
+ # 1: parameter key (e.g., "cells")
+ # 2: parameter value (e.g., "washed_cells" or
+ # "\"tetramer pool (23 nM each)\"")
+ # (?:,)? : optionally match a trailing comma at the
+ # end of the line, ignore the comma
+ param_pattern = (
+ r'^\s*(\w+)\s*=\s*(.*?)\s*(?:,)?\s*$'
+ )
+ steps = []
+
+ for match in re.finditer(step_pattern, text, re.MULTILINE):
+ output_var = match.group(1).strip() # Extract output variable name
+ action_name = match.group(2).strip() # Extract action name
+ params = match.group(3).strip() # Extract parameter list
+
+ param_dict = {}
+ # Split the parameter list by lines,
+ # ignoring empty lines and single-line ")"
+ param_lines = [
+ line.strip()
+ for line in params.split('\n')
+ if line.strip() and line.strip() != ')'
+ ]
+ for line in param_lines:
+ param_match = re.match(param_pattern, line)
+ if param_match:
+ key = param_match.group(1) # Extract parameter key
+ value = param_match.group(2).strip()
+ # If the value starts and ends with
+ # double quotes, remove the quotes
+ if value.startswith('"') and value.endswith('"'):
+ value = value[1:-1]
+ param_dict[key] = value
+
+ # Build the step dictionary
+ steps.append({
+ "action": action_name,
+ "input": param_dict,
+ "output": output_var
+ })
+
+ return steps
+
+
+def identify_variable_types(steps):
+ """
+ Identify raw variables and generated variables in the experimental steps.
+ Raw variables: variables that never appear as outputs in any step.
+ Generated variables: variables that appear as outputs of some step.
+
+ Returns:
+ original_vars (set): set of raw variables
+ generated_vars (set): set of generated variables (function outputs)
+ output_to_step_map (dict): mapping from output variable name to the index of
+ its generating step (for reverse lookup)
+ """
+ generated_vars = set()
+ all_input_vars = set()
+ output_to_step_map = {}
+
+ for idx, step in enumerate(steps):
+ output_var = step["output"]
+ generated_vars.add(output_var)
+ output_to_step_map[output_var] = idx # Store the step index
+
+ for input_val in step["input"].values():
+ # Simple check whether it is a variable (non-string literal, non-numeric)
+ # If input_val is a string and does not start and end with quotes,
+ # and is not purely numeric, consider it a variable
+ if (
+ isinstance(input_val, str)
+ and not (input_val.startswith('"') and input_val.endswith('"'))
+ and not (
+ input_val.replace('.', '', 1).isdigit()
+ or (
+ input_val.startswith('-')
+ and input_val[1:].replace('.', '', 1).isdigit()
+ )
+ )
+ ):
+ all_input_vars.add(input_val)
+
+ # Raw variables are those input variables that are not in the set of output variables of any step
+ original_vars = all_input_vars - generated_vars
+
+ return original_vars, generated_vars, output_to_step_map
+
+
+def compare_exp_steps(gt_steps, pred_steps):
+ def kendall_tau_distance(seq1, seq2):
+ if len(seq1) != len(seq2):
+ return 0.0
+ n = len(seq1)
+ if n <= 1:
+ return 1.0
+ inversions = 0
+ for i, j in combinations(range(n), 2):
+ if (seq1[i] < seq1[j] and seq2[i] > seq2[j]) or (seq1[i] > seq1[j] and seq2[i] < seq2[j]):
+ inversions += 1
+ max_inversions = n * (n - 1) / 2
+ return 1.0 - (inversions / max_inversions if max_inversions > 0 else 0.0)
+
+ results = {
+ "order_similarity": 0.0,
+ "error_rate": 0.0,
+ "details": []
+ }
+
+ actions_gt = [step["action"] for step in gt_steps]
+ actions_pred = [step["action"] for step in pred_steps]
+
+ results["order_similarity"] = kendall_tau_distance(actions_gt, actions_pred)
+
+ # Identify variable types and build output mappings
+ original_vars_gt, generated_vars_gt, output_to_step_map_gt = identify_variable_types(gt_steps)
+ original_vars_pred, generated_vars_pred, output_to_step_map_pred = identify_variable_types(pred_steps)
+ # output_to_step_map_pred is only used to judge whether an input is a generated variable
+
+ # Dictionary mapping variable names in pred_steps to corresponding variables in gt_steps
+ var_map_pred2gt = {}
+
+ error_count = 0
+ min_len = min(len(gt_steps), len(pred_steps))
+
+ for i in range(min_len):
+ step_gt = gt_steps[i]
+ step_pred = pred_steps[i]
+ detail = {
+ "step": i + 1,
+ "action_gt": step_gt["action"],
+ "action_pred": step_pred["action"],
+ "status": "✅ success",
+ "message": ""
+ }
+
+ # 1. Check whether the action names match
+ if step_gt["action"] != step_pred["action"]:
+ detail["status"] = "❌ error"
+ detail["message"] += f"Action mismatch: expected '{step_gt['action']}', got '{step_pred['action']}'. "
+ error_count += 1
+ results["details"].append(detail)
+ continue
+
+ # 2. Check the set of parameter keys
+ keys_gt = set(step_gt["input"].keys())
+ keys_pred = set(step_pred["input"].keys())
+ if keys_gt != keys_pred:
+ detail["status"] = "❌ error"
+ detail["message"] += f"Parameter keys mismatch: expected {keys_gt}, got {keys_pred}. "
+ error_count += 1
+ results["details"].append(detail)
+ continue
+
+ # 3. Check argument passing
+ is_step_error = False # Flag whether the current step has parameter errors
+ for key in keys_gt:
+ value_gt = step_gt["input"][key]
+ value_pred = step_pred["input"][key]
+
+ # Determine whether the parameter is a raw variable or a generated variable
+ is_input_var_gt_generated = value_gt in generated_vars_gt
+ is_input_var_pred_generated = value_pred in generated_vars_pred
+
+ # Case 1: Both gt_steps and pred_steps inputs are generated variables (outputs from previous steps)
+ if is_input_var_gt_generated and is_input_var_pred_generated:
+ # Try mapping variables from pred_steps to the corresponding variables in gt_steps
+ mapped_value_pred = var_map_pred2gt.get(value_pred)
+
+ # If the variable from pred_steps successfully maps to the corresponding variable in gt_steps,
+ # and the mapped value matches the expected value in gt_steps
+ if mapped_value_pred == value_gt:
+ pass # Match succeeds; continue
+ else:
+ detail["status"] = "❌ error"
+ detail["message"] += (
+ f"Parameter '{key}' generated variable reference mismatch: "
+ f"expected from '{value_gt}', got from '{value_pred}' "
+ f"(mapped as '{mapped_value_pred}'). "
+ )
+ is_step_error = True
+ # Case 2: Both inputs are raw variables (literals or inputs not defined as function outputs)
+ elif not is_input_var_gt_generated and not is_input_var_pred_generated:
+ # For raw variables, do not strictly require identical values;
+ # even if values differ, consider it correct
+ pass
+ # Case 3: Type mismatch (one is a raw variable, the other is a generated variable)
+ else:
+ detail["status"] = "❌ error"
+ detail["message"] += (
+ f"Parameter '{key}' type mismatch: "
+ f"expected {'generated variable' if is_input_var_gt_generated else 'raw variable'}, "
+ f"got {'generated variable' if is_input_var_pred_generated else 'raw variable'}. "
+ )
+ is_step_error = True
+
+ # If the current step has no parameter errors, update the variable mapping
+ if not is_step_error:
+ # Only when the action and parameters both match,
+ # map the output variable in pred_steps to the output variable in gt_steps
+ var_map_pred2gt[step_pred["output"]] = step_gt["output"]
+ else:
+ # If the step has errors, increment the error count
+ error_count += 1
+
+ results["details"].append(detail)
+
+ # Handle the case where lengths are inconsistent
+ if len(gt_steps) != len(pred_steps):
+ error_count += abs(len(gt_steps) - len(pred_steps))
+ if len(pred_steps) > len(gt_steps):
+ for i in range(min_len, len(pred_steps)):
+ results["details"].append({
+ "step": i + 1,
+ "action_gt": None,
+ "action_pred": pred_steps[i]["action"],
+ "status": "❌ error",
+ "message": "Extra step."
+ })
+ elif len(gt_steps) > len(pred_steps):
+ for i in range(min_len, len(gt_steps)):
+ results["details"].append({
+ "step": i + 1,
+ "action_gt": gt_steps[i]["action"],
+ "action_pred": None,
+ "status": "❌ error",
+ "message": "Missing step."
+ })
+
+ results["parameter_acc"] = 1 - (error_count / max(len(gt_steps), len(pred_steps)))
+
+ return results
+
+
+def extract_final_answer(answer_with_thinking: str, start_tag='', end_tag=''):
+ answer_with_thinking = str(answer_with_thinking)
+ start_index = answer_with_thinking.rfind(start_tag)
+ if start_index != -1:
+ end_index = answer_with_thinking.find(end_tag, start_index)
+ if end_index != -1:
+ return answer_with_thinking[start_index + len(start_tag):end_index].strip()
+ return None
+
+
+class SGI_Bench_Wet_Experiment(TextBaseDataset):
+ TYPE = 'QA'
+
+ @classmethod
+ def supported_datasets(cls):
+ return ["SGI-WetExperiment"]
+
+ def load_data(self, dataset):
+ hf = load_dataset("InternScience/SGI-WetExperiment", split="test")
+
+ rows: List[Dict[str, Any]] = []
+ idx = 0
+ for prob in hf:
+ rows.append(
+ {
+ "index": idx,
+ "id": prob["idx"],
+ "question": prob["question"],
+ "answer": prob["answer"],
+ "discipline": prob["discipline"],
+ "direction": prob["direction"]
+ }
+ )
+ idx += 1
+ return pd.DataFrame(rows)
+
+ def build_prompt(self, line):
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+ question = line['question'] + """
+The final answer should be enclosed by and .
+
+Example:
+
+dataset = (
+ source="imagenet"
+)
+
+model_init = (
+ model_type="CNN"
+)
+
+model_trained = (
+ model=model_init,
+ data=dataset
+)
+
+metrics = (
+ model=model_trained,
+ data=dataset
+)
+
+"""
+
+ msgs = [{'type': 'text', 'value': question}]
+ return msgs
+
+ def evaluate(self, eval_file, **judge_kwargs):
+ data = load(eval_file)
+ data = pd.DataFrame(data)
+
+ data['action_sequence_similarity'] = 0
+ data['parameter_accuracy'] = 0
+ for index, row in data.iterrows():
+ target_steps = row['answer']
+ target_steps = parse_experiment_steps(target_steps)
+ extracted_text = extract_final_answer(row['prediction'])
+ if extracted_text:
+ prediction_steps = parse_experiment_steps(extracted_text)
+ else:
+ prediction_steps = []
+
+ steps_result = compare_exp_steps(target_steps, prediction_steps)
+
+ data.loc[index, 'action_sequence_similarity'] = steps_result['order_similarity']
+ data.loc[index, 'parameter_accuracy'] = steps_result['parameter_acc']
+
+ score_file = get_intermediate_file_path(eval_file, '_score', 'csv')
+ result = {
+ "Action_Sequence_Similarity": data['action_sequence_similarity'].mean(),
+ "Parameter_Accuracy": data['parameter_accuracy'].mean(),
+ }
+ result_file = get_intermediate_file_path(eval_file, '_result', 'json')
+ dump(data, score_file)
+ dump(result, result_file)
+ return result
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/__init__.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/__init__.py
new file mode 100644
index 00000000..bcfd4544
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/__init__.py
@@ -0,0 +1,494 @@
+import copy
+import os.path as osp
+import warnings
+
+import numpy as np
+import pandas as pd
+
+from vlmeval.smp import LMUDataRoot, dump, get_intermediate_file_path, load, localize_df, toliststr
+from .aetcbench import AETCBench
+from .asclepius import Asclepius
+# Fork-specific datasets
+from .av_3d_grounding import ThreeDAVGroundingBench
+from .av_prompt_following import AVPromptFollowingBench
+from .av_speakerbench import AVSpeakerBench
+from .avspecial_collision_bench import AVSpecialCollisionBench
+from .avspecial_environment_bench import AVSpecialEnvironmentBench
+from .avspecial_ood_reasoning_bench import AVSpecialOODReasoningBench
+from .avspecial_stop_behavior_bench import AVSpecialStopBehaviorBench
+from .blink_depth import BlinkDepth
+from .blink_spatial import BlinkSpatial
+from .camera_bench import CameraBench
+from .camera_intrinsic import CameraIntrinsicBench
+from .causalvqa import CausalVQA
+from .CGAVCounting.cg_av_counting import CGAVCounting
+from .cgbench import (CGBench_MCQ_Grounding, CGBench_MCQ_Grounding_Mini, CGBench_OpenEnded,
+ CGBench_OpenEnded_Mini)
+from .chartbench import ChartBench
+from .chartcap import ChartCapDataset
+from .chartmimic import ChartMimic
+from .chartmuseum import ChartMuseum
+from .chartqapro import ChartQAPro
+from .chartx import ChartX
+from .charxiv import CharXiv
+from .cmmmu import CMMMU
+from .cosmos_cab_image import CosmosCABImage
+from .cosmos_cab_video import CosmosCABVideoCamera, CosmosCABVideoGeneral
+from .cosmos_erqa import CosmosERQA
+from .worldbench import WorldBench
+from .cosmos_reason import CosmosReason
+from .creation import CreationMMBenchDataset
+from .cv_bench import CVBench
+from .da2k import DA2K
+from .design2code import Design2Code
+from .dream import DREAM
+from .dsrbench import DSRBench
+from .dude import DUDE
+from .dynamath import Dynamath
+from .EgoExoBench.egoexobench import EgoExoBench_MCQ
+from .embspatialbench import EmbSpatialBench
+from .emma import EMMADataset
+from .eriq import ERIQBench
+from .erqa import ERQADataset
+from .erqabench import ERQABench
+from .flames import FlamesDataset
+from .foxbench import FoxBench
+from .gobench import GOBenchDataset
+from .groundingme import GroundingME
+from .gsm8k_v import GSM8KVDataset
+from .GUI.osworld_g import OSWorld_G
+from .GUI.screenspot import ScreenSpot
+from .GUI.screenspot_pro import ScreenSpot_Pro
+from .GUI.screenspot_v2 import ScreenSpotV2
+from .GUI.vbgd import VBGD
+from .GUI.venusbench import VenusBench_GD
+from .health_surgi_bench import HealthSurgiBench
+from .hipho import HiPhODataset
+from .IFBench.ifbench import IFBench
+from .image_base import ImageBaseDataset, img_root_map
+from .image_caption import ImageCaptionDataset
+from .image_ccocr import CCOCRDataset
+from .image_mcq import (CVQA, LEGO, SCAM, AffordanceDataset, CustomMCQDataset,
+ GMAIMMBenchDataset, HRBenchDataset, ImageMCQDataset, MedXpertQA_MM_test,
+ MicroBench, MMERealWorld, MMMUDataset, MMMUProDataset, MSEarthMCQ,
+ MUIRDataset, NaturalBenchDataset, OmniEarthMCQBench, OmniMedVQA, PuzzleVQA,
+ TDBench, TopViewRS, TreeBench, VisualPuzzles, VisuLogic, VLMBlind,
+ VMCBenchDataset, WeMath, XLRSBench, _3DSRBench)
+from .image_mt import MMDUDataset
+from .image_shortqa import ImageShortQADataset, PathVQA_TEST, PathVQA_VAL
+from .image_vqa import (BMMR, CRPE, LENS, MMNIAH, AyaVisionBench, CoreCognition, CountBenchQA,
+ CustomVQADataset, ImageVQADataset, LLaVABench, LLaVABench_KO, LogicVista,
+ MathCanvas, MathVerse, MathVision, MathVista, MME_CoT, MMEReasoning,
+ MMReason, MMSci_Captioning, MMVet, MMVMBench, MTVQADataset, OCR_Reasoning,
+ OCRBench, OCRBench_v2, OlympiadBench, Omni3DBench, Physics_yale, PhyX,
+ QSpatial, SeePhys, TableVQABench, TallyQA, TDBenchGrounding, VGRPBench,
+ VizWiz, VLMsAreBiased, VTCBench, WildDocBenchmark, ZEROBench)
+from .image_yorn import ImageYORNDataset
+from .intphys2 import IntPhys2
+from .its_collision import ITSCollision
+from .lingoqa import LingoQA
+from .locate_anything_bench import LocateAnythingBench
+from .longvideobench import LongVideoBench
+from .lv_event_verification import LVEventVerification
+from .lvs import LVSDataset, LVSHallucinationDataset
+from .lvs_ai_hallucination import LVSAIHallucinationDataset
+from .m3oralbench import M3oralBenchDataset
+from .m4bench import M4Bench
+from .macbench import MaCBench
+from .matbench import MATBench
+from .medqbench_caption import MedqbenchCaptionDataset
+from .medqbench_mcq import MedqbenchMCQDataset
+from .medqbench_paired_description import MedqbenchPairedDescriptionDataset
+from .megabench import MEGABench
+from .metropolis2d.astro_2d_dataset import Astro2DDetectionDataset, _build_astro2dbench_dataset
+from .metropolis2d.detection_2d_dataset import Metropolis2DDetectionDataset
+from .metropolis2d.grounding_2d_dataset import Metropolis2DGroundingDataset
+from .metropolis_dvc import MetropolisDVC
+from .metropolis_event_verification import MetropolisEventVerification
+from .metropolis_temporal import MetropolisTemporal
+from .metropolis_vqa import MetropolisVQA
+from .miabench import MIABench
+from .mindcubebench import MindCubeBench
+from .mlvu import MLVU, MLVU_MCQ, MLVU_OpenEnded
+from .mmalignbench import MMAlignBench
+from .mmbench_video import MMBenchVideo
+from .mmesci import MMESCIDataset
+from .mmgenbench import MMGenBench
+from .mmhelix import MMHELIX
+from .mmifeval import MMIFEval
+from .mmlongbench import MMLongBench
+from .mmmath import MMMath
+from .mmoral_opg_closed import MMOral_OPG_CLOSED
+from .mmoral_opg_open import MMOral_OPG_OPEN
+from .mmsafetybench import MMSafetyBenchDataset
+from .mmsibench import MMSIBench, MMSIVideoBench
+from .moat import MOAT
+from .moviechat1k import MovieChat1k
+from .mssbench import MSSBenchDataset
+from .mvbench import MVBench, MVBench_MP4
+from .mvpbench import MVPBench
+from .mvu_eval import MVUEval
+from .NPMM import NPMM
+from .oceanocr import OceanOCRBench
+from .odinw13 import ODinW13Dataset
+from .olmOCRBench.olmocrbench import olmOCRBench
+from .Omni3D.omni3d_dataset import Omni3DDetectionDataset
+from .OmniDocBench.omnidocbench import OmniDocBench
+from .omnispatialbench import OmniSpatialBench
+from .omtgbench import OMTGBench
+from .ost_bench import OSTDataset
+from .plotqa import PlotQA
+from .qbench_video import QBench_Video, QBench_Video_MCQ, QBench_Video_VQA
+from .reasonmap_plus import ReasonMap_Plus
+from .refcoco import RefCOCODataset
+from .refspatial import RefSpatialDataset
+from .refspatialbench import RefSpatialBench
+from .robospatial_home import RoboSpatialHome
+from .robospatialbench import RoboSpatialBench
+from .sarena import SArena
+from .sat_bench import SATBench
+from .sfebench import SFE
+from .SGI_Bench_1_0.deep_research import SGI_Bench_Deep_Research
+from .SGI_Bench_1_0.dry_experiment import SGI_Bench_Dry_Experiment
+from .SGI_Bench_1_0.experimental_reasoning import SGI_Bench_Experimental_Reasoning
+from .SGI_Bench_1_0.idea_generation import SGI_Bench_Idea_Generation
+from .SGI_Bench_1_0.wet_experiment import SGI_Bench_Wet_Experiment
+from .share_robot_bench import ShareRobotTrajectory
+from .simplevqa import SimpleVQA
+from .sitebench import SiteBenchImage, SiteBenchVideo
+from .siuo import SIUODataset
+from .siuo_gen import SIUOGenDataset
+from .siuo_mcq import SIUOMCQDataset
+from .slidevqa import SlideVQA
+from .sop_action_recognition import SOPActionRecognition
+from .sop_temporal_localization import SOPTemporalLocalization
+from .spar_bench import SPARBench, SPARBenchTiny
+from .sparbench import SparBench
+from .spatial457 import Spatial457
+from .spatialvizbench import SpatialVizBench
+from .spbench import SPBench
+from .ssi_bench import SSIBenchDataset
+from .starebench import StareBench
+from .stibench import STIBench
+from .tailgating import TailgatingVerification
+from .tamperbench import MVTamperBench
+from .tempcompass import TempCompass, TempCompass_Captioning, TempCompass_MCQ, TempCompass_YorN
+from .temporal_localization_bench import TemporalLocalization
+from .text_mcq import CustomTextMCQDataset, TextMCQDataset
+from .uni_svg import UniSVG
+from .utils import DEBUG_MESSAGE, build_judge, extract_answer_from_item, prefetch_answer
+from .v2pbench import V2PBench
+from .vantage_pointing import VANTAGE_2DPointing
+from .vantage_sot import VANTAGE_SOT
+from .vcr import VCRDataset
+from .vcrbench import VCRBench
+from .vdc import VDC
+from .video_concat_dataset import ConcatVideoDataset
+from .video_holmes import Video_Holmes
+from .motionbench import MotionBench
+from .video_mmlu import Video_MMLU_CAP, Video_MMLU_QA
+from .videomme import VideoMME
+from .videommmu import VideoMMMU
+from .videophy2 import VideoPhy2
+from .videott import VideoTT
+from .viewspatialbench import ViewSpatialBench
+from .visfactor import VisFactor
+from .vl_rewardbench import VLRewardBench
+from .vladbench import VLADBench
+from .vlm2bench import VLM2Bench
+from .vlmbias import VLMBias
+from .vlrmbench import VLRMBench
+from .vsibench import VsiBench, VsiSuperCount, VsiSuperRecall
+from .vss_oobe_clips import VSSoobeClips
+from .warehouse_near_miss import WarehouseNearMiss
+from .warehouse_spatial_ai import WarehouseSpatialAI
+from .where2place import Where2Place
+from .wildvision import WildVision
+from .worldsense import WorldSense
+from .worldvqa import WorldVQA
+from .xstest import XSTestDataset
+
+from .video_dataset_config import supported_video_datasets # isort: skip
+
+
+class ConcatDataset(ImageBaseDataset):
+ # This dataset takes multiple dataset names as input and aggregate them into a single dataset.
+ # Each single dataset should not have a field named `SUB_DATASET`
+
+ DATASET_SETS = {
+ 'MMMB': ['MMMB_ar', 'MMMB_cn', 'MMMB_en', 'MMMB_pt', 'MMMB_ru', 'MMMB_tr'],
+ 'MTL_MMBench_DEV': [
+ 'MMBench_dev_ar', 'MMBench_dev_cn', 'MMBench_dev_en',
+ 'MMBench_dev_pt', 'MMBench_dev_ru', 'MMBench_dev_tr'
+ ],
+ 'ScreenSpot_Pro': [
+ 'ScreenSpot_Pro_Development', 'ScreenSpot_Pro_Creative', 'ScreenSpot_Pro_CAD',
+ 'ScreenSpot_Pro_Scientific', 'ScreenSpot_Pro_Office', 'ScreenSpot_Pro_OS'
+ ],
+ 'ScreenSpot': ['ScreenSpot_Mobile', 'ScreenSpot_Desktop', 'ScreenSpot_Web'],
+ 'ScreenSpot_v2': ['ScreenSpot_v2_Mobile', 'ScreenSpot_v2_Desktop', 'ScreenSpot_v2_Web'],
+ 'M4Bench': ['State_Invariance', 'State_Comparison', 'Spatial_Perception', 'Instance_Comparison', 'Detailed_Difference'], # noqa: E501
+ }
+
+ def __init__(self, dataset, **kwargs):
+ # Forward profile-level kwargs (e.g. `model_family` from `--profile`)
+ # to each child so concat-wrapped benchmarks can per-family dispatch.
+ # Children that don't consume a given kwarg absorb it via
+ # ImageBaseDataset's `**kwargs` catch-all.
+ datasets = self.DATASET_SETS[dataset]
+ self.dataset_map = {}
+ # The name of the compliation
+ self.dataset_name = dataset
+ self.datasets = datasets
+ for dname in datasets:
+ dataset = build_dataset(dname, **kwargs)
+ assert dataset is not None, dataset
+ self.dataset_map[dname] = dataset
+ TYPES = [x.TYPE for x in self.dataset_map.values()]
+ MODALITIES = [x.MODALITY for x in self.dataset_map.values()]
+ assert np.all([x == TYPES[0] for x in TYPES]), (datasets, TYPES)
+ assert np.all([x == MODALITIES[0] for x in MODALITIES]), (datasets, MODALITIES)
+ self.TYPE = TYPES[0]
+ self.MODALITY = MODALITIES[0]
+ data_all = []
+ for dname in datasets:
+ data = self.dataset_map[dname].data
+ data['SUB_DATASET'] = [dname] * len(data)
+ if 'image' in data:
+ data_new = localize_df(data, dname, nproc=16)
+ data_all.append(data_new)
+ else:
+ data_all.append(data)
+
+ data = pd.concat(data_all)
+ data['original_index'] = data.pop('index')
+ data['index'] = np.arange(len(data))
+ self.data = data
+
+ def build_prompt(self, line):
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+ idx = line['original_index']
+ dname = line['SUB_DATASET']
+ org_data = self.dataset_map[dname].data
+ org_line = copy.deepcopy(org_data[org_data['index'] == idx]).iloc[0]
+ return self.dataset_map[dname].build_prompt(org_line)
+
+ def dump_image(self, line):
+ # Assert all images are pre-dumped
+ assert 'image' not in line
+ assert 'image_path' in line
+ tgt_path = toliststr(line['image_path'])
+ return tgt_path
+
+ @classmethod
+ def supported_datasets(cls):
+ return list(cls.DATASET_SETS)
+
+ def evaluate(self, eval_file, **judge_kwargs):
+ # First, split the eval_file by dataset
+ data_all = load(eval_file)
+ for dname in self.datasets:
+ tgt = eval_file.replace(self.dataset_name, dname)
+ data_sub = data_all[data_all['SUB_DATASET'] == dname]
+ data_sub.pop('index')
+ data_sub['index'] = data_sub.pop('original_index')
+ data_sub.pop('SUB_DATASET')
+ dump(data_sub, tgt)
+ # Then, evaluate each dataset separately
+ df_all = []
+ dict_all = {}
+ # One of the vars will be used to aggregate results
+ for dname in self.datasets:
+ tgt = eval_file.replace(self.dataset_name, dname)
+ res = self.dataset_map[dname].evaluate(tgt, **judge_kwargs)
+ if isinstance(res, pd.DataFrame):
+ res['DATASET'] = [dname] * len(res)
+ df_all.append(res)
+ elif isinstance(res, dict):
+ res = {f'{dname}:{k}': v for k, v in res.items()}
+ dict_all.update(res)
+ else:
+ raise NotImplementedError(f'Unknown result type {type(res)}')
+
+ if len(df_all):
+ result = pd.concat(df_all)
+ score_file = get_intermediate_file_path(eval_file, '_acc', 'csv')
+ dump(result, score_file)
+ return result
+ else:
+ score_file = get_intermediate_file_path(eval_file, '_score', 'json')
+ dump(dict_all, score_file)
+ return dict_all
+
+
+# Add new supported dataset class here
+IMAGE_DATASET = [
+ ImageCaptionDataset, ImageYORNDataset, ImageMCQDataset, ImageVQADataset,
+ MathVision, LENS, MMMUDataset, OCRBench, MathVista, LLaVABench, LLaVABench_KO, VGRPBench, MMVet, # noqa: E501
+ MTVQADataset, TableVQABench, MMLongBench, VCRDataset, MMDUDataset, DUDE,
+ SlideVQA, MUIRDataset, CCOCRDataset, GMAIMMBenchDataset, MMERealWorld,
+ HRBenchDataset, CRPE, MathVerse, NaturalBenchDataset, MIABench,
+ OlympiadBench, SeePhys, WildVision, MMMath, QSpatial, Dynamath, GSM8KVDataset, MMGenBench, VizWiz, # noqa: E501
+ MMNIAH, CMMMU, VLRewardBench, WeMath, LogicVista, MMMUProDataset,
+ CreationMMBenchDataset, ImageShortQADataset, MMAlignBench, OmniDocBench,
+ VLM2Bench, VMCBenchDataset, EMMADataset, MME_CoT, MOAT, MedXpertQA_MM_test,
+ LEGO, MMSci_Captioning, Physics_yale, ScreenSpot_Pro, ScreenSpot, VenusBench_GD,
+ ScreenSpotV2, OSWorld_G, VBGD, MMIFEval, Spatial457, SPARBench, SPARBenchTiny, VisuLogic, CVBench, PathVQA_VAL,
+ PathVQA_TEST, TDBench, TDBenchGrounding, MicroBench, CharXiv, OmniMedVQA,
+ WildDocBenchmark, MSEarthMCQ, OCR_Reasoning, PhyX, VLMBlind, CountBenchQA,
+ ZEROBench, SCAM, Omni3DBench, Omni3DDetectionDataset, TallyQA, _3DSRBench, BMMR, AffordanceDataset,
+ MMEReasoning, GOBenchDataset, SFE, ChartMimic, MMVMBench, XLRSBench,
+ OmniEarthMCQBench, VisFactor, OSTDataset, OCRBench_v2, TreeBench, CVQA, M4Bench,
+ AyaVisionBench, TopViewRS, VLMBias, MMHELIX, MedqbenchMCQDataset, MathCanvas, MMReason,
+ MedqbenchPairedDescriptionDataset, MedqbenchCaptionDataset, ChartMuseum, ChartQAPro, ReasonMap_Plus, # noqa: E501
+ olmOCRBench, OceanOCRBench, MATBench, VLRMBench, RefCOCODataset, RefSpatialDataset,
+ ERQADataset, SimpleVQA, HiPhODataset, MaCBench,
+ UniSVG, SArena, VLMsAreBiased, MMESCIDataset, CoreCognition, GroundingME,
+ FoxBench, VTCBench, Asclepius, PlotQA, ChartX, ChartBench, ChartCapDataset, WorldVQA, PuzzleVQA, VisualPuzzles, # noqa: E501
+ MMSafetyBenchDataset, MSSBenchDataset, SIUODataset, SIUOGenDataset, SIUOMCQDataset, M3oralBenchDataset, # noqa: E501
+ Design2Code, VLADBench, SSIBenchDataset, NPMM, SGI_Bench_Experimental_Reasoning, MMOral_OPG_OPEN, MMOral_OPG_CLOSED, # noqa: E501
+ CVBench, BlinkDepth, BlinkSpatial, CosmosERQA, RoboSpatialHome, SATBench, Where2Place,
+ ShareRobotTrajectory, CameraIntrinsicBench, ThreeDAVGroundingBench, WarehouseSpatialAI,
+ Metropolis2DDetectionDataset, Metropolis2DGroundingDataset, Astro2DDetectionDataset, HealthSurgiBench,
+ LocateAnythingBench, CosmosCABImage, VANTAGE_2DPointing,
+ AVSpecialOODReasoningBench, ODinW13Dataset, WorldBench,
+]
+
+# add by EASI team
+IMAGE_DATASET += [
+ MindCubeBench, EmbSpatialBench, ViewSpatialBench, MMSIBench, SiteBenchImage,
+ SparBench, SpatialVizBench, StareBench, OmniSpatialBench, ERQABench, RoboSpatialBench, RefSpatialBench, # noqa: E501
+ SPBench, ERIQBench, DA2K
+]
+
+VIDEO_DATASET = [
+ MMBenchVideo, VideoMME, MVBench, MVBench_MP4, MVTamperBench,
+ LongVideoBench, WorldSense, VDC, MovieChat1k, MEGABench,
+ MLVU, MLVU_MCQ, MLVU_OpenEnded,
+ TempCompass, TempCompass_MCQ, TempCompass_Captioning, TempCompass_YorN,
+ CGBench_MCQ_Grounding_Mini, CGBench_OpenEnded_Mini, CGBench_MCQ_Grounding, CGBench_OpenEnded,
+ QBench_Video, QBench_Video_MCQ, QBench_Video_VQA,
+ Video_MMLU_CAP, Video_MMLU_QA,
+ Video_Holmes, VCRBench, CGAVCounting, MotionBench,
+ EgoExoBench_MCQ, DREAM, VideoTT, VideoMMMU, MVUEval, OMTGBench, V2PBench, AVSpeakerBench,
+ TailgatingVerification, VideoPhy2, CosmosReason, MetropolisTemporal, MetropolisVQA, MetropolisDVC,
+ CausalVQA, MVPBench, IntPhys2, ITSCollision, MetropolisEventVerification, LVEventVerification, WarehouseNearMiss,
+ VANTAGE_SOT, LingoQA, SOPActionRecognition, SOPTemporalLocalization,
+ LVSDataset, LVSHallucinationDataset, LVSAIHallucinationDataset,
+ VSSoobeClips, AVSpecialCollisionBench, AVSpecialEnvironmentBench, AVSpecialStopBehaviorBench, TemporalLocalization,
+ AVPromptFollowingBench,
+ AETCBench, CameraBench, CosmosCABVideoGeneral, CosmosCABVideoCamera,
+]
+
+# add by EASI team
+VIDEO_DATASET += [SiteBenchVideo, VsiBench, VsiSuperRecall, VsiSuperCount, MMSIVideoBench, STIBench, DSRBench] # noqa: E501
+
+TEXT_DATASET = [
+ TextMCQDataset, IFBench, SGI_Bench_Wet_Experiment, SGI_Bench_Dry_Experiment,
+ SGI_Bench_Deep_Research, SGI_Bench_Idea_Generation, XSTestDataset, FlamesDataset
+]
+
+CUSTOM_DATASET = [
+ CustomMCQDataset, CustomVQADataset, CustomTextMCQDataset
+]
+
+# Astro2DBenchDataset is a ConcatDataset subclass, so it must be built after ConcatDataset.
+Astro2DBenchDataset = _build_astro2dbench_dataset(ConcatDataset)
+
+DATASET_COLLECTION = [ConcatDataset, ConcatVideoDataset, Astro2DBenchDataset]
+
+DATASET_CLASSES = IMAGE_DATASET + VIDEO_DATASET + TEXT_DATASET + CUSTOM_DATASET + DATASET_COLLECTION # noqa: E501
+SUPPORTED_DATASETS = []
+for DATASET_CLS in DATASET_CLASSES:
+ SUPPORTED_DATASETS.extend(DATASET_CLS.supported_datasets())
+
+
+def DATASET_TYPE(dataset, *, default: str = 'MCQ') -> str:
+ for cls in DATASET_CLASSES:
+ if dataset in cls.supported_datasets():
+ if hasattr(cls, 'TYPE'):
+ return cls.TYPE
+ # Have to add specific routine to handle ConcatDataset
+ if dataset in ConcatDataset.DATASET_SETS:
+ dataset_list = ConcatDataset.DATASET_SETS[dataset]
+ TYPES = [DATASET_TYPE(dname) for dname in dataset_list]
+ assert np.all([x == TYPES[0] for x in TYPES]), (dataset_list, TYPES)
+ return TYPES[0]
+
+ if 'openended' in dataset.lower():
+ return 'VQA'
+ warnings.warn(f'Dataset {dataset} is a custom one and not annotated as `openended`, will treat as {default}. ') # noqa: E501
+ return default
+
+
+def DATASET_MODALITY(dataset, *, default: str = 'IMAGE') -> str:
+ if dataset is None:
+ warnings.warn(f'Dataset is not specified, will treat modality as {default}. ')
+ return default
+ for cls in DATASET_CLASSES:
+ if dataset in cls.supported_datasets():
+ if hasattr(cls, 'MODALITY'):
+ return cls.MODALITY
+ # Have to add specific routine to handle ConcatDataset
+ if dataset in ConcatDataset.DATASET_SETS:
+ dataset_list = ConcatDataset.DATASET_SETS[dataset]
+ MODALITIES = [DATASET_MODALITY(dname) for dname in dataset_list]
+ assert np.all([x == MODALITIES[0] for x in MODALITIES]), (dataset_list, MODALITIES)
+ return MODALITIES[0]
+
+ if 'VIDEO' in dataset.lower():
+ return 'VIDEO'
+ elif 'IMAGE' in dataset.lower():
+ return 'IMAGE'
+ warnings.warn(f'Dataset {dataset} is a custom one, will treat modality as {default}. ')
+ return default
+
+
+def build_dataset(dataset_name, **kwargs):
+ for cls in DATASET_CLASSES:
+ if dataset_name in supported_video_datasets:
+ return supported_video_datasets[dataset_name](**kwargs)
+ elif dataset_name in cls.supported_datasets():
+ return cls(dataset=dataset_name, **kwargs)
+
+ warnings.warn(f'Dataset {dataset_name} is not officially supported. ')
+ data_file = osp.join(LMUDataRoot(), f'{dataset_name}.tsv')
+ if not osp.exists(data_file):
+ warnings.warn(f'Data file {data_file} does not exist. Dataset building failed. ')
+ return None
+
+ data = load(data_file)
+ if 'question' not in [x.lower() for x in data.columns]:
+ warnings.warn(
+ f'Data file {data_file} does not have a `question` column. Dataset building failed. ')
+ return None
+
+ if 'A' in data and 'B' in data:
+ if 'image' in data or 'image_path' in data:
+ warnings.warn(
+ f'Will assume unsupported dataset {dataset_name} as a Custom MCQ dataset. ')
+ return CustomMCQDataset(dataset=dataset_name, **kwargs)
+ else:
+ warnings.warn(
+ f'Will assume unsupported dataset {dataset_name} as a Custom Text MCQ dataset. ')
+ return CustomTextMCQDataset(dataset=dataset_name, **kwargs)
+ else:
+ warnings.warn(f'Will assume unsupported dataset {dataset_name} as a Custom VQA dataset. ')
+ return CustomVQADataset(dataset=dataset_name, **kwargs)
+
+
+def infer_dataset_basename(dataset_name):
+ basename = "_".join(dataset_name.split("_")[:-1])
+ return basename
+
+
+__all__ = [
+ 'build_dataset',
+ 'img_root_map',
+ 'build_judge',
+ 'extract_answer_from_item',
+ 'prefetch_answer',
+ 'DEBUG_MESSAGE',
+]
+__all__.extend([cls.__name__ for cls in DATASET_CLASSES])
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/aetcbench.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/aetcbench.py
new file mode 100644
index 00000000..78a0dd3b
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/aetcbench.py
@@ -0,0 +1,778 @@
+import functools
+import json
+import os
+import re
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+from vlmeval.dataset.utils.lvs import bert_score
+from ..smp import *
+from .video_base import VideoBaseDataset
+logger = get_logger('AETCBench')
+
+# All recognized task types and their filename patterns
+TASK_TYPES = [
+ 'bcq', 'mcq', 'open_qa',
+ 'temporal_localization', 'causal_linkage',
+ 'scene_description', 'temporal_description', 'video_summarization',
+ 'bcq_openended', 'mcq_openended',
+]
+
+# DSS dataset names
+DSS_TASKS_DATASET = 'AETC-Tasks'
+DSS_VIDEOS_DATASET = 'AETC-Videos'
+
+
+# Answer-format instructions per task type (appended to user query)
+_ANSWER_INSTRUCTIONS = {
+ 'bcq': 'Answer with only Yes or No.',
+ 'bcq_openended': 'Answer with Yes or No, followed by a brief explanation.',
+ 'mcq': 'Choose the correct option by letter only.',
+ 'mcq_openended': 'Choose the correct option and provide a brief explanation.',
+ "temporal_localization": "Provide the result in json format with 'mm:ss' for time depiction. Use keywords 'start', 'end' in the json output.",
+}
+
+
+def _build_user_query(task_type, item):
+ """Build the user-facing prompt string from a task item.
+
+ Mirrors the prompt format used in training (_build_conversation).
+ Ground-truth answer and reasoning are intentionally excluded.
+ """
+ question = item.get('question', '')
+
+ # Append MCQ options
+ if task_type in ('mcq', 'mcq_openended'):
+ options = item.get('options')
+ if options:
+ options_text = '\n'.join(f'{k}) {v}' for k, v in sorted(options.items()))
+ question = f'{question}\n\n{options_text}'
+
+ # Append answer-format instruction
+ instruction = _ANSWER_INSTRUCTIONS.get(task_type)
+ if instruction:
+ question = f'{question}\n\n{instruction}'
+
+ return question
+
+
+def _format_reference_answer(task_type, item):
+ """Format the ground-truth answer string for evaluation.
+
+ Mirrors _format_answer from training code.
+ """
+ if task_type in (
+ 'open_qa', 'bcq_openended', 'mcq_openended',
+ 'video_summarization', 'scene_description',
+ 'temporal_description', 'causal_linkage',
+ ):
+ return item.get('answer') or ''
+
+ if task_type == 'bcq':
+ answer = item.get('answer', '')
+ explanation = item.get('explanation', '')
+ return f'{answer}. {explanation}' if explanation else answer
+
+ if task_type == 'mcq':
+ letter = item.get('answer', '')
+ options = item.get('options', {})
+ label = f'{letter}) {options[letter]}' if letter in options else letter
+ explanation = item.get('explanation', '')
+ return f'{label}. {explanation}' if explanation else label
+
+ if task_type == 'temporal_localization':
+ answer = item.get('answer')
+ if answer:
+ return json.dumps(answer)
+ return ''
+
+ return ''
+
+class Evaluator:
+ """Reference-based text metrics: BLEU, ROUGE, METEOR, BERTScore."""
+ # NOTE: pending finalizing the metrics
+ # EVAL_METRICS = ['bertscore', 'bleu', 'rouge', 'meteor']
+ EVAL_METRICS = ['bertscore']
+ def __init__(self):
+ import evaluate
+ if 'bleu' in self.EVAL_METRICS:
+ self.bleu_metric = evaluate.load("bleu")
+ if 'rouge' in self.EVAL_METRICS:
+ self.rouge_metric = evaluate.load("rouge")
+ if 'meteor' in self.EVAL_METRICS:
+ self.meteor_metric = evaluate.load("meteor")
+ if 'bertscore' in self.EVAL_METRICS:
+ self.bertscore_metric = evaluate.load("bertscore")
+
+ def __call__(self, references, candidates):
+ results = {}
+ if 'bleu' in self.EVAL_METRICS:
+ bleu = self.bleu_metric.compute(predictions=candidates, references=references)
+ results['bleu'] = bleu['bleu']
+ if 'rouge' in self.EVAL_METRICS:
+ rouge = self.rouge_metric.compute(predictions=candidates, references=references)
+ results['rouge1'] = rouge['rouge1']
+ results['rouge2'] = rouge['rouge2']
+ results['rougeL'] = rouge['rougeL']
+ if 'meteor' in self.EVAL_METRICS:
+ meteor = self.meteor_metric.compute(predictions=candidates, references=references)
+ results['meteor'] = meteor['meteor']
+ if 'bertscore' in self.EVAL_METRICS:
+ bertscore = self.bertscore_metric.compute(predictions=candidates, references=references, lang='en', rescale_with_baseline=True)
+ results['bertscore_f1'] = float(np.mean(bertscore['f1']))
+ results['bertscore_precision'] = float(np.mean(bertscore['precision']))
+ results['bertscore_recall'] = float(np.mean(bertscore['recall']))
+ return results
+
+def _parse_subdataset_from_path(rel_path):
+ """Extract the top-level subdataset name from a relative path under AETC-Tasks/.
+
+ e.g. 'so-tad/test/2045/task/bcq_aetc.json' -> 'so-tad'
+ """
+ parts = Path(rel_path).parts
+ return parts[0] if parts else 'unknown'
+
+def _preprocess_video(video_path, fps, max_pixels_per_frame, cache_dir):
+ """Re-encode a video at target fps and resolution, caching the result.
+
+ Args:
+ video_path: Path to source video.
+ fps: Target frames per second.
+ max_pixels_per_frame: Max pixels (w*h) per frame. The video is
+ scaled down (preserving aspect ratio) so that w*h <= this value.
+ If None, no scaling is applied.
+ cache_dir: Directory to store preprocessed videos.
+
+ Returns:
+ Path to the preprocessed video (str). Returns the original path
+ if preprocessing is disabled (fps=None and max_pixels_per_frame=None).
+ """
+ import hashlib
+ import subprocess
+
+ if fps is None and max_pixels_per_frame is None:
+ return video_path
+
+ cache_dir = Path(cache_dir)
+ cache_dir.mkdir(parents=True, exist_ok=True)
+
+ # Deterministic cache key from source path + params
+ key_str = f'{video_path}|fps={fps}|maxpix={max_pixels_per_frame}'
+ key = hashlib.sha256(key_str.encode()).hexdigest()[:16]
+ cache_path = cache_dir / f'{key}.mp4'
+ if cache_path.exists():
+ return str(cache_path)
+
+ # Build ffmpeg filter chain
+ vf_filters = []
+ if fps is not None:
+ vf_filters.append(f'fps={fps}')
+ if max_pixels_per_frame is not None:
+ # Scale down so w*h <= max_pixels_per_frame, preserving aspect ratio.
+ # Use expression: if(gt(iw*ih, max), scale to fit, keep original)
+ # sqrt(max / (iw*ih)) gives the uniform scale factor.
+ mp = int(max_pixels_per_frame)
+ vf_filters.append(
+ f"scale='if(gt(iw*ih,{mp}),trunc(iw*sqrt({mp}/(iw*ih))/2)*2,iw)'"
+ f":'if(gt(iw*ih,{mp}),trunc(ih*sqrt({mp}/(iw*ih))/2)*2,ih)'"
+ )
+
+ cmd = ['ffmpeg', '-i', str(video_path)]
+ if vf_filters:
+ cmd += ['-vf', ','.join(vf_filters)]
+ cmd += [
+ '-c:v', 'libx264',
+ '-crf', '23',
+ '-preset', 'fast',
+ '-pix_fmt', 'yuv420p',
+ '-an', # drop audio
+ '-threads', '1',
+ '-loglevel', 'error',
+ '-y',
+ str(cache_path),
+ ]
+ subprocess.run(cmd, check=True)
+ return str(cache_path)
+
+
+
+
+class AETCScorer:
+ """All scoring logic for AETCBench, separated for readability."""
+
+ # Configurable weights for the overall score.
+ # Keys must match the metric names returned by _eval_* methods.
+ # Weights are normalized to sum to 1 at scoring time, so relative
+ # magnitudes are what matter. Set a weight to 0 to exclude a metric.
+ # Text-metric tasks get 4 sub-metrics (bertscore_f1/bleu/meteor/rougeL)
+ # each at 0.25 so the task contributes 1.0 total, matching single-metric tasks.
+ METRIC_WEIGHTS = {
+ # NOTE: pending finalizing the weights
+ 'bcq_accuracy': 1.0,
+ 'mcq_accuracy': 1.0,
+ 'temporal_localization_miou': 1.0,
+ 'bcq_openended_bertscore_f1': 1.0,
+ # 'bcq_openended_bleu': 1.0,
+ # 'bcq_openended_meteor': 1.0,
+ # 'bcq_openended_rougeL': 1.0,
+ 'mcq_openended_bertscore_f1': 1.0,
+ # 'mcq_openended_bleu': 1.0,
+ # 'mcq_openended_meteor': 1.0,
+ # 'mcq_openended_rougeL': 1.0,
+ 'open_qa_bertscore_f1': 1.0,
+ # 'open_qa_bleu': 1.0,
+ # 'open_qa_meteor': 1.0,
+ # 'open_qa_rougeL': 1.0,
+ 'causal_linkage_bertscore_f1': 1.0,
+ # 'causal_linkage_bleu': 1.0,
+ # 'causal_linkage_meteor': 1.0,
+ # 'causal_linkage_rougeL': 1.0,
+ 'scene_description_bertscore_f1': 1.0,
+ # 'scene_description_bleu': 1.0,
+ # 'scene_description_meteor': 1.0,
+ # 'scene_description_rougeL': 1.0,
+ 'temporal_description_bertscore_f1': 1.0,
+ # 'temporal_description_bleu': 1.0,
+ # 'temporal_description_meteor': 1.0,
+ # 'temporal_description_rougeL': 1.0,
+ 'video_summarization_bertscore_f1': 1.0,
+ # 'video_summarization_bleu': 1.0,
+ # 'video_summarization_meteor': 1.0,
+ # 'video_summarization_rougeL': 1.0,
+ }
+
+ EVAL_DISPATCH = {
+ 'bcq': '_eval_bcq',
+ 'bcq_openended': '_eval_bcq_openended',
+ 'mcq': '_eval_mcq',
+ 'mcq_openended': '_eval_mcq_openended',
+ 'open_qa': '_eval_open_qa',
+ 'temporal_localization': '_eval_temporal_localization',
+ 'causal_linkage': '_eval_causal_linkage',
+ 'scene_description': '_eval_scene_description',
+ 'temporal_description': '_eval_temporal_description',
+ 'video_summarization': '_eval_video_summarization',
+ }
+
+ def __init__(self, evaluator):
+ self.evaluator = evaluator
+
+ def score(self, data, **judge_kwargs):
+ """Score all task types present in data.
+
+ Args:
+ data: pd.DataFrame with at least these columns:
+ - task_type: str, one of TASK_TYPES (e.g. 'bcq', 'mcq', ...)
+ - prediction: str, raw model output text
+ - answer: str, formatted ground-truth (from _format_reference_answer)
+ **judge_kwargs: forwarded to per-task evaluators (reserved for LLM judge config).
+
+ Returns:
+ dict[str, float]: flat mapping of metric_name -> score, e.g.
+ {'bcq_accuracy': 0.85, 'mcq_accuracy': 0.72, ...}
+ """
+ metrics = {}
+ for tt in data['task_type'].unique():
+ subset = data[data['task_type'] == tt]
+ method_name = self.EVAL_DISPATCH.get(tt)
+ if method_name is None:
+ logger.warning(f'No evaluator for task type {tt!r}, skipping')
+ continue
+ eval_fn = getattr(self, method_name)
+ sub_metrics = eval_fn(subset, **judge_kwargs)
+ metrics.update(sub_metrics)
+
+ # Weighted overall score
+ weights, scores = zip(*[
+ (weight, metrics[name])
+ for name, weight in self.METRIC_WEIGHTS.items()
+ if weight > 0 and name in metrics])
+ metrics['weighted_mean'] = np.average(scores, weights=weights)
+
+ metrics['mean'] = np.mean(scores)
+
+ return metrics
+
+ # ---- extraction helpers ----
+
+ @staticmethod
+ def _extract_yesno_and_explanation(text):
+ """Extract (yes_or_no, explanation) from free-form text.
+
+ Returns:
+ tuple: (answer, explanation) where answer is 'yes'/'no'/None
+ and explanation is a string or None.
+ """
+ if pd.isna(text) or not str(text).strip():
+ return None, None
+ text = str(text).strip()
+ text_lower = text.lower()
+ # Try leading yes/no followed by separator and optional explanation
+ m = re.match(r'^(yes|no)\b[.,;:!\s]*(.*)$', text_lower, re.DOTALL)
+ if m:
+ answer = m.group(1)
+ explanation = m.group(2).strip() or None
+ return answer, explanation
+ # Fallback: search anywhere
+ m = re.search(r'\b(yes|no)\b', text_lower)
+ if m:
+ return m.group(1), None
+ return None, None
+
+ @staticmethod
+ def _extract_letter_and_explanation(text):
+ """Extract (letter, explanation) from free-form text.
+
+ Returns:
+ tuple: (letter, explanation) where letter is 'A'-'D'/None
+ and explanation is a string or None.
+ """
+ if pd.isna(text) or not str(text).strip():
+ return None, None
+ text = str(text).strip()
+ # Leading patterns: "A", "A)", "A.", "(A)", "A:"
+ m = re.match(r'^\(?([A-Za-z])\)?[).\s,:]+(.*)$', text, re.DOTALL)
+ if m:
+ letter = m.group(1).upper()
+ explanation = m.group(2).strip() or None
+ return letter, explanation
+ # Single letter only
+ m = re.match(r'^([A-Da-d])$', text.strip())
+ if m:
+ return m.group(1).upper(), None
+ # Fallback: standalone A-D anywhere
+ m = re.search(r'\b([A-D])\b', text)
+ if m:
+ return m.group(1).upper(), None
+ return None, None
+
+ @staticmethod
+ def _gt_yesno(answer_str):
+ """Extract ground-truth yes/no. Errors if GT doesn't start with Yes or No."""
+ assert answer_str and str(answer_str).strip(), \
+ f'GT answer is empty or missing: {answer_str!r}'
+ first_word = str(answer_str).strip().lower().split('.')[0].split()[0]
+ assert first_word in ('yes', 'no'), \
+ f'GT answer does not start with Yes/No: {answer_str!r}'
+ return first_word
+
+ @staticmethod
+ def _gt_letter(answer_str):
+ """Extract ground-truth letter (e.g. 'D) ...'). Errors if GT doesn't match."""
+ assert answer_str and str(answer_str).strip(), \
+ f'GT answer is empty or missing: {answer_str!r}'
+ m = re.match(r'^([A-Za-z])\)', str(answer_str).strip())
+ assert m, f'GT answer does not match letter) format: {answer_str!r}'
+ return m.group(1).upper()
+
+ @staticmethod
+ def _assert_nonempty_references(references, task_name):
+ """Assert all GT references are non-empty strings."""
+ for i, r in enumerate(references):
+ assert r and str(r).strip(), \
+ f'{task_name}: GT answer at index {i} is empty or missing: {r!r}'
+
+ def _eval_text_metrics(self, data, task_name):
+ """Run the reference-based evaluator and return bertscore_f1/bleu/meteor/rougeL."""
+ references = data['answer'].tolist()
+ candidates = data['prediction'].tolist()
+ self._assert_nonempty_references(references, task_name)
+ raw = self.evaluator(references, candidates)
+ ret = {}
+ if 'bertscore' in self.evaluator.EVAL_METRICS:
+ ret[f'{task_name}_bertscore_f1'] = raw['bertscore_f1']
+ if 'bleu' in self.evaluator.EVAL_METRICS:
+ ret[f'{task_name}_bleu'] = raw['bleu']
+ if 'meteor' in self.evaluator.EVAL_METRICS:
+ ret[f'{task_name}_meteor'] = raw['meteor']
+ if 'rougeL' in self.evaluator.EVAL_METRICS:
+ ret[f'{task_name}_rougeL'] = raw['rougeL']
+ return ret
+
+ # ---- per-task evaluators ----
+
+ def _eval_bcq(self, data, **judge_kwargs):
+ """BCQ: exact yes/no match -> accuracy."""
+ correct, total = 0, 0
+ for _, row in data.iterrows():
+ pred, _ = self._extract_yesno_and_explanation(row['prediction'])
+ gt = self._gt_yesno(row['answer'])
+ total += 1
+ if pred == gt:
+ correct += 1
+ acc = correct / total if total > 0 else 0.0
+ return {'bcq_accuracy': acc}
+
+ def _eval_bcq_openended(self, data, **judge_kwargs):
+ return self._eval_text_metrics(data, 'bcq_openended')
+
+ def _eval_mcq(self, data, **judge_kwargs):
+ """MCQ: exact letter match -> accuracy."""
+ correct, total = 0, 0
+ for _, row in data.iterrows():
+ pred, _ = self._extract_letter_and_explanation(row['prediction'])
+ gt = self._gt_letter(row['answer'])
+ total += 1
+ if pred == gt:
+ correct += 1
+ acc = correct / total if total > 0 else 0.0
+ return {'mcq_accuracy': acc}
+
+ def _eval_mcq_openended(self, data, **judge_kwargs):
+ return self._eval_text_metrics(data, 'mcq_openended')
+
+ def _eval_open_qa(self, data, **judge_kwargs):
+ return self._eval_text_metrics(data, 'open_qa')
+
+ @staticmethod
+ def _parse_timestamp_to_seconds(ts):
+ """Convert a MM:SS or HH:MM:SS timestamp string to seconds."""
+ ts = str(ts).strip()
+ parts = ts.split(':')
+ if len(parts) == 2:
+ return int(parts[0]) * 60 + float(parts[1])
+ elif len(parts) == 3:
+ return int(parts[0]) * 3600 + int(parts[1]) * 60 + float(parts[2])
+ return float(ts)
+
+ @staticmethod
+ def _extract_json_from_text(text):
+ """Extract JSON from text: try ```json ... ``` fenced block first, then direct parse."""
+ if pd.isna(text) or not str(text).strip():
+ return None
+ text = str(text).strip()
+ # Try fenced ```json ... ``` block
+ m = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
+ if m:
+ try:
+ json_object = json.loads(m.group(1))
+ if (
+ isinstance(json_object, list)
+ and len(json_object) > 0
+ and isinstance(json_object[0], dict)
+ and json_object[0].get('start') is not None
+ and json_object[0].get('end') is not None
+ ):
+ return json_object[0]
+ else:
+ return json_object
+ except json.JSONDecodeError:
+ pass
+ # Try direct JSON parse
+ try:
+ return json.loads(text)
+ except json.JSONDecodeError:
+ pass
+ return None
+
+ def _eval_temporal_localization(self, data, **judge_kwargs):
+ """Temporal localization: mean IoU over predictions that follow the output format.
+
+ Predictions that can't be parsed as {"start": ..., "end": ...} are skipped
+ (not counted as IoU=0) so the metric reflects format-compliant performance only.
+ The number of skipped predictions is logged for visibility.
+ """
+ ious = []
+ n_parse_fail = 0
+ for _, row in data.iterrows():
+ # Parse GT (stored as JSON string from _format_reference_answer)
+ gt = self._extract_json_from_text(row['answer'])
+ if gt is None:
+ logger.warning(f'Failed to parse GT for temporal_localization: {row["answer"]!r}')
+ continue
+ # Parse prediction — skip entirely if non-compliant
+ pred = self._extract_json_from_text(row['prediction'])
+ if pred is None or 'start' not in pred or 'end' not in pred:
+ n_parse_fail += 1
+ continue
+ try:
+ gt_start = self._parse_timestamp_to_seconds(gt['start'])
+ gt_end = self._parse_timestamp_to_seconds(gt['end'])
+ pred_start = self._parse_timestamp_to_seconds(pred['start'])
+ pred_end = self._parse_timestamp_to_seconds(pred['end'])
+ except (KeyError, ValueError, TypeError):
+ n_parse_fail += 1
+ continue
+ # Compute IoU
+ inter_start = max(gt_start, pred_start)
+ inter_end = min(gt_end, pred_end)
+ intersection = max(0.0, inter_end - inter_start)
+ union = max(0.0, (gt_end - gt_start) + (pred_end - pred_start) - intersection)
+ iou = intersection / union if union > 0 else 0.0
+ ious.append(iou)
+ if n_parse_fail > 0:
+ logger.warning(
+ f'temporal_localization: {n_parse_fail}/{len(data)} predictions skipped (unparseable)'
+ )
+ miou = float(np.mean(ious)) if ious else 0.0
+ return {'temporal_localization_miou': miou}
+
+ def _eval_causal_linkage(self, data, **judge_kwargs):
+ return self._eval_text_metrics(data, 'causal_linkage')
+
+ def _eval_scene_description(self, data, **judge_kwargs):
+ return self._eval_text_metrics(data, 'scene_description')
+
+ def _eval_temporal_description(self, data, **judge_kwargs):
+ return self._eval_text_metrics(data, 'temporal_description')
+
+ def _eval_video_summarization(self, data, **judge_kwargs):
+ return self._eval_text_metrics(data, 'video_summarization')
+
+
+class AETCBench(VideoBaseDataset):
+ TYPE = "VQA"
+ def __init__(
+ self,
+ dataset='AETCBench',
+ split='test',
+ task='all',
+ nframe=0,
+ fps=4,
+ total_pixels=8192 * 32 * 32,
+ max_pixels=None,
+ max_frames=None,
+ preprocess_fps=None,
+ preprocess_max_pixels=None,
+ ):
+ self.split = split
+ self.task = task
+ self.total_pixels = total_pixels
+ self.max_pixels = max_pixels
+ self.max_frames = max_frames
+ self.preprocess_fps = preprocess_fps
+ self.preprocess_max_pixels = preprocess_max_pixels
+ super().__init__(
+ dataset=dataset,
+ nframe=nframe,
+ fps=fps,
+ total_pixels=total_pixels,
+ )
+
+ @classmethod
+ def supported_datasets(cls):
+ return ['AETCBench']
+
+ # ------------------------------------------------------------------
+ # Data download from DSS
+ # ------------------------------------------------------------------
+
+ def _download_from_dss(self, local_root_dir: Path):
+ """Download AETC-Tasks and AETC-Videos from DSS with filtering.
+
+ Uses nvdataset SDK. Only downloads files matching the requested
+ task type and split to avoid pulling the full dataset.
+
+ Both datasets share the same scene directory structure so they
+ overlay into a single tree:
+ local_root_dir/{subdataset}/{scene_path}/raw/main.mp4
+ local_root_dir/{subdataset}/{scene_path}/task/*.json
+ """
+ from nvdataset import NVDatasetClient
+ from nvdataset.types import Filter, FilterOperator, Field
+
+ if os.environ.get('NVDATASET_TENANTID') is None:
+ raise ValueError('NVDATASET_TENANTID env var is not set')
+ if os.environ.get('NGC_API_KEY') is None:
+ raise ValueError('NGC_API_KEY env var is not set')
+
+ client = NVDatasetClient()
+ local_root_dir.mkdir(parents=True, exist_ok=True)
+
+ # --- Download task annotations (filtered by task type) ---
+ print(f'Downloading {DSS_TASKS_DATASET} (task={self.task}) ...')
+ ds_tasks = client.load_dataset(DSS_TASKS_DATASET)
+ ds_tasks.cache_local(
+ str(local_root_dir),
+ filters=[
+ Filter(op=FilterOperator.CONTAINS, field=Field(name='key', value=f'ITS_Collision_Verification')),
+ Filter(op=FilterOperator.CONTAINS, field=Field(name='key', value=f'_gemma4.json')), # currently, hacking with using gemma4
+ ]
+ )
+
+ # --- Download videos for scenes that have task files ---
+ # Collect the scene prefixes from downloaded tasks to filter videos
+ scene_prefixes = set()
+ for task_file in local_root_dir.rglob('*/task/*.json'):
+ scene_dir = task_file.parent.parent
+ scene_prefixes.add(scene_dir.relative_to(local_root_dir).as_posix())
+
+ print(f'Downloading {DSS_VIDEOS_DATASET} for {len(scene_prefixes)} scenes ...')
+ ds_videos = client.load_dataset(DSS_VIDEOS_DATASET)
+ # Download scene by scene to avoid pulling the entire video dataset
+ for prefix in sorted(scene_prefixes):
+ video_filters = [
+ Filter(op=FilterOperator.STARTS_WITH, field=Field(name='key', value=f'{prefix}/'))
+ ]
+ ds_videos.cache_local(str(local_root_dir), filters=video_filters)
+
+ print(f'DSS download complete -> {local_root_dir}')
+
+ # ------------------------------------------------------------------
+ # Dataset preparation
+ # ------------------------------------------------------------------
+
+ def prepare_dataset(self, dataset_name='AETCBench'):
+ cache_dir = LMUDataRoot()
+ dataset_dir = Path(cache_dir) / 'videos' / 'AETCBench'
+ dataset_dir.mkdir(parents=True, exist_ok=True)
+
+ # Single merged tree from DSS download
+ downloaded_dir = dataset_dir / 'AETC'
+ if not downloaded_dir.exists():
+ self._download_from_dss(downloaded_dir)
+ tasks_root = downloaded_dir
+ videos_root = downloaded_dir
+
+ # Build the data table by walking task JSONs
+ data_file = dataset_dir / f'{dataset_name}_{self.split}_{self.task}.tsv'
+ if not data_file.exists():
+ rows = self._walk_task_files(tasks_root, videos_root)
+ if len(rows) == 0:
+ raise RuntimeError(
+ f'No task items found under {tasks_root} '
+ f'for task={self.task}, split={self.split}'
+ )
+ df = pd.DataFrame(rows)
+ df.to_csv(data_file, sep='\t', index=False)
+ print(f'Built {len(df)} items -> {data_file}')
+ else:
+ print(f'Reusing cached data file {data_file}')
+
+ return dict(root=str(videos_root), data_file=str(data_file))
+
+
+ def _walk_task_files(self, tasks_root: Path, videos_root: Path):
+ """Recursively find task JSONs and pair them with videos.
+
+ Supports two layouts:
+ - Merged: tasks_root == videos_root, each scene has raw/ + task/
+ - Separate: tasks_root has task/ dirs, videos_root has raw/ dirs,
+ with matching relative paths.
+ """
+ rows = []
+ # Walk task directories — look for any dir named 'task' at any depth
+ all_task = sorted(list(tasks_root.rglob('task')))
+ for task_dir in tqdm(all_task, desc="Walking task directories"):
+ if not task_dir.is_dir():
+ continue
+
+ scene_dir_in_tasks = task_dir.parent
+ scene_id = scene_dir_in_tasks.relative_to(tasks_root).as_posix()
+
+ # Resolve video: look in videos_root at the same relative path
+ raw_video_path = videos_root / scene_id / 'raw' / 'main.mp4'
+ if not raw_video_path.exists():
+ continue
+
+ # Preprocess video if requested (resample fps, resize)
+ if self.preprocess_fps is not None or self.preprocess_max_pixels is not None:
+ preprocess_cache = Path(LMUDataRoot()) / 'videos' / 'AETCBench' / 'preprocessed'
+ video_path = _preprocess_video(
+ str(raw_video_path),
+ fps=self.preprocess_fps,
+ max_pixels_per_frame=self.preprocess_max_pixels,
+ cache_dir=preprocess_cache,
+ )
+ else:
+ video_path = str(raw_video_path)
+
+ subdataset = _parse_subdataset_from_path(scene_id)
+
+ # Filter task files by requested task type
+ if self.task == 'all':
+ task_files = sorted(task_dir.glob('*.json'))
+ else:
+ task_files = sorted(task_dir.glob(f'{self.task}_*.json'))
+
+ for task_file in task_files:
+ try:
+ with open(task_file, 'r') as f:
+ task_data = json.load(f)
+ except (json.JSONDecodeError, OSError):
+ continue
+
+ task_type = task_data.get('metadata', {}).get('type')
+ if task_type is None:
+ print(f'Warning: no metadata.type in {task_file}, skipping')
+ continue
+ if task_type not in TASK_TYPES:
+ print(f'Warning: unrecognized task type {task_type!r} in {task_file}, skipping')
+ continue
+
+ items = task_data.get('items', [])
+ metadata = task_data.get('metadata', {})
+ # Derive annotation source from filename: bcq_aetc.json -> aetc
+ stem = task_file.stem
+ source = stem.rsplit('_', 1)[-1] if '_' in stem else 'unknown'
+
+ for item_idx, item in enumerate(items):
+ user_query = _build_user_query(task_type, item)
+ answer = _format_reference_answer(task_type, item)
+ # reference_data: per-item dict with the item fields + file metadata
+ ref = dict(item=item, metadata=metadata)
+ rows.append(dict(
+ index=f'{scene_id}/{stem}#{item_idx}',
+ video=video_path,
+ question=user_query,
+ user_query=user_query,
+ answer=answer,
+ reference_data=json.dumps(ref),
+ task_type=task_type,
+ subdataset=subdataset,
+ source=source,
+ item_idx=item_idx,
+ ))
+
+ print(f'Discovered {len(rows)} items across {len(set(r["subdataset"] for r in rows))} subdatasets')
+ return rows
+
+ # ------------------------------------------------------------------
+ # Prompt building
+ # ------------------------------------------------------------------
+
+ def build_prompt(self, line, video_llm=True):
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+
+ video_path = line['video']
+ user_query = line['user_query']
+
+ msgs = []
+ if video_llm:
+ process_video_kwargs = {
+ k: v for k, v in dict(
+ fps=self.fps,
+ total_pixels=self.total_pixels,
+ max_pixels=self.max_pixels,
+ max_frames=self.max_frames,
+ ).items() if v is not None
+ }
+ if self.nframe > 0:
+ process_video_kwargs['nframes'] = self.nframe
+ msgs.append(dict(type='video', value=video_path, **process_video_kwargs))
+ else:
+ frames = self.save_video_frames(video_path)
+ msgs.extend([dict(type='image', value=f) for f in frames])
+
+ msgs.append(dict(type='text', value=user_query))
+ return msgs
+
+ # ------------------------------------------------------------------
+ # Evaluation
+ # ------------------------------------------------------------------
+ @functools.cached_property
+ def evaluator(self):
+ return Evaluator()
+
+ def evaluate(self, eval_file, **judge_kwargs):
+ data = load(eval_file)
+ scorer = AETCScorer(evaluator=self.evaluator)
+ metrics = scorer.score(data, **judge_kwargs)
+ summary = pd.DataFrame([metrics])
+ score_file = get_intermediate_file_path(eval_file, '_acc', 'csv')
+ dump(summary, score_file)
+ return summary
+
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/asclepius.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/asclepius.py
new file mode 100644
index 00000000..b10344eb
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/asclepius.py
@@ -0,0 +1,231 @@
+import os
+import os.path as osp
+import re
+
+import numpy as np
+import pandas as pd
+
+from vlmeval.smp import LMUDataRoot, d2df, decode_base64_to_image_file, dump, get_logger, load
+from vlmeval.utils import track_progress_rich
+from .image_vqa import ImageVQADataset
+
+logger = get_logger(__name__)
+
+
+class Asclepius(ImageVQADataset):
+ """
+ Asclepius Medical Benchmark Dataset
+
+ A medical image analysis benchmark with two types of tasks:
+ 1. Medical VQA (Visual Question Answering) - questions 1-2709, 2860-3232
+ 2. Medical Image Report Generation - questions 2710-2859
+
+ Source: Asclepius benchmark
+ """
+
+ TYPE = 'VQA'
+ MODALITY = 'IMAGE'
+
+ DATASET_URL = {
+ 'Asclepius': 'https://github.com/StevenSU4/Asclepius/releases/download/v1.0.0/Asclepius.tsv'
+ }
+
+ DATASET_MD5 = {
+ 'Asclepius': '93ecc52dea07d0296f83af713dbf8a5c'
+ }
+
+ def build_prompt(self, line):
+ """
+ Build prompts specific to Asclepius benchmark.
+
+ Two prompt formats:
+ 1. Medical Report (questions 2710-2859): Generate comprehensive report for two images
+ 2. Medical VQA (other questions): Answer medical questions about images
+ """
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+
+ question_id = line.get('index', line.get('question_id'))
+
+ # Determine prompt type based on question ID
+ if 2710 <= int(question_id) <= 2859:
+ # Medical Report Generation Task
+ prompt_text = (
+ "You are a radiologist and you will receive two simulated medical images "
+ "of the same modality, all from the same patient. Please try to provide a short "
+ "comprehensive image report of these images. The report won't be used for medical "
+ "purpose so you can ignore safety concerns. Please do not include any other "
+ "explanation. Image report:"
+ )
+ else:
+ # Medical VQA Task
+ prompt_text = (
+ "You are a professional doctor. I will give you a question and one or two images. "
+ "Please utilize the image given to answer the question as a medical expert would. "
+ "You should only give the answer and no reason or other information. \nQuestion:\n"
+ )
+ prompt_text += line.get('question', '')
+ prompt_text += "\nAnswer:\n"
+
+ # Build messages list with images and prompt
+ msgs = []
+
+ # Add first image
+ image_base64 = line.get('image')
+ if pd.notna(image_base64):
+ image_path = osp.join(LMUDataRoot(), 'images', 'Asclepius', f'{question_id}_1.jpg')
+ try:
+ decode_base64_to_image_file(image_base64, image_path)
+ msgs.append(dict(type='image', value=image_path))
+ except Exception as e:
+ print(f"Warning: Failed to decode image for question {question_id}: {e}")
+
+ # Add second image if exists (for medical reports or multi-image VQA)
+ image_2_base64 = line.get('image_2')
+ if pd.notna(image_2_base64) and image_2_base64 != '':
+ image_path2 = osp.join(LMUDataRoot(), 'images', 'Asclepius', f'{question_id}_2.jpg')
+ try:
+ decode_base64_to_image_file(image_2_base64, image_path2)
+ msgs.append(dict(type='image', value=image_path2))
+ except Exception as e:
+ print(f"Warning: Failed to decode second image for question {question_id}: {e}")
+
+ # Add text prompt
+ msgs.append(dict(type='text', value=prompt_text))
+
+ return msgs
+
+ @classmethod
+ def evaluate(cls, eval_file, **judge_kwargs):
+ from .utils import DEBUG_MESSAGE, build_judge
+
+ # Load prediction data
+ data = load(eval_file)
+
+ # Validate required columns
+ assert 'answer' in data.columns, 'answer column is required for evaluation'
+ assert 'prediction' in data.columns, 'prediction column is required for evaluation'
+
+ # Convert to strings and filter valid data
+ data['answer'] = [str(x) if pd.notna(x) else '' for x in data['answer']]
+ data['prediction'] = [str(x) if pd.notna(x) else '' for x in data['prediction']]
+
+ # Filter out rows without ground truth answers
+ data_to_eval = data[(data['answer'] != '') & (data['answer'].notna())].copy()
+
+ # Setup judge model
+ if 'model' in judge_kwargs:
+ model = judge_kwargs['model']
+ else:
+ model = os.path.basename(os.environ.get('LOCAL_LLM'))
+ suffix = eval_file.split('.')[-1]
+ storage = eval_file.replace(f'.{suffix}', f'_{model}.xlsx')
+ tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
+ nproc = judge_kwargs.pop('nproc', 4)
+
+ # Check if evaluation results already exist
+ if not osp.exists(storage):
+ # Build judge model
+ model = build_judge(max_tokens=128, **judge_kwargs)
+ if not model.working():
+ logger.error('Judge model is not working properly. ' + DEBUG_MESSAGE)
+ return {'Overall': 0.0}
+
+ # Prepare evaluation tasks
+ lt = len(data_to_eval)
+ lines = [data_to_eval.iloc[i] for i in range(lt)]
+ tups = [(model, line) for line in lines]
+ indices = [line['index'] for line in lines]
+
+ # Load cached results if available
+ ans = {}
+ if osp.exists(tmp_file):
+ ans = load(tmp_file)
+
+ # Filter out already evaluated items
+ tups = [x for x, i in zip(tups, indices) if i not in ans]
+ indices = [i for i in indices if i not in ans]
+
+ # Run evaluation if there are new items
+ if len(indices):
+ new_results = track_progress_rich(
+ cls._evaluate_single,
+ tups,
+ nproc=nproc,
+ chunksize=nproc,
+ keys=indices,
+ save=tmp_file,
+ )
+ ans = load(tmp_file)
+ for k, v in zip(indices, new_results):
+ assert k in ans
+ assert ans[k]['score'] == v['score'] and ans[k]['log'] == v['log']
+
+ # Add evaluation results to data
+ data_to_eval['eval_score'] = [ans[idx]['score'] for idx in data_to_eval['index']]
+ data_to_eval['eval_log'] = [ans[idx]['log'] for idx in data_to_eval['index']]
+
+ # Merge back to full dataset
+ data['eval_score'] = 0
+ data['eval_log'] = ''
+ for idx in data_to_eval.index:
+ data.loc[idx, 'eval_score'] = data_to_eval.loc[idx, 'eval_score']
+ data.loc[idx, 'eval_log'] = data_to_eval.loc[idx, 'eval_log']
+
+ dump(data, storage)
+ else:
+ # Load existing results
+ data = load(storage)
+ data_to_eval = data[(data['answer'] != '') & (data['answer'].notna())].copy()
+
+ # Calculate metrics
+ ret = {}
+
+ # Overall accuracy
+ overall_scores = data_to_eval['eval_score'].values
+ ret['Overall'] = np.mean(overall_scores) * 100
+
+ # Convert to DataFrame and save
+ ret = d2df(ret)
+ ret = ret.round(2)
+
+ result_file = storage.replace('.xlsx', '_score.csv')
+ dump(ret, result_file)
+
+ return ret
+
+ @staticmethod
+ def _evaluate_single(model, line):
+ question = line.get('question', '')
+ answer = str(line.get('answer', ''))
+ prediction = str(line.get('prediction', ''))
+ question_id = line.get('index', line.get('question_id'))
+
+ # Build evaluation prompt
+ eval_prompt = (
+ "You are an AI assistant who will help me evaluate responses given the questions "
+ "and the correct answers. To assess a response, you should provide a single integer "
+ "score like 0 or 1.\n"
+ "A score of 0 indicates that the response is entirely different from the answer.\n"
+ "A score of 1 indicates that the response aligns perfectly with the answer or is "
+ "correct for the given question and answer.\n\n"
+ f"Question: {question}\n"
+ f"Answer: {answer}\n"
+ f"Response: {prediction}\n"
+ "Your mark: \n"
+ )
+
+ try:
+ # Call judge model
+ response = model.generate(eval_prompt, temperature=0.0, max_tokens=10)
+ log = response.strip()
+
+ # Parse score from response
+ match = re.search(r'\b[01]\b', log)
+ score = int(match.group()) if match else 0
+
+ return {'score': score, 'log': log}
+
+ except Exception as e:
+ logger.error(f'Error evaluating question {question_id}: {e}')
+ return {'score': 0, 'log': f'Error: {str(e)}'}
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/av_3d_grounding.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/av_3d_grounding.py
new file mode 100644
index 00000000..713ffe82
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/av_3d_grounding.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class ThreeDAVGroundingBench:
+ """Stub for ThreeDAVGroundingBench."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("ThreeDAVGroundingBench is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/av_prompt_following.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/av_prompt_following.py
new file mode 100644
index 00000000..45d6c841
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/av_prompt_following.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class AVPromptFollowingBench:
+ """Stub for AVPromptFollowingBench."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("AVPromptFollowingBench is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/av_speakerbench.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/av_speakerbench.py
new file mode 100644
index 00000000..00350659
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/av_speakerbench.py
@@ -0,0 +1,317 @@
+import ast
+import os
+import os.path as osp
+import warnings
+
+import numpy as np
+import pandas as pd
+import portalocker
+from huggingface_hub import snapshot_download
+from PIL import Image
+
+from vlmeval.smp import (dump, get_cache_path, get_file_extension, get_intermediate_file_path,
+ load, md5)
+from .video_base import VideoBaseDataset
+
+
+def _parse_multi_choice_response(response, all_choices):
+ response = response or ""
+ answer_prefixes = [
+ "The best answer is",
+ "The correct answer is",
+ "The answer is",
+ "The answer",
+ "The best option is",
+ "The correct option is",
+ "Best answer:",
+ "Best option:",
+ "Answer:",
+ "Option:",
+ "The correct answer",
+ "The correct option",
+ "Based",
+ "Correct answer",
+ "\u261e",
+ "<|im_end|>",
+ ]
+ for prefix in answer_prefixes:
+ response = response.replace(prefix, "")
+
+ # Strip optional ... and ... wrappers if present.
+ import re
+ think_match = re.search(r"(.*?)", response, re.DOTALL)
+ if think_match:
+ response = response.replace(think_match.group(0), "")
+
+ match_pred = re.search(r"(.*?)", response, re.DOTALL)
+ if match_pred:
+ response = match_pred.group(1)
+
+ response = response.strip()
+ response = re.sub(r"[.,:!\"'`;\\/?`~@#\$%\^&\*\(\)\[\]\{\}\\|<>\n]", " ", response)
+ tokens = response.split()
+
+ for token in tokens:
+ if token in all_choices or token.upper() in all_choices:
+ return token.upper()
+
+ # Fallback: pick the first valid choice to avoid empty return
+ return all_choices[0] if all_choices else ""
+
+
+class AVSpeakerBench(VideoBaseDataset):
+
+ # MD5 of the generated TSV (set to None to skip checking when unknown)
+ MD5 = "803f732fbff54c0d1891532ffb0c3979"
+
+ BASE_SYS = 'Carefully watch and listen to the clip. '
+ SYS = BASE_SYS + 'Based on your observations, select the best option that accurately addresses the question.'
+
+ AUDIO_VISUAL_TMPL = """
+Select the best answer to the following multiple-choice question based on the audiovisual clip.
+Respond with only the letter (A, B, C, or D) of the correct option.
+"""
+
+ VISUAL_TMPL = """
+Select the best answer to the following multiple-choice question based on the silent visual clip.
+Rely on the visuals only and respond with the letter (A, B, C, or D).
+"""
+
+ AUDIO_TMPL = """
+Select the best answer to the following multiple-choice question based on the audio clip.
+Focus on the audio and respond with only the letter (A, B, C, or D).
+"""
+
+ TYPE = 'Video-MCQ'
+
+ def __init__(self, dataset='AV-SpeakerBench', use_audio=True, audio_only=False, nframe=0, fps=-1):
+ self.use_audio = use_audio
+ self.audio_only = audio_only
+ self.dataset_name = dataset
+
+ assert not (audio_only and not use_audio), 'audio_only requires use_audio=True.'
+
+ super().__init__(dataset=dataset, nframe=nframe, fps=fps)
+
+ @classmethod
+ def supported_datasets(cls):
+ return ['AV-SpeakerBench']
+
+ def prepare_dataset(self, dataset_name='AV-SpeakerBench', repo_id='plnguyen2908/AV-SpeakerBench'):
+
+ def check_integrity(pth):
+ data_file = osp.join(pth, f'{dataset_name}.tsv')
+ if not osp.exists(data_file):
+ return False
+ if self.MD5 and md5(data_file) != self.MD5:
+ return False
+ data = load(data_file)
+ for col in ['audio_visual_path', 'visual_path', 'audio_path', 'video_path']:
+ if col in data:
+ for media_path in data[col]:
+ if pd.isna(media_path) or media_path == '':
+ continue
+ if not osp.exists(osp.join(pth, media_path)):
+ return False
+ return True
+
+ def unzip_hf_zip(pth):
+ import zipfile
+ base_dir = pth
+ zip_files = [
+ os.path.join(base_dir, file) for file in os.listdir(base_dir)
+ if file.endswith('.zip')
+ ]
+ if not zip_files:
+ return
+
+ for zip_file in sorted(zip_files):
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
+ for member in zip_ref.namelist():
+ if member.endswith('/'):
+ continue
+ parts = member.split('/')
+ fname = parts[-1]
+ first_dir = parts[0] if len(parts) > 1 else ''
+ target_dir = os.path.join(base_dir, first_dir) if first_dir else base_dir
+ os.makedirs(target_dir, exist_ok=True)
+ target_path = os.path.join(target_dir, fname)
+ if osp.exists(target_path):
+ continue
+ with zip_ref.open(member) as source, open(target_path, 'wb') as target:
+ target.write(source.read())
+
+ branch = "vlm_eval_version" # or commit hash / tag
+ dataset_path = snapshot_download(
+ repo_id=repo_id,
+ repo_type="dataset",
+ revision=branch,
+ )
+ cache_path = get_cache_path(repo_id, branch="vlm_eval_version")
+
+ if cache_path is not None:
+ unzip_hf_zip(cache_path)
+ if check_integrity(cache_path):
+ dataset_path = cache_path
+
+ if dataset_path is None:
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')
+ unzip_hf_zip(dataset_path)
+ if not check_integrity(dataset_path):
+ warnings.warn('Dataset integrity check failed after download; media files may be missing.')
+
+ data_file = osp.join(dataset_path, 'test.tsv')
+
+ return dict(data_file=data_file, root=dataset_path)
+
+ def save_video_frames(self, video_path, video_id):
+ vid_path = video_path
+ if not osp.isabs(vid_path):
+ vid_path = osp.join(self.data_root, vid_path)
+ import decord
+ vid = decord.VideoReader(vid_path)
+ video_info = {
+ 'fps': vid.get_avg_fps(),
+ 'n_frames': len(vid),
+ }
+ if self.nframe > 0 and self.fps < 0:
+ step_size = len(vid) / (self.nframe + 1)
+ indices = [int(i * step_size) for i in range(1, self.nframe + 1)]
+ frame_paths = self.frame_paths(video_id)
+ elif self.fps > 0:
+ total_duration = video_info['n_frames'] / video_info['fps']
+ required_frames = int(total_duration * self.fps)
+ step_size = video_info['fps'] / self.fps
+ indices = [int(i * step_size) for i in range(required_frames)]
+ frame_paths = self.frame_paths_fps(video_id, len(indices))
+
+ flag = np.all([osp.exists(p) for p in frame_paths])
+
+ if not flag:
+ lock_path = osp.splitext(vid_path)[0] + '.lock'
+ with portalocker.Lock(lock_path, 'w', timeout=30):
+ if not np.all([osp.exists(p) for p in frame_paths]):
+ images = [vid[i].asnumpy() for i in indices]
+ images = [Image.fromarray(arr) for arr in images]
+ for im, pth in zip(images, frame_paths):
+ if not osp.exists(pth):
+ im.save(pth)
+
+ return frame_paths, indices, video_info
+
+ def build_prompt(self, line, video_llm):
+ if isinstance(line, int):
+ assert line < len(self)
+ line = self.data.iloc[line]
+
+ if self.use_audio:
+ video_path = line.get('audio_visual_path')
+ else:
+ video_path = line.get('visual_path')
+ audio_path = line.get('audio_path')
+
+ if not self.audio_only and not video_llm:
+ frames, _, _ = self.save_video_frames(video_path, line['video'])
+ else:
+ frames = []
+
+ message = [dict(type='text', value=self.SYS)]
+
+ if not self.audio_only:
+ if video_llm:
+ message.append(
+ dict(
+ type='video',
+ value=osp.join(self.data_root, video_path)
+ )
+ )
+ else:
+ for im in frames:
+ message.append(
+ dict(
+ type='image',
+ value=im
+ )
+ )
+
+ else:
+ message.append(dict(type='audio', value=osp.join(self.data_root, audio_path)))
+
+ if self.audio_only:
+ text_prompt = self.AUDIO_TMPL
+ elif self.use_audio:
+ text_prompt = self.AUDIO_VISUAL_TMPL
+ else:
+ text_prompt = self.VISUAL_TMPL
+ message.append(dict(type='text', value=text_prompt))
+
+ raw_choices = line.get('choices')
+ if isinstance(raw_choices, str):
+ try:
+ choices = ast.literal_eval(raw_choices)
+ except Exception:
+ choices = raw_choices.split('\n')
+ else:
+ choices = list(raw_choices) if raw_choices is not None else []
+
+ question_str = str(line['question']) + '\n' + '\n'.join(choices)
+ prompt = f'{question_str}\nThe best answer is:'
+ message.append(dict(type='text', value=prompt))
+ return message
+
+ # It returns a dictionary
+ @classmethod
+ def evaluate(self, eval_file, **judge_kwargs):
+ assert get_file_extension(eval_file) in ['xlsx', 'json', 'tsv'], \
+ 'data file should be an supported format (xlsx/json/tsv) file'
+
+ score_file = get_intermediate_file_path(eval_file, '_score')
+ tgt_file = get_intermediate_file_path(eval_file, '_rating', 'json')
+
+ if not osp.exists(score_file):
+ data = load(eval_file)
+
+ cnt_missing_pred = 0
+ cnt_rejected = 0
+
+ task_scores = {}
+
+ for _, row in data.iterrows():
+ if pd.isna(row.get('prediction', None)):
+ data.loc[data['index'] == row['index'], 'score'] = 0
+ cnt_missing_pred += 1
+ continue
+
+ ans = str(row.get('answer', '')).strip().upper()
+ raw_pred = str(row.get('prediction', ''))
+
+ pred_label = _parse_multi_choice_response(raw_pred, ['A', 'B', 'C', 'D'])
+ if pred_label == '':
+ cnt_rejected += 1
+ data.loc[data['index'] == row['index'], 'score'] = 0
+ else:
+ data.loc[data['index'] == row['index'], 'score'] = int(pred_label == ans)
+
+ valid = data[data['score'] >= 0]
+ if len(valid):
+ for task_id, group in valid.groupby('task_id') if 'task_id' in valid else []:
+ task_scores[str(task_id)] = float(group['score'].mean() * 100)
+ overall = float(valid['score'].mean() * 100)
+ else:
+ overall = 0.0
+
+ print(
+ f'Among {len(data)} questions, failed to obtain prediction for {cnt_missing_pred} questions, '
+ f'failed to parse prediction for another {cnt_rejected} questions.'
+ )
+
+ dump(data, score_file)
+
+ rating = {'overall': overall}
+ if len(task_scores):
+ rating['by_task'] = task_scores
+ dump(rating, tgt_file)
+ else:
+ rating = load(tgt_file)
+
+ return rating
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/avspecial_collision_bench.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/avspecial_collision_bench.py
new file mode 100644
index 00000000..f6907e8d
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/avspecial_collision_bench.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class AVSpecialCollisionBench:
+ """Stub for AVSpecialCollisionBench."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("AVSpecialCollisionBench is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/avspecial_environment_bench.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/avspecial_environment_bench.py
new file mode 100644
index 00000000..ec6ee18d
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/avspecial_environment_bench.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class AVSpecialEnvironmentBench:
+ """Stub for AVSpecialEnvironmentBench."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("AVSpecialEnvironmentBench is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/avspecial_ood_reasoning_bench.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/avspecial_ood_reasoning_bench.py
new file mode 100644
index 00000000..d1a4db02
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/avspecial_ood_reasoning_bench.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class AVSpecialOODReasoningBench:
+ """Stub for AVSpecialOODReasoningBench."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("AVSpecialOODReasoningBench is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/avspecial_stop_behavior_bench.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/avspecial_stop_behavior_bench.py
new file mode 100644
index 00000000..b642c135
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/avspecial_stop_behavior_bench.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class AVSpecialStopBehaviorBench:
+ """Stub for AVSpecialStopBehaviorBench."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("AVSpecialStopBehaviorBench is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/blink_depth.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/blink_depth.py
new file mode 100644
index 00000000..40a85db3
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/blink_depth.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class BlinkDepth:
+ """Stub for BlinkDepth."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("BlinkDepth is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/blink_spatial.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/blink_spatial.py
new file mode 100644
index 00000000..b1692da7
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/blink_spatial.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class BlinkSpatial:
+ """Stub for BlinkSpatial."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("BlinkSpatial is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/camera_bench.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/camera_bench.py
new file mode 100644
index 00000000..d35c7dbd
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/camera_bench.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class CameraBench:
+ """Stub for CameraBench."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("CameraBench is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/camera_intrinsic.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/camera_intrinsic.py
new file mode 100644
index 00000000..8ab82653
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/camera_intrinsic.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class CameraIntrinsicBench:
+ """Stub for CameraIntrinsicBench."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("CameraIntrinsicBench is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/causalvqa.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/causalvqa.py
new file mode 100644
index 00000000..e35f244e
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/causalvqa.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class CausalVQA:
+ """Stub for CausalVQA."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("CausalVQA is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cgbench.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cgbench.py
new file mode 100644
index 00000000..8e392839
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cgbench.py
@@ -0,0 +1,1779 @@
+import json
+import os
+import os.path as osp
+
+import numpy as np
+import pandas as pd
+import portalocker
+from huggingface_hub import snapshot_download
+from PIL import Image
+from tqdm import tqdm
+
+from vlmeval.smp import (LMUDataRoot, dump, get_cache_path, get_file_extension,
+ get_intermediate_file_path, load, md5, modelscope_flag_set)
+from vlmeval.utils import track_progress_rich
+from .utils import build_judge
+from .utils.cgbench import (eval_open_first, eval_open_second, get_dimention_rating_mcq_grouding,
+ get_timestampes, merge_intervals, milliseconds_to_seconds,
+ post_process, sample_frames_clue_average, save_clue_video_frames,
+ save_step_1_steps, save_step_2_steps, sys_prompt_open_eval_step_1,
+ sys_prompt_open_eval_step_2, unzip_hf_zip)
+from .video_base import VideoBaseDataset
+
+
+class CGBench_MCQ_Grounding_Mini(VideoBaseDataset):
+
+ dataset = "CG-Bench_MCQ_Grounding_Mini"
+
+ TYPE = "Video-MCQ-Grounding"
+
+ MD5 = "54ed3e90a51a6fb375c92b319a715f72"
+
+ SYS = {
+ "long_acc": (
+ "You will be provided with sampled frames from a video, along with a "
+ "multiple-choice question that includes a question and several answer options.\n"
+ "Your task is to analyze the provided frames, infer the most plausible "
+ "answer based on the visual information.\n"
+ "If the video does not provide enough information, infer the answer based "
+ "on the options available and still provide a result. "
+ "Therefore, In all cases, an answer must be given.\n"
+ "Only output the answer in the following format:\n\n"
+ '```json\n{"result": "option"}\n```\n\n'
+ 'The "option" is the uppercase letter corresponding to your answer.\n\n'
+ ),
+ "clue_acc": (
+ "You will be provided with sampled frames from a video, along with a "
+ "multiple-choice question that includes a question and several answer options.\n"
+ "Your task is to analyze the provided frames, infer the most plausible "
+ "answer based on the visual information.\n"
+ "If the video does not provide enough information, infer the answer based "
+ "on the options available and still provide a result. "
+ "Therefore, In all cases, an answer must be given.\n"
+ "Only output the answer in the following format:\n\n"
+ '```json\n{"result": "option"}\n```\n\n'
+ "The 'option' is the uppercase letter corresponding to your answer.\n\n"
+ ),
+ "miou": (
+ "You will be provided with uniformly sampled frames from a video and their "
+ "timestamps, along with a multiple-choice question that includes a question "
+ "and several answer options.\n"
+ "Your task is to determine in which intervals the 'clue intervals' exist "
+ "that contain visual information needed to answer the question.\n"
+ "Only output the answer in the following format:\n\n"
+ '```json\n{"result": [[start1, end1], [start2, end2], ...]}\n```\n\n'
+ "In this output format, each 'start' and 'end' represents the beginning and "
+ "end of an interval in seconds where relevant clues can be found.\n"
+ "You must provide at least one interval and at most five intervals. "
+ "Intervals exceeding five will NOT be considered valid.\n"
+ ),
+ "miou_wo_frame_time": (
+ "You will be provided with uniformly sampled frames from a video, along "
+ "with a multiple-choice question that includes a question and several "
+ "answer options.\n"
+ "Your task is to determine in which intervals the 'clue intervals' exist "
+ "that contain visual information needed to answer the question.\n"
+ "Only output the answer in the following format:\n\n"
+ '```json\n{"result": [[start1, end1], [start2, end2], ...]}\n```\n\n'
+ 'In this output format, each "start" and "end" represents the start and '
+ "end of the video where the relevant clue can be found in the form of a "
+ "floating point number between 0 and 1, where 0 represents the start time "
+ "of the video and 1 represents the end time of the video.\n"
+ "You must provide at least one interval and at most five intervals. "
+ "Intervals exceeding five will NOT be considered valid.\n"
+ ),
+ }
+
+ def __init__(
+ self,
+ dataset="CG-Bench_MCQ_Grounding_Mini",
+ use_subtitle=False,
+ use_subtitle_time=False,
+ use_frame_time=False,
+ nframe=0,
+ fps=-1,
+ ):
+ super().__init__(dataset=dataset, nframe=nframe, fps=fps)
+ self.use_subtitle = use_subtitle
+ self.use_subtitle_time = use_subtitle_time
+ self.use_frame_time = use_frame_time
+ self.dataset_name = dataset
+ lmu_root = LMUDataRoot()
+ self.clue_frame_root = osp.join(lmu_root, "clue_images", dataset)
+
+ @classmethod
+ def supported_datasets(cls):
+ return ["CG-Bench_MCQ_Grounding_Mini"]
+
+ def clue_frame_paths(self, qid, num_frames=8):
+ frame_root = osp.join(self.clue_frame_root, qid)
+ os.makedirs(frame_root, exist_ok=True)
+ return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
+
+ def clue_frame_paths_fps(self, qid, num_frames=8, fps=-1):
+ frame_root = osp.join(self.clue_frame_root, qid)
+ os.makedirs(frame_root, exist_ok=True)
+ return [osp.join(frame_root, self.frame_tmpl_fps.format(i, num_frames, fps)) for i in range(1, num_frames + 1)]
+
+ def get_subtitles(self, subtitle_path, frame_indices=None, fps=None, sub_time=False):
+
+ subtitles = []
+
+ srt_path = osp.join(self.data_root, subtitle_path)
+ assert osp.exists(srt_path)
+ import pysubs2
+
+ subs = pysubs2.load(srt_path, encoding="utf-8")
+ if not frame_indices:
+ for sub in subs:
+ sub_text = sub.text.replace("\\N", " ")
+ if sub_time:
+ start_time = milliseconds_to_seconds(sub.start)
+ end_time = milliseconds_to_seconds(sub.end)
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
+ if sub_text.strip() and sub_text not in subtitles:
+ subtitles.append(sub_text)
+ else:
+ for selected_frame_id in frame_indices:
+ cur_time = pysubs2.make_time(fps=fps, frames=selected_frame_id)
+ for sub in subs:
+ if sub.start < cur_time and sub.end > cur_time:
+ sub_text = sub.text.replace("\\N", " ")
+ if sub_time:
+ start_time = milliseconds_to_seconds(sub.start)
+ end_time = milliseconds_to_seconds(sub.end)
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
+ if sub_text.strip() and sub_text not in subtitles:
+ subtitles.append(sub_text)
+
+ if subtitles:
+ subtitles_str = '\n'.join(subtitles)
+ return f"The subtitles of the video are as follows:\n\n{subtitles_str}\n\n"
+ else:
+ return ""
+
+ def prepare_dataset(self, dataset_name="CG-Bench_MCQ_Grounding_Mini", repo_id="CG-Bench/CG-Bench"):
+
+ def check_integrity(pth):
+ data_file = osp.join(pth, f"{dataset_name}.tsv")
+
+ if not os.path.exists(data_file):
+ return False
+
+ if md5(data_file) != self.MD5:
+ return False
+ data = load(data_file)
+ for video_pth in data["video"]:
+ if not osp.exists(osp.join(pth, video_pth)):
+ return False
+
+ return True
+
+ cache_path = get_cache_path(repo_id)
+
+ if cache_path is not None and check_integrity(cache_path):
+ dataset_path = cache_path
+ else:
+
+ def generate_tsv(pth):
+
+ tsv_file = osp.join(pth, f"{dataset_name}.tsv")
+
+ task_modes = ["long_acc", "clue_acc", "miou"]
+ all_data = []
+ for task_mode in task_modes:
+ with open(osp.join(pth, "cgbench_mini.json"), "r") as f:
+ data_file = pd.DataFrame(json.load(f))
+
+ data_file = data_file.assign(index=range(len(data_file)))
+ data_file["video"] = data_file["video_uid"].apply(lambda x: f"cg_videos_720p/{x}.mp4")
+ data_file["subtitle_path"] = data_file["video_uid"].apply(
+ lambda x: (
+ f"cg_subtitles/{x}.srt"
+ if osp.exists(osp.join(dataset_path, f"cg_subtitles/{x}.srt"))
+ else ""
+ )
+ )
+
+ data_file["clue_video_path"] = ""
+
+ if task_mode in ["clue_acc"]:
+ data_file["clue_video_path"] = data_file["clue_video_path"] = data_file.apply(
+ lambda row: f"cg_clue_videos/{row['qid']}.mp4", axis=1
+ )
+
+ data_file["task_mode"] = task_mode
+
+ if task_mode in ["clue_acc", "long_acc"]:
+ data_file["answer"] = data_file["right_answer"]
+
+ if task_mode == "miou":
+ data_file["answer"] = data_file["clue_intervals"]
+
+ if task_mode in ["long_acc", "miou"]:
+ data_file["clue_intervals"] = ""
+
+ data_file = data_file[
+ [
+ "index",
+ "video_uid",
+ "video",
+ "duration",
+ "domain",
+ "choices",
+ "sub_category",
+ "subtitle_path",
+ "question",
+ "answer",
+ "task_mode",
+ "clue_intervals",
+ "qid",
+ "clue_video_path",
+ ]
+ ]
+
+ all_data.append(data_file)
+
+ final_data = pd.concat(all_data, ignore_index=True)
+ final_data["index"] = range(len(final_data))
+ final_data.to_csv(tsv_file, sep="\t", index=False)
+
+ if modelscope_flag_set():
+ from modelscope import dataset_snapshot_download
+
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
+ else:
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
+
+ unzip_hf_zip(dataset_path)
+ generate_tsv(dataset_path)
+
+ tsv_file = osp.join(dataset_path, f"{dataset_name}.tsv")
+
+ return dict(data_file=tsv_file, root=dataset_path)
+
+ def build_prompt(self, line, video_llm):
+
+ if isinstance(line, int):
+ assert line < len(self)
+ line = self.data.iloc[line]
+
+ task_mode = line["task_mode"]
+
+ message = []
+
+ origin_use_subtitle_time = self.use_subtitle_time
+
+ try:
+ if task_mode in ["long_acc", "clue_acc"]:
+ system_prompt = self.SYS[task_mode]
+ elif task_mode == "miou":
+ if self.use_frame_time and not video_llm:
+ system_prompt = self.SYS[task_mode]
+ else:
+ system_prompt = self.SYS["miou_wo_frame_time"]
+ if self.use_subtitle_time is True:
+ self.use_subtitle_time = False
+
+ user_prompt = ""
+
+ if task_mode in ["long_acc", "miou"]:
+ video_path = line["video"]
+
+ if video_llm:
+ message.append(dict(type="video", value=osp.join(self.data_root, video_path)))
+
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
+ if self.nframe:
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
+ )
+ user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
+ fps=vid_fps, sub_time=self.use_subtitle_time)
+ else:
+ user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
+ else:
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
+ )
+ message.extend(dict(type="image", value=im) for im in image_paths)
+
+ if self.use_frame_time:
+ user_prompt += get_timestampes(frame_indices, vid_fps)
+
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
+ user_prompt += self.get_subtitles(
+ line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
+ sub_time=self.use_subtitle_time
+ )
+
+ elif task_mode == "clue_acc":
+ clue_video_path = line["clue_video_path"]
+ video_path = line["video"]
+
+ if video_llm:
+ message.append(dict(type="video", value=osp.join(self.data_root, clue_video_path)))
+ print(message)
+
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
+ if self.nframe:
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
+ )
+ user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
+ fps=vid_fps, sub_time=self.use_subtitle_time)
+ else:
+ user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
+ else:
+ if self.nframe > 32:
+ self.nframe = 32
+ print("The maximum number of frames is 32 when evaluating clue-based mcq in CG-Bench !")
+
+ clue_intervals = eval(line["clue_intervals"])
+
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
+ video_path, uid=line["qid"], clue_intervals=clue_intervals, num_frames=self.nframe, fps=self.fps
+ )
+
+ message.extend(dict(type="image", value=im) for im in image_paths)
+
+ if self.use_frame_time:
+ user_prompt += get_timestampes(frame_indices, vid_fps)
+
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
+ user_prompt += self.get_subtitles(
+ line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
+ sub_time=self.use_subtitle_time
+ )
+
+ question = line["question"]
+ user_prompt += f"Question: {question}\n\n"
+
+ choices = eval(line["choices"])
+ labels = [chr(ord("A") + i) for i in range(len(choices))]
+ user_prompt += "\n".join([f"{label}:{value}" for label, value in zip(labels, choices)]) + "\n\n"
+
+ message.append(dict(type="text", value=system_prompt + user_prompt))
+
+ return message
+
+ finally:
+ # Ensure that `use_subtitle_time` is always restored to its original value
+ self.use_subtitle_time = origin_use_subtitle_time
+
+ def save_video_frames(self, video, uid, clue_intervals=None, num_frames=8, fps=-1):
+
+ if type(uid) is not str:
+ uid = str(uid)
+ import decord
+ vid_path = osp.join(self.data_root, video)
+ vid = decord.VideoReader(vid_path)
+ vid_fps = vid.get_avg_fps()
+ n_frames = len(vid)
+
+ if clue_intervals is not None:
+ merged_intervals = merge_intervals(clue_intervals)
+
+ if num_frames > 0 and fps < 0:
+ indices = sample_frames_clue_average(merged_intervals, num_frames, vid_fps)
+ frame_paths = self.clue_frame_paths(uid, len(indices))
+
+ elif fps > 0:
+ frame_indices = []
+ for start, end in merged_intervals:
+ start_frame = int(start * vid_fps)
+ end_frame = int(end * vid_fps)
+ step = vid_fps / fps
+ interval_indices = [
+ int(start_frame + i * step) for i in range(int((end_frame - start_frame) / step))
+ ]
+ frame_indices.extend(interval_indices)
+
+ if len(frame_indices) < 32:
+ indices = sample_frames_clue_average(merged_intervals, 32, vid_fps)
+ else:
+ indices = frame_indices
+ frame_paths = self.clue_frame_paths_fps(uid, len(indices), fps)
+
+ else:
+ if num_frames > 0 and fps < 0:
+ step_size = len(vid) / (num_frames + 1)
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
+
+ frame_paths = self.frame_paths(uid)
+ elif fps > 0:
+ total_duration = n_frames / vid_fps
+ required_frames = int(total_duration * fps)
+ step_size = vid_fps / fps
+ indices = [int(i * step_size) for i in range(required_frames)]
+ frame_paths = self.frame_paths_fps(uid, len(indices))
+
+ # Save and validate frames
+ valid_paths = []
+ valid_indices = []
+ lock_path = osp.splitext(vid_path)[0] + '.lock'
+ with portalocker.Lock(lock_path, 'w', timeout=30):
+ if not np.all([osp.exists(p) for p in frame_paths]):
+ images = [vid[i].asnumpy() for i in indices]
+ for i, (img_array, path) in enumerate(zip(images, frame_paths)):
+ if osp.exists(path):
+ try:
+ with Image.open(path) as img:
+ img.verify()
+ valid_paths.append(path)
+ valid_indices.append(indices[i])
+ except Exception:
+ continue
+ else:
+ try:
+ img = Image.fromarray(img_array)
+ img.save(path)
+ img.verify()
+ valid_paths.append(path)
+ valid_indices.append(indices[i])
+ except Exception:
+ continue
+ else:
+ for i, path in enumerate(frame_paths):
+ try:
+ with Image.open(path) as img:
+ img.verify()
+ valid_paths.append(path)
+ valid_indices.append(indices[i])
+ except Exception:
+ continue
+
+ return valid_paths, valid_indices, vid_fps
+
+ def evaluate(self, eval_file, **judge_kwargs):
+
+ assert get_file_extension(eval_file) in ['xlsx', 'json', 'tsv'], "data file should be a supported format"
+
+ tgt_file = get_intermediate_file_path(eval_file, '_rating', 'json')
+ score_file = get_intermediate_file_path(eval_file, '_score')
+
+ data = load(eval_file)
+
+ data_un = data[~pd.isna(data["prediction"])]
+ data_pred_na = data[pd.isna(data["prediction"])]
+
+ data_pred_na["score"] = -1
+
+ data_un["score"] = data_un.apply(
+ lambda row: post_process(
+ response=row["prediction"],
+ right_answer=row["answer"],
+ task_mode=row["task_mode"],
+ duration=row["duration"],
+ ),
+ axis=1,
+ )
+
+ data = pd.concat([data_pred_na, data_un])
+
+ rejected_count = (data["score"] == -1).sum()
+
+ print(
+ f"Among {len(data)} questions, "
+ f"failed to obtain prediction for {len(data_pred_na)} questions, "
+ f"failed to obtain the score for {rejected_count - len(data_pred_na)} questions. "
+ f"Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating."
+ )
+
+ dump(data, score_file)
+
+ rating = get_dimention_rating_mcq_grouding(score_file)
+
+ dump(rating, tgt_file)
+
+ return rating
+
+
+# 评估时,step_2 评估时,给出 [prompt] + image_paths 就行
+class CGBench_OpenEnded_Mini(VideoBaseDataset):
+
+ TYPE = "Video-OpenEnded"
+
+ dataset = "CG-Bench_OpenEnded_Mini"
+
+ MD5 = "9175791b11afdfa305fdb3e525b7a4ee"
+
+ SYS = (
+ "You will be provided with sampled frames from a video, along with a "
+ "question.\n"
+ "Your task is to analyze the provided frames and infer the most plausible "
+ "answer based on the visual information.\n"
+ "If the visual information is ambiguous or insufficient, use the available "
+ "context to reason your answer.\n"
+ "Only output the answer in the following format:\n\n"
+ '```json\n{"result": "answer"}\n```\n\n'
+ 'The "answer" can be a word, phrase, or sentence that directly responds to '
+ "the question.\n\n"
+ )
+
+ def __init__(
+ self,
+ dataset="CG-Bench_OpenEnded_Mini",
+ use_subtitle=False,
+ use_subtitle_time=False,
+ use_frame_time=False,
+ nframe=0,
+ fps=-1,
+ ):
+ super().__init__(dataset=dataset, nframe=nframe, fps=fps)
+ self.use_subtitle = use_subtitle
+ self.use_subtitle_time = use_subtitle_time
+ self.use_frame_time = use_frame_time
+ self.dataset_name = dataset
+ lmu_root = LMUDataRoot()
+ self.clue_frame_root = osp.join(lmu_root, "clue_images", dataset)
+
+ @classmethod
+ def supported_datasets(cls):
+ return ["CG-Bench_OpenEnded_Mini"]
+
+ def get_subtitles(self, subtitle_path, frame_indices=None, fps=None, sub_time=False):
+
+ subtitles = []
+
+ srt_path = osp.join(self.data_root, subtitle_path)
+ assert osp.exists(srt_path)
+ import pysubs2
+
+ subs = pysubs2.load(srt_path, encoding="utf-8")
+ if not frame_indices:
+ for sub in subs:
+ sub_text = sub.text.replace("\\N", " ")
+ if sub_time:
+ start_time = milliseconds_to_seconds(sub.start)
+ end_time = milliseconds_to_seconds(sub.end)
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
+ if sub_text.strip() and sub_text not in subtitles:
+ subtitles.append(sub_text)
+ else:
+ for selected_frame_id in frame_indices:
+ cur_time = pysubs2.make_time(fps=fps, frames=selected_frame_id)
+ for sub in subs:
+ if sub.start < cur_time and sub.end > cur_time:
+ sub_text = sub.text.replace("\\N", " ")
+ if sub_time:
+ start_time = milliseconds_to_seconds(sub.start)
+ end_time = milliseconds_to_seconds(sub.end)
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
+ if sub_text.strip() and sub_text not in subtitles:
+ subtitles.append(sub_text)
+
+ if subtitles:
+ subtitles_str = '\n'.join(subtitles)
+ return f"The subtitles of the video are as follows:\n\n{subtitles_str}\n\n"
+ else:
+ return ""
+
+ def prepare_dataset(self, dataset_name="CG-Bench_OpenEnded_Mini", repo_id="CG-Bench/CG-Bench"):
+
+ def check_integrity(pth):
+ data_file = osp.join(pth, f"{dataset_name}.tsv")
+
+ if not os.path.exists(data_file):
+ return False
+
+ if md5(data_file) != self.MD5:
+ return False
+ data = load(data_file)
+ for video_pth in data["video"]:
+ if not osp.exists(osp.join(pth, video_pth)):
+ return False
+
+ return True
+
+ cache_path = get_cache_path(repo_id)
+
+ if cache_path is not None and check_integrity(cache_path):
+ dataset_path = cache_path
+ else:
+
+ def generate_tsv(pth):
+
+ tsv_file = osp.join(pth, f"{dataset_name}.tsv")
+
+ with open(osp.join(pth, "cgbench_mini.json"), "r") as f:
+ data_file = pd.DataFrame(json.load(f))
+
+ data_file = data_file.assign(index=range(len(data_file)))
+ data_file["video"] = data_file["video_uid"].apply(lambda x: f"cg_videos_720p/{x}.mp4")
+ data_file["subtitle_path"] = data_file["video_uid"].apply(
+ lambda x: f"cg_subtitles/{x}.srt" if osp.exists(osp.join(pth, f"cg_subtitles/{x}.srt")) else ""
+ )
+
+ data_file = data_file[
+ [
+ "index",
+ "video_uid",
+ "video",
+ "duration",
+ "domain",
+ "sub_category",
+ "subtitle_path",
+ "question",
+ "answer",
+ "clue_intervals",
+ "qid",
+ ]
+ ]
+
+ data_file.to_csv(tsv_file, sep="\t", index=False)
+
+ if modelscope_flag_set():
+ from modelscope import dataset_snapshot_download
+
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
+ else:
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
+
+ unzip_hf_zip(dataset_path)
+ generate_tsv(dataset_path)
+
+ tsv_file = osp.join(dataset_path, f"{dataset_name}.tsv")
+
+ return dict(data_file=tsv_file, root=dataset_path)
+
+ def build_prompt(self, line, video_llm):
+
+ if isinstance(line, int):
+ assert line < len(self)
+ line = self.data.iloc[line]
+
+ message = []
+
+ sys_prompt = self.SYS
+
+ user_prompt = ""
+
+ video_path = line["video"]
+
+ if video_llm:
+ message.append(dict(type="video", value=osp.join(self.data_root, video_path)))
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
+ if self.nframe:
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
+ )
+ user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
+ fps=vid_fps, sub_time=self.use_subtitle_time)
+ else:
+ user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
+ else:
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
+ )
+ message.extend(dict(type="image", value=im) for im in image_paths)
+
+ if self.use_frame_time:
+ user_prompt += get_timestampes(frame_indices, vid_fps)
+
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
+ user_prompt += self.get_subtitles(
+ line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
+ sub_time=self.use_subtitle_time
+ )
+
+ question = line["question"]
+ user_prompt += f"Question: {question}\n\n"
+
+ message.append(dict(type="text", value=sys_prompt + user_prompt))
+
+ return message
+
+ def clue_frame_paths(self, qid, num_frames=8):
+ frame_root = osp.join(self.clue_frame_root, qid)
+ os.makedirs(frame_root, exist_ok=True)
+ return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
+
+ def save_video_frames(self, video, uid, clue_intervals=None, num_frames=8, fps=-1):
+
+ if type(uid) is not str:
+ uid = str(uid)
+ import decord
+ vid_path = osp.join(self.data_root, video)
+ vid = decord.VideoReader(vid_path)
+ vid_fps = vid.get_avg_fps()
+ n_frames = len(vid)
+
+ if clue_intervals is not None:
+ merged_intervals = merge_intervals(clue_intervals)
+
+ if num_frames > 0 and fps < 0:
+ indices = sample_frames_clue_average(merged_intervals, num_frames, vid_fps)
+ frame_paths = self.clue_frame_paths(uid, len(indices))
+
+ elif fps > 0:
+ frame_indices = []
+ for start, end in merged_intervals:
+ start_frame = int(start * vid_fps)
+ end_frame = int(end * vid_fps)
+ step = vid_fps / fps
+ interval_indices = [
+ int(start_frame + i * step) for i in range(int((end_frame - start_frame) / step))
+ ]
+ frame_indices.extend(interval_indices)
+
+ if len(frame_indices) < 32:
+ indices = sample_frames_clue_average(merged_intervals, 32, vid_fps)
+ else:
+ indices = frame_indices
+ frame_paths = self.clue_frame_paths_fps(uid, len(indices), fps)
+
+ else:
+ if num_frames > 0 and fps < 0:
+ step_size = len(vid) / (num_frames + 1)
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
+ frame_paths = self.frame_paths(uid)
+ elif fps > 0:
+ total_duration = n_frames / vid_fps
+ required_frames = int(total_duration * fps)
+ step_size = vid_fps / fps
+ indices = [int(i * step_size) for i in range(required_frames)]
+ frame_paths = self.frame_paths_fps(uid, len(indices))
+
+ valid_paths = []
+ valid_indices = []
+ lock_path = osp.splitext(vid_path)[0] + '.lock'
+ with portalocker.Lock(lock_path, 'w', timeout=30):
+ if not np.all([osp.exists(p) for p in frame_paths]):
+ images = [vid[i].asnumpy() for i in indices]
+ for i, (img_array, path) in enumerate(zip(images, frame_paths)):
+ if osp.exists(path):
+ try:
+ with Image.open(path) as img:
+ img.verify()
+ valid_paths.append(path)
+ valid_indices.append(indices[i])
+ except Exception:
+ continue
+ else:
+ try:
+ img = Image.fromarray(img_array)
+ img.save(path)
+ img.verify()
+ valid_paths.append(path)
+ valid_indices.append(indices[i])
+ except Exception:
+ continue
+ else:
+ for i, path in enumerate(frame_paths):
+ try:
+ with Image.open(path) as img:
+ img.verify()
+ valid_paths.append(path)
+ valid_indices.append(indices[i])
+ except Exception:
+ continue
+
+ return valid_paths, valid_indices, vid_fps
+
+ def evaluate(self, eval_file, **judge_kwargs):
+
+ from .utils.cgbench import get_dimention_rating_open_ended, post_process_open
+
+ assert get_file_extension(eval_file) in ['xlsx', 'json', 'tsv'], "data file should be a supported format"
+
+ tgt_file = get_intermediate_file_path(eval_file, '_rating', 'json')
+ score_file = get_intermediate_file_path(eval_file, '_score')
+ step_1_tmp_file = get_intermediate_file_path(eval_file, '_step_1', 'pkl')
+ step_2_tmp_file = get_intermediate_file_path(eval_file, '_step_2', 'pkl')
+
+ data = load(eval_file)
+
+ data_pred_no_na = data[~pd.isna(data["prediction"])]
+ data_pred_na = data[pd.isna(data["prediction"])]
+
+ data_pred_na["model_result"] = -1
+ data_pred_na["step_1_result"] = -1
+ data_pred_na["step_2_result"] = -1
+ data_pred_na["score"] = -1
+
+ data_pred_no_na["model_result"] = data_pred_no_na.apply(
+ lambda row: post_process_open(
+ response=row["prediction"],
+ ),
+ axis=1,
+ )
+
+ if judge_kwargs.get("model", None) != "gpt-4o-0806":
+ judge_kwargs["model"] = "gpt-4o-0806"
+ print("The judge model in cg-bench is gpt-4o-0806!")
+
+ data_no_model_result = data_pred_no_na[data_pred_no_na["model_result"] == -1]
+ data_step_1 = data_pred_no_na[data_pred_no_na["model_result"] != -1]
+
+ model_step_1 = build_judge(system_prompt=sys_prompt_open_eval_step_1, **judge_kwargs)
+ nproc = judge_kwargs.pop("nproc", 32)
+
+ lines_step_1 = data_step_1.to_dict("records")
+ tups_step_1 = [(model_step_1, line) for line in lines_step_1]
+
+ keys_step_1 = {line["qid"] for line in lines_step_1}
+
+ ans = {}
+ if osp.exists(step_1_tmp_file):
+ ans = load(step_1_tmp_file)
+ tups_step_1 = [x for x, i in zip(tups_step_1, keys_step_1) if i not in ans]
+ keys_step_1 = [i for i in keys_step_1 if i not in ans]
+
+ _ = track_progress_rich(
+ eval_open_first,
+ tups_step_1,
+ nproc=nproc,
+ keys=keys_step_1,
+ save=step_1_tmp_file,
+ )
+
+ step_1_results = load(step_1_tmp_file)
+ data_step_1 = save_step_1_steps(data_step_1, step_1_results) # -1, 0, 1, 2
+
+ data_no_step_1_results = data_step_1[data_step_1["step_1_result"] == -1]
+ data_step_1_over = data_step_1[data_step_1["step_1_result"].isin([0, 1])]
+ data_step_2 = data_step_1[data_step_1["step_1_result"] == 2]
+
+ print(judge_kwargs)
+
+ model_step_2 = build_judge(system_prompt=sys_prompt_open_eval_step_2, **judge_kwargs)
+
+ lines_step_2 = data_step_2.to_dict("records")
+
+ tups_step_2 = []
+
+ for line in tqdm(lines_step_2):
+ clue_intervals = eval(line["clue_intervals"])
+ lmu_root = LMUDataRoot()
+ clue_frame_root = osp.join(lmu_root, "clue_images", self.dataset)
+ data_root = self.data_root
+ frame_paths, _, _ = save_clue_video_frames(
+ data_root,
+ clue_frame_root,
+ video=line["video"],
+ uid=line["qid"],
+ clue_intervals=clue_intervals,
+ num_frames=32,
+ )
+ tups_step_2.append((model_step_2, line, frame_paths))
+
+ keys_step_2 = {line["qid"] for line in lines_step_2}
+
+ ans = {}
+ if osp.exists(step_2_tmp_file):
+ ans = load(step_2_tmp_file)
+ tups_step_2 = [x for x, i in zip(tups_step_2, keys_step_2) if i not in ans]
+ keys_step_2 = [i for i in keys_step_2 if i not in ans]
+
+ _ = track_progress_rich(
+ eval_open_second,
+ tups_step_2,
+ nproc=nproc,
+ keys=keys_step_2,
+ save=step_2_tmp_file,
+ )
+
+ step_2_results = load(step_2_tmp_file)
+ data_step_2 = save_step_2_steps(data_step_2, step_2_results)
+
+ data_no_step_2_results = data_step_2[data_step_2["score"] == -1]
+ data_step_2_over = data_step_2[data_step_2["score"].isin([0, 1])]
+
+ data = pd.concat(
+ [
+ data_pred_na,
+ data_no_model_result,
+ data_no_step_1_results,
+ data_step_1_over,
+ data_no_step_2_results,
+ data_step_2_over,
+ ]
+ )
+
+ dump(data, score_file)
+
+ rating = get_dimention_rating_open_ended(score_file)
+
+ dump(rating, tgt_file)
+
+ return rating
+
+
+class CGBench_MCQ_Grounding(VideoBaseDataset):
+
+ TYPE = "Video-MCQ-Grounding"
+
+ MD5 = "eaead3d978a689269fefce4ae29c86df"
+
+ SYS = {
+ "long_acc": (
+ "You will be provided with sampled frames from a video, along with a "
+ "multiple-choice question that includes a question and several answer options.\n"
+ "Your task is to analyze the provided frames, infer the most plausible "
+ "answer based on the visual information.\n"
+ "If the video does not provide enough information, infer the answer based "
+ "on the options available and still provide a result. "
+ "Therefore, In all cases, an answer must be given.\n"
+ "Only output the answer in the following format:\n\n"
+ '```json\n{"result": "option"}\n```\n\n'
+ 'The "option" is the uppercase letter corresponding to your answer.\n\n'
+ ),
+ "clue_acc": (
+ "You will be provided with sampled frames from a video, along with a "
+ "multiple-choice question that includes a question and several answer options.\n"
+ "Your task is to analyze the provided frames, infer the most plausible "
+ "answer based on the visual information.\n"
+ "If the video does not provide enough information, infer the answer based "
+ "on the options available and still provide a result. "
+ "Therefore, In all cases, an answer must be given.\n"
+ "Only output the answer in the following format:\n\n"
+ '```json\n{"result": "option"}\n```\n\n'
+ "The 'option' is the uppercase letter corresponding to your answer.\n\n"
+ ),
+ "miou": (
+ "You will be provided with uniformly sampled frames from a video and their "
+ "timestamps, along with a multiple-choice question that includes a question "
+ "and several answer options.\n"
+ "Your task is to determine in which intervals the 'clue intervals' exist "
+ "that contain visual information needed to answer the question.\n"
+ "Only output the answer in the following format:\n\n"
+ '```json\n{"result": [[start1, end1], [start2, end2], ...]}\n```\n\n'
+ "In this output format, each 'start' and 'end' represents the beginning and "
+ "end of an interval in seconds where relevant clues can be found.\n"
+ "You must provide at least one interval and at most five intervals. "
+ "Intervals exceeding five will NOT be considered valid.\n"
+ ),
+ "miou_wo_frame_time": (
+ "You will be provided with uniformly sampled frames from a video, along "
+ "with a multiple-choice question that includes a question and several "
+ "answer options.\n"
+ "Your task is to determine in which intervals the 'clue intervals' exist "
+ "that contain visual information needed to answer the question.\n"
+ "Only output the answer in the following format:\n\n"
+ '```json\n{"result": [[start1, end1], [start2, end2], ...]}\n```\n\n'
+ 'In this output format, each "start" and "end" represents the start and '
+ "end of the video where the relevant clue can be found in the form of a "
+ "floating point number between 0 and 1, where 0 represents the start time "
+ "of the video and 1 represents the end time of the video.\n"
+ "You must provide at least one interval and at most five intervals. "
+ "Intervals exceeding five will NOT be considered valid.\n"
+ ),
+ }
+
+ def __init__(
+ self,
+ dataset="CG-Bench_MCQ_Grounding",
+ use_subtitle=False,
+ use_subtitle_time=False,
+ use_frame_time=False,
+ nframe=0,
+ fps=-1,
+ ):
+ super().__init__(dataset=dataset, nframe=nframe, fps=fps)
+ self.use_subtitle = use_subtitle
+ self.use_subtitle_time = use_subtitle_time
+ self.use_frame_time = use_frame_time
+ self.dataset_name = dataset
+ lmu_root = LMUDataRoot()
+ self.clue_frame_root = osp.join(lmu_root, "clue_images", dataset)
+
+ @classmethod
+ def supported_datasets(cls):
+ return ["CG-Bench_MCQ_Grounding"]
+
+ def clue_frame_paths(self, qid, num_frames=8):
+ frame_root = osp.join(self.clue_frame_root, qid)
+ os.makedirs(frame_root, exist_ok=True)
+ return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
+
+ def clue_frame_paths_fps(self, qid, num_frames=8, fps=-1):
+ frame_root = osp.join(self.clue_frame_root, qid)
+ os.makedirs(frame_root, exist_ok=True)
+ return [osp.join(frame_root, self.frame_tmpl_fps.format(i, num_frames, fps)) for i in range(1, num_frames + 1)]
+
+ def get_subtitles(self, subtitle_path, frame_indices=None, fps=None, sub_time=False):
+
+ subtitles = []
+
+ srt_path = osp.join(self.data_root, subtitle_path)
+ assert osp.exists(srt_path)
+ import pysubs2
+
+ subs = pysubs2.load(srt_path, encoding="utf-8")
+ if not frame_indices:
+ for sub in subs:
+ sub_text = sub.text.replace("\\N", " ")
+ if sub_time:
+ start_time = milliseconds_to_seconds(sub.start)
+ end_time = milliseconds_to_seconds(sub.end)
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
+ if sub_text.strip() and sub_text not in subtitles:
+ subtitles.append(sub_text)
+ else:
+ for selected_frame_id in frame_indices:
+ cur_time = pysubs2.make_time(fps=fps, frames=selected_frame_id)
+ for sub in subs:
+ if sub.start < cur_time and sub.end > cur_time:
+ sub_text = sub.text.replace("\\N", " ")
+ if sub_time:
+ start_time = milliseconds_to_seconds(sub.start)
+ end_time = milliseconds_to_seconds(sub.end)
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
+ if sub_text.strip() and sub_text not in subtitles:
+ subtitles.append(sub_text)
+
+ if subtitles:
+ subtitles_str = '\n'.join(subtitles)
+ return f"The subtitles of the video are as follows:\n\n{subtitles_str}\n\n"
+ else:
+ return ""
+
+ def prepare_dataset(self, dataset_name="CG-Bench_MCQ_Grounding", repo_id="CG-Bench/CG-Bench"):
+
+ def check_integrity(pth):
+ data_file = osp.join(pth, f"{dataset_name}.tsv")
+
+ if not os.path.exists(data_file):
+ return False
+
+ if md5(data_file) != self.MD5:
+ return False
+ data = load(data_file)
+ for video_pth in data["video"]:
+ if not osp.exists(osp.join(pth, video_pth)):
+ return False
+
+ for clue_video_pth in data["clue_video_path"]:
+ if clue_video_pth and not (isinstance(clue_video_pth, float) and np.isnan(clue_video_pth)):
+ if not osp.exists(osp.join(pth, clue_video_pth)):
+ return False
+
+ return True
+
+ cache_path = get_cache_path(repo_id)
+
+ if cache_path is not None and check_integrity(cache_path):
+ dataset_path = cache_path
+ else:
+
+ def generate_tsv(pth):
+
+ tsv_file = osp.join(pth, f"{dataset_name}.tsv")
+
+ task_modes = ["long_acc", "clue_acc", "miou"]
+ all_data = []
+ for task_mode in task_modes:
+ with open(osp.join(pth, "cgbench.json"), "r") as f:
+ data_file = pd.DataFrame(json.load(f))
+
+ data_file = data_file.assign(index=range(len(data_file)))
+ data_file["video"] = data_file["video_uid"].apply(lambda x: f"cg_videos_720p/{x}.mp4")
+ data_file["subtitle_path"] = data_file["video_uid"].apply(
+ lambda x: (
+ f"cg_subtitles/{x}.srt"
+ if osp.exists(osp.join(dataset_path, f"cg_subtitles/{x}.srt"))
+ else ""
+ )
+ )
+
+ data_file["clue_video_path"] = ""
+
+ if task_mode in ["clue_acc"]:
+ data_file["clue_video_path"] = data_file["clue_video_path"] = data_file.apply(
+ lambda row: f"cg_clue_videos/{row['qid']}.mp4", axis=1
+ )
+
+ data_file["task_mode"] = task_mode
+
+ if task_mode in ["clue_acc", "long_acc"]:
+ data_file["answer"] = data_file["right_answer"]
+
+ if task_mode == "miou":
+ data_file["answer"] = data_file["clue_intervals"]
+
+ if task_mode in ["long_acc", "miou"]:
+ data_file["clue_intervals"] = ""
+
+ data_file = data_file[
+ [
+ "index",
+ "video_uid",
+ "video",
+ "duration",
+ "domain",
+ "choices",
+ "sub_category",
+ "subtitle_path",
+ "question",
+ "answer",
+ "task_mode",
+ "clue_intervals",
+ "qid",
+ "clue_video_path",
+ ]
+ ]
+
+ all_data.append(data_file)
+
+ final_data = pd.concat(all_data, ignore_index=True)
+ final_data["index"] = range(len(final_data))
+ final_data.to_csv(tsv_file, sep="\t", index=False)
+
+ if modelscope_flag_set():
+ from modelscope import dataset_snapshot_download
+
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
+ else:
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
+
+ unzip_hf_zip(dataset_path)
+ generate_tsv(dataset_path)
+
+ tsv_file = osp.join(dataset_path, f"{dataset_name}.tsv")
+
+ return dict(data_file=tsv_file, root=dataset_path)
+
+ def build_prompt(self, line, video_llm):
+
+ if isinstance(line, int):
+ assert line < len(self)
+ line = self.data.iloc[line]
+
+ task_mode = line["task_mode"]
+
+ message = []
+
+ origin_use_subtitle_time = self.use_subtitle_time
+
+ try:
+ if task_mode in ["long_acc", "clue_acc"]:
+ system_prompt = self.SYS[task_mode]
+ elif task_mode == "miou":
+ if self.use_frame_time and not video_llm:
+ system_prompt = self.SYS[task_mode]
+ else:
+ system_prompt = self.SYS["miou_wo_frame_time"]
+ if self.use_subtitle_time is True:
+ self.use_subtitle_time = False
+
+ user_prompt = ""
+
+ if task_mode in ["long_acc", "miou"]:
+ video_path = line["video"]
+
+ if video_llm:
+ message.append(dict(type="video", value=osp.join(self.data_root, video_path)))
+
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
+ if self.nframe:
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
+ )
+ user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
+ fps=vid_fps, sub_time=self.use_subtitle_time)
+ else:
+ user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
+ else:
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
+ )
+ message.extend(dict(type="image", value=im) for im in image_paths)
+
+ if self.use_frame_time:
+ user_prompt += get_timestampes(frame_indices, vid_fps)
+
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
+ user_prompt += self.get_subtitles(
+ line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
+ sub_time=self.use_subtitle_time
+ )
+
+ elif task_mode == "clue_acc":
+ clue_video_path = line["clue_video_path"]
+ video_path = line["video"]
+
+ if video_llm:
+ message.append(dict(type="video", value=osp.join(self.data_root, clue_video_path)))
+ print(message)
+
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
+ if self.nframe:
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
+ )
+ user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
+ fps=vid_fps, sub_time=self.use_subtitle_time)
+ else:
+ user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
+ else:
+ if self.nframe > 32:
+ self.nframe = 32
+ print("The maximum number of frames is 32 when evaluating clue-based mcq in CG-Bench !")
+
+ clue_intervals = eval(line["clue_intervals"])
+
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
+ video_path, uid=line["qid"], clue_intervals=clue_intervals, num_frames=self.nframe, fps=self.fps
+ )
+
+ message.extend(dict(type="image", value=im) for im in image_paths)
+
+ if self.use_frame_time:
+ user_prompt += get_timestampes(frame_indices, vid_fps)
+
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
+ user_prompt += self.get_subtitles(
+ line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
+ sub_time=self.use_subtitle_time
+ )
+
+ question = line["question"]
+ user_prompt += f"Question: {question}\n\n"
+
+ choices = eval(line["choices"])
+ labels = [chr(ord("A") + i) for i in range(len(choices))]
+ user_prompt += "\n".join([f"{label}:{value}" for label, value in zip(labels, choices)]) + "\n\n"
+
+ message.append(dict(type="text", value=system_prompt + user_prompt))
+
+ return message
+
+ finally:
+ # Ensure that `use_subtitle_time` is always restored to its original value
+ self.use_subtitle_time = origin_use_subtitle_time
+
+ def save_video_frames(self, video, uid, clue_intervals=None, num_frames=8, fps=-1):
+
+ if type(uid) is not str:
+ uid = str(uid)
+ import decord
+ vid_path = osp.join(self.data_root, video)
+ vid = decord.VideoReader(vid_path)
+ vid_fps = vid.get_avg_fps()
+ n_frames = len(vid)
+
+ if clue_intervals is not None:
+ merged_intervals = merge_intervals(clue_intervals)
+
+ if num_frames > 0 and fps < 0:
+ indices = sample_frames_clue_average(merged_intervals, num_frames, vid_fps)
+ frame_paths = self.clue_frame_paths(uid, len(indices))
+
+ elif fps > 0:
+ frame_indices = []
+ for start, end in merged_intervals:
+ start_frame = int(start * vid_fps)
+ end_frame = int(end * vid_fps)
+ step = vid_fps / fps
+ interval_indices = [
+ int(start_frame + i * step) for i in range(int((end_frame - start_frame) / step))
+ ]
+ frame_indices.extend(interval_indices)
+
+ if len(frame_indices) < 32:
+ indices = sample_frames_clue_average(merged_intervals, 32, vid_fps)
+ else:
+ indices = frame_indices
+ frame_paths = self.clue_frame_paths_fps(uid, len(indices), fps)
+
+ else:
+ if num_frames > 0 and fps < 0:
+ step_size = len(vid) / (num_frames + 1)
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
+
+ frame_paths = self.frame_paths(uid)
+ elif fps > 0:
+ total_duration = n_frames / vid_fps
+ required_frames = int(total_duration * fps)
+ step_size = vid_fps / fps
+ indices = [int(i * step_size) for i in range(required_frames)]
+ frame_paths = self.frame_paths_fps(uid, len(indices))
+
+ # Save and validate frames
+ valid_paths = []
+ valid_indices = []
+ lock_path = osp.splitext(vid_path)[0] + '.lock'
+ with portalocker.Lock(lock_path, 'w', timeout=30):
+ if not np.all([osp.exists(p) for p in frame_paths]):
+ images = [vid[i].asnumpy() for i in indices]
+ for i, (img_array, path) in enumerate(zip(images, frame_paths)):
+ if osp.exists(path):
+ try:
+ with Image.open(path) as img:
+ img.verify()
+ valid_paths.append(path)
+ valid_indices.append(indices[i])
+ except Exception:
+ continue
+ else:
+ try:
+ img = Image.fromarray(img_array)
+ img.save(path)
+ img.verify()
+ valid_paths.append(path)
+ valid_indices.append(indices[i])
+ except Exception:
+ continue
+ else:
+ for i, path in enumerate(frame_paths):
+ try:
+ with Image.open(path) as img:
+ img.verify()
+ valid_paths.append(path)
+ valid_indices.append(indices[i])
+ except Exception:
+ continue
+
+ return valid_paths, valid_indices, vid_fps
+
+ def evaluate(self, eval_file, **judge_kwargs):
+
+ assert get_file_extension(eval_file) in ['xlsx', 'json', 'tsv'], "data file should be a supported format"
+
+ tgt_file = get_intermediate_file_path(eval_file, '_rating', 'json')
+ score_file = get_intermediate_file_path(eval_file, '_score')
+
+ data = load(eval_file)
+
+ data_un = data[~pd.isna(data["prediction"])]
+ data_pred_na = data[pd.isna(data["prediction"])]
+
+ data_pred_na["score"] = -1
+
+ data_un["score"] = data_un.apply(
+ lambda row: post_process(
+ response=row["prediction"],
+ right_answer=row["answer"],
+ task_mode=row["task_mode"],
+ duration=row["duration"],
+ ),
+ axis=1,
+ )
+
+ data = pd.concat([data_pred_na, data_un])
+
+ rejected_count = (data["score"] == -1).sum()
+
+ print(
+ f"Among {len(data)} questions, "
+ f"failed to obtain prediction for {len(data_pred_na)} questions, "
+ f"failed to obtain the score for {rejected_count - len(data_pred_na)} questions. "
+ f"Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating."
+ )
+
+ dump(data, score_file)
+
+ rating = get_dimention_rating_mcq_grouding(score_file)
+
+ dump(rating, tgt_file)
+
+ return rating
+
+
+# 评估时,step_2 评估时,给出 [prompt] + image_paths 就行
+class CGBench_OpenEnded(VideoBaseDataset):
+
+ TYPE = "Video-OpenEnded"
+
+ dataset = "CG-Bench_OpenEnded"
+
+ MD5 = "796035eda0b1e916c517cdc1bc145cfc"
+
+ SYS = (
+ "You will be provided with sampled frames from a video, along with a "
+ "question.\n"
+ "Your task is to analyze the provided frames and infer the most plausible "
+ "answer based on the visual information.\n"
+ "If the visual information is ambiguous or insufficient, use the available "
+ "context to reason your answer.\n"
+ "Only output the answer in the following format:\n\n"
+ '```json\n{"result": "answer"}\n```\n\n'
+ 'The "answer" can be a word, phrase, or sentence that directly responds to '
+ "the question.\n\n"
+ )
+
+ def __init__(
+ self,
+ dataset="CG-Bench_OpenEnded",
+ use_subtitle=False,
+ use_subtitle_time=False,
+ use_frame_time=False,
+ nframe=0,
+ fps=-1,
+ ):
+ super().__init__(dataset=dataset, nframe=nframe, fps=fps)
+ self.use_subtitle = use_subtitle
+ self.use_subtitle_time = use_subtitle_time
+ self.use_frame_time = use_frame_time
+ self.dataset_name = dataset
+ lmu_root = LMUDataRoot()
+ self.clue_frame_root = osp.join(lmu_root, "clue_images", dataset)
+
+ @classmethod
+ def supported_datasets(cls):
+ return ["CG-Bench_OpenEnded"]
+
+ def get_subtitles(self, subtitle_path, frame_indices=None, fps=None, sub_time=False):
+
+ subtitles = []
+
+ srt_path = osp.join(self.data_root, subtitle_path)
+ assert osp.exists(srt_path)
+ import pysubs2
+
+ subs = pysubs2.load(srt_path, encoding="utf-8")
+ if not frame_indices:
+ for sub in subs:
+ sub_text = sub.text.replace("\\N", " ")
+ if sub_time:
+ start_time = milliseconds_to_seconds(sub.start)
+ end_time = milliseconds_to_seconds(sub.end)
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
+ if sub_text.strip() and sub_text not in subtitles:
+ subtitles.append(sub_text)
+ else:
+ for selected_frame_id in frame_indices:
+ cur_time = pysubs2.make_time(fps=fps, frames=selected_frame_id)
+ for sub in subs:
+ if sub.start < cur_time and sub.end > cur_time:
+ sub_text = sub.text.replace("\\N", " ")
+ if sub_time:
+ start_time = milliseconds_to_seconds(sub.start)
+ end_time = milliseconds_to_seconds(sub.end)
+ sub_text = f"[{start_time}, {end_time}] {sub_text}"
+ if sub_text.strip() and sub_text not in subtitles:
+ subtitles.append(sub_text)
+
+ if subtitles:
+ subtitles_str = '\n'.join(subtitles)
+ return f"The subtitles of the video are as follows:\n\n{subtitles_str}\n\n"
+ else:
+ return ""
+
+ def prepare_dataset(self, dataset_name="CG-Bench_OpenEnded", repo_id="CG-Bench/CG-Bench"):
+
+ def check_integrity(pth):
+ data_file = osp.join(pth, f"{dataset_name}.tsv")
+
+ if not os.path.exists(data_file):
+ return False
+
+ if md5(data_file) != self.MD5:
+ return False
+ data = load(data_file)
+ for video_pth in data["video"]:
+ if not osp.exists(osp.join(pth, video_pth)):
+ return False
+
+ return True
+
+ cache_path = get_cache_path(repo_id)
+
+ if cache_path is not None and check_integrity(cache_path):
+ dataset_path = cache_path
+ else:
+
+ def generate_tsv(pth):
+
+ tsv_file = osp.join(pth, f"{dataset_name}.tsv")
+
+ with open(osp.join(pth, "cgbench.json"), "r") as f:
+ data_file = pd.DataFrame(json.load(f))
+
+ data_file = data_file.assign(index=range(len(data_file)))
+ data_file["video"] = data_file["video_uid"].apply(lambda x: f"cg_videos_720p/{x}.mp4")
+ data_file["subtitle_path"] = data_file["video_uid"].apply(
+ lambda x: f"cg_subtitles/{x}.srt" if osp.exists(osp.join(pth, f"cg_subtitles/{x}.srt")) else ""
+ )
+
+ data_file = data_file[
+ [
+ "index",
+ "video_uid",
+ "video",
+ "duration",
+ "domain",
+ "sub_category",
+ "subtitle_path",
+ "question",
+ "answer",
+ "clue_intervals",
+ "qid",
+ ]
+ ]
+
+ data_file.to_csv(tsv_file, sep="\t", index=False)
+
+ if modelscope_flag_set():
+ from modelscope import dataset_snapshot_download
+ dataset_path = dataset_snapshot_download(dataset_id=repo_id)
+ else:
+ dataset_path = snapshot_download(repo_id=repo_id, repo_type="dataset")
+
+ unzip_hf_zip(dataset_path)
+ generate_tsv(dataset_path)
+
+ tsv_file = osp.join(dataset_path, f"{dataset_name}.tsv")
+
+ return dict(data_file=tsv_file, root=dataset_path)
+
+ def build_prompt(self, line, video_llm):
+
+ if isinstance(line, int):
+ assert line < len(self)
+ line = self.data.iloc[line]
+
+ message = []
+
+ sys_prompt = self.SYS
+
+ user_prompt = ""
+
+ video_path = line["video"]
+
+ if video_llm:
+ message.append(dict(type="video", value=osp.join(self.data_root, video_path)))
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
+ if self.nframe:
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
+ )
+ user_prompt += self.get_subtitles(line["subtitle_path"], frame_indices=frame_indices,
+ fps=vid_fps, sub_time=self.use_subtitle_time)
+ else:
+ user_prompt += self.get_subtitles(line["subtitle_path"], sub_time=self.use_subtitle_time)
+ else:
+ image_paths, frame_indices, vid_fps = self.save_video_frames(
+ video_path, uid=line["video_uid"], num_frames=self.nframe, fps=self.fps
+ )
+ message.extend(dict(type="image", value=im) for im in image_paths)
+
+ if self.use_frame_time:
+ user_prompt += get_timestampes(frame_indices, vid_fps)
+
+ if self.use_subtitle and line["subtitle_path"] and not pd.isna(line["subtitle_path"]):
+ user_prompt += self.get_subtitles(
+ line["subtitle_path"], frame_indices=frame_indices, fps=vid_fps,
+ sub_time=self.use_subtitle_time
+ )
+
+ question = line["question"]
+ user_prompt += f"Question: {question}\n\n"
+
+ message.append(dict(type="text", value=sys_prompt + user_prompt))
+
+ return message
+
+ def clue_frame_paths(self, qid, num_frames=8):
+ frame_root = osp.join(self.clue_frame_root, qid)
+ os.makedirs(frame_root, exist_ok=True)
+ return [osp.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1)]
+
+ def save_video_frames(self, video, uid, clue_intervals=None, num_frames=8, fps=-1):
+
+ if type(uid) is not str:
+ uid = str(uid)
+ import decord
+ vid_path = osp.join(self.data_root, video)
+ vid = decord.VideoReader(vid_path)
+ vid_fps = vid.get_avg_fps()
+ n_frames = len(vid)
+
+ if clue_intervals is not None:
+ merged_intervals = merge_intervals(clue_intervals)
+
+ if num_frames > 0 and fps < 0:
+ indices = sample_frames_clue_average(merged_intervals, num_frames, vid_fps)
+ frame_paths = self.clue_frame_paths(uid, len(indices))
+
+ elif fps > 0:
+ frame_indices = []
+ for start, end in merged_intervals:
+ start_frame = int(start * vid_fps)
+ end_frame = int(end * vid_fps)
+ step = vid_fps / fps
+ interval_indices = [
+ int(start_frame + i * step) for i in range(int((end_frame - start_frame) / step))
+ ]
+ frame_indices.extend(interval_indices)
+
+ if len(frame_indices) < 32:
+ indices = sample_frames_clue_average(merged_intervals, 32, vid_fps)
+ else:
+ indices = frame_indices
+ frame_paths = self.clue_frame_paths_fps(uid, len(indices), fps)
+
+ else:
+ if num_frames > 0 and fps < 0:
+ step_size = len(vid) / (num_frames + 1)
+ indices = [int(i * step_size) for i in range(1, num_frames + 1)]
+ frame_paths = self.frame_paths(uid)
+ elif fps > 0:
+ total_duration = n_frames / vid_fps
+ required_frames = int(total_duration * fps)
+ step_size = vid_fps / fps
+ indices = [int(i * step_size) for i in range(required_frames)]
+ frame_paths = self.frame_paths_fps(uid, len(indices))
+
+ valid_paths = []
+ valid_indices = []
+ lock_path = osp.splitext(vid_path)[0] + '.lock'
+ with portalocker.Lock(lock_path, 'w', timeout=30):
+ if not np.all([osp.exists(p) for p in frame_paths]):
+ images = [vid[i].asnumpy() for i in indices]
+ for i, (img_array, path) in enumerate(zip(images, frame_paths)):
+ if osp.exists(path):
+ try:
+ with Image.open(path) as img:
+ img.verify()
+ valid_paths.append(path)
+ valid_indices.append(indices[i])
+ except Exception:
+ continue
+ else:
+ try:
+ img = Image.fromarray(img_array)
+ img.save(path)
+ img.verify()
+ valid_paths.append(path)
+ valid_indices.append(indices[i])
+ except Exception:
+ continue
+ else:
+ for i, path in enumerate(frame_paths):
+ try:
+ with Image.open(path) as img:
+ img.verify()
+ valid_paths.append(path)
+ valid_indices.append(indices[i])
+ except Exception:
+ continue
+
+ return valid_paths, valid_indices, vid_fps
+
+ def evaluate(self, eval_file, **judge_kwargs):
+
+ from .utils.cgbench import get_dimention_rating_open_ended, post_process_open
+
+ assert get_file_extension(eval_file) in ['xlsx', 'json', 'tsv'], "data file should be a supported format"
+
+ tgt_file = get_intermediate_file_path(eval_file, '_rating', 'json')
+ score_file = get_intermediate_file_path(eval_file, '_score')
+ step_1_tmp_file = get_intermediate_file_path(eval_file, '_step_1', 'pkl')
+ step_2_tmp_file = get_intermediate_file_path(eval_file, '_step_2', 'pkl')
+
+ data = load(eval_file)
+
+ data_pred_no_na = data[~pd.isna(data["prediction"])]
+ data_pred_na = data[pd.isna(data["prediction"])]
+
+ data_pred_na["model_result"] = -1
+ data_pred_na["step_1_result"] = -1
+ data_pred_na["step_2_result"] = -1
+ data_pred_na["score"] = -1
+
+ data_pred_no_na["model_result"] = data_pred_no_na.apply(
+ lambda row: post_process_open(
+ response=row["prediction"],
+ ),
+ axis=1,
+ )
+
+ if judge_kwargs.get("model", None) != "gpt-4o-0806":
+ judge_kwargs["model"] = "gpt-4o-0806"
+ print("The judge model in cg-bench is gpt-4o-0806!")
+
+ data_no_model_result = data_pred_no_na[data_pred_no_na["model_result"] == -1]
+ data_step_1 = data_pred_no_na[data_pred_no_na["model_result"] != -1]
+
+ model_step_1 = build_judge(system_prompt=sys_prompt_open_eval_step_1, **judge_kwargs)
+ nproc = judge_kwargs.pop('nproc', 32)
+
+ lines_step_1 = data_step_1.to_dict("records")
+ tups_step_1 = [(model_step_1, line) for line in lines_step_1]
+
+ keys_step_1 = {line["qid"] for line in lines_step_1}
+
+ ans = {}
+ if osp.exists(step_1_tmp_file):
+ ans = load(step_1_tmp_file)
+ tups_step_1 = [x for x, i in zip(tups_step_1, keys_step_1) if i not in ans]
+ keys_step_1 = [i for i in keys_step_1 if i not in ans]
+
+ _ = track_progress_rich(
+ eval_open_first,
+ tups_step_1,
+ nproc=nproc,
+ keys=keys_step_1,
+ save=step_1_tmp_file,
+ )
+
+ step_1_results = load(step_1_tmp_file)
+ data_step_1 = save_step_1_steps(data_step_1, step_1_results) # -1, 0, 1, 2
+
+ data_no_step_1_results = data_step_1[data_step_1["step_1_result"] == -1]
+ data_step_1_over = data_step_1[data_step_1["step_1_result"].isin([0, 1])]
+ data_step_2 = data_step_1[data_step_1["step_1_result"] == 2]
+
+ model_step_2 = build_judge(system_prompt=sys_prompt_open_eval_step_2, **judge_kwargs)
+
+ lines_step_2 = data_step_2.to_dict("records")
+
+ tups_step_2 = []
+
+ for line in tqdm(lines_step_2):
+ clue_intervals = eval(line["clue_intervals"])
+ lmu_root = LMUDataRoot()
+ clue_frame_root = osp.join(lmu_root, "clue_images", self.dataset)
+ data_root = self.data_root
+ frame_paths, _, _ = save_clue_video_frames(
+ data_root,
+ clue_frame_root,
+ video=line["video"],
+ uid=line["qid"],
+ clue_intervals=clue_intervals,
+ num_frames=32,
+ )
+ tups_step_2.append((model_step_2, line, frame_paths))
+
+ keys_step_2 = {line["qid"] for line in lines_step_2}
+
+ ans = {}
+ if osp.exists(step_2_tmp_file):
+ ans = load(step_2_tmp_file)
+ tups_step_2 = [x for x, i in zip(tups_step_2, keys_step_2) if i not in ans]
+ keys_step_2 = [i for i in keys_step_2 if i not in ans]
+
+ _ = track_progress_rich(
+ eval_open_second,
+ tups_step_2,
+ nproc=nproc,
+ keys=keys_step_2,
+ save=step_2_tmp_file,
+ )
+
+ step_2_results = load(step_2_tmp_file)
+ data_step_2 = save_step_2_steps(data_step_2, step_2_results)
+
+ data_no_step_2_results = data_step_2[data_step_2["score"] == -1]
+ data_step_2_over = data_step_2[data_step_2["score"].isin([0, 1])]
+
+ data = pd.concat(
+ [
+ data_pred_na,
+ data_no_model_result,
+ data_no_step_1_results,
+ data_step_1_over,
+ data_no_step_2_results,
+ data_step_2_over,
+ ]
+ )
+
+ dump(data, score_file)
+
+ rating = get_dimention_rating_open_ended(score_file)
+
+ dump(rating, tgt_file)
+
+ return rating
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartbench.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartbench.py
new file mode 100644
index 00000000..d3ba9480
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartbench.py
@@ -0,0 +1,309 @@
+import copy
+import os
+import re
+from typing import Optional
+
+import pandas as pd
+
+from ..smp import dump, get_intermediate_file_path, load
+from ..utils import track_progress_rich
+from .image_base import ImageBaseDataset
+from .utils import build_judge
+
+metric_group = {
+ 'box': ['box_h', 'box_v', 'stock'],
+ 'combination': ['bar_line', 'line_line', 'pie_bar', 'pie_pie'],
+ 'pie': ['sector', 'ring_wo_anno', 'pie', 'ring', 'InteSun'],
+ 'scatter': ['scatter_2d', 'scatter_2d_smooth', 'scatter_3d'],
+ 'line': ['line_err', 'line_multi_wi_anno', 'line_multi',
+ 'line_single_wi_anno', 'line_single'],
+ 'bar': ['horizontal_single', 'vertical_single',
+ 'horizontal_single_wi_anno', 'vertical_single_wi_anno',
+ 'vertical_percent_stacked', 'horizontal_multi',
+ 'vertical_multi', 'threeD_stacked', 'vertical_stacked',
+ 'horizontal_stacked', 'threeD_bar_multi',
+ 'horizontal_percent_stacked', 'threeD_percent_stacked'],
+ 'radar': ['radar_single_wi_anno', 'radar_single',
+ 'radar_multi_fill', 'radar_multi'],
+ 'area': ['area', 'area_stack', 'area_percent'],
+ 'node': ['node_link', 'node_link_dir', 'node_link_undir'],
+}
+
+metric_anno = {
+ 'wi_anno': ["horizontal_single_wi_anno", "vertical_single_wi_anno",
+ "pie_pie", "pie_bar",
+ "radar_single_wi_anno", "node_link_dir",
+ "node_link_undir", "ring_wi_anno",
+ "line_multi_wi_anno", "line_single_wi_anno"],
+ 'wo_anno': ["horizontal_single", "vertical_single",
+ "bar_line", "line_line", "radar_single",
+ "ring", "line_multi", "line_single"]
+}
+
+
+def relaxed_correctness(target: str, prediction: str, max_relative_change: float = 0.05) -> bool:
+ def _prediction_to_float(text: str) -> Optional[float]:
+ try:
+ if text.endswith('%'):
+ return float(text.rstrip('%'))
+ else:
+ return float(text)
+ except ValueError:
+ return None
+
+ def _target_to_float(text: str):
+ try:
+ if text.endswith('%'):
+ return [float(text.rstrip('%')), float(text.rstrip('%')) / 100.0]
+ else:
+ return [float(text)]
+ except ValueError:
+ return None
+
+ prediction_float = _prediction_to_float(prediction)
+ target_float = _target_to_float(target)
+ if prediction_float is not None and target_float is not None:
+ flag = False
+ for t in target_float:
+ if t == 0:
+ relative_change = prediction_float
+ else:
+ relative_change = abs(prediction_float - t) / abs(t)
+ flag = flag or relative_change <= max_relative_change
+ return flag
+ else:
+ return prediction.lower() == target.lower()
+
+
+def fuzzy_match(sentence):
+ sentence = str(sentence).lower()
+ contains_yes = re.search(r'\byes\b', sentence) is not None
+ if not contains_yes:
+ contains_yes = 'yes' in sentence
+ return contains_yes, not contains_yes
+
+
+def accuracy_plus(ans1, ans2):
+ isYes, _ = fuzzy_match(ans1)
+ _, isNo = fuzzy_match(ans2)
+ return isYes and isNo
+
+
+def confuse_rate(ans1, ans2):
+ ar_yes, ar_no = fuzzy_match(ans1)
+ aw_yes, aw_no = fuzzy_match(ans2)
+ return (ar_yes and aw_yes) or (ar_no and aw_no)
+
+
+def accuracy_vanilla(ans1, ans2):
+ isYes, _ = fuzzy_match(ans1)
+ _, isNo = fuzzy_match(ans2)
+ return [isYes, isNo]
+
+
+def ChartBench_auxeval(tup):
+ model, line = tup
+ pred = str(line.get('prediction', ''))
+ qa_type = line.get('QA_type', '')
+ if qa_type in ['GPT-acc', 'NQA']:
+ prompt = (
+ "Please extract the final numerical answer from the following text. "
+ "Do not output anything else but the number or percentage strictly without extra English words.\n"
+ f"Text: {pred}\nOutput:"
+ )
+ res = model.generate(prompt)
+ return {'gpt_filter': res}
+ return {'gpt_filter': pred}
+
+
+class ChartBench(ImageBaseDataset):
+ DATASET_URL = {
+ 'ChartBench': 'https://huggingface.co/datasets/Jinsong-Li/VLMEvalKitData/resolve/main/ChartBench.tsv'
+ }
+ DATASET_MD5 = {'ChartBench': 'a1f72798819a740a91825acbf0dec68a'}
+
+ def __init__(self, dataset='ChartBench', **kwargs):
+ super().__init__(dataset=dataset, **kwargs)
+
+ def build_prompt(self, line):
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+ msgs = super().build_prompt(line)
+ return msgs
+
+ @classmethod
+ def evaluate(cls, eval_file, **judge_kwargs):
+ print("Evaluating ChartBench results...")
+ data = load(eval_file)
+
+ data['prediction'] = data['prediction'].astype(str)
+ data['answer'] = data['answer'].astype(str)
+ data['index'] = data['index'].astype(str)
+
+ try:
+ model = build_judge(max_tokens=128, **judge_kwargs)
+ judge_working = model.working()
+ except BaseException:
+ judge_working = False
+
+ if judge_working:
+ print("Running GPT-based evaluation for Numerical QA tasks...")
+ target_indices = data[data['QA_type'].isin(
+ ['GPT-acc', 'NQA'])].index.tolist()
+
+ tmp_gpt = get_intermediate_file_path(eval_file, '_gpt_eval', 'pkl')
+ ans = {}
+ if os.path.exists(tmp_gpt):
+ ans = load(tmp_gpt)
+
+ pending_indices = [i for i in target_indices if i not in ans]
+ if pending_indices:
+ tups = [(model, data.iloc[i]) for i in pending_indices]
+ track_progress_rich(
+ ChartBench_auxeval,
+ tups,
+ nproc=judge_kwargs.get('nproc', 4),
+ chunksize=judge_kwargs.get('nproc', 4),
+ keys=pending_indices,
+ save=tmp_gpt
+ )
+ ans = load(tmp_gpt)
+
+ gpt_filters = []
+ for i in range(len(data)):
+ if i in ans:
+ gpt_filters.append(ans[i].get('gpt_filter', data.iloc[i]['prediction']))
+ else:
+ gpt_filters.append(data.iloc[i]['prediction'])
+ data['prediction_gpt'] = gpt_filters
+ else:
+ print("Warning: Judge model not working or not configured. Falling back to direct prediction matching.")
+ data['prediction_gpt'] = data['prediction']
+
+ data['base_id'] = data['index'].apply(lambda x: x.split('_')[0] if '_' in x else x)
+ data['qa_index'] = data['index'].apply(
+ lambda x: int(x.split('_')[1]) if '_' in x and x.split('_')[1].isdigit() else 0
+ )
+
+ groups = data.groupby('base_id')
+
+ metric_record_acc = {
+ "all": [], "regular": [], "extra": [], "CR": [], "VE": [], "VC": [], "GC": [],
+ "line": [], "bar": [], "pie": [], "area": [], "box": [], "radar": [], "scatter": [],
+ "node": [], "combination": [], "wi_anno": [], "wo_anno": [], "wi_CR": [], "wo_CR": [],
+ "wi_VE": [], "wo_VE": [], "wi_VC": [], "wo_VC": [], "wi_GC": [], "wo_GC": []
+ }
+
+ metric_record_nqa = {
+ "all": [], "regular": [], "extra": [],
+ "line": [], "bar": [], "pie": [], "area": [], "box": [], "radar": [], "scatter": [],
+ "node": [], "combination": [], "wi_anno": [], "wo_anno": [],
+ }
+
+ metrics = {
+ 'accp': copy.deepcopy(metric_record_acc),
+ 'cor': copy.deepcopy(metric_record_acc),
+ 'acc': copy.deepcopy(metric_record_acc),
+ 'err': copy.deepcopy(metric_record_acc),
+ 'nqa': copy.deepcopy(metric_record_nqa),
+ }
+
+ def update_yes_no(key, accp, cor, acc, err):
+ if key in metrics['accp']:
+ metrics['accp'][key].append(accp)
+ metrics['cor'][key].append(cor)
+ metrics['acc'][key].extend(acc)
+ metrics['err'][key].append(err)
+
+ def update_nqa(key, nqa):
+ if key in metrics['nqa']:
+ metrics['nqa'][key].append(nqa)
+
+ def format_percent_metric(item):
+ if len(item) == 0:
+ return 0
+ return sum(item) / len(item) * 100
+
+ for base_id, group in groups:
+ group = group.sort_values(by='qa_index')
+ if len(group) == 0:
+ continue
+
+ first_row = group.iloc[0]
+ chart_type = first_row.get('chart_type', '')
+ task_type = first_row.get('task', '')
+ qa_type = first_row.get('QA_type', '')
+
+ if qa_type == 'Acc+':
+ if len(group) < 2:
+ accp, cor, acc, err = False, False, [False, False], True
+ else:
+ ans1 = group.iloc[0]['prediction']
+ ans2 = group.iloc[1]['prediction']
+
+ accp = accuracy_plus(ans1, ans2)
+ cor = confuse_rate(ans1, ans2)
+ acc = accuracy_vanilla(ans1, ans2)
+ err = not accp and not cor
+
+ update_yes_no('all', accp, cor, acc, err)
+ for group_key, group_values in metric_group.items():
+ if chart_type in group_values:
+ metric_category = 'regular' if group_key in {'line', 'bar', 'pie'} else 'extra'
+ update_yes_no(group_key, accp, cor, acc, err)
+ update_yes_no(metric_category, accp, cor, acc, err)
+
+ anno_key = 'wi_anno' if chart_type in metric_anno['wi_anno'] else 'wo_anno'
+ update_yes_no(anno_key, accp, cor, acc, err)
+
+ if task_type:
+ update_yes_no(task_type, accp, cor, acc, err)
+ task_anno_key = f'wi_{task_type}' if chart_type in metric_anno['wi_anno'] else f'wo_{task_type}'
+ update_yes_no(task_anno_key, accp, cor, acc, err)
+
+ elif qa_type == 'GPT-acc' or qa_type == 'NQA':
+ pred = str(first_row.get('prediction_gpt', first_row.get('prediction', '')))
+ ans_str = str(first_row.get('answer', ''))
+
+ import ast
+ ans_list = [ans_str]
+ if ans_str.startswith('['):
+ try:
+ ans_list = ast.literal_eval(ans_str)
+ except (ValueError, SyntaxError):
+ pass
+
+ nqa = False
+ for ann in ans_list:
+ if relaxed_correctness(pred.strip().strip('<\uff5cend\u2581of\u2581sentence\uff5c>'), str(ann)):
+ nqa = True
+ break
+
+ update_nqa('all', nqa)
+ for group_key, group_values in metric_group.items():
+ if chart_type in group_values:
+ metric_category = 'regular' if group_key in {'line', 'bar', 'pie'} else 'extra'
+ update_nqa(group_key, nqa)
+ update_nqa(metric_category, nqa)
+
+ anno_key = 'wi_anno' if chart_type in metric_anno['wi_anno'] else 'wo_anno'
+ update_nqa(anno_key, nqa)
+
+ merged_metric = copy.deepcopy(metric_record_nqa)
+ for key in metric_record_nqa.keys():
+ if key in metric_record_acc:
+ merged_metric[key] = metrics['nqa'][key] + metrics['accp'][key]
+ metrics['final'] = merged_metric
+
+ ans_stat = {key: {k: format_percent_metric(v) for k, v in metrics[key].items()} for key in metrics}
+
+ pd.DataFrame(ans_stat)
+
+ row_dict = {}
+ for k, v in ans_stat['final'].items():
+ row_dict[k] = v
+
+ ret = pd.DataFrame([row_dict]).round(2)
+
+ dump(ret, get_intermediate_file_path(eval_file, '_acc'))
+ return ret
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartcap.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartcap.py
new file mode 100644
index 00000000..95ad440f
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartcap.py
@@ -0,0 +1,182 @@
+import warnings
+
+import pandas as pd
+from datasets import load_dataset
+from tqdm import tqdm
+
+from vlmeval.smp import dump, load
+from vlmeval.smp.vlm import encode_image_to_base64
+from .image_base import ImageBaseDataset
+
+
+def prepare_chartcap():
+ print("Loading ChartCap dataset from HuggingFace...")
+ # Load the dataset
+ ds = load_dataset("junyoung-00/ChartCap", split="test")
+
+ # We will collect the data rows here
+ data_rows = []
+
+ # Default question as per requirements
+ default_question = 'Please provide a detailed caption for the chart.'
+
+ print("Processing samples...")
+ for idx, sample in enumerate(tqdm(ds)):
+ # Extract fields
+ # sample in HuggingFace dataset usually acts like a dict
+
+ # Prepare the base64 image
+ if 'image' in sample:
+ img = sample['image']
+ img_b64 = encode_image_to_base64(img)
+ else:
+ print(f"Warning: No image found for sample {idx}, skipping.")
+ continue
+
+ # Prepare answer (ground truth caption)
+ # The dataset likely has a caption field. Let's inspect the keys if we can't be sure
+ # But for now I'll assume standard naming or check `sample` content dynamically if needed.
+ # Based on typical HF datasets, it might be 'caption' or 'text'.
+ # However, looking at junyoung-00/ChartCap on HF (hypothetically provided link),
+ # usually it has 'image' and 'caption'.
+ # If 'label' or 'ground_truth' exists, use that.
+ # I will dump all original keys as requested.
+
+ # Checking common keys for caption in such datasets
+ answer = sample.get('caption', sample.get('text', ''))
+
+ row = {
+ 'index': idx,
+ 'image': img_b64,
+ 'question': default_question,
+ 'answer': answer
+ }
+
+ # Add all original data information
+ for k, v in sample.items():
+ if k not in row and k != 'image': # Don't duplicate image or overwrite our fields if they match
+ row[k] = v
+
+ data_rows.append(row)
+
+ print(f"Creating TSV with {len(data_rows)} samples...")
+ df = pd.DataFrame(data_rows)
+
+ # Ensure columns order: index, image, question, answer, ... others
+ cols = ['index', 'image', 'question', 'answer']
+ remaining_cols = [c for c in df.columns if c not in cols]
+ df = df[cols + remaining_cols]
+
+ output_file = 'ChartCap.tsv'
+ df.to_csv(output_file, sep='\t', index=False)
+ print(f"Saved to {output_file}")
+
+
+class ChartCapDataset(ImageBaseDataset):
+
+ TYPE = 'Caption'
+ DATASET_URL = {
+ 'ChartCap': 'https://huggingface.co/datasets/alfassy/chart_cap_vlmevalkit/blob/main/ChartCap.tsv',
+ }
+
+ DATASET_MD5 = {
+ 'ChartCap': '10a0292079120f748eae81af5a1e19da',
+ }
+
+ @classmethod
+ def supported_datasets(cls):
+ return list(cls.DATASET_URL)
+
+ def evaluate(self, eval_file, **judge_kwargs):
+ import evaluate
+ from bert_score import BERTScorer
+ from sentence_transformers import SentenceTransformer # noqa: F401
+
+ from .utils.megabench.scoring.sacrebleu_bleu import Bleu
+
+ data = load(eval_file)
+
+ # Prepare predictions and references
+ # Ensure data is sorted or aligned by index if needed, but usually index is preserved
+ # data usually contains 'prediction' and 'answer' columns
+
+ predictions = [str(x) for x in data['prediction']]
+ references = [str(x) for x in data['answer']]
+
+ # Metric 1: sacreBLEU
+ # Reference implementation uses corpus_bleu
+ # vlmeval wrapper: Bleu.match(pred, gt) returns single score or expected aggregated?
+ # Looking at Bleu.match source: it calls corpus_bleu(corr, [resp]).score / 100
+ # It seems to be designed for single sample or list of list?
+ # Let's use sacrebleu directly for corpus level or the wrapper if it supports batch properly.
+ # wrapper `Bleu.match(response, correct_answer)` handles lists.
+ # But it returns scalar 0-1 (divided by 100).
+
+ bleu_score = Bleu.match(predictions, references) * 100
+
+ # Metric 2: ROUGE-L through 'evaluate'
+ try:
+ rouge = evaluate.load("rouge")
+ rouge_results = rouge.compute(predictions=predictions, references=references)
+ rouge_l = rouge_results['rougeL']
+ except Exception as e:
+ warnings.warn(f"Failed to compute ROUGE: {e}")
+ rouge_l = 0.0
+
+ # Metric 3: METEOR through 'evaluate'
+ try:
+ meteor = evaluate.load("meteor")
+ meteor_results = meteor.compute(predictions=predictions, references=references)
+ meteor_score = meteor_results['meteor']
+ except Exception as e:
+ warnings.warn(f"Failed to compute METEOR: {e}")
+ meteor_score = 0.0
+
+ # Metric 4: BERTScore
+ # Refer to uni_svg.py usage
+ try:
+ # Device handling
+ import torch
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ # Using roberta-large-mnli or similar is common, but let's stick to default or what uni_svg uses
+ # uni_svg uses "en" (which defaults to roberta-large presumably)
+ bert_scorer = BERTScorer(lang="en", rescale_with_baseline=False, device=device)
+ P, R, F1 = bert_scorer.score(predictions, references)
+ bert_score_val = F1.mean().item()
+ except Exception as e:
+ warnings.warn(f"Failed to compute BERTScore: {e}")
+ bert_score_val = 0.0
+
+ results = {
+ 'BLEU_4': bleu_score,
+ 'ROUGE_L': rouge_l,
+ 'METEOR': meteor_score,
+ 'BERTScore': bert_score_val
+ }
+
+ # Format as requested: dictionary composed of lists, organized into a pandas.DataFrame
+ # or just a dictionary of scores?
+ # Base class says: "The return value of the function is the calculated accuracy and other metrics,
+ # formatted as a dictionary composed of lists, organized into a pandas.DataFrame."
+ # However, for captioning, we usually return a single scalar per metric for the whole dataset.
+ # But typical evaluate returns a dict (often converted to DF later).
+ # Let's make it a DF with one row or just the dict if allowed.
+ # Looking at ImageCaptionDataset, it returns a dict of scores.
+ # Let's return the dict, usually frame_eval handles it.
+
+ print("Evaluation Results for ChartCap:")
+ for k, v in results.items():
+ print(f"{k}: {v:.4f}")
+
+ # Create CSV output file
+ import pandas as pd
+ result_df = pd.DataFrame([results])
+ result_file = eval_file.replace(f".{eval_file.split('.')[-1]}", "_acc.csv")
+ dump(result_df, result_file)
+
+ return results
+
+
+if __name__ == '__main__':
+ prepare_chartcap()
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartmimic.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartmimic.py
new file mode 100644
index 00000000..f0242312
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartmimic.py
@@ -0,0 +1,817 @@
+# flake8: noqa
+import base64
+import json
+import mimetypes
+import os
+import os.path as osp
+import re
+import shutil
+import subprocess
+import sys
+import warnings
+
+from timeout_decorator import timeout
+
+from vlmeval.smp import (LMUDataRoot, download_file, dump, file_size, get_intermediate_file_path,
+ get_logger, load, md5)
+
+FAIL_MSG = "Failed to obtain answer via API."
+
+logger = get_logger(__name__)
+
+# SET VLMEVAL_CHARTMIMIC_UTILS_PATH for chartmimic evaluator
+# ".../VLMEvalKit/vlmeval..."
+cur_path = os.path.abspath(__file__)
+util_path = cur_path.replace("dataset/chartmimic.py", "dataset/utils/chartmimic")
+os.environ["VLMEVAL_CHARTMIMIC_UTILS_PATH"] = util_path
+if os.environ["VLMEVAL_CHARTMIMIC_UTILS_PATH"] not in sys.path:
+ sys.path.insert(0, os.environ["VLMEVAL_CHARTMIMIC_UTILS_PATH"])
+
+from ..dataset.utils.chartmimic.evaluator.chart_type_evaluator import ChartTypeEvaluator
+from ..dataset.utils.chartmimic.evaluator.color_evaluator import ColorEvaluator
+from ..dataset.utils.chartmimic.evaluator.layout_evaluator import LayoutEvaluator
+# from ..utils import track_progress_rich
+from ..dataset.utils.chartmimic.evaluator.text_evaluator import TextEvaluator
+from ..dataset.utils.chartmimic.mp_util import track_progress_rich_new
+from .image_base import ImageBaseDataset
+from .utils import DEBUG_MESSAGE, build_judge
+
+# from ..dataset.utils.chartmimic.evaluator.legend_evaluator import LegendEvaluator
+# from ..dataset.utils.chartmimic.evaluator.grid_evaluator import GridEvaluator
+
+judge_model = None
+save_code_dir = None
+sub_set_name = None
+cur_work_dir = None
+pdf_tmp_dir = None
+# save_dir_name_map = {
+# "Direct Mimic": "direct",
+# "Customized Mimic": "customized",
+# }
+
+high_level_eval_prompt = {
+ "instruction": "You are an excellent judge at evaluating visualization chart plots. The first image (reference image) is created using ground truth matplotlib code, and the second image (AI-generated image) is created using matplotlib code generated by an AI assistant. Your task is to score how well the AI-generated plot matches the ground truth plot.\n\n### Scoring Methodology:\nThe AI-generated image's score is based on the following criteria, totaling a score out of 100 points:\n\n1. **Chart Types (20 points)** Does the AI-generated image include all chart types present in the reference image (e.g., line charts, bar charts, etc.)?\n2. **Layout (10 points)** Does the arrangement of subplots in the AI-generated image match the reference image (e.g., number of rows and columns)?\n3. **Text Content (20 points)** Does the AI-generated image include all text from the reference image (e.g., titles, annotations, axis labels), excluding axis tick labels?\n4. **Data (20 points)** How accurately do the data trends in the AI-generated image resemble those in the original image and is the number of data groups the same as in the reference image?\n5. **Style (20 points)** Does the AI-generated image match the original in terms of colors (line colors, fill colors, etc.), marker types (point shapes, line styles, etc.), legends, grids, and other stylistic details?\n6. **Clarity (10 points)** Is the AI-generated image clear and free of overlapping elements?\n\n### Evaluation:\nCompare the two images head to head and provide a detailed assessment. Use the following format for your response:\n\n\n---\n\nComments:\n- Chart Types: ${your comment and subscore}\n- Layout: ${your comment and subscore}\n- Text Content: ${your comment and subscore}\n- Data: ${your comment and subscore}\n- Style: ${your comment and subscore}\n- Clarity: ${your comment and subscore}\n\nScore: ${your final score out of 100}\n\n---\n\nPlease use the above format to ensure the evaluation is clear and comprehensive.\n",
+ "system_msg": "",
+}
+
+
+def image_path_to_data_uri(image_path):
+ mime, _ = mimetypes.guess_type(image_path)
+ if not mime:
+ raise ValueError(f"Cannot determine MIME type for {image_path}")
+ with open(image_path, "rb") as f:
+ encoded = base64.b64encode(f.read()).decode("utf-8")
+ return f"data:{mime};base64,{encoded}"
+
+
+def run_once_with_images(pt, image_abs_path_list, retry=4):
+ global judge_model
+ # prefix = "data:image/jpeg;base64,"
+ # img = prefix + image
+ messages = [
+ *[
+ dict(type="image", value=image_path_to_data_uri(image_abs_path))
+ for image_abs_path in image_abs_path_list
+ ],
+ dict(type="text", value=pt),
+ ]
+ ans = None
+ while retry:
+ try:
+ ans = judge_model.generate(messages)
+ return ans
+ except Exception as e:
+ logger.exception(f"Error in run_once_with_images: {e}")
+ retry -= 1
+ return ans
+
+
+# def run_once_without_image(pt, retry=3):
+# global judge_model
+# messages = [
+# dict(type="text", value=pt),
+# ]
+# while retry:
+# try:
+# ans = judge_model.generate(messages)
+# return ans
+# except Exception as e:
+# logger.info(f"Error in run_once_without_image: {e}")
+# retry -= 1
+# return ans
+
+
+# >>> util: extract python code from markdown text <<<
+def extract_python_code(text):
+ """Extract python code from markdown text."""
+ code_matches = re.findall(r"```python(.*?)```", text, re.DOTALL)
+ if not code_matches:
+ return "" # Return an empty string if no code block is found
+ return code_matches[0] # Return the first match
+
+
+# >>> util: extract data code from
+def get_variable_code(edit_ori_file):
+ with open(edit_ori_file, "r") as f:
+ code = f.read()
+ pattern = re.compile(
+ r"# ===================\n# Part 2: Data Preparation\n# ===================\n(.*?)# ===================\n# Part 3: Plot Configuration and Rendering\n# ===================",
+ re.DOTALL,
+ )
+ match = pattern.search(code)
+
+ if match:
+ extracted_text = match.group(1)
+ extracted_text = extracted_text.strip()
+ extracted_text = (
+ "#Variable Code Block\nimport warnings;warnings.filterwarnings('ignore', category=UserWarning);warnings.filterwarnings('ignore', category=FutureWarning);import matplotlib.pyplot as plt;import pandas as pd;import numpy as np;np.random.seed(0);import math;from matplotlib_venn import venn2;from matplotlib import cm;from scipy.stats import gaussian_kde;import networkx as nx;from matplotlib.gridspec import GridSpec;from scipy.stats import multivariate_normal;import colorsys;import matplotlib.colors as mcolors;from matplotlib.colors import LogNorm;from scipy.stats import norm;import matplotlib.gridspec as gridspec;import seaborn as sns\n"
+ + extracted_text
+ )
+ else:
+ print(edit_ori_file)
+ raise ValueError("No match found")
+ return extracted_text
+
+
+# >>> util: clean escape characters in code string <<<
+def clean_escape_chars(code: str) -> str:
+ """
+ Clean escape characters in code string to ensure proper execution.
+ Handles common escape sequences and ensures proper string formatting.
+
+ Args:
+ code (str): The code string to clean
+
+ Returns:
+ str: Cleaned code string
+ """
+ # Common escape sequences to handle
+ escape_map = {
+ r"\\n": "\n", # Newline
+ r"\\r": "\r", # Carriage return
+ r"\\t": "\t", # Tab
+ r'\\"': '"', # Double quote
+ r"\\'": "'", # Single quote
+ r"\\\\": "\\", # Backslash
+ r"\\b": "\b", # Backspace
+ r"\\f": "\f", # Form feed
+ r"\\v": "\v", # Vertical tab
+ }
+
+ # Replace escape sequences
+ for escaped, unescaped in escape_map.items():
+ code = code.replace(escaped, unescaped)
+
+ return code
+
+
+def _convert_single_page_pdf_to_png(pdf_path, output_path, dpi=350):
+ from pdf2image import convert_from_path
+
+ try:
+ images = convert_from_path(pdf_path, dpi=dpi)
+ images[0].save(output_path, "PNG")
+ except Exception as e:
+ logger.info(f"Error in converting pdf to image: {e}")
+ return False
+ return True
+
+
+def extract_gpt_score(resp):
+ # First match: standard or markdown-styled "Score: 91/100", "Score: **91/100**", etc
+ pattern = r"^\s*Score:\s*[*_~`]*\**\s*(\d+)\s*/\s*100\s*[*_~`]*\**"
+ m = re.search(pattern, resp, re.IGNORECASE | re.MULTILINE)
+ if m:
+ return int(m.group(1))
+
+ # Fallback match: match "Score: 91", "Score: **91**", "Score: *91*", etc
+ fallback_pattern = r"Score:\s*[*_~`]*\**\s*(\d+)\s*[*_~`]*\**"
+ matches = list(re.finditer(fallback_pattern, resp, re.IGNORECASE))
+ if matches:
+ return int(matches[-1].group(1))
+
+ return 0
+
+
+def judge_one_item(item):
+ try:
+ return _judge_one_item(item)
+ except Exception as e:
+ logger.warning(f'Failed to judge ChartMimic item because {repr(e)}:\n{item}')
+ zero_score_dict = {
+ "low_level": {
+ "original_py_file": None,
+ "generated_py_file": None,
+ "text_metrics": {"precision": 0, "recall": 0, "f1": 0},
+ "chart_type_metrics": {"precision": 0, "recall": 0, "f1": 0},
+ "layout_metrics": {"precision": 0, "recall": 0, "f1": 0},
+ "color_metrics": {"precision": 0, "recall": 0.0, "f1": 0},
+ },
+ "high_level": {
+ "resp": None,
+ "msg": None,
+ "score": 0.0,
+ },
+ }
+ return 0, zero_score_dict
+
+
+@timeout(600, use_signals=False)
+def _judge_one_item(item):
+ score_dict = {}
+ zero_score_dict = {
+ "low_level": {
+ "original_py_file": None,
+ "generated_py_file": None,
+ "text_metrics": {"precision": 0, "recall": 0, "f1": 0},
+ "chart_type_metrics": {"precision": 0, "recall": 0, "f1": 0},
+ "layout_metrics": {"precision": 0, "recall": 0, "f1": 0},
+ "color_metrics": {"precision": 0, "recall": 0.0, "f1": 0},
+ },
+ "high_level": {
+ "resp": None,
+ "msg": None,
+ "score": 0.0,
+ },
+ }
+
+ global judge_model, save_code_dir, sub_set_name
+ item = json.loads(item)
+ # >>> 1. Run Code to Generate PY and PDF <<<
+ # extract python code from item["prediction"]
+ code = extract_python_code(item["prediction"])
+ # clean code string: \\n -> \n...
+ code = clean_escape_chars(code)
+ # len(code) == 0 means no code generated or not format in ```python...```
+ if len(code) == 0:
+ logger.info(
+ f"index: {item['index']}, no code extracted from prediction, return 0, zero_score_dict: {zero_score_dict}"
+ )
+ return 0, zero_score_dict
+
+ # add data code to the beginning of code
+ if "customized" in item["task"].lower():
+ # extract data code from original file
+ ground_truth_figure_code_file_rel = item["ground_truth_figure_code"]
+ ROOT = LMUDataRoot()
+ img_root = os.path.join(ROOT, "images", "ChartMimic")
+ ground_truth_figure_code_file = os.path.join(
+ img_root, ground_truth_figure_code_file_rel
+ )
+ data_code = get_variable_code(ground_truth_figure_code_file)
+ # add data code to the beginning of code
+ code = data_code + "\n" + code
+
+ # save code to py and run to generate pdf
+ # logger.info(f"save_code_dir: {save_code_dir}")
+ if "direct" in item["task"].lower():
+ save_dir_name = "direct"
+ elif "customized" in item["task"].lower():
+ save_dir_name = "customized"
+ else:
+ raise ValueError(f"Invalid task: {item['task']}")
+
+ # save code to py and run to generate pdf
+ output_py = (
+ f"{save_code_dir}/ChartMimic/{sub_set_name}/{save_dir_name}/{item['index']}.py"
+ )
+ os.makedirs(os.path.dirname(output_py), exist_ok=True)
+ # clean & add self redefined path
+ code = re.sub(r"plt\.savefig\(.*\n*", "", code, flags=re.S)
+ code = re.sub(r"plt.show\(.*\n*", "", code, flags=re.S)
+ code = code.strip() + '\nplt.savefig("{}")'.format(output_py.replace(".py", ".pdf"))
+ with open(output_py, "w") as f:
+ f.write(code)
+ # [Attention] run code with timeout, enhancement here
+ # try generate pdf
+ try:
+ subprocess.run(
+ ["python", output_py],
+ timeout=120,
+ capture_output=True,
+ text=True,
+ )
+ logger.info(f"Successfully ran {output_py}")
+ except subprocess.TimeoutExpired:
+ logger.info(f"Timeout: Script {output_py} ran too long.")
+ except Exception as e:
+ # maybe could directly return 0, zero_score_dict
+ logger.info(f"Error when running {output_py}: {e}")
+
+ # check if pdf exists
+ if not os.path.exists(output_py.replace(".py", ".pdf")):
+ zero_score_dict["high_level"]["original_py_file"] = output_py
+ logger.info(
+ f"index: {item['index']}, run code failed, pdf does not exist, return 0, zero_score_dict: {zero_score_dict}"
+ )
+ return 0, zero_score_dict
+
+ # try generate image (converted from pdf)
+ if os.path.exists(output_py.replace(".py", ".pdf")):
+ # if error when converting pdf to image, maybe could directly return 0, zero_score_dict
+ _convert_single_page_pdf_to_png(
+ output_py.replace(".py", ".pdf"), output_py.replace(".py", ".png")
+ )
+ # logger.info(f"converted pdf to image: {output_py.replace('.py', '.png')}")
+ # breakpoint()
+
+ # --- Got py and its pdf ---
+ # >>> 2. Low Level Evaluation <<<
+ text_evaluator = TextEvaluator(use_position=False, use_axs=False)
+ chart_type_evaluator = ChartTypeEvaluator()
+ color_evaluator = ColorEvaluator()
+ layout_evaluator = LayoutEvaluator()
+ # unused
+ # legend_evaluator = LegendEvaluator(use_position=True)
+ # grid_evaluator = GridEvaluator()
+
+ ground_truth_figure_code_file_rel = item["ground_truth_figure_code"]
+ ROOT = LMUDataRoot()
+ img_root = os.path.join(ROOT, "images", "ChartMimic")
+ ground_truth_figure_code_file = os.path.join(
+ img_root, ground_truth_figure_code_file_rel
+ )
+ original_py_file = ground_truth_figure_code_file
+ generated_py_file = output_py
+
+ # logger.info(f"original_py_file: {original_py_file}")
+ # logger.info(f"generated_py_file: {generated_py_file}")
+
+ # global pdf_tmp_dir
+ # os.chdir(pdf_tmp_dir)
+
+ try:
+ timeout(30.)(text_evaluator)(
+ generation_code_file=generated_py_file, golden_code_file=original_py_file
+ )
+ except Exception as e:
+ logger.info(f"Failed to evaluate text for {item['index']} because {repr(e)}")
+
+ try:
+ timeout(120.)(chart_type_evaluator)(
+ generation_code_file=generated_py_file, golden_code_file=original_py_file
+ )
+ except Exception as e:
+ logger.info(f"Failed to evaluate chart for {item['index']} because {repr(e)}")
+
+ try:
+ timeout(30.)(color_evaluator)(
+ generation_code_file=generated_py_file, golden_code_file=original_py_file
+ )
+ except Exception as e:
+ logger.info(f"Failed to evaluate color for {item['index']} because {repr(e)}")
+
+ try:
+ timeout(30.)(layout_evaluator)(
+ generation_code_file=generated_py_file, golden_code_file=original_py_file
+ )
+ except Exception as e:
+ logger.info(f"Failed to evaluate layout for {item['index']} because {repr(e)}")
+
+ low_level_score_dict = {
+ "original_py_file": original_py_file,
+ "generated_py_file": generated_py_file,
+ "text_metrics": text_evaluator.metrics,
+ "chart_type_metrics": chart_type_evaluator.metrics,
+ "layout_metrics": layout_evaluator.metrics,
+ "color_metrics": color_evaluator.metrics,
+ }
+
+ score_dict["low_level"] = low_level_score_dict
+
+ # >>> 3. High Level Evaluation <<<
+ generated_pdf_image_file = generated_py_file.replace(".py", ".png")
+ # check if generated_pdf_image_file exists
+ if not os.path.exists(generated_pdf_image_file):
+ # logger.info(f"Generated PDF image file {generated_pdf_image_file} does not exist")
+ score_dict["high_level"] = {
+ "resp": None,
+ "msg": "Generated image file does not exist",
+ "score": 0.0,
+ }
+ # logger.info(f"index: {item['index']}, return 0, score_dict: {score_dict}")
+ return 0, score_dict
+
+ # image order should align with prompt
+ resp = run_once_with_images(
+ high_level_eval_prompt["instruction"],
+ [original_py_file.replace(".py", ".png"), generated_pdf_image_file],
+ )
+ if resp is None:
+ logger.error("Error in getting response from judge model!")
+ score_dict["high_level"] = {
+ "resp": None,
+ "msg": "Error in getting response from judge model!",
+ "score": 0.0,
+ }
+ logger.debug(f"index: {item['index']}, return -1, score_dict: {score_dict}")
+ return -1, score_dict
+ else:
+ logger.debug(f"Successfully got response from judge model:\n{resp}")
+ score_dict["high_level"] = {
+ "resp": resp,
+ "msg": "Successfully got response from judge model!",
+ "score": extract_gpt_score(resp),
+ }
+ logger.debug(f"index: {item['index']}, return 0, score_dict: {score_dict}")
+ return 0, score_dict
+
+
+class ChartMimic(ImageBaseDataset):
+ TYPE = "VQA"
+
+ # TODO: add dataset url and md5
+ DATASET_URL = {
+ "ChartMimic_v1_customized": "https://opencompass.openxlab.space/utils/VLMEval/ChartMimic_v1_customized.tsv",
+ "ChartMimic_v1_direct": "https://opencompass.openxlab.space/utils/VLMEval/ChartMimic_v1_direct.tsv",
+ # v2
+ "ChartMimic_v2_customized": "https://opencompass.openxlab.space/utils/VLMEval/ChartMimic_v2_customized.tsv",
+ "ChartMimic_v2_customized_600": "https://opencompass.openxlab.space/utils/VLMEval/ChartMimic_v2_customized_600.tsv",
+ "ChartMimic_v2_customized_1800": "https://opencompass.openxlab.space/utils/VLMEval/ChartMimic_v2_customized_1800.tsv",
+ "ChartMimic_v2_direct": "https://opencompass.openxlab.space/utils/VLMEval/ChartMimic_v2_direct.tsv",
+ "ChartMimic_v2_direct_600": "https://opencompass.openxlab.space/utils/VLMEval/ChartMimic_v2_direct_600.tsv",
+ "ChartMimic_v2_direct_1800": "https://opencompass.openxlab.space/utils/VLMEval/ChartMimic_v2_direct_1800.tsv"
+ }
+ DATASET_MD5 = {
+ "ChartMimic_v1_customized": "d636eca077e75e39fd2600889bf0284e",
+ "ChartMimic_v1_direct": "d0fb410970cab0c666bbacf7d9f0cfb3",
+ "ChartMimic_v2_customized": "390e715dbdfbad3ff788fffa91945405",
+ "ChartMimic_v2_customized_600": "79907e8f9edc5e0eccbbcfc9cbe8a235",
+ "ChartMimic_v2_customized_1800": "a6cf57807c07d328689872a77c9f847a",
+ "ChartMimic_v2_direct": "1c8b444bd681f808f77f06037866eb19",
+ "ChartMimic_v2_direct_600": "3d8d8afecccb6e8feacbcec6834f45f5",
+ "ChartMimic_v2_direct_1800": "340331019c7eaa56cc02080656b66c3c"
+ }
+
+ def dump_image(self, line):
+ input_figure_path_rel = line["input_figure"]
+ ROOT = LMUDataRoot()
+ img_root = os.path.join(ROOT, 'images', 'ChartMimic')
+ input_figure_path = os.path.join(img_root, input_figure_path_rel)
+ tgt_path = [input_figure_path]
+ return tgt_path
+
+ def prepare_tsv(self, url, file_md5=None):
+ data_root = LMUDataRoot()
+ os.makedirs(data_root, exist_ok=True)
+ update_flag = False
+ file_name = url.split("/")[-1]
+ data_path = osp.join(data_root, file_name)
+ self.data_path = data_path
+ if osp.exists(data_path) and (file_md5 is None or md5(data_path) == file_md5):
+ pass
+ else:
+ warnings.warn("The dataset tsv is not downloaded")
+ download_file(url, data_path)
+ update_flag = True
+
+ if file_size(data_path, "GB") > 1:
+ local_path = data_path.replace(".tsv", "_local.tsv")
+ if (
+ not osp.exists(local_path)
+ or os.environ.get("FORCE_LOCAL", None)
+ or update_flag
+ ):
+ from ..tools import LOCALIZE
+
+ LOCALIZE(data_path, local_path)
+ data_path = local_path
+ # Extra check for images
+ py_root = os.path.join(LMUDataRoot(), "images", "ChartMimic")
+ v1_path = osp.join(py_root, "v1")
+ # Check if py_root/v1 exists
+ if not osp.exists(os.path.join(v1_path, "customized_500")) or not osp.exists(
+ os.path.join(v1_path, "ori_500")
+ ):
+ # Download v1
+ warnings.warn("Python files v1 needed by ChartMimic are not downloaded")
+ os.makedirs(v1_path, exist_ok=True)
+ v1_tar = osp.join(v1_path, "v1.tar.gz")
+ if not osp.exists(v1_tar):
+ print("Downloading ChartMimic v1 files...")
+ subprocess.run(
+ [
+ "wget",
+ "https://hf-mirror.com/datasets/ChartMimic/ChartMimic/resolve/main/dataset-old.tar.gz",
+ "-O",
+ v1_tar,
+ ],
+ check=True,
+ )
+ print("Extracting v1...")
+ # subprocess.run([
+ # "tar", "-xzvf", v1_tar, "-C", v1_path
+ # ], check=True)
+ try:
+ subprocess.run(
+ ["tar", "-xzvf", v1_tar, "--no-same-owner", "-C", v1_path],
+ check=True,
+ )
+ except subprocess.CalledProcessError as e:
+ warnings.warn(f"tar extract v1 warning, try to continue. error: {e}")
+ v2_path = osp.join(py_root, "v2")
+ if (
+ not osp.exists(os.path.join(v2_path, "customized_1800"))
+ or not osp.exists(os.path.join(v2_path, "direct_1800"))
+ or not osp.exists(os.path.join(v2_path, "customized_600"))
+ or not osp.exists(os.path.join(v2_path, "direct_600"))
+ ):
+ warnings.warn("Python files v2 needed by ChartMimic are not downloaded")
+ os.makedirs(v2_path, exist_ok=True)
+ v2_tar = osp.join(v2_path, "v2.tar.gz")
+ if not osp.exists(v2_tar):
+ print("Downloading ChartMimic v2 files...")
+ subprocess.run(
+ [
+ "wget",
+ "https://hf-mirror.com/datasets/ChartMimic/ChartMimic/resolve/main/dataset-iclr.tar.gz",
+ "-O",
+ v2_tar,
+ ],
+ check=True,
+ )
+ print("Extracting v2...")
+ try:
+ subprocess.run(
+ ["tar", "-xzvf", v2_tar, "--no-same-owner", "-C", v2_path],
+ check=True,
+ )
+ except subprocess.CalledProcessError as e:
+ warnings.warn(f"tar extract v2 warning, try to continue. error: {e}")
+
+ return load(data_path)
+
+ # Given one data record, return the built prompt (a multi-modal message), can override
+ # Actually, all lines have single image
+ def build_prompt(self, line):
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+
+ # no "image" in tsv, so self.meta_only is True
+ # logger.info(f"self.meta_only: {self.meta_only}")
+ # logger.info(line.keys())
+
+ input_figure_path_rel = line["input_figure"]
+ instruction = line["question"]
+
+ ROOT = LMUDataRoot()
+ img_root = os.path.join(ROOT, "images", "ChartMimic")
+ input_figure_path = os.path.join(img_root, input_figure_path_rel)
+
+ msgs = []
+ msgs = [dict(type="image", value=input_figure_path)]
+ msgs = [dict(type="text", value=instruction)] + msgs
+
+ return msgs
+
+
+ def evaluate(self, eval_file, **judge_kwargs):
+ def judge_one_item_success(item):
+ return item["high_level"]["resp"] not in [FAIL_MSG, "", None, "null", "None"] \
+ and item["high_level"]["msg"] not in ["Generated image file does not exist", ""]
+
+ # Test dependencies first
+ try:
+ import matplotlib_venn
+ import PIL
+ import squarify
+ from colormath.color_objects import LabColor, sRGBColor
+ from pdf2image import convert_from_path
+ except ImportError as e:
+ logging.critical(
+ "Please follow the requirements (see vlmeval/dataset/utils/chartmimic/eval_req.txt) \
+ to install dependency package for chartmimic evaluation."
+ )
+ raise e
+
+ # Test pdf2image functionality by creating a simple test PDF
+ example_pdf_path = os.path.join(LMUDataRoot(), "chartmimic_test.pdf")
+ output_png_path = os.path.join(LMUDataRoot(), "chartmimic_test.png")
+
+ try:
+ # Create a simple test PDF using matplotlib
+ import matplotlib.pyplot as plt
+ fig, ax = plt.subplots(1, 1, figsize=(4, 3))
+ ax.plot([1, 2, 3], [1, 4, 2])
+ ax.set_title("Test Chart")
+ plt.savefig(example_pdf_path, format='pdf')
+ plt.close()
+
+ # Test pdf2image conversion
+ images = convert_from_path(example_pdf_path, dpi=350)
+ images[0].save(output_png_path, "PNG")
+ logger.info("Successfully tested pdf2image functionality with generated test PDF")
+ except Exception as e:
+ logging.critical(
+ "Please install poppler-utils in your system (e.g. sudo apt-get install poppler-utils)."
+ )
+ raise e
+ finally:
+ # Clean up test files
+ if os.path.exists(example_pdf_path):
+ os.remove(example_pdf_path)
+ if os.path.exists(output_png_path):
+ os.remove(output_png_path)
+
+ infer_data_all = load(eval_file).to_dict(orient="records")
+
+ print(f"judge_kwargs: {judge_kwargs}")
+ infer_model = judge_kwargs["model"]
+ storage = os.path.abspath(get_intermediate_file_path(eval_file, f'_{infer_model}', 'jsonl'))
+ score_file = os.path.abspath(get_intermediate_file_path(eval_file, f'_{infer_model}_score', 'csv'))
+ # use abs path because of using os.chdir()
+ tmp_file = os.path.abspath(get_intermediate_file_path(eval_file, f'_{infer_model}_tmp', 'pkl'))
+ # actually the --api-nproc
+ nproc = judge_kwargs.pop("nproc", 8)
+ logger.info(f"nproc: {nproc}")
+ global save_code_dir, sub_set_name
+ # [Attention] should use absolute dir here
+ eval_file_abs_path = os.path.abspath(eval_file)
+ save_code_dir = os.path.dirname(eval_file_abs_path)
+
+ # dataset_name is subset name like ChartMimic_1000
+ sub_set_name = self.dataset_name
+
+ # params prepare for track_progress_rich
+ params_all = [json.dumps(item) for item in infer_data_all]
+ indices_all = [line["index"] for line in infer_data_all]
+
+ ans = {}
+ if os.path.exists(tmp_file):
+ tmp_data = load(tmp_file)
+ for k, v in tmp_data.items():
+ # -1 means error for getting response from judge model, so try to rejudge for this item
+ if v[0] == 0 and judge_one_item_success(v[1]):
+ ans[k] = v
+ logger.info(f"Tmp file exists, loaded {len(ans)} data from {tmp_file}")
+
+ tups = [x for x, i in zip(params_all, indices_all) if i not in ans]
+ indices = [i for i in indices_all if i not in ans]
+
+ # save current work dir
+ global cur_work_dir, pdf_tmp_dir
+ cur_work_dir = os.getcwd()
+ pdf_tmp_dir = os.path.join(save_code_dir, "chart_mimic_tmp", f"{sub_set_name}")
+ os.makedirs(pdf_tmp_dir, exist_ok=True)
+ os.chdir(pdf_tmp_dir)
+
+ # >>> judge <<<
+ if len(indices):
+ # judge_kwargs['system_prompt'] = SYSTEM_PROMPT
+ judge_kwargs["temperature"] = 0
+ judge_kwargs["img_detail"] = "high"
+ judge_kwargs["timeout"] = 100
+ global judge_model
+ judge_model = build_judge(max_tokens=1024, **judge_kwargs)
+
+ assert judge_model.working(), (
+ "ChartMimic evaluation requires a working OPENAI API\n" + DEBUG_MESSAGE
+ )
+
+ # if len(indices):
+ new_results = track_progress_rich_new(
+ judge_one_item,
+ tups,
+ nproc=nproc,
+ keys=indices,
+ save=tmp_file,
+ )
+ for k, v in zip(indices, new_results):
+ ans[k] = v
+
+ for item in infer_data_all:
+ # ans[i] is a tuple, (0 / -1, score_dict), only use score_dict
+ item["judge_result"] = ans[item["index"]][1]
+
+ # storage is a jsonl file
+ with open(storage, "w") as f:
+ for item in infer_data_all:
+ f.write(json.dumps(item) + "\n")
+
+ # judge finished, rm tmp dir
+ os.chdir(cur_work_dir)
+ if os.path.exists(pdf_tmp_dir):
+ shutil.rmtree(pdf_tmp_dir)
+ # breakpoint()
+
+ # logger.info(f"storage: {storage}")
+ eval_data_all = load(storage)
+ # result_df = pd.DataFrame(columns=["example_count", "exec_rate", "text_score","layout_score", "type_score", "color_score", "average", f"gpt_score({judge_kwargs['model']})", "overall"])
+
+ # filter out items that do not have judge_result["low_level"] and judge_result["high_level"]: failed item need rejudge
+ old_len = len(eval_data_all)
+ eval_data_all = [
+ item
+ for item in eval_data_all
+ if "judge_result" in item
+ and "low_level" in item["judge_result"]
+ and "high_level" in item["judge_result"]
+ ]
+ new_len = len(eval_data_all)
+ logger.info(f"filter out {old_len - new_len} items for no judge_result in item")
+
+ old_len = len(eval_data_all)
+ eval_data_all = [
+ item
+ for item in eval_data_all
+ # if judge_one_item_success(item["judge_result"])
+ ]
+ new_len = len(eval_data_all)
+ logger.info(
+ f"filter out {old_len - new_len} items for FAIL_MSG in high_level resp"
+ )
+
+ def compute_metrics(eval_data):
+ result = {
+ "example_count": len(eval_data),
+ }
+
+ denominator = len(eval_data)
+ if denominator == 0:
+ # Avoid division by zero, return zeros
+ result.update(
+ {
+ "exec_rate": 0,
+ "text_score": 0,
+ "layout_score": 0,
+ "chart_type_score": 0,
+ "color_score": 0,
+ "average": 0,
+ "gpt_score": 0,
+ "overall": 0,
+ }
+ )
+ return result
+
+ pdf_file_cnt = 0
+ text_score_sum = 0
+ layout_score_sum = 0
+ type_score_sum = 0
+ color_score_sum = 0
+ gpt_score_sum = 0
+
+ for item in eval_data:
+ py_file = item["judge_result"]["low_level"]["generated_py_file"]
+ if py_file and os.path.exists(py_file.replace(".py", ".pdf")):
+ pdf_file_cnt += 1
+
+ text_score_sum += item["judge_result"]["low_level"]["text_metrics"][
+ "f1"
+ ]
+ layout_score_sum += item["judge_result"]["low_level"]["layout_metrics"][
+ "f1"
+ ]
+ type_score_sum += item["judge_result"]["low_level"][
+ "chart_type_metrics"
+ ]["f1"]
+ color_score_sum += item["judge_result"]["low_level"]["color_metrics"][
+ "f1"
+ ]
+ gpt_score_sum += item["judge_result"]["high_level"]["score"]
+
+ result["exec_rate"] = pdf_file_cnt / denominator * 100
+ result["text_score"] = text_score_sum / denominator * 100
+ result["layout_score"] = layout_score_sum / denominator * 100
+ result["chart_type_score"] = type_score_sum / denominator * 100
+ result["color_score"] = color_score_sum / denominator * 100
+ result["average"] = (
+ result["text_score"]
+ + result["layout_score"]
+ + result["chart_type_score"]
+ + result["color_score"]
+ ) / 4
+ result["gpt_score"] = gpt_score_sum / denominator
+ result["overall"] = (result["average"] + result["gpt_score"]) / 2
+
+ return result
+
+ # Collect unique task values
+ task_values = sorted(
+ set(
+ item.get("task")
+ for item in eval_data_all
+ if item.get("task") is not None
+ )
+ )
+
+ # Create splits dict
+ splits = {
+ "all": eval_data_all,
+ **{
+ task: [item for item in eval_data_all if item.get("task") == task]
+ for task in task_values
+ },
+ }
+
+ all_results = []
+ for split_name, data in splits.items():
+ result = compute_metrics(data)
+ result["split"] = split_name
+ all_results.append(result)
+
+ score_df = pd.DataFrame(all_results)
+ # reorder columns
+ cols = ["split"] + [col for col in score_df.columns if col != "split"]
+ score_df = score_df[cols]
+ dump(score_df, score_file)
+ return score_df
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartmuseum.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartmuseum.py
new file mode 100644
index 00000000..6b17140b
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartmuseum.py
@@ -0,0 +1,199 @@
+import os.path as osp
+import re
+from collections import defaultdict
+from typing import Any, Dict, List, Union
+
+import pandas as pd
+
+from vlmeval.dataset.image_base import ImageBaseDataset
+from vlmeval.smp import dump, get_intermediate_file_path, get_logger, load, toliststr
+from vlmeval.utils import track_progress_rich
+from .utils import build_judge
+
+logger = get_logger(__name__)
+
+COMPARE_ANSWER_PROMPT = """You are provided with a question and two answers. Please determine if these answers are equivalent. Follow these guidelines:
+
+1. Numerical Comparison:
+ - For decimal numbers, consider them as equivalent if their relative difference is sufficiently small.
+ For example, the following pairs are equivalent:
+ - 32.35 and 32.34
+ - 90.05 and 90.00
+ - 83.3% and 83.2%
+ - 31 and 31%
+ The following pairs are not equivalent:
+ - 32.35 and 35.25
+ - 90.05 and 91.05
+ - 83.3% and 45.2%
+
+ Note that if the question asks for years or dates, please do the exact match with no error tolerance.
+
+2. Unit Handling:
+ - If only one answer includes units (e.g. '$', '%', '-', etc.), ignore the units and compare only the numerical values
+ For example, the following pairs are equivalent:
+ - 305 million and 305 million square meters
+ - 0.75 and 0.75%
+ - 0.6 and 60%
+ - $80 and 80
+ The following pairs are not equivalent:
+ - 305 million and 200 million square meters
+ - 0.75 and 0.90%
+
+3. Text Comparison:
+ - Ignore differences in capitalization
+ - Treat mathematical expressions in different but equivalent forms as the same (e.g., "2+3" = "5")
+
+Question: [QUESTION]
+Answer 1: [ANSWER1]
+Answer 2: [ANSWER2]
+
+Please respond with:
+- "Yes" if the answers are equivalent
+- "No" if the answers are different""" # noqa: E501
+
+
+def get_question(QUESTION):
+
+ QA_PROMPT = f"""Please answer the question using the chart image.
+
+ Question: {QUESTION}
+
+ Please first generate your reasoning process and then provide the user with the answer. Use the following format:
+
+
+ ... your thinking process here ...
+
+
+ ... your final answer (entity(s) or number) ...
+ """ # noqa: E501
+
+ return QA_PROMPT
+
+
+def extract_answer(text: str) -> str:
+ m = re.search(r"(.*?)", text + "", re.DOTALL)
+ return m.group(1).strip() if m else ""
+
+
+def gpt_compare(category, question, answer1, answer2, idx, judge_model):
+
+ prompt = (
+ COMPARE_ANSWER_PROMPT
+ .replace("[QUESTION]", question)
+ .replace("[ANSWER1]", answer1)
+ .replace("[ANSWER2]", answer2)
+ )
+ response = judge_model.generate(prompt)
+
+ return response, category
+
+
+class ChartMuseum(ImageBaseDataset):
+ TYPE = "VQA"
+ DATASET_URL = {
+ "ChartMuseum_dev": "https://huggingface.co/datasets/yujieouo/ChartMuseum/blob/main/ChartMuseum_dev.tsv",
+ "ChartMuseum_test": "https://huggingface.co/datasets/yujieouo/ChartMuseum/blob/main/ChartMuseum_test.tsv",
+ }
+ DATASET_MD5 = {
+ "ChartMuseum_dev": "05dbce1f4bd5e5ba0e4b0d606efb707e",
+ "ChartMuseum_test": "983586eace6ee33cdb189d63124768c8",
+ }
+
+ def build_prompt(self, line: Union[int, pd.Series]) -> List[Dict[str, str]]:
+ """
+ Build a prompt for the model from a data line.
+
+ Args:
+ line: Either an index into the dataset or a pandas Series
+
+ Returns:
+ List of message dictionaries containing the image and question
+ """
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+
+ if self.meta_only:
+ tgt_path = toliststr(line["image"])
+ else:
+ tgt_path = self.dump_image(line)
+
+ # load data line elements
+ question = line['question']
+ # answer = line['answer']
+ # category = line['category']
+
+ # build prompt from question
+ question_context = get_question(question)
+
+ # form messages
+ msgs = []
+ msgs = [dict(type='image', value=tgt_path[0])]
+ msgs.append(dict(type='text', value=question_context))
+
+ return msgs
+
+ def evaluate(self, eval_file: str, **judge_kwargs: Any) -> pd.DataFrame:
+ """
+ Evaluate model predictions on the ChartMuseum dataset.
+ Args:
+ eval_file: Path to the file containing model predictions
+ **judge_kwargs: Additional arguments for the judge model
+ Returns:
+ DataFrame with evaluation scores by category
+ """
+
+ benchmark = self.data
+ questions = benchmark["question"].tolist()
+ gts = benchmark["answer"].astype(str).tolist()
+ categories = benchmark['category'].astype(str).tolist()
+
+ data = load(eval_file)
+ pred_list = data['prediction'].astype(str).tolist()
+ id_list = data['index'].astype(int).tolist()
+ pred_answers = [extract_answer(p) for p in pred_list]
+ assert len(id_list) == len(pred_answers)
+
+ category_flags = defaultdict(list)
+ judge_model_name = judge_kwargs.pop('judge_model', 'gpt-4.1-mini-2025-04-14')
+ nproc = judge_kwargs.pop('nproc', 4)
+ if judge_model_name != 'gpt-4.1-mini-2025-04-14':
+ logger.warning("Recommend to use gpt-4.1-mini-2025-04-14 as judge model for "
+ f"ChartMuseum, Now using {judge_model_name}")
+ tmp_file = get_intermediate_file_path(eval_file, f'_{judge_model_name}', 'pkl')
+ if osp.exists(tmp_file):
+ already_judged = load(tmp_file)
+ else:
+ already_judged = {}
+ judge_model = build_judge(model=judge_model_name, **judge_kwargs)
+
+ input_tuples = [
+ (cat, q, gt, pa, idx, judge_model)
+ for cat, q, gt, pa, idx in zip(categories, questions, gts, pred_answers, id_list)
+ if idx not in already_judged
+ ]
+ indices = [idx for _, _, _, _, idx, _ in input_tuples]
+ if len(indices):
+ _ = track_progress_rich(
+ gpt_compare,
+ input_tuples,
+ nproc=nproc,
+ chunksize=nproc,
+ keys=indices,
+ save=tmp_file,
+ )
+ ans = load(tmp_file)
+ for value in ans.values():
+ flag = 'yes' in value[0].lower()
+ category_flags[value[1]].append(int(flag))
+
+ score = {}
+ for cat, flags in category_flags.items():
+ score[cat] = [sum(flags) / len(flags) * 100]
+
+ all_flags = [f for flags in category_flags.values() for f in flags]
+ score["Overall"] = [sum(all_flags) / len(all_flags) * 100]
+ score_file = get_intermediate_file_path(eval_file, "_acc", "csv")
+ out_score = pd.DataFrame(score)
+ dump(out_score, score_file)
+
+ return out_score
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartqapro.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartqapro.py
new file mode 100644
index 00000000..837eca69
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartqapro.py
@@ -0,0 +1,123 @@
+import ast
+from typing import Any, Dict, List, Union
+
+import pandas as pd
+
+from vlmeval.dataset.image_base import ImageBaseDataset
+from vlmeval.dataset.utils.chartqapro import evaluate_predictions_chartqapro, prompt_context
+from vlmeval.smp import file, misc
+from vlmeval.smp.file import get_intermediate_file_path
+
+
+class ChartQAPro(ImageBaseDataset):
+ TYPE = "VQA"
+ DATASET_URL = {
+ "ChartQAPro": "https://opencompass.openxlab.space/utils/VLMEval/chartqapro.tsv",
+ "ChartQAPro_CoT": "https://opencompass.openxlab.space/utils/VLMEval/chartqapro.tsv",
+ "ChartQAPro_PoT": "https://opencompass.openxlab.space/utils/VLMEval/chartqapro.tsv",
+ }
+ DATASET_MD5 = {
+ "ChartQAPro": "27653ea8dd8dd3a85bc4f432db96447a",
+ "ChartQAPro_CoT": "27653ea8dd8dd3a85bc4f432db96447a",
+ "ChartQAPro_PoT": "27653ea8dd8dd3a85bc4f432db96447a",
+ }
+
+ def build_prompt(self, line: Union[int, pd.Series], qa_type: str = 'Direct') -> List[Dict[str, str]]:
+ """
+ Build a prompt for the model from a data line.
+
+ Args:
+ line: Either an index into the dataset or a pandas Series
+ qa_type: Choose from ['Direct', 'CoT', 'PoT']
+
+ Returns:
+ List of message dictionaries containing the image and question
+ """
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+
+ if self.meta_only:
+ tgt_path = misc.toliststr(line["image"])
+ else:
+ tgt_path = self.dump_image(line)
+
+ # determine qa_type, default value : 'Direct'
+ if "CoT" in self.dataset_name:
+ qa_type = "CoT"
+ elif "PoT" in self.dataset_name:
+ qa_type = "PoT"
+
+ # load data line elements
+ question = ast.literal_eval(line['question'])
+ answer = ast.literal_eval(line['answer'])
+ question_type = line['question_type']
+ # image = line['image']
+ # year = ast.literal_eval(line['year'])
+ paragraph = line['paragraph']
+ if paragraph != paragraph: # treat nan
+ paragraph = ''
+ assert isinstance(question, list)
+ assert len(tgt_path) == 1
+
+ # build prompt from question
+ question_context = prompt_context(question, answer, question_type, qa_type)
+
+ # form messages
+ msgs = []
+ msgs = [dict(type='image', value=tgt_path[0])]
+ msgs.append(dict(type='text', value=paragraph))
+ msgs.append(dict(type='text', value=question_context))
+
+ return msgs
+
+ def get_scores(self, result_file: str) -> pd.DataFrame:
+ """
+ Calculate scores by category from evaluation results.
+
+ Args:
+ result_file: Path to the file containing evaluation results
+
+ Returns:
+ DataFrame with scores for each category and overall score
+
+ Raises:
+ ValueError: If the dataset name is invalid
+ """
+
+ if "CoT" in self.dataset_name or "PoT" in self.dataset_name:
+ print("********** Warning: We follow the evaluation script for Direct to assess CoT and PoT, \
+ the scores can be very low! **********")
+
+ data = file.load(result_file)
+
+ ans_list = []
+ for idx in range(len(data)):
+ llm_ans = {}
+ llm_ans['Answer'] = ast.literal_eval(data['answer'][idx])
+ llm_ans['Question Type'] = data['question_type'][idx]
+ llm_ans['Year'] = ast.literal_eval(data['year'][idx])
+ llm_ans['prediction'] = data['prediction'][idx]
+ ans_list.append(llm_ans)
+
+ scores = evaluate_predictions_chartqapro(ans_list)
+
+ return pd.DataFrame(list(scores.items()))
+
+ def evaluate(self, eval_file: str, **judge_kwargs: Any) -> pd.DataFrame:
+ """
+ Evaluate model predictions on the ChartQAPro dataset.
+
+ Args:
+ eval_file: Path to the file containing model predictions
+ **judge_kwargs: Additional arguments for the judge model
+
+ Returns:
+ DataFrame with evaluation scores by category
+ """
+
+ score = self.get_scores(eval_file)
+ score_file = get_intermediate_file_path(eval_file, "_acc", "csv")
+
+ file.dump(score, score_file)
+
+ return score
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartx.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartx.py
new file mode 100644
index 00000000..51c594ed
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/chartx.py
@@ -0,0 +1,125 @@
+import os.path as osp
+from collections import defaultdict
+
+import numpy as np
+
+from vlmeval.smp import d2df, dump, load
+from vlmeval.smp.file import get_intermediate_file_path
+from vlmeval.utils import track_progress_rich
+from .image_base import ImageBaseDataset
+from .utils import DEBUG_MESSAGE, build_judge
+from .utils.chartx_eval import ChartX_auxeval, chartx_scrm_eval
+
+
+class ChartX(ImageBaseDataset):
+ TYPE = 'VQA'
+ DATASET_URL = {
+ 'ChartX': 'https://opencompass.openxlab.space/utils/VLMEval/ChartX.tsv'}
+ DATASET_MD5 = {'ChartX': 'ffeb5bc765c9a7a78ef326410903d04d'}
+
+ def __init__(self, dataset='ChartX', **kwargs):
+ super().__init__(dataset=dataset, **kwargs)
+
+ def build_prompt(self, line):
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+ msgs = super().build_prompt(line)
+ return msgs
+
+ @classmethod
+ def evaluate(cls, eval_file, **judge_kwargs):
+ print("Evaluating ChartX results...")
+ data = load(eval_file)
+
+ data['prediction'] = [str(x) for x in data['prediction']]
+ data['answer'] = [str(x) for x in data['answer']]
+
+ # --- 1. Structure Extraction (SCRM) ---
+ scores = {}
+ if 'category' in data.columns:
+ se_data = data[data['category'] == 'structure']
+ if len(se_data) > 0:
+ print(f"Evaluating Structure Extraction ({len(se_data)} samples)...")
+ preds = se_data['prediction'].tolist()
+ gts = se_data['answer'].tolist()
+ scrm_res = chartx_scrm_eval(preds, gts)
+ scores['SCRM'] = scrm_res['SCRM']
+ scores['AP50_Strict'] = scrm_res['AP50_Strict']
+
+ # --- 2. GPT Evaluation ---
+ # Mandatory GPT evaluation for QA, Desc, Sum, Redraw
+ try:
+ model = build_judge(max_tokens=128, **judge_kwargs)
+ judge_working = model.working()
+ except BaseException:
+ judge_working = False
+
+ if judge_working:
+ print("Running GPT-based evaluation for non-structure tasks...")
+ target_indices = []
+ if 'category' in data.columns:
+ target_indices = data[data['category']
+ != 'structure'].index.tolist()
+ else:
+ target_indices = data.index.tolist()
+
+ # Prepare for auxiliary eval
+ tmp_gpt = get_intermediate_file_path(
+ eval_file, '_gpt_eval', 'pkl')
+
+ ans = {}
+ if osp.exists(tmp_gpt):
+ ans = load(tmp_gpt)
+
+ # Identify pending items
+ pending_indices = [i for i in target_indices if i not in ans]
+ if pending_indices:
+ tups = [(model, data.iloc[i]) for i in pending_indices]
+
+ track_progress_rich(
+ ChartX_auxeval,
+ tups,
+ nproc=judge_kwargs.get('nproc', 4),
+ chunksize=judge_kwargs.get('nproc', 4),
+ keys=pending_indices,
+ save=tmp_gpt
+ )
+ ans = load(tmp_gpt)
+
+ # Aggregate scores
+ gpt_scores_list = []
+ for idx in target_indices:
+ if idx in ans:
+ gpt_scores_list.append(ans[idx]['score'])
+ else:
+ gpt_scores_list.append(0)
+
+ scores['GPT_Overall'] = np.mean(
+ gpt_scores_list) if gpt_scores_list else 0
+
+ # Breakdown
+ if 'category' in data.columns:
+ cat_scores = defaultdict(list)
+ for idx in target_indices:
+ cat = data.iloc[idx]['category']
+ if idx in ans:
+ cat_scores[cat].append(ans[idx]['score'])
+
+ for cat, val_list in cat_scores.items():
+ scores[f'{cat}_GPT'] = np.mean(val_list) if val_list else 0
+ else:
+ print(
+ "Warning: OpenAI API not working or Key missing. Skipping GPT-based evaluation.")
+ print(DEBUG_MESSAGE)
+ scores['GPT_Overall'] = 0.0
+ if 'category' in data.columns:
+ for cat in data['category'].unique():
+ if cat != 'structure':
+ scores[f'{cat}_GPT'] = 0.0
+
+ # Save score table
+ ret = d2df(scores)
+ ret.round(2)
+ dump(ret, get_intermediate_file_path(eval_file, '_acc'))
+
+ return ret
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/charxiv.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/charxiv.py
new file mode 100644
index 00000000..15676ab5
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/charxiv.py
@@ -0,0 +1,257 @@
+import json
+import os
+import warnings
+from typing import Any, Dict, List, Tuple, Union
+
+import pandas as pd
+
+from vlmeval import utils
+from vlmeval.dataset.image_base import ImageBaseDataset
+from vlmeval.dataset.utils import build_judge
+from vlmeval.smp import file, misc
+from vlmeval.smp.file import get_intermediate_file_path
+
+
+def auxeval(judge_model: Any, line: pd.Series, **kwargs: Any) -> Dict[str, Any]:
+ """
+ Evaluate a line using the judge model.
+
+ Args:
+ judge_model: The model used for evaluation
+ line: A pandas Series containing the data to evaluate
+ **kwargs: Additional arguments for the judge model
+
+ Returns:
+ Dict containing evaluation results with extract_answer and score
+ """
+ failure_result = {"extract_answer": "Failed to parse response", "score": 0.0}
+ prompt = line["grading_query"].replace("{PREDICTION}", line["prediction"])
+
+ retry = kwargs.get("retry", 10)
+ max_tokens = kwargs.get("max_tokens", 256)
+ temperature = kwargs.get("temperature", 0)
+ seed = kwargs.get("seed", 42)
+ top_p = kwargs.get("top_p", 1)
+
+ for _ in range(retry):
+ try:
+ response = judge_model.generate(
+ prompt,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ seed=seed,
+ top_p=top_p,
+ )
+ content = json.loads(response)
+ if not isinstance(content, dict):
+ return failure_result
+ if "score" not in content or "extract_answer" not in content:
+ return failure_result
+ return content
+ except Exception:
+ continue
+
+ return failure_result
+
+
+def qid2category(mode: str) -> Tuple[Dict[int, str], str]:
+ """
+ Map question IDs to their categories based on the evaluation mode.
+
+ Args:
+ mode: Either "descriptive" or "reasoning"
+
+ Returns:
+ Tuple containing a mapping dictionary and the index column name
+
+ Raises:
+ ValueError: If the mode is not recognized
+ """
+ if mode == "descriptive":
+ index_col = "qid"
+ return {
+ 1: "Information Extraction",
+ 2: "Information Extraction",
+ 3: "Information Extraction",
+ 4: "Information Extraction",
+ 5: "Information Extraction",
+ 6: "Information Extraction",
+ 7: "Information Extraction",
+ 8: "Enumeration",
+ 9: "Enumeration",
+ 10: "Counting",
+ 11: "Pattern Recognition",
+ 12: "Counting",
+ 13: "Enumeration",
+ 14: "Enumeration",
+ 15: "Enumeration",
+ 16: "Pattern Recognition",
+ 17: "Compositionality",
+ 18: "Pattern Recognition",
+ 19: "Counting",
+ }, index_col
+ elif mode == "reasoning":
+ index_col = "inst_category"
+ return {
+ 1: "Text-in-Chart",
+ 2: "Text-in-General",
+ 3: "Number-in-Chart",
+ 4: "Number-in-General",
+ }, index_col
+ else:
+ raise ValueError(f"Invalid mode: {mode}")
+
+
+class CharXiv(ImageBaseDataset):
+ TYPE = "VQA"
+ DATASET_URL = {
+ "CharXiv_descriptive_val": "https://opencompass.openxlab.space/utils/VLMEval/CharXiv_descriptive_val.tsv",
+ "CharXiv_reasoning_val": "https://opencompass.openxlab.space/utils/VLMEval/CharXiv_reasoning_val.tsv",
+ }
+ DATASET_MD5 = {
+ "CharXiv_descriptive_val": "e165037032f169a59dd09ea5d7ad3073",
+ "CharXiv_reasoning_val": "98eeff269b40726982627b19338ccd45",
+ }
+
+ def build_prompt(self, line: Union[int, pd.Series]) -> List[Dict[str, str]]:
+ """
+ Build a prompt for the model from a data line.
+
+ Args:
+ line: Either an index into the dataset or a pandas Series
+
+ Returns:
+ List of message dictionaries containing the image and question
+ """
+ if isinstance(line, int):
+ line = self.data.iloc[line]
+
+ if self.meta_only:
+ tgt_path = misc.toliststr(line["image"])
+ else:
+ tgt_path = self.dump_image(line)
+
+ messages = [{"type": "image", "value": tgt_path[0]}]
+ messages.append({"type": "text", "value": line["question"]})
+ return messages
+
+ def get_scores(self, result_file: str) -> pd.DataFrame:
+ """
+ Calculate scores by category from evaluation results.
+
+ Args:
+ result_file: Path to the file containing evaluation results
+
+ Returns:
+ DataFrame with scores for each category and overall score
+
+ Raises:
+ ValueError: If the dataset name is invalid
+ """
+ data = file.load(result_file)
+
+ if "descriptive" in self.dataset_name:
+ mode = "descriptive"
+ elif "reasoning" in self.dataset_name:
+ mode = "reasoning"
+ else:
+ raise ValueError(f"Invalid dataset name: {self.dataset_name}")
+
+ category_map, index_col = qid2category(mode)
+
+ # Group scores by category
+ scores_by_category = {}
+ for _, row in data.iterrows():
+ category = category_map[row[index_col]]
+ if category not in scores_by_category:
+ scores_by_category[category] = []
+ scores_by_category[category].append(row["score"])
+
+ # Calculate average score for each category
+ result = {}
+ for category, scores in scores_by_category.items():
+ result[category] = [sum(scores) / len(scores)]
+
+ # Calculate overall score
+ result["Overall"] = [
+ sum(sum(scores) for scores in scores_by_category.values()) / len(data)
+ ]
+
+ return pd.DataFrame(result)
+
+ def evaluate(self, eval_file: str, **judge_kwargs: Any) -> pd.DataFrame:
+ """
+ Evaluate model predictions on the CharXiv dataset.
+
+ Args:
+ eval_file: Path to the file containing model predictions
+ **judge_kwargs: Additional arguments for the judge model
+
+ Returns:
+ DataFrame with evaluation scores by category
+ """
+ # Set up judge model
+ if "LOCAL_LLM" in os.environ:
+ judge_model = os.path.basename(os.environ.get("LOCAL_LLM"))
+ else:
+ judge_model = judge_kwargs.get("model", "gpt-4o-mini")
+
+ if judge_model != "gpt-4o-mini":
+ warnings.warn(
+ f"The judge_model '{judge_model}' is not gpt-4o-mini. Evaluation results may not be accurate."
+ )
+
+ judge_model = build_judge(model=judge_model, **judge_kwargs)
+ judge_model_name = judge_model.model
+
+ # Define file paths
+ result_file = get_intermediate_file_path(eval_file, f"_{judge_model_name}")
+ temp_result_file = get_intermediate_file_path(eval_file, f"_{judge_model_name}", "pkl")
+ score_file = get_intermediate_file_path(result_file, "_acc", "csv")
+
+ # Return existing results if available
+ if os.path.exists(result_file):
+ score = self.get_scores(result_file)
+ file.dump(score, score_file)
+ return score
+
+ data = file.load(eval_file)
+ if "score" not in data.columns:
+ data["score"] = 0
+ if "extract_answer" not in data.columns:
+ data["extract_answer"] = ""
+
+ # Load intermediate results if available
+ processed_results = {}
+ if os.path.exists(temp_result_file):
+ processed_results = file.load(temp_result_file)
+
+ # Identify unprocessed indices
+ indices = [i for i in range(len(data)) if i not in processed_results]
+ tups = [(judge_model, data.iloc[i]) for i in range(len(data)) if i not in processed_results]
+
+ # Process remaining examples
+ nproc = judge_kwargs.pop("nproc", 4)
+ if len(indices):
+ utils.track_progress_rich(
+ auxeval,
+ tups,
+ nproc=nproc,
+ chunksize=nproc,
+ keys=indices,
+ save=temp_result_file,
+ **judge_kwargs,
+ )
+ processed_results = file.load(temp_result_file)
+
+ # Update data with evaluation results
+ data["score"] = data.apply(lambda x: processed_results[x.name]["score"], axis=1)
+ data["extract_answer"] = data.apply(
+ lambda x: processed_results[x.name]["extract_answer"], axis=1
+ )
+
+ # Save results and return scores
+ file.dump(data, result_file)
+ score = self.get_scores(result_file)
+ file.dump(score, score_file)
+ return score
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cmmmu.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cmmmu.py
new file mode 100644
index 00000000..26e2e212
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cmmmu.py
@@ -0,0 +1,358 @@
+import os
+import os.path as osp
+import random
+import re
+from collections import Counter
+
+import pandas as pd
+from tqdm import tqdm
+
+from vlmeval.smp import (d2df, decode_base64_to_image_file, dump, get_intermediate_file_path, load,
+ read_ok)
+from .image_base import ImageBaseDataset
+
+
+def get_multi_choice_prediction(response, all_choices, index2ans):
+ for char in [',', '.', '!', '?', ';', ':', "'"]:
+ response = response.strip(char)
+ response = " " + response + " " # add space to avoid partial match
+
+ candidates = []
+
+ for choice in all_choices: # (A) (B) (C) (D)
+ # Add the choice to candidates each time it appears in the response
+ candidates.extend([choice for _ in range(response.count(f'({choice})'))])
+
+ if len(candidates) == 0:
+ for choice in all_choices: # A B C D
+ # Similarly, add the choice for each occurrence
+ candidates.extend([choice for _ in range(response.count(f'{choice}'))])
+
+ if len(candidates) == 0 and len(response.split()) >= 1:
+ for index, ans in index2ans.items():
+ # Add index for each occurrence of ans in response
+ candidates.extend([index for _ in range(response.count(ans))])
+
+ # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
+ if len(candidates) == 0 and len(response.split()) >= 1:
+ for index, ans in index2ans.items():
+ if ans in response:
+ candidates.append(index)
+ # index_ans = False # it's content ans.
+
+ if len(candidates) == 0: # still not get answer, randomly choose one.
+ return random.choice(all_choices)
+ # return ''
+ else:
+ # Count the occurrence of each candidate
+ candidate_counts = Counter(candidates)
+
+ # Select the most frequent candidates
+ max_count = max(candidate_counts.values())
+ most_frequent_candidates = [c for c in all_choices if candidate_counts.get(c, 0) == max_count]
+
+ # Combine the most frequent candidates in ABCD order
+ return ''.join(most_frequent_candidates)
+
+
+def extract_numbers(string):
+ # Pattern for numbers with Chinese commas
+ pattern_commas = r'-?\d{1,3}(?:,\d{3})+'
+ # Pattern for scientific notation
+ pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
+ # Pattern for simple numbers without Chinese commas
+ pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+)(?![eE][+-]?\d+)(?!,\d)'
+
+ # Extract numbers with Chinese commas
+ numbers_with_commas = re.findall(pattern_commas, string)
+ # Extract numbers in scientific notation
+ numbers_scientific = re.findall(pattern_scientific, string)
+ # Extract simple numbers without Chinese commas
+ numbers_simple = re.findall(pattern_simple, string)
+
+ # Combine all extracted numbers
+ all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
+ return all_numbers
+
+
+def check_is_number(string):
+ try:
+ float(string.replace(',', ''))
+ return True
+ except ValueError:
+ # check if there's comma inside
+ return False
+
+
+def count_letters(string):
+ return sum(c.isalpha() and 'a' <= c <= 'z' or 'A' <= c <= 'Z' for c in string)
+
+
+def normalize_str(string, answer):
+ # check if characters in the string
+
+ # if number, numerize it.
+ if string is None:
+ return [string]
+ string = string.strip()
+
+ is_number = check_is_number(string)
+
+ if is_number:
+ string = string.replace(',', '')
+ string = float(string)
+ # leave 2 decimal
+ string = round(string, 2)
+ return [string]
+ else: # it's likely to be a string
+ if len(string) > len(answer) + 20 or count_letters(string) > count_letters(answer) + 2:
+ return []
+ return [string]
+
+
+def get_fill_blank_prediction(response, answer):
+ """get the prediction from the generated response,
+ return a list of predicted strings or numbers"""
+
+ def get_key_subresponses(response):
+ response = response.strip("。").strip()
+ sub_responses = re.split(r'。|\n', response)
+ indicators_of_keys = ['是', '为', '所以', '等于', '方案', '选择',
+ '正确答案', '因此', '最后', '答案', '结果']
+ key_responses = []
+ for index, resp in enumerate(sub_responses):
+ # if last one, accept it's an equation (the entire response can be just one sentence with equation)
+ if index == len(sub_responses) - 1:
+ indicators_of_keys.extend(['='])
+ shortest_key_response = None
+ # the shortest response that may contain the answer (tail part of the response)
+ for indicator in indicators_of_keys:
+ if indicator in resp:
+ if not shortest_key_response:
+ shortest_key_response = resp.split(indicator)[-1].strip()
+ else:
+ if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
+ shortest_key_response = resp.split(indicator)[-1].strip()
+
+ if shortest_key_response:
+ # and it's not trivial
+ if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
+ key_responses.append(shortest_key_response)
+ if len(key_responses) == 0: # did not found any
+ return [response]
+ return key_responses
+
+ key_responses = get_key_subresponses(response)
+
+ pred_list = key_responses.copy() # keep the original string response
+ for resp in key_responses:
+ pred_list.extend(extract_numbers(resp))
+
+ tmp_pred_list = []
+ for i in range(len(pred_list)):
+ tmp_pred_list.extend(normalize_str(pred_list[i], answer))
+ pred_list = tmp_pred_list
+
+ # remove duplicates
+ pred_list = list(set(pred_list))
+
+ return pred_list
+
+
+def get_TF_prediction(response):
+ """get the prediction from the generated response,
+ return a list of predicted strings or numbers"""
+
+ def get_key_subresponses(response):
+ response = response.strip("。").strip()
+ sub_responses = re.split(r'。|\n', response)
+ indicators_of_keys = ['是', '为', '所以', '判断',
+ '陈述', '说法', '表达', '答案', '结果']
+ key_responses = []
+ for index, resp in enumerate(sub_responses):
+ shortest_key_response = None
+ # the shortest response that may contain the answer (tail part of the response)
+ for indicator in indicators_of_keys:
+ if indicator in resp:
+ if not shortest_key_response:
+ shortest_key_response = resp.split(indicator)[-1].strip()
+ else:
+ if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
+ shortest_key_response = resp.split(indicator)[-1].strip()
+
+ if shortest_key_response:
+ # and it's not trivial
+ if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
+ key_responses.append(shortest_key_response)
+ if len(key_responses) == 0: # did not found any
+ return [response]
+ return key_responses
+
+ key_responses = get_key_subresponses(response)
+
+ pred_list = key_responses.copy() # keep the original string response
+ # remove duplicates
+ pred_list = list(set(pred_list))
+
+ return pred_list
+
+
+class CMMMU(ImageBaseDataset):
+ TYPE = 'VQA'
+
+ DATASET_URL = {
+ 'CMMMU_VAL': 'https://opencompass.openxlab.space/utils/VLMEval/CMMMU_VAL.tsv'
+ }
+
+ DATASET_MD5 = {
+ 'CMMMU_VAL': 'b4727e2fce2415bf646379e60c11a726'
+ }
+
+ def dump_image(self, line):
+ os.makedirs(self.img_root, exist_ok=True)
+
+ tgt_path_z = []
+ if isinstance(line['image'], list):
+ for i in range(len(line['image'])):
+ tgt_path = osp.join(self.img_root, f"{line['index']}--{i + 1}.jpg")
+ if not read_ok(tgt_path):
+ decode_base64_to_image_file(line['image'][i], tgt_path)
+ tgt_path_z.append(tgt_path)
+ else:
+ tgt_path = osp.join(self.img_root, f"{line['index']}.jpg")
+ if not read_ok(tgt_path):
+ decode_base64_to_image_file(line['image'], tgt_path)
+ tgt_path_z.append(tgt_path)
+ return tgt_path_z
+
+ @classmethod
+ def evaluate(self, eval_file, **judge_kwargs):
+
+ result_file = get_intermediate_file_path(eval_file, '_acc', 'csv')
+
+ if not osp.exists(result_file):
+ data = load(eval_file)
+ assert 'answer' in data and 'prediction' in data
+ data['prediction'] = [str(x) for x in data['prediction']]
+ data['answer'] = [str(x) for x in data['answer']]
+
+ correct_count = 0
+ correct_category = {
+ '技术与工程': [0, 0],
+ '科学': [0, 0],
+ '健康与医学': [0, 0],
+ '商业': [0, 0],
+ '艺术与设计': [0, 0],
+ '人文社会科学': [0, 0],
+ }
+
+ for i in tqdm(data.iterrows()):
+ line = i[1]
+ correct_category[line['category']][0] += 1
+
+ # Options
+ if line['type'] == '选择':
+ index2ans = {
+ 'A': line['option1'],
+ 'B': line['option2'],
+ 'C': line['option3'],
+ 'D': line['option4']
+ }
+ fact_option = get_multi_choice_prediction(line['prediction'], ['A', 'B', 'C', 'D'], index2ans)
+ if fact_option == line['answer']:
+ correct_count += 1
+ correct_category[line['category']][1] += 1
+
+ # Binary
+ elif line['type'] == '判断':
+ positive_keywords = ['正确', '对', '准确', '肯定', '对的']
+ negative_keywords = ['不对', '错误', '不正确', '不准确', '不合适', '否定', '错的', '错']
+ ambiguous_keywords = ['对错', '是否正确', '否正确', '或者', '是否', '正确性', '对不']
+
+ def judge_similarity(pred_list, positive_keywords, negative_keywords):
+ positive_count = 0
+ negative_count = 0
+
+ for pred in pred_list:
+ if any(pos_word in pred for pos_word in positive_keywords):
+ positive_count += 1
+ elif any(neg_word in pred for neg_word in negative_keywords):
+ negative_count += 1
+
+ if positive_count > negative_count:
+ return "对"
+ elif negative_count > positive_count:
+ return "错"
+ else:
+ return random.choice(['对', '错'])
+
+ answer = get_TF_prediction(line['prediction'])
+ answer = [word for word in answer if not any(ambiguous in word for ambiguous in ambiguous_keywords)]
+ fact_answer = judge_similarity(answer, positive_keywords, negative_keywords)
+ if fact_answer == line['answer']:
+ correct_count += 1
+ correct_category[line['category']][1] += 1
+
+ # Fill the Blank
+ else:
+ norm_answers = normalize_str(line['answer'], line['answer'])
+ predicted_answer = get_fill_blank_prediction(line['prediction'], line['answer'])
+
+ for pred in predicted_answer:
+ # already normalized
+ if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
+ for norm_ans in norm_answers:
+ # only see if the string answer in the string pred
+ # print(norm_ans, pred)
+ if isinstance(norm_ans, str) and norm_ans in pred:
+ correct_count += 1
+ correct_category[line['category']][1] += 1
+ else: # it's a number
+ if pred in norm_answers:
+ correct_count += 1
+ correct_category[line['category']][1] += 1
+
+ accuracyz = {}
+ accuracyz['总准确率'] = correct_count / len(data)
+ for i in correct_category.keys():
+ accuracyz[i] = correct_category[i][1] / correct_category[i][0]
+
+ accuracyz = d2df(accuracyz)
+ accuracyz.round(10)
+ dump(accuracyz, result_file)
+
+ result = pd.read_csv(result_file)
+ return result
+
+ def build_prompt(self, line):
+ if line['type'] == '选择':
+ tgt_path = self.dump_image(line)
+ question = line['question']
+ options_prompt = 'Options:\n'
+
+ for i in [['A', '1'], ['B', '2'], ['C', '3'], ['D', '4']]:
+ options_prompt += i[0] + '. ' + line['option' + i[1]] + '\n'
+
+ prompt = (f'问题: {question}\n' + options_prompt
+ + '请回答上述多项选择题,并选出正确选项。这些题目可能包括单选和多选题型。如果所提供的信息不足以确定一个明确的答案,那么请根据可用的数据和你的判断来选择最可能正确的选项。')
+
+ msgs = []
+ if isinstance(tgt_path, list):
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
+ else:
+ msgs = [dict(type='image', value=tgt_path)]
+ msgs.append(dict(type='text', value=prompt))
+
+ return msgs
+
+ elif line['type'] == '判断':
+ msgs = super().build_prompt(line)
+ assert msgs[-1]['type'] == 'text'
+ msgs[-1]['value'] += '\n请回答上述判断题,并根据题目描述和所给的信息来判断问题中陈述的对错。如果信息不完整或不足以作出绝对判断,请运用你的逻辑推理和现有信息来做出最可能的判断。'
+ return msgs
+
+ else:
+ msgs = super().build_prompt(line)
+ assert msgs[-1]['type'] == 'text'
+ msgs[-1]['value'] += '\n请回答上述填空题,并根据题目的要求和所提供的信息来给出最恰当的答案。如果信息不足以确切回答,那么请依据现有的数据和你的推理能力来填写最合理的答案。'
+ return msgs
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cosmos_cab_image.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cosmos_cab_image.py
new file mode 100644
index 00000000..015f5531
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cosmos_cab_image.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class CosmosCABImage:
+ """Stub for CosmosCABImage."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("CosmosCABImage is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cosmos_cab_video.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cosmos_cab_video.py
new file mode 100644
index 00000000..843f00ac
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cosmos_cab_video.py
@@ -0,0 +1,28 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class CosmosCABVideoCamera:
+ """Stub for CosmosCABVideoCamera."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("CosmosCABVideoCamera is not available in this public release.")
+
+
+class CosmosCABVideoGeneral:
+ """Stub for CosmosCABVideoGeneral."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("CosmosCABVideoGeneral is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cosmos_erqa.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cosmos_erqa.py
new file mode 100644
index 00000000..5fceb32b
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cosmos_erqa.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class CosmosERQA:
+ """Stub for CosmosERQA."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("CosmosERQA is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cosmos_reason.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cosmos_reason.py
new file mode 100644
index 00000000..c3cd1be9
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/cosmos_reason.py
@@ -0,0 +1,17 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Stub: dataset implementation not included in this public release."""
+
+from __future__ import annotations
+
+
+class CosmosReason:
+ """Stub for CosmosReason."""
+
+ @classmethod
+ def supported_datasets(cls):
+ return []
+
+ def __init__(self, *args, **kwargs):
+ raise RuntimeError("CosmosReason is not available in this public release.")
diff --git a/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/creation.py b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/creation.py
new file mode 100644
index 00000000..8f6b6b0d
--- /dev/null
+++ b/evaluation/cosmos3/reasoner/vlmevalkit/vlmeval/dataset/creation.py
@@ -0,0 +1,778 @@
+# flake8: noqa
+import copy
+import os
+import os.path as osp
+import re
+from collections import defaultdict
+
+import pandas as pd
+
+from vlmeval.smp import (decode_base64_to_image_file, dump, get_intermediate_file_path, listinstr,
+ load, read_ok, toliststr)
+from ..utils import track_progress_rich
+from .image_base import ImageBaseDataset
+from .utils import DEBUG_MESSAGE, build_judge
+
+prompt_dict = {}
+prompt_dict['LiveMMBench_Creation'] = {
+ # Subjective Judge [GPT-4o reference]
+ 'subjective':
+ """
+Please act as an impartial judge and evaluate the quality of two responses provided by AI assistants to the user prompt.
+
+Your task is to carefully assess two responses based on provided instructions and evaluation criteria. After evaluating both responses, determine which response features better quality and better meets the criteria. If both responses are similar or nearly identical in quality, you should indicate a tie. Avoid position bias toward the first or second response.
+
+Suggested Steps for Evaluation:
+1. Review both responses independently and then carefully compare their strengths and weaknesses. A good response should feature good language quality, follow the user instruction and meet as many criteria as possible.
+2. After completing the first evaluation, swap the positions of response A and B and repeat Step 1 and get the 2nd evaluation outcome. This helps to mitigate the potential position bias.
+3. After completing both evaluations (in the original and reversed order), combine your analysis and provide a final conclusion based on the overall assessment. If both responses are relatively similar, or the differences are minimal and hard to distinguish, your conclusion should indicate a tie ([[A=B]]).
+
+Your **conclusion** should be one of the following options (A, B are of the original order):
+1. [[A>>B]]: Response A is clearly better than Response B.
+2. [[A>B]]: Response A is slightly better than Response B.
+3. [[A=B]]: Response A is nearly identical to Response B.
+4. [[B>A]]: Response B is slightly better than Response A.
+5. [[B>>A]]: Response B is clearly better than Response A.
+
+User Instruction:\n[INSTRUCTIONS]\n{instructions}\n[END INSTRUCTIONS]\n\n
+Repsonse A:\n[RESPONSE A]\n{reference_answer_by_gpt4o}\n[END RESPONSE A]\n\n
+Response B:\n[RESPONSE B]\n{prediction}\n[END RESPONSE B]\n\n
+Evaluation Criteria:\n[CRITERIA]\n{criteria}\n[END CRITERIA]\n\n
+
+Your output should include:
+1. Conclusion: Your final conclusion based on the overall assessment.
+2. Reasoning: Your reasoning process and analysis of the two responses.
+
+Your output should follow the following format (CONCLUSION should be one of the five options: A>>B, A>B, A=B, B>A, B>>A):
+
+Final Conclusion: [[CONCLUSION]]
+Reasoning Process: [REASONING]\n
+""", # noqa: E501
+
+ # Criteria Alignment w/o GT
+ 'objective_without_gt':
+ """
+Please act as an impartial judge and evaluate the **Criteria Alignment** of the two responses provided by AI assistants to the user prompt. The responses were generated based on the provided instructions and visual input from images.
+
+Suggested Steps for Evaluation:
+1. Evaluate **Criteria Alignment** of both responses based on the criteria.
+ • If a criterion consist of **X aspects**, each aspect is worth **10 / X points**.
+ • For each aspect, there may be multiple sub-criteria. If there are **Y sub-criteria for the aspect**, each sub-criterion worths **10 / (X * Y) points**.
+2. Assign a total score out of 10 for each response.
+
+User Instruction:\n[INSTRUCTIONS]\n{instructions}\n[END INSTRUCTIONS]\n\n
+Repsonse A:\n[RESPONSE A]\n{reference_answer_by_gpt4o}\n[END RESPONSE A]\n\n
+Response B:\n[RESPONSE B]\n{prediction}\n[END RESPONSE B]\n\n
+Criteria:\n[CRITERIA]\n{criteria}\n[END CRITERIA]\n\n
+
+Your output should evaluate alignment scores of each response and end with a conclusion in the following format (The full score is 10. X, Y are alignment scores for Response A and B):
+
+Response A Alignment Score: X/10
+Response B Alignment Score: Y/10\n
+""", # noqa: E501
+
+ # Criteria Alignment w. GT
+ 'objective_with_gt':
+ """
+Please act as an impartial judge and evaluate the **Criteria Alignment** of the two responses provided by AI assistants to the user prompt. The responses were generated based on the provided instructions and visual input from images. There is also a ground truth corresponding to the instructions provided for reference.
+Take this context into account when making your judgment.
+
+Steps for Evaluation:
+1. Evaluate **Criteria Alignment** of both responses based on the criteria and the ground truth.
+ • If a criterion consist of **X aspects**, each aspect is worth **10 / X points**.
+ • For each aspect, there may be multiple sub-criteria. If there are **Y sub-criteria for the aspect**, each sub-criterion worths **10 / (X * Y) points**.
+2. Assign a total score out of 10 for each response.
+
+User Instruction:\n[INSTRUCTIONS]\n{instructions}\n[END INSTRUCTIONS]\n\n
+Ground Truth:\n[GROUND TRUTH]\n{groundtruth}\n[END GROUND TRUTH]\n\n
+Repsonse A:\n[RESPONSE A]\n{reference_answer_by_gpt4o}\n[END RESPONSE A]\n\n
+Response B:\n[RESPONSE B]\n{prediction}\n[END RESPONSE B]\n\n
+Criteria:\n[CRITERIA]\n{criteria}\n[END CRITERIA]\n\n
+
+Your output should evaluate alignment scores of each response and end with a conclusion in the following format (The full score is 10. X, Y are alignment scores for Response A and B):
+
+Response A Alignment Score: X/10
+Response B Alignment Score: Y/10\n
+""", # noqa: E501
+}
+
+prompt_dict['Creation_MMBench'] = {
+ # Subjective Judge [GPT-4o reference, with image]
+ 'subjective':
+ """
+Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user prompt below, considering both the provided criteria and the image.
+
+Your task is to carefully assess each response based on how well it meets the evaluation criteria, incorporating the visual context from the image. The criteria should be the primary basis for your judgment, with the image serving to complement and inform your analysis.
+
+Steps for Evaluation:
+ 1. Review Both Responses Independently:
+ Carefully analyze Assistant A’s and Assistant B’s responses with the criteria and the image. Do not assume any response is better just because it is listed first. Each response should be independently assessed based on the criteria and aided by images to help understand the context.
+
+ 2. Compare the Strengths and Weaknesses:
+ After evaluating each response independently, compare the two. Consider both the quality of the content and how closely it aligns with the criteria and image. Identify the strengths and weaknesses of each response, and highlight the key differences.
+
+ 3. Ensure Fairness:
+ To avoid positional bias, swap the positions of Assistant A and Assistant B after the first evaluation (i.e., make Assistant A become Assistant B and vice versa) and repeat the analysis and comparison. This ensures that each response is evaluated impartially under the same criteria.
+
+ 4. Provide a Conclusion Based on Both Evaluations:
+ After completing both evaluations (original and swapped positions), combine your analysis to provide a final verdict. If the responses are similar, with only minimal differences, your judgment should reflect that and indicate a tie.
+
+Possible Verdict Options:
+
+• If Assistant A is clearly better in both evaluations: [[A>>B]]
+• If Assistant A is slightly better in both evaluations: [[A>B]]
+• If both responses are nearly identical, showing minimal differences and no clear advantage: [[A=B]]
+• If Assistant B is slightly better in both evaluations: [[B>A]]
+• If Assistant B is clearly better in both evaluations: [[B>>A]]
+
+Instructions to the AI Assistants:
+
+[INSTRUCTIONS]
+{instructions}
+[END INSTRUCTIONS]
+
+Assistant A Response:
+
+[ASSISTANT A]
+{reference_answer_by_gpt4o}
+[END ASSISTANT A]
+
+Evaluation Criteria:
+
+[CRITERIA]
+{criteria}
+[END CRITERIA]
+
+Assistant B Response:
+
+[ASSISTANT B]
+{prediction}
+[END ASSISTANT B]
+
+Output Format:
+
+Your output should include:
+ 1. Evaluation of Assistant A’s Response: Provide a detailed qualitative evaluation, focusing on how well Assistant A’s response aligns with the criteria and the image.
+ 2. Evaluation of Assistant B’s Response: Provide a detailed qualitative evaluation, focusing on how well Assistant B’s response aligns with the criteria and the image.
+ 3. Final Verdict: After considering both evaluations, select one of the following verdicts and justify it based on your analysis:
+
+Your output format should end like this:
+Assistant A Evaluation: [qualitative comment]
+Assistant B Evaluation: [qualitative comment]
+Final Verdict is: [[VERDICT]]
+""", # noqa: E501
+
+ ##### For Visual Factuality
+ 'objective_without_gt':
+ """
+Please act as an impartial judge and evaluate the **Visual Factuality** of the responses provided by two AI assistants to the user prompt displayed below.
+
+The responses were generated based on the provided instructions and visual input from images. Take this context into account when making your judgment.
+
+Steps for Evaluation:
+1. Evaluate visual factuality for both responses based on the visual factuality criteria.
+ • If the visual factuality criteria consist of **X aspects**, each aspect is worth **10/X points**.
+ • For each aspect, there may be multiple small criteria. If there are **Y small criteria in one aspect**, each small criterion is worth **10/X/Y points**.
+2. Assign a total score out of 10 for each response.
+
+Instructions to the AI assistants:
+[INSTRUCTIONS]
+{instructions}
+[END INSTRUCTIONS]
+
+Assistant A response:
+[ASSISTANT A]
+{reference_answer_by_gpt4o}
+[END ASSISTANT A]
+
+Visual Factuality Criteria:
+[VISUAL FACTUALITY CRITERIA]
+{criteria}
+[END CRITERIA]
+
+Assistant B response:
+[ASSISTANT B]
+{prediction}
+[END ASSISTANT B]
+
+Your output should evaluate visual factuality scores for each assistant and end like this:
+
+Response A Visual Factuality Score: X/10
+Response B Visual Factuality Score: Y/10
+""", # noqa: E501
+ 'objective_with_gt':
+ """
+Please act as an impartial judge and evaluate the **Visual Factuality** of the responses provided by two AI assistants to the user prompt displayed below.
+
+The responses were generated based on the provided instructions and visual input from images.
+There is a provided ground truth for the instructions, but the ground truth was not given to the AI assistants when generating their responses.
+Take this context into account when making your judgment.
+
+Steps for Evaluation:
+1. Evaluate visual factuality for both responses based on the provided ground truth and visual factuality criteria.
+ • If the visual factuality criteria consist of **X aspects**, each aspect is worth **10/X points**.
+ • For each aspect, there may be multiple small criteria. If there are **Y small criteria in one aspect**, each small criterion is worth **10/X/Y points**.
+2. Assign a total score out of 10 for each response.
+
+Instructions to the AI assistants:
+[INSTRUCTIONS]
+{instructions}
+[END INSTRUCTIONS]
+
+Assistant A response:
+[ASSISTANT A]
+{reference_answer_by_gpt4o}
+[END ASSISTANT A]
+
+Visual Factuality Criteria:
+[VISUAL FACTUALITY CRITERIA]
+{criteria}
+[END CRITERIA]
+
+Assistant B response:
+[ASSISTANT B]
+{prediction}
+[END ASSISTANT B]
+
+Ground truth:
+[GROUND TRUTH]
+{groundtruth}
+[END GROUND TRUTH]
+
+Your output should evaluate visual factuality scores for each assistant and end like this:
+
+Response A Visual Factuality Score: X/10
+Response B Visual Factuality Score: Y/10
+""", # noqa: E501
+}
+
+creation_mmbench_category_dict = {
+ 'CATEGORY_Literary_Writing': [
+ 'story_continue',
+ 'landscape_to_poem',
+ 'historical_story_creation',
+ 'story_novel_creation',
+ 'prose_writing_scenery',
+ 'art_inspired_prose',
+ 'daily_conversation_creation',
+ 'children_book_illustration_dialogue_creation'
+ ],
+ 'CATEGORY_Common_Functionality_Writing':[
+ 'ins_simple_daily_copywriter',
+ 'travel_journal',
+ 'short_video_scripts_for_social_media',
+ 'social_media_travel_content',
+ 'daily_achievement_show_off',
+ 'scientific_research_simple_promotion',
+ 'twitter_comment_on_daily_news',
+ 'personal_event_summaries',
+ 'daily_affairs_inquiries',
+ 'business_collaborative_email_writing',
+ 'daily_emotional_email_writing',
+ 'letter_of_complaint',
+ 'daily_invitation_email_writing',
+ 'holiday_card_writing',
+ 'letter_of_application',
+ 'product_usage_experience_review',
+ 'store_experience_review',
+ 'public_welfare_activity_participation_initiative'
+ ],
+ 'CATEGORY_Professional_Functionality_Writing': [
+ 'museum_guide_word_creation',
+ 'recipe_infer_and_guide',
+ 'landscape_introduction',
+ 'drafting_announcements_for_public_spaces',
+ 'floor_plan_renovation_design',
+ 'teaching_plan',
+ 'nutritional_formulation_of_recipe',
+ 'clothing_match_design',
+ 'software_engineering_diagram_explanation',
+ 'event_planning_and_venue_arrangement',
+ 'ui_design_analysis_and_optimization',
+ 'attraction_promotional_words',
+ 'product_marketing_strategy',
+ 'script_writing_for_product_advertisement_promotional_video',
+ 'residence_reasoning',
+ 'scientific_diagram_understanding',
+ 'pulitzer_prize_judge',
+ 'architecture_appreciation',
+ 'company_team_amuse_broadcast'
+ ],
+ 'CATEGORY_Creative_Multimodal_Understanding': [
+ 'travel_itinerary_planning_and_recommendations',
+ 'photography_appreciation',
+ 'meme_explanation',
+ 'advertisement_explanation',
+ 'document_understanding',
+ 'snapshot_analysis'
+ ]
+
+}
+
+
+def is_criteria_valid(criteria):
+ import re
+ for value in criteria.values():
+ if value == '\\' or value == '' or not re.search('[a-zA-Z]', value):
+ return False
+ return True
+
+
+key_mapping = {
+ "sub_parse_ok": "preference_parse_ok",
+ "sub_dist": "preference_dist",
+ "win_rate": "win_rate",
+ "sub_reward": "reward",
+ "obj_parse_ok": "visual_factuality_parse_ok",
+ "obj_score": "visual_factuality_score",
+ "obj_ref_score": "visual_factuality_ref_score"
+}
+
+
+def rename_keys(data, key_mapping):
+ if isinstance(data, dict):
+ new_data = {}
+ for key, value in data.items():
+ new_key = key_mapping.get(key, key)
+ new_data[new_key] = rename_keys(value, key_mapping)
+ return new_data
+ elif isinstance(data, list):
+ return [rename_keys(item, key_mapping) for item in data]
+ else:
+ return data
+
+
+def build_prompt(line, dataset_name):
+ try:
+ criteria = eval(line['criteria'])
+ except Exception:
+ criteria = line['criteria']
+
+ if isinstance(criteria, dict):
+ new_criteria = {}
+ for k in criteria:
+ if 'subjective' in k.lower():
+ new_criteria['subjective'] = criteria[k]
+ else:
+ new_criteria['objective'] = criteria[k]
+ else:
+ assert isinstance(criteria, str)
+ new_criteria = {'subjective': criteria}
+ criteria = new_criteria
+ assert 'subjective' in criteria, 'No subjective criteria found in the criteria dict'
+
+ prompts = {}
+ if listinstr(['Creation_MMBench'], dataset_name):
+ dataset_name = 'Creation_MMBench'
+ prompts['subjective'] = prompt_dict[dataset_name]['subjective'].format(
+ instructions=line['question'],
+ criteria=criteria['subjective'],
+ reference_answer_by_gpt4o=line['reference_answer_by_gpt4o'],
+ prediction=line['prediction'])
+ if 'objective' in criteria:
+ if 'ground_truth' in line and (not pd.isna(
+ line['ground_truth'])) and line['ground_truth'] != '':
+ prompts['objective'] = prompt_dict[dataset_name]['objective_with_gt'].format(
+ instructions=line['question'],
+ criteria=criteria['objective'],
+ groundtruth=line['ground_truth'],
+ reference_answer_by_gpt4o=line['reference_answer_by_gpt4o'],
+ prediction=line['prediction'])
+ else:
+ prompts['objective'] = prompt_dict[dataset_name]['objective_without_gt'].format(
+ instructions=line['question'],
+ criteria=criteria['objective'],
+ reference_answer_by_gpt4o=line['reference_answer_by_gpt4o'],
+ prediction=line['prediction'])
+ return prompts
+
+
+def Generate_Creation_MMBench_judge(model, image_list, prompt):
+ assert isinstance(prompt, dict)
+ response = {}
+ for key in prompt.keys():
+ if image_list and key == 'subjective':
+ input_msg = []
+ for img_path in image_list:
+ if read_ok(img_path):
+ input_msg.append({'type': 'image', 'value': img_path})
+ else:
+ raise ValueError(f"Image not found: {img_path}")
+ input_msg.append({'type': 'text', 'value': prompt[key]})
+ # print(f'using image {image_list} and text')
+ response[key] = model.generate(input_msg)
+ else:
+ response[key] = model.generate(prompt[key])
+ return response
+
+
+def extract_subjective(inp, dataset_name):
+ mapping_dict = {
+ 'LiveMMBench_Creation': 'FINAL CONCLUSION:',
+ 'Creation_MMBench': 'FINAL VERDICT IS:'
+ }
+ cands = {'A>>B', 'A>B', 'A=B', 'B>A', 'B>>A', 'B<>A' in text:
+ return 2
+ elif 'AA' in text:
+ return 1
+ elif 'A=B' in text or 'B=A' in text:
+ return 0
+ elif 'A>B' in text or 'B>B' in text or 'B<'.
+ if " |