diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index 56e994be86eb4..43954a5af8b91 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -12,10 +12,12 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import java.util.Locale; import java.util.Objects; import java.util.function.Function; @@ -34,17 +36,23 @@ public abstract class BaseResponseHandler implements ResponseHandler { public static final String SERVER_ERROR_OBJECT = "Received an error response"; public static final String BAD_REQUEST = "Received a bad request status code"; public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code"; + protected static final String ERROR_TYPE = "error"; + protected static final String STREAM_ERROR = "stream_error"; protected final String requestType; protected final ResponseParser parseFunction; private final Function errorParseFunction; private final boolean canHandleStreamingResponses; - public BaseResponseHandler(String requestType, ResponseParser parseFunction, Function errorParseFunction) { + protected BaseResponseHandler( + String requestType, + ResponseParser parseFunction, + Function errorParseFunction + ) { this(requestType, parseFunction, errorParseFunction, false); } - public BaseResponseHandler( + protected BaseResponseHandler( String requestType, ResponseParser parseFunction, Function errorParseFunction, @@ -109,19 +117,230 @@ private void checkForErrorObject(Request request, HttpResult result) { } protected Exception buildError(String message, Request request, HttpResult result) { - var errorEntityMsg = errorParseFunction.apply(result); - return buildError(message, request, result, errorEntityMsg); + var errorResponse = errorParseFunction.apply(result); + return buildError(message, request, result, errorResponse); } protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { var responseStatusCode = result.response().getStatusLine().getStatusCode(); return new ElasticsearchStatusException( - errorMessage(message, request, result, errorResponse, responseStatusCode), + errorMessage(message, request, errorResponse, responseStatusCode), toRestStatus(responseStatusCode) ); } - protected String errorMessage(String message, Request request, HttpResult result, ErrorResponse errorResponse, int statusCode) { + /** + * Builds an error for a streaming request with a custom error type. + * This method is used when an error response is received from the external service. + * Only streaming requests support this format, and it should be used when the error response. + * + * @param message the error message to include in the exception + * @param request the request that caused the error + * @param result the HTTP result containing the error response + * @param errorResponse the parsed error response from the HTTP result + * @param errorResponseClass the class of the expected error response type + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + protected UnifiedChatCompletionException buildChatCompletionError( + String message, + Request request, + HttpResult result, + ErrorResponse errorResponse, + Class errorResponseClass + ) { + assert request.isStreaming() : "Only streaming requests support this format"; + var statusCode = result.response().getStatusLine().getStatusCode(); + var errorMessage = errorMessage(message, request, errorResponse, statusCode); + var restStatus = toRestStatus(statusCode); + + return buildChatCompletionError(errorResponse, errorMessage, restStatus, errorResponseClass); + } + + /** + * Builds a {@link UnifiedChatCompletionException} for a streaming request. + * This method is used when an error response is received from the external service. + * Only streaming requests should use this method. + * + * @param errorResponse the error response parsed from the HTTP result + * @param errorMessage the error message to include in the exception + * @param restStatus the REST status code of the response + * @param errorResponseClass the class of the expected error response type + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + protected UnifiedChatCompletionException buildChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus, + Class errorResponseClass + ) { + if (errorResponseClass.isInstance(errorResponse)) { + return buildProviderSpecificChatCompletionError(errorResponse, errorMessage, restStatus); + } else { + return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus); + } + } + + /** + * Builds a custom {@link UnifiedChatCompletionException} for a streaming request. + * This method is called when a specific error response is found in the HTTP result. + * It must be implemented by subclasses to handle specific error response formats. + * Only streaming requests should use this method. + * + * @param errorResponse the error response parsed from the HTTP result + * @param errorMessage the error message to include in the exception + * @param restStatus the REST status code of the response + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ) { + throw new UnsupportedOperationException( + "Custom error handling is not implemented. Please override buildProviderSpecificChatCompletionError method." + ); + } + + /** + * Builds a default {@link UnifiedChatCompletionException} for a streaming request. + * This method is used when an error response is received but no specific error handling is implemented. + * Only streaming requests should use this method. + * + * @param errorResponse the error response parsed from the HTTP result + * @param errorMessage the error message to include in the exception + * @param restStatus the REST status code of the response + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + protected UnifiedChatCompletionException buildDefaultChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ) { + return new UnifiedChatCompletionException( + restStatus, + errorMessage, + createErrorType(errorResponse), + restStatus.name().toLowerCase(Locale.ROOT) + ); + } + + /** + * Builds a mid-stream error for a streaming request. + * This method is used when an error occurs while processing a streaming response. + * It must be implemented by subclasses to handle specific error response formats. + * Only streaming requests should use this method. + * + * @param inferenceEntityId the ID of the inference entity + * @param message the error message + * @param e the exception that caused the error, can be null + * @return a {@link UnifiedChatCompletionException} representing the mid-stream error + */ + public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { + throw new UnsupportedOperationException( + "Mid-stream error handling is not implemented. Please override buildMidStreamChatCompletionError method." + ); + } + + /** + * Builds a mid-stream error for a streaming request with a custom error type. + * This method is used when an error occurs while processing a streaming response and allows for custom error handling. + * Only streaming requests should use this method. + * + * @param inferenceEntityId the ID of the inference entity + * @param message the error message + * @param e the exception that caused the error, can be null + * @param errorResponseClass the class of the expected error response type + * @return a {@link UnifiedChatCompletionException} representing the mid-stream error + */ + protected UnifiedChatCompletionException buildMidStreamChatCompletionError( + String inferenceEntityId, + String message, + Exception e, + Class errorResponseClass + ) { + // Extract the error response from the message using the provided method + var errorResponse = extractMidStreamChatCompletionErrorResponse(message); + // Check if the error response matches the expected type + if (errorResponseClass.isInstance(errorResponse)) { + // If it matches, we can build a custom mid-stream error exception + return buildProviderSpecificMidStreamChatCompletionError(inferenceEntityId, errorResponse); + } else if (e != null) { + // If the error response does not match, we can still return an exception based on the original throwable + return UnifiedChatCompletionException.fromThrowable(e); + } else { + // If no specific error response is found, we return a default mid-stream error + return buildDefaultMidStreamChatCompletionError(inferenceEntityId, errorResponse); + } + } + + /** + * Builds a custom mid-stream {@link UnifiedChatCompletionException} for a streaming request. + * This method is called when a specific error response is found in the message. + * It must be implemented by subclasses to handle specific error response formats. + * Only streaming requests should use this method. + * + * @param inferenceEntityId the ID of the inference entity + * @param errorResponse the error response parsed from the message + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + throw new UnsupportedOperationException( + "Mid-stream error handling is not implemented for this response handler. " + + "Please override buildProviderSpecificMidStreamChatCompletionError method." + ); + } + + /** + * Builds a default mid-stream error for a streaming request. + * This method is used when no specific error response is found in the message. + * Only streaming requests should use this method. + * + * @param inferenceEntityId the ID of the inference entity + * @param errorResponse the error response extracted from the message + * @return a {@link UnifiedChatCompletionException} representing the default mid-stream error + */ + protected UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId), + createErrorType(errorResponse), + STREAM_ERROR + ); + } + + /** + * Extracts the mid-stream error response from the message. + * This method is used to parse the error response from a streaming message. + * It must be implemented by subclasses to handle specific error response formats. + * Only streaming requests should use this method. + * + * @param message the message containing the error response + * @return an {@link ErrorResponse} object representing the mid-stream error + */ + protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) { + throw new UnsupportedOperationException( + "Mid-stream error extraction is not implemented. Please override extractMidStreamChatCompletionErrorResponse method." + ); + } + + /** + * Creates a string representation of the error type based on the provided ErrorResponse. + * This method is used to generate a human-readable error type for logging or exception messages. + * + * @param errorResponse the ErrorResponse object + * @return a string representing the error type + */ + protected static String createErrorType(ErrorResponse errorResponse) { + return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown"; + } + + protected String errorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) { return (errorResponse == null || errorResponse.errorStructureFound() == false || Strings.isNullOrEmpty(errorResponse.getErrorMessage())) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java index ad40d43b3af3b..346b6a4f9026f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java @@ -25,6 +25,10 @@ import static org.elasticsearch.core.Strings.format; +/** + * Handles streaming chat completion responses and error parsing for Elastic Inference Service endpoints. + * This handler is designed to work with the unified Elastic Inference Service chat completion API. + */ public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler { public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { super(requestType, parseFunction, true); @@ -34,53 +38,128 @@ public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String reques public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); // EIS uses the unified API spec - var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e)); + var openAiProcessor = new OpenAiUnifiedStreamingProcessor( + (m, e) -> buildMidStreamChatCompletionError(request.getInferenceEntityId(), m, e) + ); flow.subscribe(serverSentEventProcessor); serverSentEventProcessor.subscribe(openAiProcessor); return new StreamingUnifiedChatCompletionResults(openAiProcessor); } + /** + * Builds an error for the Elastic Inference Service. + * This method is called when an error response is received from the service. + * + * @param message The error message to include in the exception. + * @param request The request that caused the error. + * @param result The HTTP result containing the error response. + * @param errorResponse The parsed error response from the service. + * @return An instance of {@link Exception} representing the error. + */ @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var restStatus = toRestStatus(responseStatusCode); - return new UnifiedChatCompletionException( - restStatus, - errorMessage(message, request, result, errorResponse, responseStatusCode), - "error", - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } + protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + return buildChatCompletionError(message, request, result, errorResponse, ErrorResponse.class); + } + + /** + * Builds a custom {@link UnifiedChatCompletionException} for the Elastic Inference Service. + * This method is called when an error response is received from the service. + * + * @param errorResponse The error response received from the service. + * @param errorMessage The error message to include in the exception. + * @param restStatus The HTTP status of the error response. + * @param errorResponseClass The class of the error response. + * @return An instance of {@link UnifiedChatCompletionException} with details from the error response. + */ + @Override + protected UnifiedChatCompletionException buildChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus, + Class errorResponseClass + ) { + return new UnifiedChatCompletionException(restStatus, errorMessage, ERROR_TYPE, restStatus.name().toLowerCase(Locale.ROOT)); } - private static Exception buildMidStreamError(Request request, String message, Exception e) { - var errorResponse = ElasticInferenceServiceErrorResponseEntity.fromString(message); + /** + * Builds a mid-stream error for the Elastic Inference Service. + * This method is called when an error occurs during the streaming process. + * + * @param inferenceEntityId The ID of the inference entity. + * @param message The error message received from the service. + * @param e The exception that occurred, if any. + * @return An instance of {@link UnifiedChatCompletionException} representing the mid-stream error. + */ + @Override + public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { + var errorResponse = extractMidStreamChatCompletionErrorResponse(message); + // Check if the error response contains a specific structure if (errorResponse.errorStructureFound()) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - request.getInferenceEntityId(), - errorResponse.getErrorMessage() - ), - "error", - "stream_error" - ); + return buildProviderSpecificMidStreamChatCompletionError(inferenceEntityId, errorResponse); } else if (e != null) { return UnifiedChatCompletionException.fromThrowable(e); } else { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), - "error", - "stream_error" - ); + return buildDefaultMidStreamChatCompletionError(inferenceEntityId, errorResponse); } } + + /** + * Extracts the error response from the message. This method is specific to the Elastic Inference Service + * and should parse the message according to its error response format. + * + * @param message The message containing the error response. + * @return An instance of {@link ErrorResponse} parsed from the message. + */ + @Override + protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) { + return ElasticInferenceServiceErrorResponseEntity.fromString(message); + } + + /** + * Builds a custom mid-stream {@link UnifiedChatCompletionException} for the Elastic Inference Service. + * This method is called when a specific error response structure is found in the message. + * + * @param inferenceEntityId The ID of the inference entity. + * @param errorResponse The error response parsed from the message. + * @return An instance of {@link UnifiedChatCompletionException} with details from the error response. + */ + @Override + protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + errorResponse.getErrorMessage() + ), + ERROR_TYPE, + STREAM_ERROR + ); + } + + /** + * Builds a default mid-stream {@link UnifiedChatCompletionException} for the Elastic Inference Service. + * This method is called when specific error response structure is NOT found in the message. + * + * @param inferenceEntityId The ID of the inference entity. + * @param errorResponse The error response parsed from the message. + * @return An instance of {@link UnifiedChatCompletionException} with a generic error message. + */ + @Override + protected UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId), + ERROR_TYPE, + STREAM_ERROR + ); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java index 8c355c9f67f18..29c5c910abf75 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java @@ -10,8 +10,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -29,13 +27,16 @@ import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; import java.nio.charset.StandardCharsets; -import java.util.Locale; import java.util.Objects; import java.util.Optional; import java.util.concurrent.Flow; import static org.elasticsearch.core.Strings.format; +/** + * Handles streaming chat completion responses and error parsing for Google Vertex AI inference endpoints. + * This handler is designed to work with the unified Google Vertex AI chat completion API. + */ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVertexAiResponseHandler { private static final String ERROR_FIELD = "error"; @@ -54,7 +55,9 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher buildMidStreamError(request, m, e)); + var googleVertexAiProcessor = new GoogleVertexAiUnifiedStreamingProcessor( + (m, e) -> buildMidStreamChatCompletionError(request.getInferenceEntityId(), m, e) + ); flow.subscribe(serverSentEventProcessor); serverSentEventProcessor.subscribe(googleVertexAiProcessor); @@ -62,57 +65,57 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher, Void> ERROR_PARSER = new ConstructingObjectParser<>( "google_vertex_ai_error_wrapper", true, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java index 8dffd612db5c8..d3cbffb10f203 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java @@ -41,55 +41,54 @@ public HuggingFaceChatCompletionResponseHandler(String requestType, ResponsePars } @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof HuggingFaceErrorResponseEntity - ? new UnifiedChatCompletionException( - restStatus, - errorMessage, - HUGGING_FACE_ERROR, - restStatus.name().toLowerCase(Locale.ROOT) - ) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } + protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + return buildChatCompletionError(message, request, result, errorResponse, HuggingFaceErrorResponseEntity.class); } @Override - protected Exception buildMidStreamError(Request request, String message, Exception e) { - var errorResponse = StreamingHuggingFaceErrorResponseEntity.fromString(message); - if (errorResponse instanceof StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - request.getInferenceEntityId(), - errorResponse.getErrorMessage() - ), - HUGGING_FACE_ERROR, - extractErrorCode(streamingHuggingFaceErrorResponseEntity) - ); - } else if (e != null) { - return UnifiedChatCompletionException.fromThrowable(e); - } else { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), - createErrorType(errorResponse), - "stream_error" - ); - } + protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ) { + return new UnifiedChatCompletionException(restStatus, errorMessage, HUGGING_FACE_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); + } + + /** + * Builds an error for mid-stream responses from Hugging Face. + * This method is called when an error response is received during streaming operations. + * + * @param inferenceEntityId The ID of the inference entity that made the request. + * @param message The error message to include in the exception. + * @param e The exception that occurred. + * @return An instance of {@link UnifiedChatCompletionException} representing the error. + */ + @Override + public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { + return buildMidStreamChatCompletionError(inferenceEntityId, message, e, StreamingHuggingFaceErrorResponseEntity.class); + } + + @Override + protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + errorResponse.getErrorMessage() + ), + HUGGING_FACE_ERROR, + extractErrorCode((StreamingHuggingFaceErrorResponseEntity) errorResponse) + ); + } + + @Override + protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) { + return StreamingHuggingFaceErrorResponseEntity.fromString(message); } private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java index a9d6df687fe99..bb7ee509fa3dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.mistral; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; @@ -30,22 +31,16 @@ public MistralUnifiedChatCompletionResponseHandler(String requestType, ResponseP } @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof MistralErrorResponse - ? new UnifiedChatCompletionException(restStatus, errorMessage, MISTRAL_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } + protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + return buildChatCompletionError(message, request, result, errorResponse, MistralErrorResponse.class); + } + + @Override + protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ) { + return new UnifiedChatCompletionException(restStatus, errorMessage, MISTRAL_ERROR, restStatus.name().toLowerCase(Locale.ROOT)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java index e1a0117c7bcca..7e779645c77d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; import org.elasticsearch.xpack.inference.external.response.streaming.StreamingErrorResponse; -import java.util.Locale; import java.util.concurrent.Flow; import java.util.function.Function; @@ -45,64 +44,95 @@ public OpenAiUnifiedChatCompletionResponseHandler( @Override public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); - var openAiProcessor = new OpenAiUnifiedStreamingProcessor((m, e) -> buildMidStreamError(request, m, e)); + var openAiProcessor = new OpenAiUnifiedStreamingProcessor( + (m, e) -> buildMidStreamChatCompletionError(request.getInferenceEntityId(), m, e) + ); flow.subscribe(serverSentEventProcessor); serverSentEventProcessor.subscribe(openAiProcessor); return new StreamingUnifiedChatCompletionResults(openAiProcessor); } @Override - protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { - assert request.isStreaming() : "Only streaming requests support this format"; - var responseStatusCode = result.response().getStatusLine().getStatusCode(); - if (request.isStreaming()) { - var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); - var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof StreamingErrorResponse oer - ? new UnifiedChatCompletionException(restStatus, errorMessage, oer.type(), oer.code(), oer.param()) - : new UnifiedChatCompletionException( - restStatus, - errorMessage, - createErrorType(errorResponse), - restStatus.name().toLowerCase(Locale.ROOT) - ); - } else { - return super.buildError(message, request, result, errorResponse); - } + protected UnifiedChatCompletionException buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + return buildChatCompletionError(message, request, result, errorResponse, StreamingErrorResponse.class); } - protected static String createErrorType(ErrorResponse errorResponse) { - return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown"; + /** + * Builds a custom {@link UnifiedChatCompletionException} for OpenAI inference endpoints. + * This method is called when an error response is received. + * + * @param errorResponse the parsed error response from the service + * @param errorMessage the error message received + * @param restStatus the HTTP status code of the error + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + @Override + protected UnifiedChatCompletionException buildProviderSpecificChatCompletionError( + ErrorResponse errorResponse, + String errorMessage, + RestStatus restStatus + ) { + var streamingError = (StreamingErrorResponse) errorResponse; + return new UnifiedChatCompletionException( + restStatus, + errorMessage, + streamingError.type(), + streamingError.code(), + streamingError.param() + ); + } + + /** + * Builds a custom mid-stream {@link UnifiedChatCompletionException} for OpenAI inference endpoints. + * This method is called when an error response is received during streaming. + * + * @param inferenceEntityId the ID of the inference entity + * @param message the error message received during streaming + * @param e the exception that occurred + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + @Override + public UnifiedChatCompletionException buildMidStreamChatCompletionError(String inferenceEntityId, String message, Exception e) { + // Use the custom type StreamingErrorResponse for mid-stream errors + return buildMidStreamChatCompletionError(inferenceEntityId, message, e, StreamingErrorResponse.class); } - protected Exception buildMidStreamError(Request request, String message, Exception e) { - return buildMidStreamError(request.getInferenceEntityId(), message, e); + /** + * Extracts the mid-stream error response from the message. + * + * @param message the message containing the error response + * @return the extracted {@link ErrorResponse} + */ + @Override + protected ErrorResponse extractMidStreamChatCompletionErrorResponse(String message) { + return StreamingErrorResponse.fromString(message); } - public static UnifiedChatCompletionException buildMidStreamError(String inferenceEntityId, String message, Exception e) { - var errorResponse = StreamingErrorResponse.fromString(message); - if (errorResponse instanceof StreamingErrorResponse oer) { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format( - "%s for request from inference entity id [%s]. Error message: [%s]", - SERVER_ERROR_OBJECT, - inferenceEntityId, - errorResponse.getErrorMessage() - ), - oer.type(), - oer.code(), - oer.param() - ); - } else if (e != null) { - return UnifiedChatCompletionException.fromThrowable(e); - } else { - return new UnifiedChatCompletionException( - RestStatus.INTERNAL_SERVER_ERROR, - format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId), - createErrorType(errorResponse), - "stream_error" - ); - } + /** + * Builds a custom mid-stream {@link UnifiedChatCompletionException} for OpenAI inference endpoints. + * This method is called when an error response is received during streaming. + * + * @param inferenceEntityId the ID of the inference entity + * @param errorResponse the parsed error response from the service + * @return an instance of {@link UnifiedChatCompletionException} with details from the error response + */ + @Override + protected UnifiedChatCompletionException buildProviderSpecificMidStreamChatCompletionError( + String inferenceEntityId, + ErrorResponse errorResponse + ) { + var streamingError = (StreamingErrorResponse) errorResponse; + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + inferenceEntityId, + streamingError.getErrorMessage() + ), + streamingError.type(), + streamingError.code(), + streamingError.param() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java index 64b42f00d2d5b..d7e506d690610 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiCompletionPayload.java @@ -22,7 +22,6 @@ import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; @@ -40,6 +39,11 @@ import java.util.Map; import java.util.stream.Stream; +/** + * Handles chat completion requests and responses for OpenAI models in SageMaker. + * This class implements the SageMakerStreamSchemaPayload interface to provide + * the necessary methods for handling OpenAI chat completions. + */ public class OpenAiCompletionPayload implements SageMakerStreamSchemaPayload { private static final XContent jsonXContent = JsonXContent.jsonXContent; @@ -50,7 +54,7 @@ public class OpenAiCompletionPayload implements SageMakerStreamSchemaPayload { private static final String USER_FIELD = "user"; private static final String USER_ROLE = "user"; private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; - private static final ResponseHandler ERROR_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( + private static final OpenAiUnifiedChatCompletionResponseHandler ERROR_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( "sagemaker openai chat completion", ((request, result) -> { assert false : "do not call this"; @@ -88,12 +92,12 @@ public StreamingUnifiedChatCompletionResults.Results chatCompletionResponseBody( var serverSentEvents = serverSentEvents(response); var results = serverSentEvents.flatMap(event -> { if ("error".equals(event.type())) { - throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), event.data(), null); + throw ERROR_HANDLER.buildMidStreamChatCompletionError(model.getInferenceEntityId(), event.data(), null); } else { try { return OpenAiUnifiedStreamingProcessor.parse(parserConfig, event); } catch (Exception e) { - throw OpenAiUnifiedChatCompletionResponseHandler.buildMidStreamError(model.getInferenceEntityId(), event.data(), e); + throw ERROR_HANDLER.buildMidStreamChatCompletionError(model.getInferenceEntityId(), event.data(), e); } } })