一个时空预测入门框架: 时空预测, 视频序列预测, 雷达回波外推...
python>=3.12, cuda>=12.x, pytorch>=2.6.0
- pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128
- pip install -r requirements.txt
- ConvLSTM*(NIPS 2015)
- PredRNN(V1:NIPS 2017, V2:IEEE 2022)
- PhyDNet(CVPR 2020)
- UNet(CVPR 2015)
- SmaAtUNet(PRL 2021)
- SimVP(IncepU:CVPR 2022, gSTA:IEEE 2022, TAU:CVPR 2023)
- STLight(WACV 2024)
- 下载(bash): sh "Dataset\Download\download_moving_mnist.sh"
- 模型配置与性能指标:
| Model | Model Config |
|---|---|
| PredRNN-V1 | input_channels=1, num_hidden_channels=[128,128,128,128], input_frames=10, output_frames=10, patch_size=4, kernel_size=5, reverse_scheduled_sampling=True |
| PredRNN-V2 | input_channels=1, num_hidden_channels=[128,128,128,128], input_frames=10, output_frames=10, patch_size=4, kernel_size=5, reverse_scheduled_sampling=True |
| UNet | in_channels=1, out_channels=1, in_frames=10, out_frames=10, bilinear=True |
| SmaAtUNet | in_channels=1, out_channels=1, in_frames=10, out_frames=10, num_kernel=2, reduction_ratio=16 |
| SimVP-V1 | input_shape=[10,1,64,64], translator_type=IncepU, hid_channels_S=64, hid_channels_T=512, layers_S=4, layers_T=8 |
| SimVP-V2 | input_shape=[10,1,64,64], translator_type=gSTA, hid_channels_S=64, hid_channels_T=512, layers_S=4, layers_T=8 |
| TAU | input_shape=[10,1,64,64], translator_type=TAU, hid_channels_S=64, hid_channels_T=512, layers_S=4, layers_T=8 |
| STLight | in_channels=10, out_channels=10, hid_channels=1024, layers=16, patch_size=2 |
| Model | Params | MSE ↓ | MAE ↓ | RMSE ↓ | PSNR ↑ | SSIM ↑ |
|---|---|---|---|---|---|---|
| PredRNN-V1 | 23.84M | 25.4224 | 76.9728 | 4.9976 | 22.9966 | 0.9251 |
| PredRNN-V2 | 23.86M | 25.6829 | 77.4449 | 5.0228 | 22.9554 | 0.9264 |
| UNet | 17.27M | 50.973 | 127.98 | 7.121 | 19.7168 | 0.8576 |
| SmaAtUNet | 4.03M | 55.1125 | 137.735 | 7.407 | 19.3423 | 0.841 |
| SimVP-V1 | 57.95M | 32.6546 | 89.7017 | 5.6823 | 21.7713 | 0.9133 |
| SimVP-V2 | 46.77M | 27.2356 | 78.2131 | 5.1791 | 22.6811 | 0.9285 |
| TAU | 44.66M | 26.4949 | 76.8922 | 5.1083 | 22.8022 | 0.9304 |
| STLight | 17.89M | 23.1482 | 70.9686 | 4.7676 | 23.5775 | 0.9355 |
注: ↓表示越小越好,↑表示越大越好;所有模型仅训练了200个epoch,SimVP类型和UNet类型均未触发早停机制
- 下载(bash): sh "Dataset\Download\download_taxibj.sh"
- 模型配置与性能指标:
| Model | Model Config |
|---|---|
| PredRNN-V1 | input_channels=2, num_hidden_channels=[128,128,128,128], input_frames=4, output_frames=4, patch_size=2, kernel_size=5, reverse_scheduled_sampling=False |
| PredRNN-V2 | input_channels=2, num_hidden_channels=[128,128,128,128], input_frames=4, output_frames=4, patch_size=2, kernel_size=5, reverse_scheduled_sampling=True |
| UNet | in_channels=2, out_channels=2, in_frames=4, out_frames=4, bilinear=True |
| SmaAtUNet | in_channels=2, out_channels=2, in_frames=4, out_frames=4, num_kernel=2, reduction_ratio=16 |
| SimVP-V1 | input_shape=[4,2,32,32], translator_type=IncepU, hid_channels_S=32, hid_channels_T=256, layers_S=2, layers_T=8 |
| SimVP-V2 | input_shape=[4,2,32,32], translator_type=gSTA, hid_channels_S=32, hid_channels_T=256, layers_S=2, layers_T=8 |
| TAU | input_shape=[4,2,32,32], translator_type=TAU, hid_channels_S=32, hid_channels_T=256, layers_S=2, layers_T=8 |
| STLight | in_channels=8, out_channels=8, hid_channels=256, layers=16, patch_size=1 |
| Model | Params | MSE ↓ | MAE ↓ | RMSE ↓ | PSNR ↑ | SSIM ↑ |
|---|---|---|---|---|---|---|
| PredRNN-V1 | 23.66M | 0.3276 | 15.1447 | 0.5344 | 39.6018 | 0.9772 |
| PredRNN-V2 | 23.67M | 0.3654 | 15.29 | 0.5453 | 39.5338 | 0.9764 |
| UNet | 17.27M | 0.3518 | 15.7073 | 0.5444 | 39.3579 | 0.9766 |
| SmaAtUNet | 4.03M | 0.3798 | 16.3223 | 0.5657 | 39.0235 | 0.9736 |
| SimVP-V1 | 13.79M | 0.3229 | 15.3546 | 0.5342 | 39.5185 | 0.9766 |
| SimVP-V2 | 6.08M | 0.3026 | 14.8174 | 0.5182 | 39.7826 | 0.9790 |
| TAU | 5.66M | 0.3003 | 15.0326 | 0.5198 | 39.7168 | 0.9787 |
| STLight | 1.32M | 0.3338 | 15.319 | 0.5328 | 39.5448 | 0.9774 |
注: ↓表示越小越好,↑表示越大越好
- 下载(bash): sh "Dataset\Download\download_sevir_vil.sh"
- 前置条件: 运行download_sevir_vil.sh脚本之前, 请确保已安装AWS CLI: https://docs.aws.amazon.com/zh_cn/cli/latest/userguide/getting-started-install.html
- SEVIR数据处理: https://github.com/ziheng1027/SEVIR/tree/main
- 模型配置与性能指标:
| Model | Model Config |
|---|---|
| PredRNN-V1 | input_channels=1, num_hidden_channels=[128,128,128,128], input_frames=7, output_frames=6, patch_size=8, kernel_size=5, reverse_scheduled_sampling=False |
| PredRNN-V2 | input_channels=1, num_hidden_channels=[128,128,128,128], input_frames=7, output_frames=6, patch_size=8, kernel_size=5, reverse_scheduled_sampling=True |
| UNet | in_channels=1, out_channels=1, in_frames=7, out_frames=6, bilinear=True |
| SmaAtUNet | in_channels=1, out_channels=1, in_frames=7, out_frames=6, num_kernel=2, reduction_ratio=16 |
| SimVP-V1 | input_shape=[7,1,128,128], translator_type=IncepU, hid_channels_S=64, hid_channels_T=256, layers_S=4, layers_T=6 |
| SimVP-V2 | input_shape=[7,1,128,128], translator_type=gSTA, hid_channels_S=64, hid_channels_T=256, layers_S=4, layers_T=6 |
| TAU | input_shape=[7,1,128,128], translator_type=TAU, hid_channels_S=64, hid_channels_T=256, layers_S=4, layers_T=6 |
| STLight | in_channels=7, out_channels=6, hid_channels=512, layers=16, patch_size=4 |
| Model | Params | MSE ↓ | MAE ↓ | Precision ↑ | Recall ↑ | F1 ↑ | Accuracy ↑ | FAR ↓ | CSI ↑ | HSS ↑ |
|---|---|---|---|---|---|---|---|---|---|---|
| PredRNN-V1 | 28.85M | 102.78 | 618.525 | 0.7282 | 0.7421 | 0.7299 | 0.9462 | 0.0337 | 0.5885 | 0.6983 |
| PredRNN-V2 | 28.87M | 81.5994 | 535.471 | 0.7976 | 0.7048 | 0.741 | 0.9521 | 0.0227 | 0.6028 | 0.7132 |
| UNet | 17.27M | 74.7094 | 502.168 | 0.8193 | 0.7129 | 0.7547 | 0.955 | 0.021 | 0.62 | 0.7286 |
| SmaAtUNet | 4.03M | 75.0876 | 502.901 | 0.8153 | 0.7178 | 0.756 | 0.955 | 0.0216 | 0.6214 | 0.7298 |
| SimVP-V1 | 11.76M | 73.25 | 497.939 | 0.814 | 0.728 | 0.7612 | 0.9558 | 0.0224 | 0.6282 | 0.7353 |
| SimVP-V2 | 7.28M | 73.6464 | 488.54 | 0.8194 | 0.7214 | 0.7606 | 0.9558 | 0.021 | 0.6268 | 0.7349 |
| TAU | 6.80M | 74.4695 | 499.356 | 0.8048 | 0.7327 | 0.761 | 0.955 | 0.0233 | 0.627 | 0.7347 |
| STLight | 4.79M | 76.2394 | 516.638 | 0.8084 | 0.721 | 0.7545 | 0.9546 | 0.0228 | 0.6202 | 0.728 |
注: ↓表示越小越好,↑表示越大越好
- 配置: 模型&数据集&评估指标:
- 模型配置文件(以数据集名称组织)
- 数据集配置文件
- 指标配置文件(不同数据集可选择不同指标)
- 训练: 在main脚本中指定model_name, dataset_name, mode进行训练
- mode=train: 训练模型, 训练好的最佳模型保存在Output/Checkpoint目录, 训练日志记录在Output/Log目录, Loss曲线保存在Output/Visualization目录
- 测试: 打开Config中的模型配置文件, 将训练好的模型路径填入model_path, 然后在main脚本中指定model_name, dataset_name, mode进行测试
- mode=test: 测试模型性能, 可指定save_interval采样保存样本数据, 样本数据保存在Output/Sample目录, 测试指标记录在Output/Log目录
- 可视化: 在main脚本中指定model_name, dataset_name, mode进行可视化
- mode=visualize: 可视化Output/Sample目录下的样本数据
- 参数说明:
- patience: 早停耐心值, 当超过多少个epoch没有提升valid loss时停止训练
- resume_from: 断点续训, 意外中断训练时, 将当前最新的模型路径填入可以重新接着训练
- model_path: 模型路径, 用于测试(mode=test)
- save_interval: test模式下样本保存的间隔


