Skip to content

Commit

Permalink
Optionally include IConsumerContext and CancellationToken in consumer…
Browse files Browse the repository at this point in the history
… method invocation

Signed-off-by: Richard Pringle <[email protected]>
  • Loading branch information
Richard Pringle committed Apr 1, 2024
1 parent 331f68d commit 80983a4
Show file tree
Hide file tree
Showing 18 changed files with 591 additions and 284 deletions.
5 changes: 3 additions & 2 deletions docs/intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,13 @@ The `SomeConsumer` needs to be registered in the DI container. The SMB runtime w

> When `.WithConsumer<TConsumer>()` is not declared, then a default consumer of type `IConsumer<TMessage>` will be assumed (since v2.0.0).

Alternatively, if you do not want to implement the `IConsumer<SomeMessage>`, then you can provide the method name (2) or a delegate that calls the consumer method (3):
Alternatively, if you do not want to implement the `IConsumer<SomeMessage>`, then you can provide the method name (2) or a delegate that calls the consumer method (3).
`IConsumerContext` and/or `CancellationToken` can optionally be included as parameters to be populated on invocation when taking this approach:

```cs
public class SomeConsumer
{
public async Task MyHandleMethod(SomeMessage msg)
public async Task MyHandleMethod(SomeMessage msg, IConsumerContext consumerContext, CancellationToken cancellationToken)
{
// handle the msg
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,53 @@ public T Do<T>(Action<T> builder) where T : AbstractConsumerBuilder

static internal void SetupConsumerOnHandleMethod(IMessageTypeConsumerInvokerSettings invoker, string methodName = null)
{
static bool ParameterMatch(IMessageTypeConsumerInvokerSettings invoker, MethodInfo methodInfo)
{
var parameters = new List<Type>(methodInfo.GetParameters().Select(x => x.ParameterType));

var requiredParameters = new[] { invoker.MessageType };
foreach (var parameter in requiredParameters)
{
if (!parameters.Remove(parameter))
{
return false;
}
}

var allowedParameters = new[] { typeof(IConsumerContext), typeof(CancellationToken) };
foreach (var parameter in allowedParameters)
{
parameters.Remove(parameter);
}

if (parameters.Count != 0)
{
return false;
};

// ensure the method returns a Task or Task<T>
if (!typeof(Task).IsAssignableFrom(methodInfo.ReturnType))
{
return false;
}

return true;
}

if (invoker == null) throw new ArgumentNullException(nameof(invoker));

methodName ??= nameof(IConsumer<object>.OnHandle);

/// See <see cref="IConsumer{TMessage}.OnHandle(TMessage)"/> and <see cref="IRequestHandler{TRequest, TResponse}.OnHandle(TRequest)"/>

var consumerOnHandleMethod = invoker.ConsumerType.GetMethod(methodName, new[] { invoker.MessageType });
if (consumerOnHandleMethod == null)
{
throw new ConfigurationMessageBusException($"Consumer type {invoker.ConsumerType} validation error: the method {methodName} with parameters of type {invoker.MessageType} was not found.");
}
var consumerOnHandleMethod = invoker.ConsumerType.GetMethods(BindingFlags.Public | BindingFlags.Instance)
.Where(x => x.Name.Equals(methodName, StringComparison.OrdinalIgnoreCase) && ParameterMatch(invoker, x))
.OrderByDescending(x => x.GetParameters().Length)
.FirstOrDefault();

// ensure the method returns a Task or Task<T>
if (!typeof(Task).IsAssignableFrom(consumerOnHandleMethod.ReturnType))
if (consumerOnHandleMethod == null)
{
throw new ConfigurationMessageBusException($"Consumer type {invoker.ConsumerType} validation error: the response type of method {methodName} must return {typeof(Task)}");
throw new ConfigurationMessageBusException($"Consumer type {invoker.ConsumerType} validation error: no suitable method candidate with name {methodName} can be found");
}

invoker.ConsumerMethodInfo = consumerOnHandleMethod;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public ConsumerBuilder<T> WithConsumer<TConsumer>()
where TConsumer : class, IConsumer<T>
{
ConsumerSettings.ConsumerType = typeof(TConsumer);
ConsumerSettings.ConsumerMethod = (consumer, message) => ((IConsumer<T>)consumer).OnHandle((T)message);
ConsumerSettings.ConsumerMethod = (consumer, message, _, _) => ((IConsumer<T>)consumer).OnHandle((T)message);

ConsumerSettings.Invokers.Add(ConsumerSettings);

Expand All @@ -58,7 +58,7 @@ public ConsumerBuilder<T> WithConsumer<TConsumer, TMessage>()

var invoker = new MessageTypeConsumerInvokerSettings(ConsumerSettings, messageType: typeof(TMessage), consumerType: typeof(TConsumer))
{
ConsumerMethod = (consumer, message) => ((IConsumer<TMessage>)consumer).OnHandle((TMessage)message)
ConsumerMethod = (consumer, message, _, _) => ((IConsumer<TMessage>)consumer).OnHandle((TMessage)message)
};
ConsumerSettings.Invokers.Add(invoker);

Expand All @@ -71,7 +71,7 @@ public ConsumerBuilder<T> WithConsumer<TConsumer, TMessage>()
/// </summary>
/// <typeparam name="TConsumer"></typeparam>
/// <returns></returns>
public ConsumerBuilder<T> WithConsumer(Type derivedConsumerType, Type derivedMessageType)
public ConsumerBuilder<T> WithConsumer(Type derivedConsumerType, Type derivedMessageType, string methodName = null)
{
AssertInvokerUnique(derivedConsumerType, derivedMessageType);

Expand All @@ -81,7 +81,7 @@ public ConsumerBuilder<T> WithConsumer(Type derivedConsumerType, Type derivedMes
}

var invoker = new MessageTypeConsumerInvokerSettings(ConsumerSettings, messageType: derivedMessageType, consumerType: derivedConsumerType);
SetupConsumerOnHandleMethod(invoker);
SetupConsumerOnHandleMethod(invoker, methodName);
ConsumerSettings.Invokers.Add(invoker);

return this;
Expand All @@ -99,7 +99,7 @@ public ConsumerBuilder<T> WithConsumer<TConsumer>(Func<TConsumer, T, Task> consu
if (consumerMethod == null) throw new ArgumentNullException(nameof(consumerMethod));

ConsumerSettings.ConsumerType = typeof(TConsumer);
ConsumerSettings.ConsumerMethod = (consumer, message) => consumerMethod((TConsumer)consumer, (T)message);
ConsumerSettings.ConsumerMethod = (consumer, message, _, _) => consumerMethod((TConsumer)consumer, (T)message);

ConsumerSettings.Invokers.Add(ConsumerSettings);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public HandlerBuilder<TRequest, TResponse> WithHandler<THandler>()
where THandler : IRequestHandler<TRequest, TResponse>
{
ConsumerSettings.ConsumerType = typeof(THandler);
ConsumerSettings.ConsumerMethod = (consumer, message) => ((THandler)consumer).OnHandle((TRequest)message);
ConsumerSettings.ConsumerMethod = (consumer, message, _, _) => ((THandler)consumer).OnHandle((TRequest)message);

ConsumerSettings.Invokers.Add(ConsumerSettings);

Expand All @@ -152,7 +152,7 @@ public HandlerBuilder<TRequest, TResponse> WithHandler<THandler, TDerivedRequest

var invoker = new MessageTypeConsumerInvokerSettings(ConsumerSettings, messageType: typeof(TDerivedRequest), consumerType: typeof(THandler))
{
ConsumerMethod = (consumer, message) => ((IRequestHandler<TDerivedRequest, TResponse>)consumer).OnHandle((TDerivedRequest)message)
ConsumerMethod = (consumer, message, _, _) => ((IRequestHandler<TDerivedRequest, TResponse>)consumer).OnHandle((TDerivedRequest)message)
};
ConsumerSettings.Invokers.Add(invoker);

Expand All @@ -178,7 +178,7 @@ public HandlerBuilder<TRequest> WithHandler<THandler>()
where THandler : IRequestHandler<TRequest>
{
ConsumerSettings.ConsumerType = typeof(THandler);
ConsumerSettings.ConsumerMethod = (consumer, message) => ((THandler)consumer).OnHandle((TRequest)message);
ConsumerSettings.ConsumerMethod = (consumer, message, _, _) => ((THandler)consumer).OnHandle((TRequest)message);

ConsumerSettings.Invokers.Add(ConsumerSettings);

Expand All @@ -200,7 +200,7 @@ public HandlerBuilder<TRequest> WithHandler<THandler, TDerivedRequest>()

var invoker = new MessageTypeConsumerInvokerSettings(ConsumerSettings, messageType: typeof(TDerivedRequest), consumerType: typeof(THandler))
{
ConsumerMethod = (consumer, message) => ((IRequestHandler<TDerivedRequest>)consumer).OnHandle((TDerivedRequest)message)
ConsumerMethod = (consumer, message, _, _) => ((IRequestHandler<TDerivedRequest>)consumer).OnHandle((TDerivedRequest)message)
};
ConsumerSettings.Invokers.Add(invoker);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ private void CalculateResponseType()
/// <inheritdoc/>
public Type ConsumerType { get; set; }
/// <inheritdoc/>
public Func<object, object, Task> ConsumerMethod { get; set; }
public Func<object, object, IConsumerContext, CancellationToken, Task> ConsumerMethod { get; set; }
/// <inheritdoc/>
public MethodInfo ConsumerMethodInfo { get; set; }
/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public interface IMessageTypeConsumerInvokerSettings
/// <summary>
/// The delegate to the consumer method responsible for accepting messages.
/// </summary>
Func<object, object, Task> ConsumerMethod { get; set; }
Func<object, object, IConsumerContext, CancellationToken, Task> ConsumerMethod { get; set; }
/// <summary>
/// The consumer method.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public class MessageTypeConsumerInvokerSettings : IMessageTypeConsumerInvokerSet
/// <inheritdoc/>
public Type ConsumerType { get; }
/// <inheritdoc/>
public Func<object, object, Task> ConsumerMethod { get; set; }
public Func<object, object, IConsumerContext, CancellationToken, Task> ConsumerMethod { get; set; }
/// <inheritdoc/>
public MethodInfo ConsumerMethodInfo { get; set; }

Expand Down
4 changes: 2 additions & 2 deletions src/SlimMessageBus.Host/Consumer/Context/ConsumerContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public class ConsumerContext(IDictionary<string, object> properties = null) : IC

public IReadOnlyDictionary<string, object> Headers { get; set; }

public CancellationToken CancellationToken { get; set; }
public CancellationToken CancellationToken { get; set; } = default;

public IMessageBus Bus { get; set; }

Expand All @@ -24,4 +24,4 @@ public IDictionary<string, object> Properties
public object Consumer { get; set; }

public IMessageTypeConsumerInvokerSettings ConsumerInvoker { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ public async Task<object> ExecuteConsumer(object message, IConsumerContext consu
}

// the consumer just subscribes to the message
var task = consumerInvoker.ConsumerMethod(consumerContext.Consumer, message);
var task = consumerInvoker.ConsumerMethod(consumerContext.Consumer, message, consumerContext, consumerContext.CancellationToken);
await task.ConfigureAwait(false);

if (responseType != null && responseType != typeof(Void))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public void Run(MessageBusSettings settings)
.SelectMany(x => x.Invokers).ToList();
foreach (var consumerInvoker in consumerInvokers.Where(x => x.ConsumerMethod == null && x.ConsumerMethodInfo != null))
{
consumerInvoker.ConsumerMethod = ReflectionUtils.GenerateMethodCallToFunc<Func<object, object, Task>>(consumerInvoker.ConsumerMethodInfo, consumerInvoker.ConsumerType, typeof(Task), consumerInvoker.MessageType);
consumerInvoker.ConsumerMethod = ReflectionUtils.GenerateMethodCallToFunc<Func<object, object, IConsumerContext, CancellationToken, Task>>(consumerInvoker.ConsumerMethodInfo, consumerInvoker.MessageType);
}
}
}
99 changes: 98 additions & 1 deletion src/SlimMessageBus.Host/Helpers/ReflectionUtils.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
namespace SlimMessageBus.Host;
namespace SlimMessageBus.Host;

using System.Diagnostics;
using System.Linq.Expressions;

public static class ReflectionUtils
Expand Down Expand Up @@ -40,6 +41,102 @@ public static T GenerateMethodCallToFunc<T>(MethodInfo method, Type instanceType
return Expression.Lambda<T>(typedMethodResultExpr, new[] { objInstanceExpr }.Concat(objArguments)).Compile();
}

/// <summary>
/// Creates a delegate for the specified method wrapping both required and optional parameters.
///
/// The first parameter in the delegate is the instance to invoke the method against and must be supplied as an object.
/// Subsequent parameters that are supplied as objects and are typed (with index) in argumentTypes are required.
/// Any further parameters are typed and optional.
///
/// The target method can accept the parameters in any order. As such, types are explicit and cannot be duplicated.
/// </summary>
/// <typeparam name="TDelegate">Method facade</typeparam>
/// <param name="methodInfo">Target method to invoke</param>
/// <param name="argumentTypes">Required types (indexed 1.. in delegate)</param>
/// <returns></returns>
/// <example>
/// GenerateMethodCallToFunc<Func<object, object, IConsumerContext, CancellationToken, Task>>(methodInfo, typeof(SampleMessage));
///
/// Initial object is the instance to invoke the method on (type determined by methodInfo.DeclaringType)
/// SampleMessage is required as a parameter defined by methodInfo
/// IConsumerContext and CancellationToken are optional parameters as defined by methodInfo. If they exist, they will be populated otherwise ignored.
///
/// methodInfo must:
/// * be for an instance (static not supported in current implementation)
/// * contain at least a parameter of type SampleMessage
/// * optionally require parameters of type IConsumerContext and CancellationToken
/// * require no other parameters
/// * return a Task (as specified by the delegate)
/// </example>
/// <exception cref="ArgumentNullException"><see cref="methodInfo"/> is required</exception>
/// <exception cref="ArgumentException">Target invocation requires unsupplied parameter</exception>
/// <exception cref="ArgumentException">Required parameter(s) missing from target invocation</exception>
public static TDelegate GenerateMethodCallToFunc<TDelegate>(MethodInfo methodInfo, params Type[] argumentTypes)
where TDelegate : Delegate
{
if (methodInfo == null)
{
throw new ArgumentNullException(nameof(methodInfo));
}

var delegateSignature = typeof(TDelegate).GetMethod("Invoke")!;
var delegateReturn = delegateSignature.ReturnType;

Debug.Assert(delegateSignature.ReturnType == methodInfo.ReturnType);

var instanceParameter = Expression.Parameter(typeof(object), "instance");
var optionalTypes = delegateSignature.GetParameters()
.Skip(argumentTypes.Length + 1)
.Select(p => p.ParameterType);

var parameters = argumentTypes.Select(
(type, index) =>
new
{
Expression = Expression.Parameter(typeof(object), $"arg{index}"),
Required = true,
Type = type
})
.Union(
optionalTypes.Select(
(type, index) =>
new
{
Expression = Expression.Parameter(type, $"optArg{index}"),
Required = false,
Type = type
}))
.ToDictionary(x => x.Type, x => x);

var allParameters = parameters.Select(x => x.Value.Expression).ToList();

var argumentExpressions = methodInfo.GetParameters().Select(
p =>
{
if (parameters.TryGetValue(p.ParameterType, out var arg) && parameters.Remove(p.ParameterType))
{
return Expression.Convert(arg.Expression, p.ParameterType);
}

throw new ArgumentException($"Target invocation requires unsupplied parameter {p.ParameterType.AssemblyQualifiedName}");
}).ToList();

var missing = parameters.Values.Where(x => x.Required).Select(x => $"'{x.Type.AssemblyQualifiedName}'").ToList();
if (missing.Count > 0)
{
throw new ArgumentException($"Required parameter(s) missing from target invocation ({string.Join(", ", missing)})");
}

var callExpression = Expression.Call(
Expression.Convert(instanceParameter, methodInfo.DeclaringType!),
methodInfo,
argumentExpressions);

var lambda = Expression.Lambda<TDelegate>(callExpression, new[] { instanceParameter }.Concat(allParameters));

return lambda.Compile();
}

public static T GenerateGenericMethodCallToFunc<T>(MethodInfo genericMethod, Type[] genericTypeArguments, Type instanceType, Type returnType, params Type[] argumentTypes)
{
var method = genericMethod.MakeGenericMethod(genericTypeArguments);
Expand Down
Loading

0 comments on commit 80983a4

Please sign in to comment.