Skip to content

Commit

Permalink
Enable resolve for query and json parameters (#2208)
Browse files Browse the repository at this point in the history
  • Loading branch information
francescomucio authored Jan 28, 2025
1 parent 852c515 commit aa5355d
Show file tree
Hide file tree
Showing 8 changed files with 545 additions and 89 deletions.
28 changes: 22 additions & 6 deletions dlt/sources/rest_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generic API Source"""

from copy import deepcopy
from typing import Any, Dict, List, Optional, Generator, Callable, cast, Union
import graphlib
Expand Down Expand Up @@ -68,7 +69,11 @@ def rest_api(
) -> List[DltResource]:
"""Creates and configures a REST API source with default settings"""
return rest_api_resources(
{"client": client, "resources": resources, "resource_defaults": resource_defaults}
{
"client": client,
"resources": resources,
"resource_defaults": resource_defaults,
}
)


Expand Down Expand Up @@ -346,6 +351,7 @@ def paginate_dependent_resource(
items: List[Dict[str, Any]],
method: HTTPMethodBasic,
path: str,
request_json: Optional[Dict[str, Any]],
params: Dict[str, Any],
paginator: Optional[BasePaginator],
data_selector: Optional[jsonpath.TJsonPath],
Expand All @@ -368,14 +374,22 @@ def paginate_dependent_resource(
)

for item in items:
formatted_path, parent_record = process_parent_data_item(
path, item, resolved_params, include_from_parent
formatted_path, parent_record, updated_params, updated_json = (
process_parent_data_item(
path=path,
item=item,
# params=params,
request_json=request_json,
resolved_params=resolved_params,
include_from_parent=include_from_parent,
)
)

for child_page in client.paginate(
method=method,
path=formatted_path,
params=params,
params=updated_params,
json=updated_json,
paginator=paginator,
data_selector=data_selector,
hooks=hooks,
Expand All @@ -393,6 +407,7 @@ def paginate_dependent_resource(
method=endpoint_config.get("method", "get"),
path=endpoint_config.get("path"),
params=base_params,
request_json=request_json,
paginator=paginator,
data_selector=endpoint_config.get("data_selector"),
hooks=hooks,
Expand Down Expand Up @@ -435,7 +450,8 @@ def _mask_secrets(auth_config: AuthConfig) -> AuthConfig:
has_sensitive_key = any(key in auth_config for key in SENSITIVE_KEYS)
if (
isinstance(
auth_config, (APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials)
auth_config,
(APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials),
)
or has_sensitive_key
):
Expand Down Expand Up @@ -481,7 +497,7 @@ def identity_func(x: Any) -> Any:


def _validate_param_type(
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]]
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]],
) -> None:
for _, value in request_params.items():
if isinstance(value, dict) and value.get("type") not in PARAM_TYPES:
Expand Down
204 changes: 163 additions & 41 deletions dlt/sources/rest_api/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dlt.common.utils import update_dict_nested, exclude_keys
from dlt.common.typing import add_value_to_literal
from dlt.common import jsonpath
from dlt.common import json

from dlt.extract.incremental import Incremental
from dlt.extract.utils import ensure_table_schema_columns
Expand Down Expand Up @@ -255,7 +256,7 @@ def setup_incremental_object(


def parse_convert_or_deprecated_transform(
config: Union[IncrementalConfig, Dict[str, Any]]
config: Union[IncrementalConfig, Dict[str, Any]],
) -> Optional[Callable[..., Any]]:
convert = config.get("convert", None)
deprecated_transform = config.get("transform", None)
Expand Down Expand Up @@ -296,6 +297,16 @@ def build_resource_dependency_graph(
# find resolved parameters to connect dependent resources
resolved_params = _find_resolved_params(endpoint_resource["endpoint"])

# extract expressions from parameters that are strings
params_expressions = []
for param_value in endpoint_resource["endpoint"].get("params", {}).values():
# If param_value is a plain string (e.g. "{resources.berry.a_property}")
if isinstance(param_value, str):
extracted = _extract_expressions(param_value, "resources.")
params_expressions.extend(extracted)

resolved_params += _expressions_to_resolved_params(params_expressions)

# set of resources in resolved params
named_resources = {rp.resolve_config["resource"] for rp in resolved_params}

Expand Down Expand Up @@ -383,6 +394,15 @@ def _make_endpoint_resource(
return _merge_resource_endpoints(default_config, resource)


def _replace_expression(template: str, params: Dict[str, Any]) -> str:
"""This method is used to replace the expression in the templates
because the the str.format() doesn't like placeholders with dots.
"""
for p in params:
template = template.replace(f"{{{p}}}", str(params[p]))
return template


def _bind_path_params(resource: EndpointResource) -> None:
"""Binds params declared in path to params available in `params`. Pops the
bound params but. Params of type `resolve` and `incremental` are skipped
Expand All @@ -391,40 +411,35 @@ def _bind_path_params(resource: EndpointResource) -> None:
path_params: Dict[str, Any] = {}
assert isinstance(resource["endpoint"], dict) # type guard
resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"])]

params = resource["endpoint"].get("params", {})
path = resource["endpoint"]["path"]
for format_ in string.Formatter().parse(path):
name = format_[1]
if name:
params = resource["endpoint"].get("params", {})
if name not in params and name not in path_params:
raise ValueError(
f"The path {path} defined in resource {resource['name']} requires param with"
f" name {name} but it is not found in {params}"
)
if name in resolve_params:
resolve_params.remove(name)
if name in params:
if not isinstance(params[name], dict):
# bind resolved param and pop it from endpoint
path_params[name] = params.pop(name)
else:
param_type = params[name].get("type")
if param_type != "resolve":
raise ValueError(
f"The path {path} defined in resource {resource['name']} tries to bind"
f" param {name} with type {param_type}. Paths can only bind 'resolve'"
" type params."
)
# resolved params are bound later
path_params[name] = "{" + name + "}"

if len(resolve_params) > 0:
raise NotImplementedError(
f"Resource {resource['name']} defines resolve params {resolve_params} that are not"
f" bound in path {path}. Resolve query params not supported yet."
)

resource["endpoint"]["path"] = path.format(**path_params)
for name in _extract_expressions(path):
if name not in params and name not in path_params and name not in resolve_params:
raise ValueError(
f"The path {path} defined in resource {resource['name']} requires param with"
f" name {name} but it is not found in {params}"
)
if name in resolve_params:
resolve_params.remove(name)
if name in params:
if not isinstance(params[name], dict):
# bind resolved param and pop it from endpoint
path_params[name] = params.pop(name)
else:
param_type = params[name].get("type")
if param_type != "resolve":
raise ValueError(
f"The path {path} defined in resource {resource['name']} tries to bind"
f" param {name} with type {param_type}. Paths can only bind 'resolve'"
" type params."
)
# resolved params are bound later
path_params[name] = "{" + name + "}"

# resource["endpoint"]["path"] = path.format(**path_params)
resource["endpoint"]["path"] = _replace_expression(path, path_params)


def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint:
Expand All @@ -450,12 +465,27 @@ def _find_resolved_params(endpoint_config: Endpoint) -> List[ResolvedParam]:
Resolved params are of type ResolveParamConfig (bound param with a key "type" set to "resolve".)
"""
return [

resolved_params = [
ResolvedParam(key, value) # type: ignore[arg-type]
for key, value in endpoint_config.get("params", {}).items()
if (isinstance(value, dict) and value.get("type") == "resolve")
]

path_expressions = _extract_expressions(endpoint_config["path"], "resources.")

json_expressions = (
_extract_expressions(endpoint_config["json"], "resources.")
if endpoint_config.get("json")
else []
)

resolved_params += _expressions_to_resolved_params(
path_expressions
) + _expressions_to_resolved_params(json_expressions)

return resolved_params


def _action_type_unless_custom_hook(
action_type: Optional[str], custom_hook: Optional[List[Callable[..., Any]]]
Expand Down Expand Up @@ -577,31 +607,115 @@ def remove_field(response: Response, *args, **kwargs) -> Response:
return None


def _extract_expressions(
template: Union[str, Dict[str, Any]],
prefix: str = "",
) -> List[str]:
"""Takes a template string and extracts expressions that start with a prefix.
Args:
template (str): A string with expressions to extract
prefix (str): A string that marks the beginning of an expression
Example:
>>> _extract_expressions("blog/{resources.blog.id}/comments", "resources.")
["resources.blog.id"]
"""

expressions = set()

def recursive_search(value: Union[str, List[Any], Dict[str, Any]]) -> None:
if isinstance(value, dict):
for key, val in value.items():
recursive_search(key)
recursive_search(val)
elif isinstance(value, list):
for item in value:
recursive_search(item)
elif isinstance(value, str):
e = [
field_name
for _, field_name, _, _ in string.Formatter().parse(value)
if field_name and field_name.startswith(prefix)
]
expressions.update(e)

recursive_search(template)
return list(expressions)


def _expressions_to_resolved_params(expressions: List[str]) -> List[ResolvedParam]:
resolved_params = []
# We assume that the expressions are in the format 'resources.<resource>.<field>'
# and not more complex expressions
for expression in expressions:
parts = expression.strip().split(".")
if len(parts) != 3:
raise ValueError(
f"Invalid definition of {expression}. Expected format:"
" 'resources.<resource>.<field>'"
)
resolved_params.append(
ResolvedParam(
expression,
{
"type": "resolve",
"resource": parts[1],
"field": parts[2],
},
)
)
return resolved_params


def _bound_path_parameters(
path: str,
param_values: Dict[str, Any],
) -> Tuple[str, List[str]]:
path_params = _extract_expressions(path)
bound_path = _replace_expression(path, param_values)

return bound_path, path_params


def _bound_json_parameters(
request_json: Dict[str, Any],
param_values: Dict[str, Any],
) -> Tuple[Dict[str, Any], List[str]]:
json_params = _extract_expressions(request_json)
bound_json = _replace_expression(json.dumps(request_json), param_values)

return json.loads(bound_json), json_params


def process_parent_data_item(
path: str,
item: Dict[str, Any],
# params: Dict[str, Any],,
resolved_params: List[ResolvedParam],
include_from_parent: List[str],
) -> Tuple[str, Dict[str, Any]]:
request_json: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
parent_resource_name = resolved_params[0].resolve_config["resource"]

param_values = {}

params_values = {}
for resolved_param in resolved_params:
field_values = jsonpath.find_values(resolved_param.field_path, item)

if not field_values:
field_path = resolved_param.resolve_config["field"]
raise ValueError(
f"Transformer expects a field '{field_path}' to be present in the incoming data"
f" from resource {parent_resource_name} in order to bind it to path param"
f" from resource {parent_resource_name} in order to bind it to param"
f" {resolved_param.param_name}. Available parent fields are"
f" {', '.join(item.keys())}"
)

param_values[resolved_param.param_name] = field_values[0]
params_values[resolved_param.param_name] = field_values[0]

bound_path, path_params = _bound_path_parameters(path, params_values)

bound_path = path.format(**param_values)
json_params: List[str] = []
if request_json:
request_json, json_params = _bound_json_parameters(request_json, params_values)

parent_record: Dict[str, Any] = {}
if include_from_parent:
Expand All @@ -615,7 +729,15 @@ def process_parent_data_item(
)
parent_record[child_key] = item[parent_key]

return bound_path, parent_record
# the params not present in the params already bound,
# will be returned and used as query params
params_values = {
param_name: param_value
for param_name, param_value in params_values.items()
if param_name not in path_params and param_name not in json_params
}

return bound_path, parent_record, params_values, request_json


def _merge_resource_endpoints(
Expand Down
4 changes: 2 additions & 2 deletions tests/load/sources/rest_api/test_rest_api_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def test_rest_api_source(destination_config: DestinationTestConfiguration, reque

assert table_counts.keys() == {"pokemon_list", "berry", "location"}

assert table_counts["pokemon_list"] == 1302
assert table_counts["pokemon_list"] == 1304
assert table_counts["berry"] == 64
assert table_counts["location"] == 1036
assert table_counts["location"] == 1039


@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit aa5355d

Please sign in to comment.