-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrl_agent.py
200 lines (163 loc) · 7.76 KB
/
rl_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from dataclasses import asdict
from typing import TYPE_CHECKING, Dict, Optional, Union
import numpy as np
from gym import spaces
if TYPE_CHECKING:
from rosnav_rl.cfg import AgentCfg
from rosnav_rl.model.stable_baselines3 import StableBaselinesModel
from rosnav_rl.reward.reward_function import RewardFunction
from rosnav_rl.spaces.space_manager.base_space_manager import BaseSpaceManager
from rosnav_rl.states import AgentStateContainer
from rosnav_rl.utils.type_aliases import ObservationDict
from .model import RL_Model
class RL_Agent:
"""
RL_Agent is a reinforcement learning agent that integrates a model, reward function,
space manager, and simulation state container to interact with an environment.
Attributes:
_name (str): The name of the agent.
_model (RL_Model): The reinforcement learning model used by the agent.
_reward_function (Optional[RewardFunction]): The function used to calculate rewards.
_space_manager (BaseSpaceManager): Manages the action and observation spaces.
_agent_state_container (AgentStateContainer): Container for the agent state (action and observation space).
Methods:
__init__(agent_cfg: AgentCfg, simulation_state_container: SimulationStateContainer):
Initializes the RL_Agent with the given configuration and simulation state container.
config() -> Dict[str, dict]:
Returns the configuration of the agent, including model, reward, space, and state containers.
observation_space() -> spaces.Dict:
Returns the observation space managed by the space manager.
action_space() -> Union[spaces.Discrete, spaces.Box]:
Returns the action space managed by the space manager.
agent_state_container() -> AgentStateContainer:
Returns the agent state container managed by the space manager.
get_reward(observation: ObservationDict) -> float:
Calculates and returns the reward for a given observation.
get_action(observation: ObservationDict) -> np.ndarray:
Returns the action for a given observation by encoding the observation,
getting the action from the model, and decoding the action.
"""
_name: str
_model: Union[RL_Model, StableBaselinesModel]
_reward_function: Optional[RewardFunction] = None
_space_manager: BaseSpaceManager
_agent_state_container: AgentStateContainer
def __init__(
self,
agent_cfg: "AgentCfg",
agent_state_container: AgentStateContainer,
):
"""
Initialize the Reinforcement Learning Agent.
Args:
agent_cfg (AgentCfg): Configuration for the agent.
simulation_state_container (SimulationStateContainer): Container for the simulation state.
name (str, optional): Name of the agent. Defaults to None.
Attributes:
_name (str): Name of the agent.
_simulation_state_container (SimulationStateContainer): Container for the simulation state.
_model (StableBaselinesModel): The framework-specific RL model used by the agent.
_space_manager (BaseSpaceManager): Manages the action and observation spaces.
_reward_function (RewardFunction, optional): The reward function used by the agent, if specified in the configuration.
"""
self._name = agent_cfg.name
self._agent_state_container = agent_state_container
self._model = StableBaselinesModel(
rl_agent=self,
algorithm_cfg=agent_cfg.framework.algorithm,
)
self._space_manager = BaseSpaceManager(
action_space_kwargs={"is_discrete": agent_cfg.action_space.is_discrete},
agent_state_container=self._agent_state_container,
observation_space_list=self.model.observation_space_list,
observation_space_kwargs=self.model.observation_space_kwargs,
)
if agent_cfg.reward is not None:
self._reward_function = RewardFunction(
function_dict=agent_cfg.reward.reward_function_dict,
unit_kwargs=agent_cfg.reward.reward_unit_kwargs,
verbose=agent_cfg.reward.verbose,
)
def initialize_model(self, *args, **kwargs):
"""
Initialize the model if it has not been initialized yet.
Args:
*args: Variable length argument list to be passed to the model's initialize method.
**kwargs: Arbitrary keyword arguments to be passed to the model's initialize method.
"""
self.model.setup_model(*args, **kwargs)
def load_model(self, *args, **kwargs):
"""
Load the model if it has not been loaded yet.
Args:
*args: Variable length argument list to be passed to the model's load method.
**kwargs: Arbitrary keyword arguments to be passed to the model's load method.
"""
if not self.model.is_model_initialized:
self.model.load(*args, **kwargs)
# def get_reward(self, observation: ObservationDict) -> float:
# """
# Calculate and return the reward based on the given observation.
# Args:
# observation (ObservationDict): The current observation containing relevant state information.
# Returns:
# float: The calculated reward based on the observation and the current simulation state.
# """
# return self._reward_function.get_reward(
# observation, simulation_state_container=self._simulation_state_container
# )
def train(self, *args, **kwargs):
"""
Train the model.
Args:
*args: Variable length argument list to be passed to the model's train method.
**kwargs: Arbitrary keyword arguments to be passed to the model's train method.
"""
self.model.train(*args, **kwargs)
def get_action(self, observation: ObservationDict, *args, **kwargs) -> np.ndarray:
"""
Get an action from the model based on the given observation.
Args:
observation (ObservationDict): The observation data used by the model to determine the action.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
np.ndarray: The action determined by the model.
"""
return self.model.get_action(observation=observation, *args, **kwargs)
@property
def config(self) -> Dict[str, dict]:
config_dict = {
"model": self.model.config,
"space": self._space_manager.config,
"agent_state_container": asdict(self.agent_state_container),
# "simulation_state_container": asdict(self._simulation_state_container),
}
if self._reward_function is not None:
config_dict["reward"] = self._reward_function.config
return config_dict
@property
def model(self) -> StableBaselinesModel:
if self._model is None:
raise ValueError("'RL_Model' not initialized.")
return self._model
@property
def reward_function(self) -> Union[None, RewardFunction]:
return self._reward_function
@property
def space_manager(self) -> BaseSpaceManager:
if self._space_manager is None:
raise ValueError("'SpaceManager' not initialized.")
return self._space_manager
@property
def observation_space(self) -> spaces.Dict:
return self._space_manager.observation_space
@property
def action_space(self) -> Union[spaces.Discrete, spaces.Box]:
return self._space_manager.action_space
@property
def agent_state_container(self) -> AgentStateContainer:
return self._space_manager.agent_state_container
@property
def name(self) -> str:
return self._name