Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving the pm.Approximator object when using VI #671

Open
igrahek opened this issue Feb 26, 2025 · 2 comments
Open

Saving the pm.Approximator object when using VI #671

igrahek opened this issue Feb 26, 2025 · 2 comments

Comments

@igrahek
Copy link

igrahek commented Feb 26, 2025

When fitting larger models with HSSM I follow the workflow of saving the samples object and and then loading it and attaching it to the model:

cav_data = hssm.load_data("cavanagh_theta") # Load the data
cav_model = hssm.HSSM(data=cav_data, model="angle") # Specify the model
samples = cav_model.vi(niter=20, method="fullrank_advi") # Run VI
samples.to_netcdf('../output/ModelTest') # Save the model

samples = az.InferenceData.from_netcdf(f"../output/ModelTest") # Load the inference data
cav_model = hssm.HSSM(data=cav_data, model="angle") # Specify the model
model._inference_obj = samples # Attach the loaded inference data to the model

However, I am not sure how to save the pm.Approximator object in a similar fashion to be able to check convergence after the model has finished. What would be the good format to save/load this object? I guess that it can be attached to the model the same way the inference object is?

If the only usage of pm.Approximator is to check the loss, then maybe there is no need to save it and having samples is enough? The loss image could just be saved after model fitting with:

plt.plot(model.vi_approx.hist)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.savefig('../output/ModelTest_convergence.png')
@gpagnier
Copy link

I've been able to use cloudpickle to save and load the pm.Approximator object

import cloudpickle
#Saving
with open('cav_model.ApproxObjectSaved.pkl', 'wb') as f:
    cloudpickle.dump(cav_model.vi_approx, f)
#Loading at a later date
with open('cav_model.ApproxObjectSaved.pkl', 'rb') as f:
    approxObjectLoaded = cloudpickle.load(f)
#Looking at Loss plot using loaded approx object
plt.plot(approxObjectLoaded.hist)

Heads up, cloudpickle versions (and probably other packages but that one was the one that threw an error for me) need to match from when you save it to when you load it. Will update the VI tutorial soon to reflect this functionality.

@igrahek
Copy link
Author

igrahek commented Feb 26, 2025

Amazing, thank you so much Guillaume!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants