diff --git a/src/pyff/api.py b/src/pyff/api.py index f876c1d3..8ad86ab8 100644 --- a/src/pyff/api.py +++ b/src/pyff/api.py @@ -1,6 +1,7 @@ import importlib import threading from datetime import datetime, timedelta +from enum import Enum from json import dumps from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple @@ -21,7 +22,7 @@ from pyff.constants import config from pyff.exceptions import ResourceException from pyff.logs import get_log -from pyff.pipes import plumbing +from pyff.pipes import PipeState, plumbing from pyff.repo import MDRepository from pyff.resource import Resource from pyff.samlmd import entity_display_name @@ -153,16 +154,36 @@ def request_handler(request: Request) -> Response: return r -def process_handler(request: Request) -> Response: +class ContentNegPolicy(Enum): + extension = 'extension' # current default + adaptive = 'adaptive' + header = 'header' # future default + + +def _process_content_negotiate( + policy: ContentNegPolicy, alias: str, path: Optional[str], pfx, request: Request +) -> Tuple[MediaAccept, Optional[str], Optional[str]]: """ - The main request handler for pyFF. Implements API call hooks and content negotiation. + Determine requested content type, based on policy, Accept request header and path extension. - :param request: the HTTP request object - :return: the data to send to the client + content_negotiation_policy is one of three values: + + 1. extension - current default, inspect the path and if it ends in + an extension, e.g. .xml or .json, always strip off the extension to + get the entityID and if no accept header or a wildcard header, then + use the extension to determine the return Content-Type. + + 2. adaptive - only if no accept header or if a wildcard, then inspect + the path and if it ends in an extension strip off the extension to + get the entityID and use the extension to determine the return + Content-Type. + + 3. header - future default, do not inspect the path for an extension and + use only the Accept header to determine the return Content-Type. """ _ctypes = {'xml': 'application/samlmetadata+xml;application/xml;text/xml', 'json': 'application/json'} - def _d(x: Optional[str], do_split: bool = True) -> Tuple[Optional[str], Optional[str]]: + def _split_path(x: Optional[str], do_split: bool = True) -> Tuple[Optional[str], Optional[str]]: """ Split a path into a base component and an extension. """ if x is not None: x = x.strip() @@ -178,6 +199,45 @@ def _d(x: Optional[str], do_split: bool = True) -> Tuple[Optional[str], Optional return x, None + # TODO - sometimes the client sends > 1 accept header value with ','. + accept = str(request.accept).split(',')[0] + valid_accept = accept and not ('application/*' in accept or 'text/*' in accept or '*/*' in accept) + + path_no_extension, extension = _split_path(path, True) + accept_from_extension = accept + if extension: + accept_from_extension = _ctypes.get(extension, accept) + + if policy == ContentNegPolicy.extension: + path = path_no_extension + if not valid_accept: + accept = accept_from_extension + elif policy == ContentNegPolicy.adaptive: + if not valid_accept: + path = path_no_extension + accept = accept_from_extension + + if not accept: + log.warning('Could not determine accepted response type') + raise exc.exception_response(400) + + q: Optional[str] + if pfx and path: + q = f'{{{pfx}}}{path}' + path = f'/{alias}/{path}' + else: + q = path + + return MediaAccept(accept), path, q + + +def process_handler(request: Request) -> Response: + """ + The main request handler for pyFF. Implements API call hooks and content negotiation. + + :param request: the HTTP request object + :return: the data to send to the client + """ log.debug(f'Processing request: {request}') if request.matchdict is None: @@ -215,64 +275,29 @@ def _d(x: Optional[str], do_split: bool = True) -> Tuple[Optional[str], Optional if pfx is None: raise exc.exception_response(404) - # content_negotiation_policy is one of three values: - # 1. extension - current default, inspect the path and if it ends in - # an extension, e.g. .xml or .json, always strip off the extension to - # get the entityID and if no accept header or a wildcard header, then - # use the extension to determine the return Content-Type. - # - # 2. adaptive - only if no accept header or if a wildcard, then inspect - # the path and if it ends in an extension strip off the extension to - # get the entityID and use the extension to determine the return - # Content-Type. - # - # 3. header - future default, do not inspect the path for an extension and - # use only the Accept header to determine the return Content-Type. - policy = config.content_negotiation_policy - - # TODO - sometimes the client sends > 1 accept header value with ','. - accept = str(request.accept).split(',')[0] - valid_accept = accept and not ('application/*' in accept or 'text/*' in accept or '*/*' in accept) - - new_path: Optional[str] = path - path_no_extension, extension = _d(new_path, True) - accept_from_extension = accept - if extension: - accept_from_extension = _ctypes.get(extension, accept) - - if policy == 'extension': - new_path = path_no_extension - if not valid_accept: - accept = accept_from_extension - elif policy == 'adaptive': - if not valid_accept: - new_path = path_no_extension - accept = accept_from_extension - - if not accept: - log.warning('Could not determine accepted response type') - raise exc.exception_response(400) + try: + policy = ContentNegPolicy(config.content_negotiation_policy) + except ValueError: + log.debug( + f'Invalid value for config.content_negotiation_policy: {config.content_negotiation_policy}, ' + f'defaulting to "extension"' + ) + policy = ContentNegPolicy.extension - q: Optional[str] - if pfx and new_path: - q = f'{{{pfx}}}{new_path}' - new_path = f'/{alias}/{new_path}' - else: - q = new_path + accept, new_path, q = _process_content_negotiate(policy, alias, path, pfx, request) try: - accepter = MediaAccept(accept) for p in request.registry.plumbings: - state = { - entry: True, - 'headers': {'Content-Type': None}, - 'accept': accepter, - 'url': request.current_route_url(), - 'select': q, - 'match': match.lower() if match else match, - 'path': new_path, - 'stats': {}, - } + state = PipeState( + entry_name=entry, + headers={'Content-Type': None}, + accept=accept, + url=request.current_route_url(), + select=q, + match=match.lower() if match else match, + path=new_path, + stats={}, + ) r = p.process(request.registry.md, state=state, raise_exceptions=True, scheduler=request.registry.scheduler) log.debug(f'Plumbing process result: {r}') @@ -280,18 +305,16 @@ def _d(x: Optional[str], do_split: bool = True) -> Tuple[Optional[str], Optional r = [] response = Response() - _headers = state.get('headers', {}) - response.headers.update(_headers) - ctype = _headers.get('Content-Type', None) + response.headers.update(state.headers) + ctype = state.headers.get('Content-Type', None) if not ctype: - r, t = _fmt(r, accepter) + r, t = _fmt(r, accept) ctype = t response.text = b2u(r) response.size = len(r) response.content_type = ctype - cache_ttl = int(state.get('cache', 0)) - response.expires = datetime.now() + timedelta(seconds=cache_ttl) + response.expires = datetime.now() + timedelta(seconds=state.cache) return response except ResourceException as ex: import traceback diff --git a/src/pyff/builtins.py b/src/pyff/builtins.py index 3828839c..d2616dd3 100644 --- a/src/pyff/builtins.py +++ b/src/pyff/builtins.py @@ -25,7 +25,7 @@ from pyff.decorators import deprecated from pyff.exceptions import MetadataException from pyff.logs import get_log -from pyff.pipes import PipeException, PipelineCallback, Plumbing, pipe, registry +from pyff.pipes import PipeException, PipeState, PipelineCallback, Plumbing, pipe, registry from pyff.samlmd import ( annotate_entity, discojson_t, @@ -383,14 +383,13 @@ def when(req: Plumbing.Request, condition: str, *values): The condition operates on the state: if 'foo' is present in the state (with any value), then the something branch is followed. If 'bar' is present in the state with the value 'bill' then the other branch is followed. """ - c = req.state.get(condition, None) - if c is None: + if req.state.entry_name is None: log.debug(f'Condition {repr(condition)} not present in state {req.state}') - if c is not None and (not values or _any(values, c)): + if req.state.entry_name is not None and (not values or _any(values, req.state.entry_name)): if not isinstance(req.args, list): raise ValueError('Non-list arguments to "when" not allowed') - return Plumbing(pipeline=req.args, pid="%s.when" % req.plumbing.id).iprocess(req) + return Plumbing(pipeline=req.args, pid=f'{req.plumbing.id}.when').iprocess(req) return req.t @@ -768,9 +767,9 @@ def select(req: Plumbing.Request, *opts): entities = resolve_entities(args, lookup_fn=req.md.store.select) - if req.state.get('match', None): # TODO - allow this to be passed in via normal arguments + if req.state.match: # TODO - allow this to be passed in via normal arguments - match = req.state['match'] + match = req.state.match if isinstance(match, six.string_types): query = [match.lower()] @@ -1435,11 +1434,11 @@ def emit(req: Plumbing.Request, ctype="application/xml", *opts): if not isinstance(d, six.binary_type): d = d.encode("utf-8") m.update(d) - req.state['headers']['ETag'] = m.hexdigest() + req.state.headers['ETag'] = m.hexdigest() else: raise PipeException("Empty") - req.state['headers']['Content-Type'] = ctype + req.state.headers['Content-Type'] = ctype if six.PY2: d = six.u(d) return d @@ -1517,7 +1516,7 @@ def finalize(req: Plumbing.Request, *opts): if name is None or 0 == len(name): name = req.args.get('Name', None) if name is None or 0 == len(name): - name = req.state.get('url', None) + name = req.state.url if name and 'baseURL' in req.args: try: @@ -1569,7 +1568,7 @@ def finalize(req: Plumbing.Request, *opts): # TODO: offset can be None here, if validUntil is not a valid duration or ISO date # What is the right action to take then? if offset: - req.state['cache'] = int(total_seconds(offset) / 50) + req.state.cache = int(total_seconds(offset) / 50) cache_duration = req.args.get('cacheDuration', e.get('cacheDuration', None)) if cache_duration is not None and len(cache_duration) > 0: @@ -1578,7 +1577,7 @@ def finalize(req: Plumbing.Request, *opts): raise PipeException("Unable to parse %s as xs:duration" % cache_duration) e.set('cacheDuration', cache_duration) - req.state['cache'] = int(total_seconds(offset)) + req.state.cache = int(total_seconds(offset)) return req.t diff --git a/src/pyff/pipes.py b/src/pyff/pipes.py index 29abd427..15bdbe35 100644 --- a/src/pyff/pipes.py +++ b/src/pyff/pipes.py @@ -7,17 +7,25 @@ import functools import os import traceback -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Type +from typing import Union import yaml from apscheduler.schedulers.background import BackgroundScheduler from lxml.etree import Element, ElementTree +from pydantic import BaseModel, Field from pyff.logs import get_log from pyff.repo import MDRepository from pyff.store import SAMLStoreBase from pyff.utils import PyffException, is_text, resource_string +if TYPE_CHECKING: + from pyff.api import MediaAccept + + # Avoid static analysers flagging this import as unused + assert MediaAccept + log = get_log(__name__) __author__ = 'leifj' @@ -77,7 +85,7 @@ class PluginsRegistry(dict): def the_something_func(req,*opts): pass - Referencing this function as an entry_point using something = module:the_somethig_func in setup.py allows the + Referencing this function as an entry_point using something = module:the_something_func in setup.py allows the function to be referenced as 'something' in a pipeline. """ @@ -160,15 +168,17 @@ def __deepcopy__(self, memo: Any) -> PipelineCallback: # TODO: This seems... dangerous. What's the need for this? return self - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def __call__(self, t: ElementTree, state: Optional[PipeState] = None) -> Any: log.debug("{!s}: called".format(self.plumbing)) - t = args[0] + if state is None: + state = PipeState() if t is None: raise ValueError("PipelineCallback must be called with a parse-tree argument") + if not isinstance(state, PipeState): + raise ValueError(f'PipelineCallback called with invalid state ({type(state)}') try: - state = kwargs - state[self.entry_point] = True - log.debug("state: {}".format(repr(state))) + state.entry_name = self.entry_point + log.debug("state: {}".format(state)) return self.plumbing.process(self.req.md, store=self.store, state=state, t=t) except Exception as ex: log.debug(traceback.format_exc()) @@ -176,6 +186,19 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: raise ex +class PipeState(BaseModel): + batch: bool = False + entry_name: Optional[str] = None + headers: Dict[str, Any] = Field({}) + accept: Any = None # TODO: Re-arrange classes so that type 'MediaAccept' works + url: str = '' + select: str = '' + match: str = '' + path: str = '' + stats: Dict[str, Any] = Field({}) + cache: int = 0 # cache_ttl + + class Plumbing(object): """ A plumbing instance represents a basic processing chain for SAML metadata. A simple, yet reasonably complete example: @@ -201,7 +224,7 @@ class Plumbing(object): Running this plumbing would bake all metadata found in /var/metadata/registry and at http://md.example.com into an EntitiesDescriptor element with @Name http://example.com/metadata.xml, @cacheDuration set to 1hr and @validUntil - 1 day from the time the 'finalize' command was run. The tree woud be transformed using the "tidy" stylesheets and + 1 day from the time the 'finalize' command was run. The tree would be transformed using the "tidy" stylesheets and would then be signed (using signer.key) and finally published in /var/metadata/public/metadata.xml """ @@ -237,27 +260,25 @@ def __init__( self, pl: Plumbing, md: MDRepository, - t=None, - name=None, - args=None, - state: Optional[Dict[str, Any]] = None, - store=None, + state: Optional[PipeState] = None, + t: Optional[ElementTree] = None, + name: Optional[str] = None, + args: Optional[Union[str, Dict, List]] = None, + store: Optional[SAMLStoreBase] = None, scheduler: Optional[BackgroundScheduler] = None, raise_exceptions: bool = True, ): - if not state: - state = dict() if not args: args = [] self.plumbing: Plumbing = pl self.md: MDRepository = md self.t: ElementTree = t self._id: Optional[str] = None - self.name = name + self.name: Optional[str] = name self.args: Optional[Union[str, Dict, List]] = args - self.state: Dict[str, Any] = state + self.state: PipeState = state if state else PipeState() self.done: bool = False - self._store: SAMLStoreBase = store + self._store: Optional[SAMLStoreBase] = store self.scheduler: Optional[BackgroundScheduler] = scheduler self.raise_exceptions: bool = raise_exceptions self.exception: Optional[BaseException] = None @@ -337,8 +358,8 @@ def iprocess(self, req: Plumbing.Request) -> ElementTree: def process( self, md: MDRepository, + state: PipeState, args: Any = None, - state: Optional[Dict[str, Any]] = None, t: Optional[ElementTree] = None, store: Optional[SAMLStoreBase] = None, raise_exceptions: bool = True, @@ -357,9 +378,6 @@ def process( :param args: Pipeline arguments :return: The result of applying the processing pipeline to t. """ - if not state: - state = dict() - return Plumbing.Request( self, md, t=t, args=args, state=state, store=store, raise_exceptions=raise_exceptions, scheduler=scheduler ).process(self) diff --git a/src/pyff/resource.py b/src/pyff/resource.py index 196c9683..3acc4396 100644 --- a/src/pyff/resource.py +++ b/src/pyff/resource.py @@ -468,7 +468,13 @@ def parse(self, getter: Callable[[str], Response]) -> Deque[Resource]: if self.post: for cb in self.post: if self.t is not None: - self.t = cb(self.t, self.opts.dict()) + # TODO: This used to be + # self.t = cb(self.t, self.opts.dict()) + # but passing self.opts does not seem to be what the callback expected. + # Don't know what to do really. + from pyff.pipes import PipeState + + self.t = cb(self.t, PipeState()) if self.is_expired(): info.expired = True diff --git a/src/pyff/test/test_pipeline.py b/src/pyff/test/test_pipeline.py index 0145e013..5520bc59 100644 --- a/src/pyff/test/test_pipeline.py +++ b/src/pyff/test/test_pipeline.py @@ -11,7 +11,7 @@ from pyff import builtins from pyff.exceptions import MetadataException from pyff.parse import ParserException -from pyff.pipes import PipeException, Plumbing, plumbing +from pyff.pipes import PipeException, PipeState, Plumbing, plumbing from pyff.repo import MDRepository from pyff.resource import ResourceException from pyff.test import ExitException, SignerTestCase @@ -61,7 +61,7 @@ def run_pipeline(self, pl_name, ctx=None, md=None): template = templates.get_template(pl_name) with open(pipeline, "w") as fd: fd.write(template.render(ctx=ctx)) - res = plumbing(pipeline).process(md, state={'batch': True, 'stats': {}}) + res = plumbing(pipeline).process(md, PipeState(entry_name='batch')) os.unlink(pipeline) return res, md, ctx @@ -70,7 +70,7 @@ def exec_pipeline(self, pstr): p = yaml.safe_load(six.StringIO(pstr)) print("\n{}".format(yaml.dump(p))) pl = Plumbing(p, pid="test") - res = pl.process(md, state={'batch': True, 'stats': {}}) + res = pl.process(md, PipeState(entry_name='batch')) return res, md @classmethod diff --git a/src/pyff/test/test_simple_pipeline.py b/src/pyff/test/test_simple_pipeline.py index bb4362c3..78bbfeee 100644 --- a/src/pyff/test/test_simple_pipeline.py +++ b/src/pyff/test/test_simple_pipeline.py @@ -4,7 +4,7 @@ from mako.lookup import TemplateLookup from pyff.constants import NS -from pyff.pipes import plumbing +from pyff.pipes import PipeState, plumbing from pyff.repo import MDRepository from pyff.test import SignerTestCase @@ -24,8 +24,8 @@ def setUp(self): fd.write(self.signer_template.render(ctx=self)) with open(self.validator, "w") as fd: fd.write(self.validator_template.render(ctx=self)) - self.signer_result = plumbing(self.signer).process(self.md_signer, state={'batch': True, 'stats': {}}) - self.validator_result = plumbing(self.validator).process(self.md_validator, state={'batch': True, 'stats': {}}) + self.signer_result = plumbing(self.signer).process(self.md_signer, state=PipeState(batch=True)) + self.validator_result = plumbing(self.validator).process(self.md_validator, state=PipeState(batch=True)) def test_entityid_present(self): eids = [e.get('entityID') for e in self.md_signer.store]