Skip to content

Formalize test inference with enriched metrics and auto-analysis#106

Draft
forklady42 wants to merge 11 commits into
mainfrom
test/formalize-inference
Draft

Formalize test inference with enriched metrics and auto-analysis#106
forklady42 wants to merge 11 commits into
mainfrom
test/formalize-inference

Conversation

@forklady42

Copy link
Copy Markdown
Collaborator

Summary

  • Enrich metrics.csv from 3 columns (rank,index,nmae) to 10: adds loss, max_pred, max_target, mean_pred, mean_target, num_electrons, duration_ms — all computed per-sample over spatial dims
  • Flexible checkpoint resolution in test.py: checks ckpt_file > last.ckpt > best.ckpt > glob fallback, replacing the hardcoded last.ckpt
  • New summarize.py module: computes NMAE stats (mean/median/P95/P99/max), threshold counts, generates histogram + CDF plots, and optionally logs to W&B (image, table, histogram, scalar stats)
  • Auto-chain analysis after trainer.test(): summary + distribution plots always run; saturation and tail analysis run when applicable

Test plan

  • All 25 tests on main pass (uv run pytest)
  • Pre-commit (ruff lint + format) passes on all changed files
  • Run test inference on a checkpoint and verify metrics.csv has all 10 columns
  • Verify summary.txt and nmae_distribution.png are generated in log_dir
  • Verify analyze_saturation works on the enriched CSV (no more missing column errors)
  • Verify W&B logging works when wandb_mode: online

🤖 Generated with Claude Code

