From 44bab624aaae36a3f6e66591bdc2b2e2afdcc998 Mon Sep 17 00:00:00 2001 From: Lin Zhou <87341375+linzhou-db@users.noreply.github.com> Date: Tue, 23 May 2023 15:40:14 -0700 Subject: [PATCH] Add an optional `expirationTimestamp` field for pre signed urls (#312) * actual url print * tmp * fix * fix * remove debug * pick conserative threshold * fix * fix * Option * fi * synchronized * refactor code * fix * add comment * resolve comments * refresh --- PROTOCOL.md | 16 +- .../sharing/server/CloudFileSigner.scala | 31 +- .../scala/io/delta/sharing/server/model.scala | 4 + .../internal/DeltaSharedTableLoader.scala | 21 +- .../server/DeltaSharingServiceSuite.scala | 26 ++ .../spark/DeltaSharingProfileProvider.scala | 5 +- .../sharing/spark/DeltaSharingSource.scala | 295 +++++++++++++----- .../spark/RemoteDeltaCDFRelation.scala | 68 +++- .../sharing/spark/RemoteDeltaFileIndex.scala | 9 +- .../delta/sharing/spark/RemoteDeltaLog.scala | 29 +- .../scala/io/delta/sharing/spark/model.scala | 12 +- .../delta/sharing/PreSignedUrlCache.scala | 68 ++-- .../spark/DeltaSharingRestClientSuite.scala | 29 ++ .../sharing/CachedTableManagerSuite.scala | 99 +++++- 14 files changed, 570 insertions(+), 142 deletions(-) diff --git a/PROTOCOL.md b/PROTOCOL.md index 25f13adff..7b1a7f27b 100644 --- a/PROTOCOL.md +++ b/PROTOCOL.md @@ -2414,6 +2414,7 @@ size | Long | The size of this file in bytes. | Required stats | String | Contains statistics (e.g., count, min/max values for columns) about the data in this file. This field may be missing. A file may or may not have stats. This is a serialized JSON string which can be deserialized to a [Statistics Struct](#per-file-statistics). A client can decide whether to use stats or drop it. | Optional version | Long | The table version of the file, returned when querying a table data with a version or timestamp parameter. | Optional timestamp | Long | The unix timestamp corresponding to the table version of the file, in milliseconds, returned when querying a table data with a version or timestamp parameter. | Optional +expirationTimestamp | Long | The unix timestamp corresponding to the expiration of the url, in milliseconds, returned when the server supports the feature. | Optional Example (for illustration purposes; each JSON object must be a single line in the response): @@ -2426,7 +2427,8 @@ Example (for illustration purposes; each JSON object must be a single line in th "partitionValues": { "date": "2021-04-28" }, - "stats": "{\"numRecords\":1,\"minValues\":{\"eventTime\":\"2021-04-28T23:33:48.719Z\"},\"maxValues\":{\"eventTime\":\"2021-04-28T23:33:48.719Z\"},\"nullCount\":{\"eventTime\":0}}" + "stats": "{\"numRecords\":1,\"minValues\":{\"eventTime\":\"2021-04-28T23:33:48.719Z\"},\"maxValues\":{\"eventTime\":\"2021-04-28T23:33:48.719Z\"},\"nullCount\":{\"eventTime\":0}}", + "expirationTimestamp": 1652140800000 } } ``` @@ -2443,6 +2445,7 @@ size | Long | The size of this file in bytes. | Required timestamp | Long | The timestamp of the file in milliseconds from epoch. | Required version | Int32 | The table version of this file. | Required stats | String | Contains statistics (e.g., count, min/max values for columns) about the data in this file. This field may be missing. A file may or may not have stats. This is a serialized JSON string which can be deserialized to a [Statistics Struct](#per-file-statistics). A client can decide whether to use stats or drop it. | Optional +expirationTimestamp | Long | The unix timestamp corresponding to the expiration of the url, in milliseconds, returned when the server supports the feature. | Optional Example (for illustration purposes; each JSON object must be a single line in the response): @@ -2457,7 +2460,8 @@ Example (for illustration purposes; each JSON object must be a single line in th }, "timestamp": 1652140800000, "version": 1, - "stats": "{\"numRecords\":1,\"minValues\":{\"eventTime\":\"2021-04-28T23:33:48.719Z\"},\"maxValues\":{\"eventTime\":\"2021-04-28T23:33:48.719Z\"},\"nullCount\":{\"eventTime\":0}}" + "stats": "{\"numRecords\":1,\"minValues\":{\"eventTime\":\"2021-04-28T23:33:48.719Z\"},\"maxValues\":{\"eventTime\":\"2021-04-28T23:33:48.719Z\"},\"nullCount\":{\"eventTime\":0}}", + "expirationTimestamp": 1652144400000 } } ``` @@ -2471,6 +2475,7 @@ partitionValues | Map | A map from partition column to value for size | Long | The size of this file in bytes. | Required timestamp | Long | The timestamp of the file in milliseconds from epoch. | Required version | Int32 | The table version of this file. | Required +expirationTimestamp | Long | The unix timestamp corresponding to the expiration of the url, in milliseconds, returned when the server supports the feature. | Optional Example (for illustration purposes; each JSON object must be a single line in the response): @@ -2484,7 +2489,8 @@ Example (for illustration purposes; each JSON object must be a single line in th "date": "2021-04-28" }, "timestamp": 1652140800000, - "version": 1 + "version": 1, + "expirationTimestamp": 1652144400000 } } ``` @@ -2498,6 +2504,7 @@ partitionValues | Map | A map from partition column to value for size | Long | The size of this file in bytes. | Required timestamp | Long | The timestamp of the file in milliseconds from epoch. | Required version | Int32 | The table version of this file. | Required +expirationTimestamp | Long | The unix timestamp corresponding to the expiration of the url, in milliseconds, returned when the server supports the feature. | Optional Example (for illustration purposes; each JSON object must be a single line in the response): @@ -2511,7 +2518,8 @@ Example (for illustration purposes; each JSON object must be a single line in th "date": "2021-04-28" }, "timestamp": 1652140800000, - "version": 1 + "version": 1, + "expirationTimestamp": 1652144400000 } } ``` diff --git a/server/src/main/scala/io/delta/sharing/server/CloudFileSigner.scala b/server/src/main/scala/io/delta/sharing/server/CloudFileSigner.scala index 553636d25..22031838d 100644 --- a/server/src/main/scala/io/delta/sharing/server/CloudFileSigner.scala +++ b/server/src/main/scala/io/delta/sharing/server/CloudFileSigner.scala @@ -37,9 +37,15 @@ import org.apache.hadoop.fs.azurebfs.services.AuthType import org.apache.hadoop.fs.s3a.DefaultS3ClientFactory import org.apache.hadoop.util.ReflectionUtils +/** + * @param url The signed url. + * @param expirationTimestamp The expiration timestamp in millis of the signed url, a minimum + * between the timeout of the url and of the token. + */ +case class PreSignedUrl(url: String, expirationTimestamp: Long) trait CloudFileSigner { - def sign(path: Path): String + def sign(path: Path): PreSignedUrl } class S3FileSigner( @@ -50,7 +56,7 @@ class S3FileSigner( private val s3Client = ReflectionUtils.newInstance(classOf[DefaultS3ClientFactory], conf) .createS3Client(name) - override def sign(path: Path): String = { + override def sign(path: Path): PreSignedUrl = { val absPath = path.toUri val bucketName = absPath.getHost val objectKey = absPath.getPath.stripPrefix("/") @@ -60,7 +66,10 @@ class S3FileSigner( val request = new GeneratePresignedUrlRequest(bucketName, objectKey) .withMethod(HttpMethod.GET) .withExpiration(expiration) - s3Client.generatePresignedUrl(request).toString + PreSignedUrl( + s3Client.generatePresignedUrl(request).toString, + System.currentTimeMillis() + SECONDS.toMillis(preSignedUrlTimeoutSeconds) + ) } } @@ -102,7 +111,7 @@ class AzureFileSigner( sharedAccessPolicy } - override def sign(path: Path): String = { + override def sign(path: Path): PreSignedUrl = { val containerRef = blobClient.getContainerReference(container) val objectKey = objectKeyExtractor(path) assert(objectKey.nonEmpty, s"cannot get object key from $path") @@ -116,7 +125,10 @@ class AzureFileSigner( SharedAccessProtocols.HTTPS_ONLY ) val sasTokenCredentials = new StorageCredentialsSharedAccessSignature(sasToken) - sasTokenCredentials.transformUri(blobRef.getUri).toString + PreSignedUrl( + sasTokenCredentials.transformUri(blobRef.getUri).toString, + System.currentTimeMillis() + SECONDS.toMillis(preSignedUrlTimeoutSeconds) + ) } } @@ -202,13 +214,16 @@ class GCSFileSigner( private val storage = StorageOptions.newBuilder.build.getService - override def sign(path: Path): String = { + override def sign(path: Path): PreSignedUrl = { val (bucketName, objectName) = GCSFileSigner.getBucketAndObjectNames(path) assert(objectName.nonEmpty, s"cannot get object key from $path") val blobInfo = BlobInfo.newBuilder(BlobId.of(bucketName, objectName)).build - storage.signUrl( + PreSignedUrl( + storage.signUrl( blobInfo, preSignedUrlTimeoutSeconds, SECONDS, Storage.SignUrlOption.withV4Signature()) - .toString + .toString, + System.currentTimeMillis() + SECONDS.toMillis(preSignedUrlTimeoutSeconds) + ) } } diff --git a/server/src/main/scala/io/delta/sharing/server/model.scala b/server/src/main/scala/io/delta/sharing/server/model.scala index de60db3c6..83beb2e24 100644 --- a/server/src/main/scala/io/delta/sharing/server/model.scala +++ b/server/src/main/scala/io/delta/sharing/server/model.scala @@ -78,6 +78,7 @@ case class AddFile( size: Long, @JsonRawValue stats: String = null, + expirationTimestamp: java.lang.Long = null, timestamp: java.lang.Long = null, version: java.lang.Long = null) extends Action { @@ -90,6 +91,7 @@ case class AddFileForCDF( @JsonInclude(JsonInclude.Include.ALWAYS) partitionValues: Map[String, String], size: Long, + expirationTimestamp: java.lang.Long = null, version: Long, timestamp: Long, @JsonRawValue @@ -104,6 +106,7 @@ case class AddCDCFile( @JsonInclude(JsonInclude.Include.ALWAYS) partitionValues: Map[String, String], size: Long, + expirationTimestamp: java.lang.Long = null, timestamp: Long, version: Long) extends Action { @@ -117,6 +120,7 @@ case class RemoveFile( @JsonInclude(JsonInclude.Include.ALWAYS) partitionValues: Map[String, String], size: Long, + expirationTimestamp: java.lang.Long = null, timestamp: Long, version: Long) extends Action { diff --git a/server/src/main/scala/io/delta/standalone/internal/DeltaSharedTableLoader.scala b/server/src/main/scala/io/delta/standalone/internal/DeltaSharedTableLoader.scala index a6a556eb4..52eac3d0a 100644 --- a/server/src/main/scala/io/delta/standalone/internal/DeltaSharedTableLoader.scala +++ b/server/src/main/scala/io/delta/standalone/internal/DeltaSharedTableLoader.scala @@ -271,7 +271,9 @@ class DeltaSharedTable( filteredFiles.map { addFile => val cloudPath = absolutePath(deltaLog.dataPath, addFile.path) val signedUrl = fileSigner.sign(cloudPath) - val modelAddFile = model.AddFile(url = signedUrl, + val modelAddFile = model.AddFile( + url = signedUrl.url, + expirationTimestamp = signedUrl.expirationTimestamp, id = Hashing.md5().hashString(addFile.path, UTF_8).toString, partitionValues = addFile.partitionValues, size = addFile.size, @@ -317,8 +319,10 @@ class DeltaSharedTable( val ts = timestampsByVersion.get(v).orNull versionActions.foreach { case a: AddFile if a.dataChange => + val signedUrl = fileSigner.sign(absolutePath(deltaLog.dataPath, a.path)) val modelAddFile = model.AddFileForCDF( - url = fileSigner.sign(absolutePath(deltaLog.dataPath, a.path)), + url = signedUrl.url, + expirationTimestamp = signedUrl.expirationTimestamp, id = Hashing.md5().hashString(a.path, UTF_8).toString, partitionValues = a.partitionValues, size = a.size, @@ -328,8 +332,10 @@ class DeltaSharedTable( ) actions.append(modelAddFile.wrap) case r: RemoveFile if r.dataChange => + val signedUrl = fileSigner.sign(absolutePath(deltaLog.dataPath, r.path)) val modelRemoveFile = model.RemoveFile( - url = fileSigner.sign(absolutePath(deltaLog.dataPath, r.path)), + url = signedUrl.url, + expirationTimestamp = signedUrl.expirationTimestamp, id = Hashing.md5().hashString(r.path, UTF_8).toString, partitionValues = r.partitionValues, size = r.size.get, @@ -417,7 +423,8 @@ class DeltaSharedTable( val cloudPath = absolutePath(deltaLog.dataPath, addCDCFile.path) val signedUrl = fileSigner.sign(cloudPath) val modelCDCFile = model.AddCDCFile( - url = signedUrl, + url = signedUrl.url, + expirationTimestamp = signedUrl.expirationTimestamp, id = Hashing.md5().hashString(addCDCFile.path, UTF_8).toString, partitionValues = addCDCFile.partitionValues, size = addCDCFile.size, @@ -433,7 +440,8 @@ class DeltaSharedTable( val cloudPath = absolutePath(deltaLog.dataPath, addFile.path) val signedUrl = fileSigner.sign(cloudPath) val modelAddFile = model.AddFileForCDF( - url = signedUrl, + url = signedUrl.url, + expirationTimestamp = signedUrl.expirationTimestamp, id = Hashing.md5().hashString(addFile.path, UTF_8).toString, partitionValues = addFile.partitionValues, size = addFile.size, @@ -450,7 +458,8 @@ class DeltaSharedTable( val cloudPath = absolutePath(deltaLog.dataPath, removeFile.path) val signedUrl = fileSigner.sign(cloudPath) val modelRemoveFile = model.RemoveFile( - url = signedUrl, + url = signedUrl.url, + expirationTimestamp = signedUrl.expirationTimestamp, id = Hashing.md5().hashString(removeFile.path, UTF_8).toString, partitionValues = removeFile.partitionValues, size = removeFile.size.get, diff --git a/server/src/test/scala/io/delta/sharing/server/DeltaSharingServiceSuite.scala b/server/src/test/scala/io/delta/sharing/server/DeltaSharingServiceSuite.scala index 433fed16c..2c78fcf55 100644 --- a/server/src/test/scala/io/delta/sharing/server/DeltaSharingServiceSuite.scala +++ b/server/src/test/scala/io/delta/sharing/server/DeltaSharingServiceSuite.scala @@ -527,6 +527,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { val expectedFiles = Seq( AddFile( url = actualFiles(0).url, + expirationTimestamp = actualFiles(0).expirationTimestamp, id = "061cb3683a467066995f8cdaabd8667d", partitionValues = Map.empty, size = 781, @@ -534,12 +535,14 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { ), AddFile( url = actualFiles(1).url, + expirationTimestamp = actualFiles(1).expirationTimestamp, id = "e268cbf70dbaa6143e7e9fa3e2d3b00e", partitionValues = Map.empty, size = 781, stats = """{"numRecords":1,"minValues":{"eventTime":"2021-04-28T06:32:02.070Z","date":"2021-04-28"},"maxValues":{"eventTime":"2021-04-28T06:32:02.070Z","date":"2021-04-28"},"nullCount":{"eventTime":0,"date":0}}""" ) ) + assert(actualFiles.count(_.expirationTimestamp != null) == 2) assert(expectedFiles == actualFiles.toList) verifyPreSignedUrl(actualFiles(0).url, 781) verifyPreSignedUrl(actualFiles(1).url, 781) @@ -617,6 +620,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { val expectedFiles = Seq( AddFile( url = actualFiles(0).url, + expirationTimestamp = actualFiles(0).expirationTimestamp, id = "9f1a49539c5cffe1ea7f9e055d5c003c", partitionValues = Map("date" -> "2021-04-28"), size = 573, @@ -624,12 +628,14 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { ), AddFile( url = actualFiles(1).url, + expirationTimestamp = actualFiles(1).expirationTimestamp, id = "cd2209b32f5ed5305922dd50f5908a75", partitionValues = Map("date" -> "2021-04-28"), size = 573, stats = """{"numRecords":1,"minValues":{"eventTime":"2021-04-28T23:33:48.719Z"},"maxValues":{"eventTime":"2021-04-28T23:33:48.719Z"},"nullCount":{"eventTime":0}}""" ) ) + assert(actualFiles.count(_.expirationTimestamp != null) == 2) assert(expectedFiles == actualFiles.toList) verifyPreSignedUrl(actualFiles(0).url, 573) verifyPreSignedUrl(actualFiles(1).url, 573) @@ -727,6 +733,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { val expectedFiles = Seq( AddFile( url = actualFiles(0).url, + expirationTimestamp = actualFiles(0).expirationTimestamp, id = "db213271abffec6fd6c7fc2aad9d4b3f", partitionValues = Map("date" -> "2021-04-28"), size = 778, @@ -734,6 +741,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { ), AddFile( url = actualFiles(1).url, + expirationTimestamp = actualFiles(1).expirationTimestamp, id = "f1f8be229d8b18eb6d6a34255f2d7089", partitionValues = Map("date" -> "2021-04-28"), size = 778, @@ -741,12 +749,14 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { ), AddFile( url = actualFiles(2).url, + expirationTimestamp = actualFiles(2).expirationTimestamp, id = "a892a55d770ee70b34ffb2ebf7dc2fd0", partitionValues = Map("date" -> "2021-04-28"), size = 573, stats = """{"numRecords":1,"minValues":{"eventTime":"2021-04-28T23:35:53.156Z"},"maxValues":{"eventTime":"2021-04-28T23:35:53.156Z"},"nullCount":{"eventTime":0}}""" ) ) + assert(actualFiles.count(_.expirationTimestamp != null) == 3) assert(expectedFiles == actualFiles.toList) verifyPreSignedUrl(actualFiles(0).url, 778) verifyPreSignedUrl(actualFiles(1).url, 778) @@ -806,6 +816,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { val expectedFiles = Seq( AddFile( url = actualFiles(0).url, + expirationTimestamp = actualFiles(0).expirationTimestamp, id = "60d0cf57f3e4367db154aa2c36152a1f", partitionValues = Map.empty, size = 1030, @@ -815,6 +826,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { ), AddFile( url = actualFiles(1).url, + expirationTimestamp = actualFiles(1).expirationTimestamp, id = "d7ed708546dd70fdff9191b3e3d6448b", partitionValues = Map.empty, size = 1030, @@ -824,6 +836,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { ), AddFile( url = actualFiles(2).url, + expirationTimestamp = actualFiles(2).expirationTimestamp, id = "a6dc5694a4ebcc9a067b19c348526ad6", partitionValues = Map.empty, size = 1030, @@ -832,6 +845,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { timestamp = 1651272635000L ) ) + assert(actualFiles.count(_.expirationTimestamp != null) == 3) assert(expectedFiles == actualFiles.toList) verifyPreSignedUrl(actualFiles(0).url, 1030) verifyPreSignedUrl(actualFiles(1).url, 1030) @@ -998,6 +1012,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { val expectedFiles = Seq( AddFile( url = actualFiles(0).url, + expirationTimestamp = actualFiles(0).expirationTimestamp, id = "60d0cf57f3e4367db154aa2c36152a1f", partitionValues = Map.empty, size = 1030, @@ -1007,6 +1022,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { ), AddFile( url = actualFiles(1).url, + expirationTimestamp = actualFiles(1).expirationTimestamp, id = "d7ed708546dd70fdff9191b3e3d6448b", partitionValues = Map.empty, size = 1030, @@ -1016,6 +1032,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { ), AddFile( url = actualFiles(2).url, + expirationTimestamp = actualFiles(2).expirationTimestamp, id = "a6dc5694a4ebcc9a067b19c348526ad6", partitionValues = Map.empty, size = 1030, @@ -1024,6 +1041,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { timestamp = 1651272635000L ) ) + assert(actualFiles.count(_.expirationTimestamp != null) == 3) assert(expectedFiles == actualFiles.toList) verifyPreSignedUrl(actualFiles(0).url, 1030) verifyPreSignedUrl(actualFiles(1).url, 1030) @@ -1574,6 +1592,8 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { assert(addFile.version == version) assert(addFile.timestamp == timestamp) verifyPreSignedUrl(addFile.url, size.toInt) + val timeToExpiration = addFile.expirationTimestamp - System.currentTimeMillis() + assert(timeToExpiration < 60 * 60 * 1000 && timeToExpiration > 50 * 60 * 1000) } private def verifyAddCDCFile( @@ -1589,6 +1609,8 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { assert(addCDCFile.version == version) assert(addCDCFile.timestamp == timestamp) verifyPreSignedUrl(addCDCFile.url, size.toInt) + val timeToExpiration = addCDCFile.expirationTimestamp - System.currentTimeMillis() + assert(timeToExpiration < 60 * 60 * 1000 && timeToExpiration > 50 * 60 * 1000) } private def verifyRemove( @@ -1604,6 +1626,8 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { assert(removeFile.version == version) assert(removeFile.timestamp == timestamp) verifyPreSignedUrl(removeFile.url, size.toInt) + val timeToExpiration = removeFile.expirationTimestamp - System.currentTimeMillis() + assert(timeToExpiration < 60 * 60 * 1000 && timeToExpiration > 50 * 60 * 1000) } integrationTest("table_data_loss_with_checkpoint - /shares/{share}/schemas/{schema}/tables/{table}/query") { @@ -1917,6 +1941,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { val expectedFiles = Seq( AddFile( url = actualFiles(0).url, + expirationTimestamp = actualFiles(0).expirationTimestamp, id = "84f5f9e4de01e99837f77bfc2b7215b0", partitionValues = Map("c2" -> "foo bar"), size = 568, @@ -1948,6 +1973,7 @@ class DeltaSharingServiceSuite extends FunSuite with BeforeAndAfterAll { val expectedFiles = Seq( AddFile( url = actualFiles(0).url, + expirationTimestamp = actualFiles(0).expirationTimestamp, id = "84f5f9e4de01e99837f77bfc2b7215b0", partitionValues = Map("c2" -> "foo bar"), size = 568, diff --git a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingProfileProvider.scala b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingProfileProvider.scala index 62cc89527..ae76b4fa0 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingProfileProvider.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingProfileProvider.scala @@ -47,7 +47,10 @@ trait DeltaSharingProfileProvider { def getCustomTablePath(tablePath: String): String = tablePath - def getCustomRefresher(refresher: () => Map[String, String]): () => Map[String, String] = { + // Map[String, String] is the id to url map. + // Long is the minimum url expiration time for all the urls. + def getCustomRefresher(refresher: () => (Map[String, String], Option[Long])): () => + (Map[String, String], Option[Long]) = { refresher } } diff --git a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala index e7854b67d..bc4cf9854 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/DeltaSharingSource.scala @@ -132,9 +132,6 @@ case class DeltaSharingSource( private val tableId = initSnapshot.metadata.id - private val refreshPresignedUrls = spark.sessionState.conf.getConfString( - "spark.delta.sharing.source.refreshPresignedUrls.enabled", "true").toBoolean - // Records until which offset the delta sharing source has been processing the table files. private var previousOffset: DeltaSharingSourceOffset = null @@ -142,7 +139,17 @@ case class DeltaSharingSource( // If not empty, will advance the offset and fetch data from this list based on the read limit. // If empty, will try to load all possible new data files through delta sharing rpc to this list, // sorted by version and id. - private var sortedFetchedFiles: Seq[IndexedFile] = Seq.empty + // GLOBAL variable which should be protected by synchronized + @volatile private var sortedFetchedFiles: Seq[IndexedFile] = Seq.empty + // The latest timestamp in millisecond, records the time of the last rpc sent to the server to + // fetch the pre-signed urls. + // This is used to track whether the pre-signed urls stored in sortedFetchedFiles are going to + // expire and need a refresh. + // GLOBAL variable which should be protected by synchronized + private var lastQueryTableTimestamp: Long = -1 + // The minimum url expiration timestamp for urls returned from a rpc. + // GLOBAL variable which should be protected by synchronized + private var minUrlExpirationTimestamp: Option[Long] = None private var lastGetVersionTimestamp: Long = -1 private var latestTableVersion: Long = -1 @@ -154,12 +161,7 @@ case class DeltaSharingSource( // The latest function used to fetch presigned urls for the delta sharing table, record it in // a variable to be used by the CachedTableManager to refresh the presigned urls if the query // runs for a long time. - private var latestRefreshFunc = () => { Map.empty[String, String] } - // The latest timestamp in millisecond, records the time of the last rpc sent to the server to - // fetch the pre-signed urls. - // This is used to track whether the pre-signed urls stored in sortedFetchedFiles are going to - // expire and need a refresh. - private var lastQueryTableTimestamp: Long = -1 + private var latestRefreshFunc = () => { (Map.empty[String, String], None: Option[Long]) } // Check the latest table version from the delta sharing server through the client.getTableVersion // RPC. Adding a minimum interval of QUERY_TABLE_VERSION_INTERVAL_MILLIS between two consecutive @@ -180,8 +182,21 @@ case class DeltaSharingSource( f1.id < f2.id } - private def appendToSortedFetchedFiles(indexedFile: IndexedFile): Unit = { - sortedFetchedFiles = sortedFetchedFiles :+ indexedFile + private def appendToSortedFetchedFiles( + indexedFile: IndexedFile, + urlExpirationTimestamp: java.lang.Long = null + ): Unit = { + synchronized { + sortedFetchedFiles = sortedFetchedFiles :+ indexedFile + if (urlExpirationTimestamp != null) { + minUrlExpirationTimestamp = if (minUrlExpirationTimestamp.isDefined && + minUrlExpirationTimestamp.get < urlExpirationTimestamp) { + minUrlExpirationTimestamp + } else { + Some(urlExpirationTimestamp) + } + } + } } /** @@ -225,6 +240,110 @@ case class DeltaSharingSource( } } + private def resetGlobalTimestamp(): Unit = { + synchronized { + lastQueryTableTimestamp = System.currentTimeMillis() + minUrlExpirationTimestamp = None + } + } + + // Validate the minimum url expiration timestamp, and set it to None if it's invalid. + // It's considered valid only when it gives enough time for the client to read data out of the + // pre-signed url. If not valid, we will use the spark config to decide the refresh schedule, and + // if the url expired before that, we'll leverage the driver log to debug why the expiration + // timestamp is invalid. + private def validateMinUrlExpirationTimestamp(inputTimestamp: Option[Long] = None): Unit = { + synchronized { + if (inputTimestamp.isDefined) { + minUrlExpirationTimestamp = inputTimestamp + } + if (!CachedTableManager.INSTANCE.isValidUrlExpirationTime(minUrlExpirationTimestamp)) { + // reset to None to indicate that it's not a valid url expiration timestamp. + minUrlExpirationTimestamp = None + } + } + } + + // Pop a list of file actions from the local sortedFetchedFiles, until the given endOffset, to + // be processed by the micro batch. + // (fileActions, lastQueryTableTimestamp, minUrlExpirationTimestamp) are returned together within + // a single synchronized wrap, to avoid using old urls with refreshed timestamps when a refresh + // happens after this function and before register(). + private def popSortedFetchedFiles( + endOffset: DeltaSharingSourceOffset): (Seq[IndexedFile], Long, Option[Long]) = { + synchronized { + val fileActions = sortedFetchedFiles.takeWhile { + case IndexedFile(version, index, _, _, _, _) => + version < endOffset.tableVersion || + (version == endOffset.tableVersion && index <= endOffset.index) + } + sortedFetchedFiles = sortedFetchedFiles.drop(fileActions.size) + (fileActions, lastQueryTableTimestamp, minUrlExpirationTimestamp) + } + } + + // Function to be called in latestRefreshFunc, to refresh the pre-signed urls in + // sortedFetchedFiles. newIdToUrl contains the refreshed urls. + private def refreshSortedFetchedFiles( + newIdToUrl: Map[String, String], + queryTimestamp: Long, + newMinUrlExpiration: Option[Long] + ): Unit = { + synchronized { + logInfo(s"Refreshing sortedFetchedFiles(size: ${sortedFetchedFiles.size}) with newIdToUrl(" + + s"size: ${newIdToUrl.size}).") + lastQueryTableTimestamp = queryTimestamp + minUrlExpirationTimestamp = newMinUrlExpiration + if (!CachedTableManager.INSTANCE.isValidUrlExpirationTime(minUrlExpirationTimestamp)) { + // reset to None to indicate that it's not a valid url expiration timestamp. + minUrlExpirationTimestamp = None + } + var numUrlsRefreshed = 0 + sortedFetchedFiles = sortedFetchedFiles.map { indexedFile => + IndexedFile( + version = indexedFile.version, + index = indexedFile.index, + add = if (indexedFile.add == null) { + null + } else { + numUrlsRefreshed += 1 + val newUrl = newIdToUrl.getOrElse( + indexedFile.add.id, + throw new IllegalStateException(s"cannot find url for id ${indexedFile.add.id} " + + s"when refreshing table ${deltaLog.path}") + ) + indexedFile.add.copy(url = newUrl) + }, + remove = if (indexedFile.remove == null) { + null + } else { + numUrlsRefreshed += 1 + val newUrl = newIdToUrl.getOrElse( + indexedFile.remove.id, + throw new IllegalStateException(s"cannot find url for id ${indexedFile.remove.id} " + + s"when refreshing table ${deltaLog.path}") + ) + indexedFile.remove.copy(url = newUrl) + }, + cdc = if (indexedFile.cdc == null) { + null + } else { + numUrlsRefreshed += 1 + val newUrl = newIdToUrl.getOrElse( + indexedFile.cdc.id, + throw new IllegalStateException(s"cannot find url for id ${indexedFile.cdc.id} " + + s"when refreshing table ${deltaLog.path}") + ) + indexedFile.cdc.copy(url = newUrl) + }, + isLast = indexedFile.isLast + ) + } + logInfo(s"Refreshed ${numUrlsRefreshed} urls in sortedFetchedFiles(size: " + + s"${sortedFetchedFiles.size}).") + } + } + /** * Fetch the table changes from delta sharing server starting from (fromVersion, fromIndex), and * store them in sortedFetchedFiles. @@ -246,7 +365,7 @@ case class DeltaSharingSource( fromIndex: Long, isStartingVersion: Boolean, endingVersionForQuery: Long): Unit = { - lastQueryTableTimestamp = System.currentTimeMillis() + resetGlobalTimestamp() if (isStartingVersion) { // If isStartingVersion is true, it means to fetch the snapshot at the fromVersion, which may // include table changes from previous versions. @@ -254,11 +373,26 @@ case class DeltaSharingSource( deltaLog.table, Nil, None, Some(fromVersion), None, None ) latestRefreshFunc = () => { - deltaLog.client.getFiles( + val queryTimestamp = System.currentTimeMillis() + val files = deltaLog.client.getFiles( deltaLog.table, Nil, None, Some(fromVersion), None, None - ).files.map { f => + ).files + var minUrlExpiration: Option[Long] = None + val idToUrl = files.map { f => + if (f.expirationTimestamp != null) { + minUrlExpiration = if (minUrlExpiration.isDefined && + minUrlExpiration.get < f.expirationTimestamp) { + minUrlExpiration + } else { + Some(f.expirationTimestamp) + } + } f.id -> f.url }.toMap + + refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) + + (idToUrl, minUrlExpiration) } val numFiles = tableFiles.files.size @@ -275,9 +409,13 @@ case class DeltaSharingSource( file.size, fromVersion, file.timestamp, - file.stats + file.stats, + file.expirationTimestamp ), - isLast = (index + 1 == numFiles))) + isLast = (index + 1 == numFiles) + ), + file.expirationTimestamp + ) // For files with index <= fromIndex, skip them, otherwise an exception will be thrown. case _ => () } @@ -288,27 +426,44 @@ case class DeltaSharingSource( deltaLog.table, fromVersion, Some(endingVersionForQuery) ) latestRefreshFunc = () => { - deltaLog.client.getFiles( + val queryTimestamp = System.currentTimeMillis() + val addFiles = deltaLog.client.getFiles( deltaLog.table, fromVersion, Some(endingVersionForQuery) - ).addFiles.map { a => + ).addFiles + var minUrlExpiration: Option[Long] = None + val idToUrl = addFiles.map { a => + if (a.expirationTimestamp != null) { + minUrlExpiration = if (minUrlExpiration.isDefined && + minUrlExpiration.get < a.expirationTimestamp) { + minUrlExpiration + } else { + Some(a.expirationTimestamp) + } + } a.id -> a.url }.toMap + + refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) + + (idToUrl, minUrlExpiration) } val allAddFiles = validateCommitAndFilterAddFiles(tableFiles).groupBy(a => a.version) for (v <- fromVersion to endingVersionForQuery) { - val vAddFiles = allAddFiles.getOrElse(v, ArrayBuffer[AddFileForCDF]()) val numFiles = vAddFiles.size appendToSortedFetchedFiles(IndexedFile(v, -1, add = null, isLast = (numFiles == 0))) vAddFiles.sortWith(fileActionCompareFunc).zipWithIndex.foreach { case (add, index) if (v > fromVersion || (v == fromVersion && index > fromIndex)) => appendToSortedFetchedFiles( - IndexedFile(add.version, index, add, isLast = (index + 1 == numFiles))) + IndexedFile(add.version, index, add, isLast = (index + 1 == numFiles)), + add.expirationTimestamp + ) // For files with v <= fromVersion, skip them, otherwise an exception will be thrown. case _ => () } } } + validateMinUrlExpirationTimestamp() } /** @@ -328,7 +483,7 @@ case class DeltaSharingSource( fromVersion: Long, fromIndex: Long, endingVersionForQuery: Long): Unit = { - lastQueryTableTimestamp = System.currentTimeMillis() + resetGlobalTimestamp() val tableFiles = deltaLog.client.getCDFFiles( deltaLog.table, Map( @@ -338,6 +493,7 @@ case class DeltaSharingSource( true ) latestRefreshFunc = () => { + val queryTimestamp = System.currentTimeMillis() val d = deltaLog.client.getCDFFiles( deltaLog.table, Map( @@ -346,7 +502,17 @@ case class DeltaSharingSource( ), true ) - DeltaSharingCDFReader.getIdToUrl(d.addFiles, d.cdfFiles, d.removeFiles) + + val idToUrl = DeltaSharingCDFReader.getIdToUrl(d.addFiles, d.cdfFiles, d.removeFiles) + + val minUrlExpiration = DeltaSharingCDFReader.getMinUrlExpiration( + d.addFiles, d.cdfFiles, d.removeFiles) + refreshSortedFetchedFiles(idToUrl, queryTimestamp, minUrlExpiration) + + ( + idToUrl, + minUrlExpiration + ) } (Seq(tableFiles.metadata) ++ tableFiles.additionalMetadatas).foreach { m => @@ -355,6 +521,12 @@ case class DeltaSharingSource( throw DeltaSharingErrors.schemaChangedException(schema, schemaToCheck) } } + val cdfUrlExpirationTimestamp = DeltaSharingCDFReader.getMinUrlExpiration( + tableFiles.addFiles, + tableFiles.cdfFiles, + tableFiles.removeFiles + ) + validateMinUrlExpirationTimestamp(cdfUrlExpirationTimestamp) val perVersionAddFiles = tableFiles.addFiles.groupBy(f => f.version) val perVersionCdfFiles = tableFiles.cdfFiles.groupBy(f => f.version) @@ -493,57 +665,7 @@ case class DeltaSharingSource( endOffset: DeltaSharingSourceOffset): DataFrame = { maybeGetFileChanges(startVersion, startIndex, isStartingVersion) - if (refreshPresignedUrls && - (CachedTableManager.INSTANCE.preSignedUrlExpirationMs + lastQueryTableTimestamp - - System.currentTimeMillis() < CachedTableManager.INSTANCE.refreshThresholdMs)) { - // force a refresh if needed. - lastQueryTableTimestamp = System.currentTimeMillis() - val newIdToUrl = latestRefreshFunc() - sortedFetchedFiles = sortedFetchedFiles.map { indexedFile => - IndexedFile( - version = indexedFile.version, - index = indexedFile.index, - add = if (indexedFile.add == null) { - null - } else { - val newUrl = newIdToUrl.getOrElse( - indexedFile.add.id, - throw new IllegalStateException(s"cannot find url for id ${indexedFile.add.id} " + - s"when refreshing table ${deltaLog.path}") - ) - indexedFile.add.copy(url = newUrl) - }, - remove = if (indexedFile.remove == null) { - null - } else { - val newUrl = newIdToUrl.getOrElse( - indexedFile.remove.id, - throw new IllegalStateException(s"cannot find url for id ${indexedFile.remove.id} " + - s"when refreshing table ${deltaLog.path}") - ) - indexedFile.remove.copy(url = newUrl) - }, - cdc = if (indexedFile.cdc == null) { - null - } else { - val newUrl = newIdToUrl.getOrElse( - indexedFile.cdc.id, - throw new IllegalStateException(s"cannot find url for id ${indexedFile.cdc.id} " + - s"when refreshing table ${deltaLog.path}") - ) - indexedFile.cdc.copy(url = newUrl) - }, - isLast = indexedFile.isLast - ) - } - } - - val fileActions = sortedFetchedFiles.takeWhile { - case IndexedFile(version, index, _, _, _, _) => - version < endOffset.tableVersion || - (version == endOffset.tableVersion && index <= endOffset.index) - } - sortedFetchedFiles = sortedFetchedFiles.drop(fileActions.size) + val (fileActions, lastQueryTimestamp, urlExpirationTimestamp) = popSortedFetchedFiles(endOffset) // Proceed the offset as the files before the endOffset are processed. previousOffset = endOffset @@ -553,10 +675,10 @@ case class DeltaSharingSource( val filteredActions = fileActions.filter{ indexedFile => indexedFile.getFileAction != null } if (options.readChangeFeed) { - return createCDFDataFrame(filteredActions) + return createCDFDataFrame(filteredActions, lastQueryTimestamp, urlExpirationTimestamp) } - createDataFrame(filteredActions) + createDataFrame(filteredActions, lastQueryTimestamp, urlExpirationTimestamp) } /** @@ -564,7 +686,10 @@ case class DeltaSharingSource( * Only AddFile actions will be used to create the DataFrame. * @param indexedFiles actions list from which to generate the DataFrame. */ - private def createDataFrame(indexedFiles: Seq[IndexedFile]): DataFrame = { + private def createDataFrame( + indexedFiles: Seq[IndexedFile], + lastQueryTimestamp: Long, + urlExpirationTimestamp: Option[Long]): DataFrame = { val addFilesList = indexedFiles.map { indexedFile => // add won't be null at this step as addFile is the only interested file when // options.readChangeFeed is false, which is when this function is called. @@ -584,9 +709,15 @@ case class DeltaSharingSource( idToUrl, Seq(new WeakReference(fileIndex)), params.profileProvider, - latestRefreshFunc + latestRefreshFunc, + if (urlExpirationTimestamp.isDefined) { + urlExpirationTimestamp.get + } else { + lastQueryTimestamp + CachedTableManager.INSTANCE.preSignedUrlExpirationMs + } ) + val relation = HadoopFsRelation( fileIndex, partitionSchema = initSnapshot.partitionSchema, @@ -603,7 +734,10 @@ case class DeltaSharingSource( * table. * @param indexedFiles actions list from which to generate the DataFrame. */ - private def createCDFDataFrame(indexedFiles: Seq[IndexedFile]): DataFrame = { + private def createCDFDataFrame( + indexedFiles: Seq[IndexedFile], + lastQueryTimestamp: Long, + urlExpirationTimestamp: Option[Long]): DataFrame = { val addFiles = ArrayBuffer[AddFileForCDF]() val cdfFiles = ArrayBuffer[AddCDCFile]() val removeFiles = ArrayBuffer[RemoveFile]() @@ -625,7 +759,8 @@ case class DeltaSharingSource( schema, isStreaming = true, latestRefreshFunc, - lastQueryTableTimestamp + lastQueryTimestamp, + urlExpirationTimestamp ) } diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala index 994cbeb16..399097bed 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaCDFRelation.scala @@ -57,8 +57,18 @@ case class RemoteDeltaCDFRelation( false, () => { val d = client.getCDFFiles(table, cdfOptions, false) - DeltaSharingCDFReader.getIdToUrl(d.addFiles, d.cdfFiles, d.removeFiles) - }).rdd + ( + DeltaSharingCDFReader.getIdToUrl(d.addFiles, d.cdfFiles, d.removeFiles), + DeltaSharingCDFReader.getMinUrlExpiration(d.addFiles, d.cdfFiles, d.removeFiles) + ) + }, + System.currentTimeMillis(), + DeltaSharingCDFReader.getMinUrlExpiration( + deltaTabelFiles.addFiles, + deltaTabelFiles.cdfFiles, + deltaTabelFiles.removeFiles + ) + ).rdd } } @@ -71,8 +81,9 @@ object DeltaSharingCDFReader { removeFiles: Seq[RemoveFile], schema: StructType, isStreaming: Boolean, - refresher: () => Map[String, String], - lastQueryTableTimestamp: Long = System.currentTimeMillis() + refresher: () => (Map[String, String], Option[Long]), + lastQueryTableTimestamp: Long, + expirationTimestamp: Option[Long] ): DataFrame = { val dfs = ListBuffer[DataFrame]() val refs = ListBuffer[WeakReference[AnyRef]]() @@ -95,7 +106,11 @@ object DeltaSharingCDFReader { refs, params.profileProvider, refresher, - lastQueryTableTimestamp + if (expirationTimestamp.isDefined) { + expirationTimestamp.get + } else { + lastQueryTableTimestamp + CachedTableManager.INSTANCE.preSignedUrlExpirationMs + } ) dfs.reduce((df1, df2) => df1.unionAll(df2)) @@ -111,6 +126,49 @@ object DeltaSharingCDFReader { removeFiles.map(r => r.id -> r.url).toMap } + // Get the minimum url expiration time across all the cdf files returned from the server. + def getMinUrlExpiration( + addFiles: Seq[AddFileForCDF], + cdfFiles: Seq[AddCDCFile], + removeFiles: Seq[RemoveFile] + ): Option[Long] = { + var minUrlExpiration: Option[Long] = None + addFiles.foreach { a => + if (a.expirationTimestamp != null) { + minUrlExpiration = if ( + minUrlExpiration.isDefined && minUrlExpiration.get < a.expirationTimestamp) { + minUrlExpiration + } else { + Some(a.expirationTimestamp) + } + } + } + cdfFiles.foreach { c => + if (c.expirationTimestamp != null) { + minUrlExpiration = if ( + minUrlExpiration.isDefined && minUrlExpiration.get < c.expirationTimestamp) { + minUrlExpiration + } else { + Some(c.expirationTimestamp) + } + } + } + removeFiles.foreach { r => + if (r.expirationTimestamp != null) { + minUrlExpiration = if ( + minUrlExpiration.isDefined && minUrlExpiration.get < r.expirationTimestamp) { + minUrlExpiration + } else { + Some(r.expirationTimestamp) + } + } + } + if (!CachedTableManager.INSTANCE.isValidUrlExpirationTime(minUrlExpiration)) { + minUrlExpiration = None + } + minUrlExpiration + } + private def quoteIdentifier(part: String): String = s"`${part.replace("`", "``")}`" /** diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaFileIndex.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaFileIndex.scala index 10464459a..befdc5d90 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaFileIndex.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaFileIndex.scala @@ -184,7 +184,8 @@ private[sharing] case class RemoteDeltaCDFAddFileIndex( // makePartitionDirectories and partitionSchema. So that partitionFilters can be correctly // applied. val updatedFiles = addFiles.map { a => - AddFileForCDF(a.url, a.id, a.getPartitionValuesInDF, a.size, a.version, a.timestamp, a.stats) + AddFileForCDF(a.url, a.id, a.getPartitionValuesInDF, a.size, a.version, a.timestamp, a.stats, + a.expirationTimestamp) } val columnFilter = getColumnFilter(partitionFilters) val implicits = params.spark.implicits @@ -208,7 +209,8 @@ private[sharing] case class RemoteDeltaCDCFileIndex( // makePartitionDirectories and partitionSchema. So that partitionFilters can be correctly // applied. val updatedFiles = cdfFiles.map { c => - AddCDCFile(c.url, c.id, c.getPartitionValuesInDF, c.size, c.version, c.timestamp) + AddCDCFile(c.url, c.id, c.getPartitionValuesInDF, c.size, c.version, c.timestamp, + c.expirationTimestamp) } val columnFilter = getColumnFilter(partitionFilters) val implicits = params.spark.implicits @@ -231,7 +233,8 @@ private[sharing] case class RemoteDeltaCDFRemoveFileIndex( // makePartitionDirectories and partitionSchema. So that partitionFilters can be correctly // applied. val updatedFiles = removeFiles.map { r => - RemoveFile(r.url, r.id, r.getPartitionValuesInDF, r.size, r.version, r.timestamp) + RemoveFile(r.url, r.id, r.getPartitionValuesInDF, r.size, r.version, r.timestamp, + r.expirationTimestamp) } val columnFilter = getColumnFilter(partitionFilters) val implicits = params.spark.implicits diff --git a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala index 8634eb018..f1a8f29fb 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/RemoteDeltaLog.scala @@ -285,7 +285,16 @@ class RemoteSnapshot( val tableFiles = client.getFiles( table, predicates, limitHint, versionAsOf, timestampAsOf, jsonPredicateHints ) + var minUrlExpirationTimestamp: Option[Long] = None val idToUrl = tableFiles.files.map { file => + if (file.expirationTimestamp != null) { + minUrlExpirationTimestamp = if (minUrlExpirationTimestamp.isDefined && + minUrlExpirationTimestamp.get < file.expirationTimestamp) { + minUrlExpirationTimestamp + } else { + Some(file.expirationTimestamp) + } + } file.id -> file.url }.toMap CachedTableManager.INSTANCE @@ -295,10 +304,26 @@ class RemoteSnapshot( Seq(new WeakReference(fileIndex)), fileIndex.params.profileProvider, () => { - client.getFiles(table, Nil, None, versionAsOf, timestampAsOf, jsonPredicateHints) - .files.map { add => + val files = client.getFiles( + table, Nil, None, versionAsOf, timestampAsOf, jsonPredicateHints).files + var minUrlExpiration: Option[Long] = None + val idToUrl = files.map { add => + if (add.expirationTimestamp != null) { + minUrlExpiration = if (minUrlExpiration.isDefined + && minUrlExpiration.get < add.expirationTimestamp) { + minUrlExpiration + } else { + Some(add.expirationTimestamp) + } + } add.id -> add.url }.toMap + (idToUrl, minUrlExpiration) + }, + if (CachedTableManager.INSTANCE.isValidUrlExpirationTime(minUrlExpirationTimestamp)) { + minUrlExpirationTimestamp.get + } else { + System.currentTimeMillis() + CachedTableManager.INSTANCE.preSignedUrlExpirationMs } ) checkProtocolNotChange(tableFiles.protocol) diff --git a/spark/src/main/scala/io/delta/sharing/spark/model.scala b/spark/src/main/scala/io/delta/sharing/spark/model.scala index 304ee95df..3092e5659 100644 --- a/spark/src/main/scala/io/delta/sharing/spark/model.scala +++ b/spark/src/main/scala/io/delta/sharing/spark/model.scala @@ -135,7 +135,8 @@ private[sharing] case class AddFile( @JsonRawValue stats: String = null, version: java.lang.Long = null, - timestamp: java.lang.Long = null) extends FileAction(url, id, partitionValues, size) { + timestamp: java.lang.Long = null, + expirationTimestamp: java.lang.Long = null) extends FileAction(url, id, partitionValues, size) { override def wrap: SingleAction = SingleAction(file = this) } @@ -149,7 +150,8 @@ private[sharing] case class AddFileForCDF( version: Long, timestamp: Long, @JsonRawValue - stats: String = null) extends FileAction(url, id, partitionValues, size) { + stats: String = null, + expirationTimestamp: java.lang.Long = null) extends FileAction(url, id, partitionValues, size) { override def wrap: SingleAction = SingleAction(add = this) @@ -170,7 +172,8 @@ private[sharing] case class AddCDCFile( override val partitionValues: Map[String, String], override val size: Long, version: Long, - timestamp: Long) extends FileAction(url, id, partitionValues, size) { + timestamp: Long, + expirationTimestamp: java.lang.Long = null) extends FileAction(url, id, partitionValues, size) { override def wrap: SingleAction = SingleAction(cdf = this) @@ -190,7 +193,8 @@ private[sharing] case class RemoveFile( override val partitionValues: Map[String, String], override val size: Long, version: Long, - timestamp: Long) extends FileAction(url, id, partitionValues, size) { + timestamp: Long, + expirationTimestamp: java.lang.Long = null) extends FileAction(url, id, partitionValues, size) { override def wrap: SingleAction = SingleAction(remove = this) diff --git a/spark/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala b/spark/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala index 2e52128cb..a3230bfb9 100644 --- a/spark/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala +++ b/spark/src/main/scala/org/apache/spark/delta/sharing/PreSignedUrlCache.scala @@ -35,14 +35,15 @@ import io.delta.sharing.spark.DeltaSharingProfileProvider * remove the cached table from our cache. * @param lastAccess When the table was accessed last time. We will remove old tables that are not * accessed after `expireAfterAccessMs` milliseconds. - * @param refresher the function to generate a new file id to pre sign url map. + * @param refresher the function to generate a new file id to pre sign url map, as long as the new + * expiration timestamp of the urls. */ class CachedTable( val expiration: Long, val idToUrl: Map[String, String], val refs: Seq[WeakReference[AnyRef]], @volatile var lastAccess: Long, - val refresher: () => Map[String, String]) + val refresher: () => (Map[String, String], Option[Long])) class CachedTableManager( val preSignedUrlExpirationMs: Long, @@ -63,6 +64,20 @@ class CachedTableManager( thread } + def isValidUrlExpirationTime(expiration: Option[Long]): Boolean = { + // refreshThresholdMs is the buffer time for the refresh RPC. + // It could also help the client from keeping refreshing endlessly. + val isValid = expiration.isDefined && ( + expiration.get > (System.currentTimeMillis() + refreshThresholdMs)) + if (!isValid && expiration.isDefined) { + val currentTs = System.currentTimeMillis() + logWarning(s"Invalid url expiration timestamp(${expiration}, " + + s"${new java.util.Date(expiration.get)}), refreshThresholdMs:$refreshThresholdMs, " + + s"current timestamp(${currentTs}, ${new java.util.Date(currentTs)}).") + } + isValid + } + def refresh(): Unit = { import scala.collection.JavaConverters._ val snapshot = cache.entrySet().asScala.toArray @@ -81,9 +96,14 @@ class CachedTableManager( logInfo(s"Updating pre signed urls for $tablePath (expiration time: " + s"${new java.util.Date(cachedTable.expiration)})") try { + val (idToUrl, expOpt) = cachedTable.refresher() val newTable = new CachedTable( - preSignedUrlExpirationMs + System.currentTimeMillis(), - cachedTable.refresher(), + if (isValidUrlExpirationTime(expOpt)) { + expOpt.get + } else { + preSignedUrlExpirationMs + System.currentTimeMillis() + }, + idToUrl, cachedTable.refs, cachedTable.lastAccess, cachedTable.refresher @@ -133,36 +153,36 @@ class CachedTableManager( * still needed. When all the weak references return null, we will remove the pre * signed url cache of this table form the cache. * @param profileProvider a profile Provider that can provide customized refresher function. - * @param refresher A function to re-generate pre signed urls for the table. - * @param lastQueryTableTimestamp A timestamp to indicate the last time the idToUrl mapping is - * generated, to refresh the urls in time based on it. + * @param refresher A function to re-generate pre signed urls for the table. + * @param expirationTimestamp Optional, If set, it's a timestamp to indicate the expiration + * timestamp of the idToUrl. */ def register( tablePath: String, idToUrl: Map[String, String], refs: Seq[WeakReference[AnyRef]], profileProvider: DeltaSharingProfileProvider, - refresher: () => Map[String, String], - lastQueryTableTimestamp: Long = System.currentTimeMillis()): Unit = { + refresher: () => (Map[String, String], Option[Long]), + expirationTimestamp: Long = System.currentTimeMillis() + preSignedUrlExpirationMs + ): Unit = { val customTablePath = profileProvider.getCustomTablePath(tablePath) val customRefresher = profileProvider.getCustomRefresher(refresher) - val cachedTable = new CachedTable( - if (preSignedUrlExpirationMs + lastQueryTableTimestamp - System.currentTimeMillis() < - refreshThresholdMs) { - // If there is a refresh, start counting from now. - preSignedUrlExpirationMs + System.currentTimeMillis() - } else { - // Otherwise, start counting from lastQueryTableTimestamp. - preSignedUrlExpirationMs + lastQueryTableTimestamp - }, - idToUrl = if (preSignedUrlExpirationMs + lastQueryTableTimestamp - System.currentTimeMillis() - < refreshThresholdMs) { - // force a refresh upon register - customRefresher() + val (resolvedIdToUrl, resolvedExpiration) = + if (expirationTimestamp - System.currentTimeMillis() < refreshThresholdMs) { + val (refreshedIdToUrl, expOpt) = customRefresher() + if (isValidUrlExpirationTime(expOpt)) { + (refreshedIdToUrl, expOpt.get) + } else { + (refreshedIdToUrl, System.currentTimeMillis() + preSignedUrlExpirationMs) + } } else { - idToUrl - }, + (idToUrl, expirationTimestamp) + } + + val cachedTable = new CachedTable( + resolvedExpiration, + idToUrl = resolvedIdToUrl, refs, System.currentTimeMillis(), customRefresher diff --git a/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingRestClientSuite.scala b/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingRestClientSuite.scala index df25dfc4f..8e4dadde6 100644 --- a/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingRestClientSuite.scala +++ b/spark/src/test/scala/io/delta/sharing/spark/DeltaSharingRestClientSuite.scala @@ -164,6 +164,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { val expectedFiles = Seq( AddFile( url = tableFiles.files(0).url, + expirationTimestamp = tableFiles.files(0).expirationTimestamp, id = "9f1a49539c5cffe1ea7f9e055d5c003c", partitionValues = Map("date" -> "2021-04-28"), size = 573, @@ -171,6 +172,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ), AddFile( url = tableFiles.files(1).url, + expirationTimestamp = tableFiles.files(1).expirationTimestamp, id = "cd2209b32f5ed5305922dd50f5908a75", partitionValues = Map("date" -> "2021-04-28"), size = 573, @@ -178,6 +180,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ) ) assert(expectedFiles == tableFiles.files.toList) + assert(tableFiles.files(0).expirationTimestamp > System.currentTimeMillis()) } finally { client.close() } @@ -198,6 +201,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { val expectedFiles = Seq( AddFile( url = tableFiles.files(0).url, + expirationTimestamp = tableFiles.files(0).expirationTimestamp, id = "60d0cf57f3e4367db154aa2c36152a1f", partitionValues = Map.empty, size = 1030, @@ -207,6 +211,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ), AddFile( url = tableFiles.files(1).url, + expirationTimestamp = tableFiles.files(1).expirationTimestamp, id = "d7ed708546dd70fdff9191b3e3d6448b", partitionValues = Map.empty, size = 1030, @@ -216,6 +221,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ), AddFile( url = tableFiles.files(2).url, + expirationTimestamp = tableFiles.files(2).expirationTimestamp, id = "a6dc5694a4ebcc9a067b19c348526ad6", partitionValues = Map.empty, size = 1030, @@ -225,6 +231,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ) ) assert(expectedFiles == tableFiles.files.toList) + assert(tableFiles.files(0).expirationTimestamp > System.currentTimeMillis()) } finally { client.close() } @@ -301,6 +308,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { val expectedAddFiles = Seq( AddFileForCDF( url = tableFiles.addFiles(0).url, + expirationTimestamp = tableFiles.addFiles(0).expirationTimestamp, id = "60d0cf57f3e4367db154aa2c36152a1f", partitionValues = Map.empty, size = 1030, @@ -310,6 +318,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ), AddFileForCDF( url = tableFiles.addFiles(1).url, + expirationTimestamp = tableFiles.addFiles(1).expirationTimestamp, id = "a6dc5694a4ebcc9a067b19c348526ad6", partitionValues = Map.empty, size = 1030, @@ -319,6 +328,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ), AddFileForCDF( url = tableFiles.addFiles(2).url, + expirationTimestamp = tableFiles.addFiles(2).expirationTimestamp, id = "d7ed708546dd70fdff9191b3e3d6448b", partitionValues = Map.empty, size = 1030, @@ -328,6 +338,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ), AddFileForCDF( url = tableFiles.addFiles(3).url, + expirationTimestamp = tableFiles.addFiles(3).expirationTimestamp, id = "b875623be22c1fa1dfdeb0480fae6117", partitionValues = Map.empty, size = 1247, @@ -337,11 +348,13 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ) ) assert(expectedAddFiles == tableFiles.addFiles.toList) + assert(tableFiles.addFiles(0).expirationTimestamp > System.currentTimeMillis()) assert(tableFiles.removeFiles.size == 2) val expectedRemoveFiles = Seq( RemoveFile( url = tableFiles.removeFiles(0).url, + expirationTimestamp = tableFiles.removeFiles(0).expirationTimestamp, id = "d7ed708546dd70fdff9191b3e3d6448b", partitionValues = Map.empty, size = 1030, @@ -350,6 +363,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ), RemoveFile( url = tableFiles.removeFiles(1).url, + expirationTimestamp = tableFiles.removeFiles(1).expirationTimestamp, id = "a6dc5694a4ebcc9a067b19c348526ad6", partitionValues = Map.empty, size = 1030, @@ -358,6 +372,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ) ) assert(expectedRemoveFiles == tableFiles.removeFiles.toList) + assert(tableFiles.removeFiles(0).expirationTimestamp > System.currentTimeMillis()) assert(tableFiles.additionalMetadatas.size == 2) val v4Metadata = Metadata( @@ -390,6 +405,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { val expectedAddFiles = Seq( AddFileForCDF( url = tableFiles.addFiles(0).url, + expirationTimestamp = tableFiles.addFiles(0).expirationTimestamp, id = "60d0cf57f3e4367db154aa2c36152a1f", partitionValues = Map.empty, size = 1030, @@ -399,6 +415,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ), AddFileForCDF( url = tableFiles.addFiles(1).url, + expirationTimestamp = tableFiles.addFiles(1).expirationTimestamp, id = "a6dc5694a4ebcc9a067b19c348526ad6", partitionValues = Map.empty, size = 1030, @@ -408,6 +425,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ), AddFileForCDF( url = tableFiles.addFiles(2).url, + expirationTimestamp = tableFiles.addFiles(2).expirationTimestamp, id = "d7ed708546dd70fdff9191b3e3d6448b", partitionValues = Map.empty, size = 1030, @@ -433,6 +451,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { val expectedAddFiles = Seq( AddFileForCDF( url = tableFiles.addFiles(0).url, + expirationTimestamp = tableFiles.addFiles(0).expirationTimestamp, id = "b875623be22c1fa1dfdeb0480fae6117", partitionValues = Map.empty, size = 1247, @@ -447,6 +466,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { val expectedRemoveFiles = Seq( RemoveFile( url = tableFiles.removeFiles(0).url, + expirationTimestamp = tableFiles.removeFiles(0).expirationTimestamp, id = "d7ed708546dd70fdff9191b3e3d6448b", partitionValues = Map.empty, size = 1030, @@ -455,6 +475,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ), RemoveFile( url = tableFiles.removeFiles(1).url, + expirationTimestamp = tableFiles.removeFiles(1).expirationTimestamp, id = "a6dc5694a4ebcc9a067b19c348526ad6", partitionValues = Map.empty, size = 1030, @@ -550,6 +571,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { val expectedCdfFiles = Seq( AddCDCFile( url = tableFiles.cdfFiles(0).url, + expirationTimestamp = tableFiles.cdfFiles(0).expirationTimestamp, id = "6521ba910108d4b54d27beaa9fc2373f", partitionValues = Map.empty, size = 1301, @@ -558,6 +580,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ), AddCDCFile( url = tableFiles.cdfFiles(1).url, + expirationTimestamp = tableFiles.cdfFiles(1).expirationTimestamp, id = "2508998dce55bd726369e53761c4bc3f", partitionValues = Map.empty, size = 1416, @@ -566,10 +589,13 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ) ) assert(expectedCdfFiles == tableFiles.cdfFiles.toList) + assert(tableFiles.cdfFiles(1).expirationTimestamp > System.currentTimeMillis()) + assert(tableFiles.addFiles.size == 3) val expectedAddFiles = Seq( AddFileForCDF( url = tableFiles.addFiles(0).url, + expirationTimestamp = tableFiles.addFiles(0).expirationTimestamp, id = "60d0cf57f3e4367db154aa2c36152a1f", partitionValues = Map.empty, size = 1030, @@ -579,6 +605,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ), AddFileForCDF( url = tableFiles.addFiles(1).url, + expirationTimestamp = tableFiles.addFiles(1).expirationTimestamp, id = "a6dc5694a4ebcc9a067b19c348526ad6", partitionValues = Map.empty, size = 1030, @@ -588,6 +615,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ), AddFileForCDF( url = tableFiles.addFiles(2).url, + expirationTimestamp = tableFiles.addFiles(2).expirationTimestamp, id = "d7ed708546dd70fdff9191b3e3d6448b", partitionValues = Map.empty, size = 1030, @@ -597,6 +625,7 @@ class DeltaSharingRestClientSuite extends DeltaSharingIntegrationTest { ) ) assert(expectedAddFiles == tableFiles.addFiles.toList) + assert(tableFiles.addFiles(0).expirationTimestamp > System.currentTimeMillis()) } finally { client.close() } diff --git a/spark/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala b/spark/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala index 51c9be90c..8c0ad4013 100644 --- a/spark/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala +++ b/spark/src/test/scala/org/apache/spark/delta/sharing/CachedTableManagerSuite.scala @@ -42,7 +42,7 @@ class CachedTableManagerSuite extends SparkFunSuite { Seq(new WeakReference(ref)), provider, () => { - Map("id1" -> "url1", "id2" -> "url2") + (Map("id1" -> "url1", "id2" -> "url2"), None) }) assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"), "id1")._1 == "url1") @@ -55,7 +55,7 @@ class CachedTableManagerSuite extends SparkFunSuite { Seq(new WeakReference(ref)), provider, () => { - Map("id1" -> "url3", "id2" -> "url4") + (Map("id1" -> "url3", "id2" -> "url4"), None) }) // We should get the new urls eventually eventually(timeout(10.seconds)) { @@ -71,7 +71,7 @@ class CachedTableManagerSuite extends SparkFunSuite { Seq(new WeakReference(new AnyRef)), provider, () => { - Map("id1" -> "url3", "id2" -> "url4") + (Map("id1" -> "url3", "id2" -> "url4"), None) }) // We should remove the cached table eventually eventually(timeout(10.seconds)) { @@ -88,7 +88,7 @@ class CachedTableManagerSuite extends SparkFunSuite { Seq(new WeakReference(ref)), provider, () => { - Map("id1" -> "url3", "id2" -> "url4") + (Map("id1" -> "url3", "id2" -> "url4"), None) }, -1 ) @@ -102,6 +102,95 @@ class CachedTableManagerSuite extends SparkFunSuite { } } + test("refresh based on url expiration") { + val manager = new CachedTableManager( + preSignedUrlExpirationMs = 6000, + refreshCheckIntervalMs = 1000, + refreshThresholdMs = 1000, + expireAfterAccessMs = 60000 + ) + try { + val ref = new AnyRef + val provider = new TestDeltaSharingProfileProvider + var refreshTime = 0 + manager.register( + "test-table-path", + Map("id1" -> "url1", "id2" -> "url2"), + Seq(new WeakReference(ref)), + provider, + () => { + refreshTime += 1 + ( + Map("id1" -> ("url" + refreshTime.toString), "id2" -> "url4"), + Some(System.currentTimeMillis() + 1900) + ) + }, + System.currentTimeMillis() + 1900 + ) + // We should refresh at least 5 times within 10 seconds based on + // (System.currentTimeMillis() + 1900). + eventually(timeout(10.seconds)) { + assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"), + "id1")._1 == "url5") + assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path"), + "id2")._1 == "url4") + } + + var refreshTime2 = 0 + manager.register( + "test-table-path2", + Map("id1" -> "url1", "id2" -> "url2"), + Seq(new WeakReference(ref)), + provider, + () => { + refreshTime2 += 1 + ( + Map("id1" -> ("url" + refreshTime2.toString), "id2" -> "url4"), + Some(System.currentTimeMillis() + 4900) + ) + }, + System.currentTimeMillis() + 4900 + ) + // We should refresh 2 times within 10 seconds based on (System.currentTimeMillis() + 4900). + eventually(timeout(10.seconds)) { + assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path2"), + "id1")._1 == "url2") + assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path2"), + "id2")._1 == "url4") + } + + var refreshTime3 = 0 + manager.register( + "test-table-path3", + Map("id1" -> "url1", "id2" -> "url2"), + Seq(new WeakReference(ref)), + provider, + () => { + refreshTime3 += 1 + ( + Map("id1" -> ("url" + refreshTime3.toString), "id2" -> "url4"), + Some(System.currentTimeMillis() - 4900) + ) + }, + System.currentTimeMillis() + 6000 + ) + // We should refresh 1 times within 10 seconds based on (preSignedUrlExpirationMs = 6000). + try { + eventually(timeout(10.seconds)) { + assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path3"), + "id1")._1 == "url2") + assert(manager.getPreSignedUrl(provider.getCustomTablePath("test-table-path3"), + "id2")._1 == "url4") + } + } catch { + case e: Throwable => + assert(e.getMessage.contains("did not equal")) + } + } finally { + manager.stop() + } + } + test("expireAfterAccessMs") { val manager = new CachedTableManager( preSignedUrlExpirationMs = 10, @@ -119,7 +208,7 @@ class CachedTableManagerSuite extends SparkFunSuite { Seq(new WeakReference(ref)), provider, () => { - Map("id1" -> "url1", "id2" -> "url2") + (Map("id1" -> "url1", "id2" -> "url2"), None) }) Thread.sleep(1000) // We should remove the cached table when it's not accessed