-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent_model2.py
114 lines (90 loc) · 4.38 KB
/
agent_model2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import sys
snake_rl_path = os.path.abspath(os.path.join(os.path.dirname("run_model2")))
sys.path.append(snake_rl_path)
import torch #pytorch
import random
import numpy as np
from collections import deque #data structure to store memory
from game.snake_without_growing import SnakeGame, Direction, Point
from model.model import Linear_QNet, QTrainer
from helper.plot import plot
# Constants
MAX_MEMORY = 100_000 # Maximum memory capacity for agent's experience storage
BATCH_SIZE = 10000 # Number of experiences used for training in each batch
LR = 0.001 # Learning rate for the neural network training
class Agent:
def __init__(self):
self.n_games = 0 # Counter for the number of games played
self.epsilon = 0 # Exploration rate for making random moves
self.gamma = 0.9 # Discount rate for considering future rewards
self.memory = deque(maxlen=MAX_MEMORY) # Storage for the agent's experiences
self.model = Linear_QNet(8, 256, 3) # Neural network model instantiation
self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma) # Trainer object for model training
def get_state(self, game):
"""
Returns the current state of the game.
"""
# Constructing the state from various conditions
# Current direction of the snake
dir_l = game.direction == Direction.LEFT
dir_r = game.direction == Direction.RIGHT
dir_u = game.direction == Direction.UP
dir_d = game.direction == Direction.DOWN
state = [
dir_l, dir_r, dir_u, dir_d,
# Food's relative position
game.food.x < game.head.x,
game.food.x > game.head.x,
game.food.y < game.head.y,
game.food.y > game.head.y
]
return np.array(state, dtype=int) # Convert the state to an array and return
def remember(self, state, action, reward, next_state, done):
""" Stores the agent's experience in memory.
args:
state (list): current state of the game
action (list): action taken by the agent
reward (int): reward received by the agent
next_state (list): next state of the game
done (bool): whether the game is over or not
"""
# Store the experience in memory
self.memory.append((state, action, reward, next_state, done))
def train_long_memory(self):
""" Trains the model based on the agent's experiences stored in memory.
"""
# Train on a batch from the stored experiences
mini_sample = random.sample(self.memory, BATCH_SIZE) if len(self.memory) > BATCH_SIZE else self.memory
states, actions, rewards, next_states, dones = zip(*mini_sample)
self.trainer.train_step(states, actions, rewards, next_states, dones)
def train_short_memory(self, state, action, reward, next_state, done):
""" Trains the model based on the agent's immediate experience.
args:
state (list): current state of the game
action (list): action taken by the agent
reward (int): reward received by the agent
next_state (list): next state of the game
done (bool): whether the game is over or not
"""
# Train the model immediately after the agent takes an action
self.trainer.train_step(state, action, reward, next_state, done)
def get_action(self, state):
""" Returns the action to be taken by the agent based on the current state of the game.
args:
state (list): current state of the game
"""
# Decide the next action based on epsilon-greedy policy
self.epsilon = 80 - self.n_games # Adjusting exploration rate based on games played
final_move = [0, 0, 0]
# Choose a random action with epsilon probability
if random.randint(0, 200) < self.epsilon:
move = random.randint(0, 2)
final_move[move] = 1
else:
# Otherwise, choose the action with the highest Q-value prediction
state0 = torch.tensor(state, dtype=torch.float)
prediction = self.model(state0)
move = torch.argmax(prediction).item()
final_move[move] = 1
return final_move # Return the chosen action