MemScope 是一个面向大模型训练场景的显存分析工具,目标不是只给出一个“峰值显存”数字,而是尽可能把训练过程里的显存变化变成可追踪、可解释、可复现的事件序列。它既支持基于配置的静态显存估算,也支持在真实训练过程中做 runtime tracing,记录 step 边界、模块前向、反向梯度以及关键显存峰值。
这个仓库最初是围绕 Transformer/LLaMA 类模型的显存分析做原型实现,后来我把它进一步集成进了 FlagScale + Megatron-LM 的训练流程里。这样它就不再只是一个独立的小工具,而是可以直接插入真实分布式训练任务,在多卡环境下输出每个 rank 的显存报告。
目前已经支持在 Megatron/FlagScale 训练中记录: 训练 step 的显存边界、模块级 forward 事件、output tensor grad、parameter grad、runtime peak,以及 static 与 runtime 的对比结果。输出格式为 JSON 和 Markdown,后续也会补充更直观的 HTML 可视化。
我目前主要在下面这个环境里验证:
Ubuntu 远程服务器 8 × NVIDIA A100-SXM4-80GB 容器镜像:nvcr.io/nvidia/pytorch:25.12-py3 训练框架:FlagScale + Megatron-LM 模型配置:Qwen 风格 GPT 训练配置,支持 TP / PP / DP 组合并行
和传统只看 nvidia-smi 或只抓一个 profiler trace 不一样,MemScope 更关注“训练中的显存事件到底是谁触发的”。
它可以回答这样一些问题: 一个 step 里显存是在 forward 还是 backward 到达峰值 哪些模块在前向过程中产生了明显的显存增量 哪些参数梯度在 backward 里是真正的大头 当前 rank 上的局部模型在 TP / PP 配置下呈现出什么 shape 静态估算和真实 runtime 之间差了多少 如果你在调大模型训练里的 OOM、显存不均衡、局部 rank 异常高峰,或者只是想更清楚地理解 Megatron 下每个 rank 实际发生了什么,这类信息会比单纯的峰值数字更有帮助。
我这版集成方式是基于 FlagScale 的 runner 机制完成的。FlagScale 会把 train.system / train.model / train.data 里的配置 flatten 成命令行参数,然后交给 flagscale/train/train_gpt.py。因此,MemScope 参数是通过 extra_args_provider 注册进 Megatron argparse 的,而不是去直接修改 Megatron 原生 arguments 入口。(在..<你自己的FlagScale根目录>/FlagScale/flagscale/train/train_gpt.py)
我是这么做的: 在 train_gpt.py 顶部增加 import 加上:
import argparse新增参数注册函数 放在 if name == "main": 之前:
def add_memscope_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="memscope")
group.add_argument(
"--memscope",
action="store_true",
help="Enable MemScope runtime tracing during training.",
)
group.add_argument(
"--memscope-outdir",
type=str,
default="memscope_outputs",
help="Output directory for MemScope reports.",
)
group.add_argument(
"--memscope-top-k",
type=int,
default=20,
help="Top-K events kept in MemScope runtime report.",
)
group.add_argument(
"--memscope-hook-modules",
action="store_true",
help="Enable module forward hooks in MemScope.",
)
group.add_argument(
"--memscope-hook-output-grads",
action="store_true",
help="Enable output tensor grad hooks in MemScope.",
)
group.add_argument(
"--memscope-hook-param-grads",
action="store_true",
help="Enable parameter grad hooks in MemScope.",
)
group.add_argument(
"--memscope-enable-profiler",
action="store_true",
help="Enable torch profiler in MemScope.",
)
group.add_argument(
"--memscope-enable-memory-snapshot",
action="store_true",
help="Enable CUDA allocator memory snapshot in MemScope.",
)
group.add_argument(
"--memscope-profiler-ranks",
nargs="+",
type=int,
default=[0],
help="Global ranks on which MemScope profiler is enabled.",
)
group.add_argument(
"--memscope-snapshot-ranks",
nargs="+",
type=int,
default=[0],
help="Global ranks on which MemScope snapshot is enabled.",
)
group.add_argument(
"--memscope-sync-on-step-boundaries",
action="store_true",
help="Synchronize CUDA at MemScope step boundaries.",
)
group.add_argument(
"--memscope-sync-on-module-hooks",
action="store_true",
help="Synchronize CUDA at module hooks. Slower but more accurate.",
)
return parser合并已有的 add_modelopt_args 继续加一个组合函数:
def extra_args_provider_with_memscope(parser: argparse.ArgumentParser):
parser = add_memscope_args(parser)
# 让 hook 系列默认开启:由于 FlagScale bool flatten 只会在 true 时传参,
# 所以这里通过 set_defaults 提供默认值最稳妥
parser.set_defaults(
memscope_hook_modules=True,
memscope_hook_output_grads=True,
memscope_hook_param_grads=True,
memscope_sync_on_step_boundaries=True,
memscope_sync_on_module_hooks=False,
)
if has_nvidia_modelopt:
parser = add_modelopt_args(parser)
return parser修改 pretrain(...) 调用 原来是:
extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,改成:
extra_args_provider=extra_args_provider_with_memscope,使用时主要需要做三件事:
第一,把 memscope 仓库放进训练环境里,保证训练进程可以 import 到 memscope.integrations.megatron.runtime。最简单的办法就是把仓库放在 FlagScale workspace 下,或者用 editable install。
第二,在 flagscale/train/train_gpt.py 中通过 extra_args_provider 注册 --memscope 及相关参数,例如输出目录、top-k、是否开启 profiler、是否开启 memory snapshot 等。
第三,在 flagscale/train/train.py 的训练主循环里,在 train() 和 train_step() 中插入 runtime 生命周期与 step 边界回调。这样训练开始时会 attach hooks,训练过程中会记录事件,结束时会生成每个 rank 的报告。
这是运行时采集核心。 顶部 import 在 flagscale/train/train.py import 区加:
try:
from memscope.integrations.megatron.runtime import MemScopeMegatronRuntime
HAVE_MEMSCOPE = True
except ImportError:
MemScopeMegatronRuntime = None
HAVE_MEMSCOPE = False修改 train_step(...) 签名 原来:
def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config):改成:
def train_step(
forward_step_func,
data_iterator,
model,
optimizer,
opt_param_scheduler,
config,
memscope_runtime=None,
iteration=None,
):在 forward_backward_func 调用前 找到:
# Forward pass.
forward_backward_func = get_forward_backward_func()下面插入:
if memscope_runtime is not None:
memscope_runtime.on_train_step_start(
step=iteration if iteration is not None else 0,
batch=None,
notes="flagscale train_step start",
)在 forward_backward_func(...) 返回后插入 在:
losses_reduced = forward_backward_func(
...
)后面加:
if memscope_runtime is not None:
memscope_runtime.on_forward_backward_end(
outputs=losses_reduced,
notes="flagscale forward_backward_func complete",
)在 optimizer.step() 前后加 找到:
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()改成:
if memscope_runtime is not None:
memscope_runtime.on_optimizer_step_start(
notes="optimizer.step about to run"
)
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
if memscope_runtime is not None:
memscope_runtime.on_optimizer_step_end(
notes="optimizer.step complete"
)在 early exit 前加 找到:
if should_exit:
return {}, True, should_checkpoint, should_exit, exit_code, None, None改成:
if should_exit:
if memscope_runtime is not None:
memscope_runtime.on_train_step_end(
outputs=None,
notes="train_step early exit",
)
return {}, True, should_checkpoint, should_exit, exit_code, None, None在 pipeline last stage return 前加 找到:
return (
loss_reduced,
skipped_iter,
should_checkpoint,
should_exit,
exit_code,
grad_norm,
num_zeros_in_grad,
)改成:
if memscope_runtime is not None:
memscope_runtime.on_train_step_end(
outputs=loss_reduced,
notes="train_step end (last pipeline stage)",
)
return (
loss_reduced,
skipped_iter,
should_checkpoint,
should_exit,
exit_code,
grad_norm,
num_zeros_in_grad,
)在非 last stage return 前加 找到最后:
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad改成:
if memscope_runtime is not None:
memscope_runtime.on_train_step_end(
outputs=None,
notes="train_step end",
)
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad训练开始时初始化 在 train() 里,找到:
# Turn on training mode which enables dropout.
for model_module in model:
model_module.train()下面加:
memscope_runtime = None
if getattr(args, "memscope", False):
if HAVE_MEMSCOPE:
try:
memscope_runtime = MemScopeMegatronRuntime.from_megatron_args(args)
memscope_runtime.attach_model(model)
memscope_runtime.start()
print_rank_0(
f"> MemScope enabled. Reports will be written to {args.memscope_outdir}"
)
except Exception as e:
print_rank_0(f"> WARNING: failed to initialize MemScope: {e}")
memscope_runtime = None
else:
print_rank_0(
"> WARNING: args.memscope is enabled but memscope package is not importable."
)调用 train_step 时传入 runtime 找到:
) = train_step(
forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config
)改成:
) = train_step(
forward_step_func,
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config,
memscope_runtime=memscope_runtime,
iteration=iteration,
)在 train() 正常结束前 finalize 在最后:
return iteration, num_floating_point_operations_so_far前面加:
if memscope_runtime is not None:
try:
memscope_runtime.finalize()
print_rank_0("> MemScope finalize complete.")
except Exception as e:
print_rank_0(f"> WARNING: MemScope finalize failed: {e}")在 should_exit 分支也 finalize 找到:
if should_exit:
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
ft_integration.shutdown()
one_logger_utils.finish()
sys.exit(exit_code)改成:
if should_exit:
if memscope_runtime is not None:
try:
memscope_runtime.finalize()
except Exception:
pass
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
ft_integration.shutdown()
one_logger_utils.finish()
sys.exit(exit_code)配置文件里只需要在 train.system 下(也就是例如14b.yaml的system: 下)增加类似下面的字段:
memscope: true
memscope_outdir: ${experiment.exp_dir}/memscope
memscope_top_k: 20
memscope_hook_modules: true
memscope_hook_output_grads: true
memscope_hook_param_grads: true
memscope_enable_profiler: false
memscope_enable_memory_snapshot: false
memscope_profiler_ranks: [0]
memscope_snapshot_ranks: [0]
memscope_sync_on_step_boundaries: true
memscope_sync_on_module_hooks: false训练完成后,报告会输出到类似下面的目录:
output_xxx/memscope/rank00000/runtime_report.json
output_xxx/memscope/rank00000/runtime_report.md
如果是多卡训练,每个 rank 都会各自输出一个目录。
runtime report 里最重要的部分有四个:runtime_trace、peak、top_events、comparisons。
runtime_trace 是完整的事件时间线,包含 step、module forward、tensor grad 等事件。 peak 是当前 rank 的显存峰值快照。 top_events 是按显存占用排序的关键事件摘要,通常很适合直接定位“谁最重”。 comparisons 会给出 static estimate 和 runtime peak 的差异,方便判断静态模型是否过于保守,或者 runtime 是否出现了额外开销。
MemScope 在运行时会输出一份 runtime_report.json,里面是每个 step 边界、模块 forward、以及梯度 hook 的显存快照。为了更方便排查“台阶”和“峰值”,我们提供了一个离线的 HTML 可视化页面:把 JSON 丢进去就能生成一个自包含的报告网页(不依赖后端服务)。
运行结束后(默认路径类似 memscope_outputs/rank00000/runtime_report.json),执行:
python -m memscope.cli visualize-runtime \
--report memscope_outputs/rank00000/runtime_report.json不传 --out 的话,会在同目录生成同名的 .html 文件,比如:
memscope_outputs/rank00000/runtime_report.html
你也可以指定输出路径:
python -m memscope.cli visualize-runtime \
--report memscope_outputs/rank00000/runtime_report.json \
--out /tmp/memscope_runtime.html然后直接用浏览器打开这个 HTML 文件即可。
HTML 报告主要包含几块内容:
概览卡片:runtime 峰值(allocated / reserved)、静态估算峰值、以及静态 vs runtime 的差值(用来快速判断静态估算是否偏保守)。
Step-level memory timeline: 按 step 边界事件画出的两条曲线, allocated_after(PyTorch 当前实际持有的显存), reserved_after(caching allocator 预留的显存,包含缓存与碎片), 这块特别适合看类似 “step0 optimizer 之后基线抬升” 这种台阶现象。
Phase summary: 按 phase 聚合后的最大 allocated 水位(帮助确认峰值更偏 forward / backward / optimizer)。
Top events: 按 mem_allocated_after 排序的事件列表(支持搜索过滤),用于快速定位“峰值发生在什么模块/什么 phase/哪个 step 附近”。
Module aggregate: 对每个 module 的事件做聚合,包含 max_allocated_after 和 max_delta_allocated。当你想找“哪个模块 forward 后净增量最大”时,这张表比 top events 更直观。
补充说明:Step-level timeline 展示的是“打点时刻的显存水位”。如果某个模块内部有短生命周期的临时显存尖峰(分配后在模块结束前释放),仅靠模块前后 hook 可能看不到完整尖峰;这种场景可以配合 profiler 或 memory snapshot 做更细粒度的归因。
如果你对大模型训练显存分析、Megatron 训练调试或者 profiling 工具感兴趣,欢迎交流或提 issue。
.png)
.png)