Skip to content

Commit de0b9bc

Browse files
authored
Merge pull request #66 from janste63/inactive_key_fixes
Inactive key fixes
2 parents 2b00abb + 54e59e1 commit de0b9bc

File tree

4 files changed

+44
-25
lines changed

4 files changed

+44
-25
lines changed

src/cryptojwt/key_bundle.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,11 @@ def do_local_jwk(self, filename):
318318
Load a JWKS from a local file
319319
320320
:param filename: Name of the file from which the JWKS should be loaded
321+
:return: True if load was successful or False if file hasn't been modified
321322
"""
323+
if not self._local_update_required():
324+
return False
325+
322326
LOGGER.info("Reading local JWKS from %s", filename)
323327
with open(filename) as input_file:
324328
_info = json.load(input_file)
@@ -328,6 +332,7 @@ def do_local_jwk(self, filename):
328332
self.do_keys([_info])
329333
self.last_local = time.time()
330334
self.time_out = self.last_local + self.cache_time
335+
return True
331336

332337
def do_local_der(self, filename, keytype, keyusage=None, kid=""):
333338
"""
@@ -336,7 +341,11 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
336341
:param filename: Name of the file
337342
:param keytype: Presently 'rsa' and 'ec' supported
338343
:param keyusage: encryption ('enc') or signing ('sig') or both
344+
:return: True if load was successful or False if file hasn't been modified
339345
"""
346+
if not self._local_update_required():
347+
return False
348+
340349
LOGGER.info("Reading local DER from %s", filename)
341350
key_args = {}
342351
_kty = keytype.lower()
@@ -359,12 +368,13 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
359368
self.do_keys([key_args])
360369
self.last_local = time.time()
361370
self.time_out = self.last_local + self.cache_time
371+
return True
362372

363373
def do_remote(self):
364374
"""
365375
Load a JWKS from a webpage.
366376
367-
:return: True or False if load was successful
377+
:return: True if load was successful or False if remote hasn't been modified
368378
"""
369379
# if self.verify_ssl is not None:
370380
# self.httpc_params["verify"] = self.verify_ssl
@@ -390,7 +400,10 @@ def do_remote(self):
390400
LOGGER.error(err)
391401
raise UpdateFailed(REMOTE_FAILED.format(self.source, str(err)))
392402

393-
if _http_resp.status_code == 200: # New content
403+
load_successful = _http_resp.status_code == 200
404+
not_modified = _http_resp.status_code == 304
405+
406+
if load_successful:
394407
self.time_out = time.time() + self.cache_time
395408

396409
self.imp_jwks = self._parse_remote_response(_http_resp)
@@ -408,11 +421,9 @@ def do_remote(self):
408421
if hasattr(_http_resp, "headers"):
409422
headers = getattr(_http_resp, "headers")
410423
self.last_remote = headers.get("last-modified") or headers.get("date")
411-
412-
elif _http_resp.status_code == 304: # Not modified
424+
elif not_modified:
413425
LOGGER.debug("%s not modified since %s", self.source, self.last_remote)
414426
self.time_out = time.time() + self.cache_time
415-
416427
else:
417428
LOGGER.warning(
418429
"HTTP status %d reading remote JWKS from %s",
@@ -424,7 +435,7 @@ def do_remote(self):
424435

425436
self.last_updated = time.time()
426437
self.ignore_errors_until = None
427-
return True
438+
return load_successful
428439

429440
def _parse_remote_response(self, response):
430441
"""
@@ -449,23 +460,20 @@ def _parse_remote_response(self, response):
449460
return None
450461

451462
def _uptodate(self):
452-
res = False
453463
if self.remote or self.local:
454464
if time.time() > self.time_out:
455-
if self.local and not self._local_update_required():
456-
res = True
457-
elif self.update():
458-
res = True
459-
return res
465+
return self.update()
466+
return False
460467

