diff --git a/ci/Dockerfile.python3.6 b/ci/Dockerfile.python3.6 index 76c45a314..f77a07e7d 100644 --- a/ci/Dockerfile.python3.6 +++ b/ci/Dockerfile.python3.6 @@ -190,6 +190,7 @@ RUN set -ex \ && pip install retrying \ && pip install mock \ && pip install pytest -U \ + && pip install pytest-mock \ && pip install pylint # Install protobuf diff --git a/frontend/build.gradle b/frontend/build.gradle index cc06f55fd..e6274389d 100644 --- a/frontend/build.gradle +++ b/frontend/build.gradle @@ -45,6 +45,14 @@ configure(javaProjects()) { } jacocoTestCoverageVerification { + afterEvaluate { + classDirectories = files(classDirectories.files.collect { + fileTree(dir: it, exclude: [ + 'com/amazonaws/ml/mms/protobuf/codegen/*', + ]) + }) + } + violationRules { rule { limit { diff --git a/frontend/server/build.gradle b/frontend/server/build.gradle index 3ad36d395..e96cc42f4 100644 --- a/frontend/server/build.gradle +++ b/frontend/server/build.gradle @@ -1,8 +1,15 @@ +plugins { + id "com.google.protobuf" version "0.8.10" + id "java" + id "idea" +} + dependencies { compile "io.netty:netty-all:${netty_version}" compile project(":modelarchive") compile "commons-cli:commons-cli:${commons_cli_version}" compile "software.amazon.ai:mms-plugins-sdk:${mms_server_sdk_version}" + compile "com.google.protobuf:protobuf-java:3.13.0" testCompile "org.testng:testng:${testng_version}" } @@ -22,6 +29,20 @@ jar { exclude "META-INF//NOTICE*" } +protobuf { + // Configure the protoc executable + protoc { + // Download from repositories + artifact = 'com.google.protobuf:protoc:3.13.0' + } +} + +idea { + module { + generatedSourceDirs += file('build/generated/source/proto') + } +} + test.doFirst { systemProperty "mmsConfigFile", 'src/test/resources/config.properties' systemProperty "METRICS_LOCATION","build/logs" diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ApiDescriptionRequestHandler.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ApiDescriptionRequestHandler.java index f0d36ad70..850f65a60 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ApiDescriptionRequestHandler.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/ApiDescriptionRequestHandler.java @@ -25,7 +25,7 @@ protected void handleRequest( String[] segments) throws ModelException { - if (isApiDescription(segments)) { + if (decoder != null && isApiDescription(segments)) { String path = decoder.path(); if (("/".equals(path) && HttpMethod.OPTIONS.equals(req.method())) || (segments.length == 2 && segments[1].equals("api-description"))) { diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandler.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandler.java index ce3d28c1b..8f2969c28 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandler.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandler.java @@ -14,11 +14,13 @@ import com.amazonaws.ml.mms.archive.ModelException; import com.amazonaws.ml.mms.archive.ModelNotFoundException; +import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.NettyUtils; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.QueryStringDecoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,6 +34,7 @@ public class HttpRequestHandler extends SimpleChannelInboundHandler { + Long start = System.currentTimeMillis(); + FullHttpResponse rsp = + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false); + try { + run(endpoint, req, rsp, null, inferenceRequest.getRequest()); + NettyUtils.sendHttpResponse(ctx, rsp, true); + logger.info( + "Running \"{}\" endpoint took {} ms", + inferenceRequest.getCustomCommand(), + System.currentTimeMillis() - start); + } catch (ModelServerEndpointException me) { + NettyUtils.sendErrorProto( + ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, me); + logger.error("Error thrown by the model endpoint plugin.", me); + } catch (OutOfMemoryError oom) { + NettyUtils.sendErrorProto( + ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, oom, "Out of memory"); + logger.error("Out of memory while running the custom endpoint.", oom); + } catch (IOException ioe) { + NettyUtils.sendErrorProto( + ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, ioe); + logger.error("I/O error while running the custom endpoint.", ioe); + } catch (Throwable e) { + NettyUtils.sendErrorProto(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, e); + logger.error("Unknown exception", e); + } + }; + ModelManager.getInstance().submitTask(r); + } } diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java index 8189a0425..f0c0b93bb 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java @@ -15,6 +15,7 @@ import com.amazonaws.ml.mms.archive.ModelException; import com.amazonaws.ml.mms.archive.ModelNotFoundException; import com.amazonaws.ml.mms.openapi.OpenApiUtils; +import com.amazonaws.ml.mms.protobuf.codegen.InferenceRequest; import com.amazonaws.ml.mms.util.NettyUtils; import com.amazonaws.ml.mms.util.messages.InputParameter; import com.amazonaws.ml.mms.util.messages.RequestInput; @@ -22,6 +23,7 @@ import com.amazonaws.ml.mms.wlm.Job; import com.amazonaws.ml.mms.wlm.Model; import com.amazonaws.ml.mms.wlm.ModelManager; +import com.google.protobuf.InvalidProtocolBufferException; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderValues; @@ -51,39 +53,6 @@ public InferenceRequestHandler(Map ep) { endpointMap = ep; } - @Override - protected void handleRequest( - ChannelHandlerContext ctx, - FullHttpRequest req, - QueryStringDecoder decoder, - String[] segments) - throws ModelException { - if (isInferenceReq(segments)) { - if (endpointMap.getOrDefault(segments[1], null) != null) { - handleCustomEndpoint(ctx, req, segments, decoder); - } else { - switch (segments[1]) { - case "ping": - ModelManager.getInstance().workerStatus(ctx); - break; - case "models": - case "invocations": - validatePredictionsEndpoint(segments); - handleInvocations(ctx, req, decoder, segments); - break; - case "predictions": - handlePredictions(ctx, req, segments); - break; - default: - handleLegacyPredict(ctx, req, decoder, segments); - break; - } - } - } else { - chain.handleRequest(ctx, req, decoder, segments); - } - } - private boolean isInferenceReq(String[] segments) { return segments.length == 0 || segments[1].equals("ping") @@ -116,6 +85,63 @@ private void handlePredictions( predict(ctx, req, null, segments[2]); } + @Override + protected void handleRequest( + ChannelHandlerContext ctx, + FullHttpRequest req, + QueryStringDecoder decoder, + String[] segments) + throws ModelNotFoundException, ModelException { + if (decoder == null) { + try { + InferenceRequest inferenceRequest = + InferenceRequest.parseFrom(req.content().nioBuffer()); + + switch (inferenceRequest.getCommandValue()) { + case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.ping_VALUE: + ModelManager.getInstance().workerStatus(ctx, true); + break; + case com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands.predictions_VALUE: + handlePredictions(ctx, inferenceRequest, req.method()); + break; + default: + if (endpointMap.getOrDefault(inferenceRequest.getCustomCommand(), null) + != null) { + handleCustomEndpoint(ctx, req, inferenceRequest); + } else { + chain.handleRequest(ctx, req, null, segments); + } + break; + } + } catch (InvalidProtocolBufferException e) { + chain.handleRequest(ctx, req, null, segments); + } + } else if (isInferenceReq(segments)) { + if (endpointMap.getOrDefault(segments[1], null) != null) { + handleCustomEndpoint(ctx, req, segments, decoder); + } else { + switch (segments[1]) { + case "ping": + ModelManager.getInstance().workerStatus(ctx, false); + break; + case "models": + case "invocations": + validatePredictionsEndpoint(segments); + handleInvocations(ctx, req, decoder, segments); + break; + case "predictions": + handlePredictions(ctx, req, segments); + break; + default: + handleLegacyPredict(ctx, req, decoder, segments); + break; + } + } + } else { + chain.handleRequest(ctx, req, decoder, segments); + } + } + private void handleInvocations( ChannelHandlerContext ctx, FullHttpRequest req, @@ -147,6 +173,29 @@ private void handleLegacyPredict( predict(ctx, req, decoder, segments[1]); } + private void handlePredictions( + ChannelHandlerContext ctx, InferenceRequest inferenceRequest, HttpMethod method) + throws ModelNotFoundException { + String modelName = inferenceRequest.getModelName(); + if (modelName.isEmpty()) { + if (ModelManager.getInstance().getStartupModels().size() == 1) { + modelName = ModelManager.getInstance().getStartupModels().iterator().next(); + } + } + + RequestInput input = new RequestInput(NettyUtils.getRequestId(ctx.channel())); + input.setProto(true); + com.amazonaws.ml.mms.protobuf.codegen.RequestInput protoInput = + inferenceRequest.getRequest(); + input.setHeaders(protoInput.getHeadersMap()); + for (com.amazonaws.ml.mms.protobuf.codegen.InputParameter parameter : + protoInput.getParametersList()) { + input.addParameter( + new InputParameter(parameter.getName(), parameter.getValue().toByteArray())); + } + predict(ctx, modelName, input, method); + } + private void predict( ChannelHandlerContext ctx, FullHttpRequest req, @@ -154,11 +203,17 @@ private void predict( String modelName) throws ModelNotFoundException, BadRequestException { RequestInput input = parseRequest(ctx, req, decoder); + predict(ctx, modelName, input, req.method()); + } + + private void predict( + ChannelHandlerContext ctx, String modelName, RequestInput input, HttpMethod method) + throws ModelNotFoundException, BadRequestException { if (modelName == null) { throw new BadRequestException("Parameter model_name is required."); } - if (HttpMethod.OPTIONS.equals(req.method())) { + if (HttpMethod.OPTIONS.equals(method)) { ModelManager modelManager = ModelManager.getInstance(); Model model = modelManager.getModels().get(modelName); if (model == null) { diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerRequest.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerRequest.java index 9423a1eee..d8fa571f8 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerRequest.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerRequest.java @@ -13,11 +13,14 @@ package com.amazonaws.ml.mms.servingsdk.impl; +import com.amazonaws.ml.mms.protobuf.codegen.InputParameter; +import com.amazonaws.ml.mms.protobuf.codegen.RequestInput; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.QueryStringDecoder; import java.io.ByteArrayInputStream; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import software.amazon.ai.mms.servingsdk.http.Request; @@ -25,14 +28,28 @@ public class ModelServerRequest implements Request { private FullHttpRequest req; private QueryStringDecoder decoder; + private RequestInput input; + private Map> parameterMap; public ModelServerRequest(FullHttpRequest r, QueryStringDecoder d) { req = r; decoder = d; + this.input = null; + parameterMap = null; + } + + public ModelServerRequest(FullHttpRequest r, RequestInput input) { + req = r; + decoder = null; + this.input = input; + parameterMap = null; } @Override public List getHeaderNames() { + if (decoder == null) { + return new ArrayList<>(input.getHeadersMap().keySet()); + } return new ArrayList<>(req.headers().names()); } @@ -43,12 +60,28 @@ public String getRequestURI() { @Override public Map> getParameterMap() { - return decoder.parameters(); + if (parameterMap == null) { + if (decoder == null) { + parameterMap = new HashMap<>(); + for (InputParameter parameter : input.getParametersList()) { + List values = + parameterMap.computeIfAbsent( + parameter.getName(), r -> new ArrayList<>()); + values.add(parameter.getValue().toString()); + } + } else { + parameterMap = decoder.parameters(); + } + } + return parameterMap; } @Override public List getParameter(String k) { - return decoder.parameters().get(k); + if (parameterMap == null) { + getParameterMap(); + } + return parameterMap.get(k); } @Override @@ -58,6 +91,9 @@ public String getContentType() { @Override public ByteArrayInputStream getInputStream() { + if (decoder == null) { + return new ByteArrayInputStream(input.toByteArray()); + } return new ByteArrayInputStream(req.content().array()); } } diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java index 3f4031c9a..4d0dd3b78 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java @@ -99,6 +99,8 @@ public final class ConfigManager { public static final String MODEL_LOGGER = "MODEL_LOG"; public static final String MODEL_SERVER_METRICS_LOGGER = "MMS_METRICS"; + public static final String HTTP_CONTENT_TYPE_PROTOBUF = "application/x-protobuf"; + private Pattern blacklistPattern; private Properties prop; private static Pattern pattern = Pattern.compile("\\$\\$([^$]+[^$])\\$\\$"); diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/NettyUtils.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/NettyUtils.java index c2caf9a19..b5b560abd 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/NettyUtils.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/NettyUtils.java @@ -138,6 +138,34 @@ public static void sendError( sendJsonResponse(ctx, error, status); } + public static void sendErrorProto( + ChannelHandlerContext ctx, HttpResponseStatus status, Throwable t) { + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse errorResponse = + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse.newBuilder() + .setCode(status.code()) + .setType(t.getClass().getSimpleName()) + .setMessage(t.getMessage()) + .build(); + FullHttpResponse resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, false); + resp.headers().set(HttpHeaderNames.CONTENT_TYPE, ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); + resp.content().writeBytes(errorResponse.toByteArray()); + sendHttpResponse(ctx, resp, true); + } + + public static void sendErrorProto( + ChannelHandlerContext ctx, HttpResponseStatus status, Throwable t, String msg) { + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse errorResponse = + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse.newBuilder() + .setCode(status.code()) + .setType(t.getClass().getSimpleName()) + .setMessage(msg) + .build(); + FullHttpResponse resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, false); + resp.headers().set(HttpHeaderNames.CONTENT_TYPE, ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); + resp.content().writeBytes(errorResponse.toByteArray()); + sendHttpResponse(ctx, resp, true); + } + /** * Send HTTP response to client. * diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/RequestInput.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/RequestInput.java index f3a889914..8563ebc10 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/RequestInput.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/RequestInput.java @@ -23,11 +23,13 @@ public class RequestInput { private String requestId; private Map headers; private List parameters; + private boolean proto; public RequestInput(String requestId) { this.requestId = requestId; headers = new HashMap<>(); parameters = new ArrayList<>(); + proto = false; } public String getRequestId() { @@ -62,6 +64,14 @@ public void addParameter(InputParameter modelInput) { parameters.add(modelInput); } + public void setProto(boolean isProto) { + this.proto = isProto; + } + + public boolean isProto() { + return this.proto; + } + public String getStringParameter(String key) { for (InputParameter param : parameters) { if (key.equals(param.getName())) { diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Job.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Job.java index b669106c6..ea69096cd 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Job.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Job.java @@ -13,9 +13,12 @@ package com.amazonaws.ml.mms.wlm; import com.amazonaws.ml.mms.http.InternalServerException; +import com.amazonaws.ml.mms.protobuf.codegen.Predictions; +import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.NettyUtils; import com.amazonaws.ml.mms.util.messages.RequestInput; import com.amazonaws.ml.mms.util.messages.WorkerCommands; +import com.google.protobuf.ByteString; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse; @@ -37,6 +40,7 @@ public class Job { private RequestInput input; private long begin; private long scheduled; + private boolean proto; public Job( ChannelHandlerContext ctx, String modelName, WorkerCommands cmd, RequestInput input) { @@ -44,6 +48,7 @@ public Job( this.modelName = modelName; this.cmd = cmd; this.input = input; + this.proto = input.isProto(); begin = System.currentTimeMillis(); scheduled = begin; @@ -69,6 +74,14 @@ public RequestInput getPayload() { return input; } + public void setProto(boolean isProto) { + this.proto = isProto; + } + + public boolean isProto() { + return this.proto; + } + public void setScheduled() { scheduled = System.currentTimeMillis(); } @@ -85,15 +98,30 @@ public void response( : HttpResponseStatus.valueOf(statusCode, statusPhrase); FullHttpResponse resp = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, false); - if (contentType != null && contentType.length() > 0) { - resp.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType); - } - if (responseHeaders != null) { - for (Map.Entry e : responseHeaders.entrySet()) { - resp.headers().set(e.getKey(), e.getValue()); + if (proto) { + Predictions predictions = + Predictions.newBuilder() + .setRequestId(getJobId()) + .setStatusCode(statusCode) + .setReasonPhrase(statusPhrase) + .setContentType(contentType.toString()) + .putAllHeaders(responseHeaders) + .setResp(ByteString.copyFrom(body)) + .build(); + resp.headers() + .set(HttpHeaderNames.CONTENT_TYPE, ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); + resp.content().writeBytes(predictions.toByteArray()); + } else { + if (contentType != null && contentType.length() > 0) { + resp.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType); + } + if (responseHeaders != null) { + for (Map.Entry e : responseHeaders.entrySet()) { + resp.headers().set(e.getKey(), e.getValue()); + } } + resp.content().writeBytes(body); } - resp.content().writeBytes(body); /* * We can load the models based on the configuration file.Since this Job is @@ -119,7 +147,11 @@ public void sendError(HttpResponseStatus status, String error) { * by external clients. */ if (ctx != null) { - NettyUtils.sendError(ctx, status, new InternalServerException(error)); + if (proto) { + NettyUtils.sendErrorProto(ctx, status, new InternalServerException(error)); + } else { + NettyUtils.sendError(ctx, status, new InternalServerException(error)); + } } logger.debug( diff --git a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/ModelManager.java b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/ModelManager.java index cc478bc10..9dccf37aa 100644 --- a/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/ModelManager.java +++ b/frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/ModelManager.java @@ -21,7 +21,11 @@ import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.NettyUtils; import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; import java.io.IOException; import java.util.HashSet; import java.util.List; @@ -218,7 +222,7 @@ public boolean addJob(Job job) throws ModelNotFoundException { return model.addJob(job); } - public void workerStatus(final ChannelHandlerContext ctx) { + public void workerStatus(final ChannelHandlerContext ctx, boolean isProto) { Runnable r = () -> { String response = "Healthy"; @@ -237,8 +241,24 @@ public void workerStatus(final ChannelHandlerContext ctx) { // TODO: Check if its OK to send other 2xx errors to ALB for "Partial Healthy" // and "Unhealthy" - NettyUtils.sendJsonResponse( - ctx, new StatusResponse(response), HttpResponseStatus.OK); + if (isProto) { + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse statusResponse = + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse.newBuilder() + .setMessage(response) + .build(); + FullHttpResponse resp = + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false); + resp.headers() + .set( + HttpHeaderNames.CONTENT_TYPE, + ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); + resp.content().writeBytes(statusResponse.toByteArray()); + NettyUtils.sendHttpResponse(ctx, resp, true); + } else { + NettyUtils.sendJsonResponse( + ctx, new StatusResponse(response), HttpResponseStatus.OK); + } }; wlm.scheduleAsync(r); } diff --git a/frontend/server/src/main/proto/inference.proto b/frontend/server/src/main/proto/inference.proto new file mode 100644 index 000000000..06c8f6449 --- /dev/null +++ b/frontend/server/src/main/proto/inference.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; + +package com.amazonaws.ml.mms.protobuf.codegen; + +option optimize_for = SPEED; +option java_package = "com.amazonaws.ml.mms.protobuf.codegen"; +option java_multiple_files = true; + +message InferenceRequest { + string modelName = 1; + WorkerCommands command = 2; + string customCommand = 3; + RequestInput request = 4; +} + +message Predictions { + string requestId = 1; + int32 statusCode = 2; + string reasonPhrase = 3; + string contentType = 4; + map headers = 5; + bytes resp = 6; +} + +message StatusResponse { + int32 code = 1; + string type = 2; + string message = 3; +} + +enum WorkerCommands { + ping = 0; + predictions = 1; +} + +message RequestInput { + string requestId = 1; + map headers = 2; + repeated InputParameter parameters = 3; +} + +message InputParameter { + string name = 1; + bytes value = 2; + string contentType = 3; +} \ No newline at end of file diff --git a/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java b/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java index 8aca550d6..69bbae9fd 100644 --- a/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java +++ b/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java @@ -19,11 +19,17 @@ import com.amazonaws.ml.mms.metrics.Dimension; import com.amazonaws.ml.mms.metrics.Metric; import com.amazonaws.ml.mms.metrics.MetricManager; +import com.amazonaws.ml.mms.protobuf.codegen.InferenceRequest; +import com.amazonaws.ml.mms.protobuf.codegen.InputParameter; +import com.amazonaws.ml.mms.protobuf.codegen.RequestInput; +import com.amazonaws.ml.mms.protobuf.codegen.WorkerCommands; import com.amazonaws.ml.mms.servingsdk.impl.PluginsManager; import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.Connector; import com.amazonaws.ml.mms.util.JsonUtils; import com.google.gson.JsonParseException; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; @@ -60,14 +66,20 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; +import java.io.ObjectOutputStream; import java.lang.reflect.Field; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; +import java.util.LinkedList; import java.util.List; import java.util.Properties; +import java.util.Random; import java.util.Scanner; import java.util.concurrent.CountDownLatch; import org.apache.commons.io.IOUtils; +import org.apache.commons.io.output.ByteArrayOutputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; @@ -87,6 +99,7 @@ public class ModelServerTest { CountDownLatch latch; HttpResponseStatus httpStatus; String result; + ByteBuffer resultBuf; HttpHeaders headers; private String listInferenceApisResult; private String listManagementApisResult; @@ -150,6 +163,7 @@ public void test() Assert.assertNotNull(channel, "Failed to connect to inference port."); Assert.assertNotNull(managementChannel, "Failed to connect to management port."); testPing(channel); + testPingProto(channel); testRoot(channel, listInferenceApisResult); testRoot(managementChannel, listManagementApisResult); @@ -167,6 +181,7 @@ public void test() testPredictions(channel); testPredictionsBinary(channel); testPredictionsJson(channel); + testPredictionsProto(); testInvocationsJson(channel); testInvocationsMultipart(channel); testModelsInvokeJson(channel); @@ -195,7 +210,7 @@ public void test() testInvalidPredictionsUri(); testInvalidDescribeModel(); testPredictionsModelNotFound(); - + testPredictionsModelNotFoundProto(); testInvalidManagementUri(); testInvalidModelsMethod(); testInvalidModelMethod(); @@ -237,6 +252,25 @@ private void testPing(Channel channel) throws InterruptedException { Assert.assertTrue(headers.contains("x-request-id")); } + private void testPingProto(Channel channel) + throws InterruptedException, InvalidProtocolBufferException { + resultBuf = null; + latch = new CountDownLatch(1); + InferenceRequest inferenceRequest = + InferenceRequest.newBuilder().setCommand(WorkerCommands.ping).build(); + DefaultFullHttpRequest req = + new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/ping"); + req.headers().add("Content-Type", ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); + req.content().writeBytes(inferenceRequest.toByteArray()); + channel.writeAndFlush(req); + latch.await(); + + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse resp = + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse.parseFrom(resultBuf); + Assert.assertEquals(resp.getMessage(), "Healthy"); + Assert.assertTrue(headers.contains("x-request-id")); + } + private void testApiDescription(Channel channel, String expected) throws InterruptedException { result = null; latch = new CountDownLatch(1); @@ -430,6 +464,168 @@ private void testPredictionsBinary(Channel channel) throws InterruptedException Assert.assertEquals(result, "OK"); } + private void testPredictionsProto() throws InterruptedException, IOException { + Logger logger = LoggerFactory.getLogger(ModelServerTest.class); + + float[] featureVec = { + 0.8241127f, 0.77719664f, 0.47123995f, 0.27323001f, 0.24874457f, 0.77869387f, + 0.50711921f, 0.10696663f, 0.60663805f, 0.76063525f, 0.96358908f, 0.71026102f, + 0.57714464f, 0.58250422f, 0.91595038f, 0.24119576f, 0.58981158f, 0.67119473f, + 0.94832165f, 0.91711728f, 0.0323646f, 0.07007003f, 0.89158581f, 0.01916486f, + 0.5647568f, 0.99879008f, 0.58311515f, 0.87001143f, 0.50620349f, 0.65268692f, + 0.83657373f, 0.31589474f, 0.70910797f, 0.62886395f, 0.03498501f, 0.36503007f, + 0.94178899f, 0.21739391f, 0.29688258f, 0.34630696f, 0.30494259f, 0.04302086f, + 0.3578226f, 0.04361075f, 0.91962488f, 0.24961093f, 0.0124245f, 0.31004002f, + 0.61543447f, 0.34500444f, 0.30441186f, 0.44085924f, 0.67489625f, 0.03938287f, + 0.89307169f, 0.22283647f, 0.44441515f, 0.82044036f, 0.37541783f, 0.25868981f, + 0.46510721f, 0.51640271f, 0.40917042f, 0.65912921f, 0.72228879f, 0.42611241f, + 0.71283259f, 0.37417586f, 0.786403f, 0.6912011f, 0.4338622f, 0.29868897f, + 0.0342538f, 0.16938266f, 0.90234809f, 0.3051922f, 0.92377579f, 0.97883088f, + 0.2028601f, 0.50478822f, 0.84762944f, 0.11011502f, 0.70006246f, 0.34329564f, + 0.49022718f, 0.8569296f, 0.75698334f, 0.84864789f, 0.9477985f, 0.46994381f, + 0.05319027f, 0.07369953f, 0.08497094f, 0.54536333f, 0.87922514f, 0.97857665f, + 0.06930542f, 0.27101086f, 0.03069235f, 0.13432096f, 0.96021588f, 0.9484153f, + 0.75365465f, 0.76216408f, 0.43294879f, 0.41034781f, 0.01088872f, 0.29060839f, + 0.94462721f, 0.83999491f, 0.4364634f, 0.63611379f, 0.32102346f, 0.10418961f, + 0.2776194f, 0.73166493f, 0.76387601f, 0.83429646f, 0.94348065f, 0.85956626f, + 0.81160069f, 0.1650624f, 0.79505978f, 0.67288331f, 0.3204887f, 0.89388283f, + 0.85290859f, 0.11308228f, 0.81252801f, 0.87276483f, 0.76737167f, 0.16166891f, + 0.78767838f, 0.79160494f, 0.80843258f, 0.39723985f, 0.47062281f, 0.96028728f, + 0.55309858f, 0.05378428f, 0.3619188f, 0.69888766f, 0.76134346f, 0.60911425f, + 0.85562674f, 0.58098788f, 0.5438003f, 0.61229528f, 0.14350196f, 0.75286178f, + 0.88131248f, 0.69132185f, 0.12576858f, 0.23459534f, 0.26883056f, 0.98129534f, + 0.74060036f, 0.9607236f, 0.99617814f, 0.75829678f, 0.06310486f, 0.55572225f, + 0.72709395f, 0.77374732f, 0.81625695f, 0.13475297f, 0.89352917f, 0.19805313f, + 0.34789188f, 0.08422005f, 0.67733949f, 0.94300965f, 0.22116594f, 0.10948816f, + 0.50651639f, 0.40402931f, 0.46181863f, 0.14743327f, 0.33300708f, 0.87358395f, + 0.79312213f, 0.54662338f, 0.83890467f, 0.87690315f, 0.24570711f, 0.01534696f, + 0.11803501f, 0.21333099f, 0.75169896f, 0.42758898f, 0.80780874f, 0.57331851f, + 0.96341639f, 0.52078203f, 0.22610806f, 0.83348684f, 0.76036637f, 0.99407179f, + 0.96098997f, 0.2451298f, 0.41848766f, 0.01584927f, 0.28213452f, 0.04494721f, + 0.16963578f, 0.68096619f, 0.39404686f, 0.7621266f, 0.02721071f, 0.5481559f, + 0.59972178f, 0.61725009f, 0.76405802f, 0.83030081f, 0.87232659f, 0.16119207f, + 0.51143718f, 0.13040968f, 0.57453206f, 0.63200166f, 0.27077547f, 0.72281371f, + 0.44055048f, 0.51538986f, 0.29096202f, 0.99726975f, 0.50958807f, 0.87792484f, + 0.03956957f, 0.42187308f, 0.87694541f, 0.88974026f, 0.65590356f, 0.35029236f, + 0.18853136f, 0.50500502f, 0.95545852f, 0.94636341f, 0.84731837f, 0.13936297f, + 0.32537976f, 0.41430316f, 0.18574781f, 0.97574309f, 0.26483325f, 0.79840404f, + 0.74069621f, 0.98526361f, 0.63957011f, 0.30924823f, 0.20429374f, 0.09850504f, + 0.77676228f, 0.40561045f, 0.71999222f, 0.42545573f, 0.78092917f, 0.74532941f, + 0.52263514f, 0.01771433f, 0.15041333f, 0.41157879f, 0.15047035f, 0.66149007f, + 0.95970903f, 0.97348663f, 0.30155038f, 0.06596597f, 0.3317747f, 0.09346482f, + 0.71672818f, 0.13279156f, 0.19758743f, 0.20143709f, 0.84517665f, 0.767672f, + 0.21471986f, 0.75663108f, 0.35878468f, 0.58943601f, 0.98005496f, 0.30451585f, + 0.34754926f, 0.3298018f, 0.36859658f, 0.52568727f, 0.45107675f, 0.27778918f, + 0.4825746f, 0.6521011f, 0.16924284f, 0.54550222f, 0.33862934f, 0.88247624f, + 0.97012639f, 0.64496125f, 0.09514454f, 0.90497989f, 0.82705286f, 0.5232794f, + 0.80558394f, 0.86949601f, 0.78825486f, 0.23086437f, 0.64405503f, 0.02989425f, + 0.61423185f, 0.45341492f, 0.52462891f, 0.93029992f, 0.74040612f, 0.45227326f, + 0.35339424f, 0.30661544f, 0.70083487f, 0.68725394f, 0.2036894f, 0.85478822f, + 0.13176267f, 0.10494695f, 0.17226407f, 0.88662847f, 0.42744141f, 0.44540842f, + 0.94161152f, 0.46699513f, 0.36795051f, 0.0234292f, 0.68830582f, 0.33571055f, + 0.93930267f, 0.76513689f, 0.69002036f, 0.11983312f, 0.05524331f, 0.28743821f, + 0.53563344f, 0.00152629f, 0.50295284f, 0.24351331f, 0.6770774f, 0.42484211f, + 0.10956752f, 0.01239354f, 0.57630947f, 0.16575461f, 0.7870273f, 0.64387019f, + 0.65514058f, 0.62808722f, 0.29263556f, 0.8159863f, 0.18642033f + }; + List instances = new LinkedList<>(); + Random rand = new Random(); + for (int i = 0; i < 50; i++) { + float[] data = new float[featureVec.length]; + for (int j = 0; j < featureVec.length; j++) { + data[j] = featureVec[rand.nextInt(featureVec.length)]; + } + instances.add(data); + } + + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(bos); + oos.writeObject(instances); + byte[] bytes = bos.toByteArray(); + byte[] byteString = ByteString.copyFrom(bytes).toByteArray(); + + InputParameter parameter = + InputParameter.newBuilder().setValue(ByteString.copyFrom(bytes)).build(); + InferenceRequest inferenceRequest = + InferenceRequest.newBuilder() + .setCommand(WorkerCommands.predictions) + .setModelName("test") + .setRequest(RequestInput.newBuilder().addParameters(parameter).build()) + .build(); + logger.info( + "2D random float size=" + + featureVec.length * 50 + + ", byteString size=" + + byteString.length + + ", bytes size=" + + bytes.length + + ", parameter size=" + + parameter.toByteArray().length + + ", protobuf size=" + + inferenceRequest.toByteArray().length); + oos.close(); + + List instances1 = new LinkedList<>(); + for (int i = 0; i < 50; i++) { + instances1.add(featureVec); + } + + ByteArrayOutputStream bos1 = new ByteArrayOutputStream(); + ObjectOutputStream oos1 = new ObjectOutputStream(bos1); + oos1.writeObject(instances1); + byte[] bytes1 = bos1.toByteArray(); + byte[] byteString1 = ByteString.copyFrom(bytes1).toByteArray(); + + InputParameter parameter1 = + InputParameter.newBuilder().setValue(ByteString.copyFrom(bytes1)).build(); + InferenceRequest inferenceRequest1 = + InferenceRequest.newBuilder() + .setCommand(WorkerCommands.predictions) + .setModelName("test") + .setRequest(RequestInput.newBuilder().addParameters(parameter1).build()) + .build(); + logger.info( + "2D repeated float size=" + + featureVec.length * 50 + + ", byteString size=" + + byteString1.length + + ", bytes size=" + + bytes1.length + + ", parameter size=" + + parameter1.toByteArray().length + + ", protobuf size=" + + inferenceRequest1.toByteArray().length); + oos1.close(); + + ByteBuffer fBuffer = ByteBuffer.allocate(Float.BYTES * featureVec.length * 50); + fBuffer.order(ByteOrder.LITTLE_ENDIAN); + for (int i = 0; i < 50; i++) { + for (float feature : featureVec) { + fBuffer.putFloat(feature); + } + } + byte[] bytes2 = fBuffer.array(); + InputParameter parameter2 = + InputParameter.newBuilder().setValue(ByteString.copyFrom(bytes2)).build(); + InferenceRequest inferenceRequest2 = + InferenceRequest.newBuilder() + .setCommand(WorkerCommands.predictions) + .setModelName("test") + .setRequest(RequestInput.newBuilder().addParameters(parameter2).build()) + .build(); + logger.info( + "1D repeated float size=" + + featureVec.length * 50 + + ", fBuffer size=" + + fBuffer.array().length + + ", bytes size=" + + bytes2.length + + ", parameter size=" + + parameter2.toByteArray().length + + ", protobuf size=" + + inferenceRequest2.toByteArray().length); + } + private void testInvocationsJson(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); @@ -778,6 +974,44 @@ private void testPredictionsModelNotFound() throws InterruptedException { Assert.assertEquals(resp.getMessage(), "Model not found: InvalidModel"); } + private void testPredictionsModelNotFoundProto() + throws InterruptedException, InvalidProtocolBufferException { + Channel channel = connect(false); + Assert.assertNotNull(channel); + + resultBuf = null; + latch = new CountDownLatch(1); + + InputParameter parameter = + InputParameter.newBuilder() + .setName("data") + .setValue(ByteString.copyFrom("test", CharsetUtil.UTF_8)) + .build(); + InferenceRequest inferenceRequest = + InferenceRequest.newBuilder() + .setCommand(WorkerCommands.predictions) + .setModelName("InvalidModel") + .setRequest(RequestInput.newBuilder().addParameters(parameter).build()) + .build(); + + DefaultFullHttpRequest req = + new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + req.content().writeBytes(inferenceRequest.toByteArray()); + HttpUtil.setContentLength(req, req.content().readableBytes()); + req.headers().set(HttpHeaderNames.CONTENT_TYPE, ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF); + channel.writeAndFlush(req).sync(); + channel.closeFuture().sync(); + + latch.await(); + + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse resp = + com.amazonaws.ml.mms.protobuf.codegen.StatusResponse.parseFrom(resultBuf); + + Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); + Assert.assertEquals(resp.getMessage(), "Model not found: InvalidModel"); + channel.close(); + } + private void testInvalidManagementUri() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); @@ -1361,8 +1595,16 @@ private class TestHandler extends SimpleChannelInboundHandler @Override public void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) { + CharSequence contentType = HttpUtil.getMimeType(msg); + if (contentType != null + && contentType + .toString() + .equalsIgnoreCase(ConfigManager.HTTP_CONTENT_TYPE_PROTOBUF)) { + resultBuf = msg.content().nioBuffer(); + } else { + result = msg.content().toString(StandardCharsets.UTF_8); + } httpStatus = msg.status(); - result = msg.content().toString(StandardCharsets.UTF_8); headers = msg.headers(); latch.countDown(); } diff --git a/frontend/tools/conf/findbugs-exclude.xml b/frontend/tools/conf/findbugs-exclude.xml index d2c4821a7..f8b8f8dc7 100644 --- a/frontend/tools/conf/findbugs-exclude.xml +++ b/frontend/tools/conf/findbugs-exclude.xml @@ -10,5 +10,7 @@ - + + + diff --git a/mms/tests/unit_tests/test_beckend_metric.py b/mms/tests/unit_tests/test_beckend_metric.py index 4ef2e6ad0..6732973bb 100644 --- a/mms/tests/unit_tests/test_beckend_metric.py +++ b/mms/tests/unit_tests/test_beckend_metric.py @@ -29,6 +29,7 @@ def test_metrics(caplog): Test if metric classes methods behave as expected Also checks global metric service methods """ + caplog.set_level(logging.INFO) # Create a batch of request ids request_ids = {0: 'abcd', 1: "xyz", 2: "qwerty", 3: "hjshfj"} all_req_ids = ','.join(request_ids.values()) diff --git a/mms/tests/unit_tests/test_worker_service.py b/mms/tests/unit_tests/test_worker_service.py index 575f13d73..bafba308c 100644 --- a/mms/tests/unit_tests/test_worker_service.py +++ b/mms/tests/unit_tests/test_worker_service.py @@ -50,6 +50,7 @@ def test_valid_req(self, service): class TestEmitMetrics: def test_emit_metrics(self, caplog): + caplog.set_level(logging.INFO) metrics = {'test_emit_metrics': True} emit_metrics(metrics) assert "[METRICS]" in caplog.text