Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds prototype spark caching code #726

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 41 additions & 15 deletions hamilton/experimental/h_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def write_feather(data: object, filepath: str, name: str) -> None:


@singledispatch
def read_feather(data: object, filepath: str) -> Any:
def read_feather(data: object, filepath: str, **kwargs) -> Any:
"""Reads from a feather file"""
raise NotImplementedError(f"No feather reader for type {type(data)} registered.")

Expand All @@ -45,7 +45,7 @@ def write_parquet(data: object, filepath: str, name: str) -> None:


@singledispatch
def read_parquet(data: object, filepath: str) -> Any:
def read_parquet(data: object, filepath: str, **kwargs) -> Any:
"""Reads from a parquet file"""
raise NotImplementedError(f"No parquet reader for type {type(data)} registered.")

Expand All @@ -57,7 +57,7 @@ def write_json(data: object, filepath: str, name: str) -> None:


@singledispatch
def read_json(data: object, filepath: str) -> Any:
def read_json(data: object, filepath: str, **kwargs) -> Any:
"""Reads from a json file"""
raise NotImplementedError(f"No json reader for type {type(data)} registered.")

Expand All @@ -69,7 +69,7 @@ def write_pickle(data: Any, filepath: str, name: str) -> None:


@singledispatch
def read_pickle(data: Any, filtepath: str) -> object:
def read_pickle(data: Any, filtepath: str, **kwargs) -> object:
"""Reads from a pickle file"""
raise NotImplementedError(f"No object reader for type {type(data)} registered.")

Expand All @@ -89,13 +89,13 @@ def write_json_pd2(data: pd.Series, filepath: str, name: str) -> None:
return _df.to_json(filepath)

@read_json.register(pd.Series)
def read_json_pd1(data: pd.Series, filepath: str) -> pd.Series:
def read_json_pd1(data: pd.Series, filepath: str, **kwargs) -> pd.Series:
"""Reads a series from a feather file."""
_df = pd.read_json(filepath)
return _df[_df.columns[0]]

@read_json.register(pd.DataFrame)
def read_json_pd2(data: pd.DataFrame, filepath: str) -> pd.DataFrame:
def read_json_pd2(data: pd.DataFrame, filepath: str, **kwargs) -> pd.DataFrame:
"""Reads a dataframe from a feather file."""
return pd.read_json(filepath)

Expand All @@ -113,12 +113,12 @@ def write_feather_pd2(data: pd.Series, filepath: str, name: str) -> None:
data.to_frame(name=name).to_feather(filepath)

@read_feather.register(pd.DataFrame)
def read_feather_pd1(data: pd.DataFrame, filepath: str) -> pd.DataFrame:
def read_feather_pd1(data: pd.DataFrame, filepath: str, **kwargs) -> pd.DataFrame:
"""Reads a dataframe from a feather file."""
return pd.read_feather(filepath)

@read_feather.register(pd.Series)
def read_feather_pd2(data: pd.Series, filepath: str) -> pd.Series:
def read_feather_pd2(data: pd.Series, filepath: str, **kwargs) -> pd.Series:
"""Reads a series from a feather file."""
_df = pd.read_feather(filepath)
return _df[_df.columns[0]]
Expand All @@ -134,12 +134,12 @@ def write_parquet_pd2(data: pd.Series, filepath: str, name: str) -> None:
data.to_frame(name=name).to_parquet(filepath)

@read_parquet.register(pd.DataFrame)
def read_parquet_pd1(data: pd.DataFrame, filepath: str) -> pd.DataFrame:
def read_parquet_pd1(data: pd.DataFrame, filepath: str, **kwargs) -> pd.DataFrame:
"""Reads a dataframe from a parquet file."""
return pd.read_parquet(filepath)

@read_parquet.register(pd.Series)
def read_parquet_pd2(data: pd.Series, filepath: str) -> pd.Series:
def read_parquet_pd2(data: pd.Series, filepath: str, **kwargs) -> pd.Series:
"""Reads a series from a parquet file."""
_df = pd.read_parquet(filepath)
return _df[_df.columns[0]]
Expand All @@ -148,6 +148,23 @@ def read_parquet_pd2(data: pd.Series, filepath: str) -> pd.Series:
pass


except ImportError:
pass

