Skip to content

starting transfer to spaces #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ rand = "0.7"
fnv = "1.0.7"
ndarray-rand = "0.11.0"
ndarray-stats = "0.3.0"
spaces = "5.0.0"

[features]
default = []
Expand Down
31 changes: 18 additions & 13 deletions examples/fortress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use agent::*;
use env::Fortress;
use rust_rl::network::nn::NeuralNetwork;
use rust_rl::rl::{agent, env, training};
use spaces::discrete::Ordinal;
use spaces::*;
use std::io;
use training::{utils, Trainer};

Expand Down Expand Up @@ -62,21 +64,24 @@ pub fn main() {
);
}

fn get_agents(agent_nums: Vec<usize>) -> Result<Vec<Box<dyn Agent>>, String> {
let mut res: Vec<Box<dyn Agent>> = vec![];
fn get_agents(
agent_nums: Vec<usize>,
) -> Result<Vec<Box<dyn Agent<ProductSpace<Interval>, Ordinal>>>, String> {
let mut res: Vec<Box<dyn Agent<ProductSpace<Interval>, Ordinal>>> = vec![];
let batch_size = 16;
for agent_num in agent_nums {
let new_agent: Result<Box<dyn Agent>, String> = match agent_num {
1 => Ok(Box::new(DQLAgent::new(
1.,
batch_size,
new(0.001, batch_size),
))),
2 => Ok(Box::new(QLAgent::new(1., 6 * 6))),
3 => Ok(Box::new(RandomAgent::new())),
4 => Ok(Box::new(HumanPlayer::new())),
_ => Err("Only implemented agents 1-4!".to_string()),
};
let new_agent: Result<Box<dyn Agent<ProductSpace<Interval>, Ordinal>>, String> =
match agent_num {
1 => Ok(Box::new(DQLAgent::new(
1.,
batch_size,
new(0.001, batch_size),
))),
2 => Ok(Box::new(QLAgent::new(1., 6 * 6))),
3 => Ok(Box::new(RandomAgent::new())),
4 => Ok(Box::new(HumanPlayer::new())),
_ => Err("Only implemented agents 1-4!".to_string()),
};
res.push(new_agent?);
}
Ok(res)
Expand Down
31 changes: 18 additions & 13 deletions examples/tictactoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use agent::*;
use env::TicTacToe;
use rust_rl::network::nn::NeuralNetwork;
use rust_rl::rl::{agent, env, training};
use spaces::discrete::Ordinal;
use spaces::*;
use std::io;
use training::{utils, Trainer};

Expand Down Expand Up @@ -56,21 +58,24 @@ pub fn main() {
);
}

fn get_agents(agent_nums: Vec<usize>) -> Result<Vec<Box<dyn Agent>>, String> {
let mut res: Vec<Box<dyn Agent>> = vec![];
fn get_agents(
agent_nums: Vec<usize>,
) -> Result<Vec<Box<dyn Agent<ProductSpace<Interval>, Ordinal>>>, String> {
let mut res: Vec<Box<dyn Agent<ProductSpace<Interval>, Ordinal>>> = vec![];
let batch_size = 16;
for agent_num in agent_nums {
let new_agent: Result<Box<dyn Agent>, String> = match agent_num {
1 => Ok(Box::new(DQLAgent::new(
1.,
batch_size,
new(0.001, batch_size),
))),
2 => Ok(Box::new(QLAgent::new(1., 3 * 3))),
3 => Ok(Box::new(RandomAgent::new())),
4 => Ok(Box::new(HumanPlayer::new())),
_ => Err("Only implemented agents 1-4!".to_string()),
};
let new_agent: Result<Box<dyn Agent<ProductSpace<Interval>, Ordinal>>, String> =
match agent_num {
1 => Ok(Box::new(DQLAgent::new(
1.,
batch_size,
new(0.001, batch_size),
))),
2 => Ok(Box::new(QLAgent::new(1., 3 * 3))),
3 => Ok(Box::new(RandomAgent::new())),
4 => Ok(Box::new(HumanPlayer::new())),
_ => Err("Only implemented agents 1-4!".to_string()),
};
res.push(new_agent?);
}
Ok(res)
Expand Down
12 changes: 10 additions & 2 deletions src/rl/agent/agent_trait.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
use ndarray::{Array1, Array2};
use spaces::Space;

pub type State<S> = <S as Space>::Value;
pub type Action<S> = <S as Space>::Value;

/// A trait including all functions required to train them.
pub trait Agent {
pub trait Agent<S, A>
where
S: Space,
A: Space,
{
/// Returns a simple string identifying the specific agent type.
fn get_id(&self) -> String;

/// Expect the agent to return a single usize value corresponding to a (legal) action he picked.
///
/// The concrete encoding of actions as usize value has to be looked up in the documentation of the specific environment.
/// Advanced agents shouldn't need knowledge about the used encoding.
fn get_move(&mut self, env: Array2<f32>, actions: Array1<bool>, reward: f32) -> usize;
fn get_move(&mut self, env: Array2<f32>, actions: Array1<bool>, reward: f32) -> &Action<A>;

/// Informs the agent that the current epoch has finished and tells him about his final result.
///
Expand Down
13 changes: 10 additions & 3 deletions src/rl/agent/dql_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ use crate::network::nn::NeuralNetwork;
use crate::rl::agent::Agent;
use crate::rl::algorithms::DQlearning;
use ndarray::{Array1, Array2};
use spaces::Space;
use spaces::discrete::NonNegativeIntegers;


pub type State<S> = <S as Space>::Value;
pub type Action<S> = <S as Space>::Value;

/// An agent using Deep-Q-Learning, based on a small neural network.
pub struct DQLAgent {
Expand All @@ -19,7 +25,7 @@ impl DQLAgent {
}
}

impl Agent for DQLAgent {
impl<S: Space, A: Space> Agent<S, A> for DQLAgent {
fn get_id(&self) -> String {
"dqlearning agent".to_string()
}
Expand All @@ -29,8 +35,9 @@ impl Agent for DQLAgent {
self.dqlearning.finish_round(result, final_state);
}

fn get_move(&mut self, board: Array2<f32>, actions: Array1<bool>, reward: f32) -> usize {
self.dqlearning.get_move(board, actions, reward)
fn get_move(&mut self, board: Array2<f32>, actions: Array1<bool>, reward: f32) -> &Action<A> {
let res = self.dqlearning.get_move(board, actions, reward);
res
}

fn get_learning_rate(&self) -> f32 {
Expand Down
8 changes: 6 additions & 2 deletions src/rl/agent/human_player.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use crate::rl::agent::agent_trait::Agent;
use ndarray::{Array1, Array2};
use spaces::Space;
use std::io;

pub type State<S> = <S as Space>::Value;
pub type Action<S> = <S as Space>::Value;

/// An agent which just shows the user the current environment and lets the user decide about each action.
#[derive(Default)]
pub struct HumanPlayer {}
Expand All @@ -13,12 +17,12 @@ impl HumanPlayer {
}
}

impl Agent for HumanPlayer {
impl<S: Space, A: Space> Agent<S, A> for HumanPlayer {
fn get_id(&self) -> String {
"human player".to_string()
}

fn get_move(&mut self, board: Array2<f32>, actions: Array1<bool>, _: f32) -> usize {
fn get_move(&mut self, board: Array2<f32>, actions: Array1<bool>, _: f32) -> &Action<A> {
let (n, m) = (board.shape()[0], board.shape()[1]);
for i in 0..n {
for j in 0..m {
Expand Down
11 changes: 8 additions & 3 deletions src/rl/agent/ql_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ use crate::rl::algorithms::Qlearning;
use ndarray::{Array1, Array2};

use crate::rl::agent::Agent;
use spaces::Space;

pub type State<S> = <S as Space>::Value;
pub type Action<S> = <S as Space>::Value;

/// An agent working on a classical q-table.
pub struct QLAgent {
Expand All @@ -19,7 +23,7 @@ impl QLAgent {
}
}

impl Agent for QLAgent {
impl<S: Space, A: Space> Agent<S, A> for QLAgent {
fn get_id(&self) -> String {
"qlearning agent".to_string()
}
Expand All @@ -29,8 +33,9 @@ impl Agent for QLAgent {
self.qlearning.finish_round(result, final_state);
}

fn get_move(&mut self, board: Array2<f32>, actions: Array1<bool>, reward: f32) -> usize {
self.qlearning.get_move(board, actions, reward)
fn get_move(&mut self, board: Array2<f32>, actions: Array1<bool>, reward: f32) -> &Action<A> {
let res = self.qlearning.get_move(board, actions, reward);
res
}

fn set_learning_rate(&mut self, lr: f32) -> Result<(), String> {
Expand Down
11 changes: 8 additions & 3 deletions src/rl/agent/random_agent.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use crate::rl::agent::agent_trait::Agent;
use crate::rl::algorithms::utils;
use ndarray::{Array1, Array2};
use spaces::Space;

pub type State<S> = <S as Space>::Value;
pub type Action<S> = <S as Space>::Value;

/// An agent who acts randomly.
///
Expand All @@ -16,13 +20,14 @@ impl RandomAgent {
}
}

impl Agent for RandomAgent {
impl<S: Space, A: Space> Agent<S, A> for RandomAgent {
fn get_id(&self) -> String {
"random agent".to_string()
}

fn get_move(&mut self, _: Array2<f32>, actions: Array1<bool>, _: f32) -> usize {
utils::get_random_true_entry(actions)
fn get_move(&mut self, _: Array2<f32>, actions: Array1<bool>, _: f32) -> &Action<A> {
let res = utils::get_random_true_entry(actions);
res
}

fn finish_round(&mut self, _single_res: i32, _final_state: Array2<f32>) {}
Expand Down
15 changes: 13 additions & 2 deletions src/rl/env/env_trait.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
use ndarray::{Array1, Array2};
use ndarray::Array2;
use ndarray::Array1;
use spaces::Space;

pub type State<S> = <S as Space>::Value;
pub type Action<S> = <S as Space>::Value;

/// This trait defines all functions on which agents and other user might depend.
pub trait Environment {
/// State representation
type StateSpace: Space;

/// Action space representation
type ActionSpace: Space;

/// The central function which causes the environment to pass various information to the agent.
///
/// The Array2 encodes the environment (the board).
Expand All @@ -13,7 +24,7 @@ pub trait Environment {
///
/// If the action is allowed for the currently active agent then update the environment and return true.
/// Otherwise do nothing and return false. The same agent can then try a new move.
fn take_action(&mut self, action: usize) -> bool;
fn take_action(&mut self, action: &Action<Self::ActionSpace>) -> bool;
/// Shows the current envrionment state in a graphical way.
///
/// The representation is environment specific and might be either by terminal, or in an extra window.
Expand Down
7 changes: 6 additions & 1 deletion src/rl/env/fortress.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::rl::env::env_trait::Environment;
use ndarray::{Array, Array1, Array2};
use spaces::discrete::NonNegativeIntegers;
use spaces::*;
use std::cmp::Ordering;

static NEIGHBOURS_LIST: [&[usize]; 6 * 6] = [
Expand Down Expand Up @@ -54,6 +56,8 @@ pub struct Fortress {
}

impl Environment for Fortress {
type StateSpace = ProductSpace<Interval>;
type ActionSpace = NonNegativeIntegers;
fn step(&self) -> (Array2<f32>, Array1<bool>, f32, bool) {
if !self.active {
eprintln!("Warning, calling step() after done = true!");
Expand Down Expand Up @@ -97,7 +101,8 @@ impl Environment for Fortress {
println!();
}

fn take_action(&mut self, pos: usize) -> bool {
fn take_action(&mut self, pos: &u64) -> bool {
let pos = *pos as usize;
let player_val = if self.first_player_turn { 1 } else { -1 };

// check that field is not controlled by enemy, no enemy building on field, no own building on max lv (3) already exists
Expand Down
7 changes: 6 additions & 1 deletion src/rl/env/tictactoe.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::rl::env::env_trait::Environment;
use ndarray::{Array, Array1, Array2};
use spaces::discrete::NonNegativeIntegers;
use spaces::*;

static BITMASKS: [&[u16]; 9] = [
//TODO FIX: BROKEN
Expand Down Expand Up @@ -39,6 +41,8 @@ impl Default for TicTacToe {
}

impl Environment for TicTacToe {
type StateSpace = ProductSpace<Interval>;
type ActionSpace = NonNegativeIntegers;
fn step(&self) -> (Array2<f32>, Array1<bool>, f32, bool) {
// storing current position into ndarray
let position = board_as_arr(self.player1, self.player2)
Expand Down Expand Up @@ -78,7 +82,8 @@ impl Environment for TicTacToe {
}
}

fn take_action(&mut self, pos: usize) -> bool {
fn take_action(&mut self, pos: &u64) -> bool {
let pos = *pos as usize;
if pos > 8 {
return false;
}
Expand Down
18 changes: 13 additions & 5 deletions src/rl/training/trainer.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
use crate::rl::agent::Agent;
use crate::rl::env::Environment;
use ndarray::Array2;
use spaces::Space;
/// A trainer works on a given environment and a set of agents.
pub struct Trainer {
env: Box<dyn Environment>,
pub struct Trainer<S, A>
where
S: Space,
A: Space,
{
env: Box<dyn Environment<StateSpace = S, ActionSpace = A>>,
res: Vec<(u32, u32, u32)>,
agents: Vec<Box<dyn Agent>>,
agents: Vec<Box<dyn Agent<S, A>>>,
}

impl Trainer {
impl<S: Space, A: Space> Trainer<S, A> {
/// We construct a Trainer by passing a single environment and one or more (possibly different) agents.
pub fn new(env: Box<dyn Environment>, agents: Vec<Box<dyn Agent>>) -> Result<Self, String> {
pub fn new(
env: Box<dyn Environment<StateSpace = S, ActionSpace = A>>,
agents: Vec<Box<dyn Agent<S, A>>>,
) -> Result<Self, String> {
if agents.is_empty() {
return Err("At least one agent required!".to_string());
}
Expand Down