Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement decompression in PipeliningServerHandler #10155

Merged
merged 9 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
Expand Down Expand Up @@ -72,11 +71,6 @@
import io.netty.util.AsciiString;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLPeerUnverifiedException;
import java.io.Closeable;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
Expand All @@ -90,6 +84,10 @@
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLPeerUnverifiedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Helper class that manages the {@link ChannelPipeline} of incoming HTTP connections.
Expand Down Expand Up @@ -601,7 +599,6 @@ private void insertMicronautHandlers() {

SmartHttpContentCompressor contentCompressor = new SmartHttpContentCompressor(embeddedServices.getHttpCompressionStrategy());
pipeline.addLast(ChannelPipelineCustomizer.HANDLER_HTTP_COMPRESSOR, contentCompressor);
pipeline.addLast(ChannelPipelineCustomizer.HANDLER_HTTP_DECOMPRESSOR, new HttpContentDecompressor());

Optional<NettyServerWebSocketUpgradeHandler> webSocketUpgradeHandler = embeddedServices.getWebSocketUpgradeHandler(server);
if (webSocketUpgradeHandler.isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,22 @@
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.EventLoop;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.compression.Brotli;
import io.netty.handler.codec.compression.BrotliDecoder;
import io.netty.handler.codec.compression.SnappyFrameDecoder;
import io.netty.handler.codec.compression.ZlibCodecFactory;
import io.netty.handler.codec.compression.ZlibWrapper;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.DefaultLastHttpContent;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
Expand Down Expand Up @@ -328,18 +336,81 @@ void read(Object message) {
HttpRequest request = (HttpRequest) message;
OutboundAccess outboundAccess = new OutboundAccess();
outboundQueue.add(outboundAccess);
if (request instanceof FullHttpRequest full) {
requestHandler.accept(ctx, full, outboundAccess);

HttpHeaders headers = request.headers();
String contentEncoding = getContentEncoding(headers);
EmbeddedChannel decompressionChannel;
if (contentEncoding == null) {
decompressionChannel = null;
} else if (HttpHeaderValues.GZIP.contentEqualsIgnoreCase(contentEncoding) ||
HttpHeaderValues.X_GZIP.contentEqualsIgnoreCase(contentEncoding)) {
decompressionChannel = new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
ctx.channel().config(), ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP));
} else if (HttpHeaderValues.DEFLATE.contentEqualsIgnoreCase(contentEncoding) ||
HttpHeaderValues.X_DEFLATE.contentEqualsIgnoreCase(contentEncoding)) {
decompressionChannel = new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
ctx.channel().config(), ZlibCodecFactory.newZlibDecoder(ZlibWrapper.ZLIB_OR_NONE));
} else if (Brotli.isAvailable() && HttpHeaderValues.BR.contentEqualsIgnoreCase(contentEncoding)) {
decompressionChannel = new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
ctx.channel().config(), new BrotliDecoder());
} else if (HttpHeaderValues.SNAPPY.contentEqualsIgnoreCase(contentEncoding)) {
decompressionChannel = new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(),
ctx.channel().config(), new SnappyFrameDecoder());
} else {
decompressionChannel = null;
}
if (decompressionChannel != null) {
headers.remove(HttpHeaderNames.CONTENT_LENGTH);
headers.remove(HttpHeaderNames.CONTENT_ENCODING);
headers.add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED);
}

boolean full = request instanceof FullHttpRequest;
if (full && decompressionChannel == null) {
requestHandler.accept(ctx, request, outboundAccess);
} else if (!hasBody(request)) {
inboundHandler = droppingInboundHandler;
if (message instanceof HttpContent) {
inboundHandler.read(message);
}
if (decompressionChannel != null) {
yawkat marked this conversation as resolved.
Show resolved Hide resolved
decompressionChannel.finish();
}
requestHandler.accept(ctx, new EmptyHttpRequest(request), outboundAccess);
} else {
optimisticBufferingInboundHandler.init(request, outboundAccess);
inboundHandler = optimisticBufferingInboundHandler;
if (decompressionChannel == null) {
inboundHandler = optimisticBufferingInboundHandler;
} else {
inboundHandler = new DecompressingInboundHandler(decompressionChannel, optimisticBufferingInboundHandler);
}
if (full) {
inboundHandler.read(new DefaultLastHttpContent(((FullHttpRequest) request).content()));
}
}
}

private static String getContentEncoding(HttpHeaders headers) {
// from io.netty.handler.codec.http.HttpContentDecoder

// Determine the content encoding.
String contentEncoding = headers.get(HttpHeaderNames.CONTENT_ENCODING);
if (contentEncoding != null) {
contentEncoding = contentEncoding.trim();
} else {
String transferEncoding = headers.get(HttpHeaderNames.TRANSFER_ENCODING);
if (transferEncoding != null) {
int idx = transferEncoding.indexOf(",");
if (idx != -1) {
contentEncoding = transferEncoding.substring(0, idx).trim();
} else {
contentEncoding = transferEncoding.trim();
}
} else {
contentEncoding = null;
}
}
return contentEncoding;
}

@Override
Expand All @@ -360,7 +431,6 @@ private final class OptimisticBufferingInboundHandler extends InboundHandler {

void init(HttpRequest request, OutboundAccess outboundAccess) {
assert buffer.isEmpty();
assert !(request instanceof HttpContent);
this.request = request;
this.outboundAccess = outboundAccess;
}
Expand Down Expand Up @@ -429,7 +499,11 @@ private void devolveToStreaming() {
this.request = null;
this.outboundAccess = null;

inboundHandler = streamingInboundHandler;
if (inboundHandler == this) {
inboundHandler = streamingInboundHandler;
} else {
((DecompressingInboundHandler) inboundHandler).delegate = streamingInboundHandler;
}
Flux<HttpContent> flux = streamingInboundHandler.flux();
if (HttpUtil.is100ContinueExpected(request)) {
flux = flux.doOnSubscribe(s -> outboundAccess.writeContinue());
Expand Down Expand Up @@ -507,11 +581,7 @@ void closeIfNoSubscriber() {
}

if (sink.currentSubscriberCount() == 0) {
releaseQueue();
if (inboundHandler == this) {
inboundHandler = droppingInboundHandler;
refreshNeedMore();
}
cancelImpl();
}
}

Expand All @@ -532,14 +602,77 @@ private void cancel() {
return;
}

cancelImpl();
}

private void cancelImpl() {
if (inboundHandler == this) {
inboundHandler = droppingInboundHandler;
refreshNeedMore();
} else if (inboundHandler instanceof DecompressingInboundHandler dec && dec.delegate == this) {
dec.dispose();
inboundHandler = droppingInboundHandler;
refreshNeedMore();
}
releaseQueue();
}
}

private class DecompressingInboundHandler extends InboundHandler {
private final EmbeddedChannel channel;
private InboundHandler delegate;

public DecompressingInboundHandler(EmbeddedChannel channel, InboundHandler delegate) {
this.channel = channel;
this.delegate = delegate;
}

@Override
void read(Object message) {
ByteBuf compressed = ((HttpContent) message).content();
if (!compressed.isReadable()) {
delegate.read(message);
return;
}

channel.writeInbound(compressed);
boolean last = message instanceof LastHttpContent;
if (last) {
channel.finish();
}

while (true) {
ByteBuf decompressed = channel.readInbound();
if (decompressed == null) {
break;
}
if (!decompressed.isReadable()) {
decompressed.release();
continue;
}
delegate.read(new DefaultHttpContent(decompressed));
}

if (last) {
delegate.read(LastHttpContent.EMPTY_LAST_CONTENT);
}
}

void dispose() {
channel.finishAndReleaseAll();
}

@Override
void readComplete() {
delegate.readComplete();
}

@Override
void handleUpstreamError(Throwable cause) {
delegate.handleUpstreamError(cause);
}
}

/**
* Handler that drops all incoming content.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package io.micronaut.http.server.netty

import io.micronaut.context.ApplicationContext
import io.micronaut.context.annotation.Requires
import io.micronaut.http.HttpRequest
import io.micronaut.http.annotation.Body
import io.micronaut.http.annotation.Controller
import io.micronaut.http.annotation.Post
import io.micronaut.http.client.HttpClient
import io.micronaut.runtime.server.EmbeddedServer
import io.netty.buffer.ByteBuf
import io.netty.buffer.ByteBufUtil
import io.netty.buffer.Unpooled
import io.netty.channel.ChannelHandler
import io.netty.channel.embedded.EmbeddedChannel
import io.netty.handler.codec.compression.SnappyFrameEncoder
import io.netty.handler.codec.compression.ZlibCodecFactory
import io.netty.handler.codec.compression.ZlibWrapper
import io.netty.handler.codec.http.HttpHeaderNames
import io.netty.handler.codec.http.HttpHeaderValues
import spock.lang.Specification

import java.util.concurrent.ThreadLocalRandom

class DecompressionSpec extends Specification {
def decompression(ChannelHandler compressor, CharSequence contentEncoding) {
given:
def ctx = ApplicationContext.run(['spec.name': 'DecompressionSpec'])
def server = ctx.getBean(EmbeddedServer).start()
def client = ctx.createBean(HttpClient, server.URI).toBlocking()

def compChannel = new EmbeddedChannel(compressor)
byte[] uncompressed = new byte[1024]
ThreadLocalRandom.current().nextBytes(uncompressed)
compChannel.writeOutbound(Unpooled.copiedBuffer(uncompressed))
compChannel.finish()
ByteBuf compressed = Unpooled.buffer()
while (true) {
ByteBuf o = compChannel.readOutbound()
if (o == null) {
break
}
compressed.writeBytes(o)
o.release()
}

when:
client.exchange(HttpRequest.POST("/decompress", ByteBufUtil.getBytes(compressed)).header(HttpHeaderNames.CONTENT_ENCODING, contentEncoding))
then:
ctx.getBean(Ctrl).data == uncompressed

cleanup:
client.close()
server.stop()
ctx.close()

where:
contentEncoding | compressor
HttpHeaderValues.GZIP | ZlibCodecFactory.newZlibEncoder(ZlibWrapper.GZIP)
HttpHeaderValues.X_GZIP | ZlibCodecFactory.newZlibEncoder(ZlibWrapper.GZIP)
HttpHeaderValues.DEFLATE | ZlibCodecFactory.newZlibEncoder(ZlibWrapper.NONE)
HttpHeaderValues.X_DEFLATE | ZlibCodecFactory.newZlibEncoder(ZlibWrapper.NONE)
HttpHeaderValues.SNAPPY | new SnappyFrameEncoder()
}

@Requires(property = "spec.name", value = "DecompressionSpec")
@Controller
static class Ctrl {
byte[] data

@Post("/decompress")
void receive(@Body byte[] data) {
this.data = data
}
}
}
Loading
Loading