diff --git a/resources_servers/instruction_following/app.py b/resources_servers/instruction_following/app.py index feaae8f062..b0522a3d13 100644 --- a/resources_servers/instruction_following/app.py +++ b/resources_servers/instruction_following/app.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Literal +from typing import Any, Dict, List, Literal from fastapi import FastAPI from verifiable_instructions import instructions_registry @@ -45,9 +45,49 @@ class InstructionFollowingVerifyRequest(InstructionFollowingRunRequest, BaseVeri pass +def _get_loose_perturbations(text: str) -> list: + """Return IFEval loose-mode perturbations following the NeMo Skills convention. + + Produces 4 line-removal variants of the text (original, without first line, + without last line, without first and last line), each duplicated with asterisks + removed. Empty variants are excluded. + """ + + def remove_stars(s: str) -> str: + return s.replace("*", "") + + def without_first_line(s: str) -> str: + idx = s.find("\n") + return s[idx + 1 :] if idx >= 0 else "" + + def without_last_line(s: str) -> str: + idx = s.rfind("\n") + return s[:idx] if idx >= 0 else "" + + base = [ + text, + without_first_line(text), + without_last_line(text), + without_last_line(without_first_line(text)), + ] + return [v for s in base for v in (s, remove_stars(s)) if v.strip()] + + +def _check_following_loose(instruction, text: str) -> bool: + """Check instruction against native loose API or 8 perturbations.""" + if hasattr(instruction, "check_following_loose"): + return instruction.check_following_loose(text) + try: + return instruction.check_following(text, mode="loose") + except TypeError: + return any(instruction.check_following(p) for p in _get_loose_perturbations(text)) + + class InstructionFollowingVerifyResponse(BaseVerifyResponse): follow_all_instructions: bool follow_instruction_list: List[bool] + follow_all_instructions_loose: bool + follow_instruction_list_loose: List[bool] kwargs: List instruction_id_list: List prompt: str @@ -102,6 +142,7 @@ async def verify(self, body: InstructionFollowingVerifyRequest) -> InstructionFo instruction_list = body.instruction_id_list kwargs_list = body.kwargs is_following_list = [] + is_following_list_loose = [] for instruction_id, kwargs in zip(instruction_list, kwargs_list): try: @@ -119,16 +160,15 @@ async def verify(self, body: InstructionFollowingVerifyRequest) -> InstructionFo # Build the instruction description with the provided kwargs instruction.build_description(**filtered_kwargs) - # Check if the response follows the instruction - if instruction.check_following(final_response_text): - is_following_list.append(True) - else: - is_following_list.append(False) + # Check strict and loose from the same instruction instance + is_following_list.append(instruction.check_following(final_response_text)) + is_following_list_loose.append(_check_following_loose(instruction, final_response_text)) except Exception as e: # If there's an error processing the instruction, mark as failed print(f"Error processing instruction {instruction_id}: {e}") is_following_list.append(False) + is_following_list_loose.append(False) # Calculate overall success reward_mode = getattr(body, "grading_mode", "binary") @@ -144,8 +184,40 @@ async def verify(self, body: InstructionFollowingVerifyRequest) -> InstructionFo reward=float(reward), follow_all_instructions=all(is_following_list), follow_instruction_list=is_following_list, + follow_all_instructions_loose=all(is_following_list_loose), + follow_instruction_list_loose=is_following_list_loose, ) + def compute_metrics(self, tasks: List[List[Dict[str, Any]]]) -> Dict[str, Any]: + """Compute the four IFEval accuracy metrics over all verify responses. + + tasks[i] is the list of rollout dicts for task i. Each dict contains + follow_instruction_list and follow_instruction_list_loose. + """ + prompt_strict: list = [] + instruction_strict: list = [] + prompt_loose: list = [] + instruction_loose: list = [] + + for task_rollouts in tasks: + for rd in task_rollouts: + strict_list = rd.get("follow_instruction_list", []) + loose_list = rd.get("follow_instruction_list_loose", []) + prompt_strict.append(float(all(strict_list)) if strict_list else 0.0) + prompt_loose.append(float(all(loose_list)) if loose_list else 0.0) + instruction_strict.extend(float(v) for v in strict_list) + instruction_loose.extend(float(v) for v in loose_list) + + def _mean(lst: list) -> float: + return sum(lst) / len(lst) if lst else 0.0 + + return { + "prompt_strict_accuracy": _mean(prompt_strict) * 100.0, + "instruction_strict_accuracy": _mean(instruction_strict) * 100.0, + "prompt_loose_accuracy": _mean(prompt_loose) * 100.0, + "instruction_loose_accuracy": _mean(instruction_loose) * 100.0, + } + if __name__ == "__main__": InstructionFollowingResourcesServer.run_webserver() diff --git a/resources_servers/instruction_following/tests/test_app.py b/resources_servers/instruction_following/tests/test_app.py index 7e4f4590e5..d9014a5ae6 100644 --- a/resources_servers/instruction_following/tests/test_app.py +++ b/resources_servers/instruction_following/tests/test_app.py @@ -202,3 +202,46 @@ def test_fractional_reward_half(self): grading_mode="fraction", ) self._run_verify_test(real_request, False, 0.5, [True, False]) + + def test_loose_fields_present(self): + real_request = self._create_real_request( + instruction_ids=["punctuation:no_comma"], + prompt="The output should not contain any commas.", + kwargs=[{}], + response_content="Hello world without commas", + ) + server = self._create_server() + result = asyncio.run(server.verify(real_request)) + assert isinstance(result.follow_all_instructions_loose, bool) + assert isinstance(result.follow_instruction_list_loose, list) + assert len(result.follow_instruction_list_loose) == 1 + + def test_loose_geq_strict(self): + real_request = self._create_real_request( + instruction_ids=["punctuation:no_comma"], + prompt="No commas please.", + kwargs=[{}], + response_content="Hello, world", + ) + server = self._create_server() + result = asyncio.run(server.verify(real_request)) + for strict, loose in zip(result.follow_instruction_list, result.follow_instruction_list_loose): + assert loose >= strict + + def test_compute_metrics_four_keys(self): + server = self._create_server() + tasks = [ + [{"follow_instruction_list": [True, False], "follow_instruction_list_loose": [True, True]}], + [{"follow_instruction_list": [True], "follow_instruction_list_loose": [True]}], + ] + metrics = server.compute_metrics(tasks) + assert set(metrics.keys()) == { + "prompt_strict_accuracy", + "instruction_strict_accuracy", + "prompt_loose_accuracy", + "instruction_loose_accuracy", + } + assert metrics["prompt_strict_accuracy"] == 50.0 + assert abs(metrics["instruction_strict_accuracy"] - 200 / 3) < 1e-9 + assert metrics["prompt_loose_accuracy"] == 100.0 + assert metrics["instruction_loose_accuracy"] == 100.0