Skip to content

Commit

Permalink
Revert "Remove UART thread (#598)" (#604)
Browse files Browse the repository at this point in the history
This reverts commit 1ac01ed.
  • Loading branch information
puddly authored Dec 29, 2023
1 parent 09b2782 commit d917e99
Show file tree
Hide file tree
Showing 7 changed files with 454 additions and 6 deletions.
1 change: 1 addition & 0 deletions bellows/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
vol.Optional(CONF_EZSP_POLICIES, default={}): vol.Schema(
{vol.Optional(str): int}
),
vol.Optional(CONF_USE_THREAD, default=True): cv_boolean,
}
)

Expand Down
6 changes: 3 additions & 3 deletions bellows/ezsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def startup_reset(self) -> None:
async def initialize(cls, zigpy_config: dict) -> EZSP:
"""Return initialized EZSP instance."""
ezsp = cls(zigpy_config[conf.CONF_DEVICE])
await ezsp.connect()
await ezsp.connect(use_thread=zigpy_config[conf.CONF_USE_THREAD])

try:
await ezsp.startup_reset()
Expand All @@ -139,9 +139,9 @@ async def initialize(cls, zigpy_config: dict) -> EZSP:

return ezsp

async def connect(self) -> None:
async def connect(self, *, use_thread: bool = True) -> None:
assert self._gw is None
self._gw = await bellows.uart.connect(self._config, self)
self._gw = await bellows.uart.connect(self._config, self, use_thread=use_thread)
self._protocol = v4.EZSPv4(self.handle_callback, self._gw)

async def reset(self):
Expand Down
122 changes: 122 additions & 0 deletions bellows/thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
import functools
import logging
import sys

LOGGER = logging.getLogger(__name__)


class EventLoopThread:
"""Run a parallel event loop in a separate thread."""

def __init__(self):
self.loop = None
self.thread_complete = None

def run_coroutine_threadsafe(self, coroutine):
current_loop = asyncio.get_event_loop()
future = asyncio.run_coroutine_threadsafe(coroutine, self.loop)
return asyncio.wrap_future(future, loop=current_loop)

def _thread_main(self, init_task):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

try:
self.loop.run_until_complete(init_task)
self.loop.run_forever()
finally:
self.loop.close()
self.loop = None

async def start(self):
current_loop = asyncio.get_event_loop()
if self.loop is not None and not self.loop.is_closed():
return

executor_opts = {"max_workers": 1}
if sys.version_info[:2] >= (3, 6):
executor_opts["thread_name_prefix"] = __name__
executor = ThreadPoolExecutor(**executor_opts)

thread_started_future = current_loop.create_future()

async def init_task():
current_loop.call_soon_threadsafe(thread_started_future.set_result, None)

# Use current loop so current loop has a reference to the long-running thread
# as one of its tasks
thread_complete = current_loop.run_in_executor(
executor, self._thread_main, init_task()
)
self.thread_complete = thread_complete
current_loop.call_soon(executor.shutdown, False)
await thread_started_future
return thread_complete

def force_stop(self):
if self.loop is None:
return

def cancel_tasks_and_stop_loop():
tasks = asyncio.all_tasks(loop=self.loop)

for task in tasks:
self.loop.call_soon_threadsafe(task.cancel)

gather = asyncio.gather(*tasks, return_exceptions=True)
gather.add_done_callback(
lambda _: self.loop.call_soon_threadsafe(self.loop.stop)
)

self.loop.call_soon_threadsafe(cancel_tasks_and_stop_loop)


class ThreadsafeProxy:
"""Proxy class which enforces threadsafe non-blocking calls
This class can be used to wrap an object to ensure any calls
using that object's methods are done on a particular event loop
"""

def __init__(self, obj, obj_loop):
self._obj = obj
self._obj_loop = obj_loop

def __getattr__(self, name):
func = getattr(self._obj, name)
if not callable(func):
raise TypeError(
"Can only use ThreadsafeProxy with callable attributes: {}.{}".format(
self._obj.__class__.__name__, name
)
)

def func_wrapper(*args, **kwargs):
loop = self._obj_loop
curr_loop = asyncio.get_running_loop()
call = functools.partial(func, *args, **kwargs)
if loop == curr_loop:
return call()
if loop.is_closed():
# Disconnected
LOGGER.warning("Attempted to use a closed event loop")
return
if asyncio.iscoroutinefunction(func):
future = asyncio.run_coroutine_threadsafe(call(), loop)
return asyncio.wrap_future(future, loop=curr_loop)
else:

def check_result_wrapper():
result = call()
if result is not None:
raise TypeError(
(
"ThreadsafeProxy can only wrap functions with no return"
"value \nUse an async method to return values: {}.{}"
).format(self._obj.__class__.__name__, name)
)

loop.call_soon_threadsafe(check_result_wrapper)

return func_wrapper
22 changes: 21 additions & 1 deletion bellows/uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import zigpy.config
import zigpy.serial

from bellows.thread import EventLoopThread, ThreadsafeProxy
import bellows.types as t

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -363,7 +364,7 @@ def _unstuff(self, s):
return out


async def connect(config, application):
async def _connect(config, application):
loop = asyncio.get_event_loop()

connection_future = loop.create_future()
Expand All @@ -386,4 +387,23 @@ async def connect(config, application):

await connection_future

thread_safe_protocol = ThreadsafeProxy(protocol, loop)
return thread_safe_protocol, connection_done_future


async def connect(config, application, use_thread=True):
if use_thread:
application = ThreadsafeProxy(application, asyncio.get_event_loop())
thread = EventLoopThread()
await thread.start()
try:
protocol, connection_done = await thread.run_coroutine_threadsafe(
_connect(config, application)
)
except Exception:
thread.force_stop()
raise
connection_done.add_done_callback(lambda _: thread.force_stop())
else:
protocol, _ = await _connect(config, application)
return protocol
9 changes: 7 additions & 2 deletions bellows/zigbee/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
import zigpy.zdo.types as zdo_t

import bellows
from bellows.config import CONF_EZSP_CONFIG, CONF_EZSP_POLICIES, CONFIG_SCHEMA
from bellows.config import (
CONF_EZSP_CONFIG,
CONF_EZSP_POLICIES,
CONF_USE_THREAD,
CONFIG_SCHEMA,
)
from bellows.exception import ControllerError, EzspError, StackAlreadyRunning
import bellows.ezsp
from bellows.ezsp.v8.types.named import EmberDeviceUpdate
Expand Down Expand Up @@ -133,7 +138,7 @@ async def _get_board_info(self) -> tuple[str, str, str] | tuple[None, None, None

async def connect(self) -> None:
ezsp = bellows.ezsp.EZSP(self.config[zigpy.config.CONF_DEVICE])
await ezsp.connect()
await ezsp.connect(use_thread=self.config[CONF_USE_THREAD])

try:
await ezsp.startup_reset()
Expand Down
Loading

0 comments on commit d917e99

Please sign in to comment.