diff --git a/methods/__init__.py b/methods/__init__.py
index e943507..3c29138 100644
--- a/methods/__init__.py
+++ b/methods/__init__.py
@@ -12,6 +12,7 @@
from .mapcoder import MapCoder_HumanEval, MapCoder_MBPP
from .self_consistency import SelfConsistency
from .mav import MAV_GPQA, MAV_HumanEval, MAV_Main, MAV_MATH, MAV_MMLU
+from .aflow import AFlow_MATH
method2class = {
"vanilla": MAS,
@@ -38,7 +39,8 @@
"mav_humaneval": MAV_HumanEval,
"mav_main": MAV_Main,
"mav_math": MAV_MATH,
- "mav_mmlu": MAV_MMLU
+ "mav_mmlu": MAV_MMLU,
+ "aflow_math": AFlow_MATH
}
def get_method_class(method_name, dataset_name=None):
diff --git a/methods/aflow/__init__.py b/methods/aflow/__init__.py
new file mode 100644
index 0000000..e26512d
--- /dev/null
+++ b/methods/aflow/__init__.py
@@ -0,0 +1 @@
+from .aflow_math import AFlow_MATH
\ No newline at end of file
diff --git a/methods/aflow/aflow_math.py b/methods/aflow/aflow_math.py
new file mode 100644
index 0000000..0bc07c6
--- /dev/null
+++ b/methods/aflow/aflow_math.py
@@ -0,0 +1,708 @@
+import copy
+import importlib
+import datetime,os,json,asyncio,re,random
+import pandas as pd
+import numpy as np
+import shutil
+import time
+import warnings
+
+from termcolor import colored
+from pathlib import Path
+from collections import defaultdict
+from openai import AsyncOpenAI
+from tenacity import retry, wait_exponential, stop_after_attempt
+from tqdm.asyncio import tqdm_asyncio
+from typing import Dict,Any,Tuple
+from pydantic_core import to_jsonable_python
+
+from .prompt import *
+from .evaluate import evaluate_math
+from ..mas_base import MAS
+from ..utils import handle_retry_error
+
+warnings.filterwarnings("ignore", category=SyntaxWarning, message="invalid escape sequence")
+
+class AFlow_MATH(MAS):
+ def __init__(self, general_config, method_config_name="config"):
+ method_config_name = "config_main" if method_config_name is None else method_config_name
+ super().__init__(general_config, method_config_name)
+
+ self.dataset_name = general_config['test_dataset_name']
+ self.model_name_optimize = self.method_config.get('optimize_meta_model_name','gpt-4o')
+ self.model_name_execute = self.method_config.get('optimize_execute_model_name','gpt-4o-mini-2024-07-18')
+ self.sample = self.method_config['sample']
+ self.max_rounds = self.method_config['max_rounds']
+ self.validation_rounds = self.method_config['validation_rounds']
+ self.earlystop = self.method_config['earlystop']
+ self.root_path = str(os.path.relpath(Path(__file__).parent, start=os.getcwd()))
+ self.results_path = f"results/{self.dataset_name}/aflow/{self.model_name_optimize}/{self.model_name_execute}"
+ self.top_scores = []
+ self.round = 1
+ self.graph = None
+
+ self.operators:list = ["Custom", "ScEnsemble", "Programmer"]
+ self.type = "math"
+
+ results_path = Path(self.results_path)
+ if not results_path.exists():
+ graph_path = Path(self.root_path) / "initial_workflows" / "math"
+ results_path.mkdir(parents=True, exist_ok=True)
+ exp_path = os.path.join(self.results_path, "processed_experience.json")
+ res_path = os.path.join(self.results_path, "results.json")
+ with open(exp_path, 'w') as f:
+ pass
+ with open(res_path, 'w') as f:
+ pass
+ for item in graph_path.iterdir():
+ dest = results_path / item.name
+ if item.is_dir():
+ shutil.copytree(item, dest, dirs_exist_ok=True)
+ else:
+ shutil.copy2(item, dest)
+
+ self.optimized_round = 1
+ self.inference_flag = True
+
+ model_dict = random.choice(self.model_api_config[self.model_name_optimize]["model_list"])
+ model_url, api_key = model_dict['model_url'], model_dict['api_key']
+ self.asyncllm = AsyncOpenAI(base_url=model_url, api_key=api_key)
+
+ @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(5), retry_error_callback=handle_retry_error)
+ async def async_call_llm(self, prompt=None, system_prompt=None, messages=None, model_name=None, temperature=None):
+
+ model_name = model_name if model_name is not None else self.model_name
+
+ if messages is None:
+ assert prompt is not None, "'prompt' must be provided if 'messages' is not provided."
+ if system_prompt is not None:
+ messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
+ else:
+ messages = [{"role": "user", "content": prompt}]
+
+ model_temperature = temperature if temperature is not None else self.model_temperature
+
+ request_dict = {
+ "model": model_name,
+ "messages": messages,
+ "max_tokens": self.model_max_tokens,
+ "timeout": self.model_timeout
+ }
+ if "o1" not in model_name: # OpenAI's o1 models do not support temperature
+ request_dict["temperature"] = model_temperature
+
+ completion = await self.asyncllm.chat.completions.create(**request_dict)
+ response, num_prompt_tokens, num_completion_tokens = completion.choices[0].message.content, completion.usage.prompt_tokens, completion.usage.completion_tokens
+
+ if isinstance(response, str): # in cases where response is None or an error message
+ if model_name not in self.token_stats:
+ self.token_stats[model_name] = {"num_llm_calls": 1, "prompt_tokens": num_prompt_tokens, "completion_tokens": num_completion_tokens}
+ else:
+ self.token_stats[model_name]["num_llm_calls"] += 1
+ self.token_stats[model_name]["prompt_tokens"] += num_prompt_tokens
+ self.token_stats[model_name]["completion_tokens"] += num_completion_tokens
+ else:
+ raise ValueError(f"Invalid response from LLM: {response}")
+
+ return response
+
+ def inference(self, query):
+ """
+ query: Query to be passed to the MAS
+ """
+ self.inference_flag = True
+ optimized_path = Path(self.results_path) / "best_workflow"
+ if optimized_path.exists():
+ graph_module_name = f"results.{self.dataset_name}.aflow.{self.model_name_optimize}.{self.model_name_execute}.best_workflow.graph"
+ else:
+ raise NotImplementedError("Best_workflow path does not exist!")
+ module = importlib.import_module(graph_module_name, package=__package__)
+ self.graph = getattr(module, "Workflow")
+
+ graph = self.graph(name="Optimized", env=self)
+
+ response = asyncio.run(graph(problem=query))
+ return response
+
+ def optimizing(self,val_dataset):
+ self.inference_flag = False
+
+ optimized_path = Path(self.results_path) / "best_workflow"
+ if optimized_path.exists():
+ print(colored("The optimal graph already exists!\n","red"))
+ return
+
+ print(colored("Start optimizing ...\n","yellow"))
+ for i in range(self.max_rounds):
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ retry_count = 0
+ max_retries = 1
+ while retry_count < max_retries:
+ try:
+ print(colored(f"{i+1} round of optimization...\n","light_cyan"))
+ score = loop.run_until_complete(self._optimize_graph(val_dataset))
+ break
+ except Exception as e:
+ retry_count += 1
+ print(f"Optimization failed: {e}")
+ if retry_count == max_retries:
+ score = None
+ wait_time = 5 * retry_count
+ time.sleep(wait_time)
+
+ if retry_count < max_retries:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
+ self.round += 1
+ print(f"Score for round {self.round}: {score}")
+ self.save_optimized_graph()
+ converged, convergence_round, final_round = self.check_convergence(top_k=3)
+ if self.earlystop and converged:
+ print(f"Convergence detected, occurred in round {convergence_round}, final round is {final_round}")
+ self.print_results()
+ break
+
+
+ print(colored("Optimization complete!","green"))
+ print(colored(f"\n>> Optimization token stats: {self.get_token_stats()}","light_yellow"))
+ token_path = os.path.join(self.results_path,"api_token.json")
+ os.makedirs(os.path.dirname(token_path), exist_ok=True)
+ with open(token_path,"w") as f:
+ json.dump(self.get_token_stats(), f, indent=4)
+
+ async def _optimize_graph(self,val_dataset):
+ validation_n = self.validation_rounds
+ graph_path = self.results_path
+ result_path = os.path.join(graph_path, "results.json")
+ data=[]
+ if os.path.exists(result_path):
+ with open(result_path, "r") as json_file:
+ try:
+ data = json.load(json_file)
+ except json.JSONDecodeError:
+ data = []
+ else:
+ data = []
+ if self.round == 1:
+ directory = os.path.join(graph_path, f"round_{self.round}")
+ os.makedirs(directory, exist_ok=True)
+
+ graph_module_name = f"results.{self.dataset_name}.aflow.{self.model_name_optimize}.{self.model_name_execute}.round_{self.round}.graph"
+ module = importlib.import_module(graph_module_name, package=__package__)
+ self.graph = getattr(module, "Workflow")
+ avg_score = await self.evaluate_graph(directory, validation_n, data,val_dataset,True)
+
+
+ while True:
+ directory = os.path.join(graph_path, f"round_{self.round+1}")
+ os.makedirs(directory, exist_ok=True)
+
+ #parent <- SelectParent(results)
+ top_rounds = self.get_top_rounds()
+ sample,_ = self.select_round(top_rounds)
+
+ prompt, graph_load = self.read_graph_files(sample["round"], graph_path)
+ pattern = r"class Workflow:.+"
+ graph = re.findall(pattern, graph_load, re.DOTALL)
+
+ #context <- LoadContext(parent,experiences)
+ processed_experience = self.load_experience()
+ experience = self.format_experience(processed_experience, sample["round"])
+
+ path = os.path.join(graph_path, "template/operator.json")
+ operators_description = ""
+ for id, operator in enumerate(self.operators):
+ with open(path, "r") as f:
+ operator_data = json.load(f)
+ matched_data = operator_data[operator]
+ desc = matched_data["description"]
+ interface = matched_data["interface"]
+ operator_description = f"{id+1}. {operator}: {desc}, with interface {interface})."
+ operators_description += f"{operator_description}\n"
+
+ log_data = self.load_log(sample["round"])
+
+ graph_input = WORKFLOW_INPUT.format(
+ experience=experience,
+ score=sample["score"],
+ graph=graph[0],
+ prompt=prompt,
+ operator_description=operators_description,
+ type=self.type,
+ log=log_data,
+ )
+ graph_system = WORKFLOW_OPTIMIZE_PROMPT.format(type=self.type)
+ graph_optimize_prompt = graph_input + WORKFLOW_CUSTOM_USE + graph_system
+ names = ["modification","graph","prompt"]
+ types = {"modification":str,"graph":str,"prompt":str}
+ examples = []
+ for name in names:
+ examples.append(f"<{name}>content{name}>")
+
+ example_str = "\n".join(examples)
+ instructions = graph_optimize_prompt+ f"\n# Response format (must be strictly followed) (do not include any other formats except for the given XML format):\n{example_str}"
+ response = await self.async_call_llm(prompt=instructions,model_name=self.model_name_optimize)
+
+ response = self.xml_extract(response,names,types)
+ # Check if the modification meets the conditions
+ check = self.check_modification(
+ processed_experience, response["modification"], sample["round"]
+ )
+
+ # If `check` is True, break the loop; otherwise, regenerate the graph
+ if check:
+ break
+
+ # Save the graph and evaluate
+ graph = WORKFLOW_TEMPLATE.format(graph=response["graph"], round=self.round + 1, dataset=self.dataset_name)
+
+ with open(os.path.join(directory, "graph.py"), "w", encoding="utf-8") as file:
+ file.write(graph)
+
+ with open(os.path.join(directory, "prompt.py"), "w", encoding="utf-8") as file:
+ file.write(response["prompt"])
+
+ with open(os.path.join(directory, "__init__.py"), "w", encoding="utf-8") as file:
+ file.write("")
+ experience = {
+ "father node": sample["round"],
+ "modification": response["modification"],
+ "before": sample["score"],
+ "after": None,
+ "succeed": None,
+ }
+ graph_module_name = f"results.{self.dataset_name}.aflow.{self.model_name_optimize}.{self.model_name_execute}.round_{self.round+1}.graph"
+ module = importlib.import_module(graph_module_name, package=__package__)
+ self.graph = getattr(module, "Workflow")
+ avg_score = await self.evaluate_graph(directory, validation_n, data,val_dataset)
+
+ experience["after"] = avg_score
+ experience["succeed"] = bool(avg_score > experience["before"])
+ folder_path = Path(os.path.join(directory, "experience.json")).parent
+ if not folder_path.exists():
+ folder_path.mkdir(parents=True, exist_ok=True)
+
+ with open(os.path.join(directory, "experience.json"), "w", encoding="utf-8") as fout:
+ json.dump(experience, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
+ return avg_score
+
+ async def evaluate_graph(self, directory, validation_n, data, val_dataset,initial=False):
+ sum_score = 0
+ max_workers = 50
+ for _ in range(validation_n):
+ graph = self.graph(name=self.dataset_name+f"/round_{self.round}", env=self)
+ semaphore = asyncio.Semaphore(max_workers)
+ async def sem_evaluate(problem):
+ async with semaphore:
+ return await evaluate_math(problem, graph, directory)
+ tasks = [sem_evaluate(problem) for problem in val_dataset]
+ results = await tqdm_asyncio.gather(*tasks, desc=f"Evaluating {self.type} problems", total=len(val_dataset))
+ columns = ["question", "prediction", "expected_output", "score"]
+ df = pd.DataFrame(results, columns=columns)
+ average_score = df["score"].mean()
+
+
+ cur_round = self.round + 1 if initial is False else self.round
+ now = datetime.datetime.now()
+ new_data = {"round": cur_round, "score": average_score,"time": now}
+ data.append(new_data)
+
+ result_path = os.path.join(self.results_path, "results.json")
+ folder_path = Path(result_path).parent
+ if not folder_path.exists():
+ folder_path.mkdir(parents=True, exist_ok=True)
+
+ with open(result_path, "w", encoding="utf-8") as fout:
+ json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
+ sum_score += average_score
+
+ return sum_score / validation_n
+
+ def check_convergence(self, top_k=3, z=0, consecutive_rounds=5):
+ result_file = os.path.join(self.results_path, "results.json")
+ with open(result_file, "r") as file:
+ self.data = json.load(file)
+ rounds = {}
+ for entry in self.data:
+ round_number = entry["round"]
+ score = entry["score"]
+ if round_number not in rounds:
+ rounds[round_number] = []
+ rounds[round_number].append(score)
+ self.rounds = rounds
+ sorted_rounds = sorted(self.rounds.items(), key=lambda x: x[0])
+ avg_scores = []
+ stds = []
+ for round_number, scores in sorted_rounds:
+ avg_scores.append(np.mean(scores))
+ stds.append(np.std(scores))
+ # If total rounds are not enough to calculate top_k+1 rounds, return not converged
+ if len(avg_scores) < top_k + 1:
+ return False, None, None
+ convergence_count = 0 # Convergence counter
+ previous_y = None # Y value of the previous round (average of top_k scores)
+ sigma_y_previous = None # Standard error of Y value from previous round
+ for i in range(len(avg_scores)):
+ # Dynamically select top_k from current round and all previous rounds
+ top_k_indices = np.argsort(avg_scores[: i + 1])[::-1][:top_k] # Select top k indices by descending average score
+ top_k_scores = [avg_scores[j] for j in top_k_indices] # Get list of top k scores
+ top_k_stds = [
+ stds[j] for j in top_k_indices
+ ] # Get list of standard deviations corresponding to top k scores
+ # Calculate mean of top k scores for current round, i.e., y_current
+ y_current = np.mean(top_k_scores)
+ # Calculate standard error of y_current (sigma_y_current), representing score dispersion
+ sigma_y_current = np.sqrt(np.sum([s**2 for s in top_k_stds]) / (top_k**2))
+ # If not the first round, calculate change in Y (Delta_Y) and corresponding standard error
+ if previous_y is not None:
+ # Calculate Y difference between current round and previous round
+ delta_y = y_current - previous_y
+ # Calculate standard error of Y difference (sigma_Delta_Y)
+ sigma_delta_y = np.sqrt(sigma_y_current**2 + sigma_y_previous**2)
+ # Check if Y change is within acceptable confidence interval, i.e., convergence condition
+ if abs(delta_y) <= z * sigma_delta_y:
+ convergence_count += 1
+ # If consecutive converged rounds reach set value, return convergence information
+ if convergence_count >= consecutive_rounds:
+ return True, i - consecutive_rounds + 1, i
+ else:
+ # If change is large, reset convergence counter
+ convergence_count = 0
+ # Update Y value and standard error for previous round
+ previous_y = y_current
+ sigma_y_previous = sigma_y_current
+ # If convergence condition not met, return not converged
+ return False, None, None
+
+ def get_top_rounds(self):
+ rounds_dir = self.results_path
+ result_file = os.path.join(rounds_dir, "results.json")
+ self.top_scores = []
+ if not Path(result_file).exists():
+ raise FileNotFoundError(f"json_file: {result_file} not exist, return []")
+ with open(result_file, "r", encoding="utf-8") as fin:
+ try:
+ data = json.load(fin)
+ except Exception:
+ raise ValueError(f"read json file: {result_file} failed")
+ df = pd.DataFrame(data)
+
+ scores_per_round = df.groupby("round")["score"].mean().to_dict()
+
+ for round_number, average_score in scores_per_round.items():
+ self.top_scores.append({"round": int(round_number), "score": average_score})
+
+ self.top_scores.sort(key=lambda x: x["score"], reverse=True)
+
+ unique_rounds = set()
+ unique_top_scores = []
+
+ first_round = next((item for item in self.top_scores if item["round"] == 1), None)
+ if first_round:
+ unique_top_scores.append(first_round)
+ unique_rounds.add(1)
+
+ for item in self.top_scores:
+ if item["round"] not in unique_rounds:
+ unique_top_scores.append(item)
+ unique_rounds.add(item["round"])
+
+ if len(unique_top_scores) >= self.sample:
+ break
+
+ return unique_top_scores
+
+ def select_round(self, items,alpha=0.2, lambda_=0.3):
+
+ if not items:
+ raise ValueError("Item list is empty.")
+
+ sorted_items = sorted(items, key=lambda x: x["score"], reverse=True)
+ scores = [item["score"] * 100 for item in sorted_items]
+
+ scores = np.array(scores, dtype=np.float64)
+ n = len(scores)
+
+ if n == 0:
+ raise ValueError("Score list is empty.")
+
+ uniform_prob = np.full(n, 1.0 / n, dtype=np.float64)
+
+ max_score = np.max(scores)
+ shifted_scores = scores - max_score
+ exp_weights = np.exp(alpha * shifted_scores)
+
+ sum_exp_weights = np.sum(exp_weights)
+ if sum_exp_weights == 0:
+ raise ValueError("Sum of exponential weights is 0, cannot normalize.")
+
+ score_prob = exp_weights / sum_exp_weights
+
+ mixed_prob = lambda_ * uniform_prob + (1 - lambda_) * score_prob
+
+ total_prob = np.sum(mixed_prob)
+ if not np.isclose(total_prob, 1.0):
+ mixed_prob = mixed_prob / total_prob
+
+
+ print(f"\nMixed probability distribution: {mixed_prob}")
+ print(f"\nSorted rounds: {sorted_items}")
+
+ selected_index = np.random.choice(len(sorted_items), p=mixed_prob)
+ print(f"\nSelected index: {selected_index}, Selected item: {sorted_items[selected_index]}")
+
+ return sorted_items[selected_index], sorted_items
+
+ def read_graph_files(self, round_number: int, workflows_path: str):
+ prompt_file_path = os.path.join(workflows_path, f"round_{round_number}", "prompt.py")
+ graph_file_path = os.path.join(workflows_path, f"round_{round_number}", "graph.py")
+
+ try:
+ with open(prompt_file_path, "r", encoding="utf-8") as file:
+ prompt_content = file.read()
+ with open(graph_file_path, "r", encoding="utf-8") as file:
+ graph_content = file.read()
+ except FileNotFoundError as e:
+ print(f"Error: File not found for round {round_number}: {e}")
+ raise
+ except Exception as e:
+ print(f"Error loading prompt for round {round_number}: {e}")
+ raise
+ return prompt_content, graph_content
+
+ def load_experience(self):
+ rounds_dir = os.path.normpath(self.results_path)
+ experience_data = defaultdict(lambda: {"score": None, "success": {}, "failure": {}})
+
+ for round_dir in os.listdir(rounds_dir):
+ if os.path.isdir(os.path.join(rounds_dir, round_dir)) and round_dir.startswith("round_"):
+ round_path = os.path.join(rounds_dir, round_dir)
+ try:
+ round_number = int(round_dir.split("_")[1])
+ json_file_path = os.path.join(round_path, "experience.json")
+ if os.path.exists(json_file_path):
+ if not Path(json_file_path).exists():
+ raise FileNotFoundError(f"json_file: {json_file_path} not exist, return []")
+ with open(json_file_path, "r", encoding="utf-8") as fin:
+ try:
+ data = json.load(fin)
+ except Exception:
+ raise ValueError(f"read json file: {json_file_path} failed")
+
+ father_node = data["father node"]
+
+ if experience_data[father_node]["score"] is None:
+ experience_data[father_node]["score"] = data["before"]
+
+ if data["succeed"]:
+ experience_data[father_node]["success"][round_number] = {
+ "modification": data["modification"],
+ "score": data["after"],
+ }
+ else:
+ experience_data[father_node]["failure"][round_number] = {
+ "modification": data["modification"],
+ "score": data["after"],
+ }
+ except Exception as e:
+ print(f"Error processing {round_dir}: {str(e)}")
+
+ experience_data = dict(experience_data)
+
+ output_path = os.path.join(rounds_dir, "processed_experience.json")
+ with open(output_path, "w", encoding="utf-8") as outfile:
+ json.dump(experience_data, outfile, indent=4, ensure_ascii=False)
+
+ print(f"Processed experience data saved to {output_path}")
+ return experience_data
+
+ def format_experience(self, processed_experience, sample_round):
+ experience_data = processed_experience.get(sample_round)
+ if experience_data:
+ experience = f"Original Score: {experience_data['score']}\n"
+ experience += "These are some conclusions drawn from experience:\n\n"
+ for key, value in experience_data["failure"].items():
+ experience += f"-Absolutely prohibit {value['modification']} (Score: {value['score']})\n"
+ for key, value in experience_data["success"].items():
+ experience += f"-Absolutely prohibit {value['modification']} \n"
+ experience += "\n\nNote: Take into account past failures and avoid repeating the same mistakes, as these failures indicate that these approaches are ineffective. You must fundamentally change your way of thinking, rather than simply using more advanced Python syntax like for, if, else, etc., or modifying the prompt."
+ else:
+ experience = f"No experience data found for round {sample_round}."
+ return experience
+
+ def load_log(self, cur_round):
+ log_dir = os.path.join(self.results_path, f"round_{cur_round}/log.json")
+ if not os.path.exists(log_dir):
+ return ""
+ print(log_dir)
+ if not Path(log_dir).exists():
+ raise FileNotFoundError(f"json_file: {log_dir} not exist, return []")
+ with open(log_dir, "r", encoding="utf-8") as fin:
+ try:
+ data = json.load(fin)
+ except Exception:
+ raise ValueError(f"read json file: {log_dir} failed")
+
+ if isinstance(data, dict):
+ data = [data]
+ elif not isinstance(data, list):
+ data = list(data)
+
+ if not data:
+ return ""
+
+ sample_size = min(3, len(data))
+ random_samples = random.sample(data, sample_size)
+
+ log = ""
+ for sample in random_samples:
+ log += json.dumps(sample, indent=4, ensure_ascii=False) + "\n\n"
+
+ return log
+
+ def check_modification(self, processed_experience, modification, sample_round):
+ experience_data = processed_experience.get(sample_round)
+ if experience_data:
+ for key, value in experience_data["failure"].items():
+ if value["modification"] == modification:
+ return False
+ for key, value in experience_data["success"].items():
+ if value["modification"] == modification:
+ return False
+ return True
+ else:
+ return True
+
+ def print_results(self):
+ """
+ Print average score and standard deviation for all rounds.
+ """
+ rounds_dir = os.path.normpath(self.results_path)
+ result_file = os.path.join(rounds_dir, "results.json")
+ # Ensure directory exists
+ os.makedirs(rounds_dir, exist_ok=True)
+ # If file doesn't exist, create a new one with an empty list
+ if not os.path.exists(result_file):
+ with open(result_file, "w") as file:
+ json.dump([], file)
+ # Read file and return data
+ with open(result_file, "r") as file:
+ return json.load(file)
+ rounds = {}
+ for entry in self.data:
+ round_number = entry["round"]
+ score = entry["score"]
+ if round_number not in rounds:
+ rounds[round_number] = []
+ rounds[round_number].append(score)
+ return rounds
+ sorted_rounds = sorted(self.rounds.items(), key=lambda x: x[0])
+ avg_scores = []
+ stds = []
+ for round_number, scores in sorted_rounds:
+ avg_scores.append(np.mean(scores))
+ stds.append(np.std(scores))
+ return avg_scores, stds
+ for i, (avg_score, std) in enumerate(zip(self.avg_scores, self.stds), 1):
+ print(f"Round {i}: Average Score = {avg_score:.4f}, Standard Deviation = {std:.4f}")
+
+ def extract(self,response):
+ TAG = "CONTENT"
+ req_key=f"[/{TAG}]"
+
+ def re_extract_content(cont,pattern):
+ matches = re.findall(pattern, cont, re.DOTALL)
+ for match in matches:
+ if match:
+ cont = match
+ break
+ return cont.strip()
+ raw_content = copy.deepcopy(response)
+ pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]"
+ new_content = re_extract_content(raw_content, pattern)
+ if not new_content.startswith("{"):
+ # TODO find a more general pattern
+ # # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation
+ print(f"extract_content try another pattern: {pattern}")
+ if req_key not in new_content:
+ raw_content = copy.deepcopy(new_content + "\n" + req_key)
+ # # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]"
+ new_content = re_extract_content(raw_content, pattern)
+ else:
+ if req_key in new_content:
+ idx = new_content.find(req_key)
+ new_content = new_content[:idx]
+ new_content = new_content.strip()
+ return json.JSONDecoder(strict=False).decode(new_content,_w=json.decoder.WHITESPACE.match)
+
+ @staticmethod
+ def xml_extract(context: str,field_names :list,field_types) -> Dict[str, Any]:
+ """
+ Fill context with XML tags and convert according to field types, including string, integer, boolean, list and dict types
+ """
+ extracted_data: Dict[str, Any] = {}
+
+ for field_name in field_names:
+ pattern = rf"<{field_name}>(.*?){field_name}>"
+ match = re.search(pattern, context, re.DOTALL)
+ if match:
+ raw_value = match.group(1).strip()
+ field_type = field_types.get(field_name)
+
+ if field_type == str:
+ extracted_data[field_name] = raw_value
+ elif field_type == int:
+ try:
+ extracted_data[field_name] = int(raw_value)
+ except ValueError:
+ extracted_data[field_name] = 0
+ elif field_type == bool:
+ extracted_data[field_name] = raw_value.lower() in ("true", "yes", "1", "on", "True")
+ elif field_type == list:
+ try:
+ extracted_data[field_name] = eval(raw_value)
+ if not isinstance(extracted_data[field_name], list):
+ raise ValueError
+ except:
+ extracted_data[field_name] = []
+ elif field_type == dict:
+ try:
+ extracted_data[field_name] = eval(raw_value)
+ if not isinstance(extracted_data[field_name], dict):
+ raise ValueError
+ except:
+ extracted_data[field_name] = {}
+
+ return extracted_data
+
+ def validate_response(self, response: str) -> Tuple[bool, dict]:
+ """Validate if the response contains all required fields in XML format"""
+ try:
+ pattern = r"<(\w+)>(.*?)\1>"
+ matches = re.findall(pattern, response, re.DOTALL)
+
+ found_fields = {match[0]: match[1].strip() for match in matches}
+
+ for field_name in self._get_field_names():
+ field = self.model.model_fields[field_name]
+
+ return True, found_fields
+ except Exception:
+ return False, None
+
+ def save_optimized_graph(self):
+ top_rounds = self.get_top_rounds()
+ sample,items = self.select_round(top_rounds)
+ graph_path = Path(self.results_path)
+ self.optimized_round=items[0]["round"]
+
+ source_round = graph_path / f"round_{self.optimized_round}"
+ dest_round = graph_path / "best_workflow"
+ if source_round.exists():
+ shutil.copytree(source_round, dest_round, dirs_exist_ok=True)
+ else:
+ raise FileNotFoundError(f"The source folder {source_round} does not exist.")
+
\ No newline at end of file
diff --git a/methods/aflow/configs/config.yaml b/methods/aflow/configs/config.yaml
new file mode 100644
index 0000000..51f1e19
--- /dev/null
+++ b/methods/aflow/configs/config.yaml
@@ -0,0 +1,6 @@
+sample: 4
+max_rounds: 20
+validation_rounds: 3
+earlystop: True
+optimize_meta_model_name: "claude-3-5-sonnet-20241022"
+optimize_execute_model_name: "gpt-4o-mini-2024-07-18"
\ No newline at end of file
diff --git a/methods/aflow/evaluate.py b/methods/aflow/evaluate.py
new file mode 100644
index 0000000..a6bf9d1
--- /dev/null
+++ b/methods/aflow/evaluate.py
@@ -0,0 +1,242 @@
+import threading
+import regex,re,json,inspect,time
+from termcolor import colored
+from typing import Any, List,Tuple,Callable,Dict,Optional
+from math import isclose
+from pathlib import Path
+from pydantic_core import to_jsonable_python
+from sympy import N, simplify
+from sympy.parsing.latex import parse_latex
+from sympy.parsing.sympy_parser import parse_expr
+from .initial_workflows.math.template.sanitize import sanitize
+
+def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4):
+ folder_path = Path(json_file).parent
+ if not folder_path.exists():
+ folder_path.mkdir(parents=True, exist_ok=True)
+ with open(json_file, "w", encoding=encoding) as fout:
+ json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python)
+
+def extract_model_answer(text: str) -> str:
+ pattern = r"\\boxed{((?:[^{}]|{[^{}]*})*)}"
+ boxed_matches = re.findall(pattern, text, re.DOTALL)
+ if boxed_matches:
+ return boxed_matches[-1].strip()
+
+ sentence_end_pattern = r"(? Tuple[str, str, str, int, float]:
+ input_text = problem["query"]
+ expected_output = problem["solution"]
+
+ try:
+ output = await graph(input_text)
+ expected_answer = extract_model_answer(expected_output)
+ predicted_answer = extract_model_answer(output)
+
+ if math_equal(predicted_answer, expected_answer):
+ uni_score, extracted_output = 1, predicted_answer
+ else:
+ uni_score, extracted_output = 0, predicted_answer
+
+ if uni_score == 0:
+ log_mismatch(
+ input_text,
+ expected_output,
+ output,
+ extracted_output,
+ extract_answer_code=get_function_code(extract_model_answer),
+ log_path=log_path
+ )
+ return input_text, output, expected_output, uni_score
+
+ except Exception as e:
+ print(colored(f"Maximum retries reached. Skipping this sample. Error: {e}","light_red"))
+ return input_text, str(e), expected_output, 0.0
+
+def math_equal(prediction: Any, reference: Any) -> bool:
+ if str(prediction) == str(reference):
+ return True
+ try:
+ if is_digit(prediction) and is_digit(reference):
+ prediction = parse_digits(prediction)
+ reference = parse_digits(reference)
+ return isclose(prediction, reference, abs_tol=1e-3)
+ except:
+ pass
+
+ try:
+ return symbolic_equal(prediction, reference)
+ except:
+ pass
+ return False
+
+def is_digit(num):
+ return parse_digits(num) is not None
+
+def parse_digits(num):
+ num = regex.sub(",", "", str(num))
+ try:
+ return float(num)
+ except:
+# When the original input is a percentage in LaTeX format (e.g., 50\%),
+# a backslash remains after processing, causing the float conversion to
+# fail returning None, and subsequent math operations may produce type errors.
+# num = num.replace("\\%", "").replace("%", "")
+ if num.endswith("%"):
+ num = num[:-1]
+ if num.endswith("\\"):
+ num = num[:-1]
+ try:
+ return float(num) / 100
+ except:
+ pass
+ return None
+
+def get_function_code(func):
+ try:
+ source_code = inspect.getsource(func)
+ return source_code
+ except OSError:
+ return "no code"
+
+def symbolic_equal(a, b):
+ def _parse(s):
+ for f in [parse_latex, parse_expr]:
+ try:
+ return f(s)
+ except:
+ pass
+ return s
+
+ a = _parse(a)
+ b = _parse(b)
+
+ try:
+ if simplify(a - b) == 0:
+ return True
+ except:
+ pass
+
+ try:
+ if isclose(N(a), N(b), abs_tol=1e-3):
+ return True
+ except:
+ pass
+ return False
+def log_mismatch(problem: str,expected_output: Any,prediction: str,extracted_output: Any,extract_answer_code: str = "None",log_path=None):
+ log_data = {
+ "question":problem,
+ "right_answer": expected_output,
+ "model_output": prediction,
+ "extracted_output": extracted_output,
+ "extract_answer_code": extract_answer_code,
+ }
+
+ log_file = Path(log_path) / "log.json"
+ if log_file.exists():
+ with log_file.open("r", encoding="utf-8") as f:
+ try:
+ data = json.load(f)
+ except json.JSONDecodeError:
+ data = []
+ else:
+ data = []
+ data.append(log_data)
+ write_json_file(log_file, data, encoding="utf-8", indent=4)
+
+async def evaluate_mbpp(data: dict, graph: Callable,log_path:str) -> Tuple[str, str, str, float, float]:
+ input_text = data["prompt"]
+ expected_output = "\nCorrect Solution:\ndef " + data["code"]
+
+ try:
+ # Generate prediction using the graph function
+ prediction = await graph(input_text, data["entry_point"])
+ # Check the solution
+ ret = check_solution(prediction, data["test"], data["entry_point"])
+ test_case_details = ret[1]
+ expected_output = test_case_details + "\nCorrect Solution:" + data["code"]
+
+ # Calculate score based on the check result
+ score = 1.0 if ret[0] == "PASS" else 0.0
+
+ # Log mismatch if the score is 0
+ if score == 0:
+ log_mismatch(input_text, expected_output, prediction, score,log_path=log_path)
+
+ return input_text, prediction, expected_output, score
+
+ except Exception as e:
+ print(colored(f"Maximum retries reached. Skipping this sample. Error: {e}","light_red"))
+ return input_text, str(e), expected_output, 0.0
+
+def check_solution(solution, test, entry_point):
+ solution = sanitize(code=solution, entrypoint=entry_point)
+ try:
+ global_dict = {
+ "math": __import__("math"),
+ "hashlib": __import__("hashlib"),
+ "re": __import__("re"),
+ "List": List,
+ "Dict": Dict,
+ "Tuple": Tuple,
+ "Optional": Optional,
+ "Any": Any,
+ }
+
+ exec(solution, global_dict)
+
+ if entry_point not in global_dict:
+ raise ValueError(f"Function {entry_point} is not defined in the solution.")
+
+ exec(test, global_dict)
+
+ check = global_dict["check"]
+
+ result = run_with_timeout(check, 15)
+
+ if result is None:
+ result = ("PASS", "The solution passed all test cases.")
+
+ except Exception:
+ result = (
+ "FAIL",
+ "Execution timed out. Please check if your solution contains infinite loops or overly time-consuming operations.",
+ )
+ except Exception as e:
+ error_message = f"Error: {str(e)}.\n Solution: {solution}.\n Test: {test}"
+ result = ("FAIL", error_message)
+
+ with open("error.log", "a", encoding="utf-8") as log_file:
+ log_file.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {error_message}\n")
+
+ return result
+
+def run_with_timeout(func, timeout):
+ result = []
+ stop_event = threading.Event()
+
+ def target():
+ try:
+ result.append(func())
+ except Exception as e:
+ result.append(e)
+ finally:
+ stop_event.set()
+
+ thread = threading.Thread(target=target)
+ thread.start()
+ is_timeout = not stop_event.wait(timeout)
+
+ if is_timeout:
+ raise Exception("Function execution timed out")
+
+ if not result:
+ return None
+ if isinstance(result[0], Exception):
+ raise result[0]
+ return result[0]
\ No newline at end of file
diff --git a/methods/aflow/initial_workflows/math/__init__.py b/methods/aflow/initial_workflows/math/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/methods/aflow/initial_workflows/math/round_1/__init__.py b/methods/aflow/initial_workflows/math/round_1/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/methods/aflow/initial_workflows/math/round_1/graph.py b/methods/aflow/initial_workflows/math/round_1/graph.py
new file mode 100644
index 0000000..77b2619
--- /dev/null
+++ b/methods/aflow/initial_workflows/math/round_1/graph.py
@@ -0,0 +1,19 @@
+from typing import Literal
+from ..template import operator
+from ..round_1 import prompt as prompt_custom
+
+DatasetType = Literal["HumanEval", "MBPP", "GSM8K", "MATH", "HotpotQA", "DROP"]
+class Workflow:
+ def __init__(self,name: str,env) -> None:
+ self.name = name
+ self.llm=env
+ self.custom = operator.Custom(self.llm)
+
+ async def __call__(self, problem: str):
+ """
+ Implementation of the workflow
+ """
+ solution = await self.custom(input=problem, instruction="")
+ return solution['response']
+
+
\ No newline at end of file
diff --git a/methods/aflow/initial_workflows/math/round_1/log.json b/methods/aflow/initial_workflows/math/round_1/log.json
new file mode 100644
index 0000000..e69de29
diff --git a/methods/aflow/initial_workflows/math/round_1/prompt.py b/methods/aflow/initial_workflows/math/round_1/prompt.py
new file mode 100644
index 0000000..310e53d
--- /dev/null
+++ b/methods/aflow/initial_workflows/math/round_1/prompt.py
@@ -0,0 +1,6 @@
+XXX_PROMPT = """
+
+Solve it.
+
+"""
+
diff --git a/methods/aflow/initial_workflows/math/template/__init__.py b/methods/aflow/initial_workflows/math/template/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/methods/aflow/initial_workflows/math/template/op_prompt.py b/methods/aflow/initial_workflows/math/template/op_prompt.py
new file mode 100644
index 0000000..af13cc7
--- /dev/null
+++ b/methods/aflow/initial_workflows/math/template/op_prompt.py
@@ -0,0 +1,29 @@
+SC_ENSEMBLE_PROMPT = """
+Given the question described as follows: {problem}
+Several solutions have been generated to address the given question. They are as follows:
+{solutions}
+
+Carefully evaluate these solutions and identify the answer that appears most frequently across them. This consistency in answers is crucial for determining the most reliable solution.
+
+In the "thought" field, provide a detailed explanation of your thought process. In the "solution_letter" field, output only the single letter ID (A, B, C, etc.) corresponding to the most consistent solution. Do not include any additional text or explanation in the "solution_letter" field.
+"""
+
+PYTHON_CODE_VERIFIER_PROMPT = """
+You are a professional Python programmer. Your task is to write complete, self-contained code based on a given mathematical problem and output the answer. The code should include all necessary imports and dependencies, and be ready to run without additional setup or environment configuration.
+
+Problem description: {problem}
+Other analysis: {analysis}
+{feedback}
+
+Your code should:
+1. Implement the calculation steps described in the problem.
+2. Define a function named `solve` that performs the calculation and returns the result. The `solve` function should not require any input parameters; instead, it should obtain all necessary inputs from within the function or from globally defined variables.
+3. `solve` function return the final calculation result.
+
+Please ensure your code is efficient, well-commented, and follows Python best practices. The output should be limited to basic data types such as strings, integers, and floats. It is prohibited to transmit images or other file formats. The code output is intended for a text-based language model.
+
+Wrap your final code solution in and . For example:
+
+Your function code here
+
+"""
\ No newline at end of file
diff --git a/methods/aflow/initial_workflows/math/template/operator.json b/methods/aflow/initial_workflows/math/template/operator.json
new file mode 100644
index 0000000..1a57ccd
--- /dev/null
+++ b/methods/aflow/initial_workflows/math/template/operator.json
@@ -0,0 +1,14 @@
+{
+ "Custom": {
+ "description": "Generates anything based on customized input and instruction.",
+ "interface": "custom(input: str, instruction: str) -> dict with key 'response' of type str"
+ },
+ "ScEnsemble": {
+ "description": "Uses self-consistency to select the solution that appears most frequently in the solution list, improve the selection to enhance the choice of the best solution.",
+ "interface": "sc_ensemble(solutions: List[str], problem: str) -> dict with key 'response' of type str"
+ },
+ "Programmer": {
+ "description": "Automatically writes, executes Python code, and returns the solution based on the provided problem description and analysis. The `output` only contains the final answer. If you want to see the detailed solution process, it's recommended to retrieve the `code`.",
+ "interface": "programmer(problem: str, analysis: str = 'None') -> dict with keys 'code' and 'output' of type str"
+ }
+}
diff --git a/methods/aflow/initial_workflows/math/template/operator.py b/methods/aflow/initial_workflows/math/template/operator.py
new file mode 100644
index 0000000..421bcca
--- /dev/null
+++ b/methods/aflow/initial_workflows/math/template/operator.py
@@ -0,0 +1,212 @@
+import concurrent
+import sys
+import re
+import traceback
+from typing import List
+from tenacity import retry, stop_after_attempt, wait_fixed
+
+from .op_prompt import *
+from .sanitize import *
+import asyncio
+
+class Operator:
+ def __init__(self, llm, name: str):
+ self.name = name
+ self.llm = llm
+
+ def __call__(self, *args, **kwargs):
+ raise NotImplementedError
+
+class Custom(Operator):
+ def __init__(self, llm, name: str = "Custom"):
+ super().__init__(llm, name)
+
+ async def __call__(self, input, instruction):
+ prompt = instruction + input
+ if self.llm.inference_flag:
+ response = await self.llm.async_call_llm(prompt=prompt)
+ else:
+ response = await self.llm.async_call_llm(prompt=prompt,model_name=self.llm.model_name_execute)
+ return {"response":response}
+
+def run_code(code):
+ try:
+ # Create a new global namespace
+ global_namespace = {}
+
+ disallowed_imports = [
+ "os", "sys", "subprocess", "multiprocessing",
+ "matplotlib", "seaborn", "plotly", "bokeh", "ggplot",
+ "pylab", "tkinter", "PyQt5", "wx", "pyglet"
+ ]
+
+ # Check for prohibited imports
+ for lib in disallowed_imports:
+ if f"import {lib}" in code or f"from {lib}" in code:
+
+ return "Error", f"Prohibited import: {lib} and graphing functionalities"
+
+ # Use exec to execute the code
+ exec(code, global_namespace)
+ # Assume the code defines a function named 'solve'
+ if 'solve' in global_namespace and callable(global_namespace['solve']):
+ result = global_namespace['solve']()
+ return "Success", str(result)
+ else:
+ return "Error", "Function 'solve' not found"
+ except Exception as e:
+ exc_type, exc_value, exc_traceback = sys.exc_info()
+ tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback)
+ return "Error", f"Execution error: {str(e)}\n{''.join(tb_str)}"
+
+
+class Programmer(Operator):
+ def __init__(self, llm, name: str = "Programmer"):
+ super().__init__(llm, name)
+
+ async def exec_code(self, code, timeout=30):
+ """
+ Asynchronously execute code and return an error if timeout occurs.
+ """
+ loop = asyncio.get_running_loop()
+ with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
+ try:
+ # Submit run_code task to the process pool
+ future = loop.run_in_executor(executor, run_code, code)
+ # Wait for the task to complete or timeout
+ result = await asyncio.wait_for(future, timeout=timeout)
+ return result
+ except asyncio.TimeoutError:
+ # Timeout, attempt to shut down the process pool
+ executor.shutdown(wait=False, cancel_futures=True)
+ return "Error", "Code execution timed out"
+ except Exception as e:
+ return "Error", f"Unknown error: {str(e)}"
+
+ async def code_generate(self, problem, analysis, feedback):
+ """
+ Asynchronous method to generate code.
+ """
+ prompt = PYTHON_CODE_VERIFIER_PROMPT.format(
+ problem=problem,
+ analysis=analysis,
+ feedback=feedback
+ )
+ try:
+ if self.llm.inference_flag:
+ response = await self.llm.async_call_llm(prompt=prompt)
+ else:
+ response = await self.llm.async_call_llm(prompt=prompt,model_name=self.llm.model_name_execute)
+
+ code_pattern = r"\s*(.*?)\s*"
+ match = re.search(code_pattern, response, re.DOTALL)
+ if match:
+ code = match.group(1).strip()
+ code = re.sub(r"^```(?:\w+)?\n?|```$", "", code, flags=re.MULTILINE).strip()
+ if not code:
+ code = ""
+
+ response = {"code": code}
+
+ except Exception as e:
+ response = {"error": str(e)}
+
+ return response
+
+ def _extract_code_from_markdown(self, text: str) -> str:
+ """
+ Extract code from markdown code blocks in the response.
+
+ Args:
+ text: The text containing possible markdown code blocks
+
+ Returns:
+ The extracted code as a string, or empty string if no code blocks found
+ """
+ # Look for Python code blocks (```python ... ```)
+ python_pattern = r"```python\s*([\s\S]*?)\s*```"
+ python_matches = re.findall(python_pattern, text)
+
+ if python_matches:
+ # Join all Python code blocks
+ return "\n\n".join(python_matches)
+
+ # If no Python blocks found, look for generic code blocks (``` ... ```)
+ generic_pattern = r"```\s*([\s\S]*?)\s*```"
+ generic_matches = re.findall(generic_pattern, text)
+
+ if generic_matches:
+ # Join all generic code blocks
+ return "\n\n".join(generic_matches)
+
+ # No code blocks found
+ return ""
+
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
+ async def __call__(self, problem: str, analysis: str = "None"):
+ """
+ Call method, generate code and execute, retry up to 3 times.
+ """
+ code = None
+ output = None
+ feedback = ""
+ for i in range(3):
+ code_response = await self.code_generate(problem, analysis, feedback)
+ code = code_response.get("code")
+ if not code:
+ return {"code": code, "output": "No code generated"}
+ status, output = await self.exec_code(code)
+ if status == "Success":
+ return {"code": code, "output": output}
+ else:
+ print(f"Execution error on attempt {i + 1}, error message: {output}")
+ feedback = (
+ f"\nThe result of the error from the code you wrote in the previous round:\n"
+ f"Code: {code}\n\nStatus: {status}, {output}"
+ )
+ return {"code": code, "output": output}
+
+
+class ScEnsemble(Operator):
+ """
+ Paper: Self-Consistency Improves Chain of Thought Reasoning in Language Models
+ Link: https://arxiv.org/abs/2203.11171
+ Paper: Universal Self-Consistency for Large Language Model Generation
+ Link: https://arxiv.org/abs/2311.17311
+ """
+
+ def __init__(self, llm, name: str = "ScEnsemble"):
+ super().__init__(llm, name)
+
+ async def __call__(self, solutions: List[str], problem: str):
+ answer_mapping = {}
+ solution_text = ""
+ for index, solution in enumerate(solutions):
+ answer_mapping[chr(65 + index)] = index
+ solution_text += f"{chr(65 + index)}: \n{str(solution)}\n\n\n"
+
+ prompt = SC_ENSEMBLE_PROMPT.format(problem=problem, solutions=solution_text)
+
+ prompt = prompt + f"\n# Response format (must be strictly followed) (do not include any other formats except for the given XML format):\nThe letter of most consistent solution."
+
+ if self.llm.inference_flag:
+ response = await self.llm.async_call_llm(prompt=prompt)
+ else:
+ response = await self.llm.async_call_llm(prompt=prompt,model_name=self.llm.model_name_execute)
+
+ try:
+ pattern = r"<(\w+)>(.*?)\1>"
+ matches = re.findall(pattern, response, re.DOTALL)
+ found_fields = {match[0]: match[1].strip() for match in matches}
+ except:
+ pass
+ if isinstance(found_fields, dict):
+ response = found_fields
+ else:
+ response = {"response": response}
+
+
+ answer = response.get("solution_letter", "")
+ answer = answer.strip().upper()
+
+ return {"response": solutions[answer_mapping[answer]]}
\ No newline at end of file
diff --git a/methods/aflow/initial_workflows/math/template/sanitize.py b/methods/aflow/initial_workflows/math/template/sanitize.py
new file mode 100644
index 0000000..811771d
--- /dev/null
+++ b/methods/aflow/initial_workflows/math/template/sanitize.py
@@ -0,0 +1,177 @@
+
+import ast
+import traceback
+from enum import Enum
+from typing import Dict, Generator, List, Optional, Set, Tuple
+
+import tree_sitter_python
+from tree_sitter import Language, Node, Parser
+
+
+class NodeType(Enum):
+ CLASS = "class_definition"
+ FUNCTION = "function_definition"
+ IMPORT = ["import_statement", "import_from_statement"]
+ IDENTIFIER = "identifier"
+ ATTRIBUTE = "attribute"
+ RETURN = "return_statement"
+ EXPRESSION = "expression_statement"
+ ASSIGNMENT = "assignment"
+
+
+def traverse_tree(node: Node) -> Generator[Node, None, None]:
+ """
+ Traverse the tree structure starting from the given node.
+
+ :param node: The root node to start the traversal from.
+ :return: A generator object that yields nodes in the tree.
+ """
+ cursor = node.walk()
+ depth = 0
+
+ visited_children = False
+ while True:
+ if not visited_children:
+ yield cursor.node
+ if not cursor.goto_first_child():
+ depth += 1
+ visited_children = True
+ elif cursor.goto_next_sibling():
+ visited_children = False
+ elif not cursor.goto_parent() or depth == 0:
+ break
+ else:
+ depth -= 1
+
+
+def syntax_check(code, verbose=False):
+ try:
+ ast.parse(code)
+ return True
+ except (SyntaxError, MemoryError):
+ if verbose:
+ traceback.print_exc()
+ return False
+
+
+def code_extract(text: str) -> str:
+ lines = text.split("\n")
+ longest_line_pair = (0, 0)
+ longest_so_far = 0
+
+ for i in range(len(lines)):
+ for j in range(i + 1, len(lines)):
+ current_lines = "\n".join(lines[i : j + 1])
+ if syntax_check(current_lines):
+ current_length = sum(1 for line in lines[i : j + 1] if line.strip())
+ if current_length > longest_so_far:
+ longest_so_far = current_length
+ longest_line_pair = (i, j)
+
+ return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
+
+
+def get_definition_name(node: Node) -> str:
+ for child in node.children:
+ if child.type == NodeType.IDENTIFIER.value:
+ return child.text.decode("utf8")
+
+
+def has_return_statement(node: Node) -> bool:
+ traverse_nodes = traverse_tree(node)
+ for node in traverse_nodes:
+ if node.type == NodeType.RETURN.value:
+ return True
+ return False
+
+
+def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
+ def dfs_get_deps(node: Node, deps: Set[str]) -> None:
+ for child in node.children:
+ if child.type == NodeType.IDENTIFIER.value:
+ deps.add(child.text.decode("utf8"))
+ else:
+ dfs_get_deps(child, deps)
+
+ name2deps = {}
+ for name, node in nodes:
+ deps = set()
+ dfs_get_deps(node, deps)
+ name2deps[name] = deps
+ return name2deps
+
+
+def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
+ queue = [entrypoint]
+ visited = {entrypoint}
+ while queue:
+ current = queue.pop(0)
+ if current not in call_graph:
+ continue
+ for neighbour in call_graph[current]:
+ if neighbour not in visited:
+ visited.add(neighbour)
+ queue.append(neighbour)
+ return visited
+
+
+def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
+ """
+ Sanitize and extract relevant parts of the given Python code.
+ This function parses the input code, extracts import statements, class and function definitions,
+ and variable assignments. If an entrypoint is provided, it only includes definitions that are
+ reachable from the entrypoint in the call graph.
+
+ :param code: The input Python code as a string.
+ :param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis.
+ :return: A sanitized version of the input code, containing only relevant parts.
+ """
+ code = code_extract(code)
+ code_bytes = bytes(code, "utf8")
+ parser = Parser(Language(tree_sitter_python.language()))
+ tree = parser.parse(code_bytes)
+ class_names = set()
+ function_names = set()
+ variable_names = set()
+
+ root_node = tree.root_node
+ import_nodes = []
+ definition_nodes = []
+
+ for child in root_node.children:
+ if child.type in NodeType.IMPORT.value:
+ import_nodes.append(child)
+ elif child.type == NodeType.CLASS.value:
+ name = get_definition_name(child)
+ if not (name in class_names or name in variable_names or name in function_names):
+ definition_nodes.append((name, child))
+ class_names.add(name)
+ elif child.type == NodeType.FUNCTION.value:
+ name = get_definition_name(child)
+ if not (name in function_names or name in variable_names or name in class_names) and has_return_statement(
+ child
+ ):
+ definition_nodes.append((name, child))
+ function_names.add(get_definition_name(child))
+ elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value:
+ subchild = child.children[0]
+ name = get_definition_name(subchild)
+ if not (name in variable_names or name in function_names or name in class_names):
+ definition_nodes.append((name, subchild))
+ variable_names.add(name)
+
+ if entrypoint:
+ name2deps = get_deps(definition_nodes)
+ reacheable = get_function_dependency(entrypoint, name2deps)
+
+ sanitized_output = b""
+
+ for node in import_nodes:
+ sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
+
+ for pair in definition_nodes:
+ name, node = pair
+ if entrypoint and name not in reacheable:
+ continue
+ sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
+ return sanitized_output[:-1].decode("utf8")
diff --git a/methods/aflow/prompt.py b/methods/aflow/prompt.py
new file mode 100644
index 0000000..dd85a6f
--- /dev/null
+++ b/methods/aflow/prompt.py
@@ -0,0 +1,57 @@
+WORKFLOW_OPTIMIZE_PROMPT = """You are building a Graph and corresponding Prompt to jointly solve {type} problems.
+Referring to the given graph and prompt, which forms a basic example of a {type} solution approach,
+please reconstruct and optimize them. You can add, modify, or delete nodes, parameters, or prompts. Include your
+single modification in XML tags in your reply. Ensure they are complete and correct to avoid runtime failures. When
+optimizing, you can incorporate critical thinking methods like review, revise, ensemble (generating multiple answers through different/similar prompts, then voting/integrating/checking the majority to obtain a final answer), selfAsk, etc. Consider
+Python's loops (for, while, list comprehensions), conditional statements (if-elif-else, ternary operators),
+or machine learning techniques (e.g., linear regression, decision trees, neural networks, clustering). The graph
+complexity should not exceed 10. Use logical and control flow (IF-ELSE, loops) for a more enhanced graphical
+representation.Ensure that all the prompts required by the current graph from prompt_custom are included.Exclude any other prompts.
+Output the modified graph and all the necessary Prompts in prompt_custom (if needed).
+The prompt you need to generate is only the one used in `prompt_custom.XXX` within Custom. Other methods already have built-in prompts and are prohibited from being generated. Only generate those needed for use in `prompt_custom`; please remove any unused prompts in prompt_custom.
+the generated prompt must not contain any placeholders.
+Considering information loss, complex graphs may yield better results, but insufficient information transmission can omit the solution. It's crucial to include necessary context during the process."""
+
+
+WORKFLOW_INPUT = """
+Here is a graph and the corresponding prompt (prompt only related to the custom method) that performed excellently in a previous iteration (maximum score is 1). You must make further optimizations and improvements based on this graph. The modified graph must differ from the provided example, and the specific differences should be noted within the xxx section.\n
+
+ {experience}
+ (such as:add /delete /modify/ ...)
+ {score}
+ {graph}
+ {prompt}(only prompt_custom)
+ {operator_description}
+
+Below are the logs of some results with the aforementioned Graph that performed well but encountered errors, which can be used as references for optimization:
+{log}
+
+First, provide optimization ideas. **Only one detail point can be modified at a time**, and no more than 5 lines of code may be changed per modification—extensive modifications are strictly prohibited to maintain project focus!
+When introducing new functionalities in the graph, please make sure to import the necessary libraries or modules yourself, except for operator, prompt_custom, create_llm_instance, and CostManage, which have already been automatically imported.
+**Under no circumstances should Graph output None for any field.**
+Use custom methods to restrict your output format, rather than using code (outside of the code, the system will extract answers based on certain rules and score them).
+It is very important to format the Graph output answers, you can refer to the standard answer format in the log.
+"""
+
+WORKFLOW_CUSTOM_USE = """Here's an example of using the `custom` method in graph:
+```
+# You can write your own prompt in prompt_custom and then use it in the Custom method in the graph
+response = await self.custom(input=problem, instruction=prompt_custom.XXX_PROMPT)
+# You can also concatenate previously generated string results in the input to provide more comprehensive contextual information.
+# response = await self.custom(input=problem+f"xxx:{xxx}, xxx:{xxx}", instruction=prompt_custom.XXX_PROMPT)
+# The output from the Custom method can be placed anywhere you need it, as shown in the example below
+solution = await self.generate(problem=f"question:{problem}, xxx:{response['response']}")
+```
+Note: In custom, the input and instruction are directly concatenated(instruction+input), and placeholders are not supported. Please ensure to add comments and handle the concatenation externally.\n
+
+**Introducing multiple operators at appropriate points can enhance performance. If you find that some provided operators are not yet used in the graph, try incorporating them.**
+"""
+
+WORKFLOW_TEMPLATE = """from typing import Literal
+from ..template import operator as operator
+from ..round_{round} import prompt as prompt_custom
+
+DatasetType = Literal["HumanEval", "MBPP", "GSM8K", "MATH", "HotpotQA", "DROP"]
+
+{graph}
+"""