Skip to content

Handle 'CanExecute' with method overrides #1081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
<Compile Include="$(MSBuildThisFileDirectory)Extensions\GeneratorAttributeSyntaxContextWithOptions.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Extensions\IncrementalGeneratorInitializationContextExtensions.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Extensions\IncrementalValuesProviderExtensions.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Extensions\IMethodSymbolExtensions.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Extensions\MethodDeclarationSyntaxExtensions.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Extensions\SymbolInfoExtensions.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Extensions\ISymbolExtensions.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;

namespace CommunityToolkit.Mvvm.SourceGenerators.Extensions;

/// <summary>
/// Extension methods for the <see cref="IMethodSymbol"/> type.
/// </summary>
internal static class IMethodSymbolExtensions
{
/// <summary>
/// Checks whether all input symbols are <see cref="IMethodSymbol"/>-s in the same override hierarchy.
/// </summary>
/// <param name="symbols">The input <see cref="ISymbol"/> set to check.</param>
/// <returns>Whether all input symbols are <see cref="IMethodSymbol"/>-s in the same override hierarchy.</returns>
public static bool AreAllInSameOverriddenMethodHierarchy(this ImmutableArray<ISymbol> symbols)
{
IMethodSymbol? baseSymbol = null;

// Look for the base method
foreach (ISymbol currentSymbol in symbols)
{
// If any input symbol is not a method, we can stop right away
if (currentSymbol is not IMethodSymbol methodSymbol)
{
return false;
}

if (methodSymbol.IsVirtual)
{
// If we already found a base method, all methods can't possibly be in the same hierarchy
if (baseSymbol is not null)
{
return false;
}

baseSymbol = methodSymbol;
}
}

// If we didn't find any, stop here
if (baseSymbol is null)
{
return false;
}

// Verify all methods are in the same tree
foreach (ISymbol currentSymbol in symbols)
{
IMethodSymbol methodSymbol = (IMethodSymbol)currentSymbol;

// Ignore the base method
if (SymbolEqualityComparer.Default.Equals(methodSymbol, baseSymbol))
{
continue;
}

// If the current method isn't an override, then fail
if (methodSymbol.OverriddenMethod is not { } overriddenMethod)
{
return false;
}

// The current method must be overriding another one in the set
if (!symbols.Any(symbol => SymbolEqualityComparer.Default.Equals(symbol, overriddenMethod)))
{
return false;
}
}

return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -810,8 +810,10 @@ private static bool TryGetCanExecuteExpressionType(

diagnostics.Add(InvalidCanExecuteMemberNameError, methodSymbol, memberName, methodSymbol.ContainingType);
}
else if (canExecuteSymbols.Length > 1)
else if (canExecuteSymbols.Length > 1 && !canExecuteSymbols.AreAllInSameOverriddenMethodHierarchy())
{
// We specifically allow targeting methods which are overridden: they'll be more than one,
// but it doesn't matter since you'd only ever call "one", being the most derived one.
diagnostics.Add(MultipleCanExecuteMemberNameMatchesError, methodSymbol, memberName, methodSymbol.ContainingType);
}
else if (TryGetCanExecuteExpressionFromSymbol(canExecuteSymbols[0], commandTypeArguments, out canExecuteExpressionType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2060,6 +2060,60 @@ partial class MyViewModel
VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyViewModel.Test.g.cs", result));
}

[TestMethod]
public void RelayCommandWithOverriddenCanExecute_TargetsOverriddenMethod()
{
string source = """
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;

namespace MyApp;

public partial class BaseViewModel : ObservableObject
{
protected virtual bool CanDoStuff()
{
return false;
}
}

public partial class SampleViewModel : BaseViewModel
{
[RelayCommand(CanExecute = nameof(CanDoStuff)]
private void DoStuff()
{
}

protected override bool CanDoStuff()
{
return true;
}
}
""";

string result = """
// <auto-generated/>
#pragma warning disable
#nullable enable
namespace MyApp
{
/// <inheritdoc/>
partial class SampleViewModel
{
/// <summary>The backing field for <see cref="DoStuffCommand"/>.</summary>
[global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", <ASSEMBLY_VERSION>)]
private global::CommunityToolkit.Mvvm.Input.RelayCommand? doStuffCommand;
/// <summary>Gets an <see cref="global::CommunityToolkit.Mvvm.Input.IRelayCommand"/> instance wrapping <see cref="DoStuff"/>.</summary>
[global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", <ASSEMBLY_VERSION>)]
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
public global::CommunityToolkit.Mvvm.Input.IRelayCommand DoStuffCommand => doStuffCommand ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(DoStuff), CanDoStuff);
}
}
""";

VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyApp.SampleViewModel.DoStuff.g.cs", result));
}

[TestMethod]
public void ObservableProperty_AnnotatedFieldHasValueIdentifier()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2259,6 +2259,80 @@ await CSharpAnalyzerWithLanguageVersionTest<WinRTRelayCommandIsNotGeneratedBinda
editorconfig: [("_MvvmToolkitIsUsingWindowsRuntimePack", true)]);
}

[TestMethod]
public void RelayCommandWithOverriddenCanExecute_DoesNotWarn()
{
const string source = """
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;

namespace MyApp
{
public partial class BaseViewModel : ObservableObject
{
protected virtual bool CanDoStuff()
{
return false;
}
}

public partial class SampleViewModel : BaseViewModel
{
[RelayCommand(CanExecute = nameof(CanDoStuff))]
private void DoStuff()
{
}

protected override bool CanDoStuff()
{
return true;
}
}
}
""";

VerifyGeneratedDiagnostics<RelayCommandGenerator>(source, LanguageVersion.CSharp12);
}

[TestMethod]
public void RelayCommandWithOverriddenCanExecute_WithOneMethodNotInTheSameHierarchy_Warns()
{
const string source = """
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;

namespace MyApp
{
public partial class BaseViewModel : ObservableObject
{
protected virtual bool CanDoStuff()
{
return false;
}

protected bool CanDoStuff(string x)
{
}
}

public partial class SampleViewModel : BaseViewModel
{
[RelayCommand(CanExecute = nameof(CanDoStuff)]
private void DoStuff()
{
}

private override bool CanDoStuff()
{
return true;
}
}
}
""";

VerifyGeneratedDiagnostics<RelayCommandGenerator>(source, "MVVMTK0010");
}

[TestMethod]
public async Task WinRTClassUsingNotifyPropertyChangedAttributesAnalyzer_NotTargetingWindows_DoesNotWarn()
{
Expand Down Expand Up @@ -2451,10 +2525,23 @@ internal static async Task VerifyAnalyzerDiagnosticsAndSuccessfulGeneration<TAna
/// <param name="diagnosticsIds">The diagnostic ids to expect for the input source code.</param>
internal static void VerifyGeneratedDiagnostics<TGenerator>(string source, params string[] diagnosticsIds)
where TGenerator : class, IIncrementalGenerator, new()
{
VerifyGeneratedDiagnostics<TGenerator>(source, LanguageVersion.CSharp8, diagnosticsIds);
}

/// <summary>
/// Verifies the output of a source generator.
/// </summary>
/// <typeparam name="TGenerator">The generator type to use.</typeparam>
/// <param name="source">The input source to process.</param>
/// <param name="languageVersion">The language version to use to parse code and run tests.</param>
/// <param name="diagnosticsIds">The diagnostic ids to expect for the input source code.</param>
internal static void VerifyGeneratedDiagnostics<TGenerator>(string source, LanguageVersion languageVersion, params string[] diagnosticsIds)
where TGenerator : class, IIncrementalGenerator, new()
{
IIncrementalGenerator generator = new TGenerator();

VerifyGeneratedDiagnostics(CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.CSharp8)), new[] { generator }, diagnosticsIds, []);
VerifyGeneratedDiagnostics(CSharpSyntaxTree.ParseText(source, CSharpParseOptions.Default.WithLanguageVersion(languageVersion)), new[] { generator }, diagnosticsIds, []);
}

/// <summary>
Expand Down
Loading