-
Notifications
You must be signed in to change notification settings - Fork 297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
converted mlperf gpt3 ckpt starts with a worse loss #887
Comments
@ZhiyuLi-goog thanks again for your help with other issues. Do you see any problems with the config or know why the loss is much higher? |
I have never tried on GPU.
|
with attention: "dot_product" : I get similar loss as before |
Oh, could you try something like
instead of changing the base.yml?
I think |
I tested these out. First running
and then also adding the other relevant flags you posted one by one and all of them start with the bad loss (7.6x). So its not flash attn, tokenizer (as validation is pretokenized and evaluated loss is also bad), config args (as i tried the flags you have suggested) Its probably something to do with the model weights |
I can take a look at full logs if you have. |
Thanks. Here are the logs |
Checked the log. |
Thanks for checking |
The only one I found looks weird is
Could you try using weight_dtype as float32 instead of bfloat16? However, I do not expect such a big gap. |
Tried the weight_dtype as float32 as well. Same problem im wondering if we can send you our converted ckpt for you to load and verify its an ckpt problem? |
I can take a try in TPU side By the way, would it be useful to you to print the mean average of each param state after conversion? |
im not sure if it will be useful. |
It would be easiest if you have some converted ckpt, I can directly compare your converted ckpt against ours. We didn't try that in gpu, I guess there might be something differently. |
great |
Great if you can share with us some open gcloud bucket. |
ok, let me do that |
Gotcha, thank you for the info! |
We have created the bucket and will share the access with you soon (I got your google email from one of your commits) |
Hello again, We have shared three ckpts: The second and third are both the latest branch - the second was scan_layers=false, and the third is scan_layers=true let us know if you are able to access and if you have any questions |
Thank you @gramesh-amd I will test with your ckpt. |
I wrote a script to look at the checkpoint that we generated and compare it to the original data. Any suggestions for how we can confirm that this is what happened and debug it? |
@gabeweisz I am just wondering how did you run the script? I do expect to :
key idea
exampleI have tried the script yesterday and worked for me.
Note the output directory is a gcs bucket which can be accessible by all devices. |
We ran the script in a way very similar to how you ran it - my colleague Gowtham has shared what we did earlier. When we ran this, we didn't have a shared NFS big enough for all the nodes and did not have access to a GCS bucket - each node was writing to its own local directory. I did check, and once the script was finished, only node 0 had a checkpoint - none of the others did. Do you think this caused the issue? If so, does Orbax have a way to work around this? Another option is that I can try to modify an on-disk checkpoint using tensorstore as in the documentation - there is no real reason why we need to load the checkpoint onto GPUs to convert it from one format to another. |
We just found the place in the documentation where orbax says that all nodes need to write to the same filesystem - that explains what went wrong for us. |
Exactly, I think it should be the root cause. |
Hello,
We converted the paxml checkpoint and resumed training with following config:
The tokenizer and data splits (3.0.4, 3.0.5) were downloaded from mlperf2 bucket. I have also tried using the c4_mlperf dataset_type like this:
^ scan_layers set to true in line with how we converted the ckpt
^ starts with a very high loss and we expected something closer to 2.77
We have ensured that the training loads the right checkpoint, the correct data splits and also the tokenizer from the logs
The text was updated successfully, but these errors were encountered: