Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(datasets): Fix spark tests #773

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion kedro-datasets/tests/spark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _setup_spark_session():
).getOrCreate()


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(scope="module")
def spark_session(tmp_path_factory):
# When running these spark tests with pytest-xdist, we need to make sure
# that the spark session setup on each test process don't interfere with each other.
Expand All @@ -40,3 +40,6 @@ def spark_session(tmp_path_factory):
spark = _setup_spark_session()
yield spark
spark.stop()
# Ensure that the spark session is not used after it is stopped
# https://stackoverflow.com/a/41512072
spark._instantiatedContext = None
5 changes: 2 additions & 3 deletions kedro-datasets/tests/spark/test_deltatable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from kedro.runner import ParallelRunner
from packaging.version import Version
from pyspark import __version__
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
from pyspark.sql.utils import AnalysisException

Expand All @@ -17,7 +16,7 @@


@pytest.fixture
def sample_spark_df():
def sample_spark_df(spark_session):
schema = StructType(
[
StructField("name", StringType(), True),
Expand All @@ -27,7 +26,7 @@ def sample_spark_df():

data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)]

return SparkSession.builder.getOrCreate().createDataFrame(data, schema)
return spark_session.createDataFrame(data, schema)


class TestDeltaTableDataset:
Expand Down
16 changes: 8 additions & 8 deletions kedro-datasets/tests/spark/test_memory_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import pytest
from kedro.io import MemoryDataset
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when


def _update_spark_df(data, idx, jdx, value):
session = SparkSession.builder.getOrCreate()
data = session.createDataFrame(data.rdd.zipWithIndex()).select(
def _update_spark_df(spark_session, data, idx, jdx, value):
data = spark_session.createDataFrame(data.rdd.zipWithIndex()).select(
col("_1.*"), col("_2").alias("__id")
)
cname = data.columns[idx]
Expand All @@ -34,19 +32,21 @@ def memory_dataset(spark_data_frame):
return MemoryDataset(data=spark_data_frame)


def test_load_modify_original_data(memory_dataset, spark_data_frame):
def test_load_modify_original_data(spark_session, memory_dataset, spark_data_frame):
"""Check that the data set object is not updated when the original
SparkDataFrame is changed."""
spark_data_frame = _update_spark_df(spark_data_frame, 1, 1, -5)
spark_data_frame = _update_spark_df(spark_session, spark_data_frame, 1, 1, -5)
assert not _check_equals(memory_dataset.load(), spark_data_frame)


def test_save_modify_original_data(spark_data_frame):
def test_save_modify_original_data(spark_session, spark_data_frame):
"""Check that the data set object is not updated when the original
SparkDataFrame is changed."""
memory_dataset = MemoryDataset()
memory_dataset.save(spark_data_frame)
spark_data_frame = _update_spark_df(spark_data_frame, 1, 1, "new value")
spark_data_frame = _update_spark_df(
spark_session, spark_data_frame, 1, 1, "new value"
)

assert not _check_equals(memory_dataset.load(), spark_data_frame)

Expand Down
5 changes: 2 additions & 3 deletions kedro-datasets/tests/spark/test_spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from moto import mock_aws
from packaging.version import Version as PackagingVersion
from pyspark import __version__
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import (
FloatType,
Expand Down Expand Up @@ -102,7 +101,7 @@ def versioned_dataset_s3(version):


@pytest.fixture
def sample_spark_df():
def sample_spark_df(spark_session):
schema = StructType(
[
StructField("name", StringType(), True),
Expand All @@ -112,7 +111,7 @@ def sample_spark_df():

data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)]

return SparkSession.builder.getOrCreate().createDataFrame(data, schema)
return spark_session.createDataFrame(data, schema)


@pytest.fixture
Expand Down
68 changes: 38 additions & 30 deletions kedro-datasets/tests/spark/test_spark_hive_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def spark_session():
# in this module so that it always exits last and stops the spark session
# after tests are finished.
spark.stop()
# Ensure that the spark session is not used after it is stopped
# https://stackoverflow.com/a/41512072
spark._instantiatedContext = None
except PermissionError: # pragma: no cover
# On Windows machine TemporaryDirectory can't be removed because some
# files are still used by Java process.
Expand All @@ -68,7 +71,7 @@ def spark_session():
@pytest.fixture(scope="module", autouse=True)
def spark_test_databases(spark_session):
"""Setup spark test databases for all tests in this module."""
dataset = _generate_spark_df_one()
dataset = _generate_spark_df_one(spark_session)
dataset.createOrReplaceTempView("tmp")
databases = ["default_1", "default_2"]

Expand Down Expand Up @@ -100,37 +103,37 @@ def indexRDD(data_frame):
)


def _generate_spark_df_one():
def _generate_spark_df_one(spark_session):
schema = StructType(
[
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
]
)
data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)]
return SparkSession.builder.getOrCreate().createDataFrame(data, schema).coalesce(1)
return spark_session.createDataFrame(data, schema).coalesce(1)


def _generate_spark_df_upsert():
def _generate_spark_df_upsert(spark_session):
schema = StructType(
[
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
]
)
data = [("Alex", 99), ("Jeremy", 55)]
return SparkSession.builder.getOrCreate().createDataFrame(data, schema).coalesce(1)
return spark_session.createDataFrame(data, schema).coalesce(1)


def _generate_spark_df_upsert_expected():
def _generate_spark_df_upsert_expected(spark_session):
schema = StructType(
[
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
]
)
data = [("Alex", 99), ("Bob", 12), ("Clarke", 65), ("Dave", 29), ("Jeremy", 55)]
return SparkSession.builder.getOrCreate().createDataFrame(data, schema).coalesce(1)
return spark_session.createDataFrame(data, schema).coalesce(1)


