Skip to content

Commit

Permalink
Allow types defined as instantiated generic interfaces to generate mocks
Browse files Browse the repository at this point in the history
Fixes issue vektra#787. Allow `*ast.IndexExpr` in a `*ast.TypeSpec` to be a mock
target.
  • Loading branch information
LandonTClipp committed Aug 2, 2024
1 parent 5a3e47a commit 9738b5b
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 32 deletions.
5 changes: 0 additions & 5 deletions cmd/mockery.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,6 @@ func (r *RootApp) Run() error {
log.Error().Err(err).Msg("unable to parse packages")
return err
}
log.Info().Msg("done parsing, loading")
if err := parser.Load(); err != nil {
log.Err(err).Msgf("failed to load parser")
return nil
}
log.Info().Msg("done loading, visiting interface nodes")
for _, iface := range parser.Interfaces() {
ifaceLog := log.
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions pkg/fixtures/instantiated_generic_interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package test

type GenericInterface[M any] interface {
Func(arg *M) int
}

type InstantiatedGenericInterface GenericInterface[float32]
1 change: 0 additions & 1 deletion pkg/fixtures/variadic.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ type VariadicFunction = func(args1 string, args2 ...interface{}) interface{}
type Variadic interface {
VariadicFunction(str string, vFunc VariadicFunction) error
}

2 changes: 1 addition & 1 deletion pkg/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (s *GeneratorSuite) getInterfaceFromFile(interfacePath, interfaceName strin
)

s.Require().NoError(
s.parser.Load(),
s.parser.Load(context.Background()),
)

iface, err := s.parser.Find(interfaceName)
Expand Down
2 changes: 1 addition & 1 deletion pkg/outputter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ packages:
m.config.Config = confPath.String()

require.NoError(t, parser.ParsePackages(ctx, []string{tt.packagePath}))
require.NoError(t, parser.Load())
require.NoError(t, parser.Load(context.Background()))
for _, intf := range parser.Interfaces() {
t.Logf("generating interface: %s %s", intf.QualifiedName, intf.Name)
require.NoError(t, m.Generate(ctx, intf))
Expand Down
58 changes: 40 additions & 18 deletions pkg/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,27 @@ import (
"golang.org/x/tools/go/packages"
)

type parserEntry struct {
type fileEntry struct {
fileName string
pkg *packages.Package
syntax *ast.File
interfaces []string
}

func (f *fileEntry) ParseInterfaces(ctx context.Context) {
nv := NewNodeVisitor(ctx)
ast.Walk(nv, f.syntax)
f.interfaces = nv.DeclaredInterfaces()
}

type packageLoadEntry struct {
pkgs []*packages.Package
err error
}

type Parser struct {
entries []*parserEntry
entriesByFileName map[string]*parserEntry
files []*fileEntry
entriesByFileName map[string]*fileEntry
parserPackages []*types.Package
conf packages.Config
packageLoadCache map[string]packageLoadEntry
Expand All @@ -52,7 +58,7 @@ func NewParser(buildTags []string) *Parser {
}
return &Parser{
parserPackages: make([]*types.Package, 0),
entriesByFileName: map[string]*parserEntry{},
entriesByFileName: map[string]*fileEntry{},
conf: conf,
packageLoadCache: map[string]packageLoadEntry{},
}
Expand Down Expand Up @@ -86,18 +92,21 @@ func (p *Parser) ParsePackages(ctx context.Context, packageNames []string) error
Str("package", pkg.PkgPath).
Str("file", file).
Msgf("found file")
entry := parserEntry{
entry := fileEntry{
fileName: file,
pkg: pkg,
syntax: pkg.Syntax[fileIdx],
}
p.entries = append(p.entries, &entry)
entry.ParseInterfaces(ctx)
p.files = append(p.files, &entry)
p.entriesByFileName[file] = &entry
}
}
return nil
}

// DEPRECATED: Parse is part of the deprecated, legacy mockery behavior. This is not
// used when the packages feature is enabled.
func (p *Parser) Parse(ctx context.Context, path string) error {
// To support relative paths to mock targets w/ vendor deps, we need to provide eventual
// calls to build.Context.Import with an absolute path. It needs to be absolute because
Expand Down Expand Up @@ -164,30 +173,28 @@ func (p *Parser) Parse(ctx context.Context, path string) error {
if _, ok := p.entriesByFileName[f]; ok {
continue
}
entry := parserEntry{
entry := fileEntry{
fileName: f,
pkg: pkg,
syntax: pkg.Syntax[idx],
}
p.entries = append(p.entries, &entry)
p.files = append(p.files, &entry)
p.entriesByFileName[f] = &entry
}
}

return nil
}

func (p *Parser) Load() error {
for _, entry := range p.entries {
nv := NewNodeVisitor()
ast.Walk(nv, entry.syntax)
entry.interfaces = nv.DeclaredInterfaces()
func (p *Parser) Load(ctx context.Context) error {
for _, entry := range p.files {
entry.ParseInterfaces(ctx)
}
return nil
}

func (p *Parser) Find(name string) (*Interface, error) {
for _, entry := range p.entries {
for _, entry := range p.files {
for _, iface := range entry.interfaces {
if iface == name {
list := p.packageInterfaces(entry.pkg.Types, entry.fileName, []string{name}, nil)
Expand All @@ -202,7 +209,7 @@ func (p *Parser) Find(name string) (*Interface, error) {

func (p *Parser) Interfaces() []*Interface {
ifaces := make(sortableIFaceList, 0)
for _, entry := range p.entries {
for _, entry := range p.files {
declaredIfaces := entry.interfaces
ifaces = p.packageInterfaces(entry.pkg.Types, entry.fileName, declaredIfaces, ifaces)
}
Expand Down Expand Up @@ -314,12 +321,15 @@ func (s sortableIFaceList) Less(i, j int) bool {
}

type NodeVisitor struct {
declaredInterfaces []string
declaredInterfaces []string
genericInstantiationInterface map[string]any
ctx context.Context
}

func NewNodeVisitor() *NodeVisitor {
func NewNodeVisitor(ctx context.Context) *NodeVisitor {
return &NodeVisitor{
declaredInterfaces: make([]string, 0),
ctx: ctx,
}
}

Expand All @@ -328,11 +338,23 @@ func (nv *NodeVisitor) DeclaredInterfaces() []string {
}

func (nv *NodeVisitor) Visit(node ast.Node) ast.Visitor {
log := zerolog.Ctx(nv.ctx)

switch n := node.(type) {
case *ast.TypeSpec:
log := log.With().
Str("node-name", n.Name.Name).
Str("node-type", fmt.Sprintf("%T", n.Type)).
Logger()

switch n.Type.(type) {
case *ast.InterfaceType, *ast.FuncType:
case *ast.InterfaceType, *ast.FuncType, *ast.IndexExpr:
log.Debug().
Str("node-type", fmt.Sprintf("%T", n.Type)).
Msg("found node with acceptable type for mocking")
nv.declaredInterfaces = append(nv.declaredInterfaces, n.Name.Name)
default:
log.Debug().Msg("Found node with unacceptable type for mocking. Rejecting.")
}
}
return nv
Expand Down
10 changes: 5 additions & 5 deletions pkg/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func TestFileParse(t *testing.T) {
err := parser.Parse(ctx, testFile)
assert.NoError(t, err)

err = parser.Load()
err = parser.Load(context.Background())
assert.NoError(t, err)

node, err := parser.Find("Requester")
Expand All @@ -38,7 +38,7 @@ func TestBuildTagInFilename(t *testing.T) {
err = parser.Parse(ctx, getFixturePath("buildtag", "filename", "iface_freebsd.go"))
assert.NoError(t, err)

err = parser.Load()
err = parser.Load(context.Background())
assert.NoError(t, err) // Expect "redeclared in this block" if tags aren't respected

nodes := parser.Interfaces()
Expand All @@ -60,7 +60,7 @@ func TestBuildTagInComment(t *testing.T) {
err = parser.Parse(ctx, getFixturePath("buildtag", "comment", "freebsd_iface.go"))
assert.NoError(t, err)

err = parser.Load()
err = parser.Load(context.Background())
assert.NoError(t, err) // Expect "redeclared in this block" if tags aren't respected

nodes := parser.Interfaces()
Expand All @@ -78,7 +78,7 @@ func TestCustomBuildTag(t *testing.T) {
err = parser.Parse(ctx, getFixturePath("buildtag", "comment", "custom2_iface.go"))
assert.NoError(t, err)

err = parser.Load()
err = parser.Load(context.Background())
assert.NoError(t, err) // Expect "redeclared in this block" if tags aren't respected

found := false
Expand All @@ -94,6 +94,6 @@ func TestCustomBuildTag(t *testing.T) {
func TestParsePackages(t *testing.T) {
parser := NewParser([]string{})
require.NoError(t, parser.ParsePackages(context.Background(), []string{"github.com/vektra/mockery/v2/pkg/fixtures"}))
assert.NotEqual(t, 0, len(parser.entries))
assert.NotEqual(t, 0, len(parser.files))

}
Loading

0 comments on commit 9738b5b

Please sign in to comment.