Skip to content

Commit

Permalink
Fix type erasure warnings in files not scheduled to be removed by oth…
Browse files Browse the repository at this point in the history
…er commits
  • Loading branch information
Alex Shelkovnykov authored and ashelkovnykov committed May 9, 2020
1 parent 3d9c1e2 commit 58ab125
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,6 @@ class GameTransformer(val sc: SparkContext, implicit val logger: Logger) extends
randomEffectTypes: Set[REType],
featureShards: Set[FeatureShardId]): RDD[(UniqueSampleId, GameDatum)] = {

val parallelism = sc.getConf.get("spark.default.parallelism", s"${sc.getExecutorStorageStatus.length * 3}").toInt
val partitioner = new LongHashPartitioner(parallelism)
val idTagSet = randomEffectTypes ++
get(validationEvaluators).map(MultiEvaluatorType.getMultiEvaluatorIdTags).getOrElse(Seq())
val gameDataset = GameConverters
Expand All @@ -220,7 +218,6 @@ class GameTransformer(val sc: SparkContext, implicit val logger: Logger) extends
idTagSet,
isResponseRequired = false,
getOrDefault(inputColumnNames))
.partitionBy(partitioner)
.setName("Game dataset with UIDs for scoring")
.persist(StorageLevel.DISK_ONLY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.linkedin.photon.ml.cli.game.scoring
import org.apache.commons.cli.MissingArgumentException
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators, Params}
import org.apache.spark.sql.DataFrame
import org.apache.spark.storage.StorageLevel

Expand Down Expand Up @@ -53,6 +53,11 @@ object GameScoringDriver extends GameDriver {
// Parameters
//

val scoringPartitions: Param[Int] = ParamUtils.createParam(
"scoring partitions",
"Number of partitions to use for the data being scored",
ParamValidators.gt[Int](0.0))

val modelId: Param[String] = ParamUtils.createParam(
"model id",
"ID to tag scores with.")
Expand Down Expand Up @@ -102,6 +107,7 @@ object GameScoringDriver extends GameDriver {
setDefault(overrideOutputDirectory, false)
setDefault(dataValidation, DataValidationType.VALIDATE_DISABLED)
setDefault(logDataAndModelStats, false)
setDefault(scoringPartitions, 1)
setDefault(spillScoresToDisk, false)
setDefault(logLevel, PhotonLogger.LogLevelInfo)
setDefault(applicationName, DEFAULT_APPLICATION_NAME)
Expand Down Expand Up @@ -206,8 +212,6 @@ object GameScoringDriver extends GameDriver {
featureShardIdToIndexMapLoaderMapOpt: Option[Map[FeatureShardId, IndexMapLoader]])
: (DataFrame, Map[FeatureShardId, IndexMapLoader]) = {

val parallelism = sc.getConf.get("spark.default.parallelism", s"${sc.getExecutorStorageStatus.length * 3}").toInt

// Handle date range input
val dateRangeOpt = IOUtils.resolveRange(get(inputDataDateRange), get(inputDataDaysRange), getOrDefault(timeZone))
val recordsPaths = pathsForDateRange(getRequiredParam(inputDataDirectories), dateRangeOpt)
Expand All @@ -218,7 +222,7 @@ object GameScoringDriver extends GameDriver {
recordsPaths.map(_.toString),
featureShardIdToIndexMapLoaderMapOpt,
getRequiredParam(featureShardConfigurations),
parallelism)
getOrDefault(scoringPartitions))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,28 +120,36 @@ object Utils {
* Fetch the java map from an Avro map field.
*
* @param record The Avro generic record
* @param key The field key
* @param field The field key
* @return A java map of String -> Object
*/
def getMapAvro(
record: GenericRecord,
key: String,
field: String,
isNullOK: Boolean = false): Map[String, JObject] = {

type T = java.util.Map[Any, JObject] // to avoid type erasure warning
record.get(key) match {
case map: T => map.asScala.map {
case (k, value) => (k.toString, value match {
// Need to convert Utf8 values to String here, because otherwise we get schema casting errors and misleading
// equivalence failures downstream.
case s@(_: Utf8 | _: JString) => s.toString
case x@(_: Number | _: JBoolean) => x
case _ => null
})
}.filter(_._2 != null).toMap

case obj: JObject => throw new IllegalArgumentException(s"$obj is not map type.")
case _ => if (isNullOK) null else throw new IllegalArgumentException(s"field $key is null")
val map = record.get(field).asInstanceOf[java.util.Map[Any, JObject]]

if (map == null && isNullOK) {
null
} else if (map == null) {
throw new IllegalArgumentException(s"field '$field' is null")
} else {
map
.asScala
.flatMap { case (key, value) =>

val keyString = key.toString

value match {
// Need to convert Utf8 values to String here, because otherwise we get schema casting errors and misleading
// equivalence failures downstream.
case s@(_: Utf8 | _: JString) => Some((keyString, s.toString))
case x@(_: Number | _: JBoolean) => Some((keyString, x))
case _ => None
}
}
.toMap
}
}

Expand Down Expand Up @@ -291,7 +299,7 @@ object Utils {
@throws(classOf[IllegalArgumentException])
def getKeyFromMapOrElse[T](map: Map[String, Any], key: String, elseBranch: Either[String, T]): T = {
map.get(key) match {
case Some(x: T) => x // type erasure warning here
case Some(x) => x.asInstanceOf[T]
case _ =>
elseBranch match {
case Left(errorMsg) => throw new IllegalArgumentException(errorMsg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class ModelDataScores(override val scoresRdd: RDD[(UniqueSampleId, ScoredGameDat
case (Some(thisScore), Some(thatScore)) => op(thisScore, thatScore)
case (Some(thisScore), None) => op(thisScore, thisScore.copy(score = MathConst.DEFAULT_SCORE))
case (None, Some(thatScore)) => op(thatScore.copy(score = MathConst.DEFAULT_SCORE), thatScore)
// Only included so that Scala doesn't throw a compiler warning. Obviously, this case can never happen
case (None, None) => throw new UnsupportedOperationException("No scores to merge")
}
})

Expand Down

0 comments on commit 58ab125

Please sign in to comment.