Skip to content

Add Azure AI Rerank support #129848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/129848.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 129848
summary: "[ML] Add Azure AI Rerank support to the Inference Plugin"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ static TransportVersion def(int id) {
public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS_8_19 = def(8_841_0_58);
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_59);
public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION_8_19 = def(8_841_0_60);
public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO_ADDED_8_19 = def(8_841_0_61);

public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -322,6 +324,7 @@ static TransportVersion def(int id) {
public static final TransportVersion CLUSTER_STATE_PROJECTS_SETTINGS = def(9_108_0_00);
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED = def(9_109_00_0);
public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION = def(9_110_0_00);
public static final TransportVersion ML_INFERENCE_AZURE_AI_STUDIO_ADDED = def(9_111_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down Expand Up @@ -388,7 +391,7 @@ static TransportVersion def(int id) {
* Reference to the minimum transport version that can be used with CCS.
* This should be the transport version used by the previous minor release.
*/
public static final TransportVersion MINIMUM_CCS_VERSION = INITIAL_ELASTICSEARCH_9_0_3;
public static final TransportVersion MINIMUM_CCS_VERSION = INITIAL_ELASTICSEARCH_9_0_2;

/**
* Sorted list of all versions defined in this class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionTaskSettings;
Expand Down Expand Up @@ -305,6 +307,17 @@ private static void addAzureAiStudioNamedWriteables(List<NamedWriteableRegistry.
AzureAiStudioChatCompletionTaskSettings::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
AzureAiStudioRerankServiceSettings.NAME,
AzureAiStudioRerankServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, AzureAiStudioRerankTaskSettings.NAME, AzureAiStudioRerankTaskSettings::new)
);
}

private static void addAzureOpenAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
public class AzureAiStudioConstants {
public static final String EMBEDDINGS_URI_PATH = "/v1/embeddings";
public static final String COMPLETIONS_URI_PATH = "/v1/chat/completions";
public static final String RERANK_URI_PATH = "/v1/rerank";

// common service settings fields
public static final String TARGET_FIELD = "target";
Expand All @@ -22,6 +23,10 @@ public class AzureAiStudioConstants {
public static final String DIMENSIONS_FIELD = "dimensions";
public static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";

// rerank task settings fields
public static final String DOCUMENTS_FIELD = "documents";
public static final String QUERY_FIELD = "query";

// embeddings task settings fields
public static final String USER_FIELD = "user";

Expand All @@ -35,5 +40,9 @@ public class AzureAiStudioConstants {
public static final Double MIN_TEMPERATURE_TOP_P = 0.0;
public static final Double MAX_TEMPERATURE_TOP_P = 2.0;

// rerank task settings fields
public static final String RETURN_DOCUMENTS_FIELD = "return_documents";
public static final String TOP_N_FIELD = "top_n";

private AzureAiStudioConstants() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ public final class AzureAiStudioProviderCapabilities {
// these providers have chat completion inference (all providers at the moment)
public static final List<AzureAiStudioProvider> chatCompletionProviders = List.of(AzureAiStudioProvider.values());

// these providers allow token ("pay as you go") embeddings endpoints
public static final List<AzureAiStudioProvider> rerankProviders = List.of(AzureAiStudioProvider.COHERE);

// these providers allow token ("pay as you go") embeddings endpoints
public static final List<AzureAiStudioProvider> tokenRerankProviders = List.of(AzureAiStudioProvider.COHERE);

// these providers allow realtime rerank endpoints (none at the moment)
public static final List<AzureAiStudioProvider> realtimeRerankProviders = List.of();

// these providers allow token ("pay as you go") embeddings endpoints
public static final List<AzureAiStudioProvider> tokenEmbeddingsProviders = List.of(
AzureAiStudioProvider.OPENAI,
Expand Down Expand Up @@ -54,6 +63,9 @@ public static boolean providerAllowsTaskType(AzureAiStudioProvider provider, Tas
case TEXT_EMBEDDING -> {
return embeddingProviders.contains(provider);
}
case RERANK -> {
return rerankProviders.contains(provider);
}
default -> {
return false;
}
Expand All @@ -76,6 +88,11 @@ public static boolean providerAllowsEndpointTypeForTask(
? tokenEmbeddingsProviders.contains(provider)
: realtimeEmbeddingsProviders.contains(provider);
}
case RERANK -> {
return (endpointType == AzureAiStudioEndpointType.TOKEN)
? tokenRerankProviders.contains(provider)
: realtimeRerankProviders.contains(provider);
}
default -> {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.azureaistudio;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
import org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRerankRequest;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.response.AzureAiStudioRerankResponseEntity;
import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler;

import java.util.function.Supplier;

public class AzureAiStudioRerankRequestManager extends AzureAiStudioRequestManager {
private static final Logger logger = LogManager.getLogger(AzureAiStudioRerankRequestManager.class);

private static final ResponseHandler HANDLER = createRerankHandler();

private final AzureAiStudioRerankModel model;

public AzureAiStudioRerankRequestManager(AzureAiStudioRerankModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = model;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestRerankFunction,
ActionListener<InferenceServiceResults> listener
) {
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
AzureAiStudioRerankRequest request = new AzureAiStudioRerankRequest(
model,
rerankInput.getQuery(),
rerankInput.getChunks(),
rerankInput.getReturnDocuments(),
rerankInput.getTopN()
);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestRerankFunction, listener));
}

private static ResponseHandler createRerankHandler() {
// This currently covers response handling for Azure AI Studio
return new AzureMistralOpenAiExternalResponseHandler(
"azure ai studio rerank",
new AzureAiStudioRerankResponseEntity(),
ErrorMessageResponseEntity::fromResponse,
true
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

Expand Down Expand Up @@ -74,7 +75,7 @@ public class AzureAiStudioService extends SenderService {
static final String NAME = "azureaistudio";

private static final String SERVICE_NAME = "Azure AI Studio";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.RERANK);

private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
InputType.INGEST,
Expand Down Expand Up @@ -307,6 +308,24 @@ private static AzureAiStudioModel createModel(
return completionModel;
}

if (taskType == TaskType.RERANK) {
var rerankModel = new AzureAiStudioRerankModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
context
);
checkProviderAndEndpointTypeForTask(
TaskType.RERANK,
rerankModel.getServiceSettings().provider(),
rerankModel.getServiceSettings().endpointType()
);
return rerankModel;
}

throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioChatCompletionRequestManager;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioRerankRequestManager;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;

import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -49,4 +51,12 @@ public ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio embeddings");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}

@Override
public ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map<String, Object> taskSettings) {
var overriddenModel = AzureAiStudioRerankModel.of(rerankModel, taskSettings);
var requestManager = new AzureAiStudioRerankRequestManager(overriddenModel, serviceComponents.threadPool());
var errorMessage = constructFailedToSendRequestMessage("Azure AI Studio rerank");
return new SenderExecutableAction(sender, requestManager, errorMessage);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;

import java.util.Map;

public interface AzureAiStudioActionVisitor {
ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings);

ExecutableAction create(AzureAiStudioChatCompletionModel completionModel, Map<String, Object> taskSettings);

ExecutableAction create(AzureAiStudioRerankModel rerankModel, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.azureaistudio.request;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;

import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

public class AzureAiStudioRerankRequest extends AzureAiStudioRequest {
private final String query;
private final List<String> input;
private final Boolean returnDocuments;
private final Integer topN;
private final AzureAiStudioRerankModel rerankModel;

public AzureAiStudioRerankRequest(
AzureAiStudioRerankModel model,
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN
) {
super(model);
this.rerankModel = Objects.requireNonNull(model);
this.query = query;
this.input = Objects.requireNonNull(input);
this.returnDocuments = returnDocuments;
this.topN = topN;
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(this.uri);

ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(createRequestEntity()).getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
setAuthHeader(httpPost, rerankModel);

return new HttpRequest(httpPost, getInferenceEntityId());
}

@Override
public Request truncate() {
// Not applicable for rerank, only used in text embedding requests
return this;
}

@Override
public boolean[] getTruncationInfo() {
// Not applicable for rerank, only used in text embedding requests
return null;
}

private AzureAiStudioRerankRequestEntity createRequestEntity() {
var taskSettings = rerankModel.getTaskSettings();
return new AzureAiStudioRerankRequestEntity(query, input, returnDocuments, topN, taskSettings);
}
}
Loading
Loading