try:
import pyspark.sql as ps

@write_parquet.register(ps.DataFrame)
def write_parquet_ps(data: ps.DataFrame, filepath: str) -> None:
"""Writes a pyspark dataframe to a parquet file."""
data.write.parquet(filepath, mode="overwrite")

@read_parquet.register(ps.DataFrame)
def read_parquet_ps(data: ps.DataFrame, filepath: str, **kwargs) -> pd.DataFrame:
"""Reads a dataframe from a parquet file."""
spark = kwargs["spark_session"]
return spark.read.parquet(filepath)

except ImportError:
pass

Expand All @@ -163,7 +180,7 @@ def write_json_dict(data: dict, filepath: str, name: str) -> None:


@read_json.register(dict)
def read_json_dict(data: dict, filepath: str) -> dict:
def read_json_dict(data: dict, filepath: str, **kwargs) -> dict:
"""Reads a dictionary from a JSON file."""
with open(filepath, "r", encoding="utf8") as file:
return json.load(file)
Expand All @@ -180,7 +197,7 @@ def write_pickle_object(data: object, filepath: str, name: str) -> None:


@read_pickle.register(object)
def read_pickle_object(data: object, filepath: str) -> object:
def read_pickle_object(data: object, filepath: str, **kwargs) -> object:
"""Reads a pickle file"""
print(filepath)
with open(filepath, "rb") as file:
Expand Down Expand Up @@ -213,7 +230,7 @@ class CachingGraphAdapter(SimplePythonGraphAdapter):
and `name` is the name of the node that is being written.

Reader functions need to have the following signature:
`def read_<format>(data: Any, filepath: str) -> Any: ...`
`def read_<format>(data: Any, filepath: str, **kwargs) -> Any: ...`
where `data` is an EMPTY OBJECT of the type you wish to instantiate, and `filepath` is the
path to the file to be read from.

Expand All @@ -227,7 +244,7 @@ def write_json_pd1(data: T, filepath: str, name: str) -> None:
...

@read_json.register(T)
def read_json_dict(data: T, filepath: str) -> T:
def read_json_dict(data: T, filepath: str, **kwargs) -> T:
...

Usage
Expand Down Expand Up @@ -280,6 +297,8 @@ def __init__(
force_compute: Optional[Set[str]] = None,
writers: Optional[Dict[str, Callable[[Any, str, str], None]]] = None,
readers: Optional[Dict[str, Callable[[Any, str], Any]]] = None,
read_kwargs: Optional[Dict[str, Any]] = None,
read_after_write: bool = False,
**kwargs,
):
"""Constructs the adapter.
Expand All @@ -288,6 +307,7 @@ def __init__(
:param force_compute: Set of nodes that should be forced to compute even if cache exists.
:param writers: A dictionary of writers for custom formats.
:param readers: A dictionary of readers for custom formats.
:param read_kwargs: A dictionary of keyword arguments to pass to the readers.
"""

super().__init__(*args, **kwargs)
Expand All @@ -298,6 +318,8 @@ def __init__(
self.writers = writers or {}
self.readers = readers or {}

self.read_kwargs = read_kwargs or {}
self.read_after_write = read_after_write
self._init_default_readers_writers()

def _init_default_readers_writers(self):
Expand Down Expand Up @@ -331,7 +353,7 @@ def _write_cache(self, fmt: str, data: Any, filepath: str, node_name: str) -> No

def _read_cache(self, fmt: str, expected_type: Any, filepath: str) -> None:
self._check_format(fmt)
return self.readers[fmt](expected_type, filepath)
return self.readers[fmt](expected_type, filepath, **self.read_kwargs)

def _get_empty_expected_type(self, expected_type: Type) -> Any:
if typing_inspect.is_generic_type(expected_type):
Expand Down Expand Up @@ -364,6 +386,10 @@ def execute_node(self, node: Node, kwargs: Dict[str, Any]) -> Any:
cache_format,
)
self._write_cache(cache_format, result, filepath, node.name)
if self.read_after_write:
# this could be useful for delayed execution type things as a means to reset
# that they have set internally
result = self._read_cache(cache_format, result, filepath)
self.computed_nodes.add(node.name)
return result
empty_expected_type = self._get_empty_expected_type(node.type)
Expand Down