Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,21 @@ aimx trace "metric.name == 'loss'" --repo data --every 10
Output modes: default plot, `--table`, `--csv`, `--json`.
Display controls: `--width W`, `--height H`, `--no-color`.

### Trace distributions

`aimx trace distribution` fetches tracked Aim distribution sequences and renders
their tensor payload (histogram weights) per step in a terminal table, CSV, or
JSON.

```bash
# Show distribution tensors in a readable table
aimx trace distribution "distribution.name == 'weights'" --repo data

# Export distribution histograms for scripting
aimx trace distribution "distribution.name == 'weights'" --repo data --csv
aimx trace distribution "distribution.name == 'weights'" --repo data --json
```

### Common query options

- Output: `--json`, `--oneline` / `--plain`, or the default rich terminal view.
Expand Down
118 changes: 118 additions & 0 deletions src/aimx/aim_bridge/metric_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,26 @@ def max(self) -> tuple[float, int]:
return (float(self.values[idx]), int(self.steps[idx]))


@dataclass(frozen=True)
class DistributionPoint:
step: int
epoch: float | None
weights: np.ndarray
bin_edges: np.ndarray


@dataclass
class DistributionSeries:
run: RunMeta
name: str
context: dict[str, Any]
points: list[DistributionPoint]

@property
def count(self) -> int:
return len(self.points)


def _extract_run_meta(run: Any) -> RunMeta:
creation_time = getattr(run, "creation_time", None)
if creation_time is None:
Expand Down Expand Up @@ -233,6 +253,62 @@ def _accessor() -> Any:
return rows


def collect_distribution_series(expression: str, repo_path: Path) -> list[DistributionSeries]:
"""Run an Aim distribution query and return flat ``DistributionSeries`` records."""
from aimx.aim_bridge.hash_resolver import resolve_hash_prefixes

expression = resolve_hash_prefixes(expression, repo_path)

try:
from aim import Repo
from aim.sdk.types import QueryReportMode
except ModuleNotFoundError as error:
raise RuntimeError(
"`aimx` requires the Python `aim` package in the current environment."
) from error

repo = Repo(str(repo_path))
results: list[DistributionSeries] = []

stderr_buf = io.StringIO()
with contextlib.redirect_stderr(stderr_buf):
query_result = repo.query_distributions(
expression, report_mode=QueryReportMode.DISABLED
Comment on lines +275 to +276

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Normalize distribution query variable before execution

collect_distribution_series forwards the user expression directly to Repo.query_distributions, but Aim’s distribution sequence variable is distributions (see aim.sdk.sequences.distribution_sequence.Distributions.sequence_name()), while this change’s own examples/tests use distribution.name == .... In that common case, Aim treats distribution as undefined and the query silently yields no matches, so aimx trace distribution reports “No matching distributions found” even when data exists. Please rewrite/alias the singular form before calling query_distributions (or reject it with a clear error) so documented expressions actually return results.

Useful? React with 👍 / 👎.

)
for run_collection in query_result.iter_runs():
for distribution in run_collection:
run_meta = _extract_run_meta(distribution.run)
try:
steps, (values, epochs, _timestamps) = distribution.data.items_list()
except ValueError:
steps, values, epochs = [], [], []

points: list[DistributionPoint] = []
for idx, value in enumerate(values):
step_value = int(steps[idx])
epoch_value = float(epochs[idx]) if idx < len(epochs) else None
weights, bin_edges = value.to_np_histogram()
points.append(
DistributionPoint(
step=step_value,
epoch=epoch_value,
weights=np.array(weights, dtype=float),
bin_edges=np.array(bin_edges, dtype=float),
)
)

results.append(
DistributionSeries(
run=run_meta,
name=distribution.name,
context=distribution.context.to_dict(),
points=points,
)
)

return results


def subsample(series: MetricSeries, *, head: int | None, tail: int | None, every: int | None) -> MetricSeries:
"""Return a new MetricSeries with points filtered by head/tail/every."""
n = len(series.values)
Expand All @@ -258,6 +334,48 @@ def subsample(series: MetricSeries, *, head: int | None, tail: int | None, every
)


