Skip to content

Lin-Aurora/MaskPack-SFT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MaskPack-SFT

Mask-aware packing and memory diagnostics for supervised fine-tuning.

MaskPack-SFT helps prepare packed SFT batches without losing assistant-only labels or allowing attention across unrelated examples. It also includes a small memory advisor for checking whether a LoRA, QLoRA, or full fine-tuning configuration is likely to fit on a target GPU.

Features

  • Best-fit packing for tokenized SFT examples.
  • Assistant-only loss masks preserved after packing.
  • Block-diagonal causal attention masks to prevent cross-example attention.
  • Position IDs and segment IDs for packed sequences.
  • Practical GPU memory estimates with tuning recommendations.
  • FSDP, DeepSpeed, and ZeRO memory-planning presets.
  • JSONL CLI for profiling, packing, mask verification, and utilization checks.
  • Optional PyTorch tensor collator for trainer integrations.
  • Lightweight TRL integration hook with no hard dependency on PyTorch or TRL.

Installation

pip install -e .[dev]

The core package uses plain Python data structures. Framework integrations can convert the returned batches to tensors where needed.

Quick Start

Estimate GPU memory for a 7B QLoRA-style run:

maskpack profile \
  --parameters-b 7 \
  --seq-length 2048 \
  --micro-batch-size 1 \
  --precision nf4 \
  --trainable-fraction 0.002 \
  --gradient-checkpointing \
  --chunked-loss \
  --gpu-memory-gb 24

Estimate a sharded run:

maskpack profile \
  --parameters-b 13 \
  --seq-length 4096 \
  --micro-batch-size 1 \
  --distributed-preset zero3 \
  --world-size 4 \
  --gpu-memory-gb 80

Pack tokenized SFT examples:

maskpack pack examples/tokenized_sft.jsonl work/packed.jsonl --max-length 8

Check packing utilization:

maskpack bench examples/tokenized_sft.jsonl --max-length 8

Verify an exported attention mask:

maskpack verify-mask examples/good_mask.json

Input Format

maskpack pack reads JSONL. Each line must include token IDs and an assistant-token mask computed after chat-template formatting:

{"source_id":"chat-1","input_ids":[10,11,12,13],"assistant_mask":[0,0,1,1]}
{"source_id":"chat-2","input_ids":[20,21,22],"assistant_mask":[0,1,1]}

assistant_mask and loss_mask are both accepted. Values with 1 contribute to the language-modeling loss; values with 0 are assigned label -100.

Python API

from maskpack_sft import MaskAwareCollator, PackedSFTDataset, SFTExample

examples = [
    SFTExample(input_ids=[10, 11, 12, 13], assistant_mask=[0, 0, 1, 1]),
    SFTExample(input_ids=[20, 21, 22], assistant_mask=[0, 1, 1]),
]

packed = PackedSFTDataset(examples, max_length=8)
batch = MaskAwareCollator(pad_token_id=0)([packed[0]])

assert batch["labels"] == [[-100, -100, 12, 13, -100, 21, 22]]

TRL Integration

from maskpack_sft.integrations import build_trl_collator

collator = build_trl_collator(as_tensors=True, pad_token_id=tokenizer.pad_token_id)

The integration helper returns the same mask-aware collator used by the core package. Training code can wrap it to convert input_ids, labels, attention_mask, position_ids, and segment_ids to tensors.

For a runnable script, see examples/trl_sfttrainer_example.py.

CLI Reference

maskpack profile --parameters-b 7 --seq-length 2048 --micro-batch-size 1
maskpack pack INPUT.jsonl OUTPUT.jsonl --max-length 4096
maskpack bench INPUT.jsonl --max-length 4096
maskpack verify-mask MASK.json

Useful options:

  • --truncate: clip examples longer than --max-length.
  • --sort: sort by length before best-fit packing.
  • --json: emit JSON for profile.
  • --gradient-checkpointing: include checkpointing in the memory estimate.
  • --chunked-loss: estimate a chunked CE or CCE-style loss path.
  • --distributed-preset: use data_parallel, fsdp_zero2, fsdp_zero3, zero2, or zero3.
  • --world-size: set the number of distributed workers for memory estimates.

Development

python -m pytest

The test suite covers packing behavior, assistant-only labels, block-diagonal causal masks, long-example truncation, memory advice, and CLI mask verification.

Benchmarking

The benchmarks/ directory contains a starter benchmark matrix. A typical comparison should report:

  • token utilization;
  • examples/sec and tokens/sec;
  • peak GPU memory;
  • train and eval loss;
  • mask verification results;
  • OOM or distributed consistency failures.

benchmarks/github_issue_benchmark.md maps public issue reports to concrete benchmark cases.

License

MIT

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages