Skip to content

Commit

Permalink
cp
Browse files Browse the repository at this point in the history
  • Loading branch information
schrockn committed Aug 22, 2024
1 parent 6db9d59 commit e8ce462
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 80 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
from collections.abc import Iterable

import dlt
from dagster_embedded_elt.dlt import (
DagsterDltResource,
DagsterDltTranslator,
dlt_assets,
)

from dagster import AssetExecutionContext, AssetKey
from dagster_embedded_elt.dlt.dlt_computation import RunDlt


@dlt.source
Expand All @@ -17,26 +9,23 @@ def example_resource(): ...
return example_resource


class CustomDagsterDltTranslator(DagsterDltTranslator):
def get_asset_key(self, resource: DagsterDltResource) -> AssetKey:
"""Overrides asset key to be the dlt resource name."""
return AssetKey(f"{resource.name}")

def get_deps_asset_keys(self, resource: DagsterDltResource) -> Iterable[AssetKey]:
"""Overrides upstream asset key to be a single source asset."""
return [AssetKey("common_upstream_dlt_dependency")]


@dlt_assets(
dlt_source = example_dlt_source()
dlt_pipeline = dlt.pipeline(
pipeline_name="example_pipeline_name",
dataset_name="example_dataset_name",
destination="snowflake",
progress="log",
)
source_asset_key = "common_upstream_dlt_dependency"
RunDlt(
name="example_dlt_assets",
dlt_source=example_dlt_source(),
dlt_pipeline=dlt.pipeline(
pipeline_name="example_pipeline_name",
dataset_name="example_dataset_name",
destination="snowflake",
progress="log",
),
dagster_dlt_translator=CustomDagsterDltTranslator(),
dlt_source=dlt_source,
dlt_pipeline=dlt_pipeline,
specs=[
RunDlt.default_spec(dlt_source, dlt_pipeline, dlt_resource)._replace(
key=dlt_resource.name, # overrides asset key to be resource name
deps=[source_asset_key], # overrides upstream to be single source asset
)
for dlt_resource in dlt_source.resources.values()
],
)
def dlt_example_assets(context: AssetExecutionContext, dlt: DagsterDltResource):
yield from dlt.run(context=context)
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Optional
from typing import Iterable, Optional

import dlt
from dagster_embedded_elt.dlt import DagsterDltResource, dlt_assets
from dagster_embedded_elt.dlt.computation import ComputationContext
from dagster_embedded_elt.dlt.dlt_computation import RunDlt

from dagster import AssetExecutionContext, StaticPartitionsDefinition
from dagster._core.definitions.asset_check_result import AssetCheckResult
from dagster._core.definitions.result import AssetResult

color_partitions = StaticPartitionsDefinition(["red", "green", "blue"])

Expand All @@ -19,16 +23,27 @@ def load_colors():
...


@dlt_assets(
dlt_source=example_dlt_source(),
dlt_source = example_dlt_source()
dlt_pipeline = dlt.pipeline(
pipeline_name="example_pipeline_name",
dataset_name="example_dataset_name",
destination="snowflake",
)


class PartitionedRunDlt(RunDlt):
def stream(self, context: ComputationContext) -> Iterable:
color = context.partition_key
yield from DagsterDltResource().run(
context=context, dlt_source=example_dlt_source(color=color)
)


PartitionedRunDlt(
name="example_dlt_assets",
dlt_pipeline=dlt.pipeline(
pipeline_name="example_pipeline_name",
dataset_name="example_dataset_name",
destination="snowflake",
dlt_source=dlt_source,
dlt_pipeline=dlt_pipeline,
specs=RunDlt.default_specs(dlt_source, dlt_pipeline).replace(
partitions_def=color_partitions
),
partitions_def=color_partitions,
)
def compute(context: AssetExecutionContext, dlt: DagsterDltResource):
color = context.partition_key
yield from dlt.run(context=context, dlt_source=example_dlt_source(color=color))
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class ComputationContext:
def __init__(self, context: AssetExecutionContext):
self._ae_context = context

# this property here to not freak people out temporarily
@property
def partition_key(self) -> Optional[str]:
return self._ae_context.partition_key

def to_asset_execution_context(self) -> AssetExecutionContext:
return self._ae_context

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,36 +37,7 @@ def get_description(resource: DltResource) -> Optional[str]:
return None


class Dlt(Computation):
@classmethod
def default_specs(cls, dlt_source: DltSource, dlt_pipeline: Pipeline) -> Specs:
return Specs(
[
Dlt.default_spec(dlt_source, dlt_pipeline, resource)
for resource in dlt_source.selected_resources.values()
]
)

@classmethod
def default_spec(
cls, dlt_source: DltSource, dlt_pipeline: Pipeline, dlt_resource: DltResource
) -> AssetSpec:
return AssetSpec(
key=f"dlt_{dlt_resource.source_name}_{dlt_resource.name}",
deps=get_upstream_deps(dlt_resource),
description=get_description(dlt_resource),
metadata={
META_KEY_SOURCE: dlt_source,
META_KEY_PIPELINE: dlt_pipeline,
META_KEY_TRANSLATOR: None,
},
tags={
"dagster/compute_kind": dlt_pipeline.destination.destination_name,
},
)


class RunDlt(Dlt):
class RunDlt(Computation):
"""Asset Factory for using data load tool (dlt).
Args:
Expand Down Expand Up @@ -117,6 +88,33 @@ class RunDlt(Dlt):
)
"""

@classmethod
def default_specs(cls, dlt_source: DltSource, dlt_pipeline: Pipeline) -> Specs:
return Specs(
[
RunDlt.default_spec(dlt_source, dlt_pipeline, resource)
for resource in dlt_source.selected_resources.values()
]
)

@classmethod
def default_spec(
cls, dlt_source: DltSource, dlt_pipeline: Pipeline, dlt_resource: DltResource
) -> AssetSpec:
return AssetSpec(
key=f"dlt_{dlt_resource.source_name}_{dlt_resource.name}",
deps=get_upstream_deps(dlt_resource),
description=get_description(dlt_resource),
metadata={
META_KEY_SOURCE: dlt_source,
META_KEY_PIPELINE: dlt_pipeline,
META_KEY_TRANSLATOR: None,
},
tags={
"dagster/compute_kind": dlt_pipeline.destination.destination_name,
},
)

def __init__(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dlt.extract.source import DltSource
from dlt.pipeline.pipeline import Pipeline

from .computation import ComputationContext
from .constants import META_KEY_PIPELINE, META_KEY_SOURCE, META_KEY_TRANSLATOR
from .translator import DagsterDltTranslator

Expand Down Expand Up @@ -114,7 +115,7 @@ def extract_resource_metadata(
@public
def run(
self,
context: Union[OpExecutionContext, AssetExecutionContext],
context: Union[OpExecutionContext, AssetExecutionContext, ComputationContext],
dlt_source: Optional[DltSource] = None,
dlt_pipeline: Optional[Pipeline] = None,
dagster_dlt_translator: Optional[DagsterDltTranslator] = None,
Expand All @@ -133,6 +134,11 @@ def run(
Iterator[Union[MaterializeResult, AssetMaterialization]]: An iterator of MaterializeResult or AssetMaterialization
"""
context = (
context.to_asset_execution_context()
if isinstance(context, ComputationContext)
else context
)
# This resource can be used in both `asset` and `op` definitions. In the context of an asset
# execution, we retrieve the dlt source, pipeline, and translator from the asset metadata
# as a fallback mechanism. We give preference to explicit parameters to make it easy to
Expand All @@ -148,7 +154,9 @@ def run(
dlt_pipeline or first_asset_metadata.get(META_KEY_PIPELINE), Pipeline
)
dagster_dlt_translator = check.inst(
dagster_dlt_translator or first_asset_metadata.get(META_KEY_TRANSLATOR) or DagsterDltTranslator(),
dagster_dlt_translator
or first_asset_metadata.get(META_KEY_TRANSLATOR)
or DagsterDltTranslator(),
DagsterDltTranslator,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def test_example_pipeline_asset_keys(dlt_pipeline: Pipeline) -> None:
} == RunDlt(dlt_source=pipeline(), dlt_pipeline=dlt_pipeline).assets_def.keys



def test_example_pipeline_deps(dlt_pipeline: Pipeline) -> None:
# Since repo_issues is a transform of the repo data, its upstream
# asset key should be the repo data asset key as well.
Expand All @@ -43,7 +42,6 @@ def test_example_pipeline_deps(dlt_pipeline: Pipeline) -> None:
} == RunDlt(dlt_source=pipeline(), dlt_pipeline=dlt_pipeline).assets_def.asset_deps



def test_example_pipeline_descs(dlt_pipeline: Pipeline) -> None:
@dlt_assets(dlt_source=pipeline(), dlt_pipeline=dlt_pipeline)
def example_pipeline_assets(
Expand Down Expand Up @@ -137,7 +135,6 @@ def test_get_materialize_policy(dlt_pipeline: Pipeline):
assert "0 1 * * *" in str(item)



def test_example_pipeline_has_required_metadata_keys(dlt_pipeline: Pipeline):
required_metadata_keys = {
"destination_type",
Expand Down Expand Up @@ -321,12 +318,11 @@ def stream(self, context: ComputationContext) -> Iterable:
dagster_dlt_resource = DagsterDltResource()
yield from dagster_dlt_resource.run(context=asset_context, dlt_source=pipeline(month))


async def run_partition(year: str):
return PartitionedDltRun(
specs=RunDlt.default_specs(
dlt_source=pipeline(), dlt_pipeline=dlt_pipeline
).replace(partitions_def=MonthlyPartitionsDefinition(start_date="2022-08-09")),
specs=RunDlt.default_specs(dlt_source=pipeline(), dlt_pipeline=dlt_pipeline).replace(
partitions_def=MonthlyPartitionsDefinition(start_date="2022-08-09")
),
dlt_source=pipeline(),
dlt_pipeline=dlt_pipeline,
).test(partitions=year)
Expand Down

0 comments on commit e8ce462

Please sign in to comment.