Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
shiftinv committed Jul 31, 2022
2 parents 3d1a96d + 6769f80 commit 09fcf9b
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 58 deletions.
1 change: 1 addition & 0 deletions changelog/660.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove the internal ``fill_with_flags`` decorator for flags classes and use the built in :meth:`object.__init_subclass__` method.
1 change: 1 addition & 0 deletions changelog/671.doc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update fields listed in :func:`on_user_update` and :func:`on_member_update` docs.
3 changes: 1 addition & 2 deletions disnake/ext/commands/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@

from typing import TYPE_CHECKING

from disnake.flags import BaseFlags, alias_flag_value, all_flags_value, fill_with_flags, flag_value
from disnake.flags import BaseFlags, alias_flag_value, all_flags_value, flag_value

if TYPE_CHECKING:
from typing_extensions import Self

__all__ = ("ApplicationCommandSyncFlags",)


@fill_with_flags()
class ApplicationCommandSyncFlags(BaseFlags):
"""Controls the library's application command syncing policy.
Expand Down
54 changes: 25 additions & 29 deletions disnake/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,25 +117,6 @@ def all_flags_value(flags: Dict[str, int]) -> int:
return functools.reduce(operator.or_, flags.values())


def fill_with_flags(*, inverted: bool = False):
def decorator(cls: Type[BF]) -> Type[BF]:
cls.VALID_FLAGS = {}
for name, value in cls.__dict__.items():
if isinstance(value, flag_value):
value._parent = cls
cls.VALID_FLAGS[name] = value.flag

if inverted:
cls.DEFAULT_VALUE = all_flags_value(cls.VALID_FLAGS)
else:
cls.DEFAULT_VALUE = 0

return cls

return decorator


# n.b. flags must inherit from this and use the decorator above
class BaseFlags:
VALID_FLAGS: ClassVar[Dict[str, int]]
DEFAULT_VALUE: ClassVar[int]
Expand All @@ -151,6 +132,29 @@ def __init__(self, **kwargs: bool):
raise TypeError(f"{key!r} is not a valid flag name.")
setattr(self, key, value)

@classmethod
def __init_subclass__(cls, inverted: bool = False, no_fill_flags: bool = False):
# add a way to bypass filling flags, eg for ListBaseFlags.
if no_fill_flags:
return cls

# use the parent's current flags as a base if they exist
cls.VALID_FLAGS = getattr(cls, "VALID_FLAGS", {}).copy()

for name, value in cls.__dict__.items():
if isinstance(value, flag_value):
value._parent = cls
cls.VALID_FLAGS[name] = value.flag

if not cls.VALID_FLAGS:
raise RuntimeError(
"At least one flag must be defined in a BaseFlags subclass, or 'no_fill_flags' must be set to True"
)

cls.DEFAULT_VALUE = all_flags_value(cls.VALID_FLAGS) if inverted else 0

return cls

@classmethod
def _from_value(cls, value: int) -> Self:
self = cls.__new__(cls)
Expand Down Expand Up @@ -295,7 +299,7 @@ def _set_flag(self, o: int, toggle: bool) -> None:
raise TypeError(f"Value to set for {self.__class__.__name__} must be a bool.")


class ListBaseFlags(BaseFlags):
class ListBaseFlags(BaseFlags, no_fill_flags=True):
"""
A base class for flags that aren't powers of 2.
Instead, values are used as exponents to map to powers of 2 to avoid collisions,
Expand Down Expand Up @@ -330,8 +334,7 @@ def __repr__(self) -> str:
return f"<{self.__class__.__name__} values={self.values}>"


@fill_with_flags(inverted=True)
class SystemChannelFlags(BaseFlags):
class SystemChannelFlags(BaseFlags, inverted=True):
"""
Wraps up a Discord system channel flag value.
Expand Down Expand Up @@ -466,7 +469,6 @@ def join_notification_replies(self):
return 8


@fill_with_flags()
class MessageFlags(BaseFlags):
"""
Wraps up a Discord Message flag value.
Expand Down Expand Up @@ -621,7 +623,6 @@ def failed_to_mention_roles_in_thread(self):
return 1 << 8


