From ac906ca50b7631bfdf5a6689d9342f562a989663 Mon Sep 17 00:00:00 2001 From: Vincent Chen <62143443+mao3267@users.noreply.github.com> Date: Wed, 12 Feb 2025 15:13:18 +0800 Subject: [PATCH] test: Structured dataset pickleable (#3121) Signed-off-by: mao3267 --- .../test_structured_dataset.py | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py index ee535697a3..770d757eba 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py @@ -19,7 +19,7 @@ from flytekit.core.workflow import workflow from flytekit.lazy_import.lazy_module import is_imported from flytekit.models import literals -from flytekit.models.literals import StructuredDatasetMetadata +from flytekit.models.literals import StructuredDatasetMetadata, Literal from flytekit.models.types import LiteralType, SchemaType, SimpleType, StructuredDatasetType from flytekit.tools.translator import get_serializable from flytekit.types.structured.structured_dataset import ( @@ -713,3 +713,41 @@ def mock_resolve_remote_path(flyte_uri: str) -> typing.Optional[str]: lit = sdte.encode(ctx, sd, df_type=pd.DataFrame, protocol="bq", format="parquet", structured_literal_type=lt) assert lit.scalar.structured_dataset.uri == "bq://blah/blah/blah" + +def test_structured_dataset_pickleable(): + import pickle + + upstream_output = Literal( + scalar=literals.Scalar( + structured_dataset=StructuredDataset( + dataframe=pd.DataFrame({"a": [1, 2], "b": [3, 4]}), + uri="bq://test_uri", + metadata=StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType( + columns=[ + StructuredDatasetType.DatasetColumn( + name="a", + literal_type=LiteralType(simple=SimpleType.INTEGER) + ), + StructuredDatasetType.DatasetColumn( + name="b", + literal_type=LiteralType(simple=SimpleType.INTEGER) + ) + ], + format="parquet" + ) + ) + ) + ) + ) + + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), + upstream_output, + StructuredDataset + ) + + pickled_input = pickle.dumps(downstream_input) + unpickled_input = pickle.loads(pickled_input) + + assert downstream_input == unpickled_input