|
| 1 | +# source: https://github.com/ikostrikov/implicit_q_learning |
| 2 | +# https://arxiv.org/abs/2110.06169 |
| 3 | + |
| 4 | +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union |
| 5 | +from dataclasses import dataclass |
| 6 | +import random |
| 7 | +import time |
| 8 | + |
| 9 | +import d4rl |
| 10 | +import gym |
| 11 | +import numpy as np |
| 12 | +import pyrallis |
| 13 | +from tqdm import tqdm |
| 14 | +import flax |
| 15 | +import flax.linen as nn |
| 16 | +import jax |
| 17 | +import jax.numpy as jnp |
| 18 | +import optax |
| 19 | +from flax.training.train_state import TrainState |
| 20 | +from tensorflow_probability.substrates import jax as tfp |
| 21 | +from tensorboardX import SummaryWriter |
| 22 | +tfd = tfp.distributions |
| 23 | +tfb = tfp.bijectors |
| 24 | + |
| 25 | + |
| 26 | +@dataclass |
| 27 | +class TrainArgs: |
| 28 | + # Experiment |
| 29 | + exp_name: str = "iql_jax" |
| 30 | + gym_id: str = "halfcheetah-medium-expert-v2" |
| 31 | + seed: int = 1 |
| 32 | + log_dir: str = "runs" |
| 33 | + # IQL |
| 34 | + total_iterations: int = int(1e6) |
| 35 | + gamma: float = 0.99 |
| 36 | + actor_lr: float = 3e-4 |
| 37 | + value_lr: float = 3e-4 |
| 38 | + critic_lr: float = 3e-4 |
| 39 | + batch_size: int = 256 |
| 40 | + expectile: float = 0.7 |
| 41 | + temperature: float = 3.0 |
| 42 | + polyak: float = 0.005 |
| 43 | + eval_freq: int = int(5e3) |
| 44 | + eval_episodes: int = 10 |
| 45 | + log_freq: int = 1000 |
| 46 | + |
| 47 | + def __post_init__(self): |
| 48 | + self.exp_name = f"{self.exp_name}__{self.gym_id}" |
| 49 | + |
| 50 | + |
| 51 | +def make_env(env_id, seed): |
| 52 | + def thunk(): |
| 53 | + env = gym.make(env_id) |
| 54 | + env = gym.wrappers.RecordEpisodeStatistics(env) |
| 55 | + env.seed(seed) |
| 56 | + env.action_space.seed(seed) |
| 57 | + env.observation_space.seed(seed) |
| 58 | + return env |
| 59 | + return thunk |
| 60 | + |
| 61 | + |
| 62 | +def layer_init(scale=jnp.sqrt(2)): |
| 63 | + return nn.initializers.orthogonal(scale) |
| 64 | + |
| 65 | + |
| 66 | +class ValueNetwork(nn.Module): |
| 67 | + @nn.compact |
| 68 | + def __call__(self, x: jnp.ndarray): |
| 69 | + x = nn.Dense(256, kernel_init=layer_init())(x) |
| 70 | + x = nn.relu(x) |
| 71 | + x = nn.Dense(256, kernel_init=layer_init())(x) |
| 72 | + x = nn.relu(x) |
| 73 | + x = nn.Dense(1, kernel_init=layer_init())(x) |
| 74 | + return x |
| 75 | + |
| 76 | + |
| 77 | +class CriticNetwork(nn.Module): |
| 78 | + @nn.compact |
| 79 | + def __call__(self, x: jnp.ndarray, a: jnp.ndarray): |
| 80 | + x = jnp.concatenate([x, a], -1) |
| 81 | + x = nn.Dense(256, kernel_init=layer_init())(x) |
| 82 | + x = nn.relu(x) |
| 83 | + x = nn.Dense(256, kernel_init=layer_init())(x) |
| 84 | + x = nn.relu(x) |
| 85 | + x = nn.Dense(1, kernel_init=layer_init())(x) |
| 86 | + return x |
| 87 | + |
| 88 | + |
| 89 | +class DoubleCriticNetwork(nn.Module): |
| 90 | + @nn.compact |
| 91 | + def __call__(self, x: jnp.ndarray, a: jnp.ndarray): |
| 92 | + critic1 = CriticNetwork()(x, a) |
| 93 | + critic2 = CriticNetwork()(x, a) |
| 94 | + return critic1, critic2 |
| 95 | + |
| 96 | + |
| 97 | +EXP_ADV_MAX = 100.0 |
| 98 | +LOG_STD_MAX = 2.0 |
| 99 | +LOG_STD_MIN = -10.0 |
| 100 | + |
| 101 | + |
| 102 | +class Actor(nn.Module): |
| 103 | + action_dim: int |
| 104 | + |
| 105 | + @nn.compact |
| 106 | + def __call__(self, x: jnp.ndarray, temperature: float = 1.0): |
| 107 | + x = nn.Dense(256, kernel_init=layer_init())(x) |
| 108 | + x = nn.relu(x) |
| 109 | + x = nn.Dense(256, kernel_init=layer_init())(x) |
| 110 | + x = nn.relu(x) |
| 111 | + mean = nn.Dense(self.action_dim, kernel_init=layer_init())(x) |
| 112 | + log_std = self.param("log_std", nn.initializers.zeros, (self.action_dim, )) |
| 113 | + log_std = jnp.clip(log_std, LOG_STD_MIN, LOG_STD_MAX) |
| 114 | + dist = tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std) * temperature) |
| 115 | + return dist |
| 116 | + |
| 117 | + |
| 118 | +class TargetTrainState(TrainState): |
| 119 | + target_params: flax.core.FrozenDict |
| 120 | + |
| 121 | + |
| 122 | +class Batch(NamedTuple): |
| 123 | + observations: np.ndarray |
| 124 | + actions: np.ndarray |
| 125 | + rewards: np.ndarray |
| 126 | + masks: np.ndarray |
| 127 | + next_observations: np.ndarray |
| 128 | + |
| 129 | + |
| 130 | +class Dataset: |
| 131 | + def __init__(self): |
| 132 | + self.size = None |
| 133 | + self.observations = None |
| 134 | + self.actions = None |
| 135 | + self.rewards = None |
| 136 | + self.masks = None |
| 137 | + self.next_observations = None |
| 138 | + |
| 139 | + def load(self, env, eps=1e-5): |
| 140 | + self.env = env |
| 141 | + dataset = d4rl.qlearning_dataset(env) |
| 142 | + lim = 1 - eps # Clip to eps |
| 143 | + dataset["actions"] = np.clip(dataset["actions"], -lim, lim) |
| 144 | + self.size = len(dataset["observations"]) |
| 145 | + self.observations = dataset["observations"].astype(np.float32) |
| 146 | + self.actions = dataset["actions"].astype(np.float32) |
| 147 | + self.rewards = dataset["rewards"].astype(np.float32) |
| 148 | + self.masks = 1.0 - dataset["terminals"].astype(np.float32) |
| 149 | + self.next_observations = dataset["next_observations"].astype(np.float32) |
| 150 | + |
| 151 | + def sample(self, batch_size): |
| 152 | + idx = np.random.randint(self.size, size=batch_size) |
| 153 | + data = ( |
| 154 | + self.observations[idx], |
| 155 | + self.actions[idx], |
| 156 | + self.rewards[idx], |
| 157 | + self.masks[idx], |
| 158 | + self.next_observations[idx], |
| 159 | + ) |
| 160 | + return Batch(*data) |
| 161 | + |
| 162 | + |
| 163 | +if __name__ == "__main__": |
| 164 | + # Logging setup |
| 165 | + args = pyrallis.parse(config_class=TrainArgs) |
| 166 | + print(vars(args)) |
| 167 | + run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" |
| 168 | + writer = SummaryWriter(f"{args.log_dir}/{run_name}") |
| 169 | + writer.add_text( |
| 170 | + "hyperparameters", |
| 171 | + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), |
| 172 | + ) |
| 173 | + |
| 174 | + # Seeding |
| 175 | + random.seed(args.seed) |
| 176 | + np.random.seed(args.seed) |
| 177 | + key = jax.random.PRNGKey(args.seed) |
| 178 | + key, actor_key, critic_key, value_key = jax.random.split(key, 4) |
| 179 | + |
| 180 | + # Eval env setup |
| 181 | + env = make_env(args.gym_id, args.seed)() |
| 182 | + assert isinstance(env.action_space, gym.spaces.Box), "only continuous action space is supported" |
| 183 | + observation = env.observation_space.sample()[np.newaxis] |
| 184 | + action = env.action_space.sample()[np.newaxis] |
| 185 | + |
| 186 | + # Agent setup |
| 187 | + actor = Actor(action_dim=np.prod(env.action_space.shape)) |
| 188 | + actor_state = TrainState.create( |
| 189 | + apply_fn=actor.apply, |
| 190 | + params=actor.init(actor_key, observation), |
| 191 | + tx=optax.adam(learning_rate=args.actor_lr) |
| 192 | + ) |
| 193 | + vf = ValueNetwork() |
| 194 | + vf_state = TrainState.create( |
| 195 | + apply_fn=vf.apply, |
| 196 | + params=vf.init(value_key, observation), |
| 197 | + tx=optax.adam(learning_rate=args.value_lr) |
| 198 | + ) |
| 199 | + qf = DoubleCriticNetwork() |
| 200 | + qf_state = TargetTrainState.create( |
| 201 | + apply_fn=qf.apply, |
| 202 | + params=qf.init(critic_key, observation, action), |
| 203 | + target_params=qf.init(critic_key, observation, action), |
| 204 | + tx=optax.adam(learning_rate=args.critic_lr) |
| 205 | + ) |
| 206 | + |
| 207 | + # Dataset setup |
| 208 | + dataset = Dataset() |
| 209 | + dataset.load(env) |
| 210 | + start_time = time.time() |
| 211 | + |
| 212 | + def asymmetric_l2_loss(diff, expectile=0.8): |
| 213 | + weight = jnp.where(diff > 0, expectile, (1 - expectile)) |
| 214 | + return weight * (diff**2) |
| 215 | + |
| 216 | + def update_vf(vf_state, qf_state, batch): |
| 217 | + q1, q2 = qf.apply(qf_state.target_params, batch.observations, batch.actions) |
| 218 | + q = jnp.minimum(q1, q2) |
| 219 | + |
| 220 | + def vf_loss_fn(params): |
| 221 | + v = vf.apply(params, batch.observations) |
| 222 | + vf_loss = asymmetric_l2_loss(q - v, args.expectile).mean() |
| 223 | + return vf_loss, { |
| 224 | + "vf_loss": vf_loss, |
| 225 | + "v": v.mean(), |
| 226 | + } |
| 227 | + |
| 228 | + (vf_loss, info), grads = jax.value_and_grad(vf_loss_fn, has_aux=True)(vf_state.params) |
| 229 | + vf_state = vf_state.apply_gradients(grads=grads) |
| 230 | + return vf_state, info |
| 231 | + |
| 232 | + def update_actor(actor_state, vf_state, qf_state, batch): |
| 233 | + v = vf.apply(vf_state.params, batch.observations) |
| 234 | + q1, q2 = qf.apply(qf_state.target_params, batch.observations, batch.actions) |
| 235 | + q = jnp.minimum(q1, q2) |
| 236 | + exp_adv = jnp.exp((q - v) * args.temperature) |
| 237 | + exp_adv = jnp.minimum(exp_adv, EXP_ADV_MAX) |
| 238 | + |
| 239 | + def actor_loss_fn(params): |
| 240 | + dist = actor.apply(params, batch.observations) |
| 241 | + log_probs = dist.log_prob(batch.actions) |
| 242 | + actor_loss = -(exp_adv * log_probs).mean() |
| 243 | + return actor_loss, { |
| 244 | + "actor_loss": actor_loss, |
| 245 | + "adv": q - v, |
| 246 | + } |
| 247 | + |
| 248 | + (actor_loss, info), grads = jax.value_and_grad(actor_loss_fn, has_aux=True)(actor_state.params) |
| 249 | + actor_state = actor_state.apply_gradients(grads=grads) |
| 250 | + return actor_state, info |
| 251 | + |
| 252 | + def update_qf(vf_state, qf_state, batch): |
| 253 | + next_v = vf.apply(vf_state.params, batch.next_observations) |
| 254 | + target_q = batch.rewards + args.gamma * batch.masks * next_v |
| 255 | + |
| 256 | + def qf_loss_fn(params): |
| 257 | + q1, q2 = qf.apply(params, batch.observations, batch.actions) |
| 258 | + qf_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean() |
| 259 | + return qf_loss, { |
| 260 | + "qf_loss": qf_loss, |
| 261 | + "q1": q1.mean(), |
| 262 | + "q2": q2.mean(), |
| 263 | + } |
| 264 | + |
| 265 | + (qf_loss, info), grads = jax.value_and_grad(qf_loss_fn, has_aux=True)(qf_state.params) |
| 266 | + qf_state = qf_state.apply_gradients(grads=grads) |
| 267 | + return qf_state, info |
| 268 | + |
| 269 | + def update_target(qf_state): |
| 270 | + new_target_params = jax.tree_map( |
| 271 | + lambda p, tp: p * args.polyak + tp * (1 - args.polyak), qf_state.params, |
| 272 | + qf_state.target_params) |
| 273 | + return qf_state.replace(target_params=new_target_params) |
| 274 | + |
| 275 | + @jax.jit |
| 276 | + def update(actor_state, vf_state, qf_state, batch): |
| 277 | + vf_state, vf_info = update_vf(vf_state, qf_state, batch) |
| 278 | + actor_state, actor_info = update_actor(actor_state, vf_state, qf_state, batch) |
| 279 | + qf_state, qf_info = update_qf(vf_state, qf_state, batch) |
| 280 | + qf_state = update_target(qf_state) |
| 281 | + return actor_state, vf_state, qf_state, { |
| 282 | + **vf_info, **actor_info, **qf_info |
| 283 | + } |
| 284 | + |
| 285 | + @jax.jit |
| 286 | + def get_action(rng, actor_state, observation, temperature=1.0): |
| 287 | + dist = actor.apply(actor_state.params, observation, temperature) |
| 288 | + rng, key = jax.random.split(rng) |
| 289 | + action = dist.sample(seed=key) |
| 290 | + return rng, jnp.clip(action, -1, 1) |
| 291 | + |
| 292 | + # Main loop |
| 293 | + for global_step in tqdm(range(args.total_iterations), desc="Training", unit="iter"): |
| 294 | + |
| 295 | + # Batch update |
| 296 | + batch = dataset.sample(batch_size=args.batch_size) |
| 297 | + actor_state, vf_state, qf_state, update_info = update( |
| 298 | + actor_state, vf_state, qf_state, batch |
| 299 | + ) |
| 300 | + |
| 301 | + # Evaluation |
| 302 | + if global_step % args.eval_freq == 0: |
| 303 | + env.seed(args.seed) |
| 304 | + stats = {"return": [], "length": []} |
| 305 | + for _ in range(args.eval_episodes): |
| 306 | + obs, done = env.reset(), False |
| 307 | + while not done: |
| 308 | + key, action = get_action(key, actor_state, obs, temperature=0.0) |
| 309 | + action = np.asarray(action) |
| 310 | + obs, reward, done, info = env.step(action) |
| 311 | + for k in stats.keys(): |
| 312 | + stats[k].append(info["episode"][k[0]]) |
| 313 | + for k, v in stats.items(): |
| 314 | + writer.add_scalar(f"charts/episodic_{k}", np.mean(v), global_step) |
| 315 | + if k == "return": |
| 316 | + normalized_score = env.get_normalized_score(np.mean(v)) * 100 |
| 317 | + writer.add_scalar("charts/normalized_score", normalized_score, global_step) |
| 318 | + writer.flush() |
| 319 | + |
| 320 | + # Logging |
| 321 | + if global_step % args.log_freq == 0: |
| 322 | + for k, v in update_info.items(): |
| 323 | + if v.ndim == 0: |
| 324 | + writer.add_scalar(f"losses/{k}", v, global_step) |
| 325 | + else: |
| 326 | + writer.add_histogram(f"losses/{k}", v, global_step) |
| 327 | + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) |
| 328 | + writer.flush() |
| 329 | + |
| 330 | + env.close() |
| 331 | + writer.close() |
0 commit comments