Skip to content

Commit

Permalink
Adding proper process executor shutdown to ensure freeing the port in…
Browse files Browse the repository at this point in the history
… tcp container.
  • Loading branch information
rcschrg committed Feb 24, 2024
1 parent b032fd0 commit e3898dc
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 23 deletions.
5 changes: 4 additions & 1 deletion mango/agent/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,13 @@ async def shutdown(self):
await self._check_inbox_task
except asyncio.CancelledError:
pass

try:
await self._scheduler.stop()
except asyncio.CancelledError:
pass
try:
self._scheduler.shutdown()
except asyncio.CancelledError:
pass
finally:
logger.info("Agent %s: Shutdown successful", self.aid)
1 change: 1 addition & 0 deletions mango/container/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This module contains the abstract Container class and the subclasses
TCPContainer and MQTTContainer
"""

import asyncio
import logging
import time
Expand Down
46 changes: 31 additions & 15 deletions mango/util/scheduling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Module for commonly used time based scheduled task executed inside one agent.
"""

import asyncio
import concurrent.futures
import datetime
Expand All @@ -18,9 +19,10 @@ class Suspendable:
Wraps a coroutine, intercepting __await__ to add the functionality of suspending.
"""

def __init__(self, coro, ext_contr_event=None):
def __init__(self, coro, ext_contr_event=None, kill_event=None):
self._coro = coro

self._kill_event = kill_event
if ext_contr_event is not None:
self._can_run = ext_contr_event
else:
Expand All @@ -43,6 +45,9 @@ def __await__(self):
except BaseException as err:
send, message = iter_throw, err

if self._kill_event is not None and self._kill_event.is_set():
return None

try:
# throw error or resume coroutine
signal = send(message)
Expand Down Expand Up @@ -411,13 +416,16 @@ def __init__(
self._process_pool_exec = concurrent.futures.ProcessPoolExecutor(
max_workers=num_process_parallel, initializer=_create_asyncio_context
)
self._manager = None
self._suspendable = suspendable
self._observable = observable

@staticmethod
def _run_task_in_p_context(task, suspend_event):
def _run_task_in_p_context(task, suspend_event, kill_event):
try:
coro = Suspendable(task.run(), ext_contr_event=suspend_event)
coro = Suspendable(
task.run(), ext_contr_event=suspend_event, kill_event=kill_event
)

return asyncio.get_event_loop().run_until_complete(coro)
finally:
Expand Down Expand Up @@ -608,17 +616,24 @@ def schedule_process_task(self, task: ScheduledProcessTask, src=None):
"""

loop = asyncio.get_running_loop()
manager = Manager()
event = manager.Event()
if self._manager is None:
self._manager = Manager()
event = self._manager.Event()
kill_event = self._manager.Event()
kill_event.clear()
event.set()
l_task = asyncio.ensure_future(
loop.run_in_executor(
self._process_pool_exec, Scheduler._run_task_in_p_context, task, event
self._process_pool_exec,
Scheduler._run_task_in_p_context,
task,
event,
kill_event,
)
)
l_task.add_done_callback(self._remove_process_task)
l_task.add_done_callback(task.on_stop)
self._scheduled_process_tasks.append((task, l_task, event, src))
self._scheduled_process_tasks.append((task, l_task, (event, kill_event), src))
return l_task

def schedule_timestamp_process_task(
Expand Down Expand Up @@ -747,8 +762,8 @@ def suspend(self, given_src):
if src == given_src and coro is not None:
coro.suspend()
for _, _, event, src in self._scheduled_process_tasks:
if src == given_src and event is not None:
event.clear()
if src == given_src and event[0] is not None:
event[0].clear()

def resume(self, given_src):
"""Resume a set of tasks triggered by the given src object.
Expand All @@ -763,15 +778,14 @@ def resume(self, given_src):
if src == given_src and coro is not None:
coro.resume()
for _, _, event, src in self._scheduled_process_tasks:
if src == given_src and event is not None:
event.set()
if src == given_src and event[0] is not None:
event[0].set()

def _remove_process_task(self, fut=asyncio.Future):
for i in range(len(self._scheduled_process_tasks)):
_, task, event, _ = self._scheduled_process_tasks[i]
if task == fut:
del self._scheduled_process_tasks[i]
event.set()
break

# methods for removing tasks, stopping or shutting down
Expand Down Expand Up @@ -836,6 +850,8 @@ def shutdown(self):
"""
# resume all process so they can get shutdown
for _, _, event, _ in self._scheduled_process_tasks:
if event is not None:
event.set()
self._process_pool_exec.shutdown()
if event[1] is not None:
event[1].set()
self._process_pool_exec.shutdown(wait=True, cancel_futures=True)
if self._manager is not None:
self._manager.shutdown()
16 changes: 9 additions & 7 deletions tests/unit_tests/util/scheduling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ async def increase_counter():
)

# WHEN
t = scheduler.schedule_task(RecurrentScheduledTask(increase_counter, recurrency, clock))
t = scheduler.schedule_task(
RecurrentScheduledTask(increase_counter, recurrency, clock)
)
try:
new_time = start + datetime.timedelta(days=1)
clock.set_time(new_time.timestamp())
Expand Down Expand Up @@ -73,6 +75,7 @@ async def increase_counter():
assert task._is_done.done()
assert len(l) == 2


@pytest.mark.asyncio
async def test_recurrent_wait():
# GIVEN
Expand All @@ -84,14 +87,15 @@ async def test_recurrent_wait():

async def increase_counter():
l.append(clock._time)

tomorrow = start + datetime.timedelta(days=1)
aftertomorrow = start + datetime.timedelta(days=2)
recurrency = rrule.rrule(
rrule.DAILY, interval=1, dtstart=tomorrow, until=end
)
recurrency = rrule.rrule(rrule.DAILY, interval=1, dtstart=tomorrow, until=end)

# WHEN
t = scheduler.schedule_task(RecurrentScheduledTask(increase_counter, recurrency, clock))
t = scheduler.schedule_task(
RecurrentScheduledTask(increase_counter, recurrency, clock)
)
task = scheduler._scheduled_tasks[0][0]
try:
clock.set_time(start.timestamp())
Expand Down Expand Up @@ -364,8 +368,6 @@ async def test_task_as_process_suspend():

scheduler.resume(marker)

scheduler.shutdown()


@pytest.mark.asyncio
async def test_future_wait_task():
Expand Down

0 comments on commit e3898dc

Please sign in to comment.