Add statistics logging for params and activations#5492
Conversation
Signed-off-by: Philip Monk <pmonk@nvidia.com>
f62799e to
34f71bc
Compare
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
/claude review |
| prefix, local_expert_index, suffix = sequential_match.groups() | ||
| return f"{prefix}{int(local_expert_index) + expert_offset}{suffix}" | ||
|
|
||
| return local_name | ||
|
|
||
|
|
||
| def _get_local_expert_offset(num_experts: int) -> int: | ||
| expert_group = parallel_state.get_expert_model_parallel_group(check_initialized=False) | ||
| if expert_group is None: |
There was a problem hiding this comment.
Per the repo guidelines on process groups in megatron/core: this function introduces three new direct reads from parallel_state (get_expert_model_parallel_group, get_expert_model_parallel_world_size, get_expert_model_parallel_rank). The preference is to accept a ProcessGroupCollection or explicit torch.distributed.ProcessGroup from the caller and thread it through.
The caller (_canonical_param_name → _global_expert_param_name → _get_local_expert_offset) could receive the expert parallel size and rank from the PerParameterStatRegistry constructor (which already has the model chunks and could accept a ProcessGroupCollection), avoiding the global reads here.
| grad = optimizer._get_grad_for_grad_norm(param) | ||
| if not optimizer._include_param_in_grad_norm(param, grad): | ||
| continue | ||
| names.append(name) | ||
| grads.append(grad.detach()) | ||
|
|
There was a problem hiding this comment.
Bug: reduce_groups=(None,) passes group=None to torch.distributed.all_reduce, which reduces over the default (WORLD) group. This will multiply all raw moments by world_size — every DP replica holds the same (already-reduced) gradient, and every PP/TP rank's contribution gets summed globally.
Compare with the base MegatronOptimizer.get_raw_moment_buckets_for_grad_norm, which carefully selects the data-parallel group (for DTensor) and get_grad_stats_parallel_group() (model-parallel group). The LayerWiseOptimizer should assemble equivalent reduce groups from its chained optimizers rather than falling back to WORLD.
If (None,) was intended as "no reduction needed", use an empty tuple () instead — reduce_raw_moments_by_param iterates bucket.reduce_groups and will skip the all-reduce entirely when the tuple is empty.
There was a problem hiding this comment.
I struggled with this one a bit and am not sure I got it right in the end. None as the world group is used based on the similar logic in get_grad_norm and count_zeros.
However, I notice in their helper functions that if they're dtensors, they also reduce across that data parallel group. That seems like a double-count to me, since reducing across None should already include those devices, I believe.
Signed-off-by: Philip Monk <pmonk@nvidia.com>
What does this PR do ?
This introduces per-parameter and per-activation-tensor statistics, logged to a jsonl file for later processing.
We log statistics for parameters, wgrads, activations, and dgrads. These form two natural groups: parameters and wgrads have the size of the model, are calculated after the forward and backward passes have completed, and are generally fast enough to run on every step. We store statistics for each of these at the granularity of a parameter -- for example,
decoder.layers.3.mlp.linear_fc2.weight.Activations and dgrads are summarized over the size of the activations times the number of times we store statistics in the model and are calculated by hooks during the forward and backward passes (and all-reduced after). We store statistics at the beginning and end of most linear layers -- for example
decoder.layers.3.mlp.linear_fc2/input0. This incurs a nontrivial cost, I measured up to 30% of the step time. Consider setting--activation-log-intervalto 100 steps to reduce the impact.For all of these, we're most interested in the mean, variance/stddev, L2 norm, RMS norm, and kurtosis. We want to be able to compute these statistics for arbitrary combinations -- for example, across all the parameters in a logical layer. These statistics don't necessarily aggregate well directly, but they're all derivable from the first four raw moments (sum, sum of squares, sum of cubes, and sum of fourth powers), so we compute and log those.
A performance consideration is that computing these statistics across many tensors can be slow without a multi_tensor-style kernel that fuses the four statistics. I've opened a PR in TransformerEngine with an implementation of this, and a raw torch backup is used if that kernel can't be found.
For reference, these are the step times I observed on a 1B dense model with and without the TE kernel.
I have a simple streamlit dashboard to visualize the logged data, but I have not included it in this PR. It could perhaps be added to the tools/ directory, but this PR is already pretty large.
Contribution process
Pre-checks