Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable resolve for query and json parameters #2208

Merged
28 changes: 22 additions & 6 deletions dlt/sources/rest_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Generic API Source"""

from copy import deepcopy
from typing import Any, Dict, List, Optional, Generator, Callable, cast, Union
import graphlib
Expand Down Expand Up @@ -68,7 +69,11 @@ def rest_api(
) -> List[DltResource]:
"""Creates and configures a REST API source with default settings"""
return rest_api_resources(
{"client": client, "resources": resources, "resource_defaults": resource_defaults}
{
"client": client,
"resources": resources,
"resource_defaults": resource_defaults,
}
)


Expand Down Expand Up @@ -346,6 +351,7 @@ def paginate_dependent_resource(
items: List[Dict[str, Any]],
method: HTTPMethodBasic,
path: str,
request_json: Optional[Dict[str, Any]],
params: Dict[str, Any],
paginator: Optional[BasePaginator],
data_selector: Optional[jsonpath.TJsonPath],
Expand All @@ -368,14 +374,22 @@ def paginate_dependent_resource(
)

for item in items:
formatted_path, parent_record = process_parent_data_item(
path, item, resolved_params, include_from_parent
formatted_path, parent_record, updated_params, updated_json = (
process_parent_data_item(
path=path,
item=item,
# params=params,
request_json=request_json,
resolved_params=resolved_params,
include_from_parent=include_from_parent,
)
)

for child_page in client.paginate(
method=method,
path=formatted_path,
params=params,
params=updated_params,
json=updated_json,
paginator=paginator,
data_selector=data_selector,
hooks=hooks,
Expand All @@ -393,6 +407,7 @@ def paginate_dependent_resource(
method=endpoint_config.get("method", "get"),
path=endpoint_config.get("path"),
params=base_params,
request_json=request_json,
paginator=paginator,
data_selector=endpoint_config.get("data_selector"),
hooks=hooks,
Expand Down Expand Up @@ -435,7 +450,8 @@ def _mask_secrets(auth_config: AuthConfig) -> AuthConfig:
has_sensitive_key = any(key in auth_config for key in SENSITIVE_KEYS)
if (
isinstance(
auth_config, (APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials)
auth_config,
(APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials),
)
or has_sensitive_key
):
Expand Down Expand Up @@ -481,7 +497,7 @@ def identity_func(x: Any) -> Any:


def _validate_param_type(
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]]
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]],
) -> None:
for _, value in request_params.items():
if isinstance(value, dict) and value.get("type") not in PARAM_TYPES:
Expand Down
206 changes: 166 additions & 40 deletions dlt/sources/rest_api/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dlt.common.utils import update_dict_nested, exclude_keys
from dlt.common.typing import add_value_to_literal
from dlt.common import jsonpath
from dlt.common import json

from dlt.extract.incremental import Incremental
from dlt.extract.utils import ensure_table_schema_columns
Expand Down Expand Up @@ -255,7 +256,7 @@ def setup_incremental_object(


def parse_convert_or_deprecated_transform(
config: Union[IncrementalConfig, Dict[str, Any]]
config: Union[IncrementalConfig, Dict[str, Any]],
) -> Optional[Callable[..., Any]]:
convert = config.get("convert", None)
deprecated_transform = config.get("transform", None)
Expand Down Expand Up @@ -296,6 +297,20 @@ def build_resource_dependency_graph(
# find resolved parameters to connect dependent resources
resolved_params = _find_resolved_params(endpoint_resource["endpoint"])

# extract more resolved params from path expressions
# path_expressions = _extract_expressions(endpoint_resource["endpoint"]["path"], "resources.")
# resolved_params += _expressions_to_resolved_params(path_expressions)

# extract expressions from parameters that are strings
params_expressions = []
for param_value in endpoint_resource["endpoint"].get("params", {}).values():
# If param_value is a plain string (e.g. "{resources.berry.a_property}")
if isinstance(param_value, str):
extracted = _extract_expressions(param_value, "resources.")
params_expressions.extend(extracted)

resolved_params += _expressions_to_resolved_params(params_expressions)

# set of resources in resolved params
named_resources = {rp.resolve_config["resource"] for rp in resolved_params}

Expand Down Expand Up @@ -383,6 +398,15 @@ def _make_endpoint_resource(
return _merge_resource_endpoints(default_config, resource)


def _replace_expression(template: str, params: Dict[str, Any]):
"""This method is used to replace the expression in the templates
because the the str.format() doesn't like placeholders with dots.
"""
for p in params:
template = template.replace(f"{{{p}}}", str(params[p]))
return template


def _bind_path_params(resource: EndpointResource) -> None:
"""Binds params declared in path to params available in `params`. Pops the
bound params but. Params of type `resolve` and `incremental` are skipped
Expand All @@ -391,40 +415,35 @@ def _bind_path_params(resource: EndpointResource) -> None:
path_params: Dict[str, Any] = {}
assert isinstance(resource["endpoint"], dict) # type guard
resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"])]

