Skip to content

Commit

Permalink
Memoize job snapshot ID when called repeatedly on the same object (#2…
Browse files Browse the repository at this point in the history
…6057)

Test Plan: BK, verify that a daemon keeps this value across multiple
runs within a single backfill (and better yet, consecutive backfill
ticks when the code has not changed)

NOCHANGELOG

> Insert changelog entry or delete this section.

## Summary & Motivation

## How I Tested These Changes

## Changelog

> Insert changelog entry or delete this section.
  • Loading branch information
gibsondan committed Nov 21, 2024
1 parent b25a75b commit 7065cd1
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 96 deletions.
6 changes: 3 additions & 3 deletions python_modules/dagster/dagster/_core/instance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,13 +1290,13 @@ def _ensure_persisted_job_snapshot(
job_snapshot: "JobSnap",
parent_job_snapshot: "Optional[JobSnap]",
) -> str:
from dagster._core.snap import JobSnap, create_job_snapshot_id
from dagster._core.snap import JobSnap

check.inst_param(job_snapshot, "job_snapshot", JobSnap)
check.opt_inst_param(parent_job_snapshot, "parent_job_snapshot", JobSnap)

if job_snapshot.lineage_snapshot:
parent_snapshot_id = create_job_snapshot_id(check.not_none(parent_job_snapshot))
parent_snapshot_id = check.not_none(parent_job_snapshot).snapshot_id

if job_snapshot.lineage_snapshot.parent_snapshot_id != parent_snapshot_id:
warnings.warn(
Expand All @@ -1308,7 +1308,7 @@ def _ensure_persisted_job_snapshot(
check.not_none(parent_job_snapshot), parent_snapshot_id
)

job_snapshot_id = create_job_snapshot_id(job_snapshot)
job_snapshot_id = job_snapshot.snapshot_id
if not self._run_storage.has_job_snapshot(job_snapshot_id):
returned_job_snapshot_id = self._run_storage.add_job_snapshot(job_snapshot)
check.invariant(job_snapshot_id == returned_job_snapshot_id)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from threading import Lock
from typing import Any, Mapping, Optional, Sequence, Union

import dagster._check as check
from dagster._config import ConfigSchemaSnapshot
from dagster._core.snap import DependencyStructureIndex, JobSnap, create_job_snapshot_id
from dagster._core.snap import DependencyStructureIndex, JobSnap
from dagster._core.snap.dagster_types import DagsterTypeSnap
from dagster._core.snap.mode import ModeDefSnap
from dagster._core.snap.node import GraphDefSnap, OpDefSnap
Expand Down Expand Up @@ -47,9 +46,6 @@ def __init__(
for comp_snap in job_snapshot.node_defs_snapshot.graph_def_snaps
}

self._memo_lock = Lock()
self._job_snapshot_id = None

@property
def name(self) -> str:
return self.job_snapshot.name
Expand All @@ -68,10 +64,7 @@ def metadata(self):

@property
def job_snapshot_id(self) -> str:
with self._memo_lock:
if not self._job_snapshot_id:
self._job_snapshot_id = create_job_snapshot_id(self.job_snapshot)
return self._job_snapshot_id
return self.job_snapshot.snapshot_id

def has_dagster_type_name(self, type_name: str) -> bool:
return type_name in self._dagster_type_snaps_by_name_index
Expand Down
5 changes: 1 addition & 4 deletions python_modules/dagster/dagster/_core/snap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@
create_execution_plan_snapshot_id as create_execution_plan_snapshot_id,
snapshot_from_execution_plan as snapshot_from_execution_plan,
)
from dagster._core.snap.job_snapshot import (
JobSnap as JobSnap,
create_job_snapshot_id as create_job_snapshot_id,
)
from dagster._core.snap.job_snapshot import JobSnap as JobSnap
from dagster._core.snap.mode import (
LoggerDefSnap as LoggerDefSnap,
ModeDefSnap as ModeDefSnap,
Expand Down
18 changes: 9 additions & 9 deletions python_modules/dagster/dagster/_core/snap/job_snapshot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import cached_property
from typing import AbstractSet, Any, Dict, Mapping, Optional, Sequence, Union, cast

from dagster import _check as check
Expand Down Expand Up @@ -52,9 +53,8 @@
from dagster._serdes.serdes import RecordSerializer


def create_job_snapshot_id(snapshot: "JobSnap") -> str:
check.inst_param(snapshot, "snapshot", JobSnap)
return create_snapshot_id(snapshot)
def _create_job_snapshot_id(job_snap: "JobSnap"):
return create_snapshot_id(job_snap)


class JobSnapSerializer(RecordSerializer["JobSnap"]):
Expand Down Expand Up @@ -157,17 +157,13 @@ def from_job_def(cls, job_def: JobDefinition) -> "JobSnap":
lineage = None
if job_def.op_selection_data:
lineage = JobLineageSnap(
parent_snapshot_id=create_job_snapshot_id(
cls.from_job_def(job_def.op_selection_data.parent_job_def)
),
parent_snapshot_id=job_def.op_selection_data.parent_job_def.get_job_snapshot_id(),
op_selection=sorted(job_def.op_selection_data.op_selection),
resolved_op_selection=job_def.op_selection_data.resolved_op_selection,
)
if job_def.asset_selection_data:
lineage = JobLineageSnap(
parent_snapshot_id=create_job_snapshot_id(
cls.from_job_def(job_def.asset_selection_data.parent_job_def)
),
parent_snapshot_id=job_def.asset_selection_data.parent_job_def.get_job_snapshot_id(),
asset_selection=job_def.asset_selection_data.asset_selection,
asset_check_selection=job_def.asset_selection_data.asset_check_selection,
)
Expand All @@ -187,6 +183,10 @@ def from_job_def(cls, job_def: JobDefinition) -> "JobSnap":
graph_def_name=job_def.graph.name,
)

@cached_property
def snapshot_id(self) -> str:
return _create_job_snapshot_id(self)

def get_node_def_snap(self, node_def_name: str) -> Union[OpDefSnap, GraphDefSnap]:
check.str_param(node_def_name, "node_def_name")
for node_def_snap in self.node_defs_snapshot.op_def_snaps:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@
)
from dagster._core.execution.backfill import BulkActionsFilter, BulkActionStatus, PartitionBackfill
from dagster._core.remote_representation.origin import RemoteJobOrigin
from dagster._core.snap import (
ExecutionPlanSnapshot,
JobSnap,
create_execution_plan_snapshot_id,
create_job_snapshot_id,
)
from dagster._core.snap import ExecutionPlanSnapshot, JobSnap, create_execution_plan_snapshot_id
from dagster._core.storage.dagster_run import (
DagsterRun,
DagsterRunStatus,
Expand Down Expand Up @@ -580,7 +575,7 @@ def add_job_snapshot(self, job_snapshot: JobSnap, snapshot_id: Optional[str] = N
check.opt_str_param(snapshot_id, "snapshot_id")

if not snapshot_id:
snapshot_id = create_job_snapshot_id(job_snapshot)
snapshot_id = job_snapshot.snapshot_id

return self._add_snapshot(
snapshot_id=snapshot_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@
from dagster._core.instance.config import DEFAULT_LOCAL_CODE_SERVER_STARTUP_TIMEOUT
from dagster._core.launcher import LaunchRunContext, RunLauncher
from dagster._core.run_coordinator.queued_run_coordinator import QueuedRunCoordinator
from dagster._core.snap import (
create_execution_plan_snapshot_id,
create_job_snapshot_id,
snapshot_from_execution_plan,
)
from dagster._core.snap import create_execution_plan_snapshot_id, snapshot_from_execution_plan
from dagster._core.storage.partition_status_cache import AssetPartitionStatus, AssetStatusCacheValue
from dagster._core.storage.sqlite_storage import (
_event_logs_directory,
Expand Down Expand Up @@ -266,7 +262,7 @@ def test_create_job_snapshot():

run = instance.get_run_by_id(result.run_id)

assert run.job_snapshot_id == create_job_snapshot_id(noop_job.get_job_snapshot())
assert run.job_snapshot_id == noop_job.get_job_snapshot().snapshot_id


def test_create_execution_plan_snapshot():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
RepositorySnap,
TimeWindowPartitionsSnap,
)
from dagster._core.snap.job_snapshot import create_job_snapshot_id
from dagster._core.test_utils import in_process_test_workspace, instance_for_test
from dagster._core.types.loadable_target_origin import LoadableTargetOrigin
from dagster._core.snap.job_snapshot import _create_job_snapshot_id
from dagster._core.test_utils import create_test_daemon_workspace_context, instance_for_test
from dagster._serdes import serialize_pp
from dagster._time import get_current_datetime

Expand Down Expand Up @@ -73,44 +72,60 @@ def test_remote_job_data(snapshot):
)


@mock.patch("dagster._core.remote_representation.job_index.create_job_snapshot_id")
def test_remote_repo_shared_index(snapshot_mock):
# ensure we don't rebuild indexes / snapshot ids repeatedly
import os

snapshot_mock.side_effect = create_job_snapshot_id
with instance_for_test() as instance:
with in_process_test_workspace(
instance, LoadableTargetOrigin(python_file=__file__)
) as workspace:
from dagster._core.workspace.load_target import ModuleTarget

def _fetch_snap_id():
location = workspace.code_locations[0]
ex_repo = next(iter(location.get_repositories().values()))
return ex_repo.get_all_jobs()[0].identifying_job_snapshot_id

_fetch_snap_id()
assert snapshot_mock.call_count == 1
def workspace_load_target():
return ModuleTarget(
module_name="dagster_tests.core_tests.snap_tests.test_active_data",
attribute="a_repo",
working_directory=os.path.join(os.path.dirname(__file__), "..", "..", ".."),
location_name="test_location",
)

_fetch_snap_id()
assert snapshot_mock.call_count == 1

def test_remote_repo_shared_index_single_threaded():
# ensure we don't rebuild indexes / snapshot ids repeatedly
with mock.patch("dagster._core.snap.job_snapshot._create_job_snapshot_id") as snapshot_mock:
snapshot_mock.side_effect = _create_job_snapshot_id
with instance_for_test() as instance:
with create_test_daemon_workspace_context(
workspace_load_target(),
instance,
) as workspace_process_context:
workspace = workspace_process_context.create_request_context()

@mock.patch("dagster._core.remote_representation.job_index.create_job_snapshot_id")
def test_remote_repo_shared_index_threaded(snapshot_mock):
# ensure we don't rebuild indexes / snapshot ids repeatedly across threads
def _fetch_snap_id():
location = workspace.code_locations[0]
ex_repo = next(iter(location.get_repositories().values()))
return ex_repo.get_all_jobs()[0].identifying_job_snapshot_id

snapshot_mock.side_effect = create_job_snapshot_id
with instance_for_test() as instance:
with in_process_test_workspace(
instance, LoadableTargetOrigin(python_file=__file__)
) as workspace:
_fetch_snap_id()
assert snapshot_mock.call_count == 1

def _fetch_snap_id():
location = workspace.code_locations[0]
ex_repo = next(iter(location.get_repositories().values()))
return ex_repo.get_all_jobs()[0].identifying_job_snapshot_id
_fetch_snap_id()
assert snapshot_mock.call_count == 1

with ThreadPoolExecutor() as executor:
wait([executor.submit(_fetch_snap_id) for _ in range(100)])

assert snapshot_mock.call_count == 1
def test_remote_repo_shared_index_multi_threaded():
# ensure we don't rebuild indexes / snapshot ids repeatedly across threads
with mock.patch("dagster._core.snap.job_snapshot._create_job_snapshot_id") as snapshot_mock:
snapshot_mock.side_effect = _create_job_snapshot_id
with instance_for_test() as instance:
with create_test_daemon_workspace_context(
workspace_load_target(),
instance,
) as workspace_process_context:
workspace = workspace_process_context.create_request_context()

def _fetch_snap_id():
location = workspace.code_locations[0]
ex_repo = next(iter(location.get_repositories().values()))
return ex_repo.get_all_jobs()[0].identifying_job_snapshot_id

with ThreadPoolExecutor() as executor:
wait([executor.submit(_fetch_snap_id) for _ in range(100)])

assert snapshot_mock.call_count == 1
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dagster import GraphOut, In, Out, graph, job, op
from dagster._core.execution.api import create_execution_plan
from dagster._core.snap import create_job_snapshot_id, snapshot_from_execution_plan
from dagster._core.snap import snapshot_from_execution_plan
from dagster._serdes import serialize_pp


Expand All @@ -17,10 +17,7 @@ def noop_job():

snapshot.assert_match(
serialize_pp(
snapshot_from_execution_plan(
execution_plan,
create_job_snapshot_id(noop_job.get_job_snapshot()),
)
snapshot_from_execution_plan(execution_plan, noop_job.get_job_snapshot().snapshot_id)
)
)

Expand All @@ -44,7 +41,7 @@ def noop_job():
serialize_pp(
snapshot_from_execution_plan(
execution_plan,
create_job_snapshot_id(noop_job.get_job_snapshot()),
noop_job.get_job_snapshot().snapshot_id,
)
)
)
Expand Down Expand Up @@ -84,7 +81,7 @@ def do_comps():
serialize_pp(
snapshot_from_execution_plan(
execution_plan,
create_job_snapshot_id(do_comps.get_job_snapshot()),
do_comps.get_job_snapshot().snapshot_id,
)
)
)
Expand All @@ -105,7 +102,7 @@ def noop_job():
serialize_pp(
snapshot_from_execution_plan(
execution_plan,
create_job_snapshot_id(noop_job.get_job_snapshot()),
noop_job.get_job_snapshot().snapshot_id,
)
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
DependencyStructureIndex,
JobSnap,
NodeInvocationSnap,
create_job_snapshot_id,
snap_from_config_type,
)
from dagster._core.snap.dep_snapshot import (
Expand Down Expand Up @@ -51,7 +50,7 @@ def test_empty_job_snap_props(snapshot):
assert job_snapshot == serialize_rt(job_snapshot)

snapshot.assert_match(serialize_pp(job_snapshot))
snapshot.assert_match(create_job_snapshot_id(job_snapshot))
snapshot.assert_match(job_snapshot.snapshot_id)


def test_job_snap_all_props(snapshot):
Expand All @@ -72,7 +71,7 @@ def noop_job():
assert job_snapshot == serialize_rt(job_snapshot)

snapshot.assert_match(serialize_pp(job_snapshot))
snapshot.assert_match(create_job_snapshot_id(job_snapshot))
snapshot.assert_match(job_snapshot.snapshot_id)


def test_noop_deps_snap():
Expand Down Expand Up @@ -107,7 +106,7 @@ def two_op_job():
assert job_snapshot == serialize_rt(job_snapshot)

snapshot.assert_match(serialize_pp(job_snapshot))
snapshot.assert_match(create_job_snapshot_id(job_snapshot))
snapshot.assert_match(job_snapshot.snapshot_id)


def test_basic_dep():
Expand Down Expand Up @@ -177,7 +176,7 @@ def single_dep_job():
assert job_snapshot == serialize_rt(job_snapshot)

snapshot.assert_match(serialize_pp(job_snapshot))
snapshot.assert_match(create_job_snapshot_id(job_snapshot))
snapshot.assert_match(job_snapshot.snapshot_id)


def test_basic_fan_in(snapshot):
Expand Down Expand Up @@ -218,7 +217,7 @@ def fan_in_test():
assert job_snapshot == serialize_rt(job_snapshot)

snapshot.assert_match(serialize_pp(job_snapshot))
snapshot.assert_match(create_job_snapshot_id(job_snapshot))
snapshot.assert_match(job_snapshot.snapshot_id)


def _dict_has_stable_hashes(hydrated_map, snapshot_config_snap_map):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
JobPythonOrigin,
RepositoryPythonOrigin,
)
from dagster._core.snap import JobSnap, create_job_snapshot_id
from dagster._core.snap import JobSnap
from dagster._core.test_utils import instance_for_test
from dagster._utils import file_relative_path
from dagster._utils.hosted_user_process import recon_job_from_origin
Expand Down Expand Up @@ -52,7 +52,7 @@ def get_with_args(_x):


def pid(pipeline_def):
return create_job_snapshot_id(JobSnap.from_job_def(pipeline_def))
return JobSnap.from_job_def(pipeline_def).snapshot_id


@job
Expand Down
Loading

0 comments on commit 7065cd1

Please sign in to comment.