@@ -99,68 +99,69 @@ lines of code*!
99
99
100
100
from torchrl.collectors import SyncDataCollector
101
101
from torchrl.data.replay_buffers import TensorDictReplayBuffer, \
102
- LazyTensorStorage, SamplerWithoutReplacement
102
+ LazyTensorStorage, SamplerWithoutReplacement
103
103
from torchrl.envs.libs.gym import GymEnv
104
104
from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal
105
105
from torchrl.objectives import ClipPPOLoss
106
106
from torchrl.objectives.value import GAE
107
107
108
- env = GymEnv(" Pendulum-v1" )
108
+ env = GymEnv(" Pendulum-v1" )
109
109
model = TensorDictModule(
110
- nn.Sequential(
111
- nn.Linear(3 , 128 ), nn.Tanh(),
112
- nn.Linear(128 , 128 ), nn.Tanh(),
113
- nn.Linear(128 , 128 ), nn.Tanh(),
114
- nn.Linear(128 , 2 ),
115
- NormalParamExtractor()
116
- ),
117
- in_keys = [" observation" ],
118
- out_keys = [" loc" , " scale" ]
110
+ nn.Sequential(
111
+ nn.Linear(3 , 128 ), nn.Tanh(),
112
+ nn.Linear(128 , 128 ), nn.Tanh(),
113
+ nn.Linear(128 , 128 ), nn.Tanh(),
114
+ nn.Linear(128 , 2 ),
115
+ NormalParamExtractor()
116
+ ),
117
+ in_keys = [" observation" ],
118
+ out_keys = [" loc" , " scale" ]
119
119
)
120
120
critic = ValueOperator(
121
- nn.Sequential(
122
- nn.Linear(3 , 128 ), nn.Tanh(),
123
- nn.Linear(128 , 128 ), nn.Tanh(),
124
- nn.Linear(128 , 128 ), nn.Tanh(),
125
- nn.Linear(128 , 1 ),
126
- ),
127
- in_keys = [" observation" ],
121
+ nn.Sequential(
122
+ nn.Linear(3 , 128 ), nn.Tanh(),
123
+ nn.Linear(128 , 128 ), nn.Tanh(),
124
+ nn.Linear(128 , 128 ), nn.Tanh(),
125
+ nn.Linear(128 , 1 ),
126
+ ),
127
+ in_keys = [" observation" ],
128
128
)
129
129
actor = ProbabilisticActor(
130
- model,
131
- in_keys = [" loc" , " scale" ],
132
- distribution_class = TanhNormal,
133
- distribution_kwargs = {" min " : - 1.0 , " max " : 1.0 },
134
- return_log_prob = True
135
- )
130
+ model,
131
+ in_keys = [" loc" , " scale" ],
132
+ distribution_class = TanhNormal,
133
+ distribution_kwargs = {" low " : - 1.0 , " high " : 1.0 },
134
+ return_log_prob = True
135
+ )
136
136
buffer = TensorDictReplayBuffer(
137
- LazyTensorStorage(1000 ),
138
- SamplerWithoutReplacement()
139
- )
137
+ storage = LazyTensorStorage(1000 ),
138
+ sampler = SamplerWithoutReplacement(),
139
+ batch_size = 50 ,
140
+ )
140
141
collector = SyncDataCollector(
141
- env,
142
- actor,
143
- frames_per_batch = 1000 ,
144
- total_frames = 1_000_000
145
- )
146
- loss_fn = ClipPPOLoss(actor, critic, gamma = 0.99 )
142
+ env,
143
+ actor,
144
+ frames_per_batch = 1000 ,
145
+ total_frames = 1_000_000 ,
146
+ )
147
+ loss_fn = ClipPPOLoss(actor, critic)
148
+ adv_fn = GAE(value_network = critic, average_gae = True , gamma = 0.99 , lmbda = 0.95 )
147
149
optim = torch.optim.Adam(loss_fn.parameters(), lr = 2e-4 )
148
- adv_fn = GAE( value_network = critic, gamma = 0.99 , lmbda = 0.95 , average_gae = True )
150
+
149
151
for data in collector: # collect data
150
- for epoch in range (10 ):
151
- adv_fn(data) # compute advantage
152
- buffer.extend(data.view(- 1 ))
153
- for i in range (20 ): # consume data
154
- sample = buffer.sample(50 ) # mini-batch
155
- loss_vals = loss_fn(sample)
156
- loss_val = sum (
157
- value for key, value in loss_vals.items() if
158
- key.startswith(" loss" )
159
- )
160
- loss_val.backward()
161
- optim.step()
162
- optim.zero_grad()
163
- print (f " avg reward: { data[' next' , ' reward' ].mean().item(): 4.4f } " )
152
+ for epoch in range (10 ):
153
+ adv_fn(data) # compute advantage
154
+ buffer.extend(data)
155
+ for sample in buffer: # consume data
156
+ loss_vals = loss_fn(sample)
157
+ loss_val = sum (
158
+ value for key, value in loss_vals.items() if
159
+ key.startswith(" loss" )
160
+ )
161
+ loss_val.backward()
162
+ optim.step()
163
+ optim.zero_grad()
164
+ print (f " avg reward: { data[' next' , ' reward' ].mean().item(): 4.4f } " )
164
165
```
165
166
</details >
166
167
0 commit comments