Skip to content

Commit d8e97b7

Browse files
Adding configurable inference service (#127939) (#128644)
* Inference changes * Custom service fixes * Update docs/changelog/127939.yaml * Cleaning up from failed merge * Fixing changelog * [CI] Auto commit changes from spotless * Fixing test * Adding feature flag * [CI] Auto commit changes from spotless --------- Co-authored-by: elasticsearchmachine <[email protected]> (cherry picked from commit 9db1837) # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java # test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java
1 parent 38e16d4 commit d8e97b7

File tree

65 files changed

+5523
-163
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+5523
-163
lines changed

docs/changelog/127939.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127939
2+
summary: Add Custom inference service
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ static TransportVersion def(int id) {
228228
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
229229
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37);
230230
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_38);
231+
public static final TransportVersion INFERENCE_CUSTOM_SERVICE_ADDED_8_19 = def(8_841_0_39);
231232

232233
/*
233234
* STOP! READ THIS FIRST! No, really,

test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
*/
1818
public enum FeatureFlag {
1919
TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null),
20-
SUB_OBJECTS_AUTO_ENABLED("es.sub_objects_auto_feature_flag_enabled=true", Version.fromString("8.16.0"), null);
20+
SUB_OBJECTS_AUTO_ENABLED("es.sub_objects_auto_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
21+
INFERENCE_CUSTOM_SERVICE_ENABLED("es.inference_custom_service_feature_flag_enabled=true", Version.fromString("8.19.0"), null);
2122

2223
public final String systemProperty;
2324
public final Version from;

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.util.concurrent.ThreadContext;
1515
import org.elasticsearch.core.TimeValue;
1616
import org.elasticsearch.test.cluster.ElasticsearchCluster;
17+
import org.elasticsearch.test.cluster.FeatureFlag;
1718
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
1819
import org.elasticsearch.test.rest.ESRestTestCase;
1920
import org.junit.ClassRule;
@@ -46,6 +47,7 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase {
4647
// This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin
4748
.plugin("inference-service-test")
4849
.user("x_pack_rest_user", "x-pack-test-password")
50+
.feature(FeatureFlag.INFERENCE_CUSTOM_SERVICE_ENABLED)
4951
.build();
5052

5153
// The reason we're doing this is to make sure the mock server is initialized first so we can get the address before communicating

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.core.Nullable;
2121
import org.elasticsearch.inference.TaskType;
2222
import org.elasticsearch.test.cluster.ElasticsearchCluster;
23+
import org.elasticsearch.test.cluster.FeatureFlag;
2324
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
2425
import org.elasticsearch.test.rest.ESRestTestCase;
2526
import org.elasticsearch.xcontent.XContentBuilder;
@@ -50,6 +51,7 @@ public class InferenceBaseRestTest extends ESRestTestCase {
5051
.setting("xpack.security.enabled", "true")
5152
.plugin("inference-service-test")
5253
.user("x_pack_rest_user", "x-pack-test-password")
54+
.feature(FeatureFlag.INFERENCE_CUSTOM_SERVICE_ENABLED)
5355
.build();
5456

5557
@ClassRule

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2525

2626
public void testGetServicesWithoutTaskType() throws IOException {
2727
List<Object> services = getAllServices();
28-
assertThat(services.size(), equalTo(22));
28+
assertThat(services.size(), equalTo(23));
2929

3030
var providers = providers(services);
3131

@@ -39,6 +39,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
3939
"azureaistudio",
4040
"azureopenai",
4141
"cohere",
42+
"custom",
4243
"deepseek",
4344
"elastic",
4445
"elasticsearch",
@@ -70,7 +71,7 @@ private Iterable<String> providers(List<Object> services) {
7071

7172
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
7273
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
73-
assertThat(services.size(), equalTo(16));
74+
assertThat(services.size(), equalTo(17));
7475

7576
var providers = providers(services);
7677

@@ -83,6 +84,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
8384
"azureaistudio",
8485
"azureopenai",
8586
"cohere",
87+
"custom",
8688
"elasticsearch",
8789
"googleaistudio",
8890
"googlevertexai",
@@ -101,7 +103,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
101103

102104
public void testGetServicesWithRerankTaskType() throws IOException {
103105
List<Object> services = getServices(TaskType.RERANK);
104-
assertThat(services.size(), equalTo(8));
106+
assertThat(services.size(), equalTo(9));
105107

106108
var providers = providers(services);
107109

@@ -111,6 +113,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
111113
List.of(
112114
"alibabacloud-ai-search",
113115
"cohere",
116+
"custom",
114117
"elasticsearch",
115118
"googlevertexai",
116119
"jinaai",
@@ -124,7 +127,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
124127

125128
public void testGetServicesWithCompletionTaskType() throws IOException {
126129
List<Object> services = getServices(TaskType.COMPLETION);
127-
assertThat(services.size(), equalTo(12));
130+
assertThat(services.size(), equalTo(13));
128131

129132
var providers = providers(services);
130133

@@ -138,6 +141,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
138141
"azureaistudio",
139142
"azureopenai",
140143
"cohere",
144+
"custom",
141145
"deepseek",
142146
"googleaistudio",
143147
"openai",
@@ -173,7 +177,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
173177

174178
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
175179
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
176-
assertThat(services.size(), equalTo(6));
180+
assertThat(services.size(), equalTo(7));
177181

178182
var providers = providers(services);
179183

@@ -182,6 +186,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
182186
containsInAnyOrder(
183187
List.of(
184188
"alibabacloud-ai-search",
189+
"custom",
185190
"elastic",
186191
"elasticsearch",
187192
"hugging_face",
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.common.util.FeatureFlag;
11+
12+
public class CustomServiceFeatureFlag {
13+
/**
14+
* {@link org.elasticsearch.xpack.inference.services.custom.CustomService} feature flag. When the feature is complete,
15+
* this flag will be removed.
16+
* Enable feature via JVM option: `-Des.inference_custom_service_feature_flag_enabled=true`.
17+
*/
18+
public static final FeatureFlag CUSTOM_SERVICE_FEATURE_FLAG = new FeatureFlag("inference_custom_service");
19+
20+
private CustomServiceFeatureFlag() {}
21+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@
5959
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
6060
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
6161
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
62+
import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings;
63+
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
64+
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
65+
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
66+
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
67+
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
68+
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
69+
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
70+
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
6271
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
6372
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
6473
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
@@ -108,6 +117,8 @@
108117
import java.util.ArrayList;
109118
import java.util.List;
110119

120+
import static org.elasticsearch.xpack.inference.CustomServiceFeatureFlag.CUSTOM_SERVICE_FEATURE_FLAG;
121+
111122
public class InferenceNamedWriteablesProvider {
112123

113124
private InferenceNamedWriteablesProvider() {}
@@ -158,6 +169,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
158169
addAlibabaCloudSearchNamedWriteables(namedWriteables);
159170
addJinaAINamedWriteables(namedWriteables);
160171
addVoyageAINamedWriteables(namedWriteables);
172+
addCustomNamedWriteables(namedWriteables);
161173

162174
addUnifiedNamedWriteables(namedWriteables);
163175

@@ -169,6 +181,42 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
169181
return namedWriteables;
170182
}
171183

184+
private static void addCustomNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
185+
if (CUSTOM_SERVICE_FEATURE_FLAG.isEnabled() == false) {
186+
return;
187+
}
188+
189+
namedWriteables.add(
190+
new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new)
191+
);
192+
193+
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new));
194+
195+
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, CustomSecretSettings.NAME, CustomSecretSettings::new));
196+
197+
namedWriteables.add(
198+
new NamedWriteableRegistry.Entry(CustomResponseParser.class, TextEmbeddingResponseParser.NAME, TextEmbeddingResponseParser::new)
199+
);
200+
201+
namedWriteables.add(
202+
new NamedWriteableRegistry.Entry(
203+
CustomResponseParser.class,
204+
SparseEmbeddingResponseParser.NAME,
205+
SparseEmbeddingResponseParser::new
206+
)
207+
);
208+
209+
namedWriteables.add(
210+
new NamedWriteableRegistry.Entry(CustomResponseParser.class, RerankResponseParser.NAME, RerankResponseParser::new)
211+
);
212+
213+
namedWriteables.add(new NamedWriteableRegistry.Entry(CustomResponseParser.class, NoopResponseParser.NAME, NoopResponseParser::new));
214+
215+
namedWriteables.add(
216+
new NamedWriteableRegistry.Entry(CustomResponseParser.class, CompletionResponseParser.NAME, CompletionResponseParser::new)
217+
);
218+
}
219+
172220
private static void addUnifiedNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
173221
var writeables = UnifiedCompletionRequest.getNamedWriteables();
174222
namedWriteables.addAll(writeables);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
122122
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
123123
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
124+
import org.elasticsearch.xpack.inference.services.custom.CustomService;
124125
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
125126
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
126127
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
@@ -150,8 +151,10 @@
150151
import java.util.Set;
151152
import java.util.function.Predicate;
152153
import java.util.function.Supplier;
154+
import java.util.stream.Stream;
153155

154156
import static java.util.Collections.singletonList;
157+
import static org.elasticsearch.xpack.inference.CustomServiceFeatureFlag.CUSTOM_SERVICE_FEATURE_FLAG;
155158
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
156159
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;
157160

@@ -381,7 +384,11 @@ public void loadExtensions(ExtensionLoader loader) {
381384
}
382385

383386
public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
384-
return List.of(
387+
List<InferenceServiceExtension.Factory> conditionalServices = CUSTOM_SERVICE_FEATURE_FLAG.isEnabled()
388+
? List.of(context -> new CustomService(httpFactory.get(), serviceComponents.get()))
389+
: List.of();
390+
391+
List<InferenceServiceExtension.Factory> availableServices = List.of(
385392
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()),
386393
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()),
387394
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
@@ -400,6 +407,8 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
400407
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
401408
ElasticsearchInternalService::new
402409
);
410+
411+
return Stream.concat(availableServices.stream(), conditionalServices.stream()).toList();
403412
}
404413

405414
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public abstract class BaseResponseHandler implements ResponseHandler {
3636
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";
3737

3838
protected final String requestType;
39-
private final ResponseParser parseFunction;
39+
protected final ResponseParser parseFunction;
4040
private final Function<HttpResult, ErrorResponse> errorParseFunction;
4141
private final boolean canHandleStreamingResponses;
4242

0 commit comments

Comments
 (0)