Skip to content

[ML] Removing secure string wrapper for custom service #128698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.common;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.json.JsonXContent;

Expand All @@ -23,7 +24,11 @@ public class JsonUtils {
public static <T> String toJson(T value, String field) {
try {
XContentBuilder builder = JsonXContent.contentBuilder();
builder.value(value);
if (value instanceof SecureString secureString) {
builder.value(secureString.toString());
} else {
builder.value(value);
}
return Strings.toString(builder);
} catch (Exception e) {
throw new IllegalStateException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString;

import java.net.URI;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -652,7 +651,7 @@ public static void validateMapValues(
}
}

public static Map<String, SerializableSecureString> convertMapStringsToSecureString(
public static Map<String, SecureString> convertMapStringsToSecureString(
Map<String, ?> map,
String settingName,
ValidationException validationException
Expand All @@ -661,11 +660,11 @@ public static Map<String, SerializableSecureString> convertMapStringsToSecureStr
return Map.of();
}

validateMapStringValues(map, settingName, validationException, true);
var validatedMap = validateMapStringValues(map, settingName, validationException, true);

return map.entrySet()
return validatedMap.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> new SerializableSecureString((String) e.getValue())));
.collect(Collectors.toMap(Map.Entry::getKey, e -> new SecureString(e.getValue().toCharArray())));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.SecretSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString;

import java.io.IOException;
import java.util.HashMap;
Expand Down Expand Up @@ -48,22 +48,22 @@ public static CustomSecretSettings fromMap(@Nullable Map<String, Object> map) {
return new CustomSecretSettings(secureStringMap);
}

private final Map<String, SerializableSecureString> secretParameters;
private final Map<String, SecureString> secretParameters;

@Override
public SecretSettings newSecretSettings(Map<String, Object> newSecrets) {
return fromMap(new HashMap<>(newSecrets));
}

public CustomSecretSettings(@Nullable Map<String, SerializableSecureString> secretParameters) {
public CustomSecretSettings(@Nullable Map<String, SecureString> secretParameters) {
this.secretParameters = Objects.requireNonNullElse(secretParameters, Map.of());
}

public CustomSecretSettings(StreamInput in) throws IOException {
secretParameters = in.readImmutableMap(SerializableSecureString::new);
secretParameters = in.readImmutableMap(StreamInput::readSecureString);
}

public Map<String, SerializableSecureString> getSecretParameters() {
public Map<String, SecureString> getSecretParameters() {
return secretParameters;
}

Expand All @@ -74,7 +74,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject(SECRET_PARAMETERS);
{
for (var entry : secretParameters.entrySet()) {
builder.field(entry.getKey(), entry.getValue());
builder.field(entry.getKey(), entry.getValue().toString());
}
}
builder.endObject();
Expand All @@ -95,7 +95,7 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeMap(secretParameters, (streamOutput, v) -> { v.writeTo(streamOutput); });
out.writeMap(secretParameters, StreamOutput::writeSecureString);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.custom.CustomModel;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString;

import java.net.URI;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -70,8 +69,6 @@ private static void addStringParams(Map<String, String> stringParams, Map<String
for (var entry : paramsToAdd.entrySet()) {
if (entry.getValue() instanceof String str) {
stringParams.put(entry.getKey(), str);
} else if (entry.getValue() instanceof SerializableSecureString serializableSecureString) {
stringParams.put(entry.getKey(), serializableSecureString.getSecureString().toString());
} else if (entry.getValue() instanceof SecureString secureString) {
stringParams.put(entry.getKey(), secureString.toString());
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentFragment;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -53,16 +52,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
assertThat(toJson(1.1f, "field"), is("1.1"));
assertThat(toJson(true, "field"), is("true"));
assertThat(toJson(false, "field"), is("false"));
assertThat(toJson(new SerializableSecureString("api_key"), "field"), is("\"api_key\""));
assertThat(toJson(new SecureString("api_key".toCharArray()), "field"), is("\"api_key\""));
}

public void testToJson_ThrowsException_WhenUnableToSerialize() {
var exception = expectThrows(IllegalStateException.class, () -> toJson(new SecureString("string".toCharArray()), "field"));
var exception = expectThrows(IllegalStateException.class, () -> toJson(new Object(), "field"));
assertThat(
exception.getMessage(),
is(
"Failed to serialize value as JSON, field: field, error: "
+ "cannot write xcontent for unknown value of type class org.elasticsearch.common.settings.SecureString"
+ "cannot write xcontent for unknown value of type class java.lang.Object"
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Booleans;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString;

import java.util.EnumSet;
import java.util.HashMap;
Expand Down Expand Up @@ -1116,7 +1116,7 @@ public void testConvertMapStringsToSecureString() {
var validation = new ValidationException();
assertThat(
convertMapStringsToSecureString(Map.of("key", "value", "key2", "abc"), "setting", validation),
is(Map.of("key", new SerializableSecureString("value"), "key2", new SerializableSecureString("abc")))
is(Map.of("key", new SecureString("value".toCharArray()), "key2", new SecureString("abc".toCharArray())))
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.custom;

import org.apache.http.HttpHeaders;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.SimilarityMeasure;
Expand All @@ -17,7 +18,6 @@
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString;
import org.hamcrest.MatcherAssert;

import java.util.HashMap;
Expand All @@ -30,7 +30,7 @@ public class CustomModelTests extends ESTestCase {
private static final String taskSettingsValue = "test_taskSettings_value";

private static final String secretSettingsKey = "test_secret_key";
private static final SerializableSecureString secretSettingsValue = new SerializableSecureString("test_secret_value");
private static final SecureString secretSettingsValue = new SecureString("test_secret_value".toCharArray());
private static final String url = "http://www.abc.com";

public void testOverride_DoesNotModifiedFields_TaskSettingsIsEmpty() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
Expand All @@ -19,7 +20,6 @@
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString;
import org.junit.After;
import org.junit.Before;

Expand Down Expand Up @@ -73,7 +73,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() {
TaskType.RERANK,
serviceSettings,
new CustomTaskSettings(Map.of("url", "^")),
new CustomSecretSettings(Map.of("api_key", new SerializableSecureString("my-secret-key")))
new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
);

var listener = new PlainActionFuture<InferenceServiceResults>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString;

import java.io.IOException;
import java.util.HashMap;
Expand All @@ -28,10 +28,10 @@

public class CustomSecretSettingsTests extends AbstractBWCWireSerializationTestCase<CustomSecretSettings> {
public static CustomSecretSettings createRandom() {
Map<String, SerializableSecureString> secretParameters = randomMap(
Map<String, SecureString> secretParameters = randomMap(
0,
5,
() -> tuple(randomAlphaOfLength(5), new SerializableSecureString(randomAlphaOfLength(5)))
() -> tuple(randomAlphaOfLength(5), new SecureString(randomAlphaOfLength(5).toCharArray()))
);

return new CustomSecretSettings(secretParameters);
Expand All @@ -44,7 +44,7 @@ public void testFromMap() {

assertThat(
CustomSecretSettings.fromMap(secretParameters),
is(new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value"))))
is(new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray()))))
);
}

Expand All @@ -59,7 +59,7 @@ public void testFromMap_RemovesNullValues() {

assertThat(
CustomSecretSettings.fromMap(modifiableMap(Map.of(CustomSecretSettings.SECRET_PARAMETERS, mapWithNulls))),
is(new CustomSecretSettings(Map.of("value", new SerializableSecureString("abc"))))
is(new CustomSecretSettings(Map.of("value", new SecureString("abc".toCharArray()))))
);
}

Expand Down Expand Up @@ -87,7 +87,7 @@ public void testFromMap_DefaultsToEmptyMap_WhenSecretParametersField_DoesNotExis
}

public void testXContent() throws IOException {
var entity = new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value")));
var entity = new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray())));

XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.custom;

import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.InferenceServiceResults;
Expand Down Expand Up @@ -35,7 +36,6 @@
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.settings.SerializableSecureString;

import java.io.IOException;
import java.util.EnumSet;
Expand Down Expand Up @@ -123,7 +123,7 @@ private static CustomModel assertCommonModelFields(Model model) {
assertThat(customModel.getTaskSettings().getParameters(), is(Map.of("test_key", "test_value")));
assertThat(
customModel.getSecretSettings().getSecretParameters(),
is(Map.of("test_key", new SerializableSecureString("test_value")))
is(Map.of("test_key", new SecureString("test_value".toCharArray())))
);

return customModel;
Expand Down Expand Up @@ -249,7 +249,7 @@ private static CustomModel createInternalEmbeddingModel(
new ErrorResponseParser("$.error.message", inferenceId)
),
new CustomTaskSettings(Map.of("key", "test_value")),
new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value")))
new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray())))
);
}

Expand All @@ -271,7 +271,7 @@ private static CustomModel createCustomModel(TaskType taskType, CustomResponsePa
new ErrorResponseParser("$.error.message", inferenceId)
),
new CustomTaskSettings(Map.of("key", "test_value")),
new CustomSecretSettings(Map.of("test_key", new SerializableSecureString("test_value")))
new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray())))
);
}

Expand Down
Loading