Skip to content

Commit

Permalink
Added header to HTTP credential provider (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavr12 authored Aug 12, 2024
1 parent b0d1a0d commit 231b974
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 9 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,18 @@ spark = SparkSession\

# try spark sql commands
```

## Module Configurations

### HTTP Credentials Provider

The HTTP credentials provider provides an option to include additional headers on requests sent to the HTTP service (e.g., for authentication).

These can be configured with `credentials-provider.http.headers`. This config entry is formatted as a comma-separated list of header names and values, where each entry is in the format `header-name:header-value`.

For instance, `header1:value1,header2:value2`.
If a header name or value should contain a comma, these can be escaped by doubling them (`,,` translates to a single comma in the literal header name or value, and is not treated as a separator).

E.g.: setting this config property to `"x-api-key: xyz,,123, Authorization: key,,,,123"` results in 2 headers:
- `x-api-key`: with value `xyz,123`
- `Authorization`: with value `key,,123`
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Multimaps;
import com.google.inject.Inject;
import io.airlift.http.client.FullJsonResponseHandler.JsonResponse;
import io.airlift.http.client.HttpClient;
Expand All @@ -28,6 +30,7 @@
import jakarta.ws.rs.core.UriBuilder;

import java.net.URI;
import java.util.Map;
import java.util.Optional;

import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler;
Expand All @@ -40,6 +43,7 @@ public class HttpCredentialsProvider
private final HttpClient httpClient;
private final JsonCodec<Credentials> jsonCodec;
private final URI httpCredentialsProviderEndpoint;
private final Map<String, String> httpHeaders;

@Inject
public HttpCredentialsProvider(@ForHttpCredentialsProvider HttpClient httpClient, HttpCredentialsProviderConfig config, ObjectMapper objectMapper, Class<? extends Identity> identityClass)
Expand All @@ -49,17 +53,18 @@ public HttpCredentialsProvider(@ForHttpCredentialsProvider HttpClient httpClient
this.httpCredentialsProviderEndpoint = config.getEndpoint();
ObjectMapper adjustedObjectMapper = objectMapper.registerModule(new SimpleModule().addAbstractTypeMapping(Identity.class, identityClass));
this.jsonCodec = new JsonCodecFactory(() -> adjustedObjectMapper).jsonCodec(Credentials.class);
this.httpHeaders = ImmutableMap.copyOf(config.getHttpHeaders());
}

@Override
public Optional<Credentials> credentials(String emulatedAccessKey, Optional<String> session)
{
UriBuilder uriBuilder = UriBuilder.fromUri(httpCredentialsProviderEndpoint).path(emulatedAccessKey);
session.ifPresent(sessionToken -> uriBuilder.queryParam("sessionToken", sessionToken));
Request request = prepareGet()
.setUri(uriBuilder.build())
.build();
JsonResponse<Credentials> response = httpClient.execute(request, createFullJsonResponseHandler(jsonCodec));
Request.Builder requestBuilder = prepareGet()
.addHeaders(Multimaps.forMap(httpHeaders))
.setUri(uriBuilder.build());
JsonResponse<Credentials> response = httpClient.execute(requestBuilder.build(), createFullJsonResponseHandler(jsonCodec));
if (response.getStatusCode() == HttpStatus.NOT_FOUND.code() || !response.hasValue()) {
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@
*/
package io.trino.aws.proxy.server.credentials.http;

import com.google.common.base.Splitter;
import io.airlift.configuration.Config;
import jakarta.validation.constraints.NotNull;

import java.net.URI;
import java.util.Map;

import static com.google.common.collect.ImmutableMap.toImmutableMap;

public class HttpCredentialsProviderConfig
{
private URI endpoint;
private Map<String, String> httpHeaders = Map.of();

@NotNull
public URI getEndpoint()
Expand All @@ -34,4 +39,27 @@ public HttpCredentialsProviderConfig setEndpoint(String endpoint)
this.endpoint = URI.create(endpoint);
return this;
}

public Map<String, String> getHttpHeaders()
{
return httpHeaders;
}

@Config("credentials-provider.http.headers")
public HttpCredentialsProviderConfig setHttpHeaders(String httpHeadersList)
{
try {
this.httpHeaders = Splitter.on(",").trimResults().omitEmptyStrings()
.splitToStream(httpHeadersList.replaceAll(",,", "\r"))
.map(item -> item.replace("\r", ","))
.map(s -> s.split(":", 2))
.collect(toImmutableMap(
a -> a[0].trim(),
a -> a[1].trim()));
}
catch (IndexOutOfBoundsException e) {
throw new IllegalArgumentException("Invalid HTTP header list: " + httpHeadersList);
}
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.aws.proxy.server.credentials.http;

import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;
Expand Down Expand Up @@ -50,13 +51,16 @@ public class TestHttpCredentialsProvider
public static class Filter
implements BuilderFilter
{
private static String httpEndpointUri;

@Override
public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Builder builder)
{
TestingHttpServer httpCredentialsServer;
try {
httpCredentialsServer = createTestingHttpCredentialsServer();
httpCredentialsServer.start();
httpEndpointUri = httpCredentialsServer.getBaseUrl().toString();
}
catch (Exception e) {
throw new RuntimeException("Failed to start test http credentials provider server", e);
Expand All @@ -65,7 +69,8 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil
.addModule(new HttpCredentialsModule())
.addModule(binder -> bindIdentityType(binder, TestingIdentity.class))
.withProperty("credentials-provider.type", HTTP_CREDENTIALS_PROVIDER_IDENTIFIER)
.withProperty("credentials-provider.http.endpoint", httpCredentialsServer.getBaseUrl().toString());
.withProperty("credentials-provider.http.endpoint", httpEndpointUri)
.withProperty("credentials-provider.http.headers", "Authorization: auth, Content-Type: application/json");
}
}

Expand Down Expand Up @@ -139,6 +144,10 @@ private static class HttpCredentialsServlet
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws IOException
{
if (Strings.isNullOrEmpty(request.getHeader("Authorization")) || Strings.isNullOrEmpty("Content-Type")) {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
return;
}
Optional<String> sessionToken = Optional.ofNullable(request.getParameter("sessionToken"));
String emulatedAccessKey = request.getPathInfo().substring(1);
String credentialsIdentifier = "";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,60 @@
import com.google.common.collect.ImmutableMap;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.util.Map;

import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;

public class TestHttpCredentialsProviderConfig
{
@Test
public void testExplicitPropertyMappings()
throws IOException
{
Map<String, String> properties = ImmutableMap.of(
"credentials-provider.http.endpoint", "http://usersvc:9000/api/v1/users");
"credentials-provider.http.endpoint", "http://usersvc:9000/api/v1/users",
"credentials-provider.http.headers", "x-api-key: xyz123, Content-Type: application/json");
HttpCredentialsProviderConfig expected = new HttpCredentialsProviderConfig()
.setEndpoint("http://usersvc:9000/api/v1/users");
.setEndpoint("http://usersvc:9000/api/v1/users")
.setHttpHeaders("x-api-key: xyz123, Content-Type: application/json");
assertFullMapping(properties, expected);
}

@Test
public void testValidHttpHeaderVariation1()
{
HttpCredentialsProviderConfig config = new HttpCredentialsProviderConfig()
.setEndpoint("http://usersvc:9000/api/v1/users")
.setHttpHeaders("x-api-key: Authorization: xyz123");
Map<String, String> httpHeaders = config.getHttpHeaders();
assertThat(httpHeaders.get("x-api-key")).isEqualTo("Authorization: xyz123");
}

@Test
public void testValidHttpHeaderVariation2()
{
HttpCredentialsProviderConfig config = new HttpCredentialsProviderConfig()
.setEndpoint("http://usersvc:9000/api/v1/users")
.setHttpHeaders("x-api-key: xyz,,123, Authorization: key,,,,123");
Map<String, String> httpHeaders = config.getHttpHeaders();
assertThat(httpHeaders.get("x-api-key")).isEqualTo("xyz,123");
assertThat(httpHeaders.get("Authorization")).isEqualTo("key,,123");
}

@Test
public void testIncorrectHttpHeader1()
{
assertThrows(IllegalArgumentException.class, () -> new HttpCredentialsProviderConfig()
.setEndpoint("http://usersvc:9000/api/v1/users")
.setHttpHeaders("malformed-header"));
}

@Test
public void testIncorrectHttpHeader2()
{
assertThrows(IllegalArgumentException.class, () -> new HttpCredentialsProviderConfig()
.setEndpoint("http://usersvc:9000/api/v1/users")
.setHttpHeaders("x-api-key: xyz,,,123, Authorization: key123"));
}
}

0 comments on commit 231b974

Please sign in to comment.