Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support protobuf #957

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/Dockerfile.python3.6
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ RUN set -ex \
&& pip install retrying \
&& pip install mock \
&& pip install pytest -U \
&& pip install pytest-mock \
&& pip install pylint

# Install protobuf
Expand Down
8 changes: 8 additions & 0 deletions frontend/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ configure(javaProjects()) {
}

jacocoTestCoverageVerification {
afterEvaluate {
classDirectories = files(classDirectories.files.collect {
fileTree(dir: it, exclude: [
'com/amazonaws/ml/mms/protobuf/codegen/*',
])
})
}

violationRules {
rule {
limit {
Expand Down
21 changes: 21 additions & 0 deletions frontend/server/build.gradle
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
plugins {
id "com.google.protobuf" version "0.8.10"
id "java"
id "idea"
}

dependencies {
compile "io.netty:netty-all:${netty_version}"
compile project(":modelarchive")
compile "commons-cli:commons-cli:${commons_cli_version}"
compile "software.amazon.ai:mms-plugins-sdk:${mms_server_sdk_version}"
compile "com.google.protobuf:protobuf-java:3.13.0"
testCompile "org.testng:testng:${testng_version}"
}

Expand All @@ -22,6 +29,20 @@ jar {
exclude "META-INF//NOTICE*"
}

protobuf {
// Configure the protoc executable
protoc {
// Download from repositories
artifact = 'com.google.protobuf:protoc:3.13.0'
}
}

idea {
module {
generatedSourceDirs += file('build/generated/source/proto')
}
}

test.doFirst {
systemProperty "mmsConfigFile", 'src/test/resources/config.properties'
systemProperty "METRICS_LOCATION","build/logs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ protected void handleRequest(
String[] segments)
throws ModelException {

if (isApiDescription(segments)) {
if (decoder != null && isApiDescription(segments)) {
String path = decoder.path();
if (("/".equals(path) && HttpMethod.OPTIONS.equals(req.method()))
|| (segments.length == 2 && segments[1].equals("api-description"))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

import com.amazonaws.ml.mms.archive.ModelException;
import com.amazonaws.ml.mms.archive.ModelNotFoundException;
import com.amazonaws.ml.mms.util.ConfigManager;
import com.amazonaws.ml.mms.util.NettyUtils;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.QueryStringDecoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -32,6 +34,7 @@ public class HttpRequestHandler extends SimpleChannelInboundHandler<FullHttpRequ

private static final Logger logger = LoggerFactory.getLogger(HttpRequestHandler.class);
HttpRequestHandlerChain handlerChain;

/** Creates a new {@code HttpRequestHandler} instance. */
public HttpRequestHandler() {}

Expand All @@ -42,47 +45,89 @@ public HttpRequestHandler(HttpRequestHandlerChain chain) {
/** {@inheritDoc} */
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) {
boolean proto = false;
try {
NettyUtils.requestReceived(ctx.channel(), req);
if (!req.decoderResult().isSuccess()) {
throw new BadRequestException("Invalid HTTP message.");
}
QueryStringDecoder decoder = new QueryStringDecoder(req.uri());
String path = decoder.path();

String[] segments = path.split("/");
handlerChain.handleRequest(ctx, req, decoder, segments);
CharSequence contentType = HttpUtil.getMimeType(req);
if (contentType != null
&& contentType
.toString()
.toLowerCase()
.contains(ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF)) {
proto = true;
handlerChain.handleRequest(ctx, req, null, null);
} else {
QueryStringDecoder decoder = new QueryStringDecoder(req.uri());
String path = decoder.path();
String[] segments = path.split("/");
handlerChain.handleRequest(ctx, req, decoder, segments);
}
} catch (ResourceNotFoundException | ModelNotFoundException e) {
logger.trace("", e);
NettyUtils.sendError(ctx, HttpResponseStatus.NOT_FOUND, e);
if (proto) {
NettyUtils.sendErrorProto(ctx, HttpResponseStatus.NOT_FOUND, e);
} else {
NettyUtils.sendError(ctx, HttpResponseStatus.NOT_FOUND, e);
}
} catch (BadRequestException | ModelException e) {
logger.trace("", e);
NettyUtils.sendError(ctx, HttpResponseStatus.BAD_REQUEST, e);
if (proto) {
NettyUtils.sendErrorProto(ctx, HttpResponseStatus.BAD_REQUEST, e);
} else {
NettyUtils.sendError(ctx, HttpResponseStatus.BAD_REQUEST, e);
}
} catch (ConflictStatusException e) {
logger.trace("", e);
NettyUtils.sendError(ctx, HttpResponseStatus.CONFLICT, e);
if (proto) {
NettyUtils.sendErrorProto(ctx, HttpResponseStatus.CONFLICT, e);
} else {
NettyUtils.sendError(ctx, HttpResponseStatus.CONFLICT, e);
}
} catch (RequestTimeoutException e) {
logger.trace("", e);
NettyUtils.sendError(ctx, HttpResponseStatus.REQUEST_TIMEOUT, e);
if (proto) {
NettyUtils.sendErrorProto(ctx, HttpResponseStatus.REQUEST_TIMEOUT, e);
} else {
NettyUtils.sendError(ctx, HttpResponseStatus.REQUEST_TIMEOUT, e);
}
} catch (MethodNotAllowedException e) {
logger.trace("", e);
NettyUtils.sendError(ctx, HttpResponseStatus.METHOD_NOT_ALLOWED, e);
if (proto) {
NettyUtils.sendErrorProto(ctx, HttpResponseStatus.METHOD_NOT_ALLOWED, e);
} else {
NettyUtils.sendError(ctx, HttpResponseStatus.METHOD_NOT_ALLOWED, e);
}
} catch (ServiceUnavailableException e) {
logger.trace("", e);
NettyUtils.sendError(ctx, HttpResponseStatus.SERVICE_UNAVAILABLE, e);
if (proto) {
NettyUtils.sendErrorProto(ctx, HttpResponseStatus.SERVICE_UNAVAILABLE, e);
} else {
NettyUtils.sendError(ctx, HttpResponseStatus.SERVICE_UNAVAILABLE, e);
}
} catch (OutOfMemoryError e) {
logger.trace("", e);
NettyUtils.sendError(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, e);
if (proto) {
NettyUtils.sendErrorProto(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, e);
} else {
NettyUtils.sendError(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, e);
}
} catch (Throwable t) {
logger.error("", t);
NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, t);
if (proto) {
NettyUtils.sendErrorProto(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, t);
} else {
NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, t);
}
}
}

/** {@inheritDoc} */
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
logger.error("", cause);
logger.error("exceptionCaught:", cause);
if (cause instanceof OutOfMemoryError) {
NettyUtils.sendError(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, cause);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import com.amazonaws.ml.mms.archive.ModelException;
import com.amazonaws.ml.mms.archive.ModelNotFoundException;
import com.amazonaws.ml.mms.protobuf.codegen.InferenceRequest;
import com.amazonaws.ml.mms.protobuf.codegen.RequestInput;
import com.amazonaws.ml.mms.servingsdk.impl.ModelServerContext;
import com.amazonaws.ml.mms.servingsdk.impl.ModelServerRequest;
import com.amazonaws.ml.mms.servingsdk.impl.ModelServerResponse;
Expand Down Expand Up @@ -49,32 +51,30 @@ private void run(
FullHttpRequest req,
FullHttpResponse rsp,
QueryStringDecoder decoder,
String method)
RequestInput input)
throws IOException {
switch (method) {
ModelServerRequest modelServerRequest;
if (decoder == null) {
modelServerRequest = new ModelServerRequest(req, input);
} else {
modelServerRequest = new ModelServerRequest(req, decoder);
}
switch (req.method().toString()) {
case "GET":
endpoint.doGet(
new ModelServerRequest(req, decoder),
new ModelServerResponse(rsp),
new ModelServerContext());
modelServerRequest, new ModelServerResponse(rsp), new ModelServerContext());
break;
case "PUT":
endpoint.doPut(
new ModelServerRequest(req, decoder),
new ModelServerResponse(rsp),
new ModelServerContext());
modelServerRequest, new ModelServerResponse(rsp), new ModelServerContext());
break;
case "DELETE":
endpoint.doDelete(
new ModelServerRequest(req, decoder),
new ModelServerResponse(rsp),
new ModelServerContext());
modelServerRequest, new ModelServerResponse(rsp), new ModelServerContext());
break;
case "POST":
endpoint.doPost(
new ModelServerRequest(req, decoder),
new ModelServerResponse(rsp),
new ModelServerContext());
modelServerRequest, new ModelServerResponse(rsp), new ModelServerContext());
break;
default:
throw new ServiceUnavailableException("Invalid HTTP method received");
Expand All @@ -94,7 +94,7 @@ protected void handleCustomEndpoint(
new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false);
try {
run(endpoint, req, rsp, decoder, req.method().toString());
run(endpoint, req, rsp, decoder, null);
NettyUtils.sendHttpResponse(ctx, rsp, true);
logger.info(
"Running \"{}\" endpoint took {} ms",
Expand All @@ -106,6 +106,7 @@ protected void handleCustomEndpoint(
} catch (OutOfMemoryError oom) {
NettyUtils.sendError(
ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, oom, "Out of memory");
logger.error("Out of memory while running the custom endpoint.", oom);
} catch (IOException ioe) {
NettyUtils.sendError(
ctx,
Expand All @@ -124,4 +125,40 @@ protected void handleCustomEndpoint(
};
ModelManager.getInstance().submitTask(r);
}

protected void handleCustomEndpoint(
ChannelHandlerContext ctx, FullHttpRequest req, InferenceRequest inferenceRequest) {
ModelServerEndpoint endpoint = endpointMap.get(inferenceRequest.getCustomCommand());
Runnable r =
() -> {
Long start = System.currentTimeMillis();
FullHttpResponse rsp =
new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false);
try {
run(endpoint, req, rsp, null, inferenceRequest.getRequest());
NettyUtils.sendHttpResponse(ctx, rsp, true);
logger.info(
"Running \"{}\" endpoint took {} ms",
inferenceRequest.getCustomCommand(),
System.currentTimeMillis() - start);
} catch (ModelServerEndpointException me) {
NettyUtils.sendErrorProto(
ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, me);
logger.error("Error thrown by the model endpoint plugin.", me);
} catch (OutOfMemoryError oom) {
NettyUtils.sendErrorProto(
ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, oom, "Out of memory");
logger.error("Out of memory while running the custom endpoint.", oom);
} catch (IOException ioe) {
NettyUtils.sendErrorProto(
ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, ioe);
logger.error("I/O error while running the custom endpoint.", ioe);
} catch (Throwable e) {
NettyUtils.sendErrorProto(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, e);
logger.error("Unknown exception", e);
}
};
ModelManager.getInstance().submitTask(r);
}
}
Loading