Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 78 additions & 6 deletions resources_servers/instruction_following/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Literal
from typing import Any, Dict, List, Literal

from fastapi import FastAPI
from verifiable_instructions import instructions_registry
Expand Down Expand Up @@ -45,9 +45,49 @@ class InstructionFollowingVerifyRequest(InstructionFollowingRunRequest, BaseVeri
pass


def _get_loose_perturbations(text: str) -> list:
"""Return IFEval loose-mode perturbations following the NeMo Skills convention.

Produces 4 line-removal variants of the text (original, without first line,
without last line, without first and last line), each duplicated with asterisks
removed. Empty variants are excluded.
"""

def remove_stars(s: str) -> str:
return s.replace("*", "")

def without_first_line(s: str) -> str:
idx = s.find("\n")
return s[idx + 1 :] if idx >= 0 else ""

def without_last_line(s: str) -> str:
idx = s.rfind("\n")
return s[:idx] if idx >= 0 else ""

base = [
text,
without_first_line(text),
without_last_line(text),
without_last_line(without_first_line(text)),
]
return [v for s in base for v in (s, remove_stars(s)) if v.strip()]


def _check_following_loose(instruction, text: str) -> bool:
"""Check instruction against native loose API or 8 perturbations."""
if hasattr(instruction, "check_following_loose"):
return instruction.check_following_loose(text)
try:
return instruction.check_following(text, mode="loose")
except TypeError:
return any(instruction.check_following(p) for p in _get_loose_perturbations(text))


class InstructionFollowingVerifyResponse(BaseVerifyResponse):
follow_all_instructions: bool
follow_instruction_list: List[bool]
follow_all_instructions_loose: bool
follow_instruction_list_loose: List[bool]
kwargs: List
instruction_id_list: List
prompt: str
Expand Down Expand Up @@ -102,6 +142,7 @@ async def verify(self, body: InstructionFollowingVerifyRequest) -> InstructionFo
instruction_list = body.instruction_id_list
kwargs_list = body.kwargs
is_following_list = []
is_following_list_loose = []

for instruction_id, kwargs in zip(instruction_list, kwargs_list):
try:
Expand All @@ -119,16 +160,15 @@ async def verify(self, body: InstructionFollowingVerifyRequest) -> InstructionFo
# Build the instruction description with the provided kwargs
instruction.build_description(**filtered_kwargs)

# Check if the response follows the instruction
if instruction.check_following(final_response_text):
is_following_list.append(True)
else:
is_following_list.append(False)
# Check strict and loose from the same instruction instance
is_following_list.append(instruction.check_following(final_response_text))
is_following_list_loose.append(_check_following_loose(instruction, final_response_text))

except Exception as e:
# If there's an error processing the instruction, mark as failed
print(f"Error processing instruction {instruction_id}: {e}")
is_following_list.append(False)
is_following_list_loose.append(False)

# Calculate overall success
reward_mode = getattr(body, "grading_mode", "binary")
Expand All @@ -144,8 +184,40 @@ async def verify(self, body: InstructionFollowingVerifyRequest) -> InstructionFo
reward=float(reward),
follow_all_instructions=all(is_following_list),
follow_instruction_list=is_following_list,
follow_all_instructions_loose=all(is_following_list_loose),
follow_instruction_list_loose=is_following_list_loose,
)

def compute_metrics(self, tasks: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
"""Compute the four IFEval accuracy metrics over all verify responses.

tasks[i] is the list of rollout dicts for task i. Each dict contains
follow_instruction_list and follow_instruction_list_loose.
"""
prompt_strict: list = []
instruction_strict: list = []
prompt_loose: list = []
instruction_loose: list = []

for task_rollouts in tasks:
for rd in task_rollouts:
strict_list = rd.get("follow_instruction_list", [])
loose_list = rd.get("follow_instruction_list_loose", [])
prompt_strict.append(float(all(strict_list)) if strict_list else 0.0)
prompt_loose.append(float(all(loose_list)) if loose_list else 0.0)
instruction_strict.extend(float(v) for v in strict_list)
instruction_loose.extend(float(v) for v in loose_list)

def _mean(lst: list) -> float:
return sum(lst) / len(lst) if lst else 0.0

return {
"prompt_strict_accuracy": _mean(prompt_strict) * 100.0,
"instruction_strict_accuracy": _mean(instruction_strict) * 100.0,
"prompt_loose_accuracy": _mean(prompt_loose) * 100.0,
"instruction_loose_accuracy": _mean(instruction_loose) * 100.0,
}


if __name__ == "__main__":
InstructionFollowingResourcesServer.run_webserver()
43 changes: 43 additions & 0 deletions resources_servers/instruction_following/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,46 @@ def test_fractional_reward_half(self):
grading_mode="fraction",
)
self._run_verify_test(real_request, False, 0.5, [True, False])

def test_loose_fields_present(self):
real_request = self._create_real_request(
instruction_ids=["punctuation:no_comma"],
prompt="The output should not contain any commas.",
kwargs=[{}],
response_content="Hello world without commas",
)
server = self._create_server()
result = asyncio.run(server.verify(real_request))
assert isinstance(result.follow_all_instructions_loose, bool)
assert isinstance(result.follow_instruction_list_loose, list)
assert len(result.follow_instruction_list_loose) == 1

def test_loose_geq_strict(self):
real_request = self._create_real_request(
instruction_ids=["punctuation:no_comma"],
prompt="No commas please.",
kwargs=[{}],
response_content="Hello, world",
)
server = self._create_server()
result = asyncio.run(server.verify(real_request))
for strict, loose in zip(result.follow_instruction_list, result.follow_instruction_list_loose):
assert loose >= strict

def test_compute_metrics_four_keys(self):
server = self._create_server()
tasks = [
[{"follow_instruction_list": [True, False], "follow_instruction_list_loose": [True, True]}],
[{"follow_instruction_list": [True], "follow_instruction_list_loose": [True]}],
]
metrics = server.compute_metrics(tasks)
assert set(metrics.keys()) == {
"prompt_strict_accuracy",
"instruction_strict_accuracy",
"prompt_loose_accuracy",
"instruction_loose_accuracy",
}
assert metrics["prompt_strict_accuracy"] == 50.0
assert abs(metrics["instruction_strict_accuracy"] - 200 / 3) < 1e-9
assert metrics["prompt_loose_accuracy"] == 100.0
assert metrics["instruction_loose_accuracy"] == 100.0