From 07aef5a30324daeb2075ea01f702e2b188b4ab49 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa <49385643+MinuraPunchihewa@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:23:18 +0530 Subject: [PATCH] feat(datasets): Improved Dependency Management for Spark-based Datasets (#911) * added the skeleton for the utils sub pkg Signed-off-by: Minura Punchihewa * moved the utility funcs from spark_dataset to relevant modules in _utils Signed-off-by: Minura Punchihewa * updated the use of utility funcs in spark_dataset Signed-off-by: Minura Punchihewa * fixed import in databricks_utils Signed-off-by: Minura Punchihewa * renamed _strip_dbfs_prefix to strip_dbfs_prefix Signed-off-by: Minura Punchihewa * updated the other modules that import from spark_dataset to use _utils Signed-off-by: Minura Punchihewa * updated the use of strip_dbfs_prefix in spark_dataset Signed-off-by: Minura Punchihewa * fixed lint issues Signed-off-by: Minura Punchihewa * removed the base deps for spark, pandas and delta from databricks datasets Signed-off-by: Minura Punchihewa * moved the file based utility funcs to databricks_utils Signed-off-by: Minura Punchihewa * fixed the imports of the file based utility funcs Signed-off-by: Minura Punchihewa * fixed lint issues Signed-off-by: Minura Punchihewa * fixed the use of _get_spark() in tests Signed-off-by: Minura Punchihewa * fixed uses of databricks utils in tests Signed-off-by: Minura Punchihewa * fixed more tests Signed-off-by: Minura Punchihewa * fixed more lint issues Signed-off-by: Minura Punchihewa * fixed more tests Signed-off-by: Minura Punchihewa * fixed more tests Signed-off-by: Minura Punchihewa * improved type hints for spark & databricks utility funcs Signed-off-by: Minura Punchihewa * fixed more lint issues Signed-off-by: Minura Punchihewa * further improved type hints for utility funcs Signed-off-by: Minura Punchihewa * fixed a couple of incorrect type hints Signed-off-by: Minura Punchihewa * fixed several incorrect type hints Signed-off-by: Minura Punchihewa * updated the release notes Signed-off-by: Minura Punchihewa * Reorder release notes Signed-off-by: Merel Theisen --------- Signed-off-by: Minura Punchihewa Signed-off-by: Merel Theisen Co-authored-by: Nok Lam Chan Co-authored-by: Merel Theisen <49397448+merelcht@users.noreply.github.com> Co-authored-by: Merel Theisen --- kedro-datasets/RELEASE.md | 2 +- .../kedro_datasets/_utils/__init__.py | 0 .../kedro_datasets/_utils/databricks_utils.py | 105 ++++++++++++ .../kedro_datasets/_utils/spark_utils.py | 29 ++++ .../databricks/_base_table_dataset.py | 22 +-- .../spark/deltatable_dataset.py | 15 +- .../kedro_datasets/spark/spark_dataset.py | 154 +++--------------- .../spark/spark_hive_dataset.py | 6 +- .../spark/spark_jdbc_dataset.py | 4 +- .../spark/spark_streaming_dataset.py | 21 +-- kedro-datasets/pyproject.toml | 4 +- .../tests/spark/test_deltatable_dataset.py | 4 +- .../tests/spark/test_spark_dataset.py | 44 ++--- .../tests/spark/test_spark_jdbc_dataset.py | 6 +- .../spark/test_spark_streaming_dataset.py | 4 +- 15 files changed, 220 insertions(+), 200 deletions(-) create mode 100644 kedro-datasets/kedro_datasets/_utils/__init__.py create mode 100644 kedro-datasets/kedro_datasets/_utils/databricks_utils.py create mode 100644 kedro-datasets/kedro_datasets/_utils/spark_utils.py diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index daff41362..482b3c76f 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,7 +1,6 @@ # Upcoming Release 6.0.0 ## Major features and improvements - - Added functionality to save Pandas DataFrame directly to Snowflake, facilitating seemless `.csv` ingestion - Added Python 3.9, 3.10 and 3.11 support for SnowflakeTableDataset - Added the following new **experimental** datasets: @@ -13,6 +12,7 @@ ## Bug fixes and other changes - Implemented Snowflake's (local testing framework)[https://docs.snowflake.com/en/developer-guide/snowpark/python/testing-locally] for testing purposes +- Improved the dependency management for Spark-based datasets by refactoring the Spark and Databricks utility functions used across the datasets. ## Breaking Changes - Demoted `video.VideoDataset` from core to experimental dataset. diff --git a/kedro-datasets/kedro_datasets/_utils/__init__.py b/kedro-datasets/kedro_datasets/_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/kedro_datasets/_utils/databricks_utils.py b/kedro-datasets/kedro_datasets/_utils/databricks_utils.py new file mode 100644 index 000000000..858b5e5e0 --- /dev/null +++ b/kedro-datasets/kedro_datasets/_utils/databricks_utils.py @@ -0,0 +1,105 @@ +import os +from fnmatch import fnmatch +from pathlib import PurePosixPath +from typing import TYPE_CHECKING, Union + +from pyspark.sql import SparkSession + +if TYPE_CHECKING: + from databricks.connect import DatabricksSession + from pyspark.dbutils import DBUtils + + +def parse_glob_pattern(pattern: str) -> str: + special = ("*", "?", "[") + clean = [] + for part in pattern.split("/"): + if any(char in part for char in special): + break + clean.append(part) + return "/".join(clean) + + +def split_filepath(filepath: str | os.PathLike) -> tuple[str, str]: + split_ = str(filepath).split("://", 1) + if len(split_) == 2: # noqa: PLR2004 + return split_[0] + "://", split_[1] + return "", split_[0] + + +def strip_dbfs_prefix(path: str, prefix: str = "/dbfs") -> str: + return path[len(prefix) :] if path.startswith(prefix) else path + + +def dbfs_glob(pattern: str, dbutils: "DBUtils") -> list[str]: + """Perform a custom glob search in DBFS using the provided pattern. + It is assumed that version paths are managed by Kedro only. + + Args: + pattern: Glob pattern to search for. + dbutils: dbutils instance to operate with DBFS. + + Returns: + List of DBFS paths prefixed with '/dbfs' that satisfy the glob pattern. + """ + pattern = strip_dbfs_prefix(pattern) + prefix = parse_glob_pattern(pattern) + matched = set() + filename = pattern.split("/")[-1] + + for file_info in dbutils.fs.ls(prefix): + if file_info.isDir(): + path = str( + PurePosixPath(strip_dbfs_prefix(file_info.path, "dbfs:")) / filename + ) + if fnmatch(path, pattern): + path = "/dbfs" + path + matched.add(path) + return sorted(matched) + + +def get_dbutils(spark: Union[SparkSession, "DatabricksSession"]) -> "DBUtils": + """Get the instance of 'dbutils' or None if the one could not be found.""" + dbutils = globals().get("dbutils") + if dbutils: + return dbutils + + try: + from pyspark.dbutils import DBUtils + + dbutils = DBUtils(spark) + except ImportError: + try: + import IPython + except ImportError: + pass + else: + ipython = IPython.get_ipython() + dbutils = ipython.user_ns.get("dbutils") if ipython else None + + return dbutils + + +def dbfs_exists(pattern: str, dbutils: "DBUtils") -> bool: + """Perform an `ls` list operation in DBFS using the provided pattern. + It is assumed that version paths are managed by Kedro. + Broad `Exception` is present due to `dbutils.fs.ExecutionError` that + cannot be imported directly. + Args: + pattern: Filepath to search for. + dbutils: dbutils instance to operate with DBFS. + Returns: + Boolean value if filepath exists. + """ + pattern = strip_dbfs_prefix(pattern) + file = parse_glob_pattern(pattern) + try: + dbutils.fs.ls(file) + return True + except Exception: + return False + + +def deployed_on_databricks() -> bool: + """Check if running on Databricks.""" + return "DATABRICKS_RUNTIME_VERSION" in os.environ diff --git a/kedro-datasets/kedro_datasets/_utils/spark_utils.py b/kedro-datasets/kedro_datasets/_utils/spark_utils.py new file mode 100644 index 000000000..e55012275 --- /dev/null +++ b/kedro-datasets/kedro_datasets/_utils/spark_utils.py @@ -0,0 +1,29 @@ +from typing import TYPE_CHECKING, Union + +from pyspark.sql import SparkSession + +if TYPE_CHECKING: + from databricks.connect import DatabricksSession + + +def get_spark() -> Union[SparkSession, "DatabricksSession"]: + """ + Returns the SparkSession. In case databricks-connect is available we use it for + extended configuration mechanisms and notebook compatibility, + otherwise we use classic pyspark. + """ + try: + # When using databricks-connect >= 13.0.0 (a.k.a databricks-connect-v2) + # the remote session is instantiated using the databricks module + # If the databricks-connect module is installed, we use a remote session + from databricks.connect import DatabricksSession + + # We can't test this as there's no Databricks test env available + spark = DatabricksSession.builder.getOrCreate() # pragma: no cover + + except ImportError: + # For "normal" spark sessions that don't use databricks-connect + # we get spark normally + spark = SparkSession.builder.getOrCreate() + + return spark diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index 4cbb2ea37..95fba67a7 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -19,7 +19,7 @@ from pyspark.sql.types import StructType from pyspark.sql.utils import AnalysisException, ParseException -from kedro_datasets.spark.spark_dataset import _get_spark +from kedro_datasets._utils.spark_utils import get_spark logger = logging.getLogger(__name__) pd.DataFrame.iteritems = pd.DataFrame.items @@ -183,7 +183,7 @@ def exists(self) -> bool: """ if self.catalog: try: - _get_spark().sql(f"USE CATALOG `{self.catalog}`") + get_spark().sql(f"USE CATALOG `{self.catalog}`") except (ParseException, AnalysisException) as exc: logger.warning( "catalog %s not found or unity not enabled. Error message: %s", @@ -192,7 +192,7 @@ def exists(self) -> bool: ) try: return ( - _get_spark() + get_spark() .sql(f"SHOW TABLES IN `{self.database}`") .filter(f"tableName = '{self.table}'") .count() @@ -359,7 +359,7 @@ def _load(self) -> DataFrame | pd.DataFrame: if self._version and self._version.load >= 0: try: data = ( - _get_spark() + get_spark() .read.format("delta") .option("versionAsOf", self._version.load) .table(self._table.full_table_location()) @@ -367,7 +367,7 @@ def _load(self) -> DataFrame | pd.DataFrame: except Exception as exc: raise VersionNotFoundError(self._version.load) from exc else: - data = _get_spark().table(self._table.full_table_location()) + data = get_spark().table(self._table.full_table_location()) if self._table.dataframe_type == "pandas": data = data.toPandas() return data @@ -391,13 +391,13 @@ def _save(self, data: DataFrame | pd.DataFrame) -> None: if schema: cols = schema.fieldNames() if self._table.dataframe_type == "pandas": - data = _get_spark().createDataFrame( + data = get_spark().createDataFrame( data.loc[:, cols], schema=self._table.schema() ) else: data = data.select(*cols) elif self._table.dataframe_type == "pandas": - data = _get_spark().createDataFrame(data) + data = get_spark().createDataFrame(data) method = getattr(self, f"_save_{self._table.write_mode}", None) if method: @@ -456,7 +456,7 @@ def _save_upsert(self, update_data: DataFrame) -> None: update_data (DataFrame): The Spark dataframe to upsert. """ if self._exists(): - base_data = _get_spark().table(self._table.full_table_location()) + base_data = get_spark().table(self._table.full_table_location()) base_columns = base_data.columns update_columns = update_data.columns @@ -479,11 +479,11 @@ def _save_upsert(self, update_data: DataFrame) -> None: ) update_data.createOrReplaceTempView("update") - _get_spark().conf.set("fullTableAddress", self._table.full_table_location()) - _get_spark().conf.set("whereExpr", where_expr) + get_spark().conf.set("fullTableAddress", self._table.full_table_location()) + get_spark().conf.set("whereExpr", where_expr) upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr} WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *""" - _get_spark().sql(upsert_sql) + get_spark().sql(upsert_sql) else: self._save_append(update_data) diff --git a/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py b/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py index efe693efa..5b5690912 100644 --- a/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py @@ -10,11 +10,8 @@ from kedro.io.core import AbstractDataset, DatasetError from pyspark.sql.utils import AnalysisException -from kedro_datasets.spark.spark_dataset import ( - _get_spark, - _split_filepath, - _strip_dbfs_prefix, -) +from kedro_datasets._utils.databricks_utils import split_filepath, strip_dbfs_prefix +from kedro_datasets._utils.spark_utils import get_spark class DeltaTableDataset(AbstractDataset[None, DeltaTable]): @@ -81,7 +78,7 @@ def __init__( metadata: Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins. """ - fs_prefix, filepath = _split_filepath(filepath) + fs_prefix, filepath = split_filepath(filepath) self._fs_prefix = fs_prefix self._filepath = PurePosixPath(filepath) @@ -89,16 +86,16 @@ def __init__( def load(self) -> DeltaTable: load_path = self._fs_prefix + str(self._filepath) - return DeltaTable.forPath(_get_spark(), load_path) + return DeltaTable.forPath(get_spark(), load_path) def save(self, data: None) -> NoReturn: raise DatasetError(f"{self.__class__.__name__} is a read only dataset type") def _exists(self) -> bool: - load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) + load_path = strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) try: - _get_spark().read.load(path=load_path, format="delta") + get_spark().read.load(path=load_path, format="delta") except AnalysisException as exception: # `AnalysisException.desc` is deprecated with pyspark >= 3.4 message = exception.desc if hasattr(exception, "desc") else str(exception) diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_dataset.py index 6d96c440a..ef24aa7f0 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset.py @@ -6,7 +6,6 @@ import json import logging -import os from copy import deepcopy from fnmatch import fnmatch from functools import partial @@ -24,130 +23,23 @@ get_filepath_str, get_protocol_and_path, ) -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import DataFrame from pyspark.sql.types import StructType from pyspark.sql.utils import AnalysisException from s3fs import S3FileSystem -logger = logging.getLogger(__name__) - - -def _get_spark() -> Any: - """ - Returns the SparkSession. In case databricks-connect is available we use it for - extended configuration mechanisms and notebook compatibility, - otherwise we use classic pyspark. - """ - try: - # When using databricks-connect >= 13.0.0 (a.k.a databricks-connect-v2) - # the remote session is instantiated using the databricks module - # If the databricks-connect module is installed, we use a remote session - from databricks.connect import DatabricksSession - - # We can't test this as there's no Databricks test env available - spark = DatabricksSession.builder.getOrCreate() # pragma: no cover - - except ImportError: - # For "normal" spark sessions that don't use databricks-connect - # we get spark normally - spark = SparkSession.builder.getOrCreate() - - return spark - - -def _parse_glob_pattern(pattern: str) -> str: - special = ("*", "?", "[") - clean = [] - for part in pattern.split("/"): - if any(char in part for char in special): - break - clean.append(part) - return "/".join(clean) - - -def _split_filepath(filepath: str | os.PathLike) -> tuple[str, str]: - split_ = str(filepath).split("://", 1) - if len(split_) == 2: # noqa: PLR2004 - return split_[0] + "://", split_[1] - return "", split_[0] - - -def _strip_dbfs_prefix(path: str, prefix: str = "/dbfs") -> str: - return path[len(prefix) :] if path.startswith(prefix) else path - - -def _dbfs_glob(pattern: str, dbutils: Any) -> list[str]: - """Perform a custom glob search in DBFS using the provided pattern. - It is assumed that version paths are managed by Kedro only. - - Args: - pattern: Glob pattern to search for. - dbutils: dbutils instance to operate with DBFS. - - Returns: - List of DBFS paths prefixed with '/dbfs' that satisfy the glob pattern. - """ - pattern = _strip_dbfs_prefix(pattern) - prefix = _parse_glob_pattern(pattern) - matched = set() - filename = pattern.split("/")[-1] - - for file_info in dbutils.fs.ls(prefix): - if file_info.isDir(): - path = str( - PurePosixPath(_strip_dbfs_prefix(file_info.path, "dbfs:")) / filename - ) - if fnmatch(path, pattern): - path = "/dbfs" + path - matched.add(path) - return sorted(matched) - - -def _get_dbutils(spark: SparkSession) -> Any: - """Get the instance of 'dbutils' or None if the one could not be found.""" - dbutils = globals().get("dbutils") - if dbutils: - return dbutils - - try: - from pyspark.dbutils import DBUtils - - dbutils = DBUtils(spark) - except ImportError: - try: - import IPython - except ImportError: - pass - else: - ipython = IPython.get_ipython() - dbutils = ipython.user_ns.get("dbutils") if ipython else None - - return dbutils - - -def _dbfs_exists(pattern: str, dbutils: Any) -> bool: - """Perform an `ls` list operation in DBFS using the provided pattern. - It is assumed that version paths are managed by Kedro. - Broad `Exception` is present due to `dbutils.fs.ExecutionError` that - cannot be imported directly. - Args: - pattern: Filepath to search for. - dbutils: dbutils instance to operate with DBFS. - Returns: - Boolean value if filepath exists. - """ - pattern = _strip_dbfs_prefix(pattern) - file = _parse_glob_pattern(pattern) - try: - dbutils.fs.ls(file) - return True - except Exception: - return False - +from kedro_datasets._utils.databricks_utils import ( + dbfs_exists, + dbfs_glob, + deployed_on_databricks, + get_dbutils, + parse_glob_pattern, + split_filepath, + strip_dbfs_prefix, +) +from kedro_datasets._utils.spark_utils import get_spark -def _deployed_on_databricks() -> bool: - """Check if running on Databricks.""" - return "DATABRICKS_RUNTIME_VERSION" in os.environ +logger = logging.getLogger(__name__) class KedroHdfsInsecureClient(InsecureClient): @@ -174,7 +66,7 @@ def hdfs_glob(self, pattern: str) -> list[str]: Returns: List of HDFS paths that satisfy the glob pattern. """ - prefix = _parse_glob_pattern(pattern) or "/" + prefix = parse_glob_pattern(pattern) or "/" matched = set() try: for dpath, _, fnames in self.walk(prefix): @@ -308,7 +200,7 @@ def __init__( # noqa: PLR0913 This is ignored by Kedro, but may be consumed by users or external plugins. """ credentials = deepcopy(credentials) or {} - fs_prefix, filepath = _split_filepath(filepath) + fs_prefix, filepath = split_filepath(filepath) path = PurePosixPath(filepath) exists_function = None glob_function = None @@ -317,7 +209,7 @@ def __init__( # noqa: PLR0913 if ( not filepath.startswith("/dbfs/") and fs_prefix not in (protocol + "://" for protocol in CLOUD_PROTOCOLS) - and _deployed_on_databricks() + and deployed_on_databricks() ): logger.warning( "Using SparkDataset on Databricks without the `/dbfs/` prefix in the " @@ -349,10 +241,10 @@ def __init__( # noqa: PLR0913 elif filepath.startswith("/dbfs/"): # dbfs add prefix to Spark path by default # See https://github.com/kedro-org/kedro-plugins/issues/117 - dbutils = _get_dbutils(_get_spark()) + dbutils = get_dbutils(get_spark()) if dbutils: - glob_function = partial(_dbfs_glob, dbutils=dbutils) - exists_function = partial(_dbfs_exists, dbutils=dbutils) + glob_function = partial(dbfs_glob, dbutils=dbutils) + exists_function = partial(dbfs_exists, dbutils=dbutils) else: filesystem = fsspec.filesystem(fs_prefix.strip("://"), **credentials) exists_function = filesystem.exists @@ -414,8 +306,8 @@ def _describe(self) -> dict[str, Any]: } def load(self) -> DataFrame: - load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) - read_obj = _get_spark().read + load_path = strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) + read_obj = get_spark().read # Pass schema if defined if self._schema: @@ -424,14 +316,14 @@ def load(self) -> DataFrame: return read_obj.load(load_path, self._file_format, **self._load_args) def save(self, data: DataFrame) -> None: - save_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_save_path())) + save_path = strip_dbfs_prefix(self._fs_prefix + str(self._get_save_path())) data.write.save(save_path, self._file_format, **self._save_args) def _exists(self) -> bool: - load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) + load_path = strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) try: - _get_spark().read.load(load_path, self._file_format) + get_spark().read.load(load_path, self._file_format) except AnalysisException as exception: # `AnalysisException.desc` is deprecated with pyspark >= 3.4 message = exception.desc if hasattr(exception, "desc") else str(exception) diff --git a/kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py index 8908c0fac..3b10657de 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py @@ -11,7 +11,7 @@ from pyspark.sql import DataFrame, Window from pyspark.sql.functions import col, lit, row_number -from kedro_datasets.spark.spark_dataset import _get_spark +from kedro_datasets._utils.spark_utils import get_spark class SparkHiveDataset(AbstractDataset[DataFrame, DataFrame]): @@ -149,7 +149,7 @@ def _create_hive_table(self, data: DataFrame, mode: str | None = None): ) def load(self) -> DataFrame: - return _get_spark().read.table(self._full_table_address) + return get_spark().read.table(self._full_table_address) def save(self, data: DataFrame) -> None: self._validate_save(data) @@ -201,7 +201,7 @@ def _validate_save(self, data: DataFrame): def _exists(self) -> bool: return ( - _get_spark() + get_spark() ._jsparkSession.catalog() .tableExists(self._database, self._table) ) diff --git a/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py index 36aec10ad..eae504ef3 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py @@ -6,7 +6,7 @@ from kedro.io.core import AbstractDataset, DatasetError from pyspark.sql import DataFrame -from kedro_datasets.spark.spark_dataset import _get_spark +from kedro_datasets._utils.spark_utils import get_spark class SparkJDBCDataset(AbstractDataset[DataFrame, DataFrame]): @@ -167,7 +167,7 @@ def _describe(self) -> dict[str, Any]: } def load(self) -> DataFrame: - return _get_spark().read.jdbc(self._url, self._table, **self._load_args) + return get_spark().read.jdbc(self._url, self._table, **self._load_args) def save(self, data: DataFrame) -> None: return data.write.jdbc(self._url, self._table, **self._save_args) diff --git a/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py index 294ea35cb..623d89b7e 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py @@ -8,12 +8,9 @@ from pyspark.sql import DataFrame from pyspark.sql.utils import AnalysisException -from kedro_datasets.spark.spark_dataset import ( - SparkDataset, - _get_spark, - _split_filepath, - _strip_dbfs_prefix, -) +from kedro_datasets._utils.databricks_utils import split_filepath, strip_dbfs_prefix +from kedro_datasets._utils.spark_utils import get_spark +from kedro_datasets.spark.spark_dataset import SparkDataset class SparkStreamingDataset(AbstractDataset): @@ -80,7 +77,7 @@ def __init__( # noqa: PLR0913 self._load_args = load_args self.metadata = metadata - fs_prefix, filepath = _split_filepath(filepath) + fs_prefix, filepath = split_filepath(filepath) self._fs_prefix = fs_prefix self._filepath = PurePosixPath(filepath) @@ -111,9 +108,9 @@ def load(self) -> DataFrame: Returns: Data from filepath as pyspark dataframe. """ - load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) + load_path = strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) data_stream_reader = ( - _get_spark() + get_spark() .readStream.schema(self._schema) .format(self._file_format) .options(**self._load_args) @@ -125,7 +122,7 @@ def save(self, data: DataFrame) -> None: Args: data: PySpark streaming dataframe for saving """ - save_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) + save_path = strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) output_constructor = data.writeStream.format(self._file_format) output_mode = ( self._save_args.pop("output_mode", None) if self._save_args else None @@ -142,10 +139,10 @@ def save(self, data: DataFrame) -> None: ) def _exists(self) -> bool: - load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) + load_path = strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) try: - _get_spark().readStream.schema(self._schema).load( + get_spark().readStream.schema(self._schema).load( load_path, self._file_format ) except AnalysisException as exception: diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 67e8e38f9..9ae3af9aa 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -37,7 +37,7 @@ dask-csvdataset = ["dask[dataframe]>=2021.10"] dask-parquetdataset = ["dask[complete]>=2021.10", "triad>=0.6.7, <1.0"] dask = ["kedro-datasets[dask-parquetdataset, dask-csvdataset]"] -databricks-managedtabledataset = ["kedro-datasets[spark-base,pandas-base,delta-base,hdfs-base,s3fs-base]"] +databricks-managedtabledataset = ["kedro-datasets[hdfs-base,s3fs-base]"] databricks = ["kedro-datasets[databricks-managedtabledataset]"] geopandas-genericdataset = ["geopandas>=0.8.0, <2.0", "fiona >=1.8, <2.0"] @@ -171,7 +171,7 @@ yaml-yamldataset = ["kedro-datasets[pandas-base]", "PyYAML>=4.2, <7.0"] yaml = ["kedro-datasets[yaml-yamldataset]"] # Experimental Datasets -databricks-externaltabledataset = ["kedro-datasets[spark-base,pandas-base,delta-base,hdfs-base,s3fs-base]"] +databricks-externaltabledataset = ["kedro-datasets[hdfs-base,s3fs-base]"] langchain-chatopenaidataset = ["langchain-openai~=0.1.7"] langchain-openaiembeddingsdataset = ["langchain-openai~=0.1.7"] langchain-chatanthropicdataset = ["langchain-anthropic~=0.1.13", "langchain-community~=0.2.0"] diff --git a/kedro-datasets/tests/spark/test_deltatable_dataset.py b/kedro-datasets/tests/spark/test_deltatable_dataset.py index 938e90a31..613251c5f 100644 --- a/kedro-datasets/tests/spark/test_deltatable_dataset.py +++ b/kedro-datasets/tests/spark/test_deltatable_dataset.py @@ -72,12 +72,12 @@ def test_exists_raises_error(self, mocker): delta_ds = DeltaTableDataset(filepath="") if SPARK_VERSION >= Version("3.4.0"): mocker.patch( - "kedro_datasets.spark.deltatable_dataset._get_spark", + "kedro_datasets.spark.deltatable_dataset.get_spark", side_effect=AnalysisException("Other Exception"), ) else: mocker.patch( - "kedro_datasets.spark.deltatable_dataset._get_spark", + "kedro_datasets.spark.deltatable_dataset.get_spark", side_effect=AnalysisException("Other Exception", []), ) with pytest.raises(DatasetError, match="Other Exception"): diff --git a/kedro-datasets/tests/spark/test_spark_dataset.py b/kedro-datasets/tests/spark/test_spark_dataset.py index bc40f9512..c96c547b2 100644 --- a/kedro-datasets/tests/spark/test_spark_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_dataset.py @@ -26,14 +26,14 @@ ) from pyspark.sql.utils import AnalysisException +from kedro_datasets._utils.databricks_utils import ( + dbfs_exists, + dbfs_glob, + get_dbutils, +) from kedro_datasets.pandas import CSVDataset, ParquetDataset from kedro_datasets.pickle import PickleDataset from kedro_datasets.spark import SparkDataset -from kedro_datasets.spark.spark_dataset import ( - _dbfs_exists, - _dbfs_glob, - _get_dbutils, -) FOLDER_NAME = "fake_folder" FILENAME = "test.parquet" @@ -410,12 +410,12 @@ def test_exists_raises_error(self, mocker): spark_dataset = SparkDataset(filepath="") if SPARK_VERSION >= PackagingVersion("3.4.0"): mocker.patch( - "kedro_datasets.spark.spark_dataset._get_spark", + "kedro_datasets.spark.spark_dataset.get_spark", side_effect=AnalysisException("Other Exception"), ) else: mocker.patch( - "kedro_datasets.spark.spark_dataset._get_spark", + "kedro_datasets.spark.spark_dataset.get_spark", side_effect=AnalysisException("Other Exception", []), ) with pytest.raises(DatasetError, match="Other Exception"): @@ -636,7 +636,7 @@ def test_dbfs_glob(self, mocker): pattern = "/tmp/file/*/file" expected = ["/dbfs/tmp/file/date1/file", "/dbfs/tmp/file/date2/file"] - result = _dbfs_glob(pattern, dbutils_mock) + result = dbfs_glob(pattern, dbutils_mock) assert result == expected dbutils_mock.fs.ls.assert_called_once_with("/tmp/file") @@ -650,15 +650,15 @@ def test_dbfs_exists(self, mocker): FileInfo("/tmp/file/"), ] - assert _dbfs_exists(test_path, dbutils_mock) + assert dbfs_exists(test_path, dbutils_mock) # add side effect to test that non-existence is handled dbutils_mock.fs.ls.side_effect = Exception() - assert not _dbfs_exists(test_path, dbutils_mock) + assert not dbfs_exists(test_path, dbutils_mock) def test_ds_init_no_dbutils(self, mocker): get_dbutils_mock = mocker.patch( - "kedro_datasets.spark.spark_dataset._get_dbutils", + "kedro_datasets.spark.spark_dataset.get_dbutils", return_value=None, ) @@ -669,7 +669,7 @@ def test_ds_init_no_dbutils(self, mocker): def test_ds_init_dbutils_available(self, mocker): get_dbutils_mock = mocker.patch( - "kedro_datasets.spark.spark_dataset._get_dbutils", + "kedro_datasets.spark.spark_dataset.get_dbutils", return_value="mock", ) @@ -677,23 +677,23 @@ def test_ds_init_dbutils_available(self, mocker): get_dbutils_mock.assert_called_once() assert dataset._glob_function.__class__.__name__ == "partial" - assert dataset._glob_function.func.__name__ == "_dbfs_glob" + assert dataset._glob_function.func.__name__ == "dbfs_glob" assert dataset._glob_function.keywords == { "dbutils": get_dbutils_mock.return_value } def test_get_dbutils_from_globals(self, mocker): mocker.patch( - "kedro_datasets.spark.spark_dataset.globals", + "kedro_datasets._utils.databricks_utils.globals", return_value={"dbutils": "dbutils_from_globals"}, ) - assert _get_dbutils("spark") == "dbutils_from_globals" + assert get_dbutils("spark") == "dbutils_from_globals" def test_get_dbutils_from_pyspark(self, mocker): dbutils_mock = mocker.Mock() dbutils_mock.DBUtils.return_value = "dbutils_from_pyspark" mocker.patch.dict("sys.modules", {"pyspark.dbutils": dbutils_mock}) - assert _get_dbutils("spark") == "dbutils_from_pyspark" + assert get_dbutils("spark") == "dbutils_from_pyspark" dbutils_mock.DBUtils.assert_called_once_with("spark") def test_get_dbutils_from_ipython(self, mocker): @@ -702,13 +702,13 @@ def test_get_dbutils_from_ipython(self, mocker): "dbutils": "dbutils_from_ipython" } mocker.patch.dict("sys.modules", {"IPython": ipython_mock}) - assert _get_dbutils("spark") == "dbutils_from_ipython" + assert get_dbutils("spark") == "dbutils_from_ipython" ipython_mock.get_ipython.assert_called_once_with() def test_get_dbutils_no_modules(self, mocker): mocker.patch("kedro_datasets.spark.spark_dataset.globals", return_value={}) mocker.patch.dict("sys.modules", {"pyspark": None, "IPython": None}) - assert _get_dbutils("spark") is None + assert get_dbutils("spark") is None @pytest.mark.parametrize("os_name", ["nt", "posix"]) def test_regular_path_in_different_os(self, os_name, mocker): @@ -737,7 +737,7 @@ def test_no_version(self, versioned_dataset_s3): def test_load_latest(self, mocker, versioned_dataset_s3): get_spark = mocker.patch( - "kedro_datasets.spark.spark_dataset._get_spark", + "kedro_datasets.spark.spark_dataset.get_spark", ) mocked_glob = mocker.patch.object(versioned_dataset_s3, "_glob_function") mocked_glob.return_value = [ @@ -762,7 +762,7 @@ def test_load_exact(self, mocker): version=Version(ts, None), ) get_spark = mocker.patch( - "kedro_datasets.spark.spark_dataset._get_spark", + "kedro_datasets.spark.spark_dataset.get_spark", ) ds_s3.load() @@ -857,7 +857,7 @@ def test_load_latest(self, mocker, version): versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) get_spark = mocker.patch( - "kedro_datasets.spark.spark_dataset._get_spark", + "kedro_datasets.spark.spark_dataset.get_spark", ) versioned_hdfs.load() @@ -876,7 +876,7 @@ def test_load_exact(self, mocker): filepath=f"hdfs://{HDFS_PREFIX}", version=Version(ts, None) ) get_spark = mocker.patch( - "kedro_datasets.spark.spark_dataset._get_spark", + "kedro_datasets.spark.spark_dataset.get_spark", ) versioned_hdfs.load() diff --git a/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py b/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py index 119c82164..af9be9cac 100644 --- a/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py @@ -88,7 +88,7 @@ def test_except_bad_credentials(mocker, spark_jdbc_args_credentials_with_none_pa def test_load(mocker, spark_jdbc_args): spark = mocker.patch( - "kedro_datasets.spark.spark_jdbc_dataset._get_spark" + "kedro_datasets.spark.spark_jdbc_dataset.get_spark" ).return_value dataset = SparkJDBCDataset(**spark_jdbc_args) dataset.load() @@ -97,7 +97,7 @@ def test_load(mocker, spark_jdbc_args): def test_load_credentials(mocker, spark_jdbc_args_credentials): spark = mocker.patch( - "kedro_datasets.spark.spark_jdbc_dataset._get_spark" + "kedro_datasets.spark.spark_jdbc_dataset.get_spark" ).return_value dataset = SparkJDBCDataset(**spark_jdbc_args_credentials) dataset.load() @@ -110,7 +110,7 @@ def test_load_credentials(mocker, spark_jdbc_args_credentials): def test_load_args(mocker, spark_jdbc_args_save_load): spark = mocker.patch( - "kedro_datasets.spark.spark_jdbc_dataset._get_spark" + "kedro_datasets.spark.spark_jdbc_dataset.get_spark" ).return_value dataset = SparkJDBCDataset(**spark_jdbc_args_save_load) dataset.load() diff --git a/kedro-datasets/tests/spark/test_spark_streaming_dataset.py b/kedro-datasets/tests/spark/test_spark_streaming_dataset.py index 330c8d10d..4c44c31e2 100644 --- a/kedro-datasets/tests/spark/test_spark_streaming_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_streaming_dataset.py @@ -174,12 +174,12 @@ def test_exists_raises_error(self, mocker): if SPARK_VERSION >= Version("3.4.0"): mocker.patch( - "kedro_datasets.spark.spark_streaming_dataset._get_spark", + "kedro_datasets.spark.spark_streaming_dataset.get_spark", side_effect=AnalysisException("Other Exception"), ) else: mocker.patch( - "kedro_datasets.spark.spark_streaming_dataset._get_spark", + "kedro_datasets.spark.spark_streaming_dataset.get_spark", side_effect=AnalysisException("Other Exception", []), ) with pytest.raises(DatasetError, match="Other Exception"):