diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index ed87b8e2..685a7cb7 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -304,10 +304,12 @@ func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPac } sort.Strings(sortedPaths) + packagesName := createPackageMap(sortedPaths) + g.packageMap = make(map[string]string, len(im)) localNames := make(map[string]bool, len(im)) for _, pth := range sortedPaths { - base, ok := lookupPackageName(pth) + base, ok := packagesName[pth] if !ok { base = sanitize(path.Base(pth)) } @@ -622,22 +624,30 @@ func (g *generator) Output() []byte { return src } -func lookupPackageName(importPath string) (string, bool) { +// createPackageMap returns a map of import path to package name +// for specified importPaths. +func createPackageMap(importPaths []string) map[string]string { var pkg struct { - Name string + Name string + ImportPath string } + pkgMap := make(map[string]string) b := bytes.NewBuffer(nil) - cmd := exec.Command("go", "list", "-json", importPath) + args := []string{"list", "-json"} + args = append(args, importPaths...) + cmd := exec.Command("go", args...) cmd.Stdout = b - err := cmd.Run() - if err != nil { - return "", false - } - err = json.Unmarshal(b.Bytes(), &pkg) - if err != nil { - return "", false + cmd.Run() + dec := json.NewDecoder(b) + for dec.More() { + err := dec.Decode(&pkg) + if err != nil { + log.Printf("failed to decode 'go list' output: %v", err) + continue + } + pkgMap[pkg.ImportPath] = pkg.Name } - return pkg.Name, true + return pkgMap } func printVersion() { diff --git a/mockgen/mockgen_test.go b/mockgen/mockgen_test.go index 130d3cf4..515efb2f 100644 --- a/mockgen/mockgen_test.go +++ b/mockgen/mockgen_test.go @@ -335,10 +335,7 @@ func TestGetArgNames(t *testing.T) { } } -func Test_lookupPackageName(t *testing.T) { - type args struct { - importPath string - } +func Test_createPackageMap(t *testing.T) { tests := []struct { name string importPath string @@ -350,14 +347,19 @@ func Test_lookupPackageName(t *testing.T) { {"modules", "rsc.io/quote/v3", "quote", true}, {"fail", "this/should/not/work", "", false}, } + var importPaths []string + for _, t := range tests { + importPaths = append(importPaths, t.importPath) + } + packages := createPackageMap(importPaths) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotPackageName, gotOk := lookupPackageName(tt.importPath) + gotPackageName, gotOk := packages[tt.importPath] if gotPackageName != tt.wantPackageName { - t.Errorf("lookupPackageName() gotPackageName = %v, wantPackageName %v", gotPackageName, tt.wantPackageName) + t.Errorf("createPackageMap() gotPackageName = %v, wantPackageName = %v", gotPackageName, tt.wantPackageName) } if gotOk != tt.wantOK { - t.Errorf("lookupPackageName() gotOk = %v, wantOK %v", gotOk, tt.wantOK) + t.Errorf("createPackageMap() gotOk = %v, wantOK = %v", gotOk, tt.wantOK) } }) } diff --git a/mockgen/parse.go b/mockgen/parse.go index d88f3c95..a8edde80 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -432,6 +432,15 @@ func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) { // importsOfFile returns a map of package name to import path // of the imports in file. func importsOfFile(file *ast.File) (normalImports map[string]string, dotImports []string) { + var importPaths []string + for _, is := range file.Imports { + if is.Name != nil { + continue + } + importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes + importPaths = append(importPaths, importPath) + } + packagesName := createPackageMap(importPaths) normalImports = make(map[string]string) dotImports = make([]string, 0) for _, is := range file.Imports { @@ -445,7 +454,7 @@ func importsOfFile(file *ast.File) (normalImports map[string]string, dotImports } pkgName = is.Name.Name } else { - pkg, ok := lookupPackageName(importPath) + pkg, ok := packagesName[importPath] if !ok { // Fallback to import path suffix. Note that this is uncertain. _, last := path.Split(importPath)