Skip to content

ziheng1027/STF

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

173 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

STF

一个时空预测入门框架: 时空预测, 视频序列预测, 雷达回波外推...

环境依赖

python>=3.12, cuda>=12.x, pytorch>=2.6.0

  1. pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128
  2. pip install -r requirements.txt

模型

  • ConvLSTM*(NIPS 2015)
  • PredRNN(V1:NIPS 2017, V2:IEEE 2022)
  • PhyDNet(CVPR 2020)
  • UNet(CVPR 2015)
  • SmaAtUNet(PRL 2021)
  • SimVP(IncepU:CVPR 2022, gSTA:IEEE 2022, TAU:CVPR 2023)
  • STLight(WACV 2024)

数据集

MovingMNIST:

  • 下载(bash): sh "Dataset\Download\download_moving_mnist.sh"
  • 模型配置与性能指标:

模型配置

Model Model Config
PredRNN-V1 input_channels=1, num_hidden_channels=[128,128,128,128], input_frames=10, output_frames=10, patch_size=4, kernel_size=5, reverse_scheduled_sampling=True
PredRNN-V2 input_channels=1, num_hidden_channels=[128,128,128,128], input_frames=10, output_frames=10, patch_size=4, kernel_size=5, reverse_scheduled_sampling=True
UNet in_channels=1, out_channels=1, in_frames=10, out_frames=10, bilinear=True
SmaAtUNet in_channels=1, out_channels=1, in_frames=10, out_frames=10, num_kernel=2, reduction_ratio=16
SimVP-V1 input_shape=[10,1,64,64], translator_type=IncepU, hid_channels_S=64, hid_channels_T=512, layers_S=4, layers_T=8
SimVP-V2 input_shape=[10,1,64,64], translator_type=gSTA, hid_channels_S=64, hid_channels_T=512, layers_S=4, layers_T=8
TAU input_shape=[10,1,64,64], translator_type=TAU, hid_channels_S=64, hid_channels_T=512, layers_S=4, layers_T=8
STLight in_channels=10, out_channels=10, hid_channels=1024, layers=16, patch_size=2

性能指标

Model Params MSE ↓ MAE ↓ RMSE ↓ PSNR ↑ SSIM ↑
PredRNN-V1 23.84M 25.4224 76.9728 4.9976 22.9966 0.9251
PredRNN-V2 23.86M 25.6829 77.4449 5.0228 22.9554 0.9264
UNet 17.27M 50.973 127.98 7.121 19.7168 0.8576
SmaAtUNet 4.03M 55.1125 137.735 7.407 19.3423 0.841
SimVP-V1 57.95M 32.6546 89.7017 5.6823 21.7713 0.9133
SimVP-V2 46.77M 27.2356 78.2131 5.1791 22.6811 0.9285
TAU 44.66M 26.4949 76.8922 5.1083 22.8022 0.9304
STLight 17.89M 23.1482 70.9686 4.7676 23.5775 0.9355

注: ↓表示越小越好,↑表示越大越好;所有模型仅训练了200个epoch,SimVP类型和UNet类型均未触发早停机制

SimVP Visualization

TaxiBJ:

  • 下载(bash): sh "Dataset\Download\download_taxibj.sh"
  • 模型配置与性能指标:

模型配置

Model Model Config
PredRNN-V1 input_channels=2, num_hidden_channels=[128,128,128,128], input_frames=4, output_frames=4, patch_size=2, kernel_size=5, reverse_scheduled_sampling=False
PredRNN-V2 input_channels=2, num_hidden_channels=[128,128,128,128], input_frames=4, output_frames=4, patch_size=2, kernel_size=5, reverse_scheduled_sampling=True
UNet in_channels=2, out_channels=2, in_frames=4, out_frames=4, bilinear=True
SmaAtUNet in_channels=2, out_channels=2, in_frames=4, out_frames=4, num_kernel=2, reduction_ratio=16
SimVP-V1 input_shape=[4,2,32,32], translator_type=IncepU, hid_channels_S=32, hid_channels_T=256, layers_S=2, layers_T=8
SimVP-V2 input_shape=[4,2,32,32], translator_type=gSTA, hid_channels_S=32, hid_channels_T=256, layers_S=2, layers_T=8
TAU input_shape=[4,2,32,32], translator_type=TAU, hid_channels_S=32, hid_channels_T=256, layers_S=2, layers_T=8
STLight in_channels=8, out_channels=8, hid_channels=256, layers=16, patch_size=1

性能指标

