Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨feat(source-microsoft-sharepoint): add all sites iteration #55912

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ data:
connectorSubtype: file
connectorType: source
definitionId: 59353119-f0f2-4e5a-a8ba-15d887bc34f6
dockerImageTag: 0.8.1
dockerImageTag: 0.9.0
dockerRepository: airbyte/source-microsoft-sharepoint
githubIssueLabel: source-microsoft-sharepoint
icon: microsoft-sharepoint.svg
Expand Down
196 changes: 106 additions & 90 deletions airbyte-integrations/connectors/source-microsoft-sharepoint/poetry.lock

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = [ "poetry-core>=1.0.0",]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
version = "0.8.1"
version = "0.9.0"
name = "source-microsoft-sharepoint"
description = "Source implementation for Microsoft SharePoint."
authors = [ "Airbyte <[email protected]>",]
Expand All @@ -17,10 +17,10 @@ include = "source_microsoft_sharepoint"

[tool.poetry.dependencies]
python = "^3.11,<3.12"
msal = "==1.25.0"
msal = "==1.27.0"
Office365-REST-Python-Client = "==2.5.5"
smart-open = "==6.4.0"
airbyte-cdk = {extras = ["file-based"], version = "^6"}
airbyte-cdk = {extras = ["file-based"], version = "^6.38.5"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set this back to ^6 so that it gets the CDK updates?


[tool.poetry.scripts]
source-microsoft-sharepoint = "source_microsoft_sharepoint.run:run"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@
from io import IOBase
from os import makedirs, path
from os.path import getsize
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Tuple

import requests
import smart_open
from msal import ConfidentialClientApplication
from office365.entity_collection import EntityCollection
from office365.graph_client import GraphClient
from office365.onedrive.drives.drive import Drive
from office365.runtime.auth.token_response import TokenResponse
from office365.sharepoint.client_context import ClientContext
from office365.sharepoint.search.service import SearchService

from airbyte_cdk import AirbyteTracedException, FailureType
from airbyte_cdk.sources.file_based.exceptions import FileSizeLimitError
Expand All @@ -24,7 +29,18 @@
from source_microsoft_sharepoint.spec import SourceMicrosoftSharePointSpec

from .exceptions import ErrorDownloadingFile, ErrorFetchingMetadata
from .utils import FolderNotFoundException, MicrosoftSharePointRemoteFile, execute_query_with_retry, filter_http_urls
from .utils import (
FolderNotFoundException,
MicrosoftSharePointRemoteFile,
execute_query_with_retry,
filter_http_urls,
get_site,
get_site_prefix,
)


SITE_TITLE = "Title"
SITE_PATH = "Path"


class SourceMicrosoftSharePointClient:
Expand All @@ -50,9 +66,20 @@ def client(self):
self._client = GraphClient(self._get_access_token)
return self._client

def _get_access_token(self):
@staticmethod
def _get_scope(tenant_prefix: str = None):
"""
Returns the scope for the access token.
We use admin site to retrieve objects like Sites.
"""
if tenant_prefix:
admin_site_url = f"https://{tenant_prefix}-admin.sharepoint.com"
return [f"{admin_site_url}/.default"]
return ["https://graph.microsoft.com/.default"]

def _get_access_token(self, tenant_prefix: str = None):
"""Retrieves an access token for SharePoint access."""
scope = ["https://graph.microsoft.com/.default"]
scope = self._get_scope(tenant_prefix)
refresh_token = self.config.credentials.refresh_token if hasattr(self.config.credentials, "refresh_token") else None

if refresh_token:
Expand All @@ -67,6 +94,13 @@ def _get_access_token(self):

return result

def get_token_response_object_wrapper(self, tenant_prefix: str):
def get_token_response_object():
token = self._get_access_token(tenant_prefix=tenant_prefix)
return TokenResponse.from_json(token)

return get_token_response_object


class SourceMicrosoftSharePointStreamReader(AbstractFileBasedStreamReader):
"""
Expand Down Expand Up @@ -103,6 +137,20 @@ def get_access_token(self):
# Directly fetch a new access token from the auth_client each time it's called
return self.auth_client._get_access_token()["access_token"]

def get_token_response_object(self, tenant_prefix: str = None) -> Callable:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like if we set the value to None by default, this should be optional. However, I don't see a case where tenant_prefix is not provided in the current code. What am I missing?

I see the _get_scope supports this case but I'm not sure when this can happen

"""
When building a ClientContext using with_access_token() method,
the token_func param is expected to be a method/callable that returns a TokenResponse object.
tenant_prefix is used to determine the scope of the access token.
return: A callable that returns a TokenResponse object.
"""
return self.auth_client.get_token_response_object_wrapper(tenant_prefix=tenant_prefix)

def get_client_context(self):
site_url, root_site_prefix = get_site_prefix(get_site(self.one_drive_client))
client_context = ClientContext(site_url).with_access_token(self.get_token_response_object(tenant_prefix=root_site_prefix))
return client_context

@config.setter
def config(self, value: SourceMicrosoftSharePointSpec):
"""
Expand Down Expand Up @@ -202,11 +250,67 @@ def _get_files_by_drive_name(self, drives, folder_path):

yield from self._list_directories_and_files(folder, folder_path_url)

def get_all_sites(self) -> List[MutableMapping[str, Any]]:
"""
Retrieves all SharePoint sites from the current tenant.

Returns:
List[MutableMapping[str, Any]]: A list of site information.
"""
_, root_site_prefix = get_site_prefix(get_site(self.one_drive_client))
ctx = self.get_client_context()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like get_site_prefix is also called within get_client_context. Since get_client_context should be private, should we pass the output of get_site_prefix as arguments for get_client_config? It would prevent us from calling this twice in the same context

search_service = SearchService(ctx)
# ignore default OneDrive site with NOT Path:https://prefix-my.sharepoint.com
search_job = search_service.post_query(f"contentclass:STS_Site NOT Path:https://{root_site_prefix}-my.sharepoint.com")
search_job_result = execute_query_with_retry(search_job)

found_sites = []
if search_job.value and search_job_result.value.PrimaryQueryResult:
table = search_job_result.value.PrimaryQueryResult.RelevantResults.Table
for row in table.Rows:
found_site = {}
data = row.Cells
found_site[SITE_TITLE] = data.get(SITE_TITLE)
found_site[SITE_PATH] = data.get(SITE_PATH)
found_sites.append(found_site)
else:
raise Exception("No site collections found")

return found_sites

def get_drives_from_sites(self, sites: List[MutableMapping[str, Any]]) -> EntityCollection:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a private method, right? Should we prefix the name by _?

"""
Retrieves SharePoint drives from the provided sites.
Args:
sites (List[MutableMapping[str, Any]]): A list of site information.

Returns:
EntityCollection: A collection of SharePoint drives.
"""
all_sites_drives = EntityCollection(context=self.one_drive_client, item_type=Drive)
for site in sites:
drives = execute_query_with_retry(self.one_drive_client.sites.get_by_url(site[SITE_PATH]).drives.get())
for site_drive in drives:
all_sites_drives.add_child(site_drive)
return all_sites_drives

def get_site_drive(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this was outside of the scope of the PR but can we add typing here?

"""
Retrieves SharePoint drives based on the provided site URL.
It iterates over the sites if something like sharepoint.com/sites/ is in the site_url.
Returns:
EntityCollection: A collection of SharePoint drives.

Raises:
AirbyteTracedException: If an error occurs while retrieving drives.
"""
try:
if not self.config.site_url:
# get main site drives
drives = execute_query_with_retry(self.one_drive_client.drives.get())
elif re.search(r"sharepoint\.com/sites/?$", self.config.site_url):
# get all sites and then get drives from each site
return self.get_drives_from_sites(self.get_all_sites())
else:
# get drives for site drives provided in the config
drives = execute_query_with_retry(self.one_drive_client.sites.get_by_url(self.config.site_url).drives.get())
Expand Down Expand Up @@ -398,17 +502,3 @@ def get_file(self, file: MicrosoftSharePointRemoteFile, local_directory: str, lo
raise AirbyteTracedException(
f"There was an error while trying to download the file {file.uri}: {str(e)}", failure_type=FailureType.config_error
)

def get_file_acl_permissions(self):
return None

def load_identity_groups(self):
return None

@property
def identities_schema(self):
return None

@property
def file_permissions_schema(self):
return None
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import time
from datetime import datetime
from enum import Enum
from functools import lru_cache
from http import HTTPStatus
from typing import List

from office365.graph_client import GraphClient
from office365.onedrive.sites.site import Site

from airbyte_cdk import AirbyteTracedException, FailureType
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
Expand Down Expand Up @@ -150,3 +155,22 @@ def build(self) -> str:
query_string = "&".join(self._segments)
query_string = "?" + query_string if query_string else ""
return f"{self._scheme}://{self._host}{self._path}{query_string}"


@lru_cache(maxsize=None)
def get_site(graph_client: GraphClient, site_url: str = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like get_site is always called with get_site_prefix. Should we make this private to avoid exposing more things than needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how the lru_cache works but it seems like GraphClient does not implement __eq__ which I assume will only hit the cache if it is the same reference object. Is this what we want here?

if site_url:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I know some tools would consider not returning immediately as a code smell (example). I would refactor this as such:

def get_site(graph_client: GraphClient, site_url: str = None):
    if site_url:
        return execute_query_with_retry(graph_client.sites.get_by_url(site_url))
    return execute_query_with_retry(graph_client.sites.root.get())

site = execute_query_with_retry(graph_client.sites.get_by_url(site_url))
else:
site = execute_query_with_retry(graph_client.sites.root.get())
return site


def get_site_prefix(site: Site):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add return types for these added functions?

site_url = site.web_url
host_name = site.site_collection.hostname
host_name_parts: List = host_name.split(".") # e.g. "contoso.sharepoint.com" => ["contoso", "sharepoint", "com"]
if len(host_name_parts) < 2:
raise ValueError(f"Invalid host name: {host_name}")

return site_url, host_name_parts[0]
Loading
Loading