From 5c5d27ea0c9e7727b24f496c7bb16eae75c40e7e Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Fri, 25 Oct 2024 19:45:01 +0200 Subject: [PATCH 1/2] Experimental in-process workflow runner on `runtime_ctx` --- src/databricks/labs/ucx/runtime.py | 36 ++-- tests/integration/conftest.py | 158 +++++++++--------- tests/integration/contexts/common.py | 2 +- tests/integration/contexts/runtime.py | 74 ++++++++ .../hive_metastore/test_workflows.py | 9 +- 5 files changed, 178 insertions(+), 101 deletions(-) diff --git a/src/databricks/labs/ucx/runtime.py b/src/databricks/labs/ucx/runtime.py index c5954afc62..8cc5f3b3b2 100644 --- a/src/databricks/labs/ucx/runtime.py +++ b/src/databricks/labs/ucx/runtime.py @@ -44,23 +44,25 @@ def __init__(self, workflows: list[Workflow]): @classmethod def all(cls): - return cls( - [ - Assessment(), - MigrationProgress(), - GroupMigration(), - TableMigration(), - MigrateHiveSerdeTablesInPlace(), - MigrateExternalTablesCTAS(), - ValidateGroupPermissions(), - RemoveWorkspaceLocalGroups(), - ScanTablesInMounts(), - MigrateTablesInMounts(), - PermissionsMigrationAPI(), - MigrationRecon(), - Failing(), - ] - ) + return cls(Workflows.definitions()) + + @classmethod + def definitions(cls): + return [ + Assessment(), + MigrationProgress(), + GroupMigration(), + TableMigration(), + MigrateHiveSerdeTablesInPlace(), + MigrateExternalTablesCTAS(), + ValidateGroupPermissions(), + RemoveWorkspaceLocalGroups(), + ScanTablesInMounts(), + MigrateTablesInMounts(), + PermissionsMigrationAPI(), + MigrationRecon(), + Failing(), + ] def tasks(self) -> list[Task]: return self._tasks diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 527359f07a..168bbc384b 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -308,58 +308,58 @@ def get_azure_spark_conf(): @pytest.fixture def runtime_ctx( ws, - make_catalog_fixture, - make_schema_fixture, - make_table_fixture, - make_udf_fixture, - make_group_fixture, - make_job_fixture, - make_query_fixture, - make_dashboard_fixture, - env_or_skip_fixture, - make_random_fixture, + make_catalog, + make_schema, + make_table, + make_udf, + make_group, + make_job, + make_query, + make_dashboard, + env_or_skip, + make_random, ) -> MockRuntimeContext: return MockRuntimeContext( ws, - make_catalog_fixture, - make_schema_fixture, - make_table_fixture, - make_udf_fixture, - make_group_fixture, - make_job_fixture, - make_query_fixture, - make_dashboard_fixture, - env_or_skip_fixture, - make_random_fixture, + make_catalog, + make_schema, + make_table, + make_udf, + make_group, + make_job, + make_query, + make_dashboard, + env_or_skip, + make_random, ) @pytest.fixture def az_cli_ctx( ws, - make_catalog_fixture, - make_schema_fixture, - make_table_fixture, - make_udf_fixture, - make_group_fixture, - make_job_fixture, - make_query_fixture, - make_dashboard_fixture, - env_or_skip_fixture, - make_random_fixture, + make_catalog, + make_schema, + make_table, + make_udf, + make_group, + make_job, + make_query, + make_dashboard, + env_or_skip, + make_random, ): return MockLocalAzureCli( ws, - make_catalog_fixture, - make_schema_fixture, - make_table_fixture, - make_udf_fixture, - make_group_fixture, - make_job_fixture, - make_query_fixture, - make_dashboard_fixture, - env_or_skip_fixture, - make_random_fixture, + make_catalog, + make_schema, + make_table, + make_udf, + make_group, + make_job, + make_query, + make_dashboard, + env_or_skip, + make_random, ) @@ -374,36 +374,36 @@ def aws_cli_ctx(installation_ctx, env_or_skip): @pytest.fixture def installation_ctx( - make_acc_group_fixture, - make_user_fixture, + make_acc_group, + make_user, watchdog_purge_suffix, ws, - make_catalog_fixture, - make_schema_fixture, - make_table_fixture, - make_udf_fixture, - make_group_fixture, - make_job_fixture, - make_query_fixture, - make_dashboard_fixture, - env_or_skip_fixture, - make_random_fixture, + make_catalog, + make_schema, + make_table, + make_udf, + make_group, + make_job, + make_query, + make_dashboard, + env_or_skip, + make_random, ) -> Generator[MockInstallationContext, None, None]: ctx = MockInstallationContext( - make_acc_group_fixture, - make_user_fixture, + make_acc_group, + make_user, watchdog_purge_suffix, ws, - make_catalog_fixture, - make_schema_fixture, - make_table_fixture, - make_udf_fixture, - make_group_fixture, - make_job_fixture, - make_query_fixture, - make_dashboard_fixture, - env_or_skip_fixture, - make_random_fixture, + make_catalog, + make_schema, + make_table, + make_udf, + make_group, + make_job, + make_query, + make_dashboard, + env_or_skip, + make_random, ) yield ctx.replace(workspace_client=ws) ctx.workspace_installation.uninstall() @@ -480,8 +480,14 @@ def prepare_regular_tables(context, external_csv, schema) -> dict[str, TableInfo @pytest.fixture -def prepare_tables_for_migration( - ws, installation_ctx, make_catalog, make_random, make_mounted_location, env_or_skip, make_storage_dir, request +def prepare_tables_for_migration( # TODO: make this a function, so that installation_ctx / runtime_ctx could be swapped + runtime_ctx, + make_catalog, + make_random, + make_mounted_location, + env_or_skip, + make_storage_dir, + request, ) -> tuple[dict[str, TableInfo], SchemaInfo]: # Here we use pytest indirect parametrization, so the test function can pass arguments to this fixture and the # arguments will be available in the request.param. If the argument is "hiveserde", we will prepare hiveserde @@ -492,24 +498,24 @@ def prepare_tables_for_migration( random = make_random(5).lower() # create external and managed tables to be migrated if is_hiveserde: - schema = installation_ctx.make_schema(catalog_name="hive_metastore", name=f"hiveserde_in_place_{random}") + schema = runtime_ctx.make_schema(catalog_name="hive_metastore", name=f"hiveserde_in_place_{random}") table_base_dir = make_storage_dir( path=f'dbfs:/mnt/{env_or_skip("TEST_MOUNT_NAME")}/a/hiveserde_in_place_{random}' ) - tables = prepare_hiveserde_tables(installation_ctx, random, schema, table_base_dir) + tables = prepare_hiveserde_tables(runtime_ctx, random, schema, table_base_dir) else: - schema = installation_ctx.make_schema(catalog_name="hive_metastore", name=f"migrate_{random}") - tables = prepare_regular_tables(installation_ctx, make_mounted_location, schema) + schema = runtime_ctx.make_schema(catalog_name="hive_metastore", name=f"migrate_{random}") + tables = prepare_regular_tables(runtime_ctx, make_mounted_location, schema) # create destination catalog and schema dst_catalog = make_catalog() - dst_schema = installation_ctx.make_schema(catalog_name=dst_catalog.name, name=schema.name) + dst_schema = runtime_ctx.make_schema(catalog_name=dst_catalog.name, name=schema.name) migrate_rules = [Rule.from_src_dst(table, dst_schema) for _, table in tables.items()] - installation_ctx.with_table_mapping_rules(migrate_rules) - installation_ctx.with_dummy_resource_permission() - installation_ctx.save_tables(is_hiveserde=is_hiveserde) - installation_ctx.save_mounts() - installation_ctx.with_dummy_grants_and_tacls() + runtime_ctx.with_table_mapping_rules(migrate_rules) + runtime_ctx.with_dummy_resource_permission() + runtime_ctx.save_tables(is_hiveserde=is_hiveserde) + runtime_ctx.save_mounts() + runtime_ctx.with_dummy_grants_and_tacls() return tables, dst_schema diff --git a/tests/integration/contexts/common.py b/tests/integration/contexts/common.py index e9044f8324..e6a920d875 100644 --- a/tests/integration/contexts/common.py +++ b/tests/integration/contexts/common.py @@ -407,7 +407,7 @@ def _crawl(self) -> Iterable[crawlers.Result]: class StaticServicePrincipalCrawler(AzureServicePrincipalCrawler): def __init__(self, dummy: list[AzureServicePrincipalInfo]): - super().__init__(create_autospec(WorkspaceClient), create_autospec(SqlBackend), "...") + super().__init__(create_autospec(WorkspaceClient), create_autospec(SqlBackend), "dummy") self._dummy = dummy def _try_fetch(self) -> Iterable[AzureServicePrincipalInfo]: diff --git a/tests/integration/contexts/runtime.py b/tests/integration/contexts/runtime.py index 79bb17f974..053f635154 100644 --- a/tests/integration/contexts/runtime.py +++ b/tests/integration/contexts/runtime.py @@ -1,8 +1,82 @@ +import logging +from datetime import timedelta +from functools import cached_property + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.lsql.backends import SqlBackend + +from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.contexts.workflow_task import RuntimeContext +from databricks.labs.ucx.install import deploy_schema +from databricks.labs.ucx.installer.workflows import DeployedWorkflows +from databricks.labs.ucx.runtime import Workflows from tests.integration.contexts.common import IntegrationContext +logger = logging.getLogger(__name__) + class MockRuntimeContext(IntegrationContext, RuntimeContext): def __init__(self, *args): super().__init__(*args) RuntimeContext.__init__(self) + + @cached_property + def deployed_workflows(self) -> DeployedWorkflows: + return InProcessDeployedWorkflows(self) + + @cached_property + def workspace_installation(self): + return MockWorkspaceInstallation(self.sql_backend, self.config, self.installation) + + +class InProcessDeployedWorkflows(DeployedWorkflows): + """This class runs workflows on the client side instead of deploying them to Databricks.""" + + def __init__(self, ctx: RuntimeContext): + super().__init__(ctx.workspace_client, ctx.install_state) + self._workflows = {workflow.name: workflow for workflow in Workflows.definitions()} + self._ctx = ctx + + def run_workflow(self, step: str, skip_job_wait: bool = False, max_wait: timedelta = timedelta(minutes=20)): + workflow = self._workflows[step] + incoming = {task.name: 0 for task in workflow.tasks()} + queue = [] + for task in workflow.tasks(): + task.workflow = workflow.name + incoming[task.name] += len(task.depends_on) + for task in workflow.tasks(): + if incoming[task.name] == 0: + queue.append(task) + while queue: + task = queue.pop(0) + fn = getattr(workflow, task.name) + # TODO: capture error logs and fail if there is ERROR event, to simulate parse_logs meta-task + fn(self._ctx) + for dep in task.depends_on: + incoming[dep] -= 1 + if incoming[dep] == 0: + queue.append(dep) + + def relay_logs(self, workflow: str | None = None): + pass # noop + + +class MockWorkspaceInstallation: + def __init__(self, sql_backend: SqlBackend, config: WorkspaceConfig, installation: Installation): + self._sql_backend = sql_backend + self._config = config + self._installation = installation + + def run(self): + deploy_schema(self._sql_backend, self._config.inventory_database) + + @property + def config(self): + return self._config + + @property + def folder(self): + return self._installation.install_folder() + + def uninstall(self): + pass # noop diff --git a/tests/integration/hive_metastore/test_workflows.py b/tests/integration/hive_metastore/test_workflows.py index 9e85e7364a..42c1017c4d 100644 --- a/tests/integration/hive_metastore/test_workflows.py +++ b/tests/integration/hive_metastore/test_workflows.py @@ -12,15 +12,10 @@ ], indirect=("prepare_tables_for_migration",), ) -def test_table_migration_job_refreshes_migration_status( - ws, - installation_ctx, - prepare_tables_for_migration, - workflow, -): +def test_table_migration_job_refreshes_migration_status(runtime_ctx, prepare_tables_for_migration, workflow): """The migration status should be refreshed after the migration job.""" tables, _ = prepare_tables_for_migration - ctx = installation_ctx.replace( + ctx = runtime_ctx.replace( extend_prompts={ r".*Do you want to update the existing installation?.*": 'yes', }, From 5a735213fc95fccabe3759ad5a0e4534d3f82bef Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Fri, 25 Oct 2024 19:57:35 +0200 Subject: [PATCH 2/2] ... --- tests/integration/contexts/runtime.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/integration/contexts/runtime.py b/tests/integration/contexts/runtime.py index 053f635154..44975a293f 100644 --- a/tests/integration/contexts/runtime.py +++ b/tests/integration/contexts/runtime.py @@ -40,22 +40,25 @@ def __init__(self, ctx: RuntimeContext): def run_workflow(self, step: str, skip_job_wait: bool = False, max_wait: timedelta = timedelta(minutes=20)): workflow = self._workflows[step] incoming = {task.name: 0 for task in workflow.tasks()} - queue = [] + downstreams = {task.name: [] for task in workflow.tasks()} + queue: list[str] = [] for task in workflow.tasks(): task.workflow = workflow.name - incoming[task.name] += len(task.depends_on) + for dep in task.depends_on: + downstreams[dep].append(task.name) + incoming[task.name] += 1 for task in workflow.tasks(): if incoming[task.name] == 0: - queue.append(task) + queue.append(task.name) while queue: - task = queue.pop(0) - fn = getattr(workflow, task.name) + task_name = queue.pop(0) + fn = getattr(workflow, task_name) # TODO: capture error logs and fail if there is ERROR event, to simulate parse_logs meta-task fn(self._ctx) - for dep in task.depends_on: - incoming[dep] -= 1 - if incoming[dep] == 0: - queue.append(dep) + for dep_name in downstreams[task_name]: + incoming[dep_name] -= 1 + if incoming[dep_name] == 0: + queue.append(dep_name) def relay_logs(self, workflow: str | None = None): pass # noop