Skip to content

Commit d691945

Browse files
committed
doesn't train property
1 parent d799a1a commit d691945

File tree

5 files changed

+125
-41
lines changed

5 files changed

+125
-41
lines changed

agent.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -31,28 +31,37 @@ def __init__(self, s_dim, a_dim, n_agents, **kwargs):
3131
hard_update(self.policy, self.policy_target)
3232
hard_update(self.critic, self.critic_target)
3333

34-
self.random_process = OrnsteinUhlenbeckProcess(size=self.a_dim, theta=self.config.ou_theta, mu=self.config.ou_mu, sigma=self.config.ou_sigma)
34+
self.random_process = OrnsteinUhlenbeckProcess(size=self.a_dim,
35+
theta=self.config.ou_theta,
36+
mu=self.config.ou_mu,
37+
sigma=self.config.ou_sigma)
3538
self.replay_buffer = list()
3639
self.epsilon = 1.
3740
self.depsilon = self.epsilon / self.config.epsilon_decay
3841

3942
self.c_loss = None
4043
self.a_loss = None
44+
self.action_log = list()
4145

4246
def choose_action(self, obs, noisy=True):
4347
obs = torch.Tensor([obs]).to(self.device)
4448

4549
action = self.policy(obs).cpu().detach().numpy()[0]
50+
self.action_log.append(action)
51+
4652
if noisy:
4753
for agent_idx in range(self.n_agents):
48-
action[agent_idx] += max(self.epsilon, 0.001) * self.random_process.sample()
54+
pass
55+
# action[agent_idx] += self.epsilon * self.random_process.sample()
4956
self.epsilon -= self.depsilon
57+
self.epsilon = max(self.epsilon, 0.001)
5058
np.clip(action, -1., 1.)
5159

5260
return action
5361

5462
def reset(self):
5563
self.random_process.reset_states()
64+
self.action_log.clear()
5665

5766
def prep_train(self):
5867
self.policy.train()
@@ -66,7 +75,6 @@ def prep_eval(self):
6675
self.policy_target.eval()
6776
self.critic_target.eval()
6877

69-
7078
def random_action(self):
7179
return np.random.uniform(low=-1, high=1, size=(self.n_agents, 2))
7280

@@ -85,9 +93,11 @@ def get_batches(self):
8593
next_state_batches = np.array([_[3] for _ in experiences])
8694
done_batches = np.array([_[4] for _ in experiences])
8795

96+
8897
return state_batches, action_batches, reward_batches, next_state_batches, done_batches
8998

9099
def train(self):
100+
91101
state_batches, action_batches, reward_batches, next_state_batches, done_batches = self.get_batches()
92102

93103
state_batches = torch.Tensor(state_batches).to(self.device)
@@ -97,23 +107,36 @@ def train(self):
97107
done_batches = torch.Tensor((done_batches == False) * 1).view(-1, self.n_agents, 1).to(self.device)
98108

99109
target_next_actions = self.policy_target.forward(next_state_batches).detach()
100-
target_next_q = self.critic_target.forward(next_state_batches, target_next_actions).detach()
101-
110+
target_next_q = self.critic_target.forward(next_state_batches, target_next_actions)
102111
main_q = self.critic(state_batches, action_batches)
103112

113+
'''
114+
How to concat each agent's Q value?
115+
'''
116+
#target_next_q = target_next_q
117+
#main_q = main_q.mean(dim=1)
118+
119+
120+
'''
121+
Reward Norm
122+
'''
123+
# reward_batches = (reward_batches - reward_batches.mean(dim=0)) / reward_batches.std(dim=0) / 1024
124+
104125
# Critic Loss
105126
self.critic.zero_grad()
106127
baselines = reward_batches + done_batches * self.config.gamma * target_next_q
107-
loss_critic = torch.nn.MSELoss()(main_q, baselines.cuda())
128+
loss_critic = torch.nn.MSELoss()(main_q, baselines.detach())
108129
loss_critic.backward()
130+
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
109131
self.critic_optimizer.step()
110132

