Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pickling a trained model (NNX) #4247

Open
eugene opened this issue Oct 3, 2024 · 10 comments
Open

Pickling a trained model (NNX) #4247

eugene opened this issue Oct 3, 2024 · 10 comments

Comments

@eugene
Copy link

eugene commented Oct 3, 2024

I train small models and while prototyping and testing I wish to store trained models in a simple way (also having the model configuration and training-stats inside the model object as dicts/arrays). Idially i want to deal with a single object and being able to simply:

with open('my-model.pk', 'wb') as file:
    pickle.dump(model, file)

and later:

model = pickle.load('my-model.pk')

Similar to torch.save.

When i naively try to do the above I get:

----> 2     pickle.dump(model, file)
AttributeError: Can't pickle local object 'variance_scaling.<locals>.init'

Splitting into graphdef and state results in the following error:

cannot pickle 'PyTreeDef' object

While I am aware there is orbax and it might save the state, I really wish it would be possible to avoid that dependency and keep things simple. Is there a trick or a workaround I can use to achieve the desired functionality?

@cgarciae
Copy link
Collaborator

cgarciae commented Oct 3, 2024

Hey @eugene, can you try using cloudpickle? Usually has better results.

@eugene
Copy link
Author

eugene commented Oct 3, 2024

Hey @cgarciae , thanks for swift reply!

It might work, but still an extra dependency. It feels like "wholesome" save/load functionality is so fundamental, it should be added to NNX itself eg. nnx.save(...) / nnx.load(...). For now I guess I will stick to saving attributes individually and combining them again upon loading:

Saving:

with open('model.pickle', 'wb') as file:
    pickle.dump({
        'opts':   model.opts, 
        'stats':  model.stats,
        'state':  nnx.state(model)
    }, file)

Loading:

with open('model.pickle', 'rb') as file:
    model_dict = pickle.load(file)

model = Model(..., rngs=nnx.Rngs(0))
model.opts = model_dict['opts']
model.stats = model_dict['stats']
nnx.update(model, model_dict['state'])

@cgarciae
Copy link
Collaborator

cgarciae commented Oct 4, 2024

The only problem with such a simple method (which we could add) is that its only good for local single hosts setups. In general its recommended to use orbax for checkpointing. Here is a simple NNX example: 08_save_load_checkpoints.py

@eugene
Copy link
Author

eugene commented Oct 4, 2024

@cgarciae I get your point. But even with orbax, the attributes of the model are still stored saparately. Let's take a step back. The very premise of Flax NNX (and this is the first paragraph of the introduction to documentation, emphasis mine):

Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in JAX. It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.

So, as you have seen, I tried expressing my model as regular Python object and Flax behaved unexpectedly when I tried to save it.

Also, while multi-host scanario is important, I would bet that the vast majority of the users (especially researchers like myself) work on a single-host. While protototyping and benchmarking saving a quick model (self contained in a regular Python object with arbitrary attributes) makes a lot of sense and I would love to have that supported out of the box.

... and the more advanced multi-host scenario, can be handled by orbax.

@cgarciae
Copy link
Collaborator

cgarciae commented Oct 4, 2024

@eugene I agree we want to be friendly with pickle / cloudpickle. I think we can commit to making NNX compatible with cloudpickle, I wouldn't recommend using pickle for these types of tasks as it cannot handle lambdas and similar objects.

I'll add a simple test for cloudpickle to begin.

@eugene
Copy link
Author

eugene commented Oct 4, 2024

@cgarciae thats a step in the right direction, but I would love to hear your deeper considerations. To require cloudpickle, a separate dependency just to save, what suppose to be your everyday "regular Python object", seems to be a convoluted solution. Why wouldnt you want a nnx.save() and nnx.load() methods mimiking the API design of PyTorch (and by extension, complying to the Flax promise citet above, where those users should feel at home)?

@cgarciae
Copy link
Collaborator

cgarciae commented Oct 4, 2024

I've created #4253 adding support for cloudpickle.

For simple use cases maybe we could add a save_state / load_state API:

# save
nnx.save_state(nnx.state(model), 'model.ckp')
# load
model = Model(...)
nnx.update(model, nnx.load_state('model.ckp'))

@kelechi-c
Copy link

For simple use cases maybe we could add a save_state / load_state API:

# save
nnx.save_state(nnx.state(model), 'model.ckp')
# load
model = Model(...)
nnx.update(model, nnx.load_state('model.ckp'))

Please is the nnx.save_state API enabled now @cgarciae ?

@cgarciae
Copy link
Collaborator

@kelechi-c JAX team is cooking something similar so I'm just going to wait.

@kelechi-c
Copy link

kelechi-c commented Nov 23, 2024

@cgarciae Thanks! And thanks for all the work on NNX 🫡

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants