Skip to content

Commit fb77ad8

Browse files
authored
Merge pull request #42 from jschlyter/refresh_keybundle
Refresh keybundle
2 parents 85e7aec + 351bf7e commit fb77ad8

File tree

2 files changed

+51
-12
lines changed

2 files changed

+51
-12
lines changed

src/cryptojwt/key_bundle.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
180180

181181
self._keys = []
182182
self.remote = False
183+
self.local = False
183184
self.cache_time = cache_time
184185
self.time_out = 0
185186
self.etag = ""
@@ -189,6 +190,8 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
189190
self.keyusage = keyusage
190191
self.imp_jwks = None
191192
self.last_updated = 0
193+
self.last_remote = None # HTTP Date of last remote update
194+
self.last_local = None # UNIX timestamp of last local update
192195

193196
if httpc:
194197
self.httpc = httpc
@@ -208,13 +211,13 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
208211
self.do_keys(keys)
209212
else:
210213
self._set_source(source, fileformat)
211-
212-
if not self.remote and self.source: # local file
214+
if self.local:
213215
self._do_local(kid)
214216

215217
def _set_source(self, source, fileformat):
216218
if source.startswith("file://"):
217219
self.source = source[7:]
220+
self.local = True
218221
elif source.startswith("http://") or source.startswith("https://"):
219222
self.source = source
220223
self.remote = True
@@ -224,6 +227,7 @@ def _set_source(self, source, fileformat):
224227
if fileformat.lower() in ['rsa', 'der', 'jwks']:
225228
if os.path.isfile(source):
226229
self.source = source
230+
self.local = True
227231
else:
228232
raise ImportError('No such file')
229233
else:
@@ -235,6 +239,16 @@ def _do_local(self, kid):
235239
elif self.fileformat == "der":
236240
self.do_local_der(self.source, self.keytype, self.keyusage, kid)
237241

242+
def _local_update_required(self) -> bool:
243+
stat = os.stat(self.source)
244+
if self.last_local and stat.st_mtime < self.last_local:
245+
LOGGER.debug("%s not modfied", self.source)
246+
return False
247+
else:
248+
LOGGER.debug("%s modfied", self.source)
249+
self.last_local = stat.st_mtime
250+
return True
251+
238252
def do_keys(self, keys):
239253
"""
240254
Go from JWK description to binary keys
@@ -290,12 +304,15 @@ def do_local_jwk(self, filename):
290304
291305
:param filename: Name of the file from which the JWKS should be loaded
292306
"""
307+
LOGGER.debug("Reading JWKS from %s", filename)
293308
with open(filename) as input_file:
294309
_info = json.load(input_file)
295310
if 'keys' in _info:
296311
self.do_keys(_info["keys"])
297312
else:
298313
self.do_keys([_info])
314+
self.last_local = time.time()
315+
self.time_out = self.last_local + self.cache_time
299316

300317
def do_local_der(self, filename, keytype, keyusage=None, kid=''):
301318
"""
@@ -305,6 +322,7 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''):
305322
:param keytype: Presently 'rsa' and 'ec' supported
306323
:param keyusage: encryption ('enc') or signing ('sig') or both
307324
"""
325+
LOGGER.debug("Reading DER from %s", filename)
308326
key_args = {}
309327
_kty = keytype.lower()
310328
if _kty in ['rsa', 'ec']:
@@ -324,6 +342,8 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''):
324342
key_args['kid'] = kid
325343

326344
self.do_keys([key_args])
345+
self.last_local = time.time()
346+
self.time_out = self.last_local + self.cache_time
327347

328348
def do_remote(self):
329349
"""
@@ -336,6 +356,10 @@ def do_remote(self):
336356

337357
try:
338358
LOGGER.debug('KeyBundle fetch keys from: %s', self.source)
359+
if self.last_remote is not None:
360+
if "headers" not in self.httpc_params:
361+
self.httpc_params["headers"] = {}
362+
self.httpc_params["headers"]["If-Modified-Since"] = self.last_remote
339363
_http_resp = self.httpc('GET', self.source, **self.httpc_params)
340364
except Exception as err:
341365
LOGGER.error(err)
@@ -357,6 +381,14 @@ def do_remote(self):
357381
LOGGER.error("No 'keys' keyword in JWKS")
358382
raise UpdateFailed(MALFORMED.format(self.source))
359383

384+
if hasattr(_http_resp, "headers"):
385+
headers = getattr(_http_resp, "headers")
386+
self.last_remote = headers.get("last-modified") or headers.get("date")
387+
388+
elif _http_resp.status_code == 304: # Not modified
389+
LOGGER.debug("%s not modified since %s", self.source, self.last_remote)
390+
pass
391+
360392
else:
361393
raise UpdateFailed(
362394
REMOTE_FAILED.format(self.source, _http_resp.status_code))
@@ -387,14 +419,12 @@ def _parse_remote_response(self, response):
387419

388420
def _uptodate(self):
389421
res = False
390-
if not self._keys:
391-
if self.remote: # verify that it's not to old
392-
if time.time() > self.time_out:
393-
if self.update():
394-
res = True
395-
elif self.remote:
396-
if self.update():
397-
res = True
422+
if self.remote or self.local:
423+
if time.time() > self.time_out:
424+
if self.local and not self._local_update_required():
425+
res = True
426+
elif self.update():
427+
res = True
398428
return res
399429

400430
def update(self):
@@ -412,13 +442,13 @@ def update(self):
412442
self._keys = []
413443

414444
try:
415-
if self.remote is False:
445+
if self.local:
416446
if self.fileformat in ["jwks", "jwk"]:
417447
self.do_local_jwk(self.source)
418448
elif self.fileformat == "der":
419449
self.do_local_der(self.source, self.keytype,
420450
self.keyusage)
421-
else:
451+
elif self.remote:
422452
res = self.do_remote()
423453
except Exception as err:
424454
LOGGER.error('Key bundle update failed: %s', err)
@@ -661,8 +691,11 @@ def dump(self):
661691
"keys": _keys,
662692
"fileformat": self.fileformat,
663693
"last_updated": self.last_updated,
694+
"last_remote": self.last_remote,
695+
"last_local": self.last_local,
664696
"httpc_params": self.httpc_params,
665697
"remote": self.remote,
698+
"local": self.local,
666699
"imp_jwks": self.imp_jwks,
667700
"time_out": self.time_out,
668701
"cache_time": self.cache_time
@@ -680,7 +713,10 @@ def load(self, spec):
680713
self.source = spec.get("source", None)
681714
self.fileformat = spec.get("fileformat", "jwks")
682715
self.last_updated = spec.get("last_updated", 0)
716+
self.last_remote = spec.get("last_remote", None)
717+
self.last_local = spec.get("last_local", None)
683718
self.remote = spec.get("remote", False)
719+
self.local = spec.get("local", False)
684720
self.imp_jwks = spec.get('imp_jwks', None)
685721
self.time_out = spec.get('time_out', 0)
686722
self.cache_time = spec.get('cache_time', 0)

tests/test_03_key_bundle.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,10 @@ def test_export_inactive():
938938
'imp_jwks',
939939
'keys',
940940
'last_updated',
941+
'last_remote',
942+
'last_local',
941943
'remote',
944+
'local',
942945
'time_out'}
943946

944947
kb2 = KeyBundle().load(res)

0 commit comments

Comments
 (0)