Skip to content

Commit

Permalink
[dagster-airlift] MaterializeAssetsOperator (#25092)
Browse files Browse the repository at this point in the history
## Summary & Motivation
Add an operator for materializing a provided set of assets.

## How I Tested These Changes
Added a new test suite for this behavior.

## Changelog
NOCHANGELOG
  • Loading branch information
dpeng817 authored Oct 10, 2024
1 parent 11fd93f commit 6a49a39
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any, Iterable, Mapping, Sequence, Union

from airflow.utils.context import Context

from dagster_airlift.in_airflow.base_asset_operator import BaseDagsterAssetsOperator


class BaseMaterializeAssetsOperator(BaseDagsterAssetsOperator):
"""An operator base class that proxies execution to a user-provided list of Dagster assets.
Will throw an error at runtime if not all assets can be found on the corresponding Dagster instance.
Args:
asset_key_paths (Sequence[Union[str, Sequence[str]]]): A sequence of asset key paths to materialize.
Each path in the sequence can be a string, which is treated as an asset key path with a single
component, or a sequence of strings representing a path with multiple components. For more,
see the docs on asset keys: https://docs.dagster.io/concepts/assets/software-defined-assets#multi-component-asset-keys
"""

def __init__(self, asset_key_paths: Sequence[Union[str, Sequence[str]]], *args, **kwargs):
self.asset_key_paths = [
(path,) if isinstance(path, str) else tuple(path) for path in asset_key_paths
]
super().__init__(*args, **kwargs)

def filter_asset_nodes(
self, context: Context, asset_nodes: Sequence[Mapping[str, Any]]
) -> Iterable[Mapping[str, Any]]:
hashable_path_to_node = {tuple(node["assetKey"]["path"]): node for node in asset_nodes}
if not all(path in hashable_path_to_node for path in self.asset_key_paths):
raise ValueError(
f"Could not find all asset key paths {self.asset_key_paths} in the asset nodes. Found: {list(hashable_path_to_node.keys())}"
)
yield from [hashable_path_to_node[path] for path in self.asset_key_paths]
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import logging
import os
from datetime import datetime

import requests
from airflow import DAG
from airflow.utils.context import Context
from dagster_airlift.in_airflow.materialize_assets_operator import BaseMaterializeAssetsOperator

logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
requests_log = logging.getLogger("requests.packages.urllib3")
requests_log.setLevel(logging.INFO)
requests_log.propagate = True


default_args = {
"owner": "airflow",
"depends_on_past": False,
"retries": 1,
}

dag = DAG(
"the_dag",
default_args=default_args,
schedule_interval=None,
is_paused_upon_creation=False,
start_date=datetime(2023, 1, 1),
)


class BlankSessionAssetsOperator(BaseMaterializeAssetsOperator):
"""An assets operator which opens a blank session and expects the dagster URL to be set in the environment.
The dagster url is expected to be set in the environment as DAGSTER_URL.
"""

def get_dagster_session(self, context: Context) -> requests.Session:
return requests.Session()

def get_dagster_url(self, context: Context) -> str:
return os.environ["DAGSTER_URL"]


the_task = BlankSessionAssetsOperator(
# Test both string syntax and list of strings syntax.
task_id="some_task",
dag=dag,
asset_key_paths=["some_asset", ["other_asset"], ["nested", "asset"]],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dagster import Definitions, asset


@asset
def some_asset():
return "asset_value"


@asset
def other_asset():
return "other_asset_value"


@asset(key=["nested", "asset"])
def nested_asset():
return "nested_asset_value"


defs = Definitions(assets=[some_asset, other_asset, nested_asset])
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import time
from pathlib import Path

import pytest
import requests
from dagster import AssetKey, DagsterInstance, DagsterRunStatus
from dagster._core.test_utils import environ
from dagster._time import get_current_timestamp
from dagster_airlift.constants import DAG_RUN_ID_TAG_KEY


def _test_project_dir() -> Path:
return Path(__file__).parent / "materialize_assets_operator_test_project"


@pytest.fixture(name="dags_dir")
def dags_dir() -> Path:
return _test_project_dir() / "dags"


@pytest.fixture(name="dagster_defs_path")
def dagster_defs_path_fixture() -> str:
return str(_test_project_dir() / "dagster_defs.py")


def test_dagster_operator(airflow_instance: None, dagster_dev: None, dagster_home: str) -> None:
"""Tests that dagster operator can correctly map airflow tasks to dagster tasks, and kick off executions."""
response = requests.post(
"http://localhost:8080/api/v1/dags/the_dag/dagRuns", auth=("admin", "admin"), json={}
)
assert response.status_code == 200, response.json()
# Wait until the run enters a terminal state
terminal_status = None
start_time = get_current_timestamp()
dag_run = None
while get_current_timestamp() - start_time < 30:
response = requests.get(
"http://localhost:8080/api/v1/dags/the_dag/dagRuns", auth=("admin", "admin")
)
assert response.status_code == 200, response.json()
dag_runs = response.json()["dag_runs"]
if dag_runs[0]["state"] in ["success", "failed"]:
terminal_status = dag_runs[0]["state"]
dag_run = dag_runs[0]
break
time.sleep(1)
assert terminal_status == "success", (
"Never reached terminal status"
if terminal_status is None
else f"terminal status was {terminal_status}"
)
with environ({"DAGSTER_HOME": dagster_home}):
instance = DagsterInstance.get()
runs = instance.get_runs()
# The graphql endpoint kicks off a run for each of the assets provided in the asset_key_paths. There are two assets,
# but they exist within the same job, so there should only be 1 run.
assert len(runs) == 1
the_run = [ # noqa
run
for run in runs
if set(list(run.asset_selection)) # type: ignore
== {AssetKey(["some_asset"]), AssetKey(["other_asset"]), AssetKey(["nested", "asset"])}
][0]
assert the_run.status == DagsterRunStatus.SUCCESS

assert isinstance(dag_run, dict)
assert "dag_run_id" in dag_run
assert the_run.tags[DAG_RUN_ID_TAG_KEY] == dag_run["dag_run_id"]

0 comments on commit 6a49a39

Please sign in to comment.