Skip to content

updates to support 1000-node SPS scaling test using the IDPS SRL pipeline #414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,29 @@ repos:
- id: check-toml # Checks toml files for parsable syntax.

- repo: https://github.com/igorshubovych/markdownlint-cli
rev: "v0.39.0"
rev: "v0.44.0"
hooks:
- id: markdownlint
args: ["--config", ".markdownlintrc", "--ignore", "CHANGELOG.md"]

- repo: https://github.com/PyCQA/isort
rev: 5.13.2
rev: 6.0.1
hooks:
- id: isort
args: ["--profile=black"]

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
rev: 25.1.0
hooks:
- id: black

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2
rev: v0.11.9
hooks:
- id: ruff

- repo: https://github.com/PyCQA/bandit
rev: "1.7.8" # you must change this to newest version
rev: "1.8.3" # you must change this to newest version
hooks:
- id: bandit
args:
Expand All @@ -56,7 +56,7 @@ repos:
additional_dependencies: [".[toml]"]

- repo: https://github.com/hadolint/hadolint
rev: v2.13.0-beta
rev: v2.13.1-beta
hooks:
- id: hadolint # requires hadolint is installed (brew install hadolint)
args:
Expand All @@ -65,7 +65,7 @@ repos:
- --verbose

- repo: https://github.com/antonbabenko/pre-commit-terraform
rev: v1.89.1
rev: v1.99.0
hooks:
- id: terraform_validate # Validates all Terraform configuration files.
args:
Expand Down
138 changes: 138 additions & 0 deletions airflow/dags/db_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""A DB Cleanup DAG maintained by Astronomer."""

import logging
from datetime import UTC, datetime, timedelta
from typing import List, Optional

from airflow.cli.commands.db_command import all_tables
from airflow.decorators import dag, task
from airflow.models.param import Param
from airflow.operators.bash import BashOperator
from airflow.utils.db import reflect_tables
from airflow.utils.db_cleanup import _effective_table_names
from airflow.utils.session import NEW_SESSION, provide_session
from sqlalchemy import func
from sqlalchemy.orm import Session

default_args = {"owner": "unity-sps", "start_date": datetime.utcfromtimestamp(0)}


@dag(
dag_id="astronomer_db_cleanup_dag",
default_args=default_args,
schedule_interval=None,
start_date=datetime(2024, 1, 1),
catchup=False,
is_paused_upon_creation=False,
description=__doc__,
doc_md=__doc__,
render_template_as_native_obj=True,
max_active_tasks=1,
tags=["cleanup"],
params={
"clean_before_timestamp": Param(
default=datetime.now(tz=UTC) - timedelta(days=90),
type="string",
format="date-time",
description="Delete records older than this timestamp. Default is 90 days ago.",
),
"tables": Param(
default=[],
type=["null", "array"],
examples=all_tables,
description="List of tables to clean. Default is all tables.",
),
"dry_run": Param(
default=False,
type="boolean",
description="Print the SQL queries that would be run, but do not execute them. Default is False.",
),
"batch_size_days": Param(
default=7,
type="integer",
description="Number of days in each batch for the cleanup. Default is 7 days.",
),
},
)
def astronomer_db_cleanup_dag():

@provide_session
def get_oldest_timestamp(
tables,
session: Session = NEW_SESSION,
) -> Optional[str]:
oldest_timestamp_list = []
existing_tables = reflect_tables(tables=None, session=session).tables
_, effective_config_dict = _effective_table_names(table_names=tables)
for table_name, table_config in effective_config_dict.items():
if table_name in existing_tables:
orm_model = table_config.orm_model
recency_column = table_config.recency_column
oldest_execution_date = (
session.query(func.min(recency_column)).select_from(orm_model).scalar()
)
if oldest_execution_date:
oldest_timestamp_list.append(oldest_execution_date.isoformat())
else:
logging.info(f"No data found for {table_name}, skipping...")
else:
logging.warning(f"Table {table_name} not found. Skipping.")

if oldest_timestamp_list:
return min(oldest_timestamp_list)

@task
def get_chunked_timestamps(**context) -> List:
batches = []
start_chunk_time = get_oldest_timestamp(context["params"]["tables"])
if start_chunk_time:
start_ts = datetime.fromisoformat(start_chunk_time)
end_ts = datetime.fromisoformat(context["params"]["clean_before_timestamp"])
batch_size_days = context["params"]["batch_size_days"]

while start_ts < end_ts:
batch_end = min(start_ts + timedelta(days=batch_size_days), end_ts)
batches.append({"BATCH_TS": batch_end.isoformat()})
start_ts += timedelta(days=batch_size_days)
return batches

# The "clean_archive_tables" task drops archived tables created by the previous "clean_db" task, in case that task fails due to an error or timeout.
db_archive_cleanup = BashOperator(
task_id="clean_archive_tables",
bash_command="""\
airflow db drop-archived \
{% if params.tables -%}
--tables {{ params.tables|join(',') }} \
{% endif -%}
--yes \
""",
do_xcom_push=False,
trigger_rule="all_done",
)

chunked_timestamps = get_chunked_timestamps()

(
BashOperator.partial(
task_id="db_cleanup",
bash_command="""\
airflow db clean \
--clean-before-timestamp $BATCH_TS \
{% if params.dry_run -%}
--dry-run \
{% endif -%}
--skip-archive \
{% if params.tables -%}
--tables '{{ params.tables|join(',') }}' \
{% endif -%}
--verbose \
--yes \
""",
append_env=True,
do_xcom_push=False,
).expand(env=chunked_timestamps)
>> db_archive_cleanup
)


astronomer_db_cleanup_dag()
50 changes: 50 additions & 0 deletions airflow/dags/delete_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""A DB Cleanup DAG maintained by Astronomer."""

from datetime import UTC, datetime, timedelta

from airflow.decorators import dag
from airflow.models.param import Param
from airflow.operators.bash import BashOperator

default_args = {"owner": "unity-sps", "start_date": datetime.utcfromtimestamp(0)}


@dag(
dag_id="delete_dag",
default_args=default_args,
schedule_interval=None,
catchup=False,
is_paused_upon_creation=False,
description=__doc__,
doc_md=__doc__,
render_template_as_native_obj=True,
max_active_tasks=1,
tags=["cleanup"],
params={
"clean_before_timestamp": Param(
default=datetime.now(tz=UTC) - timedelta(days=90),
type="string",
format="date-time",
description="Delete records older than this timestamp. Default is 90 days ago.",
),
"dag_id": Param(type="string"),
},
)
def delete_dag():

delete_dag_task = BashOperator(
task_id="delete_dag_task",
bash_command="airflow dags delete {{ params.dag_id }} --yes",
do_xcom_push=False,
)

db_clean_task = BashOperator(
task_id="db_clean_task",
bash_command="airflow db clean --yes",
do_xcom_push=False,
)

delete_dag_task >> db_clean_task


delete_dag()
15 changes: 9 additions & 6 deletions airflow/dags/edrgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
},
) as dag:

@task
@task(weight_rule="absolute", priority_weight=103)
def prep(params: dict):
context = get_current_context()
dag_run_id = context["dag_run"].run_id
Expand Down Expand Up @@ -91,10 +91,13 @@ def prep(params: dict):
prep_task = prep()

