Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : add the ability to identify if tcp connection has failed #185

Merged
merged 14 commits into from
Jan 26, 2024
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ features = [
[dev-dependencies]
easy-parallel = "3.1.0"
fastrand = "2.0.0"
socket2 = "0.5.5"

[target.'cfg(unix)'.dev-dependencies]
libc = "0.2"
Expand Down
36 changes: 36 additions & 0 deletions examples/tcp_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use std::{io, net};

use polling::Event;
use socket2::Type;

fn main() -> io::Result<()> {
let socket = socket2::Socket::new(socket2::Domain::IPV4, Type::STREAM, None)?;
let poller = polling::Poller::new()?;
unsafe {
poller.add(&socket, Event::new(0, true, true))?;
}
let addr = net::SocketAddr::new(net::Ipv4Addr::LOCALHOST.into(), 8080);
socket.set_nonblocking(true)?;
let _ = socket.connect(&addr.into());

let mut events = polling::Events::new();

events.clear();
poller.wait(&mut events, None)?;

let event = events.iter().next();
let event = match event {
Some(event) => event,
None => {
println!("no event");
return Ok(());
}
};

println!("event: {:?}", event);
if event.is_connect_failed().unwrap_or_default() {
println!("connect failed");
}

Ok(())
}
8 changes: 8 additions & 0 deletions src/epoll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,14 @@ impl EventExtra {
pub fn is_pri(&self) -> bool {
self.flags.contains(epoll::EventFlags::PRI)
}

#[inline]
pub fn is_connect_failed(&self) -> Option<bool> {
Some(
self.flags.contains(epoll::EventFlags::ERR)
|| self.flags.contains(epoll::EventFlags::HUP),
)
}
}

/// The notifier for Linux.
Expand Down
6 changes: 6 additions & 0 deletions src/iocp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,12 @@ impl EventExtra {
pub fn set_pri(&mut self, active: bool) {
self.flags.set(AfdPollMask::RECEIVE_EXPEDITED, active);
}

/// Check if TCP connect failed.
#[inline]
pub fn is_connect_failed(&self) -> Option<bool> {
Some(self.flags.intersects(AfdPollMask::CONNECT_FAIL))
}
}

/// A packet used to wake up the poller with an event.
Expand Down
5 changes: 5 additions & 0 deletions src/kqueue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,11 @@ impl EventExtra {
pub fn is_pri(&self) -> bool {
false
}

#[inline]
pub fn is_connect_failed(&self) -> Option<bool> {
None
}
}

pub(crate) fn mode_to_flags(mode: PollMode) -> kqueue::EventFlags {
Expand Down
59 changes: 59 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,65 @@ impl Event {
self.extra.is_pri()
}

/// Tells if this event is the result of a connection failure.
///
/// This function checks if a TCP connection has failed. It corresponds to the `EPOLLERR` or `EPOLLHUP` event in Linux
/// and `CONNECT_FAILED` event in Windows IOCP.
///
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have an example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

/// # Examples
///
/// ```
/// use std::{io, net};
/// // Assuming polling and socket2 are included as dependencies in Cargo.toml
/// use polling::Event;
/// use socket2::Type;
///
/// fn main() -> io::Result<()> {
/// let socket = socket2::Socket::new(socket2::Domain::IPV4, Type::STREAM, None)?;
/// let poller = polling::Poller::new()?;
/// unsafe {
/// poller.add(&socket, Event::new(0, true, true))?;
/// }
/// let addr = net::SocketAddr::new(net::Ipv4Addr::LOCALHOST.into(), 8080);
/// socket.set_nonblocking(true)?;
/// let _ = socket.connect(&addr.into());
///
/// let mut events = polling::Events::new();
///
/// events.clear();
/// poller.wait(&mut events, None)?;
///
/// let event = events.iter().next();
///
/// let event = match event {
/// Some(event) => event,
/// None => {
/// println!("no event");
/// return Ok(());
/// },
/// };
///
/// println!("event: {:?}", event);
/// if event
/// .is_connect_failed()
/// .unwrap_or_default()
/// {
/// println!("connect failed");
/// }
///
/// Ok(())
/// }
/// ```
///
/// # Returns
///
/// Returns `Some(true)` if the connection has failed, `Some(false)` if the connection has not failed,
/// or `None` if the platform does not support detecting this condition.
#[inline]
pub fn is_connect_failed(&self) -> Option<bool> {
self.extra.is_connect_failed()
}

/// Remove any extra information from this event.
#[inline]
pub fn clear_extra(&mut self) {
Expand Down
5 changes: 5 additions & 0 deletions src/poll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,11 @@ impl EventExtra {
pub fn is_pri(&self) -> bool {
self.flags.contains(PollFlags::PRI)
}

#[inline]
pub fn is_connect_failed(&self) -> Option<bool> {
Some(self.flags.contains(PollFlags::ERR) || self.flags.contains(PollFlags::HUP))
}
}

fn cvt_mode_as_remove(mode: PollMode) -> io::Result<bool> {
Expand Down
5 changes: 5 additions & 0 deletions src/port.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,4 +250,9 @@ impl EventExtra {
pub fn is_pri(&self) -> bool {
self.flags.contains(PollFlags::PRI)
}

#[inline]
pub fn is_connect_failed(&self) -> Option<bool> {
Some(self.flags.contains(PollFlags::ERR) || self.flags.contains(PollFlags::HUP))
}
}
Loading