Model Params MSE ↓ MAE ↓ RMSE ↓ PSNR ↑ SSIM ↑
PredRNN-V1 23.66M 0.3276 15.1447 0.5344 39.6018 0.9772
PredRNN-V2 23.67M 0.3654 15.29 0.5453 39.5338 0.9764
UNet 17.27M 0.3518 15.7073 0.5444 39.3579 0.9766
SmaAtUNet 4.03M 0.3798 16.3223 0.5657 39.0235 0.9736
SimVP-V1 13.79M 0.3229 15.3546 0.5342 39.5185 0.9766
SimVP-V2 6.08M 0.3026 14.8174 0.5182 39.7826 0.9790
TAU 5.66M 0.3003 15.0326 0.5198 39.7168 0.9787
STLight 1.32M 0.3338 15.319 0.5328 39.5448 0.9774

注: ↓表示越小越好,↑表示越大越好

SimVP Visualization

SEVIR:

模型配置

Model Model Config
PredRNN-V1 input_channels=1, num_hidden_channels=[128,128,128,128], input_frames=7, output_frames=6, patch_size=8, kernel_size=5, reverse_scheduled_sampling=False
PredRNN-V2 input_channels=1, num_hidden_channels=[128,128,128,128], input_frames=7, output_frames=6, patch_size=8, kernel_size=5, reverse_scheduled_sampling=True
UNet in_channels=1, out_channels=1, in_frames=7, out_frames=6, bilinear=True
SmaAtUNet in_channels=1, out_channels=1, in_frames=7, out_frames=6, num_kernel=2, reduction_ratio=16
SimVP-V1 input_shape=[7,1,128,128], translator_type=IncepU, hid_channels_S=64, hid_channels_T=256, layers_S=4, layers_T=6
SimVP-V2 input_shape=[7,1,128,128], translator_type=gSTA, hid_channels_S=64, hid_channels_T=256, layers_S=4, layers_T=6
TAU input_shape=[7,1,128,128], translator_type=TAU, hid_channels_S=64, hid_channels_T=256, layers_S=4, layers_T=6
STLight in_channels=7, out_channels=6, hid_channels=512, layers=16, patch_size=4

性能指标

Model Params MSE ↓ MAE ↓ Precision ↑ Recall ↑ F1 ↑ Accuracy ↑ FAR ↓ CSI ↑ HSS ↑
PredRNN-V1 28.85M 102.78 618.525 0.7282 0.7421 0.7299 0.9462 0.0337 0.5885 0.6983
PredRNN-V2 28.87M 81.5994 535.471 0.7976 0.7048 0.741 0.9521 0.0227 0.6028 0.7132
UNet 17.27M 74.7094 502.168 0.8193 0.7129 0.7547 0.955 0.021 0.62 0.7286
SmaAtUNet 4.03M 75.0876 502.901 0.8153 0.7178 0.756 0.955 0.0216 0.6214 0.7298
SimVP-V1 11.76M 73.25 497.939 0.814 0.728 0.7612 0.9558 0.0224 0.6282 0.7353
SimVP-V2 7.28M 73.6464 488.54 0.8194 0.7214 0.7606 0.9558 0.021 0.6268 0.7349
TAU 6.80M 74.4695 499.356 0.8048 0.7327 0.761 0.955 0.0233 0.627 0.7347
STLight 4.79M 76.2394 516.638 0.8084 0.721 0.7545 0.9546 0.0228 0.6202 0.728

注: ↓表示越小越好,↑表示越大越好

SimVP Visualization

运行

  1. 配置: 模型&数据集&评估指标:
    • 模型配置文件(以数据集名称组织)
    • 数据集配置文件
    • 指标配置文件(不同数据集可选择不同指标)
  2. 训练: 在main脚本中指定model_name, dataset_name, mode进行训练
    • mode=train: 训练模型, 训练好的最佳模型保存在Output/Checkpoint目录, 训练日志记录在Output/Log目录, Loss曲线保存在Output/Visualization目录
  3. 测试: 打开Config中的模型配置文件, 将训练好的模型路径填入model_path, 然后在main脚本中指定model_name, dataset_name, mode进行测试
    • mode=test: 测试模型性能, 可指定save_interval采样保存样本数据, 样本数据保存在Output/Sample目录, 测试指标记录在Output/Log目录
  4. 可视化: 在main脚本中指定model_name, dataset_name, mode进行可视化
    • mode=visualize: 可视化Output/Sample目录下的样本数据
  5. 参数说明:
    • patience: 早停耐心值, 当超过多少个epoch没有提升valid loss时停止训练
    • resume_from: 断点续训, 意外中断训练时, 将当前最新的模型路径填入可以重新接着训练
    • model_path: 模型路径, 用于测试(mode=test)
    • save_interval: test模式下样本保存的间隔

About

STF: A beginner-friendly Spatio-Temporal Forecasting framework for video prediction and radar extrapolation. Includes ConvLSTM, PredRNN, SimVP, UNet, STLight, etc...

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors