diff --git a/envs/repl_env/runner.py b/envs/repl_env/runner.py index 226b26657..bf37d0e7b 100644 --- a/envs/repl_env/runner.py +++ b/envs/repl_env/runner.py @@ -243,10 +243,18 @@ def _default_answer( ] try: response = self._chat(final_prompt, model) - # Try to extract FINAL(...) from the response - match = re.search(r"FINAL\((.*?)\)", response, re.DOTALL) - if match: - return match.group(1).strip() + # Try to extract FINAL(...) from the response. Paren-counting handles + # nested parens and mid-sentence FINAL without regex flag trade-offs. + idx = response.find("FINAL(") + if idx != -1: + depth, start = 0, idx + len("FINAL") + for i, ch in enumerate(response[idx + len("FINAL"):], start=idx + len("FINAL")): + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + if depth == 0: + return response[start + 1 : i].strip() # If no FINAL pattern, return the raw response as best-effort return response.strip() if response.strip() else None except Exception: diff --git a/envs/repl_env/server/repl_environment.py b/envs/repl_env/server/repl_environment.py index f2e6f5d98..8ba645a44 100644 --- a/envs/repl_env/server/repl_environment.py +++ b/envs/repl_env/server/repl_environment.py @@ -536,10 +536,18 @@ def _extract_final_answer(self, stdout: str) -> Optional[str]: Returns: Final answer string or None if not found """ - # Pattern 1: RLM-style FINAL(answer) - final_match = re.search(r"FINAL\((.*?)\)", stdout, re.DOTALL) - if final_match: - return final_match.group(1).strip() + # Pattern 1: RLM-style FINAL(answer). Paren-counting handles nested + # parens (e.g. FINAL(f(x))) and multi-line values without regex flag trade-offs. + idx = stdout.find("FINAL(") + if idx != -1: + depth, start = 0, idx + len("FINAL") + for i, ch in enumerate(stdout[idx + len("FINAL"):], start=idx + len("FINAL")): + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + if depth == 0: + return stdout[start + 1 : i].strip() # Pattern 2: RLM-style FINAL_VAR(variable_name) final_var_match = re.search(r"FINAL_VAR\((\w+)\)", stdout) diff --git a/tests/envs/test_repl_env.py b/tests/envs/test_repl_env.py index 4811119fb..b8cb8311e 100644 --- a/tests/envs/test_repl_env.py +++ b/tests/envs/test_repl_env.py @@ -229,6 +229,32 @@ def test_final_pattern_basic(self): assert obs.done assert obs.metadata["final_answer"] == "42" + @pytest.mark.parametrize( + "code, expected", + [ + # Nested function calls inside FINAL(...). + ("print('FINAL(f(x))')", "f(x)"), + # Tuple as the final answer. + ("print('FINAL((1, 2, 3))')", "(1, 2, 3)"), + # Math expression with multiple nested parens (e2b_repl_example). + ( + "print('FINAL(2^(2^(2^(2))) = 65536)')", + "2^(2^(2^(2))) = 65536", + ), + # Dict containing a tuple value. + ("print(\"FINAL({'a': (1, 2)})\")", "{'a': (1, 2)}"), + # Output after FINAL must not bleed into the extracted answer. + ("print('FINAL(42)\\nresult: (ok)')", "42"), + ], + ) + def test_final_pattern_nested_parentheses(self, code, expected): + """FINAL(...) extraction must handle nested parentheses (rlm #75).""" + env = REPLEnvironment() + env.reset() + obs = env.step(REPLAction(code=code)) + assert obs.done + assert obs.metadata["final_answer"] == expected + def test_final_var_pattern(self): """Test FINAL_VAR() pattern.""" env = REPLEnvironment()