一个轻量级的 Vision Transformer (ViT) 实现,支持 CIFAR-10、CIFAR-100 和 Tiny-ImageNet-200 数据集的图像分类任务。本项目使用 PyTorch 从零开始实现 ViT 架构,并支持从 HuggingFace 预训练模型进行迁移学习。
如果你想更深入地了解 Vision Transformer 的原理、架构设计和实现细节,可以阅读:Vision Transformer —— 图像识别中的 Transformer 架构
本项目的理念和代码结构深受 Andrej Karpathy 的 nanoGPT 项目启发。
uv 是一个快速的 Python 包管理工具,安装速度比传统 pip 快 10-100 倍。
# 创建虚拟环境
uv venv --python 3.13 --seed
# 激活虚拟环境
source .venv/bin/activate
# 安装项目依赖
uv pip install -r requirements.txt如果你更熟悉 conda 生态系统,也可以使用 conda 进行环境管理。
# 创建新的 conda 环境
conda create -n nanovit python=3.13 -y
# 激活环境
conda activate nanovit
# 安装项目依赖
pip install -r requirements.txt本项目支持三个常用的图像分类数据集,每个数据集都有不同的特点和应用场景。
数据集简介:CIFAR-10 是一个经典的计算机视觉数据集,包含 60,000 张 32×32 彩色图像,分为 10 个类别(飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船、卡车),每个类别 6,000 张图像。其中训练集 50,000 张,测试集 10,000 张。
下载与准备:
# 创建数据目录
mkdir -p data
# 下载 CIFAR-10 数据集
cd data
wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
# 解压数据集
tar -zxvf cifar-10-python.tar.gz解压后的目录结构:
data/cifar-10-batches-py/
├── data_batch_1
├── data_batch_2
├── data_batch_3
├── data_batch_4
├── data_batch_5
├── test_batch
└── batches.meta
数据集简介:CIFAR-100 是 CIFAR-10 的扩展版本,同样包含 60,000 张 32×32 彩色图像,但分为 100 个细粒度类别(如苹果、橙子、梨等),每个类别 600 张图像。100 个类别还被分组为 20 个粗粒度类别。这个数据集比 CIFAR-10 更具挑战性。
下载与准备:
cd data
wget https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
# 解压数据集
tar -zxvf cifar-100-python.tar.gz解压后的目录结构:
data/cifar-100-python/
├── train
├── test
└── meta
数据集简介:Tiny-ImageNet-200 是 ImageNet 数据集的缩小版本,由斯坦福大学 CS231n 课程创建。包含 200 个类别,每个类别 500 张训练图像、50 张验证图像和 50 张测试图像。图像尺寸为 64×64,比 CIFAR 更大,更接近真实场景。
下载与准备:
cd data
wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
# 解压数据集
unzip tiny-imagenet-200.zip注意:Tiny-ImageNet-200 需要额外的预处理步骤,将原始图像转换为 .npy 格式以加快训练速度。请运行以下脚本进行预处理:
# 预处理脚本
python scripts/preprocess_tiny_imagenet.py --data_dir data/tiny-imagenet-200预处理后的目录结构:
data/tiny-imagenet-200/
├── train_data.npy
├── train_labels.npy
├── test_data.npy
└── test_labels.npy
从头开始训练一个 ViT 模型(不使用预训练权重):
# 在 CIFAR-10 上从头训练(使用较小的模型配置)
python train.py --init_from=scratch --data=cifar10 --data_root=data/cifar-10-batches-py \
--num_classes=10 --dim=384 --depth=12 --heads=6 --mlp_dim=1536 \
--learning_rate=0.001 --max_steps=10000
# 在 CIFAR-100 上从头训练
python train.py --init_from=scratch --data=cifar100 --data_root=data/cifar-100-python \
--num_classes=100 --dim=384 --depth=12 --heads=6 --mlp_dim=1536 \
--learning_rate=0.001 --max_steps=15000
# 在 Tiny-ImageNet-200 上从头训练
python train.py --init_from=scratch --data=tiny-imagenet --data_root=data/tiny-imagenet-200 \
--num_classes=200 --dim=512 --depth=12 --heads=8 --mlp_dim=2048 \
--learning_rate=0.001 --max_steps=20000使用预训练的 ViT 模型进行微调(推荐,效果更好且速度更快):
提示:预训练模型会从 HuggingFace 自动下载,可能需要配置代理才能成功下载。
# 在 CIFAR-10 上微调 ViT-Base 模型(推荐)
python train.py --init_from=ViT-B_16 --data=cifar10 --data_root=data/cifar-10-batches-py \
--num_classes=10 --learning_rate=0.03 --max_steps=5000 --warmup_steps=250
# 在 CIFAR-100 上微调
python train.py --init_from=ViT-B_16 --data=cifar100 --data_root=data/cifar-100-python \
--num_classes=100 --learning_rate=0.03 --max_steps=10000 --warmup_steps=500
# 在 Tiny-ImageNet-200 上微调
python train.py --init_from=ViT-B_16 --data=tiny-imagenet --data_root=data/tiny-imagenet-200 \
--num_classes=200 --learning_rate=0.03 --max_steps=20000 --warmup_steps=1000
# 使用更大的 ViT-Large 模型微调
python train.py --init_from=ViT-L_16 --data=cifar10 --data_root=data/cifar-10-batches-py \
--num_classes=10 --data_batch_size=16 --gradient_accumulation_steps=16out_dir: 输出目录,用于保存模型检查点(默认:"out")eval_every: 每隔多少步进行一次评估(默认:100)
image_size: 输入图像尺寸(默认:224)patch_size: 图像块大小(默认:16)num_classes: 分类类别数量(CIFAR-10:10, CIFAR-100:100, Tiny-ImageNet:200)dim: 模型嵌入维度(默认:768)depth: Transformer 层数(默认:12)heads: 多头注意力的头数(默认:12)mlp_dim: MLP 隐藏层维度(默认:3072)emb_dropout: 嵌入层 dropout 率(默认:0.1)dropout: 模型内部 dropout 率(默认:0.1)
init_from: 模型初始化方式"scratch": 从头训练,随机初始化模型权重"resumed": 从检查点恢复训练,继续之前中断的训练过程"ViT-B_16": 使用 ViT-Base 预训练模型(patch size 16)"ViT-B_32": 使用 ViT-Base 预训练模型(patch size 32)"ViT-L_16": 使用 ViT-Large 预训练模型"ViT-L_32": 使用 ViT-Large 预训练模型(patch size 32)"ViT-H_14": 使用 ViT-Huge 预训练模型
从检查点恢复训练:
当训练过程因意外中断或需要继续训练时,可以使用 init_from=resumed 从保存的检查点恢复。训练脚本会自动:
- 从
out_dir目录加载ckpt.pth检查点文件 - 恢复模型权重、优化器状态、训练步数和最佳评估损失
- 从上次中断的地方继续训练
检查点会在每次评估时,如果当前模型的测试损失优于历史最佳值,自动保存到 {out_dir}/ckpt.pth。
恢复训练示例:
# 假设之前的训练保存在默认的 out/ 目录
python train.py --init_from=resumed
# 如果使用了自定义输出目录,需要指定相同的 out_dir
python train.py --init_from=resumed --out_dir=out-cifar100data: 数据集名称(可选:"cifar10","cifar100","tiny-imagenet")data_root: 数据集根目录路径workers: 数据加载器的工作进程数(默认:16)data_batch_size: 每个 batch 的样本数量(默认:64)gradient_accumulation_steps: 梯度累积步数,用于模拟更大的 batch size(默认:8)- 实际有效 batch size =
data_batch_size × gradient_accumulation_steps
- 实际有效 batch size =
Weights & Biases (WandB) 是一个强大的机器学习实验跟踪平台,可以自动记录训练过程中的指标、可视化训练曲线、管理模型版本等。
# 安装后首次使用需要登录(只需一次)
wandb login
# 按提示输入你的 API Key(可在 https://wandb.ai/settings#api 获取)如果不想使用 WandB,可以通过 --wandb_log=False 禁用。
wandb_log: 是否启用 Weights & Biases 日志记录(默认:True)wandb_project: WandB 项目名称(默认:"tiny-imagenet-vit")wandb_run_name: WandB 运行名称(默认:"vit_base_16_224_cifar10_finetune")
learning_rate: 学习率(从头训练推荐:0.001,微调推荐:0.03)momentum: SGD 动量(默认:0.9)weight_decay: 权重衰减系数(默认:0.0)grad_clip: 梯度裁剪阈值(默认:1.0,设为0则禁用)
decay_lr: 是否启用学习率衰减(默认:True)max_steps: 最大训练步数(默认:10000)warmup_steps: 学习率预热步数(默认:500)final_lr: 最终学习率(默认:1e-5)
device: 训练设备(默认:"cuda",CPU 训练使用"cpu")dtype: 数据类型(可选:"float32","bfloat16","float16")- 默认:如果支持
"bfloat16"则使用,否则使用"float16"
- 默认:如果支持
compile: 是否使用torch.compile编译模型以提升性能(默认:True,需要 PyTorch 2.0+)
# 完全自定义的训练配置
python train.py \
--init_from=scratch \
--data=cifar10 \
--data_root=data/cifar-10-batches-py \
--num_classes=10 \
--image_size=224 \
--patch_size=16 \
--dim=512 \
--depth=8 \
--heads=8 \
--mlp_dim=2048 \
--learning_rate=0.001 \
--data_batch_size=128 \
--gradient_accumulation_steps=2 \
--max_steps=5000 \
--warmup_steps=200 \
--out_dir=out-custom \
--wandb_log=False \
--compile=TruenanoViT/
├── model.py # ViT 模型实现
├── datasets.py # 数据集加载器
├── train.py # 训练脚本
├── configurator.py # 命令行参数配置器(借鉴自 nanoGPT)
├── requirements.txt # 项目依赖
├── scripts/ # 实用脚本
│ └── preprocess_tiny_imagenet.py # Tiny-ImageNet 预处理
└── data/ # 数据集目录