111-
# TODO Edit Actor Loss
112133
# Actor Loss
113134
self.policy.zero_grad()
114135
clear_action_batches = self.policy.forward(state_batches)
115136
loss_actor = (-self.critic.forward(state_batches, clear_action_batches)).mean()
137+
loss_actor += (clear_action_batches ** 2).mean() * 1e-3
116138
loss_actor.backward()
139+
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
117140
self.policy_optimizer.step()
118141

119142
# This is for logging
@@ -124,4 +147,7 @@ def train(self):
124147
soft_update(self.critic, self.critic_target, self.config.tau)
125148

126149
def get_loss(self):
127-
return self.c_loss, self.a_loss
150+
return self.c_loss, self.a_loss
151+
152+
def get_action_std(self):
153+
return np.array(self.action_log).std(axis=-1).mean()

network.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import torch
22
import torch.nn as nn
33

4+
from utils import weight_init
5+
6+
HIDDEN_DIM = 200
7+
48

59
class Actor(nn.Module):
610
def __init__(self, s_dim, a_dim, n_agents):
@@ -11,19 +15,19 @@ def __init__(self, s_dim, a_dim, n_agents):
1115
self.n_agents = n_agents
1216

1317
# input (batch, s_dim) output (batch, 300)
14-
self.prev_dense = DenseNet(s_dim, 200, 200, output_activation=None, norm_in=True)
18+
19+
self.prev_dense = DenseNet(s_dim, HIDDEN_DIM // 2, HIDDEN_DIM, output_activation=None, norm_in=True)
1520
# input (num_agents, batch, 200) output (num_agents, batch, num_agents * 2)\
16-
self.comm_net = LSTMNet(200, n_agents, num_layers=1)
21+
self.comm_net = LSTMNet(HIDDEN_DIM, HIDDEN_DIM, num_layers=1)
1722
# input (batch, 2) output (batch, a_dim)
18-
self.post_dense = DenseNet(6, 32, a_dim, output_activation=nn.Tanh)
23+
self.post_dense = DenseNet(HIDDEN_DIM * 2, HIDDEN_DIM, a_dim, output_activation=nn.Tanh)
1924

2025
def forward(self, x):
21-
2226
x = x.view(-1, self.s_dim)
2327
x = self.prev_dense(x)
24-
x = x.reshape(-1, self.n_agents, 200)
28+
x = x.reshape(-1, self.n_agents, HIDDEN_DIM)
2529
x = self.comm_net(x)
26-
x = x.reshape(-1, 6)
30+
x = x.reshape(-1, HIDDEN_DIM * 2)
2731
x = self.post_dense(x)
2832
x = x.view(-1, self.n_agents, self.a_dim)
2933
return x
@@ -38,19 +42,19 @@ def __init__(self, s_dim, a_dim, n_agents):
3842
self.n_agents = n_agents
3943

4044
# input (batch, s_dim) output (batch, 300)
41-
self.prev_dense = DenseNet((s_dim + a_dim), 200, 200, output_activation=None, norm_in=True)
45+
self.prev_dense = DenseNet((s_dim + a_dim), HIDDEN_DIM // 2, HIDDEN_DIM, output_activation=None, norm_in=True)
4246
# input (num_agents, batch, 200) output (num_agents, batch, num_agents * 2)\
43-
self.comm_net = LSTMNet(200, n_agents, num_layers=1)
47+
self.comm_net = LSTMNet(HIDDEN_DIM, HIDDEN_DIM, num_layers=1)
4448
# input (batch, 2) output (batch, a_dim)
45-
self.post_dense = DenseNet(6, 32, 1, output_activation=None)
49+
self.post_dense = DenseNet(HIDDEN_DIM * 2, HIDDEN_DIM, 1, output_activation=None)
4650

4751
def forward(self, x_n, a_n):
4852
x = torch.cat((x_n, a_n), dim=-1)
4953
x = x.view(-1, (self.s_dim + self.a_dim))
5054
x = self.prev_dense(x)
51-
x = x.reshape(-1, self.n_agents, 200)
55+
x = x.reshape(-1, self.n_agents, HIDDEN_DIM)
5256
x = self.comm_net(x)
53-
x = x.reshape(-1, 6)
57+
x = x.reshape(-1, HIDDEN_DIM * 2)
5458
x = self.post_dense(x)
5559
x = x.view(-1, self.n_agents, 1)
5660
return x
@@ -65,10 +69,10 @@ def __init__(self, s_dim, hidden_dim, a_dim, norm_in=False, hidden_activation=nn
6569
if self._norm_in:
6670
self.norm1 = nn.BatchNorm1d(s_dim)
6771

68-
self.dense1 = nn.Linear(s_dim, hidden_dim)
69-
self.dense2 = nn.Linear(hidden_dim, hidden_dim)
70-
self.dense3 = nn.Linear(hidden_dim, hidden_dim)
71-
self.dense4 = nn.Linear(hidden_dim, a_dim)
72+
self.dense1 = nn.Linear(s_dim, hidden_dim // 2)
73+
self.dense2 = nn.Linear(hidden_dim // 2, hidden_dim)
74+
self.dense3 = nn.Linear(hidden_dim, hidden_dim // 2)
75+
self.dense4 = nn.Linear(hidden_dim // 2, a_dim)
7276

7377
if hidden_activation:
7478
self.hidden_activation = hidden_activation()

normalized_env.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
'''
2+
Implemented by ghliu
3+
https://github.com/ghliu/pytorch-ddpg/blob/master/normalized_env.py
4+
'''
5+
6+
import gym
7+
import numpy as np
8+
9+
# https://github.com/openai/gym/blob/master/gym/core.py
10+
class ActionNormalizedEnv(gym.ActionWrapper):
11+
""" Wrap action """
12+
def __init__(self, env):
13+
super(ActionNormalizedEnv, self).__init__(env=env)
14+
self.action_high = 1.
15+
self.action_low = -1.
16+
17+
def action(self, action):
18+
act_k = (self.action_high - self.action_low)/ 2.
19+
act_b = (self.action_high + self.action_low)/ 2.
20+
return act_k * action + act_b
21+
22+
def reverse_action(self, action):
23+
act_k_inv = 2./(self.action_high - self.action_low)
24+
act_b = (self.action_high + self.action_low)/ 2.
25+
return act_k_inv * (action - act_b)
26+
27+
class ObsNormalizedEnv(gym.ObservationWrapper):
28+
""" Wrap action """
29+
def __init__(self, env):
30+
super(ObsNormalizedEnv, self).__init__(env=env)
31+
self.action_high = 1.
32+
self.action_low = -1.
33+
34+
def observation(self, observation):
35+
obs = np.array([[observation[0][2] - observation[0][0], observation[0][3] - observation[0][1]]])
36+
return obs

train.py

+31-17
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,25 @@
22
import argparse, datetime
33
from tensorboardX import SummaryWriter
44
import numpy as np
5+
import torch
56

67
from agent import BiCNet
8+
from normalized_env import ActionNormalizedEnv, ObsNormalizedEnv
79

810
def main(args):
911

1012
env = make_env('simple_spread')
13+
# env = make_env('simple')
14+
env = ActionNormalizedEnv(env)
15+
# env = ObsNormalizedEnv(env)
16+
1117
kwargs = dict()
1218
kwargs['config'] = args
13-
19+
torch.manual_seed(args.seed)
1420

1521
if args.tensorboard:
1622
writer = SummaryWriter(log_dir='runs/'+args.log_dir)
23+
#model = BiCNet(18, 2, 3, **kwargs)
1724
model = BiCNet(18, 2, 3, **kwargs)
1825

1926
episode = 0
@@ -22,45 +29,52 @@ def main(args):
2229
while episode < args.max_episodes:
2330

2431
state = env.reset()
32+
2533
episode += 1
2634
step = 0
2735
accum_reward = 0
36+
prev_reward = np.zeros((3), dtype=np.float)
2837

2938
while True:
3039

3140
# action = agent.random_action()
3241
action = model.choose_action(state, noisy=True)
3342

3443
next_state, reward, done, info = env.step(action)
35-
3644
step += 1
3745
total_step += 1
3846
accum_reward += sum(reward)
3947
state = next_state
48+
reward = np.array(reward)
4049

41-
if args.render:
50+
if args.render and episode % 100 == 0:
4251
env.render(mode='rgb_array')
52+
model.memory(state, action, reward - prev_reward, next_state, done)
4353

44-
model.memory(state, action, reward, next_state, done)
45-
54+
prev_reward = reward
4655
if len(model.replay_buffer) >= args.batch_size and total_step % args.steps_per_update == 0:
4756
model.prep_train()
4857
model.train()
4958
model.prep_eval()
5059

5160
if args.episode_length < step or (True in done):
5261
c_loss, a_loss = model.get_loss()
53-
62+
action_std = model.get_action_std()
5463
print("[Episode %05d] reward %6.4f eps %.4f" % (episode, accum_reward, model.epsilon), end='')
5564
if args.tensorboard:
5665
writer.add_scalar(tag='agent/reward', global_step=episode, scalar_value=accum_reward.item())
5766
writer.add_scalar(tag='agent/epsilon', global_step=episode, scalar_value=model.epsilon)
5867
if c_loss and a_loss:
5968
writer.add_scalars('agent/loss', global_step=episode, tag_scalar_dict={'actor':a_loss, 'critic':c_loss})
69+
if action_std:
70+
writer.add_scalar(tag='agent/action_std', global_step=episode, scalar_value=action_std)
6071
if c_loss and a_loss:
61-
print(" a_loss %3.2f c_loss %3.2f" % (a_loss, c_loss))
62-
else:
63-
print()
72+
print(" a_loss %3.2f c_loss %3.2f" % (a_loss, c_loss), end='')
73+
if action_std:
74+
print(" action_std %3.2f" % (action_std), end='')
75+
76+
77+
print()
6478
env.reset()
6579
model.reset()
6680
break
@@ -75,22 +89,22 @@ def main(args):
7589

7690
if __name__ == '__main__':
7791
parser = argparse.ArgumentParser()
78-
parser.add_argument('--max_episodes', default=50000, type=int)
92+
parser.add_argument('--max_episodes', default=1000000, type=int)
7993
parser.add_argument('--episode_length', default=25, type=int)
80-
parser.add_argument('--memory_length', default=int(1e6), type=int)
94+
parser.add_argument('--memory_length', default=int(1e5), type=int)
8195
parser.add_argument("--steps_per_update", default=100, type=int)
8296
parser.add_argument('--tau', default=0.01, type=float)
83-
parser.add_argument('--gamma', default=0.95, type=float)
97+
parser.add_argument('--gamma', default=0.99, type=float)
8498
parser.add_argument('--use_cuda', default=True, type=bool)
8599
parser.add_argument('--seed', default=777, type=int)
86-
parser.add_argument('--a_lr', default=0.05, type=float)
87-
parser.add_argument('--c_lr', default=0.05, type=float)
88-
parser.add_argument('--batch_size', default=1024, type=int)
89-
parser.add_argument('--render', default=False, type=bool)
100+
parser.add_argument('--a_lr', default=0.001, type=float)
101+
parser.add_argument('--c_lr', default=0.001, type=float)
102+
parser.add_argument('--batch_size', default=512, type=int)
103+
parser.add_argument('--render', action="store_true")
90104
parser.add_argument('--ou_theta', default=0.15, type=float)
91105
parser.add_argument('--ou_mu', default=0.0, type=float)
92106
parser.add_argument('--ou_sigma', default=0.2, type=float)
93-
parser.add_argument('--epsilon_decay', default=600000, type=int)
107+
parser.add_argument('--epsilon_decay', default=1000000, type=int)
94108
parser.add_argument('--reward_coef', default=1, type=float)
95109
parser.add_argument('--tensorboard', action="store_true")
96110
parser.add_argument("--save_interval", default=1000, type=int)

utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import numpy as np
3-
3+
import torch.nn as nn
44

55
def to_torch(np_array):
66
return torch.from_numpy(np_array)
@@ -21,3 +21,7 @@ def fanin_init(size, fanin=None):
2121
v = 1. / np.sqrt(fanin)
2222
return torch.Tensor(size).uniform_(-v, v)
2323

24+
def weight_init(m):
25+
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
26+
m.weight.data.fill_(0.)
27+
m.bias.data.fill_(0.)

0 commit comments

Comments
 (0)