-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KL div wrong way around? #5
Comments
Hi - thanks for the comment. I believe you are correct that there's an inconsistency with the paper and the code. Regarding implications: my initial guess is that it shouldn't make too much of a difference. However, the hyperparameters may not work as is and may have to be tuned for that new change (which could take a lot of time for the best results). If I have time, I will rerun experiments and see if I can reproduce the results using the KL the other way, and if not, how big the difference is. I found this paper "Revisiting Design Choices in Proximal Policy Optimization" that investigates forward vs reverse KL for PPO and seems to suggest there tends not to be a big difference in practice (but if there is, there appears to be a slight advantage to using the reverse KL, which is the one in my code and not my paper). If it turns out that the reverse KL is actually playing a big part in the performance of POLA-DiCE (ie, if even with extensive hyperparameter tuning, the results on coin game aren't as good using forward KL), then it might be an interesting follow-up research project to investigate exactly why that is (why the difference is magnified in the multi-agent setting compared to e.g. in the linked paper) and what the implications more broadly might be for multi-agent RL. |
Closing as fixes have been implemented and experiments rerun (see updated readme) |
POLA/jax_files/POLA_dice_jax.py
Line 500 in 6b07e89
The KL divergence looks like it's the wrong way around. Typically you want the expectation to be under the target distribution. Certainly in the paper it's written as KL(old||new), which would be
-old * (log(new) - log(old))
. Not sure what the implications are if it is indeed the wrong way around.The text was updated successfully, but these errors were encountered: