diff --git a/examples/add_route.rs b/examples/add_route.rs index 27e7bcf..ab7ff81 100644 --- a/examples/add_route.rs +++ b/examples/add_route.rs @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -use std::env; +use std::{env, net::Ipv4Addr}; use ipnetwork::Ipv4Network; -use rtnetlink::{new_connection, Error, Handle}; +use rtnetlink::{new_connection, Error, Handle, RouteMessageBuilder}; const TEST_TABLE_ID: u32 = 299; @@ -40,15 +40,12 @@ async fn add_route( gateway: &Ipv4Network, handle: Handle, ) -> Result<(), Error> { - let route = handle.route(); - route - .add() - .v4() + let route = RouteMessageBuilder::::new() .destination_prefix(dest.ip(), dest.prefix()) .gateway(gateway.ip()) .table_id(TEST_TABLE_ID) - .execute() - .await?; + .build(); + handle.route().add(route).execute().await?; Ok(()) } diff --git a/examples/add_route_pref_src.rs b/examples/add_route_pref_src.rs index 17002ae..3cb42e7 100644 --- a/examples/add_route_pref_src.rs +++ b/examples/add_route_pref_src.rs @@ -4,7 +4,7 @@ use futures::TryStreamExt; use std::{env, net::Ipv4Addr}; use ipnetwork::Ipv4Network; -use rtnetlink::{new_connection, Error, Handle}; +use rtnetlink::{new_connection, Error, Handle, RouteMessageBuilder}; #[tokio::main] async fn main() -> Result<(), ()> { @@ -53,15 +53,12 @@ async fn add_route( .header .index; - let route = handle.route(); - route - .add() - .v4() + let route = RouteMessageBuilder::::new() .destination_prefix(dest.ip(), dest.prefix()) .output_interface(iface_idx) .pref_source(source) - .execute() - .await?; + .build(); + handle.route().add(route).execute().await?; Ok(()) } diff --git a/examples/listen.rs b/examples/listen.rs index b7c25c6..7f20452 100644 --- a/examples/listen.rs +++ b/examples/listen.rs @@ -6,7 +6,10 @@ use futures::stream::StreamExt; use netlink_sys::{AsyncSocket, SocketAddr}; use rtnetlink::{ - constants::{RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_ROUTE}, + constants::{ + RTMGRP_IPV4_IFADDR, RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_IFADDR, + RTMGRP_IPV6_ROUTE, RTMGRP_LINK, + }, new_connection, }; @@ -18,7 +21,11 @@ async fn main() -> Result<(), String> { // These flags specify what kinds of broadcast messages we want to listen // for. - let mgroup_flags = RTMGRP_IPV4_ROUTE | RTMGRP_IPV6_ROUTE; + let mgroup_flags = RTMGRP_LINK + | RTMGRP_IPV4_IFADDR + | RTMGRP_IPV4_ROUTE + | RTMGRP_IPV6_IFADDR + | RTMGRP_IPV6_ROUTE; // A netlink socket address is created with said flags. let addr = SocketAddr::new(0, mgroup_flags); diff --git a/src/route/add.rs b/src/route/add.rs index 1416049..ebf8067 100644 --- a/src/route/add.rs +++ b/src/route/add.rs @@ -1,28 +1,19 @@ // SPDX-License-Identifier: MIT use futures::stream::StreamExt; -use std::{ - marker::PhantomData, - net::{Ipv4Addr, Ipv6Addr}, -}; +use std::{marker::PhantomData, net::IpAddr}; use netlink_packet_core::{ NetlinkMessage, NLM_F_ACK, NLM_F_CREATE, NLM_F_EXCL, NLM_F_REPLACE, NLM_F_REQUEST, }; -use netlink_packet_route::{ - route::{ - RouteAddress, RouteAttribute, RouteHeader, RouteMessage, RouteProtocol, - RouteScope, RouteType, - }, - AddressFamily, RouteNetlinkMessage, -}; +use netlink_packet_route::{route::RouteMessage, RouteNetlinkMessage}; use crate::{try_nl, Error, Handle}; /// A request to create a new route. This is equivalent to the `ip route add` /// commands. -pub struct RouteAddRequest { +pub struct RouteAddRequest { handle: Handle, message: RouteMessage, replace: bool, @@ -30,14 +21,7 @@ pub struct RouteAddRequest { } impl RouteAddRequest { - pub(crate) fn new(handle: Handle) -> Self { - let mut message = RouteMessage::default(); - - message.header.table = RouteHeader::RT_TABLE_MAIN; - message.header.protocol = RouteProtocol::Static; - message.header.scope = RouteScope::Universe; - message.header.kind = RouteType::Unicast; - + pub(crate) fn new(handle: Handle, message: RouteMessage) -> Self { RouteAddRequest { handle, message, @@ -46,91 +30,8 @@ impl RouteAddRequest { } } - /// Sets the input interface index. - pub fn input_interface(mut self, index: u32) -> Self { - self.message.attributes.push(RouteAttribute::Iif(index)); - self - } - - /// Sets the output interface index. - pub fn output_interface(mut self, index: u32) -> Self { - self.message.attributes.push(RouteAttribute::Oif(index)); - self - } - - /// Sets the route priority (metric) - pub fn priority(mut self, priority: u32) -> Self { - self.message - .attributes - .push(RouteAttribute::Priority(priority)); - self - } - - /// Sets the route table. - /// - /// Default is main route table. - #[deprecated(note = "Please use `table_id` instead")] - pub fn table(mut self, table: u8) -> Self { - self.message.header.table = table; - self - } - - /// Sets the route table ID. - /// - /// Default is main route table. - pub fn table_id(mut self, table: u32) -> Self { - if table > 255 { - self.message.attributes.push(RouteAttribute::Table(table)); - } else { - self.message.header.table = table as u8; - } - self - } - - /// Sets the route protocol. - /// - /// Default is static route protocol. - pub fn protocol(mut self, protocol: RouteProtocol) -> Self { - self.message.header.protocol = protocol; - self - } - - /// Sets the route scope. - /// - /// Default is universe route scope. - pub fn scope(mut self, scope: RouteScope) -> Self { - self.message.header.scope = scope; - self - } - - /// Sets the route kind. - /// - /// Default is unicast route kind. - pub fn kind(mut self, kind: RouteType) -> Self { - self.message.header.kind = kind; - self - } - - /// Build an IP v4 route request - pub fn v4(mut self) -> RouteAddRequest { - self.message.header.address_family = AddressFamily::Inet; - RouteAddRequest { - handle: self.handle, - message: self.message, - replace: false, - _phantom: Default::default(), - } - } - - /// Build an IP v6 route request - pub fn v6(mut self) -> RouteAddRequest { - self.message.header.address_family = AddressFamily::Inet6; - RouteAddRequest { - handle: self.handle, - message: self.message, - replace: false, - _phantom: Default::default(), - } + pub fn message_mut(&mut self) -> &mut RouteMessage { + &mut self.message } /// Replace existing matching route. @@ -160,89 +61,4 @@ impl RouteAddRequest { } Ok(()) } - - /// Return a mutable reference to the request message. - pub fn message_mut(&mut self) -> &mut RouteMessage { - &mut self.message - } -} - -impl RouteAddRequest { - /// Sets the source address prefix. - pub fn source_prefix(mut self, addr: Ipv4Addr, prefix_length: u8) -> Self { - self.message.header.source_prefix_length = prefix_length; - self.message - .attributes - .push(RouteAttribute::Source(RouteAddress::Inet(addr))); - self - } - - /// Sets the preferred source address. - pub fn pref_source(mut self, addr: Ipv4Addr) -> Self { - self.message - .attributes - .push(RouteAttribute::PrefSource(RouteAddress::Inet(addr))); - self - } - - /// Sets the destination address prefix. - pub fn destination_prefix( - mut self, - addr: Ipv4Addr, - prefix_length: u8, - ) -> Self { - self.message.header.destination_prefix_length = prefix_length; - self.message - .attributes - .push(RouteAttribute::Destination(RouteAddress::Inet(addr))); - self - } - - /// Sets the gateway (via) address. - pub fn gateway(mut self, addr: Ipv4Addr) -> Self { - self.message - .attributes - .push(RouteAttribute::Gateway(RouteAddress::Inet(addr))); - self - } -} - -impl RouteAddRequest { - /// Sets the source address prefix. - pub fn source_prefix(mut self, addr: Ipv6Addr, prefix_length: u8) -> Self { - self.message.header.source_prefix_length = prefix_length; - self.message - .attributes - .push(RouteAttribute::Source(RouteAddress::Inet6(addr))); - self - } - - /// Sets the preferred source address. - pub fn pref_source(mut self, addr: Ipv6Addr) -> Self { - self.message - .attributes - .push(RouteAttribute::PrefSource(RouteAddress::Inet6(addr))); - self - } - - /// Sets the destination address prefix. - pub fn destination_prefix( - mut self, - addr: Ipv6Addr, - prefix_length: u8, - ) -> Self { - self.message.header.destination_prefix_length = prefix_length; - self.message - .attributes - .push(RouteAttribute::Destination(RouteAddress::Inet6(addr))); - self - } - - /// Sets the gateway (via) address. - pub fn gateway(mut self, addr: Ipv6Addr) -> Self { - self.message - .attributes - .push(RouteAttribute::Gateway(RouteAddress::Inet6(addr))); - self - } } diff --git a/src/route/builder.rs b/src/route/builder.rs new file mode 100644 index 0000000..c8024c1 --- /dev/null +++ b/src/route/builder.rs @@ -0,0 +1,368 @@ +// SPDX-License-Identifier: MIT + +use std::{ + marker::PhantomData, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, +}; + +use netlink_packet_route::{ + route::{ + RouteAddress, RouteAttribute, RouteHeader, RouteMessage, RouteProtocol, + RouteScope, RouteType, + }, + AddressFamily, +}; + +pub struct RouteMessageBuilder { + message: RouteMessage, + _phantom: PhantomData, +} + +impl RouteMessageBuilder { + fn new_no_address_family() -> Self { + let mut message = RouteMessage::default(); + message.header.table = RouteHeader::RT_TABLE_MAIN; + message.header.protocol = RouteProtocol::Static; + message.header.scope = RouteScope::Universe; + message.header.kind = RouteType::Unicast; + Self { + message, + _phantom: Default::default(), + } + } + + /// Sets the input interface index. + pub fn input_interface(mut self, index: u32) -> Self { + self.message.attributes.push(RouteAttribute::Iif(index)); + self + } + + /// Sets the output interface index. + pub fn output_interface(mut self, index: u32) -> Self { + self.message.attributes.push(RouteAttribute::Oif(index)); + self + } + + /// Sets the route priority (metric) + pub fn priority(mut self, priority: u32) -> Self { + self.message + .attributes + .push(RouteAttribute::Priority(priority)); + self + } + + /// Sets the route table ID. + /// + /// Default is main route table. + pub fn table_id(mut self, table: u32) -> Self { + if table > 255 { + self.message.attributes.push(RouteAttribute::Table(table)); + } else { + self.message.header.table = table as u8; + } + self + } + + /// Sets the route protocol. + /// + /// Default is static route protocol. + pub fn protocol(mut self, protocol: RouteProtocol) -> Self { + self.message.header.protocol = protocol; + self + } + + /// Sets the route scope. + /// + /// Default is universe route scope. + pub fn scope(mut self, scope: RouteScope) -> Self { + self.message.header.scope = scope; + self + } + + /// Sets the route kind. + /// + /// Default is unicast route kind. + pub fn kind(mut self, kind: RouteType) -> Self { + self.message.header.kind = kind; + self + } + + /// Return a mutable reference to the request message. + pub fn get_mut(&mut self) -> &mut RouteMessage { + &mut self.message + } + + pub fn build(self) -> RouteMessage { + self.message + } +} + +impl RouteMessageBuilder { + pub fn new() -> Self { + let mut builder = Self::new_no_address_family(); + builder.get_mut().header.address_family = AddressFamily::Inet; + builder + } + + /// Sets the source address prefix. + pub fn source_prefix(mut self, addr: Ipv4Addr, prefix_length: u8) -> Self { + self.message.header.source_prefix_length = prefix_length; + self.message + .attributes + .push(RouteAttribute::Source(RouteAddress::Inet(addr))); + self + } + + /// Sets the preferred source address. + pub fn pref_source(mut self, addr: Ipv4Addr) -> Self { + self.message + .attributes + .push(RouteAttribute::PrefSource(RouteAddress::Inet(addr))); + self + } + + /// Sets the destination address prefix. + pub fn destination_prefix( + mut self, + addr: Ipv4Addr, + prefix_length: u8, + ) -> Self { + self.message.header.destination_prefix_length = prefix_length; + self.message + .attributes + .push(RouteAttribute::Destination(RouteAddress::Inet(addr))); + self + } + + /// Sets the gateway (via) address. + pub fn gateway(mut self, addr: Ipv4Addr) -> Self { + self.message + .attributes + .push(RouteAttribute::Gateway(RouteAddress::Inet(addr))); + self + } +} + +impl Default for RouteMessageBuilder { + fn default() -> Self { + Self::new() + } +} + +impl RouteMessageBuilder { + pub fn new() -> Self { + let mut builder = Self::new_no_address_family(); + builder.get_mut().header.address_family = AddressFamily::Inet6; + builder + } + + /// Sets the source address prefix. + pub fn source_prefix(mut self, addr: Ipv6Addr, prefix_length: u8) -> Self { + self.message.header.source_prefix_length = prefix_length; + self.message + .attributes + .push(RouteAttribute::Source(RouteAddress::Inet6(addr))); + self + } + + /// Sets the preferred source address. + pub fn pref_source(mut self, addr: Ipv6Addr) -> Self { + self.message + .attributes + .push(RouteAttribute::PrefSource(RouteAddress::Inet6(addr))); + self + } + + /// Sets the destination address prefix. + pub fn destination_prefix( + mut self, + addr: Ipv6Addr, + prefix_length: u8, + ) -> Self { + self.message.header.destination_prefix_length = prefix_length; + self.message + .attributes + .push(RouteAttribute::Destination(RouteAddress::Inet6(addr))); + self + } + + /// Sets the gateway (via) address. + pub fn gateway(mut self, addr: Ipv6Addr) -> Self { + self.message + .attributes + .push(RouteAttribute::Gateway(RouteAddress::Inet6(addr))); + self + } +} + +impl Default for RouteMessageBuilder { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, thiserror::Error)] +pub enum InvalidRouteMessage { + #[error("invalid address family {:?}", _0)] + AddressFamily(AddressFamily), + + #[error("invalid gateway {}", _0)] + Gateway(IpAddr), + + #[error("invalid preferred source {}", _0)] + PrefSource(IpAddr), + + #[error("invalid source prefix {}/{}", _0, _1)] + SourcePrefix(IpAddr, u8), + + #[error("invalid destination prefix {}/{}", _0, _1)] + DestinationPrefix(IpAddr, u8), +} + +impl RouteMessageBuilder { + pub fn new() -> Self { + Self::new_no_address_family() + } + + /// Sets the source address prefix. + pub fn source_prefix( + mut self, + addr: IpAddr, + prefix_length: u8, + ) -> Result { + self.set_address_family_from_ip_addr(addr); + match self.message.header.address_family { + AddressFamily::Inet => { + if addr.is_ipv6() || prefix_length > 32 { + return Err(InvalidRouteMessage::SourcePrefix( + addr, + prefix_length, + )); + } + } + AddressFamily::Inet6 => { + if addr.is_ipv4() || prefix_length > 128 { + return Err(InvalidRouteMessage::SourcePrefix( + addr, + prefix_length, + )); + } + } + af => return Err(InvalidRouteMessage::AddressFamily(af)), + }; + self.message + .attributes + .push(RouteAttribute::Source(addr.into())); + self.message.header.source_prefix_length = prefix_length; + Ok(self) + } + + /// Sets the preferred source address. + pub fn pref_source( + mut self, + addr: IpAddr, + ) -> Result { + self.set_address_family_from_ip_addr(addr); + match self.message.header.address_family { + AddressFamily::Inet => { + if addr.is_ipv6() { + return Err(InvalidRouteMessage::PrefSource(addr)); + }; + } + AddressFamily::Inet6 => { + if addr.is_ipv4() { + return Err(InvalidRouteMessage::PrefSource(addr)); + }; + } + af => { + return Err(InvalidRouteMessage::AddressFamily(af)); + } + } + self.message + .attributes + .push(RouteAttribute::PrefSource(addr.into())); + Ok(self) + } + + /// Sets the destination address prefix. + pub fn destination_prefix( + mut self, + addr: IpAddr, + prefix_length: u8, + ) -> Result { + self.set_address_family_from_ip_addr(addr); + match self.message.header.address_family { + AddressFamily::Inet => { + if addr.is_ipv6() || prefix_length > 32 { + return Err(InvalidRouteMessage::DestinationPrefix( + addr, + prefix_length, + )); + } + } + AddressFamily::Inet6 => { + if addr.is_ipv4() || prefix_length > 128 { + return Err(InvalidRouteMessage::DestinationPrefix( + addr, + prefix_length, + )); + } + } + af => { + return Err(InvalidRouteMessage::AddressFamily(af)); + } + }; + self.message.header.destination_prefix_length = prefix_length; + self.message + .attributes + .push(RouteAttribute::Destination(addr.into())); + Ok(self) + } + + /// Sets the gateway (via) address. + pub fn gateway( + mut self, + addr: IpAddr, + ) -> Result { + self.set_address_family_from_ip_addr(addr); + match self.message.header.address_family { + AddressFamily::Inet => { + if addr.is_ipv6() { + return Err(InvalidRouteMessage::Gateway(addr)); + }; + } + AddressFamily::Inet6 => { + if addr.is_ipv4() { + return Err(InvalidRouteMessage::Gateway(addr)); + }; + } + af => { + return Err(InvalidRouteMessage::AddressFamily(af)); + } + } + self.message + .attributes + .push(RouteAttribute::Gateway(addr.into())); + Ok(self) + } + + /// If it is not set already, set the address family based on the + /// given IP address. This is a noop is the address family is + /// already set. + fn set_address_family_from_ip_addr(&mut self, addr: IpAddr) { + if self.message.header.address_family != AddressFamily::Unspec { + return; + } + if addr.is_ipv4() { + self.message.header.address_family = AddressFamily::Inet; + } else { + self.message.header.address_family = AddressFamily::Inet6; + } + } +} + +impl Default for RouteMessageBuilder { + fn default() -> Self { + Self::new() + } +} diff --git a/src/route/handle.rs b/src/route/handle.rs index e2d2116..c3bc820 100644 --- a/src/route/handle.rs +++ b/src/route/handle.rs @@ -19,8 +19,8 @@ impl RouteHandle { } /// Add an routing table entry (equivalent to `ip route add`) - pub fn add(&self) -> RouteAddRequest { - RouteAddRequest::new(self.0.clone()) + pub fn add(&self, route: RouteMessage) -> RouteAddRequest { + RouteAddRequest::new(self.0.clone(), route) } /// Delete the given routing table entry (equivalent to `ip route del`) diff --git a/src/route/mod.rs b/src/route/mod.rs index 56747b3..e4b476c 100644 --- a/src/route/mod.rs +++ b/src/route/mod.rs @@ -11,3 +11,6 @@ pub use self::del::*; mod get; pub use self::get::*; + +mod builder; +pub use self::builder::RouteMessageBuilder;