Skip to content

Commit

Permalink
Multiple source gen fixes (#39)
Browse files Browse the repository at this point in the history
* Fix code generation for methods with default arguments

* Fix code generation for interfaces inheriting from other interfaces (#36)

- Update the code generator to properly recognize and process inherited
  interfaces.
- Add unit test to validate the changes.

* use unified interface and run formatter

* Fix code generation for methods with default arguments (#35)

* move default argument tests to shared project

* Fix code generation for interfaces with overloaded methods (#38)

* Fix code generation for interfaces inheriting from other interfaces

- Update the code generator to properly recognize and process inherited
  interfaces.
- Add unit test to validate the changes.

* Fix code generation for interfaces with overloaded methods

---------

Co-authored-by: Karsten Heimrich <[email protected]>
Co-authored-by: Karsten Heimrich <[email protected]>
  • Loading branch information
3 people authored Jan 23, 2025
1 parent f5c5713 commit d06658b
Show file tree
Hide file tree
Showing 14 changed files with 433 additions and 51 deletions.
78 changes: 72 additions & 6 deletions src/MockMe.Generator/Extensions/MethodSymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ public static string GetParametersWithOriginalTypesAndModifiers(this IMethodSymb
GetParametersWithTypesAndModifiers(method);

public static string GetParametersWithArgTypesAndModifiers(this IMethodSymbol method) =>
GetParametersWithTypesAndModifiers(method, "Arg<", ">");
GetParametersWithTypesAndModifiers(method, "Arg<", ">", true);

private static string GetParametersWithTypesAndModifiers(
this IMethodSymbol method,
string? typePrefix = null,
string? typePostfix = null
string? typePostfix = null,
bool wrapInArg = false
)
{
if (method.Parameters.Length == 0)
Expand All @@ -36,14 +37,18 @@ private static string GetParametersWithTypesAndModifiers(
RefKind.None or _ => p.IsParams ? "params " : "",
};

// Build the main "ref int x" or "Arg<int> x" part
var paramString =
$"{modifiers}{typePrefix}{p.Type.ToFullTypeString()}{typePostfix} {p.Name}";
if (p.HasExplicitDefaultValue)

// If the original parameter had a default value, we only append it if we're NOT
// wrapping in Arg<...>. (Skipping avoids e.g. Arg<int> x = 2 which is invalid.)
if (p.HasExplicitDefaultValue && !wrapInArg)
{
var defaultValue =
p.ExplicitDefaultValue != null ? p.ExplicitDefaultValue.ToString() : "null";
var defaultValue = GetDefaultValueForType(p.Type, p.ExplicitDefaultValue);
paramString += $" = {defaultValue}";
}

return paramString;
})
);
Expand Down Expand Up @@ -133,7 +138,68 @@ public static string GetUniqueMethodName(this IMethodSymbol methodSymbol)
var parameterTypes = methodSymbol.Parameters.Select(p =>
(p.RefKind == RefKind.None ? "" : p.RefKind.ToString()) + p.Type.Name
);
var uniqueMethodName = $"{methodName}_{string.Join("_", parameterTypes)}";
var uniqueMethodName =
$"{methodName}_{string.Join("_", methodSymbol.TypeParameters.Select(p => p.Name)).AddSuffixIfNotEmpty("_")}{string.Join("_", parameterTypes)}";

return uniqueMethodName;
}

private static string GetDefaultValueForType(ITypeSymbol type, object? explicitValue)
{
// If the compiler recognized a default value, it sets HasExplicitDefaultValue = true,
// and ExplicitDefaultValue can be null (for '= default') or a constant (for '= 2', '= "x"', etc.)

switch (explicitValue)
{
case null:
// For a non-nullable value type, 'default'
// For reference types or nullable, 'null'
return type.IsValueType ? "default" : "null";
case bool b:
return b ? "true" : "false";
// If we do have a constant, handle string vs. others
// Wrap string in quotes
case string s:
return $"\"{s}\"";
}

if (type.TypeKind == TypeKind.Enum && type is INamedTypeSymbol namedEnum)
{
return GetEnumDefaultValueString(namedEnum, explicitValue);
}

return explicitValue.ToString();
}

private static string GetEnumDefaultValueString(INamedTypeSymbol enumType, object? rawValue) =>
rawValue switch
{
int intValue => FindEnumMember(enumType, intValue),
long longValue => FindEnumMember(enumType, longValue),
byte byteValue => FindEnumMember(enumType, byteValue),
sbyte sbyteValue => FindEnumMember(enumType, sbyteValue),
short shortValue => FindEnumMember(enumType, shortValue),
ushort ushortValue => FindEnumMember(enumType, ushortValue),
uint uintValue => FindEnumMember(enumType, uintValue),
ulong ulongValue => FindEnumMember(enumType, ulongValue),
null => $"{enumType.ToFullTypeString()} /* unknown null */",
// Fallback for unexpected cases:
_ => $"{enumType.ToFullTypeString()} /* unknown = {rawValue} */",
};

private static string FindEnumMember<T>(INamedTypeSymbol enumType, T value)
where T : struct, IComparable
{
foreach (var member in enumType.GetMembers().OfType<IFieldSymbol>())
{
if (member.HasConstantValue && Equals(member.ConstantValue, value))
{
// e.g., "System.DayOfWeek.Monday"
return $"{enumType.ToFullTypeString()}.{member.Name}";
}
}

// Fallback if no matching member is found
return $"{enumType.ToFullTypeString()} /* unknown = {value} */";
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using System.Text;
using Microsoft.CodeAnalysis;
using MockMe.Generator.Extensions;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,39 +122,52 @@ internal class {this.mockAsserterTypeName} : MockAsserter
Dictionary<string, PropertyMetadata> callTrackerMeta = [];
Dictionary<string, SetupPropertyMetadata> setupMeta = [];
Dictionary<string, AssertPropertyMetadata> assertMeta = [];
foreach (var method in this.TypeSymbolToMock.GetMembers())
{
if (method is not IMethodSymbol methodSymbol)
{
continue;
}

if (methodSymbol.MethodKind == MethodKind.Constructor)
{
continue;
}

MethodMockGeneratorBase? methodGenerator = MethodGeneratorFactory.Create(
methodSymbol,
this
);

if (methodGenerator is null)
foreach (
var interfaceSymbol in this.TypeSymbolToMock.AllInterfaces.Concat(
[this.TypeSymbolToMock]
)
)
{
foreach (var method in interfaceSymbol.GetMembers())
{
continue;
if (method is not IMethodSymbol methodSymbol)
{
continue;
}

if (methodSymbol.MethodKind == MethodKind.Constructor)
{
continue;
}

MethodMockGeneratorBase? methodGenerator = MethodGeneratorFactory.Create(
methodSymbol,
this
);

if (methodGenerator is null)
{
continue;
}

PatchMethodGeneratorFactory
.Create(interfaceSymbol, methodSymbol)
?.AddPatchMethod(
sb,
assemblyAttributesSource,
staticConstructor,
this.TypeName
);

methodGenerator.AddOriginalCollectionType(setupBuilder);
methodGenerator.AddMethodSetupToStringBuilder(setupBuilder, setupMeta);
methodGenerator.AddMethodCallTrackerToStringBuilder(
callTrackerBuilder,
callTrackerMeta
);
methodGenerator.AddMethodToAsserterClass(asserterBuilder, assertMeta);
}

PatchMethodGeneratorFactory
.Create(this.TypeSymbolToMock, methodSymbol)
?.AddPatchMethod(sb, assemblyAttributesSource, staticConstructor, this.TypeName);

methodGenerator.AddOriginalCollectionType(setupBuilder);
methodGenerator.AddMethodSetupToStringBuilder(setupBuilder, setupMeta);
methodGenerator.AddMethodCallTrackerToStringBuilder(
callTrackerBuilder,
callTrackerMeta
);
methodGenerator.AddMethodToAsserterClass(asserterBuilder, assertMeta);
}

staticConstructor.AppendLine(
Expand Down
13 changes: 12 additions & 1 deletion src/MockMe.PostBuild/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,18 @@
);

//ILReplacer.Replace(assembly, methodToReplace, replacementMethod);
ILManipulator.InsertMethodBodyBeforeExisting(assembly, methodToReplace, replacementMethod);
try
{
ILManipulator.InsertMethodBodyBeforeExisting(
assembly,
methodToReplace,
replacementMethod
);
}
catch
{
// todo...
}
}

assembly.Write(currentAssemblyPath, new() { WriteSymbols = true });
Expand Down
2 changes: 2 additions & 0 deletions tests/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
IDE0130; <!-- Namespace does not match folder structure -->
IDE0034; <!-- Simplify 'default' expression -->
AD0001; <!-- XUnit inline data. This messes up when passing empty array into 'params' method -->
CA1716; <!--Using reserved word in namespace-->
CA1715; <!--Identifiers should have correct prefix-->
$(WarningsNotAsErrors)
</WarningsNotAsErrors>

Expand Down
23 changes: 21 additions & 2 deletions tests/MockMe.Tests.ExampleClasses/ComplexCalculator.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
using System;
using System.Linq;
using System.Threading.Tasks;

//using MockMe.Tests.NuGet;
using MockMe.Tests.ExampleClasses.Interfaces;

namespace MockMe.Tests.ExampleClasses
{
public interface ISymbolVisitor
{
void VisitAddition(IAddition addition);
void VisitSubtraction(ISubtraction subtraction);
void VisitMultiplication(IMultiplication multiplication);
void VisitDivision(IDivision division);
}

public interface ISymbolVisitor<out T>
{
T VisitAddition(IAddition addition);
T VisitSubtraction(ISubtraction subtraction);
T VisitMultiplication(IMultiplication multiplication);
T VisitDivision(IDivision division);
}

public class ComplexCalculator
{
public int ComputeHashForObjects<T>(T[] values)
Expand Down Expand Up @@ -37,5 +52,9 @@ public Task WaitForOperationsToFinish()
{
throw new NotImplementedException();
}

public void Accept(ISymbolVisitor visitor) => throw new NotImplementedException();

public T Accept<T>(ISymbolVisitor<T> visitor) => throw new NotImplementedException();
}
}
24 changes: 21 additions & 3 deletions tests/MockMe.Tests.ExampleClasses/Interfaces/ICalculator.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
namespace MockMe.Tests.ExampleClasses.Interfaces
{
public interface ICalculator
public interface IAddition
{
CalculatorType CalculatorType { get; set; }
int Add(int x, int y);
}

public interface ISubtraction
{
int Subtract(int a, int b);
}

public interface IMultiplication
{
double Multiply(double x, double y);
}

public interface IDivision
{
int Divide(int a, int b);
}

public interface ICalculator : IAddition, ISubtraction, IMultiplication, IDivision
{
CalculatorType CalculatorType { get; set; }
void DivideByZero(double numToDivide);
bool IsOn();
double Multiply(double x, double y);
void TurnOff();
}
}
27 changes: 27 additions & 0 deletions tests/MockMe.Tests.Overloads.Interface/AllDefaultArgs.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using System;

namespace MockMe.Tests.Overloads
{
public enum EnumLong : long
{
Unknown = 0,
First = 1,
}

public interface AllDefaultArgs
{
public void MethodWithBoolDefault(bool value = true);
public void MethodWithConstStringDefault(string greeting = "Hello World");
public void MethodWithDateTimeDefault(DateTime date = default);
public void MethodWithEnumDefault(DayOfWeek day = DayOfWeek.Monday);
public void MethodWithEnumLongDefault(EnumLong value = EnumLong.First);
public void MethodWithMultipleDefaults(
double factor = 1,
bool enabled = true,
string label = "default"
);
public void MethodWithNullableDefault(int? arg = 15);
public void MethodWithPrimitiveDefault(int i = 5);
public void MethodWithStringDefault(string greeting = "Hello World");
}
}
21 changes: 13 additions & 8 deletions tests/MockMe.Tests.Overloads.Interface/AllOverloads.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
namespace MockMe.Tests.Overloads
#pragma warning restore IDE0130 // Using reserved word in namespace
{
[System.Diagnostics.CodeAnalysis.SuppressMessage(
"Style",
"IDE0040:Add accessibility modifiers",
Justification = "<Pending>"
)]
internal interface AllOverloads
{
public int OutArgument(out int arg);
Expand All @@ -15,19 +20,19 @@ internal interface AllOverloads
internal int InternalProp { get; set; }
internal int InternalMethod();

double this[double index] { set; }
string this[int index] { get; set; }
int this[string index] { get; set; }
public double this[double index] { set; }
public string this[int index] { get; set; }
public int this[string index] { get; set; }

internal int this[float index] { get; set; }
protected int this[decimal index] { get; set; }

int Prop_GetInit { get; init; }
int Prop_GetOnly { get; }
int Prop_GetSet { get; set; }
int Prop_SetOnly { set; }
public int Prop_GetInit { get; init; }
public int Prop_GetOnly { get; }
public int Prop_GetSet { get; set; }
public int Prop_SetOnly { set; }

Task<int> AsyncOfTReturn();
public Task<int> AsyncOfTReturn();
Task<int> AsyncOfTReturn(int p1);
Task<int> AsyncOfTReturn(int p1, int p2);
Task<int> AsyncOfTReturn(int p1, int p2, int p3);
Expand Down
Loading

0 comments on commit d06658b

Please sign in to comment.