diff --git a/datajob/datajob_base.py b/datajob/datajob_base.py index 26bd54e..8c42104 100644 --- a/datajob/datajob_base.py +++ b/datajob/datajob_base.py @@ -17,7 +17,7 @@ def __init__(self, datajob_stack, name, **kwargs): self.project_root = datajob_stack.project_root self.stage = datajob_stack.stage self.unique_name = f"{datajob_stack.unique_stack_name}-{self.name}" - self.datajob_context = datajob_stack.datajob_context + self.context = datajob_stack.context logger.info(f"adding job {self} to stack workflow resources") datajob_stack.resources.append(self) diff --git a/datajob/datajob_stack.py b/datajob/datajob_stack.py index b6799ff..4a22394 100644 --- a/datajob/datajob_stack.py +++ b/datajob/datajob_stack.py @@ -17,7 +17,7 @@ def __init__( include_folder: str = None, account: str = None, region: str = None, - scope: core.Construct = core.App(), + scope: core.Construct = None, **kwargs, ) -> None: """ @@ -36,21 +36,21 @@ def __init__( ) region = region if region is not None else os.environ.get("AWS_DEFAULT_REGION") env = {"region": region, "account": account} - self.scope = scope + self.scope = scope if scope else core.App() self.stage = self.get_stage(stage) self.unique_stack_name = self._create_unique_stack_name(stack_name, self.stage) super().__init__(scope=scope, id=self.unique_stack_name, env=env, **kwargs) self.project_root = project_root self.include_folder = include_folder self.resources = [] - self.datajob_context = None + self.context = None def __enter__(self): """ As soon as we enter the contextmanager, we create the datajob context. :return: datajob stack. """ - self.datajob_context = DatajobContext( + self.context = DatajobContext( self, unique_stack_name=self.unique_stack_name, project_root=self.project_root, @@ -74,7 +74,7 @@ def __exit__(self, exc_type, exc_value, traceback): def add(self, task: str) -> None: setattr(self, task.unique_name, task) - task.create(datajob_context=self.datajob_context) + task.create() @staticmethod def _create_unique_stack_name(stack_name: str, stage: str) -> str: diff --git a/datajob/glue/glue_job.py b/datajob/glue/glue_job.py index b265557..2ec5a30 100644 --- a/datajob/glue/glue_job.py +++ b/datajob/glue/glue_job.py @@ -65,12 +65,12 @@ def __init__( def create(self): s3_url_glue_job = self._deploy_glue_job_code( - datajob_context=self.datajob_context, + context=self.context, glue_job_name=self.unique_name, path_to_glue_job=self.job_path, ) self._create_glue_job( - datajob_context=self.datajob_context, + context=self.context, glue_job_name=self.unique_name, s3_url_glue_job=s3_url_glue_job, arguments=self.arguments, @@ -136,16 +136,18 @@ def _get_role(self, role: iam.Role, unique_name: str) -> iam.Role: @staticmethod def _create_s3_url_for_job( - glue_job_context: DatajobContext, glue_job_id: str, glue_job_file_name: str + context: DatajobContext, glue_job_id: str, glue_job_file_name: str ) -> str: """ construct the path to s3 where the code resides of the glue job.. - :param glue_job_context: DatajobContext that contains the name of the deployment bucket. + :param context: DatajobContext that contains the name of the deployment bucket. :param glue_job_id: :param glue_job_file_name: :return: """ - s3_url_glue_job = f"s3://{glue_job_context.deployment_bucket_name}/{glue_job_id}/{glue_job_file_name}" + s3_url_glue_job = ( + f"s3://{context.deployment_bucket_name}/{glue_job_id}/{glue_job_file_name}" + ) logger.debug(f"s3 url for glue job {glue_job_id}: {s3_url_glue_job}") return s3_url_glue_job @@ -164,7 +166,7 @@ def _get_glue_job_dir_and_file_name(path_to_glue_job: str) -> tuple: return glue_job_dir, glue_job_file_name def _deploy_glue_job_code( - self, datajob_context: DatajobContext, glue_job_name: str, path_to_glue_job: str + self, context: DatajobContext, glue_job_name: str, path_to_glue_job: str ) -> str: """deploy the code of this glue job to the deployment bucket (can be found in the glue context object)""" @@ -181,12 +183,12 @@ def _deploy_glue_job_code( # todo - sync only the glue job itself. aws_s3_deployment.Source.asset(glue_job_dir) ], - destination_bucket=datajob_context.deployment_bucket, + destination_bucket=context.deployment_bucket, destination_key_prefix=glue_job_name, ) s3_url_glue_job = GlueJob._create_s3_url_for_job( - glue_job_context=datajob_context, + context=context, glue_job_id=glue_job_name, glue_job_file_name=glue_job_file_name, ) @@ -194,7 +196,7 @@ def _deploy_glue_job_code( def _create_glue_job( self, - datajob_context: DatajobContext, + context: DatajobContext, glue_job_name: str, s3_url_glue_job: str = None, arguments: dict = None, @@ -209,10 +211,10 @@ def _create_glue_job( paths to wheel and business logic and arguments""" logger.debug(f"creating Glue Job {glue_job_name}") default_arguments = None - if datajob_context.s3_url_wheel: + if context.s3_url_wheel: extra_py_files = { # path to the wheel of this project - "--extra-py-files": datajob_context.s3_url_wheel + "--extra-py-files": context.s3_url_wheel } default_arguments = {**extra_py_files, **arguments} glue.CfnJob( diff --git a/datajob_tests/datajob_context_test.py b/datajob_tests/datajob_context_test.py new file mode 100644 index 0000000..5077fa0 --- /dev/null +++ b/datajob_tests/datajob_context_test.py @@ -0,0 +1,13 @@ +import unittest +from datajob.datajob_stack import DataJobStack, DatajobContext + + +class DatajobContextTest(unittest.TestCase): + def test_datajob_context_initiates_without_error(self): + exception_ = None + try: + djs = DataJobStack(stack_name="some-stack-name") + DatajobContext(djs, unique_stack_name="some-unique-name") + except Exception as e: + exception_ = e + self.assertIsNone(exception_) diff --git a/datajob_tests/datajob_stack_test.py b/datajob_tests/datajob_stack_test.py index 2e80606..c5053d2 100644 --- a/datajob_tests/datajob_stack_test.py +++ b/datajob_tests/datajob_stack_test.py @@ -3,7 +3,7 @@ class DatajobStackTest(unittest.TestCase): - def test_datajob_stack_with_no_stage(self): + def test_datajob_stack_initiates_without_error(self): exception_ = None try: with DataJobStack(stack_name="some-stack-name") as djs: @@ -11,4 +11,8 @@ def test_datajob_stack_with_no_stage(self): except Exception as e: exception_ = e self.assertIsNone(exception_) + + def test_datajob_stack_with_no_stage(self): + with DataJobStack(stack_name="some-stack-name") as djs: + pass self.assertEqual(djs.stage, DataJobStack.DEFAULT_STAGE)