Skip to content

Commit

Permalink
Refactor get_dataframe method to allow altering schema (#7)
Browse files Browse the repository at this point in the history
* Refactor get_dataframe method to allow altering schema

* Fixed testing pipeline
  • Loading branch information
JulesHuisman authored Feb 14, 2024
1 parent dc6949e commit 6bff211
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 23 deletions.
20 changes: 6 additions & 14 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
name: Test

on:
push:
branches:
- xyz
# paths:
# - ".github/workflows/test.yaml"
# - "pyspark_cdm/**"
# - "tests/**"
# - "poetry.lock"
# pull_request:
# paths:
# - ".github/workflows/test.yaml"
# - "pyspark_cdm/**"
# - "tests/**"
# - "poetry.lock"
pull_request:
paths:
- ".github/workflows/test.yaml"
- "pyspark_cdm/**"
- "tests/**"
- "poetry.lock"

jobs:
test:
Expand Down
18 changes: 12 additions & 6 deletions pyspark_cdm/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from pyspark.sql import DataFrame
from tenacity import retry, stop_after_attempt, wait_random_exponential


def log_attempt_number(retry_state):
"""Print a message after retrying."""
print(f"Retrying: {retry_state.attempt_number}...")


class Entity:
def __init__(
self,
Expand Down Expand Up @@ -140,15 +142,19 @@ def schema(self) -> StructType:
return catalog.schema

@retry(
stop=stop_after_attempt(2),
wait=wait_random_exponential(multiplier=3, max=60),
after=log_attempt_number,
)
def get_dataframe(self, spark) -> DataFrame:
stop=stop_after_attempt(2),
wait=wait_random_exponential(multiplier=3, max=60),
after=log_attempt_number,
)
def get_dataframe(
self,
spark,
alter_schema=lambda schema: schema,
) -> DataFrame:
return spark.read.csv(
list(self.file_paths),
header=False,
schema=self.schema,
schema=alter_schema(self.schema),
inferSchema=False,
multiLine=True,
escape='"',
Expand Down
18 changes: 15 additions & 3 deletions tests/test_pyspark_cdm/test_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from cdm.persistence.modeljson.types.local_entity import LocalEntity
import pytest
from pyspark_cdm import Entity
from tests.consts import MANIFEST_SAMPLE_PATH, MODEL_SAMPLE_PATH
from pyspark.sql.types import StructType
from tests.consts import MANIFEST_SAMPLE_PATH
import pyspark.sql.types as T
from pyspark.sql import DataFrame


Expand Down Expand Up @@ -65,7 +65,7 @@ def test_entity_schema(entity: Entity):
"""
Make sure that the schema property correctly returns a StructType.
"""
assert type(entity.schema) == StructType
assert type(entity.schema) == T.StructType


def test_entity_dataframe(entity: Entity, spark):
Expand All @@ -75,3 +75,15 @@ def test_entity_dataframe(entity: Entity, spark):
df = entity.get_dataframe(spark=spark)
assert type(df) == DataFrame
assert df.count() == 3


def test_entity_alter_schema(entity: Entity, spark):
"""
Make sure that the alter_schema parameter correctly alters the schema of the dataframe.
"""

def alter_schema(schema):
return T.StructType([T.StructField("_id", T.StringType()), *schema[1:]])

df = entity.get_dataframe(spark=spark, alter_schema=alter_schema)
assert df.columns[0] == "_id"

0 comments on commit 6bff211

Please sign in to comment.