-
Notifications
You must be signed in to change notification settings - Fork 1
Add aimx trace distribution with tensor (histogram) output modes
#15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,26 @@ def max(self) -> tuple[float, int]: | |
| return (float(self.values[idx]), int(self.steps[idx])) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class DistributionPoint: | ||
| step: int | ||
| epoch: float | None | ||
| weights: np.ndarray | ||
| bin_edges: np.ndarray | ||
|
|
||
|
|
||
| @dataclass | ||
| class DistributionSeries: | ||
| run: RunMeta | ||
| name: str | ||
| context: dict[str, Any] | ||
| points: list[DistributionPoint] | ||
|
|
||
| @property | ||
| def count(self) -> int: | ||
| return len(self.points) | ||
|
|
||
|
|
||
| def _extract_run_meta(run: Any) -> RunMeta: | ||
| creation_time = getattr(run, "creation_time", None) | ||
| if creation_time is None: | ||
|
|
@@ -233,6 +253,62 @@ def _accessor() -> Any: | |
| return rows | ||
|
|
||
|
|
||
| def collect_distribution_series(expression: str, repo_path: Path) -> list[DistributionSeries]: | ||
| """Run an Aim distribution query and return flat ``DistributionSeries`` records.""" | ||
| from aimx.aim_bridge.hash_resolver import resolve_hash_prefixes | ||
|
|
||
| expression = resolve_hash_prefixes(expression, repo_path) | ||
|
|
||
| try: | ||
| from aim import Repo | ||
| from aim.sdk.types import QueryReportMode | ||
| except ModuleNotFoundError as error: | ||
| raise RuntimeError( | ||
| "`aimx` requires the Python `aim` package in the current environment." | ||
| ) from error | ||
|
|
||
| repo = Repo(str(repo_path)) | ||
| results: list[DistributionSeries] = [] | ||
|
|
||
| stderr_buf = io.StringIO() | ||
| with contextlib.redirect_stderr(stderr_buf): | ||
| query_result = repo.query_distributions( | ||
| expression, report_mode=QueryReportMode.DISABLED | ||
| ) | ||
| for run_collection in query_result.iter_runs(): | ||
| for distribution in run_collection: | ||
| run_meta = _extract_run_meta(distribution.run) | ||
| try: | ||
| steps, (values, epochs, _timestamps) = distribution.data.items_list() | ||
| except ValueError: | ||
| steps, values, epochs = [], [], [] | ||
|
|
||
| points: list[DistributionPoint] = [] | ||
| for idx, value in enumerate(values): | ||
| step_value = int(steps[idx]) | ||
| epoch_value = float(epochs[idx]) if idx < len(epochs) else None | ||
| weights, bin_edges = value.to_np_histogram() | ||
| points.append( | ||
| DistributionPoint( | ||
| step=step_value, | ||
| epoch=epoch_value, | ||
| weights=np.array(weights, dtype=float), | ||
| bin_edges=np.array(bin_edges, dtype=float), | ||
| ) | ||
| ) | ||
|
|
||
| results.append( | ||
| DistributionSeries( | ||
| run=run_meta, | ||
| name=distribution.name, | ||
| context=distribution.context.to_dict(), | ||
| points=points, | ||
| ) | ||
| ) | ||
|
|
||
| return results | ||
|
|
||
|
|
||
| def subsample(series: MetricSeries, *, head: int | None, tail: int | None, every: int | None) -> MetricSeries: | ||
| """Return a new MetricSeries with points filtered by head/tail/every.""" | ||
| n = len(series.values) | ||
|
|
@@ -258,6 +334,48 @@ def subsample(series: MetricSeries, *, head: int | None, tail: int | None, every | |
| ) | ||
|
|
||
|
|
||
| def filter_distribution_by_step_range( | ||
| series: DistributionSeries, | ||
| start: int | None, | ||
| end: int | None, | ||
| ) -> DistributionSeries: | ||
| """Return a new ``DistributionSeries`` filtered by inclusive step bounds.""" | ||
| points = series.points | ||
| if start is not None: | ||
| points = [point for point in points if point.step >= start] | ||
| if end is not None: | ||
| points = [point for point in points if point.step <= end] | ||
|
Comment on lines
+344
to
+347
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The step range filtering can be optimized to a single pass. This avoids creating an intermediate list when both if start is not None or end is not None:
points = [
point for point in points
if (start is None or point.step >= start) and (end is None or point.step <= end)
] |
||
| return DistributionSeries( | ||
| run=series.run, | ||
| name=series.name, | ||
| context=series.context, | ||
| points=points, | ||
| ) | ||
|
|
||
|
|
||
| def subsample_distribution( | ||
| series: DistributionSeries, | ||
| *, | ||
| head: int | None, | ||
| tail: int | None, | ||
| every: int | None, | ||
| ) -> DistributionSeries: | ||
| """Return a new ``DistributionSeries`` filtered by head/tail/every.""" | ||
| points = series.points | ||
| if head is not None: | ||
| points = points[:head] | ||
| if tail is not None: | ||
| points = points[-tail:] | ||
| if every is not None and every > 1: | ||
| points = points[::every] | ||
| return DistributionSeries( | ||
| run=series.run, | ||
| name=series.name, | ||
| context=series.context, | ||
| points=points, | ||
| ) | ||
|
|
||
|
|
||
| def parse_epoch_slice(s: str) -> tuple[float | None, float | None]: | ||
| """Parse a ``start:end`` slice string into inclusive float bounds for epoch filtering. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,13 +26,15 @@ def render_help() -> str: | |
| " Example: aimx query params \"run.experiment=='cloud-segmentation'\" --repo data --param hparam.lr", | ||
| " trace Plot a metric's time-series from a local Aim repository", | ||
| " Usage: aimx trace <expression> [--repo <path>]", | ||
| " Usage: aimx trace distribution <expression> [--repo <path>]", | ||
| " Options: --table --csv --json", | ||
| " --steps start:end (e.g. --steps 100:500, :50, 100:)", | ||
| " --head N --tail N --every K", | ||
| " --width W --height H --no-color", | ||
| " Repo defaults to the current directory.", | ||
| " Short run hashes in the expression are transparently expanded.", | ||
| " Example: aimx trace \"metric.name=='loss'\" --repo data --steps 100:500", | ||
| " Example: aimx trace distribution \"distribution.name=='weights'\" --repo data --json", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This example advertises Useful? React with 👍 / 👎. |
||
| "", | ||
| "All other commands are delegated to native `aim`.", | ||
| ] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
|
|
||
| @dataclass(frozen=True) | ||
| class TraceInvocation: | ||
| target: Literal["metrics", "distribution"] | ||
| expression: str | ||
| repo_path: Path | ||
| mode: Literal["plot", "table", "csv", "json"] = "plot" | ||
|
|
@@ -27,13 +28,26 @@ class TraceInvocation: | |
| def parse_trace_invocation(args: list[str]) -> TraceInvocation: | ||
| if len(args) < 1: | ||
| raise ValueError( | ||
| "Usage: aimx trace <expression> [--repo <path>] [--table|--csv|--json]" | ||
| "Usage: aimx trace [distribution] <expression> [--repo <path>] [--table|--csv|--json]" | ||
| " [--steps start:end] [--head N] [--tail N] [--every K]" | ||
| " [--width W] [--height H] [--no-color]" | ||
| ) | ||
|
|
||
| expression = args[0] | ||
| rest = args[1:] | ||
| target: Literal["metrics", "distribution"] = "metrics" | ||
| expression: str | None = None | ||
| rest = args | ||
| if args[0] == "distribution": | ||
| target = "distribution" | ||
| if len(args) < 2: | ||
| raise ValueError( | ||
| "Usage: aimx trace distribution <expression> [--repo <path>] [--table|--csv|--json]" | ||
| " [--steps start:end] [--head N] [--tail N] [--every K] [--no-color]" | ||
| ) | ||
| expression = args[1] | ||
| rest = args[2:] | ||
| else: | ||
| expression = args[0] | ||
| rest = args[1:] | ||
|
|
||
| mode: Literal["plot", "table", "csv", "json"] = "plot" | ||
| repo_value = "." | ||
|
|
@@ -118,6 +132,7 @@ def parse_trace_invocation(args: list[str]) -> TraceInvocation: | |
| raise ValueError(f"Unsupported trace option: {token}") | ||
|
|
||
| return TraceInvocation( | ||
| target=target, | ||
| expression=expression, | ||
| repo_path=Path(repo_value), | ||
| mode=mode, | ||
|
|
@@ -142,57 +157,100 @@ def run_trace_command(args: list[str]) -> QueryCommandResult: | |
| effective_no_color = invocation.no_color or not is_tty | ||
|
|
||
| try: | ||
| from aimx.aim_bridge.metric_stats import ( | ||
| collect_metric_series, | ||
| filter_by_step_range, | ||
| parse_step_slice, | ||
| subsample, | ||
| ) | ||
| from aimx.aim_bridge.metric_stats import parse_step_slice | ||
| from aimx.rendering.trace_views import ( | ||
| render_distribution_csv, | ||
| render_distribution_json, | ||
| render_distribution_table, | ||
| render_csv, | ||
| render_trace_table, | ||
| render_plot, | ||
| render_trace_json, | ||
| render_trace_table, | ||
| ) | ||
|
|
||
| series_list = collect_metric_series(invocation.expression, normalized_repo_path) | ||
| if invocation.target == "distribution": | ||
| from aimx.aim_bridge.metric_stats import ( | ||
| collect_distribution_series, | ||
| filter_distribution_by_step_range, | ||
| subsample_distribution, | ||
| ) | ||
|
|
||
| if not series_list: | ||
| return QueryCommandResult(exit_status=0, output="No matching metrics found.") | ||
| series_list = collect_distribution_series(invocation.expression, normalized_repo_path) | ||
| if not series_list: | ||
| return QueryCommandResult(exit_status=0, output="No matching distributions found.") | ||
|
|
||
| # Step range filter is a hard constraint applied before density subsampling | ||
| if invocation.step_slice is not None: | ||
| step_start, step_end = parse_step_slice(invocation.step_slice) | ||
| series_list = [filter_by_step_range(s, step_start, step_end) for s in series_list] | ||
| # Drop empty series so they don't clutter plots | ||
| series_list = [s for s in series_list if s.count > 0] | ||
| if invocation.step_slice is not None: | ||
| step_start, step_end = parse_step_slice(invocation.step_slice) | ||
| series_list = [ | ||
| filter_distribution_by_step_range(s, step_start, step_end) for s in series_list | ||
| ] | ||
| series_list = [s for s in series_list if s.count > 0] | ||
|
|
||
| if not series_list: | ||
| return QueryCommandResult(exit_status=0, output="No data in the requested step range.") | ||
| if not series_list: | ||
| return QueryCommandResult(exit_status=0, output="No data in the requested step range.") | ||
|
|
||
| # Density subsampling for visualisation | ||
| needs_sample = any( | ||
| x is not None for x in (invocation.head, invocation.tail, invocation.every) | ||
| ) | ||
| if needs_sample: | ||
| series_list = [ | ||
| subsample(s, head=invocation.head, tail=invocation.tail, every=invocation.every) | ||
| for s in series_list | ||
| ] | ||
|
|
||
| if invocation.mode == "json": | ||
| output = render_trace_json(series_list) | ||
| elif invocation.mode == "csv": | ||
| output = render_csv(series_list) | ||
| elif invocation.mode == "table": | ||
| output = render_trace_table(series_list, no_color=effective_no_color) | ||
| needs_sample = any( | ||
| x is not None for x in (invocation.head, invocation.tail, invocation.every) | ||
| ) | ||
| if needs_sample: | ||
| series_list = [ | ||
| subsample_distribution( | ||
| s, head=invocation.head, tail=invocation.tail, every=invocation.every | ||
| ) | ||
| for s in series_list | ||
| ] | ||
|
|
||
| if invocation.mode == "json": | ||
| output = render_distribution_json(series_list) | ||
| elif invocation.mode == "csv": | ||
| output = render_distribution_csv(series_list) | ||
| else: | ||
| output = render_distribution_table(series_list, no_color=effective_no_color) | ||
| else: | ||
| output = render_plot( | ||
| series_list, | ||
| width=invocation.width, | ||
| height=invocation.height, | ||
| from aimx.aim_bridge.metric_stats import ( | ||
| collect_metric_series, | ||
| filter_by_step_range, | ||
| subsample, | ||
| ) | ||
|
|
||
| series_list = collect_metric_series(invocation.expression, normalized_repo_path) | ||
|
|
||
| if not series_list: | ||
| return QueryCommandResult(exit_status=0, output="No matching metrics found.") | ||
|
|
||
| # Step range filter is a hard constraint applied before density subsampling | ||
| if invocation.step_slice is not None: | ||
| step_start, step_end = parse_step_slice(invocation.step_slice) | ||
| series_list = [filter_by_step_range(s, step_start, step_end) for s in series_list] | ||
| # Drop empty series so they don't clutter plots | ||
| series_list = [s for s in series_list if s.count > 0] | ||
|
|
||
| if not series_list: | ||
| return QueryCommandResult(exit_status=0, output="No data in the requested step range.") | ||
|
|
||
| # Density subsampling for visualisation | ||
| needs_sample = any( | ||
| x is not None for x in (invocation.head, invocation.tail, invocation.every) | ||
| ) | ||
| if needs_sample: | ||
| series_list = [ | ||
| subsample(s, head=invocation.head, tail=invocation.tail, every=invocation.every) | ||
| for s in series_list | ||
| ] | ||
|
|
||
| if invocation.mode == "json": | ||
| output = render_trace_json(series_list) | ||
| elif invocation.mode == "csv": | ||
| output = render_csv(series_list) | ||
| elif invocation.mode == "table": | ||
| output = render_trace_table(series_list, no_color=effective_no_color) | ||
| else: | ||
| output = render_plot( | ||
| series_list, | ||
| width=invocation.width, | ||
| height=invocation.height, | ||
| ) | ||
|
Comment on lines
+171
to
+252
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is significant code duplication between the Consider refactoring this into a unified pipeline. You could define a set of functions (collect, filter, subsample, render) based on the |
||
|
|
||
| return QueryCommandResult(exit_status=0, output=output) | ||
|
|
||
| except RuntimeError as error: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
collect_distribution_seriesforwards the user expression directly toRepo.query_distributions, but Aim’s distribution sequence variable isdistributions(seeaim.sdk.sequences.distribution_sequence.Distributions.sequence_name()), while this change’s own examples/tests usedistribution.name == .... In that common case, Aim treatsdistributionas undefined and the query silently yields no matches, soaimx trace distributionreports “No matching distributions found” even when data exists. Please rewrite/alias the singular form before callingquery_distributions(or reject it with a clear error) so documented expressions actually return results.Useful? React with 👍 / 👎.