forklady42 and others added 7 commits March 27, 2026 14:00
Enrich metrics.csv from 3 columns (rank, index, nmae) to 10 columns
adding loss, max_pred, max_target, mean_pred, mean_target,
num_electrons, and duration_ms. Add flexible checkpoint resolution
(ckpt_file > last.ckpt > best.ckpt > glob fallback) and automatic
post-test summary statistics with distribution plots. This unblocks
analyze_saturation.py which already expected max_pred/max_target
columns.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Compute max_pred, max_target, mean_pred, mean_target, and
num_electrons per-sample by reducing over spatial dimensions only
(keeping the batch dimension). Previously these were batch-level
scalars that happened to be correct only with batch_size=1.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The glob fallback picks the latest epoch by lexicographic sort,
not the lowest val_loss. Fix the docstring and comment to match.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
No need for a separate normmae_fn when the only loss function is
NormMAE — both compute the same thing. Uses loss_fn for both the
nmae and loss columns in metrics.csv (they'll be identical for now).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When wandb_mode != "disabled", log to W&B after test inference:
- Distribution PNG as wandb.Image
- Per-sample metrics as wandb.Table for interactive filtering
- Native histogram for the overview panel
- Scalar summary stats (mean, median, P95, P99, max)

W&B is wired into the Trainer so Lightning's built-in test_loss
metric also appears in the dashboard.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
analyze_metrics does not create its output directory. The test
entrypoint now mkdir's saturation_dir before calling it.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@forklady42

Copy link
Copy Markdown
Collaborator Author

@claude review

@claude

claude Bot commented May 20, 2026

Copy link
Copy Markdown

Claude finished @forklady42's task in 2m 32s —— View job


Code Review

  • Gathering context
  • Reading changed files
  • Analyzing code
  • Posting review

Overall this is a solid improvement — the enriched metrics, flexible checkpoint resolution, and auto-analysis pipeline are all well-structured. A few bugs and one breaking config change need attention before merging.


Critical Issues

1. analyze_saturation.py and analyze_tail.py don't exist

test.py unconditionally imports these:

from electrai.scripts.analyze.analyze_saturation import analyze_metrics  # line 145
from electrai.scripts.analyze.analyze_tail import main as tail_main       # line 159

But only summarize.py exists in src/electrai/scripts/analyze/. These imports will raise ImportError at runtime, and that error is not caught — only KeyError and ValueError are in the except clause (line 151). The saturation block will crash for every user running test inference. Either add these files, guard the imports with try/except ImportError, or remove the saturation/tail blocks until those modules exist.

Fix this →


2. Breaking config field rename: precisionmodel_precision

test.py:109: precision=cfg.model_precision

The MP configs (config_resunet.yaml, config_resnet.yaml) use precision: 32 — not model_precision. The QM9 config is the only one already using model_precision. Any test run against the MP configs will fail with AttributeError: 'SimpleNamespace' object has no attribute 'model_precision'. Either rename the field in all configs, or use getattr(cfg, "model_precision", getattr(cfg, "precision", 32)) as a fallback.

Fix this →


3. Missing __init__.py files

src/electrai/scripts/ and src/electrai/scripts/analyze/ have no __init__.py. With setuptools auto-discovery and no explicit [tool.setuptools.packages.find] config, these directories won't be installed as part of the package. The imports in test.py may work in a dev install (editable mode with namespace packages) but will silently break in a regular install. Add empty __init__.py files to both directories.


Minor Issues

4. tail_dir is never mkdir()'d (test.py:161)

saturation_dir.mkdir() is called on line 148, but the equivalent tail_dir.mkdir() is missing before tail_main() is invoked. If tail_main doesn't create it internally, this will fail.

5. mpl.use("Agg") at module level (summarize.py:21)

This mutates the global matplotlib backend when summarize is imported, which will silently break any interactive plotting elsewhere in the process. Move it inside plot_distribution() (before plt.subplots) or into the if __name__ == "__main__" block.

6. nmae and loss are always identical (lightning.py:111-112)

"nmae": loss.detach().cpu(),
"loss": loss.detach().cpu(),

Both store the exact same tensor since loss_fn is NormMAE. summarize.py handles this gracefully (checks np.allclose), but it wastes a column in every metrics.csv. Consider just dropping the loss column since it's always redundant, or rename to clarify.

7. duration_ms is batch average, not per-sample (lightning.py:152)

duration_per_sample = outputs["duration_ms"] / n_samples

The elapsed time covers the full batch, then gets divided evenly. The column name duration_ms implies a measured per-sample time. A comment or rename (avg_duration_ms) would prevent confusion.

8. devices=1 hardcodes single-GPU inference (test.py:108)

The old code computed num_nodes from environment variables and used devices="auto". This PR removes that, making multi-GPU test inference impossible without code changes. If this is intentional, it should be noted in the PR description.


Nits

  • summarize.py: P50 is computed via _PERCENTILES = [50, 75, 90, 95, 99] and printed as P50, but also separately displayed as Median. These will always match — consider removing 50 from _PERCENTILES.
  • _resolve_checkpoint docstring doesn't mention the "ckpt_path is itself a file" shortcut (lines 37-38), which is a non-obvious case.
  • test.py:87: config=vars(cfg)vars() works on SimpleNamespace (returns __dict__), so this is fine, but worth confirming the W&B config view won't be cluttered with internal fields.
    | Branch

forklady42 and others added 3 commits May 20, 2026 15:47
- Add __init__.py to src/electrai/scripts/ and src/electrai/scripts/analyze/
  so setuptools auto-discovery includes them
- Guard analyze_saturation/analyze_tail imports with try/except ImportError
  so missing modules degrade gracefully
- Use getattr fallback for precision (model_precision -> precision -> 32)
  so MP configs that use `precision:` still work
- Create tail_dir before invoking tail_main
- Restore devices="auto" so multi-GPU test inference works
- Move mpl.use("Agg") out of module scope into plot_distribution
- Rename duration_ms to avg_duration_ms (it's the batch average per sample)
- Drop redundant loss column (always identical to nmae since loss_fn is
  hardcoded NormMAE); remove now-dead loss-handling in summarize.py
- Remove P50 from _PERCENTILES (already shown as Median)
- Update _resolve_checkpoint docstring to mention the ckpt_path-as-file
  shortcut

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
These modules existed in the working tree but were never committed.
Re-add them and drop the try/except ImportError guards in test.py that
were a workaround for their absence.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Resolves conflicts in test.py (keep `import logging` and the
model_precision/precision getattr fallback alongside main's `num_nodes`
multi-node detection) and lightning.py (use y_cpu — the
already-batched/concatenated target — for max_target/mean_target/
num_electrons so the enriched per-sample stats also work when the
dataloader returns y as a list of variable-shape tensors).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@forklady42

Copy link
Copy Markdown
Collaborator Author

@claude re-review

@claude

claude Bot commented May 21, 2026

Copy link
Copy Markdown

Claude finished @forklady42's task in 1m 42s —— View job


Re-Review

  • Gathering context
  • Reading changed files
  • Comparing against previous review findings
  • Posting updated review

Good progress — all critical and most minor issues from the first review have been addressed. One new bug surfaced in the batch-size handling, and one minor issue remains.


What Was Fixed

# Issue Status
1 analyze_saturation.py / analyze_tail.py missing ✅ Fixed — both added
2 cfg.model_precision breaking MP configs ✅ Fixed — getattr double-fallback at test.py:116
3 Missing __init__.py in scripts/ and scripts/analyze/ ✅ Fixed — both added
4 tail_dir never mkdir()'d ✅ Fixed — tail_dir.mkdir(...) at test.py:169
6 nmae and loss always identical columns ✅ Fixed — loss column removed
7 duration_ms naming ambiguity ✅ Fixed — renamed avg_duration_ms throughout
8 devices=1 hardcodes single-GPU ✅ Fixed — devices="auto" + num_nodes from env vars restored

Remaining Issue

mpl.use("Agg") still at module level in the two new files

summarize.py correctly moved this inside plot_distribution(), but analyze_saturation.py:40 and analyze_tail.py:28 still call mpl.use("Agg") at module level. Importing either of these modules will silently change the global matplotlib backend for the entire process.

Fix this →


New Bug

nmae IndexError for batch_size > 1 (lightning.py:174)

loss in test_step is the batch-averaged scalar — a 0-d tensor regardless of batch size. The on_test_batch_end unsqueeze logic (val.unsqueeze(0)) pads it to shape [1]. Then the write loop does:

for i, idx in enumerate(indices):
    f.write(f"...{outputs['nmae'][i].item()}...")  # IndexError when i > 0

outputs['nmae'] is always [1] after unsqueeze, so outputs['nmae'][1] raises IndexError for any batch with more than one sample. This won't surface if test batch size is always 1, but it's a latent bug.

Two fix options:

Option A — repeat the batch-average value (quick, approximate):

# in on_test_batch_end, before the write loop
nmae_val = outputs["nmae"].item()
# then in the loop:
f.write(f"...{nmae_val},...")

Option B — compute true per-sample NMAE in test_step (precise):

# in test_step, compute loss per sample by iterating or using reduction='none'
spatial_dims = tuple(range(1, preds.ndim))
per_sample_nmae = self.loss_fn_per_sample(preds, y_cpu)  # requires a per-sample variant

Option A is simpler and consistent with the existing design (the batch-average NMAE from test_loss logging is already an approximation per sample). Option B is more correct but requires changing NormMAE.

Fix this →


The rest of the nits from the first review (P50 duplication in _PERCENTILES, _resolve_checkpoint docstring gap, config=vars(cfg) note) are still present but non-blocking. Overall this PR is close — fix the nmae indexing bug and it's ready to merge.
| Branch

- Move mpl.use("Agg") out of module scope in analyze_saturation.py and
  analyze_tail.py — call it inside the public entry functions
  (analyze_metrics, analyze_voxels, main) so importing the modules no
  longer mutates the global matplotlib backend.
- Fix latent nmae IndexError for batch_size > 1 in on_test_batch_end.
  Since loss is the batch-averaged scalar from NormMAE, extract it once
  with .item() and broadcast to every sample row instead of indexing
  outputs['nmae'][i] (which always has shape [1]).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant