Skip to content

Commit 298851b

Browse files
1st1andr-04
andcommitted
Experimental fix for #169
Co-authored-by: Andrey Egorov <[email protected]>
1 parent 700582a commit 298851b

File tree

5 files changed

+185
-39
lines changed

5 files changed

+185
-39
lines changed

tests/test_sockets.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,91 @@ def test_socket_sync_remove_and_immediately_close(self):
189189
self.assertEqual(sock.fileno(), -1)
190190
self.loop.run_until_complete(asyncio.sleep(0.01, loop=self.loop))
191191

192+
def test_sock_cancel_add_reader_race(self):
193+
srv_sock_conn = None
194+
195+
async def server():
196+
nonlocal srv_sock_conn
197+
sock_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
198+
sock_server.setblocking(False)
199+
with sock_server:
200+
sock_server.bind(('127.0.0.1', 0))
201+
sock_server.listen()
202+
fut = asyncio.ensure_future(
203+
client(sock_server.getsockname()), loop=self.loop)
204+
srv_sock_conn, _ = await self.loop.sock_accept(sock_server)
205+
srv_sock_conn.setsockopt(
206+
socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
207+
with srv_sock_conn:
208+
await fut
209+
210+
async def client(addr):
211+
sock_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
212+
sock_client.setblocking(False)
213+
with sock_client:
214+
await self.loop.sock_connect(sock_client, addr)
215+
_, pending_read_futs = await asyncio.wait(
216+
[self.loop.sock_recv(sock_client, 1)],
217+
timeout=1, loop=self.loop)
218+
219+
async def send_server_data():
220+
# Wait a little bit to let reader future cancel and
221+
# schedule the removal of the reader callback. Right after
222+
# "rfut.cancel()" we will call "loop.sock_recv()", which
223+
# will add a reader. This will make a race between
224+
# remove- and add-reader.
225+
await asyncio.sleep(0.1, loop=self.loop)
226+
await self.loop.sock_sendall(srv_sock_conn, b'1')
227+
self.loop.create_task(send_server_data())
228+
229+
for rfut in pending_read_futs:
230+
rfut.cancel()
231+
232+
data = await self.loop.sock_recv(sock_client, 1)
233+
234+
self.assertEqual(data, b'1')
235+
236+
self.loop.run_until_complete(server())
237+
238+
def test_sock_send_before_cancel(self):
239+
srv_sock_conn = None
240+
241+
async def server():
242+
nonlocal srv_sock_conn
243+
sock_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
244+
sock_server.setblocking(False)
245+
with sock_server:
246+
sock_server.bind(('127.0.0.1', 0))
247+
sock_server.listen()
248+
fut = asyncio.ensure_future(
249+
client(sock_server.getsockname()), loop=self.loop)
250+
srv_sock_conn, _ = await self.loop.sock_accept(sock_server)
251+
with srv_sock_conn:
252+
await fut
253+
254+
async def client(addr):
255+
await asyncio.sleep(0.01, loop=self.loop)
256+
sock_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
257+
sock_client.setblocking(False)
258+
with sock_client:
259+
await self.loop.sock_connect(sock_client, addr)
260+
_, pending_read_futs = await asyncio.wait(
261+
[self.loop.sock_recv(sock_client, 1)],
262+
timeout=1, loop=self.loop)
263+
264+
# server can send the data in a random time, even before
265+
# the previous result future has cancelled.
266+
await self.loop.sock_sendall(srv_sock_conn, b'1')
267+
268+
for rfut in pending_read_futs:
269+
rfut.cancel()
270+
271+
data = await self.loop.sock_recv(sock_client, 1)
272+
273+
self.assertEqual(data, b'1')
274+
275+
self.loop.run_until_complete(server())
276+
192277

193278
class TestUVSockets(_TestSockets, tb.UVTestCase):
194279

uvloop/handles/poll.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ cdef class UVPoll(UVHandle):
1313
cdef int is_active(self)
1414

1515
cdef is_reading(self)
16+
cdef is_writing(self)
17+
1618
cdef start_reading(self, Handle callback)
1719
cdef start_writing(self, Handle callback)
1820
cdef stop_reading(self)

uvloop/handles/poll.pyx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ cdef class UVPoll(UVHandle):
8787
cdef is_reading(self):
8888
return self._is_alive() and self.reading_handle is not None
8989

90+
cdef is_writing(self):
91+
return self._is_alive() and self.writing_handle is not None
92+
9093
cdef start_reading(self, Handle callback):
9194
cdef:
9295
int mask = 0

uvloop/loop.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,12 @@ cdef class Loop:
177177
cdef _track_process(self, UVProcess proc)
178178
cdef _untrack_process(self, UVProcess proc)
179179

180-
cdef _new_reader_future(self, sock)
181-
cdef _new_writer_future(self, sock)
182180
cdef _add_reader(self, fd, Handle handle)
181+
cdef _has_reader(self, fd)
183182
cdef _remove_reader(self, fd)
184183

185184
cdef _add_writer(self, fd, Handle handle)
185+
cdef _has_writer(self, fd)
186186
cdef _remove_writer(self, fd)
187187

188188
cdef _sock_recv(self, fut, sock, n)

uvloop/loop.pyx

Lines changed: 93 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,20 @@ cdef class Loop:
742742

743743
return result
744744

745+
cdef _has_reader(self, fileobj):
746+
cdef:
747+
UVPoll poll
748+
749+
self._check_closed()
750+
fd = self._fileobj_to_fd(fileobj)
751+
752+
try:
753+
poll = <UVPoll>(self._polls[fd])
754+
except KeyError:
755+
return False
756+
757+
return poll.is_reading()
758+
745759
cdef _add_writer(self, fileobj, Handle handle):
746760
cdef:
747761
UVPoll poll
@@ -791,6 +805,20 @@ cdef class Loop:
791805

792806
return result
793807

808+
cdef _has_writer(self, fileobj):
809+
cdef:
810+
UVPoll poll
811+
812+
self._check_closed()
813+
fd = self._fileobj_to_fd(fileobj)
814+
815+
try:
816+
poll = <UVPoll>(self._polls[fd])
817+
except KeyError:
818+
return False
819+
820+
return poll.is_writing()
821+
794822
cdef _getaddrinfo(self, object host, object port,
795823
int family, int type,
796824
int proto, int flags,
@@ -845,35 +873,17 @@ cdef class Loop:
845873
nr.query(addr, flags)
846874
return fut
847875

848-
cdef _new_reader_future(self, sock):
849-
def _on_cancel(fut):
850-
# Check if the future was cancelled and if the socket
851-
# is still open, i.e.
852-
#
853-
# loop.remove_reader(sock)
854-
# sock.close()
855-
# fut.cancel()
856-
#
857-
# wasn't called by the user.
858-
if fut.cancelled() and sock.fileno() != -1:
859-
self._remove_reader(sock)
860-
861-
fut = self._new_future()
862-
fut.add_done_callback(_on_cancel)
863-
return fut
864-
865-
cdef _new_writer_future(self, sock):
866-
def _on_cancel(fut):
867-
if fut.cancelled() and sock.fileno() != -1:
868-
self._remove_writer(sock)
869-
870-
fut = self._new_future()
871-
fut.add_done_callback(_on_cancel)
872-
return fut
873-
874876
cdef _sock_recv(self, fut, sock, n):
875-
cdef:
876-
Handle handle
877+
if UVLOOP_DEBUG:
878+
if fut.cancelled():
879+
# Shouldn't happen with _SyncSocketReaderFuture.
880+
raise RuntimeError(
881+
f'_sock_recv is called on a cancelled Future')
882+
883+
if not self._has_reader(sock):
884+
raise RuntimeError(
885+
f'socket {sock!r} does not have a reader '
886+
f'in the _sock_recv callback')
877887

878888
try:
879889
data = sock.recv(n)
@@ -889,8 +899,16 @@ cdef class Loop:
889899
self._remove_reader(sock)
890900

891901
cdef _sock_recv_into(self, fut, sock, buf):
892-
cdef:
893-
Handle handle
902+
if UVLOOP_DEBUG:
903+
if fut.cancelled():
904+
# Shouldn't happen with _SyncSocketReaderFuture.
905+
raise RuntimeError(
906+
f'_sock_recv_into is called on a cancelled Future')
907+
908+
if not self._has_reader(sock):
909+
raise RuntimeError(
910+
f'socket {sock!r} does not have a reader '
911+
f'in the _sock_recv_into callback')
894912

895913
try:
896914
data = sock.recv_into(buf)
@@ -910,6 +928,17 @@ cdef class Loop:
910928
Handle handle
911929
int n
912930

931+
if UVLOOP_DEBUG:
932+
if fut.cancelled():
933+
# Shouldn't happen with _SyncSocketReaderFuture.
934+
raise RuntimeError(
935+
f'_sock_sendall is called on a cancelled Future')
936+
937+
if not self._has_writer(sock):
938+
raise RuntimeError(
939+
f'socket {sock!r} does not have a writer '
940+
f'in the _sock_sendall callback')
941+
913942
try:
914943
n = sock.send(data)
915944
except (BlockingIOError, InterruptedError):
@@ -940,9 +969,6 @@ cdef class Loop:
940969
self._add_writer(sock, handle)
941970

942971
cdef _sock_accept(self, fut, sock):
943-
cdef:
944-
Handle handle
945-
946972
try:
947973
conn, address = sock.accept()
948974
conn.setblocking(False)
@@ -2217,7 +2243,7 @@ cdef class Loop:
22172243
if self._debug and sock.gettimeout() != 0:
22182244
raise ValueError("the socket must be non-blocking")
22192245

2220-
fut = self._new_reader_future(sock)
2246+
fut = _SyncSocketReaderFuture(sock, self)
22212247
handle = new_MethodHandle3(
22222248
self,
22232249
"Loop._sock_recv",
@@ -2243,7 +2269,7 @@ cdef class Loop:
22432269
if self._debug and sock.gettimeout() != 0:
22442270
raise ValueError("the socket must be non-blocking")
22452271

2246-
fut = self._new_reader_future(sock)
2272+
fut = _SyncSocketReaderFuture(sock, self)
22472273
handle = new_MethodHandle3(
22482274
self,
22492275
"Loop._sock_recv_into",
@@ -2294,7 +2320,7 @@ cdef class Loop:
22942320
data = memoryview(data)
22952321
data = data[n:]
22962322

2297-
fut = self._new_writer_future(sock)
2323+
fut = _SyncSocketWriterFuture(sock, self)
22982324
handle = new_MethodHandle3(
22992325
self,
23002326
"Loop._sock_sendall",
@@ -2324,7 +2350,7 @@ cdef class Loop:
23242350
if self._debug and sock.gettimeout() != 0:
23252351
raise ValueError("the socket must be non-blocking")
23262352

2327-
fut = self._new_reader_future(sock)
2353+
fut = _SyncSocketReaderFuture(sock, self)
23282354
handle = new_MethodHandle2(
23292355
self,
23302356
"Loop._sock_accept",
@@ -2908,6 +2934,36 @@ cdef inline void __loop_free_buffer(Loop loop):
29082934
loop._recv_buffer_in_use = 0
29092935

29102936

2937+
class _SyncSocketReaderFuture(aio_Future):
2938+
2939+
def __init__(self, sock, loop):
2940+
aio_Future.__init__(self, loop=loop)
2941+
self.__sock = sock
2942+
self.__loop = loop
2943+
2944+
def cancel(self):
2945+
if self.__sock is not None and self.__sock.fileno() != -1:
2946+
self.__loop.remove_reader(self.__sock)
2947+
self.__sock = None
2948+
2949+
aio_Future.cancel(self)
2950+
2951+
2952+
class _SyncSocketWriterFuture(aio_Future):
2953+
2954+
def __init__(self, sock, loop):
2955+
aio_Future.__init__(self, loop=loop)
2956+
self.__sock = sock
2957+
self.__loop = loop
2958+
2959+
def cancel(self):
2960+
if self.__sock is not None and self.__sock.fileno() != -1:
2961+
self.__loop.remove_writer(self.__sock)
2962+
self.__sock = None
2963+
2964+
aio_Future.cancel(self)
2965+
2966+
29112967
include "cbhandles.pyx"
29122968
include "pseudosock.pyx"
29132969

0 commit comments

Comments
 (0)