Skip to content

Commit

Permalink
feat: add GetFunctionRegistry in LocalFunctionRegistry (substrait-io#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
scgkiran authored Jan 6, 2025
1 parent 298eedc commit 18b7daf
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 0 deletions.
1 change: 1 addition & 0 deletions functions/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func (d *dialectImpl) LocalizeFunctionRegistry(registry FunctionRegistry) (Local
allFunctions: allVariants,
idToLocalFunctionMap: makeLocalFunctionVariantsMap(allVariants),
localTypeRegistry: localTypeRegistry,
funcRegistry: registry,
}, nil
}

Expand Down
1 change: 1 addition & 0 deletions functions/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ window_functions:
localRegistry, err := dialect.LocalizeFunctionRegistry(gFunctionRegistry)
assert.NoError(t, err)
assert.Equal(t, t.Name(), localRegistry.GetDialect().Name())
assert.Equal(t, gFunctionRegistry, localRegistry.GetFunctionRegistry())
}

func TestBadDialects(t *testing.T) {
Expand Down
5 changes: 5 additions & 0 deletions functions/local_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type localFunctionRegistryImpl struct {

idToLocalFunctionMap map[extensions.ID]localFunctionVariant
localTypeRegistry LocalTypeRegistry
funcRegistry FunctionRegistry
}

func makeLocalFunctionVariantsMap(functions []extensions.FunctionVariant) map[extensions.ID]localFunctionVariant {
Expand All @@ -44,6 +45,10 @@ func (l *localFunctionRegistryImpl) GetDialect() Dialect {
return l.dialect
}

func (l *localFunctionRegistryImpl) GetFunctionRegistry() FunctionRegistry {
return l.funcRegistry
}

func (l *localFunctionRegistryImpl) GetScalarFunctions(name FunctionName, numArgs int) []*LocalScalarFunctionVariant {
return getFunctionVariantsByCount(getOrEmpty(name, l.scalarFunctions), numArgs)
}
Expand Down
1 change: 1 addition & 0 deletions functions/registries.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ type FunctionRegistry interface {
type LocalFunctionRegistry interface {
functionRegistryBase[FunctionName, *LocalScalarFunctionVariant, *LocalAggregateFunctionVariant, *LocalWindowFunctionVariant]
GetDialect() Dialect
GetFunctionRegistry() FunctionRegistry
GetScalarFunctionByInvocation(scalarFuncInvocation *expr.ScalarFunction) (*LocalScalarFunctionVariant, error)
GetAggregateFunctionByInvocation(aggregateFuncInvocation *expr.AggregateFunction) (*LocalAggregateFunctionVariant, error)
GetWindowFunctionByInvocation(windowFuncInvocation *expr.WindowFunction) (*LocalWindowFunctionVariant, error)
Expand Down

0 comments on commit 18b7daf

Please sign in to comment.