Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace many occurrences of RawFd with OwnedFd/BorrowedFd #106

Merged
merged 8 commits into from
Feb 21, 2024
26 changes: 17 additions & 9 deletions core/src/mfd.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
//! Implementation of a memory fd

use std::io::{Read, Seek, SeekFrom, Write};
use std::os::fd::{AsRawFd, RawFd};
use std::{
io::{Read, Seek, SeekFrom, Write},
os::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd},
};

use anyhow::{bail, Result};
use memfd::{FileSeal, Memfd, MemfdOptions};
Expand Down Expand Up @@ -64,15 +66,9 @@ impl Mfd {
Ok(())
}

/// Returns the actual FD behind the mfd
/// TODO: Implement the AsRawFd trait instead
pub fn get_fd(&self) -> RawFd {
self.0.as_raw_fd()
}

/// Creates a memfd from a fd
// TODO: Use some Rust try_from stuff
pub fn from_fd(fd: RawFd) -> Result<Self> {
pub fn from_fd(fd: OwnedFd) -> Result<Self> {
let fd = match Memfd::try_from_fd(fd) {
Ok(memfd) => memfd,
Err(_) => bail!("cannot get Memfd from RawFd"),
Expand All @@ -81,6 +77,18 @@ impl Mfd {
}
}

impl AsFd for Mfd {
fn as_fd(&self) -> BorrowedFd<'_> {
self.0.as_file().as_fd()
}
}

impl AsRawFd for Mfd {
fn as_raw_fd(&self) -> RawFd {
self.0.as_raw_fd()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
13 changes: 7 additions & 6 deletions core/src/sampling.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashSet;
use std::convert::AsRef;
use std::os::fd::{AsFd, BorrowedFd};
use std::os::unix::prelude::{AsRawFd, OwnedFd, RawFd};
use std::time::Instant;

Expand Down Expand Up @@ -106,7 +107,7 @@ impl Sampling {
let (dir, fd, port) = if self.source_port.partition.eq(part.as_ref()) {
(
PortDirection::Source,
self.source_fd(),
self.source_fd().as_raw_fd(),
&self.source_port.port,
)
} else if let Some(port) = self
Expand All @@ -116,7 +117,7 @@ impl Sampling {
{
(
PortDirection::Destination,
self.destination_fd(),
self.destination_fd().as_raw_fd(),
&port.port,
)
} else {
Expand Down Expand Up @@ -197,12 +198,12 @@ impl Sampling {
Ok(())
}

pub fn source_fd(&self) -> RawFd {
self.source.as_raw_fd()
pub fn source_fd(&self) -> BorrowedFd {
self.source.as_fd()
}

pub fn destination_fd(&self) -> RawFd {
self.destination.as_raw_fd()
pub fn destination_fd(&self) -> BorrowedFd {
self.destination.as_fd()
}
}

Expand Down
22 changes: 11 additions & 11 deletions core/src/shmem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ use crate::error::{ResultExt, SystemError, TypedError, TypedResult};

#[derive(Debug)]
/// Internal data type for a mutable typed memory map
pub struct TypedMmapMut<T: Send + Sized> {
pub struct TypedMmapMut<'a, T: Send + Sized> {
mmap: MmapMut,
_p: PhantomData<T>,
_p: PhantomData<&'a T>,
}

impl<T: Send + Sized> TypedMmapMut<T> {
impl<'a, T: Send + Sized> TypedMmapMut<'a, T> {
/// Returns the length of the memory map
pub fn len(&self) -> usize {
self.mmap.len()
Expand All @@ -26,19 +26,19 @@ impl<T: Send + Sized> TypedMmapMut<T> {
}
}

impl<T: Send + Sized> AsRef<T> for TypedMmapMut<T> {
impl<'a, T: Send + Sized> AsRef<T> for TypedMmapMut<'a, T> {
fn as_ref(&self) -> &T {
unsafe { (self.mmap.as_ptr() as *const T).as_ref() }.unwrap()
}
}

impl<T: Send + Sized> AsMut<T> for TypedMmapMut<T> {
impl<'a, T: Send + Sized> AsMut<T> for TypedMmapMut<'a, T> {
fn as_mut(&mut self) -> &mut T {
unsafe { (self.mmap.as_mut_ptr() as *mut T).as_mut() }.unwrap()
}
}

impl<T: Send + Sized> TryFrom<MmapMut> for TypedMmapMut<T> {
impl<'a, T: Send + Sized> TryFrom<MmapMut> for TypedMmapMut<'a, T> {
type Error = TypedError;

fn try_from(mmap: MmapMut) -> TypedResult<Self> {
Expand All @@ -59,12 +59,12 @@ impl<T: Send + Sized> TryFrom<MmapMut> for TypedMmapMut<T> {

#[derive(Debug)]
/// Internal data type for a mutable typed memory map
pub struct TypedMmap<T: Send + Sized> {
pub struct TypedMmap<'a, T: Send + Sized> {
mmap: Mmap,
_p: PhantomData<T>,
_p: PhantomData<&'a T>,
}

impl<T: Send + Sized> TypedMmap<T> {
impl<'a, T: Send + Sized> TypedMmap<'a, T> {
/// Returns the length of the memory map
pub fn len(&self) -> usize {
self.mmap.len()
Expand All @@ -76,13 +76,13 @@ impl<T: Send + Sized> TypedMmap<T> {
}
}

impl<T: Send + Sized> AsRef<T> for TypedMmap<T> {
impl<'a, T: Send + Sized> AsRef<T> for TypedMmap<'a, T> {
fn as_ref(&self) -> &T {
unsafe { (self.mmap.as_ptr() as *const T).as_ref() }.unwrap()
}
}

impl<T: Send + Sized> TryFrom<Mmap> for TypedMmap<T> {
impl<'a, T: Send + Sized> TryFrom<Mmap> for TypedMmap<'a, T> {
type Error = TypedError;

fn try_from(mmap: Mmap) -> TypedResult<Self> {
Expand Down
44 changes: 24 additions & 20 deletions hypervisor/src/hypervisor/syscall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,42 @@

use std::io::IoSliceMut;
use std::num::NonZeroUsize;
use std::os::fd::RawFd;
use std::os::fd::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
use std::time::{Duration, Instant};

use a653rs_linux_core::mfd::{Mfd, Seals};
use a653rs_linux_core::syscall::{SyscallRequ, SyscallResp};
use anyhow::{bail, Result};
use anyhow::{anyhow, bail, Result};
use libc::EINTR;
use nix::sys::socket::{recvmsg, ControlMessageOwned, MsgFlags};
use nix::{cmsg_space, unistd};
use polling::{Event, Events, Poller};

/// Receives an FD triple from fd
// TODO: Use generics here
fn recv_fd_triple(fd: RawFd) -> Result<[RawFd; 3]> {
fn recv_fd_triple(fd: BorrowedFd) -> Result<[OwnedFd; 3]> {
let mut cmsg = cmsg_space!([RawFd; 3]);
let mut iobuf = [0u8];
let mut iov = [IoSliceMut::new(&mut iobuf)];
let res = recvmsg::<()>(fd, &mut iov, Some(&mut cmsg), MsgFlags::empty())?;
let res = recvmsg::<()>(fd.as_raw_fd(), &mut iov, Some(&mut cmsg), MsgFlags::empty())?;

let fds: Vec<RawFd> = match res.cmsgs().next().unwrap() {
ControlMessageOwned::ScmRights(fds) => fds,
_ => bail!("received an unknown cmsg"),
};
if fds.len() != 3 {
bail!("received fds but not a tripe")
}

Ok([fds[0], fds[1], fds[2]])
let fds = fds
.into_iter()
.map(|fd| unsafe { OwnedFd::from_raw_fd(fd) })
.collect::<Vec<_>>();
fds.try_into()
.map_err(|_| anyhow!("received fds but not a tripe"))
}

/// Waits for readable data on fd
fn wait_fds(fd: RawFd, timeout: Option<Duration>) -> Result<bool> {
fn wait_fds(fd: BorrowedFd, timeout: Option<Duration>) -> Result<bool> {
let poller = Poller::new()?;
let mut events = Events::with_capacity(NonZeroUsize::MIN);
unsafe { poller.add(fd, Event::readable(0))? };
unsafe { poller.add(fd.as_raw_fd(), Event::readable(0))? };
loop {
match poller.wait(&mut events, timeout) {
Ok(0) => return Ok(false),
Expand All @@ -56,7 +57,7 @@ fn wait_fds(fd: RawFd, timeout: Option<Duration>) -> Result<bool> {
/// Handles an unlimited amount of system calls, until timeout is reached
///
/// Returns the amount of executed system calls
pub fn handle(fd: RawFd, timeout: Option<Duration>) -> Result<u32> {
pub fn handle(fd: BorrowedFd, timeout: Option<Duration>) -> Result<u32> {
let start = Instant::now();
let mut nsyscalls: u32 = 0;

Expand All @@ -77,10 +78,9 @@ pub fn handle(fd: RawFd, timeout: Option<Duration>) -> Result<u32> {
assert!(res);
}

let fds = recv_fd_triple(fd)?;
let mut requ_fd = Mfd::from_fd(fds[0])?;
let mut resp_fd = Mfd::from_fd(fds[1])?;
let event_fd = fds[2];
let [requ_fd, resp_fd, event_fd] = recv_fd_triple(fd)?;
let mut requ_fd = Mfd::from_fd(requ_fd)?;
let mut resp_fd = Mfd::from_fd(resp_fd)?;

// Fetch the request
let requ = SyscallRequ::deserialize(&requ_fd.read_all()?)?;
Expand All @@ -96,7 +96,7 @@ pub fn handle(fd: RawFd, timeout: Option<Duration>) -> Result<u32> {

// Trigger the event
let buf = 1_u64.to_ne_bytes();
unistd::write(event_fd, &buf)?;
unistd::write(event_fd.as_raw_fd(), &buf)?;

nsyscalls += 1;
}
Expand All @@ -107,7 +107,7 @@ pub fn handle(fd: RawFd, timeout: Option<Duration>) -> Result<u32> {
#[cfg(test)]
mod tests {
use std::io::IoSlice;
use std::os::fd::AsRawFd;
use std::os::fd::{AsFd, AsRawFd};

use a653rs_linux_core::syscall::ApexSyscall;
use nix::sys::eventfd::{eventfd, EfdFlags};
Expand Down Expand Up @@ -147,7 +147,11 @@ mod tests {

// Send the fds to the responder
{
let fds = [requ_fd.get_fd(), resp_fd.get_fd(), event_fd.as_raw_fd()];
let fds = [
requ_fd.as_raw_fd(),
resp_fd.as_raw_fd(),
event_fd.as_raw_fd(),
];
let cmsg = [ControlMessage::ScmRights(&fds)];
let buffer = 0_u64.to_be_bytes();
let iov = [IoSlice::new(buffer.as_slice())];
Expand All @@ -171,7 +175,7 @@ mod tests {
});

let response_thread = std::thread::spawn(move || {
let n = handle(responder.as_raw_fd(), Some(Duration::from_secs(1))).unwrap();
let n = handle(responder.as_fd(), Some(Duration::from_secs(1))).unwrap();
assert_eq!(n, 1);
});

Expand Down
6 changes: 3 additions & 3 deletions partition/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::net::{TcpStream, UdpSocket};
#[cfg(feature = "socket")]
use a653rs_linux_core::ipc::IoReceiver;

use std::os::fd::{AsRawFd, IntoRawFd, RawFd};
use std::os::fd::{AsRawFd, OwnedFd};
use std::os::unix::prelude::FromRawFd;
use std::time::{Duration, Instant};

Expand Down Expand Up @@ -107,7 +107,7 @@ pub(crate) static SIGNAL_STACK: Lazy<MmapMut> = Lazy::new(|| {
.unwrap()
});

pub(crate) static SYSCALL: Lazy<RawFd> = Lazy::new(|| {
pub(crate) static SYSCALL: Lazy<OwnedFd> = Lazy::new(|| {
let syscall_socket = socket::socket(
AddressFamily::Unix,
SockType::Datagram,
Expand All @@ -122,7 +122,7 @@ pub(crate) static SYSCALL: Lazy<RawFd> = Lazy::new(|| {
)
.unwrap();

syscall_socket.into_raw_fd()
syscall_socket
});

#[cfg(feature = "socket")]
Expand Down
Loading
Loading