Skip to content

Commit

Permalink
Optimize ast walker
Browse files Browse the repository at this point in the history
  • Loading branch information
auxten committed Jul 28, 2021
1 parent 63fafe1 commit 5ca790c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
8 changes: 7 additions & 1 deletion example/walk.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"log"

"github.com/auxten/postgresql-parser/pkg/sql/parser"
"github.com/auxten/postgresql-parser/pkg/walk"
)

Expand All @@ -19,6 +20,11 @@ func main() {
return false
},
}
_, _ = w.Walk(sql, nil)
stmts, err := parser.Parse(sql)
if err != nil {
return
}

_, _ = w.Walk(stmts, nil)
return
}
36 changes: 24 additions & 12 deletions pkg/walk/walker.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ func (rc ReferredCols) ToList() []string {
return cols
}

func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
stmts, err := parser.Parse(sql)
if err != nil {
return false, err
}
func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err error) {

w.unknownNodes = make([]interface{}, 0)
asts := make([]tree.NodeFormatter, len(stmts))
Expand Down Expand Up @@ -67,6 +63,8 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
walk(node.Expr)
case *tree.Array:
walk(node.Exprs)
case tree.AsOfClause:
walk(node.Expr)
case *tree.BinaryExpr:
walk(node.Left, node.Right)
case *tree.CaseExpr:
Expand Down Expand Up @@ -127,7 +125,6 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
if node.With != nil {
walk(node.With)
}
walk(node.Select)
if node.OrderBy != nil {
for _, order := range node.OrderBy {
walk(order)
Expand All @@ -136,6 +133,7 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
if node.Limit != nil {
walk(node.Limit)
}
walk(node.Select)
case *tree.Order:
walk(node.Expr, node.Table)
case *tree.Limit:
Expand All @@ -148,9 +146,6 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
if node.Having != nil {
walk(node.Having)
}
for _, table := range node.From.Tables {
walk(table)
}
if node.DistinctOn != nil {
for _, distinct := range node.DistinctOn {
walk(distinct)
Expand All @@ -161,6 +156,10 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
walk(group)
}
}
walk(node.From.AsOf)
for _, table := range node.From.Tables {
walk(table)
}
case tree.SelectExpr:
walk(node.Expr)
case tree.SelectExprs:
Expand Down Expand Up @@ -192,6 +191,10 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
}
case *tree.Where:
walk(node.Expr)
case tree.Window:
for _, windowDef := range node {
walk(windowDef)
}
case *tree.WindowDef:
walk(node.Partitions)
if node.Frame != nil {
Expand All @@ -206,13 +209,14 @@ func (w *AstWalker) Walk(sql string, ctx interface{}) (ok bool, err error) {
}
case *tree.WindowFrameBound:
walk(node.OffsetExpr)
case *tree.Window:
case *tree.With:
for _, expr := range node.CTEList {
walk(expr)
}
default:
w.unknownNodes = append(w.unknownNodes, node)
if w.unknownNodes != nil {
w.unknownNodes = append(w.unknownNodes, node)
}
}
}
}
Expand Down Expand Up @@ -257,7 +261,15 @@ func ColNamesInSelect(sql string) (referredCols ReferredCols, err error) {
return false
},
}
_, err = w.Walk(sql, referredCols)
stmts, err := parser.Parse(sql)
if err != nil {
return
}

_, err = w.Walk(stmts, referredCols)
if err != nil {
return
}
for _, col := range w.unknownNodes {
log.Printf("unhandled column type %T", col)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/walk/walker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func TestReferredVarsInSelectStatement(t *testing.T) {
referredCols, err := func() (ReferredCols, error) {
return ColNamesInSelect(tc.sql)
}()
if err.Error() != tc.err.Error() {
if err != nil && err.Error() != tc.err.Error() {
t.Errorf("Expect %s, got %s", tc.err, err)
}
cols := referredCols.ToList()
Expand Down

0 comments on commit 5ca790c

Please sign in to comment.