diff --git a/core/physical/action_validator.py b/core/physical/action_validator.py new file mode 100644 index 0000000..08b4ff2 --- /dev/null +++ b/core/physical/action_validator.py @@ -0,0 +1,537 @@ +""" +core/physical/action_validator.py + +Validates that the user performed a physical step correctly and safely +before Execra advances to the next step. + +Supports three task domains out of the box: + - cooking : knife/cutting-board proximity & body-safety checks + - hardware : screwdriver orientation check + - form : pen hand + paper proximity check + +The public surface is intentionally small: + validator = ActionValidator() + result = validator.validate(step, detections, hand_results, depth_map) +""" + +from __future__ import annotations + +import math +import re +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + + +# ────────────────────────────────────────────────────────────────────────────── +# Data models +# ────────────────────────────────────────────────────────────────────────────── + + +@dataclass +class BoundingBox: + """Normalised [0, 1] bounding box: (x_min, y_min, x_max, y_max).""" + + x_min: float + y_min: float + x_max: float + y_max: float + + # ── helpers ─────────────────────────────────────────────────────────────── + + @property + def centre(self) -> Tuple[float, float]: + return ((self.x_min + self.x_max) / 2, (self.y_min + self.y_max) / 2) + + @property + def area(self) -> float: + return max(0.0, self.x_max - self.x_min) * max(0.0, self.y_max - self.y_min) + + def distance_to(self, other: "BoundingBox") -> float: + """Euclidean distance between the two centres.""" + cx1, cy1 = self.centre + cx2, cy2 = other.centre + return math.sqrt((cx1 - cx2) ** 2 + (cy1 - cy2) ** 2) + + def overlaps(self, other: "BoundingBox", threshold: float = 0.0) -> bool: + """True when the intersection-over-union > *threshold*.""" + inter_x = max(0, min(self.x_max, other.x_max) - max(self.x_min, other.x_min)) + inter_y = max(0, min(self.y_max, other.y_max) - max(self.y_min, other.y_min)) + inter = inter_x * inter_y + union = self.area + other.area - inter + if union <= 0: + return False + return (inter / union) > threshold + + +@dataclass +class Detection: + """A single object detection from a CV model.""" + + label: str + confidence: float + bbox: BoundingBox + depth: Optional[float] = None # metres, if available + attributes: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class HandResult: + """Pose result for one detected hand.""" + + hand_label: str # "Left" | "Right" + confidence: float + bbox: BoundingBox + landmarks: Optional[List[Tuple[float, float]]] = None # 21 points, normalised + holding: Optional[str] = None # label of the held object, if resolved + + +@dataclass +class ValidationResult: + """Return value of :meth:`ActionValidator.validate`.""" + + is_valid: bool + confidence: float # overall confidence [0, 1] + issues: List[str] = field(default_factory=list) + corrections: List[str] = field(default_factory=list) + + # convenience ────────────────────────────────────────────────────────────── + + def merge(self, other: "ValidationResult") -> "ValidationResult": + """Combine two results (AND semantics: both must be valid).""" + return ValidationResult( + is_valid=self.is_valid and other.is_valid, + confidence=min(self.confidence, other.confidence), + issues=self.issues + other.issues, + corrections=self.corrections + other.corrections, + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# Task-domain keyword sets +# ────────────────────────────────────────────────────────────────────────────── + +_COOKING_KEYWORDS = re.compile( + r"\b(cut|slice|chop|dice|mince|peel|carve|trim|halve|quarter)\b", re.IGNORECASE +) +_HARDWARE_KEYWORDS = re.compile( + r"\b(screw|fasten|assemble|tighten|install|attach|mount|drive)\b", re.IGNORECASE +) +_FORM_KEYWORDS = re.compile( + r"\b(fill|write|sign|complete|enter|record|note|annotate)\b", re.IGNORECASE +) + + +# ────────────────────────────────────────────────────────────────────────────── +# Internal helpers +# ────────────────────────────────────────────────────────────────────────────── + + +def _find_by_label(detections: List[Detection], *labels: str) -> List[Detection]: + """Return detections whose label matches any of *labels* (case-insensitive).""" + lowered = {lbl.lower() for lbl in labels} + return [d for d in detections if d.label.lower() in lowered] + + +def _dominant_hand(hand_results: List[HandResult]) -> Optional[HandResult]: + """ + Return the dominant (right) hand result, falling back to whichever hand + is available if no right hand is detected. + """ + rights = [h for h in hand_results if h.hand_label.lower() == "right"] + if rights: + return max(rights, key=lambda h: h.confidence) + if hand_results: + return max(hand_results, key=lambda h: h.confidence) + return None + + +def _average_depth( + depth_map: Optional[List[List[float]]], bbox: BoundingBox +) -> Optional[float]: + """Sample the mean depth inside *bbox* from a 2-D depth map (H×W, metres).""" + if depth_map is None: + return None + h = len(depth_map) + w = len(depth_map[0]) if h else 0 + if h == 0 or w == 0: + return None + r0 = int(bbox.y_min * h) + r1 = max(r0 + 1, int(bbox.y_max * h)) + c0 = int(bbox.x_min * w) + c1 = max(c0 + 1, int(bbox.x_max * w)) + samples = [ + depth_map[r][c] + for r in range(r0, min(r1, h)) + for c in range(c0, min(c1, w)) + if depth_map[r][c] is not None + ] + return sum(samples) / len(samples) if samples else None + + +# ────────────────────────────────────────────────────────────────────────────── +# Domain validators +# ────────────────────────────────────────────────────────────────────────────── + + +class _CookingValidator: + """ + Validates cutting steps: + 1. A knife must be detected. + 2. The knife must be near a cutting board (proximity ≤ MAX_KNIFE_BOARD_DIST). + 3. The knife must NOT be near the user's torso/body + (proximity ≥ MIN_KNIFE_BODY_DIST). + """ + + MAX_KNIFE_BOARD_DIST: float = 0.35 # normalised image units + MIN_KNIFE_BODY_DIST: float = 0.30 + + # Labels that map to "torso / body region" + BODY_LABELS = ("person", "torso", "body", "hand", "wrist") + BOARD_LABELS = ("cutting board", "chopping board", "cutting_board", "chopping_board") + KNIFE_LABELS = ("knife", "cleaver", "blade") + + def validate( + self, + detections: List[Detection], + hand_results: List[HandResult], + depth_map: Optional[List[List[float]]], + ) -> ValidationResult: + issues: List[str] = [] + corrections: List[str] = [] + + knives = _find_by_label(detections, *self.KNIFE_LABELS) + boards = _find_by_label(detections, *self.BOARD_LABELS) + bodies = _find_by_label(detections, *self.BODY_LABELS) + + # ── knife presence ──────────────────────────────────────────────────── + if not knives: + issues.append("No knife detected in the scene.") + corrections.append( + "Please ensure the knife is visible to the camera before continuing." + ) + return ValidationResult( + is_valid=False, confidence=0.9, issues=issues, corrections=corrections + ) + + knife = max(knives, key=lambda d: d.confidence) + + # ── cutting board presence ──────────────────────────────────────────── + if not boards: + issues.append("No cutting board detected.") + corrections.append( + "Place the item you are cutting on a cutting board and position " + "both the board and knife within the camera's view." + ) + return ValidationResult( + is_valid=False, confidence=0.85, issues=issues, corrections=corrections + ) + + board = min(boards, key=lambda b: knife.bbox.distance_to(b.bbox)) + knife_board_dist = knife.bbox.distance_to(board.bbox) + + # ── knife near board ────────────────────────────────────────────────── + if knife_board_dist > self.MAX_KNIFE_BOARD_DIST: + issues.append( + f"Knife is too far from the cutting board " + f"(distance={knife_board_dist:.2f}, threshold={self.MAX_KNIFE_BOARD_DIST})." + ) + corrections.append( + "Move the cutting action directly over the cutting board." + ) + + # ── knife not near body ─────────────────────────────────────────────── + if bodies: + nearest_body = min(bodies, key=lambda b: knife.bbox.distance_to(b.bbox)) + knife_body_dist = knife.bbox.distance_to(nearest_body.bbox) + if knife_body_dist < self.MIN_KNIFE_BODY_DIST: + issues.append( + f"Knife is dangerously close to the body " + f"(distance={knife_body_dist:.2f}, minimum safe={self.MIN_KNIFE_BODY_DIST})." + ) + corrections.append( + "Keep the knife pointed away from your body and fingers. " + "Use a stable grip on the food and curl your fingertips inward." + ) + + confidence = knife.confidence * board.confidence if not issues else 0.3 + return ValidationResult( + is_valid=len(issues) == 0, + confidence=min(1.0, confidence), + issues=issues, + corrections=corrections, + ) + + +class _HardwareValidator: + """ + Validates assembly/screw-driving steps: + 1. A screwdriver must be detected. + 2. The screwdriver must be held roughly vertically + (the bounding-box aspect ratio is tall, not wide). + 3. Depth-map check: the tip must be at a similar depth to the workpiece. + """ + + # If bbox_height / bbox_width > this ratio, orientation is "vertical" + VERTICAL_RATIO_THRESHOLD: float = 1.3 + TOOL_LABELS = ( + "screwdriver", + "drill", + "wrench", + "spanner", + "power tool", + "power_tool", + ) + WORKPIECE_LABELS = ("pcb", "circuit board", "device", "component", "workpiece", "object") + # Maximum depth delta (metres) between tool tip and workpiece + MAX_DEPTH_DELTA: float = 0.20 + + def validate( + self, + detections: List[Detection], + hand_results: List[HandResult], + depth_map: Optional[List[List[float]]], + ) -> ValidationResult: + issues: List[str] = [] + corrections: List[str] = [] + + tools = _find_by_label(detections, *self.TOOL_LABELS) + + # ── tool presence ───────────────────────────────────────────────────── + if not tools: + issues.append("No screwdriver or assembly tool detected.") + corrections.append( + "Hold the screwdriver so that it is clearly visible to the camera." + ) + return ValidationResult( + is_valid=False, confidence=0.90, issues=issues, corrections=corrections + ) + + tool = max(tools, key=lambda d: d.confidence) + bbox = tool.bbox + height = max(1e-6, bbox.y_max - bbox.y_min) + width = max(1e-6, bbox.x_max - bbox.x_min) + aspect = height / width + + # ── orientation check ───────────────────────────────────────────────── + if aspect < self.VERTICAL_RATIO_THRESHOLD: + issues.append( + f"Screwdriver appears to be held horizontally " + f"(aspect ratio height/width={aspect:.2f}, " + f"required ≥ {self.VERTICAL_RATIO_THRESHOLD})." + ) + corrections.append( + "Rotate the screwdriver so the shaft points straight down " + "into the screw head. A vertical orientation ensures proper torque " + "and avoids stripping the screw." + ) + + # ── depth alignment (optional, requires depth_map) ──────────────────── + if depth_map is not None: + workpieces = _find_by_label(detections, *self.WORKPIECE_LABELS) + tool_depth = _average_depth(depth_map, bbox) + if workpieces and tool_depth is not None: + wp = min(workpieces, key=lambda w: tool.bbox.distance_to(w.bbox)) + wp_depth = _average_depth(depth_map, wp.bbox) + if wp_depth is not None: + delta = abs(tool_depth - wp_depth) + if delta > self.MAX_DEPTH_DELTA: + issues.append( + f"Tool tip depth ({tool_depth:.2f} m) is misaligned with " + f"workpiece depth ({wp_depth:.2f} m) — Δ={delta:.2f} m." + ) + corrections.append( + "Position the screwdriver tip directly above the screw " + "before applying pressure." + ) + + confidence = tool.confidence if not issues else 0.35 + return ValidationResult( + is_valid=len(issues) == 0, + confidence=min(1.0, confidence), + issues=issues, + corrections=corrections, + ) + + +class _FormFillingValidator: + """ + Validates form-filling / signing steps: + 1. A pen must be detected in the dominant hand's region. + 2. The pen must be near a paper/document (not near a phone/tablet). + """ + + PEN_LABELS = ("pen", "pencil", "marker", "stylus") + PAPER_LABELS = ("paper", "document", "form", "notebook", "sheet") + PHONE_LABELS = ("phone", "smartphone", "mobile", "tablet", "screen", "phone screen") + + MAX_PEN_PAPER_DIST: float = 0.30 + MIN_PEN_PHONE_DIST: float = 0.25 + + def validate( + self, + detections: List[Detection], + hand_results: List[HandResult], + depth_map: Optional[List[List[float]]], + ) -> ValidationResult: + issues: List[str] = [] + corrections: List[str] = [] + + pens = _find_by_label(detections, *self.PEN_LABELS) + papers = _find_by_label(detections, *self.PAPER_LABELS) + phones = _find_by_label(detections, *self.PHONE_LABELS) + + # ── pen presence ────────────────────────────────────────────────────── + if not pens: + issues.append("No pen or writing instrument detected.") + corrections.append( + "Pick up a pen and hold it so it is visible to the camera." + ) + return ValidationResult( + is_valid=False, confidence=0.90, issues=issues, corrections=corrections + ) + + pen = max(pens, key=lambda d: d.confidence) + + # ── pen in dominant hand ────────────────────────────────────────────── + dominant = _dominant_hand(hand_results) + if dominant is not None: + pen_hand_dist = pen.bbox.distance_to(dominant.bbox) + if pen_hand_dist > 0.15: + issues.append( + f"Pen does not appear to be in the dominant hand " + f"(distance to dominant hand bbox={pen_hand_dist:.2f})." + ) + corrections.append( + "Hold the pen in your dominant (writing) hand." + ) + + # ── pen near paper ──────────────────────────────────────────────────── + if not papers: + issues.append("No paper or document detected near the pen.") + corrections.append( + "Place the form or document flat on the desk within the camera's view " + "before writing." + ) + else: + nearest_paper = min(papers, key=lambda p: pen.bbox.distance_to(p.bbox)) + pen_paper_dist = pen.bbox.distance_to(nearest_paper.bbox) + if pen_paper_dist > self.MAX_PEN_PAPER_DIST: + issues.append( + f"Pen is too far from the paper " + f"(distance={pen_paper_dist:.2f}, threshold={self.MAX_PEN_PAPER_DIST})." + ) + corrections.append( + "Move the pen so that it is positioned over the form before writing." + ) + + # ── pen not near phone (digital device) ─────────────────────────────── + if phones: + nearest_phone = min(phones, key=lambda ph: pen.bbox.distance_to(ph.bbox)) + pen_phone_dist = pen.bbox.distance_to(nearest_phone.bbox) + if pen_phone_dist < self.MIN_PEN_PHONE_DIST: + issues.append( + f"Pen is near a phone/screen rather than paper " + f"(distance={pen_phone_dist:.2f})." + ) + corrections.append( + "Write on the physical paper form, not on a phone or tablet screen." + ) + + confidence = pen.confidence if not issues else 0.30 + return ValidationResult( + is_valid=len(issues) == 0, + confidence=min(1.0, confidence), + issues=issues, + corrections=corrections, + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# Public API +# ────────────────────────────────────────────────────────────────────────────── + + +class ActionValidator: + """ + Entry point for physical-action validation. + + Usage:: + + validator = ActionValidator() + result = validator.validate( + step="Chop the onions", + detections=[...], # List[Detection] + hand_results=[...], # List[HandResult] + depth_map=None, # Optional 2-D list (H x W), metres + ) + if not result.is_valid: + dispatch_guidance(result.corrections) + """ + + def __init__(self) -> None: + self._cooking = _CookingValidator() + self._hardware = _HardwareValidator() + self._form = _FormFillingValidator() + + # ── public ──────────────────────────────────────────────────────────────── + + def validate( + self, + step: str, + detections: List[Detection], + hand_results: List[HandResult], + depth_map: Optional[List[List[float]]] = None, + ) -> ValidationResult: + """ + Validate *step* against the current CV observations. + + Parameters + ---------- + step: + Natural-language description of the current task step, + e.g. ``"Chop the carrots"`` or ``"Drive the M3 screw"``. + detections: + Object detections from the current frame. + hand_results: + Hand-pose results from the current frame. + depth_map: + Optional H×W depth map (metres). When provided, depth-aware + checks are enabled. + + Returns + ------- + ValidationResult + ``is_valid=True`` if the step was performed correctly. + ``is_valid=False`` means at least one safety / correctness rule + failed; ``corrections`` carries actionable guidance for the user. + """ + domain = self._classify_step(step) + + if domain == "cooking": + return self._cooking.validate(detections, hand_results, depth_map) + if domain == "hardware": + return self._hardware.validate(detections, hand_results, depth_map) + if domain == "form": + return self._form.validate(detections, hand_results, depth_map) + + # Unknown domain → optimistic pass, low confidence + return ValidationResult( + is_valid=True, + confidence=0.50, + issues=[], + corrections=[], + ) + + # ── private ─────────────────────────────────────────────────────────────── + + @staticmethod + def _classify_step(step: str) -> str: + """Map a step description to one of the known task domains.""" + if _COOKING_KEYWORDS.search(step): + return "cooking" + if _HARDWARE_KEYWORDS.search(step): + return "hardware" + if _FORM_KEYWORDS.search(step): + return "form" + return "unknown" diff --git a/tests/test_action_validator.py b/tests/test_action_validator.py new file mode 100644 index 0000000..45e2fbb --- /dev/null +++ b/tests/test_action_validator.py @@ -0,0 +1,501 @@ +""" +tests/test_action_validator.py + +Unit tests for core.physical.action_validator. +All tests use synthetic detection data — no camera or CV runtime required. + +Run with: + pytest tests/test_action_validator.py -v +""" + +from __future__ import annotations + +import pytest + +from core.physical.action_validator import ( + ActionValidator, + BoundingBox, + Detection, + HandResult, + ValidationResult, + _CookingValidator, + _FormFillingValidator, + _HardwareValidator, +) + + +# ────────────────────────────────────────────────────────────────────────────── +# Helpers / fixtures +# ────────────────────────────────────────────────────────────────────────────── + + +def bbox(x_min=0.1, y_min=0.1, x_max=0.3, y_max=0.4) -> BoundingBox: + return BoundingBox(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max) + + +def det(label: str, x_min=0.1, y_min=0.1, x_max=0.3, y_max=0.4, confidence=0.95) -> Detection: + return Detection(label=label, confidence=confidence, bbox=bbox(x_min, y_min, x_max, y_max)) + + +def hand(label="Right", x_min=0.05, y_min=0.1, x_max=0.25, y_max=0.4, confidence=0.92) -> HandResult: + return HandResult( + hand_label=label, + confidence=confidence, + bbox=bbox(x_min, y_min, x_max, y_max), + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# BoundingBox unit tests +# ────────────────────────────────────────────────────────────────────────────── + + +class TestBoundingBox: + def test_centre(self): + b = BoundingBox(0.0, 0.0, 0.4, 0.6) + assert b.centre == pytest.approx((0.2, 0.3)) + + def test_area(self): + b = BoundingBox(0.0, 0.0, 0.5, 0.5) + assert b.area == pytest.approx(0.25) + + def test_distance_to_same_box_is_zero(self): + b = BoundingBox(0.1, 0.1, 0.3, 0.3) + assert b.distance_to(b) == pytest.approx(0.0) + + def test_distance_to_adjacent_box(self): + b1 = BoundingBox(0.0, 0.0, 0.2, 0.2) # centre (0.1, 0.1) + b2 = BoundingBox(0.6, 0.0, 0.8, 0.2) # centre (0.7, 0.1) + assert b1.distance_to(b2) == pytest.approx(0.6, abs=1e-6) + + def test_overlaps_identical_boxes(self): + b = BoundingBox(0.0, 0.0, 0.5, 0.5) + assert b.overlaps(b, threshold=0.0) + + def test_no_overlap(self): + b1 = BoundingBox(0.0, 0.0, 0.3, 0.3) + b2 = BoundingBox(0.7, 0.7, 1.0, 1.0) + assert not b1.overlaps(b2) + + +# ────────────────────────────────────────────────────────────────────────────── +# ActionValidator._classify_step +# ────────────────────────────────────────────────────────────────────────────── + + +class TestClassifyStep: + validator = ActionValidator() + + @pytest.mark.parametrize("step", [ + "Chop the onions", + "Slice the bread thinly", + "Dice the tomatoes", + "Mince the garlic", + "Peel the potatoes", + ]) + def test_cooking_steps(self, step): + assert self.validator._classify_step(step) == "cooking" + + @pytest.mark.parametrize("step", [ + "Screw in the bolt", + "Fasten the panel", + "Assemble the gearbox", + "Tighten the M4 screw", + "Install the heatsink", + ]) + def test_hardware_steps(self, step): + assert self.validator._classify_step(step) == "hardware" + + @pytest.mark.parametrize("step", [ + "Fill in the application form", + "Write your name on line 3", + "Sign at the bottom of the page", + "Complete the address field", + "Enter your date of birth", + ]) + def test_form_steps(self, step): + assert self.validator._classify_step(step) == "form" + + def test_unknown_step(self): + assert self.validator._classify_step("Walk to the kitchen") == "unknown" + + +# ────────────────────────────────────────────────────────────────────────────── +# Cooking validator +# ────────────────────────────────────────────────────────────────────────────── + + +class TestCookingValidator: + v = _CookingValidator() + step = "Chop the onions" + + # ── happy path ──────────────────────────────────────────────────────────── + + def test_valid_scene(self): + """Knife near board, far from body → valid.""" + detections = [ + det("knife", x_min=0.3, y_min=0.4, x_max=0.5, y_max=0.6), + det("cutting board", x_min=0.25, y_min=0.35, x_max=0.55, y_max=0.65), + det("person", x_min=0.0, y_min=0.0, x_max=0.15, y_max=0.25), + ] + result = self.v.validate(detections, [hand()], None) + assert result.is_valid + assert result.confidence > 0.5 + assert result.issues == [] + assert result.corrections == [] + + # ── knife absent ────────────────────────────────────────────────────────── + + def test_no_knife_detected(self): + detections = [det("cutting board")] + result = self.v.validate(detections, [], None) + assert not result.is_valid + assert any("knife" in i.lower() for i in result.issues) + assert len(result.corrections) >= 1 + + # ── cutting board absent ────────────────────────────────────────────────── + + def test_no_cutting_board(self): + detections = [det("knife")] + result = self.v.validate(detections, [], None) + assert not result.is_valid + assert any("cutting board" in i.lower() for i in result.issues) + + # ── knife too far from board ─────────────────────────────────────────────── + + def test_knife_far_from_board(self): + """Knife in top-left, board in bottom-right → distance ~0.85 > threshold.""" + detections = [ + det("knife", x_min=0.0, y_min=0.0, x_max=0.15, y_max=0.15), + det("cutting board", x_min=0.8, y_min=0.8, x_max=1.0, y_max=1.0), + ] + result = self.v.validate(detections, [], None) + assert not result.is_valid + assert any("far" in i.lower() or "cutting board" in i.lower() for i in result.issues) + assert any("cutting board" in c.lower() for c in result.corrections) + + # ── knife near body ─────────────────────────────────────────────────────── + + def test_knife_near_body(self): + """Knife overlapping person bbox → safety violation.""" + detections = [ + det("knife", x_min=0.3, y_min=0.3, x_max=0.5, y_max=0.6), + det("cutting board", x_min=0.29, y_min=0.29, x_max=0.55, y_max=0.65), + # person bbox nearly identical to knife → distance ≈ 0 + det("person", x_min=0.3, y_min=0.3, x_max=0.5, y_max=0.6), + ] + result = self.v.validate(detections, [], None) + assert not result.is_valid + assert any("body" in i.lower() or "dangerously" in i.lower() for i in result.issues) + assert any("body" in c.lower() or "fingers" in c.lower() for c in result.corrections) + + # ── multiple issues ─────────────────────────────────────────────────────── + + def test_knife_far_from_board_and_near_body(self): + """Both distance rules violated simultaneously → two issues.""" + detections = [ + det("knife", x_min=0.05, y_min=0.05, x_max=0.20, y_max=0.20), + det("cutting board", x_min=0.80, y_min=0.80, x_max=1.00, y_max=1.00), + det("person", x_min=0.05, y_min=0.05, x_max=0.20, y_max=0.20), + ] + result = self.v.validate(detections, [], None) + assert not result.is_valid + assert len(result.issues) >= 2 + + # ── label variations ───────────────────────────────────────────────────── + + def test_chopping_board_label(self): + """'chopping board' should be treated the same as 'cutting board'.""" + detections = [ + det("knife", x_min=0.3, y_min=0.4, x_max=0.5, y_max=0.6), + det("chopping board", x_min=0.25, y_min=0.35, x_max=0.55, y_max=0.65), + ] + result = self.v.validate(detections, [], None) + assert result.is_valid + + def test_cleaver_label(self): + """'cleaver' should be treated as a knife.""" + detections = [ + det("cleaver", x_min=0.3, y_min=0.4, x_max=0.5, y_max=0.6), + det("cutting board", x_min=0.25, y_min=0.35, x_max=0.55, y_max=0.65), + ] + result = self.v.validate(detections, [], None) + assert result.is_valid + + +# ────────────────────────────────────────────────────────────────────────────── +# Hardware / assembly validator +# ────────────────────────────────────────────────────────────────────────────── + + +class TestHardwareValidator: + v = _HardwareValidator() + + # ── happy path ──────────────────────────────────────────────────────────── + + def test_vertical_screwdriver_valid(self): + """Tall bounding box → vertical orientation → valid.""" + detections = [ + # height = 0.6, width = 0.1 → aspect = 6.0 ≥ threshold + det("screwdriver", x_min=0.4, y_min=0.1, x_max=0.5, y_max=0.7, confidence=0.95), + ] + result = self.v.validate(detections, [hand()], None) + assert result.is_valid + assert result.confidence > 0.5 + + # ── no tool ─────────────────────────────────────────────────────────────── + + def test_no_screwdriver_detected(self): + result = self.v.validate([], [], None) + assert not result.is_valid + assert any("screwdriver" in i.lower() or "tool" in i.lower() for i in result.issues) + + # ── horizontal (wrong) orientation ──────────────────────────────────────── + + def test_horizontal_screwdriver_invalid(self): + """Wide bounding box → horizontal orientation → invalid.""" + detections = [ + # height = 0.1, width = 0.6 → aspect ≈ 0.17 < threshold + det("screwdriver", x_min=0.1, y_min=0.4, x_max=0.7, y_max=0.5, confidence=0.90), + ] + result = self.v.validate(detections, [], None) + assert not result.is_valid + assert any("horizontal" in i.lower() for i in result.issues) + assert any("vertical" in c.lower() for c in result.corrections) + + # ── borderline aspect ratio ─────────────────────────────────────────────── + + def test_borderline_orientation(self): + """Aspect ratio exactly at threshold should pass.""" + threshold = _HardwareValidator.VERTICAL_RATIO_THRESHOLD + # height / width = threshold → should be valid + width = 0.2 + height = width * threshold + detections = [ + det("screwdriver", + x_min=0.3, y_min=0.2, + x_max=0.3 + width, y_max=0.2 + height, + confidence=0.88), + ] + result = self.v.validate(detections, [], None) + # Aspect exactly equals threshold, so is_valid depends on strict vs ≥ + # Our implementation uses "<", so equal should be valid + assert result.is_valid + + # ── depth-map alignment ─────────────────────────────────────────────────── + + def test_depth_alignment_correct(self): + """Tool and workpiece at similar depth → no depth issue.""" + detections = [ + det("screwdriver", x_min=0.4, y_min=0.1, x_max=0.5, y_max=0.7), + det("pcb", x_min=0.35, y_min=0.60, x_max=0.55, y_max=0.80), + ] + # Uniform depth map at 0.5 m + depth_map = [[0.5] * 100 for _ in range(100)] + result = self.v.validate(detections, [], depth_map) + assert result.is_valid + + def test_depth_alignment_misaligned(self): + """Tool much closer than workpiece → depth violation.""" + detections = [ + det("screwdriver", x_min=0.4, y_min=0.1, x_max=0.5, y_max=0.7), + det("pcb", x_min=0.35, y_min=0.60, x_max=0.55, y_max=0.80), + ] + # Build a depth map where tool region = 0.20 m, workpiece region = 0.90 m + depth_map = [[0.20] * 100 for _ in range(100)] + for r in range(60, 80): + for c in range(35, 55): + depth_map[r][c] = 0.90 + result = self.v.validate(detections, [], depth_map) + assert not result.is_valid + assert any("depth" in i.lower() or "misaligned" in i.lower() for i in result.issues) + + # ── alternative tool labels ─────────────────────────────────────────────── + + def test_drill_label_accepted(self): + detections = [det("drill", x_min=0.4, y_min=0.1, x_max=0.5, y_max=0.8)] + result = self.v.validate(detections, [], None) + assert result.is_valid + + +# ────────────────────────────────────────────────────────────────────────────── +# Form-filling validator +# ────────────────────────────────────────────────────────────────────────────── + + +class TestFormFillingValidator: + v = _FormFillingValidator() + + # ── happy path ──────────────────────────────────────────────────────────── + + def test_pen_in_hand_on_paper(self): + """Pen near dominant hand, near paper, far from phone → valid.""" + right_hand = hand("Right", x_min=0.30, y_min=0.40, x_max=0.50, y_max=0.60) + detections = [ + det("pen", x_min=0.32, y_min=0.42, x_max=0.48, y_max=0.58), + det("paper", x_min=0.20, y_min=0.30, x_max=0.70, y_max=0.80), + ] + result = self.v.validate(detections, [right_hand], None) + assert result.is_valid + assert result.issues == [] + + # ── no pen ──────────────────────────────────────────────────────────────── + + def test_no_pen_detected(self): + detections = [det("paper")] + result = self.v.validate(detections, [], None) + assert not result.is_valid + assert any("pen" in i.lower() or "writing" in i.lower() for i in result.issues) + + # ── no paper ───────────────────────────────────────────────────────────── + + def test_no_paper_detected(self): + detections = [det("pen")] + result = self.v.validate(detections, [hand()], None) + assert not result.is_valid + assert any("paper" in i.lower() or "document" in i.lower() for i in result.issues) + + # ── pen far from paper ──────────────────────────────────────────────────── + + def test_pen_far_from_paper(self): + """Pen top-left, paper bottom-right → exceeds distance threshold.""" + detections = [ + det("pen", x_min=0.0, y_min=0.0, x_max=0.1, y_max=0.1), + det("paper", x_min=0.8, y_min=0.8, x_max=1.0, y_max=1.0), + ] + result = self.v.validate(detections, [], None) + assert not result.is_valid + assert any("far" in i.lower() for i in result.issues) + + # ── pen near phone (wrong surface) ─────────────────────────────────────── + + def test_pen_near_phone_invalid(self): + """Writing on a phone screen instead of paper → flagged.""" + detections = [ + det("pen", x_min=0.30, y_min=0.40, x_max=0.50, y_max=0.60), + det("paper", x_min=0.20, y_min=0.30, x_max=0.70, y_max=0.80), + det("phone", x_min=0.32, y_min=0.42, x_max=0.48, y_max=0.58), + ] + result = self.v.validate(detections, [], None) + assert not result.is_valid + assert any("phone" in i.lower() or "screen" in i.lower() for i in result.issues) + assert any("paper" in c.lower() or "form" in c.lower() for c in result.corrections) + + # ── wrong hand ──────────────────────────────────────────────────────────── + + def test_pen_not_in_dominant_hand(self): + """Pen is far from the right hand → issue raised.""" + right_hand = hand("Right", x_min=0.70, y_min=0.70, x_max=0.90, y_max=0.90) + detections = [ + det("pen", x_min=0.05, y_min=0.05, x_max=0.20, y_max=0.20), + det("paper", x_min=0.02, y_min=0.02, x_max=0.25, y_max=0.25), + ] + result = self.v.validate(detections, [right_hand], None) + assert not result.is_valid + assert any("dominant" in i.lower() or "hand" in i.lower() for i in result.issues) + + # ── label variants ──────────────────────────────────────────────────────── + + def test_pencil_label_accepted(self): + right_hand = hand("Right", x_min=0.30, y_min=0.40, x_max=0.50, y_max=0.60) + detections = [ + det("pencil", x_min=0.32, y_min=0.42, x_max=0.48, y_max=0.58), + det("document", x_min=0.20, y_min=0.30, x_max=0.70, y_max=0.80), + ] + result = self.v.validate(detections, [right_hand], None) + assert result.is_valid + + def test_marker_label_accepted(self): + right_hand = hand("Right", x_min=0.30, y_min=0.40, x_max=0.50, y_max=0.60) + detections = [ + det("marker", x_min=0.32, y_min=0.42, x_max=0.48, y_max=0.58), + det("notebook", x_min=0.20, y_min=0.30, x_max=0.70, y_max=0.80), + ] + result = self.v.validate(detections, [right_hand], None) + assert result.is_valid + + +# ────────────────────────────────────────────────────────────────────────────── +# ActionValidator (top-level dispatch) +# ────────────────────────────────────────────────────────────────────────────── + + +class TestActionValidator: + validator = ActionValidator() + + def test_cooking_step_dispatched(self): + detections = [ + det("knife", x_min=0.3, y_min=0.4, x_max=0.5, y_max=0.6), + det("cutting board", x_min=0.25, y_min=0.35, x_max=0.55, y_max=0.65), + ] + result = self.validator.validate("Slice the tomatoes", detections, [hand()], None) + assert isinstance(result, ValidationResult) + assert result.is_valid + + def test_hardware_step_dispatched(self): + detections = [det("screwdriver", x_min=0.4, y_min=0.1, x_max=0.5, y_max=0.7)] + result = self.validator.validate("Drive the bolt in", detections, [], None) + assert isinstance(result, ValidationResult) + assert result.is_valid + + def test_form_step_dispatched(self): + right_hand = hand("Right", x_min=0.30, y_min=0.40, x_max=0.50, y_max=0.60) + detections = [ + det("pen", x_min=0.32, y_min=0.42, x_max=0.48, y_max=0.58), + det("paper", x_min=0.20, y_min=0.30, x_max=0.70, y_max=0.80), + ] + result = self.validator.validate("Fill in your name", detections, [right_hand], None) + assert isinstance(result, ValidationResult) + assert result.is_valid + + def test_unknown_step_returns_optimistic(self): + result = self.validator.validate("Walk to the fridge", [], [], None) + assert result.is_valid + assert result.confidence == pytest.approx(0.50) + assert result.issues == [] + + def test_invalid_cooking_step_includes_corrections(self): + # Only a knife, no board → invalid + result = self.validator.validate("Chop the onion", [det("knife")], [], None) + assert not result.is_valid + assert len(result.corrections) >= 1 + + def test_invalid_hardware_step_includes_corrections(self): + # Horizontal screwdriver + detections = [det("screwdriver", x_min=0.1, y_min=0.4, x_max=0.7, y_max=0.5)] + result = self.validator.validate("Tighten the screw", detections, [], None) + assert not result.is_valid + assert len(result.corrections) >= 1 + + def test_invalid_form_step_includes_corrections(self): + result = self.validator.validate("Sign the form", [], [], None) + assert not result.is_valid + assert len(result.corrections) >= 1 + + # ── ValidationResult.merge ──────────────────────────────────────────────── + + def test_validation_result_merge_both_valid(self): + r1 = ValidationResult(is_valid=True, confidence=0.9) + r2 = ValidationResult(is_valid=True, confidence=0.8) + merged = r1.merge(r2) + assert merged.is_valid + assert merged.confidence == pytest.approx(0.8) + + def test_validation_result_merge_one_invalid(self): + r1 = ValidationResult(is_valid=True, confidence=0.9, issues=[], corrections=[]) + r2 = ValidationResult(is_valid=False, confidence=0.4, + issues=["bad"], corrections=["fix it"]) + merged = r1.merge(r2) + assert not merged.is_valid + assert merged.confidence == pytest.approx(0.4) + assert "bad" in merged.issues + assert "fix it" in merged.corrections + + def test_validation_result_merge_issues_concatenated(self): + r1 = ValidationResult(is_valid=False, confidence=0.5, + issues=["issue A"], corrections=["fix A"]) + r2 = ValidationResult(is_valid=False, confidence=0.3, + issues=["issue B"], corrections=["fix B"]) + merged = r1.merge(r2) + assert merged.issues == ["issue A", "issue B"] + assert merged.corrections == ["fix A", "fix B"]