diff --git a/src/saml2/assertion.py b/src/saml2/assertion.py index 4c0ab1511..69d54e9cd 100644 --- a/src/saml2/assertion.py +++ b/src/saml2/assertion.py @@ -488,6 +488,38 @@ def not_on_or_after(self, sp_entity_id): return in_a_while(**self.get_lifetime(sp_entity_id)) + @staticmethod + def _subject_id_or_pairwise_id( + ava, + required, + ): + """ + If both subject-id and pairwise-id are present in required attributes, check if both are present in + available attributes and if so default to pairwise-id only. + """ + if required is None: + return ava, None + + # check if both subject-id and pairwise-id are in required attributes + subject_id = None + pairwise_id = None + for item in required: + if item["name"] == "urn:oasis:names:tc:SAML:attribute:subject-id": + subject_id = item + if item["name"] == "urn:oasis:names:tc:SAML:attribute:pairwise-id": + pairwise_id = item + + # if both are in required attributes, check if both are in available attributes and if so remove subject-id + # from required attributes and ava + if subject_id and pairwise_id: + if all( + [friendly_name in ava for friendly_name in [subject_id["friendly_name"], pairwise_id["friendly_name"]]] + ): + required.pop(required.index(subject_id)) + ava.pop(subject_id["friendly_name"]) + + return ava, required + def filter(self, ava, sp_entity_id, mdstore=None, required=None, optional=None): """What attribute and attribute values returns depends on what the SP or the registration authority has said it wants in the request @@ -518,6 +550,10 @@ def filter(self, ava, sp_entity_id, mdstore=None, required=None, optional=None): subject_ava = ava.copy() + # make sure we don't assert both subject-id and pairwise-id if subject-id requirement is "any" + if self.metadata_store and self.metadata_store.subject_id_requirement_type(sp_entity_id) == "any": + subject_ava, required = self._subject_id_or_pairwise_id(subject_ava, required) + # entity category restrictions _ent_rest = self.get_entity_categories(sp_entity_id, mds=mdstore, required=required) if _ent_rest: diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py index 2ea9742f0..3b780298f 100644 --- a/src/saml2/mdstore.py +++ b/src/saml2/mdstore.py @@ -1304,14 +1304,17 @@ def attribute_requirement(self, entity_id, index=None): if entity_id in md_source: return md_source.attribute_requirement(entity_id, index) - def subject_id_requirement(self, entity_id): + def subject_id_requirement_type(self, entity_id): try: entity_attributes = self.entity_attributes(entity_id) except KeyError: - return [] + return "" subject_id_reqs = entity_attributes.get("urn:oasis:names:tc:SAML:profiles:subject-id:req") or [] - subject_id_req = next(iter(subject_id_reqs), None) + return next(iter(subject_id_reqs), None) + + def subject_id_requirement(self, entity_id): + subject_id_req = self.subject_id_requirement_type(entity_id) if subject_id_req == "any": return [ { diff --git a/tests/test_37_entity_categories.py b/tests/test_37_entity_categories.py index 894b03cf3..49c959b52 100644 --- a/tests/test_37_entity_categories.py +++ b/tests/test_37_entity_categories.py @@ -400,3 +400,39 @@ def test_filter_ava_refeds_personalized_access(): assert _eq(ava["eduPersonScopedAffiliation"], ["student@example.com"]) assert _eq(ava["eduPersonAssurance"], ["http://www.swamid.se/policy/assurance/al1"]) assert _eq(ava["schacHomeOrganization"], ["example.com"]) + + +def test_filter_subject_id_or_pairwise_id(): + entity_id = "https://esi-coco.example.edu/saml2/metadata/" + mds = MetadataStore(ATTRCONV, sec_config, disable_ssl_certificate_validation=True) + mds.imp([{"class": "saml2.mdstore.MetaDataFile", "metadata": [(full_path("entity_esi_and_coco_sp.xml"),)]}]) + + policy_conf = {"default": {"lifetime": {"minutes": 15}, "entity_categories": ["swamid"]}} + + policy = Policy(policy_conf, mds) + + ava = { + "subject-id": ["subject-id"], + "pairwise-id": ["pairwise-id"], + } + + required_attributes = [ + { + "__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute", + "name": "urn:oasis:names:tc:SAML:attribute:pairwise-id", + "name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + "friendly_name": "pairwise-id", + "is_required": "true", + }, + { + "__class__": "urn:oasis:names:tc:SAML:2.0:metadata&RequestedAttribute", + "name": "urn:oasis:names:tc:SAML:attribute:subject-id", + "name_format": "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + "friendly_name": "subject-id", + "is_required": "true", + }, + ] + + ava = policy.filter(ava, entity_id, required=required_attributes) + + assert _eq(list(ava.keys()), ["pairwise-id"])