Skip to content

Commit b2d36f6

Browse files
imatiach-msftsrowen
authored andcommitted
[SPARK-19591][ML][MLLIB] Add sample weights to decision trees
This is updated PR apache#16722 to latest master ## What changes were proposed in this pull request? This patch adds support for sample weights to DecisionTreeRegressor and DecisionTreeClassifier. Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr. ## How was this patch tested? The algorithms are tested to ensure that: 1. Arbitrary scaling of constant weights has no effect 2. Outliers with small weights do not affect the learned model 3. Oversampling and weighting are equivalent Unit tests are also added to test other smaller components. ## Summary of changes - Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use minInstancesPerNode. - Impurity aggregators now also hold the raw count. - This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight. - This patch modifies findSplitsForContinuousFeatures to use weighted sums. Unit tests are added. - TreePoint is modified to hold a sample weight - BaggedPoint is modified from: ``` Scala private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable ``` to ``` Scala private[spark] class BaggedPoint[Datum]( val datum: Datum, val subsampleCounts: Array[Int], val sampleWeight: Double) extends Serializable ``` We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both minInstancesPerNode and minWeightPerNode **Note**: many of the changed files are due simply to using Instance instead of LabeledPoint Closes apache#21632 from imatiach-msft/ilmat/sample-weights. Authored-by: Ilya Matiach <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 3699763 commit b2d36f6

31 files changed

+743
-280
lines changed

mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ object TestingUtils {
5252
/**
5353
* Private helper function for comparing two values using absolute tolerance.
5454
*/
55-
private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
55+
private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
5656
// Special case for NaNs
5757
if (x.isNaN && y.isNaN) {
5858
return true

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

+25-5
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,37 @@ abstract class Classifier[
7777
* @note Throws `SparkException` if any label is a non-integer or is negative
7878
*/
7979
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
80-
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
81-
s" $numClasses, but requires numClasses > 0.")
80+
validateNumClasses(numClasses)
8281
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
8382
case Row(label: Double, features: Vector) =>
84-
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
85-
s" dataset with invalid label $label. Labels must be integers in range" +
86-
s" [0, $numClasses).")
83+
validateLabel(label, numClasses)
8784
LabeledPoint(label, features)
8885
}
8986
}
9087

88+
/**
89+
* Validates that number of classes is greater than zero.
90+
*
91+
* @param numClasses Number of classes label can take.
92+
*/
93+
protected def validateNumClasses(numClasses: Int): Unit = {
94+
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
95+
s" $numClasses, but requires numClasses > 0.")
96+
}
97+
98+
/**
99+
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
100+
*
101+
* @param label The label to validate.
102+
* @param numClasses Number of classes label can take. Labels must be integers in the range
103+
* [0, numClasses).
104+
*/
105+
protected def validateLabel(label: Double, numClasses: Int): Unit = {
106+
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
107+
s" dataset with invalid label $label. Labels must be integers in range" +
108+
s" [0, $numClasses).")
109+
}
110+
91111
/**
92112
* Get the number of classes. This looks in column metadata first, and if that is missing,
93113
* then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

+34-12
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,22 @@ import org.json4s.{DefaultFormats, JObject}
2222
import org.json4s.JsonDSL._
2323

2424
import org.apache.spark.annotation.Since
25-
import org.apache.spark.ml.feature.LabeledPoint
25+
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
2626
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
2727
import org.apache.spark.ml.param.ParamMap
28+
import org.apache.spark.ml.param.shared.HasWeightCol
2829
import org.apache.spark.ml.tree._
30+
import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams}
2931
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
3032
import org.apache.spark.ml.tree.impl.RandomForest
3133
import org.apache.spark.ml.util._
3234
import org.apache.spark.ml.util.Instrumentation.instrumented
3335
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
3436
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
3537
import org.apache.spark.rdd.RDD
36-
import org.apache.spark.sql.Dataset
37-
38+
import org.apache.spark.sql.{Dataset, Row}
39+
import org.apache.spark.sql.functions.{col, lit}
40+
import org.apache.spark.sql.types.DoubleType
3841

3942
/**
4043
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
@@ -66,6 +69,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
6669
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
6770

6871
/** @group setParam */
72+
@Since("3.0.0")
73+
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
74+
6975
@Since("1.4.0")
7076
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
7177

