Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] Fix partition related issue #9491

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,9 @@ object PreXGBoost extends PreXGBoostProvider {
val group = est match {
case regressor: XGBoostRegressor =>
// get group column, if group is not defined, default to lit(-1)
Some(
if (!regressor.isDefined(regressor.groupCol) || regressor.getGroupCol.isEmpty) {
defaultGroupColumn
} else col(regressor.getGroupCol)
)
if (!regressor.isDefined(regressor.groupCol) || regressor.getGroupCol.isEmpty) {
None
} else Some(col(regressor.getGroupCol))
case _ => None

}
Expand All @@ -144,7 +142,7 @@ object PreXGBoost extends PreXGBoostProvider {
})

(PackedParams(col(est.getLabelCol), col(featuresName), weight, baseMargin, group,
est.getNumWorkers, est.needDeterministicRepartitioning), evalSets, xgbInput)
est.getNumWorkers), evalSets, xgbInput)

case _ => throw new RuntimeException("Unsupporting " + estimator)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ private[scala] sealed trait XGBoostEstimatorCommon extends GeneralParams with Le
with HasBaseMarginCol with HasLeafPredictionCol with HasContribPredictionCol with HasFeaturesCol
with HasLabelCol with HasFeaturesCols with HasHandleInvalid {

def needDeterministicRepartitioning: Boolean = {
isDefined(checkpointPath) && getCheckpointPath != null && getCheckpointPath.nonEmpty &&
isDefined(checkpointInterval) && getCheckpointInterval > 0
}

/**
* Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
* invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,33 +72,25 @@ object DataUtils extends Serializable {

private def attachPartitionKey(
row: Row,
deterministicPartition: Boolean,
numWorkers: Int,
xgbLp: XGBLabeledPoint): (Int, XGBLabeledPoint) = {
if (deterministicPartition) {
(math.abs(row.hashCode() % numWorkers), xgbLp)
xgbLp: XGBLabeledPoint,
group: Option[Int]): (Int, XGBLabeledPoint) = {
// If group exists, we must use group as key to make sure instances for a group are
// the same partition.
if (group.isDefined){
(group.get % numWorkers, xgbLp)
// If no group exists, we can use row hash as key for the repartition
} else {
(1, xgbLp)
(math.abs(row.hashCode() % numWorkers), xgbLp)
}
}

private def repartitionRDDs(
deterministicPartition: Boolean,
numWorkers: Int,
arrayOfRDDs: Array[RDD[(Int, XGBLabeledPoint)]]): Array[RDD[XGBLabeledPoint]] = {
if (deterministicPartition) {
arrayOfRDDs.map {rdd => rdd.partitionBy(new HashPartitioner(numWorkers))}.map {
rdd => rdd.map(_._2)
}
} else {
arrayOfRDDs.map(rdd => {
if (rdd.getNumPartitions != numWorkers) {
rdd.map(_._2).repartition(numWorkers)
} else {
rdd.map(_._2)
}
})
}
}

/** Packed parameters used by [[convertDataFrameToXGBLabeledPointRDDs]] */
Expand All @@ -107,8 +99,7 @@ object DataUtils extends Serializable {
weight: Column,
baseMargin: Column,
group: Option[Column],
numWorkers: Int,
deterministicPartition: Boolean)
numWorkers: Int)

/**
* convertDataFrameToXGBLabeledPointRDDs converts DataFrames to an array of RDD[XGBLabeledPoint]
Expand All @@ -122,8 +113,7 @@ object DataUtils extends Serializable {
dataFrames: DataFrame*): Array[RDD[XGBLabeledPoint]] = {

packedParams match {
case j @ PackedParams(labelCol, featuresCol, weight, baseMargin, group, numWorkers,
deterministicPartition) =>
case j @ PackedParams(labelCol, featuresCol, weight, baseMargin, group, numWorkers) =>
val selectedColumns = group.map(groupCol => Seq(labelCol.cast(FloatType),
featuresCol,
weight.cast(FloatType),
Expand All @@ -141,18 +131,18 @@ object DataUtils extends Serializable {
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
}
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
attachPartitionKey(row, numWorkers, xgbLp, Some(group))
case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) =>
val (size, indices, values) = features match {
case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat))
case v: DenseVector => (v.size, null, v.values.map(_.toFloat))
}
val xgbLp = XGBLabeledPoint(label, size, indices, values, weight,
baseMargin = baseMargin)
attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp)
attachPartitionKey(row, numWorkers, xgbLp, None)
}
}
repartitionRDDs(deterministicPartition, numWorkers, arrayOfRDDs)
repartitionRDDs(numWorkers, arrayOfRDDs)

case _ => throw new IllegalArgumentException("Wrong PackedParams") // never reach here
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,6 @@ import org.apache.spark.sql.functions._

class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite with PerTest {

test("perform deterministic partitioning when checkpointInternal and" +
" checkpointPath is set (Classifier)") {
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
val xgbClassifier = new XGBoostClassifier(paramMap)
assert(xgbClassifier.needDeterministicRepartitioning)
}

test("perform deterministic partitioning when checkpointInternal and" +
" checkpointPath is set (Regressor)") {
val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)
val xgbRegressor = new XGBoostRegressor(paramMap)
assert(xgbRegressor.needDeterministicRepartitioning)
}

test("deterministic partitioning takes effect with various parts of data") {
val trainingDF = buildDataFrame(Classification.train)
// the test idea is that, we apply a chain of repartitions over trainingDFs but they
Expand All @@ -62,8 +42,7 @@ class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite
lit(1.0),
lit(Float.NaN),
None,
numWorkers,
deterministicPartition = true),
numWorkers),
df
).head)
val resultsMaps = transformedRDDs.map(rdd => rdd.mapPartitionsWithIndex {
Expand Down Expand Up @@ -97,8 +76,7 @@ class DeterministicPartitioningSuite extends AnyFunSuite with TmpFolderPerSuite
lit(1.0),
lit(Float.NaN),
None,
10,
deterministicPartition = true), df
10), df
).head

val partitionsSizes = dfRepartitioned
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class FeatureSizeValidatingSuite extends AnyFunSuite with PerTest {
(id, lp.label, lp.features)
}.toDF("id", "label", "features")
val xgb = new XGBoostClassifier(paramMap)
xgb.fit(repartitioned)
val exception = intercept[Exception]{
xgb.fit(repartitioned)
}
assert(exception.getMessage.contains("ml.dmlc.xgboost4j.java.XGBoostError"))
}
}
Loading