Skip to content

Commit

Permalink
Merge pull request #75 from vincentclaes/69-directed-graph
Browse files Browse the repository at this point in the history
create dag for stepfunctions workflow
  • Loading branch information
vincentclaes authored Jun 21, 2021
2 parents a0662d9 + 336adec commit bce4d45
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 281 deletions.
42 changes: 29 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ with DataJobStack(scope=app, id="data-pipeline-simple") as datajob_stack:

# We define 2 glue jobs with the relative path to the source code.
task1 = GlueJob(
datajob_stack=datajob_stack, name="task1", job_path="glue_jobs/task1.py"
datajob_stack=datajob_stack, name="task1", job_path="glue_jobs/task.py"
)
task2 = GlueJob(
datajob_stack=datajob_stack, name="task2", job_path="glue_jobs/task2.py"
Expand Down Expand Up @@ -76,6 +76,8 @@ cdk bootstrap aws://$AWS_ACCOUNT/$AWS_DEFAULT_REGION

### Deploy

Deploy the pipeline using CDK.

```shell
cd examples/data_pipeline_simple
cdk deploy --app "python datajob_stack.py"
Expand All @@ -99,6 +101,20 @@ cdk destroy --app "python datajob_stack.py"
# Functionality

<details>
<summary>Deploy to a stage</summary>

Specify a stage to deploy an isolated pipeline.

Typical examples would be `dev` , `prod`, ...

```shell
cdk deploy --app "python datajob_stack.py" --stage my-stage
```

</details>

<details>

<summary>Using datajob's S3 data bucket</summary>

Dynamically reference the `datajob_stack` data bucket name to the arguments of your GlueJob by calling
Expand Down Expand Up @@ -140,14 +156,6 @@ with DataJobStack(

```

deploy to stage `my-stage`:

```shell
datajob deploy --config datajob_stack.py --stage my-stage --package setuppy
```

`datajob_stack.context.data_bucket_name` will evaluate to `datajob-python-pyspark-my-stage`

you can find this example [here](./examples/data_pipeline_pyspark/glue_job/glue_pyspark_example.py)

</details>
Expand Down Expand Up @@ -229,10 +237,18 @@ full example can be found in [examples/data_pipeline_pyspark](examples/data_pipe
<summary>Orchestrate stepfunctions tasks in parallel</summary>

```python
# task1 and task2 are orchestrated in parallel.
# task3 will only start when both task1 and task2 have succeeded.
[task1, task2] >> task3
# Task2 comes after task1. task4 comes after task3.
# Task 5 depends on both task2 and task4 to be finished.
# Therefore task1 and task2 can run in parallel,
# as well as task3 and task4.
with StepfunctionsWorkflow(datajob_stack=datajob_stack, name="workflow") as sfn:
task1 >> task2
task3 >> task4
task2 >> task5
task4 >> task5

```
More can be found in [examples/data_pipeline_parallel](./examples/data_pipeline_parallel)

</details>

Expand Down Expand Up @@ -323,7 +339,7 @@ app = core.App()
datajob_stack = DataJobStack(scope=app, id="data-pipeline-pkg", project_root=current_dir)
datajob_stack.init_datajob_context()

task1 = GlueJob(datajob_stack=datajob_stack, name="task1", job_path="glue_jobs/task1.py")
task1 = GlueJob(datajob_stack=datajob_stack, name="task1", job_path="glue_jobs/task.py")
task2 = GlueJob(datajob_stack=datajob_stack, name="task2", job_path="glue_jobs/task2.py")

with StepfunctionsWorkflow(datajob_stack=datajob_stack, name="workflow") as step_functions_workflow:
Expand Down
149 changes: 80 additions & 69 deletions datajob/stepfunctions/stepfunctions_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
import tempfile
import uuid
from pathlib import Path
from typing import Union
from typing import Union, Iterator

import boto3
import contextvars
import toposort

from collections import defaultdict


from aws_cdk import aws_iam as iam
from aws_cdk import cloudformation_include as cfn_inc
from aws_cdk import core
Expand All @@ -23,8 +28,13 @@
__workflow = contextvars.ContextVar("workflow")


class StepfunctionsWorkflowException(object):
pass


class StepfunctionsWorkflow(DataJobBase):
"""Class that defines the methods to create and execute an orchestration using the step functions sdk.
"""Class that defines the methods to create and execute an orchestration
using the step functions sdk.
example:
Expand All @@ -33,7 +43,6 @@ class StepfunctionsWorkflow(DataJobBase):
some-glue-job-1 >> [some-glue-job-2,some-glue-job-3] >> some-glue-job-4
tech_skills_parser_orchestration.execute()
"""

def __init__(
Expand All @@ -46,8 +55,8 @@ def __init__(
**kwargs,
):
super().__init__(datajob_stack, name, **kwargs)
self.chain_of_tasks = []
self.workflow = None
self.chain_of_tasks = None
self.role = (
self.get_role(
unique_name=self.unique_name, service_principal="states.amazonaws.com"
Expand All @@ -59,54 +68,88 @@ def __init__(
region if region is not None else os.environ.get("AWS_DEFAULT_REGION")
)
self.notification = self._setup_notification(notification)
# init directed graph dict where values are a set.
# we do it like this so that we can use toposort.
self.directed_graph = defaultdict(set)

def add_task(self, task_other):
def add_task(self, some_task: DataJobBase) -> GlueStartJobRunStep:
"""add a task to the workflow we would like to orchestrate."""
job_name = task_other.unique_name
job_name = some_task.unique_name
logger.debug(f"adding task with name {job_name}")
task = StepfunctionsWorkflow._create_glue_start_job_run_step(job_name=job_name)
self.chain_of_tasks.append(task)
return StepfunctionsWorkflow._create_glue_start_job_run_step(job_name=job_name)

def add_parallel_tasks(self, task_others):
"""add tasks in parallel (wrapped in a list) to the workflow we would like to orchestrate."""
deploy_pipelines = Parallel(state_id=uuid.uuid4().hex)
for one_other_task in task_others:
def add_parallel_tasks(self, parallel_tasks: Iterator[DataJobBase]) -> Parallel:
"""add tasks in parallel (wrapped in a list) to the workflow we would
like to orchestrate."""
parallel_pipelines = Parallel(state_id=uuid.uuid4().hex)
for one_other_task in parallel_tasks:
task_unique_name = one_other_task.unique_name
logger.debug(f"adding parallel task with name {task_unique_name}")
deploy_pipelines.add_branch(
parallel_pipelines.add_branch(
StepfunctionsWorkflow._create_glue_start_job_run_step(
job_name=task_unique_name
)
)
self.chain_of_tasks.append(deploy_pipelines)
return parallel_pipelines

@staticmethod
def _create_glue_start_job_run_step(job_name):
def _create_glue_start_job_run_step(job_name: str) -> GlueStartJobRunStep:
logger.debug("creating a step for a glue job.")
return GlueStartJobRunStep(
job_name, wait_for_completion=True, parameters={"JobName": job_name}
)

def _construct_toposorted_chain_of_tasks(self) -> steps.Chain:
"""Take the directed graph and toposort so that we can efficiently
organize our workflow, i.e. parallelize where possible.
if we have 2 elements where one of both is an Ellipsis object we need to orchestrate just 1 job.
In the other case we will loop over the toposorted dag and assign a stepfunctions task
or assign multiple tasks in parallel.
Returns: toposorted chain of tasks
"""
self.chain_of_tasks = steps.Chain()
directed_graph_toposorted = list(toposort.toposort(self.directed_graph))
# if we have length of 2 and the second is an Ellipsis object we have scheduled 1 task.
if len(directed_graph_toposorted) == 2 and isinstance(
list(directed_graph_toposorted[1])[0], type(Ellipsis)
):
task = self.add_task(next(iter(directed_graph_toposorted[0])))
self.chain_of_tasks.append(task)
else:
for element in directed_graph_toposorted:
if len(element) == 1:
task = self.add_task(next(iter(element)))
elif len(element) > 1:
task = self.add_parallel_tasks(element)
else:
raise StepfunctionsWorkflowException(
"cannot have an index in the directed graph with 0 elements"
)
self.chain_of_tasks.append(task)
return self.chain_of_tasks

def _build_workflow(self):
"""create a step functions workflow from the chain_of_tasks."""
self.chain_of_tasks = self._construct_toposorted_chain_of_tasks()
logger.debug(
f"creating a chain from all the different steps. \n {self.chain_of_tasks}"
)
workflow_definition = steps.Chain(self.chain_of_tasks)
workflow_definition = self._integrate_notification_in_workflow(
workflow_definition=workflow_definition
self.chain_of_tasks = self._integrate_notification_in_workflow(
chain_of_tasks=self.chain_of_tasks
)
logger.debug(f"creating a workflow with name {self.unique_name}")
self.client = boto3.client("stepfunctions")
self.workflow = Workflow(
name=self.unique_name,
definition=workflow_definition,
definition=self.chain_of_tasks,
role=self.role.role_arn,
client=self.client,
)

def create(self):
"""create sfn stack"""
"""create sfn stack."""
with tempfile.TemporaryDirectory() as tmp_dir:
sfn_cf_file_path = str(Path(tmp_dir, self.unique_name))
with open(sfn_cf_file_path, "w") as text_file:
Expand All @@ -126,12 +169,13 @@ def _setup_notification(
return SnsTopic(self.datajob_stack, name, notification)

def _integrate_notification_in_workflow(
self, workflow_definition: steps.Chain
self, chain_of_tasks: steps.Chain
) -> steps.Chain:
"""If a notification is defined we configure an SNS with email subscription to alert the user
if the stepfunctions workflow failed or succeeded.
"""If a notification is defined we configure an SNS with email
subscription to alert the user if the stepfunctions workflow failed or
succeeded.
:param workflow_definition: the workflow definition that contains all the steps we want to execute.
:param chain_of_tasks: the workflow definition that contains all the steps we want to execute.
:return: if notification is set, we adapt the workflow to include an SnsPublishStep on failure or on success.
If notification is not set, we return the workflow as we received it.
"""
Expand Down Expand Up @@ -159,14 +203,14 @@ def _integrate_notification_in_workflow(
error_equals=["States.ALL"], next_step=failure_notification
)
workflow_with_notification = Parallel(state_id="notification")
workflow_with_notification.add_branch(workflow_definition)
workflow_with_notification.add_branch(chain_of_tasks)
workflow_with_notification.add_catch(catch_error)
workflow_with_notification.next(pass_notification)
return steps.Chain([workflow_with_notification])
logger.debug(
"No notification is configured, returning the workflow definition."
)
return workflow_definition
return chain_of_tasks

def __enter__(self):
"""first steps we have to do when entering the context manager."""
Expand All @@ -182,9 +226,9 @@ def __exit__(self, exc_type, exc_value, traceback):


def task(self):
"""
Task that can configured in the orchestration of a StepfunctionsWorkflow.
You have to use this as a decorator for any class that you want to use in the orchestration.
"""Task that can configured in the orchestration of a
StepfunctionsWorkflow. You have to use this as a decorator for any class
that you want to use in the orchestration.
example:
Expand All @@ -202,22 +246,14 @@ class GlueJob(core.Construct):
"""

def __rshift__(self, other, *args, **kwargs):
"""called when doing
- task1 >> task2
- task1 >> [task2,task3]
"""called when doing task1 >> task2.
Syntactic suggar for >>.
"""
_handle_first(self=self)
_connect(other)
return self
_connect(self=self, other=other)
return other

setattr(self, "__rshift__", __rshift__)

def __rrshift__(other, self, *args, **kwargs):
"""Called for [task1, task2] >> task3 because list don't have __rshift__ operators.
Therefore we reverse the order of the arguments and call __rshift__"""
__rshift__(self=self, other=other)

setattr(self, "__rrshift__", __rrshift__)
return self


Expand All @@ -232,31 +268,6 @@ def _get_workflow():
return None


def _handle_first(self):
work_flow = _get_workflow()
if not work_flow.chain_of_tasks:
_connect(self)


def _connect(job):
if isinstance(job, list):
logger.debug("we have a list, so these are jobs orchestrated in parallel.")
_connect_parallel_job(job)
elif isinstance(job, type(Ellipsis)):
logger.debug("we have an ellipsis object, do nothing ...")
return
else:
logger.debug("default action is to connect a single job.")
_connect_single_job(job)


def _connect_parallel_job(job):
work_flow = _get_workflow()
work_flow.add_parallel_tasks(job)
return job


def _connect_single_job(job):
def _connect(self, other):
work_flow = _get_workflow()
work_flow.add_task(job)
return job
work_flow.directed_graph[other].add(self)
Loading

0 comments on commit bce4d45

Please sign in to comment.