Skip to content

Commit

Permalink
enable resolve for query and json parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jan 13, 2025
1 parent cbcff92 commit fae5ae3
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 65 deletions.
48 changes: 37 additions & 11 deletions dlt/sources/rest_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generic API Source"""

from copy import deepcopy
from typing import Any, Dict, List, Optional, Generator, Callable, cast, Union
import graphlib
Expand Down Expand Up @@ -68,7 +69,11 @@ def rest_api(
) -> List[DltResource]:
"""Creates and configures a REST API source with default settings"""
return rest_api_resources(
{"client": client, "resources": resources, "resource_defaults": resource_defaults}
{
"client": client,
"resources": resources,
"resource_defaults": resource_defaults,
}
)


Expand Down Expand Up @@ -250,7 +255,9 @@ def create_resources(

resolved_params: List[ResolvedParam] = resolved_param_map[resource_name]

include_from_parent: List[str] = endpoint_resource.get("include_from_parent", [])
include_from_parent: List[str] = endpoint_resource.get(
"include_from_parent", []
)
if not resolved_params and include_from_parent:
raise ValueError(
f"Resource {resource_name} has include_from_parent but is not "
Expand All @@ -273,7 +280,9 @@ def create_resources(

hooks = create_response_hooks(endpoint_config.get("response_actions"))

resource_kwargs = exclude_keys(endpoint_resource, {"endpoint", "include_from_parent"})
resource_kwargs = exclude_keys(
endpoint_resource, {"endpoint", "include_from_parent"}
)

def process(
resource: DltResource,
Expand Down Expand Up @@ -334,18 +343,23 @@ def paginate_resource(
hooks=hooks,
)

resources[resource_name] = process(resources[resource_name], processing_steps)
resources[resource_name] = process(
resources[resource_name], processing_steps
)

else:
first_param = resolved_params[0]
predecessor = resources[first_param.resolve_config["resource"]]

base_params = exclude_keys(request_params, {x.param_name for x in resolved_params})
base_params = exclude_keys(
request_params, {x.param_name for x in resolved_params}
)

def paginate_dependent_resource(
items: List[Dict[str, Any]],
method: HTTPMethodBasic,
path: str,
request_json: Optional[Dict[str, Any]],
params: Dict[str, Any],
paginator: Optional[BasePaginator],
data_selector: Optional[jsonpath.TJsonPath],
Expand All @@ -368,14 +382,22 @@ def paginate_dependent_resource(
)

for item in items:
formatted_path, parent_record = process_parent_data_item(
path, item, resolved_params, include_from_parent
formatted_path, parent_record, updated_params, updated_json = (
process_parent_data_item(
path=path,
item=item,
params=params,
request_json=request_json,
resolved_params=resolved_params,
include_from_parent=include_from_parent,
)
)

for child_page in client.paginate(
method=method,
path=formatted_path,
params=params,
params=updated_params,
json=updated_json,
paginator=paginator,
data_selector=data_selector,
hooks=hooks,
Expand All @@ -393,12 +415,15 @@ def paginate_dependent_resource(
method=endpoint_config.get("method", "get"),
path=endpoint_config.get("path"),
params=base_params,
request_json=request_json,
paginator=paginator,
data_selector=endpoint_config.get("data_selector"),
hooks=hooks,
)

resources[resource_name] = process(resources[resource_name], processing_steps)
resources[resource_name] = process(
resources[resource_name], processing_steps
)

return resources

Expand Down Expand Up @@ -435,7 +460,8 @@ def _mask_secrets(auth_config: AuthConfig) -> AuthConfig:
has_sensitive_key = any(key in auth_config for key in SENSITIVE_KEYS)
if (
isinstance(
auth_config, (APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials)
auth_config,
(APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials),
)
or has_sensitive_key
):
Expand Down Expand Up @@ -481,7 +507,7 @@ def identity_func(x: Any) -> Any:


def _validate_param_type(
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]]
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]],
) -> None:
for _, value in request_params.items():
if isinstance(value, dict) and value.get("type") not in PARAM_TYPES:
Expand Down
144 changes: 97 additions & 47 deletions dlt/sources/rest_api/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dlt.common.utils import update_dict_nested, exclude_keys
from dlt.common.typing import add_value_to_literal
from dlt.common import jsonpath
from dlt.common import json

from dlt.extract.incremental import Incremental
from dlt.extract.utils import ensure_table_schema_columns
Expand Down Expand Up @@ -140,7 +141,9 @@ def create_paginator(
paginator_type = paginator_config.get("type", "auto")
paginator_class = get_paginator_class(paginator_type)
return (
paginator_class(**exclude_keys(paginator_config, {"type"})) if paginator_class else None
paginator_class(**exclude_keys(paginator_config, {"type"}))
if paginator_class
else None
)

return None
Expand Down Expand Up @@ -195,7 +198,9 @@ def create_auth(auth_config: Optional[AuthConfig]) -> Optional[AuthConfigBase]:
def setup_incremental_object(
request_params: Dict[str, Any],
incremental_config: Optional[IncrementalConfig] = None,
) -> Tuple[Optional[Incremental[Any]], Optional[IncrementalParam], Optional[Callable[..., Any]]]:
) -> Tuple[
Optional[Incremental[Any]], Optional[IncrementalParam], Optional[Callable[..., Any]]
]:
incremental_params: List[str] = []
for param_name, param_config in request_params.items():
if (
Expand Down Expand Up @@ -255,7 +260,7 @@ def setup_incremental_object(


def parse_convert_or_deprecated_transform(
config: Union[IncrementalConfig, Dict[str, Any]]
config: Union[IncrementalConfig, Dict[str, Any]],
) -> Optional[Callable[..., Any]]:
convert = config.get("convert", None)
deprecated_transform = config.get("transform", None)
Expand Down Expand Up @@ -300,7 +305,9 @@ def build_resource_dependency_graph(
named_resources = {rp.resolve_config["resource"] for rp in resolved_params}

if len(named_resources) > 1:
raise ValueError(f"Multiple parent resources for {resource_name}: {resolved_params}")
raise ValueError(
f"Multiple parent resources for {resource_name}: {resolved_params}"
)
elif len(named_resources) == 1:
# validate the first parameter (note the resource is the same for all params)
first_param = resolved_params[0]
Expand Down Expand Up @@ -340,9 +347,9 @@ def expand_and_index_resources(
_bind_path_params(endpoint_resource)

resource_name = endpoint_resource["name"]
assert isinstance(
resource_name, str
), f"Resource name must be a string, got {type(resource_name)}"
assert isinstance(resource_name, str), (
f"Resource name must be a string, got {type(resource_name)}"
)

if resource_name in endpoint_resource_map:
raise ValueError(f"Resource {resource_name} has already been defined")
Expand Down Expand Up @@ -392,37 +399,29 @@ def _bind_path_params(resource: EndpointResource) -> None:
assert isinstance(resource["endpoint"], dict) # type guard
resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"])]
path = resource["endpoint"]["path"]
for format_ in string.Formatter().parse(path):
name = format_[1]
if name:
params = resource["endpoint"].get("params", {})
if name not in params and name not in path_params:
raise ValueError(
f"The path {path} defined in resource {resource['name']} requires param with"
f" name {name} but it is not found in {params}"
)
if name in resolve_params:
resolve_params.remove(name)
if name in params:
if not isinstance(params[name], dict):
# bind resolved param and pop it from endpoint
path_params[name] = params.pop(name)
else:
param_type = params[name].get("type")
if param_type != "resolve":
raise ValueError(
f"The path {path} defined in resource {resource['name']} tries to bind"
f" param {name} with type {param_type}. Paths can only bind 'resolve'"
" type params."
)
# resolved params are bound later
path_params[name] = "{" + name + "}"

if len(resolve_params) > 0:
raise NotImplementedError(
f"Resource {resource['name']} defines resolve params {resolve_params} that are not"
f" bound in path {path}. Resolve query params not supported yet."
)
for name in _get_placeholders(path):
params = resource["endpoint"].get("params", {})
if name not in params and name not in path_params:
raise ValueError(
f"The path {path} defined in resource {resource['name']} requires param with"
f" name {name} but it is not found in {params}"
)
if name in resolve_params:
resolve_params.remove(name)
if name in params:
if not isinstance(params[name], dict):
# bind resolved param and pop it from endpoint
path_params[name] = params.pop(name)
else:
param_type = params[name].get("type")
if param_type != "resolve":
raise ValueError(
f"The path {path} defined in resource {resource['name']} tries to bind"
f" param {name} with type {param_type}. Paths can only bind 'resolve'"
" type params."
)
# resolved params are bound later
path_params[name] = "{" + name + "}"

resource["endpoint"]["path"] = path.format(**path_params)

Expand Down Expand Up @@ -459,7 +458,10 @@ def _find_resolved_params(endpoint_config: Endpoint) -> List[ResolvedParam]:

def _action_type_unless_custom_hook(
action_type: Optional[str], custom_hook: Optional[List[Callable[..., Any]]]
) -> Union[Tuple[str, Optional[List[Callable[..., Any]]]], Tuple[None, List[Callable[..., Any]]],]:
) -> Union[
Tuple[str, Optional[List[Callable[..., Any]]]],
Tuple[None, List[Callable[..., Any]]],
]:
if custom_hook:
return (None, custom_hook)
return (action_type, None)
Expand Down Expand Up @@ -577,16 +579,53 @@ def remove_field(response: Response, *args, **kwargs) -> Response:
return None


def _get_placeholders(template: str) -> List[str]:
return [
field_name
for _, field_name, _, _ in string.Formatter().parse(template)
if field_name
]


def _bound_path_parameters(
path: str,
param_values: Dict[str, Any],
):
path_params = _get_placeholders(path)
bound_path = path.format(**param_values)

for param in path_params:
param_values.pop(param)

return bound_path, param_values


def _bound_json_parameters(
request_json: Optional[Dict[str, Any]],
param_values: Dict[str, Any],
):
json_as_template = "{" + json.dumps(request_json) + "}"
json_params = _get_placeholders(json_as_template)

bound_json = json_as_template.format(**param_values)

for param in json_params:
param_values.pop(param)

return json.loads(bound_json), param_values


def process_parent_data_item(
path: str,
item: Dict[str, Any],
params: Dict[str, Any],
request_json: Optional[Dict[str, Any]],
resolved_params: List[ResolvedParam],
include_from_parent: List[str],
) -> Tuple[str, Dict[str, Any]]:
parent_resource_name = resolved_params[0].resolve_config["resource"]

param_values = {}

params_values = {}
for resolved_param in resolved_params:
field_values = jsonpath.find_values(resolved_param.field_path, item)

Expand All @@ -599,9 +638,14 @@ def process_parent_data_item(
f" {', '.join(item.keys())}"
)

param_values[resolved_param.param_name] = field_values[0]
params_values[resolved_param.param_name] = field_values[0]

bound_path = path.format(**param_values)
bound_path, params_values = _bound_path_parameters(path, params_values)

if request_json:
request_json, params_values = _bound_json_parameters(
request_json, params_values
)

parent_record: Dict[str, Any] = {}
if include_from_parent:
Expand All @@ -615,7 +659,7 @@ def process_parent_data_item(
)
parent_record[child_key] = item[parent_key]

return bound_path, parent_record
return bound_path, parent_record, params_values, request_json


def _merge_resource_endpoints(
Expand Down Expand Up @@ -646,14 +690,20 @@ def _merge_resource_endpoints(
**config_endpoint["params"],
}
# merge columns
if (default_columns := default_config.get("columns")) and (columns := config.get("columns")):
if (default_columns := default_config.get("columns")) and (
columns := config.get("columns")
):
# merge only native dlt formats, skip pydantic and others
if isinstance(columns, (list, dict)) and isinstance(default_columns, (list, dict)):
if isinstance(columns, (list, dict)) and isinstance(
default_columns, (list, dict)
):
# normalize columns
columns = ensure_table_schema_columns(columns)
default_columns = ensure_table_schema_columns(default_columns)
# merge columns with deep merging hints
config["columns"] = merge_columns(copy(default_columns), columns, merge_columns=True)
config["columns"] = merge_columns(
copy(default_columns), columns, merge_columns=True
)

# no need to deep merge resources
merged_resource: EndpointResource = {
Expand Down
Loading

0 comments on commit fae5ae3

Please sign in to comment.