Skip to content

kenanking/nanoViT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

nanoViT

一个轻量级的 Vision Transformer (ViT) 实现,支持 CIFAR-10、CIFAR-100 和 Tiny-ImageNet-200 数据集的图像分类任务。本项目使用 PyTorch 从零开始实现 ViT 架构,并支持从 HuggingFace 预训练模型进行迁移学习。

如果你想更深入地了解 Vision Transformer 的原理、架构设计和实现细节,可以阅读:Vision Transformer —— 图像识别中的 Transformer 架构

本项目的理念和代码结构深受 Andrej Karpathy 的 nanoGPT 项目启发。

环境准备

方式一:使用 uv(推荐)

uv 是一个快速的 Python 包管理工具,安装速度比传统 pip 快 10-100 倍。

# 创建虚拟环境
uv venv --python 3.13 --seed

# 激活虚拟环境
source .venv/bin/activate

# 安装项目依赖
uv pip install -r requirements.txt

方式二:使用 conda

如果你更熟悉 conda 生态系统,也可以使用 conda 进行环境管理。

# 创建新的 conda 环境
conda create -n nanovit python=3.13 -y

# 激活环境
conda activate nanovit

# 安装项目依赖
pip install -r requirements.txt

数据集准备

本项目支持三个常用的图像分类数据集,每个数据集都有不同的特点和应用场景。

1. CIFAR-10

数据集简介:CIFAR-10 是一个经典的计算机视觉数据集,包含 60,000 张 32×32 彩色图像,分为 10 个类别(飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船、卡车),每个类别 6,000 张图像。其中训练集 50,000 张,测试集 10,000 张。

下载与准备

# 创建数据目录
mkdir -p data

# 下载 CIFAR-10 数据集
cd data
wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

# 解压数据集
tar -zxvf cifar-10-python.tar.gz

解压后的目录结构:

data/cifar-10-batches-py/
├── data_batch_1
├── data_batch_2
├── data_batch_3
├── data_batch_4
├── data_batch_5
├── test_batch
└── batches.meta

2. CIFAR-100

数据集简介:CIFAR-100 是 CIFAR-10 的扩展版本,同样包含 60,000 张 32×32 彩色图像,但分为 100 个细粒度类别(如苹果、橙子、梨等),每个类别 600 张图像。100 个类别还被分组为 20 个粗粒度类别。这个数据集比 CIFAR-10 更具挑战性。

下载与准备

cd data
wget https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz

# 解压数据集
tar -zxvf cifar-100-python.tar.gz

解压后的目录结构:

data/cifar-100-python/
├── train
├── test
└── meta

3. Tiny-ImageNet-200

数据集简介:Tiny-ImageNet-200 是 ImageNet 数据集的缩小版本,由斯坦福大学 CS231n 课程创建。包含 200 个类别,每个类别 500 张训练图像、50 张验证图像和 50 张测试图像。图像尺寸为 64×64,比 CIFAR 更大,更接近真实场景。

下载与准备

cd data
wget http://cs231n.stanford.edu/tiny-imagenet-200.zip

# 解压数据集
unzip tiny-imagenet-200.zip

注意:Tiny-ImageNet-200 需要额外的预处理步骤,将原始图像转换为 .npy 格式以加快训练速度。请运行以下脚本进行预处理:

# 预处理脚本
python scripts/preprocess_tiny_imagenet.py --data_dir data/tiny-imagenet-200

预处理后的目录结构:

data/tiny-imagenet-200/
├── train_data.npy
├── train_labels.npy
├── test_data.npy
└── test_labels.npy

训练

从头开始训练一个 ViT 模型(不使用预训练权重):

# 在 CIFAR-10 上从头训练(使用较小的模型配置)
python train.py --init_from=scratch --data=cifar10 --data_root=data/cifar-10-batches-py \
  --num_classes=10 --dim=384 --depth=12 --heads=6 --mlp_dim=1536 \
  --learning_rate=0.001 --max_steps=10000

# 在 CIFAR-100 上从头训练
python train.py --init_from=scratch --data=cifar100 --data_root=data/cifar-100-python \
  --num_classes=100 --dim=384 --depth=12 --heads=6 --mlp_dim=1536 \
  --learning_rate=0.001 --max_steps=15000

# 在 Tiny-ImageNet-200 上从头训练
python train.py --init_from=scratch --data=tiny-imagenet --data_root=data/tiny-imagenet-200 \
  --num_classes=200 --dim=512 --depth=12 --heads=8 --mlp_dim=2048 \
  --learning_rate=0.001 --max_steps=20000

微调

使用预训练的 ViT 模型进行微调(推荐,效果更好且速度更快):

提示:预训练模型会从 HuggingFace 自动下载,可能需要配置代理才能成功下载。

# 在 CIFAR-10 上微调 ViT-Base 模型(推荐)
python train.py --init_from=ViT-B_16 --data=cifar10 --data_root=data/cifar-10-batches-py \
  --num_classes=10 --learning_rate=0.03 --max_steps=5000 --warmup_steps=250

# 在 CIFAR-100 上微调
python train.py --init_from=ViT-B_16 --data=cifar100 --data_root=data/cifar-100-python \
  --num_classes=100 --learning_rate=0.03 --max_steps=10000 --warmup_steps=500

# 在 Tiny-ImageNet-200 上微调
python train.py --init_from=ViT-B_16 --data=tiny-imagenet --data_root=data/tiny-imagenet-200 \
  --num_classes=200 --learning_rate=0.03 --max_steps=20000 --warmup_steps=1000

