Skip to content

Commit

Permalink
feat: update test file format to support aggregate functions (#736)
Browse files Browse the repository at this point in the history
  • Loading branch information
scgkiran authored Nov 9, 2024
1 parent 9cccb04 commit c18c0c1
Show file tree
Hide file tree
Showing 14 changed files with 12,480 additions and 6,782 deletions.
7 changes: 3 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ repos:
- id: flake8
- repo: local
hooks:
- id: check-substrait-extensions
name: Check Substrait extensions
entry: pytest tests/test_extensions.py::test_read_substrait_extensions
- id: check-substrait-extensions_coverage
name: Check Substrait extensions and test coverage
entry: pytest tests/test_extensions.py::test_substrait_extension_coverage
language: python
pass_filenames: false

11 changes: 11 additions & 0 deletions grammar/FuncTestCaseLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Whitespace : [ \t\n\r]+ -> channel(HIDDEN) ;

TripleHash: '###';
SubstraitScalarTest: 'SUBSTRAIT_SCALAR_TEST';
SubstraitAggregateTest: 'SUBSTRAIT_AGGREGATE_TEST';
SubstraitInclude: 'SUBSTRAIT_INCLUDE';

FormatVersion
Expand All @@ -20,6 +21,7 @@ DescriptionLine
: '# ' ~[\r\n]* '\r'? '\n'
;

Define: 'DEFINE';
ErrorResult: '<!ERROR>';
UndefineResult: '<!UNDEFINED>';
Overflow: 'OVERFLOW';
Expand All @@ -29,6 +31,11 @@ Saturate: 'SATURATE';
Silent: 'SILENT';
TieToEven: 'TIE_TO_EVEN';
NaN: 'NAN';
AcceptNulls: 'ACCEPT_NULLS';
IgnoreNulls: 'IGNORE_NULLS';
NullHandling: 'NULL_HANDLING';
SpacesOnly: 'SPACES_ONLY';
Truncate: 'TRUNCATE';

IntegerLiteral
: [+-]? Int
Expand Down Expand Up @@ -102,3 +109,7 @@ NullLiteral: 'null';
StringLiteral
: '\'' ('\\' . | '\'\'' | ~['\\])* '\''
;

ColumnName
: 'COL' Int
;
187 changes: 150 additions & 37 deletions grammar/FuncTestCaseParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ header
;

version
: TripleHash SubstraitScalarTest Colon FormatVersion
: TripleHash (SubstraitScalarTest | SubstraitAggregateTest) Colon FormatVersion
;

include
Expand All @@ -27,11 +27,12 @@ testGroupDescription
;

testCase
: functionName=Identifier OParen arguments CParen ( OBracket func_options CBracket )? Eq result
: functionName=identifier OParen arguments CParen ( OBracket func_options CBracket )? Eq result
;

testGroup
: testGroupDescription (testCase)+
: testGroupDescription (testCase)+ #scalarFuncTestGroup
| testGroupDescription (aggFuncTestCase)+ #aggregateFuncTestGroup
;

arguments
Expand All @@ -56,6 +57,64 @@ argument
| timestampTzArg
| intervalYearArg
| intervalDayArg
| listArg
;

aggFuncTestCase
: aggFuncCall ( OBracket func_options CBracket )? Eq result
;

aggFuncCall
: tableData funcName=identifier OParen qualifiedAggregateFuncArgs CParen #multiArgAggregateFuncCall
| tableRows functName=identifier OParen aggregateFuncArgs CParen #compactAggregateFuncCall
| functName=identifier OParen dataColumn CParen #singleArgAggregateFuncCall
;

tableData
: Define tableName=Identifier OParen dataType (Comma dataType)* CParen Eq tableRows
;

tableRows
: OParen (columnValues (Comma columnValues)*)? CParen
;

dataColumn
: columnValues DoubleColon dataType
;

columnValues
: OParen (literal (Comma literal)*)? CParen
;

literal
: NullLiteral
| numericLiteral
| BooleanLiteral
| StringLiteral
| DateLiteral
| TimeLiteral
| TimestampLiteral
| TimestampTzLiteral
| IntervalYearLiteral
| IntervalDayLiteral
;

qualifiedAggregateFuncArgs
: qualifiedAggregateFuncArg (Comma qualifiedAggregateFuncArg)*
;

aggregateFuncArgs
: aggregateFuncArg (Comma aggregateFuncArg)*
;

qualifiedAggregateFuncArg
: tableName=Identifier Dot ColumnName
| argument
;

aggregateFuncArg
: ColumnName DoubleColon dataType
| argument
;

numericLiteral
Expand All @@ -66,7 +125,7 @@ floatLiteral
: FloatLiteral | NaN
;

nullArg: NullLiteral DoubleColon datatype;
nullArg: NullLiteral DoubleColon dataType;

intArg: IntegerLiteral DoubleColon (I8 | I16 | I32 | I64);

Expand All @@ -77,11 +136,11 @@ decimalArg
;

booleanArg
: BooleanLiteral DoubleColon Bool
: BooleanLiteral DoubleColon booleanType
;

stringArg
: StringLiteral DoubleColon Str
: StringLiteral DoubleColon stringType
;

dateArg
Expand All @@ -93,19 +152,27 @@ timeArg
;

timestampArg
: TimestampLiteral DoubleColon Ts
: TimestampLiteral DoubleColon timestampType
;

timestampTzArg
: TimestampTzLiteral DoubleColon TsTZ
: TimestampTzLiteral DoubleColon timestampTZType
;

intervalYearArg
: IntervalYearLiteral DoubleColon IYear
: IntervalYearLiteral DoubleColon intervalYearType
;

intervalDayArg
: IntervalDayLiteral DoubleColon IDay
: IntervalDayLiteral DoubleColon intervalDayType
;

listArg
: literalList DoubleColon listType
;

literalList
: OBracket (literal (Comma literal)*)? CBracket
;

intervalYearLiteral
Expand All @@ -126,53 +193,88 @@ timeInterval
| fractionalSeconds=IntegerLiteral FractionalSecondSuffix
;

datatype
dataType
: scalarType
| parameterizedType
;

scalarType
: Bool #Boolean
| I8 #i8
| I16 #i16
| I32 #i32
| I64 #i64
| FP32 #fp32
| FP64 #fp64
| Str #string
| Binary #binary
| Ts #timestamp
| TsTZ #timestampTz
| Date #date
| Time #time
| IDay #intervalDay
| IYear #intervalYear
| UUID #uuid
| UserDefined Identifier #userDefined
: booleanType #boolean
| I8 #i8
| I16 #i16
| I32 #i32
| I64 #i64
| FP32 #fp32
| FP64 #fp64
| stringType #string
| binaryType #binary
| timestampType #timestamp
| timestampTZType #timestampTz
| Date #date
| Time #time
| intervalDayType #intervalDay
| intervalYearType #intervalYear
| UUID #uuid
| UserDefined Identifier #userDefined
;

booleanType
: (Bool | Boolean)
;

stringType
: (Str | String)
;

binaryType
: (Binary | VBin)
;

timestampType
: (Ts | Timestamp)
;

timestampTZType
: (TsTZ | Timestamp_TZ)
;

intervalYearType
: (IYear | Interval_Year)
;

intervalDayType
: (IDay | Interval_Day)
;

fixedCharType
: FChar isnull=QMark? OAngleBracket len=numericParameter CAngleBracket #fixedChar
: (FChar | FixedChar) isnull=QMark? OAngleBracket len=numericParameter CAngleBracket #fixedChar
;

varCharType
: VChar isnull=QMark? OAngleBracket len=numericParameter CAngleBracket #varChar
: (VChar | VarChar) isnull=QMark? OAngleBracket len=numericParameter CAngleBracket #varChar
;

fixedBinaryType
: FBin isnull=QMark? OAngleBracket len=numericParameter CAngleBracket #fixedBinary
: (FBin | FixedBinary) isnull=QMark? OAngleBracket len=numericParameter CAngleBracket #fixedBinary
;

decimalType
: Dec isnull=QMark? (OAngleBracket precision=numericParameter Comma scale=numericParameter CAngleBracket)? #decimal
: (Dec | Decimal) isnull=QMark?
(OAngleBracket precision=numericParameter Comma scale=numericParameter CAngleBracket)? #decimal
;

precisionTimestampType
: PTs isnull=QMark? OAngleBracket precision=numericParameter CAngleBracket #precisionTimestamp
: (PTs | Precision_Timestamp) isnull=QMark?
OAngleBracket precision=numericParameter CAngleBracket #precisionTimestamp
;

precisionTimestampTZType
: PTsTZ isnull=QMark? OAngleBracket precision=numericParameter CAngleBracket #precisionTimestampTZ
: (PTsTZ | Precision_Timestamp_TZ) isnull=QMark?
OAngleBracket precision=numericParameter CAngleBracket #precisionTimestampTZ
;

listType
: List isnull=QMark? OAngleBracket elemType=dataType CAngleBracket #list
;

parameterizedType
Expand All @@ -185,7 +287,6 @@ parameterizedType
// TODO implement the rest of the parameterized types
// | Struct isnull='?'? Lt expr (Comma expr)* Gt #struct
// | NStruct isnull='?'? Lt Identifier expr (Comma Identifier expr)* Gt #nStruct
// | List isnull='?'? Lt expr Gt #list
// | Map isnull='?'? Lt key=expr Comma value=expr Gt #map
;

Expand All @@ -202,14 +303,26 @@ func_option
;

option_name
: Overflow | Rounding
: Overflow | Rounding | NullHandling | SpacesOnly
| Identifier
;

option_value
: Error | Saturate | Silent | TieToEven | NaN
: Error | Saturate | Silent | TieToEven | NaN | Truncate | AcceptNulls | IgnoreNulls
| BooleanLiteral
| NullLiteral
| Identifier
;

func_options
: func_option (Comma func_option)*
;

nonReserved // IMPORTANT: this rule must only contain tokens
: And | Or | Truncate
;

identifier
: nonReserved
| Identifier
;
18 changes: 18 additions & 0 deletions tests/cases/arithmetic/max.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
### SUBSTRAIT_AGGREGATE_TEST: v1.0
### SUBSTRAIT_INCLUDE: '/extensions/functions_arithmetic.yaml'

# basic: Basic examples without any special cases
max((20, -3, 1, -10, 0, 5)::i8) = 20::i8
max((-32768, 32767, 20000, -30000)::i16) = 32767::i16
max((-214748648, 214748647, 21470048, 4000000)::i32) = 214748647::i32
max((2000000000, -3217908979, 629000000, -100000000, 0, 987654321)::i64) = 2000000000::i64
max((2.5, 0, 5.0, -2.5, -7.5)::fp32) = 5.0::fp32
max((1.5e+308, 1.5e+10, -1.5e+8, -1.5e+7, -1.5e+70)::fp64) = 1.5e+308::fp64

# null_handling: Examples with null as input or output
max((Null, Null, Null)::i16) = Null::i16
max(()::i16) = Null::i16
max((2000000000, Null, 629000000, -100000000, Null, 987654321)::i64) = 2000000000::i64
max((Null, inf)::fp64) = inf::fp64
max((Null, -inf, -1.5e+8, -1.5e+7, -1.5e+70)::fp64) = -1.5e+7::fp64
max((1.5e+308, 1.5e+10, Null, -1.5e+7, Null)::fp64) = 1.5e+308::fp64
Loading

0 comments on commit c18c0c1

Please sign in to comment.