diff --git a/airflow/dags/eval_srl_edrgen_readiness.py b/airflow/dags/eval_srl_edrgen_readiness.py index 95935005..e1f89d7a 100644 --- a/airflow/dags/eval_srl_edrgen_readiness.py +++ b/airflow/dags/eval_srl_edrgen_readiness.py @@ -1,209 +1,78 @@ -import hashlib -import json -import logging -import re +import os +import time from datetime import datetime -from typing import Dict, List -from urllib.parse import urlparse -import jsonschema -from airflow.decorators import task -from airflow.models import Variable -from airflow.operators.trigger_dagrun import TriggerDagRunOperator -from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from jinja2 import Template -from referencing import Registry, Resource +from airflow.operators.python import PythonOperator from airflow import DAG -# get the airflow.task logger -task_logger = logging.getLogger("airflow.task") - default_args = { "owner": "unity-sps", "start_date": datetime.utcfromtimestamp(0), } -def render_template(template, values=dict()): - template_str = json.dumps(template) - pattern = r'"{{\s*[^}]+?\|\s*tojson\s*}}"' - replacement_func = lambda m: m.group(0)[1:-1] - template_str = re.sub(pattern, replacement_func, template_str) - template_obj = Template(template_str) - rendered_template_str = template_obj.render(values) - rendered_template = json.loads(rendered_template_str) - return rendered_template - - -@task -def identify_dataset(dp_templates: Dict, dp_filename: str) -> Dict: - matched_dp_obj = None - matched_dp_name = None - matched_dp_template = None - for dp_name, dp_template in dp_templates.items(): - matched_dp_obj = re.fullmatch(dp_template["regex_pattern"], dp_filename) - - if not matched_dp_obj: - continue - - matched_dp_name = dp_name - matched_dp_template = dp_template - break - - if not matched_dp_obj: - raise ValueError - - dp_id = matched_dp_template["data_product_id_format"].format(**matched_dp_obj.groupdict()) - return {"data_product_name": matched_dp_name, "data_product_id": dp_id} - - -@task -def identify_dags(run_config_templates: Dict, data_product_name: str) -> List[str]: - return [ - dag_id - for dag_id, rc in run_config_templates.items() - if data_product_name in rc["required_input_data_products"] - ] - - -@task -def generate_run_configs( - run_config_templates: Dict, data_products_templates: Dict, dags: List[str], dp: Dict -) -> Dict: - osl_bucket = Variable.get("osl_bucket") - isl_bucket = Variable.get("isl_bucket") - config_bucket = Variable.get("config_bucket") - - run_configs = {} - matched_dp_name = dp["data_product_name"] - matched_dp_id = dp["data_product_id"] - dp_rendered_dict = render_template( - data_products_templates[matched_dp_name], - {"DATA_PRODUCT_ID": matched_dp_id, "ISL_BUCKET": isl_bucket}, - ) - - for dag_id in dags: - rc = run_config_templates[dag_id] - rc_values = {matched_dp_name.upper(): dp_rendered_dict} - - for req_static_dp in rc["required_static_data"]: - static_dp = data_products_templates[req_static_dp] - static_dp_rendered_dict = render_template(static_dp, {"CONFIG_BUCKET": config_bucket}) - rc_values[req_static_dp.upper()] = static_dp_rendered_dict +def hello_world(): + print("Hello World") + time.sleep(30) - for exp_output_dp in rc["expected_output_data_products"]: - exp_dp = data_products_templates[exp_output_dp] - exp_dp_rendered_dict = render_template(exp_dp, {"OSL_BUCKET": osl_bucket}) - rc_values[exp_output_dp.upper()] = exp_dp_rendered_dict - run_configs[dag_id] = render_template(rc, rc_values) +def write_to_shared_data(): + file_path = "/shared-task-data/test_file.txt" # Adjust the path as necessary + with open(file_path, "w") as f: + f.write("This is a test file written at " + str(datetime.now()) + "\n") + print(f"Successfully written to {file_path}") - return run_configs +def read_from_shared_data(): + file_path = "/shared-task-data/test_file.txt" # Adjust the path as necessary + try: + with open(file_path, "r") as f: + contents = f.read() + print(f"File contents:\n{contents}") + except FileNotFoundError: + print("File not found. Make sure the file path is correct.") -@task -def validate_run_configs(run_configs: Dict) -> Dict: - dp_schema = json.loads(Variable.get("data_product_schema")) - rc_schema = json.loads(Variable.get("run_config_schema")) - registry = Registry().with_resources( - [ - ("dp_schema.json", Resource.from_contents(dp_schema)), - ("rc_schema.json", Resource.from_contents(rc_schema)), - ] - ) - - rc_schema_validator = jsonschema.Draft202012Validator(rc_schema, registry=registry) - - for _, rc in run_configs.items(): - rc_schema_validator.validate(rc) - - return run_configs - - -@task -def evaluate_dag_triggers(run_configs: Dict) -> Dict: - s3_hook = S3Hook() - - def file_exists(bucket: str, key: str) -> bool: - return s3_hook.check_for_key(key, bucket_name=bucket) - - def get_required_files(run_config: Dict) -> List[Dict]: - required_files = [] - for dp in run_config["required_input_data_products"].values(): - required_files.extend(dp["files"]) - for static_data in run_config["required_static_data"].values(): - required_files.extend(static_data["files"]) - return required_files - - filtered_run_configs = {} - for dag_id, rc in run_configs.items(): - required_files = get_required_files(rc) - missing_files = [ - f'Bucket: {rf["bucket"]}, Key: {rf["key"]}' - for rf in required_files - if not file_exists(rf["bucket"], rf["key"]) - ] - - if missing_files: - task_logger.warning( - f"Not all required files exist for DAG: {dag_id}. Missing files:\n" + "\n".join(missing_files) - ) - else: - task_logger.info(f"All required files exist for DAG: {dag_id}") - filtered_run_configs[dag_id] = rc - - return filtered_run_configs - - -@task -def trigger_downstream_dags(filtered_run_configs: Dict): - for dag_id, rc in filtered_run_configs.items(): - required_files = [file["key"] for file in rc["required_input_data_products"]["STACAM_RawDP"]["files"]] - filename_string = "".join(required_files) - dag_run_id = hashlib.sha256(filename_string.encode()).hexdigest() - - trigger_dag = TriggerDagRunOperator( - task_id=f"trigger_dag_{dag_id}", - trigger_dag_id=dag_id, - conf={"run_config": rc}, - execution_date="{{ execution_date }}", - reset_dag_run=True, - wait_for_completion=True, - poke_interval=60, - allowed_states=["success"], - failed_states=["failed", "upstream_failed"], - dag_run_id=dag_run_id, - ) - trigger_dag.execute(context=None) +def delete_shared_data_file(): + file_path = "/shared-task-data/test_file.txt" # Adjust the path as necessary + try: + os.remove(file_path) + print(f"Successfully deleted {file_path}") + except FileNotFoundError: + print("File not found. Make sure the file path is correct.") with DAG( dag_id="eval_srl_edrgen_readiness", default_args=default_args, schedule=None, - catchup=False, - tags=["srl", "edr"], - doc_md=""" - This DAG evaluates SRL EDR generation readiness by checking for the existence of required files - and triggering downstream DAGs if all prerequisites are met. - """, + is_paused_upon_creation=False, + tags=["test"], ) as dag: - payload = "{{ dag_run.conf['payload'] }}" - parsed_url = urlparse(payload) - bucket = parsed_url.netloc - key = parsed_url.path.lstrip("/") - filename = key.split("/")[-1] - - dp_templates = json.loads(Variable.get("data_products_templates")) - rc_templates = json.loads(Variable.get("run_config_templates")) - - dp = identify_dataset(dp_templates, filename) - dags = identify_dags(rc_templates, dp["data_product_name"]) - run_configs = generate_run_configs(rc_templates, dp_templates, dags, dp) - validated_run_configs = validate_run_configs(run_configs) - filtered_run_configs = evaluate_dag_triggers(validated_run_configs) - trigger_dags = trigger_downstream_dags(filtered_run_configs) - - (dp >> dags >> run_configs >> validated_run_configs >> filtered_run_configs >> trigger_dags) + hello_world_task = PythonOperator( + task_id="hello_world", + python_callable=hello_world, + ) + + write_to_shared_data_task = PythonOperator( + task_id="write_to_shared_data", + python_callable=write_to_shared_data, + ) + + read_from_shared_data_task = PythonOperator( + task_id="read_from_shared_data", + python_callable=read_from_shared_data, + ) + + delete_shared_data_file_task = PythonOperator( + task_id="delete_shared_data_file", + python_callable=delete_shared_data_file, + ) + + ( + hello_world_task + >> write_to_shared_data_task + >> read_from_shared_data_task + >> delete_shared_data_file_task + )