@@ -180,6 +180,7 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
180
180
181
181
self ._keys = []
182
182
self .remote = False
183
+ self .local = False
183
184
self .cache_time = cache_time
184
185
self .time_out = 0
185
186
self .etag = ""
@@ -189,6 +190,8 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
189
190
self .keyusage = keyusage
190
191
self .imp_jwks = None
191
192
self .last_updated = 0
193
+ self .last_remote = None # HTTP Date of last remote update
194
+ self .last_local = None # UNIX timestamp of last local update
192
195
193
196
if httpc :
194
197
self .httpc = httpc
@@ -208,13 +211,13 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
208
211
self .do_keys (keys )
209
212
else :
210
213
self ._set_source (source , fileformat )
211
-
212
- if not self .remote and self .source : # local file
214
+ if self .local :
213
215
self ._do_local (kid )
214
216
215
217
def _set_source (self , source , fileformat ):
216
218
if source .startswith ("file://" ):
217
219
self .source = source [7 :]
220
+ self .local = True
218
221
elif source .startswith ("http://" ) or source .startswith ("https://" ):
219
222
self .source = source
220
223
self .remote = True
@@ -224,6 +227,7 @@ def _set_source(self, source, fileformat):
224
227
if fileformat .lower () in ['rsa' , 'der' , 'jwks' ]:
225
228
if os .path .isfile (source ):
226
229
self .source = source
230
+ self .local = True
227
231
else :
228
232
raise ImportError ('No such file' )
229
233
else :
@@ -235,6 +239,16 @@ def _do_local(self, kid):
235
239
elif self .fileformat == "der" :
236
240
self .do_local_der (self .source , self .keytype , self .keyusage , kid )
237
241
242
+ def _local_update_required (self ) -> bool :
243
+ stat = os .stat (self .source )
244
+ if self .last_local and stat .st_mtime < self .last_local :
245
+ LOGGER .debug ("%s not modfied" , self .source )
246
+ return False
247
+ else :
248
+ LOGGER .debug ("%s modfied" , self .source )
249
+ self .last_local = stat .st_mtime
250
+ return True
251
+
238
252
def do_keys (self , keys ):
239
253
"""
240
254
Go from JWK description to binary keys
@@ -290,12 +304,15 @@ def do_local_jwk(self, filename):
290
304
291
305
:param filename: Name of the file from which the JWKS should be loaded
292
306
"""
307
+ LOGGER .debug ("Reading JWKS from %s" , filename )
293
308
with open (filename ) as input_file :
294
309
_info = json .load (input_file )
295
310
if 'keys' in _info :
296
311
self .do_keys (_info ["keys" ])
297
312
else :
298
313
self .do_keys ([_info ])
314
+ self .last_local = time .time ()
315
+ self .time_out = self .last_local + self .cache_time
299
316
300
317
def do_local_der (self , filename , keytype , keyusage = None , kid = '' ):
301
318
"""
@@ -305,6 +322,7 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''):
305
322
:param keytype: Presently 'rsa' and 'ec' supported
306
323
:param keyusage: encryption ('enc') or signing ('sig') or both
307
324
"""
325
+ LOGGER .debug ("Reading DER from %s" , filename )
308
326
key_args = {}
309
327
_kty = keytype .lower ()
310
328
if _kty in ['rsa' , 'ec' ]:
@@ -324,6 +342,8 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''):
324
342
key_args ['kid' ] = kid
325
343
326
344
self .do_keys ([key_args ])
345
+ self .last_local = time .time ()
346
+ self .time_out = self .last_local + self .cache_time
327
347
328
348
def do_remote (self ):
329
349
"""
@@ -336,6 +356,10 @@ def do_remote(self):
336
356
337
357
try :
338
358
LOGGER .debug ('KeyBundle fetch keys from: %s' , self .source )
359
+ if self .last_remote is not None :
360
+ if "headers" not in self .httpc_params :
361
+ self .httpc_params ["headers" ] = {}
362
+ self .httpc_params ["headers" ]["If-Modified-Since" ] = self .last_remote
339
363
_http_resp = self .httpc ('GET' , self .source , ** self .httpc_params )
340
364
except Exception as err :
341
365
LOGGER .error (err )
@@ -357,6 +381,14 @@ def do_remote(self):
357
381
LOGGER .error ("No 'keys' keyword in JWKS" )
358
382
raise UpdateFailed (MALFORMED .format (self .source ))
359
383
384
+ if hasattr (_http_resp , "headers" ):
385
+ headers = getattr (_http_resp , "headers" )
386
+ self .last_remote = headers .get ("last-modified" ) or headers .get ("date" )
387
+
388
+ elif _http_resp .status_code == 304 : # Not modified
389
+ LOGGER .debug ("%s not modified since %s" , self .source , self .last_remote )
390
+ pass
391
+
360
392
else :
361
393
raise UpdateFailed (
362
394
REMOTE_FAILED .format (self .source , _http_resp .status_code ))
@@ -387,14 +419,12 @@ def _parse_remote_response(self, response):
387
419
388
420
def _uptodate (self ):
389
421
res = False
390
- if not self ._keys :
391
- if self .remote : # verify that it's not to old
392
- if time .time () > self .time_out :
393
- if self .update ():
394
- res = True
395
- elif self .remote :
396
- if self .update ():
397
- res = True
422
+ if self .remote or self .local :
423
+ if time .time () > self .time_out :
424
+ if self .local and not self ._local_update_required ():
425
+ res = True
426
+ elif self .update ():
427
+ res = True
398
428
return res
399
429
400
430
def update (self ):
@@ -412,13 +442,13 @@ def update(self):
412
442
self ._keys = []
413
443
414
444
try :
415
- if self .remote is False :
445
+ if self .local :
416
446
if self .fileformat in ["jwks" , "jwk" ]:
417
447
self .do_local_jwk (self .source )
418
448
elif self .fileformat == "der" :
419
449
self .do_local_der (self .source , self .keytype ,
420
450
self .keyusage )
421
- else :
451
+ elif self . remote :
422
452
res = self .do_remote ()
423
453
except Exception as err :
424
454
LOGGER .error ('Key bundle update failed: %s' , err )
@@ -661,8 +691,11 @@ def dump(self):
661
691
"keys" : _keys ,
662
692
"fileformat" : self .fileformat ,
663
693
"last_updated" : self .last_updated ,
694
+ "last_remote" : self .last_remote ,
695
+ "last_local" : self .last_local ,
664
696
"httpc_params" : self .httpc_params ,
665
697
"remote" : self .remote ,
698
+ "local" : self .local ,
666
699
"imp_jwks" : self .imp_jwks ,
667
700
"time_out" : self .time_out ,
668
701
"cache_time" : self .cache_time
@@ -680,7 +713,10 @@ def load(self, spec):
680
713
self .source = spec .get ("source" , None )
681
714
self .fileformat = spec .get ("fileformat" , "jwks" )
682
715
self .last_updated = spec .get ("last_updated" , 0 )
716
+ self .last_remote = spec .get ("last_remote" , None )
717
+ self .last_local = spec .get ("last_local" , None )
683
718
self .remote = spec .get ("remote" , False )
719
+ self .local = spec .get ("local" , False )
684
720
self .imp_jwks = spec .get ('imp_jwks' , None )
685
721
self .time_out = spec .get ('time_out' , 0 )
686
722
self .cache_time = spec .get ('cache_time' , 0 )
0 commit comments