Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change: Reuse matches in CPE match strings API #1082

Merged
merged 6 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions pontos/nvd/cpe_match/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@
MAX_CPE_MATCHES_PER_PAGE = 500


def _result_iterator(data: JSON) -> Iterator[CPEMatchString]:
results: list[dict[str, Any]] = data.get("match_strings", []) # type: ignore
return (
CPEMatchString.from_dict(result["match_string"]) for result in results
)


class CPEMatchApi(NVDApi):
"""
API for querying the NIST NVD CPE match information.
Expand Down Expand Up @@ -83,6 +76,7 @@ def __init__(
timeout=timeout,
rate_limit=rate_limit,
)
self._cpe_match_cache: dict[str, Any] = {}

def cpe_matches(
self,
Expand Down Expand Up @@ -157,12 +151,30 @@ def cpe_matches(
return NVDResults(
self,
params,
_result_iterator,
self._result_iterator,
request_results=request_results,
results_per_page=results_per_page,
start_index=start_index,
)

def _result_iterator(self, data: JSON) -> Iterator[CPEMatchString]:
"""
Creates an iterator of all the CPEMatchStrings in given API response JSON

Args:
data: The JSON response data to get the match strings from

Returns:
An iterator over the CPEMatchStrings
"""
results: list[dict[str, Any]] = data.get("match_strings", []) # type: ignore
return (
CPEMatchString.from_dict_with_cache(
result["match_string"], self._cpe_match_cache
)
for result in results
)

async def cpe_match(self, match_criteria_id: str) -> CPEMatchString:
"""
Returns a single CPE match for the given match criteria id.
Expand Down Expand Up @@ -201,7 +213,9 @@ async def cpe_match(self, match_criteria_id: str) -> CPEMatchString:
)

match_string = match_strings[0]
return CPEMatchString.from_dict(match_string["match_string"])
return CPEMatchString.from_dict_with_cache(
match_string["match_string"], self._cpe_match_cache
)

async def __aenter__(self) -> "CPEMatchApi":
await super().__aenter__()
Expand Down
30 changes: 29 additions & 1 deletion pontos/nvd/models/cpe_match_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Optional
from typing import Any, List, Optional
from uuid import UUID

from pontos.models import Model
Expand Down Expand Up @@ -55,3 +55,31 @@ class CPEMatchString(Model):
version_start_excluding: Optional[str] = None
version_end_including: Optional[str] = None
version_end_excluding: Optional[str] = None

@classmethod
def from_dict_with_cache(
cls,
data: dict[str, Any],
cpe_match_cache: dict[str, CPEMatch],
):
"""
Create a CPEMatchString model from a dict, reusing
duplicate CPEMatch objects to reduce memory usage if a cache
dict is given.

Args:
data: The JSON dict to generate the model from
cpe_match_cache: A dictionary to store CPE matches or None
to not cache and reused CPE matches
"""
new_match_string = cls.from_dict(data)

for i, match in enumerate(new_match_string.matches):
cached_match: Optional[CPEMatch] = cpe_match_cache.get(
match.cpe_name_id
)
if cached_match and cached_match.cpe_name == match.cpe_name:
new_match_string.matches[i] = cached_match
else:
cpe_match_cache[match.cpe_name_id] = match
return new_match_string
57 changes: 57 additions & 0 deletions tests/nvd/cpe_match/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,63 @@ async def test_cpe_matches_request_results(self):
with self.assertRaises(Exception):
cpe_match = await anext(it)

async def test_cpe_match_caching(self):
match_criteria_id = uuid4()
cpe_name_id = uuid4()

responses = create_cpe_match_responses(
match_criteria_id=match_criteria_id,
cpe_name_id=cpe_name_id,
results_per_response=3,
)
self.http_client.get.side_effect = responses
response_matches = [
[
match_string["match_string"]["matches"]
for match_string in response.json.return_value["match_strings"]
]
for response in responses
]

# Make matches of first match_string identical in each response
response_matches[1][0][0]["cpe_name"] = response_matches[0][0][0][
"cpe_name"
]
response_matches[1][0][0]["cpe_name_id"] = response_matches[0][0][0][
"cpe_name_id"
]
# Make matches of second match_string only have the same cpe_name_id
response_matches[1][1][0]["cpe_name_id"] = response_matches[0][1][0][
"cpe_name_id"
]
# Leave matches of third match_string different from each other

it = aiter(self.api.cpe_matches(request_results=10))
received = [item async for item in it]

# First matches in each response of three items must be identical objects
self.assertIs(received[0].matches[0], received[3].matches[0])

# Second matches in each response of three items must only have same cpe_name_id
self.assertIsNot(received[1].matches[0], received[4].matches[0])
self.assertEqual(
received[1].matches[0].cpe_name_id,
received[4].matches[0].cpe_name_id,
)
self.assertNotEqual(
received[1].matches[0].cpe_name, received[4].matches[0].cpe_name
)

# Third matches in each response of three items must be different
self.assertIsNot(received[2].matches[0], received[5].matches[0])
self.assertNotEqual(
received[2].matches[0].cpe_name_id,
received[5].matches[0].cpe_name_id,
)
self.assertNotEqual(
received[2].matches[0].cpe_name, received[5].matches[0].cpe_name
)

async def test_context_manager(self):
async with self.api:
pass
Expand Down
Loading