diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems b/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems
index f875a91df..9e719e412 100644
--- a/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems
@@ -74,6 +74,7 @@
+
diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/IMethodSymbolExtensions.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/IMethodSymbolExtensions.cs
new file mode 100644
index 000000000..23f4786d3
--- /dev/null
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/IMethodSymbolExtensions.cs
@@ -0,0 +1,78 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Immutable;
+using System.Linq;
+using Microsoft.CodeAnalysis;
+
+namespace CommunityToolkit.Mvvm.SourceGenerators.Extensions;
+
+///
+/// Extension methods for the type.
+///
+internal static class IMethodSymbolExtensions
+{
+ ///
+ /// Checks whether all input symbols are -s in the same override hierarchy.
+ ///
+ /// The input set to check.
+ /// Whether all input symbols are -s in the same override hierarchy.
+ public static bool AreAllInSameOverriddenMethodHierarchy(this ImmutableArray symbols)
+ {
+ IMethodSymbol? baseSymbol = null;
+
+ // Look for the base method
+ foreach (ISymbol currentSymbol in symbols)
+ {
+ // If any input symbol is not a method, we can stop right away
+ if (currentSymbol is not IMethodSymbol methodSymbol)
+ {
+ return false;
+ }
+
+ if (methodSymbol.IsVirtual)
+ {
+ // If we already found a base method, all methods can't possibly be in the same hierarchy
+ if (baseSymbol is not null)
+ {
+ return false;
+ }
+
+ baseSymbol = methodSymbol;
+ }
+ }
+
+ // If we didn't find any, stop here
+ if (baseSymbol is null)
+ {
+ return false;
+ }
+
+ // Verify all methods are in the same tree
+ foreach (ISymbol currentSymbol in symbols)
+ {
+ IMethodSymbol methodSymbol = (IMethodSymbol)currentSymbol;
+
+ // Ignore the base method
+ if (SymbolEqualityComparer.Default.Equals(methodSymbol, baseSymbol))
+ {
+ continue;
+ }
+
+ // If the current method isn't an override, then fail
+ if (methodSymbol.OverriddenMethod is not { } overriddenMethod)
+ {
+ return false;
+ }
+
+ // The current method must be overriding another one in the set
+ if (!symbols.Any(symbol => SymbolEqualityComparer.Default.Equals(symbol, overriddenMethod)))
+ {
+ return false;
+ }
+ }
+
+ return true;
+ }
+}
diff --git a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs
index d40086135..57eb1ebcf 100644
--- a/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs
+++ b/src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs
@@ -810,8 +810,10 @@ private static bool TryGetCanExecuteExpressionType(
diagnostics.Add(InvalidCanExecuteMemberNameError, methodSymbol, memberName, methodSymbol.ContainingType);
}
- else if (canExecuteSymbols.Length > 1)
+ else if (canExecuteSymbols.Length > 1 && !canExecuteSymbols.AreAllInSameOverriddenMethodHierarchy())
{
+ // We specifically allow targeting methods which are overridden: they'll be more than one,
+ // but it doesn't matter since you'd only ever call "one", being the most derived one.
diagnostics.Add(MultipleCanExecuteMemberNameMatchesError, methodSymbol, memberName, methodSymbol.ContainingType);
}
else if (TryGetCanExecuteExpressionFromSymbol(canExecuteSymbols[0], commandTypeArguments, out canExecuteExpressionType))
diff --git a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs
index d3d44ec5b..d91dfe2b2 100644
--- a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs
+++ b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs
@@ -2060,6 +2060,60 @@ partial class MyViewModel
VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyViewModel.Test.g.cs", result));
}
+ [TestMethod]
+ public void RelayCommandWithOverriddenCanExecute_TargetsOverriddenMethod()
+ {
+ string source = """
+ using CommunityToolkit.Mvvm.ComponentModel;
+ using CommunityToolkit.Mvvm.Input;
+
+ namespace MyApp;
+
+ public partial class BaseViewModel : ObservableObject
+ {
+ protected virtual bool CanDoStuff()
+ {
+ return false;
+ }
+ }
+
+ public partial class SampleViewModel : BaseViewModel
+ {
+ [RelayCommand(CanExecute = nameof(CanDoStuff)]
+ private void DoStuff()
+ {
+ }
+
+ protected override bool CanDoStuff()
+ {
+ return true;
+ }
+ }
+ """;
+
+ string result = """
+ //
+ #pragma warning disable
+ #nullable enable
+ namespace MyApp
+ {
+ ///
+ partial class SampleViewModel
+ {
+ /// The backing field for .
+ [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", )]
+ private global::CommunityToolkit.Mvvm.Input.RelayCommand? doStuffCommand;
+ /// Gets an instance wrapping .
+ [global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", )]
+ [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
+ public global::CommunityToolkit.Mvvm.Input.IRelayCommand DoStuffCommand => doStuffCommand ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(DoStuff), CanDoStuff);
+ }
+ }
+ """;
+
+ VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyApp.SampleViewModel.DoStuff.g.cs", result));
+ }
+
[TestMethod]
public void ObservableProperty_AnnotatedFieldHasValueIdentifier()
{
diff --git a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsDiagnostics.cs b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsDiagnostics.cs
index cb866377d..f18ff7dfd 100644
--- a/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsDiagnostics.cs
+++ b/tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsDiagnostics.cs
@@ -2259,6 +2259,80 @@ await CSharpAnalyzerWithLanguageVersionTest(source, LanguageVersion.CSharp12);
+ }
+
+ [TestMethod]
+ public void RelayCommandWithOverriddenCanExecute_WithOneMethodNotInTheSameHierarchy_Warns()
+ {
+ const string source = """
+ using CommunityToolkit.Mvvm.ComponentModel;
+ using CommunityToolkit.Mvvm.Input;
+
+ namespace MyApp
+ {
+ public partial class BaseViewModel : ObservableObject
+ {
+ protected virtual bool CanDoStuff()
+ {
+ return false;
+ }
+
+ protected bool CanDoStuff(string x)
+ {
+ }
+ }
+
+ public partial class SampleViewModel : BaseViewModel
+ {
+ [RelayCommand(CanExecute = nameof(CanDoStuff)]
+ private void DoStuff()
+ {
+ }
+
+ private override bool CanDoStuff()
+ {
+ return true;
+ }
+ }
+ }
+ """;
+
+ VerifyGeneratedDiagnostics(source, "MVVMTK0010");
+ }
+
[TestMethod]
public async Task WinRTClassUsingNotifyPropertyChangedAttributesAnalyzer_NotTargetingWindows_DoesNotWarn()
{
@@ -2451,10 +2525,23 @@ internal static async Task VerifyAnalyzerDiagnosticsAndSuccessfulGenerationThe diagnostic ids to expect for the input source code.
internal static void VerifyGeneratedDiagnostics(string source, params string[] diagnosticsIds)
where TGenerator : class, IIncrementalGenerator, new()
+ {
+ VerifyGeneratedDiagnostics(source, LanguageVersion.CSharp8, diagnosticsIds);
+ }
+
+ ///
+ /// Verifies the output of a source generator.
+ ///
+ /// The generator type to use.
+ /// The input source to process.
+ /// The language version to use to parse code and run tests.
+ /// The diagnostic ids to expect for the input source code.
+ internal static void VerifyGeneratedDiagnostics(string source, LanguageVersion languageVersion, params string[] diagnosticsIds)
+ where TGenerator : class, IIncrementalGenerator, new()
{
IIncrementalGenerator generator = new TGenerator();
- VerifyGeneratedDiagnostics(CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.CSharp8)), new[] { generator }, diagnosticsIds, []);
+ VerifyGeneratedDiagnostics(CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default.WithLanguageVersion(languageVersion)), new[] { generator }, diagnosticsIds, []);
}
///