Skip to content

Commit

Permalink
feat(generate): render concurrently (#308)
Browse files Browse the repository at this point in the history
* feat(generate): render concurrently

* perf(generate): set GOMAXPROCS
  • Loading branch information
tr1v3r authored Dec 21, 2021
1 parent 97b4aa8 commit 539a653
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 43 deletions.
91 changes: 48 additions & 43 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,24 @@ import (
"context"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"text/template"

"golang.org/x/tools/imports"

"gorm.io/gorm"
"gorm.io/gorm/schema"

"gorm.io/gen/internal/check"
"gorm.io/gen/internal/model"
"gorm.io/gen/internal/parser"
tmpl "gorm.io/gen/internal/template"
"gorm.io/gen/internal/utils/pools"
)

// T generic type
Expand All @@ -32,6 +34,8 @@ type M map[string]interface{}
// RowsAffected execute affected raws
type RowsAffected int64

func init() { runtime.GOMAXPROCS(runtime.NumCPU()) }

// NewGenerator create a new generator
func NewGenerator(cfg Config) *Generator {
err := cfg.Revise()
Expand Down Expand Up @@ -243,19 +247,30 @@ func (g *Generator) generateQueryFile() (err error) {
return nil
}

errChan := make(chan error)
pool := pools.NewPool(runtime.NumCPU())
// generate query code for all struct
for _, info := range g.Data {
err = g.generateSingleQueryFile(info)
if err != nil {
return err
}
pool.Wait()
go func(info *genInfo) {
defer pool.Done()
err = g.generateSingleQueryFile(info)
if err != nil {
errChan <- err
}

if g.WithUnitTest {
err = g.generateQueryUnitTestFile(info)
if err != nil { // do not panic
g.db.Logger.Error(context.Background(), "generate unit test fail: %s", err)
if g.WithUnitTest {
err = g.generateQueryUnitTestFile(info)
if err != nil { // do not panic
g.db.Logger.Error(context.Background(), "generate unit test fail: %s", err)
}
}
}
}(info)
}
select {
case err = <-errChan:
return err
case <-pool.AsyncWaitAll():
}

// generate query file
Expand Down Expand Up @@ -385,24 +400,35 @@ func (g *Generator) generateModelFile() error {
return fmt.Errorf("create model pkg path(%s) fail: %s", modelOutPath, err)
}

errChan := make(chan error)
pool := pools.NewPool(runtime.NumCPU())
for _, data := range g.modelData {
if data == nil || !data.GenBaseStruct {
continue
}

var buf bytes.Buffer
err = render(tmpl.Model, &buf, data)
if err != nil {
return err
}
pool.Wait()
go func(data *check.BaseStruct) {
defer pool.Done()
var buf bytes.Buffer
err = render(tmpl.Model, &buf, data)
if err != nil {
errChan <- err
}

modelFile := modelOutPath + data.FileName + ".gen.go"
err = g.output(modelFile, buf.Bytes())
if err != nil {
return err
}
modelFile := modelOutPath + data.FileName + ".gen.go"
err = g.output(modelFile, buf.Bytes())
if err != nil {
errChan <- err
}

g.successInfo(fmt.Sprintf("generate model file(table <%s> -> {%s.%s}): %s", data.TableName, data.StructInfo.Package, data.StructInfo.Type, modelFile))
g.successInfo(fmt.Sprintf("generate model file(table <%s> -> {%s.%s}): %s", data.TableName, data.StructInfo.Package, data.StructInfo.Type, modelFile))
}(data)
}
select {
case err = <-errChan:
return err
case <-pool.AsyncWaitAll():
}
return nil
}
Expand Down Expand Up @@ -433,7 +459,7 @@ func (g *Generator) output(fileName string, content []byte) error {
}
return fmt.Errorf("cannot format file: %w", err)
}
return outputFile(fileName, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, result)
return ioutil.WriteFile(fileName, result, 0640)
}

func (g *Generator) pushBaseStruct(base *check.BaseStruct) (*genInfo, error) {
Expand All @@ -448,27 +474,6 @@ func (g *Generator) pushBaseStruct(base *check.BaseStruct) (*genInfo, error) {
return g.Data[structName], nil
}

func outputFile(filename string, flag int, data []byte) error {
out, err := os.OpenFile(filename, flag, 0640)
if err != nil {
return fmt.Errorf("open out file fail: %w", err)
}
return output(out, data)
}

func output(wr io.WriteCloser, data []byte) (err error) {
defer func() {
if e := wr.Close(); e != nil {
err = fmt.Errorf("close file fail: %w", e)
}
}()

if _, err = wr.Write(data); err != nil {
return fmt.Errorf("write file fail: %w", err)
}
return nil
}

func render(tmpl string, wr io.Writer, data interface{}) error {
t, err := template.New(tmpl).Parse(tmpl)
if err != nil {
Expand Down
9 changes: 9 additions & 0 deletions internal/utils/pools/export.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Package pools : goroutine pools
package pools

// NewPool return a new pool
func NewPool(size int) Pool {
var p pool
p.Init(size)
return &p
}
73 changes: 73 additions & 0 deletions internal/utils/pools/pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package pools

import "sync"

// Pool goroutine pool
type Pool interface {
// Wait 等待令牌
Wait()
// Done 归还令牌
Done()
// Num 当前发放的令牌书
Num() int
// Size 总令牌数
Size() int

// WaitAll 同步等待令牌全部归还
WaitAll()
// AsyncWaitAll 异步等待令牌全部归还
AsyncWaitAll() <-chan struct{}
}

type pool struct {
pool chan struct{}

wg sync.WaitGroup
}

func (p *pool) Init(size int) {
if size >= 0 {
p.pool = make(chan struct{}, size)
}
}

func (p *pool) Wait() {
if p.pool != nil {
p.wg.Add(1)
p.pool <- struct{}{}
}
}

func (p *pool) Done() {
if p.pool != nil {
<-p.pool
p.wg.Done()
}
}

func (p *pool) Num() int {
if p.pool != nil {
return len(p.pool)
}
return 0
}

func (p *pool) Size() int {
if p.pool != nil {
return cap(p.pool)
}
return 0
}

func (p *pool) WaitAll() {
p.wg.Wait()
}

func (p *pool) AsyncWaitAll() <-chan struct{} {
sig := make(chan struct{})
go func() {
p.WaitAll()
sig <- struct{}{}
}()
return sig
}

0 comments on commit 539a653

Please sign in to comment.