Skip to content

Commit

Permalink
Updates code with locations that need to change to have a generic wit…
Browse files Browse the repository at this point in the history
…h_columns decorator
  • Loading branch information
elijahbenizzy committed Apr 11, 2024
1 parent 3ca633e commit c370d2f
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions hamilton/plugins/h_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ def _fabricate_spark_function(
return FunctionType(func_code, {**globals(), **{"partial_fn": partial_fn}}, func_name)


# TODO -- change this to have a different implementation based on the dataframe type. This will have
# to likely be custom to each dataframe type
def _lambda_udf(df: DataFrame, node_: node.Node, actual_kwargs: Dict[str, Any]) -> DataFrame:
"""Function to create a lambda UDF for a function.
Expand Down Expand Up @@ -1080,12 +1082,16 @@ def create_selector_node(
"""

def new_callable(**kwargs) -> DataFrame:
# TODO -- change to have a `select` that's generic to the library
# Use the registry
return kwargs[upstream_name].select(*columns)

return node.Node(
name=node_name,
# TODO -- change to have the right dataframe type (from the registry)
typ=DataFrame,
callabl=new_callable,
# TODO -- change to have the right dataframe type (from the registry)
input_types={upstream_name: DataFrame},
)

Expand All @@ -1107,8 +1113,10 @@ def new_callable(**kwargs) -> DataFrame:

return node.Node(
name=node_name,
# TODO -- change to have the right dataframe type (from the registry)
typ=DataFrame,
callabl=new_callable,
# TODO -- change to have the right dataframe type (from the registry)
input_types={upstream_name: DataFrame},
)

Expand Down Expand Up @@ -1195,7 +1203,9 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node
column for column in node_.input_types if column in columns_passed_in_from_dataframe
}
# In the case that we are using pyspark UDFs
# TODO -- use the right dataframe type to do this correctly
if require_columns.is_decorated_pyspark_udf(node_):
# TODO -- change to use the right "sparkification" function that is dataframe-type-agnostic
sparkified = require_columns.sparkify_node(
node_,
current_dataframe_node,
Expand All @@ -1206,6 +1216,7 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node
)
# otherwise we're using pandas/primitive UDFs
else:
# TODO -- change to use the right "sparkification" function that is dataframe-type-agnostic
sparkified = sparkify_node_with_udf(
node_,
current_dataframe_node,
Expand Down

0 comments on commit c370d2f

Please sign in to comment.