Skip to content

Commit

Permalink
[dagster-pandera] Update SchemaModel to DataFrameModel (#22749)
Browse files Browse the repository at this point in the history
## Summary & Motivation

Pandera 0.20.0 removed a deprecated API we were using called
`SchemaModel`. Update to the new name `DataFrameModel`, which has been
around since at least 0.15.0 judging from the release notes.

Also drop numpy<2 pin, which new pandera release correctly handles on
its own.

## How I Tested These Changes

Existing test suite.
  • Loading branch information
smackesey authored Jun 27, 2024
1 parent 684409b commit 00d2941
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# ***** TYPES ****************************************************************


class StockPrices(pa.SchemaModel):
class StockPrices(pa.DataFrameModel):
"""Open/high/low/close prices for a set of stocks by day."""

name: Series[str] = pa.Field(description="Ticker symbol of stock")
Expand All @@ -28,7 +28,7 @@ class StockPrices(pa.SchemaModel):
StockPricesDgType = pandera_schema_to_dagster_type(StockPrices)


class BollingerBands(pa.SchemaModel):
class BollingerBands(pa.DataFrameModel):
"""Bollinger bands for a set of stock prices."""

name: Series[str] = pa.Field(description="Ticker symbol of stock")
Expand All @@ -40,7 +40,7 @@ class BollingerBands(pa.SchemaModel):
BollingerBandsDgType = pandera_schema_to_dagster_type(BollingerBands)


class AnomalousEvents(pa.SchemaModel):
class AnomalousEvents(pa.DataFrameModel):
"""Anomalous price events, defined by a day on which a stock's closing price strayed above or
below its Bollinger bands.
"""
Expand Down
1 change: 1 addition & 0 deletions examples/assets_pandas_type_metadata/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"seaborn",
"pandera",
"pandas",
"pyarrow",
],
extras_require={"dev": ["dagster-webserver", "pytest"]},
)
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@


def pandera_schema_to_dagster_type(
schema: Union[pa.DataFrameSchema, Type[pa.SchemaModel]],
schema: Union[pa.DataFrameSchema, Type[pa.DataFrameModel]],
) -> DagsterType:
"""Convert a Pandera dataframe schema to a `DagsterType`.
Expand All @@ -78,24 +78,24 @@ def pandera_schema_to_dagster_type(
- `failure_sample` a table containing up to the first 10 validation errors.
Args:
schema (Union[pa.DataFrameSchema, Type[pa.SchemaModel]]):
schema (Union[pa.DataFrameSchema, Type[pa.DataFrameModel]]):
Returns:
DagsterType: Dagster Type constructed from the Pandera schema.
"""
if not (
isinstance(schema, pa.DataFrameSchema)
or (isinstance(schema, type) and issubclass(schema, pa.SchemaModel))
or (isinstance(schema, type) and issubclass(schema, pa.DataFrameModel))
):
raise TypeError(
"schema must be a pandera `DataFrameSchema` or a subclass of a pandera `SchemaModel`"
"schema must be a pandera `DataFrameSchema` or a subclass of a pandera `DataFrameModel`"
)

name = _extract_name_from_pandera_schema(schema)
norm_schema = (
schema.to_schema()
if isinstance(schema, type) and issubclass(schema, pa.SchemaModel)
if isinstance(schema, type) and issubclass(schema, pa.DataFrameModel)
else schema
)
tschema = _pandera_schema_to_table_schema(norm_schema)
Expand All @@ -117,9 +117,9 @@ def pandera_schema_to_dagster_type(


def _extract_name_from_pandera_schema(
schema: Union[pa.DataFrameSchema, Type[pa.SchemaModel]],
schema: Union[pa.DataFrameSchema, Type[pa.DataFrameModel]],
) -> str:
if isinstance(schema, type) and issubclass(schema, pa.SchemaModel):
if isinstance(schema, type) and issubclass(schema, pa.DataFrameModel):
return (
getattr(schema.Config, "title", None)
or getattr(schema.Config, "name", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class Config(BaseConfig):


def sample_schema_model(**config_attrs):
class SampleSchemaModel(pa.SchemaModel):
class SampleDataframeModel(pa.DataFrameModel):
a: pa.typing.Series[int] = pa.Field(le=10, description="a desc")
b: pa.typing.Series[float] = pa.Field(lt=-1.2, description="b desc")
c: pa.typing.Series[str] = pa.Field(str_startswith="value_", description="c desc")
Expand All @@ -86,7 +86,7 @@ def a_gt_b(cls, df):

Config = make_schema_model_config(**config_attrs)

return SampleSchemaModel
return SampleDataframeModel


@pytest.fixture(
Expand Down
3 changes: 2 additions & 1 deletion python_modules/libraries/dagster-pandera/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@ def get_version() -> str:
install_requires=[
f"dagster{pin}",
"pandas",
"pandera>=0.14.2,<0.20.0",
"pandera>=0.15.0",
# Pin numpy pending release of pandera that either supports numpy 2 or adds a pin
"numpy<2",
],
extras_require={
"test": [
"pytest",
"pyarrow", # optional dep of dagster-pandera
],
},
)
2 changes: 1 addition & 1 deletion python_modules/libraries/dagster-pandera/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ install_command = uv pip install {opts} {packages}
deps =
-e ../../dagster[test]
-e ../../dagster-pipes
-e .
-e .[test]

allowlist_externals =
/bin/bash
Expand Down

0 comments on commit 00d2941

Please sign in to comment.