Skip to content

PSquare-Lab/DVA-PFN

Repository files navigation

DVA-PFN

This repository contains the official code for our paper, "Decoupled-Value Attention for Prior-Data Fitted Networks: GP Inference for Physical Equations" (arXiv:2509.20950).

If you use this repository in your work, please cite our paper:

@misc{sharma2025decoupledvalueattentionpriordatafitted,
      title        = {Decoupled-Value Attention for Prior-Data Fitted Networks: GP Inference for Physical Equations}, 
      author       = {Kaustubh Sharma and Simardeep Singh and Parikshit Pareek},
      year         = {2025},
      eprint       = {2509.20950},
      archivePrefix= {arXiv},
      primaryClass = {cs.LG},
      url          = {https://arxiv.org/abs/2509.20950}, 
}

Installation

To set up the project environment, run the following command:

pip install -r requirements.txt

File Structure

  • train_DVA/: Scripts for training models with our DVA mechanism on synthetic data of varying dimensions (1D, 2D, 5D, 10D).
  • Softmax_Models/: Training scripts for models using standard softmax attention.
  • Robustness_codes/: Scripts to evaluate model stability and consistency by training with multiple random initializations.
  • Optuna/: Contains a script for automated hyperparameter optimization of the Transformer model using Optuna. You can replace the Transformer with a CNN model to use the same code for CNN optimization.
  • MSE_Eval/: Evaluation scripts to benchmark trained models against a classical Gaussian Process baseline and analyze key performance metrics.
  • Kernel_att_exp/: A plotting utility to visually compare the training loss curves from different experiments (e.g., DVA vs. softmax attention).
  • 64D_voltage/: Code and models specific to our 64D power flow equation approximation task.

How to Use

The workflow is designed to be modular. You'll typically proceed through training, optional hyperparameter tuning, robustness testing, and final evaluation.

  1. Training a New Model To train a new model, navigate to the train_DVA/ or Softmax_Models/ directory and run the script corresponding to your desired input dimensionality. For example, to train a model with DVA on 10D data:
python train_DVA/train_10D.py

The code is designed to be flexible. You can switch between Transformer and CNN architectures within the same training script by simply replacing the model class instantiation. Just remember to adjust the hyperparameters accordingly.

  1. Hyperparameter Tuning To find the optimal hyperparameters for the Transformer model, use the Optuna script:
python Optuna/optuna_tr_10d.py

This will run multiple trials, and the final output will provide the best configuration for your model.

  1. Robustness Testing The Robustness_codes/ directory contains scripts to check the stability of your model's performance. For example, to run 5 separate training runs for the DVA model on 10D data:
python Robustness_codes/train_DVA_10D_rob.py

Each run will save a separate model checkpoint, allowing you to analyze the variability of results.

  1. Evaluating a Trained Model Once a model is trained, use the MSE_Eval/mse_trans.py script to benchmark it against a Gaussian Process baseline. This script compares performance across different numbers of training points.
python MSE_Eval/mse_trans.py \
    --model_path <path/to/your/trained_model.pth> \
    --bucket_path <path/to/your/bucket_limits.pth> \
    --n_features 10 \
    --out_dir my_evaluation_results/
  1. Auditing Model Complexity To audit the number of parameters in your trained models, use the MSE_Eval/model_param.py script. It will recursively search for all .pth files and log their parameter counts.
python MSE_Eval/model_param.py
  1. Analyzing Experiment Logs The Kernel_att_exp/ directory contains a utility to create a single plot from multiple training logs, making it easy to compare different model variants.
python Kernel_att_exp/kernelatt_experiment.py \
    --rbf_kernelattn_log path/to/log1.log \
    --rbf_dva_log path/to/log2.log \
    --title "My Model Comparison" \
    --outdir comparison_plots

Note: Work in progress 🚧.

About

Repo for the DVA-PFN Paper

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages