-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pickling a trained model (NNX) #4247
Comments
Hey @eugene, can you try using |
Hey @cgarciae , thanks for swift reply! It might work, but still an extra dependency. It feels like "wholesome" save/load functionality is so fundamental, it should be added to NNX itself eg. Saving: with open('model.pickle', 'wb') as file:
pickle.dump({
'opts': model.opts,
'stats': model.stats,
'state': nnx.state(model)
}, file) Loading: with open('model.pickle', 'rb') as file:
model_dict = pickle.load(file)
model = Model(..., rngs=nnx.Rngs(0))
model.opts = model_dict['opts']
model.stats = model_dict['stats']
nnx.update(model, model_dict['state']) |
The only problem with such a simple method (which we could add) is that its only good for local single hosts setups. In general its recommended to use |
@cgarciae I get your point. But even with
So, as you have seen, I tried expressing my model as regular Python object and Flax behaved unexpectedly when I tried to save it. Also, while multi-host scanario is important, I would bet that the vast majority of the users (especially researchers like myself) work on a single-host. While protototyping and benchmarking saving a quick model (self contained in a regular Python object with arbitrary attributes) makes a lot of sense and I would love to have that supported out of the box. ... and the more advanced multi-host scenario, can be handled by |
@eugene I agree we want to be friendly with pickle / cloudpickle. I think we can commit to making NNX compatible with I'll add a simple test for |
@cgarciae thats a step in the right direction, but I would love to hear your deeper considerations. To require |
I've created #4253 adding support for For simple use cases maybe we could add a # save
nnx.save_state(nnx.state(model), 'model.ckp')
# load
model = Model(...)
nnx.update(model, nnx.load_state('model.ckp')) |
Please is the |
@kelechi-c JAX team is cooking something similar so I'm just going to wait. |
@cgarciae Thanks! And thanks for all the work on NNX 🫡 |
I train small models and while prototyping and testing I wish to store trained models in a simple way (also having the model configuration and training-stats inside the model object as dicts/arrays). Idially i want to deal with a single object and being able to simply:
and later:
Similar to
torch.save
.When i naively try to do the above I get:
Splitting into
graphdef
andstate
results in the following error:While I am aware there is
orbax
and it might save the state, I really wish it would be possible to avoid that dependency and keep things simple. Is there a trick or a workaround I can use to achieve the desired functionality?The text was updated successfully, but these errors were encountered: