Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EDRgen Initiator #185

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 54 additions & 185 deletions airflow/dags/eval_srl_edrgen_readiness.py
Original file line number Diff line number Diff line change
@@ -1,209 +1,78 @@
import hashlib
import json
import logging
import re
import os
import time
from datetime import datetime
from typing import Dict, List
from urllib.parse import urlparse

import jsonschema
from airflow.decorators import task
from airflow.models import Variable
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from jinja2 import Template
from referencing import Registry, Resource
from airflow.operators.python import PythonOperator

from airflow import DAG

# get the airflow.task logger
task_logger = logging.getLogger("airflow.task")

default_args = {
"owner": "unity-sps",
"start_date": datetime.utcfromtimestamp(0),
}


def render_template(template, values=dict()):
template_str = json.dumps(template)
pattern = r'"{{\s*[^}]+?\|\s*tojson\s*}}"'
replacement_func = lambda m: m.group(0)[1:-1]
template_str = re.sub(pattern, replacement_func, template_str)
template_obj = Template(template_str)
rendered_template_str = template_obj.render(values)
rendered_template = json.loads(rendered_template_str)
return rendered_template


@task
def identify_dataset(dp_templates: Dict, dp_filename: str) -> Dict:
matched_dp_obj = None
matched_dp_name = None
matched_dp_template = None
for dp_name, dp_template in dp_templates.items():
matched_dp_obj = re.fullmatch(dp_template["regex_pattern"], dp_filename)

if not matched_dp_obj:
continue

matched_dp_name = dp_name
matched_dp_template = dp_template
break

if not matched_dp_obj:
raise ValueError

dp_id = matched_dp_template["data_product_id_format"].format(**matched_dp_obj.groupdict())
return {"data_product_name": matched_dp_name, "data_product_id": dp_id}


@task
def identify_dags(run_config_templates: Dict, data_product_name: str) -> List[str]:
return [
dag_id
for dag_id, rc in run_config_templates.items()
if data_product_name in rc["required_input_data_products"]
]


@task
def generate_run_configs(
run_config_templates: Dict, data_products_templates: Dict, dags: List[str], dp: Dict
) -> Dict:
osl_bucket = Variable.get("osl_bucket")
isl_bucket = Variable.get("isl_bucket")
config_bucket = Variable.get("config_bucket")

run_configs = {}
matched_dp_name = dp["data_product_name"]
matched_dp_id = dp["data_product_id"]
dp_rendered_dict = render_template(
data_products_templates[matched_dp_name],
{"DATA_PRODUCT_ID": matched_dp_id, "ISL_BUCKET": isl_bucket},
)

for dag_id in dags:
rc = run_config_templates[dag_id]
rc_values = {matched_dp_name.upper(): dp_rendered_dict}

for req_static_dp in rc["required_static_data"]:
static_dp = data_products_templates[req_static_dp]
static_dp_rendered_dict = render_template(static_dp, {"CONFIG_BUCKET": config_bucket})
rc_values[req_static_dp.upper()] = static_dp_rendered_dict
def hello_world():
print("Hello World")
time.sleep(30)

for exp_output_dp in rc["expected_output_data_products"]:
exp_dp = data_products_templates[exp_output_dp]
exp_dp_rendered_dict = render_template(exp_dp, {"OSL_BUCKET": osl_bucket})
rc_values[exp_output_dp.upper()] = exp_dp_rendered_dict

run_configs[dag_id] = render_template(rc, rc_values)
def write_to_shared_data():
file_path = "/shared-task-data/test_file.txt" # Adjust the path as necessary
with open(file_path, "w") as f:
f.write("This is a test file written at " + str(datetime.now()) + "\n")
print(f"Successfully written to {file_path}")

return run_configs

def read_from_shared_data():
file_path = "/shared-task-data/test_file.txt" # Adjust the path as necessary
try:
with open(file_path, "r") as f:
contents = f.read()
print(f"File contents:\n{contents}")
except FileNotFoundError:
print("File not found. Make sure the file path is correct.")

@task
def validate_run_configs(run_configs: Dict) -> Dict:
dp_schema = json.loads(Variable.get("data_product_schema"))
rc_schema = json.loads(Variable.get("run_config_schema"))

