Skip to content

Commit

Permalink
feat: BI-6025 revisionId in us_manager
Browse files Browse the repository at this point in the history
  • Loading branch information
ForrestGump committed Jan 17, 2025
1 parent df1c0a7 commit f1d2b8d
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 27 deletions.
23 changes: 14 additions & 9 deletions lib/dl_api_lib/dl_api_lib/app/control_api/resources/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import logging
from typing import (
TYPE_CHECKING,
Any,
NoReturn,
Optional,
)

from flask import request
Expand Down Expand Up @@ -37,7 +39,10 @@
DatabaseUnavailable,
USPermissionRequired,
)
from dl_core.us_connection_base import ConnectionBase
from dl_core.us_connection_base import (
ConnectionBase,
DataSourceTemplate,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -205,28 +210,28 @@ def put(self, connection_id): # type: ignore # TODO: fix
us_manager.save(conn)


def _dump_source_templates(tpls) -> dict: # type: ignore # TODO: fix
def _dump_source_templates(tpls: Optional[list[DataSourceTemplate]]) -> Optional[list[dict[str, Any]]]:
if tpls is None:
return None # type: ignore # TODO: fix
return [dict(tpl._asdict(), parameter_hash=tpl.get_param_hash()) for tpl in tpls] # type: ignore # TODO: fix
return None
return [dict(tpl._asdict(), parameter_hash=tpl.get_param_hash()) for tpl in tpls]


@ns.route("/<connection_id>/info/metadata_sources")
class ConnectionInfoMetadataSources(BIResource):
@schematic_request(ns=ns, responses={200: ("Success", ConnectionSourceTemplatesResponseSchema())})
def get(self, connection_id): # type: ignore # TODO: fix
connection = self.get_us_manager().get_by_id(connection_id, expected_type=ConnectionBase)
def get(self, connection_id: str) -> dict[str, Optional[list[dict[str, Any]]]]:
connection: ConnectionBase = self.get_us_manager().get_by_id(connection_id, expected_type=ConnectionBase)

localizer = self.get_service_registry().get_localizer()
source_template_templates = connection.get_data_source_template_templates(localizer=localizer) # type: ignore # 2024-01-24 # TODO: "USEntry" has no attribute "get_data_source_template_templates" [attr-defined]
source_template_templates = connection.get_data_source_template_templates(localizer=localizer)

source_templates = []
source_templates: Optional[list[DataSourceTemplate]] = []
try:
need_permission_on_entry(connection, USPermissionKind.read)
except USPermissionRequired:
pass
else:
source_templates = connection.get_data_source_local_templates() # type: ignore # 2024-01-24 # TODO: "USEntry" has no attribute "get_data_source_local_templates" [attr-defined]
source_templates = connection.get_data_source_local_templates()

return {
"sources": _dump_source_templates(source_templates),
Expand Down
14 changes: 12 additions & 2 deletions lib/dl_core/dl_core/united_storage_client_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,18 @@ async def _request(self, request_data: UStorageClientBase.RequestData) -> dict:
)
return self._get_us_json_from_response(response_adapter)

async def get_entry(self, entry_id: str) -> dict:
return await self._request(self._req_data_get_entry(entry_id=entry_id))
async def get_entry(
self,
entry_id: str,
params: Optional[dict[str, str]] = None,
include_permissions: bool = True,
include_links: bool = True,
) -> dict:
return await self._request(
self._req_data_get_entry(
entry_id=entry_id, params=params, include_permissions=include_permissions, include_links=include_links
)
)

async def create_entry(
self,
Expand Down
16 changes: 13 additions & 3 deletions lib/dl_core/dl_core/us_manager/schema_migration/factory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from dl_core.us_connection import get_schema_migration_cls
from dl_core.us_manager.schema_migration.base import BaseEntrySchemaMigration
from dl_core.us_manager.schema_migration.dataset import DatasetSchemaMigration
from dl_core.us_manager.schema_migration.factory_base import EntrySchemaMigrationFactoryBase


if TYPE_CHECKING:
from dl_core.services_registry import ServicesRegistry


class DummyEntrySchemaMigrationFactory(EntrySchemaMigrationFactoryBase):
def get_schema_migration(
self,
entry_scope: str,
entry_type: str,
service_registry: ServicesRegistry | None = None,
) -> BaseEntrySchemaMigration:
return BaseEntrySchemaMigration()

Expand All @@ -18,11 +27,12 @@ def get_schema_migration(
self,
entry_scope: str,
entry_type: str,
service_registry: ServicesRegistry | None = None,
) -> BaseEntrySchemaMigration:
if entry_scope == "dataset":
return DatasetSchemaMigration()
return DatasetSchemaMigration(services_registry=service_registry)
elif entry_scope == "connection":
schema_migration_cls = get_schema_migration_cls(conn_type_name=entry_type)
return schema_migration_cls()
return schema_migration_cls(services_registry=service_registry)
else:
return BaseEntrySchemaMigration()
return BaseEntrySchemaMigration(services_registry=service_registry)
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING

from dl_core.us_manager.schema_migration.base import BaseEntrySchemaMigration


if TYPE_CHECKING:
from dl_core.services_registry import ServicesRegistry


class EntrySchemaMigrationFactoryBase(abc.ABC):
@abc.abstractmethod
def get_schema_migration(
self,
entry_scope: str,
entry_type: str,
service_registry: ServicesRegistry | None = None,
) -> BaseEntrySchemaMigration:
pass
8 changes: 8 additions & 0 deletions lib/dl_core/dl_core/us_manager/us_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from dl_core.us_manager.crypto.main import CryptoController
from dl_core.us_manager.local_cache import USEntryBuffer
from dl_core.us_manager.schema_migration.base import BaseEntrySchemaMigration
from dl_core.us_manager.schema_migration.factory import DefaultEntrySchemaMigrationFactory
from dl_core.us_manager.schema_migration.factory_base import EntrySchemaMigrationFactoryBase
from dl_core.us_manager.storage_schemas.connection_schema_registry import MAP_TYPE_TO_SCHEMA_MAP_TYPE_TO_SCHEMA
Expand Down Expand Up @@ -172,6 +173,13 @@ def get_lifecycle_manager(
entry=entry, us_manager=self, service_registry=service_registry
)

def get_schema_migration(
self, entry_scope: str, entry_type: str, service_registry: Optional[ServicesRegistry] = None
) -> BaseEntrySchemaMigration:
if service_registry is None:
service_registry = self.get_services_registry()
return self._schema_migration_factory.get_schema_migration(entry_scope, entry_type, service_registry)

# TODO FIX: Prevent saving entries with tenant ID that doesn't match current tenant ID
def set_tenant_override(self, tenant: TenantDef) -> None:
if not self._us_auth_context.is_tenant_id_mutable():
Expand Down
27 changes: 21 additions & 6 deletions lib/dl_core/dl_core/us_manager/us_manager_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,37 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
LOGGER.warning("Error during closing AsyncUSManager", exc_info=True)

@overload
async def get_by_id(self, entry_id: str, expected_type: type(None) = None) -> USEntry: # type: ignore # TODO: fix
async def get_by_id(
self,
entry_id: str,
expected_type: None = None,
params: Optional[dict[str, str]] = None,
) -> USEntry:
pass

@overload # noqa
async def get_by_id(self, entry_id: str, expected_type: Type[_ENTRY_TV] = None) -> _ENTRY_TV: # type: ignore # TODO: fix
@overload
async def get_by_id(
self,
entry_id: str,
expected_type: Optional[Type[_ENTRY_TV]] = None,
params: Optional[dict[str, str]] = None,
) -> _ENTRY_TV:
pass

@generic_profiler_async("us-fetch-entity") # type: ignore # TODO: fix
async def get_by_id(self, entry_id: str, expected_type: Type[_ENTRY_TV] = None) -> _ENTRY_TV: # type: ignore # TODO: fix
async def get_by_id(
self,
entry_id: str,
expected_type: Optional[Type[USEntry]] = None,
params: Optional[dict[str, str]] = None,
) -> USEntry:
with self._enrich_us_exception(
entry_id=entry_id,
entry_scope=expected_type.scope if expected_type is not None else None,
):
us_resp = await self._us_client.get_entry(entry_id)
us_resp = await self._us_client.get_entry(entry_id, params=params)

obj: _ENTRY_TV = self._entry_dict_to_obj(us_resp, expected_type) # type: ignore # TODO: fix
obj = self._entry_dict_to_obj(us_resp, expected_type)
await self.get_lifecycle_manager(entry=obj).post_init_async_hook()

return obj
Expand Down
39 changes: 32 additions & 7 deletions lib/dl_core/dl_core/us_manager/us_manager_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,22 +169,39 @@ def delete(self, entry: USEntry) -> None:
LOGGER.exception("Error during post-delete hook execution for entry %s", entry.uuid)

@overload
def get_by_id(self, entry_id: str, expected_type: type(None) = None) -> USEntry: # type: ignore # TODO: fix
def get_by_id(
self,
entry_id: str,
expected_type: None = None,
params: Optional[dict[str, str]] = None,
) -> USEntry:
pass

@overload # noqa
def get_by_id(self, entry_id: str, expected_type: Type[_ENTRY_TV] = None) -> _ENTRY_TV: # type: ignore # TODO: fix
@overload
def get_by_id(
self,
entry_id: str,
expected_type: Optional[Type[_ENTRY_TV]] = None,
params: Optional[dict[str, str]] = None,
) -> _ENTRY_TV:
pass

@generic_profiler("us-fetch-entity")
def get_by_id(self, entry_id: str, expected_type: Optional[Type[USEntry]] = None) -> USEntry:
def get_by_id(
self,
entry_id: str,
expected_type: Optional[Type[USEntry]] = None,
params: Optional[dict[str, str]] = None,
) -> USEntry:
with self._enrich_us_exception(
entry_id=entry_id,
entry_scope=expected_type.scope if expected_type is not None else None,
):
us_resp = self._us_client.get_entry(entry_id)
us_resp = self._us_client.get_entry(entry_id, params=params)

obj = self._entry_dict_to_obj(us_resp, expected_type)
await_sync(self.get_lifecycle_manager(entry=obj).post_init_async_hook())

return obj

@overload
Expand Down Expand Up @@ -242,8 +259,16 @@ def get_collection(
if raise_on_broken_entry:
raise

def get_raw_entry(self, entry_id: str, include_permissions: bool = True, include_links: bool = True) -> dict:
return self._us_client.get_entry(entry_id, include_permissions=include_permissions, include_links=include_links)
def get_raw_entry(
self,
entry_id: str,
params: Optional[dict[str, str]] = None,
include_permissions: bool = True,
include_links: bool = True,
) -> dict[str, Any]:
return self._us_client.get_entry(
entry_id, params=params, include_permissions=include_permissions, include_links=include_links
)

def get_raw_collection(
self,
Expand Down

0 comments on commit f1d2b8d

Please sign in to comment.