diff --git a/README.md b/README.md index 2fbb0d3..ff851e9 100644 --- a/README.md +++ b/README.md @@ -232,6 +232,21 @@ aimx trace "metric.name == 'loss'" --repo data --every 10 Output modes: default plot, `--table`, `--csv`, `--json`. Display controls: `--width W`, `--height H`, `--no-color`. +### Trace distributions + +`aimx trace distribution` fetches tracked Aim distribution sequences and renders +their tensor payload (histogram weights) per step in a terminal table, CSV, or +JSON. + +```bash +# Show distribution tensors in a readable table +aimx trace distribution "distribution.name == 'weights'" --repo data + +# Export distribution histograms for scripting +aimx trace distribution "distribution.name == 'weights'" --repo data --csv +aimx trace distribution "distribution.name == 'weights'" --repo data --json +``` + ### Common query options - Output: `--json`, `--oneline` / `--plain`, or the default rich terminal view. diff --git a/src/aimx/aim_bridge/metric_stats.py b/src/aimx/aim_bridge/metric_stats.py index 6dc6864..df165e3 100644 --- a/src/aimx/aim_bridge/metric_stats.py +++ b/src/aimx/aim_bridge/metric_stats.py @@ -53,6 +53,26 @@ def max(self) -> tuple[float, int]: return (float(self.values[idx]), int(self.steps[idx])) +@dataclass(frozen=True) +class DistributionPoint: + step: int + epoch: float | None + weights: np.ndarray + bin_edges: np.ndarray + + +@dataclass +class DistributionSeries: + run: RunMeta + name: str + context: dict[str, Any] + points: list[DistributionPoint] + + @property + def count(self) -> int: + return len(self.points) + + def _extract_run_meta(run: Any) -> RunMeta: creation_time = getattr(run, "creation_time", None) if creation_time is None: @@ -233,6 +253,62 @@ def _accessor() -> Any: return rows +def collect_distribution_series(expression: str, repo_path: Path) -> list[DistributionSeries]: + """Run an Aim distribution query and return flat ``DistributionSeries`` records.""" + from aimx.aim_bridge.hash_resolver import resolve_hash_prefixes + + expression = resolve_hash_prefixes(expression, repo_path) + + try: + from aim import Repo + from aim.sdk.types import QueryReportMode + except ModuleNotFoundError as error: + raise RuntimeError( + "`aimx` requires the Python `aim` package in the current environment." + ) from error + + repo = Repo(str(repo_path)) + results: list[DistributionSeries] = [] + + stderr_buf = io.StringIO() + with contextlib.redirect_stderr(stderr_buf): + query_result = repo.query_distributions( + expression, report_mode=QueryReportMode.DISABLED + ) + for run_collection in query_result.iter_runs(): + for distribution in run_collection: + run_meta = _extract_run_meta(distribution.run) + try: + steps, (values, epochs, _timestamps) = distribution.data.items_list() + except ValueError: + steps, values, epochs = [], [], [] + + points: list[DistributionPoint] = [] + for idx, value in enumerate(values): + step_value = int(steps[idx]) + epoch_value = float(epochs[idx]) if idx < len(epochs) else None + weights, bin_edges = value.to_np_histogram() + points.append( + DistributionPoint( + step=step_value, + epoch=epoch_value, + weights=np.array(weights, dtype=float), + bin_edges=np.array(bin_edges, dtype=float), + ) + ) + + results.append( + DistributionSeries( + run=run_meta, + name=distribution.name, + context=distribution.context.to_dict(), + points=points, + ) + ) + + return results + + def subsample(series: MetricSeries, *, head: int | None, tail: int | None, every: int | None) -> MetricSeries: """Return a new MetricSeries with points filtered by head/tail/every.""" n = len(series.values) @@ -258,6 +334,48 @@ def subsample(series: MetricSeries, *, head: int | None, tail: int | None, every ) +def filter_distribution_by_step_range( + series: DistributionSeries, + start: int | None, + end: int | None, +) -> DistributionSeries: + """Return a new ``DistributionSeries`` filtered by inclusive step bounds.""" + points = series.points + if start is not None: + points = [point for point in points if point.step >= start] + if end is not None: + points = [point for point in points if point.step <= end] + return DistributionSeries( + run=series.run, + name=series.name, + context=series.context, + points=points, + ) + + +def subsample_distribution( + series: DistributionSeries, + *, + head: int | None, + tail: int | None, + every: int | None, +) -> DistributionSeries: + """Return a new ``DistributionSeries`` filtered by head/tail/every.""" + points = series.points + if head is not None: + points = points[:head] + if tail is not None: + points = points[-tail:] + if every is not None and every > 1: + points = points[::every] + return DistributionSeries( + run=series.run, + name=series.name, + context=series.context, + points=points, + ) + + def parse_epoch_slice(s: str) -> tuple[float | None, float | None]: """Parse a ``start:end`` slice string into inclusive float bounds for epoch filtering. diff --git a/src/aimx/commands/help.py b/src/aimx/commands/help.py index 309368f..971e6b2 100644 --- a/src/aimx/commands/help.py +++ b/src/aimx/commands/help.py @@ -26,6 +26,7 @@ def render_help() -> str: " Example: aimx query params \"run.experiment=='cloud-segmentation'\" --repo data --param hparam.lr", " trace Plot a metric's time-series from a local Aim repository", " Usage: aimx trace [--repo ]", + " Usage: aimx trace distribution [--repo ]", " Options: --table --csv --json", " --steps start:end (e.g. --steps 100:500, :50, 100:)", " --head N --tail N --every K", @@ -33,6 +34,7 @@ def render_help() -> str: " Repo defaults to the current directory.", " Short run hashes in the expression are transparently expanded.", " Example: aimx trace \"metric.name=='loss'\" --repo data --steps 100:500", + " Example: aimx trace distribution \"distribution.name=='weights'\" --repo data --json", "", "All other commands are delegated to native `aim`.", ] diff --git a/src/aimx/commands/trace.py b/src/aimx/commands/trace.py index 6501484..129480d 100644 --- a/src/aimx/commands/trace.py +++ b/src/aimx/commands/trace.py @@ -12,6 +12,7 @@ @dataclass(frozen=True) class TraceInvocation: + target: Literal["metrics", "distribution"] expression: str repo_path: Path mode: Literal["plot", "table", "csv", "json"] = "plot" @@ -27,13 +28,26 @@ class TraceInvocation: def parse_trace_invocation(args: list[str]) -> TraceInvocation: if len(args) < 1: raise ValueError( - "Usage: aimx trace [--repo ] [--table|--csv|--json]" + "Usage: aimx trace [distribution] [--repo ] [--table|--csv|--json]" " [--steps start:end] [--head N] [--tail N] [--every K]" " [--width W] [--height H] [--no-color]" ) - expression = args[0] - rest = args[1:] + target: Literal["metrics", "distribution"] = "metrics" + expression: str | None = None + rest = args + if args[0] == "distribution": + target = "distribution" + if len(args) < 2: + raise ValueError( + "Usage: aimx trace distribution [--repo ] [--table|--csv|--json]" + " [--steps start:end] [--head N] [--tail N] [--every K] [--no-color]" + ) + expression = args[1] + rest = args[2:] + else: + expression = args[0] + rest = args[1:] mode: Literal["plot", "table", "csv", "json"] = "plot" repo_value = "." @@ -118,6 +132,7 @@ def parse_trace_invocation(args: list[str]) -> TraceInvocation: raise ValueError(f"Unsupported trace option: {token}") return TraceInvocation( + target=target, expression=expression, repo_path=Path(repo_value), mode=mode, @@ -142,57 +157,100 @@ def run_trace_command(args: list[str]) -> QueryCommandResult: effective_no_color = invocation.no_color or not is_tty try: - from aimx.aim_bridge.metric_stats import ( - collect_metric_series, - filter_by_step_range, - parse_step_slice, - subsample, - ) + from aimx.aim_bridge.metric_stats import parse_step_slice from aimx.rendering.trace_views import ( + render_distribution_csv, + render_distribution_json, + render_distribution_table, render_csv, + render_trace_table, render_plot, render_trace_json, - render_trace_table, ) - series_list = collect_metric_series(invocation.expression, normalized_repo_path) + if invocation.target == "distribution": + from aimx.aim_bridge.metric_stats import ( + collect_distribution_series, + filter_distribution_by_step_range, + subsample_distribution, + ) - if not series_list: - return QueryCommandResult(exit_status=0, output="No matching metrics found.") + series_list = collect_distribution_series(invocation.expression, normalized_repo_path) + if not series_list: + return QueryCommandResult(exit_status=0, output="No matching distributions found.") - # Step range filter is a hard constraint applied before density subsampling - if invocation.step_slice is not None: - step_start, step_end = parse_step_slice(invocation.step_slice) - series_list = [filter_by_step_range(s, step_start, step_end) for s in series_list] - # Drop empty series so they don't clutter plots - series_list = [s for s in series_list if s.count > 0] + if invocation.step_slice is not None: + step_start, step_end = parse_step_slice(invocation.step_slice) + series_list = [ + filter_distribution_by_step_range(s, step_start, step_end) for s in series_list + ] + series_list = [s for s in series_list if s.count > 0] - if not series_list: - return QueryCommandResult(exit_status=0, output="No data in the requested step range.") + if not series_list: + return QueryCommandResult(exit_status=0, output="No data in the requested step range.") - # Density subsampling for visualisation - needs_sample = any( - x is not None for x in (invocation.head, invocation.tail, invocation.every) - ) - if needs_sample: - series_list = [ - subsample(s, head=invocation.head, tail=invocation.tail, every=invocation.every) - for s in series_list - ] - - if invocation.mode == "json": - output = render_trace_json(series_list) - elif invocation.mode == "csv": - output = render_csv(series_list) - elif invocation.mode == "table": - output = render_trace_table(series_list, no_color=effective_no_color) + needs_sample = any( + x is not None for x in (invocation.head, invocation.tail, invocation.every) + ) + if needs_sample: + series_list = [ + subsample_distribution( + s, head=invocation.head, tail=invocation.tail, every=invocation.every + ) + for s in series_list + ] + + if invocation.mode == "json": + output = render_distribution_json(series_list) + elif invocation.mode == "csv": + output = render_distribution_csv(series_list) + else: + output = render_distribution_table(series_list, no_color=effective_no_color) else: - output = render_plot( - series_list, - width=invocation.width, - height=invocation.height, + from aimx.aim_bridge.metric_stats import ( + collect_metric_series, + filter_by_step_range, + subsample, ) + series_list = collect_metric_series(invocation.expression, normalized_repo_path) + + if not series_list: + return QueryCommandResult(exit_status=0, output="No matching metrics found.") + + # Step range filter is a hard constraint applied before density subsampling + if invocation.step_slice is not None: + step_start, step_end = parse_step_slice(invocation.step_slice) + series_list = [filter_by_step_range(s, step_start, step_end) for s in series_list] + # Drop empty series so they don't clutter plots + series_list = [s for s in series_list if s.count > 0] + + if not series_list: + return QueryCommandResult(exit_status=0, output="No data in the requested step range.") + + # Density subsampling for visualisation + needs_sample = any( + x is not None for x in (invocation.head, invocation.tail, invocation.every) + ) + if needs_sample: + series_list = [ + subsample(s, head=invocation.head, tail=invocation.tail, every=invocation.every) + for s in series_list + ] + + if invocation.mode == "json": + output = render_trace_json(series_list) + elif invocation.mode == "csv": + output = render_csv(series_list) + elif invocation.mode == "table": + output = render_trace_table(series_list, no_color=effective_no_color) + else: + output = render_plot( + series_list, + width=invocation.width, + height=invocation.height, + ) + return QueryCommandResult(exit_status=0, output=output) except RuntimeError as error: diff --git a/src/aimx/rendering/trace_views.py b/src/aimx/rendering/trace_views.py index fd9eed8..8c2f508 100644 --- a/src/aimx/rendering/trace_views.py +++ b/src/aimx/rendering/trace_views.py @@ -11,7 +11,7 @@ from rich.console import Console from rich.table import Table -from aimx.aim_bridge.metric_stats import MetricSeries, RunMeta +from aimx.aim_bridge.metric_stats import DistributionSeries, MetricSeries, RunMeta from aimx.rendering import colors @@ -38,6 +38,19 @@ def _series_label(series: MetricSeries) -> str: return " · ".join(parts) +def _distribution_series_label(series: DistributionSeries) -> str: + parts = [_short_hash(series.run.hash)] + if series.run.experiment: + parts.append(series.run.experiment) + elif series.run.name: + parts.append(series.run.name) + parts.append(series.name) + ctx = _fmt_context(series.context) + if ctx: + parts.append(f"[{ctx}]") + return " · ".join(parts) + + def render_plot( series_list: list[MetricSeries], *, @@ -166,3 +179,105 @@ def render_trace_json(series_list: list[MetricSeries]) -> str: } ) return json.dumps(result) + + +def _format_tensor(values: list[float], *, limit: int = 12) -> str: + if len(values) <= limit: + return "[" + ", ".join(f"{v:.6g}" for v in values) + "]" + head = ", ".join(f"{v:.6g}" for v in values[:limit]) + return f"[{head}, …] ({len(values)} bins)" + + +def render_distribution_table( + series_list: list[DistributionSeries], + *, + no_color: bool = False, +) -> str: + """Render distribution series as a step-indexed tensor table.""" + width = 120 if no_color else shutil.get_terminal_size(fallback=(120, 24)).columns + buf = io.StringIO() + console = Console( + file=buf, + no_color=no_color, + force_terminal=not no_color, + width=width, + highlight=False, + ) + + for series in series_list: + label = _distribution_series_label(series) + console.print(f"\n[{colors.HEADER}]{label}[/] [{colors.HEADER}]{series.count} points[/]") + + table = Table( + show_header=True, + header_style=colors.HEADER, + box=None, + pad_edge=True, + show_edge=False, + padding=(0, 1), + ) + table.add_column("STEP", justify="right") + table.add_column("EPOCH", justify="right") + table.add_column("TENSOR", justify="left", style=colors.NUMBER_EMPH) + + for point in series.points: + epoch = f"{point.epoch:.6g}" if point.epoch is not None else "—" + weights = point.weights.tolist() + table.add_row(str(point.step), epoch, _format_tensor(weights)) + + console.print(table) + + return buf.getvalue() + + +def render_distribution_csv(series_list: list[DistributionSeries]) -> str: + """Render distribution rows as CSV.""" + buf = io.StringIO() + writer = csv.writer(buf) + writer.writerow( + ["run_hash", "experiment", "distribution", "context", "step", "epoch", "bin_edges", "weights"] + ) + for series in series_list: + ctx_str = json.dumps(series.context, sort_keys=True) + for point in series.points: + writer.writerow( + [ + series.run.hash, + series.run.experiment or series.run.name or "", + series.name, + ctx_str, + point.step, + point.epoch if point.epoch is not None else "", + json.dumps(point.bin_edges.tolist()), + json.dumps(point.weights.tolist()), + ] + ) + return buf.getvalue() + + +def render_distribution_json(series_list: list[DistributionSeries]) -> str: + """Render distribution rows as JSON.""" + result: list[dict[str, Any]] = [] + for series in series_list: + result.append( + { + "run": { + "hash": series.run.hash, + "experiment": series.run.experiment, + "name": series.run.name, + }, + "distribution": series.name, + "context": series.context, + "count": series.count, + "points": [ + { + "step": point.step, + "epoch": point.epoch, + "bin_edges": point.bin_edges.tolist(), + "weights": point.weights.tolist(), + } + for point in series.points + ], + } + ) + return json.dumps(result) diff --git a/tests/unit/test_trace_distribution_views.py b/tests/unit/test_trace_distribution_views.py new file mode 100644 index 0000000..26d8811 --- /dev/null +++ b/tests/unit/test_trace_distribution_views.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import csv +import io +import json + +import numpy as np + +from aimx.aim_bridge.metric_stats import DistributionPoint, DistributionSeries, RunMeta +from aimx.rendering.trace_views import ( + render_distribution_csv, + render_distribution_json, + render_distribution_table, +) + + +def _sample_distribution_series() -> list[DistributionSeries]: + run = RunMeta(hash="1234567890abcdef", experiment="exp-a", name=None, creation_time=None) + return [ + DistributionSeries( + run=run, + name="weights", + context={"subset": "train"}, + points=[ + DistributionPoint( + step=10, + epoch=1.0, + bin_edges=np.array([0.0, 1.0, 2.0]), + weights=np.array([3.0, 5.0]), + ), + DistributionPoint( + step=20, + epoch=2.0, + bin_edges=np.array([0.0, 1.0, 2.0]), + weights=np.array([2.0, 4.0]), + ), + ], + ) + ] + + +def test_render_distribution_table_includes_tensor_column() -> None: + output = render_distribution_table(_sample_distribution_series(), no_color=True) + + assert "TENSOR" in output + assert "weights" in output + assert "[3, 5]" in output + + +def test_render_distribution_csv_contains_bin_edges_and_weights() -> None: + output = render_distribution_csv(_sample_distribution_series()) + + reader = csv.DictReader(io.StringIO(output)) + rows = list(reader) + assert rows + assert rows[0]["distribution"] == "weights" + assert rows[0]["bin_edges"] == "[0.0, 1.0, 2.0]" + assert rows[0]["weights"] == "[3.0, 5.0]" + + +def test_render_distribution_json_contains_points() -> None: + output = render_distribution_json(_sample_distribution_series()) + payload = json.loads(output) + + assert payload + first = payload[0] + assert first["distribution"] == "weights" + assert first["count"] == 2 + assert first["points"][0]["step"] == 10 + assert first["points"][0]["weights"] == [3.0, 5.0] diff --git a/tests/unit/test_trace_helpers.py b/tests/unit/test_trace_helpers.py index 19aad6a..8afa427 100644 --- a/tests/unit/test_trace_helpers.py +++ b/tests/unit/test_trace_helpers.py @@ -9,6 +9,7 @@ def test_parse_trace_defaults() -> None: inv = parse_trace_invocation(["metric.name=='loss'"]) + assert inv.target == "metrics" assert inv.expression == "metric.name=='loss'" assert inv.repo_path == Path(".") assert inv.mode == "plot" @@ -66,6 +67,18 @@ def test_parse_trace_explicit_repo_overrides_default() -> None: assert inv.repo_path == Path("data") +def test_parse_trace_distribution_target() -> None: + inv = parse_trace_invocation(["distribution", "distribution.name=='weights'", "--repo", "data"]) + assert inv.target == "distribution" + assert inv.expression == "distribution.name=='weights'" + assert inv.repo_path == Path("data") + + +def test_parse_trace_distribution_requires_expression() -> None: + with pytest.raises(ValueError, match="trace distribution"): + parse_trace_invocation(["distribution"]) + + def test_parse_trace_rejects_unknown_flag() -> None: with pytest.raises(ValueError, match="Unsupported trace option"): parse_trace_invocation(["expr", "--repo", "data", "--bogus"])