diff --git a/README.md b/README.md index d92e813..2ba2b03 100644 --- a/README.md +++ b/README.md @@ -24,20 +24,94 @@ pip install aimx - `aimx version` - `aimx doctor` - `aimx query` +- `aimx trace` These commands explain how `aimx` works, show the `aimx` version, and report -whether native Aim is available for passthrough. `aimx query` adds a read-only -CLI for querying metric and image data from a local Aim repository. +whether native Aim is available for passthrough. -Query usage: +`--repo` is optional for owned `query` and `trace` commands and defaults to the +current directory `.`. When provided, it accepts either the repository root, +such as `data`, or the metadata directory itself, such as `data/.aim`. + +### `aimx query` — discover and summarise metrics + +Queries an Aim repository and shows a grouped table with per-metric statistics +(step count, last value, min/max with step). ```bash +# If your current working directory is the Aim repo root, --repo can be omitted +aimx query metrics "metric.name == 'loss'" + +# Rich table (default, colored in terminal) aimx query metrics "metric.name == 'loss'" --repo data -aimx query images "images" --repo data --json + +# Short run hashes are transparently expanded to full hashes +aimx query metrics "run.hash=='eca37394' and metric.name=='loss'" --repo data + +# Tab-separated plain text, suitable for awk/grep +aimx query metrics "metric.name == 'loss'" --repo data --oneline + +# Structured JSON (nested by run) +aimx query metrics "metric.name == 'loss'" --repo data --json + +# Step range filter — statistics recomputed within the window +aimx query metrics "metric.name == 'loss'" --repo data --steps 100:500 +aimx query metrics "metric.name == 'loss'" --repo data --steps :50 # first 50 steps +aimx query metrics "metric.name == 'loss'" --repo data --steps 100: # from step 100 onwards + +# Combine short hash + step range +aimx query metrics "run.hash=='eca37394' and metric.name=='loss'" --repo data --steps 100:300 + +# Images +aimx query images "images" --repo data +``` + +Output modes: `--json` (nested runs→metrics), `--oneline` / `--plain` (tab-separated), +default (rich table). Additional flags: `--steps start:end`, `--no-color`, `--verbose`. + +### `aimx trace` — plot or export a metric time series + +Fetches the full value sequence for one or more metrics and renders a curve, +table, or structured export. Multiple matching runs are overlaid on the same plot. + +```bash +# If your current working directory is the Aim repo root, --repo can be omitted +aimx trace "metric.name=='loss'" + +# Plot loss curve for a specific run — short hash transparently expanded +aimx trace "run.hash=='eca37394' and metric.name=='loss'" --repo data + +# Compare train vs val loss across all runs +aimx trace "metric.name=='loss'" --repo data + +# Step-by-step table +aimx trace "metric.name=='loss'" --repo data --table + +# CSV export +aimx trace "metric.name=='loss'" --repo data --csv > loss.csv + +# JSON with full value arrays +aimx trace "metric.name=='loss'" --repo data --json + +# Step range filter (hard constraint, applied before sampling) +aimx trace "metric.name=='loss'" --repo data --steps 100:500 +aimx trace "metric.name=='loss'" --repo data --steps :50 # first 50 steps +aimx trace "metric.name=='loss'" --repo data --steps 100: # step 100 onwards + +# Combine step filter + JSON +aimx trace "run.hash=='eca37394' and metric.name=='loss'" --repo data --steps 1:200 --json + +# Limit to first 50 points per series (density subsampling, applied after --steps) +aimx trace "metric.name=='loss'" --repo data --head 50 + +# Sample every 10th point +aimx trace "metric.name=='loss'" --repo data --every 10 ``` -`--repo` accepts either the repository root, such as `data`, or the metadata -directory itself, such as `data/.aim`. +Output modes: default (plotext chart), `--table`, `--csv`, `--json`. +Step filtering: `--steps start:end` (inclusive, open-ended sides allowed). +Sampling: `--head N`, `--tail N`, `--every K`. +Display: `--width W`, `--height H`, `--no-color`. ## What aimx delegates diff --git a/pyproject.toml b/pyproject.toml index 5a9616d..fc99a4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,11 @@ version = "0.2.0" description = "A safe CLI-first companion for native Aim" readme = "README.md" requires-python = ">=3.10,<3.13" -dependencies = [] +dependencies = [ + "numpy>=1.24", + "plotext>=5.3", + "rich>=13.7", +] [project.scripts] aimx = "aimx.__main__:main" diff --git a/src/aimx/aim_bridge/__init__.py b/src/aimx/aim_bridge/__init__.py new file mode 100644 index 0000000..9d48db4 --- /dev/null +++ b/src/aimx/aim_bridge/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/src/aimx/aim_bridge/hash_resolver.py b/src/aimx/aim_bridge/hash_resolver.py new file mode 100644 index 0000000..72d0807 --- /dev/null +++ b/src/aimx/aim_bridge/hash_resolver.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import re +from pathlib import Path + +try: + from aim import Repo +except ModuleNotFoundError: # aim not installed; errors surface at call time + Repo = None # type: ignore[assignment,misc] + +# Matches: run.hash == 'x' run.hash=='x' run.hash == "x" run.hash != 'x' +# Capture groups: (1) operator token (2) quote char (3) hash literal +_HASH_LITERAL_RE = re.compile( + r"""(run\.hash\s*(?:==|!=)\s*)(['"])([0-9a-fA-F]+)\2""" +) + +_FULL_HASH_LEN = 32 + + +def resolve_hash_prefixes(expression: str, repo_path: Path) -> str: + """Rewrite short run.hash literals in *expression* to full hashes. + + Rules: + - A literal shorter than ``_FULL_HASH_LEN`` hex chars is treated as a prefix. + - A literal of exactly ``_FULL_HASH_LEN`` chars passes through unchanged. + - Matching is case-insensitive (input normalised to lower-case). + - Ambiguous prefix → ``ValueError`` listing candidate previews. + - No matching run → ``ValueError``. + - No ``run.hash`` literal in *expression* → expression returned as-is + without querying the repository. + """ + if not _HASH_LITERAL_RE.search(expression): + return expression + + if Repo is None: + raise RuntimeError( + "`aimx` requires the Python `aim` package in the current environment." + ) + + repo = Repo(str(repo_path)) + all_hashes: list[str] = repo.list_all_runs() + + def _replace(m: re.Match[str]) -> str: + operator_token = m.group(1) + quote = m.group(2) + value = m.group(3).lower() + + if len(value) >= _FULL_HASH_LEN: + return m.group(0) + + candidates = [h for h in all_hashes if h.startswith(value)] + + if not candidates: + raise ValueError( + f"Short hash '{m.group(3)}' did not match any run in the repository." + ) + if len(candidates) > 1: + preview = ", ".join(c[:12] for c in candidates[:5]) + suffix = f" (+{len(candidates) - 5} more)" if len(candidates) > 5 else "" + raise ValueError( + f"Short hash '{m.group(3)}' is ambiguous — matches {len(candidates)} runs: " + f"{preview}{suffix}. Provide more characters." + ) + + return f"{operator_token}{quote}{candidates[0]}{quote}" + + return _HASH_LITERAL_RE.sub(_replace, expression) diff --git a/src/aimx/aim_bridge/metric_stats.py b/src/aimx/aim_bridge/metric_stats.py new file mode 100644 index 0000000..e052797 --- /dev/null +++ b/src/aimx/aim_bridge/metric_stats.py @@ -0,0 +1,270 @@ +from __future__ import annotations + +import contextlib +import datetime as dt +import io +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np + + +@dataclass(frozen=True) +class RunMeta: + hash: str + experiment: str | None + name: str | None + creation_time: float | None + + +@dataclass +class MetricSeries: + run: RunMeta + name: str + context: dict[str, Any] + values: np.ndarray + steps: np.ndarray + epochs: np.ndarray | None + + @property + def count(self) -> int: + return len(self.values) + + @property + def last(self) -> tuple[float, int]: + if len(self.values) == 0: + return (float("nan"), -1) + idx = len(self.values) - 1 + return (float(self.values[idx]), int(self.steps[idx])) + + @property + def min(self) -> tuple[float, int]: + if len(self.values) == 0: + return (float("nan"), -1) + idx = int(np.argmin(self.values)) + return (float(self.values[idx]), int(self.steps[idx])) + + @property + def max(self) -> tuple[float, int]: + if len(self.values) == 0: + return (float("nan"), -1) + idx = int(np.argmax(self.values)) + return (float(self.values[idx]), int(self.steps[idx])) + + +def _extract_run_meta(run: Any) -> RunMeta: + creation_time = getattr(run, "creation_time", None) + if creation_time is None: + created_at = getattr(run, "created_at", None) + if isinstance(created_at, dt.datetime): + if created_at.tzinfo is None: + creation_time = created_at.replace(tzinfo=dt.timezone.utc).timestamp() + else: + creation_time = created_at.timestamp() + + return RunMeta( + hash=run.hash, + experiment=getattr(run, "experiment", None), + name=getattr(run, "name", None), + creation_time=float(creation_time) if creation_time is not None else None, + ) + + +def _extract_values(metric: Any) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]: + """Extract aligned values, steps, and epochs from an Aim metric. + + Aim exposes the canonical series order and metadata via + ``metric.data.items_list()``. This preserves the distinction between + user-provided steps and tracked epochs, which may differ. + """ + try: + steps, (values, epochs, _timestamps) = metric.data.items_list() + except ValueError: + return np.array([], dtype=float), np.array([], dtype=int), None + + return ( + np.array(values, dtype=float), + np.array(steps, dtype=int), + np.array(epochs, dtype=float), + ) + + +def collect_metric_series(expression: str, repo_path: Path) -> list[MetricSeries]: + """Run the Aim query expression and return a flat list of MetricSeries. + + Short run.hash literals in *expression* are transparently expanded to full + hashes before being forwarded to Aim. Aim's own progress output is silenced + via stderr redirection. + """ + from aimx.aim_bridge.hash_resolver import resolve_hash_prefixes + + expression = resolve_hash_prefixes(expression, repo_path) + + try: + from aim import Repo + from aim.sdk.types import QueryReportMode + except ModuleNotFoundError as error: + raise RuntimeError( + "`aimx` requires the Python `aim` package in the current environment." + ) from error + + repo = Repo(str(repo_path)) + results: list[MetricSeries] = [] + + stderr_buf = io.StringIO() + with contextlib.redirect_stderr(stderr_buf): + query_result = repo.query_metrics( + expression, report_mode=QueryReportMode.DISABLED + ) + for run_collection in query_result.iter_runs(): + for metric in run_collection: + run_meta = _extract_run_meta(metric.run) + values, steps, epochs = _extract_values(metric) + results.append( + MetricSeries( + run=run_meta, + name=metric.name, + context=metric.context.to_dict(), + values=values, + steps=steps, + epochs=epochs, + ) + ) + + return results + + +def collect_image_series(expression: str, repo_path: Path) -> list[dict[str, Any]]: + """Run an image query and return a flat list of image record dicts. + + Short run.hash literals in *expression* are transparently expanded before + being forwarded to Aim. + """ + from aimx.aim_bridge.hash_resolver import resolve_hash_prefixes + + expression = resolve_hash_prefixes(expression, repo_path) + + try: + from aim import Repo + from aim.sdk.types import QueryReportMode + except ModuleNotFoundError as error: + raise RuntimeError( + "`aimx` requires the Python `aim` package in the current environment." + ) from error + + repo = Repo(str(repo_path)) + rows: list[dict[str, Any]] = [] + + stderr_buf = io.StringIO() + with contextlib.redirect_stderr(stderr_buf): + query_result = repo.query_images( + expression, report_mode=QueryReportMode.DISABLED + ) + for image in query_result.iter(): + run_meta = _extract_run_meta(image.run) + rows.append( + { + "run": run_meta, + "name": image.name, + "context": image.context.to_dict(), + } + ) + + return rows + + +def subsample(series: MetricSeries, *, head: int | None, tail: int | None, every: int | None) -> MetricSeries: + """Return a new MetricSeries with points filtered by head/tail/every.""" + n = len(series.values) + if n == 0: + return series + + indices = np.arange(n) + if head is not None: + indices = indices[:head] + if tail is not None: + indices = indices[-tail:] + if every is not None and every > 1: + indices = indices[::every] + + epochs_slice = series.epochs[indices] if series.epochs is not None else None + return MetricSeries( + run=series.run, + name=series.name, + context=series.context, + values=series.values[indices], + steps=series.steps[indices], + epochs=epochs_slice, + ) + + +def parse_step_slice(s: str) -> tuple[int | None, int | None]: + """Parse a ``start:end`` slice string into inclusive integer bounds. + + - ``"100:500"`` → ``(100, 500)`` + - ``"100:"`` → ``(100, None)`` + - ``":500"`` → ``(None, 500)`` + - ``":"`` → ``ValueError`` + - No colon → ``ValueError`` + """ + if ":" not in s: + raise ValueError( + f"--steps requires 'start:end' slice syntax (e.g. '100:500', ':500', '100:'), got: {s!r}" + ) + left, right = s.split(":", 1) + start: int | None = None + end: int | None = None + if left.strip(): + try: + start = int(left.strip()) + except ValueError: + raise ValueError(f"--steps: left bound is not an integer: {left!r}") + if right.strip(): + try: + end = int(right.strip()) + except ValueError: + raise ValueError(f"--steps: right bound is not an integer: {right!r}") + if start is None and end is None: + raise ValueError("--steps cannot be an open slice ':'; provide at least one bound.") + return start, end + + +def filter_by_step_range( + series: MetricSeries, + start: int | None, + end: int | None, +) -> MetricSeries: + """Return a new ``MetricSeries`` keeping only points where ``start <= step <= end``. + + Open-ended bounds (``None``) mean no constraint on that side. + """ + mask = np.ones(len(series.steps), dtype=bool) + if start is not None: + mask &= series.steps >= start + if end is not None: + mask &= series.steps <= end + epochs_slice = series.epochs[mask] if series.epochs is not None else None + return MetricSeries( + run=series.run, + name=series.name, + context=series.context, + values=series.values[mask], + steps=series.steps[mask], + epochs=epochs_slice, + ) + + +def group_by_run( + series_list: list[MetricSeries], +) -> list[tuple[RunMeta, list[MetricSeries]]]: + """Group a flat list of MetricSeries by run hash, preserving insertion order.""" + order: list[str] = [] + groups: dict[str, tuple[RunMeta, list[MetricSeries]]] = {} + for series in series_list: + h = series.run.hash + if h not in groups: + order.append(h) + groups[h] = (series.run, []) + groups[h][1].append(series) + return [groups[h] for h in order] diff --git a/src/aimx/cli.py b/src/aimx/cli.py index 6be2854..c479d99 100644 --- a/src/aimx/cli.py +++ b/src/aimx/cli.py @@ -5,6 +5,7 @@ from aimx.commands.doctor import render_doctor from aimx.commands.help import render_help from aimx.commands.query import run_query_command +from aimx.commands.trace import run_trace_command from aimx.commands.version import render_version from aimx.native_aim.locator import resolve_native_aim from aimx.native_aim.passthrough import run_passthrough @@ -33,6 +34,13 @@ def run_cli(args: list[str]) -> int: if result.error_message: sys.stderr.write(f"{result.error_message}\n") return result.exit_status + if command == "trace": + result = run_trace_command(route.owned_args or []) + if result.output: + sys.stdout.write(f"{result.output}\n") + if result.error_message: + sys.stderr.write(f"{result.error_message}\n") + return result.exit_status raise ValueError(f"Unsupported owned command: {command}") result = run_passthrough(route.delegated_args or [], resolution) diff --git a/src/aimx/commands/help.py b/src/aimx/commands/help.py index ca71749..6bac2ff 100644 --- a/src/aimx/commands/help.py +++ b/src/aimx/commands/help.py @@ -11,8 +11,22 @@ def render_help() -> str: " version Show the aimx version and detected native Aim version", " doctor Show native Aim availability and passthrough readiness", " query Query metrics or images from a local Aim repository", - " Usage: aimx query --repo [--json]", - " Repo paths may point at either the repo root or its .aim directory", + " Usage: aimx query [--repo ]", + " Options: --json --oneline --no-color --verbose", + " --steps start:end (e.g. --steps 100:500, :50, 100:)", + " Repo defaults to the current directory; paths may point at either", + " the repo root or its .aim directory", + " Short run hashes in the expression are transparently expanded.", + " Example: aimx query metrics \"run.hash=='eca37394'\" --repo data", + " trace Plot a metric's time-series from a local Aim repository", + " Usage: aimx trace [--repo ]", + " Options: --table --csv --json", + " --steps start:end (e.g. --steps 100:500, :50, 100:)", + " --head N --tail N --every K", + " --width W --height H --no-color", + " Repo defaults to the current directory.", + " Short run hashes in the expression are transparently expanded.", + " Example: aimx trace \"metric.name=='loss'\" --repo data --steps 100:500", "", "All other commands are delegated to native `aim`.", ] diff --git a/src/aimx/commands/query.py b/src/aimx/commands/query.py index 223c8f3..226f53a 100644 --- a/src/aimx/commands/query.py +++ b/src/aimx/commands/query.py @@ -1,6 +1,6 @@ from __future__ import annotations -import json +import sys from dataclasses import dataclass from pathlib import Path from typing import Any @@ -14,6 +14,10 @@ class QueryInvocation: expression: str repo_path: Path output_json: bool = False + plain: bool = False + no_color: bool = False + verbose: bool = False + step_slice: str | None = None def __post_init__(self) -> None: if self.target not in SUPPORTED_TARGETS: @@ -31,18 +35,6 @@ class QueryCommandResult: error_message: str | None = None -def load_aim_query_support(): - try: - from aim import Repo - from aim.sdk.types import QueryReportMode - except ModuleNotFoundError as error: - raise RuntimeError( - "`aimx query` requires the Python `aim` package in the current environment." - ) from error - - return Repo, QueryReportMode - - def normalize_repo_path(path: Path) -> Path: if not path.exists(): raise ValueError(f"Repository path does not exist: {path}") @@ -52,9 +44,10 @@ def normalize_repo_path(path: Path) -> Path: def parse_query_invocation(args: list[str]) -> QueryInvocation: - if len(args) < 4: + if len(args) < 2: raise ValueError( - "Usage: aimx query --repo [--json]" + "Usage: aimx query [--repo ] " + "[--json] [--oneline] [--no-color] [--verbose] [--steps start:end]" ) target = args[0] @@ -62,30 +55,49 @@ def parse_query_invocation(args: list[str]) -> QueryInvocation: rest = args[2:] output_json = False - repo_value: str | None = None + plain = False + no_color = False + verbose = False + step_slice: str | None = None + repo_value = "." + index = 0 while index < len(rest): token = rest[index] if token == "--json": output_json = True index += 1 - continue - if token == "--repo": + elif token in ("--oneline", "--plain"): + plain = True + index += 1 + elif token == "--no-color": + no_color = True + index += 1 + elif token == "--verbose": + verbose = True + index += 1 + elif token == "--steps": + if index + 1 >= len(rest): + raise ValueError("Missing value for --steps.") + step_slice = rest[index + 1] + index += 2 + elif token == "--repo": if index + 1 >= len(rest): raise ValueError("Missing value for --repo.") repo_value = rest[index + 1] index += 2 - continue - raise ValueError(f"Unsupported query option: {token}") - - if repo_value is None: - raise ValueError("Missing required --repo option.") + else: + raise ValueError(f"Unsupported query option: {token}") return QueryInvocation( target=target, expression=expression, repo_path=Path(repo_value), output_json=output_json, + plain=plain, + no_color=no_color, + verbose=verbose, + step_slice=step_slice, ) @@ -93,9 +105,23 @@ def run_query_command(args: list[str]) -> QueryCommandResult: try: invocation = parse_query_invocation(args) normalized_repo_path = normalize_repo_path(invocation.repo_path) - rows = collect_query_rows(invocation, normalized_repo_path) except ValueError as error: return QueryCommandResult(exit_status=2, error_message=str(error)) + + is_tty = sys.stdout.isatty() + effective_no_color = invocation.no_color or not is_tty + + header_info: dict[str, Any] = { + "target": invocation.target, + "repo": str(normalized_repo_path), + "expression": invocation.expression, + "verbose": invocation.verbose, + } + + try: + if invocation.target == "metrics": + return _run_metrics_query(invocation, normalized_repo_path, header_info, effective_no_color) + return _run_images_query(invocation, normalized_repo_path, header_info, effective_no_color) except RuntimeError as error: return QueryCommandResult(exit_status=2, error_message=str(error)) except Exception as error: @@ -103,78 +129,63 @@ def run_query_command(args: list[str]) -> QueryCommandResult: exit_status=2, error_message=f"Failed to evaluate query: {error}" ) - payload = { - "target": invocation.target, - "expression": invocation.expression, - "repo_path": str(normalized_repo_path), - "count": len(rows), - "rows": rows, - } - if invocation.output_json: - return QueryCommandResult(exit_status=0, output=json.dumps(payload)) - return QueryCommandResult(exit_status=0, output=render_text_output(payload)) +def _run_metrics_query( + invocation: QueryInvocation, + repo_path: Path, + header_info: dict[str, Any], + no_color: bool, +) -> QueryCommandResult: + from aimx.aim_bridge.metric_stats import ( + collect_metric_series, + filter_by_step_range, + group_by_run, + parse_step_slice, + ) + from aimx.rendering.query_views import ( + render_json, + render_oneline, + render_rich_table, + ) -def collect_query_rows(invocation: QueryInvocation, repo_path: Path) -> list[dict[str, Any]]: - Repo, QueryReportMode = load_aim_query_support() - repo = Repo(str(repo_path)) + series_list = collect_metric_series(invocation.expression, repo_path) - if invocation.target == "metrics": - rows: list[dict[str, Any]] = [] - results = repo.query_metrics( - invocation.expression, report_mode=QueryReportMode.DISABLED - ) - for run_collection in results.iter_runs(): - for metric in run_collection: - rows.append( - build_row( - run_id=metric.run.hash, - target="metrics", - name=metric.name, - context=metric.context.to_dict(), - ) - ) - return rows - - rows = [] - results = repo.query_images(invocation.expression, report_mode=QueryReportMode.DISABLED) - for image in results.iter(): - rows.append( - build_row( - run_id=image.run.hash, - target="images", - name=image.name, - context=image.context.to_dict(), - ) - ) - return rows - - -def build_row( - *, run_id: str, target: str, name: str, context: dict[str, Any] -) -> dict[str, Any]: - summary = f"run {run_id} {target[:-1] if target.endswith('s') else target} {name}" - return { - "run_id": run_id, - "target": target, - "name": name, - "context": context, - "summary": summary, - } + if invocation.step_slice is not None: + step_start, step_end = parse_step_slice(invocation.step_slice) + series_list = [filter_by_step_range(s, step_start, step_end) for s in series_list] + groups = group_by_run(series_list) -def render_text_output(payload: dict[str, Any]) -> str: - lines = [ - f"target: {payload['target']}", - f"repo: {payload['repo_path']}", - f"expression: {payload['expression']}", - f"matches: {payload['count']}", - ] - if payload["rows"]: - for row in payload["rows"]: - lines.append( - f"- run={row['run_id']} name={row['name']} context={json.dumps(row['context'], sort_keys=True)}" - ) - else: - lines.append("No matching results found.") - return "\n".join(lines) + if invocation.output_json: + return QueryCommandResult(exit_status=0, output=render_json(groups, header_info)) + if invocation.plain: + return QueryCommandResult(exit_status=0, output=render_oneline(groups, header_info)) + return QueryCommandResult( + exit_status=0, + output=render_rich_table(groups, header_info, no_color=no_color), + ) + + +def _run_images_query( + invocation: QueryInvocation, + repo_path: Path, + header_info: dict[str, Any], + no_color: bool, +) -> QueryCommandResult: + from aimx.aim_bridge.metric_stats import collect_image_series + from aimx.rendering.query_views import ( + render_image_json, + render_image_oneline, + render_image_rich_table, + ) + + image_rows = collect_image_series(invocation.expression, repo_path) + + if invocation.output_json: + return QueryCommandResult(exit_status=0, output=render_image_json(image_rows, header_info)) + if invocation.plain: + return QueryCommandResult(exit_status=0, output=render_image_oneline(image_rows, header_info)) + return QueryCommandResult( + exit_status=0, + output=render_image_rich_table(image_rows, header_info, no_color=no_color), + ) diff --git a/src/aimx/commands/trace.py b/src/aimx/commands/trace.py new file mode 100644 index 0000000..6501484 --- /dev/null +++ b/src/aimx/commands/trace.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +from aimx.commands.query import QueryCommandResult, normalize_repo_path + +_MODES = {"plot", "table", "csv", "json"} + + +@dataclass(frozen=True) +class TraceInvocation: + expression: str + repo_path: Path + mode: Literal["plot", "table", "csv", "json"] = "plot" + head: int | None = None + tail: int | None = None + every: int | None = None + width: int | None = None + height: int | None = None + no_color: bool = False + step_slice: str | None = None + + +def parse_trace_invocation(args: list[str]) -> TraceInvocation: + if len(args) < 1: + raise ValueError( + "Usage: aimx trace [--repo ] [--table|--csv|--json]" + " [--steps start:end] [--head N] [--tail N] [--every K]" + " [--width W] [--height H] [--no-color]" + ) + + expression = args[0] + rest = args[1:] + + mode: Literal["plot", "table", "csv", "json"] = "plot" + repo_value = "." + head: int | None = None + tail: int | None = None + every: int | None = None + width: int | None = None + height: int | None = None + no_color = False + step_slice: str | None = None + + index = 0 + while index < len(rest): + token = rest[index] + if token == "--table": + mode = "table" + index += 1 + elif token == "--csv": + mode = "csv" + index += 1 + elif token == "--json": + mode = "json" + index += 1 + elif token == "--no-color": + no_color = True + index += 1 + elif token == "--repo": + if index + 1 >= len(rest): + raise ValueError("Missing value for --repo.") + repo_value = rest[index + 1] + index += 2 + elif token == "--steps": + if index + 1 >= len(rest): + raise ValueError("Missing value for --steps.") + step_slice = rest[index + 1] + index += 2 + elif token == "--head": + if index + 1 >= len(rest): + raise ValueError("Missing value for --head.") + try: + head = int(rest[index + 1]) + except ValueError: + raise ValueError(f"--head requires an integer, got: {rest[index + 1]}") + index += 2 + elif token == "--tail": + if index + 1 >= len(rest): + raise ValueError("Missing value for --tail.") + try: + tail = int(rest[index + 1]) + except ValueError: + raise ValueError(f"--tail requires an integer, got: {rest[index + 1]}") + index += 2 + elif token == "--every": + if index + 1 >= len(rest): + raise ValueError("Missing value for --every.") + try: + every = int(rest[index + 1]) + if every < 1: + raise ValueError("--every must be >= 1.") + except ValueError as exc: + if "every" in str(exc): + raise + raise ValueError(f"--every requires a positive integer, got: {rest[index + 1]}") + index += 2 + elif token == "--width": + if index + 1 >= len(rest): + raise ValueError("Missing value for --width.") + try: + width = int(rest[index + 1]) + except ValueError: + raise ValueError(f"--width requires an integer, got: {rest[index + 1]}") + index += 2 + elif token == "--height": + if index + 1 >= len(rest): + raise ValueError("Missing value for --height.") + try: + height = int(rest[index + 1]) + except ValueError: + raise ValueError(f"--height requires an integer, got: {rest[index + 1]}") + index += 2 + else: + raise ValueError(f"Unsupported trace option: {token}") + + return TraceInvocation( + expression=expression, + repo_path=Path(repo_value), + mode=mode, + head=head, + tail=tail, + every=every, + width=width, + height=height, + no_color=no_color, + step_slice=step_slice, + ) + + +def run_trace_command(args: list[str]) -> QueryCommandResult: + try: + invocation = parse_trace_invocation(args) + normalized_repo_path = normalize_repo_path(invocation.repo_path) + except ValueError as error: + return QueryCommandResult(exit_status=2, error_message=str(error)) + + is_tty = sys.stdout.isatty() + effective_no_color = invocation.no_color or not is_tty + + try: + from aimx.aim_bridge.metric_stats import ( + collect_metric_series, + filter_by_step_range, + parse_step_slice, + subsample, + ) + from aimx.rendering.trace_views import ( + render_csv, + render_plot, + render_trace_json, + render_trace_table, + ) + + series_list = collect_metric_series(invocation.expression, normalized_repo_path) + + if not series_list: + return QueryCommandResult(exit_status=0, output="No matching metrics found.") + + # Step range filter is a hard constraint applied before density subsampling + if invocation.step_slice is not None: + step_start, step_end = parse_step_slice(invocation.step_slice) + series_list = [filter_by_step_range(s, step_start, step_end) for s in series_list] + # Drop empty series so they don't clutter plots + series_list = [s for s in series_list if s.count > 0] + + if not series_list: + return QueryCommandResult(exit_status=0, output="No data in the requested step range.") + + # Density subsampling for visualisation + needs_sample = any( + x is not None for x in (invocation.head, invocation.tail, invocation.every) + ) + if needs_sample: + series_list = [ + subsample(s, head=invocation.head, tail=invocation.tail, every=invocation.every) + for s in series_list + ] + + if invocation.mode == "json": + output = render_trace_json(series_list) + elif invocation.mode == "csv": + output = render_csv(series_list) + elif invocation.mode == "table": + output = render_trace_table(series_list, no_color=effective_no_color) + else: + output = render_plot( + series_list, + width=invocation.width, + height=invocation.height, + ) + + return QueryCommandResult(exit_status=0, output=output) + + except RuntimeError as error: + return QueryCommandResult(exit_status=2, error_message=str(error)) + except Exception as error: + return QueryCommandResult( + exit_status=2, error_message=f"Failed to evaluate trace: {error}" + ) diff --git a/src/aimx/rendering/__init__.py b/src/aimx/rendering/__init__.py new file mode 100644 index 0000000..9d48db4 --- /dev/null +++ b/src/aimx/rendering/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/src/aimx/rendering/colors.py b/src/aimx/rendering/colors.py new file mode 100644 index 0000000..fb02395 --- /dev/null +++ b/src/aimx/rendering/colors.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +RUN_HASH = "cyan bold" +EXPERIMENT = "yellow" +METRIC_NAME = "green" +CONTEXT_KEY = "dim" +CONTEXT_VAL = "white" +HEADER = "dim" +NUMBER_EMPH = "bold white" +NUMBER_DIM = "white" +RULE_LINE = "dim cyan" diff --git a/src/aimx/rendering/query_views.py b/src/aimx/rendering/query_views.py new file mode 100644 index 0000000..823817b --- /dev/null +++ b/src/aimx/rendering/query_views.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +import datetime as dt +import io +import json +import math +import shutil +from typing import Any + +from rich.console import Console +from rich.table import Table +from rich.text import Text + +from aimx.aim_bridge.metric_stats import MetricSeries, RunMeta +from aimx.rendering import colors + + +def _fmt_float(v: float) -> str: + if math.isnan(v): + return "—" + if abs(v) >= 1e5 or (abs(v) < 1e-3 and v != 0): + return f"{v:.3e}" + return f"{v:.4f}" + + +def _fmt_context(ctx: dict[str, Any]) -> str: + if not ctx: + return "" + return " ".join(f"{k}={v}" for k, v in sorted(ctx.items())) + + +def _short_hash(h: str) -> str: + return h[:8] + + +def _fmt_creation_time(ts: float | None) -> str: + if ts is None: + return "" + try: + local = dt.datetime.fromtimestamp(ts) + except (OverflowError, OSError, ValueError): + return "" + return local.strftime("%Y-%m-%d %H:%M") + + +def _run_label(run: RunMeta) -> str: + label = _short_hash(run.hash) + if run.experiment: + label = f"{label} {run.experiment}" + elif run.name: + label = f"{label} {run.name}" + return label + + +def render_rich_table( + groups: list[tuple[RunMeta, list[MetricSeries]]], + header_info: dict[str, Any], + *, + no_color: bool = False, +) -> str: + """Render query results as a rich-formatted table string. + + When ``no_color=True``, output is plain text (suitable for non-TTY). + When color is on, ANSI escape codes are embedded in the returned string. + """ + width = 120 if no_color else shutil.get_terminal_size(fallback=(120, 24)).columns + buf = io.StringIO() + console = Console( + file=buf, + no_color=no_color, + force_terminal=not no_color, + width=width, + highlight=False, + ) + + total = sum(len(ms) for _, ms in groups) + expr = header_info.get("expression", "") + repo = header_info.get("repo", "") + target = header_info.get("target", "") + + # Compact header + if header_info.get("verbose"): + console.print(f"[{colors.HEADER}]target:[/] {target} [{colors.HEADER}]repo:[/] {repo}") + console.print(f"[{colors.HEADER}]expression:[/] {expr}") + console.print( + f"[{colors.HEADER}]Repo:[/] {repo} [{colors.HEADER}]·[/] " + f"[{colors.NUMBER_EMPH}]{total}[/] [{colors.HEADER}]match{'es' if total != 1 else ''}[/] " + f"[{colors.HEADER}]·[/] [{colors.HEADER}]{target} where[/] {expr}" + ) + + for run, series_list in groups: + console.print() + label = Text() + label.append("▌ ", style=colors.RULE_LINE) + label.append(_short_hash(run.hash), style=colors.RUN_HASH) + if run.experiment: + label.append(f" {run.experiment}", style=colors.EXPERIMENT) + elif run.name: + label.append(f" {run.name}", style=colors.EXPERIMENT) + created_str = _fmt_creation_time(run.creation_time) + if created_str: + label.append(f" {created_str}", style=colors.HEADER) + console.print(label) + + table = Table( + show_header=True, + header_style=colors.HEADER, + box=None, + pad_edge=True, + show_edge=False, + padding=(0, 1), + ) + table.add_column("NAME", style=colors.METRIC_NAME, no_wrap=True) + table.add_column("CONTEXT", style=colors.CONTEXT_VAL, no_wrap=True) + table.add_column("STEPS", justify="right") + table.add_column("LAST", justify="right", style=colors.NUMBER_EMPH) + table.add_column("MIN", justify="right", style=colors.NUMBER_DIM) + table.add_column("@STEP", justify="right", style=colors.HEADER) + table.add_column("MAX", justify="right", style=colors.NUMBER_DIM) + table.add_column("@STEP", justify="right", style=colors.HEADER) + + for series in series_list: + ctx_str = _fmt_context(series.context) + last_val, last_step = series.last + min_val, min_step = series.min + max_val, max_step = series.max + table.add_row( + series.name, + ctx_str, + str(series.count), + _fmt_float(last_val), + _fmt_float(min_val), + str(min_step) if min_step >= 0 else "—", + _fmt_float(max_val), + str(max_step) if max_step >= 0 else "—", + ) + + console.print(table) + + return buf.getvalue() + + +def render_image_rich_table( + image_rows: list[dict[str, Any]], + header_info: dict[str, Any], + *, + no_color: bool = False, +) -> str: + width = 120 if no_color else shutil.get_terminal_size(fallback=(120, 24)).columns + buf = io.StringIO() + console = Console( + file=buf, + no_color=no_color, + force_terminal=not no_color, + width=width, + highlight=False, + ) + + total = len(image_rows) + expr = header_info.get("expression", "") + repo = header_info.get("repo", "") + target = header_info.get("target", "images") + + console.print( + f"[{colors.HEADER}]Repo:[/] {repo} [{colors.HEADER}]·[/] " + f"[{colors.NUMBER_EMPH}]{total}[/] [{colors.HEADER}]match{'es' if total != 1 else ''}[/] " + f"[{colors.HEADER}]·[/] [{colors.HEADER}]{target} where[/] {expr}" + ) + + table = Table( + show_header=True, + header_style=colors.HEADER, + box=None, + pad_edge=True, + show_edge=False, + padding=(0, 1), + ) + table.add_column("RUN", style=colors.RUN_HASH, no_wrap=True) + table.add_column("EXPERIMENT", style=colors.EXPERIMENT) + table.add_column("NAME", style=colors.METRIC_NAME) + table.add_column("CONTEXT", style=colors.CONTEXT_VAL) + + for row in image_rows: + run: RunMeta = row["run"] + ctx_str = _fmt_context(row["context"]) + table.add_row( + _short_hash(run.hash), + run.experiment or run.name or "", + row["name"], + ctx_str, + ) + + console.print(table) + return buf.getvalue() + + +def render_oneline( + groups: list[tuple[RunMeta, list[MetricSeries]]], + header_info: dict[str, Any], +) -> str: + """Plain single-line-per-metric output, suitable for awk/grep pipelines.""" + repo = header_info.get("repo", "") + lines: list[str] = [] + for run, series_list in groups: + h = _short_hash(run.hash) + exp = run.experiment or run.name or "" + for series in series_list: + ctx_str = _fmt_context(series.context) or "-" + last_val, _ = series.last + min_val, _ = series.min + max_val, _ = series.max + lines.append( + f"{repo}\t{h}\t{exp}\t{series.name}\t{ctx_str}" + f"\tsteps={series.count}" + f"\tlast={_fmt_float(last_val)}" + f"\tmin={_fmt_float(min_val)}" + f"\tmax={_fmt_float(max_val)}" + ) + return "\n".join(lines) + + +def render_image_oneline( + image_rows: list[dict[str, Any]], + header_info: dict[str, Any], +) -> str: + repo = header_info.get("repo", "") + lines: list[str] = [] + for row in image_rows: + run: RunMeta = row["run"] + h = _short_hash(run.hash) + exp = run.experiment or run.name or "" + ctx_str = _fmt_context(row["context"]) or "-" + lines.append(f"{repo}\t{h}\t{exp}\t{row['name']}\t{ctx_str}") + return "\n".join(lines) + + +def render_json( + groups: list[tuple[RunMeta, list[MetricSeries]]], + header_info: dict[str, Any], +) -> str: + """Nested JSON output: runs → metrics.""" + metrics_count = sum(len(ms) for _, ms in groups) + runs_json: list[dict[str, Any]] = [] + for run, series_list in groups: + metrics_json: list[dict[str, Any]] = [] + for series in series_list: + last_val, last_step = series.last + min_val, min_step = series.min + max_val, max_step = series.max + metrics_json.append( + { + "name": series.name, + "context": series.context, + "steps": series.count, + "last": {"value": _safe_float(last_val), "step": last_step}, + "min": {"value": _safe_float(min_val), "step": min_step}, + "max": {"value": _safe_float(max_val), "step": max_step}, + } + ) + runs_json.append( + { + "hash": run.hash, + "experiment": run.experiment, + "name": run.name, + "metrics": metrics_json, + } + ) + payload: dict[str, Any] = { + "target": header_info.get("target", "metrics"), + "repo": header_info.get("repo", ""), + "expression": header_info.get("expression", ""), + "runs_count": len(runs_json), + "metrics_count": metrics_count, + "runs": runs_json, + } + return json.dumps(payload) + + +def render_image_json( + image_rows: list[dict[str, Any]], + header_info: dict[str, Any], +) -> str: + rows_json: list[dict[str, Any]] = [] + for row in image_rows: + run: RunMeta = row["run"] + rows_json.append( + { + "run_hash": run.hash, + "experiment": run.experiment, + "name": row["name"], + "context": row["context"], + } + ) + payload: dict[str, Any] = { + "target": header_info.get("target", "images"), + "repo": header_info.get("repo", ""), + "expression": header_info.get("expression", ""), + "count": len(rows_json), + "rows": rows_json, + } + return json.dumps(payload) + + +def _safe_float(v: float) -> float | None: + if math.isnan(v): + return None + return v diff --git a/src/aimx/rendering/trace_views.py b/src/aimx/rendering/trace_views.py new file mode 100644 index 0000000..fd9eed8 --- /dev/null +++ b/src/aimx/rendering/trace_views.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import contextlib +import csv +import io +import json +import math +import shutil +from typing import Any + +from rich.console import Console +from rich.table import Table + +from aimx.aim_bridge.metric_stats import MetricSeries, RunMeta +from aimx.rendering import colors + + +def _short_hash(h: str) -> str: + return h[:8] + + +def _fmt_context(ctx: dict[str, Any]) -> str: + if not ctx: + return "" + return " ".join(f"{k}={v}" for k, v in sorted(ctx.items())) + + +def _series_label(series: MetricSeries) -> str: + parts = [_short_hash(series.run.hash)] + if series.run.experiment: + parts.append(series.run.experiment) + elif series.run.name: + parts.append(series.run.name) + parts.append(series.name) + ctx = _fmt_context(series.context) + if ctx: + parts.append(f"[{ctx}]") + return " · ".join(parts) + + +def render_plot( + series_list: list[MetricSeries], + *, + width: int | None = None, + height: int | None = None, +) -> str: + """Render time-series curves using plotext and return as a string.""" + import plotext as plt # noqa: PLC0415 + + plt.clt() + plt.cld() + + term_width = shutil.get_terminal_size(fallback=(120, 30)).columns + plot_width = width or term_width + plot_height = height or 25 + plt.plot_size(plot_width, plot_height) + + for series in series_list: + if series.count == 0: + continue + label = _series_label(series) + x = series.steps.tolist() + y = series.values.tolist() + plt.plot(x, y, label=label) + + if series_list: + first = series_list[0] + title = first.name if all(s.name == first.name for s in series_list) else "Metrics" + plt.title(title) + plt.xlabel("Step") + plt.ylabel("Value") + plt.theme("dark") + + output_buf = io.StringIO() + with contextlib.redirect_stdout(output_buf): + plt.show() + return output_buf.getvalue() + + +def render_trace_table( + series_list: list[MetricSeries], + *, + no_color: bool = False, +) -> str: + """Render each series as a rich table with step/epoch/value columns.""" + width = 120 if no_color else shutil.get_terminal_size(fallback=(120, 24)).columns + buf = io.StringIO() + console = Console( + file=buf, + no_color=no_color, + force_terminal=not no_color, + width=width, + highlight=False, + ) + + for series in series_list: + label = _series_label(series) + console.print(f"\n[{colors.HEADER}]{label}[/] [{colors.HEADER}]{series.count} points[/]") + + table = Table( + show_header=True, + header_style=colors.HEADER, + box=None, + pad_edge=True, + show_edge=False, + padding=(0, 1), + ) + table.add_column("STEP", justify="right") + table.add_column("EPOCH", justify="right") + table.add_column("VALUE", justify="right", style=colors.NUMBER_EMPH) + + for i in range(series.count): + step = int(series.steps[i]) + epoch = f"{series.epochs[i]:.0f}" if series.epochs is not None else "—" + val = series.values[i] + val_str = "—" if math.isnan(float(val)) else f"{float(val):.6g}" + table.add_row(str(step), epoch, val_str) + + console.print(table) + + return buf.getvalue() + + +def render_csv(series_list: list[MetricSeries]) -> str: + """Render series data as CSV: run_hash,experiment,metric,context,step,epoch,value.""" + buf = io.StringIO() + writer = csv.writer(buf) + writer.writerow(["run_hash", "experiment", "metric", "context", "step", "epoch", "value"]) + for series in series_list: + ctx_str = json.dumps(series.context, sort_keys=True) + for i in range(series.count): + step = int(series.steps[i]) + epoch = float(series.epochs[i]) if series.epochs is not None else "" + val = float(series.values[i]) + writer.writerow( + [ + series.run.hash, + series.run.experiment or series.run.name or "", + series.name, + ctx_str, + step, + epoch, + val, + ] + ) + return buf.getvalue() + + +def render_trace_json(series_list: list[MetricSeries]) -> str: + """Render series data as JSON with full value arrays.""" + result: list[dict[str, Any]] = [] + for series in series_list: + result.append( + { + "run": { + "hash": series.run.hash, + "experiment": series.run.experiment, + "name": series.run.name, + }, + "metric": series.name, + "context": series.context, + "count": series.count, + "steps": series.steps.tolist(), + "epochs": series.epochs.tolist() if series.epochs is not None else None, + "values": series.values.tolist(), + } + ) + return json.dumps(result) diff --git a/src/aimx/router.py b/src/aimx/router.py index 230a201..bb7562e 100644 --- a/src/aimx/router.py +++ b/src/aimx/router.py @@ -3,7 +3,7 @@ from dataclasses import dataclass -OWNED_COMMANDS = {"help", "--help", "-h", "version", "doctor"} +OWNED_COMMANDS = {"help", "--help", "-h", "version", "doctor", "query", "trace"} @dataclass(frozen=True) @@ -31,6 +31,13 @@ def route_args(args: list[str]) -> CommandRoute: owned_args=list(args[1:]), reason="reserved aimx query command", ) + if command == "trace": + return CommandRoute( + "owned", + owned_command="trace", + owned_args=list(args[1:]), + reason="reserved aimx trace command", + ) return CommandRoute( "passthrough", diff --git a/tests/contract/test_query_contract.py b/tests/contract/test_query_contract.py index 65efd2f..4091ae9 100644 --- a/tests/contract/test_query_contract.py +++ b/tests/contract/test_query_contract.py @@ -5,7 +5,7 @@ from aimx.__main__ import main -def test_query_metrics_json_contract_uses_stable_envelope(capfd, sample_repo_root) -> None: +def test_query_metrics_json_contract_uses_nested_runs_envelope(capfd, sample_repo_root) -> None: exit_code = main( [ "query", @@ -22,28 +22,56 @@ def test_query_metrics_json_contract_uses_stable_envelope(capfd, sample_repo_roo assert exit_code == 0 assert payload["target"] == "metrics" assert payload["expression"] == "metric.name == 'loss'" - assert payload["repo_path"] == "data" - assert payload["count"] > 0 - assert payload["rows"] - assert payload["rows"][0]["target"] == "metrics" - assert payload["rows"][0]["name"] == "loss" - assert payload["rows"][0]["run_id"] - assert "summary" in payload["rows"][0] + assert payload["repo"] == str(sample_repo_root) + assert payload["runs_count"] > 0 + assert payload["metrics_count"] > 0 + assert payload["runs"] + first_run = payload["runs"][0] + assert "hash" in first_run + assert "experiment" in first_run + assert "metrics" in first_run + first_metric = first_run["metrics"][0] + assert first_metric["name"] == "loss" + assert "steps" in first_metric + assert "last" in first_metric + assert "min" in first_metric + assert "max" in first_metric + + +def test_query_metrics_text_contract_reports_repo_count_and_metric_name( + capfd, sample_repo_root +) -> None: + exit_code = main( + ["query", "metrics", "metric.name == 'loss'", "--repo", str(sample_repo_root)] + ) + + captured = capfd.readouterr() + assert exit_code == 0 + assert "Repo:" in captured.out + assert "match" in captured.out + assert "loss" in captured.out -def test_query_metrics_text_contract_reports_target_repo_and_count( +def test_query_metrics_oneline_contract_is_tab_separated_and_contains_metric_name( capfd, sample_repo_root ) -> None: exit_code = main( - ["query", "metrics", "metric.name == 'loss'", "--repo", str(sample_repo_root)] + [ + "query", + "metrics", + "metric.name == 'loss'", + "--repo", + str(sample_repo_root), + "--oneline", + ] ) captured = capfd.readouterr() assert exit_code == 0 - assert "target: metrics" in captured.out - assert f"repo: {sample_repo_root}" in captured.out - assert "matches:" in captured.out + lines = [l for l in captured.out.splitlines() if l.strip()] + assert lines, "Expected at least one output line" assert "loss" in captured.out + assert "\t" in lines[0] def test_query_images_json_contract_uses_stable_envelope(capfd, sample_repo_root) -> None: @@ -57,7 +85,6 @@ def test_query_images_json_contract_uses_stable_envelope(capfd, sample_repo_root assert payload["target"] == "images" assert payload["expression"] == "images" assert payload["count"] > 0 - assert payload["rows"][0]["target"] == "images" assert payload["rows"][0]["name"] == "example" diff --git a/tests/contract/test_trace_contract.py b/tests/contract/test_trace_contract.py new file mode 100644 index 0000000..bead2cc --- /dev/null +++ b/tests/contract/test_trace_contract.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import csv +import io +import json + +from aimx.__main__ import main + + +def test_trace_plot_contract_produces_non_empty_output(capfd, sample_repo_root) -> None: + exit_code = main( + ["trace", "metric.name == 'loss'", "--repo", str(sample_repo_root)] + ) + + captured = capfd.readouterr() + assert exit_code == 0 + assert captured.out.strip(), "Expected plotext chart output" + + +def test_trace_table_contract_contains_step_and_value_columns(capfd, sample_repo_root) -> None: + exit_code = main( + ["trace", "metric.name == 'loss'", "--repo", str(sample_repo_root), "--table"] + ) + + captured = capfd.readouterr() + assert exit_code == 0 + assert "STEP" in captured.out + assert "VALUE" in captured.out + + +def test_trace_json_contract_has_steps_and_values_arrays(capfd, sample_repo_root) -> None: + exit_code = main( + ["trace", "metric.name == 'loss'", "--repo", str(sample_repo_root), "--json"] + ) + + captured = capfd.readouterr() + payload = json.loads(captured.out) + assert exit_code == 0 + assert isinstance(payload, list) + assert payload, "Expected at least one series" + first = payload[0] + assert first["metric"] == "loss" + assert "steps" in first + assert "values" in first + assert isinstance(first["steps"], list) + assert isinstance(first["values"], list) + + +def test_trace_csv_contract_has_correct_headers(capfd, sample_repo_root) -> None: + exit_code = main( + ["trace", "metric.name == 'loss'", "--repo", str(sample_repo_root), "--csv"] + ) + + captured = capfd.readouterr() + assert exit_code == 0 + reader = csv.DictReader(io.StringIO(captured.out)) + assert reader.fieldnames is not None + for field in ("run_hash", "metric", "step", "value"): + assert field in reader.fieldnames + + +def test_trace_invalid_repo_reports_error(capfd) -> None: + exit_code = main(["trace", "metric.name == 'loss'", "--repo", "missing-repo"]) + + captured = capfd.readouterr() + assert exit_code == 2 + assert "Repository path does not exist" in captured.err + + +def test_trace_invalid_expression_reports_error(capfd, sample_repo_root) -> None: + exit_code = main( + ["trace", "metric.name ==", "--repo", str(sample_repo_root)] + ) + + captured = capfd.readouterr() + assert exit_code == 2 + assert "Failed to evaluate trace" in captured.err diff --git a/tests/integration/test_missing_native_aim.py b/tests/integration/test_missing_native_aim.py index ed206c3..3b6559d 100644 --- a/tests/integration/test_missing_native_aim.py +++ b/tests/integration/test_missing_native_aim.py @@ -49,5 +49,5 @@ def test_query_owned_command_still_works_when_native_aim_is_missing( captured = capsys.readouterr() assert exit_code == 0 - assert "target: metrics" in captured.out - assert "matches:" in captured.out + assert "Repo:" in captured.out + assert "match" in captured.out diff --git a/tests/integration/test_query_command.py b/tests/integration/test_query_command.py index 33e7ded..1115b53 100644 --- a/tests/integration/test_query_command.py +++ b/tests/integration/test_query_command.py @@ -36,8 +36,8 @@ def test_metric_query_accepts_repo_root_and_dot_aim_paths( assert root_exit_code == 0 assert dot_aim_exit_code == 0 - assert root_payload["count"] == dot_aim_payload["count"] - assert root_payload["rows"] == dot_aim_payload["rows"] + assert root_payload["metrics_count"] == dot_aim_payload["metrics_count"] + assert root_payload["runs_count"] == dot_aim_payload["runs_count"] def test_metric_query_returns_matches_from_sample_repository( @@ -50,7 +50,71 @@ def test_metric_query_returns_matches_from_sample_repository( captured = capfd.readouterr() assert exit_code == 0 assert "loss" in captured.out - assert "matches:" in captured.out + assert "match" in captured.out + + +def test_metric_query_defaults_repo_to_current_directory( + capfd, monkeypatch, sample_repo_root +) -> None: + monkeypatch.chdir(sample_repo_root) + + exit_code = main(["query", "metrics", "metric.name == 'loss'"]) + + captured = capfd.readouterr() + assert exit_code == 0 + assert "loss" in captured.out + assert "match" in captured.out + + +def test_metric_query_oneline_mode_returns_tab_separated_rows( + capfd, sample_repo_root +) -> None: + exit_code = main( + [ + "query", + "metrics", + "metric.name == 'loss'", + "--repo", + str(sample_repo_root), + "--oneline", + ] + ) + + captured = capfd.readouterr() + assert exit_code == 0 + lines = [l for l in captured.out.splitlines() if l.strip()] + assert lines + assert "loss" in captured.out + assert "\t" in lines[0] + assert "steps=" in lines[0] + assert "last=" in lines[0] + + +def test_metric_query_json_mode_returns_nested_structure( + capfd, sample_repo_root +) -> None: + exit_code = main( + [ + "query", + "metrics", + "metric.name == 'loss'", + "--repo", + str(sample_repo_root), + "--json", + ] + ) + + captured = capfd.readouterr() + payload = json.loads(captured.out) + assert exit_code == 0 + assert payload["runs_count"] > 0 + assert payload["metrics_count"] > 0 + first_run = payload["runs"][0] + first_metric = first_run["metrics"][0] + assert first_metric["name"] == "loss" + # last/min/max should have real numeric values (not null) since the data exists + last_val = first_metric["last"]["value"] + assert last_val is not None def test_image_query_returns_matches_from_sample_repository(capfd, sample_repo_root) -> None: diff --git a/tests/integration/test_short_hash_and_steps.py b/tests/integration/test_short_hash_and_steps.py new file mode 100644 index 0000000..0be65d1 --- /dev/null +++ b/tests/integration/test_short_hash_and_steps.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +"""Integration tests for short-hash expansion and --steps filtering. + +These tests require the sample Aim repository at data/.aim (same fixture as +other integration tests). They are skipped automatically if the repo is +absent. +""" + +import json + +from aimx.__main__ import main + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _first_run_hash(sample_repo_root) -> str: + """Return the full hash of one run known to have a 'loss' metric.""" + from aim import Repo + + repo = Repo(str(sample_repo_root)) + return repo.list_all_runs()[0] + + +# --------------------------------------------------------------------------- +# Short hash: query +# --------------------------------------------------------------------------- + + +def test_query_with_short_hash_returns_same_result_as_full_hash( + capfd, sample_repo_root +) -> None: + full_hash = _first_run_hash(sample_repo_root) + short_hash = full_hash[:8] + + exit_full = main( + [ + "query", "metrics", + f"run.hash=='{full_hash}' and metric.name=='loss'", + "--repo", str(sample_repo_root), + "--json", + ] + ) + out_full = json.loads(capfd.readouterr().out) + + exit_short = main( + [ + "query", "metrics", + f"run.hash=='{short_hash}' and metric.name=='loss'", + "--repo", str(sample_repo_root), + "--json", + ] + ) + out_short = json.loads(capfd.readouterr().out) + + assert exit_full == 0 + assert exit_short == 0 + assert out_full["metrics_count"] == out_short["metrics_count"] + assert out_full["runs_count"] == out_short["runs_count"] + + +def test_query_with_short_hash_exits_zero_and_contains_metric_name( + capfd, sample_repo_root +) -> None: + full_hash = _first_run_hash(sample_repo_root) + short_hash = full_hash[:8] + + exit_code = main( + [ + "query", "metrics", + f"run.hash=='{short_hash}' and metric.name=='loss'", + "--repo", str(sample_repo_root), + ] + ) + captured = capfd.readouterr() + + assert exit_code == 0 + assert "loss" in captured.out + + +# --------------------------------------------------------------------------- +# Short hash: trace +# --------------------------------------------------------------------------- + + +def test_trace_with_short_hash_produces_output(capfd, sample_repo_root) -> None: + full_hash = _first_run_hash(sample_repo_root) + short_hash = full_hash[:8] + + exit_code = main( + [ + "trace", + f"run.hash=='{short_hash}' and metric.name=='loss'", + "--repo", str(sample_repo_root), + "--json", + ] + ) + captured = capfd.readouterr() + + assert exit_code == 0 + payload = json.loads(captured.out) + assert isinstance(payload, list) + for series in payload: + assert series["run"]["hash"] == full_hash + + +# --------------------------------------------------------------------------- +# Short hash: error paths +# --------------------------------------------------------------------------- + + +def test_query_nonexistent_short_hash_reports_error(capfd, sample_repo_root) -> None: + exit_code = main( + [ + "query", "metrics", + "run.hash=='0000000000000000' and metric.name=='loss'", + "--repo", str(sample_repo_root), + ] + ) + captured = capfd.readouterr() + + assert exit_code == 2 + assert "did not match" in captured.err + + +def test_query_ambiguous_short_hash_reports_error(capfd, sample_repo_root) -> None: + """A prefix of 1 char will almost certainly match multiple runs.""" + from aim import Repo + + repo = Repo(str(sample_repo_root)) + all_hashes = repo.list_all_runs() + # Use the first character of the first hash; if more than one run shares + # that character, the prefix is ambiguous. + first_char = all_hashes[0][0] + matches = [h for h in all_hashes if h.startswith(first_char)] + if len(matches) < 2: + import pytest + pytest.skip("No ambiguous prefix available in this repository") + + exit_code = main( + [ + "query", "metrics", + f"run.hash=='{first_char}' and metric.name=='loss'", + "--repo", str(sample_repo_root), + ] + ) + captured = capfd.readouterr() + + assert exit_code == 2 + assert "ambiguous" in captured.err + + +# --------------------------------------------------------------------------- +# --steps filter: query +# --------------------------------------------------------------------------- + + +def test_query_steps_filter_reduces_step_count(capfd, sample_repo_root) -> None: + """Steps in the filtered result should all be <= the bound.""" + exit_code = main( + [ + "query", "metrics", + "metric.name == 'loss'", + "--repo", str(sample_repo_root), + "--steps", ":50", + "--json", + ] + ) + captured = capfd.readouterr() + payload = json.loads(captured.out) + + assert exit_code == 0 + assert payload["metrics_count"] >= 0 + for run in payload["runs"]: + for metric in run["metrics"]: + assert metric["steps"] <= 50 + + +def test_query_steps_filter_closed_range_bounds_last_value( + capfd, sample_repo_root +) -> None: + """last.step must fall within the requested window (or be -1 for empty).""" + exit_code = main( + [ + "query", "metrics", + "metric.name == 'loss'", + "--repo", str(sample_repo_root), + "--steps", "50:100", + "--json", + ] + ) + captured = capfd.readouterr() + payload = json.loads(captured.out) + + assert exit_code == 0 + for run in payload["runs"]: + for metric in run["metrics"]: + last_step = metric["last"]["step"] + if last_step != -1: + assert 50 <= last_step <= 100 + + +# --------------------------------------------------------------------------- +# --steps filter: trace +# --------------------------------------------------------------------------- + + +def test_trace_steps_filter_constrains_step_values(capfd, sample_repo_root) -> None: + exit_code = main( + [ + "trace", + "metric.name == 'loss'", + "--repo", str(sample_repo_root), + "--steps", "1:50", + "--json", + ] + ) + captured = capfd.readouterr() + payload = json.loads(captured.out) + + assert exit_code == 0 + for series in payload: + for step in series["steps"]: + assert 1 <= step <= 50 + + +def test_trace_steps_open_end_filter_keeps_steps_from_start( + capfd, sample_repo_root +) -> None: + exit_code = main( + [ + "trace", + "metric.name == 'loss'", + "--repo", str(sample_repo_root), + "--steps", "50:", + "--json", + ] + ) + captured = capfd.readouterr() + payload = json.loads(captured.out) + + assert exit_code == 0 + for series in payload: + for step in series["steps"]: + assert step >= 50 + + +def test_trace_steps_filter_and_short_hash_work_together( + capfd, sample_repo_root +) -> None: + full_hash = _first_run_hash(sample_repo_root) + short_hash = full_hash[:8] + + exit_code = main( + [ + "trace", + f"run.hash=='{short_hash}' and metric.name=='loss'", + "--repo", str(sample_repo_root), + "--steps", "1:200", + "--json", + ] + ) + captured = capfd.readouterr() + payload = json.loads(captured.out) + + assert exit_code == 0 + for series in payload: + assert series["run"]["hash"] == full_hash + for step in series["steps"]: + assert 1 <= step <= 200 diff --git a/tests/integration/test_trace_command.py b/tests/integration/test_trace_command.py new file mode 100644 index 0000000..d4703c4 --- /dev/null +++ b/tests/integration/test_trace_command.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import csv +import io +import json + +from aimx.__main__ import main + + +def test_trace_plot_produces_output_containing_metric_name(capfd, sample_repo_root) -> None: + exit_code = main( + ["trace", "metric.name == 'loss'", "--repo", str(sample_repo_root)] + ) + + captured = capfd.readouterr() + assert exit_code == 0 + assert captured.out.strip(), "Expected non-empty plotext output" + + +def test_trace_defaults_repo_to_current_directory( + capfd, monkeypatch, sample_repo_root +) -> None: + monkeypatch.chdir(sample_repo_root) + + exit_code = main(["trace", "metric.name == 'loss'", "--json"]) + + captured = capfd.readouterr() + payload = json.loads(captured.out) + assert exit_code == 0 + assert payload + + +def test_trace_table_mode_contains_step_and_value_columns(capfd, sample_repo_root) -> None: + exit_code = main( + ["trace", "metric.name == 'loss'", "--repo", str(sample_repo_root), "--table"] + ) + + captured = capfd.readouterr() + assert exit_code == 0 + assert "STEP" in captured.out + assert "VALUE" in captured.out + + +def test_trace_json_mode_returns_full_value_arrays(capfd, sample_repo_root) -> None: + exit_code = main( + ["trace", "metric.name == 'loss'", "--repo", str(sample_repo_root), "--json"] + ) + + captured = capfd.readouterr() + payload = json.loads(captured.out) + assert exit_code == 0 + assert isinstance(payload, list) + for series in payload: + assert series["metric"] == "loss" + assert len(series["steps"]) == len(series["values"]) + assert len(series["values"]) > 0 + + +def test_trace_csv_mode_contains_correct_fields(capfd, sample_repo_root) -> None: + exit_code = main( + ["trace", "metric.name == 'loss'", "--repo", str(sample_repo_root), "--csv"] + ) + + captured = capfd.readouterr() + assert exit_code == 0 + reader = csv.DictReader(io.StringIO(captured.out)) + rows = list(reader) + assert rows, "Expected at least one CSV data row" + for row in rows: + assert row["metric"] == "loss" + assert row["step"].isdigit() + + +def test_trace_head_limits_to_n_points_per_series(capfd, sample_repo_root) -> None: + exit_code = main( + [ + "trace", + "metric.name == 'loss'", + "--repo", + str(sample_repo_root), + "--json", + "--head", + "5", + ] + ) + + captured = capfd.readouterr() + payload = json.loads(captured.out) + assert exit_code == 0 + for series in payload: + assert len(series["values"]) <= 5 + + +def test_trace_no_matching_expression_exits_cleanly(capfd, sample_repo_root) -> None: + exit_code = main( + [ + "trace", + "metric.name == 'nonexistent_metric_xyz'", + "--repo", + str(sample_repo_root), + ] + ) + + captured = capfd.readouterr() + assert exit_code == 0 + assert "No matching" in captured.out diff --git a/tests/unit/test_hash_resolver.py b/tests/unit/test_hash_resolver.py new file mode 100644 index 0000000..699082b --- /dev/null +++ b/tests/unit/test_hash_resolver.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from aimx.aim_bridge.hash_resolver import resolve_hash_prefixes, _FULL_HASH_LEN + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_FULL_A = "a" * _FULL_HASH_LEN +_FULL_B = "b" * _FULL_HASH_LEN +_FULL_C = "abc123" + "0" * (_FULL_HASH_LEN - 6) + + +def _repo(hashes: list[str]): + """Return a mock Aim Repo whose list_all_runs() returns *hashes*.""" + mock_repo = MagicMock() + mock_repo.list_all_runs.return_value = hashes + return mock_repo + + +def _patch_repo(hashes: list[str]): + return patch( + "aimx.aim_bridge.hash_resolver.Repo", + return_value=_repo(hashes), + ) + + +# --------------------------------------------------------------------------- +# Expression contains no run.hash → no repo call needed +# --------------------------------------------------------------------------- + + +def test_expression_without_run_hash_is_returned_unchanged(tmp_path) -> None: + expr = "metric.name == 'loss'" + assert resolve_hash_prefixes(expr, tmp_path) == expr + + +def test_expression_without_run_hash_does_not_call_repo(tmp_path) -> None: + with patch("aimx.aim_bridge.hash_resolver.Repo") as mock_cls: + resolve_hash_prefixes("metric.name == 'loss'", tmp_path) + mock_cls.assert_not_called() + + +# --------------------------------------------------------------------------- +# Full-length hash passes through unchanged +# --------------------------------------------------------------------------- + + +def test_full_length_hash_is_not_rewritten(tmp_path) -> None: + expr = f"run.hash == '{_FULL_A}'" + with _patch_repo([_FULL_A, _FULL_B]): + result = resolve_hash_prefixes(expr, tmp_path) + assert result == expr + + +# --------------------------------------------------------------------------- +# Unique prefix → expanded to full hash +# --------------------------------------------------------------------------- + + +def test_unique_prefix_is_expanded_to_full_hash(tmp_path) -> None: + prefix = "abc123" + expr = f"run.hash == '{prefix}'" + with _patch_repo([_FULL_C, _FULL_B]): + result = resolve_hash_prefixes(expr, tmp_path) + assert _FULL_C in result + # The short literal should now be replaced by the full hash + assert f"'{prefix}'" not in result + + +def test_prefix_with_spaces_around_operator_is_expanded(tmp_path) -> None: + prefix = "abc123" + expr = f"run.hash == '{prefix}'" + with _patch_repo([_FULL_C]): + result = resolve_hash_prefixes(expr, tmp_path) + assert _FULL_C in result + + +def test_prefix_without_spaces_around_operator_is_expanded(tmp_path) -> None: + prefix = "abc123" + expr = f"run.hash=='{prefix}'" + with _patch_repo([_FULL_C]): + result = resolve_hash_prefixes(expr, tmp_path) + assert _FULL_C in result + + +# --------------------------------------------------------------------------- +# Quote style +# --------------------------------------------------------------------------- + + +def test_double_quoted_prefix_is_expanded(tmp_path) -> None: + prefix = "abc123" + expr = f'run.hash == "{prefix}"' + with _patch_repo([_FULL_C]): + result = resolve_hash_prefixes(expr, tmp_path) + assert _FULL_C in result + assert '"' in result + + +def test_single_quoted_prefix_is_expanded(tmp_path) -> None: + prefix = "abc123" + expr = f"run.hash == '{prefix}'" + with _patch_repo([_FULL_C]): + result = resolve_hash_prefixes(expr, tmp_path) + assert _FULL_C in result + assert "'" in result + + +# --------------------------------------------------------------------------- +# != operator +# --------------------------------------------------------------------------- + + +def test_not_equal_operator_is_supported(tmp_path) -> None: + prefix = "abc123" + expr = f"run.hash != '{prefix}'" + with _patch_repo([_FULL_C]): + result = resolve_hash_prefixes(expr, tmp_path) + assert _FULL_C in result + assert "!=" in result + + +# --------------------------------------------------------------------------- +# Case-insensitive prefix matching +# --------------------------------------------------------------------------- + + +def test_uppercase_prefix_matches_lowercase_hash(tmp_path) -> None: + full = "abc123" + "0" * (_FULL_HASH_LEN - 6) + expr = f"run.hash == 'ABC123'" + with _patch_repo([full]): + result = resolve_hash_prefixes(expr, tmp_path) + assert full in result + + +# --------------------------------------------------------------------------- +# Error: no match +# --------------------------------------------------------------------------- + + +def test_unmatched_prefix_raises_value_error(tmp_path) -> None: + expr = "run.hash == 'deadbeef'" + with _patch_repo([_FULL_A, _FULL_B]): + with pytest.raises(ValueError, match="did not match"): + resolve_hash_prefixes(expr, tmp_path) + + +def test_error_message_contains_original_short_hash(tmp_path) -> None: + expr = "run.hash == 'deadbeef'" + with _patch_repo([_FULL_A]): + with pytest.raises(ValueError, match="deadbeef"): + resolve_hash_prefixes(expr, tmp_path) + + +# --------------------------------------------------------------------------- +# Error: ambiguous +# --------------------------------------------------------------------------- + + +def test_ambiguous_prefix_raises_value_error(tmp_path) -> None: + full1 = "abc" + "1" * (_FULL_HASH_LEN - 3) + full2 = "abc" + "2" * (_FULL_HASH_LEN - 3) + expr = "run.hash == 'abc'" + with _patch_repo([full1, full2]): + with pytest.raises(ValueError, match="ambiguous"): + resolve_hash_prefixes(expr, tmp_path) + + +def test_ambiguous_error_lists_candidates(tmp_path) -> None: + full1 = "abc" + "1" * (_FULL_HASH_LEN - 3) + full2 = "abc" + "2" * (_FULL_HASH_LEN - 3) + expr = "run.hash == 'abc'" + with _patch_repo([full1, full2]): + with pytest.raises(ValueError, match="abc"): + resolve_hash_prefixes(expr, tmp_path) + + +# --------------------------------------------------------------------------- +# Multiple run.hash literals in one expression +# --------------------------------------------------------------------------- + + +def test_multiple_hash_literals_are_all_resolved(tmp_path) -> None: + full1 = "aaaa" + "0" * (_FULL_HASH_LEN - 4) + full2 = "bbbb" + "0" * (_FULL_HASH_LEN - 4) + expr = "run.hash == 'aaaa' or run.hash == 'bbbb'" + with _patch_repo([full1, full2]): + result = resolve_hash_prefixes(expr, tmp_path) + assert full1 in result + assert full2 in result + assert "'aaaa'" not in result + assert "'bbbb'" not in result diff --git a/tests/unit/test_metric_stats.py b/tests/unit/test_metric_stats.py new file mode 100644 index 0000000..8bfae1f --- /dev/null +++ b/tests/unit/test_metric_stats.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import datetime as dt +import math + +import numpy as np +import pytest + +from aimx.aim_bridge.metric_stats import ( + MetricSeries, + RunMeta, + _extract_run_meta, + _extract_values, + group_by_run, + subsample, +) + + +def _make_run(hash_: str = "abc123", experiment: str | None = "exp") -> RunMeta: + return RunMeta(hash=hash_, experiment=experiment, name=None, creation_time=None) + + +def _make_series( + run: RunMeta | None = None, + name: str = "loss", + values: list[float] | None = None, + steps: list[int] | None = None, +) -> MetricSeries: + if run is None: + run = _make_run() + vals = np.array(values if values is not None else [2.0, 1.5, 1.0, 0.5], dtype=float) + stps = np.array(steps if steps is not None else list(range(len(vals))), dtype=int) + return MetricSeries(run=run, name=name, context={}, values=vals, steps=stps, epochs=None) + + +class TestMetricSeriesStats: + def test_count_matches_values_length(self) -> None: + s = _make_series(values=[1.0, 2.0, 3.0]) + assert s.count == 3 + + def test_last_returns_final_value_and_step(self) -> None: + s = _make_series(values=[1.0, 0.5], steps=[10, 20]) + val, step = s.last + assert val == pytest.approx(0.5) + assert step == 20 + + def test_min_returns_minimum_value_and_its_step(self) -> None: + s = _make_series(values=[3.0, 1.0, 2.0], steps=[0, 1, 2]) + val, step = s.min + assert val == pytest.approx(1.0) + assert step == 1 + + def test_max_returns_maximum_value_and_its_step(self) -> None: + s = _make_series(values=[3.0, 1.0, 2.0], steps=[0, 1, 2]) + val, step = s.max + assert val == pytest.approx(3.0) + assert step == 0 + + def test_empty_series_returns_nan_and_minus_one(self) -> None: + s = _make_series(values=[], steps=[]) + last_val, last_step = s.last + min_val, min_step = s.min + max_val, max_step = s.max + assert math.isnan(last_val) + assert last_step == -1 + assert math.isnan(min_val) + assert min_step == -1 + assert math.isnan(max_val) + assert max_step == -1 + + +class _FakeRun: + def __init__( + self, + *, + hash: str = "abc123", + experiment: str | None = "exp", + name: str | None = None, + creation_time: float | None = None, + created_at: dt.datetime | None = None, + ) -> None: + self.hash = hash + self.experiment = experiment + self.name = name + if creation_time is not None: + self.creation_time = creation_time + if created_at is not None: + self.created_at = created_at + + +class TestExtractRunMeta: + def test_prefers_creation_time_timestamp(self) -> None: + run = _FakeRun(creation_time=1744532960.888126) + + meta = _extract_run_meta(run) + + assert meta.creation_time == pytest.approx(1744532960.888126) + + def test_falls_back_to_created_at_datetime(self) -> None: + run = _FakeRun(created_at=dt.datetime(2025, 4, 13, 8, 29, 20, 888126)) + + meta = _extract_run_meta(run) + + assert meta.creation_time == pytest.approx(1744532960.888126) + + +class _FakeMetricData: + def __init__( + self, + steps: list[int] | None = None, + values: list[float] | None = None, + epochs: list[float] | None = None, + *, + raise_value_error: bool = False, + ) -> None: + self._steps = steps or [] + self._values = values or [] + self._epochs = epochs or [] + self._raise_value_error = raise_value_error + + def items_list(self) -> tuple[list[int], list[list[float]]]: + if self._raise_value_error: + raise ValueError("no data") + return self._steps, [self._values, self._epochs, [0.0] * len(self._steps)] + + +class _FakeMetric: + def __init__(self, data: _FakeMetricData) -> None: + self.data = data + + +class TestExtractValues: + def test_preserves_distinct_steps_and_epochs(self) -> None: + metric = _FakeMetric( + _FakeMetricData( + steps=[10, 20, 30], + values=[0.1, 0.2, 0.3], + epochs=[1.0, 1.0, 2.0], + ) + ) + + values, steps, epochs = _extract_values(metric) + + assert values.tolist() == pytest.approx([0.1, 0.2, 0.3]) + assert steps.tolist() == [10, 20, 30] + assert epochs is not None + assert epochs.tolist() == pytest.approx([1.0, 1.0, 2.0]) + + def test_empty_metric_returns_empty_arrays(self) -> None: + metric = _FakeMetric(_FakeMetricData(raise_value_error=True)) + + values, steps, epochs = _extract_values(metric) + + assert values.tolist() == [] + assert steps.tolist() == [] + assert epochs is None + + +class TestGroupByRun: + def test_single_run_produces_one_group(self) -> None: + run = _make_run("aaa") + series_list = [_make_series(run=run, name="loss"), _make_series(run=run, name="lr")] + groups = group_by_run(series_list) + assert len(groups) == 1 + assert groups[0][0].hash == "aaa" + assert len(groups[0][1]) == 2 + + def test_multiple_runs_produce_separate_groups(self) -> None: + run_a = _make_run("aaa") + run_b = _make_run("bbb") + series_list = [ + _make_series(run=run_a, name="loss"), + _make_series(run=run_b, name="loss"), + _make_series(run=run_a, name="lr"), + ] + groups = group_by_run(series_list) + assert len(groups) == 2 + hashes = [g[0].hash for g in groups] + assert hashes == ["aaa", "bbb"] + assert len(groups[0][1]) == 2 # loss + lr for run_a + assert len(groups[1][1]) == 1 # only loss for run_b + + def test_empty_list_returns_empty_groups(self) -> None: + assert group_by_run([]) == [] + + def test_insertion_order_is_preserved(self) -> None: + runs = [_make_run(f"run{i}") for i in range(5)] + series_list = [_make_series(run=r, name="loss") for r in runs] + groups = group_by_run(series_list) + assert [g[0].hash for g in groups] == [f"run{i}" for i in range(5)] + + +class TestSubsample: + def test_head_keeps_first_n_points(self) -> None: + s = _make_series(values=[1.0, 2.0, 3.0, 4.0, 5.0]) + result = subsample(s, head=3, tail=None, every=None) + assert result.count == 3 + assert result.values.tolist() == pytest.approx([1.0, 2.0, 3.0]) + + def test_tail_keeps_last_n_points(self) -> None: + s = _make_series(values=[1.0, 2.0, 3.0, 4.0, 5.0]) + result = subsample(s, head=None, tail=2, every=None) + assert result.count == 2 + assert result.values.tolist() == pytest.approx([4.0, 5.0]) + + def test_every_k_keeps_every_kth_point(self) -> None: + s = _make_series(values=[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) + result = subsample(s, head=None, tail=None, every=2) + assert result.count == 3 + assert result.values.tolist() == pytest.approx([0.0, 2.0, 4.0]) + + def test_empty_series_is_returned_unchanged(self) -> None: + s = _make_series(values=[], steps=[]) + result = subsample(s, head=5, tail=None, every=None) + assert result.count == 0 + + def test_steps_are_sliced_consistently_with_values(self) -> None: + s = _make_series(values=[10.0, 20.0, 30.0, 40.0], steps=[100, 200, 300, 400]) + result = subsample(s, head=2, tail=None, every=None) + assert result.steps.tolist() == [100, 200] diff --git a/tests/unit/test_owned_commands.py b/tests/unit/test_owned_commands.py index 78329ee..af5a2a0 100644 --- a/tests/unit/test_owned_commands.py +++ b/tests/unit/test_owned_commands.py @@ -12,6 +12,7 @@ def test_render_help_lists_owned_commands_and_passthrough_boundary() -> None: assert "version" in help_text assert "doctor" in help_text assert "query" in help_text + assert "trace" in help_text assert "delegated to native `aim`" in help_text diff --git a/tests/unit/test_query_helpers.py b/tests/unit/test_query_helpers.py index 509c504..13f016f 100644 --- a/tests/unit/test_query_helpers.py +++ b/tests/unit/test_query_helpers.py @@ -4,7 +4,11 @@ import pytest -from aimx.commands.query import QueryInvocation, normalize_repo_path +from aimx.commands.query import ( + QueryInvocation, + normalize_repo_path, + parse_query_invocation, +) def test_normalize_repo_path_keeps_repo_root(tmp_path: Path) -> None: @@ -37,5 +41,75 @@ def test_query_invocation_rejects_unsupported_target() -> None: target="artifacts", expression="metric.name == 'loss'", repo_path=Path("data"), - output_json=False, ) + + +def test_parse_query_invocation_defaults() -> None: + inv = parse_query_invocation(["metrics", "metric.name == 'loss'"]) + assert inv.target == "metrics" + assert inv.expression == "metric.name == 'loss'" + assert inv.repo_path == Path(".") + assert not inv.output_json + assert not inv.plain + assert not inv.no_color + assert not inv.verbose + + +def test_parse_query_invocation_json_flag() -> None: + inv = parse_query_invocation(["metrics", "metric.name == 'loss'", "--repo", "data", "--json"]) + assert inv.output_json is True + + +def test_parse_query_invocation_oneline_flag() -> None: + inv = parse_query_invocation(["metrics", "metric.name == 'loss'", "--repo", "data", "--oneline"]) + assert inv.plain is True + + +def test_parse_query_invocation_plain_flag_alias() -> None: + inv = parse_query_invocation(["metrics", "metric.name == 'loss'", "--repo", "data", "--plain"]) + assert inv.plain is True + + +def test_parse_query_invocation_no_color_flag() -> None: + inv = parse_query_invocation(["metrics", "metric.name == 'loss'", "--repo", "data", "--no-color"]) + assert inv.no_color is True + + +def test_parse_query_invocation_verbose_flag() -> None: + inv = parse_query_invocation(["metrics", "metric.name == 'loss'", "--repo", "data", "--verbose"]) + assert inv.verbose is True + + +def test_parse_query_invocation_explicit_repo_overrides_default() -> None: + inv = parse_query_invocation(["metrics", "metric.name == 'loss'", "--repo", "data"]) + assert inv.repo_path == Path("data") + + +def test_parse_query_invocation_rejects_unknown_flag() -> None: + with pytest.raises(ValueError, match="Unsupported query option"): + parse_query_invocation(["metrics", "loss", "--repo", "data", "--bogus"]) + + +def test_parse_query_invocation_steps_closed_range() -> None: + inv = parse_query_invocation(["metrics", "metric.name=='loss'", "--repo", "data", "--steps", "100:500"]) + assert inv.step_slice == "100:500" + + +def test_parse_query_invocation_steps_open_end() -> None: + inv = parse_query_invocation(["metrics", "metric.name=='loss'", "--repo", "data", "--steps", "100:"]) + assert inv.step_slice == "100:" + + +def test_parse_query_invocation_steps_open_start() -> None: + inv = parse_query_invocation(["metrics", "metric.name=='loss'", "--repo", "data", "--steps", ":500"]) + assert inv.step_slice == ":500" + + +def test_parse_query_invocation_steps_missing_value_raises() -> None: + with pytest.raises(ValueError, match="Missing value for --steps"): + parse_query_invocation(["metrics", "loss", "--repo", "data", "--steps"]) + + +def test_parse_query_invocation_steps_defaults_to_none() -> None: + inv = parse_query_invocation(["metrics", "metric.name=='loss'", "--repo", "data"]) + assert inv.step_slice is None diff --git a/tests/unit/test_router.py b/tests/unit/test_router.py index b460f79..c18d747 100644 --- a/tests/unit/test_router.py +++ b/tests/unit/test_router.py @@ -31,6 +31,13 @@ def test_query_command_routes_to_owned_handler() -> None: assert route.owned_command == "query" +def test_trace_command_routes_to_owned_handler() -> None: + route = route_args(["trace", "metric.name == 'loss'", "--repo", "data"]) + + assert route.route_kind == "owned" + assert route.owned_command == "trace" + + def test_unknown_command_routes_to_passthrough_without_reordering() -> None: args = ["runs", "ls", "--json"] diff --git a/tests/unit/test_step_filter.py b/tests/unit/test_step_filter.py new file mode 100644 index 0000000..c7ad897 --- /dev/null +++ b/tests/unit/test_step_filter.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from aimx.aim_bridge.metric_stats import ( + MetricSeries, + RunMeta, + filter_by_step_range, + parse_step_slice, +) + + +def _make_series( + steps: list[int] | None = None, + values: list[float] | None = None, + epochs: list[float] | None = None, +) -> MetricSeries: + run = RunMeta(hash="a" * 32, experiment="exp", name=None, creation_time=None) + steps = steps if steps is not None else list(range(1, 6)) + values = values if values is not None else [float(i) for i in steps] + epochs_arr = np.array(epochs, dtype=float) if epochs is not None else None + return MetricSeries( + run=run, + name="loss", + context={}, + values=np.array(values, dtype=float), + steps=np.array(steps, dtype=int), + epochs=epochs_arr, + ) + + +# --------------------------------------------------------------------------- +# parse_step_slice +# --------------------------------------------------------------------------- + + +class TestParseStepSlice: + def test_full_range(self) -> None: + assert parse_step_slice("100:500") == (100, 500) + + def test_open_end(self) -> None: + assert parse_step_slice("100:") == (100, None) + + def test_open_start(self) -> None: + assert parse_step_slice(":500") == (None, 500) + + def test_no_colon_raises(self) -> None: + with pytest.raises(ValueError, match="slice syntax"): + parse_step_slice("100") + + def test_fully_open_raises(self) -> None: + with pytest.raises(ValueError, match="open slice"): + parse_step_slice(":") + + def test_non_integer_left_raises(self) -> None: + with pytest.raises(ValueError, match="left bound"): + parse_step_slice("abc:500") + + def test_non_integer_right_raises(self) -> None: + with pytest.raises(ValueError, match="right bound"): + parse_step_slice("100:xyz") + + def test_zero_start_is_valid(self) -> None: + assert parse_step_slice("0:10") == (0, 10) + + def test_whitespace_is_tolerated(self) -> None: + assert parse_step_slice(" 10 : 50 ") == (10, 50) + + +# --------------------------------------------------------------------------- +# filter_by_step_range +# --------------------------------------------------------------------------- + + +class TestFilterByStepRange: + def test_closed_range_keeps_correct_points(self) -> None: + s = _make_series(steps=[1, 2, 3, 4, 5]) + result = filter_by_step_range(s, 2, 4) + assert result.steps.tolist() == [2, 3, 4] + assert result.values.tolist() == pytest.approx([2.0, 3.0, 4.0]) + + def test_inclusive_lower_bound(self) -> None: + s = _make_series(steps=[1, 2, 3]) + result = filter_by_step_range(s, 1, None) + assert 1 in result.steps.tolist() + + def test_inclusive_upper_bound(self) -> None: + s = _make_series(steps=[1, 2, 3]) + result = filter_by_step_range(s, None, 3) + assert 3 in result.steps.tolist() + + def test_open_start_keeps_from_beginning(self) -> None: + s = _make_series(steps=[1, 2, 3, 4, 5]) + result = filter_by_step_range(s, None, 3) + assert result.steps.tolist() == [1, 2, 3] + + def test_open_end_keeps_to_end(self) -> None: + s = _make_series(steps=[1, 2, 3, 4, 5]) + result = filter_by_step_range(s, 3, None) + assert result.steps.tolist() == [3, 4, 5] + + def test_range_outside_data_returns_empty(self) -> None: + s = _make_series(steps=[1, 2, 3]) + result = filter_by_step_range(s, 100, 200) + assert result.count == 0 + + def test_epochs_sliced_consistently(self) -> None: + s = _make_series(steps=[1, 2, 3, 4, 5], epochs=[10.0, 20.0, 30.0, 40.0, 50.0]) + result = filter_by_step_range(s, 2, 4) + assert result.epochs is not None + assert result.epochs.tolist() == pytest.approx([20.0, 30.0, 40.0]) + + def test_series_without_epochs_stays_none(self) -> None: + s = _make_series(steps=[1, 2, 3]) + result = filter_by_step_range(s, 1, 2) + assert result.epochs is None + + def test_no_bounds_returns_all_points(self) -> None: + s = _make_series(steps=[1, 2, 3, 4]) + result = filter_by_step_range(s, None, None) + assert result.count == 4 + + def test_empty_series_returns_empty(self) -> None: + s = _make_series(steps=[], values=[]) + result = filter_by_step_range(s, 1, 10) + assert result.count == 0 diff --git a/tests/unit/test_trace_helpers.py b/tests/unit/test_trace_helpers.py new file mode 100644 index 0000000..19aad6a --- /dev/null +++ b/tests/unit/test_trace_helpers.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from aimx.commands.trace import TraceInvocation, parse_trace_invocation + + +def test_parse_trace_defaults() -> None: + inv = parse_trace_invocation(["metric.name=='loss'"]) + assert inv.expression == "metric.name=='loss'" + assert inv.repo_path == Path(".") + assert inv.mode == "plot" + assert inv.head is None + assert inv.tail is None + assert inv.every is None + assert inv.width is None + assert inv.height is None + assert not inv.no_color + + +def test_parse_trace_table_mode() -> None: + inv = parse_trace_invocation(["expr", "--repo", "data", "--table"]) + assert inv.mode == "table" + + +def test_parse_trace_csv_mode() -> None: + inv = parse_trace_invocation(["expr", "--repo", "data", "--csv"]) + assert inv.mode == "csv" + + +def test_parse_trace_json_mode() -> None: + inv = parse_trace_invocation(["expr", "--repo", "data", "--json"]) + assert inv.mode == "json" + + +def test_parse_trace_head_flag() -> None: + inv = parse_trace_invocation(["expr", "--repo", "data", "--head", "50"]) + assert inv.head == 50 + + +def test_parse_trace_tail_flag() -> None: + inv = parse_trace_invocation(["expr", "--repo", "data", "--tail", "10"]) + assert inv.tail == 10 + + +def test_parse_trace_every_flag() -> None: + inv = parse_trace_invocation(["expr", "--repo", "data", "--every", "5"]) + assert inv.every == 5 + + +def test_parse_trace_width_and_height() -> None: + inv = parse_trace_invocation(["expr", "--repo", "data", "--width", "100", "--height", "30"]) + assert inv.width == 100 + assert inv.height == 30 + + +def test_parse_trace_no_color_flag() -> None: + inv = parse_trace_invocation(["expr", "--repo", "data", "--no-color"]) + assert inv.no_color is True + + +def test_parse_trace_explicit_repo_overrides_default() -> None: + inv = parse_trace_invocation(["metric.name=='loss'", "--repo", "data"]) + assert inv.repo_path == Path("data") + + +def test_parse_trace_rejects_unknown_flag() -> None: + with pytest.raises(ValueError, match="Unsupported trace option"): + parse_trace_invocation(["expr", "--repo", "data", "--bogus"]) + + +def test_parse_trace_rejects_missing_expression() -> None: + with pytest.raises(ValueError, match="Usage"): + parse_trace_invocation([]) + + +def test_parse_trace_rejects_non_integer_head() -> None: + with pytest.raises(ValueError, match="--head requires an integer"): + parse_trace_invocation(["expr", "--repo", "data", "--head", "abc"]) + + +def test_parse_trace_rejects_every_less_than_one() -> None: + with pytest.raises(ValueError, match="--every must be >= 1"): + parse_trace_invocation(["expr", "--repo", "data", "--every", "0"]) + + +def test_parse_trace_steps_closed_range() -> None: + inv = parse_trace_invocation(["expr", "--repo", "data", "--steps", "50:100"]) + assert inv.step_slice == "50:100" + + +def test_parse_trace_steps_open_end() -> None: + inv = parse_trace_invocation(["expr", "--repo", "data", "--steps", "50:"]) + assert inv.step_slice == "50:" + + +def test_parse_trace_steps_open_start() -> None: + inv = parse_trace_invocation(["expr", "--repo", "data", "--steps", ":100"]) + assert inv.step_slice == ":100" + + +def test_parse_trace_steps_missing_value_raises() -> None: + with pytest.raises(ValueError, match="Missing value for --steps"): + parse_trace_invocation(["expr", "--repo", "data", "--steps"]) + + +def test_parse_trace_steps_defaults_to_none() -> None: + inv = parse_trace_invocation(["expr", "--repo", "data"]) + assert inv.step_slice is None diff --git a/uv.lock b/uv.lock index 0dfc4c2..ca1a7e0 100644 --- a/uv.lock +++ b/uv.lock @@ -98,6 +98,11 @@ wheels = [ name = "aimx" version = "0.2.0" source = { editable = "." } +dependencies = [ + { name = "numpy" }, + { name = "plotext" }, + { name = "rich" }, +] [package.dev-dependencies] dev = [ @@ -106,6 +111,11 @@ dev = [ ] [package.metadata] +requires-dist = [ + { name = "numpy", specifier = ">=1.24" }, + { name = "plotext", specifier = ">=5.3" }, + { name = "rich", specifier = ">=13.7" }, +] [package.metadata.requires-dev] dev = [ @@ -530,6 +540,18 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509 }, ] +[[package]] +name = "markdown-it-py" +version = "4.0.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321 }, +] + [[package]] name = "markupsafe" version = "3.0.3" @@ -571,6 +593,15 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e5/f1/216fc1bbfd74011693a4fd837e7026152e89c4bcf3e77b6692fba9923123/markupsafe-3.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f", size = 13906 }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 }, +] + [[package]] name = "numpy" version = "2.2.6" @@ -670,6 +701,15 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/bc/60/5382c03e1970de634027cee8e1b7d39776b778b81812aaf45b694dfe9e28/pillow-12.2.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:bfa9c230d2fe991bed5318a5f119bd6780cda2915cca595393649fc118ab895e", size = 7080946 }, ] +[[package]] +name = "plotext" +version = "5.3.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c9/d7/f75f397af966fe252d0d34ffd3cae765317fce2134f925f95e7d6725d1ce/plotext-5.3.2.tar.gz", hash = "sha256:52d1e932e67c177bf357a3f0fe6ce14d1a96f7f7d5679d7b455b929df517068e", size = 61967 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f6/1e/12fe7c40cd2099a1f454518754ed229b01beaf3bbb343127f0cc13ce6c22/plotext-5.3.2-py3-none-any.whl", hash = "sha256:394362349c1ddbf319548cfac17ca65e6d5dfc03200c40dfdc0503b3e95a2283", size = 64047 }, +] + [[package]] name = "pluggy" version = "1.6.0" @@ -867,6 +907,19 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1a/c0/3848f4006f7e164ee20833ca984067e4b3fc99fe7f1dfa88b4927e681299/restrictedpython-8.1-py3-none-any.whl", hash = "sha256:4769449c6cdb10f2071649ba386902befff0eff2a8fd6217989fa7b16aeae926", size = 27651 }, ] +[[package]] +name = "rich" +version = "15.0.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c0/8f/0722ca900cc807c13a6a0c696dacf35430f72e0ec571c4275d2371fca3e9/rich-15.0.0.tar.gz", hash = "sha256:edd07a4824c6b40189fb7ac9bc4c52536e9780fbbfbddf6f1e2502c31b068c36", size = 230680 } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654 }, +] + [[package]] name = "s3transfer" version = "0.16.0"