Skip to content

Commit

Permalink
Merge pull request Remora#310 from Itamaram/async-token-store
Browse files Browse the repository at this point in the history
Allow fetching tokens asynchrnously
  • Loading branch information
Nihlus authored Aug 22, 2023
2 parents 0cf86d2 + 300da01 commit b07282d
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 43 deletions.
8 changes: 4 additions & 4 deletions Backend/Remora.Discord.Gateway/DiscordGatewayClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public class DiscordGatewayClient : IDisposable

private readonly IDiscordRestGatewayAPI _gatewayAPI;
private readonly DiscordGatewayClientOptions _gatewayOptions;
private readonly ITokenStore _tokenStore;
private readonly IAsyncTokenStore _tokenStore;
private readonly Random _random;

private readonly IResponderDispatchService _responderDispatch;
Expand Down Expand Up @@ -130,7 +130,7 @@ public DiscordGatewayClient
IDiscordRestGatewayAPI gatewayAPI,
IPayloadTransportService transportService,
IOptions<DiscordGatewayClientOptions> gatewayOptions,
ITokenStore tokenStore,
IAsyncTokenStore tokenStore,
Random random,
ILogger<DiscordGatewayClient> log,
IServiceProvider services,
Expand Down Expand Up @@ -705,7 +705,7 @@ private async Task<Result> CreateNewSessionAsync(CancellationToken ct = default)
(
new Identify
(
_tokenStore.Token,
await _tokenStore.GetTokenAsync(ct),
_gatewayOptions.ConnectionProperties,
false,
_gatewayOptions.LargeThreshold,
Expand Down Expand Up @@ -811,7 +811,7 @@ private async Task<Result> ResumeExistingSessionAsync(CancellationToken ct = def
(
new Resume
(
_tokenStore.Token,
await _tokenStore.GetTokenAsync(ct),
_sessionInformation.SessionID,
_sessionInformation.SequenceNumber
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,27 @@ public static IServiceCollection AddDiscordGateway
Action<IHttpClientBuilder>? buildClient = null
)
{
serviceCollection
.AddDiscordRest
(
s => (tokenFactory(s), DiscordTokenType.Bot),
buildClient
);
serviceCollection.AddSingleton<IAsyncTokenStore>
(
ctx => new StaticTokenStore(tokenFactory(ctx), DiscordTokenType.Bot)
);

return serviceCollection.AddDiscordGateway(buildClient);
}

/// <summary>
/// Adds services required by the Discord Gateway system.
/// </summary>
/// <param name="serviceCollection">The service collection.</param>
/// <param name="buildClient">Extra options to configure the rest client.</param>
/// <returns>The service collection, with the services added.</returns>
public static IServiceCollection AddDiscordGateway
(
this IServiceCollection serviceCollection,
Action<IHttpClientBuilder>? buildClient = null
)
{
serviceCollection.AddDiscordRest(buildClient);

serviceCollection.TryAddSingleton<Random>();
serviceCollection.TryAddSingleton<IResponderDispatchService, ResponderDispatchService>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,27 @@ public static IServiceCollection AddDiscordRest
Func<IServiceProvider, (string Token, DiscordTokenType TokenType)> tokenFactory,
Action<IHttpClientBuilder>? buildClient = null
)
{
serviceCollection.AddSingleton<IAsyncTokenStore>(ctx =>
{
var (token, type) = tokenFactory(ctx);
return new StaticTokenStore(token, type);
});

return serviceCollection.AddDiscordRest(buildClient);
}

/// <summary>
/// Adds the services required for Discord's REST API.
/// </summary>
/// <param name="serviceCollection">The service collection.</param>
/// <param name="buildClient">Extra client building operations.</param>
/// <returns>The service collection, with the services added.</returns>
public static IServiceCollection AddDiscordRest
(
this IServiceCollection serviceCollection,
Action<IHttpClientBuilder>? buildClient = null
)
{
serviceCollection.AddMemoryCache();

Expand All @@ -76,13 +97,7 @@ public static IServiceCollection AddDiscordRest

serviceCollection.ConfigureDiscordJsonConverters();

serviceCollection
.AddSingleton<ITokenStore>(serviceProvider =>
{
var (token, tokenType) = tokenFactory(serviceProvider);
return new TokenStore(token, tokenType);
})
.AddTransient<TokenAuthorizationHandler>();
serviceCollection.AddTransient<TokenAuthorizationHandler>();

serviceCollection.TryAddTransient<IDiscordRestAuditLogAPI>(s => new DiscordRestAuditLogAPI
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,25 @@ namespace Remora.Discord.Rest.Handlers;
/// </summary>
internal class TokenAuthorizationHandler : DelegatingHandler
{
private readonly ITokenStore _tokenStore;
private readonly IAsyncTokenStore _tokenStore;

/// <summary>
/// Initializes a new instance of the <see cref="TokenAuthorizationHandler"/> class.
/// </summary>
/// <param name="tokenStore">The token store.</param>
public TokenAuthorizationHandler(ITokenStore tokenStore)
public TokenAuthorizationHandler(IAsyncTokenStore tokenStore)
{
_tokenStore = tokenStore;
}

/// <inheritdoc />
protected override Task<HttpResponseMessage> SendAsync
protected override async Task<HttpResponseMessage> SendAsync
(
HttpRequestMessage request,
CancellationToken cancellationToken
)
{
var token = _tokenStore.Token;
var token = await _tokenStore.GetTokenAsync(cancellationToken);
var tokenType = _tokenStore.TokenType;

if (string.IsNullOrWhiteSpace(token))
Expand All @@ -67,13 +67,13 @@ CancellationToken cancellationToken
if (request.Properties.ContainsKey(Constants.SkipAuthorizationPropertyName))
#endif
{
return base.SendAsync(request, cancellationToken);
return await base.SendAsync(request, cancellationToken);
}

AddTokenToPollyContext(request, token);
AddAuthorizationHeader(request, token, tokenType);

return base.SendAsync(request, cancellationToken);
return await base.SendAsync(request, cancellationToken);
}

private static void AddAuthorizationHeader
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// ITokenStore.cs
// IAsyncTokenStore.cs
//
// Author:
// Jarl Gullberg <[email protected]>
Expand All @@ -20,6 +20,8 @@
// along with this program. If not, see <http://www.gnu.org/licenses/>.
//

using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;

namespace Remora.Discord.Rest;
Expand All @@ -28,12 +30,14 @@ namespace Remora.Discord.Rest;
/// Represents a storage class for a single token.
/// </summary>
[PublicAPI]
public interface ITokenStore
public interface IAsyncTokenStore
{
/// <summary>
/// Gets the token.
/// </summary>
string Token { get; }
/// <param name="cancellationToken">A cancellation token to cancel operation.</param>
/// <returns>The token's value.</returns>
ValueTask<string> GetTokenAsync(CancellationToken cancellationToken);

/// <summary>
/// Gets the type of the token.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// TokenStore.cs
// StaticTokenStore.cs
//
// Author:
// Jarl Gullberg <[email protected]>
Expand All @@ -20,30 +20,34 @@
// along with this program. If not, see <http://www.gnu.org/licenses/>.
//

using System.Threading;
using System.Threading.Tasks;
using JetBrains.Annotations;

namespace Remora.Discord.Rest;

/// <summary>
/// Represents a storage class for a single token.
/// Represents a storage class for a static token.
/// </summary>
[PublicAPI]
public class TokenStore : ITokenStore
public class StaticTokenStore : IAsyncTokenStore
{
private readonly string _token;

/// <inheritdoc />
public string Token { get; }
public ValueTask<string> GetTokenAsync(CancellationToken cancellationToken) => new(_token);

/// <inheritdoc />
public DiscordTokenType TokenType { get; }

/// <summary>
/// Initializes a new instance of the <see cref="TokenStore"/> class.
/// Initializes a new instance of the <see cref="StaticTokenStore"/> class.
/// </summary>
/// <param name="token">The token to store.</param>
/// <param name="tokenType">The type of token to store.</param>
public TokenStore(string token, DiscordTokenType tokenType)
public StaticTokenStore(string token, DiscordTokenType tokenType)
{
this.Token = token;
_token = token;
this.TokenType = tokenType;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// TokenStoreTests.cs
// StaticTokenStoreTests.cs
//
// Author:
// Jarl Gullberg <[email protected]>
Expand All @@ -20,6 +20,7 @@
// along with this program. If not, see <http://www.gnu.org/licenses/>.
//

using System.Threading.Tasks;
using Remora.Discord.Rest;
using Xunit;

Expand All @@ -29,41 +30,41 @@
namespace Remora.Discord.Tests.Tests.Core;

/// <summary>
/// Tests the <see cref="TokenStore"/> class.
/// Tests the <see cref="StaticTokenStore"/> class.
/// </summary>
public class TokenStoreTests
public class StaticTokenStoreTests
{
/// <summary>
/// Tests the <see cref="TokenStore.Token"/> property.
/// Tests the <see cref="Token"/> property.
/// </summary>
public class Token
{
[Fact]
public void ReturnsCorrectValue()
public async Task ReturnsCorrectValue()
{
var tokenStore = new TokenStore("Hello world!", DiscordTokenType.Bearer);
var tokenStore = new StaticTokenStore("Hello world!", DiscordTokenType.Bearer);

Assert.Equal("Hello world!", tokenStore.Token);
Assert.Equal("Hello world!", await tokenStore.GetTokenAsync(default));
}
}

/// <summary>
/// Tests the <see cref="TokenStore.TokenType"/> property.
/// Tests the <see cref="TokenType"/> property.
/// </summary>
public class TokenType
{
[Fact]
public void ReturnsCorrectValueForBotTokenType()
{
var tokenStore = new TokenStore("Hello world!", DiscordTokenType.Bot);
var tokenStore = new StaticTokenStore("Hello world!", DiscordTokenType.Bot);

Assert.Equal(DiscordTokenType.Bot, tokenStore.TokenType);
}

[Fact]
public void ReturnsCorrectValueForBearerTokenType()
{
var tokenStore = new TokenStore("Hello world!", DiscordTokenType.Bearer);
var tokenStore = new StaticTokenStore("Hello world!", DiscordTokenType.Bearer);

Assert.Equal(DiscordTokenType.Bearer, tokenStore.TokenType);
}
Expand Down

0 comments on commit b07282d

Please sign in to comment.