diff --git a/pom.xml b/pom.xml
index 6826743..9c0272d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1,6 +1,6 @@
-
4.0.0
@@ -38,7 +38,8 @@
Reinforcement Learning Algorithms
- Classical RL algorithms implemented in Java, including Q-Learn, R-Learn, SARSA, Actor-Critic
+ Classical RL algorithms implemented in Java, including Q-Learn, R-Learn, SARSA, Actor-Critic
+
https://github.com/chen0040/java-reinforcement-learning
@@ -235,7 +236,6 @@
-
org.codehaus.mojo
findbugs-maven-plugin
@@ -361,8 +361,6 @@
-
-
@@ -382,14 +380,13 @@
org.apache.maven.plugins
maven-resources-plugin
+ 2.3
UTF-8
-
-
@@ -488,23 +485,12 @@
-
-
-
- org.projectlombok
- lombok
- provided
- ${lombok.version}
-
-
com.alibaba
fastjson
1.2.41
-
-
diff --git a/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java
index 7de7f9b..f591257 100644
--- a/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java
+++ b/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java
@@ -1,8 +1,8 @@
package com.github.chen0040.rl.actionselection;
-import com.github.chen0040.rl.utils.IndexValue;
import com.github.chen0040.rl.models.QModel;
import com.github.chen0040.rl.models.UtilityModel;
+import com.github.chen0040.rl.utils.IndexValue;
import java.util.HashMap;
import java.util.Map;
@@ -14,58 +14,50 @@
*/
public abstract class AbstractActionSelectionStrategy implements ActionSelectionStrategy {
+ Map attributes = new HashMap<>();
private String prototype;
- protected Map attributes = new HashMap();
- public String getPrototype(){
- return prototype;
+ @SuppressWarnings("Used-by-user")
+ public AbstractActionSelectionStrategy() {
+ this.prototype = this.getClass().getCanonicalName();
}
- public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
- return new IndexValue();
+ @SuppressWarnings("Used-by-user")
+ public AbstractActionSelectionStrategy(final HashMap attributes) {
+ this.attributes = attributes;
+ if (attributes.containsKey("prototype")) {
+ this.prototype = attributes.get("prototype");
+ }
}
- public IndexValue selectAction(int stateId, UtilityModel model, Set actionsAtState) {
- return new IndexValue();
+ @Override
+ public String getPrototype() {
+ return this.prototype;
}
- public AbstractActionSelectionStrategy(){
- prototype = this.getClass().getCanonicalName();
+ @Override
+ public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) {
+ return new IndexValue();
}
-
- public AbstractActionSelectionStrategy(HashMap attributes){
- this.attributes = attributes;
- if(attributes.containsKey("prototype")){
- this.prototype = attributes.get("prototype");
- }
+ @Override
+ public IndexValue selectAction(final int stateId, final UtilityModel model, final Set actionsAtState) {
+ return new IndexValue();
}
- public Map getAttributes(){
- return attributes;
+ @Override
+ public Map getAttributes() {
+ return this.attributes;
}
@Override
- public boolean equals(Object obj) {
- ActionSelectionStrategy rhs = (ActionSelectionStrategy)obj;
- if(!prototype.equalsIgnoreCase(rhs.getPrototype())) return false;
- for(Map.Entry entry : rhs.getAttributes().entrySet()) {
- if(!attributes.containsKey(entry.getKey())) {
- return false;
- }
- if(!attributes.get(entry.getKey()).equals(entry.getValue())){
- return false;
- }
- }
- for(Map.Entry entry : attributes.entrySet()) {
- if(!rhs.getAttributes().containsKey(entry.getKey())) {
- return false;
- }
- if(!rhs.getAttributes().get(entry.getKey()).equals(entry.getValue())){
- return false;
- }
- }
- return true;
+ public boolean equals(final Object obj) {
+ final ActionSelectionStrategy rhs = (ActionSelectionStrategy) obj;
+ return this.prototype.equalsIgnoreCase(rhs.getPrototype()) &&
+ rhs.getAttributes().entrySet().stream().noneMatch(entry -> !this.attributes.containsKey(entry.getKey()) ||
+ !this.attributes.get(entry.getKey()).equals(entry.getValue())) &&
+ this.attributes.entrySet().stream().noneMatch(entry -> !rhs.getAttributes().containsKey(entry.getKey()) ||
+ !rhs.getAttributes().get(entry.getKey()).equals(entry.getValue()));
}
@Override
diff --git a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java
index ce92be0..cee0a7e 100644
--- a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java
+++ b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java
@@ -1,5 +1,6 @@
package com.github.chen0040.rl.actionselection;
+import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
@@ -7,53 +8,54 @@
/**
* Created by xschen on 9/27/2015 0027.
*/
-public class ActionSelectionStrategyFactory {
- public static ActionSelectionStrategy deserialize(String conf){
- String[] comps = conf.split(";");
-
- HashMap attributes = new HashMap();
- for(int i=0; i < comps.length; ++i){
- String comp = comps[i];
- String[] field = comp.split("=");
- if(field.length < 2) continue;
- String fieldname = field[0].trim();
- String fieldvalue = field[1].trim();
-
- attributes.put(fieldname, fieldvalue);
+public class ActionSelectionStrategyFactory implements Serializable {
+
+ public static ActionSelectionStrategy deserialize(final String conf) {
+ final String[] comps = conf.split(";");
+
+ final HashMap attributes = new HashMap<>();
+ for (final String comp : comps) {
+ final String[] field = comp.split("=");
+ if (field.length < 2) {
+ continue;
+ }
+
+ attributes.put(field[0].trim(), field[1].trim());
}
- if(attributes.isEmpty()){
+ if (attributes.isEmpty()) {
attributes.put("prototype", conf);
}
- String prototype = attributes.get("prototype");
- if(prototype.equals(GreedyActionSelectionStrategy.class.getCanonicalName())){
+ final String prototype = attributes.get("prototype");
+ if (prototype.equals(GreedyActionSelectionStrategy.class.getCanonicalName())) {
return new GreedyActionSelectionStrategy();
- } else if(prototype.equals(SoftMaxActionSelectionStrategy.class.getCanonicalName())){
+ } else if (prototype.equals(SoftMaxActionSelectionStrategy.class.getCanonicalName())) {
return new SoftMaxActionSelectionStrategy();
- } else if(prototype.equals(EpsilonGreedyActionSelectionStrategy.class.getCanonicalName())){
+ } else if (prototype.equals(EpsilonGreedyActionSelectionStrategy.class.getCanonicalName())) {
return new EpsilonGreedyActionSelectionStrategy(attributes);
- } else if(prototype.equals(GibbsSoftMaxActionSelectionStrategy.class.getCanonicalName())){
+ } else if (prototype.equals(GibbsSoftMaxActionSelectionStrategy.class.getCanonicalName())) {
return new GibbsSoftMaxActionSelectionStrategy();
}
return null;
}
- public static String serialize(ActionSelectionStrategy strategy){
- Map attributes = strategy.getAttributes();
+ public static String serialize(final ActionSelectionStrategy strategy) {
+ final Map attributes = strategy.getAttributes();
attributes.put("prototype", strategy.getPrototype());
- StringBuilder sb = new StringBuilder();
+ final StringBuilder sb = new StringBuilder();
boolean first = true;
- for(Map.Entry entry : attributes.entrySet()){
- if(first){
+ for (final Map.Entry entry : attributes.entrySet()) {
+ if (first) {
first = false;
- }
- else{
+ } else {
sb.append(";");
}
- sb.append(entry.getKey()+"="+entry.getValue());
+ sb.append(entry.getKey()).append("=").append(entry.getValue());
}
return sb.toString();
}
+
+
}
diff --git a/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java
index 5f7db9a..5ac3d36 100644
--- a/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java
+++ b/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java
@@ -1,7 +1,7 @@
package com.github.chen0040.rl.actionselection;
-import com.github.chen0040.rl.utils.IndexValue;
import com.github.chen0040.rl.models.QModel;
+import com.github.chen0040.rl.utils.IndexValue;
import java.util.*;
@@ -10,69 +10,62 @@
* Created by xschen on 9/27/2015 0027.
*/
public class EpsilonGreedyActionSelectionStrategy extends AbstractActionSelectionStrategy {
- public static final String EPSILON = "epsilon";
+ private static final String EPSILON = "epsilon";
private Random random = new Random();
- @Override
- public Object clone(){
- EpsilonGreedyActionSelectionStrategy clone = new EpsilonGreedyActionSelectionStrategy();
- clone.copy(this);
- return clone;
- }
-
- public void copy(EpsilonGreedyActionSelectionStrategy rhs){
- random = rhs.random;
- for(Map.Entry entry : rhs.attributes.entrySet()){
- attributes.put(entry.getKey(), entry.getValue());
- }
+ @SuppressWarnings("Used-by-user")
+ public EpsilonGreedyActionSelectionStrategy() {
+ this.epsilon();
}
- @Override
- public boolean equals(Object obj){
- if(obj != null && obj instanceof EpsilonGreedyActionSelectionStrategy){
- EpsilonGreedyActionSelectionStrategy rhs = (EpsilonGreedyActionSelectionStrategy)obj;
- if(epsilon() != rhs.epsilon()) return false;
- // if(!random.equals(rhs.random)) return false;
- return true;
- }
- return false;
+ @SuppressWarnings("Used-by-user")
+ public EpsilonGreedyActionSelectionStrategy(final HashMap attributes) {
+ super(attributes);
}
- private double epsilon(){
- return Double.parseDouble(attributes.get(EPSILON));
+ @SuppressWarnings("Used-by-user")
+ public EpsilonGreedyActionSelectionStrategy(final Random random) {
+ this.random = random;
+ this.epsilon();
}
- public EpsilonGreedyActionSelectionStrategy(){
- epsilon(0.1);
+ @Override
+ public Object clone() {
+ final EpsilonGreedyActionSelectionStrategy clone = new EpsilonGreedyActionSelectionStrategy();
+ clone.copy(this);
+ return clone;
}
- public EpsilonGreedyActionSelectionStrategy(HashMap attributes){
- super(attributes);
+ public void copy(final EpsilonGreedyActionSelectionStrategy rhs) {
+ this.random = rhs.random;
+ for (final Map.Entry entry : rhs.attributes.entrySet()) {
+ this.attributes.put(entry.getKey(), entry.getValue());
+ }
}
- private void epsilon(double value){
- attributes.put(EPSILON, "" + value);
+ @Override
+ public boolean equals(final Object obj) {
+ return obj instanceof EpsilonGreedyActionSelectionStrategy && this.epsilon() == ((EpsilonGreedyActionSelectionStrategy) obj).epsilon();
}
- public EpsilonGreedyActionSelectionStrategy(Random random){
- this.random = random;
- epsilon(0.1);
+ private double epsilon() {
+ return Double.parseDouble(this.attributes.get(EpsilonGreedyActionSelectionStrategy.EPSILON));
}
@Override
- public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
- if(random.nextDouble() < 1- epsilon()){
+ public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) {
+ if (this.random.nextDouble() < 1 - this.epsilon()) {
return model.actionWithMaxQAtState(stateId, actionsAtState);
- }else{
- int actionId;
- if(actionsAtState != null && !actionsAtState.isEmpty()) {
- List actions = new ArrayList<>(actionsAtState);
- actionId = actions.get(random.nextInt(actions.size()));
+ } else {
+ final int actionId;
+ if (actionsAtState != null && !actionsAtState.isEmpty()) {
+ final List actions = new ArrayList<>(actionsAtState);
+ actionId = actions.get(this.random.nextInt(actions.size()));
} else {
- actionId = random.nextInt(model.getActionCount());
+ actionId = this.random.nextInt(model.getActionCount());
}
- double Q = model.getQ(stateId, actionId);
+ final double Q = model.getQ(stateId, actionId);
return new IndexValue(actionId, Q);
}
}
diff --git a/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java
index 8b2d8d2..283748c 100644
--- a/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java
+++ b/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java
@@ -1,12 +1,13 @@
package com.github.chen0040.rl.actionselection;
-import com.github.chen0040.rl.utils.IndexValue;
import com.github.chen0040.rl.models.QModel;
+import com.github.chen0040.rl.utils.IndexValue;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.Set;
+import java.util.stream.IntStream;
/**
@@ -14,52 +15,49 @@
*/
public class GibbsSoftMaxActionSelectionStrategy extends AbstractActionSelectionStrategy {
- private Random random = null;
- public GibbsSoftMaxActionSelectionStrategy(){
- random = new Random();
+ private final Random random;
+
+ @SuppressWarnings("Used-by-user")
+ public GibbsSoftMaxActionSelectionStrategy() {
+ this.random = new Random();
}
- public GibbsSoftMaxActionSelectionStrategy(Random random){
+ @SuppressWarnings("Used-by-user")
+ public GibbsSoftMaxActionSelectionStrategy(final Random random) {
this.random = random;
}
@Override
public Object clone() {
- GibbsSoftMaxActionSelectionStrategy clone = new GibbsSoftMaxActionSelectionStrategy();
- return clone;
+ return new GibbsSoftMaxActionSelectionStrategy();
}
@Override
- public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
- List actions = new ArrayList();
- if(actionsAtState == null){
- for(int i=0; i < model.getActionCount(); ++i){
- actions.add(i);
- }
- }else{
- for(Integer actionId : actionsAtState){
- actions.add(actionId);
- }
+ public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) {
+ final List actions = new ArrayList();
+ if (actionsAtState == null) {
+ IntStream.range(0, model.getActionCount()).forEach(actions::add);
+ } else {
+ actions.addAll(actionsAtState);
}
double sum = 0;
- List plist = new ArrayList();
- for(int i=0; i < actions.size(); ++i){
- int actionId = actions.get(i);
- double p = Math.exp(model.getQ(stateId, actionId));
+ final List plist = new ArrayList();
+ for (final int actionId : actions) {
+ final double p = Math.exp(model.getQ(stateId, actionId));
sum += p;
plist.add(sum);
}
- IndexValue iv = new IndexValue();
+ final IndexValue iv = new IndexValue();
iv.setIndex(-1);
iv.setValue(Double.NEGATIVE_INFINITY);
- double r = sum * random.nextDouble();
- for(int i=0; i < actions.size(); ++i){
+ final double r = sum * this.random.nextDouble();
+ for (int i = 0; i < actions.size(); ++i) {
- if(plist.get(i) >= r){
- int actionId = actions.get(i);
+ if (plist.get(i) >= r) {
+ final int actionId = actions.get(i);
iv.setValue(model.getQ(stateId, actionId));
iv.setIndex(actionId);
break;
diff --git a/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java
index 6b0f350..e735fd3 100644
--- a/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java
+++ b/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java
@@ -1,7 +1,7 @@
package com.github.chen0040.rl.actionselection;
-import com.github.chen0040.rl.utils.IndexValue;
import com.github.chen0040.rl.models.QModel;
+import com.github.chen0040.rl.utils.IndexValue;
import java.util.Set;
@@ -11,18 +11,17 @@
*/
public class GreedyActionSelectionStrategy extends AbstractActionSelectionStrategy {
@Override
- public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
+ public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) {
return model.actionWithMaxQAtState(stateId, actionsAtState);
}
@Override
- public Object clone(){
- GreedyActionSelectionStrategy clone = new GreedyActionSelectionStrategy();
- return clone;
+ public Object clone() {
+ return new GreedyActionSelectionStrategy();
}
@Override
- public boolean equals(Object obj){
- return obj != null && obj instanceof GreedyActionSelectionStrategy;
+ public boolean equals(final Object obj) {
+ return obj instanceof GreedyActionSelectionStrategy;
}
}
diff --git a/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java
index f9735b9..c0b89f5 100644
--- a/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java
+++ b/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java
@@ -1,7 +1,7 @@
package com.github.chen0040.rl.actionselection;
-import com.github.chen0040.rl.utils.IndexValue;
import com.github.chen0040.rl.models.QModel;
+import com.github.chen0040.rl.utils.IndexValue;
import java.util.Random;
import java.util.Set;
@@ -13,27 +13,26 @@
public class SoftMaxActionSelectionStrategy extends AbstractActionSelectionStrategy {
private Random random = new Random();
- @Override
- public Object clone(){
- SoftMaxActionSelectionStrategy clone = new SoftMaxActionSelectionStrategy(random);
- return clone;
- }
+ public SoftMaxActionSelectionStrategy() {
- @Override
- public boolean equals(Object obj){
- return obj != null && obj instanceof SoftMaxActionSelectionStrategy;
}
- public SoftMaxActionSelectionStrategy(){
+ public SoftMaxActionSelectionStrategy(final Random random) {
+ this.random = random;
+ }
+ @Override
+ public Object clone() {
+ return new SoftMaxActionSelectionStrategy(this.random);
}
- public SoftMaxActionSelectionStrategy(Random random){
- this.random = random;
+ @Override
+ public boolean equals(final Object obj) {
+ return obj instanceof SoftMaxActionSelectionStrategy;
}
@Override
- public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
- return model.actionWithSoftMaxQAtState(stateId, actionsAtState, random);
+ public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) {
+ return model.actionWithSoftMaxQAtState(stateId, actionsAtState, this.random);
}
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java
index 6e34874..864668f 100644
--- a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java
+++ b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java
@@ -3,9 +3,7 @@
import com.github.chen0040.rl.utils.Vec;
import java.io.Serializable;
-import java.util.Random;
import java.util.Set;
-import java.util.function.Function;
/**
@@ -17,85 +15,85 @@ public class ActorCriticAgent implements Serializable {
private int prevState;
private int prevAction;
- public void enableEligibilityTrace(double lambda){
- ActorCriticLambdaLearner acll = new ActorCriticLambdaLearner(learner);
- acll.setLambda(lambda);
- learner = acll;
+ @SuppressWarnings("Used-by-user")
+ public ActorCriticAgent(final int stateCount, final int actionCount) {
+ this.learner = new ActorCriticLearner(stateCount, actionCount);
}
- public void start(int stateId){
- currentState = stateId;
- prevAction = -1;
- prevState = -1;
- }
+ public ActorCriticAgent() {
- public ActorCriticLearner getLearner(){
- return learner;
}
- public void setLearner(ActorCriticLearner learner){
+ public ActorCriticAgent(final ActorCriticLearner learner) {
this.learner = learner;
}
- public ActorCriticAgent(int stateCount, int actionCount){
- learner = new ActorCriticLearner(stateCount, actionCount);
+ @SuppressWarnings("Used-by-user")
+ public void enableEligibilityTrace(final double lambda) {
+ final ActorCriticLambdaLearner acll = new ActorCriticLambdaLearner(this.learner);
+ acll.setLambda(lambda);
+ this.learner = acll;
}
- public ActorCriticAgent(){
+ @SuppressWarnings("Used-by-user")
+ public void start(final int stateId) {
+ this.currentState = stateId;
+ this.prevAction = -1;
+ this.prevState = -1;
+ }
+ @SuppressWarnings("Used-by-user")
+ public ActorCriticLearner getLearner() {
+ return this.learner;
}
- public ActorCriticAgent(ActorCriticLearner learner){
+ public void setLearner(final ActorCriticLearner learner) {
this.learner = learner;
}
- public ActorCriticAgent makeCopy(){
- ActorCriticAgent clone = new ActorCriticAgent();
+ public ActorCriticAgent makeCopy() {
+ final ActorCriticAgent clone = new ActorCriticAgent();
clone.copy(this);
return clone;
}
- public void copy(ActorCriticAgent rhs){
- learner = (ActorCriticLearner)rhs.learner.makeCopy();
- prevAction = rhs.prevAction;
- prevState = rhs.prevState;
- currentState = rhs.currentState;
+ public void copy(final ActorCriticAgent rhs) {
+ this.learner = (ActorCriticLearner) rhs.learner.makeCopy();
+ this.prevAction = rhs.prevAction;
+ this.prevState = rhs.prevState;
+ this.currentState = rhs.currentState;
}
@Override
- public boolean equals(Object obj){
- if(obj != null && obj instanceof ActorCriticAgent){
- ActorCriticAgent rhs = (ActorCriticAgent)obj;
- return learner.equals(rhs.learner) && prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState;
+ public boolean equals(final Object obj) {
+ if (obj instanceof ActorCriticAgent) {
+ final ActorCriticAgent rhs = (ActorCriticAgent) obj;
+ return this.learner.equals(rhs.learner) && this.prevAction == rhs.prevAction && this.prevState == rhs.prevState && this.currentState == rhs.currentState;
}
return false;
}
- public int selectAction(Set actionsAtState){
- return learner.selectAction(currentState, actionsAtState);
+ public int selectAction(final Set actionsAtState) {
+ return this.learner.selectAction(this.currentState, actionsAtState);
}
- public int selectAction(){
- return learner.selectAction(currentState);
+ public int selectAction() {
+ return this.learner.selectAction(this.currentState);
}
- public void update(int actionTaken, int newState, double immediateReward, final Vec V){
- update(actionTaken, newState, null, immediateReward, V);
+ public void update(final int actionTaken, final int newState, final double immediateReward, final Vec V) {
+ this.update(actionTaken, newState, null, immediateReward, V);
}
- public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward, final Vec V){
+ public void update(final int actionTaken, final int newState, final Set actionsAtNewState, final double immediateReward, final Vec V) {
- learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward, new Function() {
- public Double apply(Integer stateId) {
- return V.get(stateId);
- }
- });
+ this.learner.update(this.currentState, actionTaken, newState, actionsAtNewState, immediateReward, V::get);
- prevAction = actionTaken;
- prevState = currentState;
+ this.prevAction = actionTaken;
+ this.prevState = this.currentState;
- currentState = newState;
+ this.currentState = newState;
}
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java
index afdb314..1f85622 100644
--- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java
+++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java
@@ -3,14 +3,13 @@
import com.github.chen0040.rl.utils.IndexValue;
import java.io.Serializable;
-import java.util.Random;
import java.util.Set;
/**
* Created by xschen on 9/27/2015 0027.
*/
-public class QAgent implements Serializable{
+public class QAgent implements Serializable {
private QLearner learner;
private int currentState;
private int prevState;
@@ -18,94 +17,97 @@ public class QAgent implements Serializable{
/** action taken at prevState */
private int prevAction;
- public int getCurrentState(){
- return currentState;
+ public QAgent(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) {
+ this.learner = new QLearner(stateCount, actionCount, alpha, gamma, initialQ);
}
- public int getPrevState(){
- return prevState;
+ public QAgent(final QLearner learner) {
+ this.learner = learner;
}
- public int getPrevAction(){
- return prevAction;
+ public QAgent(final int stateCount, final int actionCount) {
+ this.learner = new QLearner(stateCount, actionCount);
}
- public void start(int currentState){
- this.currentState = currentState;
- this.prevAction = -1;
- this.prevState = -1;
- }
+ public QAgent() {
- public IndexValue selectAction(){
- return learner.selectAction(currentState);
}
- public IndexValue selectAction(Set actionsAtState){
- return learner.selectAction(currentState, actionsAtState);
+ @SuppressWarnings("Used-by-user")
+ public int getCurrentState() {
+ return this.currentState;
}
- public void update(int actionTaken, int newState, double immediateReward){
- update(actionTaken, newState, null, immediateReward);
+ @SuppressWarnings("Used-by-user")
+ public int getPrevState() {
+ return this.prevState;
}
- public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward){
-
- learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward);
-
- prevState = currentState;
- prevAction = actionTaken;
-
- currentState = newState;
+ @SuppressWarnings("Used-by-user")
+ public int getPrevAction() {
+ return this.prevAction;
}
- public void enableEligibilityTrace(double lambda){
- QLambdaLearner acll = new QLambdaLearner(learner);
- acll.setLambda(lambda);
- learner = acll;
+ public void start(final int currentState) {
+ this.currentState = currentState;
+ this.prevAction = -1;
+ this.prevState = -1;
}
- public QLearner getLearner(){
- return learner;
+ public IndexValue selectAction() {
+ return this.learner.selectAction(this.currentState);
}
- public void setLearner(QLearner learner){
- this.learner = learner;
+ public IndexValue selectAction(final Set actionsAtState) {
+ return this.learner.selectAction(this.currentState, actionsAtState);
}
- public QAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ){
- learner = new QLearner(stateCount, actionCount, alpha, gamma, initialQ);
+ public void update(final int actionTaken, final int newState, final double immediateReward) {
+ this.update(actionTaken, newState, null, immediateReward);
}
- public QAgent(QLearner learner){
- this.learner = learner;
+ public void update(final int actionTaken, final int newState, final Set actionsAtNewState, final double immediateReward) {
+
+ this.learner.update(this.currentState, actionTaken, newState, actionsAtNewState, immediateReward);
+
+ this.prevState = this.currentState;
+ this.prevAction = actionTaken;
+
+ this.currentState = newState;
}
- public QAgent(int stateCount, int actionCount){
- learner = new QLearner(stateCount, actionCount);
+ public void enableEligibilityTrace(final double lambda) {
+ final QLambdaLearner acll = new QLambdaLearner(this.learner);
+ acll.setLambda(lambda);
+ this.learner = acll;
}
- public QAgent(){
+ public QLearner getLearner() {
+ return this.learner;
+ }
+ public void setLearner(final QLearner learner) {
+ this.learner = learner;
}
- public QAgent makeCopy(){
- QAgent clone = new QAgent();
+ public QAgent makeCopy() {
+ final QAgent clone = new QAgent();
clone.copy(this);
return clone;
}
- public void copy(QAgent rhs){
- learner.copy(rhs.learner);
- prevAction = rhs.prevAction;
- prevState = rhs.prevState;
- currentState = rhs.currentState;
+ public void copy(final QAgent rhs) {
+ this.learner.copy(rhs.learner);
+ this.prevAction = rhs.prevAction;
+ this.prevState = rhs.prevState;
+ this.currentState = rhs.currentState;
}
@Override
- public boolean equals(Object obj){
- if(obj != null && obj instanceof QAgent){
- QAgent rhs = (QAgent)obj;
- return prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState && learner.equals(rhs.learner);
+ public boolean equals(final Object obj) {
+ if (obj instanceof QAgent) {
+ final QAgent rhs = (QAgent) obj;
+ return this.prevAction == rhs.prevAction && this.prevState == rhs.prevState && this.currentState == rhs.currentState && this.learner.equals(rhs.learner);
}
return false;
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java
index 875ef3a..30ffd2e 100644
--- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java
+++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java
@@ -1,6 +1,7 @@
package com.github.chen0040.rl.learning.qlearn;
+import com.github.chen0040.rl.models.DefaultValues;
import com.github.chen0040.rl.models.EligibilityTraceUpdateMode;
import com.github.chen0040.rl.utils.Matrix;
@@ -11,125 +12,129 @@
* Created by xschen on 9/28/2015 0028.
*/
public class QLambdaLearner extends QLearner {
- private double lambda = 0.9;
+ private double lambda = DefaultValues.LAMBDA;
private Matrix e;
private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
+ @SuppressWarnings("Used-by-user")
+ public QLambdaLearner(final QLearner learner) {
+ this.copy(learner);
+ this.e = new Matrix(this.model.getStateCount(), this.model.getActionCount());
+ }
+
+ private QLambdaLearner() {
+ super();
+ }
+
+ @SuppressWarnings("Used-by-user")
+ public QLambdaLearner(final int stateCount, final int actionCount) {
+ super(stateCount, actionCount);
+ this.e = new Matrix(stateCount, actionCount);
+ }
+
+ @SuppressWarnings("Used-by-user")
+ public QLambdaLearner(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) {
+ super(stateCount, actionCount, alpha, gamma, initialQ);
+ this.e = new Matrix(stateCount, actionCount);
+ }
+
+ @SuppressWarnings("Used-by-user")
public EligibilityTraceUpdateMode getTraceUpdateMode() {
- return traceUpdateMode;
+ return this.traceUpdateMode;
}
- public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) {
+ @SuppressWarnings("Used-by-user")
+ public void setTraceUpdateMode(final EligibilityTraceUpdateMode traceUpdateMode) {
this.traceUpdateMode = traceUpdateMode;
}
- public double getLambda(){
- return lambda;
+ @SuppressWarnings("Used-by-user")
+ public double getLambda() {
+ return this.lambda;
}
- public void setLambda(double lambda){
+ @SuppressWarnings("Used-by-user")
+ public void setLambda(final double lambda) {
this.lambda = lambda;
}
- public QLambdaLearner makeCopy(){
- QLambdaLearner clone = new QLambdaLearner();
+ @Override
+ public QLambdaLearner makeCopy() {
+ final QLambdaLearner clone = new QLambdaLearner();
clone.copy(this);
return clone;
}
@Override
- public void copy(QLearner rhs){
+ public void copy(final QLearner rhs) {
super.copy(rhs);
- QLambdaLearner rhs2 = (QLambdaLearner)rhs;
- lambda = rhs2.lambda;
- e = rhs2.e.makeCopy();
- traceUpdateMode = rhs2.traceUpdateMode;
- }
-
- public QLambdaLearner(QLearner learner){
- copy(learner);
- e = new Matrix(model.getStateCount(), model.getActionCount());
+ final QLambdaLearner rhs2 = (QLambdaLearner) rhs;
+ this.lambda = rhs2.lambda;
+ this.e = rhs2.e.makeCopy();
+ this.traceUpdateMode = rhs2.traceUpdateMode;
}
@Override
- public boolean equals(Object obj){
- if(!super.equals(obj)){
+ public boolean equals(final Object obj) {
+ if (!super.equals(obj)) {
return false;
}
- if(obj instanceof QLambdaLearner){
- QLambdaLearner rhs = (QLambdaLearner)obj;
- return rhs.lambda == lambda && e.equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode;
+ if (obj instanceof QLambdaLearner) {
+ final QLambdaLearner rhs = (QLambdaLearner) obj;
+ return rhs.lambda == this.lambda && this.e.equals(rhs.e) && this.traceUpdateMode == rhs.traceUpdateMode;
}
return false;
}
- public QLambdaLearner(){
- super();
- }
-
- public QLambdaLearner(int stateCount, int actionCount){
- super(stateCount, actionCount);
- e = new Matrix(stateCount, actionCount);
- }
-
- public QLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ){
- super(stateCount, actionCount, alpha, gamma, initialQ);
- e = new Matrix(stateCount, actionCount);
+ @SuppressWarnings("Used-by-user")
+ public Matrix getEligibility() {
+ return this.e;
}
- public Matrix getEligibility()
- {
- return e;
- }
-
- public void setEligibility(Matrix e){
+ @SuppressWarnings("Used-by-user")
+ public void setEligibility(final Matrix e) {
this.e = e;
}
@Override
- public void update(int currentStateId, int currentActionId, int nextStateId, Set actionsAtNextStateId, double immediateReward)
- {
+ public void update(final int currentStateId, final int currentActionId, final int nextStateId, final Set actionsAtNextStateId, final double immediateReward) {
// old_value is $Q_t(s_t, a_t)$
- double oldQ = model.getQ(currentStateId, currentActionId);
+ double oldQ = this.model.getQ(currentStateId, currentActionId);
// learning_rate;
- double alpha = model.getAlpha(currentStateId, currentActionId);
+ final double alpha = this.model.getAlpha(currentStateId, currentActionId);
// discount_rate;
- double gamma = model.getGamma();
+ final double gamma = this.model.getGamma();
// estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$
- double maxQ = maxQAtState(nextStateId, actionsAtNextStateId);
-
- double td_error = immediateReward + gamma * maxQ - oldQ;
+ final double maxQ = this.maxQAtState(nextStateId, actionsAtNextStateId);
- int stateCount = model.getStateCount();
- int actionCount = model.getActionCount();
+ final double td_error = immediateReward + gamma * maxQ - oldQ;
- e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1);
+ final int stateCount = this.model.getStateCount();
+ final int actionCount = this.model.getActionCount();
+ this.e.set(currentStateId, currentActionId, this.e.get(currentStateId, currentActionId) + 1);
- for(int stateId = 0; stateId < stateCount; ++stateId){
- for(int actionId = 0; actionId < actionCount; ++actionId){
- oldQ = model.getQ(stateId, actionId);
- double newQ = oldQ + alpha * td_error * e.get(stateId, actionId);
+ for (int stateId = 0; stateId < stateCount; ++stateId) {
+ for (int actionId = 0; actionId < actionCount; ++actionId) {
+ oldQ = this.model.getQ(stateId, actionId);
+ final double newQ = oldQ + alpha * td_error * this.e.get(stateId, actionId);
// new_value is $Q_{t+1}(s_t, a_t)$
- model.setQ(currentStateId, currentActionId, newQ);
+ this.model.setQ(currentStateId, currentActionId, newQ);
if (actionId != currentActionId) {
- e.set(currentStateId, actionId, 0);
+ this.e.set(currentStateId, actionId, 0);
} else {
- e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda);
+ this.e.set(stateId, actionId, this.e.get(stateId, actionId) * gamma * this.lambda);
}
}
}
-
-
-
}
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java
index 865abc5..7970237 100644
--- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java
+++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java
@@ -2,7 +2,6 @@
import com.alibaba.fastjson.JSON;
-import com.alibaba.fastjson.annotation.JSONField;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy;
import com.github.chen0040.rl.actionselection.ActionSelectionStrategy;
@@ -12,131 +11,133 @@
import com.github.chen0040.rl.utils.IndexValue;
import java.io.Serializable;
-import java.util.Random;
import java.util.Set;
+import static com.github.chen0040.rl.models.DefaultValues.*;
+
/**
- * Created by xschen on 9/27/2015 0027.
- * Implement temporal-difference learning Q-Learning, which is an off-policy TD control algorithm
- * Q is known as the quality of state-action combination, note that it is different from utility of a state
+ * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Q-Learning, which is an off-policy TD
+ * control algorithm Q is known as the quality of state-action combination, note that it is different from utility of a
+ * state
*/
-public class QLearner implements Serializable,Cloneable {
+public class QLearner implements Serializable, Cloneable {
protected QModel model;
private ActionSelectionStrategy actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
- public QLearner makeCopy(){
- QLearner clone = new QLearner();
- clone.copy(this);
- return clone;
- }
+ public QLearner() {
- public String toJson() {
- return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
}
- public static QLearner fromJson(String json){
- return JSON.parseObject(json, QLearner.class);
+ public QLearner(final int stateCount, final int actionCount) {
+ this(stateCount, actionCount, ALPHA, GAMMA, INITIAL_Q);
}
- public void copy(QLearner rhs){
- model = rhs.model.makeCopy();
- actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone();
+ public QLearner(final QModel model, final ActionSelectionStrategy actionSelectionStrategy) {
+ this.model = model;
+ this.actionSelectionStrategy = actionSelectionStrategy;
}
- @Override
- public boolean equals(Object obj){
- if(obj !=null && obj instanceof QLearner){
- QLearner rhs = (QLearner)obj;
- if(!model.equals(rhs.model)) return false;
- return actionSelectionStrategy.equals(rhs.actionSelectionStrategy);
- }
- return false;
+ public QLearner(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) {
+ this.model = new QModel(stateCount, actionCount, initialQ);
+ this.model.setAlpha(alpha);
+ this.model.setGamma(gamma);
+ this.actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
}
- public QModel getModel() {
- return model;
+ @SuppressWarnings("Used-by-user")
+ public static QLearner fromJson(final String json) {
+ return JSON.parseObject(json, QLearner.class);
}
- public void setModel(QModel model) {
- this.model = model;
+ public QLearner makeCopy() {
+ final QLearner clone = new QLearner();
+ clone.copy(this);
+ return clone;
}
-
- public String getActionSelection() {
- return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy);
+ @SuppressWarnings("Used-by-user")
+ public String toJson() {
+ return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
}
- public void setActionSelection(String conf) {
- this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf);
+ public void copy(final QLearner rhs) {
+ this.model = rhs.model.makeCopy();
+ this.actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone();
}
- public QLearner(){
-
+ @Override
+ public boolean equals(final Object obj) {
+ if (obj instanceof QLearner) {
+ final QLearner rhs = (QLearner) obj;
+ if (!this.model.equals(rhs.model)) {
+ return false;
+ }
+ return this.actionSelectionStrategy.equals(rhs.actionSelectionStrategy);
+ }
+ return false;
}
- public QLearner(int stateCount, int actionCount){
- this(stateCount, actionCount, 0.1, 0.7, 0.1);
+ public QModel getModel() {
+ return this.model;
}
- public QLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy){
+ public void setModel(final QModel model) {
this.model = model;
- this.actionSelectionStrategy = actionSelectionStrategy;
}
- public QLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ)
- {
- model = new QModel(stateCount, actionCount, initialQ);
- model.setAlpha(alpha);
- model.setGamma(gamma);
- actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
+ @SuppressWarnings("Used-by-user")
+ public String getActionSelection() {
+ return ActionSelectionStrategyFactory.serialize(this.actionSelectionStrategy);
}
+ @SuppressWarnings("Used-by-user")
+ public void setActionSelection(final String conf) {
+ this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf);
+ }
- protected double maxQAtState(int stateId, Set actionsAtState){
- IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState);
- double maxQ = iv.getValue();
- return maxQ;
+ double maxQAtState(final int stateId, final Set actionsAtState) {
+ return this.model.actionWithMaxQAtState(stateId, actionsAtState).getValue();
}
- public IndexValue selectAction(int stateId, Set actionsAtState){
- return actionSelectionStrategy.selectAction(stateId, model, actionsAtState);
+ @SuppressWarnings("Used-by-user")
+ public IndexValue selectAction(final int stateId, final Set actionsAtState) {
+ return this.actionSelectionStrategy.selectAction(stateId, this.model, actionsAtState);
}
- public IndexValue selectAction(int stateId){
- return selectAction(stateId, null);
+ @SuppressWarnings("Used-by-user")
+ public IndexValue selectAction(final int stateId) {
+ return this.selectAction(stateId, null);
}
- public void update(int stateId, int actionId, int nextStateId, double immediateReward){
- update(stateId, actionId, nextStateId, null, immediateReward);
+ public void update(final int stateId, final int actionId, final int nextStateId, final double immediateReward) {
+ this.update(stateId, actionId, nextStateId, null, immediateReward);
}
- public void update(int stateId, int actionId, int nextStateId, Set actionsAtNextStateId, double immediateReward)
- {
+ public void update(final int stateId, final int actionId, final int nextStateId, final Set actionsAtNextStateId, final double immediateReward) {
// old_value is $Q_t(s_t, a_t)$
- double oldQ = model.getQ(stateId, actionId);
+ final double oldQ = this.model.getQ(stateId, actionId);
// learning_rate;
- double alpha = model.getAlpha(stateId, actionId);
+ final double alpha = this.model.getAlpha(stateId, actionId);
// discount_rate;
- double gamma = model.getGamma();
+ final double gamma = this.model.getGamma();
// estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$
- double maxQ = maxQAtState(nextStateId, actionsAtNextStateId);
+ final double maxQ = this.maxQAtState(nextStateId, actionsAtNextStateId);
// learned_value = immediate_reward + gamma * estimate_of_optimal_future_value
// old_value = oldQ
// temporal_difference = learned_value - old_value
// new_value = old_value + learning_rate * temporal_difference
- double newQ = oldQ + alpha * (immediateReward + gamma * maxQ - oldQ);
+ final double newQ = oldQ + alpha * (immediateReward + gamma * maxQ - oldQ);
// new_value is $Q_{t+1}(s_t, a_t)$
- model.setQ(stateId, actionId, newQ);
+ this.model.setQ(stateId, actionId, newQ);
}
-
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java b/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java
index f26f20a..d7ef30c 100644
--- a/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java
+++ b/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java
@@ -3,99 +3,95 @@
import com.github.chen0040.rl.utils.IndexValue;
import java.io.Serializable;
-import java.util.Random;
import java.util.Set;
/**
* Created by xschen on 9/27/2015 0027.
*/
-public class RAgent implements Serializable{
+public class RAgent implements Serializable {
private RLearner learner;
private int currentState;
private int currentAction;
private double currentValue;
- public int getCurrentState(){
- return currentState;
+ public RAgent() {
+
+ }
+
+ public RAgent(final int stateCount, final int actionCount, final double alpha, final double beta, final double rho, final double initialQ) {
+ this.learner = new RLearner(stateCount, actionCount, alpha, beta, rho, initialQ);
+ }
+
+ public RAgent(final int stateCount, final int actionCount) {
+ this.learner = new RLearner(stateCount, actionCount);
}
- public int getCurrentAction(){
- return currentAction;
+ @SuppressWarnings("Used-by-user")
+ public int getCurrentState() {
+ return this.currentState;
}
- public void start(int currentState){
+ @SuppressWarnings("Used-by-user")
+ public int getCurrentAction() {
+ return this.currentAction;
+ }
+
+ public void start(final int currentState) {
this.currentState = currentState;
}
- public RAgent makeCopy(){
- RAgent clone = new RAgent();
+ public RAgent makeCopy() {
+ final RAgent clone = new RAgent();
clone.copy(this);
return clone;
}
- public void copy(RAgent rhs){
- currentState = rhs.currentState;
- currentAction = rhs.currentAction;
- learner.copy(rhs.learner);
+ public void copy(final RAgent rhs) {
+ this.currentState = rhs.currentState;
+ this.currentAction = rhs.currentAction;
+ this.learner.copy(rhs.learner);
}
@Override
- public boolean equals(Object obj){
- if(obj != null && obj instanceof RAgent){
- RAgent rhs = (RAgent)obj;
- if(!learner.equals(rhs.learner)) return false;
- if(currentAction != rhs.currentAction) return false;
- return currentState == rhs.currentState;
+ public boolean equals(final Object obj) {
+ if (obj instanceof RAgent) {
+ final RAgent rhs = (RAgent) obj;
+ return this.learner.equals(rhs.learner) && this.currentAction == rhs.currentAction && this.currentState == rhs.currentState;
}
return false;
}
- public IndexValue selectAction(){
- return selectAction(null);
+ public IndexValue selectAction() {
+ return this.selectAction(null);
}
- public IndexValue selectAction(Set actionsAtState){
-
- if(currentAction==-1){
- IndexValue iv = learner.selectAction(currentState, actionsAtState);
- currentAction = iv.getIndex();
- currentValue = iv.getValue();
+ public IndexValue selectAction(final Set actionsAtState) {
+ if (this.currentAction == -1) {
+ final IndexValue iv = this.learner.selectAction(this.currentState, actionsAtState);
+ this.currentAction = iv.getIndex();
+ this.currentValue = iv.getValue();
}
- return new IndexValue(currentAction, currentValue);
+ return new IndexValue(this.currentAction, this.currentValue);
}
- public void update(int newState, double immediateReward){
- update(newState, null, immediateReward);
+ public void update(final int newState, final double immediateReward) {
+ this.update(newState, null, immediateReward);
}
- public void update(int newState, Set actionsAtState, double immediateReward){
- if(currentAction != -1) {
- learner.update(currentState, currentAction, newState, actionsAtState, immediateReward);
- currentState = newState;
- currentAction = -1;
+ public void update(final int newState, final Set actionsAtState, final double immediateReward) {
+ if (this.currentAction != -1) {
+ this.learner.update(this.currentState, this.currentAction, newState, actionsAtState, immediateReward);
+ this.currentState = newState;
+ this.currentAction = -1;
}
}
- public RAgent(){
-
+ public RLearner getLearner() {
+ return this.learner;
}
-
-
- public RLearner getLearner(){
- return learner;
- }
-
- public void setLearner(RLearner learner){
+ public void setLearner(final RLearner learner) {
this.learner = learner;
}
-
- public RAgent(int stateCount, int actionCount, double alpha, double beta, double rho, double initialQ){
- learner = new RLearner(stateCount, actionCount, alpha, beta, rho, initialQ);
- }
-
- public RAgent(int stateCount, int actionCount){
- learner = new RLearner(stateCount, actionCount);
- }
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java
index 910d53f..bc520e7 100644
--- a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java
+++ b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java
@@ -7,135 +7,141 @@
import com.github.chen0040.rl.actionselection.ActionSelectionStrategy;
import com.github.chen0040.rl.actionselection.ActionSelectionStrategyFactory;
import com.github.chen0040.rl.actionselection.EpsilonGreedyActionSelectionStrategy;
+import com.github.chen0040.rl.models.DefaultValues;
import com.github.chen0040.rl.models.QModel;
import com.github.chen0040.rl.utils.IndexValue;
-import lombok.Getter;
import java.io.Serializable;
import java.util.Set;
+import static com.github.chen0040.rl.models.DefaultValues.ALPHA;
+
/**
* Created by xschen on 9/27/2015 0027.
*/
-public class RLearner implements Serializable, Cloneable{
+public class RLearner implements Serializable, Cloneable {
private QModel model;
private ActionSelectionStrategy actionSelectionStrategy;
private double rho;
private double beta;
- public String toJson() {
- return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
+ public RLearner() {
+
+ }
+
+ public RLearner(final int stateCount, final int actionCount) {
+ this(stateCount, actionCount, ALPHA, DefaultValues.BETA, DefaultValues.RHO, DefaultValues.INITIAL_Q);
+ }
+
+ public RLearner(final int state_count, final int action_count, final double alpha, final double beta, final double rho, final double initial_Q) {
+ this.model = new QModel(state_count, action_count, initial_Q);
+ this.model.setAlpha(alpha);
+
+ this.rho = rho;
+ this.beta = beta;
+
+ this.actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
}
- public static RLearner fromJson(String json){
+ @SuppressWarnings("Used-by-user")
+ public static RLearner fromJson(final String json) {
return JSON.parseObject(json, RLearner.class);
}
- public RLearner makeCopy(){
- RLearner clone = new RLearner();
+ @SuppressWarnings("Used-by-user")
+ public String toJson() {
+ return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
+ }
+
+ public RLearner makeCopy() {
+ final RLearner clone = new RLearner();
clone.copy(this);
return clone;
}
- public void copy(RLearner rhs){
- model = rhs.model.makeCopy();
- actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy)rhs.actionSelectionStrategy).clone();
- rho = rhs.rho;
- beta = rhs.beta;
+ public void copy(final RLearner rhs) {
+ this.model = rhs.model.makeCopy();
+ this.actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone();
+ this.rho = rhs.rho;
+ this.beta = rhs.beta;
}
@Override
- public boolean equals(Object obj){
- if(obj != null && obj instanceof RLearner){
- RLearner rhs = (RLearner)obj;
- if(!model.equals(rhs.model)) return false;
- if(!actionSelectionStrategy.equals(rhs.actionSelectionStrategy)) return false;
- if(rho != rhs.rho) return false;
- return beta == rhs.beta;
+ public boolean equals(final Object obj) {
+ if (obj instanceof RLearner) {
+ final RLearner rhs = (RLearner) obj;
+ if (!this.model.equals(rhs.model)) {
+ return false;
+ }
+ if (!this.actionSelectionStrategy.equals(rhs.actionSelectionStrategy)) {
+ return false;
+ }
+ if (this.rho != rhs.rho) {
+ return false;
+ }
+ return this.beta == rhs.beta;
}
return false;
}
- public RLearner(){
-
- }
-
+ @SuppressWarnings("Used-by-user")
public double getRho() {
- return rho;
+ return this.rho;
}
- public void setRho(double rho) {
+ @SuppressWarnings("Used-by-user")
+ public void setRho(final double rho) {
this.rho = rho;
}
+ @SuppressWarnings("Used-by-user")
public double getBeta() {
- return beta;
+ return this.beta;
}
- public void setBeta(double beta) {
+ @SuppressWarnings("Used-by-user")
+ public void setBeta(final double beta) {
this.beta = beta;
}
- public QModel getModel(){
- return model;
+ public QModel getModel() {
+ return this.model;
}
- public void setModel(QModel model){
+ public void setModel(final QModel model) {
this.model = model;
}
- public String getActionSelection(){
- return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy);
+ @SuppressWarnings("Used-by-user")
+ public String getActionSelection() {
+ return ActionSelectionStrategyFactory.serialize(this.actionSelectionStrategy);
}
- public void setActionSelection(String conf){
+ @SuppressWarnings("Used-by-user")
+ public void setActionSelection(final String conf) {
this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf);
}
- public RLearner(int stateCount, int actionCount){
- this(stateCount, actionCount, 0.1, 0.1, 0.7, 0.1);
- }
-
- public RLearner(int state_count, int action_count, double alpha, double beta, double rho, double initial_Q)
- {
- model = new QModel(state_count, action_count, initial_Q);
- model.setAlpha(alpha);
-
- this.rho = rho;
- this.beta = beta;
-
- actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
- }
-
- private double maxQAtState(int stateId, Set actionsAtState){
- IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState);
- double maxQ = iv.getValue();
- return maxQ;
+ private double maxQAtState(final int stateId, final Set actionsAtState) {
+ return this.model.actionWithMaxQAtState(stateId, actionsAtState).getValue();
}
- public void update(int currentState, int actionTaken, int newState, Set actionsAtNextStateId, double immediate_reward)
- {
- double oldQ = model.getQ(currentState, actionTaken);
-
- double alpha = model.getAlpha(currentState, actionTaken); // learning rate;
-
- double maxQ = maxQAtState(newState, actionsAtNextStateId);
-
- double newQ = oldQ + alpha * (immediate_reward - rho + maxQ - oldQ);
-
- double maxQAtCurrentState = maxQAtState(currentState, null);
- if (newQ == maxQAtCurrentState)
- {
- rho = rho + beta * (immediate_reward - rho + maxQ - maxQAtCurrentState);
+ public void update(final int currentState, final int actionTaken, final int newState, final Set actionsAtNextStateId, final double immediate_reward) {
+ final double oldQ = this.model.getQ(currentState, actionTaken);
+ final double alpha = this.model.getAlpha(currentState, actionTaken); // learning rate;
+ final double maxQ = this.maxQAtState(newState, actionsAtNextStateId);
+ final double newQ = oldQ + alpha * (immediate_reward - this.rho + maxQ - oldQ);
+ final double maxQAtCurrentState = this.maxQAtState(currentState, null);
+ if (newQ == maxQAtCurrentState) {
+ this.rho += this.beta * (immediate_reward - this.rho + maxQ - maxQAtCurrentState);
}
-
- model.setQ(currentState, actionTaken, newQ);
+ this.model.setQ(currentState, actionTaken, newQ);
}
- public IndexValue selectAction(int stateId, Set actionsAtState){
- return actionSelectionStrategy.selectAction(stateId, model, actionsAtState);
+ public IndexValue selectAction(final int stateId, final Set actionsAtState) {
+ return this.actionSelectionStrategy.selectAction(stateId, this.model, actionsAtState);
}
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java
index c4c8f27..80335fd 100644
--- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java
+++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java
@@ -3,15 +3,14 @@
import com.github.chen0040.rl.utils.IndexValue;
import java.io.Serializable;
-import java.util.Random;
import java.util.Set;
/**
- * Created by xschen on 9/27/2015 0027.
- * Implement temporal-difference learning Sarsa, which is an on-policy TD control algorithm
+ * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Sarsa, which is an on-policy TD control
+ * algorithm
*/
-public class SarsaAgent implements Serializable{
+public class SarsaAgent implements Serializable {
private SarsaLearner learner;
private int currentState;
private int currentAction;
@@ -19,111 +18,118 @@ public class SarsaAgent implements Serializable{
private int prevState;
private int prevAction;
- public int getCurrentState(){
- return currentState;
+ public SarsaAgent(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) {
+ this.learner = new SarsaLearner(stateCount, actionCount, alpha, gamma, initialQ);
}
- public int getCurrentAction(){
- return currentAction;
+ public SarsaAgent(final int stateCount, final int actionCount) {
+ this.learner = new SarsaLearner(stateCount, actionCount);
}
- public int getPrevState() { return prevState; }
+ public SarsaAgent(final SarsaLearner learner) {
+ this.learner = learner;
+ }
- public int getPrevAction() { return prevAction; }
+ public SarsaAgent() {
- public void start(int currentState){
- this.currentState = currentState;
- this.prevState = -1;
- this.prevAction = -1;
}
- public IndexValue selectAction(){
- return selectAction(null);
+ @SuppressWarnings("Used-by-user")
+ public int getCurrentState() {
+ return this.currentState;
}
- public IndexValue selectAction(Set actionsAtState){
- if(currentAction == -1){
- IndexValue iv = learner.selectAction(currentState, actionsAtState);
- currentAction = iv.getIndex();
- currentValue = iv.getValue();
- }
+ @SuppressWarnings("Used-by-user")
+ public int getCurrentAction() {
+ return this.currentAction;
+ }
- return new IndexValue(currentAction, currentValue);
+ @SuppressWarnings("Used-by-user")
+ public int getPrevState() {
+ return this.prevState;
}
- public void update(int actionTaken, int newState, double immediateReward){
- update(actionTaken, newState, null, immediateReward);
+ @SuppressWarnings("Used-by-user")
+ public int getPrevAction() {
+ return this.prevAction;
}
- public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward){
+ public void start(final int currentState) {
+ this.currentState = currentState;
+ this.prevState = -1;
+ this.prevAction = -1;
+ }
- IndexValue iv = learner.selectAction(currentState, actionsAtNewState);
- int futureAction = iv.getIndex();
+ public IndexValue selectAction() {
+ return this.selectAction(null);
+ }
- learner.update(currentState, actionTaken, newState, futureAction, immediateReward);
+ public IndexValue selectAction(final Set actionsAtState) {
+ if (this.currentAction == -1) {
+ final IndexValue iv = this.learner.selectAction(this.currentState, actionsAtState);
+ this.currentAction = iv.getIndex();
+ this.currentValue = iv.getValue();
+ }
- prevState = this.currentState;
- this.prevAction = actionTaken;
+ return new IndexValue(this.currentAction, this.currentValue);
+ }
- currentAction = futureAction;
- currentState = newState;
+ public void update(final int actionTaken, final int newState, final double immediateReward) {
+ this.update(actionTaken, newState, null, immediateReward);
}
+ public void update(final int actionTaken, final int newState, final Set actionsAtNewState, final double immediateReward) {
+ final IndexValue iv = this.learner.selectAction(this.currentState, actionsAtNewState);
+ final int futureAction = iv.getIndex();
- public SarsaLearner getLearner(){
- return learner;
- }
+ this.learner.update(this.currentState, actionTaken, newState, futureAction, immediateReward);
- public void setLearner(SarsaLearner learner){
- this.learner = learner;
- }
+ this.prevState = this.currentState;
+ this.prevAction = actionTaken;
- public SarsaAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ){
- learner = new SarsaLearner(stateCount, actionCount, alpha, gamma, initialQ);
+ this.currentAction = futureAction;
+ this.currentState = newState;
}
- public SarsaAgent(int stateCount, int actionCount){
- learner = new SarsaLearner(stateCount, actionCount);
+ public SarsaLearner getLearner() {
+ return this.learner;
}
- public SarsaAgent(SarsaLearner learner){
+ public void setLearner(final SarsaLearner learner) {
this.learner = learner;
}
- public SarsaAgent(){
-
- }
-
- public void enableEligibilityTrace(double lambda){
- SarsaLambdaLearner acll = new SarsaLambdaLearner(learner);
+ @SuppressWarnings("Used-by-user")
+ public void enableEligibilityTrace(final double lambda) {
+ final SarsaLambdaLearner acll = new SarsaLambdaLearner(this.learner);
acll.setLambda(lambda);
- learner = acll;
+ this.learner = acll;
}
- public SarsaAgent makeCopy(){
- SarsaAgent clone = new SarsaAgent();
+ public SarsaAgent makeCopy() {
+ final SarsaAgent clone = new SarsaAgent();
clone.copy(this);
return clone;
}
- public void copy(SarsaAgent rhs){
- learner.copy(rhs.learner);
- currentAction = rhs.currentAction;
- currentState = rhs.currentState;
- prevAction = rhs.prevAction;
- prevState = rhs.prevState;
+ public void copy(final SarsaAgent rhs) {
+ this.learner.copy(rhs.learner);
+ this.currentAction = rhs.currentAction;
+ this.currentState = rhs.currentState;
+ this.prevAction = rhs.prevAction;
+ this.prevState = rhs.prevState;
}
@Override
- public boolean equals(Object obj){
- if(obj != null && obj instanceof SarsaAgent){
- SarsaAgent rhs = (SarsaAgent)obj;
- return prevAction == rhs.prevAction
- && prevState == rhs.prevState
- && currentAction == rhs.currentAction
- && currentState == rhs.currentState
- && learner.equals(rhs.learner);
+ public boolean equals(final Object obj) {
+ if (obj instanceof SarsaAgent) {
+ final SarsaAgent rhs = (SarsaAgent) obj;
+ return this.prevAction == rhs.prevAction
+ && this.prevState == rhs.prevState
+ && this.currentAction == rhs.currentAction
+ && this.currentState == rhs.currentState
+ && this.learner.equals(rhs.learner);
}
return false;
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java
index e51543e..6298ac8 100644
--- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java
+++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java
@@ -1,6 +1,7 @@
package com.github.chen0040.rl.learning.sarsa;
+import com.github.chen0040.rl.models.DefaultValues;
import com.github.chen0040.rl.models.EligibilityTraceUpdateMode;
import com.github.chen0040.rl.utils.Matrix;
@@ -9,119 +10,123 @@
* Created by xschen on 9/28/2015 0028.
*/
public class SarsaLambdaLearner extends SarsaLearner {
- private double lambda = 0.9;
+ private double lambda = DefaultValues.LAMBDA;
private Matrix e;
private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
+ public SarsaLambdaLearner(final int stateCount, final int actionCount) {
+ super(stateCount, actionCount);
+ this.e = new Matrix(stateCount, actionCount);
+ }
+
+ public SarsaLambdaLearner(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) {
+ super(stateCount, actionCount, alpha, gamma, initialQ);
+ this.e = new Matrix(stateCount, actionCount);
+ }
+
+ @SuppressWarnings("Used-by-user")
+ public SarsaLambdaLearner(final SarsaLearner learner) {
+ this.copy(learner);
+ this.e = new Matrix(this.model.getStateCount(), this.model.getActionCount());
+ }
+
+ private SarsaLambdaLearner() {
+
+ }
+
+ @SuppressWarnings("Used-by-user")
public EligibilityTraceUpdateMode getTraceUpdateMode() {
- return traceUpdateMode;
+ return this.traceUpdateMode;
}
- public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) {
+ @SuppressWarnings("Used-by-user")
+ public void setTraceUpdateMode(final EligibilityTraceUpdateMode traceUpdateMode) {
this.traceUpdateMode = traceUpdateMode;
}
- public double getLambda(){
- return lambda;
+ @SuppressWarnings("Used-by-user")
+ public double getLambda() {
+ return this.lambda;
}
- public void setLambda(double lambda){
+ public void setLambda(final double lambda) {
this.lambda = lambda;
}
@Override
- public Object clone(){
- SarsaLambdaLearner clone = new SarsaLambdaLearner();
+ public SarsaLambdaLearner clone() {
+ final SarsaLambdaLearner clone = new SarsaLambdaLearner();
clone.copy(this);
return clone;
}
@Override
- public void copy(SarsaLearner rhs){
+ public void copy(final SarsaLearner rhs) {
super.copy(rhs);
- SarsaLambdaLearner rhs2 = (SarsaLambdaLearner)rhs;
- lambda = rhs2.lambda;
- e = rhs2.e.makeCopy();
- traceUpdateMode = rhs2.traceUpdateMode;
+ final SarsaLambdaLearner rhs2 = (SarsaLambdaLearner) rhs;
+ this.lambda = rhs2.lambda;
+ this.e = rhs2.e.makeCopy();
+ this.traceUpdateMode = rhs2.traceUpdateMode;
}
@Override
- public boolean equals(Object obj){
- if(!super.equals(obj)){
+ public boolean equals(final Object obj) {
+ if (!super.equals(obj)) {
return false;
}
- if(obj instanceof SarsaLambdaLearner){
- SarsaLambdaLearner rhs = (SarsaLambdaLearner)obj;
- return rhs.lambda == lambda && e.equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode;
+ if (obj instanceof SarsaLambdaLearner) {
+ final SarsaLambdaLearner rhs = (SarsaLambdaLearner) obj;
+ return rhs.lambda == this.lambda && this.e.equals(rhs.e) && this.traceUpdateMode == rhs.traceUpdateMode;
}
return false;
}
- public SarsaLambdaLearner(){
- super();
- }
-
- public SarsaLambdaLearner(int stateCount, int actionCount){
- super(stateCount, actionCount);
- e = new Matrix(stateCount, actionCount);
- }
-
- public SarsaLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ){
- super(stateCount, actionCount, alpha, gamma, initialQ);
- e = new Matrix(stateCount, actionCount);
- }
-
- public SarsaLambdaLearner(SarsaLearner learner){
- copy(learner);
- e = new Matrix(model.getStateCount(), model.getActionCount());
- }
-
- public Matrix getEligibility()
- {
- return e;
+ @SuppressWarnings("Used-by-user")
+ public Matrix getEligibility() {
+ return this.e;
}
- public void setEligibility(Matrix e){
+ @SuppressWarnings("Used-by-user")
+ public void setEligibility(final Matrix e) {
this.e = e;
}
@Override
- public void update(int currentStateId, int currentActionId, int nextStateId, int nextActionId, double immediateReward)
- {
+ public void update(final int currentStateId, final int currentActionId, final int nextStateId, final int nextActionId, final double immediateReward) {
// old_value is $Q_t(s_t, a_t)$
- double oldQ = model.getQ(currentStateId, currentActionId);
+ double oldQ = this.model.getQ(currentStateId, currentActionId);
// learning_rate;
- double alpha = model.getAlpha(currentStateId, currentActionId);
+ final double alpha = this.model.getAlpha(currentStateId, currentActionId);
// discount_rate;
- double gamma = model.getGamma();
+ final double gamma = this.model.getGamma();
// estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$
- double nextQ = model.getQ(nextStateId, nextActionId);
+ final double nextQ = this.model.getQ(nextStateId, nextActionId);
- double td_error = immediateReward + gamma * nextQ - oldQ;
+ final double td_error = immediateReward + gamma * nextQ - oldQ;
- int stateCount = model.getStateCount();
- int actionCount = model.getActionCount();
+ final int stateCount = this.model.getStateCount();
+ final int actionCount = this.model.getActionCount();
- e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1);
+ this.e.set(currentStateId, currentActionId, this.e.get(currentStateId, currentActionId) + 1);
- for(int stateId = 0; stateId < stateCount; ++stateId){
- for(int actionId = 0; actionId < actionCount; ++actionId){
- oldQ = model.getQ(stateId, actionId);
+ for (int stateId = 0; stateId < stateCount; ++stateId) {
+ for (int actionId = 0; actionId < actionCount; ++actionId) {
+ oldQ = this.model.getQ(stateId, actionId);
- double newQ = oldQ + alpha * td_error * e.get(stateId, actionId);
+ final double newQ = oldQ + alpha * td_error * this.e.get(stateId, actionId);
- model.setQ(stateId, actionId, newQ);
+ this.model.setQ(stateId, actionId, newQ);
if (actionId != currentActionId) {
- e.set(currentStateId, actionId, 0);
+ this.e.set(currentStateId, actionId, 0);
} else {
- e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda);
+ this.e.set(stateId, actionId, this.e.get(stateId, actionId) * gamma * this.lambda);
}
}
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java
index 7fef780..3ce748a 100644
--- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java
+++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java
@@ -11,150 +11,125 @@
import com.github.chen0040.rl.utils.IndexValue;
import java.io.Serializable;
-import java.util.Random;
import java.util.Set;
+import static com.github.chen0040.rl.models.DefaultValues.*;
+
/**
- * Created by xschen on 9/27/2015 0027.
- * Implement temporal-difference learning Q-Learning, which is an off-policy TD control algorithm
- * Q is known as the quality of state-action combination, note that it is different from utility of a state
+ * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Q-Learning, which is an off-policy TD
+ * control algorithm Q is known as the quality of state-action combination, note that it is different from utility of a
+ * state
*/
-public class SarsaLearner implements Serializable,Cloneable {
+public class SarsaLearner implements Serializable, Cloneable {
protected QModel model;
private ActionSelectionStrategy actionSelectionStrategy;
- public String toJson() {
- return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
- }
+ @SuppressWarnings("Used-by-user")
+ public SarsaLearner() {
- public static SarsaLearner fromJson(String json){
- return JSON.parseObject(json, SarsaLearner.class);
}
- public SarsaLearner makeCopy(){
- SarsaLearner clone = new SarsaLearner();
- clone.copy(this);
- return clone;
+ @SuppressWarnings("Used-by-user")
+ public SarsaLearner(final int stateCount, final int actionCount) {
+ this(stateCount, actionCount, ALPHA, GAMMA, INITIAL_Q);
}
- public void copy(SarsaLearner rhs){
- model = rhs.model.makeCopy();
- actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone();
+ @SuppressWarnings("Used-by-user")
+ public SarsaLearner(final QModel model, final ActionSelectionStrategy actionSelectionStrategy) {
+ this.model = model;
+ this.actionSelectionStrategy = actionSelectionStrategy;
}
- @Override
- public boolean equals(Object obj){
- if(obj !=null && obj instanceof SarsaLearner){
- SarsaLearner rhs = (SarsaLearner)obj;
- if(!model.equals(rhs.model)) return false;
- return actionSelectionStrategy.equals(rhs.actionSelectionStrategy);
- }
- return false;
+ @SuppressWarnings("Used-by-user")
+ public SarsaLearner(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) {
+ this.model = new QModel(stateCount, actionCount, initialQ);
+ this.model.setAlpha(alpha);
+ this.model.setGamma(gamma);
+ this.actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
}
- public QModel getModel() {
- return model;
+ @SuppressWarnings("Used-by-user")
+ public static SarsaLearner fromJson(final String json) {
+ return JSON.parseObject(json, SarsaLearner.class);
}
- public void setModel(QModel model) {
- this.model = model;
+ @SuppressWarnings("Used-by-user")
+ public String toJson() {
+ return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
}
- public String getActionSelection() {
- return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy);
+ public SarsaLearner makeCopy() {
+ final SarsaLearner clone = new SarsaLearner();
+ clone.copy(this);
+ return clone;
}
- public void setActionSelection(String conf) {
- this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf);
+ public void copy(final SarsaLearner rhs) {
+ this.model = rhs.model.makeCopy();
+ this.actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone();
}
- public SarsaLearner(){
-
+ @Override
+ public boolean equals(final Object obj) {
+ if (obj instanceof SarsaLearner) {
+ final SarsaLearner rhs = (SarsaLearner) obj;
+ if (!this.model.equals(rhs.model)) {
+ return false;
+ }
+ return this.actionSelectionStrategy.equals(rhs.actionSelectionStrategy);
+ }
+ return false;
}
- public SarsaLearner(int stateCount, int actionCount){
- this(stateCount, actionCount, 0.1, 0.7, 0.1);
+ public QModel getModel() {
+ return this.model;
}
- public SarsaLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy){
+ public void setModel(final QModel model) {
this.model = model;
- this.actionSelectionStrategy = actionSelectionStrategy;
}
- public SarsaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ)
- {
- model = new QModel(stateCount, actionCount, initialQ);
- model.setAlpha(alpha);
- model.setGamma(gamma);
- actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
+ @SuppressWarnings("Used-by-user")
+ public String getActionSelection() {
+ return ActionSelectionStrategyFactory.serialize(this.actionSelectionStrategy);
}
- public static void main(String[] args){
- int stateCount = 100;
- int actionCount = 10;
-
- SarsaLearner learner = new SarsaLearner(stateCount, actionCount);
-
- double reward = 0; // reward gained by transiting from prevState to currentState
- Random random = new Random();
- int currentStateId = random.nextInt(stateCount);
- int currentActionId = learner.selectAction(currentStateId).getIndex();
-
- for(int time=0; time < 1000; ++time){
-
- System.out.println("Controller does action-"+currentActionId);
-
- int newStateId = random.nextInt(actionCount);
- reward = random.nextDouble();
-
- System.out.println("Now the new state is " + newStateId);
- System.out.println("Controller receives Reward = " + reward);
-
- int futureActionId = learner.selectAction(newStateId).getIndex();
-
- System.out.println("Controller is expected to do action-"+futureActionId);
-
- learner.update(currentStateId, currentActionId, newStateId, futureActionId, reward);
-
- currentStateId = newStateId;
- currentActionId = futureActionId;
- }
+ @SuppressWarnings("Used-by-user")
+ public void setActionSelection(final String conf) {
+ this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf);
}
-
- public IndexValue selectAction(int stateId, Set actionsAtState){
- return actionSelectionStrategy.selectAction(stateId, model, actionsAtState);
+ public IndexValue selectAction(final int stateId, final Set actionsAtState) {
+ return this.actionSelectionStrategy.selectAction(stateId, this.model, actionsAtState);
}
- public IndexValue selectAction(int stateId){
- return selectAction(stateId, null);
+ public IndexValue selectAction(final int stateId) {
+ return this.selectAction(stateId, null);
}
- public void update(int stateId, int actionId, int nextStateId, int nextActionId, double immediateReward)
- {
+ public void update(final int stateId, final int actionId, final int nextStateId, final int nextActionId, final double immediateReward) {
// old_value is $Q_t(s_t, a_t)$
- double oldQ = model.getQ(stateId, actionId);
+ final double oldQ = this.model.getQ(stateId, actionId);
// learning_rate;
- double alpha = model.getAlpha(stateId, actionId);
+ final double alpha = this.model.getAlpha(stateId, actionId);
// discount_rate;
- double gamma = model.getGamma();
+ final double gamma = this.model.getGamma();
// estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$
- double nextQ = model.getQ(nextStateId, nextActionId);
+ final double nextQ = this.model.getQ(nextStateId, nextActionId);
// learned_value = immediate_reward + gamma * estimate_of_optimal_future_value
// old_value = oldQ
// temporal_difference = learned_value - old_value
// new_value = old_value + learning_rate * temporal_difference
- double newQ = oldQ + alpha * (immediateReward + gamma * nextQ - oldQ);
+ final double newQ = oldQ + alpha * (immediateReward + gamma * nextQ - oldQ);
// new_value is $Q_{t+1}(s_t, a_t)$
- model.setQ(stateId, actionId, newQ);
+ this.model.setQ(stateId, actionId, newQ);
}
-
}
diff --git a/src/main/java/com/github/chen0040/rl/models/DefaultValues.java b/src/main/java/com/github/chen0040/rl/models/DefaultValues.java
new file mode 100644
index 0000000..879ea64
--- /dev/null
+++ b/src/main/java/com/github/chen0040/rl/models/DefaultValues.java
@@ -0,0 +1,11 @@
+package com.github.chen0040.rl.models;
+
+public enum DefaultValues {
+ ;
+ public static final double GAMMA = 0.9;
+ public static final double ALPHA = 0.1;
+ public static final double INITIAL_Q = 0.1;
+ public static final double LAMBDA = 0.9;
+ public static final double BETA = 0.1;
+ public static final double RHO = 0.7;
+}
diff --git a/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java b/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java
index e25380f..bd891f0 100644
--- a/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java
+++ b/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java
@@ -4,6 +4,5 @@
* Created by xschen on 9/28/2015 0028.
*/
public enum EligibilityTraceUpdateMode {
- ReplaceTrace,
- AccumulateTrace
+ ReplaceTrace
}
diff --git a/src/main/java/com/github/chen0040/rl/models/QModel.java b/src/main/java/com/github/chen0040/rl/models/QModel.java
index 2d314a1..d48e1c1 100644
--- a/src/main/java/com/github/chen0040/rl/models/QModel.java
+++ b/src/main/java/com/github/chen0040/rl/models/QModel.java
@@ -4,149 +4,147 @@
import com.github.chen0040.rl.utils.IndexValue;
import com.github.chen0040.rl.utils.Matrix;
import com.github.chen0040.rl.utils.Vec;
-import lombok.Getter;
-import lombok.Setter;
import java.util.*;
/**
- * @author xschen
- * 9/27/2015 0027.
- * Q is known as the quality of state-action combination, note that it is different from utility of a state
+ * @author xschen 9/27/2015 0027. Q is known as the quality of state-action combination, note that it is different from
+ * utility of a state
*/
-@Getter
-@Setter
public class QModel {
/**
- * Q value for (state_id, action_id) pair
- * Q is known as the quality of state-action combination, note that it is different from utility of a state
- */
+ * Q value for (state_id, action_id) pair Q is known as the quality of state-action combination, note that it is
+ * different from utility of a state
+ */
+
private Matrix Q;
/**
- * $\alpha[s, a]$ value for learning rate: alpha(state_id, action_id)
- */
+ * $\alpha[s, a]$ value for learning rate: alpha(state_id, action_id)
+ */
+
private Matrix alphaMatrix;
/**
* discount factor
*/
- private double gamma = 0.7;
+ private double gamma = DefaultValues.GAMMA;
private int stateCount;
private int actionCount;
- public QModel(int stateCount, int actionCount, double initialQ){
+ public QModel(final int stateCount, final int actionCount, final double initialQ) {
this.stateCount = stateCount;
this.actionCount = actionCount;
- Q = new Matrix(stateCount,actionCount);
- alphaMatrix = new Matrix(stateCount, actionCount);
- Q.setAll(initialQ);
- alphaMatrix.setAll(0.1);
+ this.Q = new Matrix(stateCount, actionCount);
+ this.alphaMatrix = new Matrix(stateCount, actionCount);
+ this.Q.setAll(initialQ);
+ this.alphaMatrix.setAll(DefaultValues.ALPHA);
}
- public QModel(int stateCount, int actionCount){
- this(stateCount, actionCount, 0.1);
+ public QModel(final int stateCount, final int actionCount) {
+ this(stateCount, actionCount, DefaultValues.INITIAL_Q);
}
- public QModel(){
+ public QModel() {
}
@Override
- public boolean equals(Object rhs){
- if(rhs != null && rhs instanceof QModel){
- QModel rhs2 = (QModel)rhs;
-
-
- if(gamma != rhs2.gamma) return false;
-
-
- if(stateCount != rhs2.stateCount || actionCount != rhs2.actionCount) return false;
-
- if((Q!=null && rhs2.Q==null) || (Q==null && rhs2.Q !=null)) return false;
- if((alphaMatrix !=null && rhs2.alphaMatrix ==null) || (alphaMatrix ==null && rhs2.alphaMatrix !=null)) return false;
-
- return !((Q != null && !Q.equals(rhs2.Q)) || (alphaMatrix != null && !alphaMatrix.equals(rhs2.alphaMatrix)));
-
+ public boolean equals(final Object rhs) {
+ if (rhs instanceof QModel) {
+ final QModel rhs2 = (QModel) rhs;
+ return this.gamma == rhs2.gamma &&
+ this.stateCount == rhs2.stateCount &&
+ this.actionCount == rhs2.actionCount &&
+ (this.Q == null || rhs2.Q != null) &&
+ (this.Q != null || rhs2.Q == null) &&
+ (this.alphaMatrix == null || rhs2.alphaMatrix != null) &&
+ (this.alphaMatrix != null || rhs2.alphaMatrix == null) &&
+ !((this.Q != null && !this.Q.equals(rhs2.Q)) || (this.alphaMatrix != null && !this.alphaMatrix.equals(rhs2.alphaMatrix)));
}
return false;
}
- public QModel makeCopy(){
- QModel clone = new QModel();
+ public QModel makeCopy() {
+ final QModel clone = new QModel();
clone.copy(this);
return clone;
}
- public void copy(QModel rhs){
- gamma = rhs.gamma;
- stateCount = rhs.stateCount;
- actionCount = rhs.actionCount;
- Q = rhs.Q==null ? null : rhs.Q.makeCopy();
- alphaMatrix = rhs.alphaMatrix == null ? null : rhs.alphaMatrix.makeCopy();
+ public void copy(final QModel rhs) {
+ this.gamma = rhs.gamma;
+ this.stateCount = rhs.stateCount;
+ this.actionCount = rhs.actionCount;
+ this.Q = rhs.Q == null ? null : rhs.Q.makeCopy();
+ this.alphaMatrix = rhs.alphaMatrix == null ? null : rhs.alphaMatrix.makeCopy();
}
- public double getQ(int stateId, int actionId){
- return Q.get(stateId, actionId);
+ public double getQ(final int stateId, final int actionId) {
+ assert this.Q != null;
+ return this.Q.get(stateId, actionId);
}
- public void setQ(int stateId, int actionId, double Qij){
- Q.set(stateId, actionId, Qij);
+ public void setQ(final int stateId, final int actionId, final double Qij) {
+ assert this.Q != null;
+ this.Q.set(stateId, actionId, Qij);
}
- public double getAlpha(int stateId, int actionId){
- return alphaMatrix.get(stateId, actionId);
+ public double getAlpha(final int stateId, final int actionId) {
+ assert this.alphaMatrix != null;
+ return this.alphaMatrix.get(stateId, actionId);
}
- public void setAlpha(double defaultAlpha) {
+ public void setAlpha(final double defaultAlpha) {
+ assert this.alphaMatrix != null;
this.alphaMatrix.setAll(defaultAlpha);
}
- public IndexValue actionWithMaxQAtState(int stateId, Set actionsAtState){
- Vec rowVector = Q.rowAt(stateId);
+ public IndexValue actionWithMaxQAtState(final int stateId, final Set actionsAtState) {
+ assert this.Q != null;
+ final Vec rowVector = this.Q.rowAt(stateId);
return rowVector.indexWithMaxValue(actionsAtState);
}
- private void reset(double initialQ){
- Q.setAll(initialQ);
+ private void reset(final double initialQ) {
+ assert this.Q != null;
+ this.Q.setAll(initialQ);
}
- public IndexValue actionWithSoftMaxQAtState(int stateId,Set actionsAtState, Random random) {
- Vec rowVector = Q.rowAt(stateId);
+ public IndexValue actionWithSoftMaxQAtState(final int stateId, final Set actionsAtState, final Random random) {
+ Set atState = actionsAtState;
+ assert this.Q != null;
+ final Vec rowVector = this.Q.rowAt(stateId);
double sum = 0;
- if(actionsAtState==null){
- actionsAtState = new HashSet<>();
- for(int i=0; i < actionCount; ++i){
- actionsAtState.add(i);
+ if (atState == null) {
+ atState = new HashSet<>();
+ for (int i = 0; i < this.actionCount; ++i) {
+ atState.add(i);
}
}
- List actions = new ArrayList<>();
- for(Integer actionId : actionsAtState){
- actions.add(actionId);
- }
+ final List actions = new ArrayList<>(atState);
- double[] acc = new double[actions.size()];
- for(int i=0; i < actions.size(); ++i){
+ final double[] acc = new double[actions.size()];
+ for (int i = 0; i < actions.size(); ++i) {
sum += rowVector.get(actions.get(i));
acc[i] = sum;
}
- double r = random.nextDouble() * sum;
+ final double r = random.nextDouble() * sum;
- IndexValue result = new IndexValue();
- for(int i=0; i < actions.size(); ++i){
- if(acc[i] >= r){
- int actionId = actions.get(i);
+ final IndexValue result = new IndexValue();
+ for (int i = 0; i < actions.size(); ++i) {
+ if (acc[i] >= r) {
+ final int actionId = actions.get(i);
result.setIndex(actionId);
result.setValue(rowVector.get(actionId));
break;
@@ -155,4 +153,38 @@ public IndexValue actionWithSoftMaxQAtState(int stateId,Set actionsAtSt
return result;
}
+
+
+ public Matrix getQ() {
+ return this.Q;
+ }
+
+
+ public Matrix getAlphaMatrix() {
+ return this.alphaMatrix;
+ }
+
+ public double getGamma() {
+ return this.gamma;
+ }
+
+ public void setGamma(final double gamma) {
+ this.gamma = gamma;
+ }
+
+ public int getStateCount() {
+ return this.stateCount;
+ }
+
+ public void setStateCount(final int stateCount) {
+ this.stateCount = stateCount;
+ }
+
+ public int getActionCount() {
+ return this.actionCount;
+ }
+
+ public void setActionCount(final int actionCount) {
+ this.actionCount = actionCount;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java
index cff1859..4496b62 100644
--- a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java
+++ b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java
@@ -1,91 +1,79 @@
package com.github.chen0040.rl.models;
import com.github.chen0040.rl.utils.Vec;
-import lombok.Getter;
-import lombok.Setter;
import java.io.Serializable;
/**
- * @author xschen
- * 9/27/2015 0027.
- * Utility value of a state $U(s)$ is the expected long term reward of state $s$ given the sequence of reward and the optimal policy
- * Utility value $U(s)$ at state $s$ can be obtained by the Bellman equation
- * Bellman Equtation states that $U(s) = R(s) + \gamma * max_a \sum_{s'} T(s,a,s')U(s')$
- * where s' is the possible transitioned state given that action $a$ is applied at state $s$
- * where $T(s,a,s')$ is the transition probability of $s \rightarrow s'$ given that action $a$ is applied at state $s$
- * where $\sum_{s'} T(s,a,s')U(s')$ is the expected long term reward given that action $a$ is applied at state $s$
- * where $max_a \sum_{s'} T(s,a,s')U(s')$ is the maximum expected long term reward given that the chosen optimal action $a$ is applied at state $s$
+ * @author xschen 9/27/2015 0027. Utility value of a state $U(s)$ is the expected long term reward of state $s$ given
+ * the sequence of reward and the optimal policy Utility value $U(s)$ at state $s$ can be obtained by the
+ * Bellman equation Bellman Equtation states that $U(s) = R(s) + \gamma * max_a \sum_{s'} T(s,a,s')U(s')$ where
+ * s' is the possible transitioned state given that action $a$ is applied at state $s$ where $T(s,a,s')$ is the
+ * transition probability of $s \rightarrow s'$ given that action $a$ is applied at state $s$ where $\sum_{s'}
+ * T(s,a,s')U(s')$ is the expected long term reward given that action $a$ is applied at state $s$ where $max_a
+ * \sum_{s'} T(s,a,s')U(s')$ is the maximum expected long term reward given that the chosen optimal action $a$
+ * is applied at state $s$
*/
-@Getter
-@Setter
public class UtilityModel implements Serializable {
+
private Vec U;
private int stateCount;
private int actionCount;
- public void setU(Vec U){
- this.U = U;
- }
-
- public Vec getU() {
- return U;
+ @SuppressWarnings("Used-by-user")
+ public UtilityModel(final int stateCount, final int actionCount, final double initialU) {
+ this.stateCount = stateCount;
+ this.actionCount = actionCount;
+ this.U = new Vec(stateCount);
+ this.U.setAll(initialU);
}
- public double getU(int stateId){
- return U.get(stateId);
+ private UtilityModel() {
}
public int getStateCount() {
- return stateCount;
+ return this.stateCount;
}
- public int getActionCount() {
- return actionCount;
- }
-
- public UtilityModel(int stateCount, int actionCount, double initialU){
+ public void setStateCount(final int stateCount) {
this.stateCount = stateCount;
- this.actionCount = actionCount;
- U = new Vec(stateCount);
- U.setAll(initialU);
}
- public UtilityModel(int stateCount, int actionCount){
- this(stateCount, actionCount, 0.1);
+ public int getActionCount() {
+ return this.actionCount;
}
- public UtilityModel(){
-
+ public void setActionCount(final int actionCount) {
+ this.actionCount = actionCount;
}
- public void copy(UtilityModel rhs){
- U = rhs.U==null ? null : rhs.U.makeCopy();
- actionCount = rhs.actionCount;
- stateCount = rhs.stateCount;
+ public void copy(final UtilityModel rhs) {
+ this.U = rhs.U == null ? null : rhs.U.makeCopy();
+ this.actionCount = rhs.actionCount;
+ this.stateCount = rhs.stateCount;
}
- public UtilityModel makeCopy(){
- UtilityModel clone = new UtilityModel();
+ public UtilityModel makeCopy() {
+ final UtilityModel clone = new UtilityModel();
clone.copy(this);
return clone;
}
@Override
- public boolean equals(Object rhs){
- if(rhs != null && rhs instanceof UtilityModel){
- UtilityModel rhs2 = (UtilityModel)rhs;
- if(actionCount != rhs2.actionCount || stateCount != rhs2.stateCount) return false;
-
- if((U==null && rhs2.U!=null) && (U!=null && rhs2.U ==null)) return false;
- return !(U != null && !U.equals(rhs2.U));
+ public boolean equals(final Object rhs) {
+ if (rhs instanceof UtilityModel) {
+ final UtilityModel rhs2 = (UtilityModel) rhs;
+ return this.actionCount == rhs2.actionCount &&
+ this.stateCount == rhs2.stateCount &&
+ !(this.U != null && !this.U.equals(rhs2.U));
}
return false;
}
- public void reset(double initialU){
- U.setAll(initialU);
+ public void reset(final double initialU) {
+ assert this.U != null;
+ this.U.setAll(initialU);
}
}
diff --git a/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java b/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java
index e840bc1..946bd56 100644
--- a/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java
+++ b/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java
@@ -3,12 +3,13 @@
/**
* Created by xschen on 10/11/2015 0011.
*/
-public class DoubleUtils {
- public static boolean equals(double a1, double a2){
- return Math.abs(a1-a2) < 1e-10;
- }
+public enum DoubleUtils {
+ ;
+
+ public static final double TOLERANCE = 0.0000000001;
- public static boolean isZero(double a){
- return a < 1e-20;
+ public static boolean equals(final double a1, final double a2) {
+ return Math.abs(a1 - a2) < DoubleUtils.TOLERANCE;
}
+
}
diff --git a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java
index 66c2bf6..6c3d6ae 100644
--- a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java
+++ b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java
@@ -1,46 +1,55 @@
package com.github.chen0040.rl.utils;
-import lombok.Getter;
-import lombok.Setter;
-
-
/**
* Created by xschen on 6/5/2017.
*/
-@Getter
-@Setter
public class IndexValue {
- private int index;
- private double value;
-
- public IndexValue(){
-
- }
-
- public IndexValue(int index, double value){
- this.index = index;
- this.value = value;
- }
-
- public IndexValue makeCopy(){
- IndexValue clone = new IndexValue();
- clone.setValue(value);
- clone.setIndex(index);
- return clone;
- }
-
- @Override
- public boolean equals(Object rhs){
- if(rhs != null && rhs instanceof IndexValue){
- IndexValue rhs2 = (IndexValue)rhs;
- return index == rhs2.index && value == rhs2.value;
- }
- return false;
- }
-
- public boolean isValid(){
- return index != -1;
- }
-
+ private int index;
+ private double value;
+
+ public IndexValue() {
+
+ }
+
+ public IndexValue(final int index, final double value) {
+ this.index = index;
+ this.value = value;
+ }
+
+ public IndexValue makeCopy() {
+ final IndexValue clone = new IndexValue();
+ clone.setValue(this.value);
+ clone.setIndex(this.index);
+ return clone;
+ }
+
+ @Override
+ public boolean equals(final Object rhs) {
+ if (rhs instanceof IndexValue) {
+ final IndexValue rhs2 = (IndexValue) rhs;
+ return this.index == rhs2.index && this.value == rhs2.value;
+ }
+ return false;
+ }
+
+ boolean isValid() {
+ return this.index != -1;
+ }
+
+ public int getIndex() {
+ return this.index;
+ }
+
+ public void setIndex(final int index) {
+ this.index = index;
+ }
+
+ public double getValue() {
+ return this.value;
+ }
+
+ public void setValue(final double value) {
+ this.value = value;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/utils/Matrix.java b/src/main/java/com/github/chen0040/rl/utils/Matrix.java
index cd42bd5..b20b86f 100644
--- a/src/main/java/com/github/chen0040/rl/utils/Matrix.java
+++ b/src/main/java/com/github/chen0040/rl/utils/Matrix.java
@@ -1,83 +1,41 @@
package com.github.chen0040.rl.utils;
-import com.alibaba.fastjson.annotation.JSONField;
-import lombok.Getter;
-import lombok.Setter;
-
import java.io.Serializable;
-import java.util.ArrayList;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
/**
* Created by xschen on 9/27/2015 0027.
*/
-@Getter
-@Setter
public class Matrix implements Serializable {
- private Map rows = new HashMap<>();
+ private final Map rows = new HashMap<>();
private int rowCount;
private int columnCount;
private double defaultValue;
- public Matrix(){
-
- }
-
- public Matrix(double[][] A){
- for(int i = 0; i < A.length; ++i){
- double[] B = A[i];
- for(int j=0; j < B.length; ++j){
- set(i, j, B[j]);
- }
- }
- }
-
- public void setRow(int rowIndex, Vec rowVector){
- rowVector.setId(rowIndex);
- rows.put(rowIndex, rowVector);
- }
-
-
- public static Matrix identity(int dimension){
- Matrix m = new Matrix(dimension, dimension);
- for(int i=0; i < m.getRowCount(); ++i){
- m.set(i, i, 1);
- }
- return m;
+ public Matrix(final int rowCount, final int columnCount) {
+ this.rowCount = rowCount;
+ this.columnCount = columnCount;
+ this.defaultValue = 0;
}
@Override
- public boolean equals(Object rhs){
- if(rhs != null && rhs instanceof Matrix){
- Matrix rhs2 = (Matrix)rhs;
- if(rowCount != rhs2.rowCount || columnCount != rhs2.columnCount){
+ public boolean equals(final Object rhs) {
+ if (rhs instanceof Matrix) {
+ final Matrix rhs2 = (Matrix) rhs;
+ if (this.rowCount != rhs2.rowCount || this.columnCount != rhs2.columnCount) {
return false;
}
- if(defaultValue == rhs2.defaultValue) {
- for (Integer index : rows.keySet()) {
- if (!rhs2.rows.containsKey(index)) return false;
- if (!rows.get(index).equals(rhs2.rows.get(index))) {
- System.out.println("failed!");
- return false;
- }
- }
-
- for (Integer index : rhs2.rows.keySet()) {
- if (!rows.containsKey(index)) return false;
- if (!rhs2.rows.get(index).equals(rows.get(index))) {
- System.out.println("failed! 22");
- return false;
- }
- }
+ if (this.defaultValue == rhs2.defaultValue) {
+ return this.rows.keySet().stream().noneMatch(index -> !rhs2.rows.containsKey(index) || !this.rows.get(index).equals(rhs2.rows.get(index))) &&
+ rhs2.rows.keySet().stream().noneMatch(index -> !this.rows.containsKey(index) || !rhs2.rows.get(index).equals(this.rows.get(index)));
} else {
- for(int i=0; i < rowCount; ++i) {
- for(int j=0; j < columnCount; ++j) {
- if(this.get(i, j) != rhs2.get(i, j)){
+ for (int i = 0; i < this.rowCount; ++i) {
+ for (int j = 0; j < this.columnCount; ++j) {
+ if (this.get(i, j) != rhs2.get(i, j)) {
return false;
}
}
@@ -90,154 +48,63 @@ public boolean equals(Object rhs){
return false;
}
- public Matrix makeCopy(){
- Matrix clone = new Matrix(rowCount, columnCount);
+ public Matrix makeCopy() {
+ final Matrix clone = new Matrix(this.rowCount, this.columnCount);
clone.copy(this);
return clone;
}
- public void copy(Matrix rhs){
- rowCount = rhs.rowCount;
- columnCount = rhs.columnCount;
- defaultValue = rhs.defaultValue;
+ private void copy(final Matrix rhs) {
+ this.rowCount = rhs.rowCount;
+ this.columnCount = rhs.columnCount;
+ this.defaultValue = rhs.defaultValue;
- rows.clear();
+ this.rows.clear();
- for(Map.Entry entry : rhs.rows.entrySet()){
- rows.put(entry.getKey(), entry.getValue().makeCopy());
- }
+ rhs.rows.forEach((key, value) -> this.rows.put(key, value.makeCopy()));
}
-
-
- public void set(int rowIndex, int columnIndex, double value){
- Vec row = rowAt(rowIndex);
+ public void set(final int rowIndex, final int columnIndex, final double value) {
+ final Vec row = this.rowAt(rowIndex);
row.set(columnIndex, value);
- if(rowIndex >= rowCount) { rowCount = rowIndex+1; }
- if(columnIndex >= columnCount) { columnCount = columnIndex + 1; }
- }
-
-
-
- public Matrix(int rowCount, int columnCount){
- this.rowCount = rowCount;
- this.columnCount = columnCount;
- this.defaultValue = 0;
+ if (rowIndex >= this.rowCount) {
+ this.rowCount = rowIndex + 1;
+ }
+ if (columnIndex >= this.columnCount) {
+ this.columnCount = columnIndex + 1;
+ }
}
- public Vec rowAt(int rowIndex){
- Vec row = rows.get(rowIndex);
- if(row == null){
- row = new Vec(columnCount);
- row.setAll(defaultValue);
+ public Vec rowAt(final int rowIndex) {
+ Vec row = this.rows.get(rowIndex);
+ if (row == null) {
+ row = new Vec(this.columnCount);
+ row.setAll(this.defaultValue);
row.setId(rowIndex);
- rows.put(rowIndex, row);
+ this.rows.put(rowIndex, row);
}
return row;
}
- public void setAll(double value){
- defaultValue = value;
- for(Vec row : rows.values()){
+ public void setAll(final double value) {
+ this.defaultValue = value;
+ for (final Vec row : this.rows.values()) {
row.setAll(value);
}
}
- public double get(int rowIndex, int columnIndex) {
- Vec row= rowAt(rowIndex);
+ public double get(final int rowIndex, final int columnIndex) {
+ final Vec row = this.rowAt(rowIndex);
return row.get(columnIndex);
}
- public List columnVectors()
- {
- Matrix A = this;
- int n = A.getColumnCount();
- int rowCount = A.getRowCount();
-
- List Acols = new ArrayList();
-
- for (int c = 0; c < n; ++c)
- {
- Vec Acol = new Vec(rowCount);
- Acol.setAll(defaultValue);
- Acol.setId(c);
-
- for (int r = 0; r < rowCount; ++r)
- {
- Acol.set(r, A.get(r, c));
- }
- Acols.add(Acol);
- }
- return Acols;
- }
-
- public Matrix multiply(Matrix rhs)
- {
- if(this.getColumnCount() != rhs.getRowCount()){
- System.err.println("A.columnCount must be equal to B.rowCount in multiplication");
- return null;
- }
-
- Vec row1;
- Vec col2;
-
- Matrix result = new Matrix(getRowCount(), rhs.getColumnCount());
- result.setAll(defaultValue);
-
- List rhsColumns = rhs.columnVectors();
-
- for (Map.Entry entry : rows.entrySet())
- {
- int r1 = entry.getKey();
- row1 = entry.getValue();
- for (int c2 = 0; c2 < rhsColumns.size(); ++c2)
- {
- col2 = rhsColumns.get(c2);
- result.set(r1, c2, row1.multiply(col2));
- }
- }
-
- return result;
- }
-
- @JSONField(serialize = false)
- public boolean isSymmetric(){
- if (getRowCount() != getColumnCount()) return false;
- for (Map.Entry rowEntry : rows.entrySet())
- {
- int row = rowEntry.getKey();
- Vec rowVec = rowEntry.getValue();
-
- for (Integer col : rowVec.getData().keySet())
- {
- if (row == col.intValue()) continue;
- if(DoubleUtils.equals(rowVec.get(col), this.get(col, row))){
- return false;
- }
- }
- }
-
- return true;
+ public int getRowCount() {
+ return this.rowCount;
}
- public Vec multiply(Vec rhs)
- {
- if(this.getColumnCount() != rhs.getDimension()){
- System.err.println("columnCount must be equal to the size of the vector for multiplication");
- }
-
- Vec row1;
- Vec result = new Vec(getRowCount());
- for (Map.Entry entry : rows.entrySet())
- {
- row1 = entry.getValue();
- result.set(entry.getKey(), row1.multiply(rhs));
- }
- return result;
+ public int getColumnCount() {
+ return this.columnCount;
}
-
-
-
}
diff --git a/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java b/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java
deleted file mode 100644
index e43c28b..0000000
--- a/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java
+++ /dev/null
@@ -1,29 +0,0 @@
-package com.github.chen0040.rl.utils;
-
-import java.util.List;
-
-
-/**
- * Created by xschen on 10/11/2015 0011.
- */
-public class MatrixUtils {
- /**
- * Convert a list of column vectors into a matrix
- */
- public static Matrix matrixFromColumnVectors(List R)
- {
- int n = R.size();
- int m = R.get(0).getDimension();
-
- Matrix T = new Matrix(m, n);
- for (int c = 0; c < n; ++c)
- {
- Vec Rcol = R.get(c);
- for (int r : Rcol.getData().keySet())
- {
- T.set(r, c, Rcol.get(r));
- }
- }
- return T;
- }
-}
diff --git a/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java b/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java
deleted file mode 100644
index b4895ea..0000000
--- a/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java
+++ /dev/null
@@ -1,56 +0,0 @@
-package com.github.chen0040.rl.utils;
-
-/**
- * Created by xschen on 10/11/2015 0011.
- */
-public class TupleTwo {
- private T1 item1;
- private T2 item2;
-
- public TupleTwo(T1 item1, T2 item2){
- this.item1 = item1;
- this.item2 = item2;
- }
-
- public T1 getItem1() {
- return item1;
- }
-
- public void setItem1(T1 item1) {
- this.item1 = item1;
- }
-
- public T2 getItem2() {
- return item2;
- }
-
- public void setItem2(T2 item2) {
- this.item2 = item2;
- }
-
- public static TupleTwo create(U1 item1, U2 item2){
- return new TupleTwo(item1, item2);
- }
-
-
- @Override public boolean equals(Object o) {
- if (this == o)
- return true;
- if (o == null || getClass() != o.getClass())
- return false;
-
- TupleTwo, ?> tupleTwo = (TupleTwo, ?>) o;
-
- if (item1 != null ? !item1.equals(tupleTwo.item1) : tupleTwo.item1 != null)
- return false;
- return item2 != null ? item2.equals(tupleTwo.item2) : tupleTwo.item2 == null;
-
- }
-
-
- @Override public int hashCode() {
- int result = item1 != null ? item1.hashCode() : 0;
- result = 31 * result + (item2 != null ? item2.hashCode() : 0);
- return result;
- }
-}
diff --git a/src/main/java/com/github/chen0040/rl/utils/Vec.java b/src/main/java/com/github/chen0040/rl/utils/Vec.java
index 4699d0e..ca007cf 100644
--- a/src/main/java/com/github/chen0040/rl/utils/Vec.java
+++ b/src/main/java/com/github/chen0040/rl/utils/Vec.java
@@ -1,131 +1,108 @@
package com.github.chen0040.rl.utils;
-import lombok.Getter;
-import lombok.Setter;
-
import java.io.Serializable;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.stream.IntStream;
/**
* Created by xschen on 9/27/2015 0027.
*/
-@Getter
-@Setter
public class Vec implements Serializable {
- private Map data = new HashMap();
+ private final Map data = new HashMap<>();
private int dimension;
private double defaultValue;
private int id = -1;
- public Vec(){
+ public Vec() {
}
- public Vec(double[] v){
- for(int i=0; i < v.length; ++i){
- set(i, v[i]);
- }
+ public Vec(final double[] v) {
+ IntStream.range(0, v.length).forEach(i -> this.set(i, v[i]));
}
- public Vec(int dimension){
+ public Vec(final int dimension) {
this.dimension = dimension;
- defaultValue = 0;
+ this.defaultValue = 0;
}
- public Vec(int dimension, Map data){
+ public Vec(final int dimension, final Map data) {
this.dimension = dimension;
- defaultValue = 0;
+ this.defaultValue = 0;
- for(Map.Entry entry : data.entrySet()){
- set(entry.getKey(), entry.getValue());
- }
+ data.forEach(this::set);
}
- public Vec makeCopy(){
- Vec clone = new Vec(dimension);
+ public Vec makeCopy() {
+ final Vec clone = new Vec(this.dimension);
clone.copy(this);
return clone;
}
- public void copy(Vec rhs){
- defaultValue = rhs.defaultValue;
- dimension = rhs.dimension;
- id = rhs.id;
+ public void copy(final Vec rhs) {
+ this.defaultValue = rhs.defaultValue;
+ this.dimension = rhs.dimension;
+ this.id = rhs.id;
- data.clear();
- for(Map.Entry entry : rhs.data.entrySet()){
- data.put(entry.getKey(), entry.getValue());
- }
+ this.data.clear();
+ rhs.data.forEach(this.data::put);
}
- public void set(int i, double value){
- if(value == defaultValue) return;
+ public void set(final int i, final double value) {
+ if (value == this.defaultValue) {
+ return;
+ }
- data.put(i, value);
- if(i >= dimension){
- dimension = i+1;
+ this.data.put(i, value);
+ if (i >= this.dimension) {
+ this.dimension = i + 1;
}
}
- public double get(int i){
- return data.getOrDefault(i, defaultValue);
+ public double get(final int i) {
+ return this.data.getOrDefault(i, this.defaultValue);
}
@Override
- public boolean equals(Object rhs){
- if(rhs != null && rhs instanceof Vec){
- Vec rhs2 = (Vec)rhs;
- if(dimension != rhs2.dimension){
+ public boolean equals(final Object rhs) {
+ if (rhs instanceof Vec) {
+ final Vec rhs2 = (Vec) rhs;
+ if (this.dimension != rhs2.dimension || this.data.size() != rhs2.data.size()) {
return false;
}
-
- if(data.size() != rhs2.data.size()){
- return false;
- }
-
- for(Integer index : data.keySet()){
- if(!rhs2.data.containsKey(index)) return false;
- if(!DoubleUtils.equals(data.get(index), rhs2.data.get(index))){
+ for (final Integer index : this.data.keySet()) {
+ if (!rhs2.data.containsKey(index) || !DoubleUtils.equals(this.data.get(index), rhs2.data.get(index))) {
return false;
}
}
-
- if(defaultValue != rhs2.defaultValue){
- for(int i=0; i < dimension; ++i){
- if(data.containsKey(i)){
- return false;
- }
- }
+ if (this.defaultValue != rhs2.defaultValue) {
+ return IntStream.range(0, this.dimension).noneMatch(this.data::containsKey);
}
-
return true;
}
return false;
}
- public void setAll(double value){
- defaultValue = value;
- for(Integer index : data.keySet()){
- data.put(index, defaultValue);
- }
+ public void setAll(final double value) {
+ this.defaultValue = value;
+ this.data.keySet().forEach(index -> this.data.put(index, this.defaultValue));
}
- public IndexValue indexWithMaxValue(Set indices){
- if(indices == null){
- return indexWithMaxValue();
- }else{
- IndexValue iv = new IndexValue();
+ public IndexValue indexWithMaxValue(final Set indices) {
+ if (indices == null) {
+ return this.indexWithMaxValue();
+ } else {
+ final IndexValue iv = new IndexValue();
iv.setIndex(-1);
iv.setValue(Double.NEGATIVE_INFINITY);
- for(Integer index : indices){
- double value = data.getOrDefault(index, Double.NEGATIVE_INFINITY);
- if(value > iv.getValue()){
+ for (final Integer index : indices) {
+ final double value = this.data.getOrDefault(index, Double.NEGATIVE_INFINITY);
+ if (value > iv.getValue()) {
iv.setIndex(index);
iv.setValue(value);
}
@@ -134,29 +111,31 @@ public IndexValue indexWithMaxValue(Set indices){
}
}
- public IndexValue indexWithMaxValue(){
- IndexValue iv = new IndexValue();
+ private IndexValue indexWithMaxValue() {
+ final IndexValue iv = new IndexValue();
iv.setIndex(-1);
iv.setValue(Double.NEGATIVE_INFINITY);
- for(Map.Entry entry : data.entrySet()){
- if(entry.getKey() >= dimension) continue;
+ for (final Map.Entry entry : this.data.entrySet()) {
+ if (entry.getKey() >= this.dimension) {
+ continue;
+ }
- double value = entry.getValue();
- if(value > iv.getValue()){
+ final double value = entry.getValue();
+ if (value > iv.getValue()) {
iv.setValue(value);
iv.setIndex(entry.getKey());
}
}
- if(!iv.isValid()){
- iv.setValue(defaultValue);
- } else{
- if(iv.getValue() < defaultValue){
- for(int i=0; i < dimension; ++i){
- if(!data.containsKey(i)){
- iv.setValue(defaultValue);
+ if (!iv.isValid()) {
+ iv.setValue(this.defaultValue);
+ } else {
+ if (iv.getValue() < this.defaultValue) {
+ for (int i = 0; i < this.dimension; ++i) {
+ if (!this.data.containsKey(i)) {
+ iv.setValue(this.defaultValue);
iv.setIndex(i);
break;
}
@@ -168,182 +147,7 @@ public IndexValue indexWithMaxValue(){
}
-
- public Vec projectOrthogonal(Iterable vlist) {
- Vec b = this;
- for(Vec v : vlist)
- {
- b = b.minus(b.projectAlong(v));
- }
-
- return b;
- }
-
- public Vec projectOrthogonal(List vlist, Map alpha) {
- Vec b = this;
- for(int i = 0; i < vlist.size(); ++i)
- {
- Vec v = vlist.get(i);
- double norm_a = v.multiply(v);
-
- if (DoubleUtils.isZero(norm_a)) {
- return new Vec(dimension);
- }
- double sigma = multiply(v) / norm_a;
- Vec v_parallel = v.multiply(sigma);
-
- alpha.put(i, sigma);
-
- b = b.minus(v_parallel);
- }
-
- return b;
- }
-
- public Vec projectAlong(Vec rhs)
- {
- double norm_a = rhs.multiply(rhs);
-
- if (DoubleUtils.isZero(norm_a)) {
- return new Vec(dimension);
- }
- double sigma = multiply(rhs) / norm_a;
- return rhs.multiply(sigma);
- }
-
- public Vec multiply(double rhs){
- Vec clone = (Vec)this.makeCopy();
- for(Integer i : data.keySet()){
- clone.data.put(i, rhs * data.get(i));
- }
- return clone;
- }
-
- public double multiply(Vec rhs)
- {
- double productSum = 0;
- if(defaultValue == 0) {
- for (Map.Entry entry : data.entrySet()) {
- productSum += entry.getValue() * rhs.get(entry.getKey());
- }
- } else {
- for(int i=0; i < dimension; ++i){
- productSum += get(i) * rhs.get(i);
- }
- }
-
- return productSum;
- }
-
- public Vec pow(double scalar)
- {
- Vec result = new Vec(dimension);
- for (Map.Entry entry : data.entrySet())
- {
- result.data.put(entry.getKey(), Math.pow(entry.getValue(), scalar));
- }
- return result;
- }
-
- public Vec add(Vec rhs)
- {
- Vec result = new Vec(dimension);
- int index;
- for (Map.Entry entry : data.entrySet()) {
- index = entry.getKey();
- result.data.put(index, entry.getValue() + rhs.data.get(index));
- }
- for(Map.Entry entry : rhs.data.entrySet()){
- index = entry.getKey();
- if(result.data.containsKey(index)) continue;
- result.data.put(index, entry.getValue() + data.get(index));
- }
-
- return result;
- }
-
- public Vec minus(Vec rhs)
- {
- Vec result = new Vec(dimension);
- int index;
- for (Map.Entry entry : data.entrySet()) {
- index = entry.getKey();
- result.data.put(index, entry.getValue() - rhs.data.get(index));
- }
- for(Map.Entry entry : rhs.data.entrySet()){
- index = entry.getKey();
- if(result.data.containsKey(index)) continue;
- result.data.put(index, data.get(index) - entry.getValue());
- }
-
- return result;
- }
-
- public double sum(){
- double sum = 0;
-
- for(Map.Entry entry : data.entrySet()){
- sum += entry.getValue();
- }
- sum += defaultValue * (dimension - data.size());
-
- return sum;
- }
-
- public boolean isZero(){
- return DoubleUtils.isZero(sum());
- }
-
- public double norm(int level)
- {
- if (level == 1)
- {
- double sum = 0;
- for (Double val : data.values())
- {
- sum += Math.abs(val);
- }
- if(!DoubleUtils.isZero(defaultValue)) {
- sum += Math.abs(defaultValue) * (dimension - data.size());
- }
- return sum;
- }
- else if (level == 2)
- {
- double sum = multiply(this);
- if(!DoubleUtils.isZero(defaultValue)){
- sum += (dimension - data.size()) * (defaultValue * defaultValue);
- }
- return Math.sqrt(sum);
- }
- else
- {
- double sum = 0;
- for (Double val : this.data.values())
- {
- sum += Math.pow(Math.abs(val), level);
- }
- if(!DoubleUtils.isZero(defaultValue)) {
- sum += Math.pow(Math.abs(defaultValue), level) * (dimension - data.size());
- }
- return Math.pow(sum, 1.0 / level);
- }
- }
-
- public Vec normalize()
- {
- double norm = norm(2); // L2 norm is the cartesian distance
- if (DoubleUtils.isZero(norm))
- {
- return new Vec(dimension);
- }
- Vec clone = new Vec(dimension);
- clone.setAll(defaultValue / norm);
-
- for (Integer k : data.keySet())
- {
- clone.data.put(k, data.get(k) / norm);
- }
- return clone;
+ void setId(final int id) {
+ this.id = id;
}
}
diff --git a/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java b/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java
deleted file mode 100644
index 2bbfbaa..0000000
--- a/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java
+++ /dev/null
@@ -1,39 +0,0 @@
-package com.github.chen0040.rl.utils;
-
-import java.util.ArrayList;
-import java.util.List;
-
-
-/**
- * Created by xschen on 10/11/2015 0011.
- */
-public class VectorUtils {
- public static List removeZeroVectors(Iterable vlist)
- {
- List vstarlist = new ArrayList();
- for (Vec v : vlist)
- {
- if (!v.isZero())
- {
- vstarlist.add(v);
- }
- }
-
- return vstarlist;
- }
-
- public static TupleTwo, List> normalize(Iterable vlist)
- {
- List norms = new ArrayList();
- List vstarlist = new ArrayList();
- for (Vec v : vlist)
- {
- norms.add(v.norm(2));
- vstarlist.add(v.normalize());
- }
-
- return TupleTwo.create(vstarlist, norms);
- }
-
-
-}