Skip to content

Commit 72cf265

Browse files
committed
WIP: refactor pipe state into a pydantic basemodel (not working yet)
1 parent 4e4de10 commit 72cf265

File tree

6 files changed

+78
-57
lines changed

6 files changed

+78
-57
lines changed

src/pyff/api.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pyff.constants import config
2323
from pyff.exceptions import ResourceException
2424
from pyff.logs import get_log
25-
from pyff.pipes import plumbing
25+
from pyff.pipes import PipeState, plumbing
2626
from pyff.repo import MDRepository
2727
from pyff.resource import Resource
2828
from pyff.samlmd import entity_display_name
@@ -288,35 +288,33 @@ def process_handler(request: Request) -> Response:
288288

289289
try:
290290
for p in request.registry.plumbings:
291-
state = {
292-
entry: True,
293-
'headers': {'Content-Type': None},
294-
'accept': accept,
295-
'url': request.current_route_url(),
296-
'select': q,
297-
'match': match.lower() if match else match,
298-
'path': new_path,
299-
'stats': {},
300-
}
291+
state = PipeState(
292+
entry_name=entry,
293+
headers={'Content-Type': None},
294+
accept=accept,
295+
url=request.current_route_url(),
296+
select=q,
297+
match=match.lower() if match else match,
298+
path=new_path,
299+
stats={},
300+
)
301301

302302
r = p.process(request.registry.md, state=state, raise_exceptions=True, scheduler=request.registry.scheduler)
303303
log.debug(f'Plumbing process result: {r}')
304304
if r is None:
305305
r = []
306306

307307
response = Response()
308-
_headers = state.get('headers', {})
309-
response.headers.update(_headers)
310-
ctype = _headers.get('Content-Type', None)
308+
response.headers.update(state.headers)
309+
ctype = state.headers.get('Content-Type', None)
311310
if not ctype:
312311
r, t = _fmt(r, accept)
313312
ctype = t
314313

315314
response.text = b2u(r)
316315
response.size = len(r)
317316
response.content_type = ctype
318-
cache_ttl = int(state.get('cache', 0))
319-
response.expires = datetime.now() + timedelta(seconds=cache_ttl)
317+
response.expires = datetime.now() + timedelta(seconds=state.cache)
320318
return response
321319
except ResourceException as ex:
322320
import traceback

src/pyff/builtins.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pyff.decorators import deprecated
2626
from pyff.exceptions import MetadataException
2727
from pyff.logs import get_log
28-
from pyff.pipes import PipeException, PipelineCallback, Plumbing, pipe, registry
28+
from pyff.pipes import PipeException, PipeState, PipelineCallback, Plumbing, pipe, registry
2929
from pyff.samlmd import (
3030
annotate_entity,
3131
discojson_t,
@@ -383,14 +383,13 @@ def when(req: Plumbing.Request, condition: str, *values):
383383
The condition operates on the state: if 'foo' is present in the state (with any value), then the something branch is
384384
followed. If 'bar' is present in the state with the value 'bill' then the other branch is followed.
385385
"""
386-
c = req.state.get(condition, None)
387-
if c is None:
386+
if req.state.entry_name is None:
388387
log.debug(f'Condition {repr(condition)} not present in state {req.state}')
389-
if c is not None and (not values or _any(values, c)):
388+
if req.state.entry_name is not None and (not values or _any(values, req.state.entry_name)):
390389
if not isinstance(req.args, list):
391390
raise ValueError('Non-list arguments to "when" not allowed')
392391

393-
return Plumbing(pipeline=req.args, pid="%s.when" % req.plumbing.id).iprocess(req)
392+
return Plumbing(pipeline=req.args, pid=f'{req.plumbing.id}.when').iprocess(req)
394393
return req.t
395394

396395

@@ -768,9 +767,9 @@ def select(req: Plumbing.Request, *opts):
768767

769768
entities = resolve_entities(args, lookup_fn=req.md.store.select)
770769

771-
if req.state.get('match', None): # TODO - allow this to be passed in via normal arguments
770+
if req.state.match: # TODO - allow this to be passed in via normal arguments
772771

773-
match = req.state['match']
772+
match = req.state.match
774773

775774
if isinstance(match, six.string_types):
776775
query = [match.lower()]
@@ -1435,11 +1434,11 @@ def emit(req: Plumbing.Request, ctype="application/xml", *opts):
14351434
if not isinstance(d, six.binary_type):
14361435
d = d.encode("utf-8")
14371436
m.update(d)
1438-
req.state['headers']['ETag'] = m.hexdigest()
1437+
req.state.headers['ETag'] = m.hexdigest()
14391438
else:
14401439
raise PipeException("Empty")
14411440

1442-
req.state['headers']['Content-Type'] = ctype
1441+
req.state.headers['Content-Type'] = ctype
14431442
if six.PY2:
14441443
d = six.u(d)
14451444
return d
@@ -1517,7 +1516,7 @@ def finalize(req: Plumbing.Request, *opts):
15171516
if name is None or 0 == len(name):
15181517
name = req.args.get('Name', None)
15191518
if name is None or 0 == len(name):
1520-
name = req.state.get('url', None)
1519+
name = req.state.url
15211520
if name and 'baseURL' in req.args:
15221521

15231522
try:
@@ -1569,7 +1568,7 @@ def finalize(req: Plumbing.Request, *opts):
15691568
# TODO: offset can be None here, if validUntil is not a valid duration or ISO date
15701569
# What is the right action to take then?
15711570
if offset:
1572-
req.state['cache'] = int(total_seconds(offset) / 50)
1571+
req.state.cache = int(total_seconds(offset) / 50)
15731572

15741573
cache_duration = req.args.get('cacheDuration', e.get('cacheDuration', None))
15751574
if cache_duration is not None and len(cache_duration) > 0:
@@ -1578,7 +1577,7 @@ def finalize(req: Plumbing.Request, *opts):
15781577
raise PipeException("Unable to parse %s as xs:duration" % cache_duration)
15791578

15801579
e.set('cacheDuration', cache_duration)
1581-
req.state['cache'] = int(total_seconds(offset))
1580+
req.state.cache = int(total_seconds(offset))
15821581

15831582
return req.t
15841583

src/pyff/pipes.py

+40-22
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,25 @@
77
import functools
88
import os
99
import traceback
10-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
10+
from typing import Any, Callable, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Type
11+
from typing import Union
1112

1213
import yaml
1314
from apscheduler.schedulers.background import BackgroundScheduler
1415
from lxml.etree import Element, ElementTree
16+
from pydantic import BaseModel, Field
1517

1618
from pyff.logs import get_log
1719
from pyff.repo import MDRepository
1820
from pyff.store import SAMLStoreBase
1921
from pyff.utils import PyffException, is_text, resource_string
2022

23+
if TYPE_CHECKING:
24+
from pyff.api import MediaAccept
25+
26+
# Avoid static analysers flagging this import as unused
27+
assert MediaAccept
28+
2129
log = get_log(__name__)
2230

2331
__author__ = 'leifj'
@@ -77,7 +85,7 @@ class PluginsRegistry(dict):
7785
def the_something_func(req,*opts):
7886
pass
7987
80-
Referencing this function as an entry_point using something = module:the_somethig_func in setup.py allows the
88+
Referencing this function as an entry_point using something = module:the_something_func in setup.py allows the
8189
function to be referenced as 'something' in a pipeline.
8290
"""
8391

@@ -160,22 +168,37 @@ def __deepcopy__(self, memo: Any) -> PipelineCallback:
160168
# TODO: This seems... dangerous. What's the need for this?
161169
return self
162170

163-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
171+
def __call__(self, t: ElementTree, state: Optional[PipeState] = None) -> Any:
164172
log.debug("{!s}: called".format(self.plumbing))
165-
t = args[0]
173+
if state is None:
174+
state = PipeState()
166175
if t is None:
167176
raise ValueError("PipelineCallback must be called with a parse-tree argument")
177+
if not isinstance(state, PipeState):
178+
raise ValueError(f'PipelineCallback called with invalid state ({type(state)}')
168179
try:
169-
state = kwargs
170-
state[self.entry_point] = True
171-
log.debug("state: {}".format(repr(state)))
180+
state.entry_name = self.entry_point
181+
log.debug("state: {}".format(state))
172182
return self.plumbing.process(self.req.md, store=self.store, state=state, t=t)
173183
except Exception as ex:
174184
log.debug(traceback.format_exc())
175185
log.error(f'Got an exception executing the plumbing process: {ex}')
176186
raise ex
177187

178188

189+
class PipeState(BaseModel):
190+
batch: bool = False
191+
entry_name: Optional[str] = None
192+
headers: Dict[str, Any] = Field({})
193+
accept: Any = None # TODO: Re-arrange classes so that type 'MediaAccept' works
194+
url: str = ''
195+
select: str = ''
196+
match: str = ''
197+
path: str = ''
198+
stats: Dict[str, Any] = Field({})
199+
cache: int = 0 # cache_ttl
200+
201+
179202
class Plumbing(object):
180203
"""
181204
A plumbing instance represents a basic processing chain for SAML metadata. A simple, yet reasonably complete example:
@@ -201,7 +224,7 @@ class Plumbing(object):
201224
202225
Running this plumbing would bake all metadata found in /var/metadata/registry and at http://md.example.com into an
203226
EntitiesDescriptor element with @Name http://example.com/metadata.xml, @cacheDuration set to 1hr and @validUntil
204-
1 day from the time the 'finalize' command was run. The tree woud be transformed using the "tidy" stylesheets and
227+
1 day from the time the 'finalize' command was run. The tree would be transformed using the "tidy" stylesheets and
205228
would then be signed (using signer.key) and finally published in /var/metadata/public/metadata.xml
206229
"""
207230

@@ -237,27 +260,25 @@ def __init__(
237260
self,
238261
pl: Plumbing,
239262
md: MDRepository,
240-
t=None,
241-
name=None,
242-
args=None,
243-
state: Optional[Dict[str, Any]] = None,
244-
store=None,
263+
state: Optional[PipeState] = None,
264+
t: Optional[ElementTree] = None,
265+
name: Optional[str] = None,
266+
args: Optional[Union[str, Dict, List]] = None,
267+
store: Optional[SAMLStoreBase] = None,
245268
scheduler: Optional[BackgroundScheduler] = None,
246269
raise_exceptions: bool = True,
247270
):
248-
if not state:
249-
state = dict()
250271
if not args:
251272
args = []
252273
self.plumbing: Plumbing = pl
253274
self.md: MDRepository = md
254275
self.t: ElementTree = t
255276
self._id: Optional[str] = None
256-
self.name = name
277+
self.name: Optional[str] = name
257278
self.args: Optional[Union[str, Dict, List]] = args
258-
self.state: Dict[str, Any] = state
279+
self.state: PipeState = state if state else PipeState()
259280
self.done: bool = False
260-
self._store: SAMLStoreBase = store
281+
self._store: Optional[SAMLStoreBase] = store
261282
self.scheduler: Optional[BackgroundScheduler] = scheduler
262283
self.raise_exceptions: bool = raise_exceptions
263284
self.exception: Optional[BaseException] = None
@@ -337,8 +358,8 @@ def iprocess(self, req: Plumbing.Request) -> ElementTree:
337358
def process(
338359
self,
339360
md: MDRepository,
361+
state: PipeState,
340362
args: Any = None,
341-
state: Optional[Dict[str, Any]] = None,
342363
t: Optional[ElementTree] = None,
343364
store: Optional[SAMLStoreBase] = None,
344365
raise_exceptions: bool = True,
@@ -357,9 +378,6 @@ def process(
357378
:param args: Pipeline arguments
358379
:return: The result of applying the processing pipeline to t.
359380
"""
360-
if not state:
361-
state = dict()
362-
363381
return Plumbing.Request(
364382
self, md, t=t, args=args, state=state, store=store, raise_exceptions=raise_exceptions, scheduler=scheduler
365383
).process(self)

src/pyff/resource.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,13 @@ def parse(self, getter: Callable[[str], Response]) -> Deque[Resource]:
468468
if self.post:
469469
for cb in self.post:
470470
if self.t is not None:
471-
self.t = cb(self.t, self.opts.dict())
471+
# TODO: This used to be
472+
# self.t = cb(self.t, self.opts.dict())
473+
# but passing self.opts does not seem to be what the callback expected.
474+
# Don't know what to do really.
475+
from pyff.pipes import PipeState
476+
477+
self.t = cb(self.t, PipeState())
472478

473479
if self.is_expired():
474480
info.expired = True

src/pyff/test/test_pipeline.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pyff import builtins
1212
from pyff.exceptions import MetadataException
1313
from pyff.parse import ParserException
14-
from pyff.pipes import PipeException, Plumbing, plumbing
14+
from pyff.pipes import PipeException, PipeState, Plumbing, plumbing
1515
from pyff.repo import MDRepository
1616
from pyff.resource import ResourceException
1717
from pyff.test import ExitException, SignerTestCase
@@ -61,7 +61,7 @@ def run_pipeline(self, pl_name, ctx=None, md=None):
6161
template = templates.get_template(pl_name)
6262
with open(pipeline, "w") as fd:
6363
fd.write(template.render(ctx=ctx))
64-
res = plumbing(pipeline).process(md, state={'batch': True, 'stats': {}})
64+
res = plumbing(pipeline).process(md, PipeState(entry_name='batch'))
6565
os.unlink(pipeline)
6666
return res, md, ctx
6767

@@ -70,7 +70,7 @@ def exec_pipeline(self, pstr):
7070
p = yaml.safe_load(six.StringIO(pstr))
7171
print("\n{}".format(yaml.dump(p)))
7272
pl = Plumbing(p, pid="test")
73-
res = pl.process(md, state={'batch': True, 'stats': {}})
73+
res = pl.process(md, PipeState(entry_name='batch'))
7474
return res, md
7575

7676
@classmethod

src/pyff/test/test_simple_pipeline.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mako.lookup import TemplateLookup
55

66
from pyff.constants import NS
7-
from pyff.pipes import plumbing
7+
from pyff.pipes import PipeState, plumbing
88
from pyff.repo import MDRepository
99
from pyff.test import SignerTestCase
1010

@@ -24,8 +24,8 @@ def setUp(self):
2424
fd.write(self.signer_template.render(ctx=self))
2525
with open(self.validator, "w") as fd:
2626
fd.write(self.validator_template.render(ctx=self))
27-
self.signer_result = plumbing(self.signer).process(self.md_signer, state={'batch': True, 'stats': {}})
28-
self.validator_result = plumbing(self.validator).process(self.md_validator, state={'batch': True, 'stats': {}})
27+
self.signer_result = plumbing(self.signer).process(self.md_signer, state=PipeState(batch=True))
28+
self.validator_result = plumbing(self.validator).process(self.md_validator, state=PipeState(batch=True))
2929

3030
def test_entityid_present(self):
3131
eids = [e.get('entityID') for e in self.md_signer.store]

0 commit comments

Comments
 (0)