Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Option for Kaiming He Layer Initialization #53

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rllte/xplore/reward/disagreement.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Disagreement(BaseReward):
batch_size (int): The batch size for training.
update_proportion (float): The proportion of the training data used for updating the forward dynamics models.
encoder_model (str): The network architecture of the encoder from ['mnih', 'pathak'].
weight_init (str): The weight initialization method from ['default', 'orthogonal'].
weight_init (str): The weight initialization method from ['default', 'orthogonal', 'kaiming he'].

Returns:
Instance of Disagreement.
Expand Down
2 changes: 1 addition & 1 deletion rllte/xplore/reward/e3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class E3B(BaseReward):
batch_size (int): The batch size for training.
update_proportion (float): The proportion of the training data used for updating the forward dynamics models.
encoder_model (str): The network architecture of the encoder from ['mnih', 'pathak'].
weight_init (str): The weight initialization method from ['default', 'orthogonal'].
weight_init (str): The weight initialization method from ['default', 'orthogonal', 'kaiming he'].

Returns:
Instance of E3B.
Expand Down
2 changes: 1 addition & 1 deletion rllte/xplore/reward/icm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ICM(BaseReward):
batch_size (int): The batch size for training.
update_proportion (float): The proportion of the training data used for updating the forward dynamics models.
encoder_model (str): The network architecture of the encoder from ['mnih', 'pathak'].
weight_init (str): The weight initialization method from ['default', 'orthogonal'].
weight_init (str): The weight initialization method from ['default', 'orthogonal', 'kaiming he'].

Returns:
Instance of ICM.
Expand Down
10 changes: 10 additions & 0 deletions rllte/xplore/reward/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def orthogonal_layer_init(layer, std=np.sqrt(2), bias_const=0.0):
th.nn.init.constant_(layer.bias, bias_const)
return layer

def kaiming_he_init(layer):
th.nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
if layer.bias is not None:
th.nn.init.zeros_(layer.bias)
return layer

def default_layer_init(layer):
stdv = 1. / math.sqrt(layer.weight.size(1))
layer.weight.data.uniform_(-stdv, stdv)
Expand All @@ -49,6 +55,8 @@ class ObservationEncoder(nn.Module):
Args:
obs_shape (Tuple): The data shape of observations.
latent_dim (int): The dimension of encoding vectors.
encoder_model (str): The network architecture of the encoder from ['mnih', 'espeholt']. Defaults to 'mnih'
weight_init (str): The weight initialization method from ['default', 'orthogonal', 'kaiming he']. Defaults to 'default'

Returns:
Encoder instance.
Expand All @@ -59,6 +67,8 @@ def __init__(self, obs_shape: Tuple, latent_dim: int, encoder_model:str = "mnih"

if weight_init == "orthogonal":
init_ = orthogonal_layer_init
elif weight_init == "kaiming he":
init_ = kaiming_he_init
elif weight_init == "default":
init_ = default_layer_init
else:
Expand Down
2 changes: 2 additions & 0 deletions rllte/xplore/reward/ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class NGU(Fabric):
sm (float): The kernel maximum similarity.
mrs (float): The maximum reward scaling.
update_proportion (float): The proportion of the training data used for updating the forward dynamics models.
encoder_model (str): The network architecture of the encoder from ['mnih', 'pathak'].
weight_init (str): The weight initialization method from ['default', 'orthogonal', 'kaiming he'].

Returns:
Instance of NGU.
Expand Down
2 changes: 1 addition & 1 deletion rllte/xplore/reward/pseudo_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class PseudoCounts(BaseReward):
sm (float): The kernel maximum similarity.
update_proportion (float): The proportion of the training data used for updating the forward dynamics models.
encoder_model (str): The network architecture of the encoder from ['mnih', 'pathak'].
weight_init (str): The weight initialization method from ['default', 'orthogonal'].
weight_init (str): The weight initialization method from ['default', 'orthogonal', 'kaiming he'].

Returns:
Instance of PseudoCounts.
Expand Down
2 changes: 1 addition & 1 deletion rllte/xplore/reward/re3.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class RE3(BaseReward):
k (int): Use the k-th neighbors.
average_entropy (bool): Use the average of entropy estimation.
encoder_model (str): The network architecture of the encoder from ['mnih', 'pathak'].
weight_init (str): The weight initialization method from ['default', 'orthogonal'].
weight_init (str): The weight initialization method from ['default', 'orthogonal', 'kaiming he'].

Returns:
Instance of RE3.
Expand Down
2 changes: 1 addition & 1 deletion rllte/xplore/reward/ride.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class RIDE(BaseReward):
sm (float): The kernel maximum similarity.
update_proportion (float): The proportion of the training data used for updating the forward dynamics models.
encoder_model (str): The network architecture of the encoder from ['mnih', 'pathak'].
weight_init (str): The weight initialization method from ['default', 'orthogonal'].
weight_init (str): The weight initialization method from ['default', 'orthogonal', 'kaiming he'].

Returns:
Instance of RIDE.
Expand Down
2 changes: 1 addition & 1 deletion rllte/xplore/reward/rnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class RND(BaseReward):
batch_size (int): The batch size for training.
update_proportion (float): The proportion of the training data used for updating the forward dynamics models.
encoder_model (str): The network architecture of the encoder from ['mnih', 'pathak'].
weight_init (str): The weight initialization method from ['default', 'orthogonal'].
weight_init (str): The weight initialization method from ['default', 'orthogonal', 'kaiming he'].

Returns:
Instance of RND.
Expand Down