Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manage sources being inserted into kqueue #153

Merged
merged 4 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion src/kqueue.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
//! Bindings to kqueue (macOS, iOS, tvOS, watchOS, FreeBSD, NetBSD, OpenBSD, DragonFly BSD).

use std::collections::HashSet;
use std::io;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
use std::sync::RwLock;
use std::time::Duration;

use rustix::event::kqueue;
Expand All @@ -15,13 +17,35 @@ pub struct Poller {
/// File descriptor for the kqueue instance.
kqueue_fd: OwnedFd,

/// List of sources currently registered in this poller.
///
/// This is used to make sure the same source is not registered twice.
sources: RwLock<HashSet<SourceId>>,

/// Notification pipe for waking up the poller.
///
/// On platforms that support `EVFILT_USER`, this uses that to wake up the poller. Otherwise, it
/// uses a pipe.
notify: notify::Notify,
}

/// Identifier for a source.
#[doc(hidden)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum SourceId {
/// Registered file descriptor.
Fd(RawFd),

/// Signal.
Signal(std::os::raw::c_int),

/// Process ID.
Pid(rustix::process::Pid),

/// Timer ID.
Timer(usize),
}

impl Poller {
/// Creates a new poller.
pub fn new() -> io::Result<Poller> {
Expand All @@ -31,6 +55,7 @@ impl Poller {

let poller = Poller {
kqueue_fd,
sources: RwLock::new(HashSet::new()),
notify: notify::Notify::new()?,
};

Expand Down Expand Up @@ -60,6 +85,8 @@ impl Poller {
///
/// The file descriptor must be valid and it must last until it is deleted.
pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> {
self.add_source(SourceId::Fd(fd))?;

// File descriptors don't need to be added explicitly, so just modify the interest.
self.modify(BorrowedFd::borrow_raw(fd), ev, mode)
}
Expand All @@ -79,6 +106,8 @@ impl Poller {
};
let _enter = span.as_ref().map(|s| s.enter());

self.has_source(SourceId::Fd(fd.as_raw_fd()))?;

let mode_flags = mode_to_flags(mode);

let read_flags = if ev.readable {
Expand Down Expand Up @@ -143,10 +172,57 @@ impl Poller {
Ok(())
}

/// Add a source to the sources set.
#[inline]
pub(crate) fn add_source(&self, source: SourceId) -> io::Result<()> {
if self
.sources
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(source)
{
Ok(())
} else {
Err(io::Error::from(io::ErrorKind::AlreadyExists))
}
}

/// Tell if a source is currently inside the set.
#[inline]
pub(crate) fn has_source(&self, source: SourceId) -> io::Result<()> {
if self
.sources
.read()
.unwrap_or_else(|e| e.into_inner())
.contains(&source)
{
Ok(())
} else {
Err(io::Error::from(io::ErrorKind::NotFound))
}
}

/// Remove a source from the sources set.
#[inline]
pub(crate) fn remove_source(&self, source: SourceId) -> io::Result<()> {
if self
.sources
.write()
.unwrap_or_else(|e| e.into_inner())
.remove(&source)
{
Ok(())
} else {
Err(io::Error::from(io::ErrorKind::NotFound))
}
}

/// Deletes a file descriptor.
pub fn delete(&self, fd: BorrowedFd<'_>) -> io::Result<()> {
// Simply delete interest in the file descriptor.
self.modify(fd, Event::none(0), PollMode::Oneshot)
self.modify(fd, Event::none(0), PollMode::Oneshot)?;

self.remove_source(SourceId::Fd(fd.as_raw_fd()))
}

/// Waits for I/O events with an optional timeout.
Expand Down
33 changes: 31 additions & 2 deletions src/os/kqueue.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Functionality that is only available for `kqueue`-based platforms.

use crate::sys::mode_to_flags;
use crate::sys::{mode_to_flags, SourceId};
use crate::{PollMode, Poller};

use std::io;
Expand Down Expand Up @@ -98,10 +98,13 @@ impl<F: Filter> PollerKqueueExt<F> for Poller {
#[inline(always)]
fn add_filter(&self, filter: F, key: usize, mode: PollMode) -> io::Result<()> {
// No difference between adding and modifying in kqueue.
self.poller.add_source(filter.source_id())?;
self.modify_filter(filter, key, mode)
}

fn modify_filter(&self, filter: F, key: usize, mode: PollMode) -> io::Result<()> {
self.poller.has_source(filter.source_id())?;

// Convert the filter into a kevent.
let event = filter.filter(kqueue::EventFlags::ADD | mode_to_flags(mode), key);

Expand All @@ -114,7 +117,9 @@ impl<F: Filter> PollerKqueueExt<F> for Poller {
let event = filter.filter(kqueue::EventFlags::DELETE, 0);

// Delete the filter.
self.poller.submit_changes([event])
self.poller.submit_changes([event])?;

self.poller.remove_source(filter.source_id())
}
}

Expand All @@ -126,6 +131,11 @@ unsafe impl<T: FilterSealed + ?Sized> FilterSealed for &T {
fn filter(&self, flags: kqueue::EventFlags, key: usize) -> kqueue::Event {
(**self).filter(flags, key)
}

#[inline(always)]
fn source_id(&self) -> SourceId {
(**self).source_id()
}
}

impl<T: Filter + ?Sized> Filter for &T {}
Expand All @@ -149,6 +159,11 @@ unsafe impl FilterSealed for Signal {
key as _,
)
}

#[inline(always)]
fn source_id(&self) -> SourceId {
SourceId::Signal(self.0)
}
}

impl Filter for Signal {}
Expand Down Expand Up @@ -207,6 +222,11 @@ unsafe impl FilterSealed for Process<'_> {
key as _,
)
}

#[inline(always)]
fn source_id(&self) -> SourceId {
SourceId::Pid(rustix::process::Pid::from_child(self.child))
}
}

impl Filter for Process<'_> {}
Expand Down Expand Up @@ -234,11 +254,17 @@ unsafe impl FilterSealed for Timer {
key as _,
)
}

