Skip to content

[8.19] Add Mistral AI Chat Completion support to Inference Plugin (#128538) #128947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
5 changes: 5 additions & 0 deletions docs/changelog/128538.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 128538
summary: "Added Mistral Chat Completion support to the Inference Plugin"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_REGEX_MATCH_WITH_CASE_INSENSITIVITY_8_19 = def(8_841_0_44);
public static final TransportVersion ESQL_QUERY_PLANNING_DURATION_8_19 = def(8_841_0_45);
public static final TransportVersion SEARCH_SOURCE_EXCLUDE_VECTORS_PARAM_8_19 = def(8_841_0_46);

public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_47);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ public record UnifiedCompletionRequest(
* {@link #MAX_COMPLETION_TOKENS_FIELD}. Providers are expected to pass in their supported field name.
*/
private static final String MAX_TOKENS_PARAM = "max_tokens_field";
/**
* Indicates whether to include the `stream_options` field in the JSON output.
* Some providers do not support this field. In such cases, this parameter should be set to "false",
* and the `stream_options` field will be excluded from the output.
* For providers that do support stream options, this parameter is left unset (default behavior),
* which implicitly includes the `stream_options` field in the output.
*/
public static final String INCLUDE_STREAM_OPTIONS_PARAM = "include_stream_options";

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
Expand All @@ -91,6 +99,23 @@ public static Params withMaxTokens(String modelId, Params params) {
);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId
* - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD}
* - Key: {@link #INCLUDE_STREAM_OPTIONS_PARAM}, Value: "false"
*/
public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Params params) {
return new DelegatingMapParams(
Map.ofEntries(
Map.entry(MODEL_ID_PARAM, modelId),
Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD),
Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString())
),
params
);
}

/**
* Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
* - Key: {@link #MODEL_FIELD}, Value: modelId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(13));
assertThat(services.size(), equalTo(14));

var providers = providers(services);

Expand All @@ -154,15 +154,16 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"openai",
"streaming_completion_test_service",
"hugging_face",
"amazon_sagemaker"
"amazon_sagemaker",
"mistral"
).toArray()
)
);
}

public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(7));
assertThat(services.size(), equalTo(8));

var providers = providers(services);

Expand All @@ -176,7 +177,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
"streaming_completion_test_service",
"hugging_face",
"amazon_sagemaker",
"googlevertexai"
"googlevertexai",
"mistral"
).toArray()
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
Expand Down Expand Up @@ -266,6 +267,13 @@ private static void addMistralNamedWriteables(List<NamedWriteableRegistry.Entry>
MistralEmbeddingsServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
MistralChatCompletionServiceSettings.NAME,
MistralChatCompletionServiceSettings::new
)
);

// note - no task settings for Mistral embeddings...
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
* A pattern is emerging in how external providers provide error responses.
*
* At a minimum, these return:
* <pre><code>
* {
* "error: {
* "message": "(error message)"
* }
* }
*
* </code></pre>
* Others may return additional information such as error codes specific to the service.
*
* This currently covers error handling for Azure AI Studio, however this pattern
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.streaming;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;

import java.util.Objects;
import java.util.Optional;

/**
* Represents an error response from a streaming inference service.
* This class extends {@link ErrorResponse} and provides additional fields
* specific to streaming errors, such as code, param, and type.
* An example error response for a streaming service might look like:
* <pre><code>
* {
* "error": {
* "message": "Invalid input",
* "code": "400",
* "param": "input",
* "type": "invalid_request_error"
* }
* }
* </code></pre>
* TODO: {@link ErrorMessageResponseEntity} is nearly identical to this, but doesn't parse as many fields. We must remove the duplication.
*/
public class StreamingErrorResponse extends ErrorResponse {
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
"streaming_error",
true,
args -> Optional.ofNullable((StreamingErrorResponse) args[0])
);
private static final ConstructingObjectParser<StreamingErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
"streaming_error",
true,
args -> new StreamingErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3])
);

static {
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message"));
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("code"));
ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("param"));
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type"));

ERROR_PARSER.declareObjectOrNull(
ConstructingObjectParser.optionalConstructorArg(),
ERROR_BODY_PARSER,
null,
new ParseField("error")
);
}

/**
* Standard error response parser. This can be overridden for those subclasses that
* have a different error response structure.
* @param response The error response as an HttpResult
*/
public static ErrorResponse fromResponse(HttpResult response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}

return ErrorResponse.UNDEFINED_ERROR;
}

/**
* Standard error response parser. This can be overridden for those subclasses that
* have a different error response structure.
* @param response The error response as a string
*/
public static ErrorResponse fromString(String response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response)
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}

return ErrorResponse.UNDEFINED_ERROR;
}

@Nullable
private final String code;
@Nullable
private final String param;
private final String type;

StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
super(errorMessage);
this.code = code;
this.param = param;
this.type = Objects.requireNonNull(type);
}

@Nullable
public String code() {
return code;
}

@Nullable
public String param() {
return param;
}

public String type() {
return type;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.inference.UnifiedCompletionRequest.INCLUDE_STREAM_OPTIONS_PARAM;

/**
* Represents a unified chat completion request entity.
* This class is used to convert the unified chat input into a format that can be serialized to XContent.
*/
public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {

public static final String STREAM_FIELD = "stream";
Expand Down Expand Up @@ -42,7 +48,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);

builder.field(STREAM_FIELD, stream);
if (stream) {
// If request is streamed and skip stream options parameter is not true, include stream options in the request.
if (stream && params.paramAsBoolean(INCLUDE_STREAM_OPTIONS_PARAM, true)) {
builder.startObject(STREAM_OPTIONS_FIELD);
builder.field(INCLUDE_USAGE_FIELD, true);
builder.endObject();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.mistral;

import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponse;
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;

/**
* Handles non-streaming completion responses for Mistral models, extending the OpenAI completion response handler.
* This class is specifically designed to handle Mistral's error response format.
*/
public class MistralCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {

/**
* Constructs a MistralCompletionResponseHandler with the specified request type and response parser.
*
* @param requestType The type of request being handled (e.g., "mistral completions").
* @param parseFunction The function to parse the response.
*/
public MistralCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, MistralErrorResponse::fromResponse);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

public class MistralConstants {
public static final String API_EMBEDDINGS_PATH = "https://api.mistral.ai/v1/embeddings";
public static final String API_COMPLETIONS_PATH = "https://api.mistral.ai/v1/chat/completions";

// note - there is no bounds information available from Mistral,
// so we'll use a sane default here which is the same as Cohere's
Expand All @@ -18,4 +19,8 @@ public class MistralConstants {
public static final String MODEL_FIELD = "model";
public static final String INPUT_FIELD = "input";
public static final String ENCODING_FORMAT_FIELD = "encoding_format";
public static final String MAX_TOKENS_FIELD = "max_tokens";
public static final String DETAIL_FIELD = "detail";
public static final String MSG_FIELD = "msg";
public static final String MESSAGE_FIELD = "message";
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.mistral.request.MistralEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.mistral.request.embeddings.MistralEmbeddingsRequest;
import org.elasticsearch.xpack.inference.services.mistral.response.MistralEmbeddingsResponseEntity;

import java.util.List;
Expand Down
Loading