diff --git a/Cargo.toml b/Cargo.toml index 404d376..4f6907b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,9 +12,9 @@ repository = "https://github.com/rust-netlink/netlink-packet-core" description = "netlink packet types" [dependencies] -anyhow = "1.0.31" byteorder = "1.3.2" -netlink-packet-utils = "0.5.2" +netlink-packet-utils = { git = "https://github.com/miguelfrde/netlink-packet-utils.git" } +thiserror = "2.0.9" [dev-dependencies] netlink-packet-route = "0.13.0" diff --git a/src/buffer.rs b/src/buffer.rs index 3801583..24d0471 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1,9 +1,8 @@ // SPDX-License-Identifier: MIT use byteorder::{ByteOrder, NativeEndian}; -use netlink_packet_utils::DecodeError; -use crate::{Field, Rest}; +use crate::{CoreError, Field, Rest}; const LENGTH: Field = 0..4; const MESSAGE_TYPE: Field = 4..6; @@ -156,33 +155,20 @@ impl> NetlinkBuffer { /// 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; /// assert!(NetlinkBuffer::new_checked(&BYTES[..]).is_err()); /// ``` - pub fn new_checked(buffer: T) -> Result, DecodeError> { + pub fn new_checked(buffer: T) -> Result, CoreError> { let packet = Self::new(buffer); packet.check_buffer_length()?; Ok(packet) } - fn check_buffer_length(&self) -> Result<(), DecodeError> { + fn check_buffer_length(&self) -> Result<(), CoreError> { let len = self.buffer.as_ref().len(); if len < PORT_NUMBER.end { - Err(format!( - "invalid netlink buffer: length is {} but netlink packets are at least {} bytes", - len, PORT_NUMBER.end - ) - .into()) + Err(CoreError::PacketTooShort { received: len, expected: PORT_NUMBER.end }) } else if len < self.length() as usize { - Err(format!( - "invalid netlink buffer: length field says {} the buffer is {} bytes long", - self.length(), - len - ) - .into()) + Err(CoreError::NonmatchingLength { expected: self.length(), actual: len }) } else if (self.length() as usize) < PORT_NUMBER.end { - Err(format!( - "invalid netlink buffer: length field says {} but netlink packets are at least {} bytes", - self.length(), - len - ).into()) + Err(CoreError::InvalidLength { given: self.length(), at_least: len }) } else { Ok(()) } diff --git a/src/done.rs b/src/done.rs index f81507c..486d2cb 100644 --- a/src/done.rs +++ b/src/done.rs @@ -3,9 +3,8 @@ use std::mem::size_of; use byteorder::{ByteOrder, NativeEndian}; -use netlink_packet_utils::DecodeError; -use crate::{Emitable, Field, Parseable, Rest}; +use crate::{CoreError, Emitable, Field, Parseable, Rest}; const CODE: Field = 0..4; const EXTENDED_ACK: Rest = 4..; @@ -27,20 +26,16 @@ impl> DoneBuffer { self.buffer } - pub fn new_checked(buffer: T) -> Result { + pub fn new_checked(buffer: T) -> Result { let packet = Self::new(buffer); packet.check_buffer_length()?; Ok(packet) } - fn check_buffer_length(&self) -> Result<(), DecodeError> { + fn check_buffer_length(&self) -> Result<(), CoreError> { let len = self.buffer.as_ref().len(); if len < DONE_HEADER_LEN { - Err(format!( - "invalid DoneBuffer: length is {len} but DoneBuffer are \ - at least {DONE_HEADER_LEN} bytes" - ) - .into()) + Err(CoreError::InvalidDoneBuffer { received: len }) } else { Ok(()) } @@ -100,7 +95,9 @@ impl Emitable for DoneMessage { impl<'buffer, T: AsRef<[u8]> + 'buffer> Parseable> for DoneMessage { - fn parse(buf: &DoneBuffer<&'buffer T>) -> Result { + type Error = CoreError; + + fn parse(buf: &DoneBuffer<&'buffer T>) -> Result { Ok(DoneMessage { code: buf.code(), extended_ack: buf.extended_ack().to_vec(), diff --git a/src/error.rs b/src/error.rs index f7951f7..d38f987 100644 --- a/src/error.rs +++ b/src/error.rs @@ -3,9 +3,8 @@ use std::{fmt, io, mem::size_of, num::NonZeroI32}; use byteorder::{ByteOrder, NativeEndian}; -use netlink_packet_utils::DecodeError; -use crate::{Emitable, Field, Parseable, Rest}; +use crate::{CoreError, Emitable, Field, Parseable, Rest}; const CODE: Field = 0..4; const PAYLOAD: Rest = 4..; @@ -27,20 +26,16 @@ impl> ErrorBuffer { self.buffer } - pub fn new_checked(buffer: T) -> Result { + pub fn new_checked(buffer: T) -> Result { let packet = Self::new(buffer); packet.check_buffer_length()?; Ok(packet) } - fn check_buffer_length(&self) -> Result<(), DecodeError> { + fn check_buffer_length(&self) -> Result<(), CoreError> { let len = self.buffer.as_ref().len(); if len < ERROR_HEADER_LEN { - Err(format!( - "invalid ErrorBuffer: length is {len} but ErrorBuffer are \ - at least {ERROR_HEADER_LEN} bytes" - ) - .into()) + Err(CoreError::InvalidErrorBuffer { received: len }) } else { Ok(()) } @@ -118,9 +113,11 @@ impl Emitable for ErrorMessage { impl<'buffer, T: AsRef<[u8]> + 'buffer> Parseable> for ErrorMessage { + type Error = CoreError; + fn parse( buf: &ErrorBuffer<&'buffer T>, - ) -> Result { + ) -> Result { // FIXME: The payload of an error is basically a truncated packet, which // requires custom logic to parse correctly. For now we just // return it as a Vec let header: NetlinkHeader = { diff --git a/src/header.rs b/src/header.rs index 22e2445..50edfa7 100644 --- a/src/header.rs +++ b/src/header.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: MIT -use netlink_packet_utils::DecodeError; - +use crate::CoreError; use crate::{buffer::NETLINK_HEADER_LEN, Emitable, NetlinkBuffer, Parseable}; /// A Netlink header representation. A netlink header has the following @@ -57,7 +56,9 @@ impl Emitable for NetlinkHeader { impl<'a, T: AsRef<[u8]> + ?Sized> Parseable> for NetlinkHeader { - fn parse(buf: &NetlinkBuffer<&'a T>) -> Result { + type Error = CoreError; + + fn parse(buf: &NetlinkBuffer<&'a T>) -> Result { Ok(NetlinkHeader { length: buf.length(), message_type: buf.message_type(), diff --git a/src/lib.rs b/src/lib.rs index f7f1508..8aa1735 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -270,3 +270,51 @@ pub use self::constants::*; pub(crate) use self::utils::traits::*; pub(crate) use netlink_packet_utils as utils; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum CoreError { + #[error("invalid netlink buffer: length is {received} but netlink packets are at least {expected} bytes")] + PacketTooShort { received: usize, expected: usize }, + + #[error("invalid netlink buffer: length field says {expected} but the buffer is {actual} bytes long")] + NonmatchingLength { expected: u32, actual: usize }, + + #[error("invalid netlink buffer: length field says {given} but netlink packets are at least {at_least} bytes")] + InvalidLength { given: u32, at_least: usize }, + + #[error( + "invalid ErrorBuffer: length is {received}, expected at least 4 bytes" + )] + InvalidErrorBuffer { received: usize }, + + #[error( + "invalid DoneBuffer: length is {received}, expected at least 4 bytes" + )] + InvalidDoneBuffer { received: usize }, + + #[error("invalid Netlink header")] + InvalidHeader { + #[source] + due_to: Box, + }, + + #[error("invalid Netlink message of type NLMSG_ERROR")] + InvalidErrorMsg { + #[source] + due_to: Box, + }, + + #[error("invalid Netlink message of type NLMSG_DONE")] + InvalidDoneMsg { + #[source] + due_to: Box, + }, + + #[error("failed to parse the netlink message, of type {message_type}")] + ParseFailure { + message_type: u16, + #[source] + due_to: Box, + }, +} diff --git a/src/message.rs b/src/message.rs index 4bc7dda..361ba4c 100644 --- a/src/message.rs +++ b/src/message.rs @@ -2,12 +2,9 @@ use std::fmt::Debug; -use anyhow::Context; -use netlink_packet_utils::DecodeError; - use crate::{ payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN}, - DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage, + CoreError, DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage, NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable, Parseable, }; @@ -39,7 +36,7 @@ where I: NetlinkDeserializable, { /// Parse the given buffer as a netlink message - pub fn deserialize(buffer: &[u8]) -> Result { + pub fn deserialize(buffer: &[u8]) -> Result { let netlink_buffer = NetlinkBuffer::new_checked(&buffer)?; >>::parse(&netlink_buffer) } @@ -88,33 +85,43 @@ where B: AsRef<[u8]> + 'buffer, I: NetlinkDeserializable, { - fn parse(buf: &NetlinkBuffer<&'buffer B>) -> Result { + type Error = CoreError; + + fn parse(buf: &NetlinkBuffer<&'buffer B>) -> Result { use self::NetlinkPayload::*; let header = >>::parse(buf) - .context("failed to parse netlink header")?; + .map_err(|e| CoreError::InvalidHeader { due_to: e.into() })?; let bytes = buf.payload(); let payload = match header.message_type { NLMSG_ERROR => { let msg = ErrorBuffer::new_checked(&bytes) .and_then(|buf| ErrorMessage::parse(&buf)) - .context("failed to parse NLMSG_ERROR")?; + .map_err(|e| CoreError::InvalidErrorMsg { + due_to: e.into(), + })?; Error(msg) } NLMSG_NOOP => Noop, NLMSG_DONE => { let msg = DoneBuffer::new_checked(&bytes) .and_then(|buf| DoneMessage::parse(&buf)) - .context("failed to parse NLMSG_DONE")?; + .map_err(|e| CoreError::InvalidDoneMsg { + due_to: e.into(), + })?; Done(msg) } NLMSG_OVERRUN => Overrun(bytes.to_vec()), message_type => { - let inner_msg = I::deserialize(&header, bytes).context( - format!("Failed to parse message with type {message_type}"), - )?; + let inner_msg = + I::deserialize(&header, bytes).map_err(|e| { + CoreError::ParseFailure { + message_type, + due_to: e.into(), + } + })?; InnerMessage(inner_msg) } };