Skip to content

Commit

Permalink
feat: Use antlr parser to parse the types in extensions (substrait-io#64
Browse files Browse the repository at this point in the history
)

* Use antlr parser to parse the types in extensions
* support UserDefineType in grammar and parser
* Remove old parser code.
* Add tests for parameterized types
* Remove function_test_format grammar files
* Move parser under types folder
* Fetch grammar from substraite core repo
  • Loading branch information
scgkiran authored Nov 5, 2024
1 parent 5ba7f5e commit bdb436b
Show file tree
Hide file tree
Showing 40 changed files with 8,744 additions and 948 deletions.
4 changes: 4 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# .codecov.yml
coverage:
ignore:
- "types/parser/baseparser/*.go"
1 change: 1 addition & 0 deletions .github/workflows/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ jobs:
disable_search: true
file: ./coverage.out
fail_ci_if_error: true
codecov_yml_path: .codecov.yml
5 changes: 2 additions & 3 deletions expr/expressions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,7 @@ func TestRoundTripUsingTestData(t *testing.T) {
require.NoError(t, dec.Decode(&tmp))

var (
typeParser, _ = parser.New()
protoSchema proto.NamedStruct
protoSchema proto.NamedStruct
)

raw, err := json.Marshal(tmp["baseSchema"])
Expand Down Expand Up @@ -370,7 +369,7 @@ func TestRoundTripUsingTestData(t *testing.T) {
assert.True(t, e.Equals(e))

if typTest, ok := test["type"].(string); ok {
exp, err := typeParser.ParseString(typTest)
exp, err := parser.ParseType(typTest)
require.NoError(t, err)

assert.Equal(t, exp.String(), e.GetType().String())
Expand Down
19 changes: 9 additions & 10 deletions extensions/simple_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ type ValueArg struct {
}

func (v ValueArg) toTypeString() string {
return v.Value.Expr.(*parser.Type).ShortType()
return v.Value.ValueType.ShortString()
}

func (v ValueArg) argumentMarker() {}
Expand Down Expand Up @@ -146,7 +146,6 @@ func (a *ArgumentList) UnmarshalYAML(fn func(interface{}) error) error {
rv.Elem().Set(reflect.ValueOf(val))
return nil
})

if err != nil {
return fmt.Errorf("failure reading YAML %v", err)
}
Expand Down Expand Up @@ -200,14 +199,14 @@ type Function interface {
}

type ScalarFunctionImpl struct {
Args ArgumentList `yaml:",omitempty"`
Options map[string]Option `yaml:",omitempty"`
Variadic *VariadicBehavior `yaml:",omitempty"`
SessionDependent bool `yaml:"sessionDependent,omitempty"`
Deterministic bool `yaml:",omitempty"`
Nullability NullabilityHandling `yaml:",omitempty" default:"MIRROR"`
Return parser.TypeExpression `yaml:",omitempty"`
Implementation map[string]string `yaml:",omitempty"`
Args ArgumentList `yaml:",omitempty"`
Options map[string]Option `yaml:",omitempty"`
Variadic *VariadicBehavior `yaml:",omitempty"`
SessionDependent bool `yaml:"sessionDependent,omitempty"`
Deterministic bool `yaml:",omitempty"`
Nullability NullabilityHandling `yaml:",omitempty" default:"MIRROR"`
Return *parser.TypeExpression `yaml:",omitempty"`
Implementation map[string]string `yaml:",omitempty"`
}

func (s *ScalarFunctionImpl) signatureKey() string {
Expand Down
13 changes: 6 additions & 7 deletions extensions/simple_extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/substrait-io/substrait-go/extensions"
"github.com/substrait-io/substrait-go/proto"
"github.com/substrait-io/substrait-go/types"
"github.com/substrait-io/substrait-go/types/parser"
)

func TestUnmarshalSimpleExtension(t *testing.T) {
Expand Down Expand Up @@ -68,17 +67,17 @@ scalar_functions:
assert.Equal(t, "scalar1", f.ScalarFunctions[0].Name)
assert.IsType(t, extensions.ValueArg{}, f.ScalarFunctions[0].Impls[0].Args[0])
arg1 := f.ScalarFunctions[0].Impls[0].Args[0].(extensions.ValueArg)
assert.Equal(t, "u!customtype1", arg1.Value.String())
typ, err := arg1.Value.Expr.(*parser.Type).TypeDef.RetType()
assert.Equal(t, "u!customtype1", arg1.Value.ValueType.String())
typ, err := arg1.Value.ValueType.ReturnType()
assert.NoError(t, err)
assert.IsType(t, &types.UserDefinedType{}, typ)
assert.Equal(t, proto.Type_NULLABILITY_REQUIRED, typ.GetNullability(), "expected NULLABILITY_REQUIRED")

assert.Equal(t, "scalar2", f.ScalarFunctions[1].Name)
assert.IsType(t, extensions.ValueArg{}, f.ScalarFunctions[1].Impls[0].Args[0])
ret := f.ScalarFunctions[1].Impls[0].Return
assert.Equal(t, "u!customtype2?", ret.String())
typ, err = ret.Expr.(*parser.Type).TypeDef.RetType()
assert.Equal(t, "u!customtype2?", ret.ValueType.String())
typ, err = ret.ValueType.ReturnType()
assert.NoError(t, err)
assert.IsType(t, &types.UserDefinedType{}, typ)
assert.Equal(t, proto.Type_NULLABILITY_NULLABLE, typ.GetNullability(), "expected NULLABILITY_NULLABLE")
Expand Down Expand Up @@ -113,10 +112,10 @@ scalar_functions:

x := f.ScalarFunctions[0].Impls[0].Args[0].(extensions.ValueArg)
assert.Equal(t, "x", x.Name)
assert.Equal(t, "i8", x.Value.String())
assert.Equal(t, "i8", x.Value.ValueType.String())
y := f.ScalarFunctions[0].Impls[0].Args[1].(extensions.ValueArg)
assert.Equal(t, "y", y.Name)
assert.Equal(t, "i8", y.Value.String())
assert.Equal(t, "i8", y.Value.ValueType.String())

assert.Equal(t, map[string]extensions.Option{
"overflow": {Values: []string{"SILENT", "SATURATE", "ERROR"}},
Expand Down
60 changes: 21 additions & 39 deletions extensions/variants.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,9 @@ func validateType(arg Argument, actual types.Type, idx int, nullHandling Nullabi
}

if nullHandling == DiscreteNullability {
if t, ok := p.Value.Expr.(*parser.Type); ok {
if isNullable != t.Optional() {
return allNonNull, fmt.Errorf("%w: discrete nullability did not match for arg #%d",
substraitgo.ErrInvalidType, idx)
}
} else {
return allNonNull, substraitgo.ErrNotImplemented
if isNullable != (p.Value.ValueType.GetNullability() == types.NullabilityNullable) {
return allNonNull, fmt.Errorf("%w: discrete nullability did not match for arg #%d",
substraitgo.ErrInvalidType, idx)
}
}
case TypeArg:
Expand All @@ -76,7 +72,7 @@ func validateType(arg Argument, actual types.Type, idx int, nullHandling Nullabi
return allNonNull, nil
}

func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeExpression, paramTypeList ArgumentList, variadic *VariadicBehavior, actualTypes []types.Type) (types.Type, error) {
func EvaluateTypeExpression(nullHandling NullabilityHandling, expr types.FuncDefArgType, paramTypeList ArgumentList, variadic *VariadicBehavior, actualTypes []types.Type) (types.Type, error) {
if len(paramTypeList) != len(actualTypes) {
if variadic == nil {
return nil, fmt.Errorf("%w: mismatch in number of arguments provided. got %d, expected %d",
Expand Down Expand Up @@ -111,15 +107,9 @@ func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeEx
}
}

var outType types.Type
if t, ok := expr.Expr.(*parser.Type); ok {
var err error
outType, err = t.RetType()
if err != nil {
return nil, err
}
} else {
return nil, substraitgo.ErrNotImplemented
outType, err := expr.ReturnType()
if err != nil {
return nil, err
}

if nullHandling == MirrorNullability || nullHandling == "" {
Expand Down Expand Up @@ -219,11 +209,7 @@ func getFuncDefFromArgList(paramTypeList ArgumentList) ([]types.FuncDefArgType,
for argPos, param := range paramTypeList {
switch paramType := param.(type) {
case ValueArg:
funcDefArgType, err := paramType.Value.Expr.(*parser.Type).ArgType()
if err != nil {
return nil, err
}
out = append(out, funcDefArgType)
out = append(out, paramType.Value.ValueType)
case EnumArg:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)
case TypeArg:
Expand All @@ -242,11 +228,11 @@ func parseFuncName(compoundName string) (name string, args ArgumentList) {
}
splitArgs := strings.Split(argsStr, "_")
for _, argStr := range splitArgs {
parsed, err := defParser.ParseString(argStr)
parsed, err := parser.ParseType(argStr)
if err != nil {
panic(err)
}
exp := ValueArg{Name: name, Value: parsed}
exp := ValueArg{Name: name, Value: &parser.TypeExpression{ValueType: parsed}}
args = append(args, exp)
}

Expand Down Expand Up @@ -316,7 +302,7 @@ func (s *ScalarFunctionVariant) SessionDependent() bool { return s.imp
func (s *ScalarFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
func (s *ScalarFunctionVariant) URI() string { return s.uri }
func (s *ScalarFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, s.impl.Variadic, argumentTypes)
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return.ValueType, s.impl.Args, s.impl.Variadic, argumentTypes)
}
func (s *ScalarFunctionVariant) CompoundName() string {
return s.name + ":" + s.impl.signatureKey()
Expand Down Expand Up @@ -374,10 +360,6 @@ type AggVariantOptions struct {
IntermediateOutputType string
}

var (
defParser, _ = parser.New()
)

func NewAggFuncVariantOpts(id ID, opts AggVariantOptions) *AggregateFunctionVariant {
var aggIntermediate parser.TypeExpression
if opts.Decomposable == "" {
Expand All @@ -389,11 +371,11 @@ func NewAggFuncVariantOpts(id ID, opts AggVariantOptions) *AggregateFunctionVari
substraitgo.ErrInvalidExpr, id))
}

intermediate, err := defParser.ParseString(opts.IntermediateOutputType)
intermediate, err := parser.ParseType(opts.IntermediateOutputType)
if err != nil {
panic(err)
}
aggIntermediate = *intermediate
aggIntermediate.ValueType = intermediate
}

simpleName, args := parseFuncName(id.Name)
Expand Down Expand Up @@ -432,7 +414,7 @@ func (s *AggregateFunctionVariant) SessionDependent() bool { return s.
func (s *AggregateFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
func (s *AggregateFunctionVariant) URI() string { return s.uri }
func (s *AggregateFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, s.impl.Variadic, argumentTypes)
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return.ValueType, s.impl.Args, s.impl.Variadic, argumentTypes)
}
func (s *AggregateFunctionVariant) CompoundName() string {
return s.name + ":" + s.impl.signatureKey()
Expand All @@ -442,8 +424,8 @@ func (s *AggregateFunctionVariant) ID() ID {
}
func (s *AggregateFunctionVariant) Decomposability() DecomposeType { return s.impl.Decomposable }
func (s *AggregateFunctionVariant) Intermediate() (types.FuncDefArgType, error) {
if t, ok := s.impl.Intermediate.Expr.(*parser.Type); ok {
return t.ArgType()
if s.impl.Intermediate.ValueType != nil {
return s.impl.Intermediate.ValueType, nil
}
return nil, fmt.Errorf("%w: bad intermediate type expression", substraitgo.ErrInvalidType)
}
Expand Down Expand Up @@ -513,11 +495,11 @@ func NewWindowFuncVariantOpts(id ID, opts WindowVariantOpts) *WindowFunctionVari
substraitgo.ErrInvalidExpr, id))
}

intermediate, err := defParser.ParseString(opts.IntermediateOutputType)
intermediate, err := parser.ParseType(opts.IntermediateOutputType)
if err != nil {
panic(err)
}
aggIntermediate = *intermediate
aggIntermediate.ValueType = intermediate
}

simpleName, args := parseFuncName(id.Name)
Expand Down Expand Up @@ -552,7 +534,7 @@ func (s *WindowFunctionVariant) SessionDependent() bool { return s.imp
func (s *WindowFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
func (s *WindowFunctionVariant) URI() string { return s.uri }
func (s *WindowFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, s.impl.Variadic, argumentTypes)
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return.ValueType, s.impl.Args, s.impl.Variadic, argumentTypes)
}
func (s *WindowFunctionVariant) CompoundName() string {
return s.name + ":" + s.impl.signatureKey()
Expand All @@ -562,8 +544,8 @@ func (s *WindowFunctionVariant) ID() ID {
}
func (s *WindowFunctionVariant) Decomposability() DecomposeType { return s.impl.Decomposable }
func (s *WindowFunctionVariant) Intermediate() (types.FuncDefArgType, error) {
if t, ok := s.impl.Intermediate.Expr.(*parser.Type); ok {
return t.ArgType()
if s.impl.Intermediate.ValueType != nil {
return s.impl.Intermediate.ValueType, nil
}
return nil, fmt.Errorf("%w: bad intermediate type expression", substraitgo.ErrInvalidType)
}
Expand Down
47 changes: 22 additions & 25 deletions extensions/variants_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,41 @@ import (

func TestEvaluateTypeExpression(t *testing.T) {
var (
p, _ = parser.New()
i64Null, _ = p.ParseString("i64?")
i64NonNull, _ = p.ParseString("i64")
strNull, _ = p.ParseString("string?")
i64Null, _ = parser.ParseType("i64?")
i64NonNull, _ = parser.ParseType("i64")
strNull, _ = parser.ParseType("string?")
)

tests := []struct {
name string
nulls extensions.NullabilityHandling
ret parser.TypeExpression
ret types.FuncDefArgType
extArgs extensions.ArgumentList
args []types.Type
expected types.Type
err string
}{
{"defaults", "", *i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: i64Null}},
{"defaults", "", i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: i64Null}}},
[]types.Type{&types.Int64Type{Nullability: types.NullabilityNullable}},
&types.Int64Type{Nullability: types.NullabilityNullable}, ""},
{"arg mismatch", "", *strNull, extensions.ArgumentList{extensions.ValueArg{Value: strNull}},
{"arg mismatch", "", strNull, extensions.ArgumentList{extensions.ValueArg{Value: &parser.TypeExpression{ValueType: strNull}}},
[]types.Type{}, nil, "invalid expression: mismatch in number of arguments provided. got 0, expected 1"},
{"missing enum arg", "", *i64Null, extensions.ArgumentList{
extensions.ValueArg{Value: i64NonNull}, extensions.EnumArg{Name: "foo"}},
{"missing enum arg", "", i64Null, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: i64NonNull}}, extensions.EnumArg{Name: "foo"}},
[]types.Type{&types.Int64Type{}, &types.Int64Type{}}, nil, "invalid type: arg #1 (foo) should be an enum"},
{"discrete null handling", extensions.DiscreteNullability, *strNull, extensions.ArgumentList{
extensions.ValueArg{Value: strNull}},
{"discrete null handling", extensions.DiscreteNullability, strNull, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: strNull}}},
[]types.Type{&types.StringType{Nullability: types.NullabilityRequired}},
nil, "invalid type: discrete nullability did not match for arg #0"},
{"mirror", extensions.MirrorNullability, *strNull, extensions.ArgumentList{
extensions.ValueArg{Value: i64NonNull}, extensions.ValueArg{Value: i64Null}},
{"mirror", extensions.MirrorNullability, strNull, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: i64NonNull}}, extensions.ValueArg{Value: &parser.TypeExpression{ValueType: i64Null}}},
[]types.Type{
&types.Int64Type{Nullability: types.NullabilityRequired},
&types.Int64Type{Nullability: types.NullabilityRequired}},
&types.StringType{Nullability: types.NullabilityRequired}, ""},
{"declared output", extensions.DeclaredOutputNullability, *strNull, extensions.ArgumentList{
extensions.ValueArg{Value: strNull}},
{"declared output", extensions.DeclaredOutputNullability, strNull, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: strNull}}},
[]types.Type{&types.StringType{Nullability: types.NullabilityRequired}},
&types.StringType{Nullability: types.NullabilityNullable}, ""},
}
Expand All @@ -70,31 +69,29 @@ func TestEvaluateTypeExpression(t *testing.T) {

func TestVariantWithVariadic(t *testing.T) {
var (
p, _ = parser.New()
i64Null, _ = p.ParseString("i64?")
i64NonNull, _ = p.ParseString("i64")
// strNull, _ = p.ParseString("string?")
i64Null, _ = parser.ParseType("i64?")
i64NonNull, _ = parser.ParseType("i64")
)

tests := []struct {
name string
nulls extensions.NullabilityHandling
ret parser.TypeExpression
ret types.FuncDefArgType
extArgs extensions.ArgumentList
args []types.Type
expected types.Type
variadic extensions.VariadicBehavior
err string
}{
{"basic", "", *i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: i64Null}},
{"basic", "", i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: i64Null}}},
[]types.Type{&types.Int64Type{Nullability: types.NullabilityNullable},
&types.Int64Type{Nullability: types.NullabilityNullable}},
&types.Int64Type{Nullability: types.NullabilityNullable},
extensions.VariadicBehavior{
Min: 0, ParameterConsistency: extensions.ConsistentParams}, ""},
{"bad arg count", "", *i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: i64Null}},
{"bad arg count", "", i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: i64Null}}},
[]types.Type{&types.Int64Type{Nullability: types.NullabilityNullable},
&types.Int64Type{Nullability: types.NullabilityNullable}},
nil, extensions.VariadicBehavior{
Expand Down
5 changes: 3 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ toolchain go1.22.3

require (
github.com/alecthomas/participle/v2 v2.0.0
github.com/antlr4-go/antlr/v4 v4.13.1
github.com/cockroachdb/apd/v3 v3.2.1
github.com/creasty/defaults v1.8.0
github.com/goccy/go-yaml v1.9.8
github.com/google/go-cmp v0.5.9
github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.6.0
github.com/stretchr/testify v1.8.2
github.com/substrait-io/substrait v0.57.1
golang.org/x/exp v0.0.0-20230206171751-46f607a40771
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
google.golang.org/protobuf v1.33.0
gopkg.in/yaml.v3 v3.0.1
)
Expand Down
Loading

0 comments on commit bdb436b

Please sign in to comment.