diff --git a/serl_launcher/serl_launcher/networks/reward_classifier.py b/serl_launcher/serl_launcher/networks/reward_classifier.py index b2109ff1..6cdf3c93 100644 --- a/serl_launcher/serl_launcher/networks/reward_classifier.py +++ b/serl_launcher/serl_launcher/networks/reward_classifier.py @@ -12,6 +12,7 @@ from serl_launcher.common.encoding import EncodingWrapper from flax.core.frozen_dict import freeze, unfreeze + class BinaryClassifier(nn.Module): encoder_def: nn.Module hidden_dim: int = 256