diff --git a/Cargo.toml b/Cargo.toml index 57cbf39..6c9775c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ rand = "0.7" fnv = "1.0.7" ndarray-rand = "0.11.0" ndarray-stats = "0.3.0" +spaces = "5.0.0" [features] default = [] diff --git a/examples/fortress.rs b/examples/fortress.rs index 0bd57b1..4f3c844 100644 --- a/examples/fortress.rs +++ b/examples/fortress.rs @@ -2,6 +2,8 @@ use agent::*; use env::Fortress; use rust_rl::network::nn::NeuralNetwork; use rust_rl::rl::{agent, env, training}; +use spaces::discrete::Ordinal; +use spaces::*; use std::io; use training::{utils, Trainer}; @@ -62,21 +64,24 @@ pub fn main() { ); } -fn get_agents(agent_nums: Vec) -> Result>, String> { - let mut res: Vec> = vec![]; +fn get_agents( + agent_nums: Vec, +) -> Result, Ordinal>>>, String> { + let mut res: Vec, Ordinal>>> = vec![]; let batch_size = 16; for agent_num in agent_nums { - let new_agent: Result, String> = match agent_num { - 1 => Ok(Box::new(DQLAgent::new( - 1., - batch_size, - new(0.001, batch_size), - ))), - 2 => Ok(Box::new(QLAgent::new(1., 6 * 6))), - 3 => Ok(Box::new(RandomAgent::new())), - 4 => Ok(Box::new(HumanPlayer::new())), - _ => Err("Only implemented agents 1-4!".to_string()), - }; + let new_agent: Result, Ordinal>>, String> = + match agent_num { + 1 => Ok(Box::new(DQLAgent::new( + 1., + batch_size, + new(0.001, batch_size), + ))), + 2 => Ok(Box::new(QLAgent::new(1., 6 * 6))), + 3 => Ok(Box::new(RandomAgent::new())), + 4 => Ok(Box::new(HumanPlayer::new())), + _ => Err("Only implemented agents 1-4!".to_string()), + }; res.push(new_agent?); } Ok(res) diff --git a/examples/tictactoe.rs b/examples/tictactoe.rs index d0025cf..68f5d5f 100644 --- a/examples/tictactoe.rs +++ b/examples/tictactoe.rs @@ -2,6 +2,8 @@ use agent::*; use env::TicTacToe; use rust_rl::network::nn::NeuralNetwork; use rust_rl::rl::{agent, env, training}; +use spaces::discrete::Ordinal; +use spaces::*; use std::io; use training::{utils, Trainer}; @@ -56,21 +58,24 @@ pub fn main() { ); } -fn get_agents(agent_nums: Vec) -> Result>, String> { - let mut res: Vec> = vec![]; +fn get_agents( + agent_nums: Vec, +) -> Result, Ordinal>>>, String> { + let mut res: Vec, Ordinal>>> = vec![]; let batch_size = 16; for agent_num in agent_nums { - let new_agent: Result, String> = match agent_num { - 1 => Ok(Box::new(DQLAgent::new( - 1., - batch_size, - new(0.001, batch_size), - ))), - 2 => Ok(Box::new(QLAgent::new(1., 3 * 3))), - 3 => Ok(Box::new(RandomAgent::new())), - 4 => Ok(Box::new(HumanPlayer::new())), - _ => Err("Only implemented agents 1-4!".to_string()), - }; + let new_agent: Result, Ordinal>>, String> = + match agent_num { + 1 => Ok(Box::new(DQLAgent::new( + 1., + batch_size, + new(0.001, batch_size), + ))), + 2 => Ok(Box::new(QLAgent::new(1., 3 * 3))), + 3 => Ok(Box::new(RandomAgent::new())), + 4 => Ok(Box::new(HumanPlayer::new())), + _ => Err("Only implemented agents 1-4!".to_string()), + }; res.push(new_agent?); } Ok(res) diff --git a/src/rl/agent/agent_trait.rs b/src/rl/agent/agent_trait.rs index d91ea4c..7ba0efe 100644 --- a/src/rl/agent/agent_trait.rs +++ b/src/rl/agent/agent_trait.rs @@ -1,7 +1,15 @@ use ndarray::{Array1, Array2}; +use spaces::Space; + +pub type State = ::Value; +pub type Action = ::Value; /// A trait including all functions required to train them. -pub trait Agent { +pub trait Agent +where + S: Space, + A: Space, +{ /// Returns a simple string identifying the specific agent type. fn get_id(&self) -> String; @@ -9,7 +17,7 @@ pub trait Agent { /// /// The concrete encoding of actions as usize value has to be looked up in the documentation of the specific environment. /// Advanced agents shouldn't need knowledge about the used encoding. - fn get_move(&mut self, env: Array2, actions: Array1, reward: f32) -> usize; + fn get_move(&mut self, env: Array2, actions: Array1, reward: f32) -> &Action; /// Informs the agent that the current epoch has finished and tells him about his final result. /// diff --git a/src/rl/agent/dql_agent.rs b/src/rl/agent/dql_agent.rs index ecb86c9..80d76ac 100644 --- a/src/rl/agent/dql_agent.rs +++ b/src/rl/agent/dql_agent.rs @@ -2,6 +2,12 @@ use crate::network::nn::NeuralNetwork; use crate::rl::agent::Agent; use crate::rl::algorithms::DQlearning; use ndarray::{Array1, Array2}; +use spaces::Space; +use spaces::discrete::NonNegativeIntegers; + + +pub type State = ::Value; +pub type Action = ::Value; /// An agent using Deep-Q-Learning, based on a small neural network. pub struct DQLAgent { @@ -19,7 +25,7 @@ impl DQLAgent { } } -impl Agent for DQLAgent { +impl Agent for DQLAgent { fn get_id(&self) -> String { "dqlearning agent".to_string() } @@ -29,8 +35,9 @@ impl Agent for DQLAgent { self.dqlearning.finish_round(result, final_state); } - fn get_move(&mut self, board: Array2, actions: Array1, reward: f32) -> usize { - self.dqlearning.get_move(board, actions, reward) + fn get_move(&mut self, board: Array2, actions: Array1, reward: f32) -> &Action { + let res = self.dqlearning.get_move(board, actions, reward); + res } fn get_learning_rate(&self) -> f32 { diff --git a/src/rl/agent/human_player.rs b/src/rl/agent/human_player.rs index 14d99b0..2571222 100644 --- a/src/rl/agent/human_player.rs +++ b/src/rl/agent/human_player.rs @@ -1,7 +1,11 @@ use crate::rl::agent::agent_trait::Agent; use ndarray::{Array1, Array2}; +use spaces::Space; use std::io; +pub type State = ::Value; +pub type Action = ::Value; + /// An agent which just shows the user the current environment and lets the user decide about each action. #[derive(Default)] pub struct HumanPlayer {} @@ -13,12 +17,12 @@ impl HumanPlayer { } } -impl Agent for HumanPlayer { +impl Agent for HumanPlayer { fn get_id(&self) -> String { "human player".to_string() } - fn get_move(&mut self, board: Array2, actions: Array1, _: f32) -> usize { + fn get_move(&mut self, board: Array2, actions: Array1, _: f32) -> &Action { let (n, m) = (board.shape()[0], board.shape()[1]); for i in 0..n { for j in 0..m { diff --git a/src/rl/agent/ql_agent.rs b/src/rl/agent/ql_agent.rs index 1fc41ad..5c441a1 100644 --- a/src/rl/agent/ql_agent.rs +++ b/src/rl/agent/ql_agent.rs @@ -2,6 +2,10 @@ use crate::rl::algorithms::Qlearning; use ndarray::{Array1, Array2}; use crate::rl::agent::Agent; +use spaces::Space; + +pub type State = ::Value; +pub type Action = ::Value; /// An agent working on a classical q-table. pub struct QLAgent { @@ -19,7 +23,7 @@ impl QLAgent { } } -impl Agent for QLAgent { +impl Agent for QLAgent { fn get_id(&self) -> String { "qlearning agent".to_string() } @@ -29,8 +33,9 @@ impl Agent for QLAgent { self.qlearning.finish_round(result, final_state); } - fn get_move(&mut self, board: Array2, actions: Array1, reward: f32) -> usize { - self.qlearning.get_move(board, actions, reward) + fn get_move(&mut self, board: Array2, actions: Array1, reward: f32) -> &Action { + let res = self.qlearning.get_move(board, actions, reward); + res } fn set_learning_rate(&mut self, lr: f32) -> Result<(), String> { diff --git a/src/rl/agent/random_agent.rs b/src/rl/agent/random_agent.rs index 7f73286..f29f9bd 100644 --- a/src/rl/agent/random_agent.rs +++ b/src/rl/agent/random_agent.rs @@ -1,6 +1,10 @@ use crate::rl::agent::agent_trait::Agent; use crate::rl::algorithms::utils; use ndarray::{Array1, Array2}; +use spaces::Space; + +pub type State = ::Value; +pub type Action = ::Value; /// An agent who acts randomly. /// @@ -16,13 +20,14 @@ impl RandomAgent { } } -impl Agent for RandomAgent { +impl Agent for RandomAgent { fn get_id(&self) -> String { "random agent".to_string() } - fn get_move(&mut self, _: Array2, actions: Array1, _: f32) -> usize { - utils::get_random_true_entry(actions) + fn get_move(&mut self, _: Array2, actions: Array1, _: f32) -> &Action { + let res = utils::get_random_true_entry(actions); + res } fn finish_round(&mut self, _single_res: i32, _final_state: Array2) {} diff --git a/src/rl/env/env_trait.rs b/src/rl/env/env_trait.rs index 34cbe2a..1cb3953 100644 --- a/src/rl/env/env_trait.rs +++ b/src/rl/env/env_trait.rs @@ -1,7 +1,18 @@ -use ndarray::{Array1, Array2}; +use ndarray::Array2; +use ndarray::Array1; +use spaces::Space; + +pub type State = ::Value; +pub type Action = ::Value; /// This trait defines all functions on which agents and other user might depend. pub trait Environment { + /// State representation + type StateSpace: Space; + + /// Action space representation + type ActionSpace: Space; + /// The central function which causes the environment to pass various information to the agent. /// /// The Array2 encodes the environment (the board). @@ -13,7 +24,7 @@ pub trait Environment { /// /// If the action is allowed for the currently active agent then update the environment and return true. /// Otherwise do nothing and return false. The same agent can then try a new move. - fn take_action(&mut self, action: usize) -> bool; + fn take_action(&mut self, action: &Action) -> bool; /// Shows the current envrionment state in a graphical way. /// /// The representation is environment specific and might be either by terminal, or in an extra window. diff --git a/src/rl/env/fortress.rs b/src/rl/env/fortress.rs index ff49f87..7e1b3b1 100644 --- a/src/rl/env/fortress.rs +++ b/src/rl/env/fortress.rs @@ -1,5 +1,7 @@ use crate::rl::env::env_trait::Environment; use ndarray::{Array, Array1, Array2}; +use spaces::discrete::NonNegativeIntegers; +use spaces::*; use std::cmp::Ordering; static NEIGHBOURS_LIST: [&[usize]; 6 * 6] = [ @@ -54,6 +56,8 @@ pub struct Fortress { } impl Environment for Fortress { + type StateSpace = ProductSpace; + type ActionSpace = NonNegativeIntegers; fn step(&self) -> (Array2, Array1, f32, bool) { if !self.active { eprintln!("Warning, calling step() after done = true!"); @@ -97,7 +101,8 @@ impl Environment for Fortress { println!(); } - fn take_action(&mut self, pos: usize) -> bool { + fn take_action(&mut self, pos: &u64) -> bool { + let pos = *pos as usize; let player_val = if self.first_player_turn { 1 } else { -1 }; // check that field is not controlled by enemy, no enemy building on field, no own building on max lv (3) already exists diff --git a/src/rl/env/tictactoe.rs b/src/rl/env/tictactoe.rs index 1f8c67f..e64dce3 100644 --- a/src/rl/env/tictactoe.rs +++ b/src/rl/env/tictactoe.rs @@ -1,5 +1,7 @@ use crate::rl::env::env_trait::Environment; use ndarray::{Array, Array1, Array2}; +use spaces::discrete::NonNegativeIntegers; +use spaces::*; static BITMASKS: [&[u16]; 9] = [ //TODO FIX: BROKEN @@ -39,6 +41,8 @@ impl Default for TicTacToe { } impl Environment for TicTacToe { + type StateSpace = ProductSpace; + type ActionSpace = NonNegativeIntegers; fn step(&self) -> (Array2, Array1, f32, bool) { // storing current position into ndarray let position = board_as_arr(self.player1, self.player2) @@ -78,7 +82,8 @@ impl Environment for TicTacToe { } } - fn take_action(&mut self, pos: usize) -> bool { + fn take_action(&mut self, pos: &u64) -> bool { + let pos = *pos as usize; if pos > 8 { return false; } diff --git a/src/rl/training/trainer.rs b/src/rl/training/trainer.rs index a3bae1a..561ade2 100644 --- a/src/rl/training/trainer.rs +++ b/src/rl/training/trainer.rs @@ -1,16 +1,24 @@ use crate::rl::agent::Agent; use crate::rl::env::Environment; use ndarray::Array2; +use spaces::Space; /// A trainer works on a given environment and a set of agents. -pub struct Trainer { - env: Box, +pub struct Trainer +where + S: Space, + A: Space, +{ + env: Box>, res: Vec<(u32, u32, u32)>, - agents: Vec>, + agents: Vec>>, } -impl Trainer { +impl Trainer { /// We construct a Trainer by passing a single environment and one or more (possibly different) agents. - pub fn new(env: Box, agents: Vec>) -> Result { + pub fn new( + env: Box>, + agents: Vec>>, + ) -> Result { if agents.is_empty() { return Err("At least one agent required!".to_string()); }