Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix request backpressure #10142

Merged
merged 6 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ public final class PipeliningServerHandler extends ChannelInboundHandlerAdapter
*/
private boolean reading = false;
/**
* {@code true} iff we want to read more data.
* {@code true} iff {@code ctx.read()} has been called already.
*/
private boolean moreRequested = false;
private boolean readCalled = false;
/**
* {@code true} iff this handler has been removed.
*/
Expand Down Expand Up @@ -151,23 +151,28 @@ private static boolean hasBody(HttpRequest request) {
}

/**
* Set whether we need more input, i.e. another call to {@link #channelRead}. This is usally a
* {@link ChannelHandlerContext#read()} call, but it's coalesced until
* {@link #channelReadComplete}.
*
* @param needMore {@code true} iff we need more input
* Call {@code ctx.read()} if necessary.
*/
private void setNeedMore(boolean needMore) {
boolean oldMoreRequested = moreRequested;
moreRequested = needMore;
if (!oldMoreRequested && !reading && needMore) {
private void refreshNeedMore() {
// if readCalled is true, ctx.read() is already called and we haven't seen the associated readComplete yet.

// needMore is false if there is downstream backpressure.

// requestHandler itself (i.e. non-streaming request processing) does not have
// backpressure. For this, check whether there is a request that has been fully read but
// has no response yet. If there is, apply backpressure.
if (!readCalled && outboundQueue.size() <= 1 && inboundHandler.needMore()) {
readCalled = true;
ctx.read();
}
}

@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
this.ctx = ctx;
// we take control of reading now.
ctx.channel().config().setAutoRead(false);
refreshNeedMore();
}

@Override
Expand Down Expand Up @@ -195,13 +200,13 @@ public void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg)
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
inboundHandler.readComplete();
reading = false;
// only unset readCalled now. This ensures no read call is done before channelReadComplete
readCalled = false;
if (flushPending) {
ctx.flush();
flushPending = false;
}
if (moreRequested) {
ctx.read();
}
refreshNeedMore();
}

@Override
Expand Down Expand Up @@ -267,6 +272,7 @@ private void writeSome() {
if (next != null && next.handler != null) {
outboundQueue.poll();
outboundHandler = next.handler;
refreshNeedMore();
} else {
return;
}
Expand All @@ -286,7 +292,15 @@ private void writeSome() {
/**
* An inbound handler is responsible for all incoming messages.
*/
private abstract static class InboundHandler {
private abstract class InboundHandler {
/**
* @return {@code true} iff this handler can process more data. This is usually {@code true},
* except for streaming requests when there is downstream backpressure.
*/
boolean needMore() {
return true;
}

/**
* @see #channelRead
*/
Expand Down Expand Up @@ -448,7 +462,6 @@ void read(Object message) {
sink.tryEmitComplete();
inboundHandler = baseInboundHandler;
}
setNeedMore(requested > 0);
}

@Override
Expand All @@ -459,6 +472,11 @@ void handleUpstreamError(Throwable cause) {
}
}

@Override
boolean needMore() {
return requested > 0;
}

private void request(long n) {
EventLoop eventLoop = ctx.channel().eventLoop();
if (!eventLoop.inEventLoop()) {
Expand All @@ -472,20 +490,27 @@ private void request(long n) {
newRequested = Long.MAX_VALUE;
}
requested = newRequested;
setNeedMore(newRequested > 0);
refreshNeedMore();
}

Flux<HttpContent> flux() {
return sink.asFlux()
.doOnRequest(this::request)
.doOnCancel(this::releaseQueue);
.doOnCancel(this::cancel);
}

void closeIfNoSubscriber() {
EventLoop eventLoop = ctx.channel().eventLoop();
if (!eventLoop.inEventLoop()) {
eventLoop.execute(this::closeIfNoSubscriber);
return;
}

if (sink.currentSubscriberCount() == 0) {
releaseQueue();
if (inboundHandler == this) {
inboundHandler = droppingInboundHandler;
refreshNeedMore();
}
}
}
Expand All @@ -499,6 +524,20 @@ private void releaseQueue() {
c.release();
}
}

private void cancel() {
EventLoop eventLoop = ctx.channel().eventLoop();
if (!eventLoop.inEventLoop()) {
eventLoop.execute(this::cancel);
return;
}

if (inboundHandler == this) {
inboundHandler = droppingInboundHandler;
refreshNeedMore();
}
releaseQueue();
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ private void writeResponse(ChannelHandlerContext ctx,
} catch (NoSuchElementException ignored) {
}

// websocket needs auto read for now
ctx.channel().config().setAutoRead(true);
} catch (Throwable e) {
if (LOG.isErrorEnabled()) {
LOG.error("Error opening WebSocket: {}", e.getMessage(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ class EmbeddedTestUtil {

static void connect(EmbeddedChannel server, EmbeddedChannel client) {
new ConnectionDirection(server, client).register()
new ConnectionDirection(client, server).register()
def csDir = new ConnectionDirection(client, server)
csDir.register()
// PipeliningServerHandler fires a read() before this method is called, so we don't see it.
csDir.readPending = true
}

private static class ConnectionDirection {
Expand All @@ -40,7 +43,7 @@ class EmbeddedTestUtil {
}

private void forwardLater(Object msg) {
if (readPending || dest.config().isAutoRead()) {
if (readPending || dest.config().isAutoRead() || msg == FLUSH) {
dest.eventLoop().execute(() -> forwardNow(msg))
readPending = false
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import io.netty.handler.codec.http.HttpRequest
import io.netty.handler.codec.http.HttpResponse
import io.netty.handler.codec.http.HttpResponseStatus
import io.netty.handler.codec.http.HttpVersion
import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription
import io.netty.handler.codec.http.LastHttpContent
import reactor.core.publisher.Flux
import reactor.core.publisher.Sinks
Expand Down Expand Up @@ -291,6 +293,67 @@ class PipeliningServerHandlerSpec extends Specification {
completeOnCancel << [true, false]
}

def 'read backpressure for streaming requests'() {
given:
def mon = new MonitorHandler()
Subscription subscription = null
def ch = new EmbeddedChannel(mon, new PipeliningServerHandler(new RequestHandler() {
@Override
void accept(ChannelHandlerContext ctx, HttpRequest request, PipeliningServerHandler.OutboundAccess outboundAccess) {
((StreamedHttpRequest) request).subscribe(new Subscriber<HttpContent>() {
@Override
void onSubscribe(Subscription s) {
subscription = s
}

@Override
void onNext(HttpContent httpContent) {
httpContent.release()
}

@Override
void onError(Throwable t) {
t.printStackTrace()
}

@Override
void onComplete() {
outboundAccess.writeFull(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NO_CONTENT))
}
})
}

@Override
void handleUnboundError(Throwable cause) {
cause.printStackTrace()
}
}))

expect:
mon.read == 1
mon.flush == 0

when:
def req = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/")
req.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED)
ch.writeInbound(req)
then:
// no read call until request
mon.read == 1

when:
subscription.request(1)
then:
mon.read == 2

when:
ch.writeInbound(new DefaultLastHttpContent(Unpooled.wrappedBuffer("foo".getBytes(StandardCharsets.UTF_8))))
then:
// read call for the next request
mon.read == 3
ch.checkException()
}

def 'empty streaming response while in queue'() {
given:
def resp = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)
Expand Down
Loading