Skip to content

[SPARK-36989][TESTS][PYTHON] Add type hints data tests #34296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,8 @@ jobs:
# Jinja2 3.0.0+ causes error when building with Sphinx.
# See also https://issues.apache.org/jira/browse/SPARK-35375.
python3.9 -m pip install flake8 pydata_sphinx_theme 'mypy==0.910' numpydoc 'jinja2<3.0.0' 'black==21.5b2'
# TODO Update to PyPI
python3.9 -m pip install git+https://github.com/typeddjango/pytest-mypy-plugins.git@b0020061f48e85743ee3335bd62a3a608d17c6bd
- name: Install R linter dependencies and SparkR
run: |
apt-get install -y libcurl4-openssl-dev libgit2-dev libssl-dev libxml2-dev
Expand Down
75 changes: 61 additions & 14 deletions dev/lint-python
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ MINIMUM_MYPY="0.910"
MYPY_BUILD="mypy"
PYCODESTYLE_BUILD="pycodestyle"
MINIMUM_PYCODESTYLE="2.7.0"
PYTEST_BUILD="pytest"

PYTHON_EXECUTABLE="${PYTHON_EXECUTABLE:-python3}"

Expand Down Expand Up @@ -124,10 +125,66 @@ function pycodestyle_test {
fi
}

