diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 3a66564..21bc633 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,20 +1,12 @@ name: Test on: - push: - branches: - - xyz -# paths: -# - ".github/workflows/test.yaml" -# - "pyspark_cdm/**" -# - "tests/**" -# - "poetry.lock" -# pull_request: -# paths: -# - ".github/workflows/test.yaml" -# - "pyspark_cdm/**" -# - "tests/**" -# - "poetry.lock" + pull_request: + paths: + - ".github/workflows/test.yaml" + - "pyspark_cdm/**" + - "tests/**" + - "poetry.lock" jobs: test: diff --git a/pyspark_cdm/entity.py b/pyspark_cdm/entity.py index 12a5468..0f69cf3 100644 --- a/pyspark_cdm/entity.py +++ b/pyspark_cdm/entity.py @@ -25,10 +25,12 @@ from pyspark.sql import DataFrame from tenacity import retry, stop_after_attempt, wait_random_exponential + def log_attempt_number(retry_state): """Print a message after retrying.""" print(f"Retrying: {retry_state.attempt_number}...") + class Entity: def __init__( self, @@ -140,15 +142,19 @@ def schema(self) -> StructType: return catalog.schema @retry( - stop=stop_after_attempt(2), - wait=wait_random_exponential(multiplier=3, max=60), - after=log_attempt_number, - ) - def get_dataframe(self, spark) -> DataFrame: + stop=stop_after_attempt(2), + wait=wait_random_exponential(multiplier=3, max=60), + after=log_attempt_number, + ) + def get_dataframe( + self, + spark, + alter_schema=lambda schema: schema, + ) -> DataFrame: return spark.read.csv( list(self.file_paths), header=False, - schema=self.schema, + schema=alter_schema(self.schema), inferSchema=False, multiLine=True, escape='"', diff --git a/tests/test_pyspark_cdm/test_entity.py b/tests/test_pyspark_cdm/test_entity.py index 2fdf411..e67bccf 100644 --- a/tests/test_pyspark_cdm/test_entity.py +++ b/tests/test_pyspark_cdm/test_entity.py @@ -2,8 +2,8 @@ from cdm.persistence.modeljson.types.local_entity import LocalEntity import pytest from pyspark_cdm import Entity -from tests.consts import MANIFEST_SAMPLE_PATH, MODEL_SAMPLE_PATH -from pyspark.sql.types import StructType +from tests.consts import MANIFEST_SAMPLE_PATH +import pyspark.sql.types as T from pyspark.sql import DataFrame @@ -65,7 +65,7 @@ def test_entity_schema(entity: Entity): """ Make sure that the schema property correctly returns a StructType. """ - assert type(entity.schema) == StructType + assert type(entity.schema) == T.StructType def test_entity_dataframe(entity: Entity, spark): @@ -75,3 +75,15 @@ def test_entity_dataframe(entity: Entity, spark): df = entity.get_dataframe(spark=spark) assert type(df) == DataFrame assert df.count() == 3 + + +def test_entity_alter_schema(entity: Entity, spark): + """ + Make sure that the alter_schema parameter correctly alters the schema of the dataframe. + """ + + def alter_schema(schema): + return T.StructType([T.StructField("_id", T.StringType()), *schema[1:]]) + + df = entity.get_dataframe(spark=spark, alter_schema=alter_schema) + assert df.columns[0] == "_id"