diff --git a/core/src/main/scala/com.snowplowanalytics.snowplow.collector.core/HttpServer.scala b/core/src/main/scala/com.snowplowanalytics.snowplow.collector.core/HttpServer.scala index 55d4f173e..21c2d0da2 100644 --- a/core/src/main/scala/com.snowplowanalytics.snowplow.collector.core/HttpServer.scala +++ b/core/src/main/scala/com.snowplowanalytics.snowplow.collector.core/HttpServer.scala @@ -26,6 +26,8 @@ import org.typelevel.log4cats.slf4j.Slf4jLogger import java.net.InetSocketAddress import javax.net.ssl.SSLContext +import org.http4s.Response +import org.http4s.Status object HttpServer { @@ -40,12 +42,46 @@ object HttpServer { networking: Config.Networking, metricsConfig: Config.Metrics, debugHttp: Config.Debug.Http + )( + mkServer: ((HttpApp[F], Int, Boolean, Config.Networking) => Resource[F, Server]) ): Resource[F, Server] = for { withMetricsMiddleware <- createMetricsMiddleware(routes, metricsConfig) - server <- buildBlazeServer[F](withMetricsMiddleware, healthRoutes, port, secure, hsts, networking, debugHttp) + httpApp <- Resource.pure(httpApp(withMetricsMiddleware, healthRoutes, hsts, networking, debugHttp)) + server <- mkServer(httpApp, port, secure, networking) } yield server + def buildBlazeServer[F[_]: Async]( + httpApp: HttpApp[F], + port: Int, + secure: Boolean, + networking: Config.Networking + ): Resource[F, Server] = + Resource.eval(Logger[F].info("Building blaze server")) >> + BlazeServerBuilder[F] + .bindSocketAddress(new InetSocketAddress(port)) + .withHttpApp(httpApp) + .withIdleTimeout(networking.idleTimeout) + .withMaxConnections(networking.maxConnections) + .withResponseHeaderTimeout(networking.responseHeaderTimeout) + .withLengthLimits( + maxRequestLineLen = networking.maxRequestLineLength, + maxHeadersLen = networking.maxHeadersLength + ) + .cond(secure, _.withSslContext(SSLContext.getDefault)) + .resource + + def httpApp[F[_]: Async]( + routes: HttpRoutes[F], + healthRoutes: HttpRoutes[F], + hsts: Config.HSTS, + networking: Config.Networking, + debugHttp: Config.Debug.Http + ): HttpApp[F] = hstsApp( + hsts, + loggerMiddleware(timeoutMiddleware(routes, networking), debugHttp) <+> loggerMiddleware(healthRoutes, debugHttp) + ) + private def createMetricsMiddleware[F[_]: Async]( routes: HttpRoutes[F], metricsConfig: Config.Metrics @@ -81,36 +117,10 @@ object HttpServer { } else routes private def timeoutMiddleware[F[_]: Async](routes: HttpRoutes[F], networking: Config.Networking): HttpRoutes[F] = - Timeout.httpRoutes[F](timeout = networking.responseHeaderTimeout)(routes) - - private def buildBlazeServer[F[_]: Async]( - routes: HttpRoutes[F], - healthRoutes: HttpRoutes[F], - port: Int, - secure: Boolean, - hsts: Config.HSTS, - networking: Config.Networking, - debugHttp: Config.Debug.Http - ): Resource[F, Server] = - Resource.eval(Logger[F].info("Building blaze server")) >> - BlazeServerBuilder[F] - .bindSocketAddress(new InetSocketAddress(port)) - .withHttpApp( - hstsApp( - hsts, - loggerMiddleware(timeoutMiddleware(routes, networking), debugHttp) - <+> loggerMiddleware(healthRoutes, debugHttp) - ) - ) - .withIdleTimeout(networking.idleTimeout) - .withMaxConnections(networking.maxConnections) - .withResponseHeaderTimeout(networking.responseHeaderTimeout) - .withLengthLimits( - maxRequestLineLen = networking.maxRequestLineLength, - maxHeadersLen = networking.maxHeadersLength - ) - .cond(secure, _.withSslContext(SSLContext.getDefault)) - .resource + Timeout.httpRoutes[F](timeout = networking.responseHeaderTimeout)(routes).collect { + case Response(Status.ServiceUnavailable, httpVersion, headers, body, attributes) => + Response[F](Status.RequestTimeout, httpVersion, headers, body, attributes) + } implicit class ConditionalAction[A](item: A) { def cond(cond: Boolean, action: A => A): A = diff --git a/core/src/main/scala/com.snowplowanalytics.snowplow.collector.core/Run.scala b/core/src/main/scala/com.snowplowanalytics.snowplow.collector.core/Run.scala index 33f6b6a67..5eba83582 100644 --- a/core/src/main/scala/com.snowplowanalytics.snowplow.collector.core/Run.scala +++ b/core/src/main/scala/com.snowplowanalytics.snowplow.collector.core/Run.scala @@ -109,7 +109,7 @@ object Run { config.networking, config.monitoring.metrics, config.debug.http - ) + )(HttpServer.buildBlazeServer) _ <- withGracefulShutdown(config.preTerminationPeriod)(httpServer) httpClient <- BlazeClientBuilder[F].resource } yield httpClient diff --git a/core/src/test/scala/com.snowplowanalytics.snowplow.collector.core/HttpServerSpec.scala b/core/src/test/scala/com.snowplowanalytics.snowplow.collector.core/HttpServerSpec.scala new file mode 100644 index 000000000..5215edb30 --- /dev/null +++ b/core/src/test/scala/com.snowplowanalytics.snowplow.collector.core/HttpServerSpec.scala @@ -0,0 +1,52 @@ +package com.snowplowanalytics.snowplow.collector.core + +import org.specs2.mutable.Specification +import cats.effect.IO + +import org.http4s.client.Client +import org.http4s._ +import org.http4s.dsl.io._ +import org.http4s.implicits._ +import scala.concurrent.duration._ +import cats.effect.testing.specs2._ + +class HttpServerSpecification extends Specification with CatsEffect { + + "HttpServer" should { + "manage request timeout" should { + "timeout threshold is configured" in { + val config = + TestUtils + .testConfig + .copy(networking = TestUtils.testConfig.networking.copy(responseHeaderTimeout = 100.millis)) + val routes = HttpRoutes.of[IO] { + case _ -> Root / "fast" => + Ok("Fast") + case _ -> Root / "never" => + IO.never[Response[IO]] + } + val healthRoutes = HttpRoutes.of[IO] { + case _ -> Root / "health" => + Ok("ok") + } + val httpApp = HttpServer.httpApp( + routes, + healthRoutes, + config.hsts, + config.networking, + config.debug.http + ) + val client: Client[IO] = Client.fromHttpApp(httpApp) + val request: Request[IO] = Request(method = Method.GET, uri = uri"/never") + val res: IO[String] = client.expect[String](request) + + res + .attempt + .map(_ must beLeft[Throwable].which { + case org.http4s.client.UnexpectedStatus(Status.RequestTimeout, _, _) => true + case _ => false + }) + } + } + } +}