function mypy_test {

function mypy_annotation_test {
local MYPY_REPORT=
local MYPY_STATUS=

echo "starting mypy annotations test..."
MYPY_REPORT=$( ($MYPY_BUILD \
--config-file python/mypy.ini \
--cache-dir /tmp/.mypy_cache/ \
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we set --cache-dir here explicitly, we can reuse cached for the data tests.

Alternatively, we can use $PWD as the --mypy-testing-base in the data tests, but I'd prefer to avoid that, because temporary files are written there.

python/pyspark) 2>&1)
MYPY_STATUS=$?

if [ "$MYPY_STATUS" -ne 0 ]; then
echo "annotations failed mypy checks:"
echo "$MYPY_REPORT"
echo "$MYPY_STATUS"
exit "$MYPY_STATUS"
else
echo "annotations passed mypy checks."
echo
fi
}


function mypy_data_test {
local PYTEST_REPORT=
local PYTEST_STATUS=

echo "starting mypy data test..."

$PYTHON_EXECUTABLE -c "import importlib.util; import sys; \
sys.exit(0 if importlib.util.find_spec('pytest_mypy_plugins') else 1)"

if [ $? -ne 0 ]; then
echo "pytest-mypy-plugins missing. Skipping for now."
return
fi

PYTEST_REPORT=$( (MYPYPATH=python $PYTEST_BUILD \
-c python/pyproject.toml \
--rootdir python \
--mypy-only-local-stub \
--mypy-ini-file python/mypy.ini \
python/pyspark ) 2>&1)

PYTEST_STATUS=$?

if [ "$PYTEST_STATUS" -ne 0 ]; then
echo "annotations failed data checks:"
echo "$PYTEST_REPORT"
echo "$PYTEST_STATUS"
exit "$PYTEST_STATUS"
else
echo "annotations passed data checks."
echo
fi
}


function mypy_test {
if ! hash "$MYPY_BUILD" 2> /dev/null; then
echo "The $MYPY_BUILD command was not found. Skipping for now."
return
Expand All @@ -142,21 +199,11 @@ function mypy_test {
return
fi

echo "starting mypy test..."
MYPY_REPORT=$( ($MYPY_BUILD --config-file python/mypy.ini python/pyspark) 2>&1)
MYPY_STATUS=$?

if [ "$MYPY_STATUS" -ne 0 ]; then
echo "mypy checks failed:"
echo "$MYPY_REPORT"
echo "$MYPY_STATUS"
exit "$MYPY_STATUS"
else
echo "mypy checks passed."
echo
fi
mypy_annotation_test
mypy_data_test
}


function flake8_test {
local FLAKE8_VERSION=
local EXPECTED_FLAKE8=
Expand Down
25 changes: 25 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

[tool.pytest.ini_options]
# Pytest it used only to run mypy data tests
python_files = "test_*.yml"
testpaths = [
"pyspark/tests/typing",
"pyspark/sql/tests/typing",
"pyspark/ml/typing",
]
Comment on lines +1 to +25
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should config be placed in python or in dev (we already have dev/tox.ini and python/mypy.ini).

38 changes: 38 additions & 0 deletions python/pyspark/ml/tests/typing/test_classification.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

- case: oneVsRest
main: |
from pyspark.ml.classification import (
OneVsRest, OneVsRestModel, LogisticRegression, LogisticRegressionModel
)

# Should support
OneVsRest(classifier=LogisticRegression())
OneVsRest(classifier=LogisticRegressionModel.load("/foo")) # E: Argument "classifier" to "OneVsRest" has incompatible type "LogisticRegressionModel"; expected "Optional[Estimator[<nothing>]]" [arg-type]
OneVsRest(classifier="foo") # E: Argument "classifier" to "OneVsRest" has incompatible type "str"; expected "Optional[Estimator[<nothing>]]" [arg-type]


- case: fitFMClassifier
main: |
from pyspark.sql import SparkSession
from pyspark.ml.classification import FMClassifier, FMClassificationModel

spark = SparkSession.builder.getOrCreate()
fm_model: FMClassificationModel = FMClassifier().fit(spark.read.parquet("/foo"))
fm_model.linear.toArray()
fm_model.factors.numRows
26 changes: 26 additions & 0 deletions python/pyspark/ml/tests/typing/test_evaluation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

- case: BinaryClassificationEvaluator
main: |
from pyspark.ml.evaluation import BinaryClassificationEvaluator

BinaryClassificationEvaluator().setMetricName("areaUnderROC")
BinaryClassificationEvaluator(metricName="areaUnderPR")

BinaryClassificationEvaluator().setMetricName("foo") # E: Argument 1 to "setMetricName" of "BinaryClassificationEvaluator" has incompatible type "Literal['foo']"; expected "Union[Literal['areaUnderROC'], Literal['areaUnderPR']]" [arg-type]
BinaryClassificationEvaluator(metricName="bar") # E: Argument "metricName" to "BinaryClassificationEvaluator" has incompatible type "Literal['bar']"; expected "Union[Literal['areaUnderROC'], Literal['areaUnderPR']]" [arg-type]
44 changes: 44 additions & 0 deletions python/pyspark/ml/tests/typing/test_feature.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

- case: stringIndexerOverloads
main: |
from pyspark.ml.feature import StringIndexer

# No arguments is OK
StringIndexer()

StringIndexer(inputCol="foo")
StringIndexer(outputCol="bar")
StringIndexer(inputCol="foo", outputCol="bar")

StringIndexer(inputCols=["foo"])
StringIndexer(outputCols=["bar"])
StringIndexer(inputCols=["foo"], outputCols=["bar"])

StringIndexer(inputCol="foo", outputCols=["bar"])
StringIndexer(inputCols=["foo"], outputCol="bar")

out: |
main:14: error: No overload variant of "StringIndexer" matches argument types "str", "List[str]" [call-overload]
main:14: note: Possible overload variants:
main:14: note: def StringIndexer(self, *, inputCol: Optional[str] = ..., outputCol: Optional[str] = ..., handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer
main:14: note: def StringIndexer(self, *, inputCols: Optional[List[str]] = ..., outputCols: Optional[List[str]] = ..., handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer
main:15: error: No overload variant of "StringIndexer" matches argument types "List[str]", "str" [call-overload]
main:15: note: Possible overload variants:
main:15: note: def StringIndexer(self, *, inputCol: Optional[str] = ..., outputCol: Optional[str] = ..., handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer
main:15: note: def StringIndexer(self, *, inputCols: Optional[List[str]] = ..., outputCols: Optional[List[str]] = ..., handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer
30 changes: 30 additions & 0 deletions python/pyspark/ml/tests/typing/test_param.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

- case: paramGenric
main: |
from pyspark.ml.param import Param, Params, TypeConverters

class Foo(Params):
foo = Param(Params(), "foo", "foo", TypeConverters.toInt)
def getFoo(self) -> int:
return self.getOrDefault(self.foo)

class Bar(Params):
bar = Param(Params(), "bar", "bar", TypeConverters.toInt)
def getFoo(self) -> str:
return self.getOrDefault(self.bar) # E: Incompatible return value type (got "int", expected "str") [return-value]
28 changes: 28 additions & 0 deletions python/pyspark/ml/tests/typing/test_readable.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

- case: readLinearSVCModel
main: |
from pyspark.ml.classification import LinearSVCModel

model1 = LinearSVCModel.load("dummy")
model1.coefficients.toArray()
model1.foo() # E: "LinearSVCModel" has no attribute "foo" [attr-defined]

model2 = LinearSVCModel.read().load("dummy")
model2.coefficients.toArray()
model2.foo() # E: "LinearSVCModel" has no attribute "foo" [attr-defined]
38 changes: 38 additions & 0 deletions python/pyspark/ml/tests/typing/test_regression.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

- case: loadFMRegressor
main: |
from pyspark.ml.regression import FMRegressor, FMRegressionModel

fm = FMRegressor.load("/foo")
fm.setMiniBatchFraction(0.1)

fm_model = FMRegressionModel.load("/bar")
fm_model.factors.numCols

fm_model.foo() # E: "FMRegressionModel" has no attribute "foo" [attr-defined]


- case: loadLinearRegressor
main: |
from pyspark.ml.regression import LinearRegressionModel

lr_model = LinearRegressionModel.load("/foo")
lr_model.getLabelCol().upper()

lr_model.foo # E: "LinearRegressionModel" has no attribute "foo" [attr-defined]
37 changes: 37 additions & 0 deletions python/pyspark/sql/tests/typing/test_column.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

- case: colDateTimeCompare
main: |
import datetime
from pyspark.sql.functions import col

today = datetime.date.today()
now = datetime.datetime.now()
a_col = col("")

a_col < today
a_col <= today
a_col == today
a_col >= today
a_col > today

a_col < now
a_col <= now
a_col == now
a_col >= now
a_col > now
Loading