diff options
-rw-r--r-- | .buildkite/pipeline.yml | 17 | ||||
-rw-r--r-- | .github/dependabot.yml | 7 | ||||
-rw-r--r-- | .gitmodules | 3 | ||||
-rw-r--r-- | Android.bp | 6 | ||||
-rw-r--r-- | Cargo.toml | 9 | ||||
-rw-r--r-- | METADATA | 6 | ||||
-rw-r--r-- | coverage_config_x86_64.json | 2 | ||||
m--------- | rust-vmm-ci | 0 | ||||
-rw-r--r-- | src/vhost_user/connection.rs | 152 | ||||
-rw-r--r-- | src/vhost_user/dummy_slave.rs | 52 | ||||
-rw-r--r-- | src/vhost_user/master.rs | 150 | ||||
-rw-r--r-- | src/vhost_user/master_req_handler.rs | 78 | ||||
-rw-r--r-- | src/vhost_user/message.rs | 210 | ||||
-rw-r--r-- | src/vhost_user/mod.rs | 31 | ||||
-rw-r--r-- | src/vhost_user/slave_fs_cache.rs | 33 | ||||
-rw-r--r-- | src/vhost_user/slave_req_handler.rs | 262 |
16 files changed, 649 insertions, 369 deletions
diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml new file mode 100644 index 0000000..0e77e1f --- /dev/null +++ b/.buildkite/pipeline.yml @@ -0,0 +1,17 @@ +# Copyright 2021 The Chromium OS Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE-BSD-Google file. + +steps: + - label: "clippy-x86-custom" + commands: + - cargo clippy --all-features --all-targets --workspace -- -D warnings + retry: + automatic: false + agents: + platform: x86_64.metal + os: linux + plugins: + - docker#v3.0.1: + image: "rustvmm/dev:v12" + always-pull: true diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..4fcd556 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,7 @@ +version: 2 +updates: +- package-ecosystem: gitsubmodule + directory: "/" + schedule: + interval: daily + open-pull-requests-limit: 10 diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index bda97eb..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "rust-vmm-ci"] - path = rust-vmm-ci - url = https://github.com/rust-vmm/rust-vmm-ci.git @@ -1,4 +1,4 @@ -// This file is generated by cargo2android.py --run. +// This file is generated by cargo2android.py --run --device --features default,vhost-user,vhost-user-master,vhost-user-slave --global_defaults crosvm_defaults. // Do not modify this file as changes will be overridden on upgrade. package { @@ -40,8 +40,9 @@ license { rust_library { name: "libvmm_vhost", - crate_name: "vmm_vhost", + defaults: ["crosvm_defaults"], host_supported: true, + crate_name: "vmm_vhost", srcs: ["src/lib.rs"], edition: "2018", features: [ @@ -56,5 +57,4 @@ rust_library { "libsys_util", "libtempfile", ], - defaults: ["crosvm_defaults"], } @@ -21,10 +21,13 @@ vhost-user-slave = ["vhost-user"] [dependencies] bitflags = ">=1.0.1" libc = ">=0.2.39" - -sys_util = { path = "../../../external/crosvm/sys_util" } # provided by ebuild -tempfile = { path = "../../../external/crosvm/tempfile" } # provided by ebuild +sys_util = "*" +tempfile = "*" vm-memory = { version = "0.2.0", optional = true } [dev-dependencies] vm-memory = { version = "0.2.0", features=["backend-mmap"] } + +[patch.crates-io] +sys_util = { path = "../../../external/crosvm/sys_util" } # ignored by ebuild +tempfile = { path = "../../../external/crosvm/tempfile" } # ignored by ebuild @@ -9,11 +9,11 @@ third_party { type: GIT value: "https://chromium.googlesource.com/chromiumos/third_party/rust-vmm/vhost" } - version: "eaca5d36a2701c99b354ab5bc0954a78dfc9ff4f" + version: "d65bd280d9f4e192a884f1761e4b097c11aae6de" license_type: NOTICE last_upgrade_date { year: 2021 - month: 5 - day: 19 + month: 9 + day: 22 } } diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json index 2b2c164..c3e6939 100644 --- a/coverage_config_x86_64.json +++ b/coverage_config_x86_64.json @@ -1 +1 @@ -{"coverage_score": 81.2, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"} +{"coverage_score": 82.3, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"} diff --git a/rust-vmm-ci b/rust-vmm-ci -Subproject 24d66cdae63d4aa7f8de01b616c015b97604a11 +Subproject d2ab3c090833aec72eee7da1e3884032206b00e diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs index f92db45..ea8461a 100644 --- a/src/vhost_user/connection.rs +++ b/src/vhost_user/connection.rs @@ -5,9 +5,10 @@ #![allow(dead_code)] +use std::fs::File; use std::io::ErrorKind; use std::marker::PhantomData; -use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::{UnixListener, UnixStream}; use std::path::{Path, PathBuf}; use std::{mem, slice}; @@ -301,7 +302,7 @@ impl<R: Req> Endpoint<R> { } /// Reads bytes from the socket into the given scatter/gather vectors with optional attached - /// file descriptors. + /// file. /// /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little /// tricky to pass file descriptors through such a communication channel. Let's assume that a @@ -311,29 +312,37 @@ impl<R: Req> Endpoint<R> { /// 2) message(packet) boundaries must be respected on the receive side. /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the /// attached file descriptors will get lost. + /// Note that this function wraps received file descriptors as `File`. /// /// # Return: - /// * - (number of bytes received, [received fds]) on success + /// * - (number of bytes received, [received files]) on success /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. - pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> { + pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<File>>)> { let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES]; let (bytes, fds) = self.sock.recv_iovecs_with_fds(iovs, &mut fd_array)?; - let rfds = match fds { + + let files = match fds { 0 => None, n => { - let mut fds = Vec::with_capacity(n); - fds.extend_from_slice(&fd_array[0..n]); - Some(fds) + let files = fd_array + .iter() + .take(n) + .map(|fd| { + // Safe because we have the ownership of `fd`. + unsafe { File::from_raw_fd(*fd) } + }) + .collect(); + Some(files) } }; - Ok((bytes, rfds)) + Ok((bytes, files)) } /// Reads all bytes from the socket into the given scatter/gather vectors with optional - /// attached file descriptors. Will loop until all data has been transfered. + /// attached files. Will loop until all data has been transferred. /// /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little /// tricky to pass file descriptors through such a communication channel. Let's assume that a @@ -343,6 +352,7 @@ impl<R: Req> Endpoint<R> { /// 2) message(packet) boundaries must be respected on the receive side. /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the /// attached file descriptors will get lost. + /// Note that this function wraps received file descriptors as `File`. /// /// # Return: /// * - (number of bytes received, [received fds]) on success @@ -351,7 +361,7 @@ impl<R: Req> Endpoint<R> { pub fn recv_into_iovec_all( &mut self, iovs: &mut [iovec], - ) -> Result<(usize, Option<Vec<RawFd>>)> { + ) -> Result<(usize, Option<Vec<File>>)> { let mut data_read = 0; let mut data_total = 0; let mut rfds = None; @@ -392,46 +402,46 @@ impl<R: Req> Endpoint<R> { } /// Reads bytes from the socket into a new buffer with optional attached - /// file descriptors. Received file descriptors are set close-on-exec. + /// files. Received file descriptors are set close-on-exec and converted to `File`. /// /// # Return: - /// * - (number of bytes received, buf, [received fds]) on success. + /// * - (number of bytes received, buf, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. pub fn recv_into_buf( &mut self, buf_size: usize, - ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> { + ) -> Result<(usize, Vec<u8>, Option<Vec<File>>)> { let mut buf = vec![0u8; buf_size]; - let (bytes, rfds) = { + let (bytes, files) = { let mut iovs = [iovec { iov_base: buf.as_mut_ptr() as *mut c_void, iov_len: buf_size, }]; self.recv_into_iovec(&mut iovs)? }; - Ok((bytes, buf, rfds)) + Ok((bytes, buf, files)) } - /// Receive a header-only message with optional attached file descriptors. + /// Receive a header-only message with optional attached files. /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be /// accepted and all other file descriptor will be discard silently. /// /// # Return: - /// * - (message header, [received fds]) on success. + /// * - (message header, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. /// * - PartialMessage: received a partial message. /// * - InvalidMessage: received a invalid message. - pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> { + pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<File>>)> { let mut hdr = VhostUserMsgHeader::default(); let mut iovs = [iovec { iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), }]; - let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?; if bytes != mem::size_of::<VhostUserMsgHeader<R>>() { return Err(Error::PartialMessage); @@ -439,7 +449,7 @@ impl<R: Req> Endpoint<R> { return Err(Error::InvalidMessage); } - Ok((hdr, rfds)) + Ok((hdr, files)) } /// Receive a message with optional attached file descriptors. @@ -447,7 +457,7 @@ impl<R: Req> Endpoint<R> { /// accepted and all other file descriptor will be discard silently. /// /// # Return: - /// * - (message header, message body, [received fds]) on success. + /// * - (message header, message body, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. @@ -455,7 +465,7 @@ impl<R: Req> Endpoint<R> { /// * - InvalidMessage: received a invalid message. pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>( &mut self, - ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> { + ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<File>>)> { let mut hdr = VhostUserMsgHeader::default(); let mut body: T = Default::default(); let mut iovs = [ @@ -468,7 +478,7 @@ impl<R: Req> Endpoint<R> { iov_len: mem::size_of::<T>(), }, ]; - let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?; let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); if bytes != total { @@ -477,7 +487,7 @@ impl<R: Req> Endpoint<R> { return Err(Error::InvalidMessage); } - Ok((hdr, body, rfds)) + Ok((hdr, body, files)) } /// Receive a message with header and optional content. Callers need to @@ -488,7 +498,7 @@ impl<R: Req> Endpoint<R> { /// silently. /// /// # Return: - /// * - (message header, message size, [received fds]) on success. + /// * - (message header, message size, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. @@ -497,7 +507,7 @@ impl<R: Req> Endpoint<R> { pub fn recv_body_into_buf( &mut self, buf: &mut [u8], - ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> { + ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<File>>)> { let mut hdr = VhostUserMsgHeader::default(); let mut iovs = [ iovec { @@ -509,7 +519,7 @@ impl<R: Req> Endpoint<R> { iov_len: buf.len(), }, ]; - let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?; if bytes < mem::size_of::<VhostUserMsgHeader<R>>() { return Err(Error::PartialMessage); @@ -517,7 +527,7 @@ impl<R: Req> Endpoint<R> { return Err(Error::InvalidMessage); } - Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds)) + Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), files)) } /// Receive a message with optional payload and attached file descriptors. @@ -525,7 +535,7 @@ impl<R: Req> Endpoint<R> { /// accepted and all other file descriptor will be discard silently. /// /// # Return: - /// * - (message header, message body, size of payload, [received fds]) on success. + /// * - (message header, message body, size of payload, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. @@ -535,7 +545,7 @@ impl<R: Req> Endpoint<R> { pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>( &mut self, buf: &mut [u8], - ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> { + ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<File>>)> { let mut hdr = VhostUserMsgHeader::default(); let mut body: T = Default::default(); let mut iovs = [ @@ -552,7 +562,7 @@ impl<R: Req> Endpoint<R> { iov_len: buf.len(), }, ]; - let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?; let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); if bytes < total { @@ -561,17 +571,7 @@ impl<R: Req> Endpoint<R> { return Err(Error::InvalidMessage); } - Ok((hdr, body, bytes - total, rfds)) - } - - /// Close all raw file descriptors. - pub fn close_rfds(rfds: Option<Vec<RawFd>>) { - if let Some(fds) = rfds { - for fd in fds { - // safe because the rawfds are valid and we don't care about the result. - let _ = unsafe { libc::close(fd) }; - } - } + Ok((hdr, body, bytes - total, files)) } } @@ -604,7 +604,6 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) { #[cfg(test)] mod tests { use super::*; - use std::fs::File; use std::io::{Read, Seek, SeekFrom, Write}; use std::os::unix::io::FromRawFd; use tempfile::{tempfile, Builder, TempDir}; @@ -685,14 +684,14 @@ mod tests { .unwrap(); assert_eq!(len, 4); - let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(4).unwrap(); assert_eq!(bytes, 4); assert_eq!(&buf1[..], &buf2[..]); - assert!(rfds.is_some()); - let fds = rfds.unwrap(); + assert!(files.is_some()); + let files = files.unwrap(); { - assert_eq!(fds.len(), 1); - let mut file = unsafe { File::from_raw_fd(fds[0]) }; + assert_eq!(files.len(), 1); + let mut file = &files[0]; let mut content = String::new(); file.seek(SeekFrom::Start(0)).unwrap(); file.read_to_string(&mut content).unwrap(); @@ -710,23 +709,23 @@ mod tests { .unwrap(); assert_eq!(len, 4); - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[..2], &buf2[..]); - assert!(rfds.is_some()); - let fds = rfds.unwrap(); + assert!(files.is_some()); + let files = files.unwrap(); { - assert_eq!(fds.len(), 3); - let mut file = unsafe { File::from_raw_fd(fds[1]) }; + assert_eq!(files.len(), 3); + let mut file = &files[1]; let mut content = String::new(); file.seek(SeekFrom::Start(0)).unwrap(); file.read_to_string(&mut content).unwrap(); assert_eq!(content, "test"); } - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[2..], &buf2[..]); - assert!(rfds.is_none()); + assert!(files.is_none()); // Following communication pattern should not work: // Sending side: data(header, body) with fds @@ -742,10 +741,10 @@ mod tests { let (bytes, buf4) = slave.recv_data(2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[..2], &buf4[..]); - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[2..], &buf2[..]); - assert!(rfds.is_none()); + assert!(files.is_none()); // Following communication pattern should work: // Sending side: data, data with fds @@ -760,28 +759,28 @@ mod tests { .unwrap(); assert_eq!(len, 4); - let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x4).unwrap(); assert_eq!(bytes, 4); assert_eq!(&buf1[..], &buf2[..]); - assert!(rfds.is_none()); + assert!(files.is_none()); - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[..2], &buf2[..]); - assert!(rfds.is_some()); - let fds = rfds.unwrap(); + assert!(files.is_some()); + let files = files.unwrap(); { - assert_eq!(fds.len(), 3); - let mut file = unsafe { File::from_raw_fd(fds[1]) }; + assert_eq!(files.len(), 3); + let mut file = &files[1]; let mut content = String::new(); file.seek(SeekFrom::Start(0)).unwrap(); file.read_to_string(&mut content).unwrap(); assert_eq!(content, "test"); } - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[2..], &buf2[..]); - assert!(rfds.is_none()); + assert!(files.is_none()); // Following communication pattern should not work: // Sending side: data1, data2 with fds @@ -799,9 +798,9 @@ mod tests { let (bytes, _) = slave.recv_data(5).unwrap(); assert_eq!(bytes, 5); - let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap(); + let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap(); assert_eq!(bytes, 3); - assert!(rfds.is_none()); + assert!(files.is_none()); // If the target fd array is too small, extra file descriptors will get lost. let len = master @@ -812,12 +811,9 @@ mod tests { .unwrap(); assert_eq!(len, 4); - let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap(); + let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap(); assert_eq!(bytes, 4); - assert!(rfds.is_some()); - - Endpoint::<MasterReq>::close_rfds(rfds); - Endpoint::<MasterReq>::close_rfds(None); + assert!(files.is_some()); } #[test] @@ -844,15 +840,15 @@ mod tests { mem::size_of::<u64>(), ) }; - let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap(); + let (hdr2, bytes, files) = slave.recv_body_into_buf(slice).unwrap(); assert_eq!(hdr1, hdr2); assert_eq!(bytes, 8); assert_eq!(features1, features2); - assert!(rfds.is_none()); + assert!(files.is_none()); master.send_header(&hdr1, None).unwrap(); - let (hdr2, rfds) = slave.recv_header().unwrap(); + let (hdr2, files) = slave.recv_header().unwrap(); assert_eq!(hdr1, hdr2); - assert!(rfds.is_none()); + assert!(files.is_none()); } } diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs index b2b83d2..cc9a9fb 100644 --- a/src/vhost_user/dummy_slave.rs +++ b/src/vhost_user/dummy_slave.rs @@ -1,7 +1,7 @@ // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use std::os::unix::io::RawFd; +use std::fs::File; use super::message::*; use super::*; @@ -20,11 +20,12 @@ pub struct DummySlaveReqHandler { pub queue_num: usize, pub vring_num: [u32; MAX_QUEUE_NUM], pub vring_base: [u32; MAX_QUEUE_NUM], - pub call_fd: [Option<RawFd>; MAX_QUEUE_NUM], - pub kick_fd: [Option<RawFd>; MAX_QUEUE_NUM], - pub err_fd: [Option<RawFd>; MAX_QUEUE_NUM], + pub call_fd: [Option<File>; MAX_QUEUE_NUM], + pub kick_fd: [Option<File>; MAX_QUEUE_NUM], + pub err_fd: [Option<File>; MAX_QUEUE_NUM], pub vring_started: [bool; MAX_QUEUE_NUM], pub vring_enabled: [bool; MAX_QUEUE_NUM], + pub inflight_file: Option<File>, } impl DummySlaveReqHandler { @@ -83,7 +84,7 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { Ok(()) } - fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> { + fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _files: Vec<File>) -> Result<()> { Ok(()) } @@ -134,14 +135,10 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { )) } - fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> { if index as usize >= self.queue_num || index as usize > self.queue_num { return Err(Error::InvalidParam); } - if self.kick_fd[index as usize].is_some() { - // Close file descriptor set by previous operations. - let _ = unsafe { libc::close(self.kick_fd[index as usize].unwrap()) }; - } self.kick_fd[index as usize] = fd; // Quotation from vhost-user spec: @@ -155,26 +152,18 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { Ok(()) } - fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> { if index as usize >= self.queue_num || index as usize > self.queue_num { return Err(Error::InvalidParam); } - if self.call_fd[index as usize].is_some() { - // Close file descriptor set by previous operations. - let _ = unsafe { libc::close(self.call_fd[index as usize].unwrap()) }; - } self.call_fd[index as usize] = fd; Ok(()) } - fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> { if index as usize >= self.queue_num || index as usize > self.queue_num { return Err(Error::InvalidParam); } - if self.err_fd[index as usize].is_some() { - // Close file descriptor set by previous operations. - let _ = unsafe { libc::close(self.err_fd[index as usize].unwrap()) }; - } self.err_fd[index as usize] = fd; Ok(()) } @@ -245,11 +234,32 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { Ok(()) } + fn get_inflight_fd( + &mut self, + inflight: &VhostUserInflight, + ) -> Result<(VhostUserInflight, File)> { + let file = tempfile::tempfile().unwrap(); + self.inflight_file = Some(file.try_clone().unwrap()); + Ok(( + VhostUserInflight { + mmap_size: 0x1000, + mmap_offset: 0, + num_queues: inflight.num_queues, + queue_size: inflight.queue_size, + }, + file, + )) + } + + fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> Result<()> { + Ok(()) + } + fn get_max_mem_slots(&mut self) -> Result<u64> { Ok(MAX_MEM_SLOTS as u64) } - fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: RawFd) -> Result<()> { + fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: File) -> Result<()> { Ok(()) } diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs index 16f0e02..9a65fbe 100644 --- a/src/vhost_user/master.rs +++ b/src/vhost_user/master.rs @@ -3,6 +3,7 @@ //! Traits and Struct for vhost-user master. +use std::fs::File; use std::mem; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::UnixStream; @@ -13,7 +14,7 @@ use sys_util::EventFd; use super::connection::Endpoint; use super::message::*; -use super::{Error as VhostUserError, Result as VhostUserResult}; +use super::{take_single_file, Error as VhostUserError, Result as VhostUserResult}; use crate::backend::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData}; use crate::{Error, Result}; @@ -49,7 +50,16 @@ pub trait VhostUserMaster: VhostBackend { fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()>; /// Setup slave communication channel. - fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()>; + fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()>; + + /// Retrieve shared buffer for inflight I/O tracking. + fn get_inflight_fd( + &mut self, + inflight: &VhostUserInflight, + ) -> Result<(VhostUserInflight, File)>; + + /// Set shared buffer for inflight I/O tracking. + fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()>; /// Query the maximum amount of memory slots supported by the backend. fn get_max_mem_slots(&mut self) -> Result<u64>; @@ -84,6 +94,7 @@ impl Master { protocol_features_ready: false, max_queue_num, error: None, + hdr_flags: VhostUserHeaderFlag::empty(), })), } } @@ -125,6 +136,12 @@ impl Master { Ok(Self::new(endpoint, max_queue_num)) } + + /// Set the header flags that should be applied to all following messages. + pub fn set_hdr_flags(&self, flags: VhostUserHeaderFlag) { + let mut node = self.node(); + node.hdr_flags = flags; + } } impl VhostBackend for Master { @@ -141,11 +158,9 @@ impl VhostBackend for Master { fn set_features(&self, features: u64) -> Result<()> { let mut node = self.node(); let val = VhostUserU64::new(features); - let _ = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?; - // Don't wait for ACK here because the protocol feature negotiation process hasn't been - // completed yet. + let hdr = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?; node.acked_virtio_features = features & node.virtio_features; - Ok(()) + node.wait_for_ack(&hdr).map_err(|e| e.into()) } /// Set the current Master as an owner of the session. @@ -153,18 +168,14 @@ impl VhostBackend for Master { // We unwrap() the return value to assert that we are not expecting threads to ever fail // while holding the lock. let mut node = self.node(); - let _ = node.send_request_header(MasterReq::SET_OWNER, None)?; - // Don't wait for ACK here because the protocol feature negotiation process hasn't been - // completed yet. - Ok(()) + let hdr = node.send_request_header(MasterReq::SET_OWNER, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) } fn reset_owner(&self) -> Result<()> { let mut node = self.node(); - let _ = node.send_request_header(MasterReq::RESET_OWNER, None)?; - // Don't wait for ACK here because the protocol feature negotiation process hasn't been - // completed yet. - Ok(()) + let hdr = node.send_request_header(MasterReq::RESET_OWNER, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) } /// Set the memory map regions on the slave so it can translate the vring @@ -220,8 +231,8 @@ impl VhostBackend for Master { fn set_log_fd(&self, fd: RawFd) -> Result<()> { let mut node = self.node(); let fds = [fd]; - node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?; - Ok(()) + let hdr = node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) } /// Set the size of the queue. @@ -283,8 +294,8 @@ impl VhostBackend for Master { if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } - node.send_fd_for_vring(MasterReq::SET_VRING_CALL, queue_index, fd.as_raw_fd())?; - Ok(()) + let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_CALL, queue_index, fd.as_raw_fd())?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) } /// Set the event file descriptor for adding buffers to the vring. @@ -296,8 +307,8 @@ impl VhostBackend for Master { if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } - node.send_fd_for_vring(MasterReq::SET_VRING_KICK, queue_index, fd.as_raw_fd())?; - Ok(()) + let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_KICK, queue_index, fd.as_raw_fd())?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) } /// Set the event file descriptor to signal when error occurs. @@ -308,8 +319,8 @@ impl VhostBackend for Master { if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } - node.send_fd_for_vring(MasterReq::SET_VRING_ERR, queue_index, fd.as_raw_fd())?; - Ok(()) + let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_ERR, queue_index, fd.as_raw_fd())?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) } } @@ -317,7 +328,7 @@ impl VhostUserMaster for Master { fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { let mut node = self.node(); let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); - if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 { + if node.virtio_features & flag == 0 { return error_code(VhostUserError::InvalidOperation); } let hdr = node.send_request_header(MasterReq::GET_PROTOCOL_FEATURES, None)?; @@ -334,16 +345,16 @@ impl VhostUserMaster for Master { fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> { let mut node = self.node(); let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); - if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 { + if node.virtio_features & flag == 0 { return error_code(VhostUserError::InvalidOperation); } let val = VhostUserU64::new(features.bits()); - let _ = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?; + let hdr = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?; // Don't wait for ACK here because the protocol feature negotiation process hasn't been // completed yet. node.acked_protocol_features = features.bits(); node.protocol_features_ready = true; - Ok(()) + node.wait_for_ack(&hdr).map_err(|e| e.into()) } fn get_queue_num(&mut self) -> Result<u64> { @@ -401,7 +412,6 @@ impl VhostUserMaster for Master { let (body_reply, buf_reply, rfds) = node.recv_reply_with_payload::<VhostUserConfig>(&hdr)?; if rfds.is_some() { - Endpoint::<MasterReq>::close_rfds(rfds); return error_code(VhostUserError::InvalidMessage); } else if body_reply.size == 0 { return error_code(VhostUserError::SlaveInternalError); @@ -434,15 +444,47 @@ impl VhostUserMaster for Master { node.wait_for_ack(&hdr).map_err(|e| e.into()) } - fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> { + fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()> { let mut node = self.node(); if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { return error_code(VhostUserError::InvalidOperation); } + let fds = [fd.as_raw_fd()]; + let hdr = node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } - let fds = [fd]; - node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?; - Ok(()) + fn get_inflight_fd( + &mut self, + inflight: &VhostUserInflight, + ) -> Result<(VhostUserInflight, File)> { + let mut node = self.node(); + if node.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } + + let hdr = node.send_request_with_body(MasterReq::GET_INFLIGHT_FD, inflight, None)?; + let (inflight, files) = node.recv_reply_with_files::<VhostUserInflight>(&hdr)?; + + match take_single_file(files) { + Some(file) => Ok((inflight, file)), + None => error_code(VhostUserError::IncorrectFds), + } + } + + fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()> { + let mut node = self.node(); + if node.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } + + if inflight.mmap_size == 0 || inflight.num_queues == 0 || inflight.queue_size == 0 || fd < 0 + { + return error_code(VhostUserError::InvalidParam); + } + + let hdr = node.send_request_with_body(MasterReq::SET_INFLIGHT_FD, inflight, Some(&[fd]))?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) } fn get_max_mem_slots(&mut self) -> Result<u64> { @@ -546,6 +588,8 @@ struct MasterInternal { max_queue_num: u64, // Internal flag to mark failure state. error: Option<i32>, + // List of header flags. + hdr_flags: VhostUserHeaderFlag, } impl MasterInternal { @@ -555,7 +599,7 @@ impl MasterInternal { fds: Option<&[RawFd]>, ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { self.check_state()?; - let hdr = Self::new_request_header(code, 0); + let hdr = self.new_request_header(code, 0); self.main_sock.send_header(&hdr, fds)?; Ok(hdr) } @@ -571,7 +615,7 @@ impl MasterInternal { } self.check_state()?; - let hdr = Self::new_request_header(code, mem::size_of::<T>() as u32); + let hdr = self.new_request_header(code, mem::size_of::<T>() as u32); self.main_sock.send_message(&hdr, msg, fds)?; Ok(hdr) } @@ -594,7 +638,7 @@ impl MasterInternal { } self.check_state()?; - let hdr = Self::new_request_header(code, len as u32); + let hdr = self.new_request_header(code, len as u32); self.main_sock .send_message_with_payload(&hdr, msg, payload, fds)?; Ok(hdr) @@ -615,7 +659,7 @@ impl MasterInternal { // This flag is set when there is no file descriptor in the ancillary data. This signals // that polling will be used instead of waiting for the call. let msg = VhostUserU64::new(queue_index as u64); - let hdr = Self::new_request_header(code, mem::size_of::<VhostUserU64>() as u32); + let hdr = self.new_request_header(code, mem::size_of::<VhostUserU64>() as u32); self.main_sock.send_message(&hdr, &msg, Some(&[fd]))?; Ok(hdr) } @@ -631,16 +675,31 @@ impl MasterInternal { let (reply, body, rfds) = self.main_sock.recv_body::<T>()?; if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { - Endpoint::<MasterReq>::close_rfds(rfds); return Err(VhostUserError::InvalidMessage); } Ok(body) } + fn recv_reply_with_files<T: Sized + Default + VhostUserMsgValidator>( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + ) -> VhostUserResult<(T, Option<Vec<File>>)> { + if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let (reply, body, files) = self.main_sock.recv_body::<T>()?; + if !reply.is_reply_for(&hdr) || files.is_none() || !body.is_valid() { + return Err(VhostUserError::InvalidMessage); + } + Ok((body, files)) + } + fn recv_reply_with_payload<T: Sized + Default + VhostUserMsgValidator>( &mut self, hdr: &VhostUserMsgHeader<MasterReq>, - ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<RawFd>>)> { + ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<File>>)> { if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.get_size() as usize <= mem::size_of::<T>() || hdr.get_size() as usize > MAX_MSG_SIZE @@ -651,18 +710,17 @@ impl MasterInternal { self.check_state()?; let mut buf: Vec<u8> = vec![0; hdr.get_size() as usize - mem::size_of::<T>()]; - let (reply, body, bytes, rfds) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?; + let (reply, body, bytes, files) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?; if !reply.is_reply_for(hdr) || reply.get_size() as usize != mem::size_of::<T>() + bytes - || rfds.is_some() + || files.is_some() || !body.is_valid() + || bytes != buf.len() { - Endpoint::<MasterReq>::close_rfds(rfds); - return Err(VhostUserError::InvalidMessage); - } else if bytes != buf.len() { return Err(VhostUserError::InvalidMessage); } - Ok((body, buf, rfds)) + + Ok((body, buf, files)) } fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<MasterReq>) -> VhostUserResult<()> { @@ -675,7 +733,6 @@ impl MasterInternal { let (reply, body, rfds) = self.main_sock.recv_body::<VhostUserU64>()?; if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { - Endpoint::<MasterReq>::close_rfds(rfds); return Err(VhostUserError::InvalidMessage); } if body.value != 0 { @@ -698,9 +755,8 @@ impl MasterInternal { } #[inline] - fn new_request_header(request: MasterReq, size: u32) -> VhostUserMsgHeader<MasterReq> { - // TODO: handle NEED_REPLY flag - VhostUserMsgHeader::new(request, 0x1, size) + fn new_request_header(&self, request: MasterReq, size: u32) -> VhostUserMsgHeader<MasterReq> { + VhostUserMsgHeader::new(request, self.hdr_flags.bits() | 0x1, size) } } diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs index 8cba188..0ecda4e 100644 --- a/src/vhost_user/master_req_handler.rs +++ b/src/vhost_user/master_req_handler.rs @@ -1,6 +1,7 @@ // Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use std::fs::File; use std::mem; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::UnixStream; @@ -33,9 +34,7 @@ pub trait VhostUserMasterReqHandler { } /// Handle virtio-fs map file requests. - fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { - // Safe because we have just received the rawfd from kernel. - unsafe { libc::close(fd) }; + fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } @@ -50,14 +49,12 @@ pub trait VhostUserMasterReqHandler { } /// Handle virtio-fs file IO requests. - fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { - // Safe because we have just received the rawfd from kernel. - unsafe { libc::close(fd) }; + fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); - // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); + // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: &dyn AsRawFd); } /// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability. @@ -70,9 +67,7 @@ pub trait VhostUserMasterReqHandlerMut { } /// Handle virtio-fs map file requests. - fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { - // Safe because we have just received the rawfd from kernel. - unsafe { libc::close(fd) }; + fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } @@ -87,9 +82,7 @@ pub trait VhostUserMasterReqHandlerMut { } /// Handle virtio-fs file IO requests. - fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { - // Safe because we have just received the rawfd from kernel. - unsafe { libc::close(fd) }; + fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } @@ -102,7 +95,7 @@ impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> { self.lock().unwrap().handle_config_change() } - fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> { self.lock().unwrap().fs_slave_map(fs, fd) } @@ -114,7 +107,7 @@ impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> { self.lock().unwrap().fs_slave_sync(fs) } - fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> { self.lock().unwrap().fs_slave_io(fs, fd) } } @@ -206,8 +199,8 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { // . recv optional message body and payload according size field in // message header // . validate message body and optional payload - let (hdr, rfds) = self.sub_sock.recv_header()?; - let rfds = self.check_attached_rfds(&hdr, rfds)?; + let (hdr, files) = self.sub_sock.recv_header()?; + self.check_attached_files(&hdr, &files)?; let (size, buf) = match hdr.get_size() { 0 => (0, vec![0u8; 0]), len => { @@ -231,9 +224,9 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { } SlaveReq::FS_MAP => { let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; - // check_attached_rfds() has validated rfds + // check_attached_files() has validated files self.backend - .fs_slave_map(&msg, rfds.unwrap()[0]) + .fs_slave_map(&msg, &files.unwrap()[0]) .map_err(Error::ReqHandlerError) } SlaveReq::FS_UNMAP => { @@ -250,9 +243,9 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { } SlaveReq::FS_IO => { let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; - // check_attached_rfds() has validated rfds + // check_attached_files() has validated files self.backend - .fs_slave_io(&msg, rfds.unwrap()[0]) + .fs_slave_io(&msg, &files.unwrap()[0]) .map_err(Error::ReqHandlerError) } _ => Err(Error::InvalidMessage), @@ -286,34 +279,21 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { Ok(()) } - fn check_attached_rfds( + fn check_attached_files( &self, hdr: &VhostUserMsgHeader<SlaveReq>, - rfds: Option<Vec<RawFd>>, - ) -> Result<Option<Vec<RawFd>>> { + files: &Option<Vec<File>>, + ) -> Result<()> { match hdr.get_code() { SlaveReq::FS_MAP | SlaveReq::FS_IO => { - // Expect an fd set with a single fd. - match rfds { - None => Err(Error::InvalidMessage), - Some(fds) => { - if fds.len() != 1 { - Endpoint::<SlaveReq>::close_rfds(Some(fds)); - Err(Error::InvalidMessage) - } else { - Ok(Some(fds)) - } - } - } - } - _ => { - if rfds.is_some() { - Endpoint::<SlaveReq>::close_rfds(rfds); - Err(Error::InvalidMessage) - } else { - Ok(rfds) + // Expect a single file is passed. + match files { + Some(files) if files.len() == 1 => Ok(()), + _ => Err(Error::InvalidMessage), } } + _ if files.is_some() => Err(Error::InvalidMessage), + _ => Ok(()), } } @@ -390,9 +370,11 @@ mod tests { impl VhostUserMasterReqHandlerMut for MockMasterReqHandler { /// Handle virtio-fs map file requests from the slave. - fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { - // Safe because we have just received the rawfd from kernel. - unsafe { libc::close(fd) }; + fn fs_slave_map( + &mut self, + _fs: &VhostUserFSSlaveMsg, + _fd: &dyn AsRawFd, + ) -> HandlerResult<u64> { Ok(0) } @@ -437,7 +419,7 @@ mod tests { }); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd) .unwrap(); // When REPLY_ACK has not been negotiated, the master has no way to detect failure from // slave side. @@ -468,7 +450,7 @@ mod tests { fs_cache.set_reply_ack_flag(true); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd) .unwrap(); fs_cache .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) diff --git a/src/vhost_user/message.rs b/src/vhost_user/message.rs index 32b2f8c..fc33e1b 100644 --- a/src/vhost_user/message.rs +++ b/src/vhost_user/message.rs @@ -7,6 +7,7 @@ #![allow(dead_code)] #![allow(non_camel_case_types)] +#![allow(clippy::upper_case_acronyms)] use std::fmt::Debug; use std::marker::PhantomData; @@ -140,9 +141,9 @@ pub enum MasterReq { MAX_CMD = 41, } -impl Into<u32> for MasterReq { - fn into(self) -> u32 { - self as u32 +impl From<MasterReq> for u32 { + fn from(req: MasterReq) -> u32 { + req as u32 } } @@ -180,9 +181,9 @@ pub enum SlaveReq { MAX_CMD = 10, } -impl Into<u32> for SlaveReq { - fn into(self) -> u32 { - self as u32 +impl From<SlaveReq> for u32 { + fn from(req: SlaveReq) -> u32 { + req as u32 } } @@ -222,9 +223,8 @@ bitflags! { /// Common message header for vhost-user requests and replies. /// A vhost-user message consists of 3 header fields and an optional payload. All numbers are in the /// machine native byte order. -#[allow(safe_packed_borrows)] #[repr(packed)] -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Copy)] pub(super) struct VhostUserMsgHeader<R: Req> { request: u32, flags: u32, @@ -232,6 +232,28 @@ pub(super) struct VhostUserMsgHeader<R: Req> { _r: PhantomData<R>, } +impl<R: Req> Debug for VhostUserMsgHeader<R> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Point") + .field("request", &{ self.request }) + .field("flags", &{ self.flags }) + .field("size", &{ self.size }) + .finish() + } +} + +impl<R: Req> Clone for VhostUserMsgHeader<R> { + fn clone(&self) -> VhostUserMsgHeader<R> { + *self + } +} + +impl<R: Req> PartialEq for VhostUserMsgHeader<R> { + fn eq(&self, other: &Self) -> bool { + self.request == other.request && self.flags == other.flags && self.size == other.size + } +} + impl<R: Req> VhostUserMsgHeader<R> { /// Create a new instance of `VhostUserMsgHeader`. pub fn new(request: R, flags: u32, size: u32) -> Self { @@ -248,7 +270,7 @@ impl<R: Req> VhostUserMsgHeader<R> { /// Get message type. pub fn get_code(&self) -> R { // It's safe because R is marked as repr(u32). - unsafe { std::mem::transmute_copy::<u32, R>(&self.request) } + unsafe { std::mem::transmute_copy::<u32, R>(&{ self.request }) } } /// Set message type. @@ -673,6 +695,42 @@ impl VhostUserMsgValidator for VhostUserConfig { /// Payload for the VhostUserConfig message. pub type VhostUserConfigPayload = Vec<u8>; +/// Single memory region descriptor as payload for ADD_MEM_REG and REM_MEM_REG +/// requests. +#[repr(C)] +#[derive(Default, Clone)] +pub struct VhostUserInflight { + /// Size of the area to track inflight I/O. + pub mmap_size: u64, + /// Offset of this area from the start of the supplied file descriptor. + pub mmap_offset: u64, + /// Number of virtqueues. + pub num_queues: u16, + /// Size of virtqueues. + pub queue_size: u16, +} + +impl VhostUserInflight { + /// Create a new instance. + pub fn new(mmap_size: u64, mmap_offset: u64, num_queues: u16, queue_size: u16) -> Self { + VhostUserInflight { + mmap_size, + mmap_offset, + num_queues, + queue_size, + } + } +} + +impl VhostUserMsgValidator for VhostUserInflight { + fn is_valid(&self) -> bool { + if self.num_queues == 0 || self.queue_size == 0 { + return false; + } + true + } +} + /* * TODO: support dirty log, live migration and IOTLB operations. #[repr(packed)] @@ -744,6 +802,137 @@ impl VhostUserMsgValidator for VhostUserFSSlaveMsg { } } +/// Inflight I/O descriptor state for split virtqueues +#[repr(packed)] +#[derive(Clone, Copy, Default)] +pub struct DescStateSplit { + /// Indicate whether this descriptor (only head) is inflight or not. + pub inflight: u8, + /// Padding + padding: [u8; 5], + /// List of last batch of used descriptors, only when batching is used for submitting + pub next: u16, + /// Preserve order of fetching available descriptors, only for head descriptor + pub counter: u64, +} + +impl DescStateSplit { + /// New instance of DescStateSplit struct + pub fn new() -> Self { + Self::default() + } +} + +/// Inflight I/O queue region for split virtqueues +#[repr(packed)] +pub struct QueueRegionSplit { + /// Features flags of this region + pub features: u64, + /// Version of this region + pub version: u16, + /// Number of DescStateSplit entries + pub desc_num: u16, + /// List to track last batch of used descriptors + pub last_batch_head: u16, + /// Idx value of used ring + pub used_idx: u16, + /// Pointer to an array of DescStateSplit entries + pub desc: u64, +} + +impl QueueRegionSplit { + /// New instance of QueueRegionSplit struct + pub fn new(features: u64, queue_size: u16) -> Self { + QueueRegionSplit { + features, + version: 1, + desc_num: queue_size, + last_batch_head: 0, + used_idx: 0, + desc: 0, + } + } +} + +/// Inflight I/O descriptor state for packed virtqueues +#[repr(packed)] +#[derive(Clone, Copy, Default)] +pub struct DescStatePacked { + /// Indicate whether this descriptor (only head) is inflight or not. + pub inflight: u8, + /// Padding + padding: u8, + /// Link to next free entry + pub next: u16, + /// Link to last entry of descriptor list, only for head + pub last: u16, + /// Length of descriptor list, only for head + pub num: u16, + /// Preserve order of fetching avail descriptors, only for head + pub counter: u64, + /// Buffer ID + pub id: u16, + /// Descriptor flags + pub flags: u16, + /// Buffer length + pub len: u32, + /// Buffer address + pub addr: u64, +} + +impl DescStatePacked { + /// New instance of DescStatePacked struct + pub fn new() -> Self { + Self::default() + } +} + +/// Inflight I/O queue region for packed virtqueues +#[repr(packed)] +pub struct QueueRegionPacked { + /// Features flags of this region + pub features: u64, + /// version of this region + pub version: u16, + /// size of descriptor state array + pub desc_num: u16, + /// head of free DescStatePacked entry list + pub free_head: u16, + /// old head of free DescStatePacked entry list + pub old_free_head: u16, + /// used idx of descriptor ring + pub used_idx: u16, + /// old used idx of descriptor ring + pub old_used_idx: u16, + /// device ring wrap counter + pub used_wrap_counter: u8, + /// old device ring wrap counter + pub old_used_wrap_counter: u8, + /// Padding + padding: [u8; 7], + /// Pointer to array tracking state of each descriptor from descriptor ring + pub desc: u64, +} + +impl QueueRegionPacked { + /// New instance of QueueRegionPacked struct + pub fn new(features: u64, queue_size: u16) -> Self { + QueueRegionPacked { + features, + version: 1, + desc_num: queue_size, + free_head: 0, + old_free_head: 0, + used_idx: 0, + old_used_idx: 0, + used_wrap_counter: 0, + old_used_wrap_counter: 0, + padding: [0; 7], + desc: 0, + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -824,7 +1013,10 @@ mod tests { hdr.set_version(0x1); assert!(hdr.is_valid()); + // Test Debug, Clone, PartiaEq trait assert_eq!(hdr, hdr.clone()); + assert_eq!(hdr.clone().get_code(), hdr.get_code()); + assert_eq!(format!("{:?}", hdr.clone()), format!("{:?}", hdr)); } #[test] diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs index 9ef6453..5d8ce31 100644 --- a/src/vhost_user/mod.rs +++ b/src/vhost_user/mod.rs @@ -18,6 +18,7 @@ //! Most messages that can be sent via the Unix domain socket implementing vhost-user have an //! equivalent ioctl to the kernel implementation. +use std::fs::File; use std::io::Error as IOError; pub mod message; @@ -175,6 +176,16 @@ pub type Result<T> = std::result::Result<T, Error>; /// Result of request handler. pub type HandlerResult<T> = std::result::Result<T, IOError>; +/// Utility function to take the first element from option of a vector of files. +/// Returns `None` if the vector contains no file or more than one file. +pub(crate) fn take_single_file(files: Option<Vec<File>>) -> Option<File> { + let mut files = files?; + if files.len() != 1 { + return None; + } + Some(files.swap_remove(0)) +} + #[cfg(all(test, feature = "vhost-user-slave"))] mod dummy_slave; @@ -308,6 +319,11 @@ mod tests { VhostUserProtocolFeatures::all().bits() ); + // get_inflight_fd() + slave.handle_request().unwrap(); + // set_inflight_fd() + slave.handle_request().unwrap(); + // get_queue_num() slave.handle_request().unwrap(); @@ -360,6 +376,19 @@ mod tests { assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits()); master.set_protocol_features(features).unwrap(); + // Retrieve inflight I/O tracking information + let (inflight_info, inflight_file) = master + .get_inflight_fd(&VhostUserInflight { + num_queues: 2, + queue_size: 256, + ..Default::default() + }) + .unwrap(); + // Set the buffer back to the backend + master + .set_inflight_fd(&inflight_info, inflight_file.as_raw_fd()) + .unwrap(); + let num = master.get_queue_num().unwrap(); assert_eq!(num, 2); @@ -384,7 +413,7 @@ mod tests { assert_eq!(offset, 0x100); assert_eq!(reply_payload[0], 0xa5); - master.set_slave_request_fd(eventfd.as_raw_fd()).unwrap(); + master.set_slave_request_fd(&eventfd).unwrap(); master.set_vring_enable(0, true).unwrap(); // unimplemented yet diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs index a9c4ed2..ee5fd9b 100644 --- a/src/vhost_user/slave_fs_cache.rs +++ b/src/vhost_user/slave_fs_cache.rs @@ -3,7 +3,7 @@ use std::io; use std::mem; -use std::os::unix::io::RawFd; +use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::UnixStream; use std::sync::{Arc, Mutex, MutexGuard}; @@ -55,7 +55,6 @@ impl SlaveFsCacheReqInternal { let (reply, body, rfds) = self.sock.recv_body::<VhostUserU64>()?; if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { - Endpoint::<SlaveReq>::close_rfds(rfds); return Err(Error::InvalidMessage); } if body.value != 0 { @@ -129,8 +128,8 @@ impl SlaveFsCacheReq { impl VhostUserMasterReqHandler for SlaveFsCacheReq { /// Forward vhost-user-fs map file requests to the slave. - fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { - self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd])) + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> { + self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd.as_raw_fd()])) } /// Forward vhost-user-fs unmap file requests to the master. @@ -158,31 +157,21 @@ mod tests { #[test] fn test_slave_fs_cache_send_failure() { let (p1, p2) = UnixStream::pair().unwrap(); - let fd = p2.as_raw_fd(); let fs_cache = SlaveFsCacheReq::from_stream(p1); fs_cache.set_failed(libc::ECONNRESET); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &p2) .unwrap_err(); fs_cache .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) .unwrap_err(); fs_cache.node().error = None; - - drop(p2); - fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) - .unwrap_err(); - fs_cache - .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) - .unwrap_err(); } #[test] fn test_slave_fs_cache_recv_negative() { let (p1, p2) = UnixStream::pair().unwrap(); - let fd = p2.as_raw_fd(); let fs_cache = SlaveFsCacheReq::from_stream(p1); let mut master = Endpoint::<SlaveReq>::from_stream(p2); @@ -194,33 +183,35 @@ mod tests { ); let body = VhostUserU64::new(0); - master.send_message(&hdr, &body, Some(&[fd])).unwrap(); + master + .send_message(&hdr, &body, Some(&[master.as_raw_fd()])) + .unwrap(); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) .unwrap(); fs_cache.set_reply_ack_flag(true); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) .unwrap_err(); hdr.set_code(SlaveReq::FS_UNMAP); master.send_message(&hdr, &body, None).unwrap(); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) .unwrap_err(); hdr.set_code(SlaveReq::FS_MAP); let body = VhostUserU64::new(1); master.send_message(&hdr, &body, None).unwrap(); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) .unwrap_err(); let body = VhostUserU64::new(0); master.send_message(&hdr, &body, None).unwrap(); fs_cache - .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) .unwrap(); } } diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs index 9d7ea10..402030c 100644 --- a/src/vhost_user/slave_req_handler.rs +++ b/src/vhost_user/slave_req_handler.rs @@ -1,16 +1,16 @@ // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use std::fs::File; use std::mem; -use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::UnixStream; use std::slice; use std::sync::{Arc, Mutex}; use super::connection::Endpoint; use super::message::*; -use super::slave_fs_cache::SlaveFsCacheReq; -use super::{Error, Result}; +use super::{take_single_file, Error, Result}; /// Services provided to the master by the slave with interior mutability. /// @@ -38,7 +38,7 @@ pub trait VhostUserSlaveReqHandler { fn reset_owner(&self) -> Result<()>; fn get_features(&self) -> Result<u64>; fn set_features(&self, features: u64) -> Result<()>; - fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>; + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; fn set_vring_num(&self, index: u32, num: u32) -> Result<()>; fn set_vring_addr( &self, @@ -51,9 +51,9 @@ pub trait VhostUserSlaveReqHandler { ) -> Result<()>; fn set_vring_base(&self, index: u32, base: u32) -> Result<()>; fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>; - fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()>; - fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()>; - fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>; fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>; fn set_protocol_features(&self, features: u64) -> Result<()>; @@ -61,9 +61,11 @@ pub trait VhostUserSlaveReqHandler { fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>; fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>; fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; - fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {} + fn set_slave_req_fd(&self, _vu_req: File) {} + fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>; + fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>; fn get_max_mem_slots(&self) -> Result<u64>; - fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>; + fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>; } @@ -76,7 +78,7 @@ pub trait VhostUserSlaveReqHandlerMut { fn reset_owner(&mut self) -> Result<()>; fn get_features(&mut self) -> Result<u64>; fn set_features(&mut self, features: u64) -> Result<()>; - fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>; + fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>; fn set_vring_addr( &mut self, @@ -89,9 +91,9 @@ pub trait VhostUserSlaveReqHandlerMut { ) -> Result<()>; fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>; fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>; - fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; - fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; - fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>; fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; fn set_protocol_features(&mut self, features: u64) -> Result<()>; @@ -104,9 +106,14 @@ pub trait VhostUserSlaveReqHandlerMut { flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>; fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; - fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {} + fn set_slave_req_fd(&mut self, _vu_req: File) {} + fn get_inflight_fd( + &mut self, + inflight: &VhostUserInflight, + ) -> Result<(VhostUserInflight, File)>; + fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>; fn get_max_mem_slots(&mut self) -> Result<u64>; - fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>; + fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>; } @@ -127,8 +134,8 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { self.lock().unwrap().set_features(features) } - fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> { - self.lock().unwrap().set_mem_table(ctx, fds) + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> { + self.lock().unwrap().set_mem_table(ctx, files) } fn set_vring_num(&self, index: u32, num: u32) -> Result<()> { @@ -157,15 +164,15 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { self.lock().unwrap().get_vring_base(index) } - fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> { self.lock().unwrap().set_vring_kick(index, fd) } - fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> { self.lock().unwrap().set_vring_call(index, fd) } - fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> { self.lock().unwrap().set_vring_err(index, fd) } @@ -193,15 +200,23 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { self.lock().unwrap().set_config(offset, buf, flags) } - fn set_slave_req_fd(&self, vu_req: SlaveFsCacheReq) { + fn set_slave_req_fd(&self, vu_req: File) { self.lock().unwrap().set_slave_req_fd(vu_req) } + fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> { + self.lock().unwrap().get_inflight_fd(inflight) + } + + fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> { + self.lock().unwrap().set_inflight_fd(inflight, file) + } + fn get_max_mem_slots(&self) -> Result<u64> { self.lock().unwrap().get_max_mem_slots() } - fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()> { + fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> { self.lock().unwrap().add_mem_region(region, fd) } @@ -253,6 +268,11 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } } + /// Create a vhost-user slave endpoint from a connected socket. + pub fn from_stream(socket: UnixStream, backend: Arc<S>) -> Self { + Self::new(Endpoint::from_stream(socket), backend) + } + /// Create a new vhost-user slave endpoint. /// /// # Arguments @@ -286,8 +306,9 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { // . recv optional message body and payload according size field in // message header // . validate message body and optional payload - let (hdr, rfds) = self.main_sock.recv_header()?; - let rfds = self.check_attached_rfds(&hdr, rfds)?; + let (hdr, files) = self.main_sock.recv_header()?; + self.check_attached_files(&hdr, &files)?; + let (size, buf) = match hdr.get_size() { 0 => (0, vec![0u8; 0]), len => { @@ -302,11 +323,13 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { match hdr.get_code() { MasterReq::SET_OWNER => { self.check_request_size(&hdr, size, 0)?; - self.backend.set_owner()?; + let res = self.backend.set_owner(); + self.send_ack_message(&hdr, res)?; } MasterReq::RESET_OWNER => { self.check_request_size(&hdr, size, 0)?; - self.backend.reset_owner()?; + let res = self.backend.reset_owner(); + self.send_ack_message(&hdr, res)?; } MasterReq::GET_FEATURES => { self.check_request_size(&hdr, size, 0)?; @@ -318,12 +341,13 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } MasterReq::SET_FEATURES => { let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; - self.backend.set_features(msg.value)?; + let res = self.backend.set_features(msg.value); self.acked_virtio_features = msg.value; self.update_reply_ack_flag(); + self.send_ack_message(&hdr, res)?; } MasterReq::SET_MEM_TABLE => { - let res = self.set_mem_table(&hdr, size, &buf, rfds); + let res = self.set_mem_table(&hdr, size, &buf, files); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_NUM => { @@ -359,20 +383,20 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } MasterReq::SET_VRING_CALL => { self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; - let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.set_vring_call(index, rfds); + let (index, file) = self.handle_vring_fd_request(&buf, files)?; + let res = self.backend.set_vring_call(index, file); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_KICK => { self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; - let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.set_vring_kick(index, rfds); + let (index, file) = self.handle_vring_fd_request(&buf, files)?; + let res = self.backend.set_vring_kick(index, file); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_ERR => { self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; - let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.set_vring_err(index, rfds); + let (index, file) = self.handle_vring_fd_request(&buf, files)?; + let res = self.backend.set_vring_err(index, file); self.send_ack_message(&hdr, res)?; } MasterReq::GET_PROTOCOL_FEATURES => { @@ -385,9 +409,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } MasterReq::SET_PROTOCOL_FEATURES => { let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; - self.backend.set_protocol_features(msg.value)?; + let res = self.backend.set_protocol_features(msg.value); self.acked_protocol_features = msg.value; self.update_reply_ack_flag(); + self.send_ack_message(&hdr, res)?; } MasterReq::GET_QUEUE_NUM => { if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 { @@ -426,14 +451,40 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { return Err(Error::InvalidOperation); } self.check_request_size(&hdr, size, hdr.get_size() as usize)?; - self.set_config(&hdr, size, &buf)?; + let res = self.set_config(size, &buf); + self.send_ack_message(&hdr, res)?; } MasterReq::SET_SLAVE_REQ_FD => { if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { return Err(Error::InvalidOperation); } self.check_request_size(&hdr, size, hdr.get_size() as usize)?; - self.set_slave_req_fd(&hdr, rfds)?; + let res = self.set_slave_req_fd(files); + self.send_ack_message(&hdr, res)?; + } + MasterReq::GET_INFLIGHT_FD => { + if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() + == 0 + { + return Err(Error::InvalidOperation); + } + + let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; + let (inflight, file) = self.backend.get_inflight_fd(&msg)?; + let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?; + self.main_sock + .send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?; + } + MasterReq::SET_INFLIGHT_FD => { + if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() + == 0 + { + return Err(Error::InvalidOperation); + } + let file = take_single_file(files).ok_or(Error::IncorrectFds)?; + let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; + let res = self.backend.set_inflight_fd(&msg, file); + self.send_ack_message(&hdr, res)?; } MasterReq::GET_MAX_MEM_SLOTS => { if self.acked_protocol_features @@ -454,18 +505,13 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { { return Err(Error::InvalidOperation); } - let fd = if let Some(fds) = &rfds { - if fds.len() != 1 { - return Err(Error::InvalidParam); - } - fds[0] - } else { + let mut files = files.ok_or(Error::InvalidParam)?; + if files.len() != 1 { return Err(Error::InvalidParam); - }; - + } let msg = self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; - let res = self.backend.add_mem_region(&msg, fd); + let res = self.backend.add_mem_region(&msg, files.swap_remove(0)); self.send_ack_message(&hdr, res)?; } MasterReq::REM_MEM_REG => { @@ -493,37 +539,28 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], - rfds: Option<Vec<RawFd>>, + files: Option<Vec<File>>, ) -> Result<()> { self.check_request_size(&hdr, size, hdr.get_size() as usize)?; // check message size is consistent let hdrsize = mem::size_of::<VhostUserMemory>(); if size < hdrsize { - Endpoint::<MasterReq>::close_rfds(rfds); return Err(Error::InvalidMessage); } let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) }; if !msg.is_valid() { - Endpoint::<MasterReq>::close_rfds(rfds); return Err(Error::InvalidMessage); } if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() { - Endpoint::<MasterReq>::close_rfds(rfds); return Err(Error::InvalidMessage); } // validate number of fds matching number of memory regions - let fds = match rfds { - None => return Err(Error::InvalidMessage), - Some(fds) => { - if fds.len() != msg.num_regions as usize { - Endpoint::<MasterReq>::close_rfds(Some(fds)); - return Err(Error::InvalidMessage); - } - fds - } - }; + let files = files.ok_or(Error::InvalidMessage)?; + if files.len() != msg.num_regions as usize { + return Err(Error::InvalidMessage); + } // Validate memory regions let regions = unsafe { @@ -534,12 +571,11 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { }; for region in regions.iter() { if !region.is_valid() { - Endpoint::<MasterReq>::close_rfds(Some(fds)); return Err(Error::InvalidMessage); } } - self.backend.set_mem_table(®ions, &fds) + self.backend.set_mem_table(®ions, files) } fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> { @@ -580,12 +616,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { Ok(()) } - fn set_config( - &mut self, - hdr: &VhostUserMsgHeader<MasterReq>, - size: usize, - buf: &[u8], - ) -> Result<()> { + fn set_config(&mut self, size: usize, buf: &[u8]) -> Result<()> { if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() { return Err(Error::InvalidMessage); } @@ -602,35 +633,20 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { None => return Err(Error::InvalidMessage), } - let res = self.backend.set_config(msg.offset, buf, flags); - self.send_ack_message(&hdr, res)?; - Ok(()) + self.backend.set_config(msg.offset, buf, flags) } - fn set_slave_req_fd( - &mut self, - hdr: &VhostUserMsgHeader<MasterReq>, - rfds: Option<Vec<RawFd>>, - ) -> Result<()> { - if let Some(fds) = rfds { - if fds.len() == 1 { - let sock = unsafe { UnixStream::from_raw_fd(fds[0]) }; - let vu_req = SlaveFsCacheReq::from_stream(sock); - self.backend.set_slave_req_fd(vu_req); - self.send_ack_message(&hdr, Ok(())) - } else { - Err(Error::InvalidMessage) - } - } else { - Err(Error::InvalidMessage) - } + fn set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> { + let file = take_single_file(files).ok_or(Error::InvalidMessage)?; + self.backend.set_slave_req_fd(file); + Ok(()) } fn handle_vring_fd_request( &mut self, buf: &[u8], - rfds: Option<Vec<RawFd>>, - ) -> Result<(u8, Option<RawFd>)> { + files: Option<Vec<File>>, + ) -> Result<(u8, Option<File>)> { if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() { return Err(Error::InvalidMessage); } @@ -640,28 +656,19 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } // Bits (0-7) of the payload contain the vring index. Bit 8 is the - // invalid FD flag. This flag is set when there is no file descriptor + // invalid FD flag. This bit is set when there is no file descriptor // in the ancillary data. This signals that polling will be used // instead of waiting for the call. - let nofd = (msg.value & 0x100u64) == 0x100u64; - - let mut rfd = None; - match rfds { - Some(fds) => { - if !nofd && fds.len() == 1 { - rfd = Some(fds[0]); - } else if (nofd && !fds.is_empty()) || (!nofd && fds.len() != 1) { - Endpoint::<MasterReq>::close_rfds(Some(fds)); - return Err(Error::InvalidMessage); - } - } - None => { - if !nofd { - return Err(Error::InvalidMessage); - } - } + // If Bit 8 is unset, the data must contain a file descriptor. + let has_fd = (msg.value & 0x100u64) == 0; + + let file = take_single_file(files); + + if has_fd && file.is_none() || !has_fd && file.is_some() { + return Err(Error::InvalidMessage); } - Ok((msg.value as u8, rfd)) + + Ok((msg.value as u8, file)) } fn check_state(&self) -> Result<()> { @@ -687,29 +694,23 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { Ok(()) } - fn check_attached_rfds( + fn check_attached_files( &self, hdr: &VhostUserMsgHeader<MasterReq>, - rfds: Option<Vec<RawFd>>, - ) -> Result<Option<Vec<RawFd>>> { + files: &Option<Vec<File>>, + ) -> Result<()> { match hdr.get_code() { - MasterReq::SET_MEM_TABLE => Ok(rfds), - MasterReq::SET_VRING_CALL => Ok(rfds), - MasterReq::SET_VRING_KICK => Ok(rfds), - MasterReq::SET_VRING_ERR => Ok(rfds), - MasterReq::SET_LOG_BASE => Ok(rfds), - MasterReq::SET_LOG_FD => Ok(rfds), - MasterReq::SET_SLAVE_REQ_FD => Ok(rfds), - MasterReq::SET_INFLIGHT_FD => Ok(rfds), - MasterReq::ADD_MEM_REG => Ok(rfds), - _ => { - if rfds.is_some() { - Endpoint::<MasterReq>::close_rfds(rfds); - Err(Error::InvalidMessage) - } else { - Ok(rfds) - } - } + MasterReq::SET_MEM_TABLE + | MasterReq::SET_VRING_CALL + | MasterReq::SET_VRING_KICK + | MasterReq::SET_VRING_ERR + | MasterReq::SET_LOG_BASE + | MasterReq::SET_LOG_FD + | MasterReq::SET_SLAVE_REQ_FD + | MasterReq::SET_INFLIGHT_FD + | MasterReq::ADD_MEM_REG => Ok(()), + _ if files.is_some() => Err(Error::InvalidMessage), + _ => Ok(()), } } @@ -731,7 +732,6 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); let pflag = VhostUserProtocolFeatures::REPLY_ACK; if (self.virtio_features & vflag) != 0 - && (self.acked_virtio_features & vflag) != 0 && self.protocol_features.contains(pflag) && (self.acked_protocol_features & pflag.bits()) != 0 { @@ -774,7 +774,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { let msg = VhostUserU64::new(val); self.main_sock.send_message(&hdr, &msg, None)?; } - Ok(()) + res } fn send_reply_message<T>( |