diff --git a/connectors/config.py b/connectors/config.py index 0fd988695..f16fb01fc 100644 --- a/connectors/config.py +++ b/connectors/config.py @@ -121,6 +121,7 @@ def _default_config(): "network_drive": "connectors.sources.network_drive:NASDataSource", "notion": "connectors.sources.notion:NotionDataSource", "onedrive": "connectors.sources.onedrive:OneDriveDataSource", + "onelake": "connectors.sources.onelake:OneLakeDataSource", "oracle": "connectors.sources.oracle:OracleDataSource", "outlook": "connectors.sources.outlook:OutlookDataSource", "postgresql": "connectors.sources.postgresql:PostgreSQLDataSource", diff --git a/connectors/sources/onelake.py b/connectors/sources/onelake.py new file mode 100644 index 000000000..3a525bc5b --- /dev/null +++ b/connectors/sources/onelake.py @@ -0,0 +1,322 @@ +"""OneLake connector to retrieve data from datalakes""" + +from functools import partial + +from azure.identity import ClientSecretCredential +from azure.storage.filedatalake import DataLakeServiceClient + +from connectors.source import BaseDataSource + +ACCOUNT_NAME = "onelake" + + +class OneLakeDataSource(BaseDataSource): + """OneLake""" + + name = "OneLake" + service_type = "onelake" + incremental_sync_enabled = True + + def __init__(self, configuration): + """Set up the connection to the azure base client + + Args: + configuration (DataSourceConfiguration): Object of DataSourceConfiguration class. + """ + super().__init__(configuration=configuration) + self.tenant_id = self.configuration["tenant_id"] + self.client_id = self.configuration["client_id"] + self.client_secret = self.configuration["client_secret"] + self.workspace_name = self.configuration["workspace_name"] + self.data_path = self.configuration["data_path"] + self.account_url = ( + f"https://{self.configuration['account_name']}.dfs.fabric.microsoft.com" + ) + + @classmethod + def get_default_configuration(cls): + """Get the default configuration for OneLake + + Returns: + dictionary: Default configuration + """ + return { + "tenant_id": { + "label": "OneLake tenant id", + "order": 1, + "type": "str", + }, + "client_id": { + "label": "OneLake client id", + "order": 2, + "type": "str", + }, + "client_secret": { + "label": "OneLake client secret", + "order": 3, + "type": "str", + "sensitive": True, + }, + "workspace_name": { + "label": "OneLake workspace name", + "order": 4, + "type": "str", + }, + "data_path": { + "label": "OneLake data path", + "tooltip": "Path in format .Lakehouse/files/", + "order": 5, + "type": "str", + }, + "account_name": { + "tooltip": "In the most cases is 'onelake'", + "default_value": ACCOUNT_NAME, + "label": "Account name", + "order": 6, + "type": "str", + }, + } + + async def ping(self): + """Verify the connection with OneLake""" + + self._logger.info("Generating file system client...") + + try: + await self._get_directory_paths(self.configuration["data_path"]) + self._logger.info( + f"Connection to OneLake successful to {self.configuration['data_path']}" + ) + + except Exception: + self._logger.exception("Error while connecting to OneLake.") + raise + + def _get_token_credentials(self): + """Get the token credentials for OneLake + + Returns: + obj: Token credentials + """ + + tenant_id = self.configuration["tenant_id"] + client_id = self.configuration["client_id"] + client_secret = self.configuration["client_secret"] + + try: + return ClientSecretCredential(tenant_id, client_id, client_secret) + except Exception as e: + self._logger.error(f"Error while getting token credentials: {e}") + raise + + async def _get_service_client(self): + """Get the service client for OneLake + + Returns: + obj: Service client + """ + + try: + return DataLakeServiceClient( + account_url=self.account_url, + credential=self._get_token_credentials(), + ) + except Exception as e: + self._logger.error(f"Error while getting service client: {e}") + raise + + async def _get_file_system_client(self): + """Get the file system client for OneLake + + Returns: + obj: File system client + """ + try: + service_client = await self._get_service_client() + + return service_client.get_file_system_client( + self.configuration["workspace_name"] + ) + except Exception as e: + self._logger.error(f"Error while getting file system client: {e}") + raise + + async def _get_directory_client(self): + """Get the directory client for OneLake + + Returns: + obj: Directory client + """ + + try: + file_system_client = await self._get_file_system_client() + + return file_system_client.get_directory_client( + self.configuration["data_path"] + ) + except Exception as e: + self._logger.error(f"Error while getting directory client: {e}") + raise + + async def _get_file_client(self, file_name): + """Get file client from OneLake + + Args: + file_name (str): name of the file + + Returns: + obj: File client + """ + + try: + directory_client = await self._get_directory_client() + + return directory_client.get_file_client(file_name) + except Exception as e: + self._logger.error(f"Error while getting file client: {e}") + raise + + async def _get_directory_paths(self, directory_path): + """List directory paths from data lake + + Args: + directory_path (str): Directory path + + Returns: + list: List of paths + """ + + try: + file_system_client = await self._get_file_system_client() + + return file_system_client.get_paths(path=directory_path) + except Exception as e: + self._logger.error(f"Error while getting directory paths: {e}") + raise + + def format_file(self, file_client): + """Format file_client to be processed + + Args: + file_client (obj): File object + + Returns: + dict: Formatted file + """ + + try: + file_properties = file_client.get_file_properties() + + return { + "_id": f"{file_client.file_system_name}_{file_properties.name.split('/')[-1]}", + "name": file_properties.name.split("/")[-1], + "created_at": file_properties.creation_time.isoformat(), + "_timestamp": file_properties.last_modified.isoformat(), + "size": file_properties.size, + } + except Exception as e: + self._logger.error( + f"Error while formatting file or getting file properties: {e}" + ) + raise + + async def download_file(self, file_client): + """Download file from OneLake + + Args: + file_client (obj): File client + + Returns: + generator: File stream + """ + + try: + download = file_client.download_file() + stream = download.chunks() + + for chunk in stream: + yield chunk + except Exception as e: + self._logger.error(f"Error while downloading file: {e}") + raise + + async def get_content(self, file_name, doit=None, timestamp=None): + """Obtains the file content for the specified file in `file_name`. + + Args: + file_name (obj): The file name to process to obtain the content. + timestamp (timestamp, optional): Timestamp of blob last modified. Defaults to None. + doit (boolean, optional): Boolean value for whether to get content or not. Defaults to None. + + Returns: + str: Content of the file or None if not applicable. + """ + + if not doit: + return + + file_client = await self._get_file_client(file_name) + file_properties = file_client.get_file_properties() + file_extension = self.get_file_extension(file_name) + + doc = { + "_id": f"{file_client.file_system_name}_{file_properties.name}", # id in format _ + } + + can_be_downloaded = self.can_file_be_downloaded( + file_extension=file_extension, + filename=file_properties.name, + file_size=file_properties.size, + ) + + if not can_be_downloaded: + self._logger.warning( + f"File {file_properties.name} cannot be downloaded. Skipping." + ) + return doc + + self._logger.debug(f"Downloading file {file_properties.name}...") + extracted_doc = await self.download_and_extract_file( + doc=doc, + source_filename=file_properties.name.split("/")[-1], + file_extension=file_extension, + download_func=partial(self.download_file, file_client), + ) + + return extracted_doc if extracted_doc is not None else doc + + async def prepare_files(self, doc_paths): + """Prepare files for processing + + Args: + doc_paths (list): List of paths extracted from OneLake + + Yields: + tuple: File document and partial function to get content + """ + + for path in doc_paths: + file_name = path.name.split("/")[-1] + field_client = await self._get_file_client(file_name) + + yield self.format_file(field_client) + + async def get_docs(self, filtering=None): + """Get documents from OneLake and index them + + Yields: + tuple: dictionary with meta-data of each file and a partial function to get the file content. + """ + + self._logger.info(f"Fetching files from OneLake datalake {self.data_path}") + + directory_paths = await self._get_directory_paths( + self.configuration["data_path"] + ) + + self._logger.debug(f"Found {len(directory_paths)} files in {self.data_path}") + + async for file in self.prepare_files(directory_paths): + file_dict = file + + yield file_dict, partial(self.get_content, file_dict["name"]) diff --git a/requirements/framework.txt b/requirements/framework.txt index 775ad7ac4..94da2fc48 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -44,3 +44,5 @@ notion-client==2.2.1 certifi==2024.7.4 aioboto3==12.4.0 pyasn1<0.6.1 +azure-identity==1.19.0 +azure-storage-file-datalake==12.14.0 \ No newline at end of file diff --git a/tests/sources/fixtures/onelake/connector.json b/tests/sources/fixtures/onelake/connector.json new file mode 100644 index 000000000..1b74c208b --- /dev/null +++ b/tests/sources/fixtures/onelake/connector.json @@ -0,0 +1,44 @@ +{ + "configuration": { + "tenant_id": { + "label": "OneLake tenant id", + "order": 1, + "type": "str", + "value": "tenant-id" + }, + "client_id": { + "label": "OneLake client id", + "order": 2, + "type": "str", + "value": "client-id" + }, + "client_secret": { + "label": "OneLake client secret", + "order": 3, + "type": "str", + "sensitive": true, + "value": "client-secret" + }, + "workspace_name": { + "label": "OneLake workspace name", + "order": 4, + "type": "str", + "value": "testWorkspace" + }, + "data_path": { + "label": "OneLake data path", + "tooltip": "Path in format .Lakehouse/files/", + "order": 5, + "type": "str", + "value": "test_data_path" + }, + "account_name": { + "tooltip": "In the most cases is 'onelake'", + "default_value": "onelake", + "label": "Account name", + "order": 6, + "type": "str", + "value": "onelake" + } + } +} diff --git a/tests/sources/fixtures/onelake/docker-compose.yml b/tests/sources/fixtures/onelake/docker-compose.yml new file mode 100644 index 000000000..941f93f0a --- /dev/null +++ b/tests/sources/fixtures/onelake/docker-compose.yml @@ -0,0 +1,44 @@ +version: "3.9" + +services: + elasticsearch: + image: ${ELASTICSEARCH_DRA_DOCKER_IMAGE} + container_name: elasticsearch + environment: + - cluster.name=docker-cluster + - bootstrap.memory_lock=true + - ES_JAVA_OPTS=-Xms2g -Xmx2g + - ELASTIC_PASSWORD=changeme + - xpack.security.enabled=true + - xpack.security.authc.api_key.enabled=true + - discovery.type=single-node + - action.destructive_requires_name=false + ulimits: + memlock: + soft: -1 + hard: -1 + volumes: + - esdata:/usr/share/elasticsearch/data + ports: + - 9200:9200 + networks: + - esnet + + onelake: + build: + context: ../../../../ + dockerfile: ${DOCKERFILE_FTEST_PATH} + command: .venv/bin/python tests/sources/fixtures/onelake/fixture.py + ports: + - "8000:8000" + volumes: + - .:/python-flask + restart: always + +volumes: + esdata: + driver: local + +networks: + esnet: + driver: bridge diff --git a/tests/sources/fixtures/onelake/fixture.py b/tests/sources/fixtures/onelake/fixture.py new file mode 100644 index 000000000..48ac90fc2 --- /dev/null +++ b/tests/sources/fixtures/onelake/fixture.py @@ -0,0 +1,92 @@ +from typing import Dict, List + +""" +This is a fixture for generating test data and listing paths in a file system. +""" + + +class PathProperties: + def __init__(self, name: str, is_directory: bool = False): + self.name = name + self.is_directory = is_directory + + def to_dict(self): + return {"name": self.name, "is_directory": self.is_directory} + + +class ItemPaged: + def __init__(self, items: List[PathProperties]): + self.items = items + + def __iter__(self): + for item in self.items: + yield item + + +class FileSystemClient: + def __init__(self, file_system_name: str): + self.file_system_name = file_system_name + self.files = {} + + def add_file(self, file_path: str, is_directory: bool = False): + self.files[file_path] = is_directory + + def get_paths(self, path: str = None): + paths = [ + PathProperties(name=file_path, is_directory=is_directory) + for file_path, is_directory in self.files.items() + if path is None or file_path.startswith(path) + ] + return ItemPaged(paths) + + +FILE_SYSTEMS: Dict[str, FileSystemClient] = {} + + +def create_file_system(name: str): + if name not in FILE_SYSTEMS: + FILE_SYSTEMS[name] = FileSystemClient(name) + + +def load(config: Dict): + """ + Loads initial data into the backend based on OneLake configuration. + + Args: + config: Dictionary containing OneLake configuration with format: + { + "configuration": { + "workspace_name": {"value": str}, + "data_path": {"value": str}, + ... + } + } + """ + if not config.get("configuration"): + raise ValueError("Invalid configuration format") + + conf = config["configuration"] + workspace_name = conf["workspace_name"]["value"] + data_path = conf["data_path"]["value"] + + create_file_system(workspace_name) + generate_test_data(workspace_name, data_path, file_count=10000) + + +def generate_test_data(file_system_name: str, folder_path: str, file_count: int): + create_file_system(file_system_name) + file_system = FILE_SYSTEMS[file_system_name] + + file_system.add_file(folder_path, is_directory=True) + + for i in range(file_count): + file_name = f"{folder_path}/file_{i}.txt" + file_system.add_file(file_name) + + +def list_paths(file_system_name: str, folder_path: str): + if file_system_name not in FILE_SYSTEMS: + return [] + + file_system = FILE_SYSTEMS[file_system_name] + return file_system.get_paths(path=folder_path) diff --git a/tests/sources/test_onelake.py b/tests/sources/test_onelake.py new file mode 100644 index 000000000..aa37bf790 --- /dev/null +++ b/tests/sources/test_onelake.py @@ -0,0 +1,606 @@ +from contextlib import asynccontextmanager +from datetime import datetime +from functools import partial +from unittest.mock import AsyncMock, MagicMock, Mock, call, patch + +import pytest + +from connectors.sources.onelake import OneLakeDataSource +from tests.sources.support import create_source + + +@asynccontextmanager +async def create_abs_source( + use_text_extraction_service=False, +): + async with create_source( + OneLakeDataSource, + tenant_id="fake-tenant", + client_id="-fake-client", + client_secret="fake-client", + workspace_name="FakeWorkspace", + data_path="FakeDatalake.Lakehouse/Files/Data", + use_text_extraction_service=use_text_extraction_service, + ) as source: + yield source + + +@pytest.mark.asyncio +async def test_ping_for_successful_connection(): + """Test ping method of OneLakeDataSource class""" + + # Setup + async with create_abs_source() as source: + with patch.object( + source, "_get_directory_paths", new_callable=AsyncMock + ) as mock_get_paths: + mock_get_paths.return_value = [] + + # Run + await source.ping() + + # Check + mock_get_paths.assert_called_once_with(source.configuration["data_path"]) + + +@pytest.mark.asyncio +async def test_ping_for_failed_connection(): + """Test ping method of OneLakeDataSource class with negative case""" + + # Setup + async with create_abs_source() as source: + with patch.object( + source, "_get_directory_paths", new_callable=AsyncMock + ) as mock_get_paths: + mock_get_paths.side_effect = Exception("Something went wrong") + + # Run & Check + with pytest.raises(Exception, match="Something went wrong"): + await source.ping() + + mock_get_paths.assert_called_once_with(source.configuration["data_path"]) + + # Cleanup + mock_get_paths.reset_mock + + +@pytest.mark.asyncio +async def test_get_account_url(): + """Test _get_account_url method of OneLakeDataSource class""" + + # Setup + async with create_abs_source() as source: + account_name = source.configuration["account_name"] + expected_url = f"https://{account_name}.dfs.fabric.microsoft.com" + + # Run + actual_url = source._get_account_url() + + # Check + assert actual_url == expected_url + + +@pytest.mark.asyncio +async def test_get_token_credentials(): + """Test _get_token_credentials method of OneLakeDataSource class""" + + # Setup + async with create_abs_source() as source: + tenant_id = source.configuration["tenant_id"] + client_id = source.configuration["client_id"] + client_secret = source.configuration["client_secret"] + + with patch( + "connectors.sources.onelake.ClientSecretCredential", autospec=True + ) as mock_credential: + mock_instance = mock_credential.return_value + + # Run + credentials = source._get_token_credentials() + + # Check + mock_credential.assert_called_once_with(tenant_id, client_id, client_secret) + assert credentials is mock_instance + + +@pytest.mark.asyncio +async def test_get_token_credentials_error(): + """Test _get_token_credentials method when credential creation fails""" + + async with create_abs_source() as source: + with patch( + "connectors.sources.onelake.ClientSecretCredential", autospec=True + ) as mock_credential: + mock_credential.side_effect = Exception("Credential error") + + with pytest.raises(Exception, match="Credential error"): + source._get_token_credentials() + + +@pytest.mark.asyncio +async def test_get_service_client(): + """Test _get_service_client method of OneLakeDataSource class""" + + # Setup + async with create_abs_source() as source: + mock_service_client = Mock() + mock_account_url = "https://mockaccount.dfs.fabric.microsoft.com" + mock_credentials = Mock() + + with patch( + "connectors.sources.onelake.DataLakeServiceClient", + autospec=True, + ) as mock_client, patch.object( + source, + "_get_account_url", + return_value=mock_account_url, + ), patch.object( + source, "_get_token_credentials", return_value=mock_credentials + ): + mock_client.return_value = mock_service_client + + # Run + service_client = await source._get_service_client() + + # Check + mock_client.assert_called_once_with( + account_url=mock_account_url, + credential=mock_credentials, + ) + assert service_client is mock_service_client + + +@pytest.mark.asyncio +async def test_get_service_client_error(): + """Test _get_service_client method when client creation fails""" + + async with create_abs_source() as source: + with patch( + "connectors.sources.onelake.DataLakeServiceClient", + side_effect=Exception("Service client error"), + ): + with pytest.raises(Exception, match="Service client error"): + await source._get_service_client() + + +@pytest.mark.asyncio +async def test_get_file_system_client(): + """Test _get_file_system_client method of OneLakeDataSource class""" + + # Setup + async with create_abs_source() as source: + mock_file_system_client = Mock() + workspace_name = source.configuration["workspace_name"] + + with patch.object( + source, "_get_service_client", new_callable=AsyncMock + ) as mock_get_service_client: + mock_service_client = Mock() + mock_service_client.get_file_system_client.return_value = ( + mock_file_system_client + ) + mock_get_service_client.return_value = mock_service_client + + # Run + file_system_client = await source._get_file_system_client() + + # Check + mock_service_client.get_file_system_client.assert_called_once_with( + workspace_name + ) + assert file_system_client == mock_file_system_client + + +@pytest.mark.asyncio +async def test_get_file_system_client_error(): + """Test _get_file_system_client method when client creation fails""" + + async with create_abs_source() as source: + mock_service_client = Mock() + mock_service_client.get_file_system_client.side_effect = Exception( + "File system error" + ) + + with patch.object( + source, "_get_service_client", new_callable=AsyncMock + ) as mock_get_service_client: + mock_get_service_client.return_value = mock_service_client + + with pytest.raises(Exception, match="File system error"): + await source._get_file_system_client() + + +@pytest.mark.asyncio +async def test_get_directory_client(): + """Test _get_directory_client method of OneLakeDataSource class""" + + # Setup + async with create_abs_source() as source: + mock_directory_client = Mock() + data_path = source.configuration["data_path"] + + with patch.object( + source, "_get_file_system_client", new_callable=AsyncMock + ) as mock_get_file_system_client: + mock_file_system_client = Mock() + mock_file_system_client.get_directory_client.return_value = ( + mock_directory_client + ) + mock_get_file_system_client.return_value = mock_file_system_client + + # Run + directory_client = await source._get_directory_client() + + # Check + mock_file_system_client.get_directory_client.assert_called_once_with( + data_path + ) + assert directory_client == mock_directory_client + + +@pytest.mark.asyncio +async def test_get_directory_client_error(): + """Test _get_directory_client method when client creation fails""" + + async with create_abs_source() as source: + mock_file_system_client = Mock() + mock_file_system_client.get_directory_client.side_effect = Exception( + "Directory error" + ) + + with patch.object( + source, "_get_file_system_client", new_callable=AsyncMock + ) as mock_get_file_system_client: + mock_get_file_system_client.return_value = mock_file_system_client + + with pytest.raises(Exception, match="Directory error"): + await source._get_directory_client() + + +@pytest.mark.asyncio +async def test_get_file_client_success(): + """Test successful file client retrieval""" + + mock_file_client = Mock() + mock_directory_client = Mock() + mock_directory_client.get_file_client.return_value = mock_file_client + + async with create_abs_source() as source: + with patch.object( + source, "_get_directory_client", new_callable=AsyncMock + ) as mock_get_directory: + mock_get_directory.return_value = mock_directory_client + + result = await source._get_file_client("test.txt") + + assert result == mock_file_client + mock_directory_client.get_file_client.assert_called_once_with("test.txt") + + +@pytest.mark.asyncio +async def test_get_file_client_error(): + """Test file client retrieval with error""" + + async with create_abs_source() as source: + with patch.object( + source, "_get_directory_client", new_callable=AsyncMock + ) as mock_get_directory: + mock_get_directory.side_effect = Exception("Test error") + + with pytest.raises(Exception, match="Test error"): + await source._get_file_client("test.txt") + + +@pytest.mark.asyncio +async def test_get_directory_paths(): + """Test _get_directory_paths method of OneLakeDataSource class""" + + # Setup + async with create_abs_source() as source: + mock_paths = ["path1", "path2"] + directory_path = "mock_directory_path" + + with patch.object( + source, "_get_file_system_client", new_callable=AsyncMock + ) as mock_get_file_system_client: + mock_get_paths = Mock(return_value=mock_paths) + mock_file_system_client = mock_get_file_system_client.return_value + mock_file_system_client.get_paths = mock_get_paths + + # Run + paths = await source._get_directory_paths(directory_path) + + # Check + mock_file_system_client.get_paths.assert_called_once_with( + path=directory_path + ) + assert paths == mock_paths + + +@pytest.mark.asyncio +async def test_get_directory_paths_error(): + """Test _get_directory_paths method when getting paths fails""" + + async with create_abs_source() as source: + directory_path = "mock_directory_path" + with patch.object( + source, "_get_file_system_client", new_callable=AsyncMock + ) as mock_get_file_system_client: + mock_file_system_client = mock_get_file_system_client.return_value + mock_file_system_client.get_paths = AsyncMock( + side_effect=Exception("Path error") + ) + + with pytest.raises(Exception, match="Path error"): + await source._get_directory_paths(directory_path) + + +@pytest.mark.asyncio +async def test_format_file(): + """Test format_file method of OneLakeDataSource class""" + + # Setup + async with create_abs_source() as source: + mock_file_client = MagicMock() + mock_file_properties = MagicMock( + creation_time=datetime(2022, 4, 21, 12, 12, 30), + last_modified=datetime(2022, 4, 22, 15, 45, 10), + size=2048, + name="path/to/file.txt", + ) + + mock_file_properties.name.split.return_value = ["path", "to", "file.txt"] + mock_file_client.get_file_properties.return_value = mock_file_properties + mock_file_client.file_system_name = "my_file_system" + + expected_output = { + "_id": "my_file_system_file.txt", + "name": "file.txt", + "created_at": "2022-04-21T12:12:30", + "_timestamp": "2022-04-22T15:45:10", + "size": 2048, + } + + # Execute + actual_output = source.format_file(mock_file_client) + + # Assert + assert actual_output == expected_output + mock_file_client.get_file_properties.assert_called_once() + + +@pytest.mark.asyncio +async def test_format_file_error(): + """Test format_file method when getting properties fails""" + + async with create_abs_source() as source: + mock_file_client = MagicMock() + mock_file_client.get_file_properties.side_effect = Exception("Properties error") + mock_file_client.file_system_name = "my_file_system" + + with pytest.raises(Exception, match="Properties error"): + source.format_file(mock_file_client) + + +@pytest.mark.asyncio +async def test_format_file_empty_name(): + """Test format_file method with empty file name""" + + async with create_abs_source() as source: + mock_file_client = MagicMock() + mock_file_properties = MagicMock( + creation_time=datetime(2022, 4, 21, 12, 12, 30), + last_modified=datetime(2022, 4, 22, 15, 45, 10), + size=2048, + name="", + ) + mock_file_properties.name.split.return_value = [""] + mock_file_client.get_file_properties.return_value = mock_file_properties + mock_file_client.file_system_name = "my_file_system" + + result = source.format_file(mock_file_client) + assert result["name"] == "" + assert result["_id"] == "my_file_system_" + + +@pytest.mark.asyncio +async def test_download_file(): + """Test download_file method of OneLakeDataSource class""" + + # Setup + mock_file_client = Mock() + mock_download = Mock() + mock_file_client.download_file.return_value = mock_download + + mock_chunks = ["chunk1", "chunk2", "chunk3"] + + mock_download.chunks.return_value = iter(mock_chunks) + + async with create_abs_source() as source: + # Run + chunks = [] + async for chunk in source.download_file(mock_file_client): + chunks.append(chunk) + + # Check + assert chunks == mock_chunks + mock_file_client.download_file.assert_called_once() + mock_download.chunks.assert_called_once() + + +@pytest.mark.asyncio +async def test_download_file_with_error(): + """Test download_file method of OneLakeDataSource class with exception handling""" + + # Setup + mock_file_client = Mock() + mock_download = Mock() + mock_file_client.download_file.return_value = mock_download + mock_download.chunks.side_effect = Exception("Download error") + + async with create_abs_source() as source: + # Run & Check + with pytest.raises(Exception, match="Download error"): + async for _ in source.download_file(mock_file_client): + pass + + mock_file_client.download_file.assert_called_once() + mock_download.chunks.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_content_with_download(): + """Test get_content method when doit=True""" + + mock_configuration = { + "account_name": "mockaccount", + "tenant_id": "mocktenant", + "client_id": "mockclient", + "client_secret": "mocksecret", + "workspace_name": "mockworkspace", + "data_path": "mockpath", + } + + async with create_abs_source() as source: + source.configuration = mock_configuration + + class FileClientMock: + file_system_name = "mockfilesystem" + + class FileProperties: + def __init__(self, name, size): + self.name = name + self.size = size + + def get_file_properties(self): + return self.FileProperties(name="file1.txt", size=2000) + + with patch.object( + source, "_get_file_client", return_value=FileClientMock() + ), patch.object( + source, "can_file_be_downloaded", return_value=True + ), patch.object( + source, + "download_and_extract_file", + return_value={ + "_id": "mockfilesystem_file1.txt", + "_attachment": "TW9jayBjb250ZW50", + }, + ): + actual_response = await source.get_content("file1.txt", doit=True) + assert actual_response == { + "_id": "mockfilesystem_file1.txt", + "_attachment": "TW9jayBjb250ZW50", + } + + +@pytest.mark.asyncio +async def test_get_content_without_download(): + """Test get_content method when doit=False""" + + async with create_abs_source() as source: + source.configuration = { + "account_name": "mockaccount", + "tenant_id": "mocktenant", + "client_id": "mockclient", + "client_secret": "mocksecret", + "workspace_name": "mockworkspace", + "data_path": "mockpath", + } + + class FileClientMock: + file_system_name = "mockfilesystem" + + class FileProperties: + def __init__(self, name, size): + self.name = name + self.size = size + + def get_file_properties(self): + return self.FileProperties(name="file1.txt", size=2000) + + with patch.object(source, "_get_file_client", return_value=FileClientMock()): + actual_response = await source.get_content("file1.txt", doit=False) + assert actual_response is None + + +@pytest.mark.asyncio +async def test_prepare_files(): + """Test prepare_files method of OneLakeDataSource class""" + + # Setup + doc_paths = [ + Mock( + name="doc1", + **{"name.split.return_value": ["folder", "doc1"], "path": "folder/doc1"}, + ), + Mock( + name="doc2", + **{"name.split.return_value": ["folder", "doc2"], "path": "folder/doc2"}, + ), + ] + mock_field_client = Mock() + + async def mock_format_file(*args, **kwargs): + """Mock for the format_file method""" + + return "file_document", "partial_function" + + async with create_abs_source() as source: + with patch.object( + source, "_get_file_client", new_callable=AsyncMock + ) as mock_get_file_client: + mock_get_file_client.return_value = mock_field_client + + with patch.object(source, "format_file", side_effect=mock_format_file): + result = [] + # Run + async for item in source.prepare_files(doc_paths): + result.append(await item) + + # Check results + assert result == [ + ("file_document", "partial_function"), + ("file_document", "partial_function"), + ] + + mock_get_file_client.assert_has_calls([call("doc1"), call("doc2")]) + + +@pytest.mark.asyncio +async def test_get_docs(): + """Test get_docs method of OneLakeDataSource class""" + + mock_paths = [ + Mock(name="doc1", path="folder/doc1"), + Mock(name="doc2", path="folder/doc2"), + ] + + mock_file_docs = [{"name": "doc1", "id": "1"}, {"name": "doc2", "id": "2"}] + + async def mock_prepare_files_impl(paths): + for doc in mock_file_docs: + yield doc + + async with create_abs_source() as source: + with patch.object( + source, "_get_directory_paths", new_callable=AsyncMock + ) as mock_get_paths: + mock_get_paths.return_value = mock_paths + + with patch.object( + source, "prepare_files", side_effect=mock_prepare_files_impl + ): + result = [] + async for doc, get_content in source.get_docs(): + result.append((doc, get_content)) + + mock_get_paths.assert_called_once_with( + source.configuration["data_path"] + ) + assert len(result) == 2 + for (doc, get_content), expected_doc in zip(result, mock_file_docs): + assert doc == expected_doc + assert isinstance(get_content, partial) + assert get_content.func == source.get_content + assert get_content.args == (doc["name"],)