-
Notifications
You must be signed in to change notification settings - Fork 0
/
td_learning.py
150 lines (114 loc) · 4.17 KB
/
td_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import numpy as np
class SARSA:
"""
Implementation of the SARSA on-policy TD control
Parameters
----------
mdp
MDP the policy is applied on
policy
Policy for which the Q function is estimated
lr : float
Learning rate of the SARSA update rule
Methods
-------
play_episodes(n)
Iterates n episodes of the SARSA TD control. In each episode the state-action value
function Q of the policy is updated
"""
def __init__(self, mdp, policy, lr=0.1):
self.mdp = mdp
self.policy = policy
self.lr = lr
self.episode_reward = None
def play_episodes(self, n=100):
"""
Iterates n episodes of the SARSA TD control. In each episode the state-action value
function Q of the policy is updated
Parameters
----------
n : int
Number of episodes to play
"""
if self.episode_reward is None:
self.episode_reward = np.empty(n)
i0 = 0
else:
i0 = self.episode_reward.shape[0]
self.episode_reward = np.concatenate((self.episode_reward, np.empty(n)))
for i in range(n):
self.mdp.set_state(np.random.choice(self.mdp.states))
state = self.mdp.state
if state in self.mdp.goal:
self.episode_reward[i + i0] = None
continue
action = self.policy.select_action()
total_reward = 0
while True:
state_next, reward = self.mdp.step(action, transition=True)
total_reward += reward
if state_next in self.mdp.goal:
break
action_next = self.policy.select_action()
self.policy.Q[state][action] = self.policy.Q[state][action] + self.lr * (
reward + self.policy.gamma * self.policy.Q[state_next][action_next] -
self.policy.Q[state][action])
state = state_next
action = action_next
self.episode_reward[i + i0] = total_reward
class QLearning:
"""
Implementation of the Q-Learning off-policy TD control
Parameters
----------
mdp
MDP the policy is applied on
policy
Policy for which the Q function is estimated
lr : float
Learning rate of the Q-Learning update rule
Methods
-------
play_episodes(n)
Iterates n episodes of the Q-Learning TD control. In each episode the state-action value
function Q of the policy is updated
"""
def __init__(self, mdp, policy, lr=0.1):
self.mdp = mdp
self.policy = policy
self.lr = lr
self.episode_reward = None
def play_episodes(self, n=100):
"""
Iterates n episodes of the Q-Learning TD control. In each episode the state-action value
function Q of the policy is updated
Parameters
----------
n : int
Number of episodes to play
"""
if self.episode_reward is None:
self.episode_reward = np.empty(n)
i0 = 0
else:
i0 = self.episode_reward.shape[0]
self.episode_reward = np.concatenate((self.episode_reward, np.empty(n)))
for i in range(n):
self.mdp.set_state(np.random.choice(self.mdp.states))
state = self.mdp.state
if state in self.mdp.goal:
self.episode_reward[i + i0] = None
continue
total_reward = 0
while True:
action = self.policy.select_action()
state_next, reward = self.mdp.step(action, transition=True)
total_reward += reward
if state_next in self.mdp.goal:
break
self.policy.Q[state][action] = self.policy.Q[state][action] + self.lr * (
reward +
self.policy.gamma * np.max([self.policy.Q[state_next][a] for a in state.allowed_actions]) -
self.policy.Q[state][action])
state = state_next
self.episode_reward[i + i0] = total_reward