Skip to content

Commit

Permalink
Merge pull request #24 from vincentclaes/17-easy-access-context
Browse files Browse the repository at this point in the history
easy access to context
  • Loading branch information
vincentclaes authored Jan 20, 2021
2 parents c90c8db + 637f940 commit 3548bf6
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 18 deletions.
2 changes: 1 addition & 1 deletion datajob/datajob_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, datajob_stack, name, **kwargs):
self.project_root = datajob_stack.project_root
self.stage = datajob_stack.stage
self.unique_name = f"{datajob_stack.unique_stack_name}-{self.name}"
self.datajob_context = datajob_stack.datajob_context
self.context = datajob_stack.context
logger.info(f"adding job {self} to stack workflow resources")
datajob_stack.resources.append(self)

Expand Down
10 changes: 5 additions & 5 deletions datajob/datajob_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
include_folder: str = None,
account: str = None,
region: str = None,
scope: core.Construct = core.App(),
scope: core.Construct = None,
**kwargs,
) -> None:
"""
Expand All @@ -36,21 +36,21 @@ def __init__(
)
region = region if region is not None else os.environ.get("AWS_DEFAULT_REGION")
env = {"region": region, "account": account}
self.scope = scope
self.scope = scope if scope else core.App()
self.stage = self.get_stage(stage)
self.unique_stack_name = self._create_unique_stack_name(stack_name, self.stage)
super().__init__(scope=scope, id=self.unique_stack_name, env=env, **kwargs)
self.project_root = project_root
self.include_folder = include_folder
self.resources = []
self.datajob_context = None
self.context = None

def __enter__(self):
"""
As soon as we enter the contextmanager, we create the datajob context.
:return: datajob stack.
"""
self.datajob_context = DatajobContext(
self.context = DatajobContext(
self,
unique_stack_name=self.unique_stack_name,
project_root=self.project_root,
Expand All @@ -74,7 +74,7 @@ def __exit__(self, exc_type, exc_value, traceback):

def add(self, task: str) -> None:
setattr(self, task.unique_name, task)
task.create(datajob_context=self.datajob_context)
task.create()

@staticmethod
def _create_unique_stack_name(stack_name: str, stage: str) -> str:
Expand Down
24 changes: 13 additions & 11 deletions datajob/glue/glue_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def __init__(

def create(self):
s3_url_glue_job = self._deploy_glue_job_code(
datajob_context=self.datajob_context,
context=self.context,
glue_job_name=self.unique_name,
path_to_glue_job=self.job_path,
)
self._create_glue_job(
datajob_context=self.datajob_context,
context=self.context,
glue_job_name=self.unique_name,
s3_url_glue_job=s3_url_glue_job,
arguments=self.arguments,
Expand Down Expand Up @@ -136,16 +136,18 @@ def _get_role(self, role: iam.Role, unique_name: str) -> iam.Role:

@staticmethod
def _create_s3_url_for_job(
glue_job_context: DatajobContext, glue_job_id: str, glue_job_file_name: str
context: DatajobContext, glue_job_id: str, glue_job_file_name: str
) -> str:
"""
construct the path to s3 where the code resides of the glue job..
:param glue_job_context: DatajobContext that contains the name of the deployment bucket.
:param context: DatajobContext that contains the name of the deployment bucket.
:param glue_job_id:
:param glue_job_file_name:
:return:
"""
s3_url_glue_job = f"s3://{glue_job_context.deployment_bucket_name}/{glue_job_id}/{glue_job_file_name}"
s3_url_glue_job = (
f"s3://{context.deployment_bucket_name}/{glue_job_id}/{glue_job_file_name}"
)
logger.debug(f"s3 url for glue job {glue_job_id}: {s3_url_glue_job}")
return s3_url_glue_job

Expand All @@ -164,7 +166,7 @@ def _get_glue_job_dir_and_file_name(path_to_glue_job: str) -> tuple:
return glue_job_dir, glue_job_file_name

def _deploy_glue_job_code(
self, datajob_context: DatajobContext, glue_job_name: str, path_to_glue_job: str
self, context: DatajobContext, glue_job_name: str, path_to_glue_job: str
) -> str:
"""deploy the code of this glue job to the deployment bucket
(can be found in the glue context object)"""
Expand All @@ -181,20 +183,20 @@ def _deploy_glue_job_code(
# todo - sync only the glue job itself.
aws_s3_deployment.Source.asset(glue_job_dir)
],
destination_bucket=datajob_context.deployment_bucket,
destination_bucket=context.deployment_bucket,
destination_key_prefix=glue_job_name,
)

s3_url_glue_job = GlueJob._create_s3_url_for_job(
glue_job_context=datajob_context,
context=context,
glue_job_id=glue_job_name,
glue_job_file_name=glue_job_file_name,
)
return s3_url_glue_job

def _create_glue_job(
self,
datajob_context: DatajobContext,
context: DatajobContext,
glue_job_name: str,
s3_url_glue_job: str = None,
arguments: dict = None,
Expand All @@ -209,10 +211,10 @@ def _create_glue_job(
paths to wheel and business logic and arguments"""
logger.debug(f"creating Glue Job {glue_job_name}")
default_arguments = None
if datajob_context.s3_url_wheel:
if context.s3_url_wheel:
extra_py_files = {
# path to the wheel of this project
"--extra-py-files": datajob_context.s3_url_wheel
"--extra-py-files": context.s3_url_wheel
}
default_arguments = {**extra_py_files, **arguments}
glue.CfnJob(
Expand Down
13 changes: 13 additions & 0 deletions datajob_tests/datajob_context_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import unittest
from datajob.datajob_stack import DataJobStack, DatajobContext


class DatajobContextTest(unittest.TestCase):
def test_datajob_context_initiates_without_error(self):
exception_ = None
try:
djs = DataJobStack(stack_name="some-stack-name")
DatajobContext(djs, unique_stack_name="some-unique-name")
except Exception as e:
exception_ = e
self.assertIsNone(exception_)
6 changes: 5 additions & 1 deletion datajob_tests/datajob_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@


class DatajobStackTest(unittest.TestCase):
def test_datajob_stack_with_no_stage(self):
def test_datajob_stack_initiates_without_error(self):
exception_ = None
try:
with DataJobStack(stack_name="some-stack-name") as djs:
pass
except Exception as e:
exception_ = e
self.assertIsNone(exception_)

def test_datajob_stack_with_no_stage(self):
with DataJobStack(stack_name="some-stack-name") as djs:
pass
self.assertEqual(djs.stage, DataJobStack.DEFAULT_STAGE)

0 comments on commit 3548bf6

Please sign in to comment.