params = resource["endpoint"].get("params", {})
path = resource["endpoint"]["path"]
for format_ in string.Formatter().parse(path):
name = format_[1]
if name:
params = resource["endpoint"].get("params", {})
if name not in params and name not in path_params:
raise ValueError(
f"The path {path} defined in resource {resource['name']} requires param with"
f" name {name} but it is not found in {params}"
)
if name in resolve_params:
resolve_params.remove(name)
if name in params:
if not isinstance(params[name], dict):
# bind resolved param and pop it from endpoint
path_params[name] = params.pop(name)
else:
param_type = params[name].get("type")
if param_type != "resolve":
raise ValueError(
f"The path {path} defined in resource {resource['name']} tries to bind"
f" param {name} with type {param_type}. Paths can only bind 'resolve'"
" type params."
)
# resolved params are bound later
path_params[name] = "{" + name + "}"

if len(resolve_params) > 0:
raise NotImplementedError(
f"Resource {resource['name']} defines resolve params {resolve_params} that are not"
f" bound in path {path}. Resolve query params not supported yet."
)

resource["endpoint"]["path"] = path.format(**path_params)
for name in _extract_expressions(path):
if name not in params and name not in path_params and name not in resolve_params:
raise ValueError(
f"The path {path} defined in resource {resource['name']} requires param with"
f" name {name} but it is not found in {params}"
)
if name in resolve_params:
resolve_params.remove(name)
if name in params:
if not isinstance(params[name], dict):
# bind resolved param and pop it from endpoint
path_params[name] = params.pop(name)
else:
param_type = params[name].get("type")
if param_type != "resolve":
raise ValueError(
f"The path {path} defined in resource {resource['name']} tries to bind"
f" param {name} with type {param_type}. Paths can only bind 'resolve'"
" type params."
)
# resolved params are bound later
path_params[name] = "{" + name + "}"

# resource["endpoint"]["path"] = path.format(**path_params)
resource["endpoint"]["path"] = _replace_expression(path, path_params)


def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint:
Expand All @@ -450,12 +469,27 @@ def _find_resolved_params(endpoint_config: Endpoint) -> List[ResolvedParam]:

Resolved params are of type ResolveParamConfig (bound param with a key "type" set to "resolve".)
"""
return [

resolved_params = [
ResolvedParam(key, value) # type: ignore[arg-type]
for key, value in endpoint_config.get("params", {}).items()
if (isinstance(value, dict) and value.get("type") == "resolve")
]

path_expressions = _extract_expressions(endpoint_config["path"], "resources.")

json_expressions = (
_extract_expressions(endpoint_config["json"], "resources.")
if endpoint_config.get("json")
else []
)

resolved_params += _expressions_to_resolved_params(
path_expressions
) + _expressions_to_resolved_params(json_expressions)

return resolved_params


def _action_type_unless_custom_hook(
action_type: Optional[str], custom_hook: Optional[List[Callable[..., Any]]]
Expand Down Expand Up @@ -577,31 +611,115 @@ def remove_field(response: Response, *args, **kwargs) -> Response:
return None


def _extract_expressions(
template: Union[str, Dict],
prefix: str = "",
) -> List[str]:
"""Takes a template string and extracts expressions that start with a prefix.
Args:
template (str): A string with expressions to extract
prefix (str): A string that marks the beginning of an expression
Example:
>>> _extract_expressions("blog/{resources.blog.id}/comments", "resources.")
["resources.blog.id"]
"""

expressions = set()

def recursive_search(value):
if isinstance(value, dict):
for key, val in value.items():
recursive_search(key)
recursive_search(val)
elif isinstance(value, list):
for item in value:
recursive_search(item)
elif isinstance(value, str):
e = [
field_name
for _, field_name, _, _ in string.Formatter().parse(value)
if field_name and field_name.startswith(prefix)
]
expressions.update(e)

recursive_search(template)
return list(expressions)


def _expressions_to_resolved_params(expressions: List[str]) -> List[ResolvedParam]:
resolved_params = []
# We assume that the expressions are in the format 'resources.<resource>.<field>'
# and not more complex expressions
for expression in expressions:
parts = expression.strip().split(".")
if len(parts) != 3:
raise ValueError(
f"Invalid definition of {expression}. Expected format:"
" 'resources.<resource>.<field>'"
)
resolved_params.append(
ResolvedParam(
expression,
{
"type": "resolve",
"resource": parts[1],
"field": parts[2],
},
)
)
return resolved_params


def _bound_path_parameters(
path: str,
param_values: Dict[str, Any],
):
path_params = _extract_expressions(path)
bound_path = _replace_expression(path, param_values)

return bound_path, path_params


def _bound_json_parameters(
request_json: Dict[str, Any],
param_values: Dict[str, Any],
):
json_params = _extract_expressions(request_json)
bound_json = _replace_expression(json.dumps(request_json), param_values)

return json.loads(bound_json), json_params


def process_parent_data_item(
path: str,
item: Dict[str, Any],
# params: Dict[str, Any],,
resolved_params: List[ResolvedParam],
include_from_parent: List[str],
request_json: Optional[Dict[str, Any]] = [],
) -> Tuple[str, Dict[str, Any]]:
parent_resource_name = resolved_params[0].resolve_config["resource"]

param_values = {}

params_values = {}
for resolved_param in resolved_params:
field_values = jsonpath.find_values(resolved_param.field_path, item)

if not field_values:
field_path = resolved_param.resolve_config["field"]
raise ValueError(
f"Transformer expects a field '{field_path}' to be present in the incoming data"
f" from resource {parent_resource_name} in order to bind it to path param"
f" from resource {parent_resource_name} in order to bind it to param"
f" {resolved_param.param_name}. Available parent fields are"
f" {', '.join(item.keys())}"
)

param_values[resolved_param.param_name] = field_values[0]
params_values[resolved_param.param_name] = field_values[0]

bound_path, path_params = _bound_path_parameters(path, params_values)

bound_path = path.format(**param_values)
json_params = []
if request_json:
request_json, json_params = _bound_json_parameters(request_json, params_values)

parent_record: Dict[str, Any] = {}
if include_from_parent:
Expand All @@ -615,7 +733,15 @@ def process_parent_data_item(
)
parent_record[child_key] = item[parent_key]

return bound_path, parent_record
# the params not present in the params already bound,
# will be returned and used as query params
params_values = {
param_name: param_value
for param_name, param_value in params_values.items()
if param_name not in path_params and param_name not in json_params
}

return bound_path, parent_record, params_values, request_json


def _merge_resource_endpoints(
Expand Down
Loading
Loading