diff --git a/.mina/core.toml b/.mina/core.toml index 64c9362f..a0283af2 100644 --- a/.mina/core.toml +++ b/.mina/core.toml @@ -19,7 +19,8 @@ dependencies = [ "graia-amnesia", "loguru", "launart", - "creart" + "creart", + "elaina-flywheel" ] description = "" license = {text = "MIT"} diff --git a/.vscode/settings.json b/.vscode/settings.json index d1094f7e..fa5df180 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -25,4 +25,5 @@ "python.analysis.extraPaths": [ "__pypackages__/3.10/lib" ], + "python.languageServer": "None", } diff --git a/avilla/console/backend.py b/avilla/console/backend.py index 8812928f..909d94b0 100644 --- a/avilla/console/backend.py +++ b/avilla/console/backend.py @@ -1,9 +1,11 @@ from __future__ import annotations +from functools import cached_property import sys from contextlib import suppress from typing import TYPE_CHECKING +from flywheel import InstanceContext from loguru import logger from nonechat import Backend, Frontend from nonechat.info import Event @@ -71,9 +73,16 @@ def on_console_unmount(self): logger.success("Console exit.") logger.warning("Press Ctrl-C for Application exit") + @cached_property + def event_instance_ctx(self): + res = InstanceContext() + res.instances[type(self.account)] = self.account + res.instances[type(self._service.protocol)] = self._service.protocol + return res + async def post_event(self, event: Event): - with suppress(NotImplementedError): - res = await ConsoleCapability(self.account.staff).event_callback(event) + with self.event_instance_ctx.scope(), suppress(NotImplementedError): + res = await ConsoleCapability.event_callback(event) self._service.protocol.post_event(res) return diff --git a/avilla/console/bases.py b/avilla/console/bases.py new file mode 100644 index 00000000..cfec1123 --- /dev/null +++ b/avilla/console/bases.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from flywheel import InstanceOf +from .protocol import ConsoleProtocol +from .account import ConsoleAccount + +class InstanceOfProtocol: + protocol = InstanceOf(ConsoleProtocol) + +class InstanceOfAccount(InstanceOfProtocol): + account = InstanceOf(ConsoleAccount) diff --git a/avilla/console/capability.py b/avilla/console/capability.py index 76dd7a50..e2407a6c 100644 --- a/avilla/console/capability.py +++ b/avilla/console/capability.py @@ -1,29 +1,62 @@ from __future__ import annotations -from typing import Any, TypeVar +from typing import Protocol, TypeVar from graia.amnesia.message import Element as GraiaElement from nonechat.info import Event as ConsoleEvent from nonechat.message import Element as ConsoleElement from avilla.core.event import AvillaEvent -from avilla.core.ryanvk.collector.application import ApplicationCollector -from graia.ryanvk import Fn, TypeOverload +from flywheel import TypeOverload, Fn, FnCompose, FnRecord, OverloadRecorder -CE = TypeVar("CE", bound=ConsoleElement) -GE = TypeVar("GE", bound=GraiaElement) -CV = TypeVar("CV", bound=ConsoleEvent) +CE = TypeVar("CE", bound=ConsoleElement, contravariant=True) +GE = TypeVar("GE", bound=GraiaElement, contravariant=True) +CV = TypeVar("CV", bound=ConsoleEvent, contravariant=True) -class ConsoleCapability((m := ApplicationCollector())._): - @Fn.complex({TypeOverload(): ["event"]}) - async def event_callback(self, event: Any) -> AvillaEvent: - ... +# NOTE: 全使用 global_collect 或是 scoped_collect.globals() 最好。 - @Fn.complex({TypeOverload(): ["element"]}) - async def deserialize_element(self, element: Any) -> GraiaElement: - ... +class ConsoleCapability: + @Fn.declare + class event_callback(FnCompose): + type = TypeOverload("type") - @Fn.complex({TypeOverload(): ["element"]}) - async def serialize_element(self, element: Any) -> ConsoleElement: - ... + async def call(self, record: FnRecord, event: ConsoleEvent): + from loguru import logger + logger.info(event) + entities = self.load(self.type.dig(record, event)) + return await entities.first(event=event) + + class shapecall(Protocol[CV]): + async def __call__(self, event: CV) -> AvillaEvent: ... + + def collect(self, recorder: OverloadRecorder[shapecall[CV]], event: type[CV]): + recorder.use(self.type, event) + + @Fn.declare + class deserialize_element(FnCompose): + type = TypeOverload("type") + + async def call(self, record: FnRecord, element: ConsoleElement): + entities = self.load(self.type.dig(record, element)) + return await entities.first(element=element) + + class shapecall(Protocol[CE]): + async def __call__(self, element: CE) -> GraiaElement: ... + + def collect(self, recorder: OverloadRecorder[shapecall[CE]], element: type[CE]): + recorder.use(self.type, element) + + @Fn.declare + class serialize_element(FnCompose): + type = TypeOverload("type") + + async def call(self, record: FnRecord, element: GraiaElement): + entities = self.load(self.type.dig(record, element)) + return await entities.first(element=element) + + class shapecall(Protocol[GE]): + async def __call__(self, element: GE) -> ConsoleElement: ... + + def collect(self, recorder: OverloadRecorder[shapecall[GE]], element: type[GE]): + recorder.use(self.type, element) diff --git a/avilla/console/perform/action/activity.py b/avilla/console/perform/action/activity.py index 2a9a65a7..d3c80be2 100644 --- a/avilla/console/perform/action/activity.py +++ b/avilla/console/perform/action/activity.py @@ -2,18 +2,17 @@ from typing import TYPE_CHECKING -from avilla.core.ryanvk.collector.account import AccountCollector from avilla.core.selector import Selector -from avilla.standard.core.activity import ActivityTrigger +from avilla.standard.core.activity import start_activity +from flywheel import scoped_collect +from avilla.console.bases import InstanceOfAccount if TYPE_CHECKING: from avilla.console.account import ConsoleAccount # noqa from avilla.console.protocol import ConsoleProtocol # noqa -class ConsoleActivityActionPerform((m := AccountCollector["ConsoleProtocol", "ConsoleAccount"]())._): - m.namespace = "avilla.protocol/console::action/activity" - - @m.entity(ActivityTrigger.trigger, target="land.user.activity(bell)") +class ConsoleActivityActionPerform(m := scoped_collect.env().target, InstanceOfAccount, static=True): + @m.impl(start_activity, target="land.user.activity(bell)") async def bell(self, target: Selector | None = None): await self.account.client.call("bell", {}) diff --git a/avilla/console/perform/action/message.py b/avilla/console/perform/action/message.py index 06317a1a..d4cbb63f 100644 --- a/avilla/console/perform/action/message.py +++ b/avilla/console/perform/action/message.py @@ -12,19 +12,18 @@ from avilla.console.capability import ConsoleCapability from avilla.core.context import Context from avilla.core.message import Message -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.console.bases import InstanceOfAccount from avilla.core.selector import Selector -from avilla.standard.core.message import MessageSend, MessageSent +from avilla.standard.core.message import MessageSent, send_message +from flywheel import scoped_collect if TYPE_CHECKING: from avilla.console.account import ConsoleAccount # noqa from avilla.console.protocol import ConsoleProtocol # noqa -class ConsoleMessageActionPerform((m := AccountCollector["ConsoleProtocol", "ConsoleAccount"]())._): - m.namespace = "avilla.protocol/console::action/message" - - @m.entity(MessageSend.send, target="land.user") +class ConsoleMessageActionPerform(m := scoped_collect.env().target, InstanceOfAccount, static=True): + @m.impl(send_message, target="land.user") async def send_console_message( self, target: Selector, @@ -32,10 +31,8 @@ async def send_console_message( *, reply: Selector | None = None, ) -> Selector: - if TYPE_CHECKING: - assert isinstance(self.protocol, ConsoleProtocol) serialized_msg = ConsoleMessage( - [await ConsoleCapability(self.account.staff).serialize_element(i) for i in message] + [await ConsoleCapability.serialize_element(i) for i in message] ) await self.account.client.call( diff --git a/avilla/console/perform/action/profile.py b/avilla/console/perform/action/profile.py index d1e5352d..0a6ae8b2 100644 --- a/avilla/console/perform/action/profile.py +++ b/avilla/console/perform/action/profile.py @@ -2,24 +2,25 @@ from typing import TYPE_CHECKING -from avilla.core.ryanvk.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.profile import Nick, Summary +from avilla.console.bases import InstanceOfAccount +from flywheel import scoped_collect + +from avilla.core.builtins.capability import CoreCapability if TYPE_CHECKING: from avilla.console.account import ConsoleAccount # noqa from avilla.console.protocol import ConsoleProtocol # noqa -class ConsoleProfileActionPerform((m := AccountCollector["ConsoleProtocol", "ConsoleAccount"]())._): - m.namespace = "avilla.protocol/console::action/profile" - - @m.pull("lang.user", Nick) - async def get_console_nick(self, target: Selector, route: type[Nick]) -> Nick: +class ConsoleProfileActionPerform(m := scoped_collect.env().target, InstanceOfAccount, static=True): + @m.impl(CoreCapability.pull, "lang.user", Nick) + async def get_console_nick(self, target: Selector) -> Nick: console = self.account.client.storage.current_user return Nick(console.nickname, console.nickname, "") - @m.pull("lang.user", Summary) - async def get_summary(self, target: Selector, route: type[Summary]) -> Summary: + @m.impl(CoreCapability.pull, "lang.user", Summary) + async def get_summary(self, target: Selector) -> Summary: console = self.account.client.storage.current_user return Summary(console.nickname, console.nickname) diff --git a/avilla/console/perform/context.py b/avilla/console/perform/context.py index 8b6d0a7d..b7c21e2f 100644 --- a/avilla/console/perform/context.py +++ b/avilla/console/perform/context.py @@ -4,18 +4,17 @@ from avilla.core.builtins.capability import CoreCapability from avilla.core.context import Context -from avilla.core.ryanvk.collector.account import AccountCollector +from ..bases import InstanceOfAccount from avilla.core.selector import Selector +from flywheel import scoped_collect if TYPE_CHECKING: from avilla.console.account import ConsoleAccount # noqa from avilla.console.protocol import ConsoleProtocol # noqa -class ConsoleContextPerform((m := AccountCollector["ConsoleProtocol", "ConsoleAccount"]())._): - m.namespace = "avilla.protocol/console::action/get_context" - - @CoreCapability.get_context.collect(m, target="land.user") +class ConsoleContextPerform(m := scoped_collect.env().target, InstanceOfAccount, static=True): + @m.impl(CoreCapability.get_context, target="land.user") def get_context_from_channel(self, target: Selector, *, via: Selector | None = None): return Context( account=self.account, @@ -25,10 +24,10 @@ def get_context_from_channel(self, target: Selector, *, via: Selector | None = N selft=self.account.route, ) - @m.entity(CoreCapability.channel, target="land.user") + @m.impl(CoreCapability.channel, target="land.user") def channel_from_channel(self, target: Selector): return target["user"] - @m.entity(CoreCapability.user, target="land.user") + @m.impl(CoreCapability.user, target="land.user") def user_from_friend(self, target: Selector): return target["user"] diff --git a/avilla/console/perform/event/message.py b/avilla/console/perform/event/message.py index e8d342d0..8b6a774d 100644 --- a/avilla/console/perform/event/message.py +++ b/avilla/console/perform/event/message.py @@ -9,23 +9,23 @@ from avilla.console.capability import ConsoleCapability from avilla.core.context import Context from avilla.core.message import Message -from avilla.core.ryanvk.collector.account import AccountCollector + from avilla.core.selector import Selector from avilla.standard.core.message import MessageReceived +from flywheel import scoped_collect +from avilla.console.bases import InstanceOfAccount if TYPE_CHECKING: from avilla.console.account import ConsoleAccount # noqa from avilla.console.protocol import ConsoleProtocol # noqa -class ConsoleEventMessagePerform((m := AccountCollector["ConsoleProtocol", "ConsoleAccount"]())._): - m.namespace = "avilla.protocol/console::event/message" - - @m.entity(ConsoleCapability.event_callback, event=MessageEvent) +class ConsoleEventMessagePerform(m := scoped_collect.globals().target, InstanceOfAccount, static=True): + @m.impl(ConsoleCapability.event_callback, event=MessageEvent) async def console_message(self, event: MessageEvent): console = Selector().land(self.account.route["land"]).user(str(event.user.id)) message = MessageChain( - [await ConsoleCapability(self.account.staff).deserialize_element(i) for i in event.message.content] + [await ConsoleCapability.deserialize_element(i) for i in event.message.content] ) context = Context( account=self.account, diff --git a/avilla/console/perform/message/deserialize.py b/avilla/console/perform/message/deserialize.py index 716fa276..8b33735b 100644 --- a/avilla/console/perform/message/deserialize.py +++ b/avilla/console/perform/message/deserialize.py @@ -10,27 +10,27 @@ from avilla.console.capability import ConsoleCapability from avilla.console.element import Markdown, Markup from avilla.core.elements import Face, Text -from avilla.core.ryanvk.collector.application import ApplicationCollector +from flywheel import global_collect +@global_collect +@ConsoleCapability.deserialize_element.impl(element=CslText) +async def text(element: CslText) -> Text: + return Text(element.text) -class ConsoleMessageDeserializePerform((m := ApplicationCollector())._): - m.namespace = "avilla.protocol/console::message" - m.identify = "deserialize" - # LINK: https://github.com/microsoft/pyright/issues/5409 +@global_collect +@ConsoleCapability.deserialize_element.impl(element=CslEmoji) +async def emoji(element: CslEmoji) -> Face: + return Face(element.name) - @m.entity(ConsoleCapability.deserialize_element, element=CslText) - async def text(self, element: CslText) -> Text: - return Text(element.text) - @m.entity(ConsoleCapability.deserialize_element, element=CslEmoji) - async def emoji(self, element: CslEmoji) -> Face: - return Face(element.name) +@global_collect +@ConsoleCapability.deserialize_element.impl(element=CslMarkup) +async def markup(element: CslMarkup) -> Markup: + return Markup(**asdict(element)) - @m.entity(ConsoleCapability.deserialize_element, element=CslMarkup) - async def markup(self, element: CslMarkup) -> Markup: - return Markup(**asdict(element)) - @m.entity(ConsoleCapability.deserialize_element, element=CslMarkdown) - async def markdown(self, element: CslMarkdown) -> Markdown: - return Markdown(**asdict(element)) +@global_collect +@ConsoleCapability.deserialize_element.impl(element=CslMarkdown) +async def markdown(element: CslMarkdown) -> Markdown: + return Markdown(**asdict(element)) diff --git a/avilla/console/perform/message/serialize.py b/avilla/console/perform/message/serialize.py index 3325b95b..54bc3ead 100644 --- a/avilla/console/perform/message/serialize.py +++ b/avilla/console/perform/message/serialize.py @@ -11,31 +11,32 @@ from avilla.console.capability import ConsoleCapability from avilla.console.element import Markdown, Markup from avilla.core.elements import Face, Text -from avilla.core.ryanvk.collector.account import AccountCollector +from flywheel import global_collect if TYPE_CHECKING: from ...account import ConsoleAccount # noqa from ...protocol import ConsoleProtocol # noqa -class ConsoleMessageSerializePerform((m := AccountCollector["ConsoleProtocol", "ConsoleAccount"]())._): - m.namespace = "avilla.protocol/console::message" - m.identify = "serialize" +@global_collect +@ConsoleCapability.serialize_element.impl(element=Text) +async def text(element: Text): + return CslText(element.text) - # LINK: https://github.com/microsoft/pyright/issues/5409 - @m.entity(ConsoleCapability.serialize_element, element=Text) - async def text(self, element: Text): - return CslText(element.text) +@global_collect +@ConsoleCapability.serialize_element.impl(element=Face) +async def emoji(element: Face): + return CslEmoji(element.id) - @m.entity(ConsoleCapability.serialize_element, element=Face) - async def emoji(self, element: Face): - return CslEmoji(element.id) - @m.entity(ConsoleCapability.serialize_element, element=Markup) - async def markup(self, element: Markup): - return CslMarkup(**asdict(element)) +@global_collect +@ConsoleCapability.serialize_element.impl(element=Markup) +async def markup(element: Markup): + return CslMarkup(**asdict(element)) - @m.entity(ConsoleCapability.serialize_element, element=Markdown) - async def markdown(self, element: Markdown): - return CslMarkdown(**asdict(element)) + +@global_collect +@ConsoleCapability.serialize_element.impl(element=Markdown) +async def markdown(element: Markdown): + return CslMarkdown(**asdict(element)) diff --git a/avilla/console/protocol.py b/avilla/console/protocol.py index 1a4e7eb0..25e7af55 100644 --- a/avilla/console/protocol.py +++ b/avilla/console/protocol.py @@ -2,43 +2,35 @@ from avilla.core.application import Avilla from avilla.core.protocol import BaseProtocol -from graia.ryanvk import merge, ref +from avilla.core.utilles import cachedstatic +from flywheel import CollectContext from .service import ConsoleService -def import_perform(): - # isort: off - - import avilla.console.perform.action.activity # noqa - import avilla.console.perform.action.message - import avilla.console.perform.action.profile - import avilla.console.perform.context - import avilla.console.perform.event.message - import avilla.console.perform.message.deserialize - import avilla.console.perform.message.serialize # noqa - - class ConsoleProtocol(BaseProtocol): service: ConsoleService name: str - import_perform() - artifacts = { - **merge( - ref("avilla.protocol/console::action/activity"), - ref("avilla.protocol/console::action/message"), - ref("avilla.protocol/console::action/profile"), - ref("avilla.protocol/console::action/get_context"), - ref("avilla.protocol/console::message", "deserialize"), - ref("avilla.protocol/console::message", "serialize"), - ref("avilla.protocol/console::event/message"), - ), - } + + @cachedstatic + def artifacts(): + with CollectContext().collect_scope() as collect_context: + # isort: off + import avilla.console.perform.action.activity # noqa + import avilla.console.perform.action.message + import avilla.console.perform.action.profile + import avilla.console.perform.context + import avilla.console.perform.event.message + import avilla.console.perform.message.deserialize + import avilla.console.perform.message.serialize # noqa + + return collect_context def __init__(self, name: str = "robot"): self.name = name def ensure(self, avilla: Avilla): + self.artifacts # access at last 1 time. self.avilla = avilla self.service = ConsoleService(self) avilla.launch_manager.add_component(self.service) diff --git a/avilla/core/__init__.py b/avilla/core/__init__.py index d20dc230..d67413cf 100644 --- a/avilla/core/__init__.py +++ b/avilla/core/__init__.py @@ -75,32 +75,29 @@ AvillaLifecycleEvent as AvillaLifecycleEvent, ) from avilla.standard.core.common import Count as Count -from avilla.standard.core.message import MessageEdit as MessageEdit from avilla.standard.core.message import MessageEdited as MessageEdited from avilla.standard.core.message import MessageReceived as MessageReceived -from avilla.standard.core.message import MessageRevoke as MessageRevoke from avilla.standard.core.message import MessageRevoked as MessageRevoked -from avilla.standard.core.message import MessageSend as MessageSend + from avilla.standard.core.message import MessageSent as MessageSent -from avilla.standard.core.privilege import BanCapability as BanCapability + from avilla.standard.core.privilege import BanInfo as BanInfo -from avilla.standard.core.privilege import MuteAllCapability as MuteAllCapability -from avilla.standard.core.privilege import MuteCapability as MuteCapability + from avilla.standard.core.privilege import MuteInfo as MuteInfo from avilla.standard.core.privilege import Privilege as Privilege -from avilla.standard.core.privilege import PrivilegeCapability as PrivilegeCapability + from avilla.standard.core.profile import Nick as Nick -from avilla.standard.core.profile import NickCapability as NickCapability + from avilla.standard.core.profile import Summary as Summary -from avilla.standard.core.profile import SummaryCapability as SummaryCapability -from avilla.standard.core.relation import SceneCapability as SceneCapability + + from avilla.standard.core.request import Answers as Answers from avilla.standard.core.request import Comment as Comment from avilla.standard.core.request import Questions as Questions from avilla.standard.core.request import Reason as Reason from avilla.standard.core.request import RequestAccepted as RequestAccepted from avilla.standard.core.request import RequestCancelled as RequestCancelled -from avilla.standard.core.request import RequestCapability as RequestCapability + from avilla.standard.core.request import RequestEvent as RequestEvent from avilla.standard.core.request import RequestIgnored as RequestIgnored from avilla.standard.core.request import RequestReceived as RequestReceived diff --git a/avilla/core/account.py b/avilla/core/account.py index 5d2791c6..f73c8c61 100644 --- a/avilla/core/account.py +++ b/avilla/core/account.py @@ -4,9 +4,10 @@ from typing import TYPE_CHECKING, Any from statv import Stats, Statv +from flywheel import CollectContext -from avilla.core.ryanvk.staff import Staff from avilla.core.selector import Selector +from avilla.core.builtins.capability import CoreCapability if TYPE_CHECKING: from avilla.core.application import Avilla @@ -29,45 +30,23 @@ class AccountInfo: class BaseAccount: route: Selector avilla: Avilla + artifacts: CollectContext = field(default_factory=CollectContext, init=False) @property def info(self) -> AccountInfo: return self.avilla.accounts[self.route] - @property - def staff(self): - return Staff(self.get_staff_artifacts(), self.get_staff_components()) - @property def available(self) -> bool: return True def get_context(self, target: Selector, *, via: Selector | None = None) -> Context: - return self.staff.get_context(target, via=via) + return CoreCapability.get_context(target, via=via) def get_self_context(self): from avilla.core.context import Context - return Context( - self, - self.route, - self.route, - self.route.into("::"), - self.route, - ) - - def get_staff_components(self): - return {"account": self, "protocol": self.info.protocol, "avilla": self.avilla} - - def get_staff_artifacts(self): - return [ - self.info.artifacts, - self.info.protocol.artifacts, - self.avilla.global_artifacts, - ] - - def __staff_generic__(self, element_type: dict, event_type: dict): - ... + return Context(self, self.route, self.route, self.route.into("::"), self.route) class AccountStatus(Statv): diff --git a/avilla/core/application.py b/avilla/core/application.py index 20751cfe..b83fabaa 100644 --- a/avilla/core/application.py +++ b/avilla/core/application.py @@ -2,7 +2,7 @@ import asyncio import signal -from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, overload +from typing import TYPE_CHECKING, Callable, Iterable, TypeVar, overload from creart import it from graia.amnesia.builtins.memcache import MemcacheService @@ -11,12 +11,11 @@ from launart.service import Service from loguru import logger -from avilla.core._runtime import get_current_avilla +from avilla.core.globals import get_current_avilla from avilla.core.account import AccountInfo, BaseAccount from avilla.core.dispatchers import AvillaBuiltinDispatcher from avilla.core.event import MetadataModified from avilla.core.protocol import BaseProtocol -from avilla.core.ryanvk.staff import Staff from avilla.core.selector import Selector from avilla.core.service import AvillaService from avilla.core.utilles import identity @@ -29,7 +28,6 @@ from avilla.core.event import AvillaEvent from avilla.standard.core.application import AvillaLifecycleEvent - from .resource import Resource T = TypeVar("T") TE = TypeVar("TE", bound="AvillaEvent") @@ -42,7 +40,6 @@ class Avilla: protocols: list[BaseProtocol] accounts: dict[Selector, AccountInfo] service: AvillaService - global_artifacts: dict[Any, Any] def __init__( self, @@ -65,8 +62,6 @@ def __init__( self.launch_manager.add_component(self.service) self.broadcast.finale_dispatchers.append(AvillaBuiltinDispatcher(self)) - self.__init_isolate__() - if message_cache_size > 0: from avilla.core.context import Context from avilla.core.message import Message @@ -176,12 +171,10 @@ def event_record(self, event: AvillaEvent | AvillaLifecycleEvent): ) @overload - def add_event_recorder(self, event_type: type[TE]) -> Callable[[Callable[[TE], None]], Callable[[TE], None]]: - ... + def add_event_recorder(self, event_type: type[TE]) -> Callable[[Callable[[TE], None]], Callable[[TE], None]]: ... @overload - def add_event_recorder(self, event_type: type[TE], recorder: Callable[[TE], None]) -> Callable[[TE], None]: - ... + def add_event_recorder(self, event_type: type[TE], recorder: Callable[[TE], None]) -> Callable[[TE], None]: ... def add_event_recorder( self, event_type: type[TE], recorder: Callable[[TE], None] | None = None @@ -196,45 +189,28 @@ def wrapper(func: Callable[[TE], None]): self.custom_event_recorder[event_type] = recorder # type: ignore return recorder - def __init_isolate__(self): - from avilla.core.builtins.resource_fetch import CoreResourceFetchPerform - - CoreResourceFetchPerform.apply_to(self.global_artifacts) - - def get_staff_components(self): - return {"avilla": self} - - def get_staff_artifacts(self): - return [self.global_artifacts] - - def __staff_generic__(self, element_type: dict, event_type: dict): - ... + @classmethod + def _require_builtins(cls): + import avilla.core.builtins.resource_fetch # noqa: F401 @classmethod def current(cls): return get_current_avilla() - async def fetch_resource(self, resource: Resource[T]) -> T: - return await Staff(self.get_staff_artifacts(), self.get_staff_components()).fetch_resource(resource) - def get_account(self, target: Selector) -> AccountInfo: return self.accounts[target] @overload - def get_accounts(self, *, land: str) -> list[AccountInfo]: - ... + def get_accounts(self, *, land: str) -> list[AccountInfo]: ... @overload - def get_accounts(self, *, pattern: str) -> list[AccountInfo]: - ... + def get_accounts(self, *, pattern: str) -> list[AccountInfo]: ... @overload - def get_accounts(self, *, protocol_type: type[BaseProtocol]) -> list[AccountInfo]: - ... + def get_accounts(self, *, protocol_type: type[BaseProtocol]) -> list[AccountInfo]: ... @overload - def get_accounts(self, *, account_type: type[BaseAccount]) -> list[AccountInfo]: - ... + def get_accounts(self, *, account_type: type[BaseAccount]) -> list[AccountInfo]: ... def get_accounts( self, diff --git a/avilla/core/builtins/capability.py b/avilla/core/builtins/capability.py index 8bb50a39..578217a8 100644 --- a/avilla/core/builtins/capability.py +++ b/avilla/core/builtins/capability.py @@ -1,17 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload +from flywheel.fn.base import Fn +from flywheel.fn.compose import FnCompose +from flywheel.fn.implement import OverloadRecorder +from flywheel.fn.record import FnRecord +from flywheel.overloads import SimpleOverload, TypeOverload from typing_extensions import Unpack from avilla.core.metadata import Metadata, MetadataRoute from avilla.core.resource import Resource, T -from avilla.core.ryanvk.descriptor.query import QuerySchema -from avilla.core.ryanvk.overload.metadata import MetadataOverload -from avilla.core.ryanvk.overload.target import TargetOverload -from graia.ryanvk.capability import Capability -from graia.ryanvk.fn import Fn -from graia.ryanvk.overload import NoneOverload, TypeOverload +from avilla.core.ryanvk.overloads import TargetOverload if TYPE_CHECKING: from avilla.core.context import Context @@ -19,36 +19,124 @@ M = TypeVar("M", bound=Metadata) +MR = TypeVar("MR", bound=Metadata, covariant=True) +R = TypeVar("R", covariant=True) +Res = TypeVar("Res", bound=Resource, contravariant=True) +Res1 = TypeVar("Res1", bound=Resource, covariant=True) -class CoreCapability(Capability): - query = QuerySchema() +class CoreCapability: + # query = QuerySchema() - @Fn.complex( - { - TargetOverload(): ["target"], - NoneOverload(TargetOverload(), default_factory=lambda _: None): ["via"], - } - ) - def get_context(self, target: Selector, *, via: Selector | None = None) -> Context: - ... + target = TargetOverload("target") - @Fn.complex({TargetOverload(): ["target"], MetadataOverload(): ["route"]}) - async def pull(self, target: Selector, route: type[M] | MetadataRoute[Unpack[tuple[Any, ...]], M]) -> Any: - ... + @Fn.declare + class get_context(FnCompose): + via = TargetOverload("via") + novia = SimpleOverload("novia") - @Fn.complex({TypeOverload(): ["resource"]}) - async def fetch(self, resource: Resource[T]) -> T: - ... + def call(self, record: FnRecord, target: Selector, *, via: Selector | None = None) -> Context: + entities = self.load( + CoreCapability.target.dig(record, target), + self.novia.dig(record, via) if via is None else self.via.dig(record, via), + ) - @Fn.complex({TargetOverload(): ["target"]}) - def channel(self, target: Selector) -> str: - ... + if via is None: + return entities.first(target=target) # type: ignore - @Fn.complex({TargetOverload(): ["target"]}) - def guild(self, target: Selector) -> str: - ... + return entities.first(target=target, via=via) - @Fn.complex({TargetOverload(): ["target"]}) - def user(self, target: Selector) -> str: - ... + class shapecall_novia(Protocol): + def __call__(self, target: Selector) -> Context: ... + + class shapecall_via(Protocol): + def __call__(self, target: Selector, via: Selector) -> Context: ... + + @overload + def collect(self, recorder: OverloadRecorder[shapecall_via], target: str, via: str) -> None: ... + + @overload + def collect(self, recorder: OverloadRecorder[shapecall_novia], target: str, via: None = None) -> None: ... + + def collect(self, recorder: OverloadRecorder, target: str, via: str | None = None): + # TODO: 能否使用 predicator? + + recorder.use(CoreCapability.target, (target, {})) + if via is None: + recorder.use(self.novia, None) + else: + recorder.use(self.via, (via, {})) + + @Fn.declare + class pull(FnCompose): + route = SimpleOverload("route") + + async def call( + self, record: FnRecord, target: Selector, route: type[M] | MetadataRoute[Unpack[tuple[Any, ...]], M] + ) -> M: + entities = self.load(CoreCapability.target.dig(record, target), self.route.dig(record, route)) + return await entities.first(target=target) + + class shapecall(Protocol[MR]): + async def __call__(self, target: Selector) -> MR: ... + + def collect( + self, + recorder: OverloadRecorder[shapecall[M]], + target: str, + route: type[M] | MetadataRoute[Unpack[tuple[Any, ...]], M], + ): + recorder.use(CoreCapability.target, (target, {})) + recorder.use(self.route, route) + + @Fn.declare + class fetch(FnCompose): + resource = TypeOverload("resource") + + async def call(self, record: FnRecord, resource: Resource[T]) -> T: + entities = self.load(self.resource.dig(record, resource)) + return await entities.first(resource) + + class shapecall(Protocol[Res, Res1]): + async def __call__(self: CoreCapability.fetch.shapecall[Res, Resource[T]], resource: Res) -> T: ... + + def collect(self, recorder: OverloadRecorder[shapecall[Res, Res]], resource: type[Res]): + recorder.use(self.resource, resource) + + @Fn.declare + class channel(FnCompose): + def call(self, record: FnRecord, target: Selector) -> str: + entities = self.load(CoreCapability.target.dig(record, target)) + return entities.first(target=target) + + class shapecall(Protocol): + def __call__(self, target: Selector) -> str: ... + + def collect(self, recorder: OverloadRecorder[shapecall], target: str): + recorder.use(CoreCapability.target, (target, {})) + + @Fn.declare + class guild(FnCompose): + def call(self, record: FnRecord, target: Selector) -> str: + entities = self.load(CoreCapability.target.dig(record, target)) + return entities.first(target=target) + + class shapecall(Protocol): + def __call__(self, target: Selector) -> str: ... + + def collect(self, recorder: OverloadRecorder[shapecall], target: str): + recorder.use(CoreCapability.target, (target, {})) + + @Fn.declare + class user(FnCompose): + def call(self, record: FnRecord, target: Selector) -> str: + entities = self.load(CoreCapability.target.dig(record, target)) + return entities.first(target=target) + + class shapecall(Protocol): + def __call__(self, target: Selector) -> str: ... + + def collect(self, recorder: OverloadRecorder[shapecall], target: str): + recorder.use(CoreCapability.target, (target, {})) + + # TODO: query diff --git a/avilla/core/builtins/command/__init__.py b/avilla/core/builtins/command/__init__.py index 4d0dbda5..61ced9eb 100644 --- a/avilla/core/builtins/command/__init__.py +++ b/avilla/core/builtins/command/__init__.py @@ -216,8 +216,7 @@ def on( remove_tome: bool = False, dispatchers: Optional[list[T_Dispatcher]] = None, decorators: Optional[list[Decorator]] = None, - ) -> Callable[[TCallable], TCallable]: - ... + ) -> Callable[[TCallable], TCallable]: ... @overload def on( @@ -230,8 +229,7 @@ def on( *, args: Optional[dict[str, Union[TAValue, Args, Arg]]] = None, meta: Optional[CommandMeta] = None, - ) -> Callable[[TCallable], TCallable]: - ... + ) -> Callable[[TCallable], TCallable]: ... def on( self, diff --git a/avilla/core/builtins/resource_fetch.py b/avilla/core/builtins/resource_fetch.py index 36d95925..56562773 100644 --- a/avilla/core/builtins/resource_fetch.py +++ b/avilla/core/builtins/resource_fetch.py @@ -1,38 +1,30 @@ from __future__ import annotations +from flywheel.scoped import scoped_collect + from avilla.core.resource import LocalFileResource, RawResource, UrlResource -from avilla.core.ryanvk.collector.application import ApplicationCollector try: from aiohttp import ClientSession - - aio = True except ImportError: ClientSession = None - aio = False from .capability import CoreCapability -class CoreResourceFetchPerform((m := ApplicationCollector())._): - @m.entity(CoreCapability.fetch, resource=LocalFileResource) +class CoreResourceFetchPerform(m := scoped_collect.globals().target, static=True): + @m.impl(CoreCapability.fetch, resource=LocalFileResource) async def fetch_localfile(self, resource: LocalFileResource): return resource.file.read_bytes() - @m.entity(CoreCapability.fetch, resource=RawResource) + @m.impl(CoreCapability.fetch, resource=RawResource) async def fetch_raw(self, resource: RawResource): return resource.data - if aio: + if ClientSession is not None: - @m.entity(CoreCapability.fetch, resource=UrlResource) + @m.impl(CoreCapability.fetch, resource=UrlResource) async def fetch_url(self, resource: UrlResource): - async with ClientSession() as session: + async with ClientSession() as session: # type: ignore async with session.get(resource.url) as resp: return await resp.read() - - else: - - @m.entity(CoreCapability.fetch, resource=UrlResource) - async def fetch_url(self, resource: UrlResource): - raise NotImplementedError("aiohttp is not installed") diff --git a/avilla/core/context/__init__.py b/avilla/core/context/__init__.py index c7f83c15..fbbd8522 100644 --- a/avilla/core/context/__init__.py +++ b/avilla/core/context/__init__.py @@ -1,18 +1,19 @@ from __future__ import annotations -from collections.abc import Callable -from typing import Any, TypedDict, TypeVar, cast, overload +from contextlib import contextmanager +from typing import Any, TypedDict, TypeVar, cast from typing_extensions import ParamSpec, Unpack -from avilla.core._runtime import cx_context +from flywheel import InstanceContext + +from avilla.core.globals import CONTEXT_CONTEXT_VAR from avilla.core.account import BaseAccount +from avilla.core.builtins.capability import CoreCapability from avilla.core.metadata import Metadata, MetadataRoute from avilla.core.platform import Land from avilla.core.resource import Resource -from avilla.core.ryanvk import Fn -from avilla.core.ryanvk.staff import Staff -from avilla.core.selector import FollowsPredicater, Selectable, Selector +from avilla.core.selector import Selectable, Selector from avilla.core.utilles import classproperty from ._roles import ( @@ -54,15 +55,8 @@ def __init__( scene: Selector, selft: Selector, mediums: list[Selector] | None = None, - prelude_metadatas: dict[Selector, dict[type[Metadata] | MetadataRoute, Metadata]] | None = None, + metadatas: dict[Selector, dict[type[Metadata] | MetadataRoute, Metadata]] | None = None, ) -> None: - self.artifacts = [ - account.info.artifacts, - account.info.protocol.artifacts, - account.avilla.global_artifacts, - ] - # 这里是为了能在 Context 层级进行修改 - self.account = account self.client = ContextClientSelector.from_selector(self, client) @@ -71,8 +65,13 @@ def __init__( self.self = ContextSelfSelector.from_selector(self, selft) self.mediums = [ContextMedium(ContextSelector.from_selector(self, medium)) for medium in mediums or []] - self.cache = {"meta": prelude_metadatas or {}} - self.staff = Staff(self.get_staff_artifacts(), self.get_staff_components()) + self.cache = {"meta": metadatas or {}} + + @property + def instance_context(self): + ins = InstanceContext() + ins.instances.update({type(i): i for i in [self, self.avilla, self.protocol, self.account]}) + return ins @property def protocol(self): @@ -101,24 +100,11 @@ def _collect_metadatas(self, target: Selector | Selectable, *metadatas: Metadata @classproperty @classmethod def current(cls) -> Context: - return cx_context.get() - - def get_staff_components(self): - return { - "context": self, - "protocol": self.protocol, - "account": self.account, - "avilla": self.avilla, - } - - def get_staff_artifacts(self): - return self.artifacts - - def query(self, pattern: str, **predicators: FollowsPredicater): - return self.staff.query_entities(pattern, **predicators) + return CONTEXT_CONTEXT_VAR.get() - async def fetch(self, resource: Resource[_T]) -> _T: - return await self.staff.fetch_resource(resource) + @contextmanager + def lookup_scope(self): + yield async def pull( self, @@ -138,24 +124,4 @@ async def pull( if not route.has_params(): return cast("_MetadataT", meta) - return await self.staff.pull_metadata(target, route) - - @overload - def __getitem__(self, closure: Selector) -> ContextSelector: - ... - - @overload - def __getitem__(self, closure: Fn[P, R]) -> Callable[P, R]: - ... - - def __getitem__(self, closure: Selector | Fn[P, Any]): - if isinstance(closure, Selector): - return ContextSelector(self, closure.pattern) - - def run(*args: P.args, **kwargs: P.kwargs): - return self.staff.call_fn(closure, *args, **kwargs) - - return run - - def __staff_generic__(self, element_type: dict, event_type: dict): - ... + return await CoreCapability.pull(target, route) diff --git a/avilla/core/context/_roles.py b/avilla/core/context/_roles.py index 85bcdb36..b85814b0 100644 --- a/avilla/core/context/_roles.py +++ b/avilla/core/context/_roles.py @@ -4,17 +4,17 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, TypeVar -from graia.amnesia.message import Element, MessageChain, Text from typing_extensions import ParamSpec from avilla.core.builtins.capability import CoreCapability from avilla.core.message import Message from avilla.core.metadata import Metadata from avilla.core.selector import Selector -from avilla.standard.core.activity import ActivityTrigger -from avilla.standard.core.message import MessageSend -from avilla.standard.core.relation import SceneCapability -from avilla.standard.core.request import RequestCapability +from avilla.standard.core.activity import start_activity +from avilla.standard.core.message import send_message +from avilla.standard.core.relation import leave_scene, disband_scene, remove_member +from avilla.standard.core.request import accept_request, reject_request, cancel_request, ignore_request +from graia.amnesia.message import Element, MessageChain, Text from ._selector import ContextSelector @@ -28,22 +28,22 @@ class ContextClientSelector(ContextSelector): def trigger_activity(self, activity: str): - return self.context[ActivityTrigger.trigger](self.activity(activity)) + return start_activity(self.activity(activity)) @property def channel(self) -> str: - return self.context.staff.call_fn(CoreCapability.channel, self) + return CoreCapability.channel(self) @property def guild(self) -> str | None: try: - return self.context.staff.call_fn(CoreCapability.guild, self) + return CoreCapability.guild(self) except NotImplementedError: return None @property def user(self) -> str: - return self.context.staff.call_fn(CoreCapability.user, self) + return CoreCapability.user(self) class ContextEndpointSelector(ContextSelector): @@ -55,31 +55,31 @@ def expects_request(self) -> ContextRequestSelector: @property def channel(self) -> str: - return self.context.staff.call_fn(CoreCapability.channel, self) + return CoreCapability.channel(self) @property def guild(self) -> str | None: try: - return self.context.staff.call_fn(CoreCapability.guild, self) + return CoreCapability.guild(self) except NotImplementedError: return None @property def user(self) -> str: - return self.context.staff.call_fn(CoreCapability.user, self) + return CoreCapability.user(self) class ContextSelfSelector(ContextSelector): def trigger_activity(self, activity: str): - return self.context[ActivityTrigger.trigger](self.activity(activity)) + return start_activity(self.activity(activity)) class ContextSceneSelector(ContextSelector): def leave_scene(self): - return self.context[SceneCapability.leave](self) + return leave_scene(self) def disband_scene(self): - return self.context[SceneCapability.disband](self) + return disband_scene(self) def send_message( self, @@ -101,32 +101,32 @@ def send_message( elif isinstance(reply, str): reply = self.message(reply) - return self.context[MessageSend.send](self, message, reply=reply) + return send_message(self, message, reply=reply) def remove_member(self, target: Selector, reason: str | None = None): - return self.context[SceneCapability.remove_member](target, reason) + return remove_member(target, reason) @property def channel(self) -> str: - return self.context.staff.call_fn(CoreCapability.channel, self) + return CoreCapability.channel(self) @property def guild(self) -> str: - return self.context.staff.call_fn(CoreCapability.guild, self) + return CoreCapability.guild(self) class ContextRequestSelector(ContextEndpointSelector): - def accept_request(self): - return self.context[RequestCapability.accept](self) + def accept(self): + return accept_request(self) - def reject_request(self, reason: str | None = None, forever: bool = False): - return self.context[RequestCapability.reject](self, reason, forever) + def reject(self, reason: str | None = None, forever: bool = False): + return reject_request(self, reason, forever) - def cancel_request(self): - return self.context[RequestCapability.cancel](self) + def cancel(self): + return cancel_request(self) - def ignore_request(self): - return self.context[RequestCapability.ignore](self) + def ignore(self): + return ignore_request(self) @dataclass diff --git a/avilla/core/context/_selector.py b/avilla/core/context/_selector.py index 363e01b2..92836d31 100644 --- a/avilla/core/context/_selector.py +++ b/avilla/core/context/_selector.py @@ -2,12 +2,11 @@ from collections.abc import Mapping from copy import copy, deepcopy -from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, overload +from typing import TYPE_CHECKING, Any, Awaitable, TypeVar -from typing_extensions import Concatenate, ParamSpec, Self, Unpack +from typing_extensions import ParamSpec, Self, Unpack from avilla.core.metadata import Metadata, MetadataRoute -from avilla.core.ryanvk import Fn from avilla.core.selector import EMPTY_MAP, Selector from avilla.standard.core.privilege import Privilege from avilla.standard.core.profile import Avatar, Nick, Summary @@ -38,26 +37,6 @@ def __deepcopy__(self, memo): def from_selector(cls, cx: Context, selector: Selector) -> Self: return cls(cx, selector.pattern) - @overload - def __getitem__(self, item: str) -> str: - ... - - @overload - def __getitem__(self, item: Fn[Concatenate[Selector, P], R]) -> Callable[P, R]: - ... - - def __getitem__( - self, - item: str | Fn[Concatenate[Selector, P], R], - ) -> str | Callable[P, R]: - if isinstance(item, str): - return super().__getitem__(item) - - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - return self.context.staff.call_fn(item, self, *args, **kwargs) - - return wrapper - def pull( self, metadata: type[_MetadataT] | MetadataRoute[Unpack[tuple[Any, ...]], _MetadataT] ) -> Awaitable[_MetadataT]: diff --git a/avilla/core/dispatchers.py b/avilla/core/dispatchers.py index 78394f51..09da9edb 100644 --- a/avilla/core/dispatchers.py +++ b/avilla/core/dispatchers.py @@ -5,7 +5,7 @@ from graia.broadcast.entities.dispatcher import BaseDispatcher -from avilla.core._runtime import cx_protocol +from avilla.core.globals import PROTOCOL_CONTEXT_VAR from avilla.core.account import BaseAccount from avilla.core.context import Context from avilla.core.event import AvillaEvent @@ -28,14 +28,17 @@ async def catch(self, interface: DispatcherInterface[AvillaEvent]): if interface.annotation is Avilla: return self.avilla + if interface.annotation in self.avilla._protocol_map: return self.avilla._protocol_map[interface.annotation] + if ( isclass(interface.annotation) and issubclass(interface.annotation, BaseProtocol) - and isinstance(cx_protocol.get(None), interface.annotation) + and isinstance(PROTOCOL_CONTEXT_VAR.get(None), interface.annotation) ): - return cx_protocol.get(None) + return PROTOCOL_CONTEXT_VAR.get(None) + if ( isinstance(interface.event, AvillaEvent) and isclass(interface.annotation) diff --git a/avilla/core/event.py b/avilla/core/event.py index c1fab43c..4012880a 100644 --- a/avilla/core/event.py +++ b/avilla/core/event.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextlib import AsyncExitStack from dataclasses import dataclass, field from datetime import datetime from inspect import isclass @@ -13,7 +14,7 @@ from avilla.core.metadata import Metadata, MetadataRoute from avilla.core.selector import Selector -from ._runtime import cx_context +from avilla.core.globals import CONTEXT_CONTEXT_VAR if TYPE_CHECKING: from graia.broadcast.interfaces.dispatcher import DispatcherInterface @@ -31,8 +32,11 @@ class Dispatcher(BaseDispatcher): async def beforeExecution(interface: DispatcherInterface[AvillaEvent]): if interface.depth < 1: interface.local_storage["avilla_context"] = interface.event.context - interface.local_storage["_context_token"] = cx_context.set(interface.event.context) - await interface.event.context.staff.exit_stack.__aenter__() + interface.local_storage["_context_token"] = CONTEXT_CONTEXT_VAR.set(interface.event.context) + + stack: AsyncExitStack = interface.local_storage["_depend_lifespan_manager"] + stack.enter_context(interface.event.context.protocol.artifacts.lookup_scope()) + #stack.enter_context(interface.event.context.instance_context.scope()) @staticmethod async def catch(interface: DispatcherInterface[AvillaEvent]): @@ -56,33 +60,27 @@ async def catch(interface: DispatcherInterface[AvillaEvent]): @staticmethod async def afterExecution(interface: DispatcherInterface[AvillaEvent], exc, tb): if interface.depth < 1: - cx_context.reset(interface.local_storage["_context_token"]) - await interface.event.context.staff.exit_stack.__aexit__(type(exc), exc, tb) + CONTEXT_CONTEXT_VAR.reset(interface.local_storage["_context_token"]) @dataclass -class RelationshipEvent(AvillaEvent): - ... +class RelationshipEvent(AvillaEvent): ... @dataclass -class RelationshipCreated(RelationshipEvent): - ... +class RelationshipCreated(RelationshipEvent): ... @dataclass -class DirectSessionCreated(RelationshipCreated): - ... +class DirectSessionCreated(RelationshipCreated): ... @dataclass -class SceneCreated(RelationshipCreated): - ... +class SceneCreated(RelationshipCreated): ... @dataclass -class MemberCreated(RelationshipEvent): - ... +class MemberCreated(RelationshipEvent): ... @dataclass @@ -92,18 +90,15 @@ class RelationshipDestroyed(RelationshipEvent): @dataclass -class DirectSessionDestroyed(RelationshipDestroyed): - ... +class DirectSessionDestroyed(RelationshipDestroyed): ... @dataclass -class SceneDestroyed(RelationshipDestroyed): - ... +class SceneDestroyed(RelationshipDestroyed): ... @dataclass -class MemberDestroyed(RelationshipDestroyed): - ... +class MemberDestroyed(RelationshipDestroyed): ... @dataclass diff --git a/avilla/core/_runtime.py b/avilla/core/globals.py similarity index 68% rename from avilla/core/_runtime.py rename to avilla/core/globals.py index 207160ad..2252d7f9 100644 --- a/avilla/core/_runtime.py +++ b/avilla/core/globals.py @@ -11,29 +11,29 @@ from avilla.core.protocol import BaseProtocol -cx_avilla: Ctx[Avilla] = Ctx("avilla") -cx_protocol: Ctx[BaseProtocol] = Ctx("protocol") -cx_context: Ctx[Context] = Ctx("context") +AVILLA_CONTEXT_VAR: Ctx[Avilla] = Ctx("avilla") +PROTOCOL_CONTEXT_VAR: Ctx[BaseProtocol] = Ctx("protocol") +CONTEXT_CONTEXT_VAR: Ctx[Context] = Ctx("context") def get_current_avilla() -> Avilla: - avilla = cx_avilla.get(None) + avilla = AVILLA_CONTEXT_VAR.get(None) if avilla: return avilla - protocol = cx_protocol.get(None) + protocol = PROTOCOL_CONTEXT_VAR.get(None) if protocol: return protocol.avilla - context = cx_context.get(None) + context = CONTEXT_CONTEXT_VAR.get(None) if context: return context.protocol.avilla raise RuntimeError("no any current avilla") def get_current_protocol(): - protocol = cx_protocol.get(None) + protocol = PROTOCOL_CONTEXT_VAR.get(None) if protocol: return protocol - context = cx_context.get(None) + context = CONTEXT_CONTEXT_VAR.get(None) if context: return context.protocol raise RuntimeError("no any current protocol") @@ -42,7 +42,7 @@ def get_current_protocol(): def require_context(func): @functools.wraps(func) async def wrapper(*args, **kwargs): - if cx_context.get(None): + if CONTEXT_CONTEXT_VAR.get(None): return await func(*args, **kwargs) raise RuntimeError("no any current context") diff --git a/avilla/core/message.py b/avilla/core/message.py index ee9ab4c8..57386b4a 100644 --- a/avilla/core/message.py +++ b/avilla/core/message.py @@ -7,9 +7,8 @@ from avilla.core.platform import Land from avilla.core.selector import Selector -from avilla.standard.core.message.capability import MessageRevoke +from avilla.standard.core.message.capability import revoke_message -from ._runtime import cx_context from .metadata import Metadata @@ -30,4 +29,4 @@ def to_selector(self) -> Selector: return self.scene.message(self.id) async def revoke(self): - await cx_context.get()[MessageRevoke.revoke](self.to_selector()) + await revoke_message(self.to_selector()) diff --git a/avilla/core/metadata.py b/avilla/core/metadata.py index 704da441..2f0d0bea 100644 --- a/avilla/core/metadata.py +++ b/avilla/core/metadata.py @@ -27,14 +27,12 @@ def __init__(self, getitem: T): class MetadataMeta(type): @overload - def __rshift__(cls: type[_MetadataT1], other: type[_MetadataT2]) -> MetadataRoute[_MetadataT1, _MetadataT2]: - ... + def __rshift__(cls: type[_MetadataT1], other: type[_MetadataT2]) -> MetadataRoute[_MetadataT1, _MetadataT2]: ... @overload def __rshift__( cls: type[_MetadataT1], other: MetadataRoute[Unpack[_TVT1]] - ) -> MetadataRoute[_MetadataT1, Unpack[_TVT1]]: - ... + ) -> MetadataRoute[_MetadataT1, Unpack[_TVT1]]: ... def __rshift__(cls: Any, other: type[Metadata] | MetadataRoute) -> MetadataRoute: # sourcery skip: instance-method-first-arg-name @@ -68,8 +66,7 @@ def __init_subclass__(cls) -> None: if TYPE_CHECKING: @classmethod - def inh(cls: type[_MetadataT1]) -> _MetadataT1: - ... + def inh(cls: type[_MetadataT1]) -> _MetadataT1: ... else: @@ -107,14 +104,12 @@ def __init__(self, cells: tuple[type[Metadata], ...]) -> None: @overload def __rshift__( self: MetadataRoute[Unpack[_TVT1]], other: type[_MetadataT1] - ) -> MetadataRoute[Unpack[_TVT1], _MetadataT1]: - ... + ) -> MetadataRoute[Unpack[_TVT1], _MetadataT1]: ... @overload def __rshift__( self: MetadataRoute[Unpack[_TVT1]], other: MetadataRoute[Unpack[_TVT2]] - ) -> MetadataRoute[Unpack[_TVT1], Unpack[_TVT2]]: - ... + ) -> MetadataRoute[Unpack[_TVT1], Unpack[_TVT2]]: ... def __rshift__(self, other: type[Metadata] | MetadataRoute) -> MetadataRoute: if not isinstance(other, (type, MetadataRoute)): @@ -139,8 +134,7 @@ def clear_params(self) -> None: if TYPE_CHECKING: @property - def inh(self: MetadataRoute[Unpack[tuple[Any, ...]], _MetadataT1]) -> _MetadataT1: - ... + def inh(self: MetadataRoute[Unpack[tuple[Any, ...]], _MetadataT1]) -> _MetadataT1: ... else: @@ -172,9 +166,9 @@ def __getattr__(self, item: str) -> Self: def __call__(self, *args: Any, **kwargs: Any) -> Self: prev = self.__steps[-1] - self.__steps[ - -1 - ] = f"{prev}({', '.join(repr(arg) for arg in args)}, {', '.join(f'{key}={repr(value)}' for key, value in kwargs.items())})" + self.__steps[-1] = ( + f"{prev}({', '.join(repr(arg) for arg in args)}, {', '.join(f'{key}={repr(value)}' for key, value in kwargs.items())})" + ) return self def __getitem__(self, item: Any) -> Self: @@ -203,9 +197,9 @@ def __getattr__(self, item: str) -> Self: def __call__(self, *args: Any, **kwargs: Any) -> Self: prev = self.__steps[-1] - self.__steps[ - -1 - ] = f"{prev}({', '.join(repr(arg) for arg in args)}, {', '.join(f'{key}={repr(value)}' for key, value in kwargs.items())})" + self.__steps[-1] = ( + f"{prev}({', '.join(repr(arg) for arg in args)}, {', '.join(f'{key}={repr(value)}' for key, value in kwargs.items())})" + ) return self def __getitem__(self, item: Any) -> Self: diff --git a/avilla/core/platform.py b/avilla/core/platform.py index 776d3086..4e4972c2 100644 --- a/avilla/core/platform.py +++ b/avilla/core/platform.py @@ -5,8 +5,7 @@ @dataclass -class PlatformDescription: - ... +class PlatformDescription: ... PD = TypeVar("PD", bound=PlatformDescription) diff --git a/avilla/core/protocol.py b/avilla/core/protocol.py index 3d0dbdaf..543f355f 100644 --- a/avilla/core/protocol.py +++ b/avilla/core/protocol.py @@ -1,35 +1,43 @@ from __future__ import annotations from contextlib import nullcontext -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any +from flywheel import CollectContext from typing_extensions import Self -from avilla.core._runtime import cx_avilla, cx_context, cx_protocol from avilla.core.event import AvillaEvent +from avilla.core.globals import AVILLA_CONTEXT_VAR, CONTEXT_CONTEXT_VAR, PROTOCOL_CONTEXT_VAR +from avilla.core.utilles import cachedstatic +from avilla.standard.core.application import AvillaLifecycleEvent if TYPE_CHECKING: from avilla.core.application import Avilla from avilla.core.context import Context -class ProtocolConfig: - ... +class ProtocolConfig: ... class BaseProtocol: avilla: Avilla - artifacts: ClassVar[dict[Any, Any]] - def ensure(self, avilla: Avilla) -> Any: - ... + @cachedstatic + def artifacts(): + with CollectContext().collect_scope() as collect_context: + ... - def configure(self, config: ProtocolConfig) -> Self: - ... + return collect_context - def post_event(self, event: AvillaEvent, context: Context | None = None): - with cx_avilla.use(self.avilla), cx_protocol.use(self), ( - cx_context.use(context) if context is not None else nullcontext() + def ensure(self, avilla: Avilla) -> Any: ... + + def configure(self, config: ProtocolConfig) -> Self: ... + + def post_event(self, event: AvillaEvent | AvillaLifecycleEvent, context: Context | None = None): + with ( + AVILLA_CONTEXT_VAR.use(self.avilla), + PROTOCOL_CONTEXT_VAR.use(self), + CONTEXT_CONTEXT_VAR.use(context) if context is not None else nullcontext(), ): self.avilla.event_record(event) return self.avilla.broadcast.postEvent(event) diff --git a/avilla/core/request.py b/avilla/core/request.py index e3083863..7d3dd0ff 100644 --- a/avilla/core/request.py +++ b/avilla/core/request.py @@ -4,9 +4,8 @@ from datetime import datetime from typing import TYPE_CHECKING -from avilla.standard.core.request.capability import RequestCapability +from avilla.standard.core.request.capability import accept_request, reject_request, cancel_request, ignore_request -from ._runtime import cx_context from .metadata import Metadata from .platform import Land @@ -61,13 +60,13 @@ def to_selector(self) -> Selector: return self.scene.request(request_id) async def accept(self): - return await cx_context.get()[RequestCapability.accept](self.to_selector()) + return await accept_request(self.to_selector()) async def reject(self, reason: str | None = None, forever: bool = False): - return await cx_context.get()[RequestCapability.reject](self.to_selector(), reason, forever) + return await reject_request(self.to_selector(), reason, forever) async def cancel(self): - return await cx_context.get()[RequestCapability.cancel](self.to_selector()) + return await cancel_request(self.to_selector()) async def ignore(self): - return await cx_context.get()[RequestCapability.ignore](self.to_selector()) + return await ignore_request(self.to_selector()) diff --git a/avilla/core/ryanvk/__init__.py b/avilla/core/ryanvk/__init__.py index 53c32e4a..5709e71b 100644 --- a/avilla/core/ryanvk/__init__.py +++ b/avilla/core/ryanvk/__init__.py @@ -1,13 +1 @@ -from avilla.core.ryanvk.collector.context import ( - ContextBasedPerformTemplate as ContextBasedPerformTemplate, -) -from avilla.core.ryanvk.collector.context import ContextCollector as ContextCollector -from avilla.core.ryanvk.descriptor.query import QueryRecord as QueryRecord -from avilla.core.ryanvk.descriptor.query import QuerySchema as QuerySchema -from avilla.core.ryanvk.overload.metadata import MetadataOverload as MetadataOverload -from avilla.core.ryanvk.overload.target import TargetOverload as TargetOverload -from graia.ryanvk.capability import Capability as Capability -from graia.ryanvk.fn import Fn as Fn -from graia.ryanvk.overload import FnOverload as FnOverload -from graia.ryanvk.overload import SimpleOverload as SimpleOverload -from graia.ryanvk.overload import TypeOverload as TypeOverload +from .overloads import TargetOverload as TargetOverload diff --git a/avilla/core/ryanvk/bases.py b/avilla/core/ryanvk/bases.py new file mode 100644 index 00000000..3610a250 --- /dev/null +++ b/avilla/core/ryanvk/bases.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from flywheel import InstanceOf +from avilla.core.application import Avilla + + +class InstanceOfAvilla: + avilla = InstanceOf(Avilla) diff --git a/avilla/core/ryanvk/behavior/__init__.py b/avilla/core/ryanvk/behavior/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/avilla/core/ryanvk/behavior/query.py b/avilla/core/ryanvk/behavior/query.py deleted file mode 100644 index e80ee182..00000000 --- a/avilla/core/ryanvk/behavior/query.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations - -from typing import Any, Callable, TypedDict - -from avilla.core.ryanvk.descriptor.query import QueryRecord, find_querier_steps -from avilla.core.selector import FollowsPredicater, _parse_follows -from graia.ryanvk import FnOverload, OverloadBehavior -from graia.ryanvk.collector import BaseCollector - - -class QueryCollectParams(TypedDict): - target: str - previous: str | None - - -class QueryCallArgs(TypedDict): - pattern: str - predicators: dict[str, FollowsPredicater] - - -class QueryOverload(FnOverload): - identity: str = "query" - - def collect_entity( - self, - collector: BaseCollector, - scope: dict[Any, Any], - entity: Any, - params: QueryCollectParams, - ) -> None: - record_tuple = (collector, entity) - - sign = QueryRecord(params["previous"], params["target"]) - scope.setdefault(sign, set()).add(record_tuple) - - def get_entities( - self, - scope: dict[QueryRecord, set[tuple[BaseCollector, Callable[..., Any]]]], - args: QueryCallArgs, - ): - items = _parse_follows(args["pattern"], **args["predicators"]) - steps = find_querier_steps(scope, items) - - if steps is None: - raise NotImplementedError - - ... - - -class QueryFnBehavior(OverloadBehavior): - ... diff --git a/avilla/core/ryanvk/collector/__init__.py b/avilla/core/ryanvk/collector/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/avilla/core/ryanvk/collector/account.py b/avilla/core/ryanvk/collector/account.py deleted file mode 100644 index bf396ad6..00000000 --- a/avilla/core/ryanvk/collector/account.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar - -from graia.ryanvk import Access, BasePerform - -from .base import AvillaBaseCollector - -if TYPE_CHECKING: - from avilla.core.account import BaseAccount - from avilla.core.protocol import BaseProtocol - - -TProtocol = TypeVar("TProtocol", bound="BaseProtocol") -TAccount = TypeVar("TAccount", bound="BaseAccount") - -TProtocol1 = TypeVar("TProtocol1", bound="BaseProtocol") -TAccount1 = TypeVar("TAccount1", bound="BaseAccount") - -T = TypeVar("T") -T1 = TypeVar("T1") - - -class AccountBasedPerformTemplate(BasePerform, native=True): - __collector__: ClassVar[AccountCollector] - - protocol: Access[BaseProtocol] = Access() - account: Access[BaseAccount] = Access() - - @property - def avilla(self): - return self.protocol.avilla - - @property - def broadcast(self): - return self.avilla.broadcast - - -class AccountCollector(AvillaBaseCollector, Generic[TProtocol, TAccount]): - post_applying: bool = False - - def __init__(self): - super().__init__() - - @property - def _(self): - upper = super()._ - - class LocalPerformTemplate( - Generic[TProtocol1, TAccount1], - AccountBasedPerformTemplate, - upper, - native=True, - ): - protocol: TProtocol1 - account: TAccount1 - - return LocalPerformTemplate[TProtocol, TAccount] diff --git a/avilla/core/ryanvk/collector/application.py b/avilla/core/ryanvk/collector/application.py deleted file mode 100644 index 8dad8b28..00000000 --- a/avilla/core/ryanvk/collector/application.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, ClassVar, TypeVar - -from graia.ryanvk import Access, BasePerform - -from .base import AvillaBaseCollector - -if TYPE_CHECKING: - from avilla.core.application import Avilla - - -T = TypeVar("T") -T1 = TypeVar("T1") - - -class ApplicationBasedPerformTemplate(BasePerform, native=True): - __collector__: ClassVar[ApplicationCollector] - - avilla: Access[Avilla] = Access() - - -class ApplicationCollector(AvillaBaseCollector): - post_applying: bool = False - - def __init__(self): - super().__init__() - - @property - def _(self): - upper = super()._ - - class LocalPerformTemplate( - ApplicationBasedPerformTemplate, - upper, - native=True, - ): - ... - - return LocalPerformTemplate diff --git a/avilla/core/ryanvk/collector/base.py b/avilla/core/ryanvk/collector/base.py deleted file mode 100644 index 5866cfbd..00000000 --- a/avilla/core/ryanvk/collector/base.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, overload - -from graia.ryanvk import BaseCollector - -if TYPE_CHECKING: - from typing_extensions import Unpack - - from avilla.core.metadata import Metadata, MetadataRoute - from avilla.core.selector import Selector - -T = TypeVar("T") -T1 = TypeVar("T1") - -M = TypeVar("M", bound="Metadata") - - -Wrapper = Callable[[T], T] - - -class AvillaBaseCollector(BaseCollector): - def __init__(self): - super().__init__() - - @overload - def pull( - self, target: str, route: type[M] - ) -> Callable[[Callable[[Any, Selector, type[M]], Awaitable[M]]], Callable[[Any, Selector, type[M]], Awaitable[M]]]: - ... - - @overload - def pull( - self, target: str, route: MetadataRoute[Unpack[tuple[Any, ...]], M] - ) -> Callable[[Callable[[Any, Selector, type[M]], Awaitable[M]]], Callable[[Any, Selector, type[M]], Awaitable[M]]]: - ... - - def pull(self, target: str, route: ...) -> ...: - from avilla.core.builtins.capability import CoreCapability - - return self.entity(CoreCapability.pull, target=target, route=route) - - def fetch(self, resource_type: type[T]) -> Wrapper[Callable[[Any, T], Awaitable[Any]]]: - from avilla.core.builtins.capability import CoreCapability - - return self.entity(CoreCapability.fetch, resource=resource_type) # type: ignore diff --git a/avilla/core/ryanvk/collector/context.py b/avilla/core/ryanvk/collector/context.py deleted file mode 100644 index 3f309f23..00000000 --- a/avilla/core/ryanvk/collector/context.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar - -from graia.ryanvk import Access, BasePerform - -from .base import AvillaBaseCollector - -if TYPE_CHECKING: - from avilla.core.account import BaseAccount - from avilla.core.context import Context - from avilla.core.metadata import Metadata - from avilla.core.protocol import BaseProtocol - - -TProtocol = TypeVar("TProtocol", bound="BaseProtocol") -TAccount = TypeVar("TAccount", bound="BaseAccount") - -TProtocol1 = TypeVar("TProtocol1", bound="BaseProtocol") -TAccount1 = TypeVar("TAccount1", bound="BaseAccount") - -T = TypeVar("T") -T1 = TypeVar("T1") -M = TypeVar("M", bound="Metadata") - - -class ContextBasedPerformTemplate(BasePerform, native=True): - __collector__: ClassVar[ContextCollector] - - context: Access[Context] = Access() - - @property - def protocol(self): - return self.context.protocol - - @property - def account(self): - return self.context.account - - -class ContextCollector(AvillaBaseCollector, Generic[TProtocol, TAccount]): - post_applying: bool = False - - def __init__(self): - super().__init__() - - @property - def _(self): - upper = super()._ - - class LocalPerformTemplate( - Generic[TProtocol1, TAccount1], - ContextBasedPerformTemplate, - upper, - native=True, - ): - protocol: TProtocol1 - account: TAccount1 - - return LocalPerformTemplate[TProtocol, TAccount] diff --git a/avilla/core/ryanvk/collector/protocol.py b/avilla/core/ryanvk/collector/protocol.py deleted file mode 100644 index 6854b5c4..00000000 --- a/avilla/core/ryanvk/collector/protocol.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar - -from graia.ryanvk import Access, BasePerform - -from .base import AvillaBaseCollector - -if TYPE_CHECKING: - from avilla.core.protocol import BaseProtocol - - -TProtocol = TypeVar("TProtocol", bound="BaseProtocol") -TProtocol1 = TypeVar("TProtocol1", bound="BaseProtocol") - -T = TypeVar("T") -T1 = TypeVar("T1") - - -class ProtocolBasedPerformTemplate(BasePerform, native=True): - __collector__: ClassVar[ProtocolCollector] - - protocol: Access[BaseProtocol] = Access() - - -class ProtocolCollector(AvillaBaseCollector, Generic[TProtocol]): - post_applying: bool = False - - def __init__(self): - super().__init__() - - @property - def _(self): - upper = super()._ - - class LocalPerformTemplate( - Generic[TProtocol1], - ProtocolBasedPerformTemplate, - upper, - native=True, - ): - protocol: TProtocol1 - - return LocalPerformTemplate[TProtocol] diff --git a/avilla/core/ryanvk/descriptor/__init__.py b/avilla/core/ryanvk/descriptor/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/avilla/core/ryanvk/descriptor/query.py b/avilla/core/ryanvk/descriptor/query.py deleted file mode 100644 index 60fd7a8f..00000000 --- a/avilla/core/ryanvk/descriptor/query.py +++ /dev/null @@ -1,111 +0,0 @@ -from __future__ import annotations - -from collections import deque -from dataclasses import dataclass -from typing import Any, AsyncGenerator, Callable, Container, Protocol, overload - -from avilla.core.selector import Selector, _FollowItem -from graia.ryanvk import BaseCollector - - -@dataclass(unsafe_hash=True) -class QueryRecord: - """仅用作计算路径, 不参与实际运算, 也因此, 该元素仅存在于全局 Artifacts['query'] 中.""" - - previous: str | None - into: str - - -class QueryHandlerPerform(Protocol): - def __call__( - fself, self: Any, predicate: Callable[[str, str], bool] | str, previous: Selector | None = None - ) -> AsyncGenerator[Selector, None]: - ... - - -class QueryHandlerPerformNoPrev(Protocol): - def __call__( - fself, self: Any, predicate: Callable[[str, str], bool] | str, previous: None - ) -> AsyncGenerator[Selector, None]: - ... - - -class QueryHandlerPerformPrev(Protocol): - def __call__( - fself, self: Any, predicate: Callable[[str, str], bool] | str, previous: Selector - ) -> AsyncGenerator[Selector, None]: - ... - - -class QuerySchema: - @overload - def collect( - self, collector: BaseCollector, target: str, previous: None = None - ) -> Callable[[QueryHandlerPerformNoPrev], QueryHandlerPerformNoPrev]: - ... - - @overload - def collect( - self, collector: BaseCollector, target: str, previous: str - ) -> Callable[[QueryHandlerPerformPrev], QueryHandlerPerformPrev]: - ... - - def collect(self, collector: BaseCollector, target: str, previous: ... = None) -> ...: - def receive(entity: QueryHandlerPerform): - collector.artifacts[QueryRecord(previous, target)] = (collector, entity) - return entity - - return receive - - -class QueryHandler(Protocol): - def __call__( - self, predicate: Callable[[str, str], bool] | str, previous: Selector | None = None - ) -> AsyncGenerator[Selector, None]: - ... - - -# 使用 functools.reduce. -async def query_depth_generator( - handler: QueryHandler, - predicate: Callable[[str, str], bool] | str, - previous_generator: AsyncGenerator[Selector, None] | None = None, -): - if previous_generator is not None: - async for previous in previous_generator: - async for current in handler(predicate, previous): - yield current - else: - async for current in handler(predicate): - yield current - - -@dataclass -class _MatchStep: - upper: str - start: int - history: tuple[tuple[tuple[_FollowItem, ...], QueryRecord], ...] - - -def find_querier_steps( - artifacts: Container[Any], - frags: list[_FollowItem], -) -> list[tuple[tuple[_FollowItem, ...], QueryRecord]] | None: - result: list[tuple[tuple[_FollowItem, ...], QueryRecord]] | None = None - queue: deque[_MatchStep] = deque([_MatchStep("", 0, ())]) - whole = ".".join([i.name for i in frags]) - while queue: - head: _MatchStep = queue.popleft() - current_steps: list[_FollowItem] = [] - for curr_frag in frags[head.start :]: - current_steps.append(curr_frag) - steps = ".".join([i.name for i in current_steps]) - full_path = f"{head.upper}.{steps}" if head.upper else steps - head.start += 1 - if (query := ((*current_steps,), QueryRecord(head.upper or None, steps)))[1] in artifacts: - if full_path == whole: - if result is None or len(result) > len(head.history) + 1: - result = [*head.history, query] - else: - queue.append(_MatchStep(full_path, head.start, head.history + (query,))) - return result diff --git a/avilla/core/ryanvk/endpoint/__init__.py b/avilla/core/ryanvk/endpoint/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/avilla/core/ryanvk/endpoint/launart.py b/avilla/core/ryanvk/endpoint/launart.py deleted file mode 100644 index e69de29b..00000000 diff --git a/avilla/core/ryanvk/overload/__init__.py b/avilla/core/ryanvk/overload/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/avilla/core/ryanvk/overload/metadata.py b/avilla/core/ryanvk/overload/metadata.py deleted file mode 100644 index 20f8c475..00000000 --- a/avilla/core/ryanvk/overload/metadata.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations - -from typing import Any, Callable - -from avilla.core.metadata import Route -from graia.ryanvk.collector import BaseCollector -from graia.ryanvk.overload import FnOverload - - -class MetadataOverload(FnOverload): - def collect_entity( - self, - collector: BaseCollector, - scope: dict[Any, Any], - entity: Any, - params: dict[str, Route], - ) -> None: - record = (collector, entity) - - for param, route in params.items(): - ... - collection: dict[Route, set] = scope.setdefault(param, {}) - collection.setdefault(route, set()).add(record) - - def get_entities(self, scope: dict[Any, Any], args: dict[str, Route]) -> set[tuple[BaseCollector, Callable]]: - sets: list[set] = [] - - for arg_name, route in args.items(): - if arg_name not in scope: - raise NotImplementedError - - collection = scope[arg_name] - - if route not in collection: - raise NotImplementedError - - sets.append(collection[route]) - - return sets.pop().intersection(*sets) - - def merge_scopes(self, *scopes: dict[Any, Any]): - # layout: {param_name: {route: set()}} - result = {} - - for scope in scopes: - for param_name, collection in scope.items(): - result_collection = result.setdefault(param_name, {}) - for route, entities in collection.items(): - routes = result_collection.setdefault(route, set()) - routes.update(entities) - - return result diff --git a/avilla/core/ryanvk/overload/target.py b/avilla/core/ryanvk/overload/target.py deleted file mode 100644 index 950791a4..00000000 --- a/avilla/core/ryanvk/overload/target.py +++ /dev/null @@ -1,155 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable - -from typing_extensions import TypeAlias - -from avilla.core.selector import FollowsPredicater, Selector, _parse_follows -from graia.ryanvk.collector import BaseCollector -from graia.ryanvk.overload import FnOverload - - -@dataclass -class LookupBranchMetadata: - ... - - -@dataclass -class LookupBranch: - metadata: LookupBranchMetadata - levels: LookupCollection - bind: set[tuple[BaseCollector, Callable]] = field(default_factory=set) - - -LookupBranches: TypeAlias = "dict[str | FollowsPredicater | None, LookupBranch]" -LookupCollection: TypeAlias = "dict[str, LookupBranches]" - - -@dataclass -class TargetOverloadConfig: - pattern: str - predicators: dict[str, FollowsPredicater] - - def __init__(self, pattern: str, **predicators: FollowsPredicater): - self.pattern = pattern - self.predicators = predicators - - -def _merge_lookup_collection(current: LookupCollection, other: LookupCollection): - for key, branches in current.items(): - if (other_branches := other.pop(key, None)) is None: - continue - - for header, branch in branches.items(): - if (other_branch := other_branches.pop(header, None)) is None: - continue - - _merge_lookup_collection(branch.levels, other_branch.levels) - branch.bind |= other_branch.bind - - branches |= other_branches - - current |= other - - -class TargetOverload(FnOverload): - def collect_entity( - self, - collector: BaseCollector, - scope: dict[Any, Any], - entity: Any, - params: dict[str, str | TargetOverloadConfig], - ) -> None: - record = (collector, entity) - - for param, pattern in params.items(): - param_scope = scope.setdefault(param, {}) - - if isinstance(pattern, str): - pattern = TargetOverloadConfig(pattern) - - pattern_items = _parse_follows(pattern.pattern, **pattern.predicators) - if not pattern_items: - raise ValueError("invalid target pattern") - - processing_level = param_scope - - if TYPE_CHECKING: - branch = LookupBranch(LookupBranchMetadata(), {}) - - for item in pattern_items: - if item.name not in processing_level: - processing_level[item.name] = {} - - branches = processing_level[item.name] - if (item.literal or item.predicate) in branches: - branch = branches[item.literal or item.predicate] - else: - branch = LookupBranch(LookupBranchMetadata(), {}) - branches[item.literal or item.predicate] = branch - - processing_level = branch.levels - - branch.bind.add(record) - - def get_entities(self, scope: dict[Any, Any], args: dict[str, Selector]) -> set[tuple[BaseCollector, Callable]]: - bind_sets: list[set] = [] - - for arg_name, selector in args.items(): - if arg_name not in scope: - raise NotImplementedError - - def get_bind_set(): - processing_scope: LookupCollection = scope[arg_name] - branch = None - for key, value in selector.pattern.items(): - if (branches := processing_scope.get(key)) is None: - raise NotImplementedError - - if value in branches: - header = value - else: - for _key, branch in branches.items(): - if callable(_key) and _key(value): - header = _key - break # hit predicate - else: - if None in branches: - header = None # hit default - elif "*" in branches: - return branches["*"].bind # hit wildcard - else: - raise NotImplementedError - - branch = branches[header] - processing_scope = branch.levels - - if header is not None and None in branches: - processing_scope = branches[None].levels | processing_scope - if branch is not None and branch.bind: - # branch has bind - return branch.bind - - raise NotImplementedError - - bind_sets.append(get_bind_set()) - return bind_sets.pop().intersection(*bind_sets) - - def merge_scopes(self, *scopes: dict[Any, Any]): - # scope layout: { - # : LookupCollection - # } - param_collections: dict[str, list[LookupCollection]] = {} - result = {} - - for scope in scopes: - for param, collection in scope.items(): - param_collections.setdefault(param, []).append(collection) - - for param, collections in param_collections.items(): - result[param] = current = collections.pop(0) - for other in collections: - _merge_lookup_collection(current, other) - - return result diff --git a/avilla/core/ryanvk/overloads.py b/avilla/core/ryanvk/overloads.py new file mode 100644 index 00000000..2214dd04 --- /dev/null +++ b/avilla/core/ryanvk/overloads.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Callable + +from flywheel.fn.overload import FnOverload +from typing_extensions import TypeAlias + +from avilla.core.selector import FollowsPredicater, Selector, _parse_follows + + +@dataclass +class LookupBranchMetadata: ... + + +@dataclass +class LookupBranch: + metadata: LookupBranchMetadata + levels: LookupCollection + bind: dict[Callable, None] = field(default_factory=dict) + + +LookupBranches: TypeAlias = "dict[str | FollowsPredicater | None, LookupBranch]" +LookupCollection: TypeAlias = "dict[str, LookupBranches]" + + +@dataclass +class TargetOverloadSignature: + order: str + predicators: dict[str, FollowsPredicater] = field(default_factory=dict) + + def __hash__(self) -> int: + return hash(("TOS", self.order, tuple(self.predicators.items()))) + + +class TargetOverload(FnOverload[TargetOverloadSignature, tuple[str, dict[str, FollowsPredicater]], Selector]): + def digest(self, collect_value: tuple[str, dict[str, FollowsPredicater]]) -> TargetOverloadSignature: + return TargetOverloadSignature(collect_value[0], collect_value[1]) + + def collect(self, scope: dict, signature: TargetOverloadSignature) -> dict[Callable, None]: + pattern_items = _parse_follows(signature.order) + if not pattern_items: + raise ValueError("invalid target pattern") + + processing_level = scope + + if TYPE_CHECKING: + branch = LookupBranch(LookupBranchMetadata(), {}) + + for item in pattern_items: + if item.name not in processing_level: + processing_level[item.name] = {} + + branches = processing_level[item.name] + if (item.literal or item.predicate) in branches: + branch = branches[item.literal or item.predicate] + else: + branch = LookupBranch(LookupBranchMetadata(), {}) + branches[item.literal or item.predicate] = branch + + processing_level = branch.levels + + return branch.bind + + def harvest(self, scope: dict, value: Selector) -> dict[Callable, None]: + processing_scope: LookupCollection = scope + branch = None + + for k, v in value.pattern.items(): + if (branches := processing_scope.get(k)) is None: + return {} + + if v in branches: + header = v + else: + for _key, branch in branches.items(): + if callable(_key) and _key(v): + header = _key + break # hit predicate + else: + if None in branches: + header = None # hit default + elif "*" in branches: + return branches["*"].bind # hit wildcard + else: + return {} + + branch = branches[header] + processing_scope = branch.levels + if header is not None and None in branches: + processing_scope = branches[None].levels | processing_scope + + if branch is not None and branch.bind: + return branch.bind + else: + return {} + + def access(self, scope: dict, signature: TargetOverloadSignature) -> dict[Callable, None] | None: + pattern_items = _parse_follows(signature.order) + if not pattern_items: + raise ValueError("invalid target pattern") + + processing_level = scope + + if TYPE_CHECKING: + branch = LookupBranch(LookupBranchMetadata(), {}) + + for item in pattern_items: + if item.name not in processing_level: + return + + branches: LookupBranches = processing_level[item.name] + if (item.literal or item.predicate) not in branches: + return + + branch = branches[item.literal or item.predicate] + processing_level = branch.levels + + return branch.bind diff --git a/avilla/core/ryanvk/staff.py b/avilla/core/ryanvk/staff.py deleted file mode 100644 index 372b6e2f..00000000 --- a/avilla/core/ryanvk/staff.py +++ /dev/null @@ -1,122 +0,0 @@ -from __future__ import annotations - -from functools import reduce -from typing import TYPE_CHECKING, Any, Callable, ChainMap, overload - -from typing_extensions import ParamSpec, TypeVar, Unpack - -from avilla.core.builtins.capability import CoreCapability -from avilla.core.metadata import MetadataRoute -from avilla.core.selector import ( - FollowsPredicater, - Selector, - _FollowItem, - _parse_follows, -) -from graia.ryanvk import BaseCollector -from graia.ryanvk import Staff as BaseStaff - -from .descriptor.query import find_querier_steps, query_depth_generator - -if TYPE_CHECKING: - from avilla.core.metadata import Metadata - from avilla.core.resource import Resource - - from .descriptor.query import QueryHandler, QueryHandlerPerform - - -T = TypeVar("T") -R = TypeVar("R", covariant=True) -M = TypeVar("M", bound="Metadata") -P = ParamSpec("P") -P1 = ParamSpec("P1") -N = TypeVar("N") -Co = TypeVar("Co", bound="BaseCollector") - - -class Staff(BaseStaff): - """手杖与核心工艺 (Staff & Focus Craft).""" - - def get_context(self, target: Selector, *, via: Selector | None = None): - return self.call_fn(CoreCapability.get_context, target, via=via) - - async def fetch_resource(self, resource: Resource[T]) -> T: - return await self.get_fn_call(CoreCapability.fetch)(resource) - - @overload - async def pull_metadata( - self, - target: Selector, - route: type[M], - ) -> M: - ... - - @overload - async def pull_metadata( - self, - target: Selector, - route: MetadataRoute[Unpack[tuple[Any, ...]], T], - ) -> T: - ... - - async def pull_metadata( - self, - target: Selector, - route: ..., - ): - return await self.call_fn(CoreCapability.pull, target, route) - - async def query_entities(self, pattern: str, **predicators: FollowsPredicater): - items = _parse_follows(pattern, **predicators) - artifact_map = ChainMap(*self.artifact_collections) - steps = find_querier_steps(artifact_map, items) - - if steps is None: - return - - def build_handler(artifact: tuple[BaseCollector, QueryHandlerPerform]) -> QueryHandler: - async def handler(predicate: Callable[[str, str], bool] | str, previous: Selector | None = None): - collector, entity = artifact - - def _get_instance(_staff: Staff, _cls: type[N]) -> N: - if _cls not in _staff.instances: - res = _staff.instances[_cls] = _cls(_staff) - else: - res = _staff.instances[_cls] - - return res - - async for i in entity(_get_instance(self, collector.cls), predicate, previous): - yield i - - return handler - - def build_predicate(_steps: tuple[_FollowItem, ...]) -> Callable[[str, str], bool]: - mapping = {i.name: i for i in _steps} - - def predicater(key: str, value: str) -> bool: - if key not in mapping: - raise KeyError(f"expected existed key: {key}") - item = mapping[key] - if item.literal is not None: - return value == item.literal - elif item.predicate is not None: - return item.predicate(value) - return True - - return predicater - - handlers = [] - for follow_item, query_record in steps: - handlers.append((follow_item, build_handler(artifact_map[query_record]))) - - r = reduce( - lambda previous, current: query_depth_generator(current[1], build_predicate(current[0]), previous), - handlers, - None, - ) - if TYPE_CHECKING: - assert r is not None - - async for i in r: - yield i diff --git a/avilla/core/ryanvk/util.py b/avilla/core/ryanvk/util.py deleted file mode 100644 index c29b16a3..00000000 --- a/avilla/core/ryanvk/util.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from .overload.target import LookupCollection - - -def _merge_lookup_collection(self: LookupCollection, other: LookupCollection): - for key, branches in self.items(): - if (other_branches := other.pop(key, None)) is None: - continue - - for header, branch in branches.items(): - if (other_branch := other_branches.pop(header, None)) is None: - continue - - _merge_lookup_collection(branch.levels, other_branch.levels) - branch.bind |= other_branch.bind - - branches |= other_branches - - self |= other diff --git a/avilla/core/selector.py b/avilla/core/selector.py index 14ba3d54..34295969 100644 --- a/avilla/core/selector.py +++ b/avilla/core/selector.py @@ -221,5 +221,4 @@ def expects( @runtime_checkable class Selectable(Protocol): - def to_selector(self) -> Selector: - ... + def to_selector(self) -> Selector: ... diff --git a/avilla/core/service.py b/avilla/core/service.py index 6ea57057..ecb6ec86 100644 --- a/avilla/core/service.py +++ b/avilla/core/service.py @@ -1,8 +1,10 @@ from __future__ import annotations from collections import defaultdict +from functools import cached_property from typing import TYPE_CHECKING +from flywheel.context import InstanceContext from launart import Launart, Service from loguru import logger @@ -44,11 +46,19 @@ def required(self) -> set[str]: def stages(self): return {"preparing", "blocking", "cleanup"} - def get_interface(self, interface_type): - ... + def get_interface(self, interface_type): ... + + @cached_property + def instances_endpoint(self): + res = InstanceContext() + res.instances[type(self.avilla)] = self.avilla + return res async def launch(self, manager: Launart): + endp = self.instances_endpoint.scope() + async with self.stage("preparing"): + endp.__enter__() await self.avilla.broadcast.postEvent(ApplicationPreparing(self.avilla)) logger.info(AVILLA_ASCII_RAW_LOGO, alt=AVILLA_ASCII_LOGO) @@ -69,3 +79,4 @@ async def launch(self, manager: Launart): await self.avilla.broadcast.postEvent(ApplicationClosing(self.avilla)) await self.avilla.broadcast.postEvent(ApplicationClosed(self.avilla)) + endp.__exit__(None, None, None) diff --git a/avilla/core/typing.py b/avilla/core/typing.py index 7df84506..e4595cdf 100644 --- a/avilla/core/typing.py +++ b/avilla/core/typing.py @@ -17,5 +17,4 @@ @runtime_checkable class Ensureable(Protocol[_T]): - def ensure(self, interact: _T) -> Any: - ... + def ensure(self, interact: _T) -> Any: ... diff --git a/avilla/core/utilles/__init__.py b/avilla/core/utilles/__init__.py index f6bf5004..d1c7936b 100644 --- a/avilla/core/utilles/__init__.py +++ b/avilla/core/utilles/__init__.py @@ -22,3 +22,21 @@ def __init__(self, fget: Callable[[Any], _R_co] | classmethod[_T, [], _R_co]) -> def __get__(self, __obj: _T, __type: type[_T] | None = None, /) -> _R_co: return self.fget[0].__get__(__obj, __type)() + + +class cachedstatic(Generic[_T, _R_co]): + fget: tuple[staticmethod[[], _R_co]] + res: _R_co | None = None + + def __init__(self, fget: Callable[[], _R_co] | staticmethod[[], _R_co]) -> None: + if not isinstance(fget, staticmethod): + fget = staticmethod(fget) + + self.fget = (fget,) + + def __get__(self, __obj: _T, __type: type[_T] | None = None, /) -> _R_co: + if self.res is not None: + return self.res + + self.res = self.fget[0].__get__(__obj, __type)() + return self.res # type: ignore diff --git a/avilla/core/utilles/store.py b/avilla/core/utilles/store.py deleted file mode 100644 index 0d6e8ab1..00000000 --- a/avilla/core/utilles/store.py +++ /dev/null @@ -1,279 +0,0 @@ -""" -参考路径如下: - -## Cache path: - macOS: ~/Library/Caches/ - - Unix: ~/.cache/ (XDG default) - - Windows: C:\\Users\\\\AppData\\Local\\\\Cache - -## Data path: - macOS: ~/Library/Application Support/ - - Unix: ~/.local/share/ or in $XDG_DATA_HOME, if defined - - Win XP (not roaming): C:\\Documents and Settings\\\\Application Data\\ - - Win 7 (not roaming): C:\\Users\\\\AppData\\Local\\ - -## Config path: - macOS: same as user_data_dir - - Unix: ~/.config/ - - Win XP (roaming): C:\\Documents and Settings\\\\Local Settings\\Application Data\\ - - Win 7 (roaming): C:\\Users\\\\AppData\\Roaming\\ - -""" - -import os -import sys -from pathlib import Path -from typing import Callable, Literal, Optional - -from typing_extensions import ParamSpec - -WINDOWS = sys.platform.startswith("win") or (sys.platform == "cli" and os.name == "nt") - - -def user_cache_dir(appname: str) -> Path: - r""" - Return full path to the user-specific cache dir for this application. - "appname" is the name of application. - Typical user cache directories are: - macOS: ~/Library/Caches/ - Unix: ~/.cache/ (XDG default) - Windows: C:\\Users\\\\AppData\\Local\\\\Cache - On Windows the only suggestion in the MSDN docs is that local settings go - in the `CSIDL_LOCAL_APPDATA` directory. This is identical to the - non-roaming app data dir (the default returned by `user_data_dir`). Apps - typically put cache data somewhere *under* the given dir here. Some - examples: - ...\\Mozilla\\Firefox\\Profiles\\\\Cache - ...\\Acme\\SuperApp\\Cache\\1.0 - OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value. - """ - if WINDOWS: - return _get_win_folder("CSIDL_LOCAL_APPDATA") / appname / "Cache" - elif sys.platform == "darwin": - return Path("~/Library/Caches").expanduser() / appname - else: - return Path(os.getenv("XDG_CACHE_HOME", "~/.cache")).expanduser() / appname - - -def user_data_dir(appname: str, roaming: bool = False) -> Path: - r""" - Return full path to the user-specific data dir for this application. - "appname" is the name of application. - If None, just the system directory is returned. - "roaming" (boolean, default False) can be set True to use the Windows - roaming appdata directory. That means that for users on a Windows - network setup for roaming profiles, this user data will be - sync'd on login. See - - for a discussion of issues. - Typical user data directories are: - macOS: ~/Library/Application Support/ - Unix: ~/.local/share/ # or in - $XDG_DATA_HOME, if defined - Win XP (not roaming): C:\\Documents and Settings\\\\ ... - ...Application Data\\ - Win XP (roaming): C:\\Documents and Settings\\\\Local ... - ...Settings\\Application Data\\ - Win 7 (not roaming): C:\\Users\\\\AppData\\Local\\ - Win 7 (roaming): C:\\Users\\\\AppData\\Roaming\\ - For Unix, we follow the XDG spec and support $XDG_DATA_HOME. - That means, by default "~/.local/share/". - """ - if WINDOWS: - const = "CSIDL_APPDATA" if roaming else "CSIDL_LOCAL_APPDATA" - return Path(_get_win_folder(const)) / appname # type: ignore - elif sys.platform == "darwin": - return Path("~/Library/Application Support/").expanduser() / appname - else: - return Path(os.getenv("XDG_DATA_HOME", "~/.local/share")).expanduser() / appname - - -def user_config_dir(appname: str, roaming: bool = True) -> Path: - """Return full path to the user-specific config dir for this application. - "appname" is the name of application. - If None, just the system directory is returned. - "roaming" (boolean, default True) can be set False to not use the - Windows roaming appdata directory. That means that for users on a - Windows network setup for roaming profiles, this user data will be - sync'd on login. See - - for a discussion of issues. - Typical user data directories are: - macOS: same as user_data_dir - Unix: ~/.config/ - Win *: same as user_data_dir - For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME. - That means, by default "~/.config/". - """ - if WINDOWS: - return user_data_dir(appname, roaming=roaming) - elif sys.platform == "darwin": - return user_data_dir(appname) - else: - return Path(os.getenv("XDG_CONFIG_HOME", "~/.config")).expanduser() / appname - - -# -- Windows support functions -- -def _get_win_folder_from_registry( - csidl_name: Literal["CSIDL_APPDATA", "CSIDL_COMMON_APPDATA", "CSIDL_LOCAL_APPDATA"] -) -> Path: - """ - This is a fallback technique at best. I'm not sure if using the - registry for this guarantees us the correct answer for all CSIDL_* - names. - """ - import winreg - - shell_folder_name = { - "CSIDL_APPDATA": "AppData", - "CSIDL_COMMON_APPDATA": "Common AppData", - "CSIDL_LOCAL_APPDATA": "Local AppData", - }[csidl_name] - - key = winreg.OpenKey( - winreg.HKEY_CURRENT_USER, - r"Software\\Microsoft\\Windows\\CurrentVersion\\Explorer\\Shell Folders", - ) - directory, _type = winreg.QueryValueEx(key, shell_folder_name) - return Path(directory) - - -def _get_win_folder_with_ctypes( - csidl_name: Literal["CSIDL_APPDATA", "CSIDL_COMMON_APPDATA", "CSIDL_LOCAL_APPDATA"] -) -> Path: - import ctypes - - csidl_const = { - "CSIDL_APPDATA": 26, - "CSIDL_COMMON_APPDATA": 35, - "CSIDL_LOCAL_APPDATA": 28, - }[csidl_name] - - buf = ctypes.create_unicode_buffer(1024) - ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf) - - # Downgrade to short path name if have highbit chars. See - # . - has_high_char = any(ord(c) > 255 for c in buf) - if has_high_char: - buf2 = ctypes.create_unicode_buffer(1024) - if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024): - buf = buf2 - - return Path(buf.value) - - -if WINDOWS: - try: - _get_win_folder = _get_win_folder_with_ctypes - except ImportError: - _get_win_folder = _get_win_folder_from_registry - - -P = ParamSpec("P") - -APP_NAME = "avilla" -BASE_CACHE_DIR = user_cache_dir(APP_NAME).resolve() -BASE_CONFIG_DIR = user_config_dir(APP_NAME).resolve() -BASE_DATA_DIR = user_data_dir(APP_NAME).resolve() - - -def _ensure_dir(path: Path) -> None: - if not path.exists(): - path.mkdir(parents=True, exist_ok=True) - elif not path.is_dir(): - raise RuntimeError(f"{path} is not a directory") - - -def _auto_create_dir(func: Callable[P, Path]) -> Callable[P, Path]: - def wrapper(*args: P.args, **kwargs: P.kwargs) -> Path: - path = func(*args, **kwargs) - _ensure_dir(path) - return path - - return wrapper - - -@_auto_create_dir -def get_cache_dir(plugin_name: Optional[str]) -> Path: - """ - macOS: ~/Library/Caches/ - - Unix: ~/.cache/ (XDG default) - - Windows: C:\\Users\\\\AppData\\Local\\\\Cache - """ - return BASE_CACHE_DIR / plugin_name if plugin_name else BASE_CACHE_DIR - - -def get_cache_file(plugin_name: Optional[str], filename: str) -> Path: - """ - macOS: ~/Library/Caches/ - - Unix: ~/.cache/ (XDG default) - - Windows: C:\\Users\\\\AppData\\Local\\\\Cache - """ - return get_cache_dir(plugin_name) / filename - - -@_auto_create_dir -def get_config_dir(plugin_name: Optional[str]) -> Path: - """ - macOS: same as user_data_dir - - Unix: ~/.config/ - - Win XP (roaming): C:\\Documents and Settings\\\\Local Settings\\Application Data\\ - - Win 7 (roaming): C:\\Users\\\\AppData\\Roaming\\ - """ - return BASE_CONFIG_DIR / plugin_name if plugin_name else BASE_CONFIG_DIR - - -def get_config_file(plugin_name: Optional[str], filename: str) -> Path: - """ - macOS: same as user_data_dir - - Unix: ~/.config/ - - Win XP (roaming): C:\\Documents and Settings\\\\Local Settings\\Application Data\\ - - Win 7 (roaming): C:\\Users\\\\AppData\\Roaming\\ - """ - return get_config_dir(plugin_name) / filename - - -@_auto_create_dir -def get_data_dir(plugin_name: Optional[str]) -> Path: - """ - macOS: ~/Library/Application Support/ - - Unix: ~/.local/share/ or in $XDG_DATA_HOME, if defined - - Win XP (not roaming): C:\\Documents and Settings\\\\Application Data\\ - - Win 7 (not roaming): C:\\Users\\\\AppData\\Local\\ - """ - return BASE_DATA_DIR / plugin_name if plugin_name else BASE_DATA_DIR - - -def get_data_file(plugin_name: Optional[str], filename: str) -> Path: - """ - macOS: ~/Library/Application Support/ - - Unix: ~/.local/share/ or in $XDG_DATA_HOME, if defined - - Win XP (not roaming): C:\\Documents and Settings\\\\Application Data\\ - - Win 7 (not roaming): C:\\Users\\\\AppData\\Local\\ - """ - return get_data_dir(plugin_name) / filename diff --git a/avilla/elizabeth/capability.py b/avilla/elizabeth/capability.py index a165ad0b..78f939b3 100644 --- a/avilla/elizabeth/capability.py +++ b/avilla/elizabeth/capability.py @@ -5,7 +5,7 @@ from graia.amnesia.message import Element, MessageChain from avilla.core.event import AvillaEvent -from avilla.core.ryanvk.collector.application import ApplicationCollector +from avilla.core.ryanvk_old.collector.application import ApplicationCollector from graia.ryanvk import Fn, PredicateOverload, TypeOverload if TYPE_CHECKING: @@ -14,8 +14,7 @@ class ElizabethCapability((m := ApplicationCollector())._): @Fn.complex({PredicateOverload(lambda _, raw: raw["type"]): ["raw_event"]}) - async def event_callback(self, raw_event: dict) -> AvillaEvent | None: - ... + async def event_callback(self, raw_event: dict) -> AvillaEvent | None: ... @Fn.complex({PredicateOverload(lambda _, raw: raw["type"]): ["raw_element"]}) async def deserialize_element(self, raw_element: dict) -> Element: # type: ignore diff --git a/avilla/elizabeth/collector/connection.py b/avilla/elizabeth/collector/connection.py index 550baa84..1b954dc4 100644 --- a/avilla/elizabeth/collector/connection.py +++ b/avilla/elizabeth/collector/connection.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, ClassVar, TypeVar -from avilla.core.ryanvk.collector.base import AvillaBaseCollector +from avilla.core.ryanvk_old.collector.base import AvillaBaseCollector from graia.ryanvk import Access, BasePerform if TYPE_CHECKING: @@ -32,7 +32,6 @@ class PerformTemplate( ConnectionBasedPerformTemplate, upper, native=True, - ): - ... + ): ... return PerformTemplate diff --git a/avilla/elizabeth/connection/base.py b/avilla/elizabeth/connection/base.py index 212ac8bb..e0ebd056 100644 --- a/avilla/elizabeth/connection/base.py +++ b/avilla/elizabeth/connection/base.py @@ -8,7 +8,7 @@ from typing_extensions import Self from avilla.core.exceptions import InvalidAuthentication -from avilla.core.ryanvk.staff import Staff +from avilla.core.ryanvk_old.staff import Staff from avilla.core.selector import Selector from avilla.elizabeth.capability import ElizabethCapability from avilla.standard.core.account import AccountAvailable @@ -47,18 +47,14 @@ def get_staff_artifacts(self): def staff(self): return Staff(self.get_staff_artifacts(), self.get_staff_components()) - def message_receive(self) -> AsyncIterator[tuple[Self, dict]]: - ... + def message_receive(self) -> AsyncIterator[tuple[Self, dict]]: ... @property - def alive(self) -> bool: - ... + def alive(self) -> bool: ... - async def wait_for_available(self): - ... + async def wait_for_available(self): ... - async def send(self, payload: dict) -> None: - ... + async def send(self, payload: dict) -> None: ... async def message_handle(self): async for connection, data in self.message_receive(): @@ -140,5 +136,4 @@ async def call( finally: del self.response_waiters[echo] - async def call_http(self, method: CallMethod, action: str, params: dict | None = None) -> dict: - ... + async def call_http(self, method: CallMethod, action: str, params: dict | None = None) -> dict: ... diff --git a/avilla/elizabeth/connection/util.py b/avilla/elizabeth/connection/util.py index da9a636a..e3f76462 100644 --- a/avilla/elizabeth/connection/util.py +++ b/avilla/elizabeth/connection/util.py @@ -34,13 +34,11 @@ @overload -def validate_response(data: Any, raising: Literal[False]) -> Any | Exception: - ... +def validate_response(data: Any, raising: Literal[False]) -> Any | Exception: ... @overload -def validate_response(data: Any, raising: Literal[True] = True) -> Any: - ... +def validate_response(data: Any, raising: Literal[True] = True) -> Any: ... def validate_response(data: dict, raising: bool = True): diff --git a/avilla/elizabeth/exception.py b/avilla/elizabeth/exception.py index ea9bc3b4..6492dc05 100644 --- a/avilla/elizabeth/exception.py +++ b/avilla/elizabeth/exception.py @@ -1,6 +1,5 @@ """Ariadne 的异常定义""" - from avilla.core.exceptions import ( InvalidAuthentication, InvalidOperation, diff --git a/avilla/elizabeth/file/capability.py b/avilla/elizabeth/file/capability.py index 9ddb858b..53658832 100644 --- a/avilla/elizabeth/file/capability.py +++ b/avilla/elizabeth/file/capability.py @@ -3,7 +3,7 @@ import os from typing import IO -from avilla.core.ryanvk import Fn, TargetOverload +from avilla.core.ryanvk_old import Fn, TargetOverload from avilla.core.selector import Selector from graia.ryanvk.capability import Capability @@ -12,37 +12,27 @@ class FileUpload(Capability): @Fn.complex({TargetOverload(): ["target"]}) async def upload( self, target: Selector, name: str, file: bytes | IO[bytes] | os.PathLike, path: str | None = None - ) -> Selector: - ... + ) -> Selector: ... class FileDirectoryCreate(Capability): @Fn.complex({TargetOverload(): ["target"]}) - async def create( - self, target: Selector, name: str, parent: str | None = None - ) -> Selector: - ... + async def create(self, target: Selector, name: str, parent: str | None = None) -> Selector: ... class FileDelete(Capability): @Fn.complex({TargetOverload(): ["file"]}) async def delete( - self, file: Selector, - ) -> None: - ... + self, + file: Selector, + ) -> None: ... class FileMove(Capability): @Fn.complex({TargetOverload(): ["file"]}) - async def move( - self, file: Selector, to: Selector - ) -> None: - ... + async def move(self, file: Selector, to: Selector) -> None: ... class FileRename(Capability): @Fn.complex({TargetOverload(): ["file"]}) - async def rename( - self, file: Selector, name: str - ) -> None: - ... \ No newline at end of file + async def rename(self, file: Selector, name: str) -> None: ... diff --git a/avilla/elizabeth/perform/action/activity.py b/avilla/elizabeth/perform/action/activity.py index a49d4e9b..11d55b5a 100644 --- a/avilla/elizabeth/perform/action/activity.py +++ b/avilla/elizabeth/perform/action/activity.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.activity import ActivityTrigger diff --git a/avilla/elizabeth/perform/action/announcement.py b/avilla/elizabeth/perform/action/announcement.py index ae4e3292..c5aac22c 100644 --- a/avilla/elizabeth/perform/action/announcement.py +++ b/avilla/elizabeth/perform/action/announcement.py @@ -8,7 +8,7 @@ from graia.amnesia.builtins.memcache import Memcache, MemcacheService -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.qq.announcement import ( Announcement, diff --git a/avilla/elizabeth/perform/action/contact.py b/avilla/elizabeth/perform/action/contact.py index dcc9c49c..b8973754 100644 --- a/avilla/elizabeth/perform/action/contact.py +++ b/avilla/elizabeth/perform/action/contact.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.profile import Avatar, Nick, Summary diff --git a/avilla/elizabeth/perform/action/file.py b/avilla/elizabeth/perform/action/file.py index 91ae19be..c93eaa2d 100644 --- a/avilla/elizabeth/perform/action/file.py +++ b/avilla/elizabeth/perform/action/file.py @@ -7,7 +7,7 @@ from graia.amnesia.builtins.memcache import Memcache, MemcacheService -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.elizabeth.file import ( FileData, @@ -45,7 +45,8 @@ async def get_file(self, target: Selector, route: ...) -> FileData: ) file = FileData.parse(result) await cache.set( - f"elizabeth/account({self.account.route['account']}).group({target['group']}).file({target['file']})", file, + f"elizabeth/account({self.account.route['account']}).group({target['group']}).file({target['file']})", + file, timedelta(minutes=5), ) return file diff --git a/avilla/elizabeth/perform/action/friend.py b/avilla/elizabeth/perform/action/friend.py index 454282e9..ada776e9 100644 --- a/avilla/elizabeth/perform/action/friend.py +++ b/avilla/elizabeth/perform/action/friend.py @@ -4,7 +4,7 @@ from graia.amnesia.builtins.memcache import Memcache, MemcacheService -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.profile import Avatar, Nick, Summary from avilla.standard.core.relation.capability import RelationshipTerminate diff --git a/avilla/elizabeth/perform/action/group.py b/avilla/elizabeth/perform/action/group.py index aadd62df..30a409b5 100644 --- a/avilla/elizabeth/perform/action/group.py +++ b/avilla/elizabeth/perform/action/group.py @@ -5,7 +5,7 @@ from graia.amnesia.builtins.memcache import Memcache, MemcacheService from avilla.core.exceptions import permission_error_message -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.elizabeth.const import PRIVILEGE_LEVEL from avilla.standard.core.privilege import MuteAllCapability, Privilege diff --git a/avilla/elizabeth/perform/action/member.py b/avilla/elizabeth/perform/action/member.py index b3a8a99b..f5e3cc04 100644 --- a/avilla/elizabeth/perform/action/member.py +++ b/avilla/elizabeth/perform/action/member.py @@ -6,7 +6,7 @@ from graia.amnesia.builtins.memcache import Memcache, MemcacheService from avilla.core.exceptions import permission_error_message -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.elizabeth.const import PRIVILEGE_LEVEL, PRIVILEGE_TRANS from avilla.standard.core.privilege import ( diff --git a/avilla/elizabeth/perform/action/message.py b/avilla/elizabeth/perform/action/message.py index 39733651..eb221623 100644 --- a/avilla/elizabeth/perform/action/message.py +++ b/avilla/elizabeth/perform/action/message.py @@ -7,7 +7,7 @@ from avilla.core import Context from avilla.core.message import Message -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.elizabeth.capability import ElizabethCapability from avilla.standard.core.message import ( diff --git a/avilla/elizabeth/perform/action/request.py b/avilla/elizabeth/perform/action/request.py index 28b2af3a..d3e856fc 100644 --- a/avilla/elizabeth/perform/action/request.py +++ b/avilla/elizabeth/perform/action/request.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.request import RequestCapability diff --git a/avilla/elizabeth/perform/context.py b/avilla/elizabeth/perform/context.py index adb8ab24..0fdd5c93 100644 --- a/avilla/elizabeth/perform/context.py +++ b/avilla/elizabeth/perform/context.py @@ -4,7 +4,7 @@ from avilla.core.builtins.capability import CoreCapability from avilla.core.context import Context -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector if TYPE_CHECKING: diff --git a/avilla/elizabeth/perform/event/message.py b/avilla/elizabeth/perform/event/message.py index 6bbb6c1d..3422a2b5 100644 --- a/avilla/elizabeth/perform/event/message.py +++ b/avilla/elizabeth/perform/event/message.py @@ -152,11 +152,12 @@ async def group_recall(self, raw_event: dict): group = Selector().land("qq").group(str(group_data["id"])) author = group.member(str(raw_event["authorId"])) author_data = await self.connection.call( - "fetch", "memberInfo", + "fetch", + "memberInfo", { "target": group_data["id"], "memberId": raw_event["authorId"], - } + }, ) operator_data = raw_event["operator"] operator = group.member(str(operator_data["id"])) diff --git a/avilla/elizabeth/perform/message/deserialize.py b/avilla/elizabeth/perform/message/deserialize.py index 1bda4079..e5f8921f 100644 --- a/avilla/elizabeth/perform/message/deserialize.py +++ b/avilla/elizabeth/perform/message/deserialize.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING from avilla.core.elements import Audio, Face, File, Notice, NoticeAll, Picture, Text -from avilla.core.ryanvk.collector.application import ApplicationCollector +from avilla.core.ryanvk_old.collector.application import ApplicationCollector from avilla.core.selector import Selector from avilla.elizabeth.capability import ElizabethCapability from avilla.elizabeth.resource import ( diff --git a/avilla/elizabeth/perform/message/serialize.py b/avilla/elizabeth/perform/message/serialize.py index d4a9f847..817438f4 100644 --- a/avilla/elizabeth/perform/message/serialize.py +++ b/avilla/elizabeth/perform/message/serialize.py @@ -6,7 +6,7 @@ from avilla.core.elements import Audio, Face, Notice, NoticeAll, Picture, Text from avilla.core.resource import LocalFileResource, RawResource, UrlResource -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.elizabeth.capability import ElizabethCapability from avilla.elizabeth.resource import ElizabethImageResource, ElizabethVoiceResource from avilla.standard.qq.elements import ( diff --git a/avilla/elizabeth/perform/query/announcement.py b/avilla/elizabeth/perform/query/announcement.py index 4f6820e8..9d061315 100644 --- a/avilla/elizabeth/perform/query/announcement.py +++ b/avilla/elizabeth/perform/query/announcement.py @@ -6,7 +6,7 @@ from graia.amnesia.builtins.memcache import MemcacheService from avilla.core.builtins.capability import CoreCapability -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector if TYPE_CHECKING: diff --git a/avilla/elizabeth/perform/query/bot.py b/avilla/elizabeth/perform/query/bot.py index 45df0b6c..6e98b31d 100644 --- a/avilla/elizabeth/perform/query/bot.py +++ b/avilla/elizabeth/perform/query/bot.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Callable, cast from avilla.core.builtins.capability import CoreCapability -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector if TYPE_CHECKING: diff --git a/avilla/elizabeth/perform/query/file.py b/avilla/elizabeth/perform/query/file.py index 9e98578f..f4d43da0 100644 --- a/avilla/elizabeth/perform/query/file.py +++ b/avilla/elizabeth/perform/query/file.py @@ -6,7 +6,7 @@ from graia.amnesia.builtins.memcache import MemcacheService from avilla.core.builtins.capability import CoreCapability -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.elizabeth.file import FileData @@ -23,7 +23,9 @@ class ElizabethAnnouncementQueryPerform((m := AccountCollector["ElizabethProtoco async def query_group_file(self, predicate: Callable[[str, str], bool] | str, previous: Selector): cache = self.protocol.avilla.launch_manager.get_component(MemcacheService).cache result = await self.account.connection.call( - "fetch", "file_list", {"id": "", "target": int(previous["group"]), "offset": 0, "size": 1, "withDownloadInfo": "True"} + "fetch", + "file_list", + {"id": "", "target": int(previous["group"]), "offset": 0, "size": 1, "withDownloadInfo": "True"}, ) result = cast(list, result) for i in result: diff --git a/avilla/elizabeth/perform/query/friend.py b/avilla/elizabeth/perform/query/friend.py index 2fdc3325..090ba489 100644 --- a/avilla/elizabeth/perform/query/friend.py +++ b/avilla/elizabeth/perform/query/friend.py @@ -6,7 +6,7 @@ from graia.amnesia.builtins.memcache import MemcacheService from avilla.core.builtins.capability import CoreCapability -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector if TYPE_CHECKING: diff --git a/avilla/elizabeth/perform/query/group.py b/avilla/elizabeth/perform/query/group.py index 69e39ac5..43c10a22 100644 --- a/avilla/elizabeth/perform/query/group.py +++ b/avilla/elizabeth/perform/query/group.py @@ -6,7 +6,7 @@ from graia.amnesia.builtins.memcache import Memcache, MemcacheService from avilla.core.builtins.capability import CoreCapability -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector if TYPE_CHECKING: diff --git a/avilla/elizabeth/perform/resource_fetch.py b/avilla/elizabeth/perform/resource_fetch.py index ebb2086e..aa58239f 100644 --- a/avilla/elizabeth/perform/resource_fetch.py +++ b/avilla/elizabeth/perform/resource_fetch.py @@ -6,7 +6,7 @@ from avilla.core.builtins.capability import CoreCapability from avilla.core.exceptions import UnknownTarget -from avilla.core.ryanvk.collector.protocol import ProtocolCollector +from avilla.core.ryanvk_old.collector.protocol import ProtocolCollector from avilla.elizabeth.resource import ( ElizabethImageResource, ElizabethResource, diff --git a/avilla/nonebridge/adapter.py b/avilla/nonebridge/adapter.py index 4a46dd83..5d2d9aa5 100644 --- a/avilla/nonebridge/adapter.py +++ b/avilla/nonebridge/adapter.py @@ -4,7 +4,7 @@ from nonebot.adapters import Adapter as BaseAdapter -from avilla.core._runtime import cx_context +from avilla.core.globals import CONTEXT_CONTEXT_VAR from .bot import NoneBridgeBot @@ -30,7 +30,7 @@ def driver(self): async def _call_api(self, bot: NoneBridgeBot, api: str, **data: Any): staff = bot.service.staff - maybe_cx = cx_context.get(None) + maybe_cx = CONTEXT_CONTEXT_VAR.get(None) if maybe_cx is not None: staff = staff.ext(maybe_cx.get_staff_components()) else: diff --git a/avilla/nonebridge/capability.py b/avilla/nonebridge/capability.py index afc7cb4a..7dbf2d9a 100644 --- a/avilla/nonebridge/capability.py +++ b/avilla/nonebridge/capability.py @@ -2,11 +2,8 @@ from typing import TYPE_CHECKING -from graia.amnesia.message import Element, MessageChain -from avilla.core.event import AvillaEvent -from avilla.core.ryanvk.collector.application import ApplicationCollector -from graia.ryanvk import Fn, PredicateOverload, TypeOverload +from avilla.core.ryanvk_old.collector.application import ApplicationCollector if TYPE_CHECKING: pass diff --git a/avilla/nonebridge/perform/event/message.py b/avilla/nonebridge/perform/event/message.py index 7b2c4cad..50f56e47 100644 --- a/avilla/nonebridge/perform/event/message.py +++ b/avilla/nonebridge/perform/event/message.py @@ -1,6 +1,6 @@ from __future__ import annotations -from avilla.core.ryanvk.collector.context import ContextCollector +from avilla.core.ryanvk_old.collector.context import ContextCollector from avilla.standard.core.message import MessageReceived from ...descriptor.event import NoneEventTranslate diff --git a/avilla/nonebridge/service.py b/avilla/nonebridge/service.py index 6f7c1d08..e32b421f 100644 --- a/avilla/nonebridge/service.py +++ b/avilla/nonebridge/service.py @@ -13,10 +13,10 @@ from avilla.core.account import BaseAccount from avilla.core.event import AvillaEvent -from avilla.core.ryanvk.staff import Staff +from avilla.core.ryanvk_old.staff import Staff from avilla.core.utilles import identity from avilla.standard.core.account import AccountRegistered, AccountUnregistered -from graia.ryanvk import merge, ref +from graia.ryanvk import ref from graia.ryanvk.aio import queue_task from .adapter import NoneBridgeAdapter diff --git a/avilla/onebot/v11/bases.py b/avilla/onebot/v11/bases.py new file mode 100644 index 00000000..62a7759e --- /dev/null +++ b/avilla/onebot/v11/bases.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from flywheel import InstanceOf +from avilla.core.ryanvk.bases import InstanceOfAvilla +from avilla.onebot.v11.protocol import OneBot11Protocol +from avilla.onebot.v11.account import OneBot11Account + +class InstanceOfProtocol(InstanceOfAvilla): + protocol = InstanceOf(OneBot11Protocol) + +class InstanceOfAccount(InstanceOfProtocol): + account = InstanceOf(OneBot11Account) diff --git a/avilla/onebot/v11/capability.py b/avilla/onebot/v11/capability.py index 6656ce53..a6616d80 100644 --- a/avilla/onebot/v11/capability.py +++ b/avilla/onebot/v11/capability.py @@ -1,13 +1,18 @@ from __future__ import annotations -from typing import Any +from typing import Any, Protocol, TypeVar +from flywheel.globals import INSTANCE_CONTEXT_VAR from graia.amnesia.message import Element, MessageChain from avilla.core.event import AvillaEvent -from avilla.core.ryanvk.collector.application import ApplicationCollector +from flywheel import Fn, FnCompose, OverloadRecorder, FnRecord, SimpleOverload, TypeOverload +from avilla.core.application import Avilla from avilla.standard.core.application import AvillaLifecycleEvent -from graia.ryanvk import Fn, PredicateOverload, TypeOverload + +El = TypeVar("El", bound=Element) +El_co = TypeVar("El_co", bound=Element, covariant=True) +El_contra = TypeVar("El_contra", bound=Element, contravariant=True) SPECIAL_POST_TYPE = {"message_sent": "message"} @@ -20,38 +25,75 @@ def onebot11_event_type(raw: dict) -> str: ) -class OneBot11Capability((m := ApplicationCollector())._): - @Fn.complex({PredicateOverload(lambda _, raw: onebot11_event_type(raw)): ["raw_event"]}) - async def event_callback(self, raw_event: dict) -> AvillaEvent | AvillaLifecycleEvent | None: - ... +class OneBot11Capability: + @Fn.declare + class event_callback(FnCompose): + event_type = SimpleOverload("raw_event") + + async def call(self, record: FnRecord, raw_event: dict): + entities = self.load(self.event_type.dig(record, onebot11_event_type(raw_event))) + return await entities.first(raw_event=raw_event) + + class shapecall(Protocol): + async def __call__(self, raw_event: dict) -> AvillaEvent | AvillaLifecycleEvent: + ... + + def collect(self, recorder: OverloadRecorder[shapecall], event: str): + recorder.use(self.event_type, event) + + @Fn.declare + class serialize_element(FnCompose): + element = TypeOverload("element") + + async def call(self, record: FnRecord, element: Element): + entities = self.load(self.element.dig(record, element)) + return await entities.first(element=element) + + class shapecall(Protocol[El_contra]): + async def __call__(self, element: El_contra) -> dict: + ... + + def collect(self, recorder: OverloadRecorder[shapecall[El]], element_type: type[El]): + recorder.use(self.element, element_type) + + @Fn.declare + class deserialize_element(FnCompose): + element = SimpleOverload("element") + + async def call(self, record: FnRecord, element: dict): + entities = self.load(self.element.dig(record, element['type'])) + return await entities.first(element=element) - @Fn.complex({PredicateOverload(lambda _, raw: raw["type"]): ["raw_element"]}) - async def deserialize_element(self, raw_element: dict) -> Element: # type: ignore - ... + class shapecall(Protocol): + async def __call__(self, element: dict) -> Element: + ... - @Fn.complex({TypeOverload(): ["element"]}) - async def serialize_element(self, element: Any) -> dict: # type: ignore - ... + def collect(self, recorder: OverloadRecorder[shapecall], element_type: str): + recorder.use(self.element, element_type) - async def deserialize_chain(self, chain: list[dict]): + @staticmethod + async def deserialize_chain(chain: list[dict]): elements = [] for raw_element in chain: - elements.append(await self.deserialize_element(raw_element)) + elements.append(await OneBot11Capability.deserialize_element(raw_element)) return MessageChain(elements) - async def serialize_chain(self, chain: MessageChain): + @staticmethod + async def serialize_chain(chain: MessageChain): elements = [] for element in chain: - elements.append(await self.serialize_element(element)) + elements.append(await OneBot11Capability.serialize_element(element)) return elements - async def handle_event(self, event: dict): - maybe_event = await self.event_callback(event) + @staticmethod + async def handle_event(event: dict): + maybe_event = await OneBot11Capability.event_callback(event) + avilla = INSTANCE_CONTEXT_VAR.get().instances[Avilla] if maybe_event is not None: - self.avilla.event_record(maybe_event) - self.avilla.broadcast.postEvent(maybe_event) + avilla.event_record(maybe_event) + avilla.broadcast.postEvent(maybe_event) diff --git a/avilla/onebot/v11/collector/__init__.py b/avilla/onebot/v11/collector/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/avilla/onebot/v11/collector/connection.py b/avilla/onebot/v11/collector/connection.py deleted file mode 100644 index 5172baaa..00000000 --- a/avilla/onebot/v11/collector/connection.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, ClassVar, TypeVar - -from avilla.core.ryanvk.collector.base import AvillaBaseCollector -from graia.ryanvk import Access, BasePerform - -if TYPE_CHECKING: - from avilla.onebot.v11.net.ws_client import OneBot11WsClientNetworking - from avilla.onebot.v11.protocol import OneBot11Protocol - - -T = TypeVar("T") -T1 = TypeVar("T1") - - -class ConnectionBasedPerformTemplate(BasePerform, native=True): - __collector__: ClassVar[ConnectionCollector] - - protocol: Access[OneBot11Protocol] = Access() - connection: Access[OneBot11WsClientNetworking] = Access() - - -class ConnectionCollector(AvillaBaseCollector): - post_applying: bool = False - - @property - def _(self): - upper = super()._ - - class PerformTemplate( - ConnectionBasedPerformTemplate, - upper, - native=True, - ): - ... - - return PerformTemplate diff --git a/avilla/onebot/v11/net/base.py b/avilla/onebot/v11/net/base.py index c4172619..40a7e097 100644 --- a/avilla/onebot/v11/net/base.py +++ b/avilla/onebot/v11/net/base.py @@ -8,7 +8,6 @@ from typing_extensions import Self from avilla.core.exceptions import ActionFailed -from avilla.core.ryanvk.staff import Staff from avilla.onebot.v11.capability import OneBot11Capability if TYPE_CHECKING: @@ -29,28 +28,14 @@ def __init__(self, protocol: OneBot11Protocol): self.response_waiters = {} self.close_signal = asyncio.Event() - def get_staff_components(self): - return {"connection": self, "protocol": self.protocol, "avilla": self.protocol.avilla} - - def get_staff_artifacts(self): - return [self.protocol.artifacts, self.protocol.avilla.global_artifacts] - - @property - def staff(self): - return Staff(self.get_staff_artifacts(), self.get_staff_components()) - - def message_receive(self) -> AsyncIterator[tuple[Self, dict]]: - ... + def message_receive(self) -> AsyncIterator[tuple[Self, dict]]: ... @property - def alive(self) -> bool: - ... + def alive(self) -> bool: ... - async def wait_for_available(self): - ... + async def wait_for_available(self): ... - async def send(self, payload: dict) -> None: - ... + async def send(self, payload: dict) -> None: ... async def message_handle(self): async for connection, data in self.message_receive(): @@ -61,7 +46,7 @@ async def message_handle(self): async def event_parse_task(data: dict): with suppress(NotImplementedError): - await OneBot11Capability(connection.staff).handle_event(data) + await OneBot11Capability.handle_event(data) return logger.warning(f"received unsupported event: {data}") diff --git a/avilla/onebot/v11/perform/action/admin.py b/avilla/onebot/v11/perform/action/admin.py index 8f858c1d..94474043 100644 --- a/avilla/onebot/v11/perform/action/admin.py +++ b/avilla/onebot/v11/perform/action/admin.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.privilege import PrivilegeCapability @@ -12,9 +12,6 @@ class OneBot11PrivilegeActionPerform((m := AccountCollector["OneBot11Protocol", "OneBot11Account"]())._): - m.namespace = "avilla.protocol/onebot11::action" - m.identify = "admin" - @PrivilegeCapability.upgrade.collect(m, target="land.group.member") async def upgrade_perm(self, target: Selector, dest: str | None = None): result = await self.account.connection.call( diff --git a/avilla/onebot/v11/perform/action/ban.py b/avilla/onebot/v11/perform/action/ban.py index 33cab3e0..4c4c8f37 100644 --- a/avilla/onebot/v11/perform/action/ban.py +++ b/avilla/onebot/v11/perform/action/ban.py @@ -3,7 +3,7 @@ from datetime import timedelta from typing import TYPE_CHECKING -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.privilege import BanCapability diff --git a/avilla/onebot/v11/perform/action/leave.py b/avilla/onebot/v11/perform/action/leave.py index 5700928f..c0872998 100644 --- a/avilla/onebot/v11/perform/action/leave.py +++ b/avilla/onebot/v11/perform/action/leave.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.relation import SceneCapability diff --git a/avilla/onebot/v11/perform/action/message.py b/avilla/onebot/v11/perform/action/message.py index 0f9d4930..61aeb34f 100644 --- a/avilla/onebot/v11/perform/action/message.py +++ b/avilla/onebot/v11/perform/action/message.py @@ -7,7 +7,7 @@ from avilla.core import Context from avilla.core.message import Message -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.onebot.v11.capability import OneBot11Capability from avilla.standard.core.message import MessageRevoke, MessageSend diff --git a/avilla/onebot/v11/perform/action/mute.py b/avilla/onebot/v11/perform/action/mute.py index 65b5ee6c..099f9110 100644 --- a/avilla/onebot/v11/perform/action/mute.py +++ b/avilla/onebot/v11/perform/action/mute.py @@ -3,7 +3,7 @@ from datetime import timedelta from typing import TYPE_CHECKING -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.privilege import MuteAllCapability, MuteCapability diff --git a/avilla/onebot/v11/perform/context.py b/avilla/onebot/v11/perform/context.py index fdf3f0ad..7a777976 100644 --- a/avilla/onebot/v11/perform/context.py +++ b/avilla/onebot/v11/perform/context.py @@ -4,7 +4,8 @@ from avilla.core.builtins.capability import CoreCapability from avilla.core.context import Context -from avilla.core.ryanvk.collector.account import AccountCollector +from flywheel import scoped_collect +from avilla.onebot.v11.bases import InstanceOfAccount from avilla.core.selector import Selector if TYPE_CHECKING: @@ -12,10 +13,8 @@ from avilla.onebot.v11.protocol import OneBot11Protocol # noqa -class OneBot11ContextPerform((m := AccountCollector["OneBot11Protocol", "OneBot11Account"]())._): - m.namespace = "avilla.protocol/onebot11::context" - - @m.entity(CoreCapability.get_context, target="land.group") +class OneBot11ContextPerform(m := scoped_collect.env().target, InstanceOfAccount, static=True): + @m.impl(CoreCapability.get_context, target="land.group") def get_context_from_group(self, target: Selector, *, via: Selector | None = None): return Context( self.account, @@ -25,7 +24,7 @@ def get_context_from_group(self, target: Selector, *, via: Selector | None = Non target.member(self.account.route["account"]), ) - @m.entity(CoreCapability.get_context, target="land.friend") + @m.impl(CoreCapability.get_context, target="land.friend") def get_context_from_friend(self, target: Selector, *, via: Selector | None = None): if via: return Context( @@ -37,11 +36,11 @@ def get_context_from_friend(self, target: Selector, *, via: Selector | None = No ) return Context(self.account, target, self.account.route, target, self.account.route) - @m.entity(CoreCapability.get_context, target="land.stranger") + @m.impl(CoreCapability.get_context, target="land.stranger") def get_context_from_stranger(self, target: Selector, *, via: Selector | None = None): return Context(self.account, target, self.account.route, target, self.account.route) - @m.entity(CoreCapability.get_context, target="land.group.member") + @m.impl(CoreCapability.get_context, target="land.group.member") def get_context_from_member(self, target: Selector, *, via: Selector | None = None): return Context( self.account, @@ -51,26 +50,26 @@ def get_context_from_member(self, target: Selector, *, via: Selector | None = No target.into(f"~.member({self.account.route['account']})"), ) - @m.entity(CoreCapability.channel, target="land.group") - @m.entity(CoreCapability.channel, target="land.group.member") + @m.impl(CoreCapability.channel, target="land.group") + @m.impl(CoreCapability.channel, target="land.group.member") def channel_from_group(self, target: Selector): return target["group"] - @m.entity(CoreCapability.guild, target="land.group") - @m.entity(CoreCapability.guild, target="land.group.member") + @m.impl(CoreCapability.guild, target="land.group") + @m.impl(CoreCapability.guild, target="land.group.member") def guild_from_group(self, target: Selector): return target["group"] - @m.entity(CoreCapability.user, target="land.group.member") + @m.impl(CoreCapability.user, target="land.group.member") def user_from_member(self, target: Selector): return target["member"] - @m.entity(CoreCapability.user, target="land.friend") - @m.entity(CoreCapability.channel, target="land.friend") + @m.impl(CoreCapability.user, target="land.friend") + @m.impl(CoreCapability.channel, target="land.friend") def user_from_friend(self, target: Selector): return target["friend"] - @m.entity(CoreCapability.user, target="land.stranger") - @m.entity(CoreCapability.channel, target="land.stranger") + @m.impl(CoreCapability.user, target="land.stranger") + @m.impl(CoreCapability.channel, target="land.stranger") def user_from_stranger(self, target: Selector): return target["stranger"] diff --git a/avilla/onebot/v11/perform/event/lifespan.py b/avilla/onebot/v11/perform/event/lifespan.py index 9c595290..aa485542 100644 --- a/avilla/onebot/v11/perform/event/lifespan.py +++ b/avilla/onebot/v11/perform/event/lifespan.py @@ -9,7 +9,6 @@ from avilla.core.selector import Selector from avilla.onebot.v11.account import OneBot11Account from avilla.onebot.v11.capability import OneBot11Capability -from avilla.onebot.v11.collector.connection import ConnectionCollector from avilla.standard.core.account.event import ( AccountAvailable, AccountRegistered, diff --git a/avilla/onebot/v11/perform/message/deserialize.py b/avilla/onebot/v11/perform/message/deserialize.py index d4c7358a..c4c1abf8 100644 --- a/avilla/onebot/v11/perform/message/deserialize.py +++ b/avilla/onebot/v11/perform/message/deserialize.py @@ -1,13 +1,13 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING +from contextlib import suppress from avilla.core.elements import Face, Notice, NoticeAll, Picture, Reference, Text -from avilla.core.ryanvk.collector.application import ApplicationCollector from avilla.core.selector import Selector from avilla.onebot.v11.capability import OneBot11Capability from avilla.onebot.v11.resource import OneBot11ImageResource +from avilla.core.context import Context from avilla.standard.qq.elements import ( Dice, FlashImage, @@ -18,95 +18,109 @@ Share, Xml, ) -from graia.ryanvk import OptionalAccess - -if TYPE_CHECKING: - from avilla.core.context import Context - from avilla.onebot.v11.account import OneBot11Account - - -class OneBot11MessageDeserializePerform((m := ApplicationCollector())._): - m.namespace = "avilla.protocol/onebot11::message" - m.identify = "deserialize" - - context: OptionalAccess[Context] = OptionalAccess() - account: OptionalAccess[OneBot11Account] = OptionalAccess() - # LINK: https://github.com/microsoft/pyright/issues/5409 - - @m.entity(OneBot11Capability.deserialize_element, raw_element="text") - async def text(self, raw_element: dict) -> Text: - return Text(raw_element["data"]["text"]) - - @m.entity(OneBot11Capability.deserialize_element, raw_element="face") - async def face(self, raw_element: dict) -> Face: - return Face(raw_element["data"]["id"]) - - @m.entity(OneBot11Capability.deserialize_element, raw_element="image") - async def image(self, raw_element: dict) -> Picture | FlashImage: - data: dict = raw_element["data"] - resource = OneBot11ImageResource(Selector().land("qq").picture(file := data["file"]), file, data["url"]) - return FlashImage(resource) if raw_element.get("type") == "flash" else Picture(resource) - - @m.entity(OneBot11Capability.deserialize_element, raw_element="at") - async def at(self, raw_element: dict) -> Notice | NoticeAll: - if raw_element["data"]["qq"] == "all": - return NoticeAll() - if self.context: - return Notice(self.context.scene.member(raw_element["data"]["qq"])) - return Notice(Selector().land("qq").member(raw_element["data"]["qq"])) - - @m.entity(OneBot11Capability.deserialize_element, raw_element="reply") - async def reply(self, raw_element: dict): - if self.context: - return Reference(self.context.scene.message(raw_element["data"]["id"])) - return Reference(Selector().land("qq").message(raw_element["data"]["id"])) - - @m.entity(OneBot11Capability.deserialize_element, raw_element="dice") - async def dice(self, raw_element: dict): - return Dice() - - @m.entity(OneBot11Capability.deserialize_element, raw_element="shake") - async def shake(self, raw_element: dict): - return Poke() - - @m.entity(OneBot11Capability.deserialize_element, raw_element="json") - async def json(self, raw_element: dict): - return Json(raw_element["data"]["content"]) - - @m.entity(OneBot11Capability.deserialize_element, raw_element="xml") - async def xml(self, raw_element: dict): - return Xml(raw_element["data"]["content"]) - - @m.entity(OneBot11Capability.deserialize_element, raw_element="share") - async def share(self, raw_element: dict): - return Share( - raw_element["data"]["url"], - raw_element["data"]["title"], - raw_element["data"].get("content", None), - raw_element["data"].get("image", None), - ) +from flywheel import global_collect +from flywheel.globals import INSTANCE_CONTEXT_VAR - @m.entity(OneBot11Capability.deserialize_element, raw_element="forward") - async def forward(self, raw_element: dict): - elem = Forward(raw_element["data"]["id"]) - if not self.account: - return elem - result = await self.account.connection.call( - "get_forward_msg", - { - "message_id": raw_element["data"]["id"], - }, - ) - if result is None: - return elem - for msg in result["messages"]: - node = Node( - name=msg["sender"]["nickname"], - uid=str(msg["sender"]["user_id"]), - time=datetime.fromtimestamp(msg["time"]), - content=await OneBot11Capability(self.account.staff).deserialize_chain(msg["content"]), - ) - elem.nodes.append(node) +from avilla.onebot.v11.account import OneBot11Account + +@global_collect +@OneBot11Capability.deserialize_element.impl(element_type="text") +async def text(element: dict) -> Text: + return Text(element["data"]["text"]) + + +@global_collect +@OneBot11Capability.deserialize_element.impl(element_type="face") +async def face(element: dict) -> Face: + return Face(element["data"]["id"]) + + +@global_collect +@OneBot11Capability.deserialize_element.impl(element_type="image") +async def image(element: dict) -> Picture | FlashImage: + data: dict = element["data"] + resource = OneBot11ImageResource(Selector().land("qq").picture(file := data["file"]), file, data["url"]) + return FlashImage(resource) if element.get("type") == "flash" else Picture(resource) + + +@global_collect +@OneBot11Capability.deserialize_element.impl(element_type="at") +async def at(element: dict) -> Notice | NoticeAll: + if element["data"]["qq"] == "all": + return NoticeAll() + with suppress(LookupError): + return Notice(Context.current.scene.member(element["data"]["qq"])) + return Notice(Selector().land("qq").member(element["data"]["qq"])) + + +@global_collect +@OneBot11Capability.deserialize_element.impl(element_type="reply") +async def reply(element: dict): + with suppress(LookupError): + return Reference(Context.current.scene.message(element["data"]["id"])) + return Reference(Selector().land("qq").message(element["data"]["id"])) + + +@global_collect +@OneBot11Capability.deserialize_element.impl(element_type="dice") +async def dice(element: dict): + return Dice() + + +@global_collect +@OneBot11Capability.deserialize_element.impl(element_type="shake") +async def shake(element: dict): + return Poke() + + +@global_collect +@OneBot11Capability.deserialize_element.impl(element_type="json") +async def json(element: dict): + return Json(element["data"]["content"]) + + +@global_collect +@OneBot11Capability.deserialize_element.impl(element_type="xml") +async def xml(element: dict): + return Xml(element["data"]["content"]) + + +@global_collect +@OneBot11Capability.deserialize_element.impl(element_type="share") +async def share(element: dict): + return Share( + element["data"]["url"], + element["data"]["title"], + element["data"].get("content", None), + element["data"].get("image", None), + ) + + +@global_collect +@OneBot11Capability.deserialize_element.impl(element_type="forward") +async def forward(element: dict): + elem = Forward(element["data"]["id"]) + + if OneBot11Account not in INSTANCE_CONTEXT_VAR.get().instances: return elem - # TODO + result = await INSTANCE_CONTEXT_VAR.get().instances[OneBot11Account].connection.call( + "get_forward_msg", + { + "message_id": element["data"]["id"], + }, + ) + if result is None: + return elem + for msg in result["messages"]: + node = Node( + name=msg["sender"]["nickname"], + uid=str(msg["sender"]["user_id"]), + time=datetime.fromtimestamp(msg["time"]), + content=await OneBot11Capability.deserialize_chain(msg["content"]), + ) + elem.nodes.append(node) + return elem + + +# TODO diff --git a/avilla/onebot/v11/perform/message/serialize.py b/avilla/onebot/v11/perform/message/serialize.py index 9c4a4013..bf5133db 100644 --- a/avilla/onebot/v11/perform/message/serialize.py +++ b/avilla/onebot/v11/perform/message/serialize.py @@ -1,11 +1,13 @@ from __future__ import annotations import base64 -from typing import TYPE_CHECKING, cast +from typing import cast +from flywheel import global_collect + +from avilla.core.builtins.capability import CoreCapability from avilla.core.elements import Face, Notice, NoticeAll, Picture, Text from avilla.core.resource import LocalFileResource, RawResource, UrlResource -from avilla.core.ryanvk.collector.account import AccountCollector from avilla.onebot.v11.capability import OneBot11Capability from avilla.onebot.v11.resource import OneBot11ImageResource from avilla.standard.qq.elements import ( @@ -20,137 +22,153 @@ Xml, ) -if TYPE_CHECKING: - from avilla.onebot.v11.account import OneBot11Account # noqa - from avilla.onebot.v11.protocol import OneBot11Protocol # noqa - - -class OneBot11MessageSerializePerform((m := AccountCollector["OneBot11Protocol", "OneBot11Account"]())._): - m.namespace = "avilla.protocol/onebot11::message" - m.identify = "serialize" - - # LINK: https://github.com/microsoft/pyright/issues/5409 - - @m.entity(OneBot11Capability.serialize_element, element=Text) - async def text(self, element: Text) -> dict: - return {"type": "text", "data": {"text": element.text}} - - @m.entity(OneBot11Capability.serialize_element, element=Face) - async def face(self, element: Face) -> dict: - return {"type": "face", "data": {"id": int(element.id)}} - - @m.entity(OneBot11Capability.serialize_element, element=Picture) - async def picture(self, element: Picture) -> dict: - if isinstance(element.resource, OneBot11ImageResource): - return { - "type": "image", - "data": { - "file": element.resource.file, - "url": element.resource.url, - }, - } - elif isinstance(element.resource, UrlResource): - return { - "type": "image", - "data": { - "url": element.resource.url, - }, - } - elif isinstance(element.resource, LocalFileResource): - data = base64.b64encode(element.resource.file.read_bytes()).decode("utf-8") - return { - "type": "image", - "data": { - "file": "base64://" + data, - }, - } - elif isinstance(element.resource, RawResource): - data = base64.b64encode(element.resource.data).decode("utf-8") - return { - "type": "image", - "data": { - "file": "base64://" + data, - }, - } - else: - return { - "type": "image", - "data": { - "file": "base64://" - + base64.b64encode(cast(bytes, await self.account.staff.fetch_resource(element.resource))).decode( - "utf-8" - ), - }, - } - - @m.entity(OneBot11Capability.serialize_element, element=FlashImage) - async def flash_image(self, element: FlashImage): - raw = await self.picture(element) - raw["data"]["type"] = "flash" - return raw - - @m.entity(OneBot11Capability.serialize_element, element=Notice) - async def notice(self, element: Notice): - return {"type": "at", "data": {"qq": element.target["member"]}} - - @m.entity(OneBot11Capability.serialize_element, element=NoticeAll) - async def notice_all(self, element: NoticeAll): - return {"type": "at", "data": {"qq": "all"}} - - @m.entity(OneBot11Capability.serialize_element, element=Dice) - async def dice(self, element: Dice): - return {"type": "dice", "data": {}} - - @m.entity(OneBot11Capability.serialize_element, element=MusicShare) - async def music_share(self, element: MusicShare): - raw = { - "type": "music", + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=Text) +async def text(element: Text) -> dict: + return {"type": "text", "data": {"text": element.text}} + + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=Face) +async def face(element: Face) -> dict: + return {"type": "face", "data": {"id": int(element.id)}} + + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=Picture) +async def picture(element: Picture) -> dict: + if isinstance(element.resource, OneBot11ImageResource): + return { + "type": "image", "data": { - "type": "custom", - "url": element.url, - "audio": element.audio, - "title": element.title, + "file": element.resource.file, + "url": element.resource.url, }, } - if element.content: - raw["data"]["content"] = element.content - if element.thumbnail: - raw["data"]["image"] = element.thumbnail - return raw - - @m.entity(OneBot11Capability.serialize_element, element=Gift) - async def gift(self, element: Gift): - return {"type": "gift", "data": {"id": element.kind.value, "qq": element.target["member"]}} - - @m.entity(OneBot11Capability.serialize_element, element=Json) - async def json(self, element: Json): - return {"type": "json", "data": {"data": element.content}} - - @m.entity(OneBot11Capability.serialize_element, element=Xml) - async def xml(self, element: Xml): - return {"type": "xml", "data": {"data": element.content}} - - @m.entity(OneBot11Capability.serialize_element, element=App) - async def app(self, element: App): - return {"type": "json", "data": {"data": element.content}} - - @m.entity(OneBot11Capability.serialize_element, element=Share) - async def share(self, element: Share): - res = { - "type": "share", + elif isinstance(element.resource, UrlResource): + return { + "type": "image", "data": { - "url": element.url, - "title": element.title, + "url": element.resource.url, }, } - if element.content: - res["data"]["content"] = element.content - if element.thumbnail: - res["data"]["image"] = element.thumbnail - return res + elif isinstance(element.resource, LocalFileResource): + data = base64.b64encode(element.resource.file.read_bytes()).decode("utf-8") + return { + "type": "image", + "data": { + "file": "base64://" + data, + }, + } + elif isinstance(element.resource, RawResource): + data = base64.b64encode(element.resource.data).decode("utf-8") + return { + "type": "image", + "data": { + "file": "base64://" + data, + }, + } + else: + return { + "type": "image", + "data": { + "file": "base64://" + + base64.b64encode(cast(bytes, await CoreCapability.fetch(element.resource))).decode("utf-8"), + }, + } + + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=FlashImage) +async def flash_image(element: FlashImage): + raw = await picture(element) + raw["data"]["type"] = "flash" + return raw + + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=Notice) +async def notice(element: Notice): + return {"type": "at", "data": {"qq": element.target["member"]}} + + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=NoticeAll) +async def notice_all(element: NoticeAll): + return {"type": "at", "data": {"qq": "all"}} + + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=Dice) +async def dice(element: Dice): + return {"type": "dice", "data": {}} + + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=MusicShare) +async def music_share(element: MusicShare): + raw = { + "type": "music", + "data": { + "type": "custom", + "url": element.url, + "audio": element.audio, + "title": element.title, + }, + } + if element.content: + raw["data"]["content"] = element.content + if element.thumbnail: + raw["data"]["image"] = element.thumbnail + return raw + + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=Gift) +async def gift(element: Gift): + return {"type": "gift", "data": {"id": element.kind.value, "qq": element.target["member"]}} + + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=Json) +async def json(element: Json): + return {"type": "json", "data": {"data": element.content}} + + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=Xml) +async def xml(element: Xml): + return {"type": "xml", "data": {"data": element.content}} + + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=App) +async def app(element: App): + return {"type": "json", "data": {"data": element.content}} + + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=Share) +async def share(element: Share): + res = { + "type": "share", + "data": { + "url": element.url, + "title": element.title, + }, + } + if element.content: + res["data"]["content"] = element.content + if element.thumbnail: + res["data"]["image"] = element.thumbnail + return res + + +@global_collect +@OneBot11Capability.serialize_element.impl(element_type=Poke) +async def poke(element: Poke): + return {"type": "shake", "data": {}} - @m.entity(OneBot11Capability.serialize_element, element=Poke) - async def poke(self, element: Poke): - return {"type": "shake", "data": {}} - # TODO +# TODO diff --git a/avilla/onebot/v11/perform/query/group.py b/avilla/onebot/v11/perform/query/group.py index dd43f1fb..68a34cdc 100644 --- a/avilla/onebot/v11/perform/query/group.py +++ b/avilla/onebot/v11/perform/query/group.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Callable, cast from avilla.core.builtins.capability import CoreCapability -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector if TYPE_CHECKING: diff --git a/avilla/onebot/v11/perform/resource_fetch.py b/avilla/onebot/v11/perform/resource_fetch.py index 665739e5..c1b48ee3 100644 --- a/avilla/onebot/v11/perform/resource_fetch.py +++ b/avilla/onebot/v11/perform/resource_fetch.py @@ -1,11 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from aiohttp import ClientSession from avilla.core.builtins.capability import CoreCapability -from avilla.core.ryanvk.collector.protocol import ProtocolCollector from avilla.onebot.v11.resource import ( OneBot11FileResource, OneBot11ImageResource, @@ -13,20 +10,15 @@ OneBot11Resource, OneBot11VideoResource, ) +from flywheel import global_collect -if TYPE_CHECKING: - from avilla.onebot.v11.protocol import OneBot11Protocol # noqa - - -class OneBot11ResourceFetchPerform((m := ProtocolCollector["OneBot11Protocol"]())._): - m.namespace = "avilla.protocol/onebot11::resource_fetch" - - @m.entity(CoreCapability.fetch, resource=OneBot11Resource) - @m.entity(CoreCapability.fetch, resource=OneBot11RecordResource) - @m.entity(CoreCapability.fetch, resource=OneBot11FileResource) - @m.entity(CoreCapability.fetch, resource=OneBot11ImageResource) - @m.entity(CoreCapability.fetch, resource=OneBot11VideoResource) - async def fetch_resource(self, resource: OneBot11Resource) -> bytes: - async with ClientSession() as session: - async with session.get(resource.url) as resp: - return await resp.read() +@global_collect +@CoreCapability.fetch.impl(resource=OneBot11RecordResource) # type: ignore +@CoreCapability.fetch.impl(resource=OneBot11FileResource) # type: ignore +@CoreCapability.fetch.impl(resource=OneBot11ImageResource) # type: ignore +@CoreCapability.fetch.impl(resource=OneBot11VideoResource) +@CoreCapability.fetch.impl(resource=OneBot11Resource) +async def fetch_resource(resource: OneBot11Resource) -> bytes: + async with ClientSession() as session: + async with session.get(resource.url) as resp: + return await resp.read() diff --git a/avilla/onebot/v11/protocol.py b/avilla/onebot/v11/protocol.py index aa07efaf..d9709227 100644 --- a/avilla/onebot/v11/protocol.py +++ b/avilla/onebot/v11/protocol.py @@ -6,7 +6,8 @@ from avilla.core.application import Avilla from avilla.core.protocol import BaseProtocol -from graia.ryanvk import merge, ref +from flywheel import CollectContext +from avilla.core.utilles import cachedstatic from .net.ws_client import OneBot11WsClientNetworking from .net.ws_server import OneBot11WsServerNetworking @@ -25,46 +26,27 @@ class OneBot11ReverseConfig: access_token: str | None = None -def _import_performs(): - from avilla.onebot.v11.perform import context, resource_fetch # noqa: F401 - from avilla.onebot.v11.perform.action import admin # noqa: F401 - from avilla.onebot.v11.perform.action import ban # noqa: F401 - from avilla.onebot.v11.perform.action import leave # noqa: F401 - from avilla.onebot.v11.perform.action import message # noqa: F401 - from avilla.onebot.v11.perform.action import mute # noqa: F401 - from avilla.onebot.v11.perform.event import lifespan # noqa: F401 - from avilla.onebot.v11.perform.event import message # noqa: F401, F811 - from avilla.onebot.v11.perform.event import notice # noqa: F401 - from avilla.onebot.v11.perform.event import request # noqa: F401 - from avilla.onebot.v11.perform.message import deserialize # noqa: F401 - from avilla.onebot.v11.perform.message import serialize # noqa: F401 - from avilla.onebot.v11.perform.query import group # noqa: F401 - - -_import_performs() - - class OneBot11Protocol(BaseProtocol): service: OneBot11Service - artifacts = { - **merge( - ref("avilla.protocol/onebot11::context"), - ref("avilla.protocol/onebot11::resource_fetch"), - ref("avilla.protocol/onebot11::action", "admin"), - ref("avilla.protocol/onebot11::action", "ban"), - ref("avilla.protocol/onebot11::action", "leave"), - ref("avilla.protocol/onebot11::action", "message"), - ref("avilla.protocol/onebot11::action", "mute"), - ref("avilla.protocol/onebot11::event", "message"), - ref("avilla.protocol/onebot11::event", "lifespan"), - ref("avilla.protocol/onebot11::event", "notice"), - ref("avilla.protocol/onebot11::event", "request"), - ref("avilla.protocol/onebot11::message", "deserialize"), - ref("avilla.protocol/onebot11::message", "serialize"), - ref("avilla.protocol/onebot11::query", "group"), - ), - } + @cachedstatic + def artifacts(): + with CollectContext().lookup_scope() as collect_context: + from avilla.onebot.v11.perform import context, resource_fetch # noqa: F401 + from avilla.onebot.v11.perform.action import admin # noqa: F401 + from avilla.onebot.v11.perform.action import ban # noqa: F401 + from avilla.onebot.v11.perform.action import leave # noqa: F401 + from avilla.onebot.v11.perform.action import message # noqa: F401 + from avilla.onebot.v11.perform.action import mute # noqa: F401 + from avilla.onebot.v11.perform.event import lifespan # noqa: F401 + from avilla.onebot.v11.perform.event import message # noqa: F401, F811 + from avilla.onebot.v11.perform.event import notice # noqa: F401 + from avilla.onebot.v11.perform.event import request # noqa: F401 + from avilla.onebot.v11.perform.message import deserialize # noqa: F401 + from avilla.onebot.v11.perform.message import serialize # noqa: F401 + from avilla.onebot.v11.perform.query import group # noqa: F401 + + return collect_context def __init__(self): self.service = OneBot11Service(self) diff --git a/avilla/qqapi/capability.py b/avilla/qqapi/capability.py index 9628f458..1f8bec3f 100644 --- a/avilla/qqapi/capability.py +++ b/avilla/qqapi/capability.py @@ -5,8 +5,8 @@ from graia.amnesia.message import Element, MessageChain from avilla.core.event import AvillaEvent -from avilla.core.ryanvk.collector.application import ApplicationCollector -from avilla.core.ryanvk.overload.target import TargetOverload +from avilla.core.ryanvk_old.collector.application import ApplicationCollector +from avilla.core.ryanvk_old.overload.target import TargetOverload from avilla.core.selector import Selector from avilla.standard.core.application.event import AvillaLifecycleEvent from graia.ryanvk import Fn, PredicateOverload, SimpleOverload, TypeOverload @@ -16,16 +16,13 @@ class QQAPICapability((m := ApplicationCollector())._): @Fn.complex({SimpleOverload(): ["event_type"]}) - async def event_callback(self, event_type: str, raw_event: dict) -> AvillaEvent | AvillaLifecycleEvent | None: - ... + async def event_callback(self, event_type: str, raw_event: dict) -> AvillaEvent | AvillaLifecycleEvent | None: ... @Fn.complex({PredicateOverload(lambda _, raw: raw["type"]): ["raw_element"]}) - async def deserialize_element(self, raw_element: dict) -> Element: - ... + async def deserialize_element(self, raw_element: dict) -> Element: ... @Fn.complex({TypeOverload(): ["element"]}) - async def serialize_element(self, element: Any) -> str | tuple[str, Any]: - ... + async def serialize_element(self, element: Any) -> str | tuple[str, Any]: ... @Fn.complex({TargetOverload(): ["target"]}) async def create_dms(self, target: Selector) -> Selector: diff --git a/avilla/qqapi/collector/connection.py b/avilla/qqapi/collector/connection.py index b72681ad..b5faecde 100644 --- a/avilla/qqapi/collector/connection.py +++ b/avilla/qqapi/collector/connection.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, ClassVar, TypeVar -from avilla.core.ryanvk.collector.base import AvillaBaseCollector +from avilla.core.ryanvk_old.collector.base import AvillaBaseCollector from graia.ryanvk import Access, BasePerform if TYPE_CHECKING: @@ -32,7 +32,6 @@ class PerformTemplate( ConnectionBasedPerformTemplate, upper, native=True, - ): - ... + ): ... return PerformTemplate diff --git a/avilla/qqapi/connection/base.py b/avilla/qqapi/connection/base.py index 48f30c51..8b718af8 100644 --- a/avilla/qqapi/connection/base.py +++ b/avilla/qqapi/connection/base.py @@ -8,7 +8,7 @@ from loguru import logger from typing_extensions import Self -from avilla.core.ryanvk.staff import Staff +from avilla.core.ryanvk_old.staff import Staff from avilla.qqapi.audit import MessageAudited, audit_result from avilla.qqapi.capability import QQAPICapability @@ -52,18 +52,14 @@ def get_staff_artifacts(self): def staff(self): return Staff(self.get_staff_artifacts(), self.get_staff_components()) - def message_receive(self, shard: tuple[int, int]) -> AsyncIterator[tuple[Self, dict]]: - ... + def message_receive(self, shard: tuple[int, int]) -> AsyncIterator[tuple[Self, dict]]: ... @property - def alive(self) -> bool: - ... + def alive(self) -> bool: ... - async def wait_for_available(self): - ... + async def wait_for_available(self): ... - async def send(self, payload: dict, shard: tuple[int, int]) -> None: - ... + async def send(self, payload: dict, shard: tuple[int, int]) -> None: ... async def message_handle(self, shard: tuple[int, int]): async for connection, data in self.message_receive(shard): @@ -93,5 +89,4 @@ async def connection_closed(self): self.session_id = None self.close_signal.set() - async def call_http(self, method: CallMethod, action: str, params: dict | None = None) -> dict: - ... + async def call_http(self, method: CallMethod, action: str, params: dict | None = None) -> dict: ... diff --git a/avilla/qqapi/connection/ws_client.py b/avilla/qqapi/connection/ws_client.py index feb50a23..3d3828b8 100644 --- a/avilla/qqapi/connection/ws_client.py +++ b/avilla/qqapi/connection/ws_client.py @@ -210,8 +210,7 @@ def get_staff_components(self): def get_staff_artifacts(self): return [self.protocol.artifacts, self.protocol.avilla.global_artifacts] - def __staff_generic__(self, element_type: dict, event_type: dict): - ... + def __staff_generic__(self, element_type: dict, event_type: dict): ... @property def alive(self): diff --git a/avilla/qqapi/perform/action/channel.py b/avilla/qqapi/perform/action/channel.py index be946128..fb10a2f5 100644 --- a/avilla/qqapi/perform/action/channel.py +++ b/avilla/qqapi/perform/action/channel.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from avilla.core.exceptions import permission_error_message -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.qqapi.const import PRIVILEGE_TRANS from avilla.standard.core.privilege import MuteAllCapability, Privilege diff --git a/avilla/qqapi/perform/action/guild.py b/avilla/qqapi/perform/action/guild.py index 5da5568c..fd47bbc3 100644 --- a/avilla/qqapi/perform/action/guild.py +++ b/avilla/qqapi/perform/action/guild.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.common import Count from avilla.standard.core.profile import Nick, Summary diff --git a/avilla/qqapi/perform/action/guild_member.py b/avilla/qqapi/perform/action/guild_member.py index 77826acb..2f6ffbc7 100644 --- a/avilla/qqapi/perform/action/guild_member.py +++ b/avilla/qqapi/perform/action/guild_member.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING from avilla.core.exceptions import permission_error_message -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.qqapi.const import PRIVILEGE_TRANS from avilla.standard.core.privilege import MuteCapability, MuteInfo, Privilege diff --git a/avilla/qqapi/perform/action/message.py b/avilla/qqapi/perform/action/message.py index b52fa4af..42a5f754 100644 --- a/avilla/qqapi/perform/action/message.py +++ b/avilla/qqapi/perform/action/message.py @@ -9,7 +9,7 @@ from avilla.core import Context, CoreCapability, Message from avilla.core.exceptions import ActionFailed -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.qqapi.capability import QQAPICapability from avilla.qqapi.exception import AuditException diff --git a/avilla/qqapi/perform/action/role.py b/avilla/qqapi/perform/action/role.py index 8f19f16e..62825701 100644 --- a/avilla/qqapi/perform/action/role.py +++ b/avilla/qqapi/perform/action/role.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING from avilla.core.exceptions import permission_error_message -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.qqapi.const import PRIVILEGE_TRANS from avilla.qqapi.role import ( diff --git a/avilla/qqapi/perform/action/user.py b/avilla/qqapi/perform/action/user.py index c60a6b9c..9fbe8735 100644 --- a/avilla/qqapi/perform/action/user.py +++ b/avilla/qqapi/perform/action/user.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.profile import Nick, Summary diff --git a/avilla/qqapi/perform/context.py b/avilla/qqapi/perform/context.py index 24923b9d..2b5b3d83 100644 --- a/avilla/qqapi/perform/context.py +++ b/avilla/qqapi/perform/context.py @@ -4,7 +4,7 @@ from avilla.core.builtins.capability import CoreCapability from avilla.core.context import Context -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector if TYPE_CHECKING: diff --git a/avilla/qqapi/perform/message/deserialize.py b/avilla/qqapi/perform/message/deserialize.py index eb31425c..14eb9302 100644 --- a/avilla/qqapi/perform/message/deserialize.py +++ b/avilla/qqapi/perform/message/deserialize.py @@ -12,7 +12,7 @@ Text, Video, ) -from avilla.core.ryanvk.collector.application import ApplicationCollector +from avilla.core.ryanvk_old.collector.application import ApplicationCollector from avilla.core.selector import Selector from avilla.qqapi.capability import QQAPICapability from avilla.qqapi.element import Ark, ArkKv, Embed, Reference diff --git a/avilla/qqapi/perform/message/serialize.py b/avilla/qqapi/perform/message/serialize.py index effbebd5..af613410 100644 --- a/avilla/qqapi/perform/message/serialize.py +++ b/avilla/qqapi/perform/message/serialize.py @@ -5,7 +5,7 @@ from avilla.core.elements import Audio, Face, Notice, NoticeAll, Picture, Text, Video from avilla.core.resource import LocalFileResource, RawResource, UrlResource -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.qqapi.capability import QQAPICapability from avilla.qqapi.element import Ark, Embed, Keyboard, Markdown, Reference from avilla.qqapi.resource import ( diff --git a/avilla/qqapi/perform/query.py b/avilla/qqapi/perform/query.py index f5e439c8..faae58a4 100644 --- a/avilla/qqapi/perform/query.py +++ b/avilla/qqapi/perform/query.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Callable, cast from avilla.core.builtins.capability import CoreCapability -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector if TYPE_CHECKING: diff --git a/avilla/qqapi/perform/resource_fetch.py b/avilla/qqapi/perform/resource_fetch.py index 7476c95c..643801b0 100644 --- a/avilla/qqapi/perform/resource_fetch.py +++ b/avilla/qqapi/perform/resource_fetch.py @@ -5,7 +5,7 @@ from aiohttp import ClientSession from avilla.core.builtins.capability import CoreCapability -from avilla.core.ryanvk.collector.protocol import ProtocolCollector +from avilla.core.ryanvk_old.collector.protocol import ProtocolCollector from avilla.qqapi.resource import ( QQAPIAudioResource, QQAPIFileResource, diff --git a/avilla/qqapi/role/capability.py b/avilla/qqapi/role/capability.py index f47db061..6c23f07c 100644 --- a/avilla/qqapi/role/capability.py +++ b/avilla/qqapi/role/capability.py @@ -1,6 +1,6 @@ from __future__ import annotations -from avilla.core.ryanvk import Fn, TargetOverload +from avilla.core.ryanvk_old import Fn, TargetOverload from avilla.core.selector import Selector from graia.ryanvk.capability import Capability @@ -9,29 +9,24 @@ class RoleCreate(Capability): @Fn.complex({TargetOverload(): ["target"]}) async def create( self, target: Selector, name: str, hoist: bool | None = None, color: int | None = None - ) -> Selector: - ... + ) -> Selector: ... class RoleDelete(Capability): @Fn.complex({TargetOverload(): ["target"]}) - async def delete(self, target: Selector) -> None: - ... + async def delete(self, target: Selector) -> None: ... class RoleEdit(Capability): @Fn.complex({TargetOverload(): ["target"]}) async def edit( self, target: Selector, name: str | None = None, hoist: bool | None = None, color: int | None = None - ) -> None: - ... + ) -> None: ... class RoleMemberCapability(Capability): @Fn.complex({TargetOverload(): ["target"]}) - async def add(self, target: Selector, member: Selector) -> None: - ... + async def add(self, target: Selector, member: Selector) -> None: ... @Fn.complex({TargetOverload(): ["target"]}) - async def remove(self, target: Selector, member: Selector) -> None: - ... + async def remove(self, target: Selector, member: Selector) -> None: ... diff --git a/avilla/red/capability.py b/avilla/red/capability.py index 22890cf5..8a38c2f9 100644 --- a/avilla/red/capability.py +++ b/avilla/red/capability.py @@ -5,8 +5,8 @@ from graia.amnesia.message import Element, MessageChain from avilla.core.event import AvillaEvent -from avilla.core.ryanvk.collector.application import ApplicationCollector -from avilla.core.ryanvk.overload.target import TargetOverload +from avilla.core.ryanvk_old.collector.application import ApplicationCollector +from avilla.core.ryanvk_old.overload.target import TargetOverload from avilla.core.selector import Selector from avilla.standard.core.application.event import AvillaLifecycleEvent from avilla.standard.qq.elements import Forward @@ -15,8 +15,7 @@ class RedCapability((m := ApplicationCollector())._): @Fn.complex({SimpleOverload(): ["event_type"]}) - async def event_callback(self, event_type: str, raw_event: dict) -> AvillaEvent | AvillaLifecycleEvent | None: - ... + async def event_callback(self, event_type: str, raw_event: dict) -> AvillaEvent | AvillaLifecycleEvent | None: ... @Fn.complex({PredicateOverload(lambda _, raw: raw["type"]): ["element"]}) async def deserialize_element(self, element: dict) -> Element: # type: ignore @@ -31,8 +30,7 @@ async def forward_export(self, element: Any) -> dict: # type: ignore ... @Fn.complex({TargetOverload(): ["target"]}) - async def send_forward(self, target: Selector, forward: Forward) -> Selector: - ... + async def send_forward(self, target: Selector, forward: Forward) -> Selector: ... async def deserialize(self, elements: list[dict]): _elements = [] diff --git a/avilla/red/collector/connection.py b/avilla/red/collector/connection.py index 8ae08f0e..a4fa1135 100644 --- a/avilla/red/collector/connection.py +++ b/avilla/red/collector/connection.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, ClassVar, TypeVar -from avilla.core.ryanvk.collector.base import AvillaBaseCollector +from avilla.core.ryanvk_old.collector.base import AvillaBaseCollector from graia.ryanvk import Access, BasePerform if TYPE_CHECKING: @@ -32,7 +32,6 @@ class PerformTemplate( ConnectionBasedPerformTemplate, upper, native=True, - ): - ... + ): ... return PerformTemplate diff --git a/avilla/red/exception.py b/avilla/red/exception.py index b5af6eb3..62af8136 100644 --- a/avilla/red/exception.py +++ b/avilla/red/exception.py @@ -1,6 +1,5 @@ """Ariadne 的异常定义""" - from avilla.core.exceptions import ( InvalidAuthentication, InvalidOperation, diff --git a/avilla/red/net/base.py b/avilla/red/net/base.py index 98b37a5a..5fe57f43 100644 --- a/avilla/red/net/base.py +++ b/avilla/red/net/base.py @@ -7,7 +7,7 @@ from loguru import logger from typing_extensions import Self -from avilla.core.ryanvk.staff import Staff +from avilla.core.ryanvk_old.staff import Staff from avilla.red.account import RedAccount from avilla.red.capability import RedCapability from avilla.red.utils import MsgType, get_msg_types @@ -37,18 +37,14 @@ def get_staff_artifacts(self): def staff(self): return Staff(self.get_staff_artifacts(), self.get_staff_components()) - def message_receive(self) -> AsyncIterator[tuple[Self, dict]]: - ... + def message_receive(self) -> AsyncIterator[tuple[Self, dict]]: ... @property - def alive(self) -> bool: - ... + def alive(self) -> bool: ... - async def wait_for_available(self): - ... + async def wait_for_available(self): ... - async def send(self, payload: dict) -> None: - ... + async def send(self, payload: dict) -> None: ... async def message_handle(self): async for connection, data in self.message_receive(): @@ -121,8 +117,7 @@ async def call(self, action: str, params: dict | None = None) -> None: @overload async def call_http( self, method: Literal["get", "post", "multipart"], action: str, params: dict | None = None - ) -> dict: - ... + ) -> dict: ... @overload async def call_http( @@ -131,10 +126,8 @@ async def call_http( action: str, params: dict | None = None, raw: Literal[True] = True, - ) -> bytes: - ... + ) -> bytes: ... async def call_http( self, method: Literal["get", "post", "multipart"], action: str, params: dict | None = None, raw: bool = False - ) -> dict | bytes: - ... + ) -> dict | bytes: ... diff --git a/avilla/red/net/ws_client.py b/avilla/red/net/ws_client.py index c4e62deb..5127c819 100644 --- a/avilla/red/net/ws_client.py +++ b/avilla/red/net/ws_client.py @@ -95,8 +95,7 @@ async def wait_for_available(self): def get_staff_components(self): return {"connection": self, "protocol": self.protocol, "avilla": self.protocol.avilla} - def __staff_generic__(self, element_type: dict, event_type: dict): - ... + def __staff_generic__(self, element_type: dict, event_type: dict): ... def get_staff_artifacts(self): return [self.protocol.artifacts, self.protocol.avilla.global_artifacts] @@ -135,7 +134,11 @@ async def connection_daemon(self, manager: Launart, session: aiohttp.ClientSessi self.close_signal.set() self.connection = None for v in list(avilla.accounts.values()): - if v.protocol is self.protocol and self.account and v.route["account"] == self.account.route["account"]: # type: ignore + if ( + v.protocol is self.protocol + and self.account + and v.route["account"] == self.account.route["account"] + ): # type: ignore self.account = None del avilla.accounts[v.route] await avilla.broadcast.postEvent(AccountUnregistered(avilla, v.account)) diff --git a/avilla/red/perform/action/friend.py b/avilla/red/perform/action/friend.py index d1012b29..9047e145 100644 --- a/avilla/red/perform/action/friend.py +++ b/avilla/red/perform/action/friend.py @@ -5,7 +5,7 @@ from graia.amnesia.builtins.memcache import Memcache, MemcacheService from avilla.core.exceptions import UnknownTarget -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.profile import Nick, Summary diff --git a/avilla/red/perform/action/group.py b/avilla/red/perform/action/group.py index 671b99d5..f3e57e9b 100644 --- a/avilla/red/perform/action/group.py +++ b/avilla/red/perform/action/group.py @@ -5,7 +5,7 @@ from graia.amnesia.builtins.memcache import Memcache, MemcacheService from avilla.core.exceptions import UnknownTarget -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.common import Count from avilla.standard.core.privilege import MuteAllCapability diff --git a/avilla/red/perform/action/member.py b/avilla/red/perform/action/member.py index 3aeb4eaa..c4c86305 100644 --- a/avilla/red/perform/action/member.py +++ b/avilla/red/perform/action/member.py @@ -6,7 +6,7 @@ from graia.amnesia.builtins.memcache import Memcache, MemcacheService from avilla.core.exceptions import UnknownTarget -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.standard.core.privilege import MuteCapability, MuteInfo from avilla.standard.core.profile import Nick, Summary diff --git a/avilla/red/perform/action/message.py b/avilla/red/perform/action/message.py index 27f94f9e..656d2409 100644 --- a/avilla/red/perform/action/message.py +++ b/avilla/red/perform/action/message.py @@ -8,7 +8,7 @@ from loguru import logger from avilla.core import Context -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector from avilla.red.capability import RedCapability from avilla.standard.core.message import ( diff --git a/avilla/red/perform/context.py b/avilla/red/perform/context.py index fb72168a..c0536e8e 100644 --- a/avilla/red/perform/context.py +++ b/avilla/red/perform/context.py @@ -4,7 +4,7 @@ from avilla.core.builtins.capability import CoreCapability from avilla.core.context import Context -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector if TYPE_CHECKING: diff --git a/avilla/red/perform/message/deserialize.py b/avilla/red/perform/message/deserialize.py index 5600412b..ddd30c40 100644 --- a/avilla/red/perform/message/deserialize.py +++ b/avilla/red/perform/message/deserialize.py @@ -16,7 +16,7 @@ Text, Video, ) -from avilla.core.ryanvk.collector.application import ApplicationCollector +from avilla.core.ryanvk_old.collector.application import ApplicationCollector from avilla.core.selector import Selector from avilla.red.capability import RedCapability from avilla.red.resource import ( diff --git a/avilla/red/perform/message/serialize.py b/avilla/red/perform/message/serialize.py index 9212e4ba..712a37f6 100644 --- a/avilla/red/perform/message/serialize.py +++ b/avilla/red/perform/message/serialize.py @@ -7,7 +7,7 @@ from avilla.core.builtins.resource_fetch import CoreResourceFetchPerform from avilla.core.elements import Audio, Face, Notice, NoticeAll, Picture, Text from avilla.core.resource import LocalFileResource, RawResource, UrlResource -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.red.capability import RedCapability from avilla.standard.qq.elements import MarketFace diff --git a/avilla/red/perform/query.py b/avilla/red/perform/query.py index 144c6e32..a7dd8a2e 100644 --- a/avilla/red/perform/query.py +++ b/avilla/red/perform/query.py @@ -6,7 +6,7 @@ from graia.amnesia.builtins.memcache import Memcache, MemcacheService from avilla.core.builtins.capability import CoreCapability -from avilla.core.ryanvk.collector.account import AccountCollector +from avilla.core.ryanvk_old.collector.account import AccountCollector from avilla.core.selector import Selector if TYPE_CHECKING: diff --git a/avilla/red/perform/resource_fetch.py b/avilla/red/perform/resource_fetch.py index d48be9bb..67966483 100644 --- a/avilla/red/perform/resource_fetch.py +++ b/avilla/red/perform/resource_fetch.py @@ -6,7 +6,7 @@ from aiohttp import ClientSession from avilla.core.builtins.capability import CoreCapability -from avilla.core.ryanvk.collector.protocol import ProtocolCollector +from avilla.core.ryanvk_old.collector.protocol import ProtocolCollector from avilla.red.resource import ( RedFileResource, RedImageResource, diff --git a/avilla/satori/bases.py b/avilla/satori/bases.py new file mode 100644 index 00000000..a7ee52a0 --- /dev/null +++ b/avilla/satori/bases.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from flywheel import InstanceOf +from satori.client.account import Account + +from .account import SatoriAccount +from .protocol import SatoriProtocol + + +class InstanceOfProtocol: + protocol = InstanceOf(SatoriProtocol) + + +class InstanceOfAccount(InstanceOfProtocol): + account = InstanceOf(SatoriAccount) + + +class InstanceOfConnection(InstanceOfProtocol): + connection = InstanceOf(Account) diff --git a/avilla/satori/capability.py b/avilla/satori/capability.py index 84fab4da..c5097f06 100644 --- a/avilla/satori/capability.py +++ b/avilla/satori/capability.py @@ -1,55 +1,87 @@ from __future__ import annotations -from typing import Any +from typing import Any, Protocol, TypeVar -from graia.amnesia.message import Element, MessageChain -from satori.parser import parse +from flywheel import ( + Fn, + FnCompose, + FnRecord, + OverloadRecorder, + SimpleOverload, + TypeOverload, +) +from graia.amnesia.message import Element as GraiaElement +from graia.amnesia.message import MessageChain +from satori.element import Element as SatoriElement from satori.element import transform -from satori.model import Event +from satori.model import Event as SatoriEvent +from satori.parser import parse from avilla.core.event import AvillaEvent -from avilla.core.ryanvk.collector.application import ApplicationCollector from avilla.standard.core.application.event import AvillaLifecycleEvent -from graia.ryanvk import Fn, PredicateOverload, TypeOverload +SE = TypeVar("SE", bound=SatoriElement, contravariant=True) +GE = TypeVar("GE", bound=GraiaElement, contravariant=True) +SV = TypeVar("SV", bound=SatoriEvent, contravariant=True) + + +class SatoriCapability: + @Fn.declare + class event_callback(FnCompose): + raw_event = SimpleOverload("raw_event") + + async def call(self, record: FnRecord, event: SatoriEvent): + entities = self.load(self.raw_event.dig(record, event.type)) + return await entities.first(event=event) + + class shapecall(Protocol[SV]): + async def __call__(self, event: SV) -> AvillaEvent | AvillaLifecycleEvent | list[Any]: ... + + def collect(self, recorder: OverloadRecorder[shapecall], raw_event: str): + recorder.use(self.raw_event, raw_event) -class SatoriCapability((m := ApplicationCollector())._): - @Fn.complex({PredicateOverload(lambda _, raw: raw.type): ["raw_event"]}) - async def event_callback(self, raw_event: Event) -> AvillaEvent | AvillaLifecycleEvent | list[Any] | None: - ... + @Fn.declare + class deserialize_element(FnCompose): + type = TypeOverload("type") - @Fn.complex({TypeOverload(): ["raw_element"]}) - async def deserialize_element(self, raw_element: Any) -> Element: - ... + async def call(self, record: FnRecord, element: SatoriElement): + entities = self.load(self.type.dig(record, element)) + return await entities.first(element=element) - @Fn.complex({TypeOverload(): ["element"]}) - async def serialize_element(self, element: Any) -> str: - ... + class shapecall(Protocol[SE]): + async def __call__(self, element: SE) -> GraiaElement: ... - async def deserialize(self, content: str): + def collect(self, recorder: OverloadRecorder[shapecall[SE]], element: type[SE]): + recorder.use(self.type, element) + + @Fn.declare + class serialize_element(FnCompose): + type = TypeOverload("type") + + async def call(self, record: FnRecord, element: GraiaElement): + entities = self.load(self.type.dig(record, element)) + return await entities.first(element=element) + + class shapecall(Protocol[GE]): + async def __call__(self, element: GE) -> str: ... + + def collect(self, recorder: OverloadRecorder[shapecall[GE]], element: type[GE]): + recorder.use(self.type, element) + + @staticmethod + async def deserialize(content: str): elements = [] for raw_element in transform(parse(content)): - elements.append(await self.deserialize_element(raw_element)) + elements.append(await SatoriCapability.deserialize_element(raw_element)) return MessageChain(elements) - async def serialize(self, message: MessageChain): + @staticmethod + async def serialize(message: MessageChain): chain = [] for element in message: - chain.append(await self.serialize_element(element)) + chain.append(await SatoriCapability.serialize_element(element)) return "".join(chain) - - async def handle_event(self, event: Event): - maybe_event = await self.event_callback(event) - - if maybe_event is not None: - if isinstance(maybe_event, list): - for _event in maybe_event: - self.avilla.event_record(_event) - self.avilla.broadcast.postEvent(_event) - else: - self.avilla.event_record(maybe_event) - self.avilla.broadcast.postEvent(maybe_event) diff --git a/avilla/satori/collector/__init__.py b/avilla/satori/collector/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/avilla/satori/collector/connection.py b/avilla/satori/collector/connection.py deleted file mode 100644 index 305d9df1..00000000 --- a/avilla/satori/collector/connection.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, ClassVar, TypeVar - -from satori.client.account import Account - -from avilla.core.ryanvk.collector.base import AvillaBaseCollector -from graia.ryanvk import Access, BasePerform - -if TYPE_CHECKING: - from avilla.satori.protocol import SatoriProtocol - - -T = TypeVar("T") -T1 = TypeVar("T1") - - -class ConnectionBasedPerformTemplate(BasePerform, native=True): - __collector__: ClassVar[ConnectionCollector] - - protocol: Access[SatoriProtocol] = Access() - connection: Access[Account] = Access() - - -class ConnectionCollector(AvillaBaseCollector): - post_applying: bool = False - - @property - def _(self): - upper = super()._ - - class PerformTemplate( - ConnectionBasedPerformTemplate, - upper, - native=True, - ): - ... - - return PerformTemplate diff --git a/avilla/satori/element.py b/avilla/satori/element.py index 436cecce..fac0f41b 100644 --- a/avilla/satori/element.py +++ b/avilla/satori/element.py @@ -5,5 +5,4 @@ from avilla.core.elements import Element -class Button(SatoriButton, Element): - ... +class Button(SatoriButton, Element): ... diff --git a/avilla/satori/event.py b/avilla/satori/event.py index d5836be0..bfe1837e 100644 --- a/avilla/satori/event.py +++ b/avilla/satori/event.py @@ -4,10 +4,8 @@ @dataclass -class RoleCreated(RelationshipCreated): - ... +class RoleCreated(RelationshipCreated): ... @dataclass -class RoleDestroyed(RelationshipDestroyed): - ... +class RoleDestroyed(RelationshipDestroyed): ... diff --git a/avilla/satori/model.py b/avilla/satori/model.py deleted file mode 100644 index 443b328f..00000000 --- a/avilla/satori/model.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from datetime import datetime - -from satori.model import Channel, Event, Login, Member, MessageObject, Role, User, Guild, ButtonInteraction, ArgvInteraction - - -class MessageEvent(Event): - channel: Channel - member: Member - message: MessageObject - user: User - - -class DirectEvent(Event): - user: User - - - -class GuildEvent(Event): - guild: Guild - - -class GuildMemberEvent(Event): - guild: Guild - user: User - member: Member - -class GuildRoleEvent(Event): - guild: Guild - role: Role - - -class LoginEvent(Event): - login: Login - - -class ReactionEvent(Event): - channel: Channel - user: User - message: MessageObject - - -class ButtonInteractionEvent(Event): - button: ButtonInteraction - user: User - channel: Channel - - -class CommandInteractionEvent(Event): - message: MessageObject - user: User - channel: Channel - - -class ArgvInteractionEvent(Event): - argv: ArgvInteraction - user: User - channel: Channel diff --git a/avilla/satori/perform/action/message.py b/avilla/satori/perform/action/message.py index a032e769..d582ff49 100644 --- a/avilla/satori/perform/action/message.py +++ b/avilla/satori/perform/action/message.py @@ -2,29 +2,25 @@ from datetime import datetime, timedelta from secrets import token_urlsafe -from typing import TYPE_CHECKING +from flywheel import scoped_collect from graia.amnesia.builtins.memcache import Memcache, MemcacheService from graia.amnesia.message import MessageChain +from avilla.core.builtins.capability import CoreCapability from avilla.core.context import Context from avilla.core.elements import Reference +from avilla.core.globals import CONTEXT_CONTEXT_VAR from avilla.core.message import Message -from avilla.core.ryanvk.collector.account import AccountCollector from avilla.core.selector import Selector +from avilla.satori.bases import InstanceOfAccount from avilla.satori.capability import SatoriCapability -from avilla.standard.core.message import MessageRevoke, MessageSend, MessageSent +from avilla.standard.core.message import MessageSent, revoke_message, send_message -if TYPE_CHECKING: - from avilla.satori.account import SatoriAccount # noqa - from avilla.satori.protocol import SatoriProtocol # noqa +class SatoriMessageActionPerform(m := scoped_collect.env().target, InstanceOfAccount, static=True): -class SatoriMessageActionPerform((m := AccountCollector["SatoriProtocol", "SatoriAccount"]())._): - m.namespace = "avilla.protocol/satori::action" - m.identify = "message" - - @m.entity(MessageSend.send, target="land.guild.channel") + @m.impl(send_message, target="land.guild.channel") async def send_public_message( self, target: Selector, @@ -36,7 +32,7 @@ async def send_public_message( if reply: message = Reference(reply) + message result = await self.account.client.message_create( - channel_id=target["channel"], content=await SatoriCapability(self.account.staff).serialize(message) + channel_id=target["channel"], content=await SatoriCapability.serialize(message) ) for msg in result: _ctx = Context( @@ -46,7 +42,8 @@ async def send_public_message( target, target.member(self.account.route["account"]), ) - content = await SatoriCapability(self.account.staff.ext({"context": _ctx})).deserialize(msg.content) + with CONTEXT_CONTEXT_VAR.use(_ctx): + content = await SatoriCapability.deserialize(msg.content) content = content.exclude(Reference) _msg = Message( id=f"{msg.id}", @@ -65,8 +62,8 @@ async def send_public_message( ) return target.message(token) - @m.entity(MessageSend.send, target="land.user") - @m.entity(MessageSend.send, target="land.private.user") + @m.impl(send_message, target="land.user") + @m.impl(send_message, target="land.private.user") async def send_private_message( self, target: Selector, @@ -79,11 +76,11 @@ async def send_private_message( message = Reference(reply) + message if target.follows("::private.user"): result = await self.account.client.message_create( - channel_id=target["private"], content=await SatoriCapability(self.account.staff).serialize(message) + channel_id=target["private"], content=await SatoriCapability.serialize(message) ) else: result = await self.account.client.send_private_message( - user_id=target["user"], message=await SatoriCapability(self.account.staff).serialize(message) + user_id=target["user"], message=await SatoriCapability.serialize(message) ) for msg in result: _ctx = Context( @@ -93,7 +90,8 @@ async def send_private_message( target, self.account.route, ) - content = await SatoriCapability(self.account.staff.ext({"context": _ctx})).deserialize(msg.content) + with CONTEXT_CONTEXT_VAR.use(_ctx): + content = await SatoriCapability.deserialize(msg.content) content = content.exclude(Reference) _msg = Message( id=f"{msg.id}", @@ -112,7 +110,7 @@ async def send_private_message( ) return target.message(token) - @m.entity(MessageRevoke.revoke, target="land.guild.channel.message") + @m.impl(revoke_message, target="land.guild.channel.message") async def revoke_public_message(self, target: Selector): cache = self.protocol.avilla.launch_manager.get_component(MemcacheService).cache if result := await cache.get(f"satori/account({self.account.route['account']}).messages({target['message']})"): @@ -121,7 +119,7 @@ async def revoke_public_message(self, target: Selector): return await self.account.client.message_delete(channel_id=target["channel"], message_id=target["message"]) - @m.entity(MessageRevoke.revoke, target="land.private.user.message") + @m.impl(revoke_message, target="land.private.user.message") async def revoke_private_message(self, target: Selector): cache = self.protocol.avilla.launch_manager.get_component(MemcacheService).cache if result := await cache.get(f"satori/account({self.account.route['account']}).messages({target['message']})"): @@ -130,26 +128,27 @@ async def revoke_private_message(self, target: Selector): return await self.account.client.message_delete(channel_id=target["private"], message_id=target["message"]) - @m.pull("land.guild.channel.message", Message) - async def get_public_message(self, message: Selector, route: ...) -> Message: + @m.impl(CoreCapability.pull, "land.guild.channel.message", Message) + async def get_public_message(self, target: Selector) -> Message: msg = await self.account.client.message_get( - channel_id=message["channel"], - message_id=message["message"], + channel_id=target["channel"], + message_id=target["message"], ) _ctx = self.account.get_context( - message.info("::guild.channel").member( + target.info("::guild.channel").member( msg.member.user.id if msg.member and msg.member.user else self.account.route["account"] ) ) - content = await SatoriCapability(self.account.staff.ext({"context": _ctx})).deserialize(msg.content) + with CONTEXT_CONTEXT_VAR.use(_ctx): + content = await SatoriCapability.deserialize(msg.content) reply = None if replys := content.get(Reference): - reply = message.info(f"~.message({replys[0].message['message']})") + reply = target.info(f"~.message({replys[0].message['message']})") content = content.exclude(Reference) return Message( id=f"{msg.id}", - scene=message.info("::guild.channel"), - sender=message.info("::guild.channel").member( + scene=target.info("::guild.channel"), + sender=target.info("::guild.channel").member( msg.member.user.id if msg.member and msg.member.user else self.account.route["account"] ), content=content, @@ -157,22 +156,23 @@ async def get_public_message(self, message: Selector, route: ...) -> Message: reply=reply, ) - @m.pull("land.private.user.message", Message) - async def get_private_message(self, message: Selector, route: ...) -> Message: + @m.impl(CoreCapability.pull, "land.private.user.message", Message) + async def get_private_message(self, target: Selector) -> Message: msg = await self.account.client.message_get( - channel_id=message["private"], - message_id=message["message"], + channel_id=target["private"], + message_id=target["message"], ) - _ctx = self.account.get_context(message.info("::private.user")) - content = await SatoriCapability(self.account.staff.ext({"context": _ctx})).deserialize(msg.content) + _ctx = self.account.get_context(target.info("::private.user")) + with CONTEXT_CONTEXT_VAR.use(_ctx): + content = await SatoriCapability.deserialize(msg.content) reply = None if replys := content.get(Reference): - reply = message.info(f"~.message({replys[0].message['message']})") + reply = target.info(f"~.message({replys[0].message['message']})") content = content.exclude(Reference) return Message( id=f"{msg.id}", - scene=message.info("::private.user"), - sender=message.info("::private.user"), + scene=target.info("::private.user"), + sender=target.info("::private.user"), content=content, time=datetime.now(), reply=reply, diff --git a/avilla/satori/perform/action/request.py b/avilla/satori/perform/action/request.py index c240d06b..bae81de4 100644 --- a/avilla/satori/perform/action/request.py +++ b/avilla/satori/perform/action/request.py @@ -1,21 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from flywheel import scoped_collect -from avilla.core.ryanvk.collector.account import AccountCollector from avilla.core.selector import Selector -from avilla.standard.core.request import RequestCapability +from avilla.satori.bases import InstanceOfAccount +from avilla.standard.core.request import accept_request, reject_request -if TYPE_CHECKING: - from avilla.satori.account import SatoriAccount # noqa - from avilla.satori.protocol import SatoriProtocol # noqa +class SatoriRequestActionPerform(m := scoped_collect.env().target, InstanceOfAccount, static=True): -class SatoriRequestActionPerform((m := AccountCollector["SatoriProtocol", "SatoriAccount"]())._): - m.namespace = "avilla.protocol/satori::action" - m.identify = "request" - - @m.entity(RequestCapability.accept, target="land.guild.member.request") + @m.impl(accept_request, target="land.guild.member.request") async def accept_member_join_request(self, target: Selector) -> None: request_id = target.pattern["request"] await self.account.client.guild_approve( @@ -24,8 +18,10 @@ async def accept_member_join_request(self, target: Selector) -> None: comment="", ) - @m.entity(RequestCapability.reject, target="land.guild.member.request") - async def reject_member_join_request(self, target: Selector, reason: str | None = None, forever: bool = False) -> None: + @m.impl(reject_request, target="land.guild.member.request") + async def reject_member_join_request( + self, target: Selector, reason: str | None = None, forever: bool = False + ) -> None: request_id = target.pattern["request"] await self.account.client.guild_approve( request_id=request_id, @@ -33,7 +29,7 @@ async def reject_member_join_request(self, target: Selector, reason: str | None comment=reason or "", ) - @m.entity(RequestCapability.accept, target="land.user.request") + @m.impl(accept_request, target="land.user.request") async def accept_new_friend_request(self, target: Selector) -> None: request_id = target.pattern["request"] await self.account.client.friend_approve( @@ -42,8 +38,10 @@ async def accept_new_friend_request(self, target: Selector) -> None: comment="", ) - @m.entity(RequestCapability.reject, target="land.user.request") - async def reject_new_friend_request(self, target: Selector, reason: str | None = None, forever: bool = False) -> None: + @m.impl(reject_request, target="land.user.request") + async def reject_new_friend_request( + self, target: Selector, reason: str | None = None, forever: bool = False + ) -> None: request_id = target.pattern["request"] await self.account.client.friend_approve( request_id=request_id, @@ -51,7 +49,7 @@ async def reject_new_friend_request(self, target: Selector, reason: str | None = comment=reason or "", ) - @m.entity(RequestCapability.accept, target="land.guild.request") + @m.impl(accept_request, target="land.guild.request") async def accept_bot_invited_request(self, target: Selector) -> None: request_id = target.pattern["request"] await self.account.client.guild_approve( @@ -60,7 +58,7 @@ async def accept_bot_invited_request(self, target: Selector) -> None: comment="", ) - @m.entity(RequestCapability.reject, target="land.guild.request") + @m.impl(reject_request, target="land.guild.request") async def reject_bot_invited_request( self, target: Selector, reason: str | None = None, forever: bool = False ) -> None: diff --git a/avilla/satori/perform/context.py b/avilla/satori/perform/context.py index 15e2312d..def38578 100644 --- a/avilla/satori/perform/context.py +++ b/avilla/satori/perform/context.py @@ -2,20 +2,21 @@ from typing import TYPE_CHECKING +from flywheel import scoped_collect + from avilla.core.builtins.capability import CoreCapability from avilla.core.context import Context -from avilla.core.ryanvk.collector.account import AccountCollector from avilla.core.selector import Selector +from avilla.satori.bases import InstanceOfAccount if TYPE_CHECKING: from avilla.satori.account import SatoriAccount # noqa from avilla.satori.protocol import SatoriProtocol # noqa -class SatoriContextPerform((m := AccountCollector["SatoriProtocol", "SatoriAccount"]())._): - m.namespace = "avilla.protocol/satori::context" +class SatoriContextPerform(m := scoped_collect.env().target, InstanceOfAccount, static=True): - @m.entity(CoreCapability.get_context, target="land.public.channel") + @m.impl(CoreCapability.get_context, target="land.guild.channel") def get_context_from_public(self, target: Selector, *, via: Selector | None = None): return Context( self.account, @@ -25,36 +26,36 @@ def get_context_from_public(self, target: Selector, *, via: Selector | None = No target.member(self.account.route["account"]), ) - @m.entity(CoreCapability.get_context, target="land.private.user") + @m.impl(CoreCapability.get_context, target="land.private.user") def get_context_from_user(self, target: Selector, *, via: Selector | None = None): return Context(self.account, target, self.account.route, target, self.account.route) - @m.entity(CoreCapability.get_context, target="land.public.channel.member") + @m.impl(CoreCapability.get_context, target="land.guild.channel.member") def get_context_from_member(self, target: Selector, *, via: Selector | None = None): return Context( self.account, target, - target.into("::public.channel"), - target.into("::public.channel"), + target.into("::guild.channel"), + target.into("::guild.channel"), target.into(f"~.member({self.account.route['account']})"), ) - @m.entity(CoreCapability.channel, target="land.private.user") + @m.impl(CoreCapability.channel, target="land.private.user") def channel_from_user(self, target: Selector): return target["private"] - @m.entity(CoreCapability.channel, target="land.public.channel") + @m.impl(CoreCapability.channel, target="land.guild.channel") def channel_from_channel(self, target: Selector): return target["channel"] - @m.entity(CoreCapability.guild, target="land.public.channel") + @m.impl(CoreCapability.guild, target="land.guild.channel") def guild_from_channel(self, target: Selector): - return target["public"] + return target["guild"] - @m.entity(CoreCapability.user, target="land.private.user") + @m.impl(CoreCapability.user, target="land.private.user") def user_from_user(self, target: Selector): return target["user"] - @m.entity(CoreCapability.user, target="land.public.channel.member") + @m.impl(CoreCapability.user, target="land.guild.channel.member") def user_from_member(self, target: Selector): return target["member"] diff --git a/avilla/satori/perform/event/activity.py b/avilla/satori/perform/event/activity.py index a518dda4..1b50698c 100644 --- a/avilla/satori/perform/event/activity.py +++ b/avilla/satori/perform/event/activity.py @@ -1,31 +1,23 @@ from __future__ import annotations -from typing import TYPE_CHECKING - -from satori.model import ChannelType, Event +from flywheel import scoped_collect +from satori.model import ChannelType from avilla.core.context import Context from avilla.core.selector import Selector +from avilla.satori.bases import InstanceOfAccount from avilla.satori.capability import SatoriCapability -from avilla.satori.collector.connection import ConnectionCollector -from avilla.satori.model import ButtonInteractionEvent +from satori.event import ButtonInteractionEvent from avilla.standard.core.activity import ActivityAvailable -class SatoriEventActivityPerform((m := ConnectionCollector())._): - m.namespace = "avilla.protocol/satori::event" - m.identify = "activity" - - @m.entity(SatoriCapability.event_callback, raw_event="interaction/button") - async def button_interaction(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] - if not raw_event.channel: - return - if TYPE_CHECKING: - assert isinstance(raw_event, ButtonInteractionEvent) - if raw_event.channel.type == ChannelType.DIRECT: - private = Selector().land(account.route["land"]).private(raw_event.channel.id) - user = private.user(raw_event.user.id) # type: ignore +class SatoriEventActivityPerform(m := scoped_collect.globals().target, InstanceOfAccount, static=True): + @m.impl(SatoriCapability.event_callback, raw_event="interaction/button") + async def button_interaction(self, event: ButtonInteractionEvent): + account = self.account + if event.channel.type == ChannelType.DIRECT: + private = Selector().land(account.route["land"]).private(event.channel.id) + user = private.user(event.user.id) # type: ignore context = Context( account, user, @@ -33,17 +25,11 @@ async def button_interaction(self, raw_event: Event): user, account.route, ) - activity = private.button(raw_event.button.id) # type: ignore + activity = private.button(event.button.id) # type: ignore else: - public = ( - Selector() - .land(account.route["land"]) - .public(raw_event.guild.id if raw_event.guild else raw_event.channel.id) - ) - channel = public.channel(raw_event.channel.id) - member = channel.member( - raw_event.member.user.id if raw_event.member and raw_event.member.user else raw_event.user.id - ) + public = Selector().land(account.route["land"]).public(event.guild.id if event.guild else event.channel.id) + channel = public.channel(event.channel.id) + member = channel.member(event.member.user.id if event.member and event.member.user else event.user.id) context = Context( account, member, @@ -51,5 +37,5 @@ async def button_interaction(self, raw_event: Event): channel, channel.member(account.route["account"]), ) - activity = channel.button(raw_event.button.id) # type: ignore + activity = channel.button(event.button.id) # type: ignore return ActivityAvailable(context, "button_interaction", context.scene, activity) diff --git a/avilla/satori/perform/event/lifespan.py b/avilla/satori/perform/event/lifespan.py index aba95606..e9124f70 100644 --- a/avilla/satori/perform/event/lifespan.py +++ b/avilla/satori/perform/event/lifespan.py @@ -1,14 +1,15 @@ from __future__ import annotations +from flywheel import scoped_collect from loguru import logger -from satori.account import Account +from satori.client.account import Account from satori.model import Event, LoginStatus from avilla.core.account import AccountInfo from avilla.core.selector import Selector from avilla.satori.account import SatoriAccount +from avilla.satori.bases import InstanceOfConnection from avilla.satori.capability import SatoriCapability -from avilla.satori.collector.connection import ConnectionCollector from avilla.satori.const import platform from avilla.standard.core.account.event import ( AccountAvailable, @@ -18,15 +19,13 @@ ) -class SatoriEventLifespanPerform((m := ConnectionCollector())._): - m.namespace = "avilla.protocol/satori::event" - m.identify = "lifespan" +class SatoriEventLifespanPerform(m := scoped_collect.globals().target, InstanceOfConnection, static=True): - @m.entity(SatoriCapability.event_callback, raw_event="login-added") - async def connect(self, raw_event: Event): - self_id = raw_event.self_id - account = Account(raw_event.platform, self_id, self.connection.config) - route = Selector().land(raw_event.platform).account(self_id) + @m.impl(SatoriCapability.event_callback, raw_event="login-added") + async def connect(self, event: Event): + self_id = event.self_id + account = Account(event.platform, self_id, self.connection.config) + route = Selector().land(event.platform).account(self_id) _account = SatoriAccount(route=route, protocol=self.protocol) self.protocol.service.accounts[account.identity] = account @@ -39,32 +38,33 @@ async def connect(self, raw_event: Event): _account.client = self.connection return AccountRegistered(self.protocol.avilla, _account) - @m.entity(SatoriCapability.event_callback, raw_event="login-updated") - async def enable(self, raw_event: Event): - identity = f"{raw_event.platform}/{raw_event.self_id}" + @m.impl(SatoriCapability.event_callback, raw_event="login-updated") + async def enable(self, event: Event): + identity = f"{event.platform}/{event.self_id}" account = self.protocol.service.accounts.get(identity) if account is None: - logger.warning(f"Unknown account {identity} received enable event {raw_event}") - return + logger.warning(f"Unknown account {identity} received enable event {event}") + raise NotImplementedError _account = self.protocol.service._accounts[identity] - if _account.status.enabled and raw_event.login and raw_event.login.status != LoginStatus.ONLINE: + if _account.status.enabled and event.login and event.login.status != LoginStatus.ONLINE: _account.status.enabled = False logger.warning(f"Account {identity} disabled by remote") account.connected.clear() return AccountUnavailable(self.protocol.avilla, _account) - if not _account.status.enabled and raw_event.login and raw_event.login.status == LoginStatus.ONLINE: + if not _account.status.enabled and event.login and event.login.status == LoginStatus.ONLINE: _account.status.enabled = True account.connected.set() logger.warning(f"Account {identity} enabled by remote") return AccountAvailable(self.protocol.avilla, _account) + raise NotImplementedError - @m.entity(SatoriCapability.event_callback, raw_event="login-removed") - async def disable(self, raw_event: Event): - identity = f"{raw_event.platform}/{raw_event.self_id}" + @m.impl(SatoriCapability.event_callback, raw_event="login-removed") + async def disable(self, event: Event): + identity = f"{event.platform}/{event.self_id}" account = self.protocol.service.accounts.get(identity) if account is None: - logger.warning(f"Unknown account {identity} received disable event {raw_event}") - return + logger.warning(f"Unknown account {identity} received disable event {event}") + raise NotImplementedError _account = self.protocol.service._accounts[identity] _account.status.enabled = False logger.warning(f"Account {identity} disabled by remote") diff --git a/avilla/satori/perform/event/message.py b/avilla/satori/perform/event/message.py index 46c0c5aa..2b26b406 100644 --- a/avilla/satori/perform/event/message.py +++ b/avilla/satori/perform/event/message.py @@ -1,35 +1,32 @@ from __future__ import annotations from datetime import timedelta -from typing import TYPE_CHECKING +from flywheel import scoped_collect from graia.amnesia.builtins.memcache import Memcache, MemcacheService -from satori.model import ChannelType, Event +from satori.model import ChannelType from avilla.core.context import Context from avilla.core.elements import Reference +from avilla.core.globals import CONTEXT_CONTEXT_VAR from avilla.core.message import Message from avilla.core.selector import Selector +from avilla.satori.bases import InstanceOfAccount from avilla.satori.capability import SatoriCapability -from avilla.satori.collector.connection import ConnectionCollector -from avilla.satori.model import MessageEvent +from satori.event import MessageEvent from avilla.standard.core.message import MessageReceived, MessageSent -class SatoriEventMessagePerform((m := ConnectionCollector())._): - m.namespace = "avilla.protocol/satori::event" - m.identify = "message" +class SatoriEventMessagePerform(m := scoped_collect.globals().target, InstanceOfAccount, static=True): - @m.entity(SatoriCapability.event_callback, raw_event="message-created") - async def message_create(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] + @m.impl(SatoriCapability.event_callback, raw_event="message-created") + async def message_create(self, event: MessageEvent): + account = self.account cache: Memcache = self.protocol.avilla.launch_manager.get_component(MemcacheService).cache reply = None - if TYPE_CHECKING: - assert isinstance(raw_event, MessageEvent) - if raw_event.channel.type == ChannelType.DIRECT: - private = Selector().land(account.route["land"]).private(raw_event.channel.id) - user = private.user(raw_event.user.id) + if event.channel.type == ChannelType.DIRECT: + private = Selector().land(account.route["land"]).private(event.channel.id) + user = private.user(event.user.id) context = Context( account, user, @@ -37,30 +34,23 @@ async def message_create(self, raw_event: Event): user, account.route, ) - message = await SatoriCapability(account.staff.ext({"context": context})).deserialize( - raw_event.message.content - ) + with CONTEXT_CONTEXT_VAR.use(context): + message = await SatoriCapability.deserialize(event.message.content) if message.get(Reference): reply = message.get_first(Reference).message message = message.exclude(Reference) msg = Message( - id=f"{raw_event.message.id}", + id=f"{event.message.id}", scene=private, sender=private, content=message, - time=raw_event.timestamp, + time=event.timestamp, reply=reply, ) else: - guild = ( - Selector() - .land(account.route["land"]) - .guild(raw_event.guild.id if raw_event.guild else "True") - ) - channel = guild.channel(raw_event.channel.id) - member = channel.member( - raw_event.member.user.id if raw_event.member and raw_event.member.user else raw_event.user.id - ) + guild = Selector().land(account.route["land"]).guild(event.guild.id if event.guild else "True") + channel = guild.channel(event.channel.id) + member = channel.member(event.member.user.id if event.member and event.member.user else event.user.id) context = Context( account, member, @@ -68,26 +58,23 @@ async def message_create(self, raw_event: Event): channel, channel.member(account.route["account"]), ) - message = await SatoriCapability(account.staff.ext({"context": context})).deserialize( - raw_event.message.content - ) + with CONTEXT_CONTEXT_VAR.use(context): + message = await SatoriCapability.deserialize(event.message.content) if message.get(Reference): reply = message.get_first(Reference).message message = message.exclude(Reference) msg = Message( - id=f"{raw_event.message.id}", + id=f"{event.message.id}", scene=channel, sender=member, content=message, - time=raw_event.timestamp, + time=event.timestamp, reply=reply, ) - await cache.set( - f"satori/account({account.route['account']}).message({msg.id})", raw_event, timedelta(minutes=5) - ) + await cache.set(f"satori/account({account.route['account']}).message({msg.id})", event, timedelta(minutes=5)) context._collect_metadatas(msg.to_selector(), msg) return ( MessageSent(context, msg, account) - if msg.sender.last_value == raw_event.self_id + if msg.sender.last_value == event.self_id else MessageReceived(context, msg) ) diff --git a/avilla/satori/perform/event/metadata.py b/avilla/satori/perform/event/metadata.py index 0e78a71d..85ed7417 100644 --- a/avilla/satori/perform/event/metadata.py +++ b/avilla/satori/perform/event/metadata.py @@ -1,28 +1,24 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from flywheel import scoped_collect + from avilla.core.context import Context from avilla.core.event import MetadataModified, ModifyDetail from avilla.core.selector import Selector +from avilla.satori.bases import InstanceOfAccount from avilla.satori.capability import SatoriCapability -from avilla.satori.collector.connection import ConnectionCollector -from avilla.satori.model import GuildEvent, GuildMemberEvent, GuildRoleEvent -from avilla.standard.core.profile import Summary, Avatar, Nick -from satori.model import Event +from satori.event import GuildEvent, GuildMemberEvent, GuildRoleEvent +from avilla.standard.core.profile import Avatar, Nick, Summary -class SatoriEventMetadataPerform((m := ConnectionCollector())._): - m.namespace = "avilla.protocol/satori::event" - m.identify = "metadata" +class SatoriEventMetadataPerform(m := scoped_collect.globals().target, InstanceOfAccount, static=True): - @m.entity(SatoriCapability.event_callback, raw_event="guild-updated") - async def guild_updated(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] + @m.impl(SatoriCapability.event_callback, raw_event="guild-updated") + async def guild_updated(self, event: GuildEvent): + account = self.account land = Selector().land(account.route["land"]) - if TYPE_CHECKING: - assert isinstance(raw_event, GuildEvent) - guild = land.guild(raw_event.guild.id) - operator = guild.member(raw_event.operator.id) if raw_event.operator else guild + guild = land.guild(event.guild.id) + operator = guild.member(event.operator.id) if event.operator else guild context = Context( account, operator, @@ -36,7 +32,7 @@ async def guild_updated(self, raw_event: Event): guild, Summary, { - Summary.inh().name: ModifyDetail("set", raw_event.guild.name, None), + Summary.inh().name: ModifyDetail("set", event.guild.name, None), }, ), MetadataModified( @@ -44,8 +40,8 @@ async def guild_updated(self, raw_event: Event): guild, Nick, { - Nick.inh().name: ModifyDetail("set", raw_event.guild.name, None), - Nick.inh().nickname: ModifyDetail("set", raw_event.guild.name, None), + Nick.inh().name: ModifyDetail("set", event.guild.name, None), + Nick.inh().nickname: ModifyDetail("set", event.guild.name, None), }, ), MetadataModified( @@ -53,20 +49,18 @@ async def guild_updated(self, raw_event: Event): guild, Avatar, { - Avatar.inh().url: ModifyDetail("set", raw_event.guild.avatar, None), + Avatar.inh().url: ModifyDetail("set", event.guild.avatar, None), }, - ) + ), ] - @m.entity(SatoriCapability.event_callback, raw_event="guild-member-updated") - async def guild_member_updated(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] + @m.impl(SatoriCapability.event_callback, raw_event="guild-member-updated") + async def guild_member_updated(self, event: GuildMemberEvent): + account = self.account land = Selector().land(account.route["land"]) - if TYPE_CHECKING: - assert isinstance(raw_event, GuildMemberEvent) - guild = land.guild(raw_event.guild.id) - member = guild.member(raw_event.member.user.id if raw_event.member.user else raw_event.user.id) - operator = guild.member(raw_event.operator.id) if raw_event.operator else member + guild = land.guild(event.guild.id) + member = guild.member(event.member.user.id if event.member.user else event.user.id) + operator = guild.member(event.operator.id) if event.operator else member context = Context( account, operator, @@ -80,7 +74,7 @@ async def guild_member_updated(self, raw_event: Event): member, Summary, { - Summary.inh().name: ModifyDetail("set", raw_event.member.nick, None), + Summary.inh().name: ModifyDetail("set", event.member.nick, None), }, ), MetadataModified( @@ -88,8 +82,8 @@ async def guild_member_updated(self, raw_event: Event): member, Nick, { - Nick.inh().name: ModifyDetail("set", raw_event.user.name, None), - Nick.inh().nickname: ModifyDetail("set", raw_event.member.nick, None), + Nick.inh().name: ModifyDetail("set", event.user.name, None), + Nick.inh().nickname: ModifyDetail("set", event.member.nick, None), }, ), MetadataModified( @@ -97,20 +91,18 @@ async def guild_member_updated(self, raw_event: Event): member, Avatar, { - Avatar.inh().url: ModifyDetail("set", raw_event.member.avatar, None), + Avatar.inh().url: ModifyDetail("set", event.member.avatar, None), }, - ) + ), ] - @m.entity(SatoriCapability.event_callback, raw_event="guild-role-updated") - async def guild_role_updated(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] + @m.impl(SatoriCapability.event_callback, raw_event="guild-role-updated") + async def guild_role_updated(self, event: GuildRoleEvent): + account = self.account land = Selector().land(account.route["land"]) - if TYPE_CHECKING: - assert isinstance(raw_event, GuildRoleEvent) - guild = land.guild(raw_event.guild.id) - role = guild.role(raw_event.role.id) - operator = guild.member(raw_event.operator.id) if raw_event.operator else guild + guild = land.guild(event.guild.id) + role = guild.role(event.role.id) + operator = guild.member(event.operator.id) if event.operator else guild context = Context( account, operator, @@ -124,7 +116,7 @@ async def guild_role_updated(self, raw_event: Event): role, Summary, { - Summary.inh().name: ModifyDetail("set", raw_event.role.name, None), + Summary.inh().name: ModifyDetail("set", event.role.name, None), }, ), MetadataModified( @@ -132,8 +124,8 @@ async def guild_role_updated(self, raw_event: Event): role, Nick, { - Nick.inh().name: ModifyDetail("set", raw_event.role.name, None), - Nick.inh().nickname: ModifyDetail("set", raw_event.role.name, None), + Nick.inh().name: ModifyDetail("set", event.role.name, None), + Nick.inh().nickname: ModifyDetail("set", event.role.name, None), }, ), ] diff --git a/avilla/satori/perform/event/relationship.py b/avilla/satori/perform/event/relationship.py index c5a87092..0f49a387 100644 --- a/avilla/satori/perform/event/relationship.py +++ b/avilla/satori/perform/event/relationship.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from flywheel import scoped_collect + from avilla.core.context import Context from avilla.core.event import ( MemberCreated, @@ -9,25 +10,20 @@ SceneDestroyed, ) from avilla.core.selector import Selector +from avilla.satori.bases import InstanceOfAccount from avilla.satori.capability import SatoriCapability -from avilla.satori.collector.connection import ConnectionCollector -from avilla.satori.model import GuildEvent, GuildMemberEvent, GuildRoleEvent +from satori.event import GuildEvent, GuildMemberEvent, GuildRoleEvent from avilla.satori.event import RoleCreated, RoleDestroyed -from satori.model import Event -class SatoriEventRelationshipPerform((m := ConnectionCollector())._): - m.namespace = "avilla.protocol/satori::event" - m.identify = "relationship" +class SatoriEventRelationshipPerform(m := scoped_collect.globals().target, InstanceOfAccount, static=True): - @m.entity(SatoriCapability.event_callback, raw_event="guild-added") - async def guild_added(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] + @m.impl(SatoriCapability.event_callback, raw_event="guild-added") + async def guild_added(self, event: GuildEvent): + account = self.account land = Selector().land(account.route["land"]) - if TYPE_CHECKING: - assert isinstance(raw_event, GuildEvent) - guild = land.guild(raw_event.guild.id) - inviter = guild.member(raw_event.operator.id) if raw_event.operator else guild + guild = land.guild(event.guild.id) + inviter = guild.member(event.operator.id) if event.operator else guild context = Context( account, inviter, @@ -37,14 +33,12 @@ async def guild_added(self, raw_event: Event): ) return SceneCreated(context) - @m.entity(SatoriCapability.event_callback, raw_event="guild-removed") - async def guild_removed(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] + @m.impl(SatoriCapability.event_callback, raw_event="guild-removed") + async def guild_removed(self, event: GuildEvent): + account = self.account land = Selector().land(account.route["land"]) - if TYPE_CHECKING: - assert isinstance(raw_event, GuildEvent) - guild: Selector = land.guild(raw_event.guild.id) - operator = guild.member(raw_event.operator.id) if raw_event.operator else guild + guild: Selector = land.guild(event.guild.id) + operator = guild.member(event.operator.id) if event.operator else guild context = Context( account, operator, @@ -53,21 +47,18 @@ async def guild_removed(self, raw_event: Event): guild.member(account.route["account"]), ) return SceneDestroyed( - context, - active=bool(raw_event.operator) and raw_event.operator.id == account.route["account"], - indirect=not bool(raw_event.operator), + context, + active=bool(event.operator) and event.operator.id == account.route["account"], + indirect=not bool(event.operator), ) - - @m.entity(SatoriCapability.event_callback, raw_event="guild-member-added") - async def guild_member_added(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] + @m.impl(SatoriCapability.event_callback, raw_event="guild-member-added") + async def guild_member_added(self, event: GuildMemberEvent): + account = self.account land = Selector().land(account.route["land"]) - if TYPE_CHECKING: - assert isinstance(raw_event, GuildMemberEvent) - guild = land.guild(raw_event.guild.id) - member = guild.member(raw_event.member.user.id if raw_event.member.user else raw_event.user.id) - operator = guild.member(raw_event.operator.id) if raw_event.operator else member + guild = land.guild(event.guild.id) + member = guild.member(event.member.user.id if event.member.user else event.user.id) + operator = guild.member(event.operator.id) if event.operator else member context = Context( account, operator, @@ -77,15 +68,13 @@ async def guild_member_added(self, raw_event: Event): ) return MemberCreated(context) - @m.entity(SatoriCapability.event_callback, event_type="guild-member-removed") - async def guild_member_removed(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] + @m.impl(SatoriCapability.event_callback, raw_event="guild-member-removed") + async def guild_member_removed(self, event: GuildMemberEvent): + account = self.account land = Selector().land(account.route["land"]) - if TYPE_CHECKING: - assert isinstance(raw_event, GuildMemberEvent) - guild = land.guild(raw_event.guild.id) - member = guild.member(raw_event.member.user.id if raw_event.member.user else raw_event.user.id) - operator = guild.member(raw_event.operator.id) if raw_event.operator else member + guild = land.guild(event.guild.id) + member = guild.member(event.member.user.id if event.member.user else event.user.id) + operator = guild.member(event.operator.id) if event.operator else member context = Context( account, operator, @@ -94,20 +83,18 @@ async def guild_member_removed(self, raw_event: Event): guild.member(account.route["account"]), ) return MemberDestroyed( - context, - active=bool(raw_event.operator) and raw_event.operator.id == member["member"], - indirect=not bool(raw_event.operator), + context, + active=bool(event.operator) and event.operator.id == member["member"], + indirect=not bool(event.operator), ) - @m.entity(SatoriCapability.event_callback, raw_event="guild-role-created") - async def guild_role_created(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] + @m.impl(SatoriCapability.event_callback, raw_event="guild-role-created") + async def guild_role_created(self, event: GuildRoleEvent): + account = self.account land = Selector().land(account.route["land"]) - if TYPE_CHECKING: - assert isinstance(raw_event, GuildRoleEvent) - guild = land.guild(raw_event.guild.id) - role = guild.role(raw_event.role.id) - operator = guild.member(raw_event.operator.id) if raw_event.operator else guild + guild = land.guild(event.guild.id) + role = guild.role(event.role.id) + operator = guild.member(event.operator.id) if event.operator else guild context = Context( account, operator, @@ -117,15 +104,13 @@ async def guild_role_created(self, raw_event: Event): ) return RoleCreated(context) - @m.entity(SatoriCapability.event_callback, event_type="guild-role-deleted") - async def guild_role_deleted(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] + @m.impl(SatoriCapability.event_callback, raw_event="guild-role-deleted") + async def guild_role_deleted(self, event: GuildRoleEvent): + account = self.account land = Selector().land(account.route["land"]) - if TYPE_CHECKING: - assert isinstance(raw_event, GuildRoleEvent) - guild = land.guild(raw_event.guild.id) - role = guild.role(raw_event.role.id) - operator = guild.member(raw_event.operator.id) if raw_event.operator else guild + guild = land.guild(event.guild.id) + role = guild.role(event.role.id) + operator = guild.member(event.operator.id) if event.operator else guild context = Context( account, operator, @@ -134,7 +119,7 @@ async def guild_role_deleted(self, raw_event: Event): guild.member(account.route["account"]), ) return RoleDestroyed( - context, - active=bool(raw_event.operator) and raw_event.operator.id == account.route["account"], - indirect=not bool(raw_event.operator), + context, + active=bool(event.operator) and event.operator.id == account.route["account"], + indirect=not bool(event.operator), ) diff --git a/avilla/satori/perform/event/request.py b/avilla/satori/perform/event/request.py index f772de7c..1abf4709 100644 --- a/avilla/satori/perform/event/request.py +++ b/avilla/satori/perform/event/request.py @@ -1,69 +1,63 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING + +from flywheel import scoped_collect from avilla.core.context import Context from avilla.core.request import Request from avilla.core.selector import Selector -from avilla.satori.const import land as LAND +from avilla.satori.bases import InstanceOfAccount from avilla.satori.capability import SatoriCapability -from avilla.satori.collector.connection import ConnectionCollector -from avilla.satori.model import GuildMemberEvent, DirectEvent, GuildEvent - +from avilla.satori.const import land as LAND +from satori.event import GuildMemberEvent, UserEvent, GuildEvent from avilla.standard.core.profile.metadata import Nick, Summary from avilla.standard.core.request import RequestReceived -from satori.model import Event -class SatoriEventRequestPerform((m := ConnectionCollector())._): - m.namespace = "avilla.protocol/satori::event" - m.identify = "request" - @m.entity(SatoriCapability.event_callback, raw_event="guild-member-request") - async def member_join_request(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] +class SatoriEventRequestPerform(m := scoped_collect.globals().target, InstanceOfAccount, static=True): + + @m.impl(SatoriCapability.event_callback, raw_event="guild-member-request") + async def member_join_request(self, event: GuildMemberEvent): + account = self.account land = Selector().land(account.route["land"]) - if TYPE_CHECKING: - assert isinstance(raw_event, GuildMemberEvent) - guild = land.guild(raw_event.guild.id) - sender = land.user(raw_event.user.id) + guild = land.guild(event.guild.id) + sender = land.user(event.user.id) context = Context( account, sender, guild, guild, guild.member(account.route["account"]), - mediums=[guild.member(raw_event.operator.id)] if raw_event.operator else None, + mediums=[guild.member(event.operator.id)] if event.operator else None, ) request = Request( - f"{raw_event.message.id if raw_event.message else raw_event.id}", + f"{event.message.id if event.message else event.id}", LAND(account.route["land"]), guild, sender, account, datetime.now(), request_type="satori::guild-member-request", - message=raw_event.message.content if raw_event.message else None, + message=event.message.content if event.message else None, ) context._collect_metadatas( guild, - Nick(raw_event.guild.name, raw_event.guild.name, None), # type: ignore - Summary(raw_event.guild.name, None), # type: ignore + Nick(event.guild.name, event.guild.name, None), # type: ignore + Summary(event.guild.name, None), # type: ignore ) context._collect_metadatas( sender, - Nick(raw_event.user.name, raw_event.member.nick or raw_event.user.name, None), # type: ignore - Summary(raw_event.user.name, None), # type: ignore + Nick(event.user.name, event.member.nick or event.user.name, None), # type: ignore + Summary(event.user.name, None), # type: ignore ) return RequestReceived(context, request) - @m.entity(SatoriCapability.event_callback, raw_event="friend-request") - async def new_friend_request(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] + @m.impl(SatoriCapability.event_callback, raw_event="friend-request") + async def new_friend_request(self, event: UserEvent): + account = self.account land = Selector().land(account.route["land"]) - if TYPE_CHECKING: - assert isinstance(raw_event, DirectEvent) - sender = land.user(raw_event.user.id) + sender = land.user(event.user.id) context = Context( account, sender, @@ -72,30 +66,28 @@ async def new_friend_request(self, raw_event: Event): account.route, ) request = Request( - f"{raw_event.message.id if raw_event.message else raw_event.id}", + f"{event.message.id if event.message else event.id}", LAND(account.route["land"]), sender, sender, account, datetime.now(), request_type="satori::friend-request", - message=raw_event.message.content if raw_event.message else None, + message=event.message.content if event.message else None, ) context._collect_metadatas( sender, - Nick(raw_event.user.name, raw_event.user.name, None), # type: ignore - Summary(raw_event.user.name, None), # type: ignore + Nick(event.user.name, event.user.name, None), # type: ignore + Summary(event.user.name, None), # type: ignore ) return RequestReceived(context, request) - @m.entity(SatoriCapability.event_callback, raw_event="guild-request") - async def bot_invited_join_group_request(self, raw_event: Event): - account = self.protocol.service._accounts[self.connection.identity] + @m.impl(SatoriCapability.event_callback, raw_event="guild-request") + async def bot_invited_join_group_request(self, event: GuildEvent): + account = self.account land = Selector().land(account.route["land"]) - if TYPE_CHECKING: - assert isinstance(raw_event, GuildEvent) - guild = land.guild(raw_event.guild.id) - operator = guild.member(raw_event.operator.id) if raw_event.operator else guild + guild = land.guild(event.guild.id) + operator = guild.member(event.operator.id) if event.operator else guild context = Context( account, operator, @@ -104,24 +96,24 @@ async def bot_invited_join_group_request(self, raw_event: Event): account.route, ) request = Request( - f"{raw_event.message.id if raw_event.message else raw_event.id}", + f"{event.message.id if event.message else event.id}", LAND(account.route["land"]), operator, operator, account, datetime.now(), request_type="satori::guild-request", - message=raw_event.message.content if raw_event.message else None, + message=event.message.content if event.message else None, ) context._collect_metadatas( guild, - Nick(raw_event.guild.name, raw_event.guild.name, None), # type: ignore - Summary(raw_event.guild.name, None), # type: ignore + Nick(event.guild.name, event.guild.name, None), # type: ignore + Summary(event.guild.name, None), # type: ignore ) - if raw_event.operator: + if event.operator: context._collect_metadatas( operator, - Nick(raw_event.operator.name, raw_event.operator.name, None), # type: ignore - Summary(raw_event.operator.name, None), # type: ignore + Nick(event.operator.name, event.operator.name, None), # type: ignore + Summary(event.operator.name, None), # type: ignore ) return RequestReceived(context, request) diff --git a/avilla/satori/perform/message/deserialize.py b/avilla/satori/perform/message/deserialize.py index 7ea92464..12166049 100644 --- a/avilla/satori/perform/message/deserialize.py +++ b/avilla/satori/perform/message/deserialize.py @@ -1,8 +1,9 @@ from __future__ import annotations +from contextlib import suppress from dataclasses import asdict -from typing import TYPE_CHECKING +from flywheel import global_collect from satori.element import At from satori.element import Audio as SatoriAudio from satori.element import Bold, Br @@ -25,6 +26,7 @@ from satori.element import Underline from satori.element import Video as SatoriVideo +from avilla.core.context import Context from avilla.core.elements import ( Audio, File, @@ -35,7 +37,6 @@ Text, Video, ) -from avilla.core.ryanvk.collector.application import ApplicationCollector from avilla.core.selector import Selector from avilla.satori.capability import SatoriCapability from avilla.satori.element import Button @@ -45,117 +46,156 @@ SatoriImageResource, SatoriVideoResource, ) -from graia.ryanvk import OptionalAccess - -if TYPE_CHECKING: - from avilla.core.context import Context - from avilla.satori.account import SatoriAccount - - -class SatoriMessageDeserializePerform((m := ApplicationCollector())._): - m.namespace = "avilla.protocol/satori::message" - m.identify = "deserialize" - - context: OptionalAccess[Context] = OptionalAccess() - account: OptionalAccess[SatoriAccount] = OptionalAccess() - - # LINK: https://github.com/microsoft/pyright/issues/5409 - - @m.entity(SatoriCapability.deserialize_element, raw_element=SatoriText) - async def text(self, raw_element: SatoriText) -> Text: - return Text(raw_element.text) - - @m.entity(SatoriCapability.deserialize_element, raw_element=At) - async def at(self, raw_element: At) -> Notice | NoticeAll: - if raw_element.type in ("all", "here"): - return NoticeAll() - scene = self.context.scene if self.context else Selector().land("satori") - if raw_element.role: - return Notice(scene.role(raw_element.role)) - return Notice(scene.member(raw_element.id)) # type: ignore - - @m.entity(SatoriCapability.deserialize_element, raw_element=Sharp) - async def sharp(self, raw_element: Sharp) -> Notice: - scene = self.context.scene if self.context else Selector().land("satori") - return Notice(scene.into(f"~.channel({raw_element.id})")) # type: ignore - - @m.entity(SatoriCapability.deserialize_element, raw_element=Link) - async def a(self, raw_element: Link) -> Text: - return Text(raw_element.url, style="link") - - @m.entity(SatoriCapability.deserialize_element, raw_element=Image) - async def img(self, raw_element: Image) -> Picture: - scene = self.context.scene if self.context else Selector().land("satori") - res = SatoriImageResource(**asdict(raw_element)) - res.selector = scene.picture(raw_element.src) - return Picture(res) - - @m.entity(SatoriCapability.deserialize_element, raw_element=SatoriVideo) - async def video(self, raw_element: SatoriVideo) -> Video: - scene = self.context.scene if self.context else Selector().land("satori") - res = SatoriVideoResource(**asdict(raw_element)) - res.selector = scene.video(raw_element.src) - return Video(res) - - @m.entity(SatoriCapability.deserialize_element, raw_element=SatoriAudio) - async def audio(self, raw_element: SatoriAudio) -> Audio: - scene = self.context.scene if self.context else Selector().land("satori") - res = SatoriAudioResource(**asdict(raw_element)) - res.selector = scene.video(raw_element.src) - return Audio(res) - - @m.entity(SatoriCapability.deserialize_element, raw_element=SatoriFile) - async def file(self, raw_element: SatoriFile) -> File: - scene = self.context.scene if self.context else Selector().land("satori") - res = SatoriFileResource(**asdict(raw_element)) - res.selector = scene.video(raw_element.src) - return File(res) - - @m.entity(SatoriCapability.deserialize_element, raw_element=Quote) - async def quote(self, raw_element: Quote) -> Reference: - scene = self.context.scene if self.context else Selector().land("satori") - return Reference(scene.message(raw_element.id)) # type: ignore - - @m.entity(SatoriCapability.deserialize_element, raw_element=Bold) - async def bold(self, raw_element: Bold) -> Text: - return Text(raw_element.dumps(True), style="bold") - - @m.entity(SatoriCapability.deserialize_element, raw_element=Italic) - async def italic(self, raw_element: Italic) -> Text: - return Text(raw_element.dumps(True), style="italic") - - @m.entity(SatoriCapability.deserialize_element, raw_element=Strikethrough) - async def strikethrough(self, raw_element: Strikethrough) -> Text: - return Text(raw_element.dumps(True), style="strikethrough") - - @m.entity(SatoriCapability.deserialize_element, raw_element=Underline) - async def underline(self, raw_element: Underline) -> Text: - return Text(raw_element.dumps(True), style="underline") - - @m.entity(SatoriCapability.deserialize_element, raw_element=Spoiler) - async def spoiler(self, raw_element: Spoiler) -> Text: - return Text(raw_element.dumps(True), style="spoiler") - - @m.entity(SatoriCapability.deserialize_element, raw_element=Code) - async def code(self, raw_element: Code) -> Text: - return Text(raw_element.dumps(True), style="code") - - @m.entity(SatoriCapability.deserialize_element, raw_element=Superscript) - async def superscript(self, raw_element: Superscript) -> Text: - return Text(raw_element.dumps(True), style="superscript") - - @m.entity(SatoriCapability.deserialize_element, raw_element=Subscript) - async def subscript(self, raw_element: Subscript) -> Text: - return Text(raw_element.dumps(True), style="subscript") - - @m.entity(SatoriCapability.deserialize_element, raw_element=Br) - async def br(self, raw_element: Br) -> Text: - return Text("\n", style="br") - - @m.entity(SatoriCapability.deserialize_element, raw_element=Paragraph) - async def paragraph(self, raw_element: Paragraph) -> Text: - return Text(raw_element.dumps(True), style="paragraph") - - @m.entity(SatoriCapability.deserialize_element, raw_element=SatoriButton) - async def button(self, raw_element: SatoriButton) -> Button: - return Button(**asdict(raw_element)) + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=SatoriText) +async def text(element: SatoriText) -> Text: + return Text(element.text) + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=At) +async def at(element: At) -> Notice | NoticeAll: + if element.type in ("all", "here"): + return NoticeAll() + scene = Selector().land("satori") + with suppress(LookupError): + scene = Context.current.scene + if element.role: + return Notice(scene.role(element.role)) + return Notice(scene.member(element.id)) # type: ignore + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Sharp) +async def sharp(element: Sharp) -> Notice: + scene = Selector().land("satori") + with suppress(LookupError): + scene = Context.current.scene + return Notice(scene.into(f"~.channel({element.id})")) # type: ignore + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Link) +async def a(element: Link) -> Text: + return Text(element.url, style="link") + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Image) +async def img(element: Image) -> Picture: + scene = Selector().land("satori") + with suppress(LookupError): + scene = Context.current.scene + res = SatoriImageResource(**asdict(element)) + res.selector = scene.picture(element.src) + return Picture(res) + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=SatoriVideo) +async def video(element: SatoriVideo) -> Video: + scene = Selector().land("satori") + with suppress(LookupError): + scene = Context.current.scene + res = SatoriVideoResource(**asdict(element)) + res.selector = scene.video(element.src) + return Video(res) + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=SatoriAudio) +async def audio(element: SatoriAudio) -> Audio: + scene = Selector().land("satori") + with suppress(LookupError): + scene = Context.current.scene + res = SatoriAudioResource(**asdict(element)) + res.selector = scene.video(element.src) + return Audio(res) + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=SatoriFile) +async def file(element: SatoriFile) -> File: + scene = Selector().land("satori") + with suppress(LookupError): + scene = Context.current.scene + res = SatoriFileResource(**asdict(element)) + res.selector = scene.video(element.src) + return File(res) + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Quote) +async def quote(element: Quote) -> Reference: + scene = Selector().land("satori") + with suppress(LookupError): + scene = Context.current.scene + return Reference(scene.message(element.id)) # type: ignore + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Bold) +async def bold(element: Bold) -> Text: + return Text(element.dumps(True), style="bold") + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Italic) +async def italic(element: Italic) -> Text: + return Text(element.dumps(True), style="italic") + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Strikethrough) +async def strikethrough(element: Strikethrough) -> Text: + return Text(element.dumps(True), style="strikethrough") + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Underline) +async def underline(element: Underline) -> Text: + return Text(element.dumps(True), style="underline") + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Spoiler) +async def spoiler(element: Spoiler) -> Text: + return Text(element.dumps(True), style="spoiler") + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Code) +async def code(element: Code) -> Text: + return Text(element.dumps(True), style="code") + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Superscript) +async def superscript(element: Superscript) -> Text: + return Text(element.dumps(True), style="superscript") + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Subscript) +async def subscript(element: Subscript) -> Text: + return Text(element.dumps(True), style="subscript") + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Br) +async def br(element: Br) -> Text: + return Text("\n", style="br") + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=Paragraph) +async def paragraph(element: Paragraph) -> Text: + return Text(element.dumps(True), style="paragraph") + + +@global_collect +@SatoriCapability.deserialize_element.impl(element=SatoriButton) +async def button(element: SatoriButton) -> Button: + return Button(**asdict(element)) diff --git a/avilla/satori/perform/message/serialize.py b/avilla/satori/perform/message/serialize.py index 0d437917..d3b9d278 100644 --- a/avilla/satori/perform/message/serialize.py +++ b/avilla/satori/perform/message/serialize.py @@ -1,101 +1,104 @@ from __future__ import annotations -from typing import TYPE_CHECKING - +from flywheel import global_collect from satori.parser import escape from avilla.core.elements import Audio, File, Notice, NoticeAll, Picture, Text, Video -from avilla.core.ryanvk.collector.account import AccountCollector from avilla.satori.capability import SatoriCapability from avilla.satori.element import Button from avilla.satori.resource import SatoriResource -if TYPE_CHECKING: - from avilla.satori.account import SatoriAccount # noqa - from avilla.satori.protocol import SatoriProtocol # noqa - - -class SatoriMessageSerializePerform((m := AccountCollector["SatoriProtocol", "SatoriAccount"]())._): - m.namespace = "avilla.protocol/satori::message" - m.identify = "serialize" - - # LINK: https://github.com/microsoft/pyright/issues/5409 - - @m.entity(SatoriCapability.serialize_element, element=Text) - async def text(self, element: Text) -> str: - text = escape(element.text) - text.replace("\n", "
") - if not element.style: - return text - style = element.style - if style in {"a", "link"}: - return f'' - if style == { - "b", - "strong", - "bold", - "i", - "em", - "italic", - "u", - "ins", - "underline", - "s", - "del", - "strike", - "spl", - "spoiler", - "code", - "sup", - "sub", - "superscript", - "subscript", - "p", - "paragraph", - }: - return f"<{style}>{text}" + +@global_collect +@SatoriCapability.serialize_element.impl(element=Text) +async def text(element: Text) -> str: + text = escape(element.text) + text.replace("\n", "
") + if not element.style: return text + style = element.style + if style in {"a", "link"}: + return f'
' + if style == { + "b", + "strong", + "bold", + "i", + "em", + "italic", + "u", + "ins", + "underline", + "s", + "del", + "strike", + "spl", + "spoiler", + "code", + "sup", + "sub", + "superscript", + "subscript", + "p", + "paragraph", + }: + return f"<{style}>{text}" + return text + + +@global_collect +@SatoriCapability.serialize_element.impl(element=Notice) +async def notice(element: Notice) -> str: + if "role" in element.target.pattern: + return f'' + if "channel" in element.target.pattern: + return f'' + return f'' + + +@global_collect +@SatoriCapability.serialize_element.impl(element=NoticeAll) +async def notice_all(element: NoticeAll) -> str: + return '' + + +@global_collect +@SatoriCapability.serialize_element.impl(element=Picture) +async def picture(element: Picture) -> str: + res = element.resource + if not isinstance(res, SatoriResource): + raise NotImplementedError("Only SatoriResource is supported.") + return f'' + + +@global_collect +@SatoriCapability.serialize_element.impl(element=Audio) +async def audio(element: Audio) -> str: + res = element.resource + if not isinstance(res, SatoriResource): + raise NotImplementedError("Only SatoriResource is supported.") + return f'