Skip to content

Commit

Permalink
Add a flag to the job subset grpc call to omit including the parent s…
Browse files Browse the repository at this point in the history
…napshot in the response (#20335)

Summary:
The parent snapshot in this case is both potentially large and already
available in the caller without going through the gRPC server - add a
flag that allows us to omit it and start setting that flag by default -
there shouldn't be any back-compat concerns here b/c on old grpc servers
we'll just return it anyway and overwrite it on the caller.

Test Plan: BK

## Summary & Motivation

## How I Tested These Changes
  • Loading branch information
gibsondan authored Mar 12, 2024
1 parent b94c104 commit e556a8a
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 19 deletions.
2 changes: 2 additions & 0 deletions python_modules/dagster/dagster/_api/snapshot_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
def sync_get_external_job_subset_grpc(
api_client: "DagsterGrpcClient",
job_origin: ExternalJobOrigin,
include_parent_snapshot: bool,
op_selection: Optional[Sequence[str]] = None,
asset_selection: Optional[AbstractSet[AssetKey]] = None,
asset_check_selection: Optional[AbstractSet[AssetCheckKey]] = None,
Expand All @@ -36,6 +37,7 @@ def sync_get_external_job_subset_grpc(
op_selection=op_selection,
asset_selection=asset_selection,
asset_check_selection=asset_check_selection,
include_parent_snapshot=include_parent_snapshot,
),
),
ExternalJobSubsetResult,
Expand Down
6 changes: 0 additions & 6 deletions python_modules/dagster/dagster/_core/instance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,12 +1280,6 @@ def _ensure_persisted_job_snapshot(
if not self._run_storage.has_job_snapshot(
job_snapshot.lineage_snapshot.parent_snapshot_id
):
check.invariant(
create_job_snapshot_id(parent_job_snapshot) # type: ignore # (possible none)
== job_snapshot.lineage_snapshot.parent_snapshot_id,
"Parent pipeline snapshot id out of sync with passed parent pipeline snapshot",
)

returned_job_snapshot_id = self._run_storage.add_job_snapshot(
parent_job_snapshot # type: ignore # (possible none)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def get_subset_external_job_result(
selector.op_selection,
selector.asset_selection,
selector.asset_check_selection,
include_parent_snapshot=True,
)

def get_external_execution_plan(
Expand Down Expand Up @@ -808,13 +809,25 @@ def get_subset_external_job_result(

external_repository = self.get_repository(selector.repository_name)
job_handle = JobHandle(selector.job_name, external_repository.handle)
return sync_get_external_job_subset_grpc(
subset = sync_get_external_job_subset_grpc(
self.client,
job_handle.get_external_origin(),
selector.op_selection,
selector.asset_selection,
selector.asset_check_selection,
include_parent_snapshot=False,
op_selection=selector.op_selection,
asset_selection=selector.asset_selection,
asset_check_selection=selector.asset_check_selection,
)
if subset.external_job_data:
full_job = self.get_repository(selector.repository_name).get_full_external_job(
selector.job_name
)
subset = subset._replace(
external_job_data=subset.external_job_data._replace(
parent_job_snapshot=full_job.job_snapshot
)
)

return subset

def get_external_partition_config(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
ExternalRepositoryOrigin,
)
from dagster._core.snap import ExecutionPlanSnapshot
from dagster._core.snap.job_snapshot import JobSnapshot
from dagster._core.utils import toposort
from dagster._serdes import create_snapshot_id
from dagster._utils.cached_method import cached_method
Expand Down Expand Up @@ -514,6 +515,10 @@ def tags(self) -> Mapping[str, object]:
def metadata(self) -> Mapping[str, MetadataValue]:
return self._job_index.job_snapshot.metadata

@property
def job_snapshot(self) -> JobSnapshot:
return self._job_index.job_snapshot

@property
def computed_job_snapshot_id(self) -> str:
return self._snapshot_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,11 @@ def external_repository_data_from_def(
)
else:
job_datas = sorted(
list(map(external_job_data_from_def, jobs)),
list(
map(
lambda job: external_job_data_from_def(job, include_parent_snapshot=False), jobs
)
),
key=lambda pd: pd.name,
)
job_refs = None
Expand Down Expand Up @@ -1783,12 +1787,14 @@ def external_asset_nodes_from_defs(
return asset_nodes


def external_job_data_from_def(job_def: JobDefinition) -> ExternalJobData:
def external_job_data_from_def(
job_def: JobDefinition, include_parent_snapshot: bool
) -> ExternalJobData:
check.inst_param(job_def, "job_def", JobDefinition)
return ExternalJobData(
name=job_def.name,
job_snapshot=job_def.get_job_snapshot(),
parent_job_snapshot=job_def.get_parent_job_snapshot(),
parent_job_snapshot=job_def.get_parent_job_snapshot() if include_parent_snapshot else None,
active_presets=active_presets_from_job_def(job_def),
)

Expand Down
5 changes: 4 additions & 1 deletion python_modules/dagster/dagster/_grpc/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def get_external_pipeline_subset_result(
op_selection: Optional[Sequence[str]],
asset_selection: Optional[AbstractSet[AssetKey]],
asset_check_selection: Optional[AbstractSet[AssetCheckKey]],
include_parent_snapshot: bool,
):
try:
definition = repo_def.get_maybe_subset_job_def(
Expand All @@ -286,7 +287,9 @@ def get_external_pipeline_subset_result(
asset_selection=asset_selection,
asset_check_selection=asset_check_selection,
)
external_job_data = external_job_data_from_def(definition)
external_job_data = external_job_data_from_def(
definition, include_parent_snapshot=include_parent_snapshot
)
return ExternalJobSubsetResult(success=True, external_job_data=external_job_data)
except Exception:
return ExternalJobSubsetResult(
Expand Down
5 changes: 4 additions & 1 deletion python_modules/dagster/dagster/_grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ def ExternalPipelineSubsetSnapshot(
job_subset_snapshot_args.op_selection,
job_subset_snapshot_args.asset_selection,
job_subset_snapshot_args.asset_check_selection,
job_subset_snapshot_args.include_parent_snapshot,
)
)
except Exception:
Expand Down Expand Up @@ -792,7 +793,9 @@ def ExternalJob(
)

job_def = self._get_repo_for_origin(repository_origin).get_job(request.job_name)
ser_job_data = serialize_value(external_job_data_from_def(job_def))
ser_job_data = serialize_value(
external_job_data_from_def(job_def, include_parent_snapshot=False)
)
return api_pb2.ExternalJobReply(serialized_job_data=ser_job_data)
except Exception:
return api_pb2.ExternalJobReply(
Expand Down
5 changes: 5 additions & 0 deletions python_modules/dagster/dagster/_grpc/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ class JobSubsetSnapshotArgs(
("op_selection", Optional[Sequence[str]]),
("asset_selection", Optional[AbstractSet[AssetKey]]),
("asset_check_selection", Optional[AbstractSet[AssetCheckKey]]),
("include_parent_snapshot", bool),
],
)
):
Expand All @@ -504,6 +505,7 @@ def __new__(
op_selection: Optional[Sequence[str]],
asset_selection: Optional[AbstractSet[AssetKey]] = None,
asset_check_selection: Optional[AbstractSet[AssetCheckKey]] = None,
include_parent_snapshot: Optional[bool] = None,
):
return super(JobSubsetSnapshotArgs, cls).__new__(
cls,
Expand All @@ -515,6 +517,9 @@ def __new__(
asset_check_selection=check.opt_nullable_set_param(
asset_check_selection, "asset_check_selection"
),
include_parent_snapshot=(
include_parent_snapshot if include_parent_snapshot is not None else True
),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@ def external_job_from_recon_job(recon_job, op_selection, repository_handle, asse
job_def = recon_job.get_definition()

return ExternalJob(
external_job_data_from_def(job_def),
external_job_data_from_def(job_def, include_parent_snapshot=True),
repository_handle=repository_handle,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
from .utils import get_bar_repo_code_location


def _test_job_subset_grpc(job_handle, api_client, op_selection=None):
def _test_job_subset_grpc(job_handle, api_client, op_selection=None, include_parent_snapshot=True):
return sync_get_external_job_subset_grpc(
api_client, job_handle.get_external_origin(), op_selection=op_selection
api_client,
job_handle.get_external_origin(),
op_selection=op_selection,
include_parent_snapshot=include_parent_snapshot,
)


Expand All @@ -40,6 +43,7 @@ def test_job_snapshot_deserialize_error(instance):
job_origin=job_handle.get_external_origin(),
op_selection=None,
asset_selection=None,
include_parent_snapshot=True,
)._replace(job_origin="INVALID"),
)
)
Expand All @@ -57,6 +61,24 @@ def test_job_with_valid_subset_snapshot_api_grpc(instance):
assert isinstance(external_job_subset_result, ExternalJobSubsetResult)
assert external_job_subset_result.success is True
assert external_job_subset_result.external_job_data.name == "foo"
assert (
external_job_subset_result.external_job_data.parent_job_snapshot
== code_location.get_repository("bar_repo").get_full_external_job("foo").job_snapshot
)


def test_job_with_valid_subset_snapshot_without_parent_snapshot(instance):
with get_bar_repo_code_location(instance) as code_location:
job_handle = JobHandle("foo", code_location.get_repository("bar_repo").handle)
api_client = code_location.client

external_job_subset_result = _test_job_subset_grpc(
job_handle, api_client, ["do_something"], include_parent_snapshot=False
)
assert isinstance(external_job_subset_result, ExternalJobSubsetResult)
assert external_job_subset_result.success is True
assert external_job_subset_result.external_job_data.name == "foo"
assert not external_job_subset_result.external_job_data.parent_job_snapshot


def test_job_with_invalid_subset_snapshot_api_grpc(instance):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def repo():


def test_external_job_data(snapshot):
snapshot.assert_match(serialize_pp(external_job_data_from_def(foo_job)))
snapshot.assert_match(
serialize_pp(external_job_data_from_def(foo_job, include_parent_snapshot=True))
)


@mock.patch("dagster._core.remote_representation.job_index.create_job_snapshot_id")
Expand Down

0 comments on commit e556a8a

Please sign in to comment.