From da1636ae9fb5c0a5548897881298173418516332 Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Sun, 10 Mar 2019 11:12:15 +0100 Subject: [PATCH 01/12] adding getter and setter --- .../com/github/chen0040/rl/models/QModel.java | 161 +++++---- .../chen0040/rl/models/UtilityModel.java | 97 +++--- .../github/chen0040/rl/utils/IndexValue.java | 85 ++--- .../com/github/chen0040/rl/utils/Matrix.java | 234 +++++++------ .../com/github/chen0040/rl/utils/Vec.java | 317 +++++++++--------- 5 files changed, 496 insertions(+), 398 deletions(-) diff --git a/src/main/java/com/github/chen0040/rl/models/QModel.java b/src/main/java/com/github/chen0040/rl/models/QModel.java index 2d314a1..ae36347 100644 --- a/src/main/java/com/github/chen0040/rl/models/QModel.java +++ b/src/main/java/com/github/chen0040/rl/models/QModel.java @@ -4,28 +4,23 @@ import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.utils.Matrix; import com.github.chen0040.rl.utils.Vec; -import lombok.Getter; -import lombok.Setter; import java.util.*; /** - * @author xschen - * 9/27/2015 0027. - * Q is known as the quality of state-action combination, note that it is different from utility of a state + * @author xschen 9/27/2015 0027. Q is known as the quality of state-action combination, note that it is different from + * utility of a state */ -@Getter -@Setter public class QModel { /** - * Q value for (state_id, action_id) pair - * Q is known as the quality of state-action combination, note that it is different from utility of a state - */ + * Q value for (state_id, action_id) pair Q is known as the quality of state-action combination, note that it is + * different from utility of a state + */ private Matrix Q; /** - * $\alpha[s, a]$ value for learning rate: alpha(state_id, action_id) - */ + * $\alpha[s, a]$ value for learning rate: alpha(state_id, action_id) + */ private Matrix alphaMatrix; /** @@ -36,117 +31,125 @@ public class QModel { private int stateCount; private int actionCount; - public QModel(int stateCount, int actionCount, double initialQ){ + public QModel(final int stateCount, final int actionCount, final double initialQ) { this.stateCount = stateCount; this.actionCount = actionCount; - Q = new Matrix(stateCount,actionCount); - alphaMatrix = new Matrix(stateCount, actionCount); - Q.setAll(initialQ); - alphaMatrix.setAll(0.1); + this.Q = new Matrix(stateCount, actionCount); + this.alphaMatrix = new Matrix(stateCount, actionCount); + this.Q.setAll(initialQ); + this.alphaMatrix.setAll(0.1); } - public QModel(int stateCount, int actionCount){ + public QModel(final int stateCount, final int actionCount) { this(stateCount, actionCount, 0.1); } - public QModel(){ + public QModel() { } @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof QModel){ - QModel rhs2 = (QModel)rhs; + public boolean equals(final Object rhs) { + if (rhs != null && rhs instanceof QModel) { + final QModel rhs2 = (QModel) rhs; - if(gamma != rhs2.gamma) return false; + if (this.gamma != rhs2.gamma) { + return false; + } - if(stateCount != rhs2.stateCount || actionCount != rhs2.actionCount) return false; + if (this.stateCount != rhs2.stateCount || this.actionCount != rhs2.actionCount) { + return false; + } - if((Q!=null && rhs2.Q==null) || (Q==null && rhs2.Q !=null)) return false; - if((alphaMatrix !=null && rhs2.alphaMatrix ==null) || (alphaMatrix ==null && rhs2.alphaMatrix !=null)) return false; + if ((this.Q != null && rhs2.Q == null) || (this.Q == null && rhs2.Q != null)) { + return false; + } + if ((this.alphaMatrix != null && rhs2.alphaMatrix == null) || (this.alphaMatrix == null && rhs2.alphaMatrix != null)) { + return false; + } - return !((Q != null && !Q.equals(rhs2.Q)) || (alphaMatrix != null && !alphaMatrix.equals(rhs2.alphaMatrix))); + return !((this.Q != null && !this.Q.equals(rhs2.Q)) || (this.alphaMatrix != null && !this.alphaMatrix.equals(rhs2.alphaMatrix))); } return false; } - public QModel makeCopy(){ - QModel clone = new QModel(); + public QModel makeCopy() { + final QModel clone = new QModel(); clone.copy(this); return clone; } - public void copy(QModel rhs){ - gamma = rhs.gamma; - stateCount = rhs.stateCount; - actionCount = rhs.actionCount; - Q = rhs.Q==null ? null : rhs.Q.makeCopy(); - alphaMatrix = rhs.alphaMatrix == null ? null : rhs.alphaMatrix.makeCopy(); + public void copy(final QModel rhs) { + this.gamma = rhs.gamma; + this.stateCount = rhs.stateCount; + this.actionCount = rhs.actionCount; + this.Q = rhs.Q == null ? null : rhs.Q.makeCopy(); + this.alphaMatrix = rhs.alphaMatrix == null ? null : rhs.alphaMatrix.makeCopy(); } - public double getQ(int stateId, int actionId){ - return Q.get(stateId, actionId); + public double getQ(final int stateId, final int actionId) { + return this.Q.get(stateId, actionId); } - public void setQ(int stateId, int actionId, double Qij){ - Q.set(stateId, actionId, Qij); + public void setQ(final int stateId, final int actionId, final double Qij) { + this.Q.set(stateId, actionId, Qij); } - public double getAlpha(int stateId, int actionId){ - return alphaMatrix.get(stateId, actionId); + public double getAlpha(final int stateId, final int actionId) { + return this.alphaMatrix.get(stateId, actionId); } - public void setAlpha(double defaultAlpha) { + public void setAlpha(final double defaultAlpha) { this.alphaMatrix.setAll(defaultAlpha); } - public IndexValue actionWithMaxQAtState(int stateId, Set actionsAtState){ - Vec rowVector = Q.rowAt(stateId); + public IndexValue actionWithMaxQAtState(final int stateId, final Set actionsAtState) { + final Vec rowVector = this.Q.rowAt(stateId); return rowVector.indexWithMaxValue(actionsAtState); } - private void reset(double initialQ){ - Q.setAll(initialQ); + private void reset(final double initialQ) { + this.Q.setAll(initialQ); } - public IndexValue actionWithSoftMaxQAtState(int stateId,Set actionsAtState, Random random) { - Vec rowVector = Q.rowAt(stateId); + public IndexValue actionWithSoftMaxQAtState(final int stateId, Set actionsAtState, final Random random) { + final Vec rowVector = this.Q.rowAt(stateId); double sum = 0; - if(actionsAtState==null){ + if (actionsAtState == null) { actionsAtState = new HashSet<>(); - for(int i=0; i < actionCount; ++i){ + for (int i = 0; i < this.actionCount; ++i) { actionsAtState.add(i); } } - List actions = new ArrayList<>(); - for(Integer actionId : actionsAtState){ + final List actions = new ArrayList<>(); + for (final Integer actionId : actionsAtState) { actions.add(actionId); } - double[] acc = new double[actions.size()]; - for(int i=0; i < actions.size(); ++i){ + final double[] acc = new double[actions.size()]; + for (int i = 0; i < actions.size(); ++i) { sum += rowVector.get(actions.get(i)); acc[i] = sum; } - double r = random.nextDouble() * sum; + final double r = random.nextDouble() * sum; - IndexValue result = new IndexValue(); - for(int i=0; i < actions.size(); ++i){ - if(acc[i] >= r){ - int actionId = actions.get(i); + final IndexValue result = new IndexValue(); + for (int i = 0; i < actions.size(); ++i) { + if (acc[i] >= r) { + final int actionId = actions.get(i); result.setIndex(actionId); result.setValue(rowVector.get(actionId)); break; @@ -155,4 +158,44 @@ public IndexValue actionWithSoftMaxQAtState(int stateId,Set actionsAtSt return result; } + + public Matrix getQ() { + return this.Q; + } + + public void setQ(final Matrix q) { + this.Q = q; + } + + public Matrix getAlphaMatrix() { + return this.alphaMatrix; + } + + public void setAlphaMatrix(final Matrix alphaMatrix) { + this.alphaMatrix = alphaMatrix; + } + + public double getGamma() { + return this.gamma; + } + + public void setGamma(final double gamma) { + this.gamma = gamma; + } + + public int getStateCount() { + return this.stateCount; + } + + public void setStateCount(final int stateCount) { + this.stateCount = stateCount; + } + + public int getActionCount() { + return this.actionCount; + } + + public void setActionCount(final int actionCount) { + this.actionCount = actionCount; + } } diff --git a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java index cff1859..83f7f80 100644 --- a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java +++ b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java @@ -1,91 +1,98 @@ package com.github.chen0040.rl.models; import com.github.chen0040.rl.utils.Vec; -import lombok.Getter; -import lombok.Setter; import java.io.Serializable; /** - * @author xschen - * 9/27/2015 0027. - * Utility value of a state $U(s)$ is the expected long term reward of state $s$ given the sequence of reward and the optimal policy - * Utility value $U(s)$ at state $s$ can be obtained by the Bellman equation - * Bellman Equtation states that $U(s) = R(s) + \gamma * max_a \sum_{s'} T(s,a,s')U(s')$ - * where s' is the possible transitioned state given that action $a$ is applied at state $s$ - * where $T(s,a,s')$ is the transition probability of $s \rightarrow s'$ given that action $a$ is applied at state $s$ - * where $\sum_{s'} T(s,a,s')U(s')$ is the expected long term reward given that action $a$ is applied at state $s$ - * where $max_a \sum_{s'} T(s,a,s')U(s')$ is the maximum expected long term reward given that the chosen optimal action $a$ is applied at state $s$ + * @author xschen 9/27/2015 0027. Utility value of a state $U(s)$ is the expected long term reward of state $s$ given + * the sequence of reward and the optimal policy Utility value $U(s)$ at state $s$ can be obtained by the + * Bellman equation Bellman Equtation states that $U(s) = R(s) + \gamma * max_a \sum_{s'} T(s,a,s')U(s')$ where + * s' is the possible transitioned state given that action $a$ is applied at state $s$ where $T(s,a,s')$ is the + * transition probability of $s \rightarrow s'$ given that action $a$ is applied at state $s$ where $\sum_{s'} + * T(s,a,s')U(s')$ is the expected long term reward given that action $a$ is applied at state $s$ where $max_a + * \sum_{s'} T(s,a,s')U(s')$ is the maximum expected long term reward given that the chosen optimal action $a$ + * is applied at state $s$ */ -@Getter -@Setter public class UtilityModel implements Serializable { private Vec U; private int stateCount; private int actionCount; - public void setU(Vec U){ - this.U = U; + public UtilityModel(final int stateCount, final int actionCount, final double initialU) { + this.stateCount = stateCount; + this.actionCount = actionCount; + this.U = new Vec(stateCount); + this.U.setAll(initialU); + } + + public UtilityModel(final int stateCount, final int actionCount) { + this(stateCount, actionCount, 0.1); + } + + public UtilityModel() { + } public Vec getU() { - return U; + return this.U; } - public double getU(int stateId){ - return U.get(stateId); + public void setU(final Vec U) { + this.U = U; } - public int getStateCount() { - return stateCount; + public double getU(final int stateId) { + return this.U.get(stateId); } - public int getActionCount() { - return actionCount; + public int getStateCount() { + return this.stateCount; } - public UtilityModel(int stateCount, int actionCount, double initialU){ + public void setStateCount(final int stateCount) { this.stateCount = stateCount; - this.actionCount = actionCount; - U = new Vec(stateCount); - U.setAll(initialU); } - public UtilityModel(int stateCount, int actionCount){ - this(stateCount, actionCount, 0.1); + public int getActionCount() { + return this.actionCount; } - public UtilityModel(){ - + public void setActionCount(final int actionCount) { + this.actionCount = actionCount; } - public void copy(UtilityModel rhs){ - U = rhs.U==null ? null : rhs.U.makeCopy(); - actionCount = rhs.actionCount; - stateCount = rhs.stateCount; + public void copy(final UtilityModel rhs) { + this.U = rhs.U == null ? null : rhs.U.makeCopy(); + this.actionCount = rhs.actionCount; + this.stateCount = rhs.stateCount; } - public UtilityModel makeCopy(){ - UtilityModel clone = new UtilityModel(); + public UtilityModel makeCopy() { + final UtilityModel clone = new UtilityModel(); clone.copy(this); return clone; } @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof UtilityModel){ - UtilityModel rhs2 = (UtilityModel)rhs; - if(actionCount != rhs2.actionCount || stateCount != rhs2.stateCount) return false; - - if((U==null && rhs2.U!=null) && (U!=null && rhs2.U ==null)) return false; - return !(U != null && !U.equals(rhs2.U)); + public boolean equals(final Object rhs) { + if (rhs != null && rhs instanceof UtilityModel) { + final UtilityModel rhs2 = (UtilityModel) rhs; + if (this.actionCount != rhs2.actionCount || this.stateCount != rhs2.stateCount) { + return false; + } + + if ((this.U == null && rhs2.U != null) && (this.U != null && rhs2.U == null)) { + return false; + } + return !(this.U != null && !this.U.equals(rhs2.U)); } return false; } - public void reset(double initialU){ - U.setAll(initialU); + public void reset(final double initialU) { + this.U.setAll(initialU); } } diff --git a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java index 66c2bf6..41c400b 100644 --- a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java +++ b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java @@ -1,46 +1,55 @@ package com.github.chen0040.rl.utils; -import lombok.Getter; -import lombok.Setter; - - /** * Created by xschen on 6/5/2017. */ -@Getter -@Setter public class IndexValue { - private int index; - private double value; - - public IndexValue(){ - - } - - public IndexValue(int index, double value){ - this.index = index; - this.value = value; - } - - public IndexValue makeCopy(){ - IndexValue clone = new IndexValue(); - clone.setValue(value); - clone.setIndex(index); - return clone; - } - - @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof IndexValue){ - IndexValue rhs2 = (IndexValue)rhs; - return index == rhs2.index && value == rhs2.value; - } - return false; - } - - public boolean isValid(){ - return index != -1; - } - + private int index; + private double value; + + public IndexValue() { + + } + + public IndexValue(final int index, final double value) { + this.index = index; + this.value = value; + } + + public IndexValue makeCopy() { + final IndexValue clone = new IndexValue(); + clone.setValue(this.value); + clone.setIndex(this.index); + return clone; + } + + @Override + public boolean equals(final Object rhs) { + if (rhs != null && rhs instanceof IndexValue) { + final IndexValue rhs2 = (IndexValue) rhs; + return this.index == rhs2.index && this.value == rhs2.value; + } + return false; + } + + public boolean isValid() { + return this.index != -1; + } + + public int getIndex() { + return this.index; + } + + public void setIndex(final int index) { + this.index = index; + } + + public double getValue() { + return this.value; + } + + public void setValue(final double value) { + this.value = value; + } } diff --git a/src/main/java/com/github/chen0040/rl/utils/Matrix.java b/src/main/java/com/github/chen0040/rl/utils/Matrix.java index cd42bd5..5cb6c17 100644 --- a/src/main/java/com/github/chen0040/rl/utils/Matrix.java +++ b/src/main/java/com/github/chen0040/rl/utils/Matrix.java @@ -1,8 +1,6 @@ package com.github.chen0040.rl.utils; import com.alibaba.fastjson.annotation.JSONField; -import lombok.Getter; -import lombok.Setter; import java.io.Serializable; import java.util.ArrayList; @@ -14,70 +12,77 @@ /** * Created by xschen on 9/27/2015 0027. */ -@Getter -@Setter public class Matrix implements Serializable { private Map rows = new HashMap<>(); private int rowCount; private int columnCount; private double defaultValue; - public Matrix(){ + public Matrix() { } - public Matrix(double[][] A){ - for(int i = 0; i < A.length; ++i){ - double[] B = A[i]; - for(int j=0; j < B.length; ++j){ - set(i, j, B[j]); + public Matrix(final double[][] A) { + for (int i = 0; i < A.length; ++i) { + final double[] B = A[i]; + for (int j = 0; j < B.length; ++j) { + this.set(i, j, B[j]); } } } - public void setRow(int rowIndex, Vec rowVector){ - rowVector.setId(rowIndex); - rows.put(rowIndex, rowVector); + public Matrix(final int rowCount, final int columnCount) { + this.rowCount = rowCount; + this.columnCount = columnCount; + this.defaultValue = 0; } - - public static Matrix identity(int dimension){ - Matrix m = new Matrix(dimension, dimension); - for(int i=0; i < m.getRowCount(); ++i){ + public static Matrix identity(final int dimension) { + final Matrix m = new Matrix(dimension, dimension); + for (int i = 0; i < m.getRowCount(); ++i) { m.set(i, i, 1); } return m; } + public void setRow(final int rowIndex, final Vec rowVector) { + rowVector.setId(rowIndex); + this.rows.put(rowIndex, rowVector); + } + @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof Matrix){ - Matrix rhs2 = (Matrix)rhs; - if(rowCount != rhs2.rowCount || columnCount != rhs2.columnCount){ + public boolean equals(final Object rhs) { + if (rhs != null && rhs instanceof Matrix) { + final Matrix rhs2 = (Matrix) rhs; + if (this.rowCount != rhs2.rowCount || this.columnCount != rhs2.columnCount) { return false; } - if(defaultValue == rhs2.defaultValue) { - for (Integer index : rows.keySet()) { - if (!rhs2.rows.containsKey(index)) return false; - if (!rows.get(index).equals(rhs2.rows.get(index))) { + if (this.defaultValue == rhs2.defaultValue) { + for (final Integer index : this.rows.keySet()) { + if (!rhs2.rows.containsKey(index)) { + return false; + } + if (!this.rows.get(index).equals(rhs2.rows.get(index))) { System.out.println("failed!"); return false; } } - for (Integer index : rhs2.rows.keySet()) { - if (!rows.containsKey(index)) return false; - if (!rhs2.rows.get(index).equals(rows.get(index))) { + for (final Integer index : rhs2.rows.keySet()) { + if (!this.rows.containsKey(index)) { + return false; + } + if (!rhs2.rows.get(index).equals(this.rows.get(index))) { System.out.println("failed! 22"); return false; } } } else { - for(int i=0; i < rowCount; ++i) { - for(int j=0; j < columnCount; ++j) { - if(this.get(i, j) != rhs2.get(i, j)){ + for (int i = 0; i < this.rowCount; ++i) { + for (int j = 0; j < this.columnCount; ++j) { + if (this.get(i, j) != rhs2.get(i, j)) { return false; } } @@ -90,80 +95,71 @@ public boolean equals(Object rhs){ return false; } - public Matrix makeCopy(){ - Matrix clone = new Matrix(rowCount, columnCount); + public Matrix makeCopy() { + final Matrix clone = new Matrix(this.rowCount, this.columnCount); clone.copy(this); return clone; } - public void copy(Matrix rhs){ - rowCount = rhs.rowCount; - columnCount = rhs.columnCount; - defaultValue = rhs.defaultValue; + public void copy(final Matrix rhs) { + this.rowCount = rhs.rowCount; + this.columnCount = rhs.columnCount; + this.defaultValue = rhs.defaultValue; - rows.clear(); + this.rows.clear(); - for(Map.Entry entry : rhs.rows.entrySet()){ - rows.put(entry.getKey(), entry.getValue().makeCopy()); + for (final Map.Entry entry : rhs.rows.entrySet()) { + this.rows.put(entry.getKey(), entry.getValue().makeCopy()); } } - - - public void set(int rowIndex, int columnIndex, double value){ - Vec row = rowAt(rowIndex); + public void set(final int rowIndex, final int columnIndex, final double value) { + final Vec row = this.rowAt(rowIndex); row.set(columnIndex, value); - if(rowIndex >= rowCount) { rowCount = rowIndex+1; } - if(columnIndex >= columnCount) { columnCount = columnIndex + 1; } - } - - - - public Matrix(int rowCount, int columnCount){ - this.rowCount = rowCount; - this.columnCount = columnCount; - this.defaultValue = 0; + if (rowIndex >= this.rowCount) { + this.rowCount = rowIndex + 1; + } + if (columnIndex >= this.columnCount) { + this.columnCount = columnIndex + 1; + } } - public Vec rowAt(int rowIndex){ - Vec row = rows.get(rowIndex); - if(row == null){ - row = new Vec(columnCount); - row.setAll(defaultValue); + public Vec rowAt(final int rowIndex) { + Vec row = this.rows.get(rowIndex); + if (row == null) { + row = new Vec(this.columnCount); + row.setAll(this.defaultValue); row.setId(rowIndex); - rows.put(rowIndex, row); + this.rows.put(rowIndex, row); } return row; } - public void setAll(double value){ - defaultValue = value; - for(Vec row : rows.values()){ + public void setAll(final double value) { + this.defaultValue = value; + for (final Vec row : this.rows.values()) { row.setAll(value); } } - public double get(int rowIndex, int columnIndex) { - Vec row= rowAt(rowIndex); + public double get(final int rowIndex, final int columnIndex) { + final Vec row = this.rowAt(rowIndex); return row.get(columnIndex); } - public List columnVectors() - { - Matrix A = this; - int n = A.getColumnCount(); - int rowCount = A.getRowCount(); + public List columnVectors() { + final Matrix A = this; + final int n = A.getColumnCount(); + final int rowCount = A.getRowCount(); - List Acols = new ArrayList(); + final List Acols = new ArrayList(); - for (int c = 0; c < n; ++c) - { - Vec Acol = new Vec(rowCount); - Acol.setAll(defaultValue); + for (int c = 0; c < n; ++c) { + final Vec Acol = new Vec(rowCount); + Acol.setAll(this.defaultValue); Acol.setId(c); - for (int r = 0; r < rowCount; ++r) - { + for (int r = 0; r < rowCount; ++r) { Acol.set(r, A.get(r, c)); } Acols.add(Acol); @@ -171,9 +167,8 @@ public List columnVectors() return Acols; } - public Matrix multiply(Matrix rhs) - { - if(this.getColumnCount() != rhs.getRowCount()){ + public Matrix multiply(final Matrix rhs) { + if (this.getColumnCount() != rhs.getRowCount()) { System.err.println("A.columnCount must be equal to B.rowCount in multiplication"); return null; } @@ -181,17 +176,15 @@ public Matrix multiply(Matrix rhs) Vec row1; Vec col2; - Matrix result = new Matrix(getRowCount(), rhs.getColumnCount()); - result.setAll(defaultValue); + final Matrix result = new Matrix(this.getRowCount(), rhs.getColumnCount()); + result.setAll(this.defaultValue); - List rhsColumns = rhs.columnVectors(); + final List rhsColumns = rhs.columnVectors(); - for (Map.Entry entry : rows.entrySet()) - { - int r1 = entry.getKey(); + for (final Map.Entry entry : this.rows.entrySet()) { + final int r1 = entry.getKey(); row1 = entry.getValue(); - for (int c2 = 0; c2 < rhsColumns.size(); ++c2) - { + for (int c2 = 0; c2 < rhsColumns.size(); ++c2) { col2 = rhsColumns.get(c2); result.set(r1, c2, row1.multiply(col2)); } @@ -201,18 +194,20 @@ public Matrix multiply(Matrix rhs) } @JSONField(serialize = false) - public boolean isSymmetric(){ - if (getRowCount() != getColumnCount()) return false; - - for (Map.Entry rowEntry : rows.entrySet()) - { - int row = rowEntry.getKey(); - Vec rowVec = rowEntry.getValue(); - - for (Integer col : rowVec.getData().keySet()) - { - if (row == col.intValue()) continue; - if(DoubleUtils.equals(rowVec.get(col), this.get(col, row))){ + public boolean isSymmetric() { + if (this.getRowCount() != this.getColumnCount()) { + return false; + } + + for (final Map.Entry rowEntry : this.rows.entrySet()) { + final int row = rowEntry.getKey(); + final Vec rowVec = rowEntry.getValue(); + + for (final Integer col : rowVec.getData().keySet()) { + if (row == col.intValue()) { + continue; + } + if (DoubleUtils.equals(rowVec.get(col), this.get(col, row))) { return false; } } @@ -221,16 +216,14 @@ public boolean isSymmetric(){ return true; } - public Vec multiply(Vec rhs) - { - if(this.getColumnCount() != rhs.getDimension()){ + public Vec multiply(final Vec rhs) { + if (this.getColumnCount() != rhs.getDimension()) { System.err.println("columnCount must be equal to the size of the vector for multiplication"); } Vec row1; - Vec result = new Vec(getRowCount()); - for (Map.Entry entry : rows.entrySet()) - { + final Vec result = new Vec(this.getRowCount()); + for (final Map.Entry entry : this.rows.entrySet()) { row1 = entry.getValue(); result.set(entry.getKey(), row1.multiply(rhs)); } @@ -238,6 +231,35 @@ public Vec multiply(Vec rhs) } + public Map getRows() { + return this.rows; + } + public void setRows(final Map rows) { + this.rows = rows; + } + public int getRowCount() { + return this.rowCount; + } + + public void setRowCount(final int rowCount) { + this.rowCount = rowCount; + } + + public int getColumnCount() { + return this.columnCount; + } + + public void setColumnCount(final int columnCount) { + this.columnCount = columnCount; + } + + public double getDefaultValue() { + return this.defaultValue; + } + + public void setDefaultValue(final double defaultValue) { + this.defaultValue = defaultValue; + } } diff --git a/src/main/java/com/github/chen0040/rl/utils/Vec.java b/src/main/java/com/github/chen0040/rl/utils/Vec.java index 4699d0e..e890b39 100644 --- a/src/main/java/com/github/chen0040/rl/utils/Vec.java +++ b/src/main/java/com/github/chen0040/rl/utils/Vec.java @@ -1,8 +1,5 @@ package com.github.chen0040.rl.utils; -import lombok.Getter; -import lombok.Setter; - import java.io.Serializable; import java.util.HashMap; import java.util.List; @@ -13,91 +10,93 @@ /** * Created by xschen on 9/27/2015 0027. */ -@Getter -@Setter public class Vec implements Serializable { private Map data = new HashMap(); private int dimension; private double defaultValue; private int id = -1; - public Vec(){ + public Vec() { } - public Vec(double[] v){ - for(int i=0; i < v.length; ++i){ - set(i, v[i]); + public Vec(final double[] v) { + for (int i = 0; i < v.length; ++i) { + this.set(i, v[i]); } } - public Vec(int dimension){ + public Vec(final int dimension) { this.dimension = dimension; - defaultValue = 0; + this.defaultValue = 0; } - public Vec(int dimension, Map data){ + public Vec(final int dimension, final Map data) { this.dimension = dimension; - defaultValue = 0; + this.defaultValue = 0; - for(Map.Entry entry : data.entrySet()){ - set(entry.getKey(), entry.getValue()); + for (final Map.Entry entry : data.entrySet()) { + this.set(entry.getKey(), entry.getValue()); } } - public Vec makeCopy(){ - Vec clone = new Vec(dimension); + public Vec makeCopy() { + final Vec clone = new Vec(this.dimension); clone.copy(this); return clone; } - public void copy(Vec rhs){ - defaultValue = rhs.defaultValue; - dimension = rhs.dimension; - id = rhs.id; + public void copy(final Vec rhs) { + this.defaultValue = rhs.defaultValue; + this.dimension = rhs.dimension; + this.id = rhs.id; - data.clear(); - for(Map.Entry entry : rhs.data.entrySet()){ - data.put(entry.getKey(), entry.getValue()); + this.data.clear(); + for (final Map.Entry entry : rhs.data.entrySet()) { + this.data.put(entry.getKey(), entry.getValue()); } } - public void set(int i, double value){ - if(value == defaultValue) return; + public void set(final int i, final double value) { + if (value == this.defaultValue) { + return; + } - data.put(i, value); - if(i >= dimension){ - dimension = i+1; + this.data.put(i, value); + if (i >= this.dimension) { + this.dimension = i + 1; } } - public double get(int i){ - return data.getOrDefault(i, defaultValue); + public double get(final int i) { + return this.data.getOrDefault(i, this.defaultValue); } @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof Vec){ - Vec rhs2 = (Vec)rhs; - if(dimension != rhs2.dimension){ + public boolean equals(final Object rhs) { + if (rhs != null && rhs instanceof Vec) { + final Vec rhs2 = (Vec) rhs; + if (this.dimension != rhs2.dimension) { return false; } - if(data.size() != rhs2.data.size()){ + if (this.data.size() != rhs2.data.size()) { return false; } - for(Integer index : data.keySet()){ - if(!rhs2.data.containsKey(index)) return false; - if(!DoubleUtils.equals(data.get(index), rhs2.data.get(index))){ + for (final Integer index : this.data.keySet()) { + if (!rhs2.data.containsKey(index)) { + return false; + } + if (!DoubleUtils.equals(this.data.get(index), rhs2.data.get(index))) { return false; } } - if(defaultValue != rhs2.defaultValue){ - for(int i=0; i < dimension; ++i){ - if(data.containsKey(i)){ + if (this.defaultValue != rhs2.defaultValue) { + for (int i = 0; i < this.dimension; ++i) { + if (this.data.containsKey(i)) { return false; } } @@ -109,23 +108,23 @@ public boolean equals(Object rhs){ return false; } - public void setAll(double value){ - defaultValue = value; - for(Integer index : data.keySet()){ - data.put(index, defaultValue); + public void setAll(final double value) { + this.defaultValue = value; + for (final Integer index : this.data.keySet()) { + this.data.put(index, this.defaultValue); } } - public IndexValue indexWithMaxValue(Set indices){ - if(indices == null){ - return indexWithMaxValue(); - }else{ - IndexValue iv = new IndexValue(); + public IndexValue indexWithMaxValue(final Set indices) { + if (indices == null) { + return this.indexWithMaxValue(); + } else { + final IndexValue iv = new IndexValue(); iv.setIndex(-1); iv.setValue(Double.NEGATIVE_INFINITY); - for(Integer index : indices){ - double value = data.getOrDefault(index, Double.NEGATIVE_INFINITY); - if(value > iv.getValue()){ + for (final Integer index : indices) { + final double value = this.data.getOrDefault(index, Double.NEGATIVE_INFINITY); + if (value > iv.getValue()) { iv.setIndex(index); iv.setValue(value); } @@ -134,29 +133,31 @@ public IndexValue indexWithMaxValue(Set indices){ } } - public IndexValue indexWithMaxValue(){ - IndexValue iv = new IndexValue(); + public IndexValue indexWithMaxValue() { + final IndexValue iv = new IndexValue(); iv.setIndex(-1); iv.setValue(Double.NEGATIVE_INFINITY); - for(Map.Entry entry : data.entrySet()){ - if(entry.getKey() >= dimension) continue; + for (final Map.Entry entry : this.data.entrySet()) { + if (entry.getKey() >= this.dimension) { + continue; + } - double value = entry.getValue(); - if(value > iv.getValue()){ + final double value = entry.getValue(); + if (value > iv.getValue()) { iv.setValue(value); iv.setIndex(entry.getKey()); } } - if(!iv.isValid()){ - iv.setValue(defaultValue); - } else{ - if(iv.getValue() < defaultValue){ - for(int i=0; i < dimension; ++i){ - if(!data.containsKey(i)){ - iv.setValue(defaultValue); + if (!iv.isValid()) { + iv.setValue(this.defaultValue); + } else { + if (iv.getValue() < this.defaultValue) { + for (int i = 0; i < this.dimension; ++i) { + if (!this.data.containsKey(i)) { + iv.setValue(this.defaultValue); iv.setIndex(i); break; } @@ -168,29 +169,26 @@ public IndexValue indexWithMaxValue(){ } - - public Vec projectOrthogonal(Iterable vlist) { + public Vec projectOrthogonal(final Iterable vlist) { Vec b = this; - for(Vec v : vlist) - { + for (final Vec v : vlist) { b = b.minus(b.projectAlong(v)); } return b; } - public Vec projectOrthogonal(List vlist, Map alpha) { + public Vec projectOrthogonal(final List vlist, final Map alpha) { Vec b = this; - for(int i = 0; i < vlist.size(); ++i) - { - Vec v = vlist.get(i); - double norm_a = v.multiply(v); + for (int i = 0; i < vlist.size(); ++i) { + final Vec v = vlist.get(i); + final double norm_a = v.multiply(v); if (DoubleUtils.isZero(norm_a)) { - return new Vec(dimension); + return new Vec(this.dimension); } - double sigma = multiply(v) / norm_a; - Vec v_parallel = v.multiply(sigma); + final double sigma = this.multiply(v) / norm_a; + final Vec v_parallel = v.multiply(sigma); alpha.put(i, sigma); @@ -200,150 +198,169 @@ public Vec projectOrthogonal(List vlist, Map alpha) { return b; } - public Vec projectAlong(Vec rhs) - { - double norm_a = rhs.multiply(rhs); + public Vec projectAlong(final Vec rhs) { + final double norm_a = rhs.multiply(rhs); if (DoubleUtils.isZero(norm_a)) { - return new Vec(dimension); + return new Vec(this.dimension); } - double sigma = multiply(rhs) / norm_a; + final double sigma = this.multiply(rhs) / norm_a; return rhs.multiply(sigma); } - public Vec multiply(double rhs){ - Vec clone = (Vec)this.makeCopy(); - for(Integer i : data.keySet()){ - clone.data.put(i, rhs * data.get(i)); + public Vec multiply(final double rhs) { + final Vec clone = this.makeCopy(); + for (final Integer i : this.data.keySet()) { + clone.data.put(i, rhs * this.data.get(i)); } return clone; } - public double multiply(Vec rhs) - { + public double multiply(final Vec rhs) { double productSum = 0; - if(defaultValue == 0) { - for (Map.Entry entry : data.entrySet()) { + if (this.defaultValue == 0) { + for (final Map.Entry entry : this.data.entrySet()) { productSum += entry.getValue() * rhs.get(entry.getKey()); } } else { - for(int i=0; i < dimension; ++i){ - productSum += get(i) * rhs.get(i); + for (int i = 0; i < this.dimension; ++i) { + productSum += this.get(i) * rhs.get(i); } } return productSum; } - public Vec pow(double scalar) - { - Vec result = new Vec(dimension); - for (Map.Entry entry : data.entrySet()) - { + public Vec pow(final double scalar) { + final Vec result = new Vec(this.dimension); + for (final Map.Entry entry : this.data.entrySet()) { result.data.put(entry.getKey(), Math.pow(entry.getValue(), scalar)); } return result; } - public Vec add(Vec rhs) - { - Vec result = new Vec(dimension); + public Vec add(final Vec rhs) { + final Vec result = new Vec(this.dimension); int index; - for (Map.Entry entry : data.entrySet()) { + for (final Map.Entry entry : this.data.entrySet()) { index = entry.getKey(); result.data.put(index, entry.getValue() + rhs.data.get(index)); } - for(Map.Entry entry : rhs.data.entrySet()){ + for (final Map.Entry entry : rhs.data.entrySet()) { index = entry.getKey(); - if(result.data.containsKey(index)) continue; - result.data.put(index, entry.getValue() + data.get(index)); + if (result.data.containsKey(index)) { + continue; + } + result.data.put(index, entry.getValue() + this.data.get(index)); } return result; } - public Vec minus(Vec rhs) - { - Vec result = new Vec(dimension); + public Vec minus(final Vec rhs) { + final Vec result = new Vec(this.dimension); int index; - for (Map.Entry entry : data.entrySet()) { + for (final Map.Entry entry : this.data.entrySet()) { index = entry.getKey(); result.data.put(index, entry.getValue() - rhs.data.get(index)); } - for(Map.Entry entry : rhs.data.entrySet()){ + for (final Map.Entry entry : rhs.data.entrySet()) { index = entry.getKey(); - if(result.data.containsKey(index)) continue; - result.data.put(index, data.get(index) - entry.getValue()); + if (result.data.containsKey(index)) { + continue; + } + result.data.put(index, this.data.get(index) - entry.getValue()); } return result; } - public double sum(){ + public double sum() { double sum = 0; - for(Map.Entry entry : data.entrySet()){ + for (final Map.Entry entry : this.data.entrySet()) { sum += entry.getValue(); } - sum += defaultValue * (dimension - data.size()); + sum += this.defaultValue * (this.dimension - this.data.size()); return sum; } - public boolean isZero(){ - return DoubleUtils.isZero(sum()); + public boolean isZero() { + return DoubleUtils.isZero(this.sum()); } - public double norm(int level) - { - if (level == 1) - { + public double norm(final int level) { + if (level == 1) { double sum = 0; - for (Double val : data.values()) - { + for (final Double val : this.data.values()) { sum += Math.abs(val); } - if(!DoubleUtils.isZero(defaultValue)) { - sum += Math.abs(defaultValue) * (dimension - data.size()); + if (!DoubleUtils.isZero(this.defaultValue)) { + sum += Math.abs(this.defaultValue) * (this.dimension - this.data.size()); } return sum; - } - else if (level == 2) - { - double sum = multiply(this); - if(!DoubleUtils.isZero(defaultValue)){ - sum += (dimension - data.size()) * (defaultValue * defaultValue); + } else if (level == 2) { + double sum = this.multiply(this); + if (!DoubleUtils.isZero(this.defaultValue)) { + sum += (this.dimension - this.data.size()) * (this.defaultValue * this.defaultValue); } return Math.sqrt(sum); - } - else - { + } else { double sum = 0; - for (Double val : this.data.values()) - { + for (final Double val : this.data.values()) { sum += Math.pow(Math.abs(val), level); } - if(!DoubleUtils.isZero(defaultValue)) { - sum += Math.pow(Math.abs(defaultValue), level) * (dimension - data.size()); + if (!DoubleUtils.isZero(this.defaultValue)) { + sum += Math.pow(Math.abs(this.defaultValue), level) * (this.dimension - this.data.size()); } return Math.pow(sum, 1.0 / level); } } - public Vec normalize() - { - double norm = norm(2); // L2 norm is the cartesian distance - if (DoubleUtils.isZero(norm)) - { - return new Vec(dimension); + public Vec normalize() { + final double norm = this.norm(2); // L2 norm is the cartesian distance + if (DoubleUtils.isZero(norm)) { + return new Vec(this.dimension); } - Vec clone = new Vec(dimension); - clone.setAll(defaultValue / norm); + final Vec clone = new Vec(this.dimension); + clone.setAll(this.defaultValue / norm); - for (Integer k : data.keySet()) - { - clone.data.put(k, data.get(k) / norm); + for (final Integer k : this.data.keySet()) { + clone.data.put(k, this.data.get(k) / norm); } return clone; } + + public Map getData() { + return this.data; + } + + public void setData(final Map data) { + this.data = data; + } + + public int getDimension() { + return this.dimension; + } + + public void setDimension(final int dimension) { + this.dimension = dimension; + } + + public double getDefaultValue() { + return this.defaultValue; + } + + public void setDefaultValue(final double defaultValue) { + this.defaultValue = defaultValue; + } + + public int getId() { + return this.id; + } + + public void setId(final int id) { + this.id = id; + } } From ae22257ac4a5ed92619271780f015d64c27143c4 Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Sun, 10 Mar 2019 12:00:15 +0100 Subject: [PATCH 02/12] cleanup --- .../chen0040/rl/learning/rlearn/RAgent.java | 98 +++++---- .../chen0040/rl/learning/rlearn/RLearner.java | 135 ++++++------ .../com/github/chen0040/rl/utils/Vec.java | 194 +++--------------- 3 files changed, 136 insertions(+), 291 deletions(-) diff --git a/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java b/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java index f26f20a..1330349 100644 --- a/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java +++ b/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java @@ -3,99 +3,93 @@ import com.github.chen0040.rl.utils.IndexValue; import java.io.Serializable; -import java.util.Random; import java.util.Set; /** * Created by xschen on 9/27/2015 0027. */ -public class RAgent implements Serializable{ +public class RAgent implements Serializable { private RLearner learner; private int currentState; private int currentAction; private double currentValue; - public int getCurrentState(){ - return currentState; + public RAgent() { + + } + + public RAgent(final int stateCount, final int actionCount, final double alpha, final double beta, final double rho, final double initialQ) { + this.learner = new RLearner(stateCount, actionCount, alpha, beta, rho, initialQ); + } + + public RAgent(final int stateCount, final int actionCount) { + this.learner = new RLearner(stateCount, actionCount); } - public int getCurrentAction(){ - return currentAction; + public int getCurrentState() { + return this.currentState; } - public void start(int currentState){ + public int getCurrentAction() { + return this.currentAction; + } + + public void start(final int currentState) { this.currentState = currentState; } - public RAgent makeCopy(){ - RAgent clone = new RAgent(); + public RAgent makeCopy() { + final RAgent clone = new RAgent(); clone.copy(this); return clone; } - public void copy(RAgent rhs){ - currentState = rhs.currentState; - currentAction = rhs.currentAction; - learner.copy(rhs.learner); + public void copy(final RAgent rhs) { + this.currentState = rhs.currentState; + this.currentAction = rhs.currentAction; + this.learner.copy(rhs.learner); } @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof RAgent){ - RAgent rhs = (RAgent)obj; - if(!learner.equals(rhs.learner)) return false; - if(currentAction != rhs.currentAction) return false; - return currentState == rhs.currentState; + public boolean equals(final Object obj) { + if (obj instanceof RAgent) { + final RAgent rhs = (RAgent) obj; + return this.learner.equals(rhs.learner) && this.currentAction == rhs.currentAction && this.currentState == rhs.currentState; } return false; } - public IndexValue selectAction(){ - return selectAction(null); + public IndexValue selectAction() { + return this.selectAction(null); } - public IndexValue selectAction(Set actionsAtState){ - - if(currentAction==-1){ - IndexValue iv = learner.selectAction(currentState, actionsAtState); - currentAction = iv.getIndex(); - currentValue = iv.getValue(); + public IndexValue selectAction(final Set actionsAtState) { + if (this.currentAction == -1) { + final IndexValue iv = this.learner.selectAction(this.currentState, actionsAtState); + this.currentAction = iv.getIndex(); + this.currentValue = iv.getValue(); } - return new IndexValue(currentAction, currentValue); + return new IndexValue(this.currentAction, this.currentValue); } - public void update(int newState, double immediateReward){ - update(newState, null, immediateReward); + public void update(final int newState, final double immediateReward) { + this.update(newState, null, immediateReward); } - public void update(int newState, Set actionsAtState, double immediateReward){ - if(currentAction != -1) { - learner.update(currentState, currentAction, newState, actionsAtState, immediateReward); - currentState = newState; - currentAction = -1; + public void update(final int newState, final Set actionsAtState, final double immediateReward) { + if (this.currentAction != -1) { + this.learner.update(this.currentState, this.currentAction, newState, actionsAtState, immediateReward); + this.currentState = newState; + this.currentAction = -1; } } - public RAgent(){ - + public RLearner getLearner() { + return this.learner; } - - - public RLearner getLearner(){ - return learner; - } - - public void setLearner(RLearner learner){ + public void setLearner(final RLearner learner) { this.learner = learner; } - - public RAgent(int stateCount, int actionCount, double alpha, double beta, double rho, double initialQ){ - learner = new RLearner(stateCount, actionCount, alpha, beta, rho, initialQ); - } - - public RAgent(int stateCount, int actionCount){ - learner = new RLearner(stateCount, actionCount); - } } diff --git a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java index 910d53f..0bd5f50 100644 --- a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java @@ -9,7 +9,6 @@ import com.github.chen0040.rl.actionselection.EpsilonGreedyActionSelectionStrategy; import com.github.chen0040.rl.models.QModel; import com.github.chen0040.rl.utils.IndexValue; -import lombok.Getter; import java.io.Serializable; import java.util.Set; @@ -18,124 +17,120 @@ /** * Created by xschen on 9/27/2015 0027. */ -public class RLearner implements Serializable, Cloneable{ +public class RLearner implements Serializable, Cloneable { private QModel model; private ActionSelectionStrategy actionSelectionStrategy; private double rho; private double beta; - public String toJson() { - return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); + public RLearner() { + } - public static RLearner fromJson(String json){ + public RLearner(final int stateCount, final int actionCount) { + this(stateCount, actionCount, 0.1, 0.1, 0.7, 0.1); + } + + public RLearner(final int state_count, final int action_count, final double alpha, final double beta, final double rho, final double initial_Q) { + this.model = new QModel(state_count, action_count, initial_Q); + this.model.setAlpha(alpha); + + this.rho = rho; + this.beta = beta; + + this.actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); + } + + public static RLearner fromJson(final String json) { return JSON.parseObject(json, RLearner.class); } - public RLearner makeCopy(){ - RLearner clone = new RLearner(); + public String toJson() { + return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); + } + + public RLearner makeCopy() { + final RLearner clone = new RLearner(); clone.copy(this); return clone; } - public void copy(RLearner rhs){ - model = rhs.model.makeCopy(); - actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy)rhs.actionSelectionStrategy).clone(); - rho = rhs.rho; - beta = rhs.beta; + public void copy(final RLearner rhs) { + this.model = rhs.model.makeCopy(); + this.actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone(); + this.rho = rhs.rho; + this.beta = rhs.beta; } @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof RLearner){ - RLearner rhs = (RLearner)obj; - if(!model.equals(rhs.model)) return false; - if(!actionSelectionStrategy.equals(rhs.actionSelectionStrategy)) return false; - if(rho != rhs.rho) return false; - return beta == rhs.beta; + public boolean equals(final Object obj) { + if (obj instanceof RLearner) { + final RLearner rhs = (RLearner) obj; + if (!this.model.equals(rhs.model)) { + return false; + } + if (!this.actionSelectionStrategy.equals(rhs.actionSelectionStrategy)) { + return false; + } + if (this.rho != rhs.rho) { + return false; + } + return this.beta == rhs.beta; } return false; } - public RLearner(){ - - } - public double getRho() { - return rho; + return this.rho; } - public void setRho(double rho) { + public void setRho(final double rho) { this.rho = rho; } public double getBeta() { - return beta; + return this.beta; } - public void setBeta(double beta) { + public void setBeta(final double beta) { this.beta = beta; } - public QModel getModel(){ - return model; + public QModel getModel() { + return this.model; } - public void setModel(QModel model){ + public void setModel(final QModel model) { this.model = model; } - public String getActionSelection(){ - return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); + public String getActionSelection() { + return ActionSelectionStrategyFactory.serialize(this.actionSelectionStrategy); } - public void setActionSelection(String conf){ + public void setActionSelection(final String conf) { this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); } - public RLearner(int stateCount, int actionCount){ - this(stateCount, actionCount, 0.1, 0.1, 0.7, 0.1); - } - - public RLearner(int state_count, int action_count, double alpha, double beta, double rho, double initial_Q) - { - model = new QModel(state_count, action_count, initial_Q); - model.setAlpha(alpha); - - this.rho = rho; - this.beta = beta; - - actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); - } - - private double maxQAtState(int stateId, Set actionsAtState){ - IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState); - double maxQ = iv.getValue(); - return maxQ; + private double maxQAtState(final int stateId, final Set actionsAtState) { + return this.model.actionWithMaxQAtState(stateId, actionsAtState).getValue(); } - public void update(int currentState, int actionTaken, int newState, Set actionsAtNextStateId, double immediate_reward) - { - double oldQ = model.getQ(currentState, actionTaken); - - double alpha = model.getAlpha(currentState, actionTaken); // learning rate; - - double maxQ = maxQAtState(newState, actionsAtNextStateId); - - double newQ = oldQ + alpha * (immediate_reward - rho + maxQ - oldQ); - - double maxQAtCurrentState = maxQAtState(currentState, null); - if (newQ == maxQAtCurrentState) - { - rho = rho + beta * (immediate_reward - rho + maxQ - maxQAtCurrentState); + public void update(final int currentState, final int actionTaken, final int newState, final Set actionsAtNextStateId, final double immediate_reward) { + final double oldQ = this.model.getQ(currentState, actionTaken); + final double alpha = this.model.getAlpha(currentState, actionTaken); // learning rate; + final double maxQ = this.maxQAtState(newState, actionsAtNextStateId); + final double newQ = oldQ + alpha * (immediate_reward - this.rho + maxQ - oldQ); + final double maxQAtCurrentState = this.maxQAtState(currentState, null); + if (newQ == maxQAtCurrentState) { + this.rho += this.beta * (immediate_reward - this.rho + maxQ - maxQAtCurrentState); } - - model.setQ(currentState, actionTaken, newQ); + this.model.setQ(currentState, actionTaken, newQ); } - public IndexValue selectAction(int stateId, Set actionsAtState){ - return actionSelectionStrategy.selectAction(stateId, model, actionsAtState); + public IndexValue selectAction(final int stateId, final Set actionsAtState) { + return this.actionSelectionStrategy.selectAction(stateId, this.model, actionsAtState); } } diff --git a/src/main/java/com/github/chen0040/rl/utils/Vec.java b/src/main/java/com/github/chen0040/rl/utils/Vec.java index e890b39..bfbf945 100644 --- a/src/main/java/com/github/chen0040/rl/utils/Vec.java +++ b/src/main/java/com/github/chen0040/rl/utils/Vec.java @@ -2,16 +2,16 @@ import java.io.Serializable; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.IntStream; /** * Created by xschen on 9/27/2015 0027. */ public class Vec implements Serializable { - private Map data = new HashMap(); + private final Map data = new HashMap(); private int dimension; private double defaultValue; private int id = -1; @@ -52,9 +52,7 @@ public void copy(final Vec rhs) { this.id = rhs.id; this.data.clear(); - for (final Map.Entry entry : rhs.data.entrySet()) { - this.data.put(entry.getKey(), entry.getValue()); - } + rhs.data.forEach(this.data::put); } public void set(final int i, final double value) { @@ -75,33 +73,19 @@ public double get(final int i) { @Override public boolean equals(final Object rhs) { - if (rhs != null && rhs instanceof Vec) { + if (rhs instanceof Vec) { final Vec rhs2 = (Vec) rhs; - if (this.dimension != rhs2.dimension) { + if (this.dimension != rhs2.dimension || this.data.size() != rhs2.data.size()) { return false; } - - if (this.data.size() != rhs2.data.size()) { - return false; - } - for (final Integer index : this.data.keySet()) { - if (!rhs2.data.containsKey(index)) { - return false; - } - if (!DoubleUtils.equals(this.data.get(index), rhs2.data.get(index))) { + if (!rhs2.data.containsKey(index) || !DoubleUtils.equals(this.data.get(index), rhs2.data.get(index))) { return false; } } - if (this.defaultValue != rhs2.defaultValue) { - for (int i = 0; i < this.dimension; ++i) { - if (this.data.containsKey(i)) { - return false; - } - } + return IntStream.range(0, this.dimension).noneMatch(this.data::containsKey); } - return true; } @@ -110,9 +94,7 @@ public boolean equals(final Object rhs) { public void setAll(final double value) { this.defaultValue = value; - for (final Integer index : this.data.keySet()) { - this.data.put(index, this.defaultValue); - } + this.data.keySet().forEach(index -> this.data.put(index, this.defaultValue)); } public IndexValue indexWithMaxValue(final Set indices) { @@ -133,7 +115,7 @@ public IndexValue indexWithMaxValue(final Set indices) { } } - public IndexValue indexWithMaxValue() { + private IndexValue indexWithMaxValue() { final IndexValue iv = new IndexValue(); iv.setIndex(-1); iv.setValue(Double.NEGATIVE_INFINITY); @@ -169,46 +151,7 @@ public IndexValue indexWithMaxValue() { } - public Vec projectOrthogonal(final Iterable vlist) { - Vec b = this; - for (final Vec v : vlist) { - b = b.minus(b.projectAlong(v)); - } - - return b; - } - - public Vec projectOrthogonal(final List vlist, final Map alpha) { - Vec b = this; - for (int i = 0; i < vlist.size(); ++i) { - final Vec v = vlist.get(i); - final double norm_a = v.multiply(v); - - if (DoubleUtils.isZero(norm_a)) { - return new Vec(this.dimension); - } - final double sigma = this.multiply(v) / norm_a; - final Vec v_parallel = v.multiply(sigma); - - alpha.put(i, sigma); - - b = b.minus(v_parallel); - } - - return b; - } - - public Vec projectAlong(final Vec rhs) { - final double norm_a = rhs.multiply(rhs); - - if (DoubleUtils.isZero(norm_a)) { - return new Vec(this.dimension); - } - final double sigma = this.multiply(rhs) / norm_a; - return rhs.multiply(sigma); - } - - public Vec multiply(final double rhs) { + private Vec multiply(final double rhs) { final Vec clone = this.makeCopy(); for (final Integer i : this.data.keySet()) { clone.data.put(i, rhs * this.data.get(i)); @@ -216,86 +159,24 @@ public Vec multiply(final double rhs) { return clone; } - public double multiply(final Vec rhs) { - double productSum = 0; - if (this.defaultValue == 0) { - for (final Map.Entry entry : this.data.entrySet()) { - productSum += entry.getValue() * rhs.get(entry.getKey()); - } - } else { - for (int i = 0; i < this.dimension; ++i) { - productSum += this.get(i) * rhs.get(i); - } - } - - return productSum; - } - - public Vec pow(final double scalar) { - final Vec result = new Vec(this.dimension); - for (final Map.Entry entry : this.data.entrySet()) { - result.data.put(entry.getKey(), Math.pow(entry.getValue(), scalar)); - } - return result; - } - - public Vec add(final Vec rhs) { - final Vec result = new Vec(this.dimension); - int index; - for (final Map.Entry entry : this.data.entrySet()) { - index = entry.getKey(); - result.data.put(index, entry.getValue() + rhs.data.get(index)); - } - for (final Map.Entry entry : rhs.data.entrySet()) { - index = entry.getKey(); - if (result.data.containsKey(index)) { - continue; - } - result.data.put(index, entry.getValue() + this.data.get(index)); - } - - return result; - } - - public Vec minus(final Vec rhs) { - final Vec result = new Vec(this.dimension); - int index; - for (final Map.Entry entry : this.data.entrySet()) { - index = entry.getKey(); - result.data.put(index, entry.getValue() - rhs.data.get(index)); - } - for (final Map.Entry entry : rhs.data.entrySet()) { - index = entry.getKey(); - if (result.data.containsKey(index)) { - continue; - } - result.data.put(index, this.data.get(index) - entry.getValue()); - } + double multiply(final Vec rhs) { - return result; + return this.defaultValue == 0 ? + this.data.entrySet().stream().mapToDouble(entry -> entry.getValue() * rhs.get(entry.getKey())).sum() : + IntStream.range(0, this.dimension).mapToDouble(i -> this.get(i) * rhs.get(i)).sum(); } - public double sum() { - double sum = 0; - - for (final Map.Entry entry : this.data.entrySet()) { - sum += entry.getValue(); - } - sum += this.defaultValue * (this.dimension - this.data.size()); - - return sum; + private double sum() { + return this.data.values().stream().mapToDouble(v -> v).sum() + this.defaultValue * (this.dimension - this.data.size()); } - public boolean isZero() { + boolean isZero() { return DoubleUtils.isZero(this.sum()); } - public double norm(final int level) { + double norm(final int level) { if (level == 1) { - double sum = 0; - for (final Double val : this.data.values()) { - sum += Math.abs(val); - } + double sum = this.data.values().stream().mapToDouble(Math::abs).sum(); if (!DoubleUtils.isZero(this.defaultValue)) { sum += Math.abs(this.defaultValue) * (this.dimension - this.data.size()); } @@ -307,10 +188,7 @@ public double norm(final int level) { } return Math.sqrt(sum); } else { - double sum = 0; - for (final Double val : this.data.values()) { - sum += Math.pow(Math.abs(val), level); - } + double sum = this.data.values().stream().mapToDouble(val -> Math.pow(Math.abs(val), level)).sum(); if (!DoubleUtils.isZero(this.defaultValue)) { sum += Math.pow(Math.abs(this.defaultValue), level) * (this.dimension - this.data.size()); } @@ -318,7 +196,7 @@ public double norm(final int level) { } } - public Vec normalize() { + Vec normalize() { final double norm = this.norm(2); // L2 norm is the cartesian distance if (DoubleUtils.isZero(norm)) { return new Vec(this.dimension); @@ -326,41 +204,19 @@ public Vec normalize() { final Vec clone = new Vec(this.dimension); clone.setAll(this.defaultValue / norm); - for (final Integer k : this.data.keySet()) { - clone.data.put(k, this.data.get(k) / norm); - } + this.data.keySet().forEach(k -> clone.data.put(k, this.data.get(k) / norm)); return clone; } - public Map getData() { + Map getData() { return this.data; } - public void setData(final Map data) { - this.data = data; - } - - public int getDimension() { + int getDimension() { return this.dimension; } - public void setDimension(final int dimension) { - this.dimension = dimension; - } - - public double getDefaultValue() { - return this.defaultValue; - } - - public void setDefaultValue(final double defaultValue) { - this.defaultValue = defaultValue; - } - - public int getId() { - return this.id; - } - - public void setId(final int id) { + void setId(final int id) { this.id = id; } } From af2cd298c32f513da4b8b9f0b8bba75bafb7ebe0 Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Sun, 10 Mar 2019 12:04:27 +0100 Subject: [PATCH 03/12] cleanup 2 --- .../github/chen0040/rl/utils/DoubleUtils.java | 13 +- .../github/chen0040/rl/utils/IndexValue.java | 4 +- .../com/github/chen0040/rl/utils/Matrix.java | 167 +----------------- .../com/github/chen0040/rl/utils/Vec.java | 14 +- 4 files changed, 24 insertions(+), 174 deletions(-) diff --git a/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java b/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java index e840bc1..58ee629 100644 --- a/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java +++ b/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java @@ -3,12 +3,13 @@ /** * Created by xschen on 10/11/2015 0011. */ -public class DoubleUtils { - public static boolean equals(double a1, double a2){ - return Math.abs(a1-a2) < 1e-10; - } +public enum DoubleUtils { + ; + + public static final double TOLERANCE = 1e-10; - public static boolean isZero(double a){ - return a < 1e-20; + public static boolean equals(final double a1, final double a2) { + return Math.abs(a1 - a2) < DoubleUtils.TOLERANCE; } + } diff --git a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java index 41c400b..6c3d6ae 100644 --- a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java +++ b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java @@ -26,14 +26,14 @@ public IndexValue makeCopy() { @Override public boolean equals(final Object rhs) { - if (rhs != null && rhs instanceof IndexValue) { + if (rhs instanceof IndexValue) { final IndexValue rhs2 = (IndexValue) rhs; return this.index == rhs2.index && this.value == rhs2.value; } return false; } - public boolean isValid() { + boolean isValid() { return this.index != -1; } diff --git a/src/main/java/com/github/chen0040/rl/utils/Matrix.java b/src/main/java/com/github/chen0040/rl/utils/Matrix.java index 5cb6c17..b20b86f 100644 --- a/src/main/java/com/github/chen0040/rl/utils/Matrix.java +++ b/src/main/java/com/github/chen0040/rl/utils/Matrix.java @@ -1,11 +1,7 @@ package com.github.chen0040.rl.utils; -import com.alibaba.fastjson.annotation.JSONField; - import java.io.Serializable; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; @@ -13,71 +9,28 @@ * Created by xschen on 9/27/2015 0027. */ public class Matrix implements Serializable { - private Map rows = new HashMap<>(); + private final Map rows = new HashMap<>(); private int rowCount; private int columnCount; private double defaultValue; - public Matrix() { - - } - - public Matrix(final double[][] A) { - for (int i = 0; i < A.length; ++i) { - final double[] B = A[i]; - for (int j = 0; j < B.length; ++j) { - this.set(i, j, B[j]); - } - } - } - public Matrix(final int rowCount, final int columnCount) { this.rowCount = rowCount; this.columnCount = columnCount; this.defaultValue = 0; } - public static Matrix identity(final int dimension) { - final Matrix m = new Matrix(dimension, dimension); - for (int i = 0; i < m.getRowCount(); ++i) { - m.set(i, i, 1); - } - return m; - } - - public void setRow(final int rowIndex, final Vec rowVector) { - rowVector.setId(rowIndex); - this.rows.put(rowIndex, rowVector); - } - @Override public boolean equals(final Object rhs) { - if (rhs != null && rhs instanceof Matrix) { + if (rhs instanceof Matrix) { final Matrix rhs2 = (Matrix) rhs; if (this.rowCount != rhs2.rowCount || this.columnCount != rhs2.columnCount) { return false; } if (this.defaultValue == rhs2.defaultValue) { - for (final Integer index : this.rows.keySet()) { - if (!rhs2.rows.containsKey(index)) { - return false; - } - if (!this.rows.get(index).equals(rhs2.rows.get(index))) { - System.out.println("failed!"); - return false; - } - } - - for (final Integer index : rhs2.rows.keySet()) { - if (!this.rows.containsKey(index)) { - return false; - } - if (!rhs2.rows.get(index).equals(this.rows.get(index))) { - System.out.println("failed! 22"); - return false; - } - } + return this.rows.keySet().stream().noneMatch(index -> !rhs2.rows.containsKey(index) || !this.rows.get(index).equals(rhs2.rows.get(index))) && + rhs2.rows.keySet().stream().noneMatch(index -> !this.rows.containsKey(index) || !rhs2.rows.get(index).equals(this.rows.get(index))); } else { for (int i = 0; i < this.rowCount; ++i) { @@ -101,16 +54,14 @@ public Matrix makeCopy() { return clone; } - public void copy(final Matrix rhs) { + private void copy(final Matrix rhs) { this.rowCount = rhs.rowCount; this.columnCount = rhs.columnCount; this.defaultValue = rhs.defaultValue; this.rows.clear(); - for (final Map.Entry entry : rhs.rows.entrySet()) { - this.rows.put(entry.getKey(), entry.getValue().makeCopy()); - } + rhs.rows.forEach((key, value) -> this.rows.put(key, value.makeCopy())); } public void set(final int rowIndex, final int columnIndex, final double value) { @@ -147,119 +98,13 @@ public double get(final int rowIndex, final int columnIndex) { return row.get(columnIndex); } - public List columnVectors() { - final Matrix A = this; - final int n = A.getColumnCount(); - final int rowCount = A.getRowCount(); - - final List Acols = new ArrayList(); - - for (int c = 0; c < n; ++c) { - final Vec Acol = new Vec(rowCount); - Acol.setAll(this.defaultValue); - Acol.setId(c); - - for (int r = 0; r < rowCount; ++r) { - Acol.set(r, A.get(r, c)); - } - Acols.add(Acol); - } - return Acols; - } - - public Matrix multiply(final Matrix rhs) { - if (this.getColumnCount() != rhs.getRowCount()) { - System.err.println("A.columnCount must be equal to B.rowCount in multiplication"); - return null; - } - - Vec row1; - Vec col2; - - final Matrix result = new Matrix(this.getRowCount(), rhs.getColumnCount()); - result.setAll(this.defaultValue); - - final List rhsColumns = rhs.columnVectors(); - - for (final Map.Entry entry : this.rows.entrySet()) { - final int r1 = entry.getKey(); - row1 = entry.getValue(); - for (int c2 = 0; c2 < rhsColumns.size(); ++c2) { - col2 = rhsColumns.get(c2); - result.set(r1, c2, row1.multiply(col2)); - } - } - - return result; - } - - @JSONField(serialize = false) - public boolean isSymmetric() { - if (this.getRowCount() != this.getColumnCount()) { - return false; - } - - for (final Map.Entry rowEntry : this.rows.entrySet()) { - final int row = rowEntry.getKey(); - final Vec rowVec = rowEntry.getValue(); - - for (final Integer col : rowVec.getData().keySet()) { - if (row == col.intValue()) { - continue; - } - if (DoubleUtils.equals(rowVec.get(col), this.get(col, row))) { - return false; - } - } - } - - return true; - } - - public Vec multiply(final Vec rhs) { - if (this.getColumnCount() != rhs.getDimension()) { - System.err.println("columnCount must be equal to the size of the vector for multiplication"); - } - - Vec row1; - final Vec result = new Vec(this.getRowCount()); - for (final Map.Entry entry : this.rows.entrySet()) { - row1 = entry.getValue(); - result.set(entry.getKey(), row1.multiply(rhs)); - } - return result; - } - - - public Map getRows() { - return this.rows; - } - - public void setRows(final Map rows) { - this.rows = rows; - } public int getRowCount() { return this.rowCount; } - public void setRowCount(final int rowCount) { - this.rowCount = rowCount; - } - public int getColumnCount() { return this.columnCount; } - public void setColumnCount(final int columnCount) { - this.columnCount = columnCount; - } - - public double getDefaultValue() { - return this.defaultValue; - } - - public void setDefaultValue(final double defaultValue) { - this.defaultValue = defaultValue; - } } diff --git a/src/main/java/com/github/chen0040/rl/utils/Vec.java b/src/main/java/com/github/chen0040/rl/utils/Vec.java index bfbf945..ac2bf86 100644 --- a/src/main/java/com/github/chen0040/rl/utils/Vec.java +++ b/src/main/java/com/github/chen0040/rl/utils/Vec.java @@ -40,6 +40,10 @@ public Vec(final int dimension, final Map data) { } } + static boolean isZero(final double a) { + return a < 1e-20; + } + public Vec makeCopy() { final Vec clone = new Vec(this.dimension); clone.copy(this); @@ -171,25 +175,25 @@ private double sum() { } boolean isZero() { - return DoubleUtils.isZero(this.sum()); + return Vec.isZero(this.sum()); } double norm(final int level) { if (level == 1) { double sum = this.data.values().stream().mapToDouble(Math::abs).sum(); - if (!DoubleUtils.isZero(this.defaultValue)) { + if (!Vec.isZero(this.defaultValue)) { sum += Math.abs(this.defaultValue) * (this.dimension - this.data.size()); } return sum; } else if (level == 2) { double sum = this.multiply(this); - if (!DoubleUtils.isZero(this.defaultValue)) { + if (!Vec.isZero(this.defaultValue)) { sum += (this.dimension - this.data.size()) * (this.defaultValue * this.defaultValue); } return Math.sqrt(sum); } else { double sum = this.data.values().stream().mapToDouble(val -> Math.pow(Math.abs(val), level)).sum(); - if (!DoubleUtils.isZero(this.defaultValue)) { + if (!Vec.isZero(this.defaultValue)) { sum += Math.pow(Math.abs(this.defaultValue), level) * (this.dimension - this.data.size()); } return Math.pow(sum, 1.0 / level); @@ -198,7 +202,7 @@ boolean isZero() { Vec normalize() { final double norm = this.norm(2); // L2 norm is the cartesian distance - if (DoubleUtils.isZero(norm)) { + if (Vec.isZero(norm)) { return new Vec(this.dimension); } final Vec clone = new Vec(this.dimension); From 749880ef64a00f4517b85a8204cde01e75f84f54 Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Sun, 10 Mar 2019 12:06:05 +0100 Subject: [PATCH 04/12] cleanup 3 --- .../github/chen0040/rl/utils/MatrixUtils.java | 29 ----------- .../github/chen0040/rl/utils/TupleTwo.java | 50 +++++++------------ 2 files changed, 19 insertions(+), 60 deletions(-) delete mode 100644 src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java diff --git a/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java b/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java deleted file mode 100644 index e43c28b..0000000 --- a/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java +++ /dev/null @@ -1,29 +0,0 @@ -package com.github.chen0040.rl.utils; - -import java.util.List; - - -/** - * Created by xschen on 10/11/2015 0011. - */ -public class MatrixUtils { - /** - * Convert a list of column vectors into a matrix - */ - public static Matrix matrixFromColumnVectors(List R) - { - int n = R.size(); - int m = R.get(0).getDimension(); - - Matrix T = new Matrix(m, n); - for (int c = 0; c < n; ++c) - { - Vec Rcol = R.get(c); - for (int r : Rcol.getData().keySet()) - { - T.set(r, c, Rcol.get(r)); - } - } - return T; - } -} diff --git a/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java b/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java index b4895ea..5dba744 100644 --- a/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java +++ b/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java @@ -1,56 +1,44 @@ package com.github.chen0040.rl.utils; +import java.util.Objects; + /** * Created by xschen on 10/11/2015 0011. */ public class TupleTwo { - private T1 item1; - private T2 item2; - - public TupleTwo(T1 item1, T2 item2){ - this.item1 = item1; - this.item2 = item2; - } - - public T1 getItem1() { - return item1; - } + private final T1 item1; + private final T2 item2; - public void setItem1(T1 item1) { + private TupleTwo(final T1 item1, final T2 item2) { this.item1 = item1; - } - - public T2 getItem2() { - return item2; - } - - public void setItem2(T2 item2) { this.item2 = item2; } - public static TupleTwo create(U1 item1, U2 item2){ - return new TupleTwo(item1, item2); + static TupleTwo create(final U1 item1, final U2 item2) { + return new TupleTwo<>(item1, item2); } - @Override public boolean equals(Object o) { - if (this == o) + @Override + public boolean equals(final Object o) { + if (this == o) { return true; - if (o == null || getClass() != o.getClass()) + } + if (o == null || this.getClass() != o.getClass()) { return false; + } - TupleTwo tupleTwo = (TupleTwo) o; + final TupleTwo tupleTwo = (TupleTwo) o; - if (item1 != null ? !item1.equals(tupleTwo.item1) : tupleTwo.item1 != null) - return false; - return item2 != null ? item2.equals(tupleTwo.item2) : tupleTwo.item2 == null; + return Objects.equals(this.item1, tupleTwo.item1) && Objects.equals(this.item2, tupleTwo.item2); } - @Override public int hashCode() { - int result = item1 != null ? item1.hashCode() : 0; - result = 31 * result + (item2 != null ? item2.hashCode() : 0); + @Override + public int hashCode() { + int result = this.item1 != null ? this.item1.hashCode() : 0; + result = 31 * result + (this.item2 != null ? this.item2.hashCode() : 0); return result; } } From 98b5cfa53217406a18dd68f4ac50e36b25a0db94 Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Sun, 10 Mar 2019 12:07:54 +0100 Subject: [PATCH 05/12] cleanup 4 --- .../com/github/chen0040/rl/utils/Vec.java | 58 ++++--------------- .../github/chen0040/rl/utils/VectorUtils.java | 23 +++----- 2 files changed, 21 insertions(+), 60 deletions(-) diff --git a/src/main/java/com/github/chen0040/rl/utils/Vec.java b/src/main/java/com/github/chen0040/rl/utils/Vec.java index ac2bf86..1e10555 100644 --- a/src/main/java/com/github/chen0040/rl/utils/Vec.java +++ b/src/main/java/com/github/chen0040/rl/utils/Vec.java @@ -11,7 +11,7 @@ * Created by xschen on 9/27/2015 0027. */ public class Vec implements Serializable { - private final Map data = new HashMap(); + private final Map data = new HashMap<>(); private int dimension; private double defaultValue; private int id = -1; @@ -21,9 +21,7 @@ public Vec() { } public Vec(final double[] v) { - for (int i = 0; i < v.length; ++i) { - this.set(i, v[i]); - } + IntStream.range(0, v.length).forEach(i -> this.set(i, v[i])); } public Vec(final int dimension) { @@ -35,13 +33,11 @@ public Vec(final int dimension, final Map data) { this.dimension = dimension; this.defaultValue = 0; - for (final Map.Entry entry : data.entrySet()) { - this.set(entry.getKey(), entry.getValue()); - } + data.forEach(this::set); } - static boolean isZero(final double a) { - return a < 1e-20; + private static boolean isZero(final double a) { + return a < DoubleUtils.TOLERANCE; } public Vec makeCopy() { @@ -155,15 +151,7 @@ private IndexValue indexWithMaxValue() { } - private Vec multiply(final double rhs) { - final Vec clone = this.makeCopy(); - for (final Integer i : this.data.keySet()) { - clone.data.put(i, rhs * this.data.get(i)); - } - return clone; - } - - double multiply(final Vec rhs) { + private double multiply(final Vec rhs) { return this.defaultValue == 0 ? this.data.entrySet().stream().mapToDouble(entry -> entry.getValue() * rhs.get(entry.getKey())).sum() : @@ -178,30 +166,16 @@ boolean isZero() { return Vec.isZero(this.sum()); } - double norm(final int level) { - if (level == 1) { - double sum = this.data.values().stream().mapToDouble(Math::abs).sum(); - if (!Vec.isZero(this.defaultValue)) { - sum += Math.abs(this.defaultValue) * (this.dimension - this.data.size()); - } - return sum; - } else if (level == 2) { - double sum = this.multiply(this); - if (!Vec.isZero(this.defaultValue)) { - sum += (this.dimension - this.data.size()) * (this.defaultValue * this.defaultValue); - } - return Math.sqrt(sum); - } else { - double sum = this.data.values().stream().mapToDouble(val -> Math.pow(Math.abs(val), level)).sum(); - if (!Vec.isZero(this.defaultValue)) { - sum += Math.pow(Math.abs(this.defaultValue), level) * (this.dimension - this.data.size()); - } - return Math.pow(sum, 1.0 / level); + double norm() { + double sum = this.multiply(this); + if (!Vec.isZero(this.defaultValue)) { + sum += (this.dimension - this.data.size()) * (this.defaultValue * this.defaultValue); } + return Math.sqrt(sum); } Vec normalize() { - final double norm = this.norm(2); // L2 norm is the cartesian distance + final double norm = this.norm(); // L2 norm is the cartesian distance if (Vec.isZero(norm)) { return new Vec(this.dimension); } @@ -212,14 +186,6 @@ Vec normalize() { return clone; } - Map getData() { - return this.data; - } - - int getDimension() { - return this.dimension; - } - void setId(final int id) { this.id = id; } diff --git a/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java b/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java index 2bbfbaa..bf4f4e4 100644 --- a/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java +++ b/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java @@ -8,13 +8,10 @@ * Created by xschen on 10/11/2015 0011. */ public class VectorUtils { - public static List removeZeroVectors(Iterable vlist) - { - List vstarlist = new ArrayList(); - for (Vec v : vlist) - { - if (!v.isZero()) - { + public static List removeZeroVectors(final Iterable vlist) { + final List vstarlist = new ArrayList(); + for (final Vec v : vlist) { + if (!v.isZero()) { vstarlist.add(v); } } @@ -22,13 +19,11 @@ public static List removeZeroVectors(Iterable vlist) return vstarlist; } - public static TupleTwo, List> normalize(Iterable vlist) - { - List norms = new ArrayList(); - List vstarlist = new ArrayList(); - for (Vec v : vlist) - { - norms.add(v.norm(2)); + public static TupleTwo, List> normalize(final Iterable vlist) { + final List norms = new ArrayList(); + final List vstarlist = new ArrayList(); + for (final Vec v : vlist) { + norms.add(v.norm()); vstarlist.add(v.normalize()); } From e9c6ef7c77e6e23d10baab962a5ffc86c2b8e1d5 Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Sun, 10 Mar 2019 12:08:41 +0100 Subject: [PATCH 06/12] cleanup 5 --- .../com/github/chen0040/rl/utils/Vec.java | 39 ------------------- .../github/chen0040/rl/utils/VectorUtils.java | 34 ---------------- 2 files changed, 73 deletions(-) delete mode 100644 src/main/java/com/github/chen0040/rl/utils/VectorUtils.java diff --git a/src/main/java/com/github/chen0040/rl/utils/Vec.java b/src/main/java/com/github/chen0040/rl/utils/Vec.java index 1e10555..ca007cf 100644 --- a/src/main/java/com/github/chen0040/rl/utils/Vec.java +++ b/src/main/java/com/github/chen0040/rl/utils/Vec.java @@ -36,10 +36,6 @@ public Vec(final int dimension, final Map data) { data.forEach(this::set); } - private static boolean isZero(final double a) { - return a < DoubleUtils.TOLERANCE; - } - public Vec makeCopy() { final Vec clone = new Vec(this.dimension); clone.copy(this); @@ -151,41 +147,6 @@ private IndexValue indexWithMaxValue() { } - private double multiply(final Vec rhs) { - - return this.defaultValue == 0 ? - this.data.entrySet().stream().mapToDouble(entry -> entry.getValue() * rhs.get(entry.getKey())).sum() : - IntStream.range(0, this.dimension).mapToDouble(i -> this.get(i) * rhs.get(i)).sum(); - } - - private double sum() { - return this.data.values().stream().mapToDouble(v -> v).sum() + this.defaultValue * (this.dimension - this.data.size()); - } - - boolean isZero() { - return Vec.isZero(this.sum()); - } - - double norm() { - double sum = this.multiply(this); - if (!Vec.isZero(this.defaultValue)) { - sum += (this.dimension - this.data.size()) * (this.defaultValue * this.defaultValue); - } - return Math.sqrt(sum); - } - - Vec normalize() { - final double norm = this.norm(); // L2 norm is the cartesian distance - if (Vec.isZero(norm)) { - return new Vec(this.dimension); - } - final Vec clone = new Vec(this.dimension); - clone.setAll(this.defaultValue / norm); - - this.data.keySet().forEach(k -> clone.data.put(k, this.data.get(k) / norm)); - return clone; - } - void setId(final int id) { this.id = id; } diff --git a/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java b/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java deleted file mode 100644 index bf4f4e4..0000000 --- a/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java +++ /dev/null @@ -1,34 +0,0 @@ -package com.github.chen0040.rl.utils; - -import java.util.ArrayList; -import java.util.List; - - -/** - * Created by xschen on 10/11/2015 0011. - */ -public class VectorUtils { - public static List removeZeroVectors(final Iterable vlist) { - final List vstarlist = new ArrayList(); - for (final Vec v : vlist) { - if (!v.isZero()) { - vstarlist.add(v); - } - } - - return vstarlist; - } - - public static TupleTwo, List> normalize(final Iterable vlist) { - final List norms = new ArrayList(); - final List vstarlist = new ArrayList(); - for (final Vec v : vlist) { - norms.add(v.norm()); - vstarlist.add(v.normalize()); - } - - return TupleTwo.create(vstarlist, norms); - } - - -} From 24201ff92911ba982dd6fd55ed951a88bb2ed276 Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Sun, 10 Mar 2019 12:18:11 +0100 Subject: [PATCH 07/12] cleanup 6 --- pom.xml | 5 + .../rl/learning/sarsa/SarsaAgent.java | 140 +++++++++--------- .../chen0040/rl/models/DefaultValues.java | 8 + .../rl/models/EligibilityTraceUpdateMode.java | 3 +- .../com/github/chen0040/rl/models/QModel.java | 70 ++++----- .../chen0040/rl/models/UtilityModel.java | 36 ++--- .../github/chen0040/rl/utils/TupleTwo.java | 44 ------ 7 files changed, 126 insertions(+), 180 deletions(-) create mode 100644 src/main/java/com/github/chen0040/rl/models/DefaultValues.java delete mode 100644 src/main/java/com/github/chen0040/rl/utils/TupleTwo.java diff --git a/pom.xml b/pom.xml index 6826743..443ad61 100644 --- a/pom.xml +++ b/pom.xml @@ -503,6 +503,11 @@ fastjson 1.2.41 + + org.jetbrains + annotations + 13.0 + diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java index c4c8f27..80335fd 100644 --- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java +++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java @@ -3,15 +3,14 @@ import com.github.chen0040.rl.utils.IndexValue; import java.io.Serializable; -import java.util.Random; import java.util.Set; /** - * Created by xschen on 9/27/2015 0027. - * Implement temporal-difference learning Sarsa, which is an on-policy TD control algorithm + * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Sarsa, which is an on-policy TD control + * algorithm */ -public class SarsaAgent implements Serializable{ +public class SarsaAgent implements Serializable { private SarsaLearner learner; private int currentState; private int currentAction; @@ -19,111 +18,118 @@ public class SarsaAgent implements Serializable{ private int prevState; private int prevAction; - public int getCurrentState(){ - return currentState; + public SarsaAgent(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) { + this.learner = new SarsaLearner(stateCount, actionCount, alpha, gamma, initialQ); } - public int getCurrentAction(){ - return currentAction; + public SarsaAgent(final int stateCount, final int actionCount) { + this.learner = new SarsaLearner(stateCount, actionCount); } - public int getPrevState() { return prevState; } + public SarsaAgent(final SarsaLearner learner) { + this.learner = learner; + } - public int getPrevAction() { return prevAction; } + public SarsaAgent() { - public void start(int currentState){ - this.currentState = currentState; - this.prevState = -1; - this.prevAction = -1; } - public IndexValue selectAction(){ - return selectAction(null); + @SuppressWarnings("Used-by-user") + public int getCurrentState() { + return this.currentState; } - public IndexValue selectAction(Set actionsAtState){ - if(currentAction == -1){ - IndexValue iv = learner.selectAction(currentState, actionsAtState); - currentAction = iv.getIndex(); - currentValue = iv.getValue(); - } + @SuppressWarnings("Used-by-user") + public int getCurrentAction() { + return this.currentAction; + } - return new IndexValue(currentAction, currentValue); + @SuppressWarnings("Used-by-user") + public int getPrevState() { + return this.prevState; } - public void update(int actionTaken, int newState, double immediateReward){ - update(actionTaken, newState, null, immediateReward); + @SuppressWarnings("Used-by-user") + public int getPrevAction() { + return this.prevAction; } - public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward){ + public void start(final int currentState) { + this.currentState = currentState; + this.prevState = -1; + this.prevAction = -1; + } - IndexValue iv = learner.selectAction(currentState, actionsAtNewState); - int futureAction = iv.getIndex(); + public IndexValue selectAction() { + return this.selectAction(null); + } - learner.update(currentState, actionTaken, newState, futureAction, immediateReward); + public IndexValue selectAction(final Set actionsAtState) { + if (this.currentAction == -1) { + final IndexValue iv = this.learner.selectAction(this.currentState, actionsAtState); + this.currentAction = iv.getIndex(); + this.currentValue = iv.getValue(); + } - prevState = this.currentState; - this.prevAction = actionTaken; + return new IndexValue(this.currentAction, this.currentValue); + } - currentAction = futureAction; - currentState = newState; + public void update(final int actionTaken, final int newState, final double immediateReward) { + this.update(actionTaken, newState, null, immediateReward); } + public void update(final int actionTaken, final int newState, final Set actionsAtNewState, final double immediateReward) { + final IndexValue iv = this.learner.selectAction(this.currentState, actionsAtNewState); + final int futureAction = iv.getIndex(); - public SarsaLearner getLearner(){ - return learner; - } + this.learner.update(this.currentState, actionTaken, newState, futureAction, immediateReward); - public void setLearner(SarsaLearner learner){ - this.learner = learner; - } + this.prevState = this.currentState; + this.prevAction = actionTaken; - public SarsaAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ){ - learner = new SarsaLearner(stateCount, actionCount, alpha, gamma, initialQ); + this.currentAction = futureAction; + this.currentState = newState; } - public SarsaAgent(int stateCount, int actionCount){ - learner = new SarsaLearner(stateCount, actionCount); + public SarsaLearner getLearner() { + return this.learner; } - public SarsaAgent(SarsaLearner learner){ + public void setLearner(final SarsaLearner learner) { this.learner = learner; } - public SarsaAgent(){ - - } - - public void enableEligibilityTrace(double lambda){ - SarsaLambdaLearner acll = new SarsaLambdaLearner(learner); + @SuppressWarnings("Used-by-user") + public void enableEligibilityTrace(final double lambda) { + final SarsaLambdaLearner acll = new SarsaLambdaLearner(this.learner); acll.setLambda(lambda); - learner = acll; + this.learner = acll; } - public SarsaAgent makeCopy(){ - SarsaAgent clone = new SarsaAgent(); + public SarsaAgent makeCopy() { + final SarsaAgent clone = new SarsaAgent(); clone.copy(this); return clone; } - public void copy(SarsaAgent rhs){ - learner.copy(rhs.learner); - currentAction = rhs.currentAction; - currentState = rhs.currentState; - prevAction = rhs.prevAction; - prevState = rhs.prevState; + public void copy(final SarsaAgent rhs) { + this.learner.copy(rhs.learner); + this.currentAction = rhs.currentAction; + this.currentState = rhs.currentState; + this.prevAction = rhs.prevAction; + this.prevState = rhs.prevState; } @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof SarsaAgent){ - SarsaAgent rhs = (SarsaAgent)obj; - return prevAction == rhs.prevAction - && prevState == rhs.prevState - && currentAction == rhs.currentAction - && currentState == rhs.currentState - && learner.equals(rhs.learner); + public boolean equals(final Object obj) { + if (obj instanceof SarsaAgent) { + final SarsaAgent rhs = (SarsaAgent) obj; + return this.prevAction == rhs.prevAction + && this.prevState == rhs.prevState + && this.currentAction == rhs.currentAction + && this.currentState == rhs.currentState + && this.learner.equals(rhs.learner); } return false; } diff --git a/src/main/java/com/github/chen0040/rl/models/DefaultValues.java b/src/main/java/com/github/chen0040/rl/models/DefaultValues.java new file mode 100644 index 0000000..3cf7906 --- /dev/null +++ b/src/main/java/com/github/chen0040/rl/models/DefaultValues.java @@ -0,0 +1,8 @@ +package com.github.chen0040.rl.models; + +public enum DefaultValues { + ; + public static final double GAMMA = 0.9; + public static final double ALPHA = 0.1; + public static final double INITIAL_Q = 0.1; +} diff --git a/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java b/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java index e25380f..bd891f0 100644 --- a/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java +++ b/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java @@ -4,6 +4,5 @@ * Created by xschen on 9/28/2015 0028. */ public enum EligibilityTraceUpdateMode { - ReplaceTrace, - AccumulateTrace + ReplaceTrace } diff --git a/src/main/java/com/github/chen0040/rl/models/QModel.java b/src/main/java/com/github/chen0040/rl/models/QModel.java index ae36347..544b186 100644 --- a/src/main/java/com/github/chen0040/rl/models/QModel.java +++ b/src/main/java/com/github/chen0040/rl/models/QModel.java @@ -4,6 +4,7 @@ import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.utils.Matrix; import com.github.chen0040.rl.utils.Vec; +import org.jetbrains.annotations.Nullable; import java.util.*; @@ -17,16 +18,18 @@ public class QModel { * Q value for (state_id, action_id) pair Q is known as the quality of state-action combination, note that it is * different from utility of a state */ + @org.jetbrains.annotations.Nullable private Matrix Q; /** * $\alpha[s, a]$ value for learning rate: alpha(state_id, action_id) */ + @org.jetbrains.annotations.Nullable private Matrix alphaMatrix; /** * discount factor */ - private double gamma = 0.7; + private double gamma = DefaultValues.GAMMA; private int stateCount; private int actionCount; @@ -37,11 +40,11 @@ public QModel(final int stateCount, final int actionCount, final double initialQ this.Q = new Matrix(stateCount, actionCount); this.alphaMatrix = new Matrix(stateCount, actionCount); this.Q.setAll(initialQ); - this.alphaMatrix.setAll(0.1); + this.alphaMatrix.setAll(DefaultValues.ALPHA); } public QModel(final int stateCount, final int actionCount) { - this(stateCount, actionCount, 0.1); + this(stateCount, actionCount, DefaultValues.INITIAL_Q); } public QModel() { @@ -50,28 +53,16 @@ public QModel() { @Override public boolean equals(final Object rhs) { - if (rhs != null && rhs instanceof QModel) { + if (rhs instanceof QModel) { final QModel rhs2 = (QModel) rhs; - - - if (this.gamma != rhs2.gamma) { - return false; - } - - - if (this.stateCount != rhs2.stateCount || this.actionCount != rhs2.actionCount) { - return false; - } - - if ((this.Q != null && rhs2.Q == null) || (this.Q == null && rhs2.Q != null)) { - return false; - } - if ((this.alphaMatrix != null && rhs2.alphaMatrix == null) || (this.alphaMatrix == null && rhs2.alphaMatrix != null)) { - return false; - } - - return !((this.Q != null && !this.Q.equals(rhs2.Q)) || (this.alphaMatrix != null && !this.alphaMatrix.equals(rhs2.alphaMatrix))); - + return this.gamma == rhs2.gamma && + this.stateCount == rhs2.stateCount && + this.actionCount == rhs2.actionCount && + (this.Q == null || rhs2.Q != null) && + (this.Q != null || rhs2.Q == null) && + (this.alphaMatrix == null || rhs2.alphaMatrix != null) && + (this.alphaMatrix != null || rhs2.alphaMatrix == null) && + !((this.Q != null && !this.Q.equals(rhs2.Q)) || (this.alphaMatrix != null && !this.alphaMatrix.equals(rhs2.alphaMatrix))); } return false; } @@ -92,50 +83,55 @@ public void copy(final QModel rhs) { public double getQ(final int stateId, final int actionId) { + assert this.Q != null; return this.Q.get(stateId, actionId); } public void setQ(final int stateId, final int actionId, final double Qij) { + assert this.Q != null; this.Q.set(stateId, actionId, Qij); } public double getAlpha(final int stateId, final int actionId) { + assert this.alphaMatrix != null; return this.alphaMatrix.get(stateId, actionId); } public void setAlpha(final double defaultAlpha) { + assert this.alphaMatrix != null; this.alphaMatrix.setAll(defaultAlpha); } public IndexValue actionWithMaxQAtState(final int stateId, final Set actionsAtState) { + assert this.Q != null; final Vec rowVector = this.Q.rowAt(stateId); return rowVector.indexWithMaxValue(actionsAtState); } private void reset(final double initialQ) { + assert this.Q != null; this.Q.setAll(initialQ); } - public IndexValue actionWithSoftMaxQAtState(final int stateId, Set actionsAtState, final Random random) { + public IndexValue actionWithSoftMaxQAtState(final int stateId, final Set actionsAtState, final Random random) { + Set atState = actionsAtState; + assert this.Q != null; final Vec rowVector = this.Q.rowAt(stateId); double sum = 0; - if (actionsAtState == null) { - actionsAtState = new HashSet<>(); + if (atState == null) { + atState = new HashSet<>(); for (int i = 0; i < this.actionCount; ++i) { - actionsAtState.add(i); + atState.add(i); } } - final List actions = new ArrayList<>(); - for (final Integer actionId : actionsAtState) { - actions.add(actionId); - } + final List actions = new ArrayList<>(atState); final double[] acc = new double[actions.size()]; for (int i = 0; i < actions.size(); ++i) { @@ -159,22 +155,16 @@ public IndexValue actionWithSoftMaxQAtState(final int stateId, Set acti return result; } + @Nullable public Matrix getQ() { return this.Q; } - public void setQ(final Matrix q) { - this.Q = q; - } - + @Nullable public Matrix getAlphaMatrix() { return this.alphaMatrix; } - public void setAlphaMatrix(final Matrix alphaMatrix) { - this.alphaMatrix = alphaMatrix; - } - public double getGamma() { return this.gamma; } diff --git a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java index 83f7f80..0397799 100644 --- a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java +++ b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java @@ -1,6 +1,7 @@ package com.github.chen0040.rl.models; import com.github.chen0040.rl.utils.Vec; +import org.jetbrains.annotations.Nullable; import java.io.Serializable; @@ -16,10 +17,12 @@ * is applied at state $s$ */ public class UtilityModel implements Serializable { + @Nullable private Vec U; private int stateCount; private int actionCount; + @SuppressWarnings("Used-by-user") public UtilityModel(final int stateCount, final int actionCount, final double initialU) { this.stateCount = stateCount; this.actionCount = actionCount; @@ -27,24 +30,7 @@ public UtilityModel(final int stateCount, final int actionCount, final double in this.U.setAll(initialU); } - public UtilityModel(final int stateCount, final int actionCount) { - this(stateCount, actionCount, 0.1); - } - - public UtilityModel() { - - } - - public Vec getU() { - return this.U; - } - - public void setU(final Vec U) { - this.U = U; - } - - public double getU(final int stateId) { - return this.U.get(stateId); + private UtilityModel() { } public int getStateCount() { @@ -77,22 +63,18 @@ public UtilityModel makeCopy() { @Override public boolean equals(final Object rhs) { - if (rhs != null && rhs instanceof UtilityModel) { + if (rhs instanceof UtilityModel) { final UtilityModel rhs2 = (UtilityModel) rhs; - if (this.actionCount != rhs2.actionCount || this.stateCount != rhs2.stateCount) { - return false; - } - - if ((this.U == null && rhs2.U != null) && (this.U != null && rhs2.U == null)) { - return false; - } - return !(this.U != null && !this.U.equals(rhs2.U)); + return this.actionCount == rhs2.actionCount && + this.stateCount == rhs2.stateCount && + !(this.U != null && !this.U.equals(rhs2.U)); } return false; } public void reset(final double initialU) { + assert this.U != null; this.U.setAll(initialU); } } diff --git a/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java b/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java deleted file mode 100644 index 5dba744..0000000 --- a/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java +++ /dev/null @@ -1,44 +0,0 @@ -package com.github.chen0040.rl.utils; - -import java.util.Objects; - -/** - * Created by xschen on 10/11/2015 0011. - */ -public class TupleTwo { - private final T1 item1; - private final T2 item2; - - private TupleTwo(final T1 item1, final T2 item2) { - this.item1 = item1; - this.item2 = item2; - } - - static TupleTwo create(final U1 item1, final U2 item2) { - return new TupleTwo<>(item1, item2); - } - - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (o == null || this.getClass() != o.getClass()) { - return false; - } - - final TupleTwo tupleTwo = (TupleTwo) o; - - return Objects.equals(this.item1, tupleTwo.item1) && Objects.equals(this.item2, tupleTwo.item2); - - } - - - @Override - public int hashCode() { - int result = this.item1 != null ? this.item1.hashCode() : 0; - result = 31 * result + (this.item2 != null ? this.item2.hashCode() : 0); - return result; - } -} From 5ad65c25ac69b3d268a558af188144e54910e1da Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Sun, 10 Mar 2019 12:20:43 +0100 Subject: [PATCH 08/12] cleanup 7 --- .../rl/learning/sarsa/SarsaLambdaLearner.java | 121 +++++++------ .../rl/learning/sarsa/SarsaLearner.java | 166 +++++++++--------- .../chen0040/rl/models/DefaultValues.java | 1 + 3 files changed, 148 insertions(+), 140 deletions(-) diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java index e51543e..2c2bfc5 100644 --- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java @@ -1,6 +1,7 @@ package com.github.chen0040.rl.learning.sarsa; +import com.github.chen0040.rl.models.DefaultValues; import com.github.chen0040.rl.models.EligibilityTraceUpdateMode; import com.github.chen0040.rl.utils.Matrix; @@ -9,119 +10,123 @@ * Created by xschen on 9/28/2015 0028. */ public class SarsaLambdaLearner extends SarsaLearner { - private double lambda = 0.9; + private double lambda = DefaultValues.LAMDA; private Matrix e; private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace; + public SarsaLambdaLearner(final int stateCount, final int actionCount) { + super(stateCount, actionCount); + this.e = new Matrix(stateCount, actionCount); + } + + public SarsaLambdaLearner(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) { + super(stateCount, actionCount, alpha, gamma, initialQ); + this.e = new Matrix(stateCount, actionCount); + } + + @SuppressWarnings("Used-by-user") + public SarsaLambdaLearner(final SarsaLearner learner) { + this.copy(learner); + this.e = new Matrix(this.model.getStateCount(), this.model.getActionCount()); + } + + private SarsaLambdaLearner() { + + } + + @SuppressWarnings("Used-by-user") public EligibilityTraceUpdateMode getTraceUpdateMode() { - return traceUpdateMode; + return this.traceUpdateMode; } - public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) { + @SuppressWarnings("Used-by-user") + public void setTraceUpdateMode(final EligibilityTraceUpdateMode traceUpdateMode) { this.traceUpdateMode = traceUpdateMode; } - public double getLambda(){ - return lambda; + @SuppressWarnings("Used-by-user") + public double getLambda() { + return this.lambda; } - public void setLambda(double lambda){ + public void setLambda(final double lambda) { this.lambda = lambda; } @Override - public Object clone(){ - SarsaLambdaLearner clone = new SarsaLambdaLearner(); + public SarsaLambdaLearner clone() { + final SarsaLambdaLearner clone = new SarsaLambdaLearner(); clone.copy(this); return clone; } @Override - public void copy(SarsaLearner rhs){ + public void copy(final SarsaLearner rhs) { super.copy(rhs); - SarsaLambdaLearner rhs2 = (SarsaLambdaLearner)rhs; - lambda = rhs2.lambda; - e = rhs2.e.makeCopy(); - traceUpdateMode = rhs2.traceUpdateMode; + final SarsaLambdaLearner rhs2 = (SarsaLambdaLearner) rhs; + this.lambda = rhs2.lambda; + this.e = rhs2.e.makeCopy(); + this.traceUpdateMode = rhs2.traceUpdateMode; } @Override - public boolean equals(Object obj){ - if(!super.equals(obj)){ + public boolean equals(final Object obj) { + if (!super.equals(obj)) { return false; } - if(obj instanceof SarsaLambdaLearner){ - SarsaLambdaLearner rhs = (SarsaLambdaLearner)obj; - return rhs.lambda == lambda && e.equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode; + if (obj instanceof SarsaLambdaLearner) { + final SarsaLambdaLearner rhs = (SarsaLambdaLearner) obj; + return rhs.lambda == this.lambda && this.e.equals(rhs.e) && this.traceUpdateMode == rhs.traceUpdateMode; } return false; } - public SarsaLambdaLearner(){ - super(); - } - - public SarsaLambdaLearner(int stateCount, int actionCount){ - super(stateCount, actionCount); - e = new Matrix(stateCount, actionCount); - } - - public SarsaLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ){ - super(stateCount, actionCount, alpha, gamma, initialQ); - e = new Matrix(stateCount, actionCount); - } - - public SarsaLambdaLearner(SarsaLearner learner){ - copy(learner); - e = new Matrix(model.getStateCount(), model.getActionCount()); - } - - public Matrix getEligibility() - { - return e; + @SuppressWarnings("Used-by-user") + public Matrix getEligibility() { + return this.e; } - public void setEligibility(Matrix e){ + @SuppressWarnings("Used-by-user") + public void setEligibility(final Matrix e) { this.e = e; } @Override - public void update(int currentStateId, int currentActionId, int nextStateId, int nextActionId, double immediateReward) - { + public void update(final int currentStateId, final int currentActionId, final int nextStateId, final int nextActionId, final double immediateReward) { // old_value is $Q_t(s_t, a_t)$ - double oldQ = model.getQ(currentStateId, currentActionId); + double oldQ = this.model.getQ(currentStateId, currentActionId); // learning_rate; - double alpha = model.getAlpha(currentStateId, currentActionId); + final double alpha = this.model.getAlpha(currentStateId, currentActionId); // discount_rate; - double gamma = model.getGamma(); + final double gamma = this.model.getGamma(); // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ - double nextQ = model.getQ(nextStateId, nextActionId); + final double nextQ = this.model.getQ(nextStateId, nextActionId); - double td_error = immediateReward + gamma * nextQ - oldQ; + final double td_error = immediateReward + gamma * nextQ - oldQ; - int stateCount = model.getStateCount(); - int actionCount = model.getActionCount(); + final int stateCount = this.model.getStateCount(); + final int actionCount = this.model.getActionCount(); - e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1); + this.e.set(currentStateId, currentActionId, this.e.get(currentStateId, currentActionId) + 1); - for(int stateId = 0; stateId < stateCount; ++stateId){ - for(int actionId = 0; actionId < actionCount; ++actionId){ - oldQ = model.getQ(stateId, actionId); + for (int stateId = 0; stateId < stateCount; ++stateId) { + for (int actionId = 0; actionId < actionCount; ++actionId) { + oldQ = this.model.getQ(stateId, actionId); - double newQ = oldQ + alpha * td_error * e.get(stateId, actionId); + final double newQ = oldQ + alpha * td_error * this.e.get(stateId, actionId); - model.setQ(stateId, actionId, newQ); + this.model.setQ(stateId, actionId, newQ); if (actionId != currentActionId) { - e.set(currentStateId, actionId, 0); + this.e.set(currentStateId, actionId, 0); } else { - e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda); + this.e.set(stateId, actionId, this.e.get(stateId, actionId) * gamma * this.lambda); } } } diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java index 7fef780..4d4fe3d 100644 --- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java @@ -16,104 +16,63 @@ /** - * Created by xschen on 9/27/2015 0027. - * Implement temporal-difference learning Q-Learning, which is an off-policy TD control algorithm - * Q is known as the quality of state-action combination, note that it is different from utility of a state + * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Q-Learning, which is an off-policy TD + * control algorithm Q is known as the quality of state-action combination, note that it is different from utility of a + * state */ -public class SarsaLearner implements Serializable,Cloneable { +public class SarsaLearner implements Serializable, Cloneable { protected QModel model; private ActionSelectionStrategy actionSelectionStrategy; - public String toJson() { - return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); - } - - public static SarsaLearner fromJson(String json){ - return JSON.parseObject(json, SarsaLearner.class); - } - - public SarsaLearner makeCopy(){ - SarsaLearner clone = new SarsaLearner(); - clone.copy(this); - return clone; - } - - public void copy(SarsaLearner rhs){ - model = rhs.model.makeCopy(); - actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone(); - } - - @Override - public boolean equals(Object obj){ - if(obj !=null && obj instanceof SarsaLearner){ - SarsaLearner rhs = (SarsaLearner)obj; - if(!model.equals(rhs.model)) return false; - return actionSelectionStrategy.equals(rhs.actionSelectionStrategy); - } - return false; - } - - public QModel getModel() { - return model; - } - - public void setModel(QModel model) { - this.model = model; - } - - public String getActionSelection() { - return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); - } - - public void setActionSelection(String conf) { - this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); - } - - public SarsaLearner(){ + public SarsaLearner() { } - public SarsaLearner(int stateCount, int actionCount){ + public SarsaLearner(final int stateCount, final int actionCount) { this(stateCount, actionCount, 0.1, 0.7, 0.1); } - public SarsaLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy){ + public SarsaLearner(final QModel model, final ActionSelectionStrategy actionSelectionStrategy) { this.model = model; this.actionSelectionStrategy = actionSelectionStrategy; } - public SarsaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) - { - model = new QModel(stateCount, actionCount, initialQ); - model.setAlpha(alpha); - model.setGamma(gamma); - actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); + public SarsaLearner(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) { + this.model = new QModel(stateCount, actionCount, initialQ); + this.model.setAlpha(alpha); + this.model.setGamma(gamma); + this.actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); } - public static void main(String[] args){ - int stateCount = 100; - int actionCount = 10; + @SuppressWarnings("Used-by-user") + public static SarsaLearner fromJson(final String json) { + return JSON.parseObject(json, SarsaLearner.class); + } - SarsaLearner learner = new SarsaLearner(stateCount, actionCount); + public static void main(final String[] args) { + final int stateCount = 100; + final int actionCount = 10; - double reward = 0; // reward gained by transiting from prevState to currentState - Random random = new Random(); + final SarsaLearner learner = new SarsaLearner(stateCount, actionCount); + + double reward; // reward gained by transiting from prevState to currentState + final Random random = new Random(); int currentStateId = random.nextInt(stateCount); int currentActionId = learner.selectAction(currentStateId).getIndex(); - for(int time=0; time < 1000; ++time){ + for (int time = 0; time < 1000; ++time) { - System.out.println("Controller does action-"+currentActionId); + System.out.println("Controller does action-" + currentActionId); - int newStateId = random.nextInt(actionCount); + final int newStateId = random.nextInt(actionCount); reward = random.nextDouble(); System.out.println("Now the new state is " + newStateId); System.out.println("Controller receives Reward = " + reward); - int futureActionId = learner.selectAction(newStateId).getIndex(); + final int futureActionId = learner.selectAction(newStateId).getIndex(); - System.out.println("Controller is expected to do action-"+futureActionId); + System.out.println("Controller is expected to do action-" + futureActionId); learner.update(currentStateId, currentActionId, newStateId, futureActionId, reward); @@ -122,39 +81,82 @@ public static void main(String[] args){ } } + @SuppressWarnings("Used-by-user") + public String toJson() { + return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); + } + + public SarsaLearner makeCopy() { + final SarsaLearner clone = new SarsaLearner(); + clone.copy(this); + return clone; + } + + public void copy(final SarsaLearner rhs) { + this.model = rhs.model.makeCopy(); + this.actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone(); + } + + @Override + public boolean equals(final Object obj) { + if (obj instanceof SarsaLearner) { + final SarsaLearner rhs = (SarsaLearner) obj; + if (!this.model.equals(rhs.model)) { + return false; + } + return this.actionSelectionStrategy.equals(rhs.actionSelectionStrategy); + } + return false; + } + + public QModel getModel() { + return this.model; + } - public IndexValue selectAction(int stateId, Set actionsAtState){ - return actionSelectionStrategy.selectAction(stateId, model, actionsAtState); + public void setModel(final QModel model) { + this.model = model; } - public IndexValue selectAction(int stateId){ - return selectAction(stateId, null); + @SuppressWarnings("Used-by-user") + public String getActionSelection() { + return ActionSelectionStrategyFactory.serialize(this.actionSelectionStrategy); } - public void update(int stateId, int actionId, int nextStateId, int nextActionId, double immediateReward) - { + @SuppressWarnings("Used-by-user") + public void setActionSelection(final String conf) { + this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); + } + + public IndexValue selectAction(final int stateId, final Set actionsAtState) { + return this.actionSelectionStrategy.selectAction(stateId, this.model, actionsAtState); + } + + public IndexValue selectAction(final int stateId) { + return this.selectAction(stateId, null); + } + + public void update(final int stateId, final int actionId, final int nextStateId, final int nextActionId, final double immediateReward) { // old_value is $Q_t(s_t, a_t)$ - double oldQ = model.getQ(stateId, actionId); + final double oldQ = this.model.getQ(stateId, actionId); // learning_rate; - double alpha = model.getAlpha(stateId, actionId); + final double alpha = this.model.getAlpha(stateId, actionId); // discount_rate; - double gamma = model.getGamma(); + final double gamma = this.model.getGamma(); // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ - double nextQ = model.getQ(nextStateId, nextActionId); + final double nextQ = this.model.getQ(nextStateId, nextActionId); // learned_value = immediate_reward + gamma * estimate_of_optimal_future_value // old_value = oldQ // temporal_difference = learned_value - old_value // new_value = old_value + learning_rate * temporal_difference - double newQ = oldQ + alpha * (immediateReward + gamma * nextQ - oldQ); + final double newQ = oldQ + alpha * (immediateReward + gamma * nextQ - oldQ); // new_value is $Q_{t+1}(s_t, a_t)$ - model.setQ(stateId, actionId, newQ); + this.model.setQ(stateId, actionId, newQ); } - } diff --git a/src/main/java/com/github/chen0040/rl/models/DefaultValues.java b/src/main/java/com/github/chen0040/rl/models/DefaultValues.java index 3cf7906..df697ea 100644 --- a/src/main/java/com/github/chen0040/rl/models/DefaultValues.java +++ b/src/main/java/com/github/chen0040/rl/models/DefaultValues.java @@ -5,4 +5,5 @@ public enum DefaultValues { public static final double GAMMA = 0.9; public static final double ALPHA = 0.1; public static final double INITIAL_Q = 0.1; + public static final double LAMDA = 0.9; } From f4dd6d71aaa85edf91838cda52da63e180c8e896 Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Sun, 10 Mar 2019 12:22:14 +0100 Subject: [PATCH 09/12] cleanup 8 --- .../github/chen0040/rl/learning/rlearn/RAgent.java | 2 ++ .../chen0040/rl/learning/rlearn/RLearner.java | 13 ++++++++++++- .../github/chen0040/rl/models/DefaultValues.java | 2 ++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java b/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java index 1330349..d7ef30c 100644 --- a/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java +++ b/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java @@ -27,10 +27,12 @@ public RAgent(final int stateCount, final int actionCount) { this.learner = new RLearner(stateCount, actionCount); } + @SuppressWarnings("Used-by-user") public int getCurrentState() { return this.currentState; } + @SuppressWarnings("Used-by-user") public int getCurrentAction() { return this.currentAction; } diff --git a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java index 0bd5f50..bc520e7 100644 --- a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java @@ -7,12 +7,15 @@ import com.github.chen0040.rl.actionselection.ActionSelectionStrategy; import com.github.chen0040.rl.actionselection.ActionSelectionStrategyFactory; import com.github.chen0040.rl.actionselection.EpsilonGreedyActionSelectionStrategy; +import com.github.chen0040.rl.models.DefaultValues; import com.github.chen0040.rl.models.QModel; import com.github.chen0040.rl.utils.IndexValue; import java.io.Serializable; import java.util.Set; +import static com.github.chen0040.rl.models.DefaultValues.ALPHA; + /** * Created by xschen on 9/27/2015 0027. @@ -29,7 +32,7 @@ public RLearner() { } public RLearner(final int stateCount, final int actionCount) { - this(stateCount, actionCount, 0.1, 0.1, 0.7, 0.1); + this(stateCount, actionCount, ALPHA, DefaultValues.BETA, DefaultValues.RHO, DefaultValues.INITIAL_Q); } public RLearner(final int state_count, final int action_count, final double alpha, final double beta, final double rho, final double initial_Q) { @@ -42,10 +45,12 @@ public RLearner(final int state_count, final int action_count, final double alph this.actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); } + @SuppressWarnings("Used-by-user") public static RLearner fromJson(final String json) { return JSON.parseObject(json, RLearner.class); } + @SuppressWarnings("Used-by-user") public String toJson() { return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); } @@ -81,18 +86,22 @@ public boolean equals(final Object obj) { return false; } + @SuppressWarnings("Used-by-user") public double getRho() { return this.rho; } + @SuppressWarnings("Used-by-user") public void setRho(final double rho) { this.rho = rho; } + @SuppressWarnings("Used-by-user") public double getBeta() { return this.beta; } + @SuppressWarnings("Used-by-user") public void setBeta(final double beta) { this.beta = beta; } @@ -106,10 +115,12 @@ public void setModel(final QModel model) { this.model = model; } + @SuppressWarnings("Used-by-user") public String getActionSelection() { return ActionSelectionStrategyFactory.serialize(this.actionSelectionStrategy); } + @SuppressWarnings("Used-by-user") public void setActionSelection(final String conf) { this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); } diff --git a/src/main/java/com/github/chen0040/rl/models/DefaultValues.java b/src/main/java/com/github/chen0040/rl/models/DefaultValues.java index df697ea..a215ae7 100644 --- a/src/main/java/com/github/chen0040/rl/models/DefaultValues.java +++ b/src/main/java/com/github/chen0040/rl/models/DefaultValues.java @@ -6,4 +6,6 @@ public enum DefaultValues { public static final double ALPHA = 0.1; public static final double INITIAL_Q = 0.1; public static final double LAMDA = 0.9; + public static final double BETA = 0.1; + public static final double RHO = 0.7; } From 5a2122a6d62b45abd91339686a36b44403291f81 Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Sun, 10 Mar 2019 13:23:09 +0100 Subject: [PATCH 10/12] cleanup finished --- .../AbstractActionSelectionStrategy.java | 68 ++++----- .../ActionSelectionStrategyFactory.java | 52 +++---- .../EpsilonGreedyActionSelectionStrategy.java | 79 +++++------ .../GibbsSoftMaxActionSelectionStrategy.java | 50 ++++--- .../GreedyActionSelectionStrategy.java | 13 +- .../SoftMaxActionSelectionStrategy.java | 27 ++-- .../actorcritic/ActorCriticAgent.java | 86 ++++++----- .../chen0040/rl/learning/qlearn/QAgent.java | 108 +++++++------- .../rl/learning/qlearn/QLambdaLearner.java | 129 +++++++++-------- .../chen0040/rl/learning/qlearn/QLearner.java | 133 +++++++++--------- .../rl/learning/sarsa/SarsaLambdaLearner.java | 2 +- .../chen0040/rl/models/DefaultValues.java | 2 +- 12 files changed, 369 insertions(+), 380 deletions(-) diff --git a/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java index 7de7f9b..f591257 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java @@ -1,8 +1,8 @@ package com.github.chen0040.rl.actionselection; -import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.models.QModel; import com.github.chen0040.rl.models.UtilityModel; +import com.github.chen0040.rl.utils.IndexValue; import java.util.HashMap; import java.util.Map; @@ -14,58 +14,50 @@ */ public abstract class AbstractActionSelectionStrategy implements ActionSelectionStrategy { + Map attributes = new HashMap<>(); private String prototype; - protected Map attributes = new HashMap(); - public String getPrototype(){ - return prototype; + @SuppressWarnings("Used-by-user") + public AbstractActionSelectionStrategy() { + this.prototype = this.getClass().getCanonicalName(); } - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { - return new IndexValue(); + @SuppressWarnings("Used-by-user") + public AbstractActionSelectionStrategy(final HashMap attributes) { + this.attributes = attributes; + if (attributes.containsKey("prototype")) { + this.prototype = attributes.get("prototype"); + } } - public IndexValue selectAction(int stateId, UtilityModel model, Set actionsAtState) { - return new IndexValue(); + @Override + public String getPrototype() { + return this.prototype; } - public AbstractActionSelectionStrategy(){ - prototype = this.getClass().getCanonicalName(); + @Override + public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) { + return new IndexValue(); } - - public AbstractActionSelectionStrategy(HashMap attributes){ - this.attributes = attributes; - if(attributes.containsKey("prototype")){ - this.prototype = attributes.get("prototype"); - } + @Override + public IndexValue selectAction(final int stateId, final UtilityModel model, final Set actionsAtState) { + return new IndexValue(); } - public Map getAttributes(){ - return attributes; + @Override + public Map getAttributes() { + return this.attributes; } @Override - public boolean equals(Object obj) { - ActionSelectionStrategy rhs = (ActionSelectionStrategy)obj; - if(!prototype.equalsIgnoreCase(rhs.getPrototype())) return false; - for(Map.Entry entry : rhs.getAttributes().entrySet()) { - if(!attributes.containsKey(entry.getKey())) { - return false; - } - if(!attributes.get(entry.getKey()).equals(entry.getValue())){ - return false; - } - } - for(Map.Entry entry : attributes.entrySet()) { - if(!rhs.getAttributes().containsKey(entry.getKey())) { - return false; - } - if(!rhs.getAttributes().get(entry.getKey()).equals(entry.getValue())){ - return false; - } - } - return true; + public boolean equals(final Object obj) { + final ActionSelectionStrategy rhs = (ActionSelectionStrategy) obj; + return this.prototype.equalsIgnoreCase(rhs.getPrototype()) && + rhs.getAttributes().entrySet().stream().noneMatch(entry -> !this.attributes.containsKey(entry.getKey()) || + !this.attributes.get(entry.getKey()).equals(entry.getValue())) && + this.attributes.entrySet().stream().noneMatch(entry -> !rhs.getAttributes().containsKey(entry.getKey()) || + !rhs.getAttributes().get(entry.getKey()).equals(entry.getValue())); } @Override diff --git a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java index ce92be0..2435b5f 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java @@ -7,52 +7,54 @@ /** * Created by xschen on 9/27/2015 0027. */ -public class ActionSelectionStrategyFactory { - public static ActionSelectionStrategy deserialize(String conf){ - String[] comps = conf.split(";"); - - HashMap attributes = new HashMap(); - for(int i=0; i < comps.length; ++i){ - String comp = comps[i]; - String[] field = comp.split("="); - if(field.length < 2) continue; - String fieldname = field[0].trim(); - String fieldvalue = field[1].trim(); +public enum ActionSelectionStrategyFactory { + ; + + public static ActionSelectionStrategy deserialize(final String conf) { + final String[] comps = conf.split(";"); + + final HashMap attributes = new HashMap<>(); + for (final String comp : comps) { + final String[] field = comp.split("="); + if (field.length < 2) { + continue; + } + final String fieldname = field[0].trim(); + final String fieldvalue = field[1].trim(); attributes.put(fieldname, fieldvalue); } - if(attributes.isEmpty()){ + if (attributes.isEmpty()) { attributes.put("prototype", conf); } - String prototype = attributes.get("prototype"); - if(prototype.equals(GreedyActionSelectionStrategy.class.getCanonicalName())){ + final String prototype = attributes.get("prototype"); + if (prototype.equals(GreedyActionSelectionStrategy.class.getCanonicalName())) { return new GreedyActionSelectionStrategy(); - } else if(prototype.equals(SoftMaxActionSelectionStrategy.class.getCanonicalName())){ + } else if (prototype.equals(SoftMaxActionSelectionStrategy.class.getCanonicalName())) { return new SoftMaxActionSelectionStrategy(); - } else if(prototype.equals(EpsilonGreedyActionSelectionStrategy.class.getCanonicalName())){ + } else if (prototype.equals(EpsilonGreedyActionSelectionStrategy.class.getCanonicalName())) { return new EpsilonGreedyActionSelectionStrategy(attributes); - } else if(prototype.equals(GibbsSoftMaxActionSelectionStrategy.class.getCanonicalName())){ + } else if (prototype.equals(GibbsSoftMaxActionSelectionStrategy.class.getCanonicalName())) { return new GibbsSoftMaxActionSelectionStrategy(); } return null; } - public static String serialize(ActionSelectionStrategy strategy){ - Map attributes = strategy.getAttributes(); + public static String serialize(final ActionSelectionStrategy strategy) { + final Map attributes = strategy.getAttributes(); attributes.put("prototype", strategy.getPrototype()); - StringBuilder sb = new StringBuilder(); + final StringBuilder sb = new StringBuilder(); boolean first = true; - for(Map.Entry entry : attributes.entrySet()){ - if(first){ + for (final Map.Entry entry : attributes.entrySet()) { + if (first) { first = false; - } - else{ + } else { sb.append(";"); } - sb.append(entry.getKey()+"="+entry.getValue()); + sb.append(entry.getKey()).append("=").append(entry.getValue()); } return sb.toString(); } diff --git a/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java index 5f7db9a..5ac3d36 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java @@ -1,7 +1,7 @@ package com.github.chen0040.rl.actionselection; -import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.models.QModel; +import com.github.chen0040.rl.utils.IndexValue; import java.util.*; @@ -10,69 +10,62 @@ * Created by xschen on 9/27/2015 0027. */ public class EpsilonGreedyActionSelectionStrategy extends AbstractActionSelectionStrategy { - public static final String EPSILON = "epsilon"; + private static final String EPSILON = "epsilon"; private Random random = new Random(); - @Override - public Object clone(){ - EpsilonGreedyActionSelectionStrategy clone = new EpsilonGreedyActionSelectionStrategy(); - clone.copy(this); - return clone; - } - - public void copy(EpsilonGreedyActionSelectionStrategy rhs){ - random = rhs.random; - for(Map.Entry entry : rhs.attributes.entrySet()){ - attributes.put(entry.getKey(), entry.getValue()); - } + @SuppressWarnings("Used-by-user") + public EpsilonGreedyActionSelectionStrategy() { + this.epsilon(); } - @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof EpsilonGreedyActionSelectionStrategy){ - EpsilonGreedyActionSelectionStrategy rhs = (EpsilonGreedyActionSelectionStrategy)obj; - if(epsilon() != rhs.epsilon()) return false; - // if(!random.equals(rhs.random)) return false; - return true; - } - return false; + @SuppressWarnings("Used-by-user") + public EpsilonGreedyActionSelectionStrategy(final HashMap attributes) { + super(attributes); } - private double epsilon(){ - return Double.parseDouble(attributes.get(EPSILON)); + @SuppressWarnings("Used-by-user") + public EpsilonGreedyActionSelectionStrategy(final Random random) { + this.random = random; + this.epsilon(); } - public EpsilonGreedyActionSelectionStrategy(){ - epsilon(0.1); + @Override + public Object clone() { + final EpsilonGreedyActionSelectionStrategy clone = new EpsilonGreedyActionSelectionStrategy(); + clone.copy(this); + return clone; } - public EpsilonGreedyActionSelectionStrategy(HashMap attributes){ - super(attributes); + public void copy(final EpsilonGreedyActionSelectionStrategy rhs) { + this.random = rhs.random; + for (final Map.Entry entry : rhs.attributes.entrySet()) { + this.attributes.put(entry.getKey(), entry.getValue()); + } } - private void epsilon(double value){ - attributes.put(EPSILON, "" + value); + @Override + public boolean equals(final Object obj) { + return obj instanceof EpsilonGreedyActionSelectionStrategy && this.epsilon() == ((EpsilonGreedyActionSelectionStrategy) obj).epsilon(); } - public EpsilonGreedyActionSelectionStrategy(Random random){ - this.random = random; - epsilon(0.1); + private double epsilon() { + return Double.parseDouble(this.attributes.get(EpsilonGreedyActionSelectionStrategy.EPSILON)); } @Override - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { - if(random.nextDouble() < 1- epsilon()){ + public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) { + if (this.random.nextDouble() < 1 - this.epsilon()) { return model.actionWithMaxQAtState(stateId, actionsAtState); - }else{ - int actionId; - if(actionsAtState != null && !actionsAtState.isEmpty()) { - List actions = new ArrayList<>(actionsAtState); - actionId = actions.get(random.nextInt(actions.size())); + } else { + final int actionId; + if (actionsAtState != null && !actionsAtState.isEmpty()) { + final List actions = new ArrayList<>(actionsAtState); + actionId = actions.get(this.random.nextInt(actions.size())); } else { - actionId = random.nextInt(model.getActionCount()); + actionId = this.random.nextInt(model.getActionCount()); } - double Q = model.getQ(stateId, actionId); + final double Q = model.getQ(stateId, actionId); return new IndexValue(actionId, Q); } } diff --git a/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java index 8b2d8d2..283748c 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java @@ -1,12 +1,13 @@ package com.github.chen0040.rl.actionselection; -import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.models.QModel; +import com.github.chen0040.rl.utils.IndexValue; import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.Set; +import java.util.stream.IntStream; /** @@ -14,52 +15,49 @@ */ public class GibbsSoftMaxActionSelectionStrategy extends AbstractActionSelectionStrategy { - private Random random = null; - public GibbsSoftMaxActionSelectionStrategy(){ - random = new Random(); + private final Random random; + + @SuppressWarnings("Used-by-user") + public GibbsSoftMaxActionSelectionStrategy() { + this.random = new Random(); } - public GibbsSoftMaxActionSelectionStrategy(Random random){ + @SuppressWarnings("Used-by-user") + public GibbsSoftMaxActionSelectionStrategy(final Random random) { this.random = random; } @Override public Object clone() { - GibbsSoftMaxActionSelectionStrategy clone = new GibbsSoftMaxActionSelectionStrategy(); - return clone; + return new GibbsSoftMaxActionSelectionStrategy(); } @Override - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { - List actions = new ArrayList(); - if(actionsAtState == null){ - for(int i=0; i < model.getActionCount(); ++i){ - actions.add(i); - } - }else{ - for(Integer actionId : actionsAtState){ - actions.add(actionId); - } + public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) { + final List actions = new ArrayList(); + if (actionsAtState == null) { + IntStream.range(0, model.getActionCount()).forEach(actions::add); + } else { + actions.addAll(actionsAtState); } double sum = 0; - List plist = new ArrayList(); - for(int i=0; i < actions.size(); ++i){ - int actionId = actions.get(i); - double p = Math.exp(model.getQ(stateId, actionId)); + final List plist = new ArrayList(); + for (final int actionId : actions) { + final double p = Math.exp(model.getQ(stateId, actionId)); sum += p; plist.add(sum); } - IndexValue iv = new IndexValue(); + final IndexValue iv = new IndexValue(); iv.setIndex(-1); iv.setValue(Double.NEGATIVE_INFINITY); - double r = sum * random.nextDouble(); - for(int i=0; i < actions.size(); ++i){ + final double r = sum * this.random.nextDouble(); + for (int i = 0; i < actions.size(); ++i) { - if(plist.get(i) >= r){ - int actionId = actions.get(i); + if (plist.get(i) >= r) { + final int actionId = actions.get(i); iv.setValue(model.getQ(stateId, actionId)); iv.setIndex(actionId); break; diff --git a/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java index 6b0f350..e735fd3 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java @@ -1,7 +1,7 @@ package com.github.chen0040.rl.actionselection; -import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.models.QModel; +import com.github.chen0040.rl.utils.IndexValue; import java.util.Set; @@ -11,18 +11,17 @@ */ public class GreedyActionSelectionStrategy extends AbstractActionSelectionStrategy { @Override - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { + public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) { return model.actionWithMaxQAtState(stateId, actionsAtState); } @Override - public Object clone(){ - GreedyActionSelectionStrategy clone = new GreedyActionSelectionStrategy(); - return clone; + public Object clone() { + return new GreedyActionSelectionStrategy(); } @Override - public boolean equals(Object obj){ - return obj != null && obj instanceof GreedyActionSelectionStrategy; + public boolean equals(final Object obj) { + return obj instanceof GreedyActionSelectionStrategy; } } diff --git a/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java index f9735b9..c0b89f5 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java @@ -1,7 +1,7 @@ package com.github.chen0040.rl.actionselection; -import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.models.QModel; +import com.github.chen0040.rl.utils.IndexValue; import java.util.Random; import java.util.Set; @@ -13,27 +13,26 @@ public class SoftMaxActionSelectionStrategy extends AbstractActionSelectionStrategy { private Random random = new Random(); - @Override - public Object clone(){ - SoftMaxActionSelectionStrategy clone = new SoftMaxActionSelectionStrategy(random); - return clone; - } + public SoftMaxActionSelectionStrategy() { - @Override - public boolean equals(Object obj){ - return obj != null && obj instanceof SoftMaxActionSelectionStrategy; } - public SoftMaxActionSelectionStrategy(){ + public SoftMaxActionSelectionStrategy(final Random random) { + this.random = random; + } + @Override + public Object clone() { + return new SoftMaxActionSelectionStrategy(this.random); } - public SoftMaxActionSelectionStrategy(Random random){ - this.random = random; + @Override + public boolean equals(final Object obj) { + return obj instanceof SoftMaxActionSelectionStrategy; } @Override - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { - return model.actionWithSoftMaxQAtState(stateId, actionsAtState, random); + public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) { + return model.actionWithSoftMaxQAtState(stateId, actionsAtState, this.random); } } diff --git a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java index 6e34874..864668f 100644 --- a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java +++ b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java @@ -3,9 +3,7 @@ import com.github.chen0040.rl.utils.Vec; import java.io.Serializable; -import java.util.Random; import java.util.Set; -import java.util.function.Function; /** @@ -17,85 +15,85 @@ public class ActorCriticAgent implements Serializable { private int prevState; private int prevAction; - public void enableEligibilityTrace(double lambda){ - ActorCriticLambdaLearner acll = new ActorCriticLambdaLearner(learner); - acll.setLambda(lambda); - learner = acll; + @SuppressWarnings("Used-by-user") + public ActorCriticAgent(final int stateCount, final int actionCount) { + this.learner = new ActorCriticLearner(stateCount, actionCount); } - public void start(int stateId){ - currentState = stateId; - prevAction = -1; - prevState = -1; - } + public ActorCriticAgent() { - public ActorCriticLearner getLearner(){ - return learner; } - public void setLearner(ActorCriticLearner learner){ + public ActorCriticAgent(final ActorCriticLearner learner) { this.learner = learner; } - public ActorCriticAgent(int stateCount, int actionCount){ - learner = new ActorCriticLearner(stateCount, actionCount); + @SuppressWarnings("Used-by-user") + public void enableEligibilityTrace(final double lambda) { + final ActorCriticLambdaLearner acll = new ActorCriticLambdaLearner(this.learner); + acll.setLambda(lambda); + this.learner = acll; } - public ActorCriticAgent(){ + @SuppressWarnings("Used-by-user") + public void start(final int stateId) { + this.currentState = stateId; + this.prevAction = -1; + this.prevState = -1; + } + @SuppressWarnings("Used-by-user") + public ActorCriticLearner getLearner() { + return this.learner; } - public ActorCriticAgent(ActorCriticLearner learner){ + public void setLearner(final ActorCriticLearner learner) { this.learner = learner; } - public ActorCriticAgent makeCopy(){ - ActorCriticAgent clone = new ActorCriticAgent(); + public ActorCriticAgent makeCopy() { + final ActorCriticAgent clone = new ActorCriticAgent(); clone.copy(this); return clone; } - public void copy(ActorCriticAgent rhs){ - learner = (ActorCriticLearner)rhs.learner.makeCopy(); - prevAction = rhs.prevAction; - prevState = rhs.prevState; - currentState = rhs.currentState; + public void copy(final ActorCriticAgent rhs) { + this.learner = (ActorCriticLearner) rhs.learner.makeCopy(); + this.prevAction = rhs.prevAction; + this.prevState = rhs.prevState; + this.currentState = rhs.currentState; } @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof ActorCriticAgent){ - ActorCriticAgent rhs = (ActorCriticAgent)obj; - return learner.equals(rhs.learner) && prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState; + public boolean equals(final Object obj) { + if (obj instanceof ActorCriticAgent) { + final ActorCriticAgent rhs = (ActorCriticAgent) obj; + return this.learner.equals(rhs.learner) && this.prevAction == rhs.prevAction && this.prevState == rhs.prevState && this.currentState == rhs.currentState; } return false; } - public int selectAction(Set actionsAtState){ - return learner.selectAction(currentState, actionsAtState); + public int selectAction(final Set actionsAtState) { + return this.learner.selectAction(this.currentState, actionsAtState); } - public int selectAction(){ - return learner.selectAction(currentState); + public int selectAction() { + return this.learner.selectAction(this.currentState); } - public void update(int actionTaken, int newState, double immediateReward, final Vec V){ - update(actionTaken, newState, null, immediateReward, V); + public void update(final int actionTaken, final int newState, final double immediateReward, final Vec V) { + this.update(actionTaken, newState, null, immediateReward, V); } - public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward, final Vec V){ + public void update(final int actionTaken, final int newState, final Set actionsAtNewState, final double immediateReward, final Vec V) { - learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward, new Function() { - public Double apply(Integer stateId) { - return V.get(stateId); - } - }); + this.learner.update(this.currentState, actionTaken, newState, actionsAtNewState, immediateReward, V::get); - prevAction = actionTaken; - prevState = currentState; + this.prevAction = actionTaken; + this.prevState = this.currentState; - currentState = newState; + this.currentState = newState; } } diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java index afdb314..1f85622 100644 --- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java +++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java @@ -3,14 +3,13 @@ import com.github.chen0040.rl.utils.IndexValue; import java.io.Serializable; -import java.util.Random; import java.util.Set; /** * Created by xschen on 9/27/2015 0027. */ -public class QAgent implements Serializable{ +public class QAgent implements Serializable { private QLearner learner; private int currentState; private int prevState; @@ -18,94 +17,97 @@ public class QAgent implements Serializable{ /** action taken at prevState */ private int prevAction; - public int getCurrentState(){ - return currentState; + public QAgent(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) { + this.learner = new QLearner(stateCount, actionCount, alpha, gamma, initialQ); } - public int getPrevState(){ - return prevState; + public QAgent(final QLearner learner) { + this.learner = learner; } - public int getPrevAction(){ - return prevAction; + public QAgent(final int stateCount, final int actionCount) { + this.learner = new QLearner(stateCount, actionCount); } - public void start(int currentState){ - this.currentState = currentState; - this.prevAction = -1; - this.prevState = -1; - } + public QAgent() { - public IndexValue selectAction(){ - return learner.selectAction(currentState); } - public IndexValue selectAction(Set actionsAtState){ - return learner.selectAction(currentState, actionsAtState); + @SuppressWarnings("Used-by-user") + public int getCurrentState() { + return this.currentState; } - public void update(int actionTaken, int newState, double immediateReward){ - update(actionTaken, newState, null, immediateReward); + @SuppressWarnings("Used-by-user") + public int getPrevState() { + return this.prevState; } - public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward){ - - learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward); - - prevState = currentState; - prevAction = actionTaken; - - currentState = newState; + @SuppressWarnings("Used-by-user") + public int getPrevAction() { + return this.prevAction; } - public void enableEligibilityTrace(double lambda){ - QLambdaLearner acll = new QLambdaLearner(learner); - acll.setLambda(lambda); - learner = acll; + public void start(final int currentState) { + this.currentState = currentState; + this.prevAction = -1; + this.prevState = -1; } - public QLearner getLearner(){ - return learner; + public IndexValue selectAction() { + return this.learner.selectAction(this.currentState); } - public void setLearner(QLearner learner){ - this.learner = learner; + public IndexValue selectAction(final Set actionsAtState) { + return this.learner.selectAction(this.currentState, actionsAtState); } - public QAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ){ - learner = new QLearner(stateCount, actionCount, alpha, gamma, initialQ); + public void update(final int actionTaken, final int newState, final double immediateReward) { + this.update(actionTaken, newState, null, immediateReward); } - public QAgent(QLearner learner){ - this.learner = learner; + public void update(final int actionTaken, final int newState, final Set actionsAtNewState, final double immediateReward) { + + this.learner.update(this.currentState, actionTaken, newState, actionsAtNewState, immediateReward); + + this.prevState = this.currentState; + this.prevAction = actionTaken; + + this.currentState = newState; } - public QAgent(int stateCount, int actionCount){ - learner = new QLearner(stateCount, actionCount); + public void enableEligibilityTrace(final double lambda) { + final QLambdaLearner acll = new QLambdaLearner(this.learner); + acll.setLambda(lambda); + this.learner = acll; } - public QAgent(){ + public QLearner getLearner() { + return this.learner; + } + public void setLearner(final QLearner learner) { + this.learner = learner; } - public QAgent makeCopy(){ - QAgent clone = new QAgent(); + public QAgent makeCopy() { + final QAgent clone = new QAgent(); clone.copy(this); return clone; } - public void copy(QAgent rhs){ - learner.copy(rhs.learner); - prevAction = rhs.prevAction; - prevState = rhs.prevState; - currentState = rhs.currentState; + public void copy(final QAgent rhs) { + this.learner.copy(rhs.learner); + this.prevAction = rhs.prevAction; + this.prevState = rhs.prevState; + this.currentState = rhs.currentState; } @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof QAgent){ - QAgent rhs = (QAgent)obj; - return prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState && learner.equals(rhs.learner); + public boolean equals(final Object obj) { + if (obj instanceof QAgent) { + final QAgent rhs = (QAgent) obj; + return this.prevAction == rhs.prevAction && this.prevState == rhs.prevState && this.currentState == rhs.currentState && this.learner.equals(rhs.learner); } return false; } diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java index 875ef3a..30ffd2e 100644 --- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java @@ -1,6 +1,7 @@ package com.github.chen0040.rl.learning.qlearn; +import com.github.chen0040.rl.models.DefaultValues; import com.github.chen0040.rl.models.EligibilityTraceUpdateMode; import com.github.chen0040.rl.utils.Matrix; @@ -11,125 +12,129 @@ * Created by xschen on 9/28/2015 0028. */ public class QLambdaLearner extends QLearner { - private double lambda = 0.9; + private double lambda = DefaultValues.LAMBDA; private Matrix e; private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace; + @SuppressWarnings("Used-by-user") + public QLambdaLearner(final QLearner learner) { + this.copy(learner); + this.e = new Matrix(this.model.getStateCount(), this.model.getActionCount()); + } + + private QLambdaLearner() { + super(); + } + + @SuppressWarnings("Used-by-user") + public QLambdaLearner(final int stateCount, final int actionCount) { + super(stateCount, actionCount); + this.e = new Matrix(stateCount, actionCount); + } + + @SuppressWarnings("Used-by-user") + public QLambdaLearner(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) { + super(stateCount, actionCount, alpha, gamma, initialQ); + this.e = new Matrix(stateCount, actionCount); + } + + @SuppressWarnings("Used-by-user") public EligibilityTraceUpdateMode getTraceUpdateMode() { - return traceUpdateMode; + return this.traceUpdateMode; } - public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) { + @SuppressWarnings("Used-by-user") + public void setTraceUpdateMode(final EligibilityTraceUpdateMode traceUpdateMode) { this.traceUpdateMode = traceUpdateMode; } - public double getLambda(){ - return lambda; + @SuppressWarnings("Used-by-user") + public double getLambda() { + return this.lambda; } - public void setLambda(double lambda){ + @SuppressWarnings("Used-by-user") + public void setLambda(final double lambda) { this.lambda = lambda; } - public QLambdaLearner makeCopy(){ - QLambdaLearner clone = new QLambdaLearner(); + @Override + public QLambdaLearner makeCopy() { + final QLambdaLearner clone = new QLambdaLearner(); clone.copy(this); return clone; } @Override - public void copy(QLearner rhs){ + public void copy(final QLearner rhs) { super.copy(rhs); - QLambdaLearner rhs2 = (QLambdaLearner)rhs; - lambda = rhs2.lambda; - e = rhs2.e.makeCopy(); - traceUpdateMode = rhs2.traceUpdateMode; - } - - public QLambdaLearner(QLearner learner){ - copy(learner); - e = new Matrix(model.getStateCount(), model.getActionCount()); + final QLambdaLearner rhs2 = (QLambdaLearner) rhs; + this.lambda = rhs2.lambda; + this.e = rhs2.e.makeCopy(); + this.traceUpdateMode = rhs2.traceUpdateMode; } @Override - public boolean equals(Object obj){ - if(!super.equals(obj)){ + public boolean equals(final Object obj) { + if (!super.equals(obj)) { return false; } - if(obj instanceof QLambdaLearner){ - QLambdaLearner rhs = (QLambdaLearner)obj; - return rhs.lambda == lambda && e.equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode; + if (obj instanceof QLambdaLearner) { + final QLambdaLearner rhs = (QLambdaLearner) obj; + return rhs.lambda == this.lambda && this.e.equals(rhs.e) && this.traceUpdateMode == rhs.traceUpdateMode; } return false; } - public QLambdaLearner(){ - super(); - } - - public QLambdaLearner(int stateCount, int actionCount){ - super(stateCount, actionCount); - e = new Matrix(stateCount, actionCount); - } - - public QLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ){ - super(stateCount, actionCount, alpha, gamma, initialQ); - e = new Matrix(stateCount, actionCount); + @SuppressWarnings("Used-by-user") + public Matrix getEligibility() { + return this.e; } - public Matrix getEligibility() - { - return e; - } - - public void setEligibility(Matrix e){ + @SuppressWarnings("Used-by-user") + public void setEligibility(final Matrix e) { this.e = e; } @Override - public void update(int currentStateId, int currentActionId, int nextStateId, Set actionsAtNextStateId, double immediateReward) - { + public void update(final int currentStateId, final int currentActionId, final int nextStateId, final Set actionsAtNextStateId, final double immediateReward) { // old_value is $Q_t(s_t, a_t)$ - double oldQ = model.getQ(currentStateId, currentActionId); + double oldQ = this.model.getQ(currentStateId, currentActionId); // learning_rate; - double alpha = model.getAlpha(currentStateId, currentActionId); + final double alpha = this.model.getAlpha(currentStateId, currentActionId); // discount_rate; - double gamma = model.getGamma(); + final double gamma = this.model.getGamma(); // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ - double maxQ = maxQAtState(nextStateId, actionsAtNextStateId); - - double td_error = immediateReward + gamma * maxQ - oldQ; + final double maxQ = this.maxQAtState(nextStateId, actionsAtNextStateId); - int stateCount = model.getStateCount(); - int actionCount = model.getActionCount(); + final double td_error = immediateReward + gamma * maxQ - oldQ; - e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1); + final int stateCount = this.model.getStateCount(); + final int actionCount = this.model.getActionCount(); + this.e.set(currentStateId, currentActionId, this.e.get(currentStateId, currentActionId) + 1); - for(int stateId = 0; stateId < stateCount; ++stateId){ - for(int actionId = 0; actionId < actionCount; ++actionId){ - oldQ = model.getQ(stateId, actionId); - double newQ = oldQ + alpha * td_error * e.get(stateId, actionId); + for (int stateId = 0; stateId < stateCount; ++stateId) { + for (int actionId = 0; actionId < actionCount; ++actionId) { + oldQ = this.model.getQ(stateId, actionId); + final double newQ = oldQ + alpha * td_error * this.e.get(stateId, actionId); // new_value is $Q_{t+1}(s_t, a_t)$ - model.setQ(currentStateId, currentActionId, newQ); + this.model.setQ(currentStateId, currentActionId, newQ); if (actionId != currentActionId) { - e.set(currentStateId, actionId, 0); + this.e.set(currentStateId, actionId, 0); } else { - e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda); + this.e.set(stateId, actionId, this.e.get(stateId, actionId) * gamma * this.lambda); } } } - - - } } diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java index 865abc5..7970237 100644 --- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java @@ -2,7 +2,6 @@ import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.annotation.JSONField; import com.alibaba.fastjson.serializer.SerializerFeature; import com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy; import com.github.chen0040.rl.actionselection.ActionSelectionStrategy; @@ -12,131 +11,133 @@ import com.github.chen0040.rl.utils.IndexValue; import java.io.Serializable; -import java.util.Random; import java.util.Set; +import static com.github.chen0040.rl.models.DefaultValues.*; + /** - * Created by xschen on 9/27/2015 0027. - * Implement temporal-difference learning Q-Learning, which is an off-policy TD control algorithm - * Q is known as the quality of state-action combination, note that it is different from utility of a state + * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Q-Learning, which is an off-policy TD + * control algorithm Q is known as the quality of state-action combination, note that it is different from utility of a + * state */ -public class QLearner implements Serializable,Cloneable { +public class QLearner implements Serializable, Cloneable { protected QModel model; private ActionSelectionStrategy actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); - public QLearner makeCopy(){ - QLearner clone = new QLearner(); - clone.copy(this); - return clone; - } + public QLearner() { - public String toJson() { - return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); } - public static QLearner fromJson(String json){ - return JSON.parseObject(json, QLearner.class); + public QLearner(final int stateCount, final int actionCount) { + this(stateCount, actionCount, ALPHA, GAMMA, INITIAL_Q); } - public void copy(QLearner rhs){ - model = rhs.model.makeCopy(); - actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone(); + public QLearner(final QModel model, final ActionSelectionStrategy actionSelectionStrategy) { + this.model = model; + this.actionSelectionStrategy = actionSelectionStrategy; } - @Override - public boolean equals(Object obj){ - if(obj !=null && obj instanceof QLearner){ - QLearner rhs = (QLearner)obj; - if(!model.equals(rhs.model)) return false; - return actionSelectionStrategy.equals(rhs.actionSelectionStrategy); - } - return false; + public QLearner(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) { + this.model = new QModel(stateCount, actionCount, initialQ); + this.model.setAlpha(alpha); + this.model.setGamma(gamma); + this.actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); } - public QModel getModel() { - return model; + @SuppressWarnings("Used-by-user") + public static QLearner fromJson(final String json) { + return JSON.parseObject(json, QLearner.class); } - public void setModel(QModel model) { - this.model = model; + public QLearner makeCopy() { + final QLearner clone = new QLearner(); + clone.copy(this); + return clone; } - - public String getActionSelection() { - return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); + @SuppressWarnings("Used-by-user") + public String toJson() { + return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); } - public void setActionSelection(String conf) { - this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); + public void copy(final QLearner rhs) { + this.model = rhs.model.makeCopy(); + this.actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone(); } - public QLearner(){ - + @Override + public boolean equals(final Object obj) { + if (obj instanceof QLearner) { + final QLearner rhs = (QLearner) obj; + if (!this.model.equals(rhs.model)) { + return false; + } + return this.actionSelectionStrategy.equals(rhs.actionSelectionStrategy); + } + return false; } - public QLearner(int stateCount, int actionCount){ - this(stateCount, actionCount, 0.1, 0.7, 0.1); + public QModel getModel() { + return this.model; } - public QLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy){ + public void setModel(final QModel model) { this.model = model; - this.actionSelectionStrategy = actionSelectionStrategy; } - public QLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) - { - model = new QModel(stateCount, actionCount, initialQ); - model.setAlpha(alpha); - model.setGamma(gamma); - actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); + @SuppressWarnings("Used-by-user") + public String getActionSelection() { + return ActionSelectionStrategyFactory.serialize(this.actionSelectionStrategy); } + @SuppressWarnings("Used-by-user") + public void setActionSelection(final String conf) { + this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); + } - protected double maxQAtState(int stateId, Set actionsAtState){ - IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState); - double maxQ = iv.getValue(); - return maxQ; + double maxQAtState(final int stateId, final Set actionsAtState) { + return this.model.actionWithMaxQAtState(stateId, actionsAtState).getValue(); } - public IndexValue selectAction(int stateId, Set actionsAtState){ - return actionSelectionStrategy.selectAction(stateId, model, actionsAtState); + @SuppressWarnings("Used-by-user") + public IndexValue selectAction(final int stateId, final Set actionsAtState) { + return this.actionSelectionStrategy.selectAction(stateId, this.model, actionsAtState); } - public IndexValue selectAction(int stateId){ - return selectAction(stateId, null); + @SuppressWarnings("Used-by-user") + public IndexValue selectAction(final int stateId) { + return this.selectAction(stateId, null); } - public void update(int stateId, int actionId, int nextStateId, double immediateReward){ - update(stateId, actionId, nextStateId, null, immediateReward); + public void update(final int stateId, final int actionId, final int nextStateId, final double immediateReward) { + this.update(stateId, actionId, nextStateId, null, immediateReward); } - public void update(int stateId, int actionId, int nextStateId, Set actionsAtNextStateId, double immediateReward) - { + public void update(final int stateId, final int actionId, final int nextStateId, final Set actionsAtNextStateId, final double immediateReward) { // old_value is $Q_t(s_t, a_t)$ - double oldQ = model.getQ(stateId, actionId); + final double oldQ = this.model.getQ(stateId, actionId); // learning_rate; - double alpha = model.getAlpha(stateId, actionId); + final double alpha = this.model.getAlpha(stateId, actionId); // discount_rate; - double gamma = model.getGamma(); + final double gamma = this.model.getGamma(); // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ - double maxQ = maxQAtState(nextStateId, actionsAtNextStateId); + final double maxQ = this.maxQAtState(nextStateId, actionsAtNextStateId); // learned_value = immediate_reward + gamma * estimate_of_optimal_future_value // old_value = oldQ // temporal_difference = learned_value - old_value // new_value = old_value + learning_rate * temporal_difference - double newQ = oldQ + alpha * (immediateReward + gamma * maxQ - oldQ); + final double newQ = oldQ + alpha * (immediateReward + gamma * maxQ - oldQ); // new_value is $Q_{t+1}(s_t, a_t)$ - model.setQ(stateId, actionId, newQ); + this.model.setQ(stateId, actionId, newQ); } - } diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java index 2c2bfc5..6298ac8 100644 --- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java @@ -10,7 +10,7 @@ * Created by xschen on 9/28/2015 0028. */ public class SarsaLambdaLearner extends SarsaLearner { - private double lambda = DefaultValues.LAMDA; + private double lambda = DefaultValues.LAMBDA; private Matrix e; private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace; diff --git a/src/main/java/com/github/chen0040/rl/models/DefaultValues.java b/src/main/java/com/github/chen0040/rl/models/DefaultValues.java index a215ae7..879ea64 100644 --- a/src/main/java/com/github/chen0040/rl/models/DefaultValues.java +++ b/src/main/java/com/github/chen0040/rl/models/DefaultValues.java @@ -5,7 +5,7 @@ public enum DefaultValues { public static final double GAMMA = 0.9; public static final double ALPHA = 0.1; public static final double INITIAL_Q = 0.1; - public static final double LAMDA = 0.9; + public static final double LAMBDA = 0.9; public static final double BETA = 0.1; public static final double RHO = 0.7; } From b4b729fcfad9f1b138d1f83053e6cfd57ef4640e Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Sun, 10 Mar 2019 15:21:33 +0100 Subject: [PATCH 11/12] ready for pushing to github --- pom.xml | 29 +++---------- .../ActionSelectionStrategyFactory.java | 10 ++--- .../rl/learning/sarsa/SarsaLearner.java | 41 ++++--------------- .../com/github/chen0040/rl/models/QModel.java | 9 ++-- .../chen0040/rl/models/UtilityModel.java | 3 +- .../github/chen0040/rl/utils/DoubleUtils.java | 2 +- 6 files changed, 23 insertions(+), 71 deletions(-) diff --git a/pom.xml b/pom.xml index 443ad61..9c0272d 100644 --- a/pom.xml +++ b/pom.xml @@ -1,6 +1,6 @@ - 4.0.0 @@ -38,7 +38,8 @@ Reinforcement Learning Algorithms - Classical RL algorithms implemented in Java, including Q-Learn, R-Learn, SARSA, Actor-Critic + Classical RL algorithms implemented in Java, including Q-Learn, R-Learn, SARSA, Actor-Critic + https://github.com/chen0040/java-reinforcement-learning @@ -235,7 +236,6 @@ - org.codehaus.mojo findbugs-maven-plugin @@ -361,8 +361,6 @@ - - @@ -382,14 +380,13 @@ org.apache.maven.plugins maven-resources-plugin + 2.3 UTF-8 - - @@ -488,28 +485,12 @@ - - - - org.projectlombok - lombok - provided - ${lombok.version} - - com.alibaba fastjson 1.2.41 - - org.jetbrains - annotations - 13.0 - - - diff --git a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java index 2435b5f..cee0a7e 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java @@ -1,5 +1,6 @@ package com.github.chen0040.rl.actionselection; +import java.io.Serializable; import java.util.HashMap; import java.util.Map; @@ -7,8 +8,7 @@ /** * Created by xschen on 9/27/2015 0027. */ -public enum ActionSelectionStrategyFactory { - ; +public class ActionSelectionStrategyFactory implements Serializable { public static ActionSelectionStrategy deserialize(final String conf) { final String[] comps = conf.split(";"); @@ -19,10 +19,8 @@ public static ActionSelectionStrategy deserialize(final String conf) { if (field.length < 2) { continue; } - final String fieldname = field[0].trim(); - final String fieldvalue = field[1].trim(); - attributes.put(fieldname, fieldvalue); + attributes.put(field[0].trim(), field[1].trim()); } if (attributes.isEmpty()) { attributes.put("prototype", conf); @@ -58,4 +56,6 @@ public static String serialize(final ActionSelectionStrategy strategy) { } return sb.toString(); } + + } diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java index 4d4fe3d..3ce748a 100644 --- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java @@ -11,9 +11,10 @@ import com.github.chen0040.rl.utils.IndexValue; import java.io.Serializable; -import java.util.Random; import java.util.Set; +import static com.github.chen0040.rl.models.DefaultValues.*; + /** * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Q-Learning, which is an off-policy TD @@ -24,19 +25,23 @@ public class SarsaLearner implements Serializable, Cloneable { protected QModel model; private ActionSelectionStrategy actionSelectionStrategy; + @SuppressWarnings("Used-by-user") public SarsaLearner() { } + @SuppressWarnings("Used-by-user") public SarsaLearner(final int stateCount, final int actionCount) { - this(stateCount, actionCount, 0.1, 0.7, 0.1); + this(stateCount, actionCount, ALPHA, GAMMA, INITIAL_Q); } + @SuppressWarnings("Used-by-user") public SarsaLearner(final QModel model, final ActionSelectionStrategy actionSelectionStrategy) { this.model = model; this.actionSelectionStrategy = actionSelectionStrategy; } + @SuppressWarnings("Used-by-user") public SarsaLearner(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) { this.model = new QModel(stateCount, actionCount, initialQ); this.model.setAlpha(alpha); @@ -49,38 +54,6 @@ public static SarsaLearner fromJson(final String json) { return JSON.parseObject(json, SarsaLearner.class); } - public static void main(final String[] args) { - final int stateCount = 100; - final int actionCount = 10; - - final SarsaLearner learner = new SarsaLearner(stateCount, actionCount); - - double reward; // reward gained by transiting from prevState to currentState - final Random random = new Random(); - int currentStateId = random.nextInt(stateCount); - int currentActionId = learner.selectAction(currentStateId).getIndex(); - - for (int time = 0; time < 1000; ++time) { - - System.out.println("Controller does action-" + currentActionId); - - final int newStateId = random.nextInt(actionCount); - reward = random.nextDouble(); - - System.out.println("Now the new state is " + newStateId); - System.out.println("Controller receives Reward = " + reward); - - final int futureActionId = learner.selectAction(newStateId).getIndex(); - - System.out.println("Controller is expected to do action-" + futureActionId); - - learner.update(currentStateId, currentActionId, newStateId, futureActionId, reward); - - currentStateId = newStateId; - currentActionId = futureActionId; - } - } - @SuppressWarnings("Used-by-user") public String toJson() { return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); diff --git a/src/main/java/com/github/chen0040/rl/models/QModel.java b/src/main/java/com/github/chen0040/rl/models/QModel.java index 544b186..d48e1c1 100644 --- a/src/main/java/com/github/chen0040/rl/models/QModel.java +++ b/src/main/java/com/github/chen0040/rl/models/QModel.java @@ -4,7 +4,6 @@ import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.utils.Matrix; import com.github.chen0040.rl.utils.Vec; -import org.jetbrains.annotations.Nullable; import java.util.*; @@ -18,12 +17,12 @@ public class QModel { * Q value for (state_id, action_id) pair Q is known as the quality of state-action combination, note that it is * different from utility of a state */ - @org.jetbrains.annotations.Nullable + private Matrix Q; /** * $\alpha[s, a]$ value for learning rate: alpha(state_id, action_id) */ - @org.jetbrains.annotations.Nullable + private Matrix alphaMatrix; /** @@ -155,12 +154,12 @@ public IndexValue actionWithSoftMaxQAtState(final int stateId, final Set Date: Sun, 10 Mar 2019 15:47:27 +0100 Subject: [PATCH 12/12] -cleaning up code -remove plugin lombrok (no-compatibility with JAVA 11) --> adding Getter and Setter, where needed -adding class DefaultValues --- pom.xml | 24 +- .../AbstractActionSelectionStrategy.java | 68 ++-- .../ActionSelectionStrategyFactory.java | 56 +-- .../EpsilonGreedyActionSelectionStrategy.java | 79 ++--- .../GibbsSoftMaxActionSelectionStrategy.java | 50 ++- .../GreedyActionSelectionStrategy.java | 13 +- .../SoftMaxActionSelectionStrategy.java | 27 +- .../actorcritic/ActorCriticAgent.java | 86 +++-- .../chen0040/rl/learning/qlearn/QAgent.java | 108 +++--- .../rl/learning/qlearn/QLambdaLearner.java | 129 +++---- .../chen0040/rl/learning/qlearn/QLearner.java | 133 ++++---- .../chen0040/rl/learning/rlearn/RAgent.java | 100 +++--- .../chen0040/rl/learning/rlearn/RLearner.java | 146 ++++---- .../rl/learning/sarsa/SarsaAgent.java | 140 ++++---- .../rl/learning/sarsa/SarsaLambdaLearner.java | 121 +++---- .../rl/learning/sarsa/SarsaLearner.java | 153 ++++----- .../chen0040/rl/models/DefaultValues.java | 11 + .../rl/models/EligibilityTraceUpdateMode.java | 3 +- .../com/github/chen0040/rl/models/QModel.java | 176 ++++++---- .../chen0040/rl/models/UtilityModel.java | 86 ++--- .../github/chen0040/rl/utils/DoubleUtils.java | 13 +- .../github/chen0040/rl/utils/IndexValue.java | 85 ++--- .../com/github/chen0040/rl/utils/Matrix.java | 225 +++--------- .../github/chen0040/rl/utils/MatrixUtils.java | 29 -- .../github/chen0040/rl/utils/TupleTwo.java | 56 --- .../com/github/chen0040/rl/utils/Vec.java | 322 ++++-------------- .../github/chen0040/rl/utils/VectorUtils.java | 39 --- 27 files changed, 1014 insertions(+), 1464 deletions(-) create mode 100644 src/main/java/com/github/chen0040/rl/models/DefaultValues.java delete mode 100644 src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java delete mode 100644 src/main/java/com/github/chen0040/rl/utils/TupleTwo.java delete mode 100644 src/main/java/com/github/chen0040/rl/utils/VectorUtils.java diff --git a/pom.xml b/pom.xml index 6826743..9c0272d 100644 --- a/pom.xml +++ b/pom.xml @@ -1,6 +1,6 @@ - 4.0.0 @@ -38,7 +38,8 @@ Reinforcement Learning Algorithms - Classical RL algorithms implemented in Java, including Q-Learn, R-Learn, SARSA, Actor-Critic + Classical RL algorithms implemented in Java, including Q-Learn, R-Learn, SARSA, Actor-Critic + https://github.com/chen0040/java-reinforcement-learning @@ -235,7 +236,6 @@ - org.codehaus.mojo findbugs-maven-plugin @@ -361,8 +361,6 @@ - - @@ -382,14 +380,13 @@ org.apache.maven.plugins maven-resources-plugin + 2.3 UTF-8 - - @@ -488,23 +485,12 @@ - - - - org.projectlombok - lombok - provided - ${lombok.version} - - com.alibaba fastjson 1.2.41 - - diff --git a/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java index 7de7f9b..f591257 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java @@ -1,8 +1,8 @@ package com.github.chen0040.rl.actionselection; -import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.models.QModel; import com.github.chen0040.rl.models.UtilityModel; +import com.github.chen0040.rl.utils.IndexValue; import java.util.HashMap; import java.util.Map; @@ -14,58 +14,50 @@ */ public abstract class AbstractActionSelectionStrategy implements ActionSelectionStrategy { + Map attributes = new HashMap<>(); private String prototype; - protected Map attributes = new HashMap(); - public String getPrototype(){ - return prototype; + @SuppressWarnings("Used-by-user") + public AbstractActionSelectionStrategy() { + this.prototype = this.getClass().getCanonicalName(); } - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { - return new IndexValue(); + @SuppressWarnings("Used-by-user") + public AbstractActionSelectionStrategy(final HashMap attributes) { + this.attributes = attributes; + if (attributes.containsKey("prototype")) { + this.prototype = attributes.get("prototype"); + } } - public IndexValue selectAction(int stateId, UtilityModel model, Set actionsAtState) { - return new IndexValue(); + @Override + public String getPrototype() { + return this.prototype; } - public AbstractActionSelectionStrategy(){ - prototype = this.getClass().getCanonicalName(); + @Override + public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) { + return new IndexValue(); } - - public AbstractActionSelectionStrategy(HashMap attributes){ - this.attributes = attributes; - if(attributes.containsKey("prototype")){ - this.prototype = attributes.get("prototype"); - } + @Override + public IndexValue selectAction(final int stateId, final UtilityModel model, final Set actionsAtState) { + return new IndexValue(); } - public Map getAttributes(){ - return attributes; + @Override + public Map getAttributes() { + return this.attributes; } @Override - public boolean equals(Object obj) { - ActionSelectionStrategy rhs = (ActionSelectionStrategy)obj; - if(!prototype.equalsIgnoreCase(rhs.getPrototype())) return false; - for(Map.Entry entry : rhs.getAttributes().entrySet()) { - if(!attributes.containsKey(entry.getKey())) { - return false; - } - if(!attributes.get(entry.getKey()).equals(entry.getValue())){ - return false; - } - } - for(Map.Entry entry : attributes.entrySet()) { - if(!rhs.getAttributes().containsKey(entry.getKey())) { - return false; - } - if(!rhs.getAttributes().get(entry.getKey()).equals(entry.getValue())){ - return false; - } - } - return true; + public boolean equals(final Object obj) { + final ActionSelectionStrategy rhs = (ActionSelectionStrategy) obj; + return this.prototype.equalsIgnoreCase(rhs.getPrototype()) && + rhs.getAttributes().entrySet().stream().noneMatch(entry -> !this.attributes.containsKey(entry.getKey()) || + !this.attributes.get(entry.getKey()).equals(entry.getValue())) && + this.attributes.entrySet().stream().noneMatch(entry -> !rhs.getAttributes().containsKey(entry.getKey()) || + !rhs.getAttributes().get(entry.getKey()).equals(entry.getValue())); } @Override diff --git a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java index ce92be0..cee0a7e 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java @@ -1,5 +1,6 @@ package com.github.chen0040.rl.actionselection; +import java.io.Serializable; import java.util.HashMap; import java.util.Map; @@ -7,53 +8,54 @@ /** * Created by xschen on 9/27/2015 0027. */ -public class ActionSelectionStrategyFactory { - public static ActionSelectionStrategy deserialize(String conf){ - String[] comps = conf.split(";"); - - HashMap attributes = new HashMap(); - for(int i=0; i < comps.length; ++i){ - String comp = comps[i]; - String[] field = comp.split("="); - if(field.length < 2) continue; - String fieldname = field[0].trim(); - String fieldvalue = field[1].trim(); - - attributes.put(fieldname, fieldvalue); +public class ActionSelectionStrategyFactory implements Serializable { + + public static ActionSelectionStrategy deserialize(final String conf) { + final String[] comps = conf.split(";"); + + final HashMap attributes = new HashMap<>(); + for (final String comp : comps) { + final String[] field = comp.split("="); + if (field.length < 2) { + continue; + } + + attributes.put(field[0].trim(), field[1].trim()); } - if(attributes.isEmpty()){ + if (attributes.isEmpty()) { attributes.put("prototype", conf); } - String prototype = attributes.get("prototype"); - if(prototype.equals(GreedyActionSelectionStrategy.class.getCanonicalName())){ + final String prototype = attributes.get("prototype"); + if (prototype.equals(GreedyActionSelectionStrategy.class.getCanonicalName())) { return new GreedyActionSelectionStrategy(); - } else if(prototype.equals(SoftMaxActionSelectionStrategy.class.getCanonicalName())){ + } else if (prototype.equals(SoftMaxActionSelectionStrategy.class.getCanonicalName())) { return new SoftMaxActionSelectionStrategy(); - } else if(prototype.equals(EpsilonGreedyActionSelectionStrategy.class.getCanonicalName())){ + } else if (prototype.equals(EpsilonGreedyActionSelectionStrategy.class.getCanonicalName())) { return new EpsilonGreedyActionSelectionStrategy(attributes); - } else if(prototype.equals(GibbsSoftMaxActionSelectionStrategy.class.getCanonicalName())){ + } else if (prototype.equals(GibbsSoftMaxActionSelectionStrategy.class.getCanonicalName())) { return new GibbsSoftMaxActionSelectionStrategy(); } return null; } - public static String serialize(ActionSelectionStrategy strategy){ - Map attributes = strategy.getAttributes(); + public static String serialize(final ActionSelectionStrategy strategy) { + final Map attributes = strategy.getAttributes(); attributes.put("prototype", strategy.getPrototype()); - StringBuilder sb = new StringBuilder(); + final StringBuilder sb = new StringBuilder(); boolean first = true; - for(Map.Entry entry : attributes.entrySet()){ - if(first){ + for (final Map.Entry entry : attributes.entrySet()) { + if (first) { first = false; - } - else{ + } else { sb.append(";"); } - sb.append(entry.getKey()+"="+entry.getValue()); + sb.append(entry.getKey()).append("=").append(entry.getValue()); } return sb.toString(); } + + } diff --git a/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java index 5f7db9a..5ac3d36 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java @@ -1,7 +1,7 @@ package com.github.chen0040.rl.actionselection; -import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.models.QModel; +import com.github.chen0040.rl.utils.IndexValue; import java.util.*; @@ -10,69 +10,62 @@ * Created by xschen on 9/27/2015 0027. */ public class EpsilonGreedyActionSelectionStrategy extends AbstractActionSelectionStrategy { - public static final String EPSILON = "epsilon"; + private static final String EPSILON = "epsilon"; private Random random = new Random(); - @Override - public Object clone(){ - EpsilonGreedyActionSelectionStrategy clone = new EpsilonGreedyActionSelectionStrategy(); - clone.copy(this); - return clone; - } - - public void copy(EpsilonGreedyActionSelectionStrategy rhs){ - random = rhs.random; - for(Map.Entry entry : rhs.attributes.entrySet()){ - attributes.put(entry.getKey(), entry.getValue()); - } + @SuppressWarnings("Used-by-user") + public EpsilonGreedyActionSelectionStrategy() { + this.epsilon(); } - @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof EpsilonGreedyActionSelectionStrategy){ - EpsilonGreedyActionSelectionStrategy rhs = (EpsilonGreedyActionSelectionStrategy)obj; - if(epsilon() != rhs.epsilon()) return false; - // if(!random.equals(rhs.random)) return false; - return true; - } - return false; + @SuppressWarnings("Used-by-user") + public EpsilonGreedyActionSelectionStrategy(final HashMap attributes) { + super(attributes); } - private double epsilon(){ - return Double.parseDouble(attributes.get(EPSILON)); + @SuppressWarnings("Used-by-user") + public EpsilonGreedyActionSelectionStrategy(final Random random) { + this.random = random; + this.epsilon(); } - public EpsilonGreedyActionSelectionStrategy(){ - epsilon(0.1); + @Override + public Object clone() { + final EpsilonGreedyActionSelectionStrategy clone = new EpsilonGreedyActionSelectionStrategy(); + clone.copy(this); + return clone; } - public EpsilonGreedyActionSelectionStrategy(HashMap attributes){ - super(attributes); + public void copy(final EpsilonGreedyActionSelectionStrategy rhs) { + this.random = rhs.random; + for (final Map.Entry entry : rhs.attributes.entrySet()) { + this.attributes.put(entry.getKey(), entry.getValue()); + } } - private void epsilon(double value){ - attributes.put(EPSILON, "" + value); + @Override + public boolean equals(final Object obj) { + return obj instanceof EpsilonGreedyActionSelectionStrategy && this.epsilon() == ((EpsilonGreedyActionSelectionStrategy) obj).epsilon(); } - public EpsilonGreedyActionSelectionStrategy(Random random){ - this.random = random; - epsilon(0.1); + private double epsilon() { + return Double.parseDouble(this.attributes.get(EpsilonGreedyActionSelectionStrategy.EPSILON)); } @Override - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { - if(random.nextDouble() < 1- epsilon()){ + public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) { + if (this.random.nextDouble() < 1 - this.epsilon()) { return model.actionWithMaxQAtState(stateId, actionsAtState); - }else{ - int actionId; - if(actionsAtState != null && !actionsAtState.isEmpty()) { - List actions = new ArrayList<>(actionsAtState); - actionId = actions.get(random.nextInt(actions.size())); + } else { + final int actionId; + if (actionsAtState != null && !actionsAtState.isEmpty()) { + final List actions = new ArrayList<>(actionsAtState); + actionId = actions.get(this.random.nextInt(actions.size())); } else { - actionId = random.nextInt(model.getActionCount()); + actionId = this.random.nextInt(model.getActionCount()); } - double Q = model.getQ(stateId, actionId); + final double Q = model.getQ(stateId, actionId); return new IndexValue(actionId, Q); } } diff --git a/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java index 8b2d8d2..283748c 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java @@ -1,12 +1,13 @@ package com.github.chen0040.rl.actionselection; -import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.models.QModel; +import com.github.chen0040.rl.utils.IndexValue; import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.Set; +import java.util.stream.IntStream; /** @@ -14,52 +15,49 @@ */ public class GibbsSoftMaxActionSelectionStrategy extends AbstractActionSelectionStrategy { - private Random random = null; - public GibbsSoftMaxActionSelectionStrategy(){ - random = new Random(); + private final Random random; + + @SuppressWarnings("Used-by-user") + public GibbsSoftMaxActionSelectionStrategy() { + this.random = new Random(); } - public GibbsSoftMaxActionSelectionStrategy(Random random){ + @SuppressWarnings("Used-by-user") + public GibbsSoftMaxActionSelectionStrategy(final Random random) { this.random = random; } @Override public Object clone() { - GibbsSoftMaxActionSelectionStrategy clone = new GibbsSoftMaxActionSelectionStrategy(); - return clone; + return new GibbsSoftMaxActionSelectionStrategy(); } @Override - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { - List actions = new ArrayList(); - if(actionsAtState == null){ - for(int i=0; i < model.getActionCount(); ++i){ - actions.add(i); - } - }else{ - for(Integer actionId : actionsAtState){ - actions.add(actionId); - } + public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) { + final List actions = new ArrayList(); + if (actionsAtState == null) { + IntStream.range(0, model.getActionCount()).forEach(actions::add); + } else { + actions.addAll(actionsAtState); } double sum = 0; - List plist = new ArrayList(); - for(int i=0; i < actions.size(); ++i){ - int actionId = actions.get(i); - double p = Math.exp(model.getQ(stateId, actionId)); + final List plist = new ArrayList(); + for (final int actionId : actions) { + final double p = Math.exp(model.getQ(stateId, actionId)); sum += p; plist.add(sum); } - IndexValue iv = new IndexValue(); + final IndexValue iv = new IndexValue(); iv.setIndex(-1); iv.setValue(Double.NEGATIVE_INFINITY); - double r = sum * random.nextDouble(); - for(int i=0; i < actions.size(); ++i){ + final double r = sum * this.random.nextDouble(); + for (int i = 0; i < actions.size(); ++i) { - if(plist.get(i) >= r){ - int actionId = actions.get(i); + if (plist.get(i) >= r) { + final int actionId = actions.get(i); iv.setValue(model.getQ(stateId, actionId)); iv.setIndex(actionId); break; diff --git a/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java index 6b0f350..e735fd3 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java @@ -1,7 +1,7 @@ package com.github.chen0040.rl.actionselection; -import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.models.QModel; +import com.github.chen0040.rl.utils.IndexValue; import java.util.Set; @@ -11,18 +11,17 @@ */ public class GreedyActionSelectionStrategy extends AbstractActionSelectionStrategy { @Override - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { + public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) { return model.actionWithMaxQAtState(stateId, actionsAtState); } @Override - public Object clone(){ - GreedyActionSelectionStrategy clone = new GreedyActionSelectionStrategy(); - return clone; + public Object clone() { + return new GreedyActionSelectionStrategy(); } @Override - public boolean equals(Object obj){ - return obj != null && obj instanceof GreedyActionSelectionStrategy; + public boolean equals(final Object obj) { + return obj instanceof GreedyActionSelectionStrategy; } } diff --git a/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java index f9735b9..c0b89f5 100644 --- a/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java +++ b/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java @@ -1,7 +1,7 @@ package com.github.chen0040.rl.actionselection; -import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.models.QModel; +import com.github.chen0040.rl.utils.IndexValue; import java.util.Random; import java.util.Set; @@ -13,27 +13,26 @@ public class SoftMaxActionSelectionStrategy extends AbstractActionSelectionStrategy { private Random random = new Random(); - @Override - public Object clone(){ - SoftMaxActionSelectionStrategy clone = new SoftMaxActionSelectionStrategy(random); - return clone; - } + public SoftMaxActionSelectionStrategy() { - @Override - public boolean equals(Object obj){ - return obj != null && obj instanceof SoftMaxActionSelectionStrategy; } - public SoftMaxActionSelectionStrategy(){ + public SoftMaxActionSelectionStrategy(final Random random) { + this.random = random; + } + @Override + public Object clone() { + return new SoftMaxActionSelectionStrategy(this.random); } - public SoftMaxActionSelectionStrategy(Random random){ - this.random = random; + @Override + public boolean equals(final Object obj) { + return obj instanceof SoftMaxActionSelectionStrategy; } @Override - public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) { - return model.actionWithSoftMaxQAtState(stateId, actionsAtState, random); + public IndexValue selectAction(final int stateId, final QModel model, final Set actionsAtState) { + return model.actionWithSoftMaxQAtState(stateId, actionsAtState, this.random); } } diff --git a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java index 6e34874..864668f 100644 --- a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java +++ b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java @@ -3,9 +3,7 @@ import com.github.chen0040.rl.utils.Vec; import java.io.Serializable; -import java.util.Random; import java.util.Set; -import java.util.function.Function; /** @@ -17,85 +15,85 @@ public class ActorCriticAgent implements Serializable { private int prevState; private int prevAction; - public void enableEligibilityTrace(double lambda){ - ActorCriticLambdaLearner acll = new ActorCriticLambdaLearner(learner); - acll.setLambda(lambda); - learner = acll; + @SuppressWarnings("Used-by-user") + public ActorCriticAgent(final int stateCount, final int actionCount) { + this.learner = new ActorCriticLearner(stateCount, actionCount); } - public void start(int stateId){ - currentState = stateId; - prevAction = -1; - prevState = -1; - } + public ActorCriticAgent() { - public ActorCriticLearner getLearner(){ - return learner; } - public void setLearner(ActorCriticLearner learner){ + public ActorCriticAgent(final ActorCriticLearner learner) { this.learner = learner; } - public ActorCriticAgent(int stateCount, int actionCount){ - learner = new ActorCriticLearner(stateCount, actionCount); + @SuppressWarnings("Used-by-user") + public void enableEligibilityTrace(final double lambda) { + final ActorCriticLambdaLearner acll = new ActorCriticLambdaLearner(this.learner); + acll.setLambda(lambda); + this.learner = acll; } - public ActorCriticAgent(){ + @SuppressWarnings("Used-by-user") + public void start(final int stateId) { + this.currentState = stateId; + this.prevAction = -1; + this.prevState = -1; + } + @SuppressWarnings("Used-by-user") + public ActorCriticLearner getLearner() { + return this.learner; } - public ActorCriticAgent(ActorCriticLearner learner){ + public void setLearner(final ActorCriticLearner learner) { this.learner = learner; } - public ActorCriticAgent makeCopy(){ - ActorCriticAgent clone = new ActorCriticAgent(); + public ActorCriticAgent makeCopy() { + final ActorCriticAgent clone = new ActorCriticAgent(); clone.copy(this); return clone; } - public void copy(ActorCriticAgent rhs){ - learner = (ActorCriticLearner)rhs.learner.makeCopy(); - prevAction = rhs.prevAction; - prevState = rhs.prevState; - currentState = rhs.currentState; + public void copy(final ActorCriticAgent rhs) { + this.learner = (ActorCriticLearner) rhs.learner.makeCopy(); + this.prevAction = rhs.prevAction; + this.prevState = rhs.prevState; + this.currentState = rhs.currentState; } @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof ActorCriticAgent){ - ActorCriticAgent rhs = (ActorCriticAgent)obj; - return learner.equals(rhs.learner) && prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState; + public boolean equals(final Object obj) { + if (obj instanceof ActorCriticAgent) { + final ActorCriticAgent rhs = (ActorCriticAgent) obj; + return this.learner.equals(rhs.learner) && this.prevAction == rhs.prevAction && this.prevState == rhs.prevState && this.currentState == rhs.currentState; } return false; } - public int selectAction(Set actionsAtState){ - return learner.selectAction(currentState, actionsAtState); + public int selectAction(final Set actionsAtState) { + return this.learner.selectAction(this.currentState, actionsAtState); } - public int selectAction(){ - return learner.selectAction(currentState); + public int selectAction() { + return this.learner.selectAction(this.currentState); } - public void update(int actionTaken, int newState, double immediateReward, final Vec V){ - update(actionTaken, newState, null, immediateReward, V); + public void update(final int actionTaken, final int newState, final double immediateReward, final Vec V) { + this.update(actionTaken, newState, null, immediateReward, V); } - public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward, final Vec V){ + public void update(final int actionTaken, final int newState, final Set actionsAtNewState, final double immediateReward, final Vec V) { - learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward, new Function() { - public Double apply(Integer stateId) { - return V.get(stateId); - } - }); + this.learner.update(this.currentState, actionTaken, newState, actionsAtNewState, immediateReward, V::get); - prevAction = actionTaken; - prevState = currentState; + this.prevAction = actionTaken; + this.prevState = this.currentState; - currentState = newState; + this.currentState = newState; } } diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java index afdb314..1f85622 100644 --- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java +++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java @@ -3,14 +3,13 @@ import com.github.chen0040.rl.utils.IndexValue; import java.io.Serializable; -import java.util.Random; import java.util.Set; /** * Created by xschen on 9/27/2015 0027. */ -public class QAgent implements Serializable{ +public class QAgent implements Serializable { private QLearner learner; private int currentState; private int prevState; @@ -18,94 +17,97 @@ public class QAgent implements Serializable{ /** action taken at prevState */ private int prevAction; - public int getCurrentState(){ - return currentState; + public QAgent(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) { + this.learner = new QLearner(stateCount, actionCount, alpha, gamma, initialQ); } - public int getPrevState(){ - return prevState; + public QAgent(final QLearner learner) { + this.learner = learner; } - public int getPrevAction(){ - return prevAction; + public QAgent(final int stateCount, final int actionCount) { + this.learner = new QLearner(stateCount, actionCount); } - public void start(int currentState){ - this.currentState = currentState; - this.prevAction = -1; - this.prevState = -1; - } + public QAgent() { - public IndexValue selectAction(){ - return learner.selectAction(currentState); } - public IndexValue selectAction(Set actionsAtState){ - return learner.selectAction(currentState, actionsAtState); + @SuppressWarnings("Used-by-user") + public int getCurrentState() { + return this.currentState; } - public void update(int actionTaken, int newState, double immediateReward){ - update(actionTaken, newState, null, immediateReward); + @SuppressWarnings("Used-by-user") + public int getPrevState() { + return this.prevState; } - public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward){ - - learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward); - - prevState = currentState; - prevAction = actionTaken; - - currentState = newState; + @SuppressWarnings("Used-by-user") + public int getPrevAction() { + return this.prevAction; } - public void enableEligibilityTrace(double lambda){ - QLambdaLearner acll = new QLambdaLearner(learner); - acll.setLambda(lambda); - learner = acll; + public void start(final int currentState) { + this.currentState = currentState; + this.prevAction = -1; + this.prevState = -1; } - public QLearner getLearner(){ - return learner; + public IndexValue selectAction() { + return this.learner.selectAction(this.currentState); } - public void setLearner(QLearner learner){ - this.learner = learner; + public IndexValue selectAction(final Set actionsAtState) { + return this.learner.selectAction(this.currentState, actionsAtState); } - public QAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ){ - learner = new QLearner(stateCount, actionCount, alpha, gamma, initialQ); + public void update(final int actionTaken, final int newState, final double immediateReward) { + this.update(actionTaken, newState, null, immediateReward); } - public QAgent(QLearner learner){ - this.learner = learner; + public void update(final int actionTaken, final int newState, final Set actionsAtNewState, final double immediateReward) { + + this.learner.update(this.currentState, actionTaken, newState, actionsAtNewState, immediateReward); + + this.prevState = this.currentState; + this.prevAction = actionTaken; + + this.currentState = newState; } - public QAgent(int stateCount, int actionCount){ - learner = new QLearner(stateCount, actionCount); + public void enableEligibilityTrace(final double lambda) { + final QLambdaLearner acll = new QLambdaLearner(this.learner); + acll.setLambda(lambda); + this.learner = acll; } - public QAgent(){ + public QLearner getLearner() { + return this.learner; + } + public void setLearner(final QLearner learner) { + this.learner = learner; } - public QAgent makeCopy(){ - QAgent clone = new QAgent(); + public QAgent makeCopy() { + final QAgent clone = new QAgent(); clone.copy(this); return clone; } - public void copy(QAgent rhs){ - learner.copy(rhs.learner); - prevAction = rhs.prevAction; - prevState = rhs.prevState; - currentState = rhs.currentState; + public void copy(final QAgent rhs) { + this.learner.copy(rhs.learner); + this.prevAction = rhs.prevAction; + this.prevState = rhs.prevState; + this.currentState = rhs.currentState; } @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof QAgent){ - QAgent rhs = (QAgent)obj; - return prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState && learner.equals(rhs.learner); + public boolean equals(final Object obj) { + if (obj instanceof QAgent) { + final QAgent rhs = (QAgent) obj; + return this.prevAction == rhs.prevAction && this.prevState == rhs.prevState && this.currentState == rhs.currentState && this.learner.equals(rhs.learner); } return false; } diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java index 875ef3a..30ffd2e 100644 --- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java @@ -1,6 +1,7 @@ package com.github.chen0040.rl.learning.qlearn; +import com.github.chen0040.rl.models.DefaultValues; import com.github.chen0040.rl.models.EligibilityTraceUpdateMode; import com.github.chen0040.rl.utils.Matrix; @@ -11,125 +12,129 @@ * Created by xschen on 9/28/2015 0028. */ public class QLambdaLearner extends QLearner { - private double lambda = 0.9; + private double lambda = DefaultValues.LAMBDA; private Matrix e; private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace; + @SuppressWarnings("Used-by-user") + public QLambdaLearner(final QLearner learner) { + this.copy(learner); + this.e = new Matrix(this.model.getStateCount(), this.model.getActionCount()); + } + + private QLambdaLearner() { + super(); + } + + @SuppressWarnings("Used-by-user") + public QLambdaLearner(final int stateCount, final int actionCount) { + super(stateCount, actionCount); + this.e = new Matrix(stateCount, actionCount); + } + + @SuppressWarnings("Used-by-user") + public QLambdaLearner(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) { + super(stateCount, actionCount, alpha, gamma, initialQ); + this.e = new Matrix(stateCount, actionCount); + } + + @SuppressWarnings("Used-by-user") public EligibilityTraceUpdateMode getTraceUpdateMode() { - return traceUpdateMode; + return this.traceUpdateMode; } - public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) { + @SuppressWarnings("Used-by-user") + public void setTraceUpdateMode(final EligibilityTraceUpdateMode traceUpdateMode) { this.traceUpdateMode = traceUpdateMode; } - public double getLambda(){ - return lambda; + @SuppressWarnings("Used-by-user") + public double getLambda() { + return this.lambda; } - public void setLambda(double lambda){ + @SuppressWarnings("Used-by-user") + public void setLambda(final double lambda) { this.lambda = lambda; } - public QLambdaLearner makeCopy(){ - QLambdaLearner clone = new QLambdaLearner(); + @Override + public QLambdaLearner makeCopy() { + final QLambdaLearner clone = new QLambdaLearner(); clone.copy(this); return clone; } @Override - public void copy(QLearner rhs){ + public void copy(final QLearner rhs) { super.copy(rhs); - QLambdaLearner rhs2 = (QLambdaLearner)rhs; - lambda = rhs2.lambda; - e = rhs2.e.makeCopy(); - traceUpdateMode = rhs2.traceUpdateMode; - } - - public QLambdaLearner(QLearner learner){ - copy(learner); - e = new Matrix(model.getStateCount(), model.getActionCount()); + final QLambdaLearner rhs2 = (QLambdaLearner) rhs; + this.lambda = rhs2.lambda; + this.e = rhs2.e.makeCopy(); + this.traceUpdateMode = rhs2.traceUpdateMode; } @Override - public boolean equals(Object obj){ - if(!super.equals(obj)){ + public boolean equals(final Object obj) { + if (!super.equals(obj)) { return false; } - if(obj instanceof QLambdaLearner){ - QLambdaLearner rhs = (QLambdaLearner)obj; - return rhs.lambda == lambda && e.equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode; + if (obj instanceof QLambdaLearner) { + final QLambdaLearner rhs = (QLambdaLearner) obj; + return rhs.lambda == this.lambda && this.e.equals(rhs.e) && this.traceUpdateMode == rhs.traceUpdateMode; } return false; } - public QLambdaLearner(){ - super(); - } - - public QLambdaLearner(int stateCount, int actionCount){ - super(stateCount, actionCount); - e = new Matrix(stateCount, actionCount); - } - - public QLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ){ - super(stateCount, actionCount, alpha, gamma, initialQ); - e = new Matrix(stateCount, actionCount); + @SuppressWarnings("Used-by-user") + public Matrix getEligibility() { + return this.e; } - public Matrix getEligibility() - { - return e; - } - - public void setEligibility(Matrix e){ + @SuppressWarnings("Used-by-user") + public void setEligibility(final Matrix e) { this.e = e; } @Override - public void update(int currentStateId, int currentActionId, int nextStateId, Set actionsAtNextStateId, double immediateReward) - { + public void update(final int currentStateId, final int currentActionId, final int nextStateId, final Set actionsAtNextStateId, final double immediateReward) { // old_value is $Q_t(s_t, a_t)$ - double oldQ = model.getQ(currentStateId, currentActionId); + double oldQ = this.model.getQ(currentStateId, currentActionId); // learning_rate; - double alpha = model.getAlpha(currentStateId, currentActionId); + final double alpha = this.model.getAlpha(currentStateId, currentActionId); // discount_rate; - double gamma = model.getGamma(); + final double gamma = this.model.getGamma(); // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ - double maxQ = maxQAtState(nextStateId, actionsAtNextStateId); - - double td_error = immediateReward + gamma * maxQ - oldQ; + final double maxQ = this.maxQAtState(nextStateId, actionsAtNextStateId); - int stateCount = model.getStateCount(); - int actionCount = model.getActionCount(); + final double td_error = immediateReward + gamma * maxQ - oldQ; - e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1); + final int stateCount = this.model.getStateCount(); + final int actionCount = this.model.getActionCount(); + this.e.set(currentStateId, currentActionId, this.e.get(currentStateId, currentActionId) + 1); - for(int stateId = 0; stateId < stateCount; ++stateId){ - for(int actionId = 0; actionId < actionCount; ++actionId){ - oldQ = model.getQ(stateId, actionId); - double newQ = oldQ + alpha * td_error * e.get(stateId, actionId); + for (int stateId = 0; stateId < stateCount; ++stateId) { + for (int actionId = 0; actionId < actionCount; ++actionId) { + oldQ = this.model.getQ(stateId, actionId); + final double newQ = oldQ + alpha * td_error * this.e.get(stateId, actionId); // new_value is $Q_{t+1}(s_t, a_t)$ - model.setQ(currentStateId, currentActionId, newQ); + this.model.setQ(currentStateId, currentActionId, newQ); if (actionId != currentActionId) { - e.set(currentStateId, actionId, 0); + this.e.set(currentStateId, actionId, 0); } else { - e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda); + this.e.set(stateId, actionId, this.e.get(stateId, actionId) * gamma * this.lambda); } } } - - - } } diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java index 865abc5..7970237 100644 --- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java @@ -2,7 +2,6 @@ import com.alibaba.fastjson.JSON; -import com.alibaba.fastjson.annotation.JSONField; import com.alibaba.fastjson.serializer.SerializerFeature; import com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy; import com.github.chen0040.rl.actionselection.ActionSelectionStrategy; @@ -12,131 +11,133 @@ import com.github.chen0040.rl.utils.IndexValue; import java.io.Serializable; -import java.util.Random; import java.util.Set; +import static com.github.chen0040.rl.models.DefaultValues.*; + /** - * Created by xschen on 9/27/2015 0027. - * Implement temporal-difference learning Q-Learning, which is an off-policy TD control algorithm - * Q is known as the quality of state-action combination, note that it is different from utility of a state + * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Q-Learning, which is an off-policy TD + * control algorithm Q is known as the quality of state-action combination, note that it is different from utility of a + * state */ -public class QLearner implements Serializable,Cloneable { +public class QLearner implements Serializable, Cloneable { protected QModel model; private ActionSelectionStrategy actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); - public QLearner makeCopy(){ - QLearner clone = new QLearner(); - clone.copy(this); - return clone; - } + public QLearner() { - public String toJson() { - return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); } - public static QLearner fromJson(String json){ - return JSON.parseObject(json, QLearner.class); + public QLearner(final int stateCount, final int actionCount) { + this(stateCount, actionCount, ALPHA, GAMMA, INITIAL_Q); } - public void copy(QLearner rhs){ - model = rhs.model.makeCopy(); - actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone(); + public QLearner(final QModel model, final ActionSelectionStrategy actionSelectionStrategy) { + this.model = model; + this.actionSelectionStrategy = actionSelectionStrategy; } - @Override - public boolean equals(Object obj){ - if(obj !=null && obj instanceof QLearner){ - QLearner rhs = (QLearner)obj; - if(!model.equals(rhs.model)) return false; - return actionSelectionStrategy.equals(rhs.actionSelectionStrategy); - } - return false; + public QLearner(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) { + this.model = new QModel(stateCount, actionCount, initialQ); + this.model.setAlpha(alpha); + this.model.setGamma(gamma); + this.actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); } - public QModel getModel() { - return model; + @SuppressWarnings("Used-by-user") + public static QLearner fromJson(final String json) { + return JSON.parseObject(json, QLearner.class); } - public void setModel(QModel model) { - this.model = model; + public QLearner makeCopy() { + final QLearner clone = new QLearner(); + clone.copy(this); + return clone; } - - public String getActionSelection() { - return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); + @SuppressWarnings("Used-by-user") + public String toJson() { + return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); } - public void setActionSelection(String conf) { - this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); + public void copy(final QLearner rhs) { + this.model = rhs.model.makeCopy(); + this.actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone(); } - public QLearner(){ - + @Override + public boolean equals(final Object obj) { + if (obj instanceof QLearner) { + final QLearner rhs = (QLearner) obj; + if (!this.model.equals(rhs.model)) { + return false; + } + return this.actionSelectionStrategy.equals(rhs.actionSelectionStrategy); + } + return false; } - public QLearner(int stateCount, int actionCount){ - this(stateCount, actionCount, 0.1, 0.7, 0.1); + public QModel getModel() { + return this.model; } - public QLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy){ + public void setModel(final QModel model) { this.model = model; - this.actionSelectionStrategy = actionSelectionStrategy; } - public QLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) - { - model = new QModel(stateCount, actionCount, initialQ); - model.setAlpha(alpha); - model.setGamma(gamma); - actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); + @SuppressWarnings("Used-by-user") + public String getActionSelection() { + return ActionSelectionStrategyFactory.serialize(this.actionSelectionStrategy); } + @SuppressWarnings("Used-by-user") + public void setActionSelection(final String conf) { + this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); + } - protected double maxQAtState(int stateId, Set actionsAtState){ - IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState); - double maxQ = iv.getValue(); - return maxQ; + double maxQAtState(final int stateId, final Set actionsAtState) { + return this.model.actionWithMaxQAtState(stateId, actionsAtState).getValue(); } - public IndexValue selectAction(int stateId, Set actionsAtState){ - return actionSelectionStrategy.selectAction(stateId, model, actionsAtState); + @SuppressWarnings("Used-by-user") + public IndexValue selectAction(final int stateId, final Set actionsAtState) { + return this.actionSelectionStrategy.selectAction(stateId, this.model, actionsAtState); } - public IndexValue selectAction(int stateId){ - return selectAction(stateId, null); + @SuppressWarnings("Used-by-user") + public IndexValue selectAction(final int stateId) { + return this.selectAction(stateId, null); } - public void update(int stateId, int actionId, int nextStateId, double immediateReward){ - update(stateId, actionId, nextStateId, null, immediateReward); + public void update(final int stateId, final int actionId, final int nextStateId, final double immediateReward) { + this.update(stateId, actionId, nextStateId, null, immediateReward); } - public void update(int stateId, int actionId, int nextStateId, Set actionsAtNextStateId, double immediateReward) - { + public void update(final int stateId, final int actionId, final int nextStateId, final Set actionsAtNextStateId, final double immediateReward) { // old_value is $Q_t(s_t, a_t)$ - double oldQ = model.getQ(stateId, actionId); + final double oldQ = this.model.getQ(stateId, actionId); // learning_rate; - double alpha = model.getAlpha(stateId, actionId); + final double alpha = this.model.getAlpha(stateId, actionId); // discount_rate; - double gamma = model.getGamma(); + final double gamma = this.model.getGamma(); // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ - double maxQ = maxQAtState(nextStateId, actionsAtNextStateId); + final double maxQ = this.maxQAtState(nextStateId, actionsAtNextStateId); // learned_value = immediate_reward + gamma * estimate_of_optimal_future_value // old_value = oldQ // temporal_difference = learned_value - old_value // new_value = old_value + learning_rate * temporal_difference - double newQ = oldQ + alpha * (immediateReward + gamma * maxQ - oldQ); + final double newQ = oldQ + alpha * (immediateReward + gamma * maxQ - oldQ); // new_value is $Q_{t+1}(s_t, a_t)$ - model.setQ(stateId, actionId, newQ); + this.model.setQ(stateId, actionId, newQ); } - } diff --git a/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java b/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java index f26f20a..d7ef30c 100644 --- a/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java +++ b/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java @@ -3,99 +3,95 @@ import com.github.chen0040.rl.utils.IndexValue; import java.io.Serializable; -import java.util.Random; import java.util.Set; /** * Created by xschen on 9/27/2015 0027. */ -public class RAgent implements Serializable{ +public class RAgent implements Serializable { private RLearner learner; private int currentState; private int currentAction; private double currentValue; - public int getCurrentState(){ - return currentState; + public RAgent() { + + } + + public RAgent(final int stateCount, final int actionCount, final double alpha, final double beta, final double rho, final double initialQ) { + this.learner = new RLearner(stateCount, actionCount, alpha, beta, rho, initialQ); + } + + public RAgent(final int stateCount, final int actionCount) { + this.learner = new RLearner(stateCount, actionCount); } - public int getCurrentAction(){ - return currentAction; + @SuppressWarnings("Used-by-user") + public int getCurrentState() { + return this.currentState; } - public void start(int currentState){ + @SuppressWarnings("Used-by-user") + public int getCurrentAction() { + return this.currentAction; + } + + public void start(final int currentState) { this.currentState = currentState; } - public RAgent makeCopy(){ - RAgent clone = new RAgent(); + public RAgent makeCopy() { + final RAgent clone = new RAgent(); clone.copy(this); return clone; } - public void copy(RAgent rhs){ - currentState = rhs.currentState; - currentAction = rhs.currentAction; - learner.copy(rhs.learner); + public void copy(final RAgent rhs) { + this.currentState = rhs.currentState; + this.currentAction = rhs.currentAction; + this.learner.copy(rhs.learner); } @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof RAgent){ - RAgent rhs = (RAgent)obj; - if(!learner.equals(rhs.learner)) return false; - if(currentAction != rhs.currentAction) return false; - return currentState == rhs.currentState; + public boolean equals(final Object obj) { + if (obj instanceof RAgent) { + final RAgent rhs = (RAgent) obj; + return this.learner.equals(rhs.learner) && this.currentAction == rhs.currentAction && this.currentState == rhs.currentState; } return false; } - public IndexValue selectAction(){ - return selectAction(null); + public IndexValue selectAction() { + return this.selectAction(null); } - public IndexValue selectAction(Set actionsAtState){ - - if(currentAction==-1){ - IndexValue iv = learner.selectAction(currentState, actionsAtState); - currentAction = iv.getIndex(); - currentValue = iv.getValue(); + public IndexValue selectAction(final Set actionsAtState) { + if (this.currentAction == -1) { + final IndexValue iv = this.learner.selectAction(this.currentState, actionsAtState); + this.currentAction = iv.getIndex(); + this.currentValue = iv.getValue(); } - return new IndexValue(currentAction, currentValue); + return new IndexValue(this.currentAction, this.currentValue); } - public void update(int newState, double immediateReward){ - update(newState, null, immediateReward); + public void update(final int newState, final double immediateReward) { + this.update(newState, null, immediateReward); } - public void update(int newState, Set actionsAtState, double immediateReward){ - if(currentAction != -1) { - learner.update(currentState, currentAction, newState, actionsAtState, immediateReward); - currentState = newState; - currentAction = -1; + public void update(final int newState, final Set actionsAtState, final double immediateReward) { + if (this.currentAction != -1) { + this.learner.update(this.currentState, this.currentAction, newState, actionsAtState, immediateReward); + this.currentState = newState; + this.currentAction = -1; } } - public RAgent(){ - + public RLearner getLearner() { + return this.learner; } - - - public RLearner getLearner(){ - return learner; - } - - public void setLearner(RLearner learner){ + public void setLearner(final RLearner learner) { this.learner = learner; } - - public RAgent(int stateCount, int actionCount, double alpha, double beta, double rho, double initialQ){ - learner = new RLearner(stateCount, actionCount, alpha, beta, rho, initialQ); - } - - public RAgent(int stateCount, int actionCount){ - learner = new RLearner(stateCount, actionCount); - } } diff --git a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java index 910d53f..bc520e7 100644 --- a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java @@ -7,135 +7,141 @@ import com.github.chen0040.rl.actionselection.ActionSelectionStrategy; import com.github.chen0040.rl.actionselection.ActionSelectionStrategyFactory; import com.github.chen0040.rl.actionselection.EpsilonGreedyActionSelectionStrategy; +import com.github.chen0040.rl.models.DefaultValues; import com.github.chen0040.rl.models.QModel; import com.github.chen0040.rl.utils.IndexValue; -import lombok.Getter; import java.io.Serializable; import java.util.Set; +import static com.github.chen0040.rl.models.DefaultValues.ALPHA; + /** * Created by xschen on 9/27/2015 0027. */ -public class RLearner implements Serializable, Cloneable{ +public class RLearner implements Serializable, Cloneable { private QModel model; private ActionSelectionStrategy actionSelectionStrategy; private double rho; private double beta; - public String toJson() { - return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); + public RLearner() { + + } + + public RLearner(final int stateCount, final int actionCount) { + this(stateCount, actionCount, ALPHA, DefaultValues.BETA, DefaultValues.RHO, DefaultValues.INITIAL_Q); + } + + public RLearner(final int state_count, final int action_count, final double alpha, final double beta, final double rho, final double initial_Q) { + this.model = new QModel(state_count, action_count, initial_Q); + this.model.setAlpha(alpha); + + this.rho = rho; + this.beta = beta; + + this.actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); } - public static RLearner fromJson(String json){ + @SuppressWarnings("Used-by-user") + public static RLearner fromJson(final String json) { return JSON.parseObject(json, RLearner.class); } - public RLearner makeCopy(){ - RLearner clone = new RLearner(); + @SuppressWarnings("Used-by-user") + public String toJson() { + return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); + } + + public RLearner makeCopy() { + final RLearner clone = new RLearner(); clone.copy(this); return clone; } - public void copy(RLearner rhs){ - model = rhs.model.makeCopy(); - actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy)rhs.actionSelectionStrategy).clone(); - rho = rhs.rho; - beta = rhs.beta; + public void copy(final RLearner rhs) { + this.model = rhs.model.makeCopy(); + this.actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone(); + this.rho = rhs.rho; + this.beta = rhs.beta; } @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof RLearner){ - RLearner rhs = (RLearner)obj; - if(!model.equals(rhs.model)) return false; - if(!actionSelectionStrategy.equals(rhs.actionSelectionStrategy)) return false; - if(rho != rhs.rho) return false; - return beta == rhs.beta; + public boolean equals(final Object obj) { + if (obj instanceof RLearner) { + final RLearner rhs = (RLearner) obj; + if (!this.model.equals(rhs.model)) { + return false; + } + if (!this.actionSelectionStrategy.equals(rhs.actionSelectionStrategy)) { + return false; + } + if (this.rho != rhs.rho) { + return false; + } + return this.beta == rhs.beta; } return false; } - public RLearner(){ - - } - + @SuppressWarnings("Used-by-user") public double getRho() { - return rho; + return this.rho; } - public void setRho(double rho) { + @SuppressWarnings("Used-by-user") + public void setRho(final double rho) { this.rho = rho; } + @SuppressWarnings("Used-by-user") public double getBeta() { - return beta; + return this.beta; } - public void setBeta(double beta) { + @SuppressWarnings("Used-by-user") + public void setBeta(final double beta) { this.beta = beta; } - public QModel getModel(){ - return model; + public QModel getModel() { + return this.model; } - public void setModel(QModel model){ + public void setModel(final QModel model) { this.model = model; } - public String getActionSelection(){ - return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); + @SuppressWarnings("Used-by-user") + public String getActionSelection() { + return ActionSelectionStrategyFactory.serialize(this.actionSelectionStrategy); } - public void setActionSelection(String conf){ + @SuppressWarnings("Used-by-user") + public void setActionSelection(final String conf) { this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); } - public RLearner(int stateCount, int actionCount){ - this(stateCount, actionCount, 0.1, 0.1, 0.7, 0.1); - } - - public RLearner(int state_count, int action_count, double alpha, double beta, double rho, double initial_Q) - { - model = new QModel(state_count, action_count, initial_Q); - model.setAlpha(alpha); - - this.rho = rho; - this.beta = beta; - - actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); - } - - private double maxQAtState(int stateId, Set actionsAtState){ - IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState); - double maxQ = iv.getValue(); - return maxQ; + private double maxQAtState(final int stateId, final Set actionsAtState) { + return this.model.actionWithMaxQAtState(stateId, actionsAtState).getValue(); } - public void update(int currentState, int actionTaken, int newState, Set actionsAtNextStateId, double immediate_reward) - { - double oldQ = model.getQ(currentState, actionTaken); - - double alpha = model.getAlpha(currentState, actionTaken); // learning rate; - - double maxQ = maxQAtState(newState, actionsAtNextStateId); - - double newQ = oldQ + alpha * (immediate_reward - rho + maxQ - oldQ); - - double maxQAtCurrentState = maxQAtState(currentState, null); - if (newQ == maxQAtCurrentState) - { - rho = rho + beta * (immediate_reward - rho + maxQ - maxQAtCurrentState); + public void update(final int currentState, final int actionTaken, final int newState, final Set actionsAtNextStateId, final double immediate_reward) { + final double oldQ = this.model.getQ(currentState, actionTaken); + final double alpha = this.model.getAlpha(currentState, actionTaken); // learning rate; + final double maxQ = this.maxQAtState(newState, actionsAtNextStateId); + final double newQ = oldQ + alpha * (immediate_reward - this.rho + maxQ - oldQ); + final double maxQAtCurrentState = this.maxQAtState(currentState, null); + if (newQ == maxQAtCurrentState) { + this.rho += this.beta * (immediate_reward - this.rho + maxQ - maxQAtCurrentState); } - - model.setQ(currentState, actionTaken, newQ); + this.model.setQ(currentState, actionTaken, newQ); } - public IndexValue selectAction(int stateId, Set actionsAtState){ - return actionSelectionStrategy.selectAction(stateId, model, actionsAtState); + public IndexValue selectAction(final int stateId, final Set actionsAtState) { + return this.actionSelectionStrategy.selectAction(stateId, this.model, actionsAtState); } } diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java index c4c8f27..80335fd 100644 --- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java +++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java @@ -3,15 +3,14 @@ import com.github.chen0040.rl.utils.IndexValue; import java.io.Serializable; -import java.util.Random; import java.util.Set; /** - * Created by xschen on 9/27/2015 0027. - * Implement temporal-difference learning Sarsa, which is an on-policy TD control algorithm + * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Sarsa, which is an on-policy TD control + * algorithm */ -public class SarsaAgent implements Serializable{ +public class SarsaAgent implements Serializable { private SarsaLearner learner; private int currentState; private int currentAction; @@ -19,111 +18,118 @@ public class SarsaAgent implements Serializable{ private int prevState; private int prevAction; - public int getCurrentState(){ - return currentState; + public SarsaAgent(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) { + this.learner = new SarsaLearner(stateCount, actionCount, alpha, gamma, initialQ); } - public int getCurrentAction(){ - return currentAction; + public SarsaAgent(final int stateCount, final int actionCount) { + this.learner = new SarsaLearner(stateCount, actionCount); } - public int getPrevState() { return prevState; } + public SarsaAgent(final SarsaLearner learner) { + this.learner = learner; + } - public int getPrevAction() { return prevAction; } + public SarsaAgent() { - public void start(int currentState){ - this.currentState = currentState; - this.prevState = -1; - this.prevAction = -1; } - public IndexValue selectAction(){ - return selectAction(null); + @SuppressWarnings("Used-by-user") + public int getCurrentState() { + return this.currentState; } - public IndexValue selectAction(Set actionsAtState){ - if(currentAction == -1){ - IndexValue iv = learner.selectAction(currentState, actionsAtState); - currentAction = iv.getIndex(); - currentValue = iv.getValue(); - } + @SuppressWarnings("Used-by-user") + public int getCurrentAction() { + return this.currentAction; + } - return new IndexValue(currentAction, currentValue); + @SuppressWarnings("Used-by-user") + public int getPrevState() { + return this.prevState; } - public void update(int actionTaken, int newState, double immediateReward){ - update(actionTaken, newState, null, immediateReward); + @SuppressWarnings("Used-by-user") + public int getPrevAction() { + return this.prevAction; } - public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward){ + public void start(final int currentState) { + this.currentState = currentState; + this.prevState = -1; + this.prevAction = -1; + } - IndexValue iv = learner.selectAction(currentState, actionsAtNewState); - int futureAction = iv.getIndex(); + public IndexValue selectAction() { + return this.selectAction(null); + } - learner.update(currentState, actionTaken, newState, futureAction, immediateReward); + public IndexValue selectAction(final Set actionsAtState) { + if (this.currentAction == -1) { + final IndexValue iv = this.learner.selectAction(this.currentState, actionsAtState); + this.currentAction = iv.getIndex(); + this.currentValue = iv.getValue(); + } - prevState = this.currentState; - this.prevAction = actionTaken; + return new IndexValue(this.currentAction, this.currentValue); + } - currentAction = futureAction; - currentState = newState; + public void update(final int actionTaken, final int newState, final double immediateReward) { + this.update(actionTaken, newState, null, immediateReward); } + public void update(final int actionTaken, final int newState, final Set actionsAtNewState, final double immediateReward) { + final IndexValue iv = this.learner.selectAction(this.currentState, actionsAtNewState); + final int futureAction = iv.getIndex(); - public SarsaLearner getLearner(){ - return learner; - } + this.learner.update(this.currentState, actionTaken, newState, futureAction, immediateReward); - public void setLearner(SarsaLearner learner){ - this.learner = learner; - } + this.prevState = this.currentState; + this.prevAction = actionTaken; - public SarsaAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ){ - learner = new SarsaLearner(stateCount, actionCount, alpha, gamma, initialQ); + this.currentAction = futureAction; + this.currentState = newState; } - public SarsaAgent(int stateCount, int actionCount){ - learner = new SarsaLearner(stateCount, actionCount); + public SarsaLearner getLearner() { + return this.learner; } - public SarsaAgent(SarsaLearner learner){ + public void setLearner(final SarsaLearner learner) { this.learner = learner; } - public SarsaAgent(){ - - } - - public void enableEligibilityTrace(double lambda){ - SarsaLambdaLearner acll = new SarsaLambdaLearner(learner); + @SuppressWarnings("Used-by-user") + public void enableEligibilityTrace(final double lambda) { + final SarsaLambdaLearner acll = new SarsaLambdaLearner(this.learner); acll.setLambda(lambda); - learner = acll; + this.learner = acll; } - public SarsaAgent makeCopy(){ - SarsaAgent clone = new SarsaAgent(); + public SarsaAgent makeCopy() { + final SarsaAgent clone = new SarsaAgent(); clone.copy(this); return clone; } - public void copy(SarsaAgent rhs){ - learner.copy(rhs.learner); - currentAction = rhs.currentAction; - currentState = rhs.currentState; - prevAction = rhs.prevAction; - prevState = rhs.prevState; + public void copy(final SarsaAgent rhs) { + this.learner.copy(rhs.learner); + this.currentAction = rhs.currentAction; + this.currentState = rhs.currentState; + this.prevAction = rhs.prevAction; + this.prevState = rhs.prevState; } @Override - public boolean equals(Object obj){ - if(obj != null && obj instanceof SarsaAgent){ - SarsaAgent rhs = (SarsaAgent)obj; - return prevAction == rhs.prevAction - && prevState == rhs.prevState - && currentAction == rhs.currentAction - && currentState == rhs.currentState - && learner.equals(rhs.learner); + public boolean equals(final Object obj) { + if (obj instanceof SarsaAgent) { + final SarsaAgent rhs = (SarsaAgent) obj; + return this.prevAction == rhs.prevAction + && this.prevState == rhs.prevState + && this.currentAction == rhs.currentAction + && this.currentState == rhs.currentState + && this.learner.equals(rhs.learner); } return false; } diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java index e51543e..6298ac8 100644 --- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java @@ -1,6 +1,7 @@ package com.github.chen0040.rl.learning.sarsa; +import com.github.chen0040.rl.models.DefaultValues; import com.github.chen0040.rl.models.EligibilityTraceUpdateMode; import com.github.chen0040.rl.utils.Matrix; @@ -9,119 +10,123 @@ * Created by xschen on 9/28/2015 0028. */ public class SarsaLambdaLearner extends SarsaLearner { - private double lambda = 0.9; + private double lambda = DefaultValues.LAMBDA; private Matrix e; private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace; + public SarsaLambdaLearner(final int stateCount, final int actionCount) { + super(stateCount, actionCount); + this.e = new Matrix(stateCount, actionCount); + } + + public SarsaLambdaLearner(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) { + super(stateCount, actionCount, alpha, gamma, initialQ); + this.e = new Matrix(stateCount, actionCount); + } + + @SuppressWarnings("Used-by-user") + public SarsaLambdaLearner(final SarsaLearner learner) { + this.copy(learner); + this.e = new Matrix(this.model.getStateCount(), this.model.getActionCount()); + } + + private SarsaLambdaLearner() { + + } + + @SuppressWarnings("Used-by-user") public EligibilityTraceUpdateMode getTraceUpdateMode() { - return traceUpdateMode; + return this.traceUpdateMode; } - public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) { + @SuppressWarnings("Used-by-user") + public void setTraceUpdateMode(final EligibilityTraceUpdateMode traceUpdateMode) { this.traceUpdateMode = traceUpdateMode; } - public double getLambda(){ - return lambda; + @SuppressWarnings("Used-by-user") + public double getLambda() { + return this.lambda; } - public void setLambda(double lambda){ + public void setLambda(final double lambda) { this.lambda = lambda; } @Override - public Object clone(){ - SarsaLambdaLearner clone = new SarsaLambdaLearner(); + public SarsaLambdaLearner clone() { + final SarsaLambdaLearner clone = new SarsaLambdaLearner(); clone.copy(this); return clone; } @Override - public void copy(SarsaLearner rhs){ + public void copy(final SarsaLearner rhs) { super.copy(rhs); - SarsaLambdaLearner rhs2 = (SarsaLambdaLearner)rhs; - lambda = rhs2.lambda; - e = rhs2.e.makeCopy(); - traceUpdateMode = rhs2.traceUpdateMode; + final SarsaLambdaLearner rhs2 = (SarsaLambdaLearner) rhs; + this.lambda = rhs2.lambda; + this.e = rhs2.e.makeCopy(); + this.traceUpdateMode = rhs2.traceUpdateMode; } @Override - public boolean equals(Object obj){ - if(!super.equals(obj)){ + public boolean equals(final Object obj) { + if (!super.equals(obj)) { return false; } - if(obj instanceof SarsaLambdaLearner){ - SarsaLambdaLearner rhs = (SarsaLambdaLearner)obj; - return rhs.lambda == lambda && e.equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode; + if (obj instanceof SarsaLambdaLearner) { + final SarsaLambdaLearner rhs = (SarsaLambdaLearner) obj; + return rhs.lambda == this.lambda && this.e.equals(rhs.e) && this.traceUpdateMode == rhs.traceUpdateMode; } return false; } - public SarsaLambdaLearner(){ - super(); - } - - public SarsaLambdaLearner(int stateCount, int actionCount){ - super(stateCount, actionCount); - e = new Matrix(stateCount, actionCount); - } - - public SarsaLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ){ - super(stateCount, actionCount, alpha, gamma, initialQ); - e = new Matrix(stateCount, actionCount); - } - - public SarsaLambdaLearner(SarsaLearner learner){ - copy(learner); - e = new Matrix(model.getStateCount(), model.getActionCount()); - } - - public Matrix getEligibility() - { - return e; + @SuppressWarnings("Used-by-user") + public Matrix getEligibility() { + return this.e; } - public void setEligibility(Matrix e){ + @SuppressWarnings("Used-by-user") + public void setEligibility(final Matrix e) { this.e = e; } @Override - public void update(int currentStateId, int currentActionId, int nextStateId, int nextActionId, double immediateReward) - { + public void update(final int currentStateId, final int currentActionId, final int nextStateId, final int nextActionId, final double immediateReward) { // old_value is $Q_t(s_t, a_t)$ - double oldQ = model.getQ(currentStateId, currentActionId); + double oldQ = this.model.getQ(currentStateId, currentActionId); // learning_rate; - double alpha = model.getAlpha(currentStateId, currentActionId); + final double alpha = this.model.getAlpha(currentStateId, currentActionId); // discount_rate; - double gamma = model.getGamma(); + final double gamma = this.model.getGamma(); // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ - double nextQ = model.getQ(nextStateId, nextActionId); + final double nextQ = this.model.getQ(nextStateId, nextActionId); - double td_error = immediateReward + gamma * nextQ - oldQ; + final double td_error = immediateReward + gamma * nextQ - oldQ; - int stateCount = model.getStateCount(); - int actionCount = model.getActionCount(); + final int stateCount = this.model.getStateCount(); + final int actionCount = this.model.getActionCount(); - e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1); + this.e.set(currentStateId, currentActionId, this.e.get(currentStateId, currentActionId) + 1); - for(int stateId = 0; stateId < stateCount; ++stateId){ - for(int actionId = 0; actionId < actionCount; ++actionId){ - oldQ = model.getQ(stateId, actionId); + for (int stateId = 0; stateId < stateCount; ++stateId) { + for (int actionId = 0; actionId < actionCount; ++actionId) { + oldQ = this.model.getQ(stateId, actionId); - double newQ = oldQ + alpha * td_error * e.get(stateId, actionId); + final double newQ = oldQ + alpha * td_error * this.e.get(stateId, actionId); - model.setQ(stateId, actionId, newQ); + this.model.setQ(stateId, actionId, newQ); if (actionId != currentActionId) { - e.set(currentStateId, actionId, 0); + this.e.set(currentStateId, actionId, 0); } else { - e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda); + this.e.set(stateId, actionId, this.e.get(stateId, actionId) * gamma * this.lambda); } } } diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java index 7fef780..3ce748a 100644 --- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java +++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java @@ -11,150 +11,125 @@ import com.github.chen0040.rl.utils.IndexValue; import java.io.Serializable; -import java.util.Random; import java.util.Set; +import static com.github.chen0040.rl.models.DefaultValues.*; + /** - * Created by xschen on 9/27/2015 0027. - * Implement temporal-difference learning Q-Learning, which is an off-policy TD control algorithm - * Q is known as the quality of state-action combination, note that it is different from utility of a state + * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Q-Learning, which is an off-policy TD + * control algorithm Q is known as the quality of state-action combination, note that it is different from utility of a + * state */ -public class SarsaLearner implements Serializable,Cloneable { +public class SarsaLearner implements Serializable, Cloneable { protected QModel model; private ActionSelectionStrategy actionSelectionStrategy; - public String toJson() { - return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); - } + @SuppressWarnings("Used-by-user") + public SarsaLearner() { - public static SarsaLearner fromJson(String json){ - return JSON.parseObject(json, SarsaLearner.class); } - public SarsaLearner makeCopy(){ - SarsaLearner clone = new SarsaLearner(); - clone.copy(this); - return clone; + @SuppressWarnings("Used-by-user") + public SarsaLearner(final int stateCount, final int actionCount) { + this(stateCount, actionCount, ALPHA, GAMMA, INITIAL_Q); } - public void copy(SarsaLearner rhs){ - model = rhs.model.makeCopy(); - actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone(); + @SuppressWarnings("Used-by-user") + public SarsaLearner(final QModel model, final ActionSelectionStrategy actionSelectionStrategy) { + this.model = model; + this.actionSelectionStrategy = actionSelectionStrategy; } - @Override - public boolean equals(Object obj){ - if(obj !=null && obj instanceof SarsaLearner){ - SarsaLearner rhs = (SarsaLearner)obj; - if(!model.equals(rhs.model)) return false; - return actionSelectionStrategy.equals(rhs.actionSelectionStrategy); - } - return false; + @SuppressWarnings("Used-by-user") + public SarsaLearner(final int stateCount, final int actionCount, final double alpha, final double gamma, final double initialQ) { + this.model = new QModel(stateCount, actionCount, initialQ); + this.model.setAlpha(alpha); + this.model.setGamma(gamma); + this.actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); } - public QModel getModel() { - return model; + @SuppressWarnings("Used-by-user") + public static SarsaLearner fromJson(final String json) { + return JSON.parseObject(json, SarsaLearner.class); } - public void setModel(QModel model) { - this.model = model; + @SuppressWarnings("Used-by-user") + public String toJson() { + return JSON.toJSONString(this, SerializerFeature.BrowserCompatible); } - public String getActionSelection() { - return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy); + public SarsaLearner makeCopy() { + final SarsaLearner clone = new SarsaLearner(); + clone.copy(this); + return clone; } - public void setActionSelection(String conf) { - this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); + public void copy(final SarsaLearner rhs) { + this.model = rhs.model.makeCopy(); + this.actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone(); } - public SarsaLearner(){ - + @Override + public boolean equals(final Object obj) { + if (obj instanceof SarsaLearner) { + final SarsaLearner rhs = (SarsaLearner) obj; + if (!this.model.equals(rhs.model)) { + return false; + } + return this.actionSelectionStrategy.equals(rhs.actionSelectionStrategy); + } + return false; } - public SarsaLearner(int stateCount, int actionCount){ - this(stateCount, actionCount, 0.1, 0.7, 0.1); + public QModel getModel() { + return this.model; } - public SarsaLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy){ + public void setModel(final QModel model) { this.model = model; - this.actionSelectionStrategy = actionSelectionStrategy; } - public SarsaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) - { - model = new QModel(stateCount, actionCount, initialQ); - model.setAlpha(alpha); - model.setGamma(gamma); - actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy(); + @SuppressWarnings("Used-by-user") + public String getActionSelection() { + return ActionSelectionStrategyFactory.serialize(this.actionSelectionStrategy); } - public static void main(String[] args){ - int stateCount = 100; - int actionCount = 10; - - SarsaLearner learner = new SarsaLearner(stateCount, actionCount); - - double reward = 0; // reward gained by transiting from prevState to currentState - Random random = new Random(); - int currentStateId = random.nextInt(stateCount); - int currentActionId = learner.selectAction(currentStateId).getIndex(); - - for(int time=0; time < 1000; ++time){ - - System.out.println("Controller does action-"+currentActionId); - - int newStateId = random.nextInt(actionCount); - reward = random.nextDouble(); - - System.out.println("Now the new state is " + newStateId); - System.out.println("Controller receives Reward = " + reward); - - int futureActionId = learner.selectAction(newStateId).getIndex(); - - System.out.println("Controller is expected to do action-"+futureActionId); - - learner.update(currentStateId, currentActionId, newStateId, futureActionId, reward); - - currentStateId = newStateId; - currentActionId = futureActionId; - } + @SuppressWarnings("Used-by-user") + public void setActionSelection(final String conf) { + this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf); } - - public IndexValue selectAction(int stateId, Set actionsAtState){ - return actionSelectionStrategy.selectAction(stateId, model, actionsAtState); + public IndexValue selectAction(final int stateId, final Set actionsAtState) { + return this.actionSelectionStrategy.selectAction(stateId, this.model, actionsAtState); } - public IndexValue selectAction(int stateId){ - return selectAction(stateId, null); + public IndexValue selectAction(final int stateId) { + return this.selectAction(stateId, null); } - public void update(int stateId, int actionId, int nextStateId, int nextActionId, double immediateReward) - { + public void update(final int stateId, final int actionId, final int nextStateId, final int nextActionId, final double immediateReward) { // old_value is $Q_t(s_t, a_t)$ - double oldQ = model.getQ(stateId, actionId); + final double oldQ = this.model.getQ(stateId, actionId); // learning_rate; - double alpha = model.getAlpha(stateId, actionId); + final double alpha = this.model.getAlpha(stateId, actionId); // discount_rate; - double gamma = model.getGamma(); + final double gamma = this.model.getGamma(); // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$ - double nextQ = model.getQ(nextStateId, nextActionId); + final double nextQ = this.model.getQ(nextStateId, nextActionId); // learned_value = immediate_reward + gamma * estimate_of_optimal_future_value // old_value = oldQ // temporal_difference = learned_value - old_value // new_value = old_value + learning_rate * temporal_difference - double newQ = oldQ + alpha * (immediateReward + gamma * nextQ - oldQ); + final double newQ = oldQ + alpha * (immediateReward + gamma * nextQ - oldQ); // new_value is $Q_{t+1}(s_t, a_t)$ - model.setQ(stateId, actionId, newQ); + this.model.setQ(stateId, actionId, newQ); } - } diff --git a/src/main/java/com/github/chen0040/rl/models/DefaultValues.java b/src/main/java/com/github/chen0040/rl/models/DefaultValues.java new file mode 100644 index 0000000..879ea64 --- /dev/null +++ b/src/main/java/com/github/chen0040/rl/models/DefaultValues.java @@ -0,0 +1,11 @@ +package com.github.chen0040.rl.models; + +public enum DefaultValues { + ; + public static final double GAMMA = 0.9; + public static final double ALPHA = 0.1; + public static final double INITIAL_Q = 0.1; + public static final double LAMBDA = 0.9; + public static final double BETA = 0.1; + public static final double RHO = 0.7; +} diff --git a/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java b/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java index e25380f..bd891f0 100644 --- a/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java +++ b/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java @@ -4,6 +4,5 @@ * Created by xschen on 9/28/2015 0028. */ public enum EligibilityTraceUpdateMode { - ReplaceTrace, - AccumulateTrace + ReplaceTrace } diff --git a/src/main/java/com/github/chen0040/rl/models/QModel.java b/src/main/java/com/github/chen0040/rl/models/QModel.java index 2d314a1..d48e1c1 100644 --- a/src/main/java/com/github/chen0040/rl/models/QModel.java +++ b/src/main/java/com/github/chen0040/rl/models/QModel.java @@ -4,149 +4,147 @@ import com.github.chen0040.rl.utils.IndexValue; import com.github.chen0040.rl.utils.Matrix; import com.github.chen0040.rl.utils.Vec; -import lombok.Getter; -import lombok.Setter; import java.util.*; /** - * @author xschen - * 9/27/2015 0027. - * Q is known as the quality of state-action combination, note that it is different from utility of a state + * @author xschen 9/27/2015 0027. Q is known as the quality of state-action combination, note that it is different from + * utility of a state */ -@Getter -@Setter public class QModel { /** - * Q value for (state_id, action_id) pair - * Q is known as the quality of state-action combination, note that it is different from utility of a state - */ + * Q value for (state_id, action_id) pair Q is known as the quality of state-action combination, note that it is + * different from utility of a state + */ + private Matrix Q; /** - * $\alpha[s, a]$ value for learning rate: alpha(state_id, action_id) - */ + * $\alpha[s, a]$ value for learning rate: alpha(state_id, action_id) + */ + private Matrix alphaMatrix; /** * discount factor */ - private double gamma = 0.7; + private double gamma = DefaultValues.GAMMA; private int stateCount; private int actionCount; - public QModel(int stateCount, int actionCount, double initialQ){ + public QModel(final int stateCount, final int actionCount, final double initialQ) { this.stateCount = stateCount; this.actionCount = actionCount; - Q = new Matrix(stateCount,actionCount); - alphaMatrix = new Matrix(stateCount, actionCount); - Q.setAll(initialQ); - alphaMatrix.setAll(0.1); + this.Q = new Matrix(stateCount, actionCount); + this.alphaMatrix = new Matrix(stateCount, actionCount); + this.Q.setAll(initialQ); + this.alphaMatrix.setAll(DefaultValues.ALPHA); } - public QModel(int stateCount, int actionCount){ - this(stateCount, actionCount, 0.1); + public QModel(final int stateCount, final int actionCount) { + this(stateCount, actionCount, DefaultValues.INITIAL_Q); } - public QModel(){ + public QModel() { } @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof QModel){ - QModel rhs2 = (QModel)rhs; - - - if(gamma != rhs2.gamma) return false; - - - if(stateCount != rhs2.stateCount || actionCount != rhs2.actionCount) return false; - - if((Q!=null && rhs2.Q==null) || (Q==null && rhs2.Q !=null)) return false; - if((alphaMatrix !=null && rhs2.alphaMatrix ==null) || (alphaMatrix ==null && rhs2.alphaMatrix !=null)) return false; - - return !((Q != null && !Q.equals(rhs2.Q)) || (alphaMatrix != null && !alphaMatrix.equals(rhs2.alphaMatrix))); - + public boolean equals(final Object rhs) { + if (rhs instanceof QModel) { + final QModel rhs2 = (QModel) rhs; + return this.gamma == rhs2.gamma && + this.stateCount == rhs2.stateCount && + this.actionCount == rhs2.actionCount && + (this.Q == null || rhs2.Q != null) && + (this.Q != null || rhs2.Q == null) && + (this.alphaMatrix == null || rhs2.alphaMatrix != null) && + (this.alphaMatrix != null || rhs2.alphaMatrix == null) && + !((this.Q != null && !this.Q.equals(rhs2.Q)) || (this.alphaMatrix != null && !this.alphaMatrix.equals(rhs2.alphaMatrix))); } return false; } - public QModel makeCopy(){ - QModel clone = new QModel(); + public QModel makeCopy() { + final QModel clone = new QModel(); clone.copy(this); return clone; } - public void copy(QModel rhs){ - gamma = rhs.gamma; - stateCount = rhs.stateCount; - actionCount = rhs.actionCount; - Q = rhs.Q==null ? null : rhs.Q.makeCopy(); - alphaMatrix = rhs.alphaMatrix == null ? null : rhs.alphaMatrix.makeCopy(); + public void copy(final QModel rhs) { + this.gamma = rhs.gamma; + this.stateCount = rhs.stateCount; + this.actionCount = rhs.actionCount; + this.Q = rhs.Q == null ? null : rhs.Q.makeCopy(); + this.alphaMatrix = rhs.alphaMatrix == null ? null : rhs.alphaMatrix.makeCopy(); } - public double getQ(int stateId, int actionId){ - return Q.get(stateId, actionId); + public double getQ(final int stateId, final int actionId) { + assert this.Q != null; + return this.Q.get(stateId, actionId); } - public void setQ(int stateId, int actionId, double Qij){ - Q.set(stateId, actionId, Qij); + public void setQ(final int stateId, final int actionId, final double Qij) { + assert this.Q != null; + this.Q.set(stateId, actionId, Qij); } - public double getAlpha(int stateId, int actionId){ - return alphaMatrix.get(stateId, actionId); + public double getAlpha(final int stateId, final int actionId) { + assert this.alphaMatrix != null; + return this.alphaMatrix.get(stateId, actionId); } - public void setAlpha(double defaultAlpha) { + public void setAlpha(final double defaultAlpha) { + assert this.alphaMatrix != null; this.alphaMatrix.setAll(defaultAlpha); } - public IndexValue actionWithMaxQAtState(int stateId, Set actionsAtState){ - Vec rowVector = Q.rowAt(stateId); + public IndexValue actionWithMaxQAtState(final int stateId, final Set actionsAtState) { + assert this.Q != null; + final Vec rowVector = this.Q.rowAt(stateId); return rowVector.indexWithMaxValue(actionsAtState); } - private void reset(double initialQ){ - Q.setAll(initialQ); + private void reset(final double initialQ) { + assert this.Q != null; + this.Q.setAll(initialQ); } - public IndexValue actionWithSoftMaxQAtState(int stateId,Set actionsAtState, Random random) { - Vec rowVector = Q.rowAt(stateId); + public IndexValue actionWithSoftMaxQAtState(final int stateId, final Set actionsAtState, final Random random) { + Set atState = actionsAtState; + assert this.Q != null; + final Vec rowVector = this.Q.rowAt(stateId); double sum = 0; - if(actionsAtState==null){ - actionsAtState = new HashSet<>(); - for(int i=0; i < actionCount; ++i){ - actionsAtState.add(i); + if (atState == null) { + atState = new HashSet<>(); + for (int i = 0; i < this.actionCount; ++i) { + atState.add(i); } } - List actions = new ArrayList<>(); - for(Integer actionId : actionsAtState){ - actions.add(actionId); - } + final List actions = new ArrayList<>(atState); - double[] acc = new double[actions.size()]; - for(int i=0; i < actions.size(); ++i){ + final double[] acc = new double[actions.size()]; + for (int i = 0; i < actions.size(); ++i) { sum += rowVector.get(actions.get(i)); acc[i] = sum; } - double r = random.nextDouble() * sum; + final double r = random.nextDouble() * sum; - IndexValue result = new IndexValue(); - for(int i=0; i < actions.size(); ++i){ - if(acc[i] >= r){ - int actionId = actions.get(i); + final IndexValue result = new IndexValue(); + for (int i = 0; i < actions.size(); ++i) { + if (acc[i] >= r) { + final int actionId = actions.get(i); result.setIndex(actionId); result.setValue(rowVector.get(actionId)); break; @@ -155,4 +153,38 @@ public IndexValue actionWithSoftMaxQAtState(int stateId,Set actionsAtSt return result; } + + + public Matrix getQ() { + return this.Q; + } + + + public Matrix getAlphaMatrix() { + return this.alphaMatrix; + } + + public double getGamma() { + return this.gamma; + } + + public void setGamma(final double gamma) { + this.gamma = gamma; + } + + public int getStateCount() { + return this.stateCount; + } + + public void setStateCount(final int stateCount) { + this.stateCount = stateCount; + } + + public int getActionCount() { + return this.actionCount; + } + + public void setActionCount(final int actionCount) { + this.actionCount = actionCount; + } } diff --git a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java index cff1859..4496b62 100644 --- a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java +++ b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java @@ -1,91 +1,79 @@ package com.github.chen0040.rl.models; import com.github.chen0040.rl.utils.Vec; -import lombok.Getter; -import lombok.Setter; import java.io.Serializable; /** - * @author xschen - * 9/27/2015 0027. - * Utility value of a state $U(s)$ is the expected long term reward of state $s$ given the sequence of reward and the optimal policy - * Utility value $U(s)$ at state $s$ can be obtained by the Bellman equation - * Bellman Equtation states that $U(s) = R(s) + \gamma * max_a \sum_{s'} T(s,a,s')U(s')$ - * where s' is the possible transitioned state given that action $a$ is applied at state $s$ - * where $T(s,a,s')$ is the transition probability of $s \rightarrow s'$ given that action $a$ is applied at state $s$ - * where $\sum_{s'} T(s,a,s')U(s')$ is the expected long term reward given that action $a$ is applied at state $s$ - * where $max_a \sum_{s'} T(s,a,s')U(s')$ is the maximum expected long term reward given that the chosen optimal action $a$ is applied at state $s$ + * @author xschen 9/27/2015 0027. Utility value of a state $U(s)$ is the expected long term reward of state $s$ given + * the sequence of reward and the optimal policy Utility value $U(s)$ at state $s$ can be obtained by the + * Bellman equation Bellman Equtation states that $U(s) = R(s) + \gamma * max_a \sum_{s'} T(s,a,s')U(s')$ where + * s' is the possible transitioned state given that action $a$ is applied at state $s$ where $T(s,a,s')$ is the + * transition probability of $s \rightarrow s'$ given that action $a$ is applied at state $s$ where $\sum_{s'} + * T(s,a,s')U(s')$ is the expected long term reward given that action $a$ is applied at state $s$ where $max_a + * \sum_{s'} T(s,a,s')U(s')$ is the maximum expected long term reward given that the chosen optimal action $a$ + * is applied at state $s$ */ -@Getter -@Setter public class UtilityModel implements Serializable { + private Vec U; private int stateCount; private int actionCount; - public void setU(Vec U){ - this.U = U; - } - - public Vec getU() { - return U; + @SuppressWarnings("Used-by-user") + public UtilityModel(final int stateCount, final int actionCount, final double initialU) { + this.stateCount = stateCount; + this.actionCount = actionCount; + this.U = new Vec(stateCount); + this.U.setAll(initialU); } - public double getU(int stateId){ - return U.get(stateId); + private UtilityModel() { } public int getStateCount() { - return stateCount; + return this.stateCount; } - public int getActionCount() { - return actionCount; - } - - public UtilityModel(int stateCount, int actionCount, double initialU){ + public void setStateCount(final int stateCount) { this.stateCount = stateCount; - this.actionCount = actionCount; - U = new Vec(stateCount); - U.setAll(initialU); } - public UtilityModel(int stateCount, int actionCount){ - this(stateCount, actionCount, 0.1); + public int getActionCount() { + return this.actionCount; } - public UtilityModel(){ - + public void setActionCount(final int actionCount) { + this.actionCount = actionCount; } - public void copy(UtilityModel rhs){ - U = rhs.U==null ? null : rhs.U.makeCopy(); - actionCount = rhs.actionCount; - stateCount = rhs.stateCount; + public void copy(final UtilityModel rhs) { + this.U = rhs.U == null ? null : rhs.U.makeCopy(); + this.actionCount = rhs.actionCount; + this.stateCount = rhs.stateCount; } - public UtilityModel makeCopy(){ - UtilityModel clone = new UtilityModel(); + public UtilityModel makeCopy() { + final UtilityModel clone = new UtilityModel(); clone.copy(this); return clone; } @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof UtilityModel){ - UtilityModel rhs2 = (UtilityModel)rhs; - if(actionCount != rhs2.actionCount || stateCount != rhs2.stateCount) return false; - - if((U==null && rhs2.U!=null) && (U!=null && rhs2.U ==null)) return false; - return !(U != null && !U.equals(rhs2.U)); + public boolean equals(final Object rhs) { + if (rhs instanceof UtilityModel) { + final UtilityModel rhs2 = (UtilityModel) rhs; + return this.actionCount == rhs2.actionCount && + this.stateCount == rhs2.stateCount && + !(this.U != null && !this.U.equals(rhs2.U)); } return false; } - public void reset(double initialU){ - U.setAll(initialU); + public void reset(final double initialU) { + assert this.U != null; + this.U.setAll(initialU); } } diff --git a/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java b/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java index e840bc1..946bd56 100644 --- a/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java +++ b/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java @@ -3,12 +3,13 @@ /** * Created by xschen on 10/11/2015 0011. */ -public class DoubleUtils { - public static boolean equals(double a1, double a2){ - return Math.abs(a1-a2) < 1e-10; - } +public enum DoubleUtils { + ; + + public static final double TOLERANCE = 0.0000000001; - public static boolean isZero(double a){ - return a < 1e-20; + public static boolean equals(final double a1, final double a2) { + return Math.abs(a1 - a2) < DoubleUtils.TOLERANCE; } + } diff --git a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java index 66c2bf6..6c3d6ae 100644 --- a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java +++ b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java @@ -1,46 +1,55 @@ package com.github.chen0040.rl.utils; -import lombok.Getter; -import lombok.Setter; - - /** * Created by xschen on 6/5/2017. */ -@Getter -@Setter public class IndexValue { - private int index; - private double value; - - public IndexValue(){ - - } - - public IndexValue(int index, double value){ - this.index = index; - this.value = value; - } - - public IndexValue makeCopy(){ - IndexValue clone = new IndexValue(); - clone.setValue(value); - clone.setIndex(index); - return clone; - } - - @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof IndexValue){ - IndexValue rhs2 = (IndexValue)rhs; - return index == rhs2.index && value == rhs2.value; - } - return false; - } - - public boolean isValid(){ - return index != -1; - } - + private int index; + private double value; + + public IndexValue() { + + } + + public IndexValue(final int index, final double value) { + this.index = index; + this.value = value; + } + + public IndexValue makeCopy() { + final IndexValue clone = new IndexValue(); + clone.setValue(this.value); + clone.setIndex(this.index); + return clone; + } + + @Override + public boolean equals(final Object rhs) { + if (rhs instanceof IndexValue) { + final IndexValue rhs2 = (IndexValue) rhs; + return this.index == rhs2.index && this.value == rhs2.value; + } + return false; + } + + boolean isValid() { + return this.index != -1; + } + + public int getIndex() { + return this.index; + } + + public void setIndex(final int index) { + this.index = index; + } + + public double getValue() { + return this.value; + } + + public void setValue(final double value) { + this.value = value; + } } diff --git a/src/main/java/com/github/chen0040/rl/utils/Matrix.java b/src/main/java/com/github/chen0040/rl/utils/Matrix.java index cd42bd5..b20b86f 100644 --- a/src/main/java/com/github/chen0040/rl/utils/Matrix.java +++ b/src/main/java/com/github/chen0040/rl/utils/Matrix.java @@ -1,83 +1,41 @@ package com.github.chen0040.rl.utils; -import com.alibaba.fastjson.annotation.JSONField; -import lombok.Getter; -import lombok.Setter; - import java.io.Serializable; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; /** * Created by xschen on 9/27/2015 0027. */ -@Getter -@Setter public class Matrix implements Serializable { - private Map rows = new HashMap<>(); + private final Map rows = new HashMap<>(); private int rowCount; private int columnCount; private double defaultValue; - public Matrix(){ - - } - - public Matrix(double[][] A){ - for(int i = 0; i < A.length; ++i){ - double[] B = A[i]; - for(int j=0; j < B.length; ++j){ - set(i, j, B[j]); - } - } - } - - public void setRow(int rowIndex, Vec rowVector){ - rowVector.setId(rowIndex); - rows.put(rowIndex, rowVector); - } - - - public static Matrix identity(int dimension){ - Matrix m = new Matrix(dimension, dimension); - for(int i=0; i < m.getRowCount(); ++i){ - m.set(i, i, 1); - } - return m; + public Matrix(final int rowCount, final int columnCount) { + this.rowCount = rowCount; + this.columnCount = columnCount; + this.defaultValue = 0; } @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof Matrix){ - Matrix rhs2 = (Matrix)rhs; - if(rowCount != rhs2.rowCount || columnCount != rhs2.columnCount){ + public boolean equals(final Object rhs) { + if (rhs instanceof Matrix) { + final Matrix rhs2 = (Matrix) rhs; + if (this.rowCount != rhs2.rowCount || this.columnCount != rhs2.columnCount) { return false; } - if(defaultValue == rhs2.defaultValue) { - for (Integer index : rows.keySet()) { - if (!rhs2.rows.containsKey(index)) return false; - if (!rows.get(index).equals(rhs2.rows.get(index))) { - System.out.println("failed!"); - return false; - } - } - - for (Integer index : rhs2.rows.keySet()) { - if (!rows.containsKey(index)) return false; - if (!rhs2.rows.get(index).equals(rows.get(index))) { - System.out.println("failed! 22"); - return false; - } - } + if (this.defaultValue == rhs2.defaultValue) { + return this.rows.keySet().stream().noneMatch(index -> !rhs2.rows.containsKey(index) || !this.rows.get(index).equals(rhs2.rows.get(index))) && + rhs2.rows.keySet().stream().noneMatch(index -> !this.rows.containsKey(index) || !rhs2.rows.get(index).equals(this.rows.get(index))); } else { - for(int i=0; i < rowCount; ++i) { - for(int j=0; j < columnCount; ++j) { - if(this.get(i, j) != rhs2.get(i, j)){ + for (int i = 0; i < this.rowCount; ++i) { + for (int j = 0; j < this.columnCount; ++j) { + if (this.get(i, j) != rhs2.get(i, j)) { return false; } } @@ -90,154 +48,63 @@ public boolean equals(Object rhs){ return false; } - public Matrix makeCopy(){ - Matrix clone = new Matrix(rowCount, columnCount); + public Matrix makeCopy() { + final Matrix clone = new Matrix(this.rowCount, this.columnCount); clone.copy(this); return clone; } - public void copy(Matrix rhs){ - rowCount = rhs.rowCount; - columnCount = rhs.columnCount; - defaultValue = rhs.defaultValue; + private void copy(final Matrix rhs) { + this.rowCount = rhs.rowCount; + this.columnCount = rhs.columnCount; + this.defaultValue = rhs.defaultValue; - rows.clear(); + this.rows.clear(); - for(Map.Entry entry : rhs.rows.entrySet()){ - rows.put(entry.getKey(), entry.getValue().makeCopy()); - } + rhs.rows.forEach((key, value) -> this.rows.put(key, value.makeCopy())); } - - - public void set(int rowIndex, int columnIndex, double value){ - Vec row = rowAt(rowIndex); + public void set(final int rowIndex, final int columnIndex, final double value) { + final Vec row = this.rowAt(rowIndex); row.set(columnIndex, value); - if(rowIndex >= rowCount) { rowCount = rowIndex+1; } - if(columnIndex >= columnCount) { columnCount = columnIndex + 1; } - } - - - - public Matrix(int rowCount, int columnCount){ - this.rowCount = rowCount; - this.columnCount = columnCount; - this.defaultValue = 0; + if (rowIndex >= this.rowCount) { + this.rowCount = rowIndex + 1; + } + if (columnIndex >= this.columnCount) { + this.columnCount = columnIndex + 1; + } } - public Vec rowAt(int rowIndex){ - Vec row = rows.get(rowIndex); - if(row == null){ - row = new Vec(columnCount); - row.setAll(defaultValue); + public Vec rowAt(final int rowIndex) { + Vec row = this.rows.get(rowIndex); + if (row == null) { + row = new Vec(this.columnCount); + row.setAll(this.defaultValue); row.setId(rowIndex); - rows.put(rowIndex, row); + this.rows.put(rowIndex, row); } return row; } - public void setAll(double value){ - defaultValue = value; - for(Vec row : rows.values()){ + public void setAll(final double value) { + this.defaultValue = value; + for (final Vec row : this.rows.values()) { row.setAll(value); } } - public double get(int rowIndex, int columnIndex) { - Vec row= rowAt(rowIndex); + public double get(final int rowIndex, final int columnIndex) { + final Vec row = this.rowAt(rowIndex); return row.get(columnIndex); } - public List columnVectors() - { - Matrix A = this; - int n = A.getColumnCount(); - int rowCount = A.getRowCount(); - - List Acols = new ArrayList(); - - for (int c = 0; c < n; ++c) - { - Vec Acol = new Vec(rowCount); - Acol.setAll(defaultValue); - Acol.setId(c); - - for (int r = 0; r < rowCount; ++r) - { - Acol.set(r, A.get(r, c)); - } - Acols.add(Acol); - } - return Acols; - } - - public Matrix multiply(Matrix rhs) - { - if(this.getColumnCount() != rhs.getRowCount()){ - System.err.println("A.columnCount must be equal to B.rowCount in multiplication"); - return null; - } - - Vec row1; - Vec col2; - - Matrix result = new Matrix(getRowCount(), rhs.getColumnCount()); - result.setAll(defaultValue); - - List rhsColumns = rhs.columnVectors(); - - for (Map.Entry entry : rows.entrySet()) - { - int r1 = entry.getKey(); - row1 = entry.getValue(); - for (int c2 = 0; c2 < rhsColumns.size(); ++c2) - { - col2 = rhsColumns.get(c2); - result.set(r1, c2, row1.multiply(col2)); - } - } - - return result; - } - - @JSONField(serialize = false) - public boolean isSymmetric(){ - if (getRowCount() != getColumnCount()) return false; - for (Map.Entry rowEntry : rows.entrySet()) - { - int row = rowEntry.getKey(); - Vec rowVec = rowEntry.getValue(); - - for (Integer col : rowVec.getData().keySet()) - { - if (row == col.intValue()) continue; - if(DoubleUtils.equals(rowVec.get(col), this.get(col, row))){ - return false; - } - } - } - - return true; + public int getRowCount() { + return this.rowCount; } - public Vec multiply(Vec rhs) - { - if(this.getColumnCount() != rhs.getDimension()){ - System.err.println("columnCount must be equal to the size of the vector for multiplication"); - } - - Vec row1; - Vec result = new Vec(getRowCount()); - for (Map.Entry entry : rows.entrySet()) - { - row1 = entry.getValue(); - result.set(entry.getKey(), row1.multiply(rhs)); - } - return result; + public int getColumnCount() { + return this.columnCount; } - - - } diff --git a/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java b/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java deleted file mode 100644 index e43c28b..0000000 --- a/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java +++ /dev/null @@ -1,29 +0,0 @@ -package com.github.chen0040.rl.utils; - -import java.util.List; - - -/** - * Created by xschen on 10/11/2015 0011. - */ -public class MatrixUtils { - /** - * Convert a list of column vectors into a matrix - */ - public static Matrix matrixFromColumnVectors(List R) - { - int n = R.size(); - int m = R.get(0).getDimension(); - - Matrix T = new Matrix(m, n); - for (int c = 0; c < n; ++c) - { - Vec Rcol = R.get(c); - for (int r : Rcol.getData().keySet()) - { - T.set(r, c, Rcol.get(r)); - } - } - return T; - } -} diff --git a/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java b/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java deleted file mode 100644 index b4895ea..0000000 --- a/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java +++ /dev/null @@ -1,56 +0,0 @@ -package com.github.chen0040.rl.utils; - -/** - * Created by xschen on 10/11/2015 0011. - */ -public class TupleTwo { - private T1 item1; - private T2 item2; - - public TupleTwo(T1 item1, T2 item2){ - this.item1 = item1; - this.item2 = item2; - } - - public T1 getItem1() { - return item1; - } - - public void setItem1(T1 item1) { - this.item1 = item1; - } - - public T2 getItem2() { - return item2; - } - - public void setItem2(T2 item2) { - this.item2 = item2; - } - - public static TupleTwo create(U1 item1, U2 item2){ - return new TupleTwo(item1, item2); - } - - - @Override public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - TupleTwo tupleTwo = (TupleTwo) o; - - if (item1 != null ? !item1.equals(tupleTwo.item1) : tupleTwo.item1 != null) - return false; - return item2 != null ? item2.equals(tupleTwo.item2) : tupleTwo.item2 == null; - - } - - - @Override public int hashCode() { - int result = item1 != null ? item1.hashCode() : 0; - result = 31 * result + (item2 != null ? item2.hashCode() : 0); - return result; - } -} diff --git a/src/main/java/com/github/chen0040/rl/utils/Vec.java b/src/main/java/com/github/chen0040/rl/utils/Vec.java index 4699d0e..ca007cf 100644 --- a/src/main/java/com/github/chen0040/rl/utils/Vec.java +++ b/src/main/java/com/github/chen0040/rl/utils/Vec.java @@ -1,131 +1,108 @@ package com.github.chen0040.rl.utils; -import lombok.Getter; -import lombok.Setter; - import java.io.Serializable; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.IntStream; /** * Created by xschen on 9/27/2015 0027. */ -@Getter -@Setter public class Vec implements Serializable { - private Map data = new HashMap(); + private final Map data = new HashMap<>(); private int dimension; private double defaultValue; private int id = -1; - public Vec(){ + public Vec() { } - public Vec(double[] v){ - for(int i=0; i < v.length; ++i){ - set(i, v[i]); - } + public Vec(final double[] v) { + IntStream.range(0, v.length).forEach(i -> this.set(i, v[i])); } - public Vec(int dimension){ + public Vec(final int dimension) { this.dimension = dimension; - defaultValue = 0; + this.defaultValue = 0; } - public Vec(int dimension, Map data){ + public Vec(final int dimension, final Map data) { this.dimension = dimension; - defaultValue = 0; + this.defaultValue = 0; - for(Map.Entry entry : data.entrySet()){ - set(entry.getKey(), entry.getValue()); - } + data.forEach(this::set); } - public Vec makeCopy(){ - Vec clone = new Vec(dimension); + public Vec makeCopy() { + final Vec clone = new Vec(this.dimension); clone.copy(this); return clone; } - public void copy(Vec rhs){ - defaultValue = rhs.defaultValue; - dimension = rhs.dimension; - id = rhs.id; + public void copy(final Vec rhs) { + this.defaultValue = rhs.defaultValue; + this.dimension = rhs.dimension; + this.id = rhs.id; - data.clear(); - for(Map.Entry entry : rhs.data.entrySet()){ - data.put(entry.getKey(), entry.getValue()); - } + this.data.clear(); + rhs.data.forEach(this.data::put); } - public void set(int i, double value){ - if(value == defaultValue) return; + public void set(final int i, final double value) { + if (value == this.defaultValue) { + return; + } - data.put(i, value); - if(i >= dimension){ - dimension = i+1; + this.data.put(i, value); + if (i >= this.dimension) { + this.dimension = i + 1; } } - public double get(int i){ - return data.getOrDefault(i, defaultValue); + public double get(final int i) { + return this.data.getOrDefault(i, this.defaultValue); } @Override - public boolean equals(Object rhs){ - if(rhs != null && rhs instanceof Vec){ - Vec rhs2 = (Vec)rhs; - if(dimension != rhs2.dimension){ + public boolean equals(final Object rhs) { + if (rhs instanceof Vec) { + final Vec rhs2 = (Vec) rhs; + if (this.dimension != rhs2.dimension || this.data.size() != rhs2.data.size()) { return false; } - - if(data.size() != rhs2.data.size()){ - return false; - } - - for(Integer index : data.keySet()){ - if(!rhs2.data.containsKey(index)) return false; - if(!DoubleUtils.equals(data.get(index), rhs2.data.get(index))){ + for (final Integer index : this.data.keySet()) { + if (!rhs2.data.containsKey(index) || !DoubleUtils.equals(this.data.get(index), rhs2.data.get(index))) { return false; } } - - if(defaultValue != rhs2.defaultValue){ - for(int i=0; i < dimension; ++i){ - if(data.containsKey(i)){ - return false; - } - } + if (this.defaultValue != rhs2.defaultValue) { + return IntStream.range(0, this.dimension).noneMatch(this.data::containsKey); } - return true; } return false; } - public void setAll(double value){ - defaultValue = value; - for(Integer index : data.keySet()){ - data.put(index, defaultValue); - } + public void setAll(final double value) { + this.defaultValue = value; + this.data.keySet().forEach(index -> this.data.put(index, this.defaultValue)); } - public IndexValue indexWithMaxValue(Set indices){ - if(indices == null){ - return indexWithMaxValue(); - }else{ - IndexValue iv = new IndexValue(); + public IndexValue indexWithMaxValue(final Set indices) { + if (indices == null) { + return this.indexWithMaxValue(); + } else { + final IndexValue iv = new IndexValue(); iv.setIndex(-1); iv.setValue(Double.NEGATIVE_INFINITY); - for(Integer index : indices){ - double value = data.getOrDefault(index, Double.NEGATIVE_INFINITY); - if(value > iv.getValue()){ + for (final Integer index : indices) { + final double value = this.data.getOrDefault(index, Double.NEGATIVE_INFINITY); + if (value > iv.getValue()) { iv.setIndex(index); iv.setValue(value); } @@ -134,29 +111,31 @@ public IndexValue indexWithMaxValue(Set indices){ } } - public IndexValue indexWithMaxValue(){ - IndexValue iv = new IndexValue(); + private IndexValue indexWithMaxValue() { + final IndexValue iv = new IndexValue(); iv.setIndex(-1); iv.setValue(Double.NEGATIVE_INFINITY); - for(Map.Entry entry : data.entrySet()){ - if(entry.getKey() >= dimension) continue; + for (final Map.Entry entry : this.data.entrySet()) { + if (entry.getKey() >= this.dimension) { + continue; + } - double value = entry.getValue(); - if(value > iv.getValue()){ + final double value = entry.getValue(); + if (value > iv.getValue()) { iv.setValue(value); iv.setIndex(entry.getKey()); } } - if(!iv.isValid()){ - iv.setValue(defaultValue); - } else{ - if(iv.getValue() < defaultValue){ - for(int i=0; i < dimension; ++i){ - if(!data.containsKey(i)){ - iv.setValue(defaultValue); + if (!iv.isValid()) { + iv.setValue(this.defaultValue); + } else { + if (iv.getValue() < this.defaultValue) { + for (int i = 0; i < this.dimension; ++i) { + if (!this.data.containsKey(i)) { + iv.setValue(this.defaultValue); iv.setIndex(i); break; } @@ -168,182 +147,7 @@ public IndexValue indexWithMaxValue(){ } - - public Vec projectOrthogonal(Iterable vlist) { - Vec b = this; - for(Vec v : vlist) - { - b = b.minus(b.projectAlong(v)); - } - - return b; - } - - public Vec projectOrthogonal(List vlist, Map alpha) { - Vec b = this; - for(int i = 0; i < vlist.size(); ++i) - { - Vec v = vlist.get(i); - double norm_a = v.multiply(v); - - if (DoubleUtils.isZero(norm_a)) { - return new Vec(dimension); - } - double sigma = multiply(v) / norm_a; - Vec v_parallel = v.multiply(sigma); - - alpha.put(i, sigma); - - b = b.minus(v_parallel); - } - - return b; - } - - public Vec projectAlong(Vec rhs) - { - double norm_a = rhs.multiply(rhs); - - if (DoubleUtils.isZero(norm_a)) { - return new Vec(dimension); - } - double sigma = multiply(rhs) / norm_a; - return rhs.multiply(sigma); - } - - public Vec multiply(double rhs){ - Vec clone = (Vec)this.makeCopy(); - for(Integer i : data.keySet()){ - clone.data.put(i, rhs * data.get(i)); - } - return clone; - } - - public double multiply(Vec rhs) - { - double productSum = 0; - if(defaultValue == 0) { - for (Map.Entry entry : data.entrySet()) { - productSum += entry.getValue() * rhs.get(entry.getKey()); - } - } else { - for(int i=0; i < dimension; ++i){ - productSum += get(i) * rhs.get(i); - } - } - - return productSum; - } - - public Vec pow(double scalar) - { - Vec result = new Vec(dimension); - for (Map.Entry entry : data.entrySet()) - { - result.data.put(entry.getKey(), Math.pow(entry.getValue(), scalar)); - } - return result; - } - - public Vec add(Vec rhs) - { - Vec result = new Vec(dimension); - int index; - for (Map.Entry entry : data.entrySet()) { - index = entry.getKey(); - result.data.put(index, entry.getValue() + rhs.data.get(index)); - } - for(Map.Entry entry : rhs.data.entrySet()){ - index = entry.getKey(); - if(result.data.containsKey(index)) continue; - result.data.put(index, entry.getValue() + data.get(index)); - } - - return result; - } - - public Vec minus(Vec rhs) - { - Vec result = new Vec(dimension); - int index; - for (Map.Entry entry : data.entrySet()) { - index = entry.getKey(); - result.data.put(index, entry.getValue() - rhs.data.get(index)); - } - for(Map.Entry entry : rhs.data.entrySet()){ - index = entry.getKey(); - if(result.data.containsKey(index)) continue; - result.data.put(index, data.get(index) - entry.getValue()); - } - - return result; - } - - public double sum(){ - double sum = 0; - - for(Map.Entry entry : data.entrySet()){ - sum += entry.getValue(); - } - sum += defaultValue * (dimension - data.size()); - - return sum; - } - - public boolean isZero(){ - return DoubleUtils.isZero(sum()); - } - - public double norm(int level) - { - if (level == 1) - { - double sum = 0; - for (Double val : data.values()) - { - sum += Math.abs(val); - } - if(!DoubleUtils.isZero(defaultValue)) { - sum += Math.abs(defaultValue) * (dimension - data.size()); - } - return sum; - } - else if (level == 2) - { - double sum = multiply(this); - if(!DoubleUtils.isZero(defaultValue)){ - sum += (dimension - data.size()) * (defaultValue * defaultValue); - } - return Math.sqrt(sum); - } - else - { - double sum = 0; - for (Double val : this.data.values()) - { - sum += Math.pow(Math.abs(val), level); - } - if(!DoubleUtils.isZero(defaultValue)) { - sum += Math.pow(Math.abs(defaultValue), level) * (dimension - data.size()); - } - return Math.pow(sum, 1.0 / level); - } - } - - public Vec normalize() - { - double norm = norm(2); // L2 norm is the cartesian distance - if (DoubleUtils.isZero(norm)) - { - return new Vec(dimension); - } - Vec clone = new Vec(dimension); - clone.setAll(defaultValue / norm); - - for (Integer k : data.keySet()) - { - clone.data.put(k, data.get(k) / norm); - } - return clone; + void setId(final int id) { + this.id = id; } } diff --git a/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java b/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java deleted file mode 100644 index 2bbfbaa..0000000 --- a/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java +++ /dev/null @@ -1,39 +0,0 @@ -package com.github.chen0040.rl.utils; - -import java.util.ArrayList; -import java.util.List; - - -/** - * Created by xschen on 10/11/2015 0011. - */ -public class VectorUtils { - public static List removeZeroVectors(Iterable vlist) - { - List vstarlist = new ArrayList(); - for (Vec v : vlist) - { - if (!v.isZero()) - { - vstarlist.add(v); - } - } - - return vstarlist; - } - - public static TupleTwo, List> normalize(Iterable vlist) - { - List norms = new ArrayList(); - List vstarlist = new ArrayList(); - for (Vec v : vlist) - { - norms.add(v.norm(2)); - vstarlist.add(v.normalize()); - } - - return TupleTwo.create(vstarlist, norms); - } - - -}