# 使用更大的 ViT-Large 模型微调
python train.py --init_from=ViT-L_16 --data=cifar10 --data_root=data/cifar-10-batches-py \
  --num_classes=10 --data_batch_size=16 --gradient_accumulation_steps=16

参数

输出与评估

  • out_dir: 输出目录,用于保存模型检查点(默认:"out"
  • eval_every: 每隔多少步进行一次评估(默认:100

模型参数

  • image_size: 输入图像尺寸(默认:224
  • patch_size: 图像块大小(默认:16
  • num_classes: 分类类别数量(CIFAR-10: 10, CIFAR-100: 100, Tiny-ImageNet: 200
  • dim: 模型嵌入维度(默认:768
  • depth: Transformer 层数(默认:12
  • heads: 多头注意力的头数(默认:12
  • mlp_dim: MLP 隐藏层维度(默认:3072
  • emb_dropout: 嵌入层 dropout 率(默认:0.1
  • dropout: 模型内部 dropout 率(默认:0.1

模型初始化

  • init_from: 模型初始化方式
    • "scratch": 从头训练,随机初始化模型权重
    • "resumed": 从检查点恢复训练,继续之前中断的训练过程
    • "ViT-B_16": 使用 ViT-Base 预训练模型(patch size 16)
    • "ViT-B_32": 使用 ViT-Base 预训练模型(patch size 32)
    • "ViT-L_16": 使用 ViT-Large 预训练模型
    • "ViT-L_32": 使用 ViT-Large 预训练模型(patch size 32)
    • "ViT-H_14": 使用 ViT-Huge 预训练模型

从检查点恢复训练

当训练过程因意外中断或需要继续训练时,可以使用 init_from=resumed 从保存的检查点恢复。训练脚本会自动:

  1. out_dir 目录加载 ckpt.pth 检查点文件
  2. 恢复模型权重、优化器状态、训练步数和最佳评估损失
  3. 从上次中断的地方继续训练

检查点会在每次评估时,如果当前模型的测试损失优于历史最佳值,自动保存到 {out_dir}/ckpt.pth

恢复训练示例

# 假设之前的训练保存在默认的 out/ 目录
python train.py --init_from=resumed

# 如果使用了自定义输出目录,需要指定相同的 out_dir
python train.py --init_from=resumed --out_dir=out-cifar100

数据参数

  • data: 数据集名称(可选:"cifar10", "cifar100", "tiny-imagenet"
  • data_root: 数据集根目录路径
  • workers: 数据加载器的工作进程数(默认:16
  • data_batch_size: 每个 batch 的样本数量(默认:64
  • gradient_accumulation_steps: 梯度累积步数,用于模拟更大的 batch size(默认:8
    • 实际有效 batch size = data_batch_size × gradient_accumulation_steps

WandB 日志

Weights & Biases (WandB) 是一个强大的机器学习实验跟踪平台,可以自动记录训练过程中的指标、可视化训练曲线、管理模型版本等。

# 安装后首次使用需要登录(只需一次)
wandb login

# 按提示输入你的 API Key(可在 https://wandb.ai/settings#api 获取)

如果不想使用 WandB,可以通过 --wandb_log=False 禁用。

  • wandb_log: 是否启用 Weights & Biases 日志记录(默认:True
  • wandb_project: WandB 项目名称(默认:"tiny-imagenet-vit"
  • wandb_run_name: WandB 运行名称(默认:"vit_base_16_224_cifar10_finetune"

优化器参数

  • learning_rate: 学习率(从头训练推荐:0.001,微调推荐:0.03
  • momentum: SGD 动量(默认:0.9
  • weight_decay: 权重衰减系数(默认:0.0
  • grad_clip: 梯度裁剪阈值(默认:1.0,设为 0 则禁用)

学习率调度

  • decay_lr: 是否启用学习率衰减(默认:True
  • max_steps: 最大训练步数(默认:10000
  • warmup_steps: 学习率预热步数(默认:500
  • final_lr: 最终学习率(默认:1e-5

系统参数

  • device: 训练设备(默认:"cuda",CPU 训练使用 "cpu"
  • dtype: 数据类型(可选:"float32", "bfloat16", "float16"
    • 默认:如果支持 "bfloat16" 则使用,否则使用 "float16"
  • compile: 是否使用 torch.compile 编译模型以提升性能(默认:True,需要 PyTorch 2.0+)

示例:自定义参数训练

# 完全自定义的训练配置
python train.py \
  --init_from=scratch \
  --data=cifar10 \
  --data_root=data/cifar-10-batches-py \
  --num_classes=10 \
  --image_size=224 \
  --patch_size=16 \
  --dim=512 \
  --depth=8 \
  --heads=8 \
  --mlp_dim=2048 \
  --learning_rate=0.001 \
  --data_batch_size=128 \
  --gradient_accumulation_steps=2 \
  --max_steps=5000 \
  --warmup_steps=200 \
  --out_dir=out-custom \
  --wandb_log=False \
  --compile=True

项目结构

nanoViT/
├── model.py              # ViT 模型实现
├── datasets.py           # 数据集加载器
├── train.py              # 训练脚本
├── configurator.py       # 命令行参数配置器(借鉴自 nanoGPT)
├── requirements.txt      # 项目依赖
├── scripts/              # 实用脚本
│   └── preprocess_tiny_imagenet.py   # Tiny-ImageNet 预处理
└── data/                 # 数据集目录

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages