Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hard abort on resource closure if any part of the stream remains open #853

Open
wants to merge 4 commits into
base: series/0.9
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 129 additions & 102 deletions core/src/main/scala/org/http4s/jdkhttpclient/JdkWSClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import java.net.http.{WebSocket => JWebSocket}
import java.nio.ByteBuffer
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionStage
import scala.concurrent.duration.DurationInt

/** A `WSClient` wrapper for the JDK 11+ websocket client. It will reply to Pongs with Pings even in
* "low-level" mode. Custom (non-GET) HTTP methods are ignored.
Expand All @@ -47,115 +48,141 @@ object JdkWSClient {
jdkHttpClient: HttpClient
)(implicit F: Async[F]): WSClient[F] =
WSClient(respondToPings = false) { req =>
Dispatcher.sequential.flatMap { dispatcher =>
Resource
.make {
for {
wsBuilder <- F.delay {
val builder = jdkHttpClient.newWebSocketBuilder()
val (subprotocols, hs) = req.headers.headers.partitionEither {
case Header.Raw(ci"Sec-WebSocket-Protocol", p) => Left(p)
case h => Right(h)
Dispatcher
.sequential(
await = true
GrafBlutwurst marked this conversation as resolved.
Show resolved Hide resolved
)
.flatMap { dispatcher =>
Resource
.make {
for {
wsBuilder <- F.delay {
val builder = jdkHttpClient.newWebSocketBuilder()
val (subprotocols, hs) = req.headers.headers.partitionEither {
case Header.Raw(ci"Sec-WebSocket-Protocol", p) => Left(p)
case h => Right(h)
}
hs.foreach { h => builder.header(h.name.toString, h.value); () }
subprotocols match {
case head :: tail => builder.subprotocols(head, tail: _*)
case Nil =>
}
builder
}
hs.foreach { h => builder.header(h.name.toString, h.value); () }
subprotocols match {
case head :: tail => builder.subprotocols(head, tail: _*)
case Nil =>
queue <- Queue.unbounded[F, Either[Throwable, WSFrame]]
closedDef <- Deferred[F, Unit]
handleReceive =
(wsf: Either[Throwable, WSFrame]) =>
dispatcher.unsafeToCompletableFuture(
queue.offer(wsf) *> (wsf match {
case Left(_) | Right(_: WSFrame.Close) => closedDef.complete(()).void
case _ => F.unit
})
)
wsListener = new JWebSocket.Listener {
override def onOpen(webSocket: JWebSocket): Unit = ()
override def onClose(webSocket: JWebSocket, statusCode: Int, reason: String)
: CompletionStage[_] =
// The output side of this connection will be closed when the returned CompletionStage completes.
// Therefore, we return a never completing CompletionStage, so we can control when the output will
// be closed (as it is allowed to continue sending frames (as few as possible) after a close frame
// has been received).
handleReceive(WSFrame.Close(statusCode, reason).asRight)
.thenCompose[Nothing](_ => new CompletableFuture[Nothing])
override def onText(webSocket: JWebSocket, data: CharSequence, last: Boolean)
: CompletionStage[_] =
handleReceive(WSFrame.Text(data.toString, last).asRight)
override def onBinary(webSocket: JWebSocket, data: ByteBuffer, last: Boolean)
: CompletionStage[_] =
handleReceive(WSFrame.Binary(ByteVector(data), last).asRight)
override def onPing(webSocket: JWebSocket, message: ByteBuffer)
: CompletionStage[_] =
handleReceive(WSFrame.Ping(ByteVector(message)).asRight)
override def onPong(webSocket: JWebSocket, message: ByteBuffer)
: CompletionStage[_] =
handleReceive(WSFrame.Pong(ByteVector(message)).asRight)
override def onError(webSocket: JWebSocket, error: Throwable): Unit = {
handleReceive(error.asLeft); ()
}
}
builder
}
queue <- Queue.unbounded[F, Either[Throwable, WSFrame]]
closedDef <- Deferred[F, Unit]
handleReceive =
(wsf: Either[Throwable, WSFrame]) =>
dispatcher.unsafeToCompletableFuture(
queue.offer(wsf) *> (wsf match {
case Left(_) | Right(_: WSFrame.Close) => closedDef.complete(()).void
case _ => F.unit
})
webSocket <- F.fromCompletableFuture(
F.delay(wsBuilder.buildAsync(URI.create(req.uri.renderString), wsListener))
)
sendSem <- Semaphore[F](1L)
} yield (webSocket, queue, closedDef, sendSem)
} { case (webSocket, queue, closedDef, sendSem) =>
def awaitClose(step: Int): F[Unit] =
closedDef.tryGet.flatMap {
case Some(_) => F.unit
case None =>
if (step < 10) F.sleep(100.millis) *> awaitClose(step + 1) else F.unit
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what a sensible amount of time to wait here is? should this be configurable? overloaded apply with a duration parameter to retain backwards compat? In practice this alleviates the need for actually using abort quite a bit which is nice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by this, is this equivalent to closedDef.get.timeout(1.second) ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's what I was looking for. For some reason it wasn't in my code completion 🤦 . thanks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah, you probably just needed import cats.effect.syntax.all._ to get that :)

Sorry, another follow-up question to this: do we have to do this timeout? I am wondering if the JDK WebSocket implementation itself has an internal timeout, and we are just forgetting to listen for an error event or something. see the docs:

Unless the CompletableFuture returned from this method completes with IllegalArgumentException, or the method throws NullPointerException, the output will be closed.

If not already closed, the input remains open until a Close message received, or abort is invoked, or an error occurs.

https://docs.oracle.com/en/java/javase/17/docs/api/java.net.http/java/net/http/WebSocket.html#sendClose(int,java.lang.String)

}

val cleanupF = for {
isOutputOpen <- F.delay(!webSocket.isOutputClosed)
closeOutput = sendSem.permit.use { _ =>
F.fromCompletableFuture(
F.delay(webSocket.sendClose(JWebSocket.NORMAL_CLOSURE, ""))
)
wsListener = new JWebSocket.Listener {
override def onOpen(webSocket: JWebSocket): Unit = ()
override def onClose(webSocket: JWebSocket, statusCode: Int, reason: String)
: CompletionStage[_] =
// The output side of this connection will be closed when the returned CompletionStage completes.
// Therefore, we return a never completing CompletionStage, so we can control when the output will
// be closed (as it is allowed to continue sending frames (as few as possible) after a close frame
// has been received).
handleReceive(WSFrame.Close(statusCode, reason).asRight)
.thenCompose[Nothing](_ => new CompletableFuture[Nothing])
override def onText(webSocket: JWebSocket, data: CharSequence, last: Boolean)
: CompletionStage[_] =
handleReceive(WSFrame.Text(data.toString, last).asRight)
override def onBinary(webSocket: JWebSocket, data: ByteBuffer, last: Boolean)
: CompletionStage[_] =
handleReceive(WSFrame.Binary(ByteVector(data), last).asRight)
override def onPing(webSocket: JWebSocket, message: ByteBuffer)
: CompletionStage[_] =
handleReceive(WSFrame.Ping(ByteVector(message)).asRight)
override def onPong(webSocket: JWebSocket, message: ByteBuffer)
: CompletionStage[_] =
handleReceive(WSFrame.Pong(ByteVector(message)).asRight)
override def onError(webSocket: JWebSocket, error: Throwable): Unit = {
handleReceive(error.asLeft); ()
}
_ <-
closeOutput
.whenA(isOutputOpen)
.recover { case e: IOException if e.getMessage == "closed output" => () }
.onError { case e: IOException =>
for {
errs <- Stream
.repeatEval(queue.tryTake)
.unNoneTerminate
.collect { case Left(e) => e }
.compile
.toList
_ <- F.raiseError[Unit](CompositeFailure.fromList(errs) match {
case Some(cf) => cf
case None => e
})
} yield ()
}

isInputOpen <- F.delay {
!webSocket.isInputClosed
}

_ <- if (isInputOpen) awaitClose(0) else F.unit

} yield ()

val ensureResourceReleaseF = F.delay {
if (!webSocket.isOutputClosed || !webSocket.isInputClosed) webSocket.abort else ()
}
webSocket <- F.fromCompletableFuture(
F.delay(wsBuilder.buildAsync(URI.create(req.uri.renderString), wsListener))
)
sendSem <- Semaphore[F](1L)
} yield (webSocket, queue, closedDef, sendSem)
} { case (webSocket, queue, _, _) =>
for {
isOutputOpen <- F.delay(!webSocket.isOutputClosed)
closeOutput = F.fromCompletableFuture(
F.delay(webSocket.sendClose(JWebSocket.NORMAL_CLOSURE, ""))
)
_ <-
closeOutput
.whenA(isOutputOpen)
.recover { case e: IOException if e.getMessage == "closed output" => () }
.onError { case e: IOException =>
for {
errs <- Stream
.repeatEval(queue.tryTake)
.unNoneTerminate
.collect { case Left(e) => e }
.compile
.toList
_ <- F.raiseError[Unit](CompositeFailure.fromList(errs) match {
case Some(cf) => cf
case None => e
})
} yield ()
}
} yield ()
}
.map { case (webSocket, queue, closedDef, sendSem) =>
// sending will throw if done in parallel
val rawSend = (wsf: WSFrame) =>
F.fromCompletableFuture(F.delay(wsf match {
case WSFrame.Text(text, last) => webSocket.sendText(text, last)
case WSFrame.Binary(data, last) => webSocket.sendBinary(data.toByteBuffer, last)
case WSFrame.Ping(data) => webSocket.sendPing(data.toByteBuffer)
case WSFrame.Pong(data) => webSocket.sendPong(data.toByteBuffer)
case WSFrame.Close(statusCode, reason) => webSocket.sendClose(statusCode, reason)
}))
.void
new WSConnection[F] {
override def send(wsf: WSFrame) =
sendSem.permit.use(_ => rawSend(wsf))
override def sendMany[G[_]: Foldable, A <: WSFrame](wsfs: G[A]) =
sendSem.permit.use(_ => wsfs.traverse_(rawSend))
override def receive = closedDef.tryGet.flatMap {
case None => F.delay(webSocket.request(1)) *> queue.take.rethrow.map(_.some)
case Some(()) => none[WSFrame].pure[F]

F.guarantee(cleanupF, ensureResourceReleaseF)
}
.map { case (webSocket, queue, closedDef, sendSem) =>
// sending will throw if done in parallel
val rawSend = (wsf: WSFrame) =>
F.fromCompletableFuture(F.delay(wsf match {
case WSFrame.Text(text, last) => webSocket.sendText(text, last)
case WSFrame.Binary(data, last) => webSocket.sendBinary(data.toByteBuffer, last)
case WSFrame.Ping(data) => webSocket.sendPing(data.toByteBuffer)
case WSFrame.Pong(data) => webSocket.sendPong(data.toByteBuffer)
case WSFrame.Close(statusCode, reason) => webSocket.sendClose(statusCode, reason)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I have seen is this line failing because the output was already closed. We don't send close frames manually so the only way this should be invoked is by connectHighLevel which mirrors the close frame. Maybe we should guard this with an if, checking if the output is still open? I don't like that much though because low level usage would then have implicit restrictions.

Also as a side note, the mirroring of the closure frame in connectHighLevel actually may fail if the server emits a close code that is not applicable for a client to send.

}))
.void
new WSConnection[F] {
override def send(wsf: WSFrame) =
sendSem.permit.use(_ => rawSend(wsf))
override def sendMany[G[_]: Foldable, A <: WSFrame](wsfs: G[A]) =
sendSem.permit.use(_ => wsfs.traverse_(rawSend))
override def receive = closedDef.tryGet.flatMap {
case None => F.delay(webSocket.request(1)) *> queue.take.rethrow.map(_.some)
case Some(()) => none[WSFrame].pure[F]
}
override def subprotocol =
webSocket.getSubprotocol.some.filter(_.nonEmpty)
}
override def subprotocol =
webSocket.getSubprotocol.some.filter(_.nonEmpty)
}
}
}
}
}

/** A `WSClient` wrapping the default `HttpClient`. */
Expand Down