Codebase for mitigating distribution shift in MLHS using tangent-space reegularized algorithm based on the paper [1] by Jiaxi Zhao and Qianxiao Li.
- Environment setup: if the device is not prepared for GPU programming, please comment out the package related to cuda.
conda create -n TR python=3.9.18
conda activate TR
python -m pip install -r requirements.txt
- Clone the repository.
git clone https://github.com/jiaxi98/TR.git
cd TR
In the following we will provide the detailed procedure to reproduce the full experiments in the paper. All the estimated execution times are based on 4 GPUs (NVIDIA GeForce RTX 3090). All the checkpoints of network models are provided and can be used to directly generate the plots in demo.ipynb
notebook without training the models.
Estimated execution time: 1 hour
mkdir ../data/NS
mkdir ../data/RD
bash generate_data.sh
Estimated execution time: 24 hour
mkdir ../models/NS
mkdir ../models/RD
bash train.sh
For a quick view of all the plots, we high recommand to run the demo.ipynb
notebook.
Estimated execution time: 5 minutes
- plots of the distribution shift phenomena, this script plots the Fig. 1 and Fig. 2 which illustrate the distribution shift phenomena for RD and NS
cd exp
python exp1.py
- plots of the linear dynamics experiments
python exp2.py
- plots of the comparison of distribution shift with different simulating parameters
python exp3.py
- plots of the comparison of TR, OLS, and the ground truth
python exp4.py
- generate the table
python exp5.py
[1] Mitigating distribution shift in machine learning-augmented hybrid simulation