@@ -97,29 +103,44 @@ class DecisionTreeClassifier @Since("1.4.0") (
97103
@Since("1.6.0")
98104
def setSeed(value: Long): this.type = set(seed, value)
99105

106+
/**
107+
* Sets the value of param [[weightCol]].
108+
* If this is not set or empty, we treat all instance weights as 1.0.
109+
* Default is not set, so all instances have weight one.
110+
*
111+
* @group setParam
112+
*/
113+
@Since("3.0.0")
114+
def setWeightCol(value: String): this.type = set(weightCol, value)
115+
100116
override protected def train(
101117
dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
102118
instr.logPipelineStage(this)
103119
instr.logDataset(dataset)
104120
val categoricalFeatures: Map[Int, Int] =
105121
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
106122
val numClasses: Int = getNumClasses(dataset)
107-
instr.logNumClasses(numClasses)
108123

109124
if (isDefined(thresholds)) {
110125
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
111126
".train() called with non-matching numClasses and thresholds.length." +
112127
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
113128
}
114-
115-
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
129+
validateNumClasses(numClasses)
130+
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
131+
val instances =
132+
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
133+
case Row(label: Double, weight: Double, features: Vector) =>
134+
validateLabel(label, numClasses)
135+
Instance(label, weight, features)
136+
}
116137
val strategy = getOldStrategy(categoricalFeatures, numClasses)
117-
138+
instr.logNumClasses(numClasses)
118139
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
119140
probabilityCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
120141
cacheNodeIds, checkpointInterval, impurity, seed)
121142

122-
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
143+
val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
123144
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
124145

125146
trees.head.asInstanceOf[DecisionTreeClassificationModel]
@@ -128,13 +149,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
128149
/** (private[ml]) Train a decision tree on an RDD */
129150
private[ml] def train(data: RDD[LabeledPoint],
130151
oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr =>
152+
val instances = data.map(_.toInstance)
131153
instr.logPipelineStage(this)
132-
instr.logDataset(data)
154+
instr.logDataset(instances)
133155
instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
134156
cacheNodeIds, checkpointInterval, impurity, seed)
135-
136-
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
137-
seed = 0L, instr = Some(instr), parentUID = Some(uid))
157+
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
158+
featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid))
138159

139160
trees.head.asInstanceOf[DecisionTreeClassificationModel]
140161
}
@@ -180,6 +201,7 @@ class DecisionTreeClassificationModel private[ml] (
180201

181202
/**
182203
* Construct a decision tree classification model.
204+
*
183205
* @param rootNode Root node of tree, with other nodes attached.
184206
*/
185207
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

+6-5
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,21 @@ import org.json4s.{DefaultFormats, JObject}
2121
import org.json4s.JsonDSL._
2222

2323
import org.apache.spark.annotation.Since
24-
import org.apache.spark.ml.feature.LabeledPoint
24+
import org.apache.spark.ml.feature.Instance
2525
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
2626
import org.apache.spark.ml.param.ParamMap
2727
import org.apache.spark.ml.tree._
28+
import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
2829
import org.apache.spark.ml.tree.impl.RandomForest
2930
import org.apache.spark.ml.util._
31+
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
3032
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
3133
import org.apache.spark.ml.util.Instrumentation.instrumented
3234
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
3335
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
3436
import org.apache.spark.rdd.RDD
3537
import org.apache.spark.sql.{DataFrame, Dataset}
36-
import org.apache.spark.sql.functions._
37-
38+
import org.apache.spark.sql.functions.{col, udf}
3839

3940
/**
4041
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
@@ -130,7 +131,7 @@ class RandomForestClassifier @Since("1.4.0") (
130131
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
131132
}
132133

133-
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
134+
val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance)
134135
val strategy =
135136
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
136137

@@ -139,7 +140,7 @@ class RandomForestClassifier @Since("1.4.0") (
139140
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
140141

141142
val trees = RandomForest
142-
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
143+
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
143144
.map(_.asInstanceOf[DecisionTreeClassificationModel])
144145

145146
val numFeatures = trees.head.numFeatures

mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala

+9
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,13 @@ case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features:
3737
override def toString: String = {
3838
s"($label,$features)"
3939
}
40+
41+
private[spark] def toInstance(weight: Double): Instance = {
42+
Instance(label, weight, features)
43+
}
44+
45+
private[spark] def toInstance: Instance = {
46+
Instance(label, 1.0, features)
47+
}
48+
4049
}

mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala

+30-7
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ import org.json4s.JsonDSL._
2323

2424
import org.apache.spark.annotation.Since
2525
import org.apache.spark.ml.{PredictionModel, Predictor}
26-
import org.apache.spark.ml.feature.LabeledPoint
26+
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
2727
import org.apache.spark.ml.linalg.Vector
2828
import org.apache.spark.ml.param.ParamMap
29+
import org.apache.spark.ml.param.shared.HasWeightCol
2930
import org.apache.spark.ml.tree._
3031
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
3132
import org.apache.spark.ml.tree.impl.RandomForest
@@ -34,8 +35,9 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
3435
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
3536
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
3637
import org.apache.spark.rdd.RDD
37-
import org.apache.spark.sql.{DataFrame, Dataset}
38+
import org.apache.spark.sql.{DataFrame, Dataset, Row}
3839
import org.apache.spark.sql.functions._
40+
import org.apache.spark.sql.types.DoubleType
3941

4042

4143
/**
@@ -65,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
6567
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
6668

6769
/** @group setParam */
70+
@Since("3.0.0")
71+
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
72+
6873
@Since("1.4.0")
6974
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
7075

@@ -100,18 +105,33 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
100105
@Since("2.0.0")
101106
def setVarianceCol(value: String): this.type = set(varianceCol, value)
102107

108+
/**
109+
* Sets the value of param [[weightCol]].
110+
* If this is not set or empty, we treat all instance weights as 1.0.
111+
* Default is not set, so all instances have weight one.
112+
*
113+
* @group setParam
114+
*/
115+
@Since("3.0.0")
116+
def setWeightCol(value: String): this.type = set(weightCol, value)
117+
103118
override protected def train(
104119
dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr =>
105120
val categoricalFeatures: Map[Int, Int] =
106121
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
107-
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
122+
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
123+
val instances =
124+
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
125+
case Row(label: Double, weight: Double, features: Vector) =>
126+
Instance(label, weight, features)
127+
}
108128
val strategy = getOldStrategy(categoricalFeatures)
109129

110130
instr.logPipelineStage(this)
111-
instr.logDataset(oldDataset)
131+
instr.logDataset(instances)
112132
instr.logParams(this, params: _*)
113133

114-
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
134+
val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
115135
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
116136

117137
trees.head.asInstanceOf[DecisionTreeRegressionModel]
@@ -126,8 +146,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
126146
instr.logDataset(data)
127147
instr.logParams(this, params: _*)
128148

129-
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy,
130-
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
149+
val instances = data.map(_.toInstance)
150+
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
151+
featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid))
131152

