diff --git a/server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala b/server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala index b187106de..2135beb0d 100644 --- a/server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala +++ b/server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala @@ -17,6 +17,7 @@ package io.delta.sharing.server import java.io.{ByteArrayOutputStream, File, FileNotFoundException} +import java.lang.reflect.Method import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.AccessDeniedException import java.security.MessageDigest @@ -31,7 +32,9 @@ import com.linecorp.armeria.common.auth.OAuth2Token import com.linecorp.armeria.internal.server.ResponseConversionUtil import com.linecorp.armeria.server.{Server, ServiceRequestContext} import com.linecorp.armeria.server.annotation.{ConsumesJson, Default, ExceptionHandler, ExceptionHandlerFunction, Get, Head, Param, Post, ProducesJson} +import com.linecorp.armeria.server.annotation.decorator.CorsDecorator import com.linecorp.armeria.server.auth.AuthService +import com.linecorp.armeria.server.cors.CorsService import io.delta.standalone.internal.DeltaCDFErrors import io.delta.standalone.internal.DeltaCDFIllegalArgumentException import io.delta.standalone.internal.DeltaDataSource @@ -42,7 +45,7 @@ import org.slf4j.LoggerFactory import scalapb.json4s.Printer import io.delta.sharing.server.config.ServerConfig -import io.delta.sharing.server.model.SingleAction +import io.delta.sharing.server.model.{ AddCDCFile, AddFile, AddFileForCDF, RemoveFile, SingleAction } import io.delta.sharing.server.protocol._ import io.delta.sharing.server.util.JsonUtils @@ -62,9 +65,9 @@ class DeltaSharingServiceExceptionHandler extends ExceptionHandlerFunction { private val logger = LoggerFactory.getLogger(classOf[DeltaSharingServiceExceptionHandler]) override def handleException( - ctx: ServiceRequestContext, - req: HttpRequest, - cause: Throwable): HttpResponse = { + ctx: ServiceRequestContext, + req: HttpRequest, + cause: Throwable): HttpResponse = { cause match { // Handle exceptions caused by incorrect requests case _: DeltaSharingNoSuchElementException => @@ -128,7 +131,7 @@ class DeltaSharingServiceExceptionHandler extends ExceptionHandlerFunction { // // valid json but may not be incorect field type case (_: scalapb.json4s.JsonFormatException | - // invalid json + // invalid json _: com.fasterxml.jackson.databind.JsonMappingException) => HttpResponse.of( HttpStatus.BAD_REQUEST, @@ -190,24 +193,31 @@ class DeltaSharingService(serverConfig: ServerConfig) { @Get("/shares") @ProducesJson def listShares( - @Param("maxResults") @Default("500") maxResults: Int, - @Param("pageToken") @Nullable pageToken: String): ListSharesResponse = processRequest { + serviceRequestContext: ServiceRequestContext, + httpRequest: HttpRequest, + @Param("maxResults") @Default("500") maxResults: Int, + @Param("pageToken") @Nullable pageToken: String): ListSharesResponse = processRequest { val (shares, nextPageToken) = sharedTableManager.listShares(Option(pageToken), Some(maxResults)) ListSharesResponse(shares, nextPageToken) } @Get("/shares/{share}") @ProducesJson - def getShare(@Param("share") share: String): GetShareResponse = processRequest { + def getShare( + serviceRequestContext: ServiceRequestContext, + httpRequest: HttpRequest, + @Param("share") share: String): GetShareResponse = processRequest { GetShareResponse(share = Some(sharedTableManager.getShare(share))) } @Get("/shares/{share}/schemas") @ProducesJson def listSchemas( - @Param("share") share: String, - @Param("maxResults") @Default("500") maxResults: Int, - @Param("pageToken") @Nullable pageToken: String): ListSchemasResponse = processRequest { + serviceRequestContext: ServiceRequestContext, + httpRequest: HttpRequest, + @Param("share") share: String, + @Param("maxResults") @Default("500") maxResults: Int, + @Param("pageToken") @Nullable pageToken: String): ListSchemasResponse = processRequest { val (schemas, nextPageToken) = sharedTableManager.listSchemas(share, Option(pageToken), Some(maxResults)) ListSchemasResponse(schemas, nextPageToken) @@ -216,10 +226,12 @@ class DeltaSharingService(serverConfig: ServerConfig) { @Get("/shares/{share}/schemas/{schema}/tables") @ProducesJson def listTables( - @Param("share") share: String, - @Param("schema") schema: String, - @Param("maxResults") @Default("500") maxResults: Int, - @Param("pageToken") @Nullable pageToken: String): ListTablesResponse = processRequest { + serviceRequestContext: ServiceRequestContext, + httpRequest: HttpRequest, + @Param("share") share: String, + @Param("schema") schema: String, + @Param("maxResults") @Default("500") maxResults: Int, + @Param("pageToken") @Nullable pageToken: String): ListTablesResponse = processRequest { val (tables, nextPageToken) = sharedTableManager.listTables(share, schema, Option(pageToken), Some(maxResults)) ListTablesResponse(tables, nextPageToken) @@ -228,22 +240,22 @@ class DeltaSharingService(serverConfig: ServerConfig) { @Get("/shares/{share}/all-tables") @ProducesJson def listAllTables( - @Param("share") share: String, - @Param("maxResults") @Default("500") maxResults: Int, - @Param("pageToken") @Nullable pageToken: String): ListAllTablesResponse = processRequest { + serviceRequestContext: ServiceRequestContext, + httpRequest: HttpRequest, + @Param("share") share: String, + @Param("maxResults") @Default("500") maxResults: Int, + @Param("pageToken") @Nullable pageToken: String): ListAllTablesResponse = processRequest { val (tables, nextPageToken) = sharedTableManager.listAllTables(share, Option(pageToken), Some(maxResults)) ListAllTablesResponse(tables, nextPageToken) } - private def createHeadersBuilderForTableVersion(version: Long): ResponseHeadersBuilder = { - ResponseHeaders.builder(200).set(DELTA_TABLE_VERSION_HEADER, version.toString) - } - // TODO: deprecate HEAD request in favor of the GET request @Head("/shares/{share}/schemas/{schema}/tables/{table}") @Get("/shares/{share}/schemas/{schema}/tables/{table}/version") def getTableVersion( + serviceRequestContext: ServiceRequestContext, + httpRequest: HttpRequest, @Param("share") share: String, @Param("schema") schema: String, @Param("table") table: String, @@ -261,18 +273,31 @@ class DeltaSharingService(serverConfig: ServerConfig) { if (startingTimestamp != null && version < tableConfig.startVersion) { throw new DeltaSharingIllegalArgumentException( s"You can only query table data since version ${tableConfig.startVersion}." + - s"The provided timestamp($startingTimestamp) corresponds to $version." + s"The provided timestamp($startingTimestamp) corresponds to $version." ) } - val headers = createHeadersBuilderForTableVersion(version).build() + val headersBuilder = ResponseHeaders.builder(HttpStatus.OK.code) + + val corsService = CorsService.builder(serverConfig.getHost) + val setCorsResponseHeadersMethod = corsService.getClass + .getDeclaredMethod("setCorsResponseHeaders", + classOf[ServiceRequestContext], classOf[HttpRequest], classOf[ResponseHeadersBuilder]) + setCorsResponseHeadersMethod.setAccessible(true) + setCorsResponseHeadersMethod + .invoke(corsService, serviceRequestContext, httpRequest, headersBuilder) + + val headers = headersBuilder.set(DELTA_TABLE_VERSION_HEADER, version.toString) + .build() HttpResponse.of(headers) } @Get("/shares/{share}/schemas/{schema}/tables/{table}/metadata") def getMetadata( - @Param("share") share: String, - @Param("schema") schema: String, - @Param("table") table: String): HttpResponse = processRequest { + serviceRequestContext: ServiceRequestContext, + httpRequest: HttpRequest, + @Param("share") share: String, + @Param("schema") schema: String, + @Param("table") table: String): HttpResponse = processRequest { import scala.collection.JavaConverters._ val tableConfig = sharedTableManager.getTable(share, schema, table) val (v, actions) = deltaSharedTableLoader.loadTable(tableConfig).query( @@ -283,16 +308,18 @@ class DeltaSharingService(serverConfig: ServerConfig) { version = None, timestamp = None, startingVersion = None) - streamingOutput(Some(v), actions) + streamingOutput(serviceRequestContext, httpRequest, Some(v), actions) } @Post("/shares/{share}/schemas/{schema}/tables/{table}/query") @ConsumesJson def listFiles( - @Param("share") share: String, - @Param("schema") schema: String, - @Param("table") table: String, - request: QueryTableRequest): HttpResponse = processRequest { + serviceRequestContext: ServiceRequestContext, + httpRequest: HttpRequest, + @Param("share") share: String, + @Param("schema") schema: String, + @Param("table") table: String, + request: QueryTableRequest): HttpResponse = processRequest { val numVersionParams = Seq(request.version, request.timestamp, request.startingVersion) .filter(_.isDefined).size if (numVersionParams > 1) { @@ -337,12 +364,14 @@ class DeltaSharingService(serverConfig: ServerConfig) { } logger.info(s"Took ${System.currentTimeMillis - start} ms to load the table " + s"and sign ${actions.length - 2} urls for table $share/$schema/$table") - streamingOutput(Some(version), actions) + streamingOutput(serviceRequestContext, httpRequest, Some(version), actions) } @Get("/shares/{share}/schemas/{schema}/tables/{table}/changes") @ConsumesJson def listCdfFiles( + serviceRequestContext: ServiceRequestContext, + httpRequest: HttpRequest, @Param("share") share: String, @Param("schema") schema: String, @Param("table") table: String, @@ -351,7 +380,7 @@ class DeltaSharingService(serverConfig: ServerConfig) { @Param("startingTimestamp") @Nullable startingTimestamp: String, @Param("endingTimestamp") @Nullable endingTimestamp: String, @Param("includeHistoricalMetadata") @Nullable includeHistoricalMetadata: String - ): HttpResponse = processRequest { + ): HttpResponse = processRequest { val start = System.currentTimeMillis val tableConfig = sharedTableManager.getTable(share, schema, table) if (!tableConfig.cdfEnabled) { @@ -370,19 +399,55 @@ class DeltaSharingService(serverConfig: ServerConfig) { ) logger.info(s"Took ${System.currentTimeMillis - start} ms to load the table cdf " + s"and sign ${actions.length - 2} urls for table $share/$schema/$table") - streamingOutput(Some(v), actions) + streamingOutput(serviceRequestContext, httpRequest, Some(v), actions) } - private def streamingOutput(version: Option[Long], actions: Seq[SingleAction]): HttpResponse = { - val headers = if (version.isDefined) { - createHeadersBuilderForTableVersion(version.get) - .set(HttpHeaderNames.CONTENT_TYPE, DELTA_TABLE_METADATA_CONTENT_TYPE) - .build() - } else { - ResponseHeaders.builder(200) - .set(HttpHeaderNames.CONTENT_TYPE, DELTA_TABLE_METADATA_CONTENT_TYPE) - .build() + private def streamingOutput(serviceRequestContext: ServiceRequestContext, + httpRequest: HttpRequest, + version: Option[Long], + actions: Seq[SingleAction]): HttpResponse = { + val headersBuilder = ResponseHeaders.builder(HttpStatus.OK.code) + .set(HttpHeaderNames.CONTENT_TYPE, DELTA_TABLE_METADATA_CONTENT_TYPE); + if (actions.nonEmpty) { + val urls: Seq[String] = actions.map(((e: SingleAction) => { + val a = e.unwrap + a.getClass match { + case v if v == classOf[AddFile] => a.asInstanceOf[AddFile].url + case v if v == classOf[AddFileForCDF] => a.asInstanceOf[AddFileForCDF].url + case v if v == classOf[AddCDCFile] => a.asInstanceOf[AddCDCFile].url + case v if v == classOf[RemoveFile] => a.asInstanceOf[RemoveFile].url + case _ => null + } + }): (SingleAction => String)).filter(_ != null) + val corsUrls = (serverConfig.getHost +: urls) + val corsService = CorsService.builder(corsUrls: _*) + + /* From CorsService, the private method setCorsResponseHeaders is used to set the headers + via the builder and would be called if the library were properly utilized with static + service end-points. Since the short-lived URLs are dyanmic and change, the CORS headers + will change and need to be dynamic each time. Note this does not change the headers on + the Blob Storage that still need to allow the origin. The code exposes the method and + calls the method to allow the Delta Sharing Server to properly send back the expected + headers. For more information please see: + https://advancedweb.hu/how-to-solve-cors-problems-when-redirecting-to-s3-signed-urls/ + and + https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS + + NOTE: The setCorsResponseHeadersMethod can be made into a field on the class to then + have only the method invoked during each call of streamingOutput. + */ + val setCorsResponseHeadersMethod = corsService.getClass + .getDeclaredMethod("setCorsResponseHeaders", + classOf[ServiceRequestContext], classOf[HttpRequest], classOf[ResponseHeadersBuilder]) + setCorsResponseHeadersMethod.setAccessible(true) + + setCorsResponseHeadersMethod + .invoke(corsService, serviceRequestContext, httpRequest, headersBuilder) } + + if (version.isDefined) headersBuilder.set(DELTA_TABLE_VERSION_HEADER, version.get.toString) + val headers = headersBuilder.build() + ResponseConversionUtil.streamingFrom( actions.asJava.stream(), headers, @@ -393,7 +458,7 @@ class DeltaSharingService(serverConfig: ServerConfig) { out.write('\n') HttpData.wrap(out.toByteArray) }, - ServiceRequestContext.current().blockingTaskExecutor()) + serviceRequestContext.blockingTaskExecutor()) } } @@ -476,10 +541,10 @@ object DeltaSharingService { } private def checkCDFOptionsValidity( - startingVersion: Option[String], - endingVersion: Option[String], - startingTimestamp: Option[String], - endingTimestamp: Option[String]): Unit = { + startingVersion: Option[String], + endingVersion: Option[String], + startingTimestamp: Option[String], + endingTimestamp: Option[String]): Unit = { // check if we have both version and timestamp parameters if (startingVersion.isDefined && startingTimestamp.isDefined) { throw DeltaCDFErrors.multipleCDFBoundary("starting") @@ -510,16 +575,16 @@ object DeltaSharingService { } private[server] def getCdfOptionsMap( - startingVersion: Option[String], - endingVersion: Option[String], - startingTimestamp: Option[String], - endingTimestamp: Option[String]): Map[String, String] = { + startingVersion: Option[String], + endingVersion: Option[String], + startingTimestamp: Option[String], + endingTimestamp: Option[String]): Map[String, String] = { checkCDFOptionsValidity(startingVersion, endingVersion, startingTimestamp, endingTimestamp) (startingVersion.map(DeltaDataSource.CDF_START_VERSION_KEY -> _) ++ - endingVersion.map(DeltaDataSource.CDF_END_VERSION_KEY -> _) ++ - startingTimestamp.map(DeltaDataSource.CDF_START_TIMESTAMP_KEY -> _) ++ - endingTimestamp.map(DeltaDataSource.CDF_END_TIMESTAMP_KEY -> _)).toMap + endingVersion.map(DeltaDataSource.CDF_END_VERSION_KEY -> _) ++ + startingTimestamp.map(DeltaDataSource.CDF_START_TIMESTAMP_KEY -> _) ++ + endingTimestamp.map(DeltaDataSource.CDF_END_TIMESTAMP_KEY -> _)).toMap } def main(args: Array[String]): Unit = {