From f668c6c6010e109d20fbb6126bc53a8fe317486b Mon Sep 17 00:00:00 2001 From: Jiaqi Gao Date: Thu, 2 Jan 2025 02:35:58 -0500 Subject: [PATCH] virtio-pci: fix potential misaligned issue Signed-off-by: Jiaqi Gao --- src/devices/pci/src/config.rs | 138 ++++++++++++++++----------- src/devices/pci/src/lib.rs | 2 + src/devices/virtio/src/lib.rs | 10 ++ src/devices/virtio/src/mem.rs | 25 +++-- src/devices/virtio/src/virtio_pci.rs | 28 +++--- src/devices/vsock/src/virtio_dump.rs | 2 +- 6 files changed, 126 insertions(+), 79 deletions(-) diff --git a/src/devices/pci/src/config.rs b/src/devices/pci/src/config.rs index 189f499b..385781bb 100644 --- a/src/devices/pci/src/config.rs +++ b/src/devices/pci/src/config.rs @@ -95,12 +95,12 @@ pub fn pci_cf8_read8(bus: u8, device: u8, fnc: u8, reg: u8) -> u8 { } } -fn get_device_details(bus: u8, device: u8, func: u8) -> (u16, u16) { - let config_data = ConfigSpacePciEx::read::(bus, device, func, 0); - ( +fn get_device_details(bus: u8, device: u8, func: u8) -> Result<(u16, u16)> { + let config_data = ConfigSpacePciEx::read::(bus, device, func, 0)?; + Ok(( (config_data & 0xffff) as u16, ((config_data & 0xffff0000) >> 0x10) as u16, - ) + )) } pub fn find_device(vendor_id: u16, device_id: u16) -> Option<(u8, u8, u8)> { @@ -108,7 +108,7 @@ pub fn find_device(vendor_id: u16, device_id: u16) -> Option<(u8, u8, u8)> { const INVALID_VENDOR_ID: u16 = 0xffff; for device in 0..MAX_DEVICES { - if (vendor_id, device_id) == get_device_details(0, device, 0) { + if (vendor_id, device_id) == get_device_details(0, device, 0).ok()? { return Some((0, device, 0)); } if vendor_id == INVALID_VENDOR_ID { @@ -191,12 +191,12 @@ impl ConfigSpace { } /// Get vendor_id and device_id - pub fn get_device_details(bus: u8, device: u8, func: u8) -> (u16, u16) { - let config_data = ConfigSpacePciEx::read::(bus, device, func, 0); - ( + pub fn get_device_details(bus: u8, device: u8, func: u8) -> Result<(u16, u16)> { + let config_data = ConfigSpacePciEx::read::(bus, device, func, 0)?; + Ok(( (config_data & 0xffff) as u16, ((config_data & 0xffff0000) >> 0x10) as u16, - ) + )) } fn get_config_address(bus: u8, device: u8, func: u8, offset: u8) -> ConfigAddress { @@ -215,49 +215,73 @@ impl ConfigSpace { pub struct ConfigSpacePciEx; impl ConfigSpacePciEx { #[cfg(not(feature = "fuzz"))] - pub fn read(bus: u8, device: u8, func: u8, offset: u16) -> T { + pub fn read(bus: u8, device: u8, func: u8, offset: u16) -> Result { let addr = PCI_EX_BAR_BASE_ADDRESS + ((bus as u64) << 20) + ((device as u64) << 15) + ((func as u64) << 12) + offset as u64; + if addr % size_of::() as u64 != 0 { + return Err(PciError::Misaligned); + } #[cfg(feature = "iocall")] unsafe { - core::ptr::read_volatile(addr as *const T) + Ok(core::ptr::read_volatile(addr as *const T)) } #[cfg(feature = "tdcall")] - tdx_tdcall::tdx::tdvmcall_mmio_read(addr as usize) + Ok(tdx_tdcall::tdx::tdvmcall_mmio_read(addr as usize)) } #[cfg(feature = "fuzz")] - pub fn read(_bus: u8, _device: u8, _func: u8, offset: u16) -> T { + pub fn read(_bus: u8, _device: u8, _func: u8, offset: u16) -> Result { let base_address = crate::get_fuzz_seed_address(); let address = base_address + offset as u64; - unsafe { core::ptr::read_volatile(address as *const T) } + if address % size_of::() as u64 != 0 { + return Err(PciError::Misaligned); + } + unsafe { Ok(core::ptr::read_volatile(address as *const T)) } } #[cfg(not(feature = "fuzz"))] - pub fn write(bus: u8, device: u8, func: u8, offset: u16, value: T) { + pub fn write( + bus: u8, + device: u8, + func: u8, + offset: u16, + value: T, + ) -> Result<()> { let addr = PCI_EX_BAR_BASE_ADDRESS + ((bus as u64) << 20) + ((device as u64) << 15) + ((func as u64) << 12) + offset as u64; + + if addr % size_of::() as u64 != 0 { + return Err(PciError::Misaligned); + } #[cfg(feature = "iocall")] unsafe { - core::ptr::write_volatile(addr as *mut T, value) + core::ptr::write_volatile(addr as *mut T, value); } #[cfg(feature = "tdcall")] tdx_tdcall::tdx::tdvmcall_mmio_write(addr as *mut T, value); + Ok(()) } #[cfg(feature = "fuzz")] - pub fn write(_bus: u8, _device: u8, _func: u8, offset: u16, value: T) { - unsafe { - let base_address = crate::get_fuzz_seed_address(); - let address = base_address + offset as u64; - core::ptr::write_volatile(address as *mut T, value) + pub fn write( + _bus: u8, + _device: u8, + _func: u8, + offset: u16, + value: T, + ) -> Result<()> { + let base_address = crate::get_fuzz_seed_address(); + let address = base_address + offset as u64; + if address % size_of::() as u64 != 0 { + return Err(PciError::Misaligned); } + unsafe { Ok(core::ptr::write_volatile(address as *mut T, value)) } } } @@ -384,11 +408,11 @@ impl PciDevice { #[cfg(not(feature = "fuzz"))] pub fn init(&mut self) -> Result<()> { let (vendor_id, device_id) = - ConfigSpace::get_device_details(self.bus, self.device, self.func); + ConfigSpace::get_device_details(self.bus, self.device, self.func)?; self.common_header.vendor_id = vendor_id; self.common_header.device_id = device_id; - let command = self.read_u16(0x4); - let status = self.read_u16(0x6); + let command = self.read_u16(0x4)?; + let status = self.read_u16(0x6)?; log::info!( "PCI Device: {}:{}.{} {:x}:{:x}\nbit \t fedcba9876543210\nstate\t {:016b}\ncommand\t {:016b}\n", self.bus, @@ -405,7 +429,7 @@ impl PciDevice { //0x24 offset is last bar while current_bar_offset <= 0x24 { - let bar = self.read_u32(current_bar_offset); + let bar = self.read_u32(current_bar_offset)?; // lsb is 1 for I/O space bars if bar & 1 == 1 { @@ -415,11 +439,11 @@ impl PciDevice { // bits 2-1 are the type 0 is 32-but, 2 is 64 bit match bar >> 1 & 3 { 0 => { - let size = self.get_bar_size(current_bar_offset); + let size = self.get_bar_size(current_bar_offset)?; let addr = if size > 0 { let addr = alloc_mmio32(size)?; - self.set_bar_addr(current_bar_offset, addr); + self.set_bar_addr(current_bar_offset, addr)?; addr } else { bar @@ -432,15 +456,15 @@ impl PciDevice { 2 => { self.bars[current_bar].bar_type = PciBarType::MemorySpace64; - let mut size = self.get_bar_size(current_bar_offset) as u64; + let mut size = self.get_bar_size(current_bar_offset)? as u64; if size == 0 { - size = (self.get_bar_size(current_bar_offset + 4) as u64) << 32; + size = (self.get_bar_size(current_bar_offset + 4)? as u64) << 32; } let addr = if size > 0 { let addr = alloc_mmio64(size)?; - self.set_bar_addr(current_bar_offset, addr as u32); - self.set_bar_addr(current_bar_offset + 4, (addr >> 32) as u32); + self.set_bar_addr(current_bar_offset, addr as u32)?; + self.set_bar_addr(current_bar_offset + 4, (addr >> 32) as u32)?; addr } else { bar as u64 @@ -461,7 +485,7 @@ impl PciDevice { self.write_u16( 0x4, (PciCommand::IO_SPACE | PciCommand::MEMORY_SPACE | PciCommand::BUS_MASTER).bits(), - ); + )?; for bar in &self.bars { log::info!("Bar: type={:?} address={:x}\n", bar.bar_type, bar.address); } @@ -472,18 +496,18 @@ impl PciDevice { #[cfg(feature = "fuzz")] pub fn init(&mut self) -> Result<()> { let (vendor_id, device_id) = - ConfigSpace::get_device_details(self.bus, self.device, self.func); + ConfigSpace::get_device_details(self.bus, self.device, self.func)?; self.common_header.vendor_id = vendor_id; self.common_header.device_id = device_id; - let command = self.read_u16(0x4); - let status = self.read_u16(0x6); + let command = self.read_u16(0x4)?; + let status = self.read_u16(0x6)?; let mut current_bar_offset = 0x10; let mut current_bar = 0; //0x24 offset is last bar while current_bar_offset <= 0x24 { - let bar = self.read_u32(current_bar_offset); + let bar = self.read_u32(current_bar_offset)?; // lsb is 1 for I/O space bars if bar & 1 == 1 { @@ -493,11 +517,11 @@ impl PciDevice { // bits 2-1 are the type 0 is 32-but, 2 is 64 bit match bar >> 1 & 3 { 0 => { - let size = self.read_u32(current_bar_offset); + let size = self.read_u32(current_bar_offset)?; let addr = if size > 0 { let addr = alloc_mmio32(size)?; - self.set_bar_addr(current_bar_offset, addr); + self.set_bar_addr(current_bar_offset, addr)?; addr } else { bar @@ -510,11 +534,11 @@ impl PciDevice { 2 => { self.bars[current_bar].bar_type = PciBarType::MemorySpace64; - let mut size = self.read_u64(current_bar_offset); + let mut size = self.read_u64(current_bar_offset)?; let addr = if size > 0 { let addr = alloc_mmio64(size)?; - self.set_bar_addr(current_bar_offset, addr as u32); - self.set_bar_addr(current_bar_offset + 4, (addr >> 32) as u32); + self.set_bar_addr(current_bar_offset, addr as u32)?; + self.set_bar_addr(current_bar_offset + 4, (addr >> 32) as u32)?; addr } else { bar as u64 @@ -540,56 +564,56 @@ impl PciDevice { Ok(()) } - fn set_bar_addr(&self, offset: u8, addr: u32) { - self.write_u32(offset, addr); + fn set_bar_addr(&self, offset: u8, addr: u32) -> Result<()> { + self.write_u32(offset, addr) } - fn get_bar_size(&self, offset: u8) -> u32 { - let restore = self.read_u32(offset); - self.write_u32(offset, u32::MAX); - let size = self.read_u32(offset); - self.write_u32(offset, restore); + fn get_bar_size(&self, offset: u8) -> Result { + let restore = self.read_u32(offset)?; + self.write_u32(offset, u32::MAX)?; + let size = self.read_u32(offset)?; + self.write_u32(offset, restore)?; - if size == 0 { + Ok(if size == 0 { size } else { !(size & 0xFFFF_FFF0) + 1 - } + }) } - pub fn read_u64(&self, offset: u8) -> u64 { + pub fn read_u64(&self, offset: u8) -> Result { ConfigSpacePciEx::read::(self.bus, self.device, self.func, offset as u16) // let low = ConfigSpace::read32(self.bus, self.device, self.func, offset); // let high = ConfigSpace::read32(self.bus, self.device, self.func, offset + 8); // (low as u64) & ((high as u64) << 8) } - pub fn read_u32(&self, offset: u8) -> u32 { + pub fn read_u32(&self, offset: u8) -> Result { ConfigSpacePciEx::read::(self.bus, self.device, self.func, offset as u16) // ConfigSpace::read32(self.bus, self.device, self.func, offset) } - pub fn read_u16(&self, offset: u8) -> u16 { + pub fn read_u16(&self, offset: u8) -> Result { ConfigSpacePciEx::read::(self.bus, self.device, self.func, offset as u16) // ConfigSpace::read16(self.bus, self.device, self.func, offset) } - pub fn read_u8(&self, offset: u8) -> u8 { + pub fn read_u8(&self, offset: u8) -> Result { ConfigSpacePciEx::read::(self.bus, self.device, self.func, offset as u16) // ConfigSpace::read8(self.bus, self.device, self.func, offset) } - pub fn write_u32(&self, offset: u8, value: u32) { + pub fn write_u32(&self, offset: u8, value: u32) -> Result<()> { ConfigSpacePciEx::write::(self.bus, self.device, self.func, offset as u16, value) // ConfigSpace::write32(self.bus, self.device, self.func, offset, value) } - pub fn write_u16(&self, offset: u8, value: u16) { + pub fn write_u16(&self, offset: u8, value: u16) -> Result<()> { ConfigSpacePciEx::write::(self.bus, self.device, self.func, offset as u16, value) // ConfigSpace::write16(self.bus, self.device, self.func, offset, value) } - pub fn write_u8(&self, offset: u8, value: u8) { + pub fn write_u8(&self, offset: u8, value: u8) -> Result<()> { ConfigSpacePciEx::write::(self.bus, self.device, self.func, offset as u16, value) // ConfigSpace::write8(self.bus, self.device, self.func, offset, value) } diff --git a/src/devices/pci/src/lib.rs b/src/devices/pci/src/lib.rs index c576c721..abdf668d 100644 --- a/src/devices/pci/src/lib.rs +++ b/src/devices/pci/src/lib.rs @@ -21,8 +21,10 @@ pub fn get_fuzz_seed_address() -> u64 { pub type Result = core::result::Result; +#[derive(Debug)] pub enum PciError { InvalidParameter, MmioOutofResource, InvalidBarType, + Misaligned, } diff --git a/src/devices/virtio/src/lib.rs b/src/devices/virtio/src/lib.rs index b3d69074..058307c9 100644 --- a/src/devices/virtio/src/lib.rs +++ b/src/devices/virtio/src/lib.rs @@ -6,6 +6,7 @@ extern crate alloc; use core::fmt::Display; use mem::MemoryRegionError; +use pci::PciError; pub mod consts; mod mem; @@ -47,6 +48,8 @@ pub enum VirtioError { InvalidRingIndex, /// Invalid index for ring InvalidDescriptor, + /// Pci related error + Pci(PciError), } impl Display for VirtioError { @@ -69,6 +72,7 @@ impl Display for VirtioError { VirtioError::InvalidDescriptorIndex => write!(f, "InvalidDescriptorIndex"), VirtioError::InvalidRingIndex => write!(f, "InvalidRingIndex"), VirtioError::InvalidDescriptor => write!(f, "InvalidDescriptor"), + VirtioError::Pci(_) => write!(f, "Pci"), } } } @@ -79,6 +83,12 @@ impl From for VirtioError { } } +impl From for VirtioError { + fn from(e: PciError) -> Self { + VirtioError::Pci(e) + } +} + pub type Result = core::result::Result; /// Trait to allow separation of transport from block driver diff --git a/src/devices/virtio/src/mem.rs b/src/devices/virtio/src/mem.rs index d25c686b..6d071471 100644 --- a/src/devices/virtio/src/mem.rs +++ b/src/devices/virtio/src/mem.rs @@ -179,11 +179,14 @@ impl MemoryRegion { #[cfg(feature = "fuzz")] fn mmio_read(&self, offset: u64) -> Result { - unsafe { - Ok(core::ptr::read_volatile( - (pci::get_fuzz_seed_address() + 0x10c + offset) as *const T, - )) + let address = pci::get_fuzz_seed_address() + 0x10c + offset; + if address as usize % size_of::() != 0 { + return Err(MemoryRegionError { + region: *self, + offset, + }); } + unsafe { Ok(core::ptr::read_volatile(address as *const T)) } } #[cfg(not(feature = "fuzz"))] /// Read a value at given offset with a mechanism suitable for MMIO @@ -193,6 +196,7 @@ impl MemoryRegion { .checked_add(size_of::() as u64) .and_then(|end| if end > self.length { None } else { Some(end) }) .is_none() + || (self.base + offset) % size_of::() as u64 != 0 { return Err(MemoryRegionError { region: *self, @@ -233,11 +237,15 @@ impl MemoryRegion { #[cfg(feature = "fuzz")] fn mmio_write(&self, offset: u64, value: T) -> Result<(), MemoryRegionError> { + let address = pci::get_fuzz_seed_address() + 0x10c + offset; + if address as usize % size_of::() != 0 { + return Err(MemoryRegionError { + region: *self, + offset, + }); + } unsafe { - core::ptr::write_volatile( - (pci::get_fuzz_seed_address() + 0x10c + offset) as *mut T, - value, - ); + core::ptr::write_volatile(address as *mut T, value); } Ok(()) @@ -250,6 +258,7 @@ impl MemoryRegion { .checked_add(size_of::() as u64) .and_then(|end| if end > self.length { None } else { Some(end) }) .is_none() + || (self.base + offset) % size_of::() as u64 != 0 { return Err(MemoryRegionError { region: *self, diff --git a/src/devices/virtio/src/virtio_pci.rs b/src/devices/virtio/src/virtio_pci.rs index 75745706..32fcb76b 100644 --- a/src/devices/virtio/src/virtio_pci.rs +++ b/src/devices/virtio/src/virtio_pci.rs @@ -164,15 +164,15 @@ impl VirtioTransport for VirtioPciTransport { let mut cycle_flag = 0usize; let mut cycle_list = [0u8; CYCLE_LEN]; // Read status register - let status = self.device.read_u16(STATUS_OFFSET); - let device_id = self.device.read_u16(DEVICE_OFFSET); + let status = self.device.read_u16(STATUS_OFFSET)?; + let device_id = self.device.read_u16(DEVICE_OFFSET)?; // bit 4 of status is capability bit if status & 1 << 4 == 0 { return Err(VirtioError::VirtioUnsupportedDevice); } // capabilities list offset is at 0x34 - let mut cap_next = self.device.read_u8(PCI_CAP_POINTER); + let mut cap_next = self.device.read_u8(PCI_CAP_POINTER)?; while cap_next <= u8::MAX - CAP_LEN + 1 && cap_next > 0 { if cycle_list.contains(&cap_next) { @@ -180,7 +180,7 @@ impl VirtioTransport for VirtioPciTransport { } cycle_list[cycle_flag] = cap_next; - let capability = self.device.read_u8(cap_next); + let capability = self.device.read_u8(cap_next)?; // vendor specific capability if capability == VIRTIO_CAPABILITIES_SPECIFIC { // These offsets are into the following structure: @@ -199,11 +199,11 @@ impl VirtioTransport for VirtioPciTransport { return Err(VirtioError::InvalidParameter); } - let cfg_type = self.device.read_u8(cap_next + VIRTIO_CFG_TYPE_OFFSET); + let cfg_type = self.device.read_u8(cap_next + VIRTIO_CFG_TYPE_OFFSET)?; #[allow(clippy::disallowed_names)] - let bar = self.device.read_u8(cap_next + VIRTIO_BAR_OFFSET); - let offset = self.device.read_u32(cap_next + VIRTIO_CAP_OFFSET); - let length = self.device.read_u32(cap_next + VIRTIO_CAP_LENGTH_OFFSET); + let bar = self.device.read_u8(cap_next + VIRTIO_BAR_OFFSET)?; + let offset = self.device.read_u32(cap_next + VIRTIO_CAP_OFFSET)?; + let length = self.device.read_u32(cap_next + VIRTIO_CAP_LENGTH_OFFSET)?; if bar > MAX_BARS_INDEX { return Err(VirtioError::InvalidParameter); @@ -266,7 +266,7 @@ impl VirtioTransport for VirtioPciTransport { // struct virtio_pci_cap cap; // le32 notify_off_multiplier; /* Multiplier for queue_notify_off. */ // }; - self.notify_off_multiplier = self.device.read_u32(cap_next + CAP_LEN); + self.notify_off_multiplier = self.device.read_u32(cap_next + CAP_LEN)?; } fn device_length_check(device_id: u16, length: u32) -> Option { @@ -312,15 +312,17 @@ impl VirtioTransport for VirtioPciTransport { )?; } } else if capability == MSIX_CAPABILITY_ID { - let mcr = self.device.read_u16(cap_next + MSIX_MESSAGE_CONTROL_OFFSET); - let bir = (self.device.read_u32(cap_next + MSIX_BIR_OFFSET) & 0x7) as u8; + let mcr = self + .device + .read_u16(cap_next + MSIX_MESSAGE_CONTROL_OFFSET)?; + let bir = (self.device.read_u32(cap_next + MSIX_BIR_OFFSET)? & 0x7) as u8; // BIR specifies which BAR is used for the Message Table, which should be less than 6 if bir as usize >= self.device.bars.len() { return Err(VirtioError::InvalidParameter); } - let table_offset = self.device.read_u32(cap_next + MSIX_BIR_OFFSET) & 0xffff_fff8; + let table_offset = self.device.read_u32(cap_next + MSIX_BIR_OFFSET)? & 0xffff_fff8; // Message Control: // Bit 15 Bit 14 Bits 13-11 Bits 10-0 // Enable Function Mask Reserved Table Size @@ -348,7 +350,7 @@ impl VirtioTransport for VirtioPciTransport { if cycle_flag >= CYCLE_LEN { return Err(VirtioError::InvalidParameter); } - cap_next = self.device.read_u8(cap_next + 1) + cap_next = self.device.read_u8(cap_next + 1)? } // According to virtio-v1.1 section 4.1.4 Virtio Structure PCI Capabilities diff --git a/src/devices/vsock/src/virtio_dump.rs b/src/devices/vsock/src/virtio_dump.rs index d06a87eb..8b6f22cc 100644 --- a/src/devices/vsock/src/virtio_dump.rs +++ b/src/devices/vsock/src/virtio_dump.rs @@ -65,7 +65,7 @@ struct BarAddress { impl BarAddress { pub fn read(&self, offset: u16) -> T { let device = 1; - pci::ConfigSpacePciEx::read::(0, device, 0, offset) + pci::ConfigSpacePciEx::read::(0, device, 0, offset).expect("Invalid PCI read") } pub fn bars(&mut self) {