132153
trees.head.asInstanceOf[DecisionTreeRegressionModel]
133154
}
@@ -155,6 +176,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
155176
* <a href="http://en.wikipedia.org/wiki/Decision_tree_learning">
156177
* Decision tree (Wikipedia)</a> model for regression.
157178
* It supports both continuous and categorical features.
179+
*
158180
* @param rootNode Root of the decision tree
159181
*/
160182
@Since("1.4.0")
@@ -173,6 +195,7 @@ class DecisionTreeRegressionModel private[ml] (
173195

174196
/**
175197
* Construct a decision tree regression model.
198+
*
176199
* @param rootNode Root node of tree, with other nodes attached.
177200
*/
178201
private[ml] def this(rootNode: Node, numFeatures: Int) =

mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala

+5-7
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import org.json4s.JsonDSL._
2222

2323
import org.apache.spark.annotation.Since
2424
import org.apache.spark.ml.{PredictionModel, Predictor}
25-
import org.apache.spark.ml.feature.LabeledPoint
2625
import org.apache.spark.ml.linalg.Vector
2726
import org.apache.spark.ml.param.ParamMap
2827
import org.apache.spark.ml.tree._
@@ -32,10 +31,8 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
3231
import org.apache.spark.ml.util.Instrumentation.instrumented
3332
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
3433
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
35-
import org.apache.spark.rdd.RDD
3634
import org.apache.spark.sql.{DataFrame, Dataset}
37-
import org.apache.spark.sql.functions._
38-
35+
import org.apache.spark.sql.functions.{col, udf}
3936

4037
/**
4138
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a>
@@ -119,18 +116,19 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
119116
dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr =>
120117
val categoricalFeatures: Map[Int, Int] =
121118
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
122-
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
119+
120+
val instances = extractLabeledPoints(dataset).map(_.toInstance)
123121
val strategy =
124122
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
125123

126124
instr.logPipelineStage(this)
127-
instr.logDataset(dataset)
125+
instr.logDataset(instances)
128126
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees,
129127
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
130128
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval)
131129

132130
val trees = RandomForest
133-
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
131+
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
134132
.map(_.asInstanceOf[DecisionTreeRegressionModel])
135133

136134
val numFeatures = trees.head.numFeatures

0 commit comments

Comments
 (0)