Skip to content

Commit 0c5c677

Browse files
committedApr 4, 2023
Add iql_jax implementation
1 parent f6ba4bd commit 0c5c677

15 files changed

+365
-14
lines changed
 

‎LICENSE

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
SOFTWARE.
2222

2323
--------------------------------------------------------------------------------
24-
Code in `flexrl/src/algorithms/ppo` and `flexrl/src/algorithms/ppo_multidiscrete` are adapted from https://github.com/vwxyzjn/cleanrl
24+
Code in `flexrl/src/online/ppo.py`, `flexrl/src/online/ppo_multidiscrete.py`,
25+
`flexrl/src/online/ppo_atari.py`, `flexrl/src/online/dqn.py`,
26+
`flexrl/src/online/dqn_atari.py`, `flexrl/src/online/sac.py`,
27+
are adapted from https://github.com/vwxyzjn/cleanrl
2528

2629
MIT License
2730

‎README.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
FlexRL is a deep online/offline reinforcement learning library inspired and adapted from [CleanRL](https://github.com/vwxyzjn/cleanrl) and [CORL](https://github.com/tinkoff-ai/CORL) that provides single-file implementations of algorithms that aren't necessarily covered by these libraries. FlexRL introduces the following features:
44
- Consistent style across online and offline algorithms
5-
- Easy configuration with [Pyrallis](https://github.com/eladrich/pyrallis) and progress bar
5+
- Easy configuration with [Pyrallis](https://github.com/eladrich/pyrallis) and [tqdm](https://github.com/tqdm/tqdm) progress bar
66
- A few custom environments under `gym` API
77

88
## Quick Start
@@ -35,6 +35,7 @@ python ppo.py --config_path=some_config.yaml
3535
| | | [qr_dqn_atari.py](src/flexrl/online/qr_dqn_atari.py) |
3636
| | Soft Actor-Critic (SAC) | [sac.py](src/flexrl/online/sac.py) |
3737
| Offline | Implicit Q-Learning (IQL) | [iql.py](src/flexrl/offline/iql.py) |
38+
| | | [iql_jax.py](src/flexrl/offline/iql_jax.py) |
3839
| | In-Sample Actor-Critic (InAC) | [inac.py](src/flexrl/offline/inac.py) |
3940

4041
### Extra Requirements
@@ -51,6 +52,10 @@ ale-import-roms roms/
5152

5253
To use MuJoCo envs (for both online training and offline evaluation), you need to install MuJoCo first. See [mujoco-py](https://github.com/openai/mujoco-py) for instructions.
5354

55+
#### JAX with CUDA Support
56+
57+
To use JAX with CUDA support, you need to install the NVIDIA driver first. See [JAX Installation](https://github.com/google/jax#installation) for instructions.
58+
5459
### References
5560

5661
- [1] S. Huang, R. F. J. Dossa, C. Ye, and J. Braga, “CleanRL: High-quality Single-file Implementations of Deep Reinforcement Learning Algorithms.” arXiv, Nov. 16, 2021. Accessed: Nov. 21, 2022. [Online]. Available: http://arxiv.org/abs/2111.08819

‎setup.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
requires = [
44
"torch==1.13.0",
55
"stable-baselines3==1.1.0",
6-
"tensorboard==2.10.1",
6+
"tensorboardX==2.6",
77
"opencv-python==4.6.0.66",
88
"gym[mujoco_py, classic_control]==0.23.1",
9+
"jax==0.4.6",
10+
"flax==0.6.7",
11+
"optax==0.1.4",
12+
"tensorflow-probability==0.19.0",
913
"ale-py==0.7.4",
1014
"mujoco-py==2.1.2.14",
1115
"tqdm==4.64.0",

‎src/flexrl/offline/inac.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch.nn as nn
1616
import torch.nn.functional as F
1717
import torch.optim as optim
18-
from torch.utils.tensorboard import SummaryWriter
18+
from tensorboardX import SummaryWriter
1919

2020

2121
@dataclass

‎src/flexrl/offline/iql.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch.nn as nn
1616
import torch.nn.functional as F
1717
import torch.optim as optim
18-
from torch.utils.tensorboard import SummaryWriter
18+
from tensorboardX import SummaryWriter
1919

2020

2121
@dataclass

‎src/flexrl/offline/iql_jax.py

+331
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
# source: https://github.com/ikostrikov/implicit_q_learning
2+
# https://arxiv.org/abs/2110.06169
3+
4+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union
5+
from dataclasses import dataclass
6+
import random
7+
import time
8+
9+
import d4rl
10+
import gym
11+
import numpy as np
12+
import pyrallis
13+
from tqdm import tqdm
14+
import flax
15+
import flax.linen as nn
16+
import jax
17+
import jax.numpy as jnp
18+
import optax
19+
from flax.training.train_state import TrainState
20+
from tensorflow_probability.substrates import jax as tfp
21+
from tensorboardX import SummaryWriter
22+
tfd = tfp.distributions
23+
tfb = tfp.bijectors
24+
25+
26+
@dataclass
27+
class TrainArgs:
28+
# Experiment
29+
exp_name: str = "iql_jax"
30+
gym_id: str = "halfcheetah-medium-expert-v2"
31+
seed: int = 1
32+
log_dir: str = "runs"
33+
# IQL
34+
total_iterations: int = int(1e6)
35+
gamma: float = 0.99
36+
actor_lr: float = 3e-4
37+
value_lr: float = 3e-4
38+
critic_lr: float = 3e-4
39+
batch_size: int = 256
40+
expectile: float = 0.7
41+
temperature: float = 3.0
42+
polyak: float = 0.005
43+
eval_freq: int = int(5e3)
44+
eval_episodes: int = 10
45+
log_freq: int = 1000
46+
47+
def __post_init__(self):
48+
self.exp_name = f"{self.exp_name}__{self.gym_id}"
49+
50+
51+
def make_env(env_id, seed):
52+
def thunk():
53+
env = gym.make(env_id)
54+
env = gym.wrappers.RecordEpisodeStatistics(env)
55+
env.seed(seed)
56+
env.action_space.seed(seed)
57+
env.observation_space.seed(seed)
58+
return env
59+
return thunk
60+
61+
62+
def layer_init(scale=jnp.sqrt(2)):
63+
return nn.initializers.orthogonal(scale)
64+
65+
66+
class ValueNetwork(nn.Module):
67+
@nn.compact
68+
def __call__(self, x: jnp.ndarray):
69+
x = nn.Dense(256, kernel_init=layer_init())(x)
70+
x = nn.relu(x)
71+
x = nn.Dense(256, kernel_init=layer_init())(x)
72+
x = nn.relu(x)
73+
x = nn.Dense(1, kernel_init=layer_init())(x)
74+
return x
75+
76+
77+
class CriticNetwork(nn.Module):
78+
@nn.compact
79+
def __call__(self, x: jnp.ndarray, a: jnp.ndarray):
80+
x = jnp.concatenate([x, a], -1)
81+
x = nn.Dense(256, kernel_init=layer_init())(x)
82+
x = nn.relu(x)
83+
x = nn.Dense(256, kernel_init=layer_init())(x)
84+
x = nn.relu(x)
85+
x = nn.Dense(1, kernel_init=layer_init())(x)
86+
return x
87+
88+
89+
class DoubleCriticNetwork(nn.Module):
90+
@nn.compact
91+
def __call__(self, x: jnp.ndarray, a: jnp.ndarray):
92+
critic1 = CriticNetwork()(x, a)
93+
critic2 = CriticNetwork()(x, a)
94+
return critic1, critic2
95+
96+
97+
EXP_ADV_MAX = 100.0
98+
LOG_STD_MAX = 2.0
99+
LOG_STD_MIN = -10.0
100+
101+
102+
class Actor(nn.Module):
103+
action_dim: int
104+
105+
@nn.compact
106+
def __call__(self, x: jnp.ndarray, temperature: float = 1.0):
107+
x = nn.Dense(256, kernel_init=layer_init())(x)
108+
x = nn.relu(x)
109+
x = nn.Dense(256, kernel_init=layer_init())(x)
110+
x = nn.relu(x)
111+
mean = nn.Dense(self.action_dim, kernel_init=layer_init())(x)
112+
log_std = self.param("log_std", nn.initializers.zeros, (self.action_dim, ))
113+
log_std = jnp.clip(log_std, LOG_STD_MIN, LOG_STD_MAX)
114+
dist = tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std) * temperature)
115+
return dist
116+
117+
118+
class TargetTrainState(TrainState):
119+
target_params: flax.core.FrozenDict
120+
121+
122+
class Batch(NamedTuple):
123+
observations: np.ndarray
124+
actions: np.ndarray
125+
rewards: np.ndarray
126+
masks: np.ndarray
127+
next_observations: np.ndarray
128+
129+
130+
class Dataset:
131+
def __init__(self):
132+
self.size = None
133+
self.observations = None
134+
self.actions = None
135+
self.rewards = None
136+
self.masks = None
137+
self.next_observations = None
138+
139+
def load(self, env, eps=1e-5):
140+
self.env = env
141+
dataset = d4rl.qlearning_dataset(env)
142+
lim = 1 - eps # Clip to eps
143+
dataset["actions"] = np.clip(dataset["actions"], -lim, lim)
144+
self.size = len(dataset["observations"])
145+
self.observations = dataset["observations"].astype(np.float32)
146+
self.actions = dataset["actions"].astype(np.float32)
147+
self.rewards = dataset["rewards"].astype(np.float32)
148+
self.masks = 1.0 - dataset["terminals"].astype(np.float32)
149+
self.next_observations = dataset["next_observations"].astype(np.float32)
150+
151+
def sample(self, batch_size):
152+
idx = np.random.randint(self.size, size=batch_size)
153+
data = (
154+
self.observations[idx],
155+
self.actions[idx],
156+
self.rewards[idx],
157+
self.masks[idx],
158+
self.next_observations[idx],
159+
)
160+
return Batch(*data)
161+
162+
163+
if __name__ == "__main__":
164+
# Logging setup
165+
args = pyrallis.parse(config_class=TrainArgs)
166+
print(vars(args))
167+
run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}"
168+
writer = SummaryWriter(f"{args.log_dir}/{run_name}")
169+
writer.add_text(
170+
"hyperparameters",
171+
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
172+
)
173+
174+
# Seeding
175+
random.seed(args.seed)
176+
np.random.seed(args.seed)
177+
key = jax.random.PRNGKey(args.seed)
178+
key, actor_key, critic_key, value_key = jax.random.split(key, 4)
179+
180+
# Eval env setup
181+
env = make_env(args.gym_id, args.seed)()
182+
assert isinstance(env.action_space, gym.spaces.Box), "only continuous action space is supported"
183+
observation = env.observation_space.sample()[np.newaxis]
184+
action = env.action_space.sample()[np.newaxis]
185+
186+
# Agent setup
187+
actor = Actor(action_dim=np.prod(env.action_space.shape))
188+
actor_state = TrainState.create(
189+
apply_fn=actor.apply,
190+
params=actor.init(actor_key, observation),
191+
tx=optax.adam(learning_rate=args.actor_lr)
192+
)
193+
vf = ValueNetwork()
194+
vf_state = TrainState.create(
195+
apply_fn=vf.apply,
196+
params=vf.init(value_key, observation),
197+
tx=optax.adam(learning_rate=args.value_lr)
198+
)
199+
qf = DoubleCriticNetwork()
200+
qf_state = TargetTrainState.create(
201+
apply_fn=qf.apply,
202+
params=qf.init(critic_key, observation, action),
203+
target_params=qf.init(critic_key, observation, action),
204+
tx=optax.adam(learning_rate=args.critic_lr)
205+
)
206+
207+
# Dataset setup
208+
dataset = Dataset()
209+
dataset.load(env)
210+
start_time = time.time()
211+
212+
def asymmetric_l2_loss(diff, expectile=0.8):
213+
weight = jnp.where(diff > 0, expectile, (1 - expectile))
214+
return weight * (diff**2)
215+
216+
def update_vf(vf_state, qf_state, batch):
217+
q1, q2 = qf.apply(qf_state.target_params, batch.observations, batch.actions)
218+
q = jnp.minimum(q1, q2)
219+
220+
def vf_loss_fn(params):
221+
v = vf.apply(params, batch.observations)
222+
vf_loss = asymmetric_l2_loss(q - v, args.expectile).mean()
223+
return vf_loss, {
224+
"vf_loss": vf_loss,
225+
"v": v.mean(),
226+
}
227+
228+
(vf_loss, info), grads = jax.value_and_grad(vf_loss_fn, has_aux=True)(vf_state.params)
229+
vf_state = vf_state.apply_gradients(grads=grads)
230+
return vf_state, info
231+
232+
def update_actor(actor_state, vf_state, qf_state, batch):
233+
v = vf.apply(vf_state.params, batch.observations)
234+
q1, q2 = qf.apply(qf_state.target_params, batch.observations, batch.actions)
235+
q = jnp.minimum(q1, q2)
236+
exp_adv = jnp.exp((q - v) * args.temperature)
237+
exp_adv = jnp.minimum(exp_adv, EXP_ADV_MAX)
238+
239+
def actor_loss_fn(params):
240+
dist = actor.apply(params, batch.observations)
241+
log_probs = dist.log_prob(batch.actions)
242+
actor_loss = -(exp_adv * log_probs).mean()
243+
return actor_loss, {
244+
"actor_loss": actor_loss,
245+
"adv": q - v,
246+
}
247+
248+
(actor_loss, info), grads = jax.value_and_grad(actor_loss_fn, has_aux=True)(actor_state.params)
249+
actor_state = actor_state.apply_gradients(grads=grads)
250+
return actor_state, info
251+
252+
def update_qf(vf_state, qf_state, batch):
253+
next_v = vf.apply(vf_state.params, batch.next_observations)
254+
target_q = batch.rewards + args.gamma * batch.masks * next_v
255+
256+
def qf_loss_fn(params):
257+
q1, q2 = qf.apply(params, batch.observations, batch.actions)
258+
qf_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()
259+
return qf_loss, {
260+
"qf_loss": qf_loss,
261+
"q1": q1.mean(),
262+
"q2": q2.mean(),
263+
}
264+
265+
(qf_loss, info), grads = jax.value_and_grad(qf_loss_fn, has_aux=True)(qf_state.params)
266+
qf_state = qf_state.apply_gradients(grads=grads)
267+
return qf_state, info
268+
269+
def update_target(qf_state):
270+
new_target_params = jax.tree_map(
271+
lambda p, tp: p * args.polyak + tp * (1 - args.polyak), qf_state.params,
272+
qf_state.target_params)
273+
return qf_state.replace(target_params=new_target_params)
274+
275+
@jax.jit
276+
def update(actor_state, vf_state, qf_state, batch):
277+
vf_state, vf_info = update_vf(vf_state, qf_state, batch)
278+
actor_state, actor_info = update_actor(actor_state, vf_state, qf_state, batch)
279+
qf_state, qf_info = update_qf(vf_state, qf_state, batch)
280+
qf_state = update_target(qf_state)
281+
return actor_state, vf_state, qf_state, {
282+
**vf_info, **actor_info, **qf_info
283+
}
284+
285+
@jax.jit
286+
def get_action(rng, actor_state, observation, temperature=1.0):
287+
dist = actor.apply(actor_state.params, observation, temperature)
288+
rng, key = jax.random.split(rng)
289+
action = dist.sample(seed=key)
290+
return rng, jnp.clip(action, -1, 1)
291+
292+
# Main loop
293+
for global_step in tqdm(range(args.total_iterations), desc="Training", unit="iter"):
294+
295+
# Batch update
296+
batch = dataset.sample(batch_size=args.batch_size)
297+
actor_state, vf_state, qf_state, update_info = update(
298+
actor_state, vf_state, qf_state, batch
299+
)
300+
301+
# Evaluation
302+
if global_step % args.eval_freq == 0:
303+
env.seed(args.seed)
304+
stats = {"return": [], "length": []}
305+
for _ in range(args.eval_episodes):
306+
obs, done = env.reset(), False
307+
while not done:
308+
key, action = get_action(key, actor_state, obs, temperature=0.0)
309+
action = np.asarray(action)
310+
obs, reward, done, info = env.step(action)
311+
for k in stats.keys():
312+
stats[k].append(info["episode"][k[0]])
313+
for k, v in stats.items():
314+
writer.add_scalar(f"charts/episodic_{k}", np.mean(v), global_step)
315+
if k == "return":
316+
normalized_score = env.get_normalized_score(np.mean(v)) * 100
317+
writer.add_scalar("charts/normalized_score", normalized_score, global_step)
318+
writer.flush()
319+
320+
# Logging
321+
if global_step % args.log_freq == 0:
322+
for k, v in update_info.items():
323+
if v.ndim == 0:
324+
writer.add_scalar(f"losses/{k}", v, global_step)
325+
else:
326+
writer.add_histogram(f"losses/{k}", v, global_step)
327+
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
328+
writer.flush()
329+
330+
env.close()
331+
writer.close()

‎src/flexrl/online/dqn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.nn.functional as F
1515
import torch.optim as optim
1616
from stable_baselines3.common.buffers import ReplayBuffer
17-
from torch.utils.tensorboard import SummaryWriter
17+
from tensorboardX import SummaryWriter
1818

1919

2020
@dataclass

‎src/flexrl/online/dqn_atari.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
NoopResetEnv,
2222
)
2323
from stable_baselines3.common.buffers import ReplayBuffer
24-
from torch.utils.tensorboard import SummaryWriter
24+
from tensorboardX import SummaryWriter
2525

2626

2727
@dataclass

‎src/flexrl/online/ppo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch.nn as nn
1414
import torch.optim as optim
1515
from torch.distributions.categorical import Categorical
16-
from torch.utils.tensorboard import SummaryWriter
16+
from tensorboardX import SummaryWriter
1717

1818

1919
@dataclass

‎src/flexrl/online/ppo_atari.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch.nn as nn
1414
import torch.optim as optim
1515
from torch.distributions.categorical import Categorical
16-
from torch.utils.tensorboard import SummaryWriter
16+
from tensorboardX import SummaryWriter
1717

1818
from stable_baselines3.common.atari_wrappers import ( # isort:skip
1919
ClipRewardEnv,

‎src/flexrl/online/ppo_multidiscrete.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch.nn as nn
1414
import torch.optim as optim
1515
from torch.distributions.categorical import Categorical
16-
from torch.utils.tensorboard import SummaryWriter
16+
from tensorboardX import SummaryWriter
1717

1818

1919
@dataclass

‎src/flexrl/online/qr_dqn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch.nn.functional as F
1616
import torch.optim as optim
1717
from stable_baselines3.common.buffers import ReplayBuffer
18-
from torch.utils.tensorboard import SummaryWriter
18+
from tensorboardX import SummaryWriter
1919

2020

2121
@dataclass

‎src/flexrl/online/qr_dqn_atari.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
NoopResetEnv,
2323
)
2424
from stable_baselines3.common.buffers import ReplayBuffer
25-
from torch.utils.tensorboard import SummaryWriter
25+
from tensorboardX import SummaryWriter
2626

2727

2828
@dataclass

‎src/flexrl/online/sac.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.nn.functional as F
1515
import torch.optim as optim
1616
from stable_baselines3.common.buffers import ReplayBuffer
17-
from torch.utils.tensorboard import SummaryWriter
17+
from tensorboardX import SummaryWriter
1818

1919

2020
@dataclass

‎tests/test_d4rl.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,17 @@ def test_iql_d4rl():
99
)
1010

1111

12+
def test_iql_jax_d4rl():
13+
subprocess.run(
14+
"python src/flexrl/offline/iql_jax.py --gym_id halfcheetah-medium-expert-v2 --total_iterations 200 --eval_freq 100 --eval_episodes 1 --log_freq 100",
15+
shell=True,
16+
check=True,
17+
)
18+
19+
1220
def test_inac_d4rl():
1321
subprocess.run(
1422
"python src/flexrl/offline/inac.py --gym_id halfcheetah-medium-expert-v2 --total_iterations 200 --eval_freq 100 --eval_episodes 1 --log_freq 100",
1523
shell=True,
1624
check=True,
17-
)
25+
)

0 commit comments

Comments
 (0)
Please sign in to comment.