@fill_with_flags()
class PublicUserFlags(BaseFlags):
"""
Wraps up the Discord User Public flags.
Expand Down Expand Up @@ -814,7 +815,6 @@ def all(self) -> List[UserFlags]:
return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)]


@fill_with_flags()
class Intents(BaseFlags):
"""
Wraps up a Discord gateway intent flag.
Expand Down Expand Up @@ -1449,7 +1449,6 @@ def automod(self):
return (1 << 20) | (1 << 21)


@fill_with_flags()
class MemberCacheFlags(BaseFlags):
"""Controls the library's cache policy when it comes to members.
Expand Down Expand Up @@ -1632,7 +1631,6 @@ def _voice_only(self):
return self.value == 1


@fill_with_flags()
class ApplicationFlags(BaseFlags):
"""
Wraps up the Discord Application flags.
Expand Down Expand Up @@ -1777,7 +1775,6 @@ def gateway_message_content_limited(self):
return 1 << 19


@fill_with_flags()
class ChannelFlags(BaseFlags):
"""Wraps up the Discord Channel flags.
Expand Down Expand Up @@ -1872,7 +1869,6 @@ def pinned(self):
return 1 << 1


@fill_with_flags()
class AutoModKeywordPresets(ListBaseFlags):
"""
Wraps up the pre-defined auto moderation keyword lists, provided by Discord.
Expand Down
3 changes: 1 addition & 2 deletions disnake/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Iterator, Optional, Set, Tuple

from .flags import BaseFlags, alias_flag_value, fill_with_flags, flag_value
from .flags import BaseFlags, alias_flag_value, flag_value

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -68,7 +68,6 @@ def wrapped(cls):
return wrapped


@fill_with_flags()
class Permissions(BaseFlags):
"""Wraps up the Discord permission value.
Expand Down
48 changes: 25 additions & 23 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -362,23 +362,6 @@ This section documents events related to :class:`Client` and its connectivity to
WebSocket library. It can be :class:`bytes` to denote a binary
message or :class:`str` to denote a regular text message.

.. function:: on_user_update(before, after)

Called when a :class:`User` updates their profile.

This is called when one or more of the following things change:

- avatar
- username
- discriminator

This requires :attr:`Intents.members` to be enabled.

:param before: The updated user's old info.
:type before: :class:`User`
:param after: The updated user's updated info.
:type after: :class:`User`

Channels/Threads
~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -851,21 +834,22 @@ Members

.. function:: on_member_update(before, after)

Called when a :class:`Member` updates their profile.
Called when a :class:`Member` is updated.

This is called when one or more of the following things change, but is not limited to:

- avatar (guild-specific)
- current_timeout
- nickname
- roles
- pending
- timeout
- guild specific avatar
- premium_since
- roles

This requires :attr:`Intents.members` to be enabled.

:param before: The updated member's old info.
:param before: The member's old info.
:type before: :class:`Member`
:param after: The updated member's updated info.
:param after: The member's updated info.
:type after: :class:`Member`

.. function:: on_member_ban(guild, user)
Expand Down Expand Up @@ -910,6 +894,24 @@ Members
:param after: The updated member's updated info.
:type after: :class:`Member`

.. function:: on_user_update(before, after)

Called when a :class:`User` is updated.

This is called when one or more of the following things change, but is not limited to:

- avatar
- discriminator
- name
- public_flags

This requires :attr:`Intents.members` to be enabled.

:param before: The user's old info.
:type before: :class:`User`
:param after: The user's updated info.
:type after: :class:`User`


Scheduled Events
++++++++++++++++
Expand Down
3 changes: 1 addition & 2 deletions tests/test_flags.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pytest

from disnake.flags import ListBaseFlags, fill_with_flags, flag_value
from disnake.flags import ListBaseFlags, flag_value


@fill_with_flags()
class _ListFlags(ListBaseFlags):
@flag_value
def flag1(self):
Expand Down

0 comments on commit 09fcf9b

Please sign in to comment.