From 4703cc6de56296e90aeedfdc56851db465cc0800 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 14 Mar 2018 18:50:14 +0800 Subject: [PATCH 1/2] init pr --- build.sbt | 2 +- .../feature/OneHotEncoderEstimator.scala | 34 +++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoderEstimator.scala diff --git a/build.sbt b/build.sbt index 825dd034..cdc7a2bd 100644 --- a/build.sbt +++ b/build.sbt @@ -14,7 +14,7 @@ sparkPackageName := "databricks/spark-sql-perf" // All Spark Packages need a license licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0")) -sparkVersion := "2.2.0" +sparkVersion := "2.3.0" sparkComponents ++= Seq("sql", "hive", "mllib") diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoderEstimator.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoderEstimator.scala new file mode 100644 index 00000000..e94ccdf2 --- /dev/null +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoderEstimator.scala @@ -0,0 +1,34 @@ +package com.databricks.spark.sql.perf.mllib.feature + +import org.apache.spark.ml +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.PipelineStage +import org.apache.spark.sql._ + +import com.databricks.spark.sql.perf.mllib.OptionImplicits._ +import com.databricks.spark.sql.perf.mllib.data.DataGenerator +import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining} + +/** Object for testing OneHotEncoder performance */ +object OneHotEncoderEstimator extends BenchmarkAlgorithm with TestFromTraining with UnaryTransformer { + + override def trainingDataSet(ctx: MLBenchContext): DataFrame = { + import ctx.params._ + import ctx.sqlContext.implicits._ + + DataGenerator.generateMixedFeatures( + ctx.sqlContext, + numExamples, + ctx.seed(), + numPartitions, + Array.fill(1)(featureArity.get) + ).rdd.map { case Row(vec: Vector) => + vec(0) // extract the single generated double value for each row + }.toDF(inputCol) + } + + override def getPipelineStage(ctx: MLBenchContext): PipelineStage = { + new ml.feature.OneHotEncoderEstimator() + .setInputCols(Array(inputCol)) + } +} From 77be6dc9e9d4f453a413b5fbc0403c6e62d20c63 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 15 Mar 2018 10:15:43 +0800 Subject: [PATCH 2/2] minor update --- .../spark/sql/perf/mllib/feature/OneHotEncoderEstimator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoderEstimator.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoderEstimator.scala index e94ccdf2..312b5ab7 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoderEstimator.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoderEstimator.scala @@ -9,7 +9,7 @@ import com.databricks.spark.sql.perf.mllib.OptionImplicits._ import com.databricks.spark.sql.perf.mllib.data.DataGenerator import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining} -/** Object for testing OneHotEncoder performance */ +/** Object for testing OneHotEncoderEstimator performance */ object OneHotEncoderEstimator extends BenchmarkAlgorithm with TestFromTraining with UnaryTransformer { override def trainingDataSet(ctx: MLBenchContext): DataFrame = {