diff --git a/src/main/scala/cognite/spark/v1/DefaultSource.scala b/src/main/scala/cognite/spark/v1/DefaultSource.scala index f95830397..d38f11d8d 100644 --- a/src/main/scala/cognite/spark/v1/DefaultSource.scala +++ b/src/main/scala/cognite/spark/v1/DefaultSource.scala @@ -483,7 +483,8 @@ object DefaultSource { tracingParent = extractTracingHeadersKernel(parameters), useSharedThrottle = toBoolean(parameters, "useSharedThrottle", defaultValue = false), serverSideFilterNullValuesOnNonSchemaRawQueries = - toBoolean(parameters, "filterNullFieldsOnNonSchemaRawQueries", defaultValue = false) + toBoolean(parameters, "filterNullFieldsOnNonSchemaRawQueries", defaultValue = false), + maxOutstandingRawInsertRequests = toPositiveInt(parameters, "maxOutstandingRawInsertRequests") ) } diff --git a/src/main/scala/cognite/spark/v1/RawTableRelation.scala b/src/main/scala/cognite/spark/v1/RawTableRelation.scala index 127153b06..54446c739 100644 --- a/src/main/scala/cognite/spark/v1/RawTableRelation.scala +++ b/src/main/scala/cognite/spark/v1/RawTableRelation.scala @@ -285,10 +285,32 @@ class RawTableRelation( val (columnNames, dfWithUnRenamedKeyColumns) = prepareForInsert(df.drop(lastUpdatedTimeColName)) dfWithUnRenamedKeyColumns.foreachPartition((rows: Iterator[Row]) => { - val batches = rows.grouped(batchSize).toVector - batches - .parTraverse_(postRows(columnNames, _)) - .unsafeRunSync() + config.maxOutstandingRawInsertRequests match { + case Some(maxOutstandingRawInsertRequests) => + // We first group by batch size of a write, and then group that by the number of allowed parallel + // outstanding requests to avoid queueing up too many requests towards the RAW API (and this potentially + // leading to an OutOfMemory) + // Note: This is a suboptimal fix, as if one of the requests in a batch is slow, we will not + // start on the next batch (this limitation used to be per partition). Instead, we should + // have a cats.effect.std.Semaphore permit with X number of outstanding requests + // or cats.effect.concurrent.Backpressure. + rows + .grouped(batchSize) + .toSeq + .grouped(maxOutstandingRawInsertRequests) + .foreach { batch => + batch.toVector + .parTraverse_(postRows(columnNames, _)) + .unsafeRunSync() + } + case None => + // Same behavior as before, which is prone to OutOfMemory if the RAW API calls are too slow + // to finish + val batches = rows.grouped(batchSize).toVector + batches + .parTraverse_(postRows(columnNames, _)) + .unsafeRunSync() + } }) } diff --git a/src/main/scala/cognite/spark/v1/RelationConfig.scala b/src/main/scala/cognite/spark/v1/RelationConfig.scala index 307dc1193..38e0553c6 100644 --- a/src/main/scala/cognite/spark/v1/RelationConfig.scala +++ b/src/main/scala/cognite/spark/v1/RelationConfig.scala @@ -30,6 +30,7 @@ final case class RelationConfig( initialRetryDelayMillis: Int, useSharedThrottle: Boolean, serverSideFilterNullValuesOnNonSchemaRawQueries: Boolean, + maxOutstandingRawInsertRequests: Option[Int] ) { /** Desired number of Spark partitions ~= partitions / parallelismPerPartition */ diff --git a/src/test/scala/cognite/spark/v1/RawTableRelationSupportingTestData.scala b/src/test/scala/cognite/spark/v1/RawTableRelationSupportingTestData.scala new file mode 100644 index 000000000..56e5934f3 --- /dev/null +++ b/src/test/scala/cognite/spark/v1/RawTableRelationSupportingTestData.scala @@ -0,0 +1,70 @@ +package cognite.spark.v1 + +object RawTableRelationSupportingTestData { + + val starships = Seq( + ("key01", "Millennium Falcon", "YT-1300 light freighter", "Light freighter"), + ("key02", "X-wing", "T-65 X-wing starfighter", "Starfighter"), + ("key03", "TIE Fighter", "Twin Ion Engine/Ln Starfighter", "Starfighter"), + ("key04", "Star Destroyer", "Imperial I-class Star Destroyer", "Capital ship"), + ("key05", "Slave 1", "Firespray-31-class patrol and attack craft", "Patrol craft"), + ("key06", "A-wing", "RZ-1 A-wing interceptor", "Interceptor"), + ("key07", "B-wing", "A/SF-01 B-wing starfighter", "Assault starfighter"), + ("key08", "Y-wing", "BTL Y-wing starfighter", "Assault starfighter"), + ("key09", "Executor", "Executor-class Star Dreadnought", "Star Dreadnought"), + ("key10", "Rebel transport", "GR-75 medium transport", "Medium transport"), + ("key11", "Naboo Royal Starship", "J-type 327 Nubian", "Yacht"), + ("key12", "ARC-170", "Aggressive ReConnaissance-170 starfighter", "Starfighter"), + ("key13", "Eta-2 Actis", "Eta-2 Actis-class light interceptor", "Interceptor"), + ("key14", "Venator-class Star Destroyer", "Venator-class", "Capital ship"), + ("key15", "Naboo N-1 Starfighter", "N-1", "Starfighter"), + ("key16", "Jedi Interceptor", "Eta-2 Actis-class interceptor", "Interceptor"), + ("key17", "Sith Infiltrator", "Scimitar", "Starfighter"), + ("key18", "V-wing", "Alpha-3 Nimbus-class V-wing starfighter", "Starfighter"), + ("key19", "Delta-7 Aethersprite", "Delta-7 Aethersprite-class light interceptor", "Interceptor"), + ("key20", "Imperial Shuttle", "Lambda-class T-4a shuttle", "Shuttle"), + ("key21", "Tantive IV", "CR90 corvette", "Corvette"), + ("key22", "Slave II", "Firespray-31-class patrol and attack craft", "Patrol craft"), + ("key23", "TIE Bomber", "TIE/sa bomber", "Bomber"), + ("key24", "Imperial Star Destroyer", "Imperial I-class Star Destroyer", "Capital ship"), + ("key25", "Sith Speeder", "FC-20 speeder bike", "Speeder"), + ("key26", "Speeder Bike", "74-Z speeder bike", "Speeder"), + ("key27", "Solar Sailer", "Punworcca 116-class interstellar sloop", "Sloop"), + ("key28", "Geonosian Starfighter", "Nantex-class territorial defense starfighter", "Starfighter"), + ("key29", "Hound's Tooth", "YT-2000 light freighter", "Light freighter"), + ("key30", "Scimitar", "Sith Infiltrator", "Starfighter"), + ("key31", "Tie Interceptor", "TIE/in interceptor", "Starfighter"), + ("key32", "Naboo Royal Cruiser", "J-type diplomatic barge", "Yacht"), + ("key33", "X-34 Landspeeder", "X-34", "Landspeeder"), + ("key34", "Snowspeeder", "T-47 airspeeder", "Airspeeder"), + ("key35", "The Ghost", "VCX-100 light freighter", "Light freighter"), + ("key36", "Phantom", "VCX-series auxiliary starfighter", "Auxiliary starfighter"), + ("key37", "Outrider", "YT-2400 light freighter", "Light freighter"), + ("key38", "Razor Crest", "ST-70 Assault Ship", "Assault ship"), + ("key39", "Naboo Yacht", "J-type star skiff", "Yacht"), + ("key40", "U-wing", "UT-60D", "Transport"), + ("key41", "TIE Advanced x1", "TIE Advanced x1", "Starfighter"), + ("key42", "J-type 327 Nubian", "J-type 327", "Yacht"), + ("key43", "Naboo Royal Starship", "J-type 327 Nubian", "Yacht"), + ("key44", "Naboo N-1 Starfighter", "N-1", "Starfighter"), + ("key45", "Sith Infiltrator", "Scimitar", "Starfighter"), + ("key46", "Havoc Marauder", "Omicron-class attack shuttle", "Attack shuttle"), + ("key47", "Luthen's Ship", "Fondor Haulcraft", "Haulcraft"), + ("key48", "Tantive IV", "CR90 corvette", "Corvette"), + ("key49", "Millennium Falcon", "YT-1300 light freighter", "Light freighter"), + ("key50", "Jedi Starfighter", "Delta-7 Aethersprite-class light interceptor", "Interceptor"), + ("key51", "Vulture Droid", "Variable Geometry Self-Propelled Battle Droid", "Starfighter"), + ("key52", "Tri-Fighter", "Droid Tri-Fighter", "Starfighter"), + ("key53", "Hyena Bomber", "Baktoid Armor Workshop", "Starfighter"), + ("key54", "Droid Gunship", "Heavy Missile Platform", "Gunship"), + ("key55", "Malevolence", "Subjugator-class heavy cruiser", "Heavy cruiser"), + ("key56", "Invisible Hand", "Providence-class Dreadnought", "Dreadnought"), + ("key57", "Malevolence", "Subjugator-class heavy cruiser", "Heavy cruiser"), + ("key58", "Invisible Hand", "Providence-class Dreadnought", "Dreadnought"), + ("key59", "Droid Control Ship", "Lucrehulk-class battleship", "Battleship"), + ("key60", "Venator-class Star Destroyer", "Venator-class", "Capital ship"), + ) + + val starshipsMap: Map[String, Map[String, String]] = + starships.map { s => (s._1 -> Map("name" -> s._2, "model" -> s._3, "class" -> s._4)) }.toMap +} diff --git a/src/test/scala/cognite/spark/v1/RawTableRelationTest.scala b/src/test/scala/cognite/spark/v1/RawTableRelationTest.scala index 23e139bb9..f761fb3d6 100644 --- a/src/test/scala/cognite/spark/v1/RawTableRelationTest.scala +++ b/src/test/scala/cognite/spark/v1/RawTableRelationTest.scala @@ -173,6 +173,8 @@ class RawTableRelationTest RawRow(i.toString, Map("i" -> Json.fromString("exist"))) )), TestTable("raw-write-test", Seq.empty), // used for writes + TestTable("raw-write-no-throttling", Seq.empty), + TestTable("raw-write-with-throttling", Seq.empty), TestTable("MegaColumnTable", Seq( RawRow("rowkey", (1 to 384).map(i => (i.toString -> Json.fromString("value"))).toMap @@ -348,6 +350,60 @@ class RawTableRelationTest collectToSet[JavaLong](unRenamed2.select($"___lastUpdatedTime")) should equal(Set(111, 222)) } + it should "insert data correctly when no throttling is present, but multiple queries" taggedAs (WriteTest) in { + val ships = RawTableRelationSupportingTestData.starships.toDF("key", "name", "model", "class") + ships.createTempView("ships2") + + val destinationDataframe = spark.read + .format(DefaultSource.sparkFormatString) + .useOIDCWrite + .option("type", "raw") + .option("database", testData.dbName) + .option("table", "raw-write-no-throttling") + .option("inferSchema", false) + .option("batchSize", "5") + .schema(ships.schema) + .load() + destinationDataframe.createTempView("destinationTableNoThrottling") + + spark.sql("select key, name, model, class from ships2") + .select(destinationDataframe.columns.map(col).toIndexedSeq: _*) + .write + .insertInto("destinationTableNoThrottling") + + val verification: DataFrame = rawRead("raw-write-no-throttling", testData.dbName) + verification.count() should equal(ships.count()) + verification.collect().foreach ( row => verifyRow(row, Array("name", "model", "class").toIndexedSeq, RawTableRelationSupportingTestData.starshipsMap) ) + } + + + it should "insert data correctly when throttling of outstanding requests is set, and has multiple queries" taggedAs (WriteTest) in { + val ships = RawTableRelationSupportingTestData.starships.toDF("key", "name", "model", "class") + ships.createTempView("ships") + + val destinationDataframe = spark.read + .format(DefaultSource.sparkFormatString) + .useOIDCWrite + .option("type", "raw") + .option("database", testData.dbName) + .option("table", "raw-write-with-throttling") + .option("inferSchema", false) + .option("batchSize", "5") + .option("maxOutstandingRawInsertRequests", "2") + .schema(ships.schema) + .load() + destinationDataframe.createTempView("destinationTableWithThrottling") + + spark.sql("select key, name, model, class from ships") + .select(destinationDataframe.columns.map(col).toIndexedSeq: _*) + .write + .insertInto("destinationTableWithThrottling") + + val verification: DataFrame = rawRead("raw-write-with-throttling", testData.dbName) + verification.count() should equal(ships.count()) + verification.collect().foreach(row => verifyRow(row, Array("name", "model", "class").toIndexedSeq, RawTableRelationSupportingTestData.starshipsMap)) + } + it should "read nested StructType" in { val schema = dfWithSimpleNestedStruct.schema schema.fieldNames should contain("nested") @@ -982,4 +1038,10 @@ class RawTableRelationTest } } + private def verifyRow(row: Row, columns: Seq[String], expected: Map[String, Map[String, String]]): Unit = { + val rowKey = row.getAs[String]("key") + columns.foreach { column => + row.getAs[String](column) should equal(expected(rowKey)(column)) + } + } } diff --git a/src/test/scala/cognite/spark/v1/SparkTest.scala b/src/test/scala/cognite/spark/v1/SparkTest.scala index f792b365b..fe6fa037d 100644 --- a/src/test/scala/cognite/spark/v1/SparkTest.scala +++ b/src/test/scala/cognite/spark/v1/SparkTest.scala @@ -274,7 +274,8 @@ trait SparkTest { enableSinglePartitionDeleteAssetHierarchy = false, tracingParent = new Kernel(Map.empty), useSharedThrottle = false, - serverSideFilterNullValuesOnNonSchemaRawQueries = false + serverSideFilterNullValuesOnNonSchemaRawQueries = false, + maxOutstandingRawInsertRequests = None ) private def getCounterSafe(metricsNamespace: String, resource: String): Option[Long] = {