registry = Registry().with_resources(
[
("dp_schema.json", Resource.from_contents(dp_schema)),
("rc_schema.json", Resource.from_contents(rc_schema)),
]
)

rc_schema_validator = jsonschema.Draft202012Validator(rc_schema, registry=registry)

for _, rc in run_configs.items():
rc_schema_validator.validate(rc)

return run_configs


@task
def evaluate_dag_triggers(run_configs: Dict) -> Dict:
s3_hook = S3Hook()

def file_exists(bucket: str, key: str) -> bool:
return s3_hook.check_for_key(key, bucket_name=bucket)

def get_required_files(run_config: Dict) -> List[Dict]:
required_files = []
for dp in run_config["required_input_data_products"].values():
required_files.extend(dp["files"])
for static_data in run_config["required_static_data"].values():
required_files.extend(static_data["files"])
return required_files

filtered_run_configs = {}
for dag_id, rc in run_configs.items():
required_files = get_required_files(rc)
missing_files = [
f'Bucket: {rf["bucket"]}, Key: {rf["key"]}'
for rf in required_files
if not file_exists(rf["bucket"], rf["key"])
]

if missing_files:
task_logger.warning(
f"Not all required files exist for DAG: {dag_id}. Missing files:\n" + "\n".join(missing_files)
)
else:
task_logger.info(f"All required files exist for DAG: {dag_id}")
filtered_run_configs[dag_id] = rc

return filtered_run_configs


@task
def trigger_downstream_dags(filtered_run_configs: Dict):
for dag_id, rc in filtered_run_configs.items():
required_files = [file["key"] for file in rc["required_input_data_products"]["STACAM_RawDP"]["files"]]
filename_string = "".join(required_files)
dag_run_id = hashlib.sha256(filename_string.encode()).hexdigest()

trigger_dag = TriggerDagRunOperator(
task_id=f"trigger_dag_{dag_id}",
trigger_dag_id=dag_id,
conf={"run_config": rc},
execution_date="{{ execution_date }}",
reset_dag_run=True,
wait_for_completion=True,
poke_interval=60,
allowed_states=["success"],
failed_states=["failed", "upstream_failed"],
dag_run_id=dag_run_id,
)
trigger_dag.execute(context=None)
def delete_shared_data_file():
file_path = "/shared-task-data/test_file.txt" # Adjust the path as necessary
try:
os.remove(file_path)
print(f"Successfully deleted {file_path}")
except FileNotFoundError:
print("File not found. Make sure the file path is correct.")


with DAG(
dag_id="eval_srl_edrgen_readiness",
default_args=default_args,
schedule=None,
catchup=False,
tags=["srl", "edr"],
doc_md="""
This DAG evaluates SRL EDR generation readiness by checking for the existence of required files
and triggering downstream DAGs if all prerequisites are met.
""",
is_paused_upon_creation=False,
tags=["test"],
) as dag:
payload = "{{ dag_run.conf['payload'] }}"
parsed_url = urlparse(payload)
bucket = parsed_url.netloc
key = parsed_url.path.lstrip("/")
filename = key.split("/")[-1]

dp_templates = json.loads(Variable.get("data_products_templates"))
rc_templates = json.loads(Variable.get("run_config_templates"))

dp = identify_dataset(dp_templates, filename)
dags = identify_dags(rc_templates, dp["data_product_name"])
run_configs = generate_run_configs(rc_templates, dp_templates, dags, dp)
validated_run_configs = validate_run_configs(run_configs)
filtered_run_configs = evaluate_dag_triggers(validated_run_configs)
trigger_dags = trigger_downstream_dags(filtered_run_configs)

(dp >> dags >> run_configs >> validated_run_configs >> filtered_run_configs >> trigger_dags)
hello_world_task = PythonOperator(
task_id="hello_world",
python_callable=hello_world,
)

write_to_shared_data_task = PythonOperator(
task_id="write_to_shared_data",
python_callable=write_to_shared_data,
)

read_from_shared_data_task = PythonOperator(
task_id="read_from_shared_data",
python_callable=read_from_shared_data,
)

delete_shared_data_file_task = PythonOperator(
task_id="delete_shared_data_file",
python_callable=delete_shared_data_file,
)

(
hello_world_task
>> write_to_shared_data_task
>> read_from_shared_data_task
>> delete_shared_data_file_task
)