-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
62 lines (48 loc) · 1.93 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from torch import optim
from torch.distributions.categorical import Categorical
class PPOTrainer():
def __init__(self,
actor_critic,
ppo_clip_val=0.2,
target_kl_div=0.01,
max_policy_train_iters=80,
value_train_iters=80,
policy_lr=3e-4,
value_lr=1e-2):
self.ac = actor_critic
self.ppo_clip_val = ppo_clip_val
self.target_kl_div = target_kl_div
self.max_policy_train_iters = max_policy_train_iters
self.value_train_iters = value_train_iters
policy_params = list(self.ac.shared_layers.parameters()) + \
list(self.ac.policy_layers.parameters())
self.policy_optim = optim.Adam(policy_params, lr=policy_lr)
value_params = list(self.ac.shared_layers.parameters()) + \
list(self.ac.value_layers.parameters())
self.value_optim = optim.Adam(value_params, lr=value_lr)
def train_policy(self, obs, acts, old_log_probs, gaes):
for _ in range(self.max_policy_train_iters):
self.policy_optim.zero_grad()
new_logits = self.ac.policy(obs)
new_logits = Categorical(logits=new_logits)
new_log_probs = new_logits.log_prob(acts)
policy_ratio = torch.exp(new_log_probs - old_log_probs)
clipped_ratio = policy_ratio.clamp(
1 - self.ppo_clip_val, 1 + self.ppo_clip_val)
clipped_loss = clipped_ratio * gaes
full_loss = policy_ratio * gaes
policy_loss = -torch.min(full_loss, clipped_loss).mean()
policy_loss.backward()
self.policy_optim.step()
kl_div = (old_log_probs - new_log_probs).mean()
if kl_div >= self.target_kl_div:
break
def train_value(self, obs, returns):
for _ in range(self.value_train_iters):
self.value_optim.zero_grad()
values = self.ac.value(obs)
value_loss = (returns - values) ** 2
value_loss = value_loss.mean()
value_loss.backward()
self.value_optim.step()