461468
def update(self):
462469
"""
463470
Reload the keys if necessary.
464471
465472
This is a forced update, will happen even if cache time has not elapsed.
466473
Replaced keys will be marked as inactive and not removed.
474+
475+
:return: True if update was ok or False if we encountered an error during update.
467476
"""
468-
res = True # An update was successful
469477
if self.source:
470478
_old_keys = self._keys # just in case
471479

@@ -475,24 +483,27 @@ def update(self):
475483
try:
476484
if self.local:
477485
if self.fileformat in ["jwks", "jwk"]:
478-
self.do_local_jwk(self.source)
486+
updated = self.do_local_jwk(self.source)
479487
elif self.fileformat == "der":
480-
self.do_local_der(self.source, self.keytype, self.keyusage)
488+
updated = self.do_local_der(self.source, self.keytype, self.keyusage)
481489
elif self.remote:
482-
res = self.do_remote()
490+
updated = self.do_remote()
483491
except Exception as err:
484492
LOGGER.error("Key bundle update failed: %s", err)
485493
self._keys = _old_keys # restore
486494
return False
487495

488-
now = time.time()
489-
for _key in _old_keys:
490-
if _key not in self._keys:
491-
if not _key.inactive_since: # If already marked don't mess
492-
_key.inactive_since = now
493-
self._keys.append(_key)
496+
if updated:
497+
now = time.time()
498+
for _key in _old_keys:
499+
if _key not in self._keys:
500+
if not _key.inactive_since: # If already marked don't mess
501+
_key.inactive_since = now
502+
self._keys.append(_key)
503+
else:
504+
self._keys = _old_keys
494505

495-
return res
506+
return True
496507

497508
def get(self, typ="", only_active=True):
498509
"""

tests/test_03_key_bundle.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ def test_update_2():
567567
ec_key = new_ec_key(crv="P-256", key_ops=["sign"])
568568
_jwks = {"keys": [rsa_key.serialize(), ec_key.serialize()]}
569569

570+
time.sleep(0.5)
570571
with open(fname, "w") as fp:
571572
fp.write(json.dumps(_jwks))
572573

@@ -1009,7 +1010,7 @@ def test_remote_not_modified():
10091010

10101011
with responses.RequestsMock() as rsps:
10111012
rsps.add(method="GET", url=source, status=304, headers=headers)
1012-
assert kb.do_remote()
1013+
assert not kb.do_remote()
10131014
assert kb.last_remote == headers.get("Last-Modified")
10141015
timeout2 = kb.time_out
10151016

@@ -1019,6 +1020,7 @@ def test_remote_not_modified():
10191020
kb2 = KeyBundle().load(exp)
10201021
assert kb2.source == source
10211022
assert len(kb2.keys()) == 3
1023+
assert len(kb2.active_keys()) == 3
10221024
assert len(kb2.get("rsa")) == 1
10231025
assert len(kb2.get("oct")) == 1
10241026
assert len(kb2.get("ec")) == 1

tests/test_04_key_jar.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,12 @@ def test_aud(self):
746746
keys = self.bob_keyjar.get_jwt_verify_keys(_jwt.jwt, no_kid_issuer=no_kid_issuer)
747747
assert len(keys) == 1
748748

749+
def test_inactive_verify_key(self):
750+
_jwt = factory(self.sjwt_b)
751+
self.alice_keyjar.return_issuer("Bob")[0].mark_all_as_inactive()
752+
keys = self.alice_keyjar.get_jwt_verify_keys(_jwt.jwt)
753+
assert len(keys) == 0
754+
749755

750756
def test_copy():
751757
kj = KeyJar()

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ envlist = py{36,37,38},quality
44
[testenv]
55
passenv = CI TRAVIS TRAVIS_*
66
commands =
7-
py.test --cov=cryptojwt --isort --black {posargs}
7+
pytest -vvv -ra --cov=cryptojwt --isort --black {posargs}
88
codecov
99
extras = testing
1010
deps =

0 commit comments

Comments
 (0)