Skip to content

Commit

Permalink
Merge pull request #83 from vincentclaes/81-aws-cdk-sfns
Browse files Browse the repository at this point in the history
use aws cdk stepfunctions
  • Loading branch information
vincentclaes authored Jun 22, 2021
2 parents f84258d + 37325e6 commit 8884143
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 3,242 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,4 @@ RUN /bin/bash -c "source $NVM_DIR/nvm.sh && nvm install $NODE_VERSION && nvm use
ENV NODE_PATH $NVM_DIR/versions/node/$NODE_VERSION/lib/node_modules
ENV PATH $NVM_DIR/versions/node/$NODE_VERSION/bin:$PATH
ENV AWS_DEFAULT_REGION=eu-west-1
RUN npm install -g aws-cdk@1.98.0
RUN npm install -g aws-cdk@latest
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ tests:
poetry run pytest

run-examples:
cd "${CURDIR}/examples/data_pipeline_simple" && poetry run datajob synthesize --config datajob_stack.py --stage dev
cd "${CURDIR}/examples/data_pipeline_with_packaged_project" && poetry run datajob synthesize --config datajob_stack.py --stage dev --package setuppy
cd "${CURDIR}/examples/data_pipeline_pyspark" && poetry run datajob synthesize --config datajob_stack.py --stage dev --package setuppy
cd "${CURDIR}/examples/data_pipeline_simple" && poetry run cdk synth --app "python datajob_stack.py"
cd "${CURDIR}/examples/data_pipeline_with_packaged_project" && poetry run python setup.py bdist_wheel && poetry run cdk synth --app "python datajob_stack.py"
#cd "${CURDIR}/examples/data_pipeline_pyspark" && poetry run python setup.py bdist_wheel && poetry run cdk synth --app "python datajob_stack.py"

gh-actions:
make install
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Beware that we depend on [aws cdk cli](https://github.com/aws/aws-cdk)!

pip install datajob
npm install -g aws-cdk@1.98.0 # latest version of datajob depends this version
npm install -g aws-cdk@1.109.0 # latest version of datajob depends this version

# Quickstart

Expand Down
51 changes: 25 additions & 26 deletions datajob/stepfunctions/stepfunctions_workflow.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,17 @@
import os
import tempfile
import uuid
from pathlib import Path
from collections import defaultdict
from typing import Union, Iterator

import boto3
import contextvars
import toposort

from collections import defaultdict


from aws_cdk import aws_iam as iam
from aws_cdk import cloudformation_include as cfn_inc
from aws_cdk import core
from stepfunctions.steps import Catch, Pass, Fail
from stepfunctions.steps.service import SnsPublishStep
from stepfunctions import steps
from aws_cdk.aws_stepfunctions import CfnStateMachine
from stepfunctions.steps import Catch, Chain
from stepfunctions.steps.compute import GlueStartJobRunStep
from stepfunctions.steps.service import SnsPublishStep
from stepfunctions.steps.states import Parallel
from stepfunctions.workflow import Workflow

Expand Down Expand Up @@ -99,7 +93,7 @@ def _create_glue_start_job_run_step(job_name: str) -> GlueStartJobRunStep:
job_name, wait_for_completion=True, parameters={"JobName": job_name}
)

def _construct_toposorted_chain_of_tasks(self) -> steps.Chain:
def _construct_toposorted_chain_of_tasks(self) -> Chain:
"""Take the directed graph and toposort so that we can efficiently
organize our workflow, i.e. parallelize where possible.
Expand All @@ -109,7 +103,7 @@ def _construct_toposorted_chain_of_tasks(self) -> steps.Chain:
Returns: toposorted chain of tasks
"""
self.chain_of_tasks = steps.Chain()
self.chain_of_tasks = Chain()
directed_graph_toposorted = list(toposort.toposort(self.directed_graph))
# if we have length of 2 and the second is an Ellipsis object we have scheduled 1 task.
if len(directed_graph_toposorted) == 2 and isinstance(
Expand Down Expand Up @@ -150,11 +144,16 @@ def _build_workflow(self):

def create(self):
"""create sfn stack."""
with tempfile.TemporaryDirectory() as tmp_dir:
sfn_cf_file_path = str(Path(tmp_dir, self.unique_name))
with open(sfn_cf_file_path, "w") as text_file:
text_file.write(self.workflow.get_cloudformation_template())
cfn_inc.CfnInclude(self, self.unique_name, template_file=sfn_cf_file_path)
import json

cfn_template = json.dumps(self.workflow.definition.to_dict())
CfnStateMachine(
scope=self.datajob_stack,
id=self.unique_name,
state_machine_name=self.unique_name,
role_arn=self.role.role_arn,
definition_string=cfn_template,
)

def _setup_notification(
self, notification: Union[str, list]
Expand All @@ -168,9 +167,7 @@ def _setup_notification(
name = f"{self.name}-notification"
return SnsTopic(self.datajob_stack, name, notification)

def _integrate_notification_in_workflow(
self, chain_of_tasks: steps.Chain
) -> steps.Chain:
def _integrate_notification_in_workflow(self, chain_of_tasks: Chain) -> Chain:
"""If a notification is defined we configure an SNS with email
subscription to alert the user if the stepfunctions workflow failed or
succeeded.
Expand Down Expand Up @@ -206,7 +203,7 @@ def _integrate_notification_in_workflow(
workflow_with_notification.add_branch(chain_of_tasks)
workflow_with_notification.add_catch(catch_error)
workflow_with_notification.next(pass_notification)
return steps.Chain([workflow_with_notification])
return Chain([workflow_with_notification])
logger.debug(
"No notification is configured, returning the workflow definition."
)
Expand All @@ -218,14 +215,14 @@ def __enter__(self):
_set_workflow(self)
return self

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""steps we have to do when exiting the context manager."""
self._build_workflow()
_set_workflow(None)
logger.info(f"step functions workflow {self.unique_name} created")


def task(self):
def task(self: DataJobBase) -> DataJobBase:
"""Task that can configured in the orchestration of a
StepfunctionsWorkflow. You have to use this as a decorator for any class
that you want to use in the orchestration.
Expand All @@ -245,7 +242,9 @@ class GlueJob(core.Construct):
glue_job_1 >> glue_job_2
"""

def __rshift__(self, other, *args, **kwargs):
def __rshift__(
self: DataJobBase, other: DataJobBase, *args, **kwargs
) -> DataJobBase:
"""called when doing task1 >> task2.
Syntactic suggar for >>.
Expand All @@ -257,7 +256,7 @@ def __rshift__(self, other, *args, **kwargs):
return self


def _set_workflow(workflow):
def _set_workflow(workflow: Workflow):
__workflow.set(workflow)


Expand All @@ -268,6 +267,6 @@ def _get_workflow():
return None


def _connect(self, other):
def _connect(self, other: DataJobBase) -> None:
work_flow = _get_workflow()
work_flow.directed_graph[other].add(self)
Loading

0 comments on commit 8884143

Please sign in to comment.