#[inline(always)]
fn source_id(&self) -> SourceId {
SourceId::Timer(self.id)
}
}

impl Filter for Timer {}

mod __private {
use crate::sys::SourceId;
use rustix::event::kqueue;

#[doc(hidden)]
Expand All @@ -247,5 +273,8 @@ mod __private {
///
/// This filter's flags must have `EV_RECEIPT`.
fn filter(&self, flags: kqueue::EventFlags, key: usize) -> kqueue::Event;

/// Get the source ID for this source.
fn source_id(&self) -> SourceId;
}
}
43 changes: 43 additions & 0 deletions tests/io.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use polling::{Event, Events, Poller};
use std::io::{self, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::Arc;
use std::time::Duration;

#[test]
Expand Down Expand Up @@ -38,6 +39,48 @@ fn basic_io() {
poller.delete(&read).unwrap();
}

#[test]
fn insert_twice() {
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;

let (read, mut write) = tcp_pair().unwrap();
let read = Arc::new(read);

let poller = Poller::new().unwrap();
unsafe {
#[cfg(unix)]
let read = read.as_raw_fd();
#[cfg(windows)]
let read = read.as_raw_socket();

poller.add(read, Event::readable(1)).unwrap();
assert_eq!(
poller.add(read, Event::readable(1)).unwrap_err().kind(),
io::ErrorKind::AlreadyExists
);
}

write.write_all(&[1]).unwrap();
let mut events = Events::new();
assert_eq!(
poller
.wait(&mut events, Some(Duration::from_secs(1)))
.unwrap(),
1
);

assert_eq!(events.len(), 1);
assert_eq!(
events.iter().next().unwrap().with_no_extra(),
Event::readable(1)
);

poller.delete(&read).unwrap();
}

fn tcp_pair() -> io::Result<(TcpStream, TcpStream)> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let a = TcpStream::connect(listener.local_addr()?)?;
Expand Down