use case of multi-GPU sharding, nnx.jit, save/load and performances #4575
Unanswered
jecampagne
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
I have a toy-example of a model to use Yong Song denoising sampling, that I am writing in different conditions to experience the FLAX NXX & Obax libs. To discuss & ask few questions I've setup this notebook on Colab just to read the code.
What about tthe necessity or not of the
key_scorenet
key?train_step
where I do not figure if I have done correctlyI observe that the sharding of
perturbed_x
is the same asx
the data, butrandom_t
which isthe second argument or the model call, it looks different : horizontal "GPU0" seperated to "GPU1", so I wander
if it is correct.
loss
is a scalar, so I do not know if it is the loss of the mean on all the models ???
as I have to use succesively the two state_restored statement to get loaded the model???
After the sampling looks ok.
Thanks for your attention.
Beta Was this translation helpful? Give feedback.
All reactions