Skip to content

Commit

Permalink
Make ConnectInfo work with ListenerExt::tap_io
Browse files Browse the repository at this point in the history
  • Loading branch information
jplatte committed Dec 1, 2024
1 parent 76c4ba9 commit 23864cd
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
14 changes: 12 additions & 2 deletions axum/src/extract/connect_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//!
//! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info

use crate::extension::AddExtension;
use crate::{extension::AddExtension, serve};

use super::{Extension, FromRequestParts};
use http::request::Parts;
Expand Down Expand Up @@ -84,7 +84,6 @@ pub trait Connected<T>: Clone + Send + Sync + 'static {

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
const _: () = {
use crate::serve;
use tokio::net::TcpListener;

impl Connected<serve::IncomingStream<'_, TcpListener>> for SocketAddr {
Expand All @@ -100,6 +99,17 @@ impl Connected<SocketAddr> for SocketAddr {
}
}

impl<'a, L, F> Connected<serve::IncomingStream<'a, serve::TapIo<L, F>>> for L::Addr
where
L: serve::Listener,
L::Addr: Clone + Sync + 'static,
F: FnMut(&mut L::Io) + Send + 'static,
{
fn connect_info(stream: serve::IncomingStream<'a, serve::TapIo<L, F>>) -> Self {
stream.remote_addr().clone()
}
}

impl<S, C, T> Service<T> for IntoMakeServiceWithConnectInfo<S, C>
where
S: Clone,
Expand Down
38 changes: 38 additions & 0 deletions axum/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ mod tests {
extract::connect_info::Connected,
handler::{Handler, HandlerWithoutStateExt},
routing::get,
serve::ListenerExt,
Router,
};

Expand All @@ -452,14 +453,29 @@ mod tests {

let addr = "0.0.0.0:0";

let tcp_nodelay_listener = || async {
TcpListener::bind(addr).await.unwrap().tap_io(|tcp_stream| {
if let Err(err) = tcp_stream.set_nodelay(true) {
eprintln!("failed to set TCP_NODELAY on incoming connection: {err:#}");
}
})
};

// router
serve(TcpListener::bind(addr).await.unwrap(), router.clone());
serve(tcp_nodelay_listener().await, router.clone())
.await
.unwrap();
serve(UnixListener::bind("").unwrap(), router.clone());

serve(
TcpListener::bind(addr).await.unwrap(),
router.clone().into_make_service(),
);
serve(
tcp_nodelay_listener().await,
router.clone().into_make_service(),
);
serve(
UnixListener::bind("").unwrap(),
router.clone().into_make_service(),
Expand All @@ -471,19 +487,30 @@ mod tests {
.clone()
.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
tcp_nodelay_listener().await,
router
.clone()
.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
router.into_make_service_with_connect_info::<UdsConnectInfo>(),
);

// method router
serve(TcpListener::bind(addr).await.unwrap(), get(handler));
serve(tcp_nodelay_listener().await, get(handler));
serve(UnixListener::bind("").unwrap(), get(handler));

serve(
TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service(),
);
serve(
tcp_nodelay_listener().await,
get(handler).into_make_service(),
);
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service(),
Expand All @@ -493,6 +520,10 @@ mod tests {
TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
tcp_nodelay_listener().await,
get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
Expand All @@ -503,24 +534,31 @@ mod tests {
TcpListener::bind(addr).await.unwrap(),
handler.into_service(),
);
serve(tcp_nodelay_listener().await, handler.into_service());
serve(UnixListener::bind("").unwrap(), handler.into_service());

serve(
TcpListener::bind(addr).await.unwrap(),
handler.with_state(()),
);
serve(tcp_nodelay_listener().await, handler.with_state(()));
serve(UnixListener::bind("").unwrap(), handler.with_state(()));

serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_make_service(),
);
serve(tcp_nodelay_listener().await, handler.into_make_service());
serve(UnixListener::bind("").unwrap(), handler.into_make_service());

serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
tcp_nodelay_listener().await,
handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
UnixListener::bind("").unwrap(),
handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
Expand Down

0 comments on commit 23864cd

Please sign in to comment.