Skip to content

Commit e82a69f

Browse files
author
Vincent Moens
authored
[Doc] Fix README example (#2398)
1 parent 25e8bd2 commit e82a69f

File tree

1 file changed

+49
-48
lines changed

1 file changed

+49
-48
lines changed

README.md

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -99,68 +99,69 @@ lines of code*!
9999

100100
from torchrl.collectors import SyncDataCollector
101101
from torchrl.data.replay_buffers import TensorDictReplayBuffer, \
102-
LazyTensorStorage, SamplerWithoutReplacement
102+
LazyTensorStorage, SamplerWithoutReplacement
103103
from torchrl.envs.libs.gym import GymEnv
104104
from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal
105105
from torchrl.objectives import ClipPPOLoss
106106
from torchrl.objectives.value import GAE
107107

108-
env = GymEnv("Pendulum-v1")
108+
env = GymEnv("Pendulum-v1")
109109
model = TensorDictModule(
110-
nn.Sequential(
111-
nn.Linear(3, 128), nn.Tanh(),
112-
nn.Linear(128, 128), nn.Tanh(),
113-
nn.Linear(128, 128), nn.Tanh(),
114-
nn.Linear(128, 2),
115-
NormalParamExtractor()
116-
),
117-
in_keys=["observation"],
118-
out_keys=["loc", "scale"]
110+
nn.Sequential(
111+
nn.Linear(3, 128), nn.Tanh(),
112+
nn.Linear(128, 128), nn.Tanh(),
113+
nn.Linear(128, 128), nn.Tanh(),
114+
nn.Linear(128, 2),
115+
NormalParamExtractor()
116+
),
117+
in_keys=["observation"],
118+
out_keys=["loc", "scale"]
119119
)
120120
critic = ValueOperator(
121-
nn.Sequential(
122-
nn.Linear(3, 128), nn.Tanh(),
123-
nn.Linear(128, 128), nn.Tanh(),
124-
nn.Linear(128, 128), nn.Tanh(),
125-
nn.Linear(128, 1),
126-
),
127-
in_keys=["observation"],
121+
nn.Sequential(
122+
nn.Linear(3, 128), nn.Tanh(),
123+
nn.Linear(128, 128), nn.Tanh(),
124+
nn.Linear(128, 128), nn.Tanh(),
125+
nn.Linear(128, 1),
126+
),
127+
in_keys=["observation"],
128128
)
129129
actor = ProbabilisticActor(
130-
model,
131-
in_keys=["loc", "scale"],
132-
distribution_class=TanhNormal,
133-
distribution_kwargs={"min": -1.0, "max": 1.0},
134-
return_log_prob=True
135-
)
130+
model,
131+
in_keys=["loc", "scale"],
132+
distribution_class=TanhNormal,
133+
distribution_kwargs={"low": -1.0, "high": 1.0},
134+
return_log_prob=True
135+
)
136136
buffer = TensorDictReplayBuffer(
137-
LazyTensorStorage(1000),
138-
SamplerWithoutReplacement()
139-
)
137+
storage=LazyTensorStorage(1000),
138+
sampler=SamplerWithoutReplacement(),
139+
batch_size=50,
140+
)
140141
collector = SyncDataCollector(
141-
env,
142-
actor,
143-
frames_per_batch=1000,
144-
total_frames=1_000_000
145-
)
146-
loss_fn = ClipPPOLoss(actor, critic, gamma=0.99)
142+
env,
143+
actor,
144+
frames_per_batch=1000,
145+
total_frames=1_000_000,
146+
)
147+
loss_fn = ClipPPOLoss(actor, critic)
148+
adv_fn = GAE(value_network=critic, average_gae=True, gamma=0.99, lmbda=0.95)
147149
optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4)
148-
adv_fn = GAE(value_network=critic, gamma=0.99, lmbda=0.95, average_gae=True)
150+
149151
for data in collector: # collect data
150-
for epoch in range(10):
151-
adv_fn(data) # compute advantage
152-
buffer.extend(data.view(-1))
153-
for i in range(20): # consume data
154-
sample = buffer.sample(50) # mini-batch
155-
loss_vals = loss_fn(sample)
156-
loss_val = sum(
157-
value for key, value in loss_vals.items() if
158-
key.startswith("loss")
159-
)
160-
loss_val.backward()
161-
optim.step()
162-
optim.zero_grad()
163-
print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")
152+
for epoch in range(10):
153+
adv_fn(data) # compute advantage
154+
buffer.extend(data)
155+
for sample in buffer: # consume data
156+
loss_vals = loss_fn(sample)
157+
loss_val = sum(
158+
value for key, value in loss_vals.items() if
159+
key.startswith("loss")
160+
)
161+
loss_val.backward()
162+
optim.step()
163+
optim.zero_grad()
164+
print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")
164165
```
165166
</details>
166167

0 commit comments

Comments
 (0)