diff --git a/src/main/scala/cognite/spark/v1/AssetsRelation.scala b/src/main/scala/cognite/spark/v1/AssetsRelation.scala index 70eafab5b..89bc378c1 100644 --- a/src/main/scala/cognite/spark/v1/AssetsRelation.scala +++ b/src/main/scala/cognite/spark/v1/AssetsRelation.scala @@ -26,7 +26,10 @@ class AssetsRelation(config: RelationConfig, subtreeIds: Option[List[CogniteId]] override def getStreams(sparkFilters: Array[Filter])( client: GenericClient[IO]): Seq[Stream[IO, AssetsReadSchema]] = { val (ids, filters) = - pushdownToFilters(sparkFilters, assetsFilterFromMap, AssetsFilter(assetSubtreeIds = subtreeIds)) + pushdownToFilters( + sparkFilters, + f => assetsFilterFromMap(f.fieldValues), + AssetsFilter(assetSubtreeIds = subtreeIds)) executeFilter(client.assets, filters, ids, config.partitions, config.limitPerPartition) .map( _.map( diff --git a/src/main/scala/cognite/spark/v1/DataSetsRelation.scala b/src/main/scala/cognite/spark/v1/DataSetsRelation.scala index 4c1f6d364..215e38caa 100644 --- a/src/main/scala/cognite/spark/v1/DataSetsRelation.scala +++ b/src/main/scala/cognite/spark/v1/DataSetsRelation.scala @@ -26,7 +26,8 @@ class DataSetsRelation(config: RelationConfig)(val sqlContext: SQLContext) override def getStreams(sparkFilters: Array[Filter])( client: GenericClient[IO]): Seq[fs2.Stream[IO, DataSet]] = { - val (ids, filters) = pushdownToFilters(sparkFilters, dataSetFilterFromMap, DataSetFilter()) + val (ids, filters) = + pushdownToFilters(sparkFilters, f => dataSetFilterFromMap(f.fieldValues), DataSetFilter()) Seq(executeFilterOnePartition(client.dataSets, filters, ids, config.limitPerPartition)) } diff --git a/src/main/scala/cognite/spark/v1/EventsRelation.scala b/src/main/scala/cognite/spark/v1/EventsRelation.scala index cac5867a4..10a34043e 100644 --- a/src/main/scala/cognite/spark/v1/EventsRelation.scala +++ b/src/main/scala/cognite/spark/v1/EventsRelation.scala @@ -20,7 +20,8 @@ class EventsRelation(config: RelationConfig)(val sqlContext: SQLContext) import cognite.spark.compiletime.macros.StructTypeEncoderMacro._ override def getStreams(sparkFilters: Array[Filter])( client: GenericClient[IO]): Seq[Stream[IO, Event]] = { - val (ids, filters) = pushdownToFilters(sparkFilters, eventsFilterFromMap, EventsFilter()) + val (ids, filters) = + pushdownToFilters(sparkFilters, f => eventsFilterFromMap(f.fieldValues), EventsFilter()) executeFilter(client.events, filters, ids, config.partitions, config.limitPerPartition) } diff --git a/src/main/scala/cognite/spark/v1/FilesRelation.scala b/src/main/scala/cognite/spark/v1/FilesRelation.scala index 737e9e213..c8d27bfbe 100644 --- a/src/main/scala/cognite/spark/v1/FilesRelation.scala +++ b/src/main/scala/cognite/spark/v1/FilesRelation.scala @@ -55,7 +55,8 @@ class FilesRelation(config: RelationConfig)(val sqlContext: SQLContext) override def getStreams(sparkFilters: Array[Filter])( client: GenericClient[IO]): Seq[Stream[IO, FilesReadSchema]] = { - val (ids, filters) = pushdownToFilters(sparkFilters, filesFilterFromMap, FilesFilter()) + val (ids, filters) = + pushdownToFilters(sparkFilters, f => filesFilterFromMap(f.fieldValues), FilesFilter()) executeFilter(client.files, filters, ids, config.partitions, config.limitPerPartition).map( _.map( _.into[FilesReadSchema] diff --git a/src/main/scala/cognite/spark/v1/NumericDataPointsRelation.scala b/src/main/scala/cognite/spark/v1/NumericDataPointsRelation.scala index 0c56b91ba..97617febf 100644 --- a/src/main/scala/cognite/spark/v1/NumericDataPointsRelation.scala +++ b/src/main/scala/cognite/spark/v1/NumericDataPointsRelation.scala @@ -4,8 +4,8 @@ import cats.data.Validated.{Invalid, Valid} import cats.effect.IO import cats.implicits._ import cognite.spark.v1.PushdownUtilities.{ - getIdFromMap, - pushdownToParameters, + getIdFromAndFilter, + pushdownToSimpleOr, toPushdownFilterExpression } import cognite.spark.compiletime.macros.SparkSchemaHelper.{asRow, fromRow, structType} @@ -183,8 +183,8 @@ class NumericDataPointsRelationV1(config: RelationConfig)(sqlContext: SQLContext override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val pushdownFilterExpression = toPushdownFilterExpression(filters) val timestampLimits = filtersToTimestampLimits(filters, "timestamp") - val filtersAsMaps = pushdownToParameters(pushdownFilterExpression) - val ids = filtersAsMaps.flatMap(getIdFromMap).distinct + val filtersAsMaps = pushdownToSimpleOr(pushdownFilterExpression).filters + val ids = filtersAsMaps.flatMap(getIdFromAndFilter).distinct // Notify users that they need to supply one or more ids/externalIds when reading data points if (ids.isEmpty) { diff --git a/src/main/scala/cognite/spark/v1/PushdownUtilities.scala b/src/main/scala/cognite/spark/v1/PushdownUtilities.scala index ae6c84fa0..6da10ca31 100644 --- a/src/main/scala/cognite/spark/v1/PushdownUtilities.scala +++ b/src/main/scala/cognite/spark/v1/PushdownUtilities.scala @@ -34,27 +34,51 @@ final case class DeleteItemByCogniteId( } } +final case class SimpleAndEqualsFilter(fieldValues: Map[String, String]) { + def and(other: SimpleAndEqualsFilter): SimpleAndEqualsFilter = + SimpleAndEqualsFilter(fieldValues ++ other.fieldValues) +} + +object SimpleAndEqualsFilter { + def singleton(tuple: (String, String)): SimpleAndEqualsFilter = + SimpleAndEqualsFilter(Map(tuple)) + def singleton(key: String, value: String): SimpleAndEqualsFilter = + singleton((key, value)) + +} + +final case class SimpleOrFilter(filters: Seq[SimpleAndEqualsFilter]) { + def isJustTrue: Boolean = filters.isEmpty +} + +object SimpleOrFilter { + def alwaysTrue: SimpleOrFilter = SimpleOrFilter(Seq.empty) + def singleton(filter: SimpleAndEqualsFilter): SimpleOrFilter = + SimpleOrFilter(Seq(filter)) +} + sealed trait PushdownExpression final case class PushdownFilter(fieldName: String, value: String) extends PushdownExpression final case class PushdownAnd(left: PushdownExpression, right: PushdownExpression) extends PushdownExpression -final case class PushdownFilters(filters: Seq[PushdownExpression]) extends PushdownExpression +final case class PushdownUnion(filters: Seq[PushdownExpression]) extends PushdownExpression final case class NoPushdown() extends PushdownExpression object PushdownUtilities { def pushdownToFilters[F]( sparkFilters: Array[Filter], - mapping: Map[String, String] => F, + mapping: SimpleAndEqualsFilter => F, allFilter: F): (Vector[CogniteId], Vector[F]) = { val pushdownFilterExpression = toPushdownFilterExpression(sparkFilters) - val filtersAsMaps = pushdownToParameters(pushdownFilterExpression).toVector + val filtersAsMaps = pushdownToSimpleOr(pushdownFilterExpression).filters.toVector val (idFilterMaps, filterMaps) = - filtersAsMaps.partition(m => m.contains("id") || m.contains("externalId")) + filtersAsMaps.partition(m => m.fieldValues.contains("id") || m.fieldValues.contains("externalId")) val ids = idFilterMaps.map( m => - m.get("id") + m.fieldValues + .get("id") .map(id => CogniteInternalId(id.toLong)) - .getOrElse(CogniteExternalId(m("externalId")))) + .getOrElse(CogniteExternalId(m.fieldValues("externalId")))) val filters = filterMaps.map(mapping) val shouldGetAll = filters.contains(allFilter) || (filters.isEmpty && ids.isEmpty) if (shouldGetAll) { @@ -64,27 +88,30 @@ object PushdownUtilities { } } - def pushdownToParameters(p: PushdownExpression): Seq[Map[String, String]] = + def pushdownToSimpleOr(p: PushdownExpression): SimpleOrFilter = p match { case PushdownAnd(left, right) => - handleAnd(pushdownToParameters(left), pushdownToParameters(right)) - case PushdownFilter(field, value) => Seq(Map[String, String](field -> value)) - case PushdownFilters(filters) => filters.flatMap(pushdownToParameters) - case NoPushdown() => Seq() + handleAnd(pushdownToSimpleOr(left), pushdownToSimpleOr(right)) + case PushdownFilter(field, value) => + SimpleOrFilter.singleton( + SimpleAndEqualsFilter.singleton(field -> value) + ) + case PushdownUnion(filters) => + SimpleOrFilter(filters.flatMap(pushdownToSimpleOr(_).filters)) + case NoPushdown() => SimpleOrFilter.alwaysTrue } - def handleAnd( - left: Seq[Map[String, String]], - right: Seq[Map[String, String]]): Seq[Map[String, String]] = - if (left.isEmpty) { + def handleAnd(left: SimpleOrFilter, right: SimpleOrFilter): SimpleOrFilter = + if (left.isJustTrue) { right - } else if (right.isEmpty) { + } else if (right.isJustTrue) { left - } else { - for { - l <- left - r <- right - } yield l ++ r + } else { // try each left-right item combination + val filters = for { + l <- left.filters + r <- right.filters + } yield l.and(r) + SimpleOrFilter(filters) } def toPushdownFilterExpression(filters: Array[Filter]): PushdownExpression = @@ -134,7 +161,7 @@ object PushdownUtilities { case StringStartsWith(colName, value) => PushdownFilter(colName + "Prefix", value) case In(colName, values) => - PushdownFilters( + PushdownUnion( // X in (null, Y) will result in `NULL`, which is treated like false. // X AND NULL is NULL (like with false) // true OR NULL is true (like with false) @@ -146,7 +173,7 @@ object PushdownUtilities { .toIndexedSeq ) case And(f1, f2) => PushdownAnd(getFilter(f1), getFilter(f2)) - case Or(f1, f2) => PushdownFilters(Seq(getFilter(f1), getFilter(f2))) + case Or(f1, f2) => PushdownUnion(Seq(getFilter(f1), getFilter(f2))) case _ => NoPushdown() } } @@ -158,7 +185,7 @@ object PushdownUtilities { case PushdownAnd(left, right) => shouldGetAll(left, fieldsWithPushdownFilter) && shouldGetAll(right, fieldsWithPushdownFilter) case PushdownFilter(field, _) => !fieldsWithPushdownFilter.contains(field) - case PushdownFilters(filters) => + case PushdownUnion(filters) => filters .map(shouldGetAll(_, fieldsWithPushdownFilter)) .exists(identity) @@ -261,6 +288,9 @@ object PushdownUtilities { .map(id => CogniteInternalId(id.toLong)) .orElse(m.get("externalId").map(CogniteExternalId(_))) + def getIdFromAndFilter(f: SimpleAndEqualsFilter): Option[CogniteId] = + getIdFromMap(f.fieldValues) + def mergeStreams[T, F[_]: Concurrent](streams: Seq[Stream[F, T]]): Stream[F, T] = streams.reduceOption(_.merge(_)).getOrElse(Stream.empty) diff --git a/src/main/scala/cognite/spark/v1/RelationshipsRelation.scala b/src/main/scala/cognite/spark/v1/RelationshipsRelation.scala index 84abe2089..f4912e0c8 100644 --- a/src/main/scala/cognite/spark/v1/RelationshipsRelation.scala +++ b/src/main/scala/cognite/spark/v1/RelationshipsRelation.scala @@ -28,7 +28,10 @@ class RelationshipsRelation(config: RelationConfig)(val sqlContext: SQLContext) override def getStreams(sparkFilters: Array[Filter])( client: GenericClient[IO]): Seq[Stream[IO, RelationshipsReadSchema]] = { val (ids, filters) = - pushdownToFilters(sparkFilters, relationshipsFilterFromMap, RelationshipsFilter()) + pushdownToFilters( + sparkFilters, + f => relationshipsFilterFromMap(f.fieldValues), + RelationshipsFilter()) // TODO: support parallel retrival using partitions Seq( diff --git a/src/main/scala/cognite/spark/v1/StringDataPointsRelation.scala b/src/main/scala/cognite/spark/v1/StringDataPointsRelation.scala index 4ec1b2f66..887f507a0 100644 --- a/src/main/scala/cognite/spark/v1/StringDataPointsRelation.scala +++ b/src/main/scala/cognite/spark/v1/StringDataPointsRelation.scala @@ -2,8 +2,8 @@ package cognite.spark.v1 import cats.effect.IO import cognite.spark.v1.PushdownUtilities.{ - getIdFromMap, - pushdownToParameters, + getIdFromAndFilter, + pushdownToSimpleOr, toPushdownFilterExpression } import cognite.spark.compiletime.macros.SparkSchemaHelper.{asRow, fromRow, structType} @@ -75,8 +75,8 @@ class StringDataPointsRelationV1(config: RelationConfig)(override val sqlContext override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val pushdownFilterExpression = toPushdownFilterExpression(filters) - val filtersAsMaps = pushdownToParameters(pushdownFilterExpression) - val ids = filtersAsMaps.flatMap(getIdFromMap).distinct + val filtersAsMaps = pushdownToSimpleOr(pushdownFilterExpression).filters + val ids = filtersAsMaps.flatMap(getIdFromAndFilter).distinct // Notify users that they need to supply one or more ids/externalIds when reading data points if (ids.isEmpty) { diff --git a/src/main/scala/cognite/spark/v1/TimeSeriesRelation.scala b/src/main/scala/cognite/spark/v1/TimeSeriesRelation.scala index 79bd5a27a..683ca2ca0 100644 --- a/src/main/scala/cognite/spark/v1/TimeSeriesRelation.scala +++ b/src/main/scala/cognite/spark/v1/TimeSeriesRelation.scala @@ -72,7 +72,8 @@ class TimeSeriesRelation(config: RelationConfig)(val sqlContext: SQLContext) override def getStreams(sparkFilters: Array[Filter])( client: GenericClient[IO]): Seq[Stream[IO, TimeSeries]] = { - val (ids, filters) = pushdownToFilters(sparkFilters, timeSeriesFilterFromMap, TimeSeriesFilter()) + val (ids, filters) = + pushdownToFilters(sparkFilters, f => timeSeriesFilterFromMap(f.fieldValues), TimeSeriesFilter()) executeFilter(client.timeSeries, filters, ids, config.partitions, config.limitPerPartition) } diff --git a/src/test/scala/cognite/spark/v1/PushdownUtilitiesTest.scala b/src/test/scala/cognite/spark/v1/PushdownUtilitiesTest.scala index 1fe6c4687..42076b571 100644 --- a/src/test/scala/cognite/spark/v1/PushdownUtilitiesTest.scala +++ b/src/test/scala/cognite/spark/v1/PushdownUtilitiesTest.scala @@ -2,37 +2,37 @@ package cognite.spark.v1 import cognite.spark.v1.PushdownUtilities._ import org.scalatest.{FlatSpec, Matchers, ParallelTestExecution} -class PushdownUtilitiesTest extends FlatSpec with ParallelTestExecution with Matchers with SparkTest { +class PushdownUtilitiesTest extends FlatSpec with ParallelTestExecution with Matchers { it should "create one request for 1x1 and expression" in { val pushdownExpression = PushdownAnd(PushdownFilter("id", "123"), PushdownFilter("type", "abc")) - val params = pushdownToParameters(pushdownExpression) + val params = pushdownToSimpleOr(pushdownExpression) - assert(params.length == 1) + assert(params.filters.length == 1) } it should "create two requests for 1+1 or expression" in { val pushdownExpression = - PushdownFilters(Seq(PushdownFilter("id", "123"), PushdownFilter("type", "abc"))) - val params = pushdownToParameters(pushdownExpression) + PushdownUnion(Seq(PushdownFilter("id", "123"), PushdownFilter("type", "abc"))) + val params = pushdownToSimpleOr(pushdownExpression) - assert(params.length == 2) + assert(params.filters.length == 2) } it should "create 9 requests for 3x3 and or expression" in { - val left = PushdownFilters( + val left = PushdownUnion( Seq( PushdownFilter("id", "123"), PushdownFilter("type", "abc"), PushdownFilter("description", "test"))) - val right = PushdownFilters( + val right = PushdownUnion( Seq( PushdownFilter("id", "456"), PushdownFilter("type", "def"), PushdownFilter("description", "test2"))) val pushdownExpression = PushdownAnd(left, right) - val params = pushdownToParameters(pushdownExpression) + val params = pushdownToSimpleOr(pushdownExpression) - assert(params.length == 9) + assert(params.filters.length == 9) } }