def filter_distribution_by_step_range(
series: DistributionSeries,
start: int | None,
end: int | None,
) -> DistributionSeries:
"""Return a new ``DistributionSeries`` filtered by inclusive step bounds."""
points = series.points
if start is not None:
points = [point for point in points if point.step >= start]
if end is not None:
points = [point for point in points if point.step <= end]
Comment on lines +344 to +347

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The step range filtering can be optimized to a single pass. This avoids creating an intermediate list when both start and end are provided, which improves performance for series with many points.

    if start is not None or end is not None:
        points = [
            point for point in points
            if (start is None or point.step >= start) and (end is None or point.step <= end)
        ]

return DistributionSeries(
run=series.run,
name=series.name,
context=series.context,
points=points,
)


def subsample_distribution(
series: DistributionSeries,
*,
head: int | None,
tail: int | None,
every: int | None,
) -> DistributionSeries:
"""Return a new ``DistributionSeries`` filtered by head/tail/every."""
points = series.points
if head is not None:
points = points[:head]
if tail is not None:
points = points[-tail:]
if every is not None and every > 1:
points = points[::every]
return DistributionSeries(
run=series.run,
name=series.name,
context=series.context,
points=points,
)


def parse_epoch_slice(s: str) -> tuple[float | None, float | None]:
"""Parse a ``start:end`` slice string into inclusive float bounds for epoch filtering.

Expand Down
2 changes: 2 additions & 0 deletions src/aimx/commands/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ def render_help() -> str:
" Example: aimx query params \"run.experiment=='cloud-segmentation'\" --repo data --param hparam.lr",
" trace Plot a metric's time-series from a local Aim repository",
" Usage: aimx trace <expression> [--repo <path>]",
" Usage: aimx trace distribution <expression> [--repo <path>]",
" Options: --table --csv --json",
" --steps start:end (e.g. --steps 100:500, :50, 100:)",
" --head N --tail N --every K",
" --width W --height H --no-color",
" Repo defaults to the current directory.",
" Short run hashes in the expression are transparently expanded.",
" Example: aimx trace \"metric.name=='loss'\" --repo data --steps 100:500",
" Example: aimx trace distribution \"distribution.name=='weights'\" --repo data --json",

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Use distributions in trace-distribution expressions

This example advertises distribution.name, but Aim’s Repo.query_distributions(...) binds the sequence variable as distributions (plural). When users copy this command, the query variable is undefined and the new trace path reports no matches, making the feature look broken even when data exists. Update the expression to distributions.name == 'weights' (and align related docs/tests) or add alias rewriting before executing the query.

Useful? React with 👍 / 👎.

"",
"All other commands are delegated to native `aim`.",
]
Expand Down
140 changes: 99 additions & 41 deletions src/aimx/commands/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

@dataclass(frozen=True)
class TraceInvocation:
target: Literal["metrics", "distribution"]
expression: str
repo_path: Path
mode: Literal["plot", "table", "csv", "json"] = "plot"
Expand All @@ -27,13 +28,26 @@ class TraceInvocation:
def parse_trace_invocation(args: list[str]) -> TraceInvocation:
if len(args) < 1:
raise ValueError(
"Usage: aimx trace <expression> [--repo <path>] [--table|--csv|--json]"
"Usage: aimx trace [distribution] <expression> [--repo <path>] [--table|--csv|--json]"
" [--steps start:end] [--head N] [--tail N] [--every K]"
" [--width W] [--height H] [--no-color]"
)

expression = args[0]
rest = args[1:]
target: Literal["metrics", "distribution"] = "metrics"
expression: str | None = None
rest = args
if args[0] == "distribution":
target = "distribution"
if len(args) < 2:
raise ValueError(
"Usage: aimx trace distribution <expression> [--repo <path>] [--table|--csv|--json]"
" [--steps start:end] [--head N] [--tail N] [--every K] [--no-color]"
)
expression = args[1]
rest = args[2:]
else:
expression = args[0]
rest = args[1:]

mode: Literal["plot", "table", "csv", "json"] = "plot"
repo_value = "."
Expand Down Expand Up @@ -118,6 +132,7 @@ def parse_trace_invocation(args: list[str]) -> TraceInvocation:
raise ValueError(f"Unsupported trace option: {token}")

return TraceInvocation(
target=target,
expression=expression,
repo_path=Path(repo_value),
mode=mode,
Expand All @@ -142,57 +157,100 @@ def run_trace_command(args: list[str]) -> QueryCommandResult:
effective_no_color = invocation.no_color or not is_tty

try:
from aimx.aim_bridge.metric_stats import (
collect_metric_series,
filter_by_step_range,
parse_step_slice,
subsample,
)
from aimx.aim_bridge.metric_stats import parse_step_slice
from aimx.rendering.trace_views import (
render_distribution_csv,
render_distribution_json,
render_distribution_table,
render_csv,
render_trace_table,
render_plot,
render_trace_json,
render_trace_table,
)

series_list = collect_metric_series(invocation.expression, normalized_repo_path)
if invocation.target == "distribution":
from aimx.aim_bridge.metric_stats import (
collect_distribution_series,
filter_distribution_by_step_range,
subsample_distribution,
)

if not series_list:
return QueryCommandResult(exit_status=0, output="No matching metrics found.")
series_list = collect_distribution_series(invocation.expression, normalized_repo_path)
if not series_list:
return QueryCommandResult(exit_status=0, output="No matching distributions found.")

# Step range filter is a hard constraint applied before density subsampling
if invocation.step_slice is not None:
step_start, step_end = parse_step_slice(invocation.step_slice)
series_list = [filter_by_step_range(s, step_start, step_end) for s in series_list]
# Drop empty series so they don't clutter plots
series_list = [s for s in series_list if s.count > 0]
if invocation.step_slice is not None:
step_start, step_end = parse_step_slice(invocation.step_slice)
series_list = [
filter_distribution_by_step_range(s, step_start, step_end) for s in series_list
]
series_list = [s for s in series_list if s.count > 0]

if not series_list:
return QueryCommandResult(exit_status=0, output="No data in the requested step range.")
if not series_list:
return QueryCommandResult(exit_status=0, output="No data in the requested step range.")

# Density subsampling for visualisation
needs_sample = any(
x is not None for x in (invocation.head, invocation.tail, invocation.every)
)
if needs_sample:
series_list = [
subsample(s, head=invocation.head, tail=invocation.tail, every=invocation.every)
for s in series_list
]

if invocation.mode == "json":
output = render_trace_json(series_list)
elif invocation.mode == "csv":
output = render_csv(series_list)
elif invocation.mode == "table":
output = render_trace_table(series_list, no_color=effective_no_color)
needs_sample = any(
x is not None for x in (invocation.head, invocation.tail, invocation.every)
)
if needs_sample:
series_list = [
subsample_distribution(
s, head=invocation.head, tail=invocation.tail, every=invocation.every
)
for s in series_list
]

if invocation.mode == "json":
output = render_distribution_json(series_list)
elif invocation.mode == "csv":
output = render_distribution_csv(series_list)
else:
output = render_distribution_table(series_list, no_color=effective_no_color)
else:
output = render_plot(
series_list,
width=invocation.width,
height=invocation.height,
from aimx.aim_bridge.metric_stats import (
collect_metric_series,
filter_by_step_range,
subsample,
)

series_list = collect_metric_series(invocation.expression, normalized_repo_path)

if not series_list:
return QueryCommandResult(exit_status=0, output="No matching metrics found.")

# Step range filter is a hard constraint applied before density subsampling
if invocation.step_slice is not None:
step_start, step_end = parse_step_slice(invocation.step_slice)
series_list = [filter_by_step_range(s, step_start, step_end) for s in series_list]
# Drop empty series so they don't clutter plots
series_list = [s for s in series_list if s.count > 0]

if not series_list:
return QueryCommandResult(exit_status=0, output="No data in the requested step range.")

# Density subsampling for visualisation
needs_sample = any(
x is not None for x in (invocation.head, invocation.tail, invocation.every)
)
if needs_sample:
series_list = [
subsample(s, head=invocation.head, tail=invocation.tail, every=invocation.every)
for s in series_list
]

if invocation.mode == "json":
output = render_trace_json(series_list)
elif invocation.mode == "csv":
output = render_csv(series_list)
elif invocation.mode == "table":
output = render_trace_table(series_list, no_color=effective_no_color)
else:
output = render_plot(
series_list,
width=invocation.width,
height=invocation.height,
)
Comment on lines +171 to +252

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between the distribution and metrics execution paths in run_trace_command. Both paths follow the same sequence: collect series, filter by step range, subsample, and render based on the mode.

Consider refactoring this into a unified pipeline. You could define a set of functions (collect, filter, subsample, render) based on the invocation.target and then execute the pipeline once. This would improve maintainability and ensure that improvements to the tracing logic (like new filters or sampling methods) are automatically applied to both metrics and distributions.


return QueryCommandResult(exit_status=0, output=output)

except RuntimeError as error:
Expand Down
Loading
Loading