Skip to content
This repository has been archived by the owner on Mar 11, 2024. It is now read-only.

Commit

Permalink
⭐ add BigQueryPolarsIOManager (#8)
Browse files Browse the repository at this point in the history
⭐ add BigQueryPolarsIOManager
  • Loading branch information
danielgafni authored Jun 25, 2023
1 parent ffc9627 commit 53225ee
Show file tree
Hide file tree
Showing 7 changed files with 918 additions and 49 deletions.
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
[Polars](https://github.com/pola-rs/polars) integration library for [Dagster](https://github.com/dagster-io/dagster).

## Features
- All IOManagers log various metadata about the DataFrame - size, schema, sample, stats, ...
- For all IOManagers the `"columns"` input metadata key can be used to select a subset of columns to load
- `BasePolarsUPathIOManager` is a base class for IO managers that work with Polars DataFrames. Shouldn't be used directly unless you want to implement your own `IOManager`.
- returns the correct type (`polars.DataFrame` or `polars.LazyFrame`) based on the type annotation
- logs various metadata about the DataFrame - size, schema, sample, stats, ...
- the "columns" input metadata value can be used to select a subset of columns
- inherits all the features of the `UPathIOManager` - works with local and remote filesystems (like S3),
supports loading multiple partitions (use `dict[str, pl.DataFrame]` type annotation), ...
- Implemented serialization formats:
- `PolarsParquetIOManager` - for reading and writing files in Apache Parquet format. Supports reading partitioned Parquet datasets (for example, often produced by Spark).
- `BigQueryPolarsIOManager` - for reading and writing data from/to [BigQuery](https://cloud.google.com/bigquery). Supports writing partitioned tables (`"partition_expr"` input metadata key must be specified).

## Quickstart

Expand All @@ -20,6 +21,12 @@
pip install dagster-polars
```

To use the `BigQueryPolarsIOManager` you need to install the `gcp` extra:
```shell
pip install 'dagster-polars[gcp]'
```


### Usage
```python
import polars as pl
Expand Down Expand Up @@ -61,6 +68,6 @@ poetry run pytest
```

## TODO

- [ ] Add `PolarsDeltaIOManager`
- [ ] Data validation like in [dagster-pandas](https://docs.dagster.io/integrations/pandas#validating-pandas-dataframes-with-dagster-types)
- [ ] Maybe use `DagsterTypeLoader` ?
19 changes: 3 additions & 16 deletions dagster_polars/io_managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from pydantic.fields import Field, PrivateAttr
from upath import UPath

from dagster_polars.io_managers.utils import get_polars_metadata

POLARS_DATA_FRAME_ANNOTATIONS = [
Any,
pl.DataFrame,
Expand Down Expand Up @@ -175,19 +177,4 @@ def load_from_path(self, path: UPath, context: InputContext) -> Union[pl.DataFra
raise NotImplementedError(f"Can't load object for type annotation {context.dagster_type.typing_type}")

def get_metadata(self, context: OutputContext, obj: pl.DataFrame) -> Dict[str, MetadataValue]:
assert context.metadata is not None
schema, table = get_metadata_table_and_schema(
context=context, df=obj, descriptions=context.metadata.get("descriptions")
)

metadata = {
"stats": MetadataValue.json(get_polars_df_stats(obj)),
"row_count": MetadataValue.int(len(obj)),
}

if table is not None:
metadata["table"] = table
else:
metadata["schema"] = schema

return metadata
return get_polars_metadata(context, obj)
161 changes: 161 additions & 0 deletions dagster_polars/io_managers/bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import Optional, Sequence, Type

import polars as pl
from dagster import InputContext, OutputContext
from dagster._core.storage.db_io_manager import DbTypeHandler, TableSlice
from dagster_gcp.bigquery.io_manager import BigQueryClient, BigQueryIOManager
from google.cloud import bigquery as bigquery

from dagster_polars.io_managers.utils import get_polars_metadata

# The code below is mostly copied from `dagster-gcp-pandas`
# with a few improvements


class BigQueryPolarsTypeHandler(DbTypeHandler[pl.DataFrame]):
"""Plugin for the BigQuery I/O Manager that can store and load Polars DataFrames as BigQuery tables.
Examples:
.. code-block:: python
from dagster_gcp import BigQueryIOManager
from dagster_bigquery_polars import BigQueryPolarsTypeHandler
from dagster import Definitions, EnvVar
class MyBigQueryIOManager(BigQueryIOManager):
@staticmethod
def type_handlers() -> Sequence[DbTypeHandler]:
return [BigQueryPolarsTypeHandler()]
@asset(
key_prefix=["my_dataset"] # my_dataset will be used as the dataset in BigQuery
)
def my_table() -> pd.DataFrame: # the name of the asset will be the table name
...
defs = Definitions(
assets=[my_table],
resources={
"io_manager": MyBigQueryIOManager(project=EnvVar("GCP_PROJECT"))
}
)
"""

def handle_output(self, context: OutputContext, table_slice: TableSlice, obj: pl.DataFrame, connection):
"""Stores the polars DataFrame in BigQuery."""
assert isinstance(connection, bigquery.Client)
assert context.metadata is not None
job_config = bigquery.LoadJobConfig(write_disposition=context.metadata.get("write_disposition"))

# FIXME: load_table_from_dataframe writes the dataframe to a temporary parquet file
# and then calls load_table_from_file. This can cause problems in cloud environments
# therefore, it's better to use load_table_from_uri with GCS,
# but this requires the remote filesystem to be available in this code
job = connection.load_table_from_dataframe(
dataframe=obj.to_pandas(),
destination=f"{table_slice.schema}.{table_slice.table}",
project=table_slice.database, # type: ignore
location=context.resource_config.get("location") if context.resource_config else None, # type: ignore
timeout=context.resource_config.get("timeout") if context.resource_config else None, # type: ignore
job_config=job_config,
)
job.result()

context.add_output_metadata(get_polars_metadata(context=context, df=obj))

def load_input(self, context: InputContext, table_slice: TableSlice, connection) -> pl.DataFrame:
"""Loads the input as a Polars DataFrame."""
assert isinstance(connection, bigquery.Client)

if table_slice.partition_dimensions and len(context.asset_partition_keys) == 0:
return pl.DataFrame()
result = connection.query(
query=BigQueryClient.get_select_statement(table_slice),
project=table_slice.database, # type: ignore
location=context.resource_config.get("location") if context.resource_config else None, # type: ignore
timeout=context.resource_config.get("timeout") if context.resource_config else None, # type: ignore
).to_arrow()

return pl.DataFrame(result)

@property
def supported_types(self):
return [pl.DataFrame]


class BigQueryPolarsIOManager(BigQueryIOManager):
"""An I/O manager definition that reads inputs from and writes polars DataFrames to BigQuery.
Returns:
IOManagerDefinition
Examples:
.. code-block:: python
from dagster_gcp_polars import BigQueryPolarsIOManager
from dagster import Definitions, EnvVar
@asset(
key_prefix=["my_dataset"] # will be used as the dataset in BigQuery
)
def my_table() -> pd.DataFrame: # the name of the asset will be the table name
...
defs = Definitions(
assets=[my_table],
resources={
"io_manager": BigQueryPolarsIOManager(project=EnvVar("GCP_PROJECT"))
}
)
You can tell Dagster in which dataset to create tables by setting the "dataset" configuration value.
If you do not provide a dataset as configuration to the I/O manager, Dagster will determine a dataset based
on the assets and ops using the I/O Manager. For assets, the dataset will be determined from the asset key,
as shown in the above example. The final prefix before the asset name will be used as the dataset. For example,
if the asset "my_table" had the key prefix ["gcp", "bigquery", "my_dataset"], the dataset "my_dataset" will be
used. For ops, the dataset can be specified by including a "schema" entry in output metadata. If "schema" is
not provided via config or on the asset/op, "public" will be used for the dataset.
.. code-block:: python
@op(
out={"my_table": Out(metadata={"schema": "my_dataset"})}
)
def make_my_table() -> pd.DataFrame:
# the returned value will be stored at my_dataset.my_table
...
To only use specific columns of a table as input to a downstream op or asset, add the metadata "columns" to the
In or AssetIn.
.. code-block:: python
@asset(
ins={"my_table": AssetIn("my_table", metadata={"columns": ["a"]})}
)
def my_table_a(my_table: pd.DataFrame) -> pd.DataFrame:
# my_table will just contain the data from column "a"
...
If you cannot upload a file to your Dagster deployment, or otherwise cannot
`authenticate with GCP <https://cloud.google.com/docs/authentication/provide-credentials-adc>`_
via a standard method, you can provide a service account key as the "gcp_credentials" configuration.
Dagster will store this key in a temporary file and set GOOGLE_APPLICATION_CREDENTIALS to point to the file.
After the run completes, the file will be deleted, and GOOGLE_APPLICATION_CREDENTIALS will be
unset. The key must be base64 encoded to avoid issues with newlines in the keys. You can retrieve
the base64 encoded key with this shell command: cat $GOOGLE_APPLICATION_CREDENTIALS | base64
The "write_disposition" metadata key can be used to set the `write_disposition` parameter
of `bigquery.JobConfig`. For example, set it to `"WRITE_APPEND"` to append to an existing table intead of
overwriting it.
"""

@staticmethod
def type_handlers() -> Sequence[DbTypeHandler]:
return [BigQueryPolarsTypeHandler()]

@staticmethod
def default_load_type() -> Optional[Type]:
return pl.DataFrame
131 changes: 131 additions & 0 deletions dagster_polars/io_managers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import json
import sys
from datetime import date, datetime, time, timedelta
from pprint import pformat
from typing import Any, Dict, Mapping, Optional, Tuple, Union

import polars as pl
from dagster import MetadataValue, OutputContext, TableColumn, TableMetadataValue, TableRecord, TableSchema

POLARS_DATA_FRAME_ANNOTATIONS = [
Any,
pl.DataFrame,
Dict[str, pl.DataFrame],
Mapping[str, pl.DataFrame],
type(None),
None,
]

POLARS_LAZY_FRAME_ANNOTATIONS = [
pl.LazyFrame,
Dict[str, pl.LazyFrame],
Mapping[str, pl.LazyFrame],
]


if sys.version >= "3.9":
POLARS_DATA_FRAME_ANNOTATIONS.append(dict[str, pl.DataFrame]) # type: ignore
POLARS_LAZY_FRAME_ANNOTATIONS.append(dict[str, pl.DataFrame]) # type: ignore


def cast_polars_single_value_to_dagster_table_types(val: Any):
if val is None:
return ""
elif isinstance(val, (date, datetime, time, timedelta)):
return str(val)
elif isinstance(val, (list, dict)):
# default=str because sometimes the object can be a list of datetimes or something like this
return json.dumps(val, default=str)
else:
return val


def get_metadata_schema(
df: pl.DataFrame,
descriptions: Optional[Dict[str, str]] = None,
):
descriptions = descriptions or {}
return TableSchema(
columns=[
TableColumn(name=col, type=str(pl_type), description=descriptions.get(col))
for col, pl_type in df.schema.items()
]
)


def get_metadata_table_and_schema(
context: OutputContext,
df: pl.DataFrame,
n_rows: Optional[int] = 5,
fraction: Optional[float] = None,
descriptions: Optional[Dict[str, str]] = None,
) -> Tuple[TableSchema, Optional[TableMetadataValue]]:
assert not fraction and n_rows, "only one of n_rows and frac should be set"
n_rows = min(n_rows, len(df))

schema = get_metadata_schema(df, descriptions=descriptions)

df_sample = df.sample(n=n_rows, fraction=fraction, shuffle=True)

try:
# this can fail sometimes
# because TableRecord doesn't support all python types
table = MetadataValue.table(
records=[
TableRecord(
{
col: cast_polars_single_value_to_dagster_table_types( # type: ignore
df_sample.to_dicts()[i][col]
)
for col in df.columns
}
)
for i in range(len(df_sample))
],
schema=schema,
)

except TypeError as e:
context.log.error(
f"Failed to create table sample metadata. Will only record table schema metadata. "
f"Reason:\n{e}\n"
f"Schema:\n{df.schema}\n"
f"Polars sample:\n{df_sample}\n"
f"dict sample:\n{pformat(df_sample.to_dicts())}"
)
return schema, None

return schema, table


def get_polars_df_stats(
df: pl.DataFrame,
) -> Dict[str, Dict[str, Union[str, int, float]]]:
describe = df.describe().fill_null(pl.lit("null"))
return {
col: {stat: describe[col][i] for i, stat in enumerate(describe["describe"].to_list())}
for col in describe.columns[1:]
}


def get_polars_metadata(context: OutputContext, df: pl.DataFrame) -> Dict[str, MetadataValue]:
assert context.metadata is not None
schema, table = get_metadata_table_and_schema(
context=context,
df=df,
n_rows=context.metadata.get("n_rows", 5),
fraction=context.metadata.get("fraction"),
descriptions=context.metadata.get("descriptions"),
)

metadata = {
"stats": MetadataValue.json(get_polars_df_stats(df)),
"row_count": MetadataValue.int(len(df)),
}

if table is not None:
metadata["table"] = table
else:
metadata["schema"] = schema

return metadata
Loading

0 comments on commit 53225ee

Please sign in to comment.