Skip to content

Commit

Permalink
prevent fast network update from learner, configuratble discount, upd…
Browse files Browse the repository at this point in the history
…ate agentlace version

Signed-off-by: youliangtan <[email protected]>
  • Loading branch information
youliangtan committed Jun 6, 2024
1 parent 9be150b commit 21d7c8e
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.")
flags.DEFINE_integer("training_starts", 300, "Training starts after this step.")
flags.DEFINE_integer("steps_per_update", 10, "Number of steps per update the server.")
flags.DEFINE_integer("steps_per_update", 50, "Number of steps per update the server.")

flags.DEFINE_integer("log_period", 10, "Logging period.")
flags.DEFINE_integer("eval_period", 2000, "Evaluation period.")
Expand Down
2 changes: 1 addition & 1 deletion examples/async_cable_route_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@

flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.")
flags.DEFINE_integer("training_starts", 300, "Training starts after this step.")
flags.DEFINE_integer("steps_per_update", 10, "Number of steps per update the server.")
flags.DEFINE_integer("steps_per_update", 50, "Number of steps per update the server.")

flags.DEFINE_integer("log_period", 10, "Logging period.")
flags.DEFINE_integer("eval_period", 2000, "Evaluation period.")
Expand Down
2 changes: 1 addition & 1 deletion examples/async_drq_sim/async_drq_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.")
flags.DEFINE_integer("training_starts", 300, "Training starts after this step.")
flags.DEFINE_integer("steps_per_update", 10, "Number of steps per update the server.")
flags.DEFINE_integer("steps_per_update", 50, "Number of steps per update the server.")

flags.DEFINE_integer("log_period", 10, "Logging period.")
flags.DEFINE_integer("eval_period", 2000, "Evaluation period.")
Expand Down
2 changes: 1 addition & 1 deletion examples/async_pcb_insert_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.")
flags.DEFINE_integer("training_starts", 300, "Training starts after this step.")
flags.DEFINE_integer("steps_per_update", 10, "Number of steps per update the server.")
flags.DEFINE_integer("steps_per_update", 50, "Number of steps per update the server.")

flags.DEFINE_integer("log_period", 10, "Logging period.")
flags.DEFINE_integer("eval_period", 2000, "Evaluation period.")
Expand Down
2 changes: 1 addition & 1 deletion examples/async_peg_insert_drq/async_drq_randomized.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.")
flags.DEFINE_integer("training_starts", 300, "Training starts after this step.")
flags.DEFINE_integer("steps_per_update", 10, "Number of steps per update the server.")
flags.DEFINE_integer("steps_per_update", 50, "Number of steps per update the server.")

flags.DEFINE_integer("log_period", 10, "Logging period.")
flags.DEFINE_integer("eval_period", 2000, "Evaluation period.")
Expand Down
2 changes: 1 addition & 1 deletion examples/async_sac_state_sim/async_sac_state_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.")
flags.DEFINE_integer("training_starts", 300, "Training starts after this step.")
flags.DEFINE_integer("steps_per_update", 10, "Number of steps per update the server.")
flags.DEFINE_integer("steps_per_update", 50, "Number of steps per update the server.")

flags.DEFINE_integer("log_period", 10, "Logging period.")
flags.DEFINE_integer("eval_period", 2000, "Evaluation period.")
Expand Down
16 changes: 11 additions & 5 deletions serl_launcher/serl_launcher/utils/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def make_bc_agent(
)


def make_sac_agent(seed, sample_obs, sample_action):
def make_sac_agent(seed, sample_obs, sample_action, discount=0.99):
return SACAgent.create_states(
jax.random.PRNGKey(seed),
sample_obs,
Expand All @@ -69,15 +69,20 @@ def make_sac_agent(seed, sample_obs, sample_action):
"hidden_dims": [256, 256],
},
temperature_init=1e-2,
discount=0.99,
discount=discount,
backup_entropy=False,
critic_ensemble_size=10,
critic_subsample_size=2,
)


def make_drq_agent(
seed, sample_obs, sample_action, image_keys=("image",), encoder_type="small"
seed,
sample_obs,
sample_action,
image_keys=("image",),
encoder_type="small",
discount=0.96,
):
agent = DrQAgent.create_drq(
jax.random.PRNGKey(seed),
Expand All @@ -103,7 +108,7 @@ def make_drq_agent(
"hidden_dims": [256, 256],
},
temperature_init=1e-2,
discount=0.96, # 0.99
discount=discount,
backup_entropy=False,
critic_ensemble_size=10,
critic_subsample_size=2,
Expand All @@ -119,6 +124,7 @@ def make_vice_agent(
image_keys=("image",),
vice_image_keys=("image",),
encoder_type="small",
discount=0.96,
):
agent = VICEAgent.create_vice(
jax.random.PRNGKey(seed),
Expand Down Expand Up @@ -154,7 +160,7 @@ def make_vice_agent(
"hidden_dims": [256, 256],
},
temperature_init=1e-2,
discount=0.96, # 0.99
discount=discount,
backup_entropy=False,
critic_ensemble_size=10,
critic_subsample_size=2,
Expand Down
2 changes: 1 addition & 1 deletion serl_launcher/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"typing_extensions",
"opencv-python",
"lz4",
"agentlace@git+https://github.com/youliangtan/agentlace.git@892d1557264d7bb1d5df04b37638c850c9d36f35",
"agentlace@git+https://github.com/youliangtan/agentlace.git@b9be677d5d20772fca98c8be44777ecb7111bc59",
],
packages=find_packages(),
zip_safe=False,
Expand Down

0 comments on commit 21d7c8e

Please sign in to comment.