class TestSparkHiveDataset:
Expand All @@ -144,11 +147,11 @@ def test_cant_pickle(self):
)
)

def test_read_existing_table(self):
def test_read_existing_table(self, spark_session):
dataset = SparkHiveDataset(
database="default_1", table="table_1", write_mode="overwrite", save_args={}
)
assert_df_equal(_generate_spark_df_one(), dataset.load())
assert_df_equal(_generate_spark_df_one(spark_session), dataset.load())

def test_overwrite_empty_table(self, spark_session):
spark_session.sql(
Expand All @@ -159,8 +162,8 @@ def test_overwrite_empty_table(self, spark_session):
table="test_overwrite_empty_table",
write_mode="overwrite",
)
dataset.save(_generate_spark_df_one())
assert_df_equal(dataset.load(), _generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))
assert_df_equal(dataset.load(), _generate_spark_df_one(spark_session))

def test_overwrite_not_empty_table(self, spark_session):
spark_session.sql(
Expand All @@ -171,9 +174,9 @@ def test_overwrite_not_empty_table(self, spark_session):
table="test_overwrite_full_table",
write_mode="overwrite",
)
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one())
assert_df_equal(dataset.load(), _generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))
dataset.save(_generate_spark_df_one(spark_session))
assert_df_equal(dataset.load(), _generate_spark_df_one(spark_session))

def test_insert_not_empty_table(self, spark_session):
spark_session.sql(
Expand All @@ -184,10 +187,13 @@ def test_insert_not_empty_table(self, spark_session):
table="test_insert_not_empty_table",
write_mode="append",
)
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))
dataset.save(_generate_spark_df_one(spark_session))
assert_df_equal(
dataset.load(), _generate_spark_df_one().union(_generate_spark_df_one())
dataset.load(),
_generate_spark_df_one(spark_session).union(
_generate_spark_df_one(spark_session)
),
)

def test_upsert_config_err(self):
Expand All @@ -207,9 +213,10 @@ def test_upsert_empty_table(self, spark_session):
write_mode="upsert",
table_pk=["name"],
)
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))
assert_df_equal(
dataset.load().sort("name"), _generate_spark_df_one().sort("name")
dataset.load().sort("name"),
_generate_spark_df_one(spark_session).sort("name"),
)

def test_upsert_not_empty_table(self, spark_session):
Expand All @@ -222,15 +229,15 @@ def test_upsert_not_empty_table(self, spark_session):
write_mode="upsert",
table_pk=["name"],
)
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_upsert())
dataset.save(_generate_spark_df_one(spark_session))
dataset.save(_generate_spark_df_upsert(spark_session))

assert_df_equal(
dataset.load().sort("name"),
_generate_spark_df_upsert_expected().sort("name"),
_generate_spark_df_upsert_expected(spark_session).sort("name"),
)

def test_invalid_pk_provided(self):
def test_invalid_pk_provided(self, spark_session):
_test_columns = ["column_doesnt_exist"]
dataset = SparkHiveDataset(
database="default_1",
Expand All @@ -245,7 +252,7 @@ def test_invalid_pk_provided(self):
f"not found in table default_1.table_1",
),
):
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))

def test_invalid_write_mode_provided(self):
pattern = (
Expand Down Expand Up @@ -277,15 +284,16 @@ def test_invalid_schema_insert(self, spark_session):
r"Present on insert only: \[\('age', 'int'\)\]\n"
r"Present on schema only: \[\('additional_column_on_hive', 'int'\)\]",
):
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))

def test_insert_to_non_existent_table(self):
def test_insert_to_non_existent_table(self, spark_session):
dataset = SparkHiveDataset(
database="default_1", table="table_not_yet_created", write_mode="append"
)
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))
assert_df_equal(
dataset.load().sort("name"), _generate_spark_df_one().sort("name")
dataset.load().sort("name"),
_generate_spark_df_one(spark_session).sort("name"),
)

def test_read_from_non_existent_table(self):
Expand All @@ -300,12 +308,12 @@ def test_read_from_non_existent_table(self):
):
dataset.load()

def test_save_delta_format(self, mocker):
def test_save_delta_format(self, mocker, spark_session):
dataset = SparkHiveDataset(
database="default_1", table="delta_table", save_args={"format": "delta"}
)
mocked_save = mocker.patch("pyspark.sql.DataFrameWriter.saveAsTable")
dataset.save(_generate_spark_df_one())
dataset.save(_generate_spark_df_one(spark_session))
mocked_save.assert_called_with(
"default_1.delta_table", mode="errorifexists", format="delta"
)
Expand Down
7 changes: 2 additions & 5 deletions kedro-datasets/tests/spark/test_spark_streaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from moto import mock_aws
from packaging.version import Version
from pyspark import __version__
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
from pyspark.sql.utils import AnalysisException

Expand Down Expand Up @@ -43,15 +42,13 @@ def sample_spark_df_schema() -> StructType:


@pytest.fixture
def sample_spark_streaming_df(tmp_path, sample_spark_df_schema):
def sample_spark_streaming_df(spark_session, tmp_path, sample_spark_df_schema):
"""Create a sample dataframe for streaming"""
data = [("0001", 2), ("0001", 7), ("0002", 4)]
schema_path = (tmp_path / SCHEMA_FILE_NAME).as_posix()
with open(schema_path, "w", encoding="utf-8") as f:
json.dump(sample_spark_df_schema.jsonValue(), f)
return SparkSession.builder.getOrCreate().createDataFrame(
data, sample_spark_df_schema
)
return spark_session.createDataFrame(data, sample_spark_df_schema)


@pytest.fixture
Expand Down
Loading