edrgen_task = KubernetesPodOperator(
weight_rule="absolute",
priority_weight=104,
task_id="edrgen",
name="edrgen",
namespace="sps",
image="pymonger/srl-idps-edrgen:develop",
image="429178552491.dkr.ecr.us-west-2.amazonaws.com/srl-idps/edrgen:develop",
# image="pymonger/srl-idps-edrgen:develop",
# cmds=[
# "sh",
# "-c",
Expand All @@ -109,7 +112,7 @@ def prep(params: dict):
container_logs=True,
service_account_name="airflow-worker",
container_security_context={"privileged": True},
retries=0,
retries=3,
volume_mounts=[
k8s.V1VolumeMount(
name="workers-volume", mount_path="/stage-in", sub_path="{{ dag_run.run_id }}/stage-in"
Expand All @@ -126,16 +129,16 @@ def prep(params: dict):
],
node_selector={
"karpenter.sh/nodepool": unity_sps_utils.NODE_POOL_HIGH_WORKLOAD,
"node.kubernetes.io/instance-type": "r7i.2xlarge",
"node.kubernetes.io/instance-type": "c6i.large",
},
labels={"pod": unity_sps_utils.POD_LABEL},
annotations={"karpenter.sh/do-not-disrupt": "true"},
affinity=unity_sps_utils.get_affinity(
capacity_type=["on-demand"], anti_affinity_label=unity_sps_utils.POD_LABEL
capacity_type=["spot"], anti_affinity_label=unity_sps_utils.POD_LABEL
),
)

@task
@task(weight_rule="absolute", priority_weight=105)
def post(params: dict):
context = get_current_context()
dag_run_id = context["dag_run"].run_id
Expand Down
4 changes: 3 additions & 1 deletion airflow/dags/eval_srl_edrgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
},
) as dag:

@task
@task(weight_rule="absolute", priority_weight=100)
def evaluate_edrgen(params: dict):
s3_hook = S3Hook()

Expand Down Expand Up @@ -92,6 +92,8 @@ def edrgen_evaluation_successful():
edrgen_evaluation_successful_task = edrgen_evaluation_successful()

trigger_edrgen_task = TriggerDagRunOperator(
weight_rule="absolute",
priority_weight=102,
task_id="trigger_edrgen",
trigger_dag_id="edrgen",
# uncomment the next line if we want to dedup dagRuns for a particular ID
Expand Down
6 changes: 4 additions & 2 deletions airflow/dags/eval_srl_rdrgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
},
) as dag:

@task
@task(weight_rule="absolute", priority_weight=106)
def evaluate_rdrgen(params: dict):
s3_hook = S3Hook()

Expand Down Expand Up @@ -69,7 +69,7 @@ def evaluate_rdrgen(params: dict):

evaluate_rdrgen_task = evaluate_rdrgen()

@task.short_circuit()
@task.short_circuit(weight_rule="absolute", priority_weight=107)
def rdrgen_evaluation_successful():
context = get_current_context()
print(f"{context['ti'].xcom_pull(task_ids='evaluate_rdrgen')}")
Expand All @@ -81,6 +81,8 @@ def rdrgen_evaluation_successful():
rdrgen_evaluation_successful_task = rdrgen_evaluation_successful()

trigger_rdrgen_task = TriggerDagRunOperator(
weight_rule="absolute",
priority_weight=108,
task_id="trigger_rdrgen",
trigger_dag_id="rdrgen",
# uncomment the next line if we want to dedup dagRuns for a particular ID
Expand Down
6 changes: 4 additions & 2 deletions airflow/dags/eval_srl_vic2png.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
},
) as dag:

@task
@task(weight_rule="absolute", priority_weight=112)
def evaluate_vic2png(params: dict):
s3_hook = S3Hook()

Expand Down Expand Up @@ -64,7 +64,7 @@ def evaluate_vic2png(params: dict):

evaluate_vic2png_task = evaluate_vic2png()

@task.short_circuit()
@task.short_circuit(weight_rule="absolute", priority_weight=113)
def vic2png_evaluation_successful():
context = get_current_context()
print(f"{context['ti'].xcom_pull(task_ids='evaluate_vic2png')}")
Expand All @@ -73,6 +73,8 @@ def vic2png_evaluation_successful():
vic2png_evaluation_successful_task = vic2png_evaluation_successful()

trigger_vic2png_task = TriggerDagRunOperator(
weight_rule="absolute",
priority_weight=114,
task_id="trigger_vic2png",
trigger_dag_id="vic2png",
# uncomment the next line if we want to dedup dagRuns for a particular ID
Expand Down
Loading