From 25ade2537cb5cbfa139543366437878210ece6d3 Mon Sep 17 00:00:00 2001 From: John Nunley Date: Mon, 25 Sep 2023 11:22:26 -0700 Subject: [PATCH 1/4] Manage sources being inserted into kqueue Thus far, our kqueue implementation has been a relatively thin layer on top of the OS kqueue. However, kqueue doesn't keep track of when the same source is inserted twice, or when a source that doesn't exist is removed. In the interest of keeping consistent behavior between backends this commit adds a system for tracking when sources are inserted. Closes #151 Signed-off-by: John Nunley --- src/kqueue.rs | 77 +++++++++++++++++++++++++++++++++++++++++++++++- src/os/kqueue.rs | 35 ++++++++++++++++++++-- 2 files changed, 109 insertions(+), 3 deletions(-) diff --git a/src/kqueue.rs b/src/kqueue.rs index 40c481c..58b9935 100644 --- a/src/kqueue.rs +++ b/src/kqueue.rs @@ -1,7 +1,9 @@ //! Bindings to kqueue (macOS, iOS, tvOS, watchOS, FreeBSD, NetBSD, OpenBSD, DragonFly BSD). +use std::collections::HashSet; use std::io; use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd}; +use std::sync::RwLock; use std::time::Duration; use rustix::event::kqueue; @@ -15,6 +17,11 @@ pub struct Poller { /// File descriptor for the kqueue instance. kqueue_fd: OwnedFd, + /// List of sources currently registered in this poller. + /// + /// This is used to make sure the same source is not registered twice. + sources: RwLock>, + /// Notification pipe for waking up the poller. /// /// On platforms that support `EVFILT_USER`, this uses that to wake up the poller. Otherwise, it @@ -22,6 +29,22 @@ pub struct Poller { notify: notify::Notify, } +/// Identifier for a source. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub(crate) enum SourceId { + /// Registered file descriptor. + Fd(RawFd), + + /// Signal. + Signal(std::os::raw::c_int), + + /// Process ID. + Pid(rustix::process::Pid), + + /// Timer ID. + Timer(usize), +} + impl Poller { /// Creates a new poller. pub fn new() -> io::Result { @@ -31,6 +54,7 @@ impl Poller { let poller = Poller { kqueue_fd, + sources: RwLock::new(HashSet::new()), notify: notify::Notify::new()?, }; @@ -60,6 +84,8 @@ impl Poller { /// /// The file descriptor must be valid and it must last until it is deleted. pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { + self.add_source(SourceId::Fd(fd))?; + // File descriptors don't need to be added explicitly, so just modify the interest. self.modify(BorrowedFd::borrow_raw(fd), ev, mode) } @@ -79,6 +105,8 @@ impl Poller { }; let _enter = span.as_ref().map(|s| s.enter()); + self.has_source(SourceId::Fd(fd.as_raw_fd()))?; + let mode_flags = mode_to_flags(mode); let read_flags = if ev.readable { @@ -143,10 +171,57 @@ impl Poller { Ok(()) } + /// Add a source to the sources set. + #[inline] + pub(crate) fn add_source(&self, source: SourceId) -> io::Result<()> { + if self + .sources + .write() + .unwrap_or_else(|e| e.into_inner()) + .insert(source) + { + Ok(()) + } else { + Err(io::Error::from(io::ErrorKind::AlreadyExists)) + } + } + + /// Tell if a source is currently inside the set. + #[inline] + pub(crate) fn has_source(&self, source: SourceId) -> io::Result<()> { + if self + .sources + .read() + .unwrap_or_else(|e| e.into_inner()) + .contains(&source) + { + Ok(()) + } else { + Err(io::Error::from(io::ErrorKind::NotFound)) + } + } + + /// Remove a source from the sources set. + #[inline] + pub(crate) fn remove_source(&self, source: SourceId) -> io::Result<()> { + if self + .sources + .write() + .unwrap_or_else(|e| e.into_inner()) + .remove(&source) + { + Ok(()) + } else { + Err(io::Error::from(io::ErrorKind::NotFound)) + } + } + /// Deletes a file descriptor. pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> { // Simply delete interest in the file descriptor. - self.modify(fd, Event::none(0), PollMode::Oneshot) + self.modify(fd, Event::none(0), PollMode::Oneshot)?; + + self.remove_source(SourceId::Fd(fd.as_raw_fd())) } /// Waits for I/O events with an optional timeout. diff --git a/src/os/kqueue.rs b/src/os/kqueue.rs index f670527..e08136d 100644 --- a/src/os/kqueue.rs +++ b/src/os/kqueue.rs @@ -1,6 +1,8 @@ //! Functionality that is only available for `kqueue`-based platforms. -use crate::sys::mode_to_flags; +#![allow(private_interfaces)] + +use crate::sys::{mode_to_flags, SourceId}; use crate::{PollMode, Poller}; use std::io; @@ -98,10 +100,13 @@ impl PollerKqueueExt for Poller { #[inline(always)] fn add_filter(&self, filter: F, key: usize, mode: PollMode) -> io::Result<()> { // No difference between adding and modifying in kqueue. + self.poller.add_source(filter.source_id())?; self.modify_filter(filter, key, mode) } fn modify_filter(&self, filter: F, key: usize, mode: PollMode) -> io::Result<()> { + self.poller.has_source(filter.source_id())?; + // Convert the filter into a kevent. let event = filter.filter(kqueue::EventFlags::ADD | mode_to_flags(mode), key); @@ -114,7 +119,9 @@ impl PollerKqueueExt for Poller { let event = filter.filter(kqueue::EventFlags::DELETE, 0); // Delete the filter. - self.poller.submit_changes([event]) + self.poller.submit_changes([event])?; + + self.poller.remove_source(filter.source_id()) } } @@ -126,6 +133,11 @@ unsafe impl FilterSealed for &T { fn filter(&self, flags: kqueue::EventFlags, key: usize) -> kqueue::Event { (**self).filter(flags, key) } + + #[inline(always)] + fn source_id(&self) -> SourceId { + (**self).source_id() + } } impl Filter for &T {} @@ -149,6 +161,11 @@ unsafe impl FilterSealed for Signal { key as _, ) } + + #[inline(always)] + fn source_id(&self) -> SourceId { + SourceId::Signal(self.0) + } } impl Filter for Signal {} @@ -207,6 +224,11 @@ unsafe impl FilterSealed for Process<'_> { key as _, ) } + + #[inline(always)] + fn source_id(&self) -> SourceId { + SourceId::Pid(rustix::process::Pid::from_child(self.child)) + } } impl Filter for Process<'_> {} @@ -234,11 +256,17 @@ unsafe impl FilterSealed for Timer { key as _, ) } + + #[inline(always)] + fn source_id(&self) -> SourceId { + SourceId::Timer(self.id) + } } impl Filter for Timer {} mod __private { + use crate::sys::SourceId; use rustix::event::kqueue; #[doc(hidden)] @@ -247,5 +275,8 @@ mod __private { /// /// This filter's flags must have `EV_RECEIPT`. fn filter(&self, flags: kqueue::EventFlags, key: usize) -> kqueue::Event; + + /// Get the source ID for this source. + fn source_id(&self) -> SourceId; } } From 7d9e452e9ac55f923dbee13dede6e3d727273bbf Mon Sep 17 00:00:00 2001 From: John Nunley Date: Mon, 25 Sep 2023 11:27:02 -0700 Subject: [PATCH 2/4] Fix nightly-exempt issue Signed-off-by: John Nunley --- src/kqueue.rs | 3 ++- src/os/kqueue.rs | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/kqueue.rs b/src/kqueue.rs index 58b9935..359305b 100644 --- a/src/kqueue.rs +++ b/src/kqueue.rs @@ -30,8 +30,9 @@ pub struct Poller { } /// Identifier for a source. +#[doc(hidden)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub(crate) enum SourceId { +pub enum SourceId { /// Registered file descriptor. Fd(RawFd), diff --git a/src/os/kqueue.rs b/src/os/kqueue.rs index e08136d..aae36c3 100644 --- a/src/os/kqueue.rs +++ b/src/os/kqueue.rs @@ -1,7 +1,5 @@ //! Functionality that is only available for `kqueue`-based platforms. -#![allow(private_interfaces)] - use crate::sys::{mode_to_flags, SourceId}; use crate::{PollMode, Poller}; From 07c032831ddaf4025216b96bc962307abcfb9fe0 Mon Sep 17 00:00:00 2001 From: John Nunley Date: Mon, 25 Sep 2023 13:01:07 -0700 Subject: [PATCH 3/4] Add test for this use case Signed-off-by: John Nunley --- tests/io.rs | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/io.rs b/tests/io.rs index 2e6ce04..a6dc25c 100644 --- a/tests/io.rs +++ b/tests/io.rs @@ -1,6 +1,7 @@ use polling::{Event, Events, Poller}; use std::io::{self, Write}; use std::net::{TcpListener, TcpStream}; +use std::sync::Arc; use std::time::Duration; #[test] @@ -38,6 +39,38 @@ fn basic_io() { poller.delete(&read).unwrap(); } +#[test] +fn insert_twice() { + let (read, mut write) = tcp_pair().unwrap(); + let read = Arc::new(read); + + let poller = Poller::new().unwrap(); + unsafe { + poller.add(&read, Event::readable(1)).unwrap(); + assert_eq!( + poller.add(&read, Event::readable(1)).unwrap_err().kind(), + io::ErrorKind::AlreadyExists + ); + } + + write.write_all(&[1]).unwrap(); + let mut events = Events::new(); + assert_eq!( + poller + .wait(&mut events, Some(Duration::from_secs(1))) + .unwrap(), + 1 + ); + + assert_eq!(events.len(), 1); + assert_eq!( + events.iter().next().unwrap().with_no_extra(), + Event::readable(1) + ); + + poller.delete(&read).unwrap(); +} + fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> { let listener = TcpListener::bind("127.0.0.1:0")?; let a = TcpStream::connect(listener.local_addr()?)?; From 2de376e6643a0a300b8ebe8ffafe14739ed04905 Mon Sep 17 00:00:00 2001 From: John Nunley Date: Mon, 25 Sep 2023 13:09:51 -0700 Subject: [PATCH 4/4] Fix Windows compile failure Signed-off-by: John Nunley --- tests/io.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/io.rs b/tests/io.rs index a6dc25c..dc42103 100644 --- a/tests/io.rs +++ b/tests/io.rs @@ -41,14 +41,24 @@ fn basic_io() { #[test] fn insert_twice() { + #[cfg(unix)] + use std::os::unix::io::AsRawFd; + #[cfg(windows)] + use std::os::windows::io::AsRawSocket; + let (read, mut write) = tcp_pair().unwrap(); let read = Arc::new(read); let poller = Poller::new().unwrap(); unsafe { - poller.add(&read, Event::readable(1)).unwrap(); + #[cfg(unix)] + let read = read.as_raw_fd(); + #[cfg(windows)] + let read = read.as_raw_socket(); + + poller.add(read, Event::readable(1)).unwrap(); assert_eq!( - poller.add(&read, Event::readable(1)).unwrap_err().kind(), + poller.add(read, Event::readable(1)).unwrap_err().kind(), io::ErrorKind::AlreadyExists ); }