@@ -19,6 +19,8 @@ package org.apache.spark.ml.clustering
19
19
20
20
import java .util .Locale
21
21
22
+ import breeze .linalg .normalize
23
+ import breeze .numerics .exp
22
24
import org .apache .hadoop .fs .Path
23
25
import org .json4s .DefaultFormats
24
26
import org .json4s .JsonAST .JObject
@@ -27,23 +29,23 @@ import org.json4s.jackson.JsonMethods._
27
29
import org .apache .spark .annotation .{DeveloperApi , Since }
28
30
import org .apache .spark .internal .Logging
29
31
import org .apache .spark .ml .{Estimator , Model }
30
- import org .apache .spark .ml .linalg .{ Matrix , Vector , Vectors , VectorUDT }
32
+ import org .apache .spark .ml .linalg ._
31
33
import org .apache .spark .ml .param ._
32
34
import org .apache .spark .ml .param .shared .{HasCheckpointInterval , HasFeaturesCol , HasMaxIter , HasSeed }
33
35
import org .apache .spark .ml .util ._
34
36
import org .apache .spark .ml .util .DefaultParamsReader .Metadata
35
37
import org .apache .spark .ml .util .Instrumentation .instrumented
36
38
import org .apache .spark .mllib .clustering .{DistributedLDAModel => OldDistributedLDAModel ,
37
39
EMLDAOptimizer => OldEMLDAOptimizer , LDA => OldLDA , LDAModel => OldLDAModel ,
38
- LDAOptimizer => OldLDAOptimizer , LocalLDAModel => OldLocalLDAModel ,
40
+ LDAOptimizer => OldLDAOptimizer , LDAUtils => OldLDAUtils , LocalLDAModel => OldLocalLDAModel ,
39
41
OnlineLDAOptimizer => OldOnlineLDAOptimizer }
40
42
import org .apache .spark .mllib .linalg .{Vector => OldVector , Vectors => OldVectors }
41
43
import org .apache .spark .mllib .linalg .MatrixImplicits ._
42
44
import org .apache .spark .mllib .linalg .VectorImplicits ._
43
45
import org .apache .spark .mllib .util .MLUtils
44
46
import org .apache .spark .rdd .RDD
45
47
import org .apache .spark .sql .{DataFrame , Dataset , Row , SparkSession }
46
- import org .apache .spark .sql .functions .{col , monotonically_increasing_id , udf }
48
+ import org .apache .spark .sql .functions .{monotonically_increasing_id , udf }
47
49
import org .apache .spark .sql .types .StructType
48
50
import org .apache .spark .storage .StorageLevel
49
51
import org .apache .spark .util .PeriodicCheckpointer
@@ -457,21 +459,56 @@ abstract class LDAModel private[ml] (
457
459
*/
458
460
@ Since (" 2.0.0" )
459
461
override def transform (dataset : Dataset [_]): DataFrame = {
460
- if ($(topicDistributionCol).nonEmpty) {
462
+ transformSchema(dataset.schema, logging = true )
461
463
462
- // TODO: Make the transformer natively in ml framework to avoid extra conversion.
463
- val transformer = oldLocalModel.getTopicDistributionMethod
464
+ if ($(topicDistributionCol).nonEmpty) {
465
+ val func = getTopicDistributionMethod
466
+ val transformer = udf(func)
464
467
465
- val t = udf { (v : Vector ) => transformer(OldVectors .fromML(v)).asML }
466
468
dataset.withColumn($(topicDistributionCol),
467
- t (DatasetUtils .columnToVector(dataset, getFeaturesCol))).toDF( )
469
+ transformer (DatasetUtils .columnToVector(dataset, getFeaturesCol)))
468
470
} else {
469
471
logWarning(" LDAModel.transform was called without any output columns. Set an output column" +
470
472
" such as topicDistributionCol to produce results." )
471
473
dataset.toDF()
472
474
}
473
475
}
474
476
477
+ /**
478
+ * Get a method usable as a UDF for `topicDistributions()`
479
+ */
480
+ private def getTopicDistributionMethod : Vector => Vector = {
481
+ val expElogbeta = exp(OldLDAUtils .dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t)
482
+ val oldModel = oldLocalModel
483
+ val docConcentrationBrz = oldModel.docConcentration.asBreeze
484
+ val gammaShape = oldModel.gammaShape
485
+ val k = oldModel.k
486
+ val gammaSeed = oldModel.seed
487
+
488
+ vector : Vector =>
489
+ if (vector.numNonzeros == 0 ) {
490
+ Vectors .zeros(k)
491
+ } else {
492
+ val (ids : List [Int ], cts : Array [Double ]) = vector match {
493
+ case v : DenseVector => ((0 until v.size).toList, v.values)
494
+ case v : SparseVector => (v.indices.toList, v.values)
495
+ case other =>
496
+ throw new UnsupportedOperationException (
497
+ s " Only sparse and dense vectors are supported but got ${other.getClass}. " )
498
+ }
499
+
500
+ val (gamma, _, _) = OldOnlineLDAOptimizer .variationalTopicInference(
501
+ ids,
502
+ cts,
503
+ expElogbeta,
504
+ docConcentrationBrz,
505
+ gammaShape,
506
+ k,
507
+ gammaSeed)
508
+ Vectors .dense(normalize(gamma, 1.0 ).toArray)
509
+ }
510
+ }
511
+
475
512
@ Since (" 1.6.0" )
476
513
override def transformSchema (schema : StructType ): StructType = {
477
514
validateAndTransformSchema(schema)
0 commit comments