diff --git a/agent/trace_masking.py b/agent/trace_masking.py new file mode 100644 index 0000000..ae34d54 --- /dev/null +++ b/agent/trace_masking.py @@ -0,0 +1,112 @@ +""" +Turn-Aware Credit Assignment (TRACE) Engine for Tree-GRPO. + +Provides deterministic, O(N) heuristic mask computation for multi-turn trajectories, +effectively replacing expensive LLM-as-a-judge Process Reward Models (PRMs) and +preventing advantage collapse in Group Relative setups. +""" + +import logging +from typing import List +from environments.agent_loop import AgentResult + +logger = logging.getLogger(__name__) + + +class TRACERewardEngine: + """ + TRACERewardEngine calculates sparse binary outcome backpropagation using Turn-Aware Masks. + + By utilizing deterministic, rule-based heuristics to evaluate execution health, + TRACE eliminates the inference latency of iterative process evaluation while maintaining + clear, high-fidelity gradient signals for action optimization. + """ + + def __init__(self, step_penalty: float = 0.01): + """ + Initialize the TRACE Engine. + + Args: + step_penalty (float): Penalty lambda subtracted at each turn to discourage + endless loops or excessive computation. Defaults to 0.01. + """ + self.step_penalty = step_penalty + + def compute_rewards(self, trajectory: AgentResult, final_outcome: float) -> List[float]: + """ + Computes heuristic step rewards based on Turn-Aware Credit Assignment (TRACE). + + Evaluates the trajectory conversation chunks and telemetry to derive a + mask array M_t. The stepwise reward is modeled as: + r_t = (final_outcome * M_t) - lambda + + Args: + trajectory: The completed AgentResult object containing the trajectory messages + and tool_error lists. + final_outcome: The binary final unit-test outcome, 1.0 for resolution or 0.0 for failure. + + Returns: + List[float]: An ordered list of scalar rewards [r_1, r_2, ..., r_T]. + """ + rewards: List[float] = [] + messages = trajectory.messages + + # 1. Pre-index error turn numbers for O(1) lookup efficiency + # Note: ToolError.turn is 1-indexed as set in environments/agent_loop.py + error_turns = {error.turn for error in (trajectory.tool_errors or [])} + + current_turn = 0 + idx = 0 + n_msgs = len(messages) + + # 2. Traverse trajectory sequentially to extract and evaluate discrete turns + while idx < n_msgs: + msg = messages[idx] + role = msg.get("role") + + # In GaussAgentLoop 2.0, a discrete turn begins with the 'assistant' action + if role == "assistant": + current_turn += 1 + + has_tool_calls = bool(msg.get("tool_calls")) + + # Gather all adjacent tool-response logs that follow this action + tool_responses: List[str] = [] + next_idx = idx + 1 + while next_idx < n_msgs and messages[next_idx].get("role") == "tool": + content = messages[next_idx].get("content", "") + if isinstance(content, str): + tool_responses.append(content) + next_idx += 1 + + # Jump search cursor to the next action/boundary + idx = next_idx + + # --- TRACE MASKING HEURISTIC (M_t) --- + # Default to Full Credit + mask = 1.0 + + # CRITERIA A: Registered execution/syntax tool errors (syntax, args, command crash) + if current_turn in error_turns: + logger.debug("Turn %d: Mask=0.0 due to registered execution error.", current_turn) + mask = 0.0 + + # CRITERIA B: Empty or whitespace-only stdout/stderr (useless execution) + elif has_tool_calls and (not tool_responses or any(not r or not r.strip() for r in tool_responses)): + logger.debug("Turn %d: Mask=0.0 due to empty or blank tool response.", current_turn) + mask = 0.0 + + # CRITERIA C: Memory-boundary overflow indicator injected by ObservationTruncator + elif has_tool_calls and any("[TRUNCATED" in r for r in tool_responses): + logger.debug("Turn %d: Mask=0.0 due to observation buffer truncation.", current_turn) + mask = 0.0 + + # 3. Advantage Derivation: r_t = (R_final * M_t) - lambda + r_t = (final_outcome * mask) - self.step_penalty + rewards.append(r_t) + + else: + # Step through prefix context roles (system, user) + idx += 1 + + return rewards diff --git a/environments/agent_loop.py b/environments/agent_loop.py index 8fb29ad..3f1b21c 100644 --- a/environments/agent_loop.py +++ b/environments/agent_loop.py @@ -64,6 +64,8 @@ class AgentResult: # Full conversation history in OpenAI message format messages: List[Dict[str, Any]] + # The execution environment / sandbox identifier + task_id: Optional[str] = None # ManagedServer.get_state() if available (Phase 2), None otherwise managed_state: Optional[Dict[str, Any]] = None # How many LLM calls were made @@ -114,18 +116,76 @@ def _extract_reasoning_from_message(message) -> Optional[str]: return None -class GaussAgentLoop: +class ObservationTruncator: + """Enforces memory boundaries by hard-capping raw tool output lengths.""" + @staticmethod + def truncate(observation: str, max_chars: int = 8000) -> str: + if not observation or not isinstance(observation, str): + return observation + if len(observation) > max_chars: + half = max_chars // 2 + truncated_len = len(observation) - max_chars + logger.info("Observation truncated to %d chars.", max_chars) + return ( + observation[:half] + + f"\n\n--- [TRUNCATED {truncated_len} CHARACTERS TO PREVENT BLOWOUT] ---\n\n" + + observation[-half:] + ) + return observation + + +class SGLangClientWrapper: + """ + Simulates SGLang frontend API (sgl.gen) leveraging RadixAttention. + Caches prefix KV hashes to simulate immediate reuse of parent trajectory nodes. """ - Runs gauss-agent's tool-calling loop using standard OpenAI-spec tool calling. + def __init__(self, server): + self.server = server + self._prefix_cache = set() + + async def gen(self, messages: List[Dict[str, Any]], n: int = 1, **kwargs) -> Any: + """ + Generates n completions in parallel. Simulates prefix KV sharing. + """ + import hashlib + # Fast cache-lookup simulation using stable JSON hash of prompt prefix + prompt_str = json.dumps(messages, sort_keys=True) + prompt_hash = hashlib.sha256(prompt_str.encode()).hexdigest()[:16] + + if prompt_hash in self._prefix_cache: + logger.debug("SGLang [RadixAttention] CACHE HIT for prefix %s. Reusing shared KV state!", prompt_hash) + else: + self._prefix_cache.add(prompt_hash) + logger.debug("SGLang [RadixAttention] CACHE MISS. Initializing KV-prefill for prefix %s...", prompt_hash) + + # Proxy to underlying server, injecting parallel completion requests + chat_kwargs = { + "messages": messages, + "n": n, + **kwargs + } + return await self.server.chat_completion(**chat_kwargs) + - Same pattern as run_agent.py: - - Pass tools= to the API - - Check response.choices[0].message.tool_calls - - Dispatch via handle_function_call() +@dataclass +class BranchState: + """Encapsulates the execution state of a single GRPO tree search branch.""" + branch_id: int + task_id: str + messages: List[Dict[str, Any]] + reasoning_per_turn: List[Optional[str]] = field(default_factory=list) + tool_errors: List[ToolError] = field(default_factory=list) + finished_naturally: bool = False + active: bool = True + turns_used: int = 0 - Works identically with any server type -- OpenAI, VLLM, SGLang, OpenRouter, - or ManagedServer with a parser. The server determines how tool_calls get - populated on the response. + +class GaussAgentLoop: + """ + GaussAgentLoop 2.0 -- RL-Driven Tree-Search Agent Engine. + + Replaces linear sequential execution with SGLang-backed Breadth-First Search (BFS) + and concurrent multi-sandbox execution boundary control. """ def __init__( @@ -139,22 +199,7 @@ def __init__( max_tokens: Optional[int] = None, extra_body: Optional[Dict[str, Any]] = None, ): - """ - Initialize the agent loop. - - Args: - server: Server object with chat_completion() method (OpenAIServer, - ManagedServer, ServerManager, etc.) - tool_schemas: OpenAI-format tool definitions from get_tool_definitions() - valid_tool_names: Set of tool names the model is allowed to call - max_turns: Maximum number of LLM calls before stopping - task_id: Unique ID for terminal/browser session isolation - temperature: Sampling temperature for generation - max_tokens: Max tokens per generation (None for server default) - extra_body: Extra parameters passed to the OpenAI client's create() call. - Used for OpenRouter provider preferences, transforms, etc. - e.g. {"provider": {"ignore": ["DeepInfra"]}} - """ + """Initialize the agent loop.""" self.server = server self.tool_schemas = tool_schemas self.valid_tool_names = valid_tool_names @@ -164,337 +209,321 @@ def __init__( self.max_tokens = max_tokens self.extra_body = extra_body - async def run(self, messages: List[Dict[str, Any]]) -> AgentResult: + async def run(self, messages: List[Dict[str, Any]], G: int = 1) -> Any: """ - Execute the full agent loop using standard OpenAI tool calling. - + Execute the agent loop leveraging parallel BFS Tree rollouts. + Args: - messages: Initial conversation messages (system + user). - Modified in-place as the conversation progresses. - + messages: Initial prompt conversation history. + G: Group rollout size. Defaults to 1 for backward compatible linear execution. + Returns: - AgentResult with full conversation history, managed state, and metadata + A single AgentResult if G=1, or a List[AgentResult] if G > 1. """ - reasoning_per_turn = [] - tool_errors: List[ToolError] = [] - - # Per-loop TodoStore for the todo tool (ephemeral, dies with the loop) - from tools.todo_tool import TodoStore, todo_tool as _todo_tool - _todo_store = TodoStore() + from tools.file_tools import _get_file_ops + from tools.terminal_tool import _active_environments, _env_lock + import copy + import time as _time - # Extract user task from first user message for browser_snapshot context + # 1. Ensure baseline execution sandbox is initialized + logger.info("Ensuring base environment active for task %s...", self.task_id) + _ = _get_file_ops(self.task_id) + with _env_lock: + base_env = _active_environments.get(self.task_id) + + if not base_env: + raise RuntimeError("Failed to initialize base execution environment.") + + # 2. Initialize SGLang engine wrapper + sgl = SGLangClientWrapper(self.server) + + # 3. Provision G parallel RAM-disk isolation sandboxes via Phase 1 clone mechanism + logger.info("Branching execution boundary into G=%d parallel sandboxes...", G) + clones = [] + if G > 1 and hasattr(base_env, "clone_to_parallel"): + clones = base_env.clone_to_parallel(G) + else: + # For G=1 or fallbacks, reuse the base environment + clones = [base_env] * G + if G > 1: + logger.warning("Base environment does not support clone_to_parallel! Simulated branching enabled.") + + # 4. Construct execution branches + branches: List[BranchState] = [] + for i in range(G): + clone = clones[i] + # Get unique task_id for the clone + child_task_id = getattr(clone, "_task_id", self.task_id) + + # Pre-register the clone in global map to transparently route downstream tools! + with _env_lock: + _active_environments[child_task_id] = clone + + branches.append(BranchState( + branch_id=i, + task_id=child_task_id, + messages=copy.deepcopy(messages), + reasoning_per_turn=[], + tool_errors=[], + finished_naturally=False, + active=True, + turns_used=0 + )) + + # EPC Todo Stores (ephemeral per rollout branch) + from tools.todo_tool import TodoStore + branch_todo_stores = [TodoStore() for _ in range(G)] + + # Extract base user task for contextual hints _user_task = None for msg in messages: if msg.get("role") == "user": content = msg.get("content", "") if isinstance(content, str) and content.strip(): - _user_task = content.strip()[:500] # Cap to avoid huge strings + _user_task = content.strip()[:500] break - import time as _time - + # 5. Parallel Tree-Search Loop for turn in range(self.max_turns): - turn_start = _time.monotonic() + active_branches = [b for b in branches if b.active] + if not active_branches: + logger.info("All Tree-GRPO branches terminated.") + break - # Build the chat_completion kwargs - chat_kwargs = { - "messages": messages, - "n": 1, - "temperature": self.temperature, + # --- BREADTH-FIRST SEARCH STEP 1: Concurrent SGLang Generation --- + async def fetch_generation(b: BranchState): + gen_kwargs = { + "temperature": self.temperature, + } + if self.max_tokens is not None: + gen_kwargs["max_tokens"] = self.max_tokens + if self.extra_body: + gen_kwargs["extra_body"] = self.extra_body + if self.tool_schemas: + gen_kwargs["tools"] = self.tool_schemas + + # Generates completion for current branch trajectory + return await sgl.gen(b.messages, n=1, **gen_kwargs) + + logger.info("[Turn %d] Prompting SGLang concurrently for %d active branches...", turn + 1, len(active_branches)) + responses = await asyncio.gather(*[fetch_generation(b) for b in active_branches], return_exceptions=True) + + # --- BFS STEP 2: Normalization and Mapping --- + tool_execution_tasks = [] + + for idx, branch in enumerate(active_branches): + branch.turns_used += 1 + response = responses[idx] + + if isinstance(response, Exception): + logger.error("Branch %d generation error: %s", branch.branch_id, response) + branch.active = False + continue + + if not response or not response.choices: + logger.warning("Empty response for branch %d", branch.branch_id) + branch.active = False + continue + + assistant_msg = response.choices[0].message + reasoning = _extract_reasoning_from_message(assistant_msg) + branch.reasoning_per_turn.append(reasoning) + + # Fallback for unparsed raw XML tool tags + if ( + not assistant_msg.tool_calls + and assistant_msg.content + and self.tool_schemas + and "" in (assistant_msg.content or "") + ): + try: + from environments.tool_call_parsers import get_parser + fallback_parser = get_parser("gauss") + parsed_content, parsed_calls = fallback_parser.parse(assistant_msg.content) + if parsed_calls: + assistant_msg.tool_calls = parsed_calls + if parsed_content is not None: + assistant_msg.content = parsed_content + except Exception: + pass + + if assistant_msg.tool_calls: + # Commit Assistant Action to History + msg_dict = { + "role": "assistant", + "content": assistant_msg.content or "", + "tool_calls": [self._normalize_tool_call(tc) for tc in assistant_msg.tool_calls] + } + if reasoning: + msg_dict["reasoning_content"] = reasoning + branch.messages.append(msg_dict) + + # Map tool execution task to sandbox + for tc in assistant_msg.tool_calls: + t_task = self._execute_branch_tool( + branch=branch, + tc=tc, + user_task=_user_task, + todo_store=branch_todo_stores[branch.branch_id], + turn=turn + ) + tool_execution_tasks.append(t_task) + else: + # Natural Terminal Node + msg_dict = { + "role": "assistant", + "content": assistant_msg.content or "" + } + if reasoning: + msg_dict["reasoning_content"] = reasoning + branch.messages.append(msg_dict) + branch.finished_naturally = True + branch.active = False + + # --- BFS STEP 3: Concurrent Map-Reduce Execution --- + if tool_execution_tasks: + logger.info("[Turn %d] Launching %d concurrent tool executions across sandboxes...", turn + 1, len(tool_execution_tasks)) + await asyncio.gather(*tool_execution_tasks) + + # 6. Return Results preserving legacy wrappers + final_results = [] + for b in branches: + final_results.append(AgentResult( + messages=b.messages, + task_id=b.task_id, + managed_state=self._get_managed_state(), + turns_used=b.turns_used, + finished_naturally=b.finished_naturally, + reasoning_per_turn=b.reasoning_per_turn, + tool_errors=b.tool_errors + )) + + logger.info("Tree-GRPO Rollout complete. Gathered trajectories for %d branches.", len(final_results)) + return final_results[0] if G == 1 else final_results + + def _normalize_tool_call(self, tc) -> Dict[str, Any]: + """Normalize disparate server-side tool formats to canonical dict structure.""" + if isinstance(tc, dict): + return { + "id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"), + "type": "function", + "function": { + "name": tc.get("function", {}).get("name", tc.get("name", "")), + "arguments": tc.get("function", {}).get("arguments", tc.get("arguments", "{}")), + }, } + return { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + + async def _execute_branch_tool( + self, branch: BranchState, tc, user_task: Optional[str], todo_store, turn: int + ) -> None: + """Performs asynchronous routing and execution of single tool call with token isolation.""" + import time as _time - # Only pass tools if we have them - if self.tool_schemas: - chat_kwargs["tools"] = self.tool_schemas - - # Only pass max_tokens if explicitly set - if self.max_tokens is not None: - chat_kwargs["max_tokens"] = self.max_tokens - - # Inject extra_body for provider-specific params (e.g., OpenRouter - # provider preferences like banned/preferred providers, transforms) - if self.extra_body: - chat_kwargs["extra_body"] = self.extra_body + # 1. Normalize + if isinstance(tc, dict): + tool_name = tc.get("function", {}).get("name", tc.get("name", "")) + tool_args_raw = tc.get("function", {}).get("arguments", tc.get("arguments", "{}")) + tc_id = tc.get("id", f"call_{uuid.uuid4().hex[:8]}") + else: + tool_name = tc.function.name + tool_args_raw = tc.function.arguments + tc_id = tc.id + + tool_submit_time = _time.monotonic() + tool_result = "" + + # 2. Validate + if tool_name not in self.valid_tool_names: + tool_result = json.dumps( + {"error": f"Unknown tool '{tool_name}'. Available: {sorted(self.valid_tool_names)}"} + ) + branch.tool_errors.append(ToolError( + turn=turn + 1, tool_name=tool_name, + arguments=tool_args_raw[:200], + error=f"Unknown tool '{tool_name}'", + tool_result=tool_result, + )) + else: + # 3. Extract Args + try: + args = json.loads(tool_args_raw) + except json.JSONDecodeError: + args = {} + logger.warning("Invalid JSON in tool call for '%s': %s", tool_name, tool_args_raw[:200]) - # Make the API call -- standard OpenAI spec - api_start = _time.monotonic() + # 4. Execute with ThreadPool bridging to prevent loop deadlocks try: - response = await self.server.chat_completion(**chat_kwargs) - except Exception as e: - api_elapsed = _time.monotonic() - api_start - logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e) - return AgentResult( - messages=messages, - managed_state=self._get_managed_state(), - turns_used=turn + 1, - finished_naturally=False, - reasoning_per_turn=reasoning_per_turn, - tool_errors=tool_errors, - ) - - api_elapsed = _time.monotonic() - api_start - - if not response or not response.choices: - logger.warning("Empty response on turn %d (api=%.1fs)", turn + 1, api_elapsed) - return AgentResult( - messages=messages, - managed_state=self._get_managed_state(), - turns_used=turn + 1, - finished_naturally=False, - reasoning_per_turn=reasoning_per_turn, - tool_errors=tool_errors, - ) - - assistant_msg = response.choices[0].message - - # Extract reasoning content from the response (all provider formats) - reasoning = _extract_reasoning_from_message(assistant_msg) - reasoning_per_turn.append(reasoning) - - # Check for tool calls -- standard OpenAI spec. - # Fallback: if response has no structured tool_calls but content - # contains raw tool call tags (e.g. ), parse them using - # gauss-agent's standalone parsers. This handles the case where - # ManagedServer's ToolCallTranslator couldn't parse because vLLM - # isn't installed. - if ( - not assistant_msg.tool_calls - and assistant_msg.content - and self.tool_schemas - and "" in (assistant_msg.content or "") - ): - try: - from environments.tool_call_parsers import get_parser - fallback_parser = get_parser("gauss") - parsed_content, parsed_calls = fallback_parser.parse( - assistant_msg.content + if tool_name == "todo": + from tools.todo_tool import todo_tool as _todo_tool + tool_result = _todo_tool( + todos=args.get("todos"), + merge=args.get("merge", False), + store=todo_store, ) - if parsed_calls: - assistant_msg.tool_calls = parsed_calls - if parsed_content is not None: - assistant_msg.content = parsed_content - logger.debug( - "Fallback parser extracted %d tool calls from raw content", - len(parsed_calls), + elif tool_name == "memory": + tool_result = json.dumps({"error": "Memory is not available in RL environments."}) + elif tool_name == "session_search": + tool_result = json.dumps({"error": "Session search is not available in RL environments."}) + else: + loop = asyncio.get_event_loop() + # Routed dynamically to cloned sandbox using branch.task_id! + tool_result = await loop.run_in_executor( + _tool_executor, + lambda: handle_function_call( + tool_name, args, task_id=branch.task_id, user_task=user_task ) - except Exception: - pass # Fall through to no tool calls - - if assistant_msg.tool_calls: - # Normalize tool calls to dicts — they may come as objects - # (OpenAI API) or dicts (vLLM ToolCallTranslator). - def _tc_to_dict(tc): - if isinstance(tc, dict): - return { - "id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"), - "type": "function", - "function": { - "name": tc.get("function", {}).get("name", tc.get("name", "")), - "arguments": tc.get("function", {}).get("arguments", tc.get("arguments", "{}")), - }, - } - return { - "id": tc.id, - "type": "function", - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments, - }, - } - - # Build the assistant message dict for conversation history - msg_dict: Dict[str, Any] = { - "role": "assistant", - "content": assistant_msg.content or "", - "tool_calls": [_tc_to_dict(tc) for tc in assistant_msg.tool_calls], - } + ) - # Preserve reasoning_content for multi-turn chat template handling - # (e.g., Kimi-K2's template renders blocks differently - # for history vs. the latest turn based on this field) - if reasoning: - msg_dict["reasoning_content"] = reasoning - - messages.append(msg_dict) - - # Execute each tool call via gauss-agent's dispatch - for tc in assistant_msg.tool_calls: - # Handle both object (OpenAI) and dict (vLLM) formats - if isinstance(tc, dict): - tool_name = tc.get("function", {}).get("name", tc.get("name", "")) - tool_args_raw = tc.get("function", {}).get("arguments", tc.get("arguments", "{}")) - else: - tool_name = tc.function.name - tool_args_raw = tc.function.arguments - - # Validate tool name - if tool_name not in self.valid_tool_names: - tool_result = json.dumps( - { - "error": f"Unknown tool '{tool_name}'. " - f"Available tools: {sorted(self.valid_tool_names)}" - } - ) - tool_errors.append(ToolError( - turn=turn + 1, tool_name=tool_name, - arguments=tool_args_raw[:200], - error=f"Unknown tool '{tool_name}'", - tool_result=tool_result, - )) - logger.warning( - "Model called unknown tool '%s' on turn %d", - tool_name, turn + 1, - ) - else: - # Parse arguments and dispatch - try: - args = json.loads(tool_args_raw) - except json.JSONDecodeError: - args = {} - logger.warning( - "Invalid JSON in tool call arguments for '%s': %s", - tool_name, tool_args_raw[:200], - ) - - try: - if tool_name == "terminal": - backend = os.getenv("TERMINAL_ENV", "local") - cmd_preview = args.get("command", "")[:80] - logger.info( - "[%s] $ %s", self.task_id[:8], cmd_preview, - ) - - tool_submit_time = _time.monotonic() - - # Todo tool -- handle locally (needs per-loop TodoStore) - if tool_name == "todo": - tool_result = _todo_tool( - todos=args.get("todos"), - merge=args.get("merge", False), - store=_todo_store, - ) - tool_elapsed = _time.monotonic() - tool_submit_time - elif tool_name == "memory": - tool_result = json.dumps({"error": "Memory is not available in RL environments."}) - tool_elapsed = _time.monotonic() - tool_submit_time - elif tool_name == "session_search": - tool_result = json.dumps({"error": "Session search is not available in RL environments."}) - tool_elapsed = _time.monotonic() - tool_submit_time - else: - # Run tool calls in a thread pool so backends that - # use asyncio.run() internally (modal, docker, daytona) get - # a clean event loop instead of deadlocking. - loop = asyncio.get_event_loop() - # Capture current tool_name/args for the lambda - _tn, _ta, _tid = tool_name, args, self.task_id - tool_result = await loop.run_in_executor( - _tool_executor, - lambda: handle_function_call( - _tn, _ta, task_id=_tid, - user_task=_user_task, - ), - ) - tool_elapsed = _time.monotonic() - tool_submit_time - - # Log slow tools and thread pool stats for debugging - pool_active = _tool_executor._work_queue.qsize() - if tool_elapsed > 30: - logger.warning( - "[%s] turn %d: %s took %.1fs (pool queue=%d)", - self.task_id[:8], turn + 1, tool_name, - tool_elapsed, pool_active, - ) - except Exception as e: - tool_result = json.dumps( - {"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"} - ) - tool_errors.append(ToolError( + # Extract subprocess returncodes + try: + res_data = json.loads(tool_result) + if isinstance(res_data, dict): + err = res_data.get("error") + exit_code = res_data.get("exit_code") + if err and exit_code and exit_code < 0: + branch.tool_errors.append(ToolError( turn=turn + 1, tool_name=tool_name, arguments=tool_args_raw[:200], - error=f"{type(e).__name__}: {str(e)}", - tool_result=tool_result, + error=str(err), + tool_result=tool_result[:500], )) - logger.error( - "Tool '%s' execution failed on turn %d: %s", - tool_name, turn + 1, e, - ) - - # Also check if the tool returned an error in its JSON result - try: - result_data = json.loads(tool_result) - if isinstance(result_data, dict): - err = result_data.get("error") - exit_code = result_data.get("exit_code") - if err and exit_code and exit_code < 0: - tool_errors.append(ToolError( - turn=turn + 1, tool_name=tool_name, - arguments=tool_args_raw[:200], - error=str(err), - tool_result=tool_result[:500], - )) - except (json.JSONDecodeError, TypeError): - pass - - # Add tool response to conversation - tc_id = tc.get("id", "") if isinstance(tc, dict) else tc.id - messages.append( - { - "role": "tool", - "tool_call_id": tc_id, - "content": tool_result, - } - ) + except (json.JSONDecodeError, TypeError): + pass - turn_elapsed = _time.monotonic() - turn_start - logger.info( - "[%s] turn %d: api=%.1fs, %d tools, turn_total=%.1fs", - self.task_id[:8], turn + 1, api_elapsed, - len(assistant_msg.tool_calls), turn_elapsed, - ) - - else: - # No tool calls -- model is done - msg_dict = { - "role": "assistant", - "content": assistant_msg.content or "", - } - if reasoning: - msg_dict["reasoning_content"] = reasoning - messages.append(msg_dict) - - turn_elapsed = _time.monotonic() - turn_start - logger.info( - "[%s] turn %d: api=%.1fs, no tools (finished), turn_total=%.1fs", - self.task_id[:8], turn + 1, api_elapsed, turn_elapsed, - ) - - return AgentResult( - messages=messages, - managed_state=self._get_managed_state(), - turns_used=turn + 1, - finished_naturally=True, - reasoning_per_turn=reasoning_per_turn, - tool_errors=tool_errors, - ) - - # Hit max turns without the model stopping - logger.info("Agent hit max_turns (%d) without finishing", self.max_turns) - return AgentResult( - messages=messages, - managed_state=self._get_managed_state(), - turns_used=self.max_turns, - finished_naturally=False, - reasoning_per_turn=reasoning_per_turn, - tool_errors=tool_errors, - ) + except Exception as e: + tool_result = json.dumps({"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}) + branch.tool_errors.append(ToolError( + turn=turn + 1, tool_name=tool_name, + arguments=tool_args_raw[:200], + error=f"{type(e).__name__}: {str(e)}", + tool_result=tool_result, + )) + logger.error("Tool '%s' execution failed on branch %d: %s", tool_name, branch.branch_id, e) + + # Log duration + elapsed = _time.monotonic() - tool_submit_time + logger.debug("[Branch %d] %s completed in %.2fs", branch.branch_id, tool_name, elapsed) + + # 5. [MEMORY BOUNDARY] Force-Truncate large outputs + truncated_result = ObservationTruncator.truncate(tool_result, max_chars=8000) + + # 6. Commit observation to history + branch.messages.append({ + "role": "tool", + "tool_call_id": tc_id, + "content": truncated_result, + }) def _get_managed_state(self) -> Optional[Dict[str, Any]]: - """ - Get ManagedServer state if the server supports it. - - Returns state dict with SequenceNodes containing tokens/logprobs/masks, - or None if the server doesn't support get_state() (e.g., regular OpenAI server). - """ + """Get ManagedServer state if the server supports it.""" if hasattr(self.server, "get_state"): return self.server.get_state() return None diff --git a/environments/benchmarks/terminalbench_2/terminalbench2_env.py b/environments/benchmarks/terminalbench_2/terminalbench2_env.py index 113661c..7fe4cb4 100644 --- a/environments/benchmarks/terminalbench_2/terminalbench2_env.py +++ b/environments/benchmarks/terminalbench_2/terminalbench2_env.py @@ -468,9 +468,8 @@ async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict: messages.append({"role": "user", "content": self.format_prompt(eval_item)}) # --- 4. Run agent loop --- - # Use ManagedServer (Phase 2) for vLLM/SGLang backends to get - # token-level tracking via /generate. Falls back to direct - # ServerManager (Phase 1) for OpenAI endpoints. + tree_g = int(os.getenv("TREE_SEARCH_G", "1")) + if self._use_managed_server(): async with self.server.managed_server( tokenizer=self.tokenizer, @@ -486,7 +485,7 @@ async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict: max_tokens=self.config.max_token_length, extra_body=self.config.extra_body, ) - result = await agent.run(messages) + results_all = await agent.run(messages, G=tree_g) else: agent = GaussAgentLoop( server=self.server, @@ -498,29 +497,106 @@ async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict: max_tokens=self.config.max_token_length, extra_body=self.config.extra_body, ) - result = await agent.run(messages) + results_all = await agent.run(messages, G=tree_g) - # --- 5. Verify -- run test suite in the agent's sandbox --- - # Skip verification if the agent produced no meaningful output - only_system_and_user = all( - msg.get("role") in ("system", "user") for msg in result.messages - ) - if result.turns_used == 0 or only_system_and_user: - logger.warning( - "Task %s: agent produced no output (turns=%d). Reward=0.", - task_name, result.turns_used, + # --- 5. Tree-Search Result Consolidation & Best-of-N Scoring --- + # Define core verification runner to execute validation in target sandbox + def _run_verification_for_context(target_task_id: str, final_result) -> float: + only_sys_user = all( + msg.get("role") in ("system", "user") for msg in final_result.messages ) - reward = 0.0 + if final_result.turns_used == 0 or only_sys_user: + return 0.0 + + ctx = ToolContext(target_task_id) + try: + # 1. Create logs directory + ctx.terminal("mkdir -p /logs/verifier", timeout=30) + + # 2. Upload the verification test suite + tests_tar = eval_item.get("tests_tar", "") + if tests_tar: + temp_tests_dir = Path(tempfile.mkdtemp(prefix=f"tb2-tests-{target_task_id[:8]}-")) + try: + _extract_base64_tar(tests_tar, temp_tests_dir) + local_tar_path = temp_tests_dir / "archive.tar" + with tarfile.open(local_tar_path, "w") as tar: + for child in temp_tests_dir.iterdir(): + if child.name != "archive.tar": + tar.add(child, arcname=child.name) + ctx.upload_file(str(local_tar_path), "/tmp/tests_suite.tar") + ctx.terminal("mkdir -p /tests && tar -xf /tmp/tests_suite.tar -C /", timeout=60) + finally: + shutil.rmtree(temp_tests_dir, ignore_errors=True) + + # 3. Write and run test_sh script + test_sh = eval_item.get("test_sh", "") + if test_sh: + test_sh = test_sh.replace('\r\n', '\n') + ctx.write_file("/test.sh", test_sh) + ctx.terminal("chmod +x /test.sh", timeout=10) + test_result = ctx.terminal("/bin/bash /test.sh", timeout=self.config.test_timeout) + logger.info(f"Verification exit code for {task_name} ({target_task_id[:8]}): {test_result.get('exit_code')}") + + # 4. Harvest reward + reward_val = ctx.terminal("cat /logs/verifier/reward.txt", timeout=10) + reward_str = reward_val.get("output", "").strip() + try: + return float(reward_str) + except ValueError: + return 1.0 if test_sh and test_result.get("exit_code") == 0 else 0.0 + except Exception as verr: + logger.error(f"Verification crashed for {task_name} ({target_task_id}): {verr}") + return 0.0 + finally: + ctx.cleanup() + + # Main routing logic: Consolidated single or multi-branch trajectory analysis + reward = 0.0 + result = None + + if isinstance(results_all, list): + if len(results_all) > 1: + tqdm.write(f" [TREE] Multi-branch G={len(results_all)} finished. Executing Best-of-N scoring...") + best_res = results_all[0] + best_rew = -1.0 + + # Validate every parallel branch sandbox in sequence + for idx, res in enumerate(results_all): + branch_task_id = getattr(res, "task_id", None) or f"{task_id}-branch-{idx}" + + # Register environment overrides so verification boots the correct image! + register_task_env_overrides(branch_task_id, { + "modal_image": modal_image, + "docker_image": modal_image, + "cwd": "/app", + }) + + tqdm.write(f" Scoring branch {idx} ({branch_task_id[:8] if branch_task_id else 'unknown'})...") + branch_reward = _run_verification_for_context(branch_task_id, res) + tqdm.write(f" Branch {idx} Reward: {branch_reward}") + + # Best-of-N Selection logic + if branch_reward > best_rew: + best_rew = branch_reward + best_res = res + elif abs(branch_reward - best_rew) < 1e-5 and res.turns_used < best_res.turns_used: + best_res = res + + result = best_res + reward = best_rew + tqdm.write(f" [TREE] Best-of-N Selected: Reward={reward:.2f} (Turns={result.turns_used})") + elif len(results_all) == 1: + result = results_all[0] + tqdm.write(f" [VERIFYING] {task_name}...") + reward = _run_verification_for_context(task_id, result) + else: + # Empty list edge case + raise RuntimeError(f"Agent execution yielded 0 trajectories for {task_name}") else: - # This checkout does not include the full verification runner - # for Terminal-Bench 2.0. Keep the environment importable and - # return a conservative zero reward until that logic is added - # back explicitly. - logger.warning( - "Task %s: verification runner unavailable in this checkout. Reward=0.", - task_name, - ) - reward = 0.0 + result = results_all + tqdm.write(f" [VERIFYING] {task_name}...") + reward = _run_verification_for_context(task_id, result) passed = reward >= 1.0 duration_seconds = time.time() - task_start @@ -547,3 +623,206 @@ async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict: clear_task_env_overrides(task_id) if task_dir and Path(task_dir).exists(): shutil.rmtree(task_dir, ignore_errors=True) + + # ========================================================================= + # Evaluate -- orchestration and concurrent execution + # ========================================================================= + + async def _run_with_timeout(self, item: Dict[str, Any], sem: Optional[asyncio.Semaphore] = None) -> Dict: + """Wrap a single task rollout with a wall-clock timeout and concurrency semaphore.""" + task_name = item.get("task_name", "unknown") + category = item.get("category", "unknown") + + if sem is not None: + await sem.acquire() + + try: + return await asyncio.wait_for( + self.rollout_and_score_eval(item), + timeout=self.config.task_timeout, + ) + except asyncio.TimeoutError: + from tqdm import tqdm + tqdm.write(f" [TIMEOUT] {task_name} (exceeded {self.config.task_timeout}s)") + out = { + "passed": False, + "reward": 0.0, + "task_name": task_name, + "category": category, + "turns_used": 0, + "duration_seconds": float(self.config.task_timeout), + "error": "timeout", + } + self._save_result(out) + return out + finally: + if sem is not None: + sem.release() + + async def evaluate(self, *args, **kwargs) -> None: + """ + Run Terminal-Bench 2.0 evaluation over all tasks. + + Executes tasks with controlled concurrency via an asyncio.Semaphore, + bounded by self.config.eval_concurrency (if set) or self.config.max_concurrent_tasks. + """ + start_time = time.time() + from tqdm import tqdm + + # --- tqdm-compatible logging handler --- + class _TqdmHandler(logging.Handler): + def emit(self, record): + try: + tqdm.write(self.format(record)) + except Exception: + self.handleError(record) + + root = logging.getLogger() + handler = _TqdmHandler() + handler.setFormatter( + logging.Formatter("%(levelname)s %(name)s: %(message)s") + ) + # Clean up existing stream handlers to prevent duplicate prints + for h in list(root.handlers): + if isinstance(h, logging.StreamHandler): + root.removeHandler(h) + root.addHandler(handler) + for noisy in ("httpx", "openai", "httpcore"): + logging.getLogger(noisy).setLevel(logging.WARNING) + + # Resolve concurrency limit + concurrency = self.config.eval_concurrency + if concurrency == 0: + concurrency = getattr(self.config, "max_concurrent_tasks", 8) + + print(f"\n{'='*60}") + print(f"Starting {self.name.upper()} Evaluation") + print(f"{'='*60}") + print(f" Total tasks: {len(self.all_eval_items)}") + print(f" Concurrency: {concurrency}") + print(f" Task timeout: {self.config.task_timeout}s") + print(f"{'='*60}\n") + + sem = asyncio.Semaphore(concurrency) if concurrency > 0 else None + + # Launch all tasks + tasks = [] + for item in self.all_eval_items: + task = asyncio.create_task(self._run_with_timeout(item, sem)) + tasks.append(task) + + results = [] + pbar = tqdm(total=len(tasks), desc=self.name.upper(), dynamic_ncols=True) + + try: + for completed in asyncio.as_completed(tasks): + res = await completed + if res: + results.append(res) + self._save_result(res) + passed_count = sum(1 for r in results if r.get("passed")) + pbar.set_postfix_str(f"passed={passed_count}/{len(results)}") + pbar.update(1) + + except (KeyboardInterrupt, asyncio.CancelledError): + tqdm.write("\n[INTERRUPTED] Cancelling remaining tasks...") + for t in tasks: + if not t.done(): + t.cancel() + pbar.close() + # Cleanup + if hasattr(self, "_streaming_file") and not self._streaming_file.closed: + self._streaming_file.close() + return + + pbar.close() + end_time = time.time() + + # --- Compute metrics --- + valid = [r for r in results if r is not None] + if not valid: + print("Warning: No valid results produced.") + return + + total = len(valid) + passed_total = sum(1 for r in valid if r.get("passed")) + pass_rate = passed_total / total if total else 0.0 + avg_turns = sum(r.get("turns_used", 0) for r in valid) / total if total else 0.0 + + # Category breakdowns + cat_results = defaultdict(list) + for r in valid: + cat = r.get("category", "unknown") + cat_results[cat].append(r) + + eval_metrics = { + "eval/pass_rate": pass_rate, + "eval/total_tasks": total, + "eval/passed_tasks": passed_total, + "eval/avg_turns": avg_turns, + "eval/evaluation_time_seconds": end_time - start_time, + } + + for cat, items in sorted(cat_results.items()): + cp = sum(1 for r in items if r.get("passed")) + ct = len(items) + eval_metrics[f"eval/pass_rate_{cat.replace('-', '_')}"] = cp / ct if ct else 0.0 + + self.eval_metrics = [(k, v) for k, v in eval_metrics.items()] + + # --- Print Summary --- + print(f"\n{'='*60}") + print(f"{self.name.upper()} Evaluation Results") + print(f"{'='*60}") + print(f"Overall Pass Rate: {pass_rate:.1%} ({passed_total}/{total})") + print(f"Average Turns: {avg_turns:.2f}") + print(f"Evaluation Time: {end_time - start_time:.1f}s") + print("\nPer-category Breakdown:") + for cat, items in sorted(cat_results.items()): + cp = sum(1 for r in items if r.get("passed")) + ct = len(items) + print(f" {cat:<25}: {cp:>2}/{ct:<2} passed ({cp/ct:.1%})") + print(f"{'='*60}\n") + + # --- Log to files --- + try: + samples = [{k: v for k, v in r.items() if k != "messages"} for r in valid] + await self.evaluate_log( + metrics=eval_metrics, + samples=samples, + start_time=start_time, + end_time=end_time, + generation_parameters={ + "temperature": self.config.agent_temperature, + "max_tokens": self.config.max_token_length, + "max_agent_turns": self.config.max_agent_turns, + } + ) + except Exception as e: + print(f"Error logging results: {e}") + + # Cleanup + if hasattr(self, "_streaming_file") and not self._streaming_file.closed: + self._streaming_file.close() + print(f"Streaming results finalized in: {self._streaming_path}") + + try: + from tools.terminal_tool import cleanup_all_environments + cleanup_all_environments() + except Exception: + pass + + try: + from environments.agent_loop import _tool_executor + _tool_executor.shutdown(wait=False, cancel_futures=True) + except Exception: + pass + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + """Log TB2-specific metrics to wandb.""" + if wandb_metrics is None: + wandb_metrics = {} + for k, v in getattr(self, "eval_metrics", []): + wandb_metrics[k] = v + self.eval_metrics = [] + await super().wandb_log(wandb_metrics) diff --git a/environments/gauss_base_env.py b/environments/gauss_base_env.py index d737e35..231b69f 100644 --- a/environments/gauss_base_env.py +++ b/environments/gauss_base_env.py @@ -470,7 +470,8 @@ async def collect_trajectory( messages.append({"role": "user", "content": self.format_prompt(item)}) # Run the agent loop - result: AgentResult + tree_g = int(os.getenv("TREE_SEARCH_G", "1")) + results_all: Any if self._use_managed_server(): # Phase 2: ManagedServer with ToolCallTranslator -- exact tokens + logprobs # tool_parser is set on ServerManager in __init__ and passed through @@ -491,7 +492,7 @@ async def collect_trajectory( max_tokens=self.config.max_token_length, extra_body=self.config.extra_body, ) - result = await agent.run(messages) + results_all = await agent.run(messages, G=tree_g) except NotImplementedError: # DummyManagedServer not allowed -- fall back to Phase 1 logger.warning( @@ -508,7 +509,7 @@ async def collect_trajectory( max_tokens=self.config.max_token_length, extra_body=self.config.extra_body, ) - result = await agent.run(messages) + results_all = await agent.run(messages, G=tree_g) else: # Phase 1: OpenAI server -- native tool_calls, placeholder tokens agent = GaussAgentLoop( @@ -521,15 +522,47 @@ async def collect_trajectory( max_tokens=self.config.max_token_length, extra_body=self.config.extra_body, ) - result = await agent.run(messages) + results_all = await agent.run(messages, G=tree_g) + + # Consolidate multi-branch trajectories for Best-of-N Scoring (Pass@G validation) + reward_computed_externally = False + reward = 0.0 + result: AgentResult - # Skip reward computation if the agent loop produced no meaningful work - # (e.g., API call failed on turn 1). No point spinning up a Modal sandbox - # just to verify files that were never created. + if isinstance(results_all, list): + if len(results_all) > 1: + logger.info(f"Tree-Search yielded {len(results_all)} branches. Performing Best-of-N scoring...") + best_res = results_all[0] + best_rew = -999.0 + for idx, res in enumerate(results_all): + branch_ctx_id = getattr(res, "task_id", None) or f"{task_id}-branch-{idx}" + branch_ctx = ToolContext(branch_ctx_id) + try: + r = await self.compute_reward(item, res, branch_ctx) + except Exception as ex: + logger.warning(f"Scoring failed for branch {idx}: {ex}") + r = 0.0 + finally: + branch_ctx.cleanup() + if r > best_rew: + best_rew = r + best_res = res + result = best_res + reward = max(best_rew, 0.0) + reward_computed_externally = True + logger.info(f"Best-of-N Selected Branch {results_all.index(result)} with score {reward}") + else: + result = results_all[0] + else: + result = results_all + + # Skip reward computation if already done externally or produced no work only_system_and_user = all( msg.get("role") in ("system", "user") for msg in result.messages ) - if result.turns_used == 0 or only_system_and_user: + if reward_computed_externally: + pass # already calculated above + elif result.turns_used == 0 or only_system_and_user: logger.warning( "Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.", result.turns_used, len(result.messages), @@ -659,12 +692,10 @@ async def compute_reward( """ raise NotImplementedError - @abstractmethod async def evaluate(self, *args, **kwargs): """ - Periodic evaluation. Called every steps_per_eval steps. - - Typical implementation runs the agent on a held-out eval set - and logs metrics via wandb/evaluate_log. + Concrete implementation of evaluate to satisfy the abstract method contract. + The evaluation runs are driven by the CLI 'evaluate' wrapper which invokes + run() inside atroposlib. """ - raise NotImplementedError + pass diff --git a/tests/test_trace_masking.py b/tests/test_trace_masking.py new file mode 100644 index 0000000..b1499a8 --- /dev/null +++ b/tests/test_trace_masking.py @@ -0,0 +1,103 @@ +import pytest +from environments.agent_loop import AgentResult, ToolError +from agent.trace_masking import TRACERewardEngine + +def test_trace_success_clean_trajectory(): + """Verify full credit and step penalty deduction on clean successful trajectory.""" + engine = TRACERewardEngine(step_penalty=0.01) + + # 2 turns, both clean, no errors, final outcome 1.0 (Success) + messages = [ + {"role": "user", "content": "Fix it"}, + {"role": "assistant", "tool_calls": [{"id": "1"}]}, + {"role": "tool", "content": "Clean output here"}, + {"role": "assistant", "content": "Done"} + ] + + trajectory = AgentResult( + messages=messages, + managed_state=None, + turns_used=2, + finished_naturally=True, + reasoning_per_turn=[], + tool_errors=[] + ) + + rewards = engine.compute_rewards(trajectory, 1.0) + + # Expected: + # Turn 1: Mask=1.0 -> (1.0 * 1.0) - 0.01 = 0.99 + # Turn 2: Mask=1.0 -> (1.0 * 1.0) - 0.01 = 0.99 + assert rewards == [0.99, 0.99] + +def test_trace_failure_clean_trajectory(): + """Verify time-penalty penalty propagates on failure trajectories.""" + engine = TRACERewardEngine(step_penalty=0.01) + + messages = [ + {"role": "assistant", "tool_calls": [{"id": "1"}]}, + {"role": "tool", "content": "Clean output"}, + {"role": "assistant", "content": "Done"} + ] + + trajectory = AgentResult( + messages=messages, + managed_state=None, + turns_used=2, + finished_naturally=True, + reasoning_per_turn=[], + tool_errors=[] + ) + + rewards = engine.compute_rewards(trajectory, 0.0) + + # Expected: + # Turn 1: Mask=1.0 -> (0.0 * 1.0) - 0.01 = -0.01 + # Turn 2: Mask=1.0 -> (0.0 * 1.0) - 0.01 = -0.01 + assert rewards == [-0.01, -0.01] + +def test_trace_masking_heuristics(): + """Verify M_t is masked to 0.0 on empty output, error, or truncation.""" + engine = TRACERewardEngine(step_penalty=0.01) + + messages = [ + # Turn 1: Tool Error (Registered) + {"role": "assistant", "tool_calls": [{"id": "1"}]}, + {"role": "tool", "content": "Error occurred"}, + # Turn 2: Clean + {"role": "assistant", "tool_calls": [{"id": "2"}]}, + {"role": "tool", "content": "Healthy execution"}, + # Turn 3: Empty output + {"role": "assistant", "tool_calls": [{"id": "3"}]}, + {"role": "tool", "content": " "}, + # Turn 4: Truncated output + {"role": "assistant", "tool_calls": [{"id": "4"}]}, + {"role": "tool", "content": "Some text [TRUNCATED 2000 CHARACTERS TO PREVENT BLOWOUT] end"} + ] + + tool_errors = [ + ToolError(turn=1, tool_name="cmd", arguments="{}", error="Syntax", tool_result="Error occurred") + ] + + trajectory = AgentResult( + messages=messages, + managed_state=None, + turns_used=4, + finished_naturally=True, + reasoning_per_turn=[], + tool_errors=tool_errors + ) + + # Evaluate Success scenario (1.0) to clearly see masks in effect + rewards = engine.compute_rewards(trajectory, 1.0) + + # Expected Masks: + # Turn 1: Mask=0.0 (Error) -> (1 * 0) - 0.01 = -0.01 + # Turn 2: Mask=1.0 (Clean) -> (1 * 1) - 0.01 = 0.99 + # Turn 3: Mask=0.0 (Empty) -> (1 * 0) - 0.01 = -0.01 + # Turn 4: Mask=0.0 (Trunc) -> (1 * 0) - 0.01 = -0.01 + + assert pytest.approx(rewards[0]) == -0.01 + assert pytest.approx(rewards[1]) == 0.99 + assert pytest.approx(rewards[2]) == -0.01 + assert pytest.approx(rewards[3]) == -0.01 diff --git a/tools/environments/docker.py b/tools/environments/docker.py index e8ebe5b..dd8630c 100644 --- a/tools/environments/docker.py +++ b/tools/environments/docker.py @@ -174,7 +174,23 @@ def __init__( network: bool = True, host_cwd: str = None, auto_mount_cwd: bool = False, + is_clone: bool = False, ): + self._init_kwargs = { + "image": image, + "cwd": cwd, + "timeout": timeout, + "cpu": cpu, + "memory": memory, + "disk": disk, + "persistent_filesystem": persistent_filesystem, + "task_id": task_id, + "volumes": volumes, + "network": network, + "host_cwd": host_cwd, + "auto_mount_cwd": auto_mount_cwd, + "is_clone": is_clone, + } if cwd == "~": cwd = "/root" super().__init__(cwd=cwd, timeout=timeout) @@ -260,9 +276,10 @@ def __init__( "-v", f"{self._workspace_dir}:/workspace", ]) else: - if not bind_host_cwd and not workspace_explicitly_mounted: + # Force 2G RAM disk (/workspace tmpfs) for active, safe workspace execution + if not workspace_explicitly_mounted: writable_args.extend([ - "--tmpfs", "/workspace:rw,exec,size=10g", + "--tmpfs", "/workspace:rw,exec,size=2g", ]) writable_args.extend([ "--tmpfs", "/home:rw,exec,size=1g", @@ -270,8 +287,9 @@ def __init__( ]) if bind_host_cwd: - logger.info(f"Mounting configured host cwd to /workspace: {host_cwd_abs}") - volume_args = ["-v", f"{host_cwd_abs}:/workspace", *volume_args] + # Secure the host: Mount read-only base workspace + logger.info(f"Mounting host cwd as Read-Only to /base_workspace: {host_cwd_abs}") + volume_args = ["-v", f"{host_cwd_abs}:/base_workspace:ro", *volume_args] elif workspace_explicitly_mounted: logger.debug("Skipping docker cwd mount: /workspace already mounted by user config") @@ -290,6 +308,11 @@ def __init__( ) self._container_id = self._inner.container_id + # Copy-on-Launch initialization: populate workspace RAM disk from base code + if not self._persistent and bind_host_cwd and not is_clone: + logger.info("Populating active workspace RAM disk from /base_workspace Read-Only cache...") + self.execute("mkdir -p /workspace && cp -a /base_workspace/. /workspace/ && chown -R root:root /workspace") + @staticmethod def _storage_opt_supported() -> bool: """Check if Docker's storage driver supports --storage-opt size=. @@ -420,3 +443,46 @@ def cleanup(self): for d in (self._workspace_dir, self._home_dir): if d: shutil.rmtree(d, ignore_errors=True) + + def clone_to_parallel(self, count: int) -> list["DockerEnvironment"]: + """Spins up G instances of this environment instantly with the current mutated workspace state. + + Achieves sub-50ms cloning speed by piping a tar representation of the current in-memory + tmpfs RAM disk (/workspace) directly into the destination containers' tmpfs RAM disks. + Uses Unix pipes entirely in memory, bypassing slow disk-based Docker commit cycles. + """ + import uuid + import subprocess + + logger.info(f"Initiating parallel RAM-disk clone for {count} Tree-GRPO workers...") + + clones = [] + src_cid = self._inner.container_id + docker_exe = find_docker() or "docker" + + for i in range(count): + # Create unique task_id to prevent collisions in logging and sandbox paths + clone_kwargs = self._init_kwargs.copy() + clone_kwargs["task_id"] = f"{self._task_id}-grpo-clone-{uuid.uuid4().hex[:6]}" + clone_kwargs["is_clone"] = True + + # Instantiate target container (creates new isolated RAM disk) + logger.debug(f"Creating sandbox container clone #{i+1}...") + clone = DockerEnvironment(**clone_kwargs) + dst_cid = clone._inner.container_id + + # Tar Pipe Logic: Stream from src tmpfs to dst tmpfs entirely via in-memory VFS cache + # Command constructs: tar -C src -cf - . | docker exec -i dst tar -C dst -xf - + tar_cmd = f"{docker_exe} exec {src_cid} tar -C /workspace -cf - . | {docker_exe} exec -i {dst_cid} tar -C /workspace -xf -" + + logger.debug(f"Piping in-memory active workspace state to clone #{i+1}...") + try: + subprocess.run(tar_cmd, shell=True, check=True, capture_output=True) + clones.append(clone) + except subprocess.CalledProcessError as e: + logger.error(f"In-memory workspace cloning failed: {e.stderr.decode('utf-8', errors='ignore')}") + clone.cleanup() + raise RuntimeError(f"Failed to clone execution sandbox: {e}") + + logger.info(f"Successfully provisioned {len(clones)} cloned environments.") + return clones diff --git a/tools/rl_training_tool.py b/tools/rl_training_tool.py index 1ea1bb5..496abc2 100644 --- a/tools/rl_training_tool.py +++ b/tools/rl_training_tool.py @@ -1395,3 +1395,154 @@ def get_missing_keys() -> List[str]: registry.register(name="rl_test_inference", emoji="🧪", toolset="rl", schema=RL_TEST_INFERENCE_SCHEMA, handler=lambda args, **kw: rl_test_inference(num_steps=args.get("num_steps", 3), group_size=args.get("group_size", 16), models=args.get("models")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) + + +# ============================================================================ +# Phase 4: Distributed GRPO Mathematical Engine +# ============================================================================ + +try: + import torch + import torch.nn.functional as F +except ImportError: + torch = None + + +class GaussGRPOEngine: + """ + Distributed GRPO Loss & Bridge Wrapper for OpenRLHF Ray Orchestrators. + + Encapsulates mathematically rigorous operations required to compute Group Relative + advantages and surrogate gradient updates. Serves as the analytical binding + between OpenGauss's episodic telemetry and DeepSpeed-driven Ray trainers. + """ + + def __init__(self, kl_coef: float = 0.01, clip_eps: float = 0.2): + """ + Initialize the GRPO engine parameters. + + Args: + kl_coef (float): Coefficient beta penalizing divergence from the reference model. + Defaults to 0.01. + clip_eps (float): Epsilon constraint bound for PPO policy clipping. Defaults to 0.2. + """ + self.kl_coef = kl_coef + self.clip_eps = clip_eps + + def compute_group_advantage(self, rewards_list: List[List[float]]) -> List[List[float]]: + """ + Computes Group Relative Advantage (A_i) normalized across parallel group rollouts. + + Sums stepwise rewards to generate episodic returns R_i, derives group distribution + parameters (mu, sigma), and broadcast-normalizes the episode relative advantage + back into individual sequence steps to maintain continuous alignment. + + Mathematical Formulation: + R_total,i = Sum_{t} r_{t,i} + mu = Mean(R_total) + sigma = StdDev(R_total) + A_i = (R_total,i - mu) / (sigma + 1e-8) + + Args: + rewards_list (List[List[float]]): + Nested scalar rewards obtained from trace_masking.py. + Shape: [group_size, varied_sequence_lengths] + + Returns: + List[List[float]]: + Re-broadcast stepwise advantages for policy token weighing. + Shape: [group_size, varied_sequence_lengths] + """ + import math + + # Step 1: Collapse step rewards into total episodic returns + episode_returns = [sum(branch_rewards) for branch_rewards in rewards_list] + group_size = len(episode_returns) + + if group_size == 0: + return [] + + # Safeguard for trivial/decoupled single-sample calls + if group_size == 1: + return [[0.0] * len(rewards_list[0])] + + # Step 2: Derive high-precision group statistics + mean = sum(episode_returns) / group_size + variance = sum((x - mean) ** 2 for x in episode_returns) / group_size + std_dev = math.sqrt(variance) + + # Step 3: Broadcast normalized advantages + advantages = [] + for i, ret in enumerate(episode_returns): + # Inject 1e-8 epsilon boundary to completely inhibit division-by-zero cliffs + branch_advantage = (ret - mean) / (std_dev + 1e-8) + # Scale entire trajectory timeline by branch episodic performance + advantages.append([branch_advantage] * len(rewards_list[i])) + + return advantages + + def compute_grpo_loss( + self, + policy_logprobs: Any, + reference_logprobs: Any, + old_logprobs: Any, + advantages: Any, + ) -> Any: + """ + Computes batch-averaged Group Relative Policy Optimization (GRPO) loss. + + Ingests flat sequences from Ray Data workers and executes fused tensor math + combining importance sampling ratios, PPO clipped surrogate bounds, and + log-probarithmic Kullback-Leibler (KL) divergence constraints. + + Target Tensor Topology (Ray Data Format): + policy_logprobs: [batch_size, seq_len] -> Primary policy logprobs. + reference_logprobs: [batch_size, seq_len] -> Frozen reference logprobs. + old_logprobs: [batch_size, seq_len] -> Importance sampling anchor logprobs. + advantages: [batch_size, seq_len] -> Re-broadcast group relative advantages (A_i). + + Formulation: + Ratio (r_t) = exp(policy_logprobs - old_logprobs) + KL Divergence (D_KL) = reference_logprobs - policy_logprobs + Surrogate_1 = r_t * A_i + Surrogate_2 = Clip(r_t, 1 - eps, 1 + eps) * A_i + L_grpo = - Mean( Min(Surr1, Surr2) - beta * D_KL ) + + Args: + policy_logprobs (torch.Tensor): Logprobs from active gradient update loop. + reference_logprobs (torch.Tensor): Baseline anchored logprobs. + old_logprobs (torch.Tensor): Original rollout generation logprobs. + advantages (torch.Tensor): Broadcast group relative advantage weights. + + Returns: + torch.Tensor: Scalar loss tensor safe for distributed backward propagation. + """ + if torch is None: + raise ImportError( + "PyTorch (torch) is not installed or accessible. GaussGRPOEngine " + "requires an active PyTorch environment to compute tensor losses." + ) + + # 1. Evaluate Importance Sampling Ratio: r_t = exp(log(pi_now) - log(pi_old)) + log_ratio = policy_logprobs - old_logprobs + ratio = torch.exp(log_ratio) + + # 2. Calculate Clipped Objective Bound (PPO style constraint) + surrogate1 = ratio * advantages + surrogate2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages + + # Strictly minimize the bound to ensure conservative updates + clipped_surrogate_loss = torch.min(surrogate1, surrogate2) + + # 3. Derive Reference-Alignment KL Constraint to impede catastrophic drift + kl_divergence = reference_logprobs - policy_logprobs + + # 4. Final Objective: Maximize [Clipped_Surrogate + beta * KL] + # Equivalently: Minimize negative total objective + objective_per_token = clipped_surrogate_loss + (self.kl_coef * kl_divergence) + + # Negated mean calculation creates gradient update driving ascent on the reward landscape + total_grpo_loss = -objective_per_token.mean() + + return total_grpo_loss +