diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems b/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems index f875a91df..9e719e412 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems @@ -74,6 +74,7 @@ + diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/IMethodSymbolExtensions.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/IMethodSymbolExtensions.cs new file mode 100644 index 000000000..23f4786d3 --- /dev/null +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/IMethodSymbolExtensions.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace CommunityToolkit.Mvvm.SourceGenerators.Extensions; + +/// +/// Extension methods for the type. +/// +internal static class IMethodSymbolExtensions +{ + /// + /// Checks whether all input symbols are -s in the same override hierarchy. + /// + /// The input set to check. + /// Whether all input symbols are -s in the same override hierarchy. + public static bool AreAllInSameOverriddenMethodHierarchy(this ImmutableArray symbols) + { + IMethodSymbol? baseSymbol = null; + + // Look for the base method + foreach (ISymbol currentSymbol in symbols) + { + // If any input symbol is not a method, we can stop right away + if (currentSymbol is not IMethodSymbol methodSymbol) + { + return false; + } + + if (methodSymbol.IsVirtual) + { + // If we already found a base method, all methods can't possibly be in the same hierarchy + if (baseSymbol is not null) + { + return false; + } + + baseSymbol = methodSymbol; + } + } + + // If we didn't find any, stop here + if (baseSymbol is null) + { + return false; + } + + // Verify all methods are in the same tree + foreach (ISymbol currentSymbol in symbols) + { + IMethodSymbol methodSymbol = (IMethodSymbol)currentSymbol; + + // Ignore the base method + if (SymbolEqualityComparer.Default.Equals(methodSymbol, baseSymbol)) + { + continue; + } + + // If the current method isn't an override, then fail + if (methodSymbol.OverriddenMethod is not { } overriddenMethod) + { + return false; + } + + // The current method must be overriding another one in the set + if (!symbols.Any(symbol => SymbolEqualityComparer.Default.Equals(symbol, overriddenMethod))) + { + return false; + } + } + + return true; + } +} diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs index d40086135..57eb1ebcf 100644 --- a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs +++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs @@ -810,8 +810,10 @@ private static bool TryGetCanExecuteExpressionType( diagnostics.Add(InvalidCanExecuteMemberNameError, methodSymbol, memberName, methodSymbol.ContainingType); } - else if (canExecuteSymbols.Length > 1) + else if (canExecuteSymbols.Length > 1 && !canExecuteSymbols.AreAllInSameOverriddenMethodHierarchy()) { + // We specifically allow targeting methods which are overridden: they'll be more than one, + // but it doesn't matter since you'd only ever call "one", being the most derived one. diagnostics.Add(MultipleCanExecuteMemberNameMatchesError, methodSymbol, memberName, methodSymbol.ContainingType); } else if (TryGetCanExecuteExpressionFromSymbol(canExecuteSymbols[0], commandTypeArguments, out canExecuteExpressionType)) diff --git a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs index d3d44ec5b..d91dfe2b2 100644 --- a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs +++ b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs @@ -2060,6 +2060,60 @@ partial class MyViewModel VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyViewModel.Test.g.cs", result)); } + [TestMethod] + public void RelayCommandWithOverriddenCanExecute_TargetsOverriddenMethod() + { + string source = """ + using CommunityToolkit.Mvvm.ComponentModel; + using CommunityToolkit.Mvvm.Input; + + namespace MyApp; + + public partial class BaseViewModel : ObservableObject + { + protected virtual bool CanDoStuff() + { + return false; + } + } + + public partial class SampleViewModel : BaseViewModel + { + [RelayCommand(CanExecute = nameof(CanDoStuff)] + private void DoStuff() + { + } + + protected override bool CanDoStuff() + { + return true; + } + } + """; + + string result = """ + // + #pragma warning disable + #nullable enable + namespace MyApp + { + /// + partial class SampleViewModel + { + /// The backing field for . + [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", )] + private global::CommunityToolkit.Mvvm.Input.RelayCommand? doStuffCommand; + /// Gets an instance wrapping . + [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", )] + [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + public global::CommunityToolkit.Mvvm.Input.IRelayCommand DoStuffCommand => doStuffCommand ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(DoStuff), CanDoStuff); + } + } + """; + + VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyApp.SampleViewModel.DoStuff.g.cs", result)); + } + [TestMethod] public void ObservableProperty_AnnotatedFieldHasValueIdentifier() { diff --git a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsDiagnostics.cs b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsDiagnostics.cs index cb866377d..f18ff7dfd 100644 --- a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsDiagnostics.cs +++ b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsDiagnostics.cs @@ -2259,6 +2259,80 @@ await CSharpAnalyzerWithLanguageVersionTest(source, LanguageVersion.CSharp12); + } + + [TestMethod] + public void RelayCommandWithOverriddenCanExecute_WithOneMethodNotInTheSameHierarchy_Warns() + { + const string source = """ + using CommunityToolkit.Mvvm.ComponentModel; + using CommunityToolkit.Mvvm.Input; + + namespace MyApp + { + public partial class BaseViewModel : ObservableObject + { + protected virtual bool CanDoStuff() + { + return false; + } + + protected bool CanDoStuff(string x) + { + } + } + + public partial class SampleViewModel : BaseViewModel + { + [RelayCommand(CanExecute = nameof(CanDoStuff)] + private void DoStuff() + { + } + + private override bool CanDoStuff() + { + return true; + } + } + } + """; + + VerifyGeneratedDiagnostics(source, "MVVMTK0010"); + } + [TestMethod] public async Task WinRTClassUsingNotifyPropertyChangedAttributesAnalyzer_NotTargetingWindows_DoesNotWarn() { @@ -2451,10 +2525,23 @@ internal static async Task VerifyAnalyzerDiagnosticsAndSuccessfulGenerationThe diagnostic ids to expect for the input source code. internal static void VerifyGeneratedDiagnostics(string source, params string[] diagnosticsIds) where TGenerator : class, IIncrementalGenerator, new() + { + VerifyGeneratedDiagnostics(source, LanguageVersion.CSharp8, diagnosticsIds); + } + + /// + /// Verifies the output of a source generator. + /// + /// The generator type to use. + /// The input source to process. + /// The language version to use to parse code and run tests. + /// The diagnostic ids to expect for the input source code. + internal static void VerifyGeneratedDiagnostics(string source, LanguageVersion languageVersion, params string[] diagnosticsIds) + where TGenerator : class, IIncrementalGenerator, new() { IIncrementalGenerator generator = new TGenerator(); - VerifyGeneratedDiagnostics(CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.CSharp8)), new[] { generator }, diagnosticsIds, []); + VerifyGeneratedDiagnostics(CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default.WithLanguageVersion(languageVersion)), new[] { generator }, diagnosticsIds, []); } ///