From f8e46b90d1347847fb452892eb88a12fca37224c Mon Sep 17 00:00:00 2001 From: Jason Ginchereau Date: Fri, 8 Mar 2024 21:04:55 -1000 Subject: [PATCH 1/2] JSExport method overload support --- src/NodeApi.DotNetHost/JSMarshaller.cs | 34 ++++-- src/NodeApi.Generator/ModuleGenerator.cs | 110 ++++++++++-------- .../Interop/JSPropertyDescriptorListOfT.cs | 10 ++ test/TestCases/napi-dotnet/Overloads.cs | 55 +++++++++ test/TestCases/napi-dotnet/overloads.js | 46 ++++++++ 5 files changed, 199 insertions(+), 56 deletions(-) create mode 100644 test/TestCases/napi-dotnet/Overloads.cs create mode 100644 test/TestCases/napi-dotnet/overloads.js diff --git a/src/NodeApi.DotNetHost/JSMarshaller.cs b/src/NodeApi.DotNetHost/JSMarshaller.cs index b3d948af..462afa95 100644 --- a/src/NodeApi.DotNetHost/JSMarshaller.cs +++ b/src/NodeApi.DotNetHost/JSMarshaller.cs @@ -43,6 +43,12 @@ public class JSMarshaller /// public const string ResultPropertyName = "result"; + /// + /// Keeps track of the names of all generated lambda expressions in order to automatically + /// avoid collisions, which can occur with overloaded methods. + /// + private readonly HashSet _expressionNames = new(); + [ThreadStatic] private static JSMarshaller? s_current; @@ -805,7 +811,7 @@ Expression ParameterToJSValue(int index) => InlineOrInvoke( return Expression.Lambda( _delegates.Value.GetToJSDelegateType(method.ReturnType, parameters), Expression.Block(method.ReturnType, new[] { resultVariable }, statements), - $"to_{FullMethodName(method)}", + FullMethodName(method, "to_"), parameters); } catch (Exception ex) @@ -874,7 +880,7 @@ public LambdaExpression BuildFromJSFunctionExpression(MethodInfo method) return Expression.Lambda( JSMarshallerDelegates.GetFromJSDelegateType(method.DeclaringType!), body: Expression.Block(typeof(JSValue), variables, statements), - $"from_{FullMethodName(method)}", + FullMethodName(method, "from_"), parameters: new[] { thisParameter, s_argsParameter }); } catch (Exception ex) @@ -1265,6 +1271,7 @@ public Expression> BuildMethodOverloadDescriptorExpre * return JSCallbackOverload.CreateDescriptor(methodName, overloads); */ + string name = FullMethodName(methods[0]); ParameterExpression overloadsVariable = Expression.Variable(typeof(JSCallbackOverload[]), "overloads"); var statements = new Expression[methods.Length + 2]; @@ -1304,7 +1311,7 @@ public Expression> BuildMethodOverloadDescriptorExpre typeof(JSCallbackDescriptor), new[] { overloadsVariable }, statements), - name: FullMethodName(methods[0]), + name, Array.Empty()); } @@ -3015,17 +3022,30 @@ private static bool IsTypedArrayType(Type elementType) || elementType == typeof(double); } - private static string FullMethodName(MethodInfo method) + private string FullMethodName(MethodInfo method, string? prefix = null) { - string prefix = string.Empty; string name = method.Name; if (name.StartsWith("get_") || name.StartsWith("set_")) { - prefix = name.Substring(0, 4); + prefix ??= name.Substring(0, 4); name = name.Substring(4); } + else + { + prefix ??= string.Empty; + } + + // Ensure the generated name is unique by appending a counter suffix if necessary. + string fullName = $"{prefix}{FullTypeName(method.DeclaringType!)}_{name}"; + string suffix = string.Empty; + for (int i = 2; _expressionNames.Contains(fullName + suffix); i++) + { + suffix = $"_{i}"; + } - return $"{prefix}{FullTypeName(method.DeclaringType!)}_{name}"; + fullName += suffix; + _expressionNames.Add(fullName); + return fullName; } internal static string FullTypeName(Type type) diff --git a/src/NodeApi.Generator/ModuleGenerator.cs b/src/NodeApi.Generator/ModuleGenerator.cs index e86db0ad..eef65176 100644 --- a/src/NodeApi.Generator/ModuleGenerator.cs +++ b/src/NodeApi.Generator/ModuleGenerator.cs @@ -368,18 +368,20 @@ private void ExportModule( s += $"exportsValue = new JSModuleBuilder<{ns}.{moduleType.Name}>()"; s.IncreaseIndent(); - // Export non-static members of the module class. - foreach (ISymbol? member in moduleType.GetMembers() - .Where((m) => m.DeclaredAccessibility == Accessibility.Public && !m.IsStatic)) + // Export public non-static members of the module class. + IEnumerable members = moduleType.GetMembers() + .Where((m) => m.DeclaredAccessibility == Accessibility.Public && !m.IsStatic); + + foreach (IPropertySymbol property in members.OfType()) { - if (member is IMethodSymbol method && method.MethodKind == MethodKind.Ordinary) - { - ExportMethod(ref s, method); - } - else if (member is IPropertySymbol property) - { - ExportProperty(ref s, property); - } + ExportProperty(ref s, property, GetExportName(property)); + } + + foreach (IGrouping methodGroup in members.OfType() + .Where((m) => m.MethodKind == MethodKind.Ordinary) + .GroupBy(GetExportName)) + { + ExportMethod(ref s, methodGroup, methodGroup.Key); } } else @@ -401,11 +403,6 @@ private void ExportModule( // Export tagged static properties as properties on the module. ExportProperty(ref s, exportProperty, exportName); } - else if (exportItem is IMethodSymbol exportMethod) - { - // Export tagged static methods as top-level functions on the module. - ExportMethod(ref s, exportMethod, exportName); - } else if (exportItem is ITypeSymbol exportDelegate && exportDelegate.TypeKind == TypeKind.Delegate) { @@ -413,6 +410,13 @@ private void ExportModule( } } + // Export tagged static methods as top-level functions on the module. + foreach (IGrouping methodGroup in exportItems.OfType() + .GroupBy(GetExportName)) + { + ExportMethod(ref s, methodGroup, methodGroup.Key); + } + if (moduleType != null) { // Construct an instance of the custom module class when the module is initialized. @@ -434,10 +438,8 @@ private void ExportModule( private void ExportType( ref SourceBuilder s, ITypeSymbol type, - string? exportName = null) + string exportName) { - exportName ??= type.Name; - string propertyAttributes = string.Empty; if (type.ContainingType != null) { @@ -547,22 +549,15 @@ private void ExportMembers( { bool isStreamClass = typeof(System.IO.Stream).IsAssignableFrom(type.AsType()); - foreach (ISymbol member in type.GetMembers() - .Where((m) => m.DeclaredAccessibility == Accessibility.Public)) - { - if (isStreamClass && !member.IsStatic) - { - // Only static members on stream subclasses are exported to JS. - continue; - } + IEnumerable members = type.GetMembers() + .Where((m) => m.DeclaredAccessibility == Accessibility.Public) + .Where((m) => !isStreamClass || m.IsStatic); - if (member is IMethodSymbol method && method.MethodKind == MethodKind.Ordinary) - { - ExportMethod(ref s, method); - } - else if (member is IPropertySymbol property) + foreach (ISymbol member in members) + { + if (member is IPropertySymbol property) { - ExportProperty(ref s, property); + ExportProperty(ref s, property, GetExportName(member)); } else if (type.TypeKind == TypeKind.Enum && member is IFieldSymbol field) { @@ -571,9 +566,16 @@ private void ExportMembers( } else if (member is INamedTypeSymbol nestedType) { - ExportType(ref s, nestedType); + ExportType(ref s, nestedType, GetExportName(member)); } } + + foreach (IGrouping methodGroup in members + .OfType().Where((m) => m.MethodKind == MethodKind.Ordinary) + .GroupBy(GetExportName)) + { + ExportMethod(ref s, methodGroup, methodGroup.Key); + } } /// @@ -581,21 +583,32 @@ private void ExportMembers( /// private void ExportMethod( ref SourceBuilder s, - IMethodSymbol method, - string? exportName = null) + IEnumerable methods, + string exportName) { - exportName ??= ToCamelCase(method.Name); + // TODO: Support exporting generic methods. + methods = methods.Where((m) => !m.IsGenericMethod); + + IMethodSymbol? method = methods.FirstOrDefault(); + if (method == null) + { + return; + } - // An adapter method may be used to support marshalling arbitrary parameters, - // if the method does not match the `JSCallback` signature. string attributes = "JSPropertyAttributes.DefaultMethod" + (method.IsStatic ? " | JSPropertyAttributes.Static" : string.Empty); - if (method.IsGenericMethod) + + if (methods.Count() == 1 && !IsMethodCallbackAdapterRequired(method)) { - // TODO: Export generic method. + // No adapter is needed for a method with a JSCallback signature. + string ns = GetNamespace(method); + string className = method.ContainingType.Name; + s += $".AddMethod(\"{exportName}\", " + + $"{ns}.{className}.{method.Name},\n\t{attributes})"; } - else if (IsMethodCallbackAdapterRequired(method)) + else if (methods.Count() == 1) { + // An adapter method supports marshalling arbitrary parameters. Expression adapter = _marshaller.BuildFromJSMethodExpression(method.AsMethodInfo()); _callbackAdapters.Add(adapter.Name!, adapter); @@ -603,10 +616,11 @@ private void ExportMethod( } else { - string ns = GetNamespace(method); - string className = method.ContainingType.Name; - s += $".AddMethod(\"{exportName}\", " + - $"{ns}.{className}.{method.Name},\n\t{attributes})"; + // An adapter method provides overload resolution. + LambdaExpression adapter = _marshaller.BuildMethodOverloadDescriptorExpression( + methods.Select((m) => m.AsMethodInfo()).ToArray()); + _callbackAdapters.Add(adapter.Name!, adapter); + s += $".AddMethod(\"{exportName}\", {adapter.Name}(),\n\t{attributes})"; } } @@ -616,10 +630,8 @@ private void ExportMethod( private void ExportProperty( ref SourceBuilder s, IPropertySymbol property, - string? exportName = null) + string exportName) { - exportName ??= ToCamelCase(property.Name); - bool writable = property.SetMethod != null || (!property.IsStatic && property.ContainingType.TypeKind == TypeKind.Struct); string attributes = "JSPropertyAttributes.Enumerable | JSPropertyAttributes.Configurable" + diff --git a/src/NodeApi/Interop/JSPropertyDescriptorListOfT.cs b/src/NodeApi/Interop/JSPropertyDescriptorListOfT.cs index 5eb33c5a..311ffe8b 100644 --- a/src/NodeApi/Interop/JSPropertyDescriptorListOfT.cs +++ b/src/NodeApi/Interop/JSPropertyDescriptorListOfT.cs @@ -215,4 +215,14 @@ public TDerived AddMethod( attributes, data); } + + public TDerived AddMethod( + string name, + JSCallbackDescriptor callbackDescriptor, + JSPropertyAttributes attributes = JSPropertyAttributes.DefaultMethod) + { + Properties.Add(JSPropertyDescriptor.Function( + name, callbackDescriptor.Callback, attributes, callbackDescriptor.Data)); + return (TDerived)(object)this; + } } diff --git a/test/TestCases/napi-dotnet/Overloads.cs b/test/TestCases/napi-dotnet/Overloads.cs new file mode 100644 index 00000000..bcd134b0 --- /dev/null +++ b/test/TestCases/napi-dotnet/Overloads.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.JavaScript.NodeApi.TestCases; + +[JSExport] +public class Overloads +{ + public Overloads() + { + } + + public Overloads(int intValue) + { + IntValue = intValue; + } + + public Overloads(string stringValue) + { + StringValue = stringValue; + } + + public Overloads(int intValue, string stringValue) + { + IntValue = intValue; + StringValue = stringValue; + } + + public int? IntValue { get; private set; } + + public string? StringValue { get; private set; } + + public void SetValue(int intValue) + { + IntValue = intValue; + } + + public void SetValue(string stringValue) + { + StringValue = stringValue; + } + + public void SetValue(int intValue, string stringValue) + { + IntValue = intValue; + StringValue = stringValue; + } + + // Method with overloaded name in C# is given a non-overloaded export name. + [JSExport("setDoubleValue")] + public void SetValue(double doubleValue) + { + IntValue = (int)doubleValue; + } +} diff --git a/test/TestCases/napi-dotnet/overloads.js b/test/TestCases/napi-dotnet/overloads.js new file mode 100644 index 00000000..e28eea8a --- /dev/null +++ b/test/TestCases/napi-dotnet/overloads.js @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +const assert = require('assert'); + +/** @type {import('./napi-dotnet')} */ +const binding = require('../common').binding; + +const Overloads = binding.Overloads; + +// Overloaded constructor +const emptyObj = new Overloads(); +assert.strictEqual(emptyObj.intValue, undefined); +assert.strictEqual(emptyObj.stringValue, undefined); + +const intObj = new Overloads(1); +assert.strictEqual(intObj.intValue, 1); +assert.strictEqual(intObj.stringValue, undefined); + +const stringObj = new Overloads('two'); +assert.strictEqual(stringObj.intValue, undefined); +assert.strictEqual(stringObj.stringValue, 'two'); + +const comboObj = new Overloads(3, 'three'); +assert.strictEqual(comboObj.intValue, 3); +assert.strictEqual(comboObj.stringValue, 'three'); + +// Overloaded method +const obj1 = new Overloads(); +obj1.setValue(1); +assert.strictEqual(obj1.intValue, 1); +assert.strictEqual(obj1.stringValue, undefined); + +const obj2 = new Overloads(); +obj2.setValue('two'); +assert.strictEqual(obj2.intValue, undefined); +assert.strictEqual(obj2.stringValue, 'two'); + +const obj3 = new Overloads(); +obj3.setValue(3, 'three'); +assert.strictEqual(obj3.intValue, 3); +assert.strictEqual(obj3.stringValue, 'three'); + +const obj4 = new Overloads(); +obj4.setDoubleValue(4.0); +assert.strictEqual(obj4.intValue, 4); From ed21504b12d2108939c14ef305b8cc176a40983a Mon Sep 17 00:00:00 2001 From: Jason Ginchereau Date: Sat, 9 Mar 2024 11:41:42 -1000 Subject: [PATCH 2/2] Fix constructor overload parameter type reference bug --- src/NodeApi.Generator/SymbolExtensions.cs | 9 ++++++--- test/TestCases/napi-dotnet/Overloads.cs | 10 ++++++++++ test/TestCases/napi-dotnet/overloads.js | 15 +++++++++++++-- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/NodeApi.Generator/SymbolExtensions.cs b/src/NodeApi.Generator/SymbolExtensions.cs index 6ab3762b..87480877 100644 --- a/src/NodeApi.Generator/SymbolExtensions.cs +++ b/src/NodeApi.Generator/SymbolExtensions.cs @@ -414,7 +414,8 @@ private static ConstructorBuilder BuildSymbolicConstructor( IReadOnlyList parameters = constructorSymbol.Parameters; for (int i = 0; i < parameters.Count; i++) { - constructorBuilder.DefineParameter(i, ParameterAttributes.None, parameters[i].Name); + // The parameter index is offset by 1. + constructorBuilder.DefineParameter(i + 1, ParameterAttributes.None, parameters[i].Name); } if (isDelegateConstructor) @@ -556,8 +557,10 @@ public static ConstructorInfo AsConstructorInfo(this IMethodSymbol methodSymbol) parameter.Type.AsType(type.GenericTypeArguments, buildType: true); } - ConstructorInfo? constructorInfo = type.GetConstructor( - methodSymbol.Parameters.Select((p) => p.Type.AsType()).ToArray()); + BindingFlags bindingFlags = BindingFlags.Public | BindingFlags.Instance; + ConstructorInfo? constructorInfo = type.GetConstructors(bindingFlags) + .FirstOrDefault((c) => c.GetParameters().Select((p) => p.Name).SequenceEqual( + methodSymbol.Parameters.Select((p) => p.Name))); return constructorInfo ?? throw new InvalidOperationException( $"Constructor not found for type: {type.Name}"); } diff --git a/test/TestCases/napi-dotnet/Overloads.cs b/test/TestCases/napi-dotnet/Overloads.cs index bcd134b0..50e8b6e3 100644 --- a/test/TestCases/napi-dotnet/Overloads.cs +++ b/test/TestCases/napi-dotnet/Overloads.cs @@ -26,6 +26,11 @@ public Overloads(int intValue, string stringValue) StringValue = stringValue; } + public Overloads(ITestInterface obj) + { + StringValue = obj.Value; + } + public int? IntValue { get; private set; } public string? StringValue { get; private set; } @@ -46,6 +51,11 @@ public void SetValue(int intValue, string stringValue) StringValue = stringValue; } + public void SetValue(ITestInterface obj) + { + StringValue = obj.Value; + } + // Method with overloaded name in C# is given a non-overloaded export name. [JSExport("setDoubleValue")] public void SetValue(double doubleValue) diff --git a/test/TestCases/napi-dotnet/overloads.js b/test/TestCases/napi-dotnet/overloads.js index e28eea8a..72d185cc 100644 --- a/test/TestCases/napi-dotnet/overloads.js +++ b/test/TestCases/napi-dotnet/overloads.js @@ -7,6 +7,7 @@ const assert = require('assert'); const binding = require('../common').binding; const Overloads = binding.Overloads; +const ClassObject = binding.ClassObject; // Overloaded constructor const emptyObj = new Overloads(); @@ -25,6 +26,11 @@ const comboObj = new Overloads(3, 'three'); assert.strictEqual(comboObj.intValue, 3); assert.strictEqual(comboObj.stringValue, 'three'); +const objValue = new ClassObject(); +objValue.value = 'test'; +const objFromClass = new Overloads(objValue); +assert.strictEqual(objFromClass.stringValue, 'test'); + // Overloaded method const obj1 = new Overloads(); obj1.setValue(1); @@ -42,5 +48,10 @@ assert.strictEqual(obj3.intValue, 3); assert.strictEqual(obj3.stringValue, 'three'); const obj4 = new Overloads(); -obj4.setDoubleValue(4.0); -assert.strictEqual(obj4.intValue, 4); +obj4.setValue(objValue); +assert.strictEqual(obj4.stringValue, 'test'); + +const obj5 = new Overloads(); +obj5.setDoubleValue(5.0); +assert.strictEqual(obj5.intValue, 5); +