Skip to content

Commit

Permalink
remove pyspark dep from common types (#1268)
Browse files Browse the repository at this point in the history
Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy authored Jul 17, 2023
1 parent 7eba3a0 commit 8413a3a
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 73 deletions.
4 changes: 1 addition & 3 deletions pandera/api/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
Union,
)

import pandas as pd

from pandera import errors
from pandera.api.base.checks import BaseCheck, CheckResult
from pandera.strategies import SearchStrategy
Expand Down Expand Up @@ -200,7 +198,7 @@ def __init__(

def __call__(
self,
check_obj: Union[pd.DataFrame, pd.Series],
check_obj: Any,
column: Optional[str] = None,
) -> CheckResult:
# pylint: disable=too-many-branches
Expand Down
6 changes: 1 addition & 5 deletions pandera/api/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,7 @@ def register_check_method( # pylint:disable=too-many-branches
if supported_type not in ALLOWED_TYPES:
raise TypeError(msg.format(supported_type))

if check_type is CheckType.ELEMENT_WISE and set(supported_types) != {
pd.DataFrame,
pd.Series,
ps.DataFrame,
}: # type: ignore
if check_type is CheckType.ELEMENT_WISE and set(supported_types) != ALLOWED_TYPES: # type: ignore
raise ValueError(
"Element-wise checks should support DataFrame and Series "
"validation. Use the default setting for the 'supported_types' "
Expand Down
5 changes: 4 additions & 1 deletion pandera/api/hypotheses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Data validation checks for hypothesis testing."""

from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union

from pandera import errors
from pandera.api.checks import Check
Expand All @@ -9,6 +9,9 @@
DEFAULT_ALPHA = 0.01


T = TypeVar("T")


class Hypothesis(Check):
"""Special type of :class:`Check` that defines hypothesis tests on data."""

Expand Down
8 changes: 4 additions & 4 deletions pandera/backends/pandas/builtin_hypotheses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@

from pandera.api.extensions import register_builtin_hypothesis
from pandera.backends.pandas.builtin_checks import PandasData
from pandera.backends.pandas.hypotheses import HAS_SCIPY

if HAS_SCIPY:
from scipy import stats


@register_builtin_hypothesis(
Expand All @@ -20,6 +16,8 @@ def two_sample_ttest(
equal_var: bool = True,
nan_policy: str = "propagate",
) -> Tuple[float, float]:
from scipy import stats # pylint: disable=import-outside-toplevel

assert (
len(samples) == 2
), "Expected two sample ttest data to contain exactly two samples"
Expand All @@ -40,6 +38,8 @@ def one_sample_ttest(
popmean: float,
nan_policy: str = "propagate",
) -> Tuple[float, float]:
from scipy import stats # pylint: disable=import-outside-toplevel

assert (
len(samples) == 1
), "Expected one sample ttest data to contain only one sample"
Expand Down
10 changes: 0 additions & 10 deletions pandera/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,6 @@
INT16,
INT32,
INT64,
PYSPARK_BINARY,
PYSPARK_BYTEINT,
PYSPARK_DATE,
PYSPARK_DECIMAL,
PYSPARK_FLOAT,
PYSPARK_INT,
PYSPARK_LONGINT,
PYSPARK_SHORTINT,
PYSPARK_STRING,
PYSPARK_TIMESTAMP,
STRING,
UINT8,
UINT16,
Expand Down
48 changes: 2 additions & 46 deletions pandera/typing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import typing_inspect

from pandera import dtypes
from pandera.engines import numpy_engine, pandas_engine, pyspark_engine
from pandera.engines import numpy_engine, pandas_engine

Bool = dtypes.Bool #: ``"bool"`` numpy dtype
Date = dtypes.Date #: ``datetime.date`` object dtype
Expand Down Expand Up @@ -43,17 +43,7 @@
#: fall back on the str-as-object-array representation.
STRING = pandas_engine.STRING #: ``"str"`` numpy dtype
BOOL = pandas_engine.BOOL #: ``"str"`` numpy dtype
PYSPARK_STRING = pyspark_engine.String
PYSPARK_INT = pyspark_engine.Int
PYSPARK_LONGINT = pyspark_engine.BigInt
PYSPARK_SHORTINT = pyspark_engine.ShortInt
PYSPARK_BYTEINT = pyspark_engine.ByteInt
PYSPARK_DOUBLE = pyspark_engine.Double
PYSPARK_FLOAT = pyspark_engine.Float
PYSPARK_DECIMAL = pyspark_engine.Decimal
PYSPARK_DATE = pyspark_engine.Date
PYSPARK_TIMESTAMP = pyspark_engine.Timestamp
PYSPARK_BINARY = pyspark_engine.Binary


try:
Geometry = pandas_engine.Geometry # : ``"geometry"`` geopandas dtype
Expand Down Expand Up @@ -101,16 +91,6 @@
String,
STRING,
Geometry,
pyspark_engine.String,
pyspark_engine.Int,
pyspark_engine.BigInt,
pyspark_engine.ShortInt,
pyspark_engine.ByteInt,
pyspark_engine.Float,
pyspark_engine.Decimal,
pyspark_engine.Date,
pyspark_engine.Timestamp,
pyspark_engine.Binary,
],
)
else:
Expand Down Expand Up @@ -152,16 +132,6 @@
Object,
String,
STRING,
pyspark_engine.String,
pyspark_engine.Int,
pyspark_engine.BigInt,
pyspark_engine.ShortInt,
pyspark_engine.ByteInt,
pyspark_engine.Float,
pyspark_engine.Decimal,
pyspark_engine.Date,
pyspark_engine.Timestamp,
pyspark_engine.Binary,
],
)

Expand Down Expand Up @@ -236,20 +206,6 @@ def __get__(
raise AttributeError("Indexes should resolve to pa.Index-s")


class ColumnBase(Generic[GenericDtype]):
"""Representation of pandas.Index, only used for type annotation.
*new in 0.5.0*
"""

default_dtype: Optional[Type] = None

def __get__(
self, instance: object, owner: Type
) -> str: # pragma: no cover
raise AttributeError("column should resolve to pyspark.sql.Column-s")


class AnnotationInfo: # pylint:disable=too-few-public-methods
"""Captures extra information about an annotation.
Expand Down
15 changes: 11 additions & 4 deletions tests/core/test_extension_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@

import pytest

import pandas as pd

from pandera.api.hypotheses import Hypothesis
from pandera.backends.pandas.hypotheses import HAS_SCIPY


def test_hypotheses_module_import() -> None:
"""Test that Hypothesis built-in methods raise import error."""
data = pd.Series([1, 2, 3])
if not HAS_SCIPY:
for fn in [
lambda: Hypothesis.two_sample_ttest("sample1", "sample2"), # type: ignore[arg-type]
lambda: Hypothesis.one_sample_ttest(popmean=10),
for fn, check_args in [
(
lambda: Hypothesis.two_sample_ttest("sample1", "sample2"),
pd.DataFrame({"sample1": data, "sample2": data}),
),
(lambda: Hypothesis.one_sample_ttest(popmean=10), data),
]:
with pytest.raises(ImportError):
fn()
check = fn()
check(check_args)

0 comments on commit 8413a3a

Please sign in to comment.