-
Notifications
You must be signed in to change notification settings - Fork 28.7k
[SPARK-36989][TESTS][PYTHON] Add type hints data tests #34296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
34baef1
57bef80
1c1ee25
3ea06b9
d87ff75
0a88bc7
c410d86
f2694e9
2a99c4e
77094ae
af7a34a
1ae0494
e052edc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ MINIMUM_MYPY="0.910" | |
MYPY_BUILD="mypy" | ||
PYCODESTYLE_BUILD="pycodestyle" | ||
MINIMUM_PYCODESTYLE="2.7.0" | ||
PYTEST_BUILD="pytest" | ||
|
||
PYTHON_EXECUTABLE="${PYTHON_EXECUTABLE:-python3}" | ||
|
||
|
@@ -124,10 +125,66 @@ function pycodestyle_test { | |
fi | ||
} | ||
|
||
function mypy_test { | ||
|
||
function mypy_annotation_test { | ||
local MYPY_REPORT= | ||
local MYPY_STATUS= | ||
|
||
echo "starting mypy annotations test..." | ||
MYPY_REPORT=$( ($MYPY_BUILD \ | ||
--config-file python/mypy.ini \ | ||
--cache-dir /tmp/.mypy_cache/ \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we set Alternatively, we can use |
||
python/pyspark) 2>&1) | ||
MYPY_STATUS=$? | ||
|
||
if [ "$MYPY_STATUS" -ne 0 ]; then | ||
echo "annotations failed mypy checks:" | ||
echo "$MYPY_REPORT" | ||
echo "$MYPY_STATUS" | ||
exit "$MYPY_STATUS" | ||
else | ||
echo "annotations passed mypy checks." | ||
echo | ||
fi | ||
} | ||
|
||
|
||
function mypy_data_test { | ||
local PYTEST_REPORT= | ||
local PYTEST_STATUS= | ||
|
||
echo "starting mypy data test..." | ||
|
||
$PYTHON_EXECUTABLE -c "import importlib.util; import sys; \ | ||
sys.exit(0 if importlib.util.find_spec('pytest_mypy_plugins') else 1)" | ||
|
||
if [ $? -ne 0 ]; then | ||
echo "pytest-mypy-plugins missing. Skipping for now." | ||
return | ||
fi | ||
|
||
PYTEST_REPORT=$( (MYPYPATH=python $PYTEST_BUILD \ | ||
-c python/pyproject.toml \ | ||
--rootdir python \ | ||
--mypy-only-local-stub \ | ||
--mypy-ini-file python/mypy.ini \ | ||
python/pyspark ) 2>&1) | ||
|
||
PYTEST_STATUS=$? | ||
|
||
if [ "$PYTEST_STATUS" -ne 0 ]; then | ||
echo "annotations failed data checks:" | ||
echo "$PYTEST_REPORT" | ||
echo "$PYTEST_STATUS" | ||
exit "$PYTEST_STATUS" | ||
else | ||
echo "annotations passed data checks." | ||
echo | ||
fi | ||
} | ||
|
||
|
||
function mypy_test { | ||
if ! hash "$MYPY_BUILD" 2> /dev/null; then | ||
echo "The $MYPY_BUILD command was not found. Skipping for now." | ||
return | ||
|
@@ -142,21 +199,11 @@ function mypy_test { | |
return | ||
fi | ||
|
||
echo "starting mypy test..." | ||
MYPY_REPORT=$( ($MYPY_BUILD --config-file python/mypy.ini python/pyspark) 2>&1) | ||
MYPY_STATUS=$? | ||
|
||
if [ "$MYPY_STATUS" -ne 0 ]; then | ||
echo "mypy checks failed:" | ||
echo "$MYPY_REPORT" | ||
echo "$MYPY_STATUS" | ||
exit "$MYPY_STATUS" | ||
else | ||
echo "mypy checks passed." | ||
echo | ||
fi | ||
mypy_annotation_test | ||
mypy_data_test | ||
} | ||
|
||
|
||
function flake8_test { | ||
local FLAKE8_VERSION= | ||
local EXPECTED_FLAKE8= | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
[tool.pytest.ini_options] | ||
# Pytest it used only to run mypy data tests | ||
python_files = "test_*.yml" | ||
testpaths = [ | ||
"pyspark/tests/typing", | ||
"pyspark/sql/tests/typing", | ||
"pyspark/ml/typing", | ||
] | ||
Comment on lines
+1
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
- case: oneVsRest | ||
main: | | ||
from pyspark.ml.classification import ( | ||
OneVsRest, OneVsRestModel, LogisticRegression, LogisticRegressionModel | ||
) | ||
|
||
# Should support | ||
OneVsRest(classifier=LogisticRegression()) | ||
OneVsRest(classifier=LogisticRegressionModel.load("/foo")) # E: Argument "classifier" to "OneVsRest" has incompatible type "LogisticRegressionModel"; expected "Optional[Estimator[<nothing>]]" [arg-type] | ||
OneVsRest(classifier="foo") # E: Argument "classifier" to "OneVsRest" has incompatible type "str"; expected "Optional[Estimator[<nothing>]]" [arg-type] | ||
|
||
|
||
- case: fitFMClassifier | ||
main: | | ||
from pyspark.sql import SparkSession | ||
from pyspark.ml.classification import FMClassifier, FMClassificationModel | ||
|
||
spark = SparkSession.builder.getOrCreate() | ||
fm_model: FMClassificationModel = FMClassifier().fit(spark.read.parquet("/foo")) | ||
fm_model.linear.toArray() | ||
fm_model.factors.numRows |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
- case: BinaryClassificationEvaluator | ||
main: | | ||
from pyspark.ml.evaluation import BinaryClassificationEvaluator | ||
|
||
BinaryClassificationEvaluator().setMetricName("areaUnderROC") | ||
BinaryClassificationEvaluator(metricName="areaUnderPR") | ||
|
||
BinaryClassificationEvaluator().setMetricName("foo") # E: Argument 1 to "setMetricName" of "BinaryClassificationEvaluator" has incompatible type "Literal['foo']"; expected "Union[Literal['areaUnderROC'], Literal['areaUnderPR']]" [arg-type] | ||
BinaryClassificationEvaluator(metricName="bar") # E: Argument "metricName" to "BinaryClassificationEvaluator" has incompatible type "Literal['bar']"; expected "Union[Literal['areaUnderROC'], Literal['areaUnderPR']]" [arg-type] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
- case: stringIndexerOverloads | ||
main: | | ||
from pyspark.ml.feature import StringIndexer | ||
|
||
# No arguments is OK | ||
StringIndexer() | ||
|
||
StringIndexer(inputCol="foo") | ||
StringIndexer(outputCol="bar") | ||
StringIndexer(inputCol="foo", outputCol="bar") | ||
|
||
StringIndexer(inputCols=["foo"]) | ||
StringIndexer(outputCols=["bar"]) | ||
StringIndexer(inputCols=["foo"], outputCols=["bar"]) | ||
|
||
StringIndexer(inputCol="foo", outputCols=["bar"]) | ||
StringIndexer(inputCols=["foo"], outputCol="bar") | ||
|
||
out: | | ||
main:14: error: No overload variant of "StringIndexer" matches argument types "str", "List[str]" [call-overload] | ||
main:14: note: Possible overload variants: | ||
main:14: note: def StringIndexer(self, *, inputCol: Optional[str] = ..., outputCol: Optional[str] = ..., handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer | ||
main:14: note: def StringIndexer(self, *, inputCols: Optional[List[str]] = ..., outputCols: Optional[List[str]] = ..., handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer | ||
main:15: error: No overload variant of "StringIndexer" matches argument types "List[str]", "str" [call-overload] | ||
main:15: note: Possible overload variants: | ||
main:15: note: def StringIndexer(self, *, inputCol: Optional[str] = ..., outputCol: Optional[str] = ..., handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer | ||
main:15: note: def StringIndexer(self, *, inputCols: Optional[List[str]] = ..., outputCols: Optional[List[str]] = ..., handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
- case: paramGenric | ||
main: | | ||
from pyspark.ml.param import Param, Params, TypeConverters | ||
|
||
class Foo(Params): | ||
foo = Param(Params(), "foo", "foo", TypeConverters.toInt) | ||
def getFoo(self) -> int: | ||
return self.getOrDefault(self.foo) | ||
|
||
class Bar(Params): | ||
bar = Param(Params(), "bar", "bar", TypeConverters.toInt) | ||
def getFoo(self) -> str: | ||
return self.getOrDefault(self.bar) # E: Incompatible return value type (got "int", expected "str") [return-value] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
- case: readLinearSVCModel | ||
main: | | ||
from pyspark.ml.classification import LinearSVCModel | ||
|
||
model1 = LinearSVCModel.load("dummy") | ||
model1.coefficients.toArray() | ||
model1.foo() # E: "LinearSVCModel" has no attribute "foo" [attr-defined] | ||
|
||
model2 = LinearSVCModel.read().load("dummy") | ||
model2.coefficients.toArray() | ||
model2.foo() # E: "LinearSVCModel" has no attribute "foo" [attr-defined] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
- case: loadFMRegressor | ||
main: | | ||
from pyspark.ml.regression import FMRegressor, FMRegressionModel | ||
|
||
fm = FMRegressor.load("/foo") | ||
fm.setMiniBatchFraction(0.1) | ||
|
||
fm_model = FMRegressionModel.load("/bar") | ||
fm_model.factors.numCols | ||
|
||
fm_model.foo() # E: "FMRegressionModel" has no attribute "foo" [attr-defined] | ||
|
||
|
||
- case: loadLinearRegressor | ||
main: | | ||
from pyspark.ml.regression import LinearRegressionModel | ||
|
||
lr_model = LinearRegressionModel.load("/foo") | ||
lr_model.getLabelCol().upper() | ||
|
||
lr_model.foo # E: "LinearRegressionModel" has no attribute "foo" [attr-defined] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
- case: colDateTimeCompare | ||
main: | | ||
import datetime | ||
from pyspark.sql.functions import col | ||
|
||
today = datetime.date.today() | ||
now = datetime.datetime.now() | ||
a_col = col("") | ||
|
||
a_col < today | ||
a_col <= today | ||
a_col == today | ||
a_col >= today | ||
a_col > today | ||
|
||
a_col < now | ||
a_col <= now | ||
a_col == now | ||
a_col >= now | ||
a_col > now |
Uh oh!
There was an error while loading. Please reload this page.