Skip to content

Commit

Permalink
Refactor how generated result types are handled (#3564)
Browse files Browse the repository at this point in the history
* Refactor how generated result types are handled

Make the generated result type root a global variable similar to the expression root.
Remove the dependencies on the go-diff package, use testify instead.

* Fix linter issues
  • Loading branch information
raphael authored Jul 25, 2024
1 parent c1a4639 commit 4d06dd6
Show file tree
Hide file tree
Showing 31 changed files with 222 additions and 359 deletions.
2 changes: 1 addition & 1 deletion codegen/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ In particular package codegen defines the data structure that represents a
generated file (see File) which is composed of sections, each corresponding to a
Go text template and accompanying data used to render the final code.
THe package also include functions that can generate code that transforms a
The package also includes functions that generate code to transform a
given type into another (see GoTransform).
*/
package codegen
38 changes: 14 additions & 24 deletions codegen/go_transform_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package codegen
import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"goa.design/goa/v3/codegen/testdata"
"goa.design/goa/v3/expr"
)
Expand All @@ -29,40 +32,27 @@ func TestGoTransformHelpers(t *testing.T) {
Type expr.DataType
HelperNames []string
}{
{"simple", simple, []string{}},
{"simple", simple, nil},
{"recursive", recursive, []string{"transformRecursiveToRecursive"}},
{"composite", composite, []string{"transformSimpleToSimple"}},
{"deep", deep, []string{"transformCompositeToComposite", "transformSimpleToSimple"}},
{"deep-array", deepArray, []string{"transformCompositeToComposite", "transformSimpleToSimple"}},
{"simple-alias", simpleAlias, []string{}},
{"nested-map-alias", mapAlias, []string{}},
{"array-map-alias", arrayMapAlias, []string{}},
{"simple-alias", simpleAlias, nil},
{"nested-map-alias", mapAlias, nil},
{"array-map-alias", arrayMapAlias, nil},
{"result-type-collection", collection, []string{"transformResultTypeToResultType"}},
}
for _, c := range tc {
t.Run(c.Name, func(t *testing.T) {
if c.Type == nil {
t.Fatal("source type not found in testdata")
}
require.NotNil(t, c.Type, "source type not found in testdata")
_, funcs, err := GoTransform(&expr.AttributeExpr{Type: c.Type}, &expr.AttributeExpr{Type: c.Type}, "source", "target", defaultCtx, defaultCtx, "", true)
if err != nil {
t.Fatal(err)
}
if len(funcs) != len(c.HelperNames) {
t.Errorf("invalid helpers count, got: %d, expected %d", len(funcs), len(c.HelperNames))
} else {
var diffs []string
actual := make([]string, len(funcs))
for i, f := range funcs {
actual[i] = f.Name
if c.HelperNames[i] != f.Name {
diffs = append(diffs, f.Name)
}
}
if len(diffs) > 0 {
t.Errorf("invalid helper names, got: %v, expected: %v", actual, c.HelperNames)
}
require.NoError(t, err)
assert.Equal(t, len(c.HelperNames), len(funcs), "invalid helpers count")
var actual []string
for _, f := range funcs {
actual = append(actual, f.Name)
}
assert.Equal(t, c.HelperNames, actual, "invalid helper names")
})
}
}
19 changes: 7 additions & 12 deletions codegen/go_transform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package codegen
import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"goa.design/goa/v3/codegen/testdata"
"goa.design/goa/v3/expr"
)
Expand Down Expand Up @@ -182,20 +185,12 @@ func TestGoTransform(t *testing.T) {
t.Run(name, func(t *testing.T) {
for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
if c.Source == nil {
t.Fatal("source type not found in testdata")
}
if c.Target == nil {
t.Fatal("target type not found in testdata")
}
require.NotNil(t, c.Source)
require.NotNil(t, c.Target)
code, _, err := GoTransform(&expr.AttributeExpr{Type: c.Source}, &expr.AttributeExpr{Type: c.Target}, "source", "target", c.SourceCtx, c.TargetCtx, "", true)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
code = FormatTestCode(t, "package foo\nfunc transform(){\n"+code+"}")
if code != c.Code {
t.Errorf("invalid code, got:\n%s\ngot vs. expected:\n%s", code, Diff(t, code, c.Code))
}
assert.Equal(t, c.Code, code)
})
}
})
Expand Down
12 changes: 5 additions & 7 deletions codegen/go_transform_union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package codegen
import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"goa.design/goa/v3/codegen/testdata"
"goa.design/goa/v3/expr"
)
Expand Down Expand Up @@ -41,14 +44,9 @@ func TestGoTransformUnion(t *testing.T) {
for _, c := range tc {
t.Run(c.Name, func(t *testing.T) {
code, _, err := GoTransform(c.Source, c.Target, "source", "target", defaultCtx, defaultCtx, "", true)
if err != nil {
t.Errorf("unexpected error %s", err)
return
}
require.NoError(t, err)
code = FormatTestCode(t, "package foo\nfunc transform(){\n"+code+"}")
if code != c.Expected {
t.Errorf("invalid code, got:\n%s\ngot vs. expected:\n%s", code, Diff(t, code, c.Expected))
}
assert.Equal(t, c.Expected, code)
})
}
}
Expand Down
8 changes: 4 additions & 4 deletions codegen/service/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,17 +261,17 @@ func runDSL(t *testing.T, dsl func()) *expr.RootExpr {
Services = make(ServicesData)
eval.Reset()
expr.Root = new(expr.RootExpr)
err := eval.Register(expr.Root)
require.NoError(t, err)
expr.GeneratedResultTypes = new(expr.ResultTypesRoot)
require.NoError(t, eval.Register(expr.Root))
require.NoError(t, eval.Register(expr.GeneratedResultTypes))
expr.Root.API = expr.NewAPIExpr("test api", func() {})
expr.Root.API.Servers = []*expr.ServerExpr{expr.Root.API.DefaultServer()}

// run DSL (first pass)
require.True(t, eval.Execute(dsl, nil))

// run DSL (second pass)
err = eval.RunDSL()
require.NoError(t, err)
require.NoError(t, eval.RunDSL())

// return generated root
return expr.Root
Expand Down
1 change: 0 additions & 1 deletion codegen/service/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ func TestEndpoint(t *testing.T) {
for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
codegen.RunDSL(t, c.DSL)
expr.Root.GeneratedTypes = &expr.GeneratedRoot{}
require.Len(t, expr.Root.Services, 1)
fs := EndpointFile("goa.design/goa/example", expr.Root.Services[0])
require.NotNil(t, fs)
Expand Down
1 change: 0 additions & 1 deletion codegen/service/example_svc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ func TestExampleServiceFiles(t *testing.T) {
for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
codegen.RunDSL(t, c.DSL)
expr.Root.GeneratedTypes = &expr.GeneratedRoot{}
require.Len(t, expr.Root.Services, 3)
fs := ExampleServiceFiles("", expr.Root)
require.Len(t, fs, 3)
Expand Down
69 changes: 13 additions & 56 deletions codegen/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,26 @@ import (
"bytes"
"fmt"
"os"
"os/exec"
"strings"
"testing"

"github.com/stretchr/testify/require"
"goa.design/goa/v3/eval"
"goa.design/goa/v3/expr"

"github.com/sergi/go-diff/diffmatchpatch"
)

// RunDSL returns the DSL root resulting from running the given DSL.
func RunDSL(t *testing.T, dsl func()) *expr.RootExpr {
t.Helper()
eval.Reset()
expr.Root = new(expr.RootExpr)
expr.Root.GeneratedTypes = &expr.GeneratedRoot{}
if err := eval.Register(expr.Root); err != nil {
t.Fatal(err)
}
if err := eval.Register(expr.Root.GeneratedTypes); err != nil {
t.Fatal(err)
}
expr.GeneratedResultTypes = new(expr.ResultTypesRoot)
require.NoError(t, eval.Register(expr.Root))
require.NoError(t, eval.Register(expr.GeneratedResultTypes))
expr.Root.API = expr.NewAPIExpr("test api", func() {})
expr.Root.API.Servers = []*expr.ServerExpr{expr.Root.API.DefaultServer()}
if !eval.Execute(dsl, nil) {
t.Fatal(eval.Context.Error())
}
if err := eval.RunDSL(); err != nil {
t.Fatal(err)
}
require.True(t, eval.Execute(dsl, nil), eval.Context.Error())
require.NoError(t, eval.RunDSL())
return expr.Root
}

Expand All @@ -52,22 +42,16 @@ func SectionsCode(t *testing.T, sections []*SectionTemplate) string {
}

// SectionCodeFromImportsAndMethods generates and formats the code for given import and method definition sections.
func SectionCodeFromImportsAndMethods(t *testing.T, importSection *SectionTemplate, methodSection *SectionTemplate) string {
func SectionCodeFromImportsAndMethods(t *testing.T, importSection, methodSection *SectionTemplate) string {
t.Helper()
var code bytes.Buffer
if err := importSection.Write(&code); err != nil {
t.Fatal(err)
}

require.NoError(t, importSection.Write(&code))
return sectionCodeWithPrefix(t, methodSection, code.String())
}

func sectionCodeWithPrefix(t *testing.T, section *SectionTemplate, prefix string) string {
var code bytes.Buffer
if err := section.Write(&code); err != nil {
t.Fatal(err)
}

require.NoError(t, section.Write(&code))
codestr := code.String()

if len(prefix) > 0 {
Expand All @@ -83,50 +67,23 @@ func FormatTestCode(t *testing.T, code string) string {
t.Helper()
tmp := CreateTempFile(t, code)
defer os.Remove(tmp)
if err := finalizeGoSource(tmp); err != nil {
t.Fatal(err)
}
require.NoError(t, finalizeGoSource(tmp))
content, err := os.ReadFile(tmp)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
return strings.Join(strings.Split(string(content), "\n")[2:], "\n")
}

// Diff returns a diff between s1 and s2. It uses the diff tool if installed
// otherwise degrades to using the dmp package.
func Diff(t *testing.T, s1, s2 string) string {
_, err := exec.LookPath("diff")
supportsDiff := (err == nil)
if !supportsDiff {
dmp := diffmatchpatch.New()
diffs := dmp.DiffMain(s1, s2, false)
return dmp.DiffPrettyText(diffs)
}
left := CreateTempFile(t, s1)
right := CreateTempFile(t, s2)
defer os.Remove(left)
defer os.Remove(right)
cmd := exec.Command("diff", left, right)
diffb, _ := cmd.CombinedOutput()
return strings.ReplaceAll(string(diffb), "\t", " ␉ ")
}

// CreateTempFile creates a temporary file and writes the given content.
// It is used only for testing.
func CreateTempFile(t *testing.T, content string) string {
t.Helper()
f, err := os.CreateTemp("", "")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
_, err = f.WriteString(content)
if err != nil {
os.Remove(f.Name())
t.Fatal(err)
}
if err := f.Close(); err != nil {
t.Fatal(err)
}
require.NoError(t, f.Close())
return f.Name()
}
10 changes: 4 additions & 6 deletions codegen/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package codegen
import (
"testing"

"github.com/stretchr/testify/assert"

"goa.design/goa/v3/codegen/testdata"
"goa.design/goa/v3/expr"
)
Expand Down Expand Up @@ -66,18 +68,14 @@ func TestRecursiveValidationCode(t *testing.T) {
ctx := NewAttributeContext(c.Pointer, false, c.UseDefault, "", scope)
code := ValidationCode(&expr.AttributeExpr{Type: c.Type}, nil, ctx, c.Required, expr.IsAlias(c.Type), false, "target")
code = FormatTestCode(t, "package foo\nfunc Validate() (err error){\n"+code+"}")
if code != c.Code {
t.Errorf("invalid code, got:\n%s\ngot vs. expected:\n%s", code, Diff(t, code, c.Code))
}
assert.Equal(t, c.Code, code)
})
}
// Special case of unions with views
t.Run("union-with-view", func(t *testing.T) {
ctx := NewAttributeContext(false, false, false, "", scope)
code := ValidationCode(&expr.AttributeExpr{Type: unionT}, nil, ctx, true, false, true, "target")
code = FormatTestCode(t, "package foo\nfunc Validate() (err error){\n"+code+"}")
if code != testdata.UnionWithViewValidationCode {
t.Errorf("invalid code, got:\n%s\ngot vs. expected:\n%s", code, Diff(t, code, testdata.UnionWithViewValidationCode))
}
assert.Equal(t, testdata.UnionWithViewValidationCode, code)
})
}
2 changes: 1 addition & 1 deletion dsl/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ func Body(args ...any) {
if rt, ok := attr.Type.(*expr.ResultTypeExpr); ok && expr.IsArray(rt.Type) {
// If the attribute type is a result type collection add the type to the
// GeneratedTypes so that the type's DSLFunc is executed.
*expr.Root.GeneratedTypes = append(*expr.Root.GeneratedTypes, rt)
expr.GeneratedResultTypes.Append(rt)
}
if len(args) > 1 {
var ok bool
Expand Down
22 changes: 11 additions & 11 deletions dsl/result_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,19 @@ func ResultType(identifier string, args ...any) *expr.ResultTypeExpr {
}
canonicalID := expr.CanonicalIdentifier(identifier)
// Validate that result type identifier doesn't clash
for _, rt := range expr.Root.ResultTypes {
if re := rt.(*expr.ResultTypeExpr); re.Identifier == canonicalID {
for _, rt := range *expr.GeneratedResultTypes {
if rt.Identifier == canonicalID {
eval.ReportError(
"result type %#v with canonical identifier %#v is defined twice",
identifier, canonicalID)
return nil
}
}
// Now save the type in the API result types map
mt := expr.NewResultTypeExpr(typeName, identifier, fn)
expr.Root.ResultTypes = append(expr.Root.ResultTypes, mt)
// Add the type to the generated types root for later evaluation.
rt := expr.NewResultTypeExpr(typeName, identifier, fn)
expr.Root.ResultTypes = append(expr.Root.ResultTypes, rt)

return mt
return rt
}

// TypeName makes it possible to set the Go struct name for a type or result
Expand Down Expand Up @@ -201,7 +201,7 @@ func View(name string, adsl ...func()) {
switch e := eval.Current().(type) {
case *expr.ResultTypeExpr:
if e.View(name) != nil {
eval.ReportError("multiple expressions for view %#v in result type %#v", name, e.TypeName)
eval.ReportError("view %q is defined multiple times in result type %q", name, e.TypeName)
return
}
at := &expr.AttributeExpr{}
Expand Down Expand Up @@ -340,11 +340,11 @@ func CollectionOf(v any, adsl ...func()) *expr.ResultTypeExpr {
}
id = mime.FormatMediaType(rtype, params)
canonical := expr.CanonicalIdentifier(id)
if mt := expr.Root.GeneratedResultType(canonical); mt != nil {
if mt := expr.GeneratedResultType(canonical); mt != nil {
// Already have a type for this collection, reuse it.
return mt
}
mt := expr.NewResultTypeExpr("", id, func() {
rt := expr.NewResultTypeExpr("", id, func() {
rt, ok := eval.Current().(*expr.ResultTypeExpr)
if !ok {
eval.IncompatibleDSL()
Expand All @@ -371,8 +371,8 @@ func CollectionOf(v any, adsl ...func()) *expr.ResultTypeExpr {
})
// do not execute the DSL right away, will be done last to make sure
// the element DSL has run first.
*expr.Root.GeneratedTypes = append(*expr.Root.GeneratedTypes, mt)
return mt
expr.GeneratedResultTypes.Append(rt)
return rt
}

// Reference sets a type or result type reference. The value itself can be a
Expand Down
Loading

0 comments on commit 4d06dd6

Please sign in to comment.