Formalize test inference with enriched metrics and auto-analysis#106
Formalize test inference with enriched metrics and auto-analysis#106forklady42 wants to merge 11 commits into
Conversation
Enrich metrics.csv from 3 columns (rank, index, nmae) to 10 columns adding loss, max_pred, max_target, mean_pred, mean_target, num_electrons, and duration_ms. Add flexible checkpoint resolution (ckpt_file > last.ckpt > best.ckpt > glob fallback) and automatic post-test summary statistics with distribution plots. This unblocks analyze_saturation.py which already expected max_pred/max_target columns. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Compute max_pred, max_target, mean_pred, mean_target, and num_electrons per-sample by reducing over spatial dimensions only (keeping the batch dimension). Previously these were batch-level scalars that happened to be correct only with batch_size=1. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The glob fallback picks the latest epoch by lexicographic sort, not the lowest val_loss. Fix the docstring and comment to match. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
No need for a separate normmae_fn when the only loss function is NormMAE — both compute the same thing. Uses loss_fn for both the nmae and loss columns in metrics.csv (they'll be identical for now). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When wandb_mode != "disabled", log to W&B after test inference: - Distribution PNG as wandb.Image - Per-sample metrics as wandb.Table for interactive filtering - Native histogram for the overview panel - Scalar summary stats (mean, median, P95, P99, max) W&B is wired into the Trainer so Lightning's built-in test_loss metric also appears in the dashboard. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
analyze_metrics does not create its output directory. The test entrypoint now mkdir's saturation_dir before calling it. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
@claude review |
|
Claude finished @forklady42's task in 2m 32s —— View job Code Review
Overall this is a solid improvement — the enriched metrics, flexible checkpoint resolution, and auto-analysis pipeline are all well-structured. A few bugs and one breaking config change need attention before merging. Critical Issues1.
from electrai.scripts.analyze.analyze_saturation import analyze_metrics # line 145
from electrai.scripts.analyze.analyze_tail import main as tail_main # line 159But only 2. Breaking config field rename:
The MP configs ( 3. Missing
Minor Issues4.
5. This mutates the global matplotlib backend when 6. "nmae": loss.detach().cpu(),
"loss": loss.detach().cpu(),Both store the exact same tensor since 7. duration_per_sample = outputs["duration_ms"] / n_samplesThe elapsed time covers the full batch, then gets divided evenly. The column name 8. The old code computed Nits
|
- Add __init__.py to src/electrai/scripts/ and src/electrai/scripts/analyze/
so setuptools auto-discovery includes them
- Guard analyze_saturation/analyze_tail imports with try/except ImportError
so missing modules degrade gracefully
- Use getattr fallback for precision (model_precision -> precision -> 32)
so MP configs that use `precision:` still work
- Create tail_dir before invoking tail_main
- Restore devices="auto" so multi-GPU test inference works
- Move mpl.use("Agg") out of module scope into plot_distribution
- Rename duration_ms to avg_duration_ms (it's the batch average per sample)
- Drop redundant loss column (always identical to nmae since loss_fn is
hardcoded NormMAE); remove now-dead loss-handling in summarize.py
- Remove P50 from _PERCENTILES (already shown as Median)
- Update _resolve_checkpoint docstring to mention the ckpt_path-as-file
shortcut
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
These modules existed in the working tree but were never committed. Re-add them and drop the try/except ImportError guards in test.py that were a workaround for their absence. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Resolves conflicts in test.py (keep `import logging` and the model_precision/precision getattr fallback alongside main's `num_nodes` multi-node detection) and lightning.py (use y_cpu — the already-batched/concatenated target — for max_target/mean_target/ num_electrons so the enriched per-sample stats also work when the dataloader returns y as a list of variable-shape tensors). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@claude re-review |
|
Claude finished @forklady42's task in 1m 42s —— View job Re-Review
Good progress — all critical and most minor issues from the first review have been addressed. One new bug surfaced in the batch-size handling, and one minor issue remains. What Was Fixed
Remaining Issue
New Bug
for i, idx in enumerate(indices):
f.write(f"...{outputs['nmae'][i].item()}...") # IndexError when i > 0
Two fix options: Option A — repeat the batch-average value (quick, approximate): # in on_test_batch_end, before the write loop
nmae_val = outputs["nmae"].item()
# then in the loop:
f.write(f"...{nmae_val},...")Option B — compute true per-sample NMAE in # in test_step, compute loss per sample by iterating or using reduction='none'
spatial_dims = tuple(range(1, preds.ndim))
per_sample_nmae = self.loss_fn_per_sample(preds, y_cpu) # requires a per-sample variantOption A is simpler and consistent with the existing design (the batch-average NMAE from The rest of the nits from the first review (P50 duplication in |
- Move mpl.use("Agg") out of module scope in analyze_saturation.py and
analyze_tail.py — call it inside the public entry functions
(analyze_metrics, analyze_voxels, main) so importing the modules no
longer mutates the global matplotlib backend.
- Fix latent nmae IndexError for batch_size > 1 in on_test_batch_end.
Since loss is the batch-averaged scalar from NormMAE, extract it once
with .item() and broadcast to every sample row instead of indexing
outputs['nmae'][i] (which always has shape [1]).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
metrics.csvfrom 3 columns (rank,index,nmae) to 10: addsloss,max_pred,max_target,mean_pred,mean_target,num_electrons,duration_ms— all computed per-sample over spatial dimstest.py: checksckpt_file>last.ckpt>best.ckpt> glob fallback, replacing the hardcodedlast.ckptsummarize.pymodule: computes NMAE stats (mean/median/P95/P99/max), threshold counts, generates histogram + CDF plots, and optionally logs to W&B (image, table, histogram, scalar stats)trainer.test(): summary + distribution plots always run; saturation and tail analysis run when applicableTest plan
mainpass (uv run pytest)metrics.csvhas all 10 columnssummary.txtandnmae_distribution.pngare generated inlog_diranalyze_saturationworks on the enriched CSV (no more missing column errors)wandb_mode: online🤖 Generated with Claude Code