diff --git a/dlt/sources/rest_api/__init__.py b/dlt/sources/rest_api/__init__.py index fda3aa9a5b..56f9322089 100644 --- a/dlt/sources/rest_api/__init__.py +++ b/dlt/sources/rest_api/__init__.py @@ -1,4 +1,5 @@ """Generic API Source""" + from copy import deepcopy from typing import Any, Dict, List, Optional, Generator, Callable, cast, Union import graphlib @@ -68,7 +69,11 @@ def rest_api( ) -> List[DltResource]: """Creates and configures a REST API source with default settings""" return rest_api_resources( - {"client": client, "resources": resources, "resource_defaults": resource_defaults} + { + "client": client, + "resources": resources, + "resource_defaults": resource_defaults, + } ) @@ -346,6 +351,7 @@ def paginate_dependent_resource( items: List[Dict[str, Any]], method: HTTPMethodBasic, path: str, + request_json: Optional[Dict[str, Any]], params: Dict[str, Any], paginator: Optional[BasePaginator], data_selector: Optional[jsonpath.TJsonPath], @@ -368,14 +374,22 @@ def paginate_dependent_resource( ) for item in items: - formatted_path, parent_record = process_parent_data_item( - path, item, resolved_params, include_from_parent + formatted_path, parent_record, updated_params, updated_json = ( + process_parent_data_item( + path=path, + item=item, + # params=params, + request_json=request_json, + resolved_params=resolved_params, + include_from_parent=include_from_parent, + ) ) for child_page in client.paginate( method=method, path=formatted_path, - params=params, + params=updated_params, + json=updated_json, paginator=paginator, data_selector=data_selector, hooks=hooks, @@ -393,6 +407,7 @@ def paginate_dependent_resource( method=endpoint_config.get("method", "get"), path=endpoint_config.get("path"), params=base_params, + request_json=request_json, paginator=paginator, data_selector=endpoint_config.get("data_selector"), hooks=hooks, @@ -435,7 +450,8 @@ def _mask_secrets(auth_config: AuthConfig) -> AuthConfig: has_sensitive_key = any(key in auth_config for key in SENSITIVE_KEYS) if ( isinstance( - auth_config, (APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials) + auth_config, + (APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials), ) or has_sensitive_key ): @@ -481,7 +497,7 @@ def identity_func(x: Any) -> Any: def _validate_param_type( - request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]] + request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]], ) -> None: for _, value in request_params.items(): if isinstance(value, dict) and value.get("type") not in PARAM_TYPES: diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index 0425c28582..13a1cc52f7 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -22,6 +22,7 @@ from dlt.common.utils import update_dict_nested, exclude_keys from dlt.common.typing import add_value_to_literal from dlt.common import jsonpath +from dlt.common import json from dlt.extract.incremental import Incremental from dlt.extract.utils import ensure_table_schema_columns @@ -255,7 +256,7 @@ def setup_incremental_object( def parse_convert_or_deprecated_transform( - config: Union[IncrementalConfig, Dict[str, Any]] + config: Union[IncrementalConfig, Dict[str, Any]], ) -> Optional[Callable[..., Any]]: convert = config.get("convert", None) deprecated_transform = config.get("transform", None) @@ -296,6 +297,16 @@ def build_resource_dependency_graph( # find resolved parameters to connect dependent resources resolved_params = _find_resolved_params(endpoint_resource["endpoint"]) + # extract expressions from parameters that are strings + params_expressions = [] + for param_value in endpoint_resource["endpoint"].get("params", {}).values(): + # If param_value is a plain string (e.g. "{resources.berry.a_property}") + if isinstance(param_value, str): + extracted = _extract_expressions(param_value, "resources.") + params_expressions.extend(extracted) + + resolved_params += _expressions_to_resolved_params(params_expressions) + # set of resources in resolved params named_resources = {rp.resolve_config["resource"] for rp in resolved_params} @@ -383,6 +394,15 @@ def _make_endpoint_resource( return _merge_resource_endpoints(default_config, resource) +def _replace_expression(template: str, params: Dict[str, Any]) -> str: + """This method is used to replace the expression in the templates + because the the str.format() doesn't like placeholders with dots. + """ + for p in params: + template = template.replace(f"{{{p}}}", str(params[p])) + return template + + def _bind_path_params(resource: EndpointResource) -> None: """Binds params declared in path to params available in `params`. Pops the bound params but. Params of type `resolve` and `incremental` are skipped @@ -391,40 +411,35 @@ def _bind_path_params(resource: EndpointResource) -> None: path_params: Dict[str, Any] = {} assert isinstance(resource["endpoint"], dict) # type guard resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"])] + + params = resource["endpoint"].get("params", {}) path = resource["endpoint"]["path"] - for format_ in string.Formatter().parse(path): - name = format_[1] - if name: - params = resource["endpoint"].get("params", {}) - if name not in params and name not in path_params: - raise ValueError( - f"The path {path} defined in resource {resource['name']} requires param with" - f" name {name} but it is not found in {params}" - ) - if name in resolve_params: - resolve_params.remove(name) - if name in params: - if not isinstance(params[name], dict): - # bind resolved param and pop it from endpoint - path_params[name] = params.pop(name) - else: - param_type = params[name].get("type") - if param_type != "resolve": - raise ValueError( - f"The path {path} defined in resource {resource['name']} tries to bind" - f" param {name} with type {param_type}. Paths can only bind 'resolve'" - " type params." - ) - # resolved params are bound later - path_params[name] = "{" + name + "}" - - if len(resolve_params) > 0: - raise NotImplementedError( - f"Resource {resource['name']} defines resolve params {resolve_params} that are not" - f" bound in path {path}. Resolve query params not supported yet." - ) - resource["endpoint"]["path"] = path.format(**path_params) + for name in _extract_expressions(path): + if name not in params and name not in path_params and name not in resolve_params: + raise ValueError( + f"The path {path} defined in resource {resource['name']} requires param with" + f" name {name} but it is not found in {params}" + ) + if name in resolve_params: + resolve_params.remove(name) + if name in params: + if not isinstance(params[name], dict): + # bind resolved param and pop it from endpoint + path_params[name] = params.pop(name) + else: + param_type = params[name].get("type") + if param_type != "resolve": + raise ValueError( + f"The path {path} defined in resource {resource['name']} tries to bind" + f" param {name} with type {param_type}. Paths can only bind 'resolve'" + " type params." + ) + # resolved params are bound later + path_params[name] = "{" + name + "}" + + # resource["endpoint"]["path"] = path.format(**path_params) + resource["endpoint"]["path"] = _replace_expression(path, path_params) def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint: @@ -450,12 +465,27 @@ def _find_resolved_params(endpoint_config: Endpoint) -> List[ResolvedParam]: Resolved params are of type ResolveParamConfig (bound param with a key "type" set to "resolve".) """ - return [ + + resolved_params = [ ResolvedParam(key, value) # type: ignore[arg-type] for key, value in endpoint_config.get("params", {}).items() if (isinstance(value, dict) and value.get("type") == "resolve") ] + path_expressions = _extract_expressions(endpoint_config["path"], "resources.") + + json_expressions = ( + _extract_expressions(endpoint_config["json"], "resources.") + if endpoint_config.get("json") + else [] + ) + + resolved_params += _expressions_to_resolved_params( + path_expressions + ) + _expressions_to_resolved_params(json_expressions) + + return resolved_params + def _action_type_unless_custom_hook( action_type: Optional[str], custom_hook: Optional[List[Callable[..., Any]]] @@ -577,16 +607,96 @@ def remove_field(response: Response, *args, **kwargs) -> Response: return None +def _extract_expressions( + template: Union[str, Dict[str, Any]], + prefix: str = "", +) -> List[str]: + """Takes a template string and extracts expressions that start with a prefix. + Args: + template (str): A string with expressions to extract + prefix (str): A string that marks the beginning of an expression + Example: + >>> _extract_expressions("blog/{resources.blog.id}/comments", "resources.") + ["resources.blog.id"] + """ + + expressions = set() + + def recursive_search(value: Union[str, List[Any], Dict[str, Any]]) -> None: + if isinstance(value, dict): + for key, val in value.items(): + recursive_search(key) + recursive_search(val) + elif isinstance(value, list): + for item in value: + recursive_search(item) + elif isinstance(value, str): + e = [ + field_name + for _, field_name, _, _ in string.Formatter().parse(value) + if field_name and field_name.startswith(prefix) + ] + expressions.update(e) + + recursive_search(template) + return list(expressions) + + +def _expressions_to_resolved_params(expressions: List[str]) -> List[ResolvedParam]: + resolved_params = [] + # We assume that the expressions are in the format 'resources..' + # and not more complex expressions + for expression in expressions: + parts = expression.strip().split(".") + if len(parts) != 3: + raise ValueError( + f"Invalid definition of {expression}. Expected format:" + " 'resources..'" + ) + resolved_params.append( + ResolvedParam( + expression, + { + "type": "resolve", + "resource": parts[1], + "field": parts[2], + }, + ) + ) + return resolved_params + + +def _bound_path_parameters( + path: str, + param_values: Dict[str, Any], +) -> Tuple[str, List[str]]: + path_params = _extract_expressions(path) + bound_path = _replace_expression(path, param_values) + + return bound_path, path_params + + +def _bound_json_parameters( + request_json: Dict[str, Any], + param_values: Dict[str, Any], +) -> Tuple[Dict[str, Any], List[str]]: + json_params = _extract_expressions(request_json) + bound_json = _replace_expression(json.dumps(request_json), param_values) + + return json.loads(bound_json), json_params + + def process_parent_data_item( path: str, item: Dict[str, Any], + # params: Dict[str, Any],, resolved_params: List[ResolvedParam], include_from_parent: List[str], -) -> Tuple[str, Dict[str, Any]]: + request_json: Optional[Dict[str, Any]] = None, +) -> Tuple[str, Dict[str, Any], Dict[str, Any], Dict[str, Any]]: parent_resource_name = resolved_params[0].resolve_config["resource"] - param_values = {} - + params_values = {} for resolved_param in resolved_params: field_values = jsonpath.find_values(resolved_param.field_path, item) @@ -594,14 +704,18 @@ def process_parent_data_item( field_path = resolved_param.resolve_config["field"] raise ValueError( f"Transformer expects a field '{field_path}' to be present in the incoming data" - f" from resource {parent_resource_name} in order to bind it to path param" + f" from resource {parent_resource_name} in order to bind it to param" f" {resolved_param.param_name}. Available parent fields are" f" {', '.join(item.keys())}" ) - param_values[resolved_param.param_name] = field_values[0] + params_values[resolved_param.param_name] = field_values[0] + + bound_path, path_params = _bound_path_parameters(path, params_values) - bound_path = path.format(**param_values) + json_params: List[str] = [] + if request_json: + request_json, json_params = _bound_json_parameters(request_json, params_values) parent_record: Dict[str, Any] = {} if include_from_parent: @@ -615,7 +729,15 @@ def process_parent_data_item( ) parent_record[child_key] = item[parent_key] - return bound_path, parent_record + # the params not present in the params already bound, + # will be returned and used as query params + params_values = { + param_name: param_value + for param_name, param_value in params_values.items() + if param_name not in path_params and param_name not in json_params + } + + return bound_path, parent_record, params_values, request_json def _merge_resource_endpoints( diff --git a/tests/load/sources/rest_api/test_rest_api_source.py b/tests/load/sources/rest_api/test_rest_api_source.py index 25a9952ba4..583a67e69a 100644 --- a/tests/load/sources/rest_api/test_rest_api_source.py +++ b/tests/load/sources/rest_api/test_rest_api_source.py @@ -56,9 +56,9 @@ def test_rest_api_source(destination_config: DestinationTestConfiguration, reque assert table_counts.keys() == {"pokemon_list", "berry", "location"} - assert table_counts["pokemon_list"] == 1302 + assert table_counts["pokemon_list"] == 1304 assert table_counts["berry"] == 64 - assert table_counts["location"] == 1036 + assert table_counts["location"] == 1039 @pytest.mark.parametrize( diff --git a/tests/sources/rest_api/configurations/test_resolve_config.py b/tests/sources/rest_api/configurations/test_resolve_config.py index 791c7aa5c6..e31470c4b7 100644 --- a/tests/sources/rest_api/configurations/test_resolve_config.py +++ b/tests/sources/rest_api/configurations/test_resolve_config.py @@ -80,64 +80,76 @@ def test_bind_path_param() -> None: _bind_path_params(tp_4) assert tp_4 == tp_5 - # resolved param will remain unbounded and - tp_6 = deepcopy(three_params) - tp_6["endpoint"]["path"] = "{org}/{repo}/issues/1234/comments" # type: ignore[index] - with pytest.raises(NotImplementedError): - _bind_path_params(tp_6) - def test_process_parent_data_item() -> None: resolve_params = [ ResolvedParam("id", {"field": "obj_id", "resource": "issues", "type": "resolve"}) ] - bound_path, parent_record = process_parent_data_item( - "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345}, resolve_params, None + + bound_path, parent_record, params_values, request_json = process_parent_data_item( + path="dlt-hub/dlt/issues/{id}/comments", + item={"obj_id": 12345}, + resolved_params=resolve_params, + include_from_parent=None, ) assert bound_path == "dlt-hub/dlt/issues/12345/comments" assert parent_record == {} - bound_path, parent_record = process_parent_data_item( - "dlt-hub/dlt/issues/{id}/comments", {"obj_id": 12345}, resolve_params, ["obj_id"] + bound_path, parent_record, params_values, request_json = process_parent_data_item( + path="dlt-hub/dlt/issues/{id}/comments", + item={"obj_id": 12345}, + resolved_params=resolve_params, + include_from_parent=["obj_id"], ) assert parent_record == {"_issues_obj_id": 12345} - bound_path, parent_record = process_parent_data_item( - "dlt-hub/dlt/issues/{id}/comments", - {"obj_id": 12345, "obj_node": "node_1"}, - resolve_params, - ["obj_id", "obj_node"], + bound_path, parent_record, params_values, request_json = process_parent_data_item( + path="dlt-hub/dlt/issues/{id}/comments", + item={"obj_id": 12345, "obj_node": "node_1"}, + resolved_params=resolve_params, + include_from_parent=["obj_id", "obj_node"], ) assert parent_record == {"_issues_obj_id": 12345, "_issues_obj_node": "node_1"} # test nested data resolve_param_nested = [ ResolvedParam( - "id", {"field": "some_results.obj_id", "resource": "issues", "type": "resolve"} + "id", + {"field": "some_results.obj_id", "resource": "issues", "type": "resolve"}, ) ] item = {"some_results": {"obj_id": 12345}} - bound_path, parent_record = process_parent_data_item( - "dlt-hub/dlt/issues/{id}/comments", item, resolve_param_nested, None + bound_path, parent_record, params_values, request_json = process_parent_data_item( + path="dlt-hub/dlt/issues/{id}/comments", + item=item, + resolved_params=resolve_param_nested, + include_from_parent=None, ) assert bound_path == "dlt-hub/dlt/issues/12345/comments" # param path not found with pytest.raises(ValueError) as val_ex: - bound_path, parent_record = process_parent_data_item( - "dlt-hub/dlt/issues/{id}/comments", {"_id": 12345}, resolve_params, None + bound_path, parent_record, params_values, request_json = process_parent_data_item( + path="dlt-hub/dlt/issues/{id}/comments", + item={"_id": 12345}, + resolved_params=resolve_params, + include_from_parent=None, ) assert "Transformer expects a field 'obj_id'" in str(val_ex.value) # included path not found with pytest.raises(ValueError) as val_ex: - bound_path, parent_record = process_parent_data_item( - "dlt-hub/dlt/issues/{id}/comments", - {"obj_id": 12345, "obj_node": "node_1"}, - resolve_params, - ["obj_id", "node"], + bound_path, parent_record, params_values, request_json = process_parent_data_item( + path="dlt-hub/dlt/issues/{id}/comments", + item={"_id": 12345, "obj_node": "node_1"}, + resolved_params=resolve_params, + include_from_parent=["obj_id", "node"], ) - assert "in order to include it in child records under _issues_node" in str(val_ex.value) + assert ( + "Transformer expects a field 'obj_id' to be present in the incoming data from resource" + " issues in order to bind it to" + in str(val_ex.value) + ) # Resolve multiple parameters from a single record multi_resolve_params = [ @@ -145,22 +157,22 @@ def test_process_parent_data_item() -> None: ResolvedParam("id", {"field": "id", "resource": "comments", "type": "resolve"}), ] - bound_path, parent_record = process_parent_data_item( - "dlt-hub/dlt/issues/{issue_id}/comments/{id}", - {"issue": 12345, "id": 56789}, - multi_resolve_params, - None, + bound_path, parent_record, params_values, request_json = process_parent_data_item( + path="dlt-hub/dlt/issues/{issue_id}/comments/{id}", + item={"issue": 12345, "id": 56789}, + resolved_params=multi_resolve_params, + include_from_parent=None, ) assert bound_path == "dlt-hub/dlt/issues/12345/comments/56789" assert parent_record == {} # param path not found with multiple parameters with pytest.raises(ValueError) as val_ex: - bound_path, parent_record = process_parent_data_item( - "dlt-hub/dlt/issues/{issue_id}/comments/{id}", - {"_issue": 12345, "id": 56789}, - multi_resolve_params, - None, + bound_path, parent_record, params_values, request_json = process_parent_data_item( + path="dlt-hub/dlt/issues/{issue_id}/comments/{id}", + item={"_issue": 12345, "id": 56789}, + resolved_params=multi_resolve_params, + include_from_parent=None, ) assert "Transformer expects a field 'issue'" in str(val_ex.value) diff --git a/tests/sources/rest_api/conftest.py b/tests/sources/rest_api/conftest.py index bc58a18e5c..992a5b2838 100644 --- a/tests/sources/rest_api/conftest.py +++ b/tests/sources/rest_api/conftest.py @@ -131,6 +131,11 @@ def post_detail(request, context): post_id = request.url.split("/")[-1] return {"id": int(post_id), "body": f"Post body {post_id}"} + @router.get(r"/posts\?post_id=\d+$") + def post_detail_via_query_param(request, context): + post_id = int(request.qs.get("post_id", [0])[0]) + return {"id": int(post_id), "body": f"Post body {post_id}"} + @router.get(r"/posts/\d+/some_details_404") def post_detail_404(request, context): """Return 404 for post with id > 0. Used to test ignoring 404 errors.""" @@ -169,6 +174,20 @@ def post_detail_204(request, context): def posts_with_results_key(request, context): return paginate_by_page_number(request, generate_posts(), records_key="many-results") + @router.post(r"/posts/search_by_id/\d+$") + def search_posts_by_id(request, context): + body = request.json() + post_id = body.get("post_id", 0) + title = body.get("more", {}).get("title", 0) + + more_array = body.get("more_array", [])[0] + return { + "id": int(post_id), + "title": title, + "body": f"Post body {post_id}", + "more": f"More is equale to id: {more_array}", + } + @router.post(r"/posts/search$") def search_posts(request, context): body = request.json() diff --git a/tests/sources/rest_api/integration/test_offline.py b/tests/sources/rest_api/integration/test_offline.py index 6054af3a1f..4d843fa472 100644 --- a/tests/sources/rest_api/integration/test_offline.py +++ b/tests/sources/rest_api/integration/test_offline.py @@ -10,6 +10,7 @@ from dlt.sources.helpers.rest_client.paginators import BaseReferencePaginator from dlt.sources.rest_api import ( ClientConfig, + DltSource, Endpoint, EndpointResource, RESTAPIConfig, @@ -100,6 +101,218 @@ def test_load_mock_api(mock_api_server): ) +def test_load_mock_api_with_query_params(mock_api_server): + pipeline = dlt.pipeline( + pipeline_name="rest_api_mock", + destination="duckdb", + dataset_name="rest_api_mock", + full_refresh=True, + ) + + mock_source: DltSource = rest_api_source( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + "posts", + { + "name": "post_details", + "endpoint": { + "path": "posts", + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + } + }, + }, + }, + ], + } + ) + + load_info = pipeline.run(mock_source) + print(load_info) + assert_load_info(load_info) + table_names = [t["name"] for t in pipeline.default_schema.data_tables()] + table_counts = load_table_counts(pipeline, *table_names) + + assert table_counts.keys() == {"posts", "post_details"} + + assert table_counts["posts"] == DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES + assert table_counts["post_details"] == DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES + + with pipeline.sql_client() as client: + posts_table = client.make_qualified_table_name("posts") + posts_details_table = client.make_qualified_table_name("post_details") + + print(pipeline.default_schema.to_pretty_yaml()) + + assert_query_data( + pipeline, + f"SELECT title FROM {posts_table} ORDER BY id limit 25", + [f"Post {i}" for i in range(25)], + ) + + assert_query_data( + pipeline, + f"SELECT body FROM {posts_details_table} ORDER BY id limit 25", + [f"Post body {i}" for i in range(25)], + ) + + +def test_load_mock_api_with_json_resolved(mock_api_server): + pipeline = dlt.pipeline( + pipeline_name="rest_api_mock", + destination="duckdb", + dataset_name="rest_api_mock", + full_refresh=True, + ) + + mock_source = rest_api_source( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + "posts", + { + "name": "post_details", + "endpoint": { + "path": "posts/search_by_id/{post_id}", + "method": "POST", + "json": { + "post_id": "{post_id}", + "limit": 5, + "more": { + "title": "{post_title}", + }, + "more_array": [ + "{post_id}", + ], + }, + "params": { + "post_id": { + "type": "resolve", + "resource": "posts", + "field": "id", + }, + "post_title": { + "type": "resolve", + "resource": "posts", + "field": "title", + }, + }, + }, + }, + ], + } + ) + + load_info = pipeline.run(mock_source) + print(load_info) + assert_load_info(load_info) + table_names = [t["name"] for t in pipeline.default_schema.data_tables()] + table_counts = load_table_counts(pipeline, *table_names) + + assert table_counts.keys() == {"posts", "post_details"} + + assert table_counts["posts"] == DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES + assert table_counts["post_details"] == DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES + + with pipeline.sql_client() as client: + posts_table = client.make_qualified_table_name("posts") + posts_details_table = client.make_qualified_table_name("post_details") + + print(pipeline.default_schema.to_pretty_yaml()) + + assert_query_data( + pipeline, + f"SELECT title FROM {posts_table} ORDER BY id limit 25", + [f"Post {i}" for i in range(25)], + ) + + assert_query_data( + pipeline, + f"SELECT body FROM {posts_details_table} ORDER BY id limit 25", + [f"Post body {i}" for i in range(25)], + ) + + +def test_load_mock_api_with_json_resolved_with_implicit_param(mock_api_server): + pipeline = dlt.pipeline( + pipeline_name="rest_api_mock", + destination="duckdb", + dataset_name="rest_api_mock", + full_refresh=True, + ) + + mock_source = rest_api_source( + { + "client": {"base_url": "https://api.example.com"}, + "resources": [ + "posts", + { + "name": "post_details", + "endpoint": { + "path": "posts/search_by_id/{resources.posts.id}", + "method": "POST", + "json": { + "post_id": "{resources.posts.id}", + "limit": 5, + "more": { + "title": "{resources.posts.title}", + }, + "more_array": [ + "{resources.posts.id}", + ], + }, + }, + }, + ], + } + ) + + load_info = pipeline.run(mock_source) + print(load_info) + assert_load_info(load_info) + table_names = [t["name"] for t in pipeline.default_schema.data_tables()] + table_counts = load_table_counts(pipeline, *table_names) + + assert table_counts.keys() == {"posts", "post_details"} + + assert table_counts["posts"] == DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES + assert table_counts["post_details"] == DEFAULT_PAGE_SIZE * DEFAULT_TOTAL_PAGES + + with pipeline.sql_client() as client: + posts_table = client.make_qualified_table_name("posts") + posts_details_table = client.make_qualified_table_name("post_details") + + print(pipeline.default_schema.to_pretty_yaml()) + + assert_query_data( + pipeline, + f"SELECT title FROM {posts_table} ORDER BY id limit 25", + [f"Post {i}" for i in range(25)], + ) + + assert_query_data( + pipeline, + f"SELECT body FROM {posts_details_table} ORDER BY id limit 25", + [f"Post body {i}" for i in range(25)], + ) + + assert_query_data( + pipeline, + f"SELECT title FROM {posts_details_table} ORDER BY id limit 25", + [f"Post {i}" for i in range(25)], + ) + + assert_query_data( + pipeline, + f"SELECT more FROM {posts_details_table} ORDER BY id limit 25", + [f"More is equale to id: {i}" for i in range(25)], + ) + + def test_source_with_post_request(mock_api_server): class JSONBodyPageCursorPaginator(BaseReferencePaginator): def update_state(self, response: Response, data: Optional[List[Any]] = None) -> None: @@ -120,7 +333,11 @@ def update_request(self, request: Request) -> None: "endpoint": { "path": "/posts/search", "method": "POST", - "json": {"ids_greater_than": 50, "page_size": 25, "page_count": 4}, + "json": { + "ids_greater_than": 50, + "page_size": 25, + "page_count": 4, + }, "paginator": JSONBodyPageCursorPaginator(), }, } diff --git a/tests/sources/rest_api/integration/test_processing_steps.py b/tests/sources/rest_api/integration/test_processing_steps.py index bbe90dda06..df18f98292 100644 --- a/tests/sources/rest_api/integration/test_processing_steps.py +++ b/tests/sources/rest_api/integration/test_processing_steps.py @@ -202,6 +202,36 @@ def test_rest_api_source_filtered_child(mock_api_server) -> None: assert len(data) == 2 +def test_rest_api_source_filtered_child_with_implicit_param(mock_api_server) -> None: + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + {"filter": lambda x: x["id"] in (1, 2)}, # type: ignore[typeddict-item] + ], + }, + { + "name": "comments", + "endpoint": { + "path": "/posts/{resources.posts.id}/comments", + }, + "processing_steps": [ + {"filter": lambda x: x["id"] == 1}, # type: ignore[typeddict-item] + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("comments")) + assert len(data) == 2 + + def test_rest_api_source_filtered_and_map_child(mock_api_server) -> None: def extend_body(row): row["body"] = f"{row['_posts_title']} - {row['body']}" @@ -243,3 +273,41 @@ def extend_body(row): data = list(mock_source.with_resources("comments")) assert data[0]["body"] == "Post 2 - Comment 0 for post 2" + + +def test_rest_api_source_filtered_and_map_child_with_implicit_param( + mock_api_server, +) -> None: + def extend_body(row): + row["body"] = f"{row['_posts_title']} - {row['body']}" + return row + + config: RESTAPIConfig = { + "client": { + "base_url": "https://api.example.com", + }, + "resources": [ + { + "name": "posts", + "endpoint": "posts", + "processing_steps": [ + {"filter": lambda x: x["id"] in (1, 2)}, # type: ignore[typeddict-item] + ], + }, + { + "name": "comments", + "endpoint": { + "path": "/posts/{resources.posts.id}/comments", + }, + "include_from_parent": ["title"], + "processing_steps": [ + {"map": extend_body}, # type: ignore[typeddict-item] + {"filter": lambda x: x["body"].startswith("Post 2")}, # type: ignore[typeddict-item] + ], + }, + ], + } + mock_source = rest_api_source(config) + + data = list(mock_source.with_resources("comments")) + assert data[0]["body"] == "Post 2 - Comment 0 for post 2" diff --git a/tests/sources/rest_api/test_rest_api_source.py b/tests/sources/rest_api/test_rest_api_source.py index 904bcaf159..f7d0b03a5b 100644 --- a/tests/sources/rest_api/test_rest_api_source.py +++ b/tests/sources/rest_api/test_rest_api_source.py @@ -1,7 +1,9 @@ import dlt import pytest -from dlt.common.configuration.specs.config_providers_context import ConfigProvidersContainer +from dlt.common.configuration.specs.config_providers_context import ( + ConfigProvidersContainer, +) from dlt.sources.rest_api.typing import RESTAPIConfig from dlt.sources.helpers.rest_client.paginators import SinglePagePaginator @@ -79,9 +81,9 @@ def test_rest_api_source(destination_name: str, invocation_type: str) -> None: assert table_counts.keys() == {"pokemon_list", "berry", "location"} - assert table_counts["pokemon_list"] == 1302 + assert table_counts["pokemon_list"] == 1304 assert table_counts["berry"] == 64 - assert table_counts["location"] == 1036 + assert table_counts["location"] == 1039 @pytest.mark.parametrize("destination_name", ALL_DESTINATIONS)