This repository has been archived by the owner on Mar 11, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
7 changed files
with
918 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from typing import Optional, Sequence, Type | ||
|
||
import polars as pl | ||
from dagster import InputContext, OutputContext | ||
from dagster._core.storage.db_io_manager import DbTypeHandler, TableSlice | ||
from dagster_gcp.bigquery.io_manager import BigQueryClient, BigQueryIOManager | ||
from google.cloud import bigquery as bigquery | ||
|
||
from dagster_polars.io_managers.utils import get_polars_metadata | ||
|
||
# The code below is mostly copied from `dagster-gcp-pandas` | ||
# with a few improvements | ||
|
||
|
||
class BigQueryPolarsTypeHandler(DbTypeHandler[pl.DataFrame]): | ||
"""Plugin for the BigQuery I/O Manager that can store and load Polars DataFrames as BigQuery tables. | ||
Examples: | ||
.. code-block:: python | ||
from dagster_gcp import BigQueryIOManager | ||
from dagster_bigquery_polars import BigQueryPolarsTypeHandler | ||
from dagster import Definitions, EnvVar | ||
class MyBigQueryIOManager(BigQueryIOManager): | ||
@staticmethod | ||
def type_handlers() -> Sequence[DbTypeHandler]: | ||
return [BigQueryPolarsTypeHandler()] | ||
@asset( | ||
key_prefix=["my_dataset"] # my_dataset will be used as the dataset in BigQuery | ||
) | ||
def my_table() -> pd.DataFrame: # the name of the asset will be the table name | ||
... | ||
defs = Definitions( | ||
assets=[my_table], | ||
resources={ | ||
"io_manager": MyBigQueryIOManager(project=EnvVar("GCP_PROJECT")) | ||
} | ||
) | ||
""" | ||
|
||
def handle_output(self, context: OutputContext, table_slice: TableSlice, obj: pl.DataFrame, connection): | ||
"""Stores the polars DataFrame in BigQuery.""" | ||
assert isinstance(connection, bigquery.Client) | ||
assert context.metadata is not None | ||
job_config = bigquery.LoadJobConfig(write_disposition=context.metadata.get("write_disposition")) | ||
|
||
# FIXME: load_table_from_dataframe writes the dataframe to a temporary parquet file | ||
# and then calls load_table_from_file. This can cause problems in cloud environments | ||
# therefore, it's better to use load_table_from_uri with GCS, | ||
# but this requires the remote filesystem to be available in this code | ||
job = connection.load_table_from_dataframe( | ||
dataframe=obj.to_pandas(), | ||
destination=f"{table_slice.schema}.{table_slice.table}", | ||
project=table_slice.database, # type: ignore | ||
location=context.resource_config.get("location") if context.resource_config else None, # type: ignore | ||
timeout=context.resource_config.get("timeout") if context.resource_config else None, # type: ignore | ||
job_config=job_config, | ||
) | ||
job.result() | ||
|
||
context.add_output_metadata(get_polars_metadata(context=context, df=obj)) | ||
|
||
def load_input(self, context: InputContext, table_slice: TableSlice, connection) -> pl.DataFrame: | ||
"""Loads the input as a Polars DataFrame.""" | ||
assert isinstance(connection, bigquery.Client) | ||
|
||
if table_slice.partition_dimensions and len(context.asset_partition_keys) == 0: | ||
return pl.DataFrame() | ||
result = connection.query( | ||
query=BigQueryClient.get_select_statement(table_slice), | ||
project=table_slice.database, # type: ignore | ||
location=context.resource_config.get("location") if context.resource_config else None, # type: ignore | ||
timeout=context.resource_config.get("timeout") if context.resource_config else None, # type: ignore | ||
).to_arrow() | ||
|
||
return pl.DataFrame(result) | ||
|
||
@property | ||
def supported_types(self): | ||
return [pl.DataFrame] | ||
|
||
|
||
class BigQueryPolarsIOManager(BigQueryIOManager): | ||
"""An I/O manager definition that reads inputs from and writes polars DataFrames to BigQuery. | ||
Returns: | ||
IOManagerDefinition | ||
Examples: | ||
.. code-block:: python | ||
from dagster_gcp_polars import BigQueryPolarsIOManager | ||
from dagster import Definitions, EnvVar | ||
@asset( | ||
key_prefix=["my_dataset"] # will be used as the dataset in BigQuery | ||
) | ||
def my_table() -> pd.DataFrame: # the name of the asset will be the table name | ||
... | ||
defs = Definitions( | ||
assets=[my_table], | ||
resources={ | ||
"io_manager": BigQueryPolarsIOManager(project=EnvVar("GCP_PROJECT")) | ||
} | ||
) | ||
You can tell Dagster in which dataset to create tables by setting the "dataset" configuration value. | ||
If you do not provide a dataset as configuration to the I/O manager, Dagster will determine a dataset based | ||
on the assets and ops using the I/O Manager. For assets, the dataset will be determined from the asset key, | ||
as shown in the above example. The final prefix before the asset name will be used as the dataset. For example, | ||
if the asset "my_table" had the key prefix ["gcp", "bigquery", "my_dataset"], the dataset "my_dataset" will be | ||
used. For ops, the dataset can be specified by including a "schema" entry in output metadata. If "schema" is | ||
not provided via config or on the asset/op, "public" will be used for the dataset. | ||
.. code-block:: python | ||
@op( | ||
out={"my_table": Out(metadata={"schema": "my_dataset"})} | ||
) | ||
def make_my_table() -> pd.DataFrame: | ||
# the returned value will be stored at my_dataset.my_table | ||
... | ||
To only use specific columns of a table as input to a downstream op or asset, add the metadata "columns" to the | ||
In or AssetIn. | ||
.. code-block:: python | ||
@asset( | ||
ins={"my_table": AssetIn("my_table", metadata={"columns": ["a"]})} | ||
) | ||
def my_table_a(my_table: pd.DataFrame) -> pd.DataFrame: | ||
# my_table will just contain the data from column "a" | ||
... | ||
If you cannot upload a file to your Dagster deployment, or otherwise cannot | ||
`authenticate with GCP <https://cloud.google.com/docs/authentication/provide-credentials-adc>`_ | ||
via a standard method, you can provide a service account key as the "gcp_credentials" configuration. | ||
Dagster will store this key in a temporary file and set GOOGLE_APPLICATION_CREDENTIALS to point to the file. | ||
After the run completes, the file will be deleted, and GOOGLE_APPLICATION_CREDENTIALS will be | ||
unset. The key must be base64 encoded to avoid issues with newlines in the keys. You can retrieve | ||
the base64 encoded key with this shell command: cat $GOOGLE_APPLICATION_CREDENTIALS | base64 | ||
The "write_disposition" metadata key can be used to set the `write_disposition` parameter | ||
of `bigquery.JobConfig`. For example, set it to `"WRITE_APPEND"` to append to an existing table intead of | ||
overwriting it. | ||
""" | ||
|
||
@staticmethod | ||
def type_handlers() -> Sequence[DbTypeHandler]: | ||
return [BigQueryPolarsTypeHandler()] | ||
|
||
@staticmethod | ||
def default_load_type() -> Optional[Type]: | ||
return pl.DataFrame |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import json | ||
import sys | ||
from datetime import date, datetime, time, timedelta | ||
from pprint import pformat | ||
from typing import Any, Dict, Mapping, Optional, Tuple, Union | ||
|
||
import polars as pl | ||
from dagster import MetadataValue, OutputContext, TableColumn, TableMetadataValue, TableRecord, TableSchema | ||
|
||
POLARS_DATA_FRAME_ANNOTATIONS = [ | ||
Any, | ||
pl.DataFrame, | ||
Dict[str, pl.DataFrame], | ||
Mapping[str, pl.DataFrame], | ||
type(None), | ||
None, | ||
] | ||
|
||
POLARS_LAZY_FRAME_ANNOTATIONS = [ | ||
pl.LazyFrame, | ||
Dict[str, pl.LazyFrame], | ||
Mapping[str, pl.LazyFrame], | ||
] | ||
|
||
|
||
if sys.version >= "3.9": | ||
POLARS_DATA_FRAME_ANNOTATIONS.append(dict[str, pl.DataFrame]) # type: ignore | ||
POLARS_LAZY_FRAME_ANNOTATIONS.append(dict[str, pl.DataFrame]) # type: ignore | ||
|
||
|
||
def cast_polars_single_value_to_dagster_table_types(val: Any): | ||
if val is None: | ||
return "" | ||
elif isinstance(val, (date, datetime, time, timedelta)): | ||
return str(val) | ||
elif isinstance(val, (list, dict)): | ||
# default=str because sometimes the object can be a list of datetimes or something like this | ||
return json.dumps(val, default=str) | ||
else: | ||
return val | ||
|
||
|
||
def get_metadata_schema( | ||
df: pl.DataFrame, | ||
descriptions: Optional[Dict[str, str]] = None, | ||
): | ||
descriptions = descriptions or {} | ||
return TableSchema( | ||
columns=[ | ||
TableColumn(name=col, type=str(pl_type), description=descriptions.get(col)) | ||
for col, pl_type in df.schema.items() | ||
] | ||
) | ||
|
||
|
||
def get_metadata_table_and_schema( | ||
context: OutputContext, | ||
df: pl.DataFrame, | ||
n_rows: Optional[int] = 5, | ||
fraction: Optional[float] = None, | ||
descriptions: Optional[Dict[str, str]] = None, | ||
) -> Tuple[TableSchema, Optional[TableMetadataValue]]: | ||
assert not fraction and n_rows, "only one of n_rows and frac should be set" | ||
n_rows = min(n_rows, len(df)) | ||
|
||
schema = get_metadata_schema(df, descriptions=descriptions) | ||
|
||
df_sample = df.sample(n=n_rows, fraction=fraction, shuffle=True) | ||
|
||
try: | ||
# this can fail sometimes | ||
# because TableRecord doesn't support all python types | ||
table = MetadataValue.table( | ||
records=[ | ||
TableRecord( | ||
{ | ||
col: cast_polars_single_value_to_dagster_table_types( # type: ignore | ||
df_sample.to_dicts()[i][col] | ||
) | ||
for col in df.columns | ||
} | ||
) | ||
for i in range(len(df_sample)) | ||
], | ||
schema=schema, | ||
) | ||
|
||
except TypeError as e: | ||
context.log.error( | ||
f"Failed to create table sample metadata. Will only record table schema metadata. " | ||
f"Reason:\n{e}\n" | ||
f"Schema:\n{df.schema}\n" | ||
f"Polars sample:\n{df_sample}\n" | ||
f"dict sample:\n{pformat(df_sample.to_dicts())}" | ||
) | ||
return schema, None | ||
|
||
return schema, table | ||
|
||
|
||
def get_polars_df_stats( | ||
df: pl.DataFrame, | ||
) -> Dict[str, Dict[str, Union[str, int, float]]]: | ||
describe = df.describe().fill_null(pl.lit("null")) | ||
return { | ||
col: {stat: describe[col][i] for i, stat in enumerate(describe["describe"].to_list())} | ||
for col in describe.columns[1:] | ||
} | ||
|
||
|
||
def get_polars_metadata(context: OutputContext, df: pl.DataFrame) -> Dict[str, MetadataValue]: | ||
assert context.metadata is not None | ||
schema, table = get_metadata_table_and_schema( | ||
context=context, | ||
df=df, | ||
n_rows=context.metadata.get("n_rows", 5), | ||
fraction=context.metadata.get("fraction"), | ||
descriptions=context.metadata.get("descriptions"), | ||
) | ||
|
||
metadata = { | ||
"stats": MetadataValue.json(get_polars_df_stats(df)), | ||
"row_count": MetadataValue.int(len(df)), | ||
} | ||
|
||
if table is not None: | ||
metadata["table"] = table | ||
else: | ||
metadata["schema"] = schema | ||
|
||
return metadata |
Oops, something went wrong.