Skip to content

Commit

Permalink
Base security predicates on users memberships
Browse files Browse the repository at this point in the history
Problem: four groups-related security predicate functions are triggering
DB queries that unnecessarily fetch *all* of a group's memberships and
members from the DB. For example:

    @requires(authenticated_user, group_found)
    def group_has_user_as_owner(identity, context):
        return any(
            owner.id == identity.user.id
            for owner in context.group.get_members(GroupMembershipRoles.OWNER)
        )

The call to `context.group.get_members()` (`models.Group.get_members()`)
will iterate over the `models.Group.memberships` relationship which will
load all of the group's memberships from the DB. Because of
eager-loading on the memberships relationship this will also load all of
the group's _members_ (i.e. the `models.User` objects) from the DB as
well, though it will at least avoid doing a separate query for each
member.

The ideal solution would be to do a smaller DB query to check whether a
given user is an owner of a given group without loading all the group's
members or memberships, something like this:

    select(GroupMembership)
    .where(GroupMembership.group == group)
    .where(GroupMembership.user == user)
    .where(GroupMembership.roles.contains([GroupMembershipRoles.OWNER])

But this can't easily be done because neither `Group.get_members()` or
the security predicate functions that call it have access to the DB
session. I've tried a few different ways of refactoring the code to make
it possible to do a query like the above but ran into difficulties, see:
#9064 (comment)

So this commit implements an alternative approach: instead of iterating
over the group's memberships iterate over the *user*'s memberships
looking for one with a matching group and role. This should be much more
efficient because users typically have far fewer memberships than groups
do, see:
https://hypothes-is.slack.com/archives/C4K6M7P5E/p1731507180110099?thread_ts=1729159093.383279&cid=C4K6M7P5E

The `LongLivedUser.groups` field is removed and replaced with
`LongLivedUser.memberships` because predicates need to make decisions
based on a user's roles and groups don't have roles, memberships do.

`LongLivedUser.from_model(user)` has to be changed to create a
`LongLivedUser` with a list of `LongLivedMembership`'s based on
`user.groups` which is a little more complicated so I also extended the
unittest for this method.

The `group_has_user_as_owner()`, `group_has_user_as_admin()`,
`group_has_user_as_moderator()` and `group_has_user_as_member()`
security predicates are then rewritten based on
`identity.user.memberships` (the new `LongLivedUser.memberships`
attribute) instead of calling `Group.get_members()`.

The `Group.get_members()` method is removed because it's no longer used.
  • Loading branch information
seanh committed Nov 21, 2024
1 parent 04bdf88 commit 7f53416
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 185 deletions.
15 changes: 1 addition & 14 deletions h/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,20 +198,7 @@ def members(self) -> tuple[User, ...]:
be registered with SQLAlchemy and the changes wouldn't be saved to the
DB. So this is a read-only property that returns an immutable tuple.
"""
return self.get_members()

def get_members(self, role: GroupMembershipRoles | None = None) -> tuple[User, ...]:
"""Return a tuple of this group's members."""
if role:
memberships = [
membership
for membership in self.memberships
if role in membership.roles
]
else:
memberships = self.memberships

return tuple(membership.user for membership in memberships)
return tuple(membership.user for membership in self.memberships)

scopes = sa.orm.relationship(
"GroupScope", backref="group", cascade="all, delete-orphan"
Expand Down
32 changes: 28 additions & 4 deletions h/security/identity.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
"""Data classes used to represent authenticated users."""

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import List, Optional, Self

from h.models import AuthClient, Group, User


@dataclass
class LongLivedMembership:
"""A membership object that isn't connected to SQLAlchemy."""

group: "LongLivedGroup"
user: "LongLivedUser"
roles: List[str]


@dataclass
class LongLivedGroup:
"""
Expand Down Expand Up @@ -35,23 +44,38 @@ class LongLivedUser:
id: int
userid: str
authority: str
groups: List[LongLivedGroup]
staff: bool
admin: bool
memberships: List[LongLivedMembership] = field(default_factory=list)

@classmethod
def from_model(cls, user: User):
"""Create a long lived model from a DB model object."""

return LongLivedUser(
long_lived_user = LongLivedUser(
id=user.id,
userid=user.userid,
authority=user.authority,
admin=user.admin,
staff=user.staff,
groups=[LongLivedGroup.from_model(group) for group in user.groups],
)

groups = {}

for membership in user.memberships:
groups.setdefault(
membership.group.id, LongLivedGroup.from_model(membership.group)
)
long_lived_user.memberships.append(
LongLivedMembership(
group=groups[membership.group.id],
user=long_lived_user,
roles=membership.roles,
)
)

return long_lived_user


@dataclass
class LongLivedAuthClient:
Expand Down
23 changes: 12 additions & 11 deletions h/security/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,31 +144,32 @@ def group_created_by_user(identity, context):

@requires(authenticated_user, group_found)
def group_has_user_as_owner(identity, context):
return any(
owner.id == identity.user.id
for owner in context.group.get_members(GroupMembershipRoles.OWNER)
)
return _group_has_user_as_role(identity, context, GroupMembershipRoles.OWNER)


@requires(authenticated_user, group_found)
def group_has_user_as_admin(identity, context):
return any(
admin.id == identity.user.id
for admin in context.group.get_members(GroupMembershipRoles.ADMIN)
)
return _group_has_user_as_role(identity, context, GroupMembershipRoles.ADMIN)


@requires(authenticated_user, group_found)
def group_has_user_as_moderator(identity, context):
return _group_has_user_as_role(identity, context, GroupMembershipRoles.MODERATOR)


def _group_has_user_as_role(identity, context, role):
return any(
moderator.id == identity.user.id
for moderator in context.group.get_members(GroupMembershipRoles.MODERATOR)
membership.group.id == context.group.id and role in membership.roles
for membership in identity.user.memberships
)


@requires(authenticated_user, group_found)
def group_has_user_as_member(identity, context):
return any(member.id == identity.user.id for member in context.group.members)
return any(
membership.group.id == context.group.id
for membership in identity.user.memberships
)


@requires(authenticated_user, group_found)
Expand Down
27 changes: 0 additions & 27 deletions tests/unit/h/models/group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,33 +251,6 @@ def test_members_is_immutable(factories):
group.members.append(new_member)


def test_get_members(factories):
group = factories.Group()
owners = factories.User.create_batch(2)
admins = factories.User.create_batch(2)
moderators = factories.User.create_batch(2)
members = factories.User.create_batch(2)
group.memberships.extend(
models.GroupMembership(user=owner, roles=["owner"]) for owner in owners
)
group.memberships.extend(
models.GroupMembership(user=admin, roles=["admin"]) for admin in admins
)
group.memberships.extend(
models.GroupMembership(user=moderator, roles=["moderator"])
for moderator in moderators
)
group.memberships.extend(
models.GroupMembership(user=member, roles=["member"]) for member in members
)

assert group.get_members(role="owner") == (*owners,)
assert group.get_members(role="admin") == (*admins,)
assert group.get_members(role="moderator") == (*moderators,)
assert group.get_members(role="member") == (*members,)
assert group.get_members() == (*owners, *admins, *moderators, *members)


class TestGroupMembership:
def test_defaults(self, db_session, user, group):
membership = models.GroupMembership(user_id=user.id, group_id=group.id)
Expand Down
55 changes: 39 additions & 16 deletions tests/unit/h/security/identity_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from unittest.mock import sentinel
from unittest.mock import call, sentinel

import pytest
from h_matchers import Any

from h.models import GroupMembership
from h.models import GroupMembership, GroupMembershipRoles
from h.security.identity import (
Identity,
LongLivedAuthClient,
LongLivedGroup,
LongLivedMembership,
LongLivedUser,
)

Expand All @@ -24,22 +25,45 @@ def test_from_models(self, factories):


class TestLongLivedUser:
def test_from_models(self, factories, LongLivedGroup):
group = factories.Group.build()
user = factories.User.build(memberships=[GroupMembership(group=group)])
def test_from_model(self, db_session, factories, LongLivedGroup):
groups = factories.Group.create_batch(size=2)
user = factories.User(
memberships=[
GroupMembership(group=groups[0], roles=[GroupMembershipRoles.MEMBER]),
GroupMembership(group=groups[1], roles=[GroupMembershipRoles.ADMIN]),
]
)
LongLivedGroup.from_model.side_effect = [
sentinel.long_lived_group_1,
sentinel.long_lived_group_2,
]
db_session.flush()

model = LongLivedUser.from_model(user)

LongLivedGroup.from_model.assert_called_once_with(group)
assert model == Any.instance_of(LongLivedUser).with_attrs(
{
"id": user.id,
"userid": user.userid,
"authority": user.authority,
"admin": user.admin,
"staff": user.staff,
"groups": [LongLivedGroup.from_model.return_value],
}
assert LongLivedGroup.from_model.call_args_list == [
call(groups[0]),
call(groups[1]),
]

assert model == LongLivedUser(
id=user.id,
userid=user.userid,
authority=user.authority,
admin=user.admin,
staff=user.staff,
memberships=[
LongLivedMembership(
group=sentinel.long_lived_group_1,
user=model,
roles=[GroupMembershipRoles.MEMBER],
),
LongLivedMembership(
group=sentinel.long_lived_group_2,
user=model,
roles=[GroupMembershipRoles.ADMIN],
),
],
)

@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -94,7 +118,6 @@ def test_from_models_with_None(self, LongLivedUser, LongLivedAuthClient):
id=sentinel.id,
userid=sentinel.userid,
authority=sentinel.authority,
groups=[],
staff=False,
admin=False,
)
Expand Down
Loading

0 comments on commit 7f53416

Please sign in to comment.