Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use antlr parser to parse the types in extensions #64

Merged
merged 24 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# .codecov.yml
coverage:
ignore:
- "types/parser/baseparser/*.go"
1 change: 1 addition & 0 deletions .github/workflows/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ jobs:
disable_search: true
file: ./coverage.out
fail_ci_if_error: true
codecov_yml_path: .codecov.yml
5 changes: 2 additions & 3 deletions expr/expressions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,7 @@ func TestRoundTripUsingTestData(t *testing.T) {
require.NoError(t, dec.Decode(&tmp))

var (
typeParser, _ = parser.New()
protoSchema proto.NamedStruct
protoSchema proto.NamedStruct
)

raw, err := json.Marshal(tmp["baseSchema"])
Expand Down Expand Up @@ -370,7 +369,7 @@ func TestRoundTripUsingTestData(t *testing.T) {
assert.True(t, e.Equals(e))

if typTest, ok := test["type"].(string); ok {
exp, err := typeParser.ParseString(typTest)
exp, err := parser.ParseType(typTest)
require.NoError(t, err)

assert.Equal(t, exp.String(), e.GetType().String())
Expand Down
19 changes: 9 additions & 10 deletions extensions/simple_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ type ValueArg struct {
}

func (v ValueArg) toTypeString() string {
return v.Value.Expr.(*parser.Type).ShortType()
return v.Value.ValueType.ShortString()
}

func (v ValueArg) argumentMarker() {}
Expand Down Expand Up @@ -146,7 +146,6 @@ func (a *ArgumentList) UnmarshalYAML(fn func(interface{}) error) error {
rv.Elem().Set(reflect.ValueOf(val))
return nil
})

if err != nil {
return fmt.Errorf("failure reading YAML %v", err)
}
Expand Down Expand Up @@ -200,14 +199,14 @@ type Function interface {
}

type ScalarFunctionImpl struct {
Args ArgumentList `yaml:",omitempty"`
Options map[string]Option `yaml:",omitempty"`
Variadic *VariadicBehavior `yaml:",omitempty"`
SessionDependent bool `yaml:"sessionDependent,omitempty"`
Deterministic bool `yaml:",omitempty"`
Nullability NullabilityHandling `yaml:",omitempty" default:"MIRROR"`
Return parser.TypeExpression `yaml:",omitempty"`
Implementation map[string]string `yaml:",omitempty"`
Args ArgumentList `yaml:",omitempty"`
Options map[string]Option `yaml:",omitempty"`
Variadic *VariadicBehavior `yaml:",omitempty"`
SessionDependent bool `yaml:"sessionDependent,omitempty"`
Deterministic bool `yaml:",omitempty"`
Nullability NullabilityHandling `yaml:",omitempty" default:"MIRROR"`
Return *parser.TypeExpression `yaml:",omitempty"`
Implementation map[string]string `yaml:",omitempty"`
}

func (s *ScalarFunctionImpl) signatureKey() string {
Expand Down
13 changes: 6 additions & 7 deletions extensions/simple_extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/substrait-io/substrait-go/extensions"
"github.com/substrait-io/substrait-go/proto"
"github.com/substrait-io/substrait-go/types"
"github.com/substrait-io/substrait-go/types/parser"
)

func TestUnmarshalSimpleExtension(t *testing.T) {
Expand Down Expand Up @@ -68,17 +67,17 @@ scalar_functions:
assert.Equal(t, "scalar1", f.ScalarFunctions[0].Name)
assert.IsType(t, extensions.ValueArg{}, f.ScalarFunctions[0].Impls[0].Args[0])
arg1 := f.ScalarFunctions[0].Impls[0].Args[0].(extensions.ValueArg)
assert.Equal(t, "u!customtype1", arg1.Value.String())
typ, err := arg1.Value.Expr.(*parser.Type).TypeDef.RetType()
assert.Equal(t, "u!customtype1", arg1.Value.ValueType.String())
typ, err := arg1.Value.ValueType.ReturnType()
assert.NoError(t, err)
assert.IsType(t, &types.UserDefinedType{}, typ)
assert.Equal(t, proto.Type_NULLABILITY_REQUIRED, typ.GetNullability(), "expected NULLABILITY_REQUIRED")

assert.Equal(t, "scalar2", f.ScalarFunctions[1].Name)
assert.IsType(t, extensions.ValueArg{}, f.ScalarFunctions[1].Impls[0].Args[0])
ret := f.ScalarFunctions[1].Impls[0].Return
assert.Equal(t, "u!customtype2?", ret.String())
typ, err = ret.Expr.(*parser.Type).TypeDef.RetType()
assert.Equal(t, "u!customtype2?", ret.ValueType.String())
typ, err = ret.ValueType.ReturnType()
assert.NoError(t, err)
assert.IsType(t, &types.UserDefinedType{}, typ)
assert.Equal(t, proto.Type_NULLABILITY_NULLABLE, typ.GetNullability(), "expected NULLABILITY_NULLABLE")
Expand Down Expand Up @@ -113,10 +112,10 @@ scalar_functions:

x := f.ScalarFunctions[0].Impls[0].Args[0].(extensions.ValueArg)
assert.Equal(t, "x", x.Name)
assert.Equal(t, "i8", x.Value.String())
assert.Equal(t, "i8", x.Value.ValueType.String())
y := f.ScalarFunctions[0].Impls[0].Args[1].(extensions.ValueArg)
assert.Equal(t, "y", y.Name)
assert.Equal(t, "i8", y.Value.String())
assert.Equal(t, "i8", y.Value.ValueType.String())

assert.Equal(t, map[string]extensions.Option{
"overflow": {Values: []string{"SILENT", "SATURATE", "ERROR"}},
Expand Down
60 changes: 21 additions & 39 deletions extensions/variants.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,9 @@
}

if nullHandling == DiscreteNullability {
if t, ok := p.Value.Expr.(*parser.Type); ok {
if isNullable != t.Optional() {
return allNonNull, fmt.Errorf("%w: discrete nullability did not match for arg #%d",
substraitgo.ErrInvalidType, idx)
}
} else {
return allNonNull, substraitgo.ErrNotImplemented
if isNullable != (p.Value.ValueType.GetNullability() == types.NullabilityNullable) {
return allNonNull, fmt.Errorf("%w: discrete nullability did not match for arg #%d",
substraitgo.ErrInvalidType, idx)
}
}
case TypeArg:
Expand All @@ -76,7 +72,7 @@
return allNonNull, nil
}

func EvaluateTypeExpression(nullHandling NullabilityHandling, expr parser.TypeExpression, paramTypeList ArgumentList, variadic *VariadicBehavior, actualTypes []types.Type) (types.Type, error) {
func EvaluateTypeExpression(nullHandling NullabilityHandling, expr types.FuncDefArgType, paramTypeList ArgumentList, variadic *VariadicBehavior, actualTypes []types.Type) (types.Type, error) {
if len(paramTypeList) != len(actualTypes) {
if variadic == nil {
return nil, fmt.Errorf("%w: mismatch in number of arguments provided. got %d, expected %d",
Expand Down Expand Up @@ -111,15 +107,9 @@
}
}

var outType types.Type
if t, ok := expr.Expr.(*parser.Type); ok {
var err error
outType, err = t.RetType()
if err != nil {
return nil, err
}
} else {
return nil, substraitgo.ErrNotImplemented
outType, err := expr.ReturnType()
if err != nil {
return nil, err

Check warning on line 112 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L112

Added line #L112 was not covered by tests
}

if nullHandling == MirrorNullability || nullHandling == "" {
Expand Down Expand Up @@ -219,11 +209,7 @@
for argPos, param := range paramTypeList {
switch paramType := param.(type) {
case ValueArg:
funcDefArgType, err := paramType.Value.Expr.(*parser.Type).ArgType()
if err != nil {
return nil, err
}
out = append(out, funcDefArgType)
out = append(out, paramType.Value.ValueType)

Check warning on line 212 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L212

Added line #L212 was not covered by tests
case EnumArg:
return nil, fmt.Errorf("%w: invalid argument at position %d for match operation", substraitgo.ErrInvalidType, argPos)
case TypeArg:
Expand All @@ -242,11 +228,11 @@
}
splitArgs := strings.Split(argsStr, "_")
for _, argStr := range splitArgs {
parsed, err := defParser.ParseString(argStr)
parsed, err := parser.ParseType(argStr)

Check warning on line 231 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L231

Added line #L231 was not covered by tests
if err != nil {
panic(err)
}
exp := ValueArg{Name: name, Value: parsed}
exp := ValueArg{Name: name, Value: &parser.TypeExpression{ValueType: parsed}}

Check warning on line 235 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L235

Added line #L235 was not covered by tests
args = append(args, exp)
}

Expand Down Expand Up @@ -316,7 +302,7 @@
func (s *ScalarFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
func (s *ScalarFunctionVariant) URI() string { return s.uri }
func (s *ScalarFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, s.impl.Variadic, argumentTypes)
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return.ValueType, s.impl.Args, s.impl.Variadic, argumentTypes)
}
func (s *ScalarFunctionVariant) CompoundName() string {
return s.name + ":" + s.impl.signatureKey()
Expand Down Expand Up @@ -374,10 +360,6 @@
IntermediateOutputType string
}

var (
defParser, _ = parser.New()
)

func NewAggFuncVariantOpts(id ID, opts AggVariantOptions) *AggregateFunctionVariant {
var aggIntermediate parser.TypeExpression
if opts.Decomposable == "" {
Expand All @@ -389,11 +371,11 @@
substraitgo.ErrInvalidExpr, id))
}

intermediate, err := defParser.ParseString(opts.IntermediateOutputType)
intermediate, err := parser.ParseType(opts.IntermediateOutputType)

Check warning on line 374 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L374

Added line #L374 was not covered by tests
if err != nil {
panic(err)
}
aggIntermediate = *intermediate
aggIntermediate.ValueType = intermediate

Check warning on line 378 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L378

Added line #L378 was not covered by tests
}

simpleName, args := parseFuncName(id.Name)
Expand Down Expand Up @@ -432,7 +414,7 @@
func (s *AggregateFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
func (s *AggregateFunctionVariant) URI() string { return s.uri }
func (s *AggregateFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, s.impl.Variadic, argumentTypes)
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return.ValueType, s.impl.Args, s.impl.Variadic, argumentTypes)

Check warning on line 417 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L417

Added line #L417 was not covered by tests
}
func (s *AggregateFunctionVariant) CompoundName() string {
return s.name + ":" + s.impl.signatureKey()
Expand All @@ -442,8 +424,8 @@
}
func (s *AggregateFunctionVariant) Decomposability() DecomposeType { return s.impl.Decomposable }
func (s *AggregateFunctionVariant) Intermediate() (types.FuncDefArgType, error) {
if t, ok := s.impl.Intermediate.Expr.(*parser.Type); ok {
return t.ArgType()
if s.impl.Intermediate.ValueType != nil {
return s.impl.Intermediate.ValueType, nil
}
return nil, fmt.Errorf("%w: bad intermediate type expression", substraitgo.ErrInvalidType)
}
Expand Down Expand Up @@ -513,11 +495,11 @@
substraitgo.ErrInvalidExpr, id))
}

intermediate, err := defParser.ParseString(opts.IntermediateOutputType)
intermediate, err := parser.ParseType(opts.IntermediateOutputType)

Check warning on line 498 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L498

Added line #L498 was not covered by tests
if err != nil {
panic(err)
}
aggIntermediate = *intermediate
aggIntermediate.ValueType = intermediate

Check warning on line 502 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L502

Added line #L502 was not covered by tests
}

simpleName, args := parseFuncName(id.Name)
Expand Down Expand Up @@ -552,7 +534,7 @@
func (s *WindowFunctionVariant) Nullability() NullabilityHandling { return s.impl.Nullability }
func (s *WindowFunctionVariant) URI() string { return s.uri }
func (s *WindowFunctionVariant) ResolveType(argumentTypes []types.Type) (types.Type, error) {
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return, s.impl.Args, s.impl.Variadic, argumentTypes)
return EvaluateTypeExpression(s.impl.Nullability, s.impl.Return.ValueType, s.impl.Args, s.impl.Variadic, argumentTypes)

Check warning on line 537 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L537

Added line #L537 was not covered by tests
}
func (s *WindowFunctionVariant) CompoundName() string {
return s.name + ":" + s.impl.signatureKey()
Expand All @@ -562,8 +544,8 @@
}
func (s *WindowFunctionVariant) Decomposability() DecomposeType { return s.impl.Decomposable }
func (s *WindowFunctionVariant) Intermediate() (types.FuncDefArgType, error) {
if t, ok := s.impl.Intermediate.Expr.(*parser.Type); ok {
return t.ArgType()
if s.impl.Intermediate.ValueType != nil {
return s.impl.Intermediate.ValueType, nil

Check warning on line 548 in extensions/variants.go

View check run for this annotation

Codecov / codecov/patch

extensions/variants.go#L547-L548

Added lines #L547 - L548 were not covered by tests
}
return nil, fmt.Errorf("%w: bad intermediate type expression", substraitgo.ErrInvalidType)
}
Expand Down
47 changes: 22 additions & 25 deletions extensions/variants_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,41 @@ import (

func TestEvaluateTypeExpression(t *testing.T) {
var (
p, _ = parser.New()
i64Null, _ = p.ParseString("i64?")
i64NonNull, _ = p.ParseString("i64")
strNull, _ = p.ParseString("string?")
i64Null, _ = parser.ParseType("i64?")
i64NonNull, _ = parser.ParseType("i64")
strNull, _ = parser.ParseType("string?")
)

tests := []struct {
name string
nulls extensions.NullabilityHandling
ret parser.TypeExpression
ret types.FuncDefArgType
extArgs extensions.ArgumentList
args []types.Type
expected types.Type
err string
}{
{"defaults", "", *i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: i64Null}},
{"defaults", "", i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: i64Null}}},
[]types.Type{&types.Int64Type{Nullability: types.NullabilityNullable}},
&types.Int64Type{Nullability: types.NullabilityNullable}, ""},
{"arg mismatch", "", *strNull, extensions.ArgumentList{extensions.ValueArg{Value: strNull}},
{"arg mismatch", "", strNull, extensions.ArgumentList{extensions.ValueArg{Value: &parser.TypeExpression{ValueType: strNull}}},
[]types.Type{}, nil, "invalid expression: mismatch in number of arguments provided. got 0, expected 1"},
{"missing enum arg", "", *i64Null, extensions.ArgumentList{
extensions.ValueArg{Value: i64NonNull}, extensions.EnumArg{Name: "foo"}},
{"missing enum arg", "", i64Null, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: i64NonNull}}, extensions.EnumArg{Name: "foo"}},
[]types.Type{&types.Int64Type{}, &types.Int64Type{}}, nil, "invalid type: arg #1 (foo) should be an enum"},
{"discrete null handling", extensions.DiscreteNullability, *strNull, extensions.ArgumentList{
extensions.ValueArg{Value: strNull}},
{"discrete null handling", extensions.DiscreteNullability, strNull, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: strNull}}},
[]types.Type{&types.StringType{Nullability: types.NullabilityRequired}},
nil, "invalid type: discrete nullability did not match for arg #0"},
{"mirror", extensions.MirrorNullability, *strNull, extensions.ArgumentList{
extensions.ValueArg{Value: i64NonNull}, extensions.ValueArg{Value: i64Null}},
{"mirror", extensions.MirrorNullability, strNull, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: i64NonNull}}, extensions.ValueArg{Value: &parser.TypeExpression{ValueType: i64Null}}},
[]types.Type{
&types.Int64Type{Nullability: types.NullabilityRequired},
&types.Int64Type{Nullability: types.NullabilityRequired}},
&types.StringType{Nullability: types.NullabilityRequired}, ""},
{"declared output", extensions.DeclaredOutputNullability, *strNull, extensions.ArgumentList{
extensions.ValueArg{Value: strNull}},
{"declared output", extensions.DeclaredOutputNullability, strNull, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: strNull}}},
[]types.Type{&types.StringType{Nullability: types.NullabilityRequired}},
&types.StringType{Nullability: types.NullabilityNullable}, ""},
}
Expand All @@ -70,31 +69,29 @@ func TestEvaluateTypeExpression(t *testing.T) {

func TestVariantWithVariadic(t *testing.T) {
var (
p, _ = parser.New()
i64Null, _ = p.ParseString("i64?")
i64NonNull, _ = p.ParseString("i64")
// strNull, _ = p.ParseString("string?")
i64Null, _ = parser.ParseType("i64?")
i64NonNull, _ = parser.ParseType("i64")
)

tests := []struct {
name string
nulls extensions.NullabilityHandling
ret parser.TypeExpression
ret types.FuncDefArgType
extArgs extensions.ArgumentList
args []types.Type
expected types.Type
variadic extensions.VariadicBehavior
err string
}{
{"basic", "", *i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: i64Null}},
{"basic", "", i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: i64Null}}},
[]types.Type{&types.Int64Type{Nullability: types.NullabilityNullable},
&types.Int64Type{Nullability: types.NullabilityNullable}},
&types.Int64Type{Nullability: types.NullabilityNullable},
extensions.VariadicBehavior{
Min: 0, ParameterConsistency: extensions.ConsistentParams}, ""},
{"bad arg count", "", *i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: i64Null}},
{"bad arg count", "", i64NonNull, extensions.ArgumentList{
extensions.ValueArg{Value: &parser.TypeExpression{ValueType: i64Null}}},
[]types.Type{&types.Int64Type{Nullability: types.NullabilityNullable},
&types.Int64Type{Nullability: types.NullabilityNullable}},
nil, extensions.VariadicBehavior{
Expand Down
5 changes: 3 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ toolchain go1.22.3

require (
github.com/alecthomas/participle/v2 v2.0.0
github.com/antlr4-go/antlr/v4 v4.13.1
github.com/cockroachdb/apd/v3 v3.2.1
github.com/creasty/defaults v1.8.0
github.com/goccy/go-yaml v1.9.8
github.com/google/go-cmp v0.5.9
github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.6.0
github.com/stretchr/testify v1.8.2
github.com/substrait-io/substrait v0.57.1
golang.org/x/exp v0.0.0-20230206171751-46f607a40771
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
google.golang.org/protobuf v1.33.0
gopkg.in/yaml.v3 v3.0.1
)
Expand Down
Loading
Loading