Mask-aware packing and memory diagnostics for supervised fine-tuning.
MaskPack-SFT helps prepare packed SFT batches without losing assistant-only labels or allowing attention across unrelated examples. It also includes a small memory advisor for checking whether a LoRA, QLoRA, or full fine-tuning configuration is likely to fit on a target GPU.
- Best-fit packing for tokenized SFT examples.
- Assistant-only loss masks preserved after packing.
- Block-diagonal causal attention masks to prevent cross-example attention.
- Position IDs and segment IDs for packed sequences.
- Practical GPU memory estimates with tuning recommendations.
- FSDP, DeepSpeed, and ZeRO memory-planning presets.
- JSONL CLI for profiling, packing, mask verification, and utilization checks.
- Optional PyTorch tensor collator for trainer integrations.
- Lightweight TRL integration hook with no hard dependency on PyTorch or TRL.
pip install -e .[dev]The core package uses plain Python data structures. Framework integrations can convert the returned batches to tensors where needed.
Estimate GPU memory for a 7B QLoRA-style run:
maskpack profile \
--parameters-b 7 \
--seq-length 2048 \
--micro-batch-size 1 \
--precision nf4 \
--trainable-fraction 0.002 \
--gradient-checkpointing \
--chunked-loss \
--gpu-memory-gb 24Estimate a sharded run:
maskpack profile \
--parameters-b 13 \
--seq-length 4096 \
--micro-batch-size 1 \
--distributed-preset zero3 \
--world-size 4 \
--gpu-memory-gb 80Pack tokenized SFT examples:
maskpack pack examples/tokenized_sft.jsonl work/packed.jsonl --max-length 8Check packing utilization:
maskpack bench examples/tokenized_sft.jsonl --max-length 8Verify an exported attention mask:
maskpack verify-mask examples/good_mask.jsonmaskpack pack reads JSONL. Each line must include token IDs and an
assistant-token mask computed after chat-template formatting:
{"source_id":"chat-1","input_ids":[10,11,12,13],"assistant_mask":[0,0,1,1]}
{"source_id":"chat-2","input_ids":[20,21,22],"assistant_mask":[0,1,1]}assistant_mask and loss_mask are both accepted. Values with 1 contribute
to the language-modeling loss; values with 0 are assigned label -100.
from maskpack_sft import MaskAwareCollator, PackedSFTDataset, SFTExample
examples = [
SFTExample(input_ids=[10, 11, 12, 13], assistant_mask=[0, 0, 1, 1]),
SFTExample(input_ids=[20, 21, 22], assistant_mask=[0, 1, 1]),
]
packed = PackedSFTDataset(examples, max_length=8)
batch = MaskAwareCollator(pad_token_id=0)([packed[0]])
assert batch["labels"] == [[-100, -100, 12, 13, -100, 21, 22]]from maskpack_sft.integrations import build_trl_collator
collator = build_trl_collator(as_tensors=True, pad_token_id=tokenizer.pad_token_id)The integration helper returns the same mask-aware collator used by the core
package. Training code can wrap it to convert input_ids, labels,
attention_mask, position_ids, and segment_ids to tensors.
For a runnable script, see examples/trl_sfttrainer_example.py.
maskpack profile --parameters-b 7 --seq-length 2048 --micro-batch-size 1
maskpack pack INPUT.jsonl OUTPUT.jsonl --max-length 4096
maskpack bench INPUT.jsonl --max-length 4096
maskpack verify-mask MASK.jsonUseful options:
--truncate: clip examples longer than--max-length.--sort: sort by length before best-fit packing.--json: emit JSON forprofile.--gradient-checkpointing: include checkpointing in the memory estimate.--chunked-loss: estimate a chunked CE or CCE-style loss path.--distributed-preset: usedata_parallel,fsdp_zero2,fsdp_zero3,zero2, orzero3.--world-size: set the number of distributed workers for memory estimates.
python -m pytestThe test suite covers packing behavior, assistant-only labels, block-diagonal causal masks, long-example truncation, memory advice, and CLI mask verification.
The benchmarks/ directory contains a starter benchmark matrix. A typical
comparison should report:
- token utilization;
- examples/sec and tokens/sec;
- peak GPU memory;
- train and eval loss;
- mask verification results;
- OOM or distributed consistency failures.
benchmarks/github_issue_benchmark.md maps public issue reports to concrete
benchmark cases.
MIT