diff options
author | Andrew Walbran <qwandor@google.com> | 2023-06-14 16:37:52 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2023-06-14 16:37:52 +0000 |
commit | 5cd661c16e064e6f55849e776e263950443c2ae8 (patch) | |
tree | b374b383add3775d221d8a98373ff7b8c216b48a | |
parent | 149b5d2ec015fac8d0ece693d7b627b25d2f6b0b (diff) | |
parent | de718b8f609090e37fff649b116c9ede2ba322ba (diff) | |
download | virtio-drivers-5cd661c16e064e6f55849e776e263950443c2ae8.tar.gz |
Update to 0.5.0. am: 91b9730f88 am: acd64406dc am: 23d2bf9625 am: d75e97c86b am: 9bfad29090 am: de718b8f60
Original change: https://android-review.googlesource.com/c/platform/external/rust/crates/virtio-drivers/+/2624430
Change-Id: If437f323c679e5ef4744d3512c2d27d150034723
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
-rw-r--r-- | .cargo_vcs_info.json | 2 | ||||
-rw-r--r-- | .github/workflows/main.yml | 7 | ||||
-rw-r--r-- | Android.bp | 12 | ||||
-rw-r--r-- | Cargo.toml | 6 | ||||
-rw-r--r-- | Cargo.toml.orig | 6 | ||||
-rw-r--r-- | METADATA | 8 | ||||
-rw-r--r-- | README.md | 7 | ||||
-rw-r--r-- | cargo2android.json | 2 | ||||
-rw-r--r-- | patches/Android.bp.patch | 21 | ||||
-rw-r--r-- | src/device/blk.rs | 17 | ||||
-rw-r--r-- | src/device/common.rs | 1 | ||||
-rw-r--r-- | src/device/console.rs | 41 | ||||
-rw-r--r-- | src/device/gpu.rs | 32 | ||||
-rw-r--r-- | src/device/mod.rs | 2 | ||||
-rw-r--r-- | src/device/net.rs | 11 | ||||
-rw-r--r-- | src/device/socket/mod.rs | 20 | ||||
-rw-r--r-- | src/device/socket/multiconnectionmanager.rs | 763 | ||||
-rw-r--r-- | src/device/socket/protocol.rs | 31 | ||||
-rw-r--r-- | src/device/socket/singleconnectionmanager.rs | 447 | ||||
-rw-r--r-- | src/device/socket/vsock.rs | 544 | ||||
-rw-r--r-- | src/hal.rs | 7 | ||||
-rw-r--r-- | src/lib.rs | 2 | ||||
-rw-r--r-- | src/queue.rs | 11 | ||||
-rw-r--r-- | src/transport/fake.rs | 29 | ||||
-rw-r--r-- | src/transport/mmio.rs | 7 | ||||
-rw-r--r-- | src/transport/mod.rs | 9 | ||||
-rw-r--r-- | src/transport/pci.rs | 7 | ||||
-rw-r--r-- | src/transport/pci/bus.rs | 2 |
28 files changed, 1651 insertions, 403 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json index d2f328c..ec9996c 100644 --- a/.cargo_vcs_info.json +++ b/.cargo_vcs_info.json @@ -1,6 +1,6 @@ { "git": { - "sha1": "8e52adace55c5e082ba2effffcb70bf480d76ec0" + "sha1": "de1c3b130e507702f13d142b9bee55670a4a2858" }, "path_in_vcs": "" }
\ No newline at end of file diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 025b7e9..b991d6b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -66,6 +66,7 @@ jobs: example: - aarch64 - riscv + - x86_64 include: - example: aarch64 toolchain: stable @@ -75,6 +76,10 @@ jobs: toolchain: nightly-2022-11-03 target: riscv64imac-unknown-none-elf packages: qemu-system-misc + - example: x86_64 + toolchain: nightly + target: x86_64-unknown-none + packages: qemu-system-x86 steps: - uses: actions/checkout@v2 - name: Install QEMU @@ -93,4 +98,4 @@ jobs: run: make kernel - name: Run working-directory: examples/${{ matrix.example }} - run: QEMU_ARGS="-display none" make qemu + run: QEMU_ARGS="-display none" make qemu accel="off" @@ -24,11 +24,12 @@ rust_library_rlib { name: "libvirtio_drivers", crate_name: "virtio_drivers", cargo_env_compat: true, - cargo_pkg_version: "0.4.0", + cargo_pkg_version: "0.5.0", srcs: ["src/lib.rs"], edition: "2018", + features: ["alloc"], rustlibs: [ - "libbitflags-1.3.2", + "libbitflags", "liblog_rust_nostd", "libzerocopy_nostd", ], @@ -48,14 +49,15 @@ rust_test { name: "virtio-drivers_test_src_lib", crate_name: "virtio_drivers", cargo_env_compat: true, - cargo_pkg_version: "0.4.0", + cargo_pkg_version: "0.5.0", srcs: ["src/lib.rs"], test_suites: ["general-tests"], auto_gen_config: true, edition: "2018", + features: ["alloc"], rustlibs: [ - "libbitflags-1.3.2", + "libbitflags", "liblog_rust", - "libzerocopy", + "libzerocopy_nostd", ], } @@ -12,7 +12,7 @@ [package] edition = "2018" name = "virtio-drivers" -version = "0.4.0" +version = "0.5.0" authors = [ "Jiajie Chen <noc@jiegec.ac.cn>", "Runji Wang <wangrunji0408@163.com>", @@ -30,7 +30,7 @@ license = "MIT" repository = "https://github.com/rcore-os/virtio-drivers" [dependencies.bitflags] -version = "1.3" +version = "2.3.0" [dependencies.log] version = "0.4" @@ -39,5 +39,5 @@ version = "0.4" version = "0.6.1" [features] -alloc = [] +alloc = ["zerocopy/alloc"] default = ["alloc"] diff --git a/Cargo.toml.orig b/Cargo.toml.orig index 431ccd2..0a47ce9 100644 --- a/Cargo.toml.orig +++ b/Cargo.toml.orig @@ -1,6 +1,6 @@ [package] name = "virtio-drivers" -version = "0.4.0" +version = "0.5.0" license = "MIT" authors = [ "Jiajie Chen <noc@jiegec.ac.cn>", @@ -16,9 +16,9 @@ categories = ["hardware-support", "no-std"] [dependencies] log = "0.4" -bitflags = "1.3" +bitflags = "2.3.0" zerocopy = "0.6.1" [features] default = ["alloc"] -alloc = [] +alloc = ["zerocopy/alloc"] @@ -11,13 +11,13 @@ third_party { } url { type: ARCHIVE - value: "https://static.crates.io/crates/virtio-drivers/virtio-drivers-0.4.0.crate" + value: "https://static.crates.io/crates/virtio-drivers/virtio-drivers-0.5.0.crate" } - version: "0.4.0" + version: "0.5.0" license_type: NOTICE last_upgrade_date { year: 2023 - month: 4 - day: 19 + month: 6 + day: 13 } } @@ -44,7 +44,12 @@ VirtIO guest drivers in Rust. For **no_std** environment. ## Examples & Tests -- x86_64 (TODO) +### [x86_64](./examples/x86_64) + +```bash +cd examples/x86_64 +make qemu +``` ### [aarch64](./examples/aarch64) diff --git a/cargo2android.json b/cargo2android.json index 8bd0ae6..ede59d0 100644 --- a/cargo2android.json +++ b/cargo2android.json @@ -1,7 +1,7 @@ { "dependencies": true, "device": true, - "features": "", + "features": "alloc", "force-rlib": true, "no-host": "true", "no-std": true, diff --git a/patches/Android.bp.patch b/patches/Android.bp.patch index 8896c52..2a063c6 100644 --- a/patches/Android.bp.patch +++ b/patches/Android.bp.patch @@ -1,26 +1,23 @@ diff --git a/Android.bp b/Android.bp -index f635115..d111544 100644 +index e2b5dd5..6132b98 100644 --- a/Android.bp +++ b/Android.bp -@@ -28,9 +28,9 @@ rust_library_rlib { - srcs: ["src/lib.rs"], - edition: "2018", +@@ -30,8 +30,8 @@ rust_library_rlib { + features: ["alloc"], rustlibs: [ -- "libbitflags", + "libbitflags", - "liblog_rust", - "libzerocopy", -+ "libbitflags-1.3.2", + "liblog_rust_nostd", + "libzerocopy_nostd", ], apex_available: [ "//apex_available:platform", -@@ -54,7 +54,7 @@ rust_test { - auto_gen_config: true, - edition: "2018", +@@ -58,6 +58,6 @@ rust_test { rustlibs: [ -- "libbitflags", -+ "libbitflags-1.3.2", + "libbitflags", "liblog_rust", - "libzerocopy", +- "libzerocopy", ++ "libzerocopy_nostd", ], + } diff --git a/src/device/blk.rs b/src/device/blk.rs index d095047..ea3aef0 100644 --- a/src/device/blk.rs +++ b/src/device/blk.rs @@ -417,6 +417,7 @@ impl Default for BlkResp { pub const SECTOR_SIZE: usize = 512; bitflags! { + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] struct BlkFeature: u64 { /// Device supports request barriers. (legacy) const BARRIER = 1 << 0; @@ -477,7 +478,7 @@ mod tests { }; use alloc::{sync::Arc, vec}; use core::{mem::size_of, ptr::NonNull}; - use std::{sync::Mutex, thread, time::Duration}; + use std::{sync::Mutex, thread}; #[test] fn config() { @@ -500,7 +501,7 @@ mod tests { driver_features: 0, guest_page_size: 0, interrupt_pending: false, - queues: vec![QueueStatus::default(); 1], + queues: vec![QueueStatus::default()], })); let transport = FakeTransport { device_type: DeviceType::Console, @@ -536,7 +537,7 @@ mod tests { driver_features: 0, guest_page_size: 0, interrupt_pending: false, - queues: vec![QueueStatus::default(); 1], + queues: vec![QueueStatus::default()], })); let transport = FakeTransport { device_type: DeviceType::Console, @@ -550,9 +551,7 @@ mod tests { // Start a thread to simulate the device waiting for a read request. let handle = thread::spawn(move || { println!("Device waiting for a request."); - while !state.lock().unwrap().queues[usize::from(QUEUE)].notified { - thread::sleep(Duration::from_millis(10)); - } + State::wait_until_queue_notified(&state, QUEUE); println!("Transmit queue was notified."); state @@ -611,7 +610,7 @@ mod tests { driver_features: 0, guest_page_size: 0, interrupt_pending: false, - queues: vec![QueueStatus::default(); 1], + queues: vec![QueueStatus::default()], })); let transport = FakeTransport { device_type: DeviceType::Console, @@ -625,9 +624,7 @@ mod tests { // Start a thread to simulate the device waiting for a write request. let handle = thread::spawn(move || { println!("Device waiting for a request."); - while !state.lock().unwrap().queues[usize::from(QUEUE)].notified { - thread::sleep(Duration::from_millis(10)); - } + State::wait_until_queue_notified(&state, QUEUE); println!("Transmit queue was notified."); state diff --git a/src/device/common.rs b/src/device/common.rs index 2c8be3e..1081319 100644 --- a/src/device/common.rs +++ b/src/device/common.rs @@ -3,6 +3,7 @@ use bitflags::bitflags; bitflags! { + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] pub(crate) struct Feature: u64 { // device independent const NOTIFY_ON_EMPTY = 1 << 24; // legacy diff --git a/src/device/console.rs b/src/device/console.rs index e0b0356..7d3c7d4 100644 --- a/src/device/console.rs +++ b/src/device/console.rs @@ -1,10 +1,11 @@ //! Driver for VirtIO console devices. -use crate::hal::{BufferDirection, Dma, Hal}; +use crate::hal::Hal; use crate::queue::VirtQueue; use crate::transport::Transport; use crate::volatile::{volread, ReadOnly, WriteOnly}; -use crate::Result; +use crate::{Result, PAGE_SIZE}; +use alloc::boxed::Box; use bitflags::bitflags; use core::ptr::NonNull; use log::info; @@ -38,13 +39,12 @@ const QUEUE_SIZE: usize = 2; /// # Ok(()) /// # } /// ``` -pub struct VirtIOConsole<'a, H: Hal, T: Transport> { +pub struct VirtIOConsole<H: Hal, T: Transport> { transport: T, config_space: NonNull<Config>, receiveq: VirtQueue<H, QUEUE_SIZE>, transmitq: VirtQueue<H, QUEUE_SIZE>, - queue_buf_dma: Dma<H>, - queue_buf_rx: &'a mut [u8], + queue_buf_rx: Box<[u8; PAGE_SIZE]>, cursor: usize, pending_len: usize, /// The token of the outstanding receive request, if there is one. @@ -62,7 +62,7 @@ pub struct ConsoleInfo { pub max_ports: u32, } -impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> { +impl<H: Hal, T: Transport> VirtIOConsole<H, T> { /// Creates a new VirtIO console driver. pub fn new(mut transport: T) -> Result<Self> { transport.begin_init(|features| { @@ -74,12 +74,11 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> { let config_space = transport.config_space::<Config>()?; let receiveq = VirtQueue::new(&mut transport, QUEUE_RECEIVEQ_PORT_0)?; let transmitq = VirtQueue::new(&mut transport, QUEUE_TRANSMITQ_PORT_0)?; - let queue_buf_dma = Dma::new(1, BufferDirection::DeviceToDriver)?; // Safe because no alignment or initialisation is required for [u8], the DMA buffer is // dereferenceable, and the lifetime of the reference matches the lifetime of the DMA buffer // (which we don't otherwise access). - let queue_buf_rx = unsafe { queue_buf_dma.raw_slice().as_mut() }; + let queue_buf_rx = Box::new([0; PAGE_SIZE]); transport.finish_init(); let mut console = VirtIOConsole { @@ -87,7 +86,6 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> { config_space, receiveq, transmitq, - queue_buf_dma, queue_buf_rx, cursor: 0, pending_len: 0, @@ -118,7 +116,10 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> { if self.receive_token.is_none() && self.cursor == self.pending_len { // Safe because the buffer lasts at least as long as the queue, and there are no other // outstanding requests using the buffer. - self.receive_token = Some(unsafe { self.receiveq.add(&[], &mut [self.queue_buf_rx]) }?); + self.receive_token = Some(unsafe { + self.receiveq + .add(&[], &mut [self.queue_buf_rx.as_mut_slice()]) + }?); if self.receiveq.should_notify() { self.transport.notify(QUEUE_RECEIVEQ_PORT_0); } @@ -148,8 +149,11 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> { // Safe because we are passing the same buffer as we passed to `VirtQueue::add` in // `poll_retrieve` and it is still valid. let len = unsafe { - self.receiveq - .pop_used(receive_token, &[], &mut [self.queue_buf_rx])? + self.receiveq.pop_used( + receive_token, + &[], + &mut [self.queue_buf_rx.as_mut_slice()], + )? }; flag = true; assert_ne!(len, 0); @@ -188,7 +192,7 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> { } } -impl<H: Hal, T: Transport> Drop for VirtIOConsole<'_, H, T> { +impl<H: Hal, T: Transport> Drop for VirtIOConsole<H, T> { fn drop(&mut self) { // Clear any pointers pointing to DMA regions, so the device doesn't try to access them // after they have been freed. @@ -206,6 +210,7 @@ struct Config { } bitflags! { + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] struct Features: u64 { const SIZE = 1 << 0; const MULTIPORT = 1 << 1; @@ -241,7 +246,7 @@ mod tests { }; use alloc::{sync::Arc, vec}; use core::ptr::NonNull; - use std::{sync::Mutex, thread, time::Duration}; + use std::{sync::Mutex, thread}; #[test] fn receive() { @@ -256,7 +261,7 @@ mod tests { driver_features: 0, guest_page_size: 0, interrupt_pending: false, - queues: vec![QueueStatus::default(); 2], + queues: vec![QueueStatus::default(), QueueStatus::default()], })); let transport = FakeTransport { device_type: DeviceType::Console, @@ -304,7 +309,7 @@ mod tests { driver_features: 0, guest_page_size: 0, interrupt_pending: false, - queues: vec![QueueStatus::default(); 2], + queues: vec![QueueStatus::default(), QueueStatus::default()], })); let transport = FakeTransport { device_type: DeviceType::Console, @@ -318,9 +323,7 @@ mod tests { // Start a thread to simulate the device waiting for characters. let handle = thread::spawn(move || { println!("Device waiting for a character."); - while !state.lock().unwrap().queues[usize::from(QUEUE_TRANSMITQ_PORT_0)].notified { - thread::sleep(Duration::from_millis(10)); - } + State::wait_until_queue_notified(&state, QUEUE_TRANSMITQ_PORT_0); println!("Transmit queue was notified."); let data = state diff --git a/src/device/gpu.rs b/src/device/gpu.rs index b1b53bd..43e1b76 100644 --- a/src/device/gpu.rs +++ b/src/device/gpu.rs @@ -4,7 +4,8 @@ use crate::hal::{BufferDirection, Dma, Hal}; use crate::queue::VirtQueue; use crate::transport::Transport; use crate::volatile::{volread, ReadOnly, Volatile, WriteOnly}; -use crate::{pages, Error, Result}; +use crate::{pages, Error, Result, PAGE_SIZE}; +use alloc::boxed::Box; use bitflags::bitflags; use log::info; use zerocopy::{AsBytes, FromBytes}; @@ -18,7 +19,7 @@ const QUEUE_SIZE: u16 = 2; /// a gpu with 3D support on the host machine. /// In 2D mode the virtio-gpu device provides support for ARGB Hardware cursors /// and multiple scanouts (aka heads). -pub struct VirtIOGpu<'a, H: Hal, T: Transport> { +pub struct VirtIOGpu<H: Hal, T: Transport> { transport: T, rect: Option<Rect>, /// DMA area of frame buffer. @@ -29,17 +30,13 @@ pub struct VirtIOGpu<'a, H: Hal, T: Transport> { control_queue: VirtQueue<H, { QUEUE_SIZE as usize }>, /// Queue for sending cursor commands. cursor_queue: VirtQueue<H, { QUEUE_SIZE as usize }>, - /// DMA region for sending data to the device. - dma_send: Dma<H>, - /// DMA region for receiving data from the device. - dma_recv: Dma<H>, /// Send buffer for queue. - queue_buf_send: &'a mut [u8], + queue_buf_send: Box<[u8]>, /// Recv buffer for queue. - queue_buf_recv: &'a mut [u8], + queue_buf_recv: Box<[u8]>, } -impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> { +impl<H: Hal, T: Transport> VirtIOGpu<H, T> { /// Create a new VirtIO-Gpu driver. pub fn new(mut transport: T) -> Result<Self> { transport.begin_init(|features| { @@ -63,10 +60,8 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> { let control_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT)?; let cursor_queue = VirtQueue::new(&mut transport, QUEUE_CURSOR)?; - let dma_send = Dma::new(1, BufferDirection::DriverToDevice)?; - let dma_recv = Dma::new(1, BufferDirection::DeviceToDriver)?; - let queue_buf_send = unsafe { dma_send.raw_slice().as_mut() }; - let queue_buf_recv = unsafe { dma_recv.raw_slice().as_mut() }; + let queue_buf_send = FromBytes::new_box_slice_zeroed(PAGE_SIZE); + let queue_buf_recv = FromBytes::new_box_slice_zeroed(PAGE_SIZE); transport.finish_init(); @@ -77,8 +72,6 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> { rect: None, control_queue, cursor_queue, - dma_send, - dma_recv, queue_buf_send, queue_buf_recv, }) @@ -177,8 +170,8 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> { fn request<Req: AsBytes, Rsp: FromBytes>(&mut self, req: Req) -> Result<Rsp> { req.write_to_prefix(&mut *self.queue_buf_send).unwrap(); self.control_queue.add_notify_wait_pop( - &[self.queue_buf_send], - &mut [self.queue_buf_recv], + &[&self.queue_buf_send], + &mut [&mut self.queue_buf_recv], &mut self.transport, )?; Ok(Rsp::read_from_prefix(&*self.queue_buf_recv).unwrap()) @@ -188,7 +181,7 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> { fn cursor_request<Req: AsBytes>(&mut self, req: Req) -> Result { req.write_to_prefix(&mut *self.queue_buf_send).unwrap(); self.cursor_queue.add_notify_wait_pop( - &[self.queue_buf_send], + &[&self.queue_buf_send], &mut [], &mut self.transport, )?; @@ -286,7 +279,7 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> { } } -impl<H: Hal, T: Transport> Drop for VirtIOGpu<'_, H, T> { +impl<H: Hal, T: Transport> Drop for VirtIOGpu<H, T> { fn drop(&mut self) { // Clear any pointers pointing to DMA regions, so the device doesn't try to access them // after they have been freed. @@ -313,6 +306,7 @@ struct Config { const EVENT_DISPLAY: u32 = 1 << 0; bitflags! { + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] struct Features: u64 { /// virgl 3D mode is supported. const VIRGL = 1 << 0; diff --git a/src/device/mod.rs b/src/device/mod.rs index ca68901..00fa6fe 100644 --- a/src/device/mod.rs +++ b/src/device/mod.rs @@ -1,7 +1,9 @@ //! Drivers for specific VirtIO devices. pub mod blk; +#[cfg(feature = "alloc")] pub mod console; +#[cfg(feature = "alloc")] pub mod gpu; #[cfg(feature = "alloc")] pub mod input; diff --git a/src/device/net.rs b/src/device/net.rs index 4441f63..b9419e7 100644 --- a/src/device/net.rs +++ b/src/device/net.rs @@ -262,6 +262,7 @@ impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> Drop for VirtIONet<H, T, QUE } bitflags! { + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] struct Features: u64 { /// Device handles packets with partial checksum. /// This "checksum offload" is a common feature on modern network cards. @@ -323,6 +324,7 @@ bitflags! { } bitflags! { + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] struct Status: u16 { const LINK_UP = 1; const ANNOUNCE = 2; @@ -330,6 +332,7 @@ bitflags! { } bitflags! { + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] struct InterruptStatus : u32 { const USED_RING_UPDATE = 1 << 0; const CONFIGURATION_CHANGE = 1 << 1; @@ -364,10 +367,12 @@ pub struct VirtioNetHdr { // payload starts from here } +#[derive(AsBytes, Copy, Clone, Debug, Default, Eq, FromBytes, PartialEq)] +#[repr(transparent)] +struct Flags(u8); + bitflags! { - #[repr(transparent)] - #[derive(AsBytes, Default, FromBytes)] - struct Flags: u8 { + impl Flags: u8 { const NEEDS_CSUM = 1; const DATA_VALID = 2; const RSC_INFO = 4; diff --git a/src/device/socket/mod.rs b/src/device/socket/mod.rs index 65280aa..bf423bf 100644 --- a/src/device/socket/mod.rs +++ b/src/device/socket/mod.rs @@ -1,8 +1,26 @@ -//! This module implements the virtio vsock device. +//! Driver for VirtIO socket devices. +//! +//! To use the driver, you should first create a [`VirtIOSocket`] instance with your VirtIO +//! transport, and then create a [`VsockConnectionManager`] wrapping it to keep track of +//! connections. If you only want to have a single outgoing vsock connection at once, you can use +//! [`SingleConnectionManager`] for a slightly simpler interface. +//! +//! See [`VsockConnectionManager`] for a usage example. mod error; +#[cfg(feature = "alloc")] +mod multiconnectionmanager; mod protocol; +#[cfg(feature = "alloc")] +mod singleconnectionmanager; +#[cfg(feature = "alloc")] mod vsock; pub use error::SocketError; +#[cfg(feature = "alloc")] +pub use multiconnectionmanager::VsockConnectionManager; +pub use protocol::VsockAddr; +#[cfg(feature = "alloc")] +pub use singleconnectionmanager::SingleConnectionManager; +#[cfg(feature = "alloc")] pub use vsock::{DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType}; diff --git a/src/device/socket/multiconnectionmanager.rs b/src/device/socket/multiconnectionmanager.rs new file mode 100644 index 0000000..6aee5cd --- /dev/null +++ b/src/device/socket/multiconnectionmanager.rs @@ -0,0 +1,763 @@ +use super::{ + protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket, + VsockEvent, VsockEventType, +}; +use crate::{transport::Transport, Hal, Result}; +use alloc::{boxed::Box, vec::Vec}; +use core::cmp::min; +use core::convert::TryInto; +use core::hint::spin_loop; +use log::debug; +use zerocopy::FromBytes; + +const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024; + +/// A higher level interface for VirtIO socket (vsock) devices. +/// +/// This keeps track of multiple vsock connections. +/// +/// # Example +/// +/// ``` +/// # use virtio_drivers::{Error, Hal}; +/// # use virtio_drivers::transport::Transport; +/// use virtio_drivers::device::socket::{VirtIOSocket, VsockAddr, VsockConnectionManager}; +/// +/// # fn example<HalImpl: Hal, T: Transport>(transport: T) -> Result<(), Error> { +/// let mut socket = VsockConnectionManager::new(VirtIOSocket::<HalImpl, _>::new(transport)?); +/// +/// // Start a thread to call `socket.poll()` and handle events. +/// +/// let remote_address = VsockAddr { cid: 2, port: 42 }; +/// let local_port = 1234; +/// socket.connect(remote_address, local_port)?; +/// +/// // Wait until `socket.poll()` returns an event indicating that the socket is connected. +/// +/// socket.send(remote_address, local_port, "Hello world".as_bytes())?; +/// +/// socket.shutdown(remote_address, local_port)?; +/// # Ok(()) +/// # } +/// ``` +pub struct VsockConnectionManager<H: Hal, T: Transport> { + driver: VirtIOSocket<H, T>, + connections: Vec<Connection>, + listening_ports: Vec<u32>, +} + +#[derive(Debug)] +struct Connection { + info: ConnectionInfo, + buffer: RingBuffer, + /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is + /// still data in the buffer. + peer_requested_shutdown: bool, +} + +impl Connection { + fn new(peer: VsockAddr, local_port: u32) -> Self { + let mut info = ConnectionInfo::new(peer, local_port); + info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap(); + Self { + info, + buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY), + peer_requested_shutdown: false, + } + } +} + +impl<H: Hal, T: Transport> VsockConnectionManager<H, T> { + /// Construct a new connection manager wrapping the given low-level VirtIO socket driver. + pub fn new(driver: VirtIOSocket<H, T>) -> Self { + Self { + driver, + connections: Vec::new(), + listening_ports: Vec::new(), + } + } + + /// Returns the CID which has been assigned to this guest. + pub fn guest_cid(&self) -> u64 { + self.driver.guest_cid() + } + + /// Allows incoming connections on the given port number. + pub fn listen(&mut self, port: u32) { + if !self.listening_ports.contains(&port) { + self.listening_ports.push(port); + } + } + + /// Stops allowing incoming connections on the given port number. + pub fn unlisten(&mut self, port: u32) { + self.listening_ports.retain(|p| *p != port); + } + + /// Sends a request to connect to the given destination. + /// + /// This returns as soon as the request is sent; you should wait until `poll` returns a + /// `VsockEventType::Connected` event indicating that the peer has accepted the connection + /// before sending data. + pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result { + if self.connections.iter().any(|connection| { + connection.info.dst == destination && connection.info.src_port == src_port + }) { + return Err(SocketError::ConnectionExists.into()); + } + + let new_connection = Connection::new(destination, src_port); + + self.driver.connect(&new_connection.info)?; + debug!("Connection requested: {:?}", new_connection.info); + self.connections.push(new_connection); + Ok(()) + } + + /// Sends the buffer to the destination. + pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result { + let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; + + self.driver.send(buffer, &mut connection.info) + } + + /// Polls the vsock device to receive data or other updates. + pub fn poll(&mut self) -> Result<Option<VsockEvent>> { + let guest_cid = self.driver.guest_cid(); + let connections = &mut self.connections; + + let result = self.driver.poll(|event, body| { + let connection = get_connection_for_event(connections, &event, guest_cid); + + // Skip events which don't match any connection we know about, unless they are a + // connection request. + let connection = if let Some((_, connection)) = connection { + connection + } else if let VsockEventType::ConnectionRequest = event.event_type { + // If the requested connection already exists or the CID isn't ours, ignore it. + if connection.is_some() || event.destination.cid != guest_cid { + return Ok(None); + } + // Add the new connection to our list, at least for now. It will be removed again + // below if we weren't listening on the port. + connections.push(Connection::new(event.source, event.destination.port)); + connections.last_mut().unwrap() + } else { + return Ok(None); + }; + + // Update stored connection info. + connection.info.update_for_event(&event); + + if let VsockEventType::Received { length } = event.event_type { + // Copy to buffer + if !connection.buffer.add(body) { + return Err(SocketError::OutputBufferTooShort(length).into()); + } + } + + Ok(Some(event)) + })?; + + let Some(event) = result else { + return Ok(None); + }; + + // The connection must exist because we found it above in the callback. + let (connection_index, connection) = + get_connection_for_event(connections, &event, guest_cid).unwrap(); + + match event.event_type { + VsockEventType::ConnectionRequest => { + if self.listening_ports.contains(&event.destination.port) { + self.driver.accept(&connection.info)?; + } else { + // Reject the connection request and remove it from our list. + self.driver.force_close(&connection.info)?; + self.connections.swap_remove(connection_index); + + // No need to pass the request on to the client, as we've already rejected it. + return Ok(None); + } + } + VsockEventType::Connected => {} + VsockEventType::Disconnected { reason } => { + // Wait until client reads all data before removing connection. + if connection.buffer.is_empty() { + if reason == DisconnectReason::Shutdown { + self.driver.force_close(&connection.info)?; + } + self.connections.swap_remove(connection_index); + } else { + connection.peer_requested_shutdown = true; + } + } + VsockEventType::Received { .. } => { + // Already copied the buffer in the callback above. + } + VsockEventType::CreditRequest => { + // If the peer requested credit, send an update. + self.driver.credit_update(&connection.info)?; + // No need to pass the request on to the client, we've already handled it. + return Ok(None); + } + VsockEventType::CreditUpdate => {} + } + + Ok(Some(event)) + } + + /// Reads data received from the given connection. + pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> { + let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?; + + // Copy from ring buffer + let bytes_read = connection.buffer.drain(buffer); + + connection.info.done_forwarding(bytes_read); + + // If buffer is now empty and the peer requested shutdown, finish shutting down the + // connection. + if connection.peer_requested_shutdown && connection.buffer.is_empty() { + self.driver.force_close(&connection.info)?; + self.connections.swap_remove(connection_index); + } + + Ok(bytes_read) + } + + /// Blocks until we get some event from the vsock device. + pub fn wait_for_event(&mut self) -> Result<VsockEvent> { + loop { + if let Some(event) = self.poll()? { + return Ok(event); + } else { + spin_loop(); + } + } + } + + /// Requests to shut down the connection cleanly. + /// + /// This returns as soon as the request is sent; you should wait until `poll` returns a + /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the + /// shutdown. + pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result { + let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; + + self.driver.shutdown(&connection.info) + } + + /// Forcibly closes the connection without waiting for the peer. + pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result { + let (index, connection) = get_connection(&mut self.connections, destination, src_port)?; + + self.driver.force_close(&connection.info)?; + + self.connections.swap_remove(index); + Ok(()) + } +} + +/// Returns the connection from the given list matching the given peer address and local port, and +/// its index. +/// +/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list. +fn get_connection( + connections: &mut [Connection], + peer: VsockAddr, + local_port: u32, +) -> core::result::Result<(usize, &mut Connection), SocketError> { + connections + .iter_mut() + .enumerate() + .find(|(_, connection)| { + connection.info.dst == peer && connection.info.src_port == local_port + }) + .ok_or(SocketError::NotConnected) +} + +/// Returns the connection from the given list matching the event, if any, and its index. +fn get_connection_for_event<'a>( + connections: &'a mut [Connection], + event: &VsockEvent, + local_cid: u64, +) -> Option<(usize, &'a mut Connection)> { + connections + .iter_mut() + .enumerate() + .find(|(_, connection)| event.matches_connection(&connection.info, local_cid)) +} + +#[derive(Debug)] +struct RingBuffer { + buffer: Box<[u8]>, + /// The number of bytes currently in the buffer. + used: usize, + /// The index of the first used byte in the buffer. + start: usize, +} + +impl RingBuffer { + pub fn new(capacity: usize) -> Self { + Self { + buffer: FromBytes::new_box_slice_zeroed(capacity), + used: 0, + start: 0, + } + } + + /// Returns the number of bytes currently used in the buffer. + pub fn used(&self) -> usize { + self.used + } + + /// Returns true iff there are currently no bytes in the buffer. + pub fn is_empty(&self) -> bool { + self.used == 0 + } + + /// Returns the number of bytes currently free in the buffer. + pub fn available(&self) -> usize { + self.buffer.len() - self.used + } + + /// Adds the given bytes to the buffer if there is enough capacity for them all. + /// + /// Returns true if they were added, or false if they were not. + pub fn add(&mut self, bytes: &[u8]) -> bool { + if bytes.len() > self.available() { + return false; + } + + // The index of the first available position in the buffer. + let first_available = (self.start + self.used) % self.buffer.len(); + // The number of bytes to copy from `bytes` to `buffer` between `first_available` and + // `buffer.len()`. + let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available); + self.buffer[first_available..first_available + copy_length_before_wraparound] + .copy_from_slice(&bytes[0..copy_length_before_wraparound]); + if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) { + self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound); + } + self.used += bytes.len(); + + true + } + + /// Reads and removes as many bytes as possible from the buffer, up to the length of the given + /// buffer. + pub fn drain(&mut self, out: &mut [u8]) -> usize { + let bytes_read = min(self.used, out.len()); + + // The number of bytes to copy out between `start` and the end of the buffer. + let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start); + // The number of bytes to copy out from the beginning of the buffer after wrapping around. + let read_after_wraparound = bytes_read + .checked_sub(read_before_wraparound) + .unwrap_or_default(); + + out[0..read_before_wraparound] + .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]); + out[read_before_wraparound..bytes_read] + .copy_from_slice(&self.buffer[0..read_after_wraparound]); + + self.used -= bytes_read; + self.start = (self.start + bytes_read) % self.buffer.len(); + + bytes_read + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + device::socket::{ + protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp}, + vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX}, + }, + hal::fake::FakeHal, + transport::{ + fake::{FakeTransport, QueueStatus, State}, + DeviceStatus, DeviceType, + }, + volatile::ReadOnly, + }; + use alloc::{sync::Arc, vec}; + use core::{mem::size_of, ptr::NonNull}; + use std::{sync::Mutex, thread}; + use zerocopy::{AsBytes, FromBytes}; + + #[test] + fn send_recv() { + let host_cid = 2; + let guest_cid = 66; + let host_port = 1234; + let guest_port = 4321; + let host_address = VsockAddr { + cid: host_cid, + port: host_port, + }; + let hello_from_guest = "Hello from guest"; + let hello_from_host = "Hello from host"; + + let mut config_space = VirtioVsockConfig { + guest_cid_low: ReadOnly::new(66), + guest_cid_high: ReadOnly::new(0), + }; + let state = Arc::new(Mutex::new(State { + status: DeviceStatus::empty(), + driver_features: 0, + guest_page_size: 0, + interrupt_pending: false, + queues: vec![ + QueueStatus::default(), + QueueStatus::default(), + QueueStatus::default(), + ], + })); + let transport = FakeTransport { + device_type: DeviceType::Socket, + max_queue_size: 32, + device_features: 0, + config_space: NonNull::from(&mut config_space), + state: state.clone(), + }; + let mut socket = VsockConnectionManager::new( + VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(), + ); + + // Start a thread to simulate the device. + let handle = thread::spawn(move || { + // Wait for connection request. + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + assert_eq!( + VirtioVsockHdr::read_from( + state + .lock() + .unwrap() + .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX) + .as_slice() + ) + .unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::Request.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: guest_port.into(), + dst_port: host_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 1024.into(), + fwd_cnt: 0.into(), + } + ); + + // Accept connection and give the peer enough credit to send the message. + state.lock().unwrap().write_to_queue::<QUEUE_SIZE>( + RX_QUEUE_IDX, + VirtioVsockHdr { + op: VirtioVsockOp::Response.into(), + src_cid: host_cid.into(), + dst_cid: guest_cid.into(), + src_port: host_port.into(), + dst_port: guest_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 50.into(), + fwd_cnt: 0.into(), + } + .as_bytes(), + ); + + // Expect the guest to send some data. + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + let request = state + .lock() + .unwrap() + .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX); + assert_eq!( + request.len(), + size_of::<VirtioVsockHdr>() + hello_from_guest.len() + ); + assert_eq!( + VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::Rw.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: guest_port.into(), + dst_port: host_port.into(), + len: (hello_from_guest.len() as u32).into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 1024.into(), + fwd_cnt: 0.into(), + } + ); + assert_eq!( + &request[size_of::<VirtioVsockHdr>()..], + hello_from_guest.as_bytes() + ); + + println!("Host sending"); + + // Send a response. + let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()]; + VirtioVsockHdr { + op: VirtioVsockOp::Rw.into(), + src_cid: host_cid.into(), + dst_cid: guest_cid.into(), + src_port: host_port.into(), + dst_port: guest_port.into(), + len: (hello_from_host.len() as u32).into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 50.into(), + fwd_cnt: (hello_from_guest.len() as u32).into(), + } + .write_to_prefix(response.as_mut_slice()); + response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes()); + state + .lock() + .unwrap() + .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response); + + // Expect a shutdown. + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + assert_eq!( + VirtioVsockHdr::read_from( + state + .lock() + .unwrap() + .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX) + .as_slice() + ) + .unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::Shutdown.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: guest_port.into(), + dst_port: host_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 1024.into(), + fwd_cnt: (hello_from_host.len() as u32).into(), + } + ); + }); + + socket.connect(host_address, guest_port).unwrap(); + assert_eq!( + socket.wait_for_event().unwrap(), + VsockEvent { + source: host_address, + destination: VsockAddr { + cid: guest_cid, + port: guest_port, + }, + event_type: VsockEventType::Connected, + buffer_status: VsockBufferStatus { + buffer_allocation: 50, + forward_count: 0, + }, + } + ); + println!("Guest sending"); + socket + .send(host_address, guest_port, "Hello from guest".as_bytes()) + .unwrap(); + println!("Guest waiting to receive."); + assert_eq!( + socket.wait_for_event().unwrap(), + VsockEvent { + source: host_address, + destination: VsockAddr { + cid: guest_cid, + port: guest_port, + }, + event_type: VsockEventType::Received { + length: hello_from_host.len() + }, + buffer_status: VsockBufferStatus { + buffer_allocation: 50, + forward_count: hello_from_guest.len() as u32, + }, + } + ); + println!("Guest getting received data."); + let mut buffer = [0u8; 64]; + assert_eq!( + socket.recv(host_address, guest_port, &mut buffer).unwrap(), + hello_from_host.len() + ); + assert_eq!( + &buffer[0..hello_from_host.len()], + hello_from_host.as_bytes() + ); + socket.shutdown(host_address, guest_port).unwrap(); + + handle.join().unwrap(); + } + + #[test] + fn incoming_connection() { + let host_cid = 2; + let guest_cid = 66; + let host_port = 1234; + let guest_port = 4321; + let wrong_guest_port = 4444; + let host_address = VsockAddr { + cid: host_cid, + port: host_port, + }; + + let mut config_space = VirtioVsockConfig { + guest_cid_low: ReadOnly::new(66), + guest_cid_high: ReadOnly::new(0), + }; + let state = Arc::new(Mutex::new(State { + status: DeviceStatus::empty(), + driver_features: 0, + guest_page_size: 0, + interrupt_pending: false, + queues: vec![ + QueueStatus::default(), + QueueStatus::default(), + QueueStatus::default(), + ], + })); + let transport = FakeTransport { + device_type: DeviceType::Socket, + max_queue_size: 32, + device_features: 0, + config_space: NonNull::from(&mut config_space), + state: state.clone(), + }; + let mut socket = VsockConnectionManager::new( + VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(), + ); + + socket.listen(guest_port); + + // Start a thread to simulate the device. + let handle = thread::spawn(move || { + // Send a connection request for a port the guest isn't listening on. + println!("Host sending connection request to wrong port"); + state.lock().unwrap().write_to_queue::<QUEUE_SIZE>( + RX_QUEUE_IDX, + VirtioVsockHdr { + op: VirtioVsockOp::Request.into(), + src_cid: host_cid.into(), + dst_cid: guest_cid.into(), + src_port: host_port.into(), + dst_port: wrong_guest_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 50.into(), + fwd_cnt: 0.into(), + } + .as_bytes(), + ); + + // Expect a rejection. + println!("Host waiting for rejection"); + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + assert_eq!( + VirtioVsockHdr::read_from( + state + .lock() + .unwrap() + .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX) + .as_slice() + ) + .unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::Rst.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: wrong_guest_port.into(), + dst_port: host_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 1024.into(), + fwd_cnt: 0.into(), + } + ); + + // Send a connection request for a port the guest is listening on. + println!("Host sending connection request to right port"); + state.lock().unwrap().write_to_queue::<QUEUE_SIZE>( + RX_QUEUE_IDX, + VirtioVsockHdr { + op: VirtioVsockOp::Request.into(), + src_cid: host_cid.into(), + dst_cid: guest_cid.into(), + src_port: host_port.into(), + dst_port: guest_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 50.into(), + fwd_cnt: 0.into(), + } + .as_bytes(), + ); + + // Expect a response. + println!("Host waiting for response"); + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + assert_eq!( + VirtioVsockHdr::read_from( + state + .lock() + .unwrap() + .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX) + .as_slice() + ) + .unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::Response.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: guest_port.into(), + dst_port: host_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 1024.into(), + fwd_cnt: 0.into(), + } + ); + + println!("Host finished"); + }); + + // Expect an incoming connection. + println!("Guest expecting incoming connection."); + assert_eq!( + socket.wait_for_event().unwrap(), + VsockEvent { + source: host_address, + destination: VsockAddr { + cid: guest_cid, + port: guest_port, + }, + event_type: VsockEventType::ConnectionRequest, + buffer_status: VsockBufferStatus { + buffer_allocation: 50, + forward_count: 0, + }, + } + ); + + handle.join().unwrap(); + } +} diff --git a/src/device/socket/protocol.rs b/src/device/socket/protocol.rs index abc1702..3587005 100644 --- a/src/device/socket/protocol.rs +++ b/src/device/socket/protocol.rs @@ -2,6 +2,7 @@ use super::error::{self, SocketError}; use crate::volatile::ReadOnly; +use bitflags::bitflags; use core::{ convert::{TryFrom, TryInto}, fmt, @@ -17,6 +18,8 @@ use zerocopy::{ pub enum SocketType { /// Stream sockets provide in-order, guaranteed, connection-oriented delivery without message boundaries. Stream = 1, + /// seqpacket socket type introduced in virtio-v1.2. + SeqPacket = 2, } impl From<SocketType> for U16<LittleEndian> { @@ -40,7 +43,7 @@ pub struct VirtioVsockConfig { /// The message header for data packets sent on the tx/rx queues #[repr(packed)] -#[derive(AsBytes, Clone, Copy, Debug, FromBytes)] +#[derive(AsBytes, Clone, Copy, Debug, Eq, FromBytes, PartialEq)] pub struct VirtioVsockHdr { pub src_cid: U64<LittleEndian>, pub dst_cid: U64<LittleEndian>, @@ -182,3 +185,29 @@ impl fmt::Debug for VirtioVsockOp { } } } + +bitflags! { + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] + pub(crate) struct Feature: u64 { + /// stream socket type is supported. + const STREAM = 1 << 0; + /// seqpacket socket type is supported. + const SEQ_PACKET = 1 << 1; + + // device independent + const NOTIFY_ON_EMPTY = 1 << 24; // legacy + const ANY_LAYOUT = 1 << 27; // legacy + const RING_INDIRECT_DESC = 1 << 28; + const RING_EVENT_IDX = 1 << 29; + const UNUSED = 1 << 30; // legacy + const VERSION_1 = 1 << 32; // detect legacy + + // since virtio v1.1 + const ACCESS_PLATFORM = 1 << 33; + const RING_PACKED = 1 << 34; + const IN_ORDER = 1 << 35; + const ORDER_PLATFORM = 1 << 36; + const SR_IOV = 1 << 37; + const NOTIFICATION_DATA = 1 << 38; + } +} diff --git a/src/device/socket/singleconnectionmanager.rs b/src/device/socket/singleconnectionmanager.rs new file mode 100644 index 0000000..8c9bff6 --- /dev/null +++ b/src/device/socket/singleconnectionmanager.rs @@ -0,0 +1,447 @@ +use super::{ + protocol::VsockAddr, vsock::ConnectionInfo, SocketError, VirtIOSocket, VsockEvent, + VsockEventType, +}; +use crate::{transport::Transport, Hal, Result}; +use core::hint::spin_loop; +use log::debug; + +/// A higher level interface for VirtIO socket (vsock) devices. +/// +/// This can only keep track of a single vsock connection. If you want to support multiple +/// simultaneous connections, try [`VsockConnectionManager`](super::VsockConnectionManager). +pub struct SingleConnectionManager<H: Hal, T: Transport> { + driver: VirtIOSocket<H, T>, + connection_info: Option<ConnectionInfo>, +} + +impl<H: Hal, T: Transport> SingleConnectionManager<H, T> { + /// Construct a new connection manager wrapping the given low-level VirtIO socket driver. + pub fn new(driver: VirtIOSocket<H, T>) -> Self { + Self { + driver, + connection_info: None, + } + } + + /// Returns the CID which has been assigned to this guest. + pub fn guest_cid(&self) -> u64 { + self.driver.guest_cid() + } + + /// Sends a request to connect to the given destination. + /// + /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a + /// `VsockEventType::Connected` event indicating that the peer has accepted the connection + /// before sending data. + pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result { + if self.connection_info.is_some() { + return Err(SocketError::ConnectionExists.into()); + } + + let new_connection_info = ConnectionInfo::new(destination, src_port); + + self.driver.connect(&new_connection_info)?; + debug!("Connection requested: {:?}", new_connection_info); + self.connection_info = Some(new_connection_info); + Ok(()) + } + + /// Sends the buffer to the destination. + pub fn send(&mut self, buffer: &[u8]) -> Result { + let connection_info = self + .connection_info + .as_mut() + .ok_or(SocketError::NotConnected)?; + connection_info.buf_alloc = 0; + self.driver.send(buffer, connection_info) + } + + /// Polls the vsock device to receive data or other updates. + /// + /// A buffer must be provided to put the data in if there is some to + /// receive. + pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> { + let Some(connection_info) = &mut self.connection_info else { + return Err(SocketError::NotConnected.into()); + }; + + // Tell the peer that we have space to receive some data. + connection_info.buf_alloc = buffer.len() as u32; + self.driver.credit_update(connection_info)?; + + self.poll_rx_queue(buffer) + } + + /// Blocks until we get some event from the vsock device. + /// + /// A buffer must be provided to put the data in if there is some to + /// receive. + pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> { + loop { + if let Some(event) = self.poll_recv(buffer)? { + return Ok(event); + } else { + spin_loop(); + } + } + } + + fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> { + let guest_cid = self.driver.guest_cid(); + let self_connection_info = &mut self.connection_info; + + self.driver.poll(|event, borrowed_body| { + let Some(connection_info) = self_connection_info else { + return Ok(None); + }; + + // Skip packets which don't match our current connection. + if !event.matches_connection(connection_info, guest_cid) { + debug!( + "Skipping {:?} as connection is {:?}", + event, connection_info + ); + return Ok(None); + } + + // Update stored connection info. + connection_info.update_for_event(&event); + + match event.event_type { + VsockEventType::ConnectionRequest => { + // TODO: Send Rst or handle incoming connections. + } + VsockEventType::Connected => {} + VsockEventType::Disconnected { .. } => { + *self_connection_info = None; + } + VsockEventType::Received { length } => { + body.get_mut(0..length) + .ok_or(SocketError::OutputBufferTooShort(length))? + .copy_from_slice(borrowed_body); + connection_info.done_forwarding(length); + } + VsockEventType::CreditRequest => { + // No point sending a credit update until `poll_recv` is called with a buffer, + // as otherwise buf_alloc would just be 0 anyway. + } + VsockEventType::CreditUpdate => {} + } + + Ok(Some(event)) + }) + } + + /// Requests to shut down the connection cleanly. + /// + /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a + /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the + /// shutdown. + pub fn shutdown(&mut self) -> Result { + let connection_info = self + .connection_info + .as_mut() + .ok_or(SocketError::NotConnected)?; + connection_info.buf_alloc = 0; + + self.driver.shutdown(connection_info) + } + + /// Forcibly closes the connection without waiting for the peer. + pub fn force_close(&mut self) -> Result { + let connection_info = self + .connection_info + .as_mut() + .ok_or(SocketError::NotConnected)?; + connection_info.buf_alloc = 0; + + self.driver.force_close(connection_info)?; + self.connection_info = None; + Ok(()) + } + + /// Blocks until the peer either accepts our connection request (with a + /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a + /// `VIRTIO_VSOCK_OP_RST`). + pub fn wait_for_connect(&mut self) -> Result { + loop { + match self.wait_for_recv(&mut [])?.event_type { + VsockEventType::Connected => return Ok(()), + VsockEventType::Disconnected { .. } => { + return Err(SocketError::ConnectionFailed.into()) + } + VsockEventType::Received { .. } => return Err(SocketError::InvalidOperation.into()), + VsockEventType::ConnectionRequest + | VsockEventType::CreditRequest + | VsockEventType::CreditUpdate => {} + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + device::socket::{ + protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp}, + vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX}, + }, + hal::fake::FakeHal, + transport::{ + fake::{FakeTransport, QueueStatus, State}, + DeviceStatus, DeviceType, + }, + volatile::ReadOnly, + }; + use alloc::{sync::Arc, vec}; + use core::{mem::size_of, ptr::NonNull}; + use std::{sync::Mutex, thread}; + use zerocopy::{AsBytes, FromBytes}; + + #[test] + fn send_recv() { + let host_cid = 2; + let guest_cid = 66; + let host_port = 1234; + let guest_port = 4321; + let host_address = VsockAddr { + cid: host_cid, + port: host_port, + }; + let hello_from_guest = "Hello from guest"; + let hello_from_host = "Hello from host"; + + let mut config_space = VirtioVsockConfig { + guest_cid_low: ReadOnly::new(66), + guest_cid_high: ReadOnly::new(0), + }; + let state = Arc::new(Mutex::new(State { + status: DeviceStatus::empty(), + driver_features: 0, + guest_page_size: 0, + interrupt_pending: false, + queues: vec![ + QueueStatus::default(), + QueueStatus::default(), + QueueStatus::default(), + ], + })); + let transport = FakeTransport { + device_type: DeviceType::Socket, + max_queue_size: 32, + device_features: 0, + config_space: NonNull::from(&mut config_space), + state: state.clone(), + }; + let mut socket = SingleConnectionManager::new( + VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(), + ); + + // Start a thread to simulate the device. + let handle = thread::spawn(move || { + // Wait for connection request. + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + assert_eq!( + VirtioVsockHdr::read_from( + state + .lock() + .unwrap() + .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX) + .as_slice() + ) + .unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::Request.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: guest_port.into(), + dst_port: host_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 0.into(), + fwd_cnt: 0.into(), + } + ); + + // Accept connection and give the peer enough credit to send the message. + state.lock().unwrap().write_to_queue::<QUEUE_SIZE>( + RX_QUEUE_IDX, + VirtioVsockHdr { + op: VirtioVsockOp::Response.into(), + src_cid: host_cid.into(), + dst_cid: guest_cid.into(), + src_port: host_port.into(), + dst_port: guest_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 50.into(), + fwd_cnt: 0.into(), + } + .as_bytes(), + ); + + // Expect a credit update. + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + assert_eq!( + VirtioVsockHdr::read_from( + state + .lock() + .unwrap() + .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX) + .as_slice() + ) + .unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::CreditUpdate.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: guest_port.into(), + dst_port: host_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 0.into(), + fwd_cnt: 0.into(), + } + ); + + // Expect the guest to send some data. + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + let request = state + .lock() + .unwrap() + .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX); + assert_eq!( + request.len(), + size_of::<VirtioVsockHdr>() + hello_from_guest.len() + ); + assert_eq!( + VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::Rw.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: guest_port.into(), + dst_port: host_port.into(), + len: (hello_from_guest.len() as u32).into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 0.into(), + fwd_cnt: 0.into(), + } + ); + assert_eq!( + &request[size_of::<VirtioVsockHdr>()..], + hello_from_guest.as_bytes() + ); + + // Send a response. + let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()]; + VirtioVsockHdr { + op: VirtioVsockOp::Rw.into(), + src_cid: host_cid.into(), + dst_cid: guest_cid.into(), + src_port: host_port.into(), + dst_port: guest_port.into(), + len: (hello_from_host.len() as u32).into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 50.into(), + fwd_cnt: (hello_from_guest.len() as u32).into(), + } + .write_to_prefix(response.as_mut_slice()); + response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes()); + state + .lock() + .unwrap() + .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response); + + // Expect a credit update. + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + assert_eq!( + VirtioVsockHdr::read_from( + state + .lock() + .unwrap() + .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX) + .as_slice() + ) + .unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::CreditUpdate.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: guest_port.into(), + dst_port: host_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 64.into(), + fwd_cnt: 0.into(), + } + ); + + // Expect a shutdown. + State::wait_until_queue_notified(&state, TX_QUEUE_IDX); + assert_eq!( + VirtioVsockHdr::read_from( + state + .lock() + .unwrap() + .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX) + .as_slice() + ) + .unwrap(), + VirtioVsockHdr { + op: VirtioVsockOp::Shutdown.into(), + src_cid: guest_cid.into(), + dst_cid: host_cid.into(), + src_port: guest_port.into(), + dst_port: host_port.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + flags: 0.into(), + buf_alloc: 0.into(), + fwd_cnt: (hello_from_host.len() as u32).into(), + } + ); + }); + + socket.connect(host_address, guest_port).unwrap(); + socket.wait_for_connect().unwrap(); + socket.send(hello_from_guest.as_bytes()).unwrap(); + let mut buffer = [0u8; 64]; + let event = socket.wait_for_recv(&mut buffer).unwrap(); + assert_eq!( + event, + VsockEvent { + source: VsockAddr { + cid: host_cid, + port: host_port, + }, + destination: VsockAddr { + cid: guest_cid, + port: guest_port, + }, + event_type: VsockEventType::Received { + length: hello_from_host.len() + }, + buffer_status: VsockBufferStatus { + buffer_allocation: 50, + forward_count: hello_from_guest.len() as u32, + }, + } + ); + assert_eq!( + &buffer[0..hello_from_host.len()], + hello_from_host.as_bytes() + ); + socket.shutdown().unwrap(); + + handle.join().unwrap(); + } +} diff --git a/src/device/socket/vsock.rs b/src/device/socket/vsock.rs index 686d7a6..523930e 100644 --- a/src/device/socket/vsock.rs +++ b/src/device/socket/vsock.rs @@ -2,29 +2,31 @@ #![deny(unsafe_op_in_unsafe_fn)] use super::error::SocketError; -use super::protocol::{VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr}; -use crate::device::common::Feature; -use crate::hal::{BufferDirection, Dma, Hal}; +use super::protocol::{Feature, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr}; +use crate::hal::Hal; use crate::queue::VirtQueue; use crate::transport::Transport; use crate::volatile::volread; -use crate::Result; -use core::hint::spin_loop; +use crate::{Error, Result}; +use alloc::boxed::Box; use core::mem::size_of; -use core::ptr::NonNull; -use log::{debug, info}; +use core::ptr::{null_mut, NonNull}; +use log::debug; use zerocopy::{AsBytes, FromBytes}; -const RX_QUEUE_IDX: u16 = 0; -const TX_QUEUE_IDX: u16 = 1; +pub(crate) const RX_QUEUE_IDX: u16 = 0; +pub(crate) const TX_QUEUE_IDX: u16 = 1; const EVENT_QUEUE_IDX: u16 = 2; -const QUEUE_SIZE: usize = 8; +pub(crate) const QUEUE_SIZE: usize = 8; + +/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than size_of::<VirtioVsockHdr>(). +const RX_BUFFER_SIZE: usize = 512; #[derive(Clone, Debug, Default, PartialEq, Eq)] -struct ConnectionInfo { - dst: VsockAddr, - src_port: u32, +pub struct ConnectionInfo { + pub dst: VsockAddr, + pub src_port: u32, /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in /// bytes it has allocated for packet bodies. peer_buf_alloc: u32, @@ -33,6 +35,9 @@ struct ConnectionInfo { peer_fwd_cnt: u32, /// The number of bytes of packet bodies which we have sent to the peer. tx_cnt: u32, + /// The number of bytes of buffer space we have allocated to receive packet bodies from the + /// peer. + pub buf_alloc: u32, /// The number of bytes of packet bodies which we have received from the peer and handled. fwd_cnt: u32, /// Whether we have recently requested credit from the peer. @@ -43,6 +48,35 @@ struct ConnectionInfo { } impl ConnectionInfo { + pub fn new(destination: VsockAddr, src_port: u32) -> Self { + Self { + dst: destination, + src_port, + ..Default::default() + } + } + + /// Updates this connection info with the peer buffer allocation and forwarded count from the + /// given event. + pub fn update_for_event(&mut self, event: &VsockEvent) { + self.peer_buf_alloc = event.buffer_status.buffer_allocation; + self.peer_fwd_cnt = event.buffer_status.forward_count; + + if let VsockEventType::CreditUpdate = event.event_type { + self.has_pending_credit_request = false; + } + } + + /// Increases the forwarded count recorded for this connection by the given number of bytes. + /// + /// This should be called once received data has been passed to the client, so there is buffer + /// space available for more. + pub fn done_forwarding(&mut self, length: usize) { + self.fwd_cnt += length as u32; + } + + /// Returns the number of bytes of RX buffer space the peer has available to receive packet body + /// data from us. fn peer_free(&self) -> u32 { self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt) } @@ -53,6 +87,7 @@ impl ConnectionInfo { dst_cid: self.dst.cid.into(), src_port: self.src_port.into(), dst_port: self.dst.port.into(), + buf_alloc: self.buf_alloc.into(), fwd_cnt: self.fwd_cnt.into(), ..Default::default() } @@ -66,10 +101,77 @@ pub struct VsockEvent { pub source: VsockAddr, /// The destination of the event, i.e. the CID and port on our side. pub destination: VsockAddr, + /// The peer's buffer status for the connection. + pub buffer_status: VsockBufferStatus, /// The type of event. pub event_type: VsockEventType, } +impl VsockEvent { + /// Returns whether the event matches the given connection. + pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool { + self.source == connection_info.dst + && self.destination.cid == guest_cid + && self.destination.port == connection_info.src_port + } + + fn from_header(header: &VirtioVsockHdr) -> Result<Self> { + let op = header.op()?; + let buffer_status = VsockBufferStatus { + buffer_allocation: header.buf_alloc.into(), + forward_count: header.fwd_cnt.into(), + }; + let source = header.source(); + let destination = header.destination(); + + let event_type = match op { + VirtioVsockOp::Request => { + header.check_data_is_empty()?; + VsockEventType::ConnectionRequest + } + VirtioVsockOp::Response => { + header.check_data_is_empty()?; + VsockEventType::Connected + } + VirtioVsockOp::CreditUpdate => { + header.check_data_is_empty()?; + VsockEventType::CreditUpdate + } + VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => { + header.check_data_is_empty()?; + debug!("Disconnected from the peer"); + let reason = if op == VirtioVsockOp::Rst { + DisconnectReason::Reset + } else { + DisconnectReason::Shutdown + }; + VsockEventType::Disconnected { reason } + } + VirtioVsockOp::Rw => VsockEventType::Received { + length: header.len() as usize, + }, + VirtioVsockOp::CreditRequest => { + header.check_data_is_empty()?; + VsockEventType::CreditRequest + } + VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation.into()), + }; + + Ok(VsockEvent { + source, + destination, + buffer_status, + event_type, + }) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct VsockBufferStatus { + pub buffer_allocation: u32, + pub forward_count: u32, +} + /// The reason why a vsock connection was closed. #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum DisconnectReason { @@ -83,6 +185,8 @@ pub enum DisconnectReason { /// Details of the type of an event received from a VirtIO socket. #[derive(Clone, Debug, Eq, PartialEq)] pub enum VsockEventType { + /// The peer requests to establish a connection with us. + ConnectionRequest, /// The connection was successfully established. Connected, /// The connection was closed. @@ -95,9 +199,16 @@ pub enum VsockEventType { /// The length of the data in bytes. length: usize, }, + /// The peer requests us to send a credit update. + CreditRequest, + /// The peer just sent us a credit update with nothing else. + CreditUpdate, } -/// Driver for a VirtIO socket device. +/// Low-level driver for a VirtIO socket device. +/// +/// You probably want to use [`VsockConnectionManager`](super::VsockConnectionManager) rather than +/// using this directly. pub struct VirtIOSocket<H: Hal, T: Transport> { transport: T, /// Virtqueue to receive packets. @@ -108,10 +219,7 @@ pub struct VirtIOSocket<H: Hal, T: Transport> { /// The guest_cid field contains the guest’s context ID, which uniquely identifies /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed. guest_cid: u64, - rx_buf_dma: Dma<H>, - - /// Currently the device is only allowed to be connected to one destination at a time. - connection_info: Option<ConnectionInfo>, + rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE], } impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> { @@ -121,6 +229,12 @@ impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> { self.transport.queue_unset(RX_QUEUE_IDX); self.transport.queue_unset(TX_QUEUE_IDX); self.transport.queue_unset(EVENT_QUEUE_IDX); + + for buffer in self.rx_queue_buffers { + // Safe because we obtained the RX buffer pointer from Box::into_raw, and it won't be + // used anywhere else after the driver is destroyed. + unsafe { drop(Box::from_raw(buffer.as_ptr())) }; + } } } @@ -129,35 +243,40 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> { pub fn new(mut transport: T) -> Result<Self> { transport.begin_init(|features| { let features = Feature::from_bits_truncate(features); - info!("Device features: {:?}", features); + debug!("Device features: {:?}", features); // negotiate these flags only let supported_features = Feature::empty(); (features & supported_features).bits() }); let config = transport.config_space::<VirtioVsockConfig>()?; - info!("config: {:?}", config); + debug!("config: {:?}", config); // Safe because config is a valid pointer to the device configuration space. let guest_cid = unsafe { volread!(config, guest_cid_low) as u64 | (volread!(config, guest_cid_high) as u64) << 32 }; - info!("guest cid: {guest_cid:?}"); + debug!("guest cid: {guest_cid:?}"); let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX)?; let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?; let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?; - // Allocates 4 KiB memory as the rx buffer. - let rx_buf_dma = Dma::new( - 1, // pages - BufferDirection::DeviceToDriver, - )?; - let rx_buf = rx_buf_dma.raw_slice(); - // Safe because `rx_buf` lives as long as the `rx` queue. - unsafe { - Self::fill_rx_queue(&mut rx, rx_buf, &mut transport)?; + // Allocate and add buffers for the RX queue. + let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE]; + for (i, rx_queue_buffer) in rx_queue_buffers.iter_mut().enumerate() { + let mut buffer: Box<[u8; RX_BUFFER_SIZE]> = FromBytes::new_box_zeroed(); + // Safe because the buffer lives as long as the queue, as specified in the function + // safety requirement, and we don't access it until it is popped. + let token = unsafe { rx.add(&[], &mut [buffer.as_mut_slice()]) }?; + assert_eq!(i, token.into()); + *rx_queue_buffer = Box::into_raw(buffer); } + let rx_queue_buffers = rx_queue_buffers.map(|ptr| NonNull::new(ptr).unwrap()); + transport.finish_init(); + if rx.should_notify() { + transport.notify(RX_QUEUE_IDX); + } Ok(Self { transport, @@ -165,85 +284,41 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> { tx, event, guest_cid, - rx_buf_dma, - connection_info: None, + rx_queue_buffers, }) } - /// Fills the `rx` queue with the buffer `rx_buf`. - /// - /// # Safety - /// - /// `rx_buf` must live at least as long as the `rx` queue, and the parts of the buffer which are - /// in the queue must not be used anywhere else at the same time. - unsafe fn fill_rx_queue( - rx: &mut VirtQueue<H, { QUEUE_SIZE }>, - rx_buf: NonNull<[u8]>, - transport: &mut T, - ) -> Result { - if rx_buf.len() < size_of::<VirtioVsockHdr>() * QUEUE_SIZE { - return Err(SocketError::BufferTooShort.into()); - } - for i in 0..QUEUE_SIZE { - // Safe because the buffer lives as long as the queue, as specified in the function - // safety requirement, and we don't access it until it is popped. - unsafe { - let buffer = Self::as_mut_sub_rx_buffer(rx_buf, i)?; - let token = rx.add(&[], &mut [buffer])?; - assert_eq!(i, token.into()); - } - } - - if rx.should_notify() { - transport.notify(RX_QUEUE_IDX); - } - Ok(()) + /// Returns the CID which has been assigned to this guest. + pub fn guest_cid(&self) -> u64 { + self.guest_cid } /// Sends a request to connect to the given destination. /// - /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a + /// This returns as soon as the request is sent; you should wait until `poll` returns a /// `VsockEventType::Connected` event indicating that the peer has accepted the connection /// before sending data. - pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result { - if self.connection_info.is_some() { - return Err(SocketError::ConnectionExists.into()); - } - let new_connection_info = ConnectionInfo { - dst: VsockAddr { - cid: dst_cid, - port: dst_port, - }, - src_port, - ..Default::default() - }; + pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result { let header = VirtioVsockHdr { op: VirtioVsockOp::Request.into(), - ..new_connection_info.new_header(self.guest_cid) + ..connection_info.new_header(self.guest_cid) }; - // Sends a header only packet to the tx queue to connect the device to the listening - // socket at the given destination. - self.send_packet_to_tx_queue(&header, &[])?; - - self.connection_info = Some(new_connection_info); - debug!("Connection requested: {:?}", self.connection_info); - Ok(()) + // Sends a header only packet to the TX queue to connect the device to the listening socket + // at the given destination. + self.send_packet_to_tx_queue(&header, &[]) } - /// Blocks until the peer either accepts our connection request (with a - /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a - /// `VIRTIO_VSOCK_OP_RST`). - pub fn wait_for_connect(&mut self) -> Result { - match self.wait_for_recv(&mut [])?.event_type { - VsockEventType::Connected => Ok(()), - VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()), - VsockEventType::Received { .. } => Err(SocketError::InvalidOperation.into()), - } + /// Accepts the given connection from a peer. + pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result { + let header = VirtioVsockHdr { + op: VirtioVsockOp::Response.into(), + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) } - /// Requests the peer to send us a credit update for the current connection. - fn request_credit(&mut self) -> Result { - let connection_info = self.connection_info()?; + /// Requests the peer to send us a credit update for the given connection. + fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result { let header = VirtioVsockHdr { op: VirtioVsockOp::CreditRequest.into(), ..connection_info.new_header(self.guest_cid) @@ -252,21 +327,16 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> { } /// Sends the buffer to the destination. - pub fn send(&mut self, buffer: &[u8]) -> Result { - let mut connection_info = self.connection_info()?; - - let result = self.check_peer_buffer_is_sufficient(&mut connection_info, buffer.len()); - self.connection_info = Some(connection_info.clone()); - result?; + pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result { + self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?; let len = buffer.len() as u32; let header = VirtioVsockHdr { op: VirtioVsockOp::Rw.into(), len: len.into(), - buf_alloc: 0.into(), ..connection_info.new_header(self.guest_cid) }; - self.connection_info.as_mut().unwrap().tx_cnt += len; + connection_info.tx_cnt += len; self.send_packet_to_tx_queue(&header, buffer) } @@ -281,59 +351,48 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> { // Request an update of the cached peer credit, if we haven't already done so, and tell // the caller to try again later. if !connection_info.has_pending_credit_request { - self.request_credit()?; + self.request_credit(connection_info)?; connection_info.has_pending_credit_request = true; } Err(SocketError::InsufficientBufferSpaceInPeer.into()) } } - /// Polls the vsock device to receive data or other updates. - /// - /// A buffer must be provided to put the data in if there is some to - /// receive. - pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> { - let connection_info = self.connection_info()?; - - // Tell the peer that we have space to receive some data. + /// Tells the peer how much buffer space we have to receive data. + pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result { let header = VirtioVsockHdr { op: VirtioVsockOp::CreditUpdate.into(), - buf_alloc: (buffer.len() as u32).into(), ..connection_info.new_header(self.guest_cid) }; - self.send_packet_to_tx_queue(&header, &[])?; - - // Handle entries from the RX virtqueue until we find one that generates an event. - let event = self.poll_rx_queue(buffer)?; + self.send_packet_to_tx_queue(&header, &[]) + } - if self.rx.should_notify() { - self.transport.notify(RX_QUEUE_IDX); - } + /// Polls the RX virtqueue for the next event, and calls the given handler function to handle + /// it. + pub fn poll( + &mut self, + handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>, + ) -> Result<Option<VsockEvent>> { + let Some((header, body, token)) = self.pop_packet_from_rx_queue()? else { + return Ok(None); + }; - Ok(event) - } + let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body)); - /// Blocks until we get some event from the vsock device. - /// - /// A buffer must be provided to put the data in if there is some to - /// receive. - pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> { - loop { - if let Some(event) = self.poll_recv(buffer)? { - return Ok(event); - } else { - spin_loop(); - } + unsafe { + // TODO: What about if both handler and this give errors? + self.add_buffer_to_rx_queue(token)?; } + + result } - /// Request to shut down the connection cleanly. + /// Requests to shut down the connection cleanly. /// - /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a + /// This returns as soon as the request is sent; you should wait until `poll` returns a /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the /// shutdown. - pub fn shutdown(&mut self) -> Result { - let connection_info = self.connection_info()?; + pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result { let header = VirtioVsockHdr { op: VirtioVsockOp::Shutdown.into(), ..connection_info.new_header(self.guest_cid) @@ -342,14 +401,12 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> { } /// Forcibly closes the connection without waiting for the peer. - pub fn force_close(&mut self) -> Result { - let connection_info = self.connection_info()?; + pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result { let header = VirtioVsockHdr { op: VirtioVsockOp::Rst.into(), ..connection_info.new_header(self.guest_cid) }; self.send_packet_to_tx_queue(&header, &[])?; - self.connection_info = None; Ok(()) } @@ -362,118 +419,39 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> { Ok(()) } - /// Polls the RX virtqueue until either it is empty, there is an error, or we find a packet - /// which generates a `VsockEvent`. + /// Adds the buffer at the given index in `rx_queue_buffers` back to the RX queue. /// - /// Returns `Ok(None)` if the virtqueue is empty, possibly after processing some packets which - /// don't result in any events to return. - fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> { - loop { - let mut connection_info = self.connection_info.clone().unwrap_or_default(); - let Some(header) = self.pop_packet_from_rx_queue(body)? else{ - return Ok(None); - }; - - let op = header.op()?; - - // Skip packets which don't match our current connection. - if header.source() != connection_info.dst - || header.dst_cid.get() != self.guest_cid - || header.dst_port.get() != connection_info.src_port - { - debug!( - "Skipping {:?} as connection is {:?}", - header, connection_info - ); - continue; - } - - connection_info.peer_buf_alloc = header.buf_alloc.into(); - connection_info.peer_fwd_cnt = header.fwd_cnt.into(); - if self.connection_info.is_some() { - self.connection_info = Some(connection_info.clone()); - debug!("Connection info updated: {:?}", self.connection_info); - } + /// # Safety + /// + /// The buffer must not currently be in the RX queue, and no other references to it must exist + /// between when this method is called and when it is popped from the queue. + unsafe fn add_buffer_to_rx_queue(&mut self, index: u16) -> Result { + // Safe because the buffer lives as long as the queue, and the caller guarantees that it's + // not currently in the queue or referred to anywhere else until it is popped. + unsafe { + let buffer = self + .rx_queue_buffers + .get_mut(usize::from(index)) + .ok_or(Error::WrongToken)? + .as_mut(); + let new_token = self.rx.add(&[], &mut [buffer])?; + // If the RX buffer somehow gets assigned a different token, then our safety assumptions + // are broken and we can't safely continue to do anything with the device. + assert_eq!(new_token, index); + } - match op { - VirtioVsockOp::Request => { - header.check_data_is_empty()?; - // TODO: Send a Rst, or support listening. - } - VirtioVsockOp::Response => { - header.check_data_is_empty()?; - return Ok(Some(VsockEvent { - source: connection_info.dst, - destination: VsockAddr { - cid: self.guest_cid, - port: connection_info.src_port, - }, - event_type: VsockEventType::Connected, - })); - } - VirtioVsockOp::CreditUpdate => { - header.check_data_is_empty()?; - connection_info.has_pending_credit_request = false; - if self.connection_info.is_some() { - self.connection_info = Some(connection_info.clone()); - } - - // Virtio v1.1 5.10.6.3 - // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously - // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates - // any time a change in buffer space occurs. - continue; - } - VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => { - header.check_data_is_empty()?; - - self.connection_info = None; - info!("Disconnected from the peer"); - - let reason = if op == VirtioVsockOp::Rst { - DisconnectReason::Reset - } else { - DisconnectReason::Shutdown - }; - return Ok(Some(VsockEvent { - source: connection_info.dst, - destination: VsockAddr { - cid: self.guest_cid, - port: connection_info.src_port, - }, - event_type: VsockEventType::Disconnected { reason }, - })); - } - VirtioVsockOp::Rw => { - self.connection_info.as_mut().unwrap().fwd_cnt += header.len(); - return Ok(Some(VsockEvent { - source: connection_info.dst, - destination: VsockAddr { - cid: self.guest_cid, - port: connection_info.src_port, - }, - event_type: VsockEventType::Received { - length: header.len() as usize, - }, - })); - } - VirtioVsockOp::CreditRequest => { - header.check_data_is_empty()?; - // TODO: Send a credit update. - } - VirtioVsockOp::Invalid => { - return Err(SocketError::InvalidOperation.into()); - } - } + if self.rx.should_notify() { + self.transport.notify(RX_QUEUE_IDX); } + + Ok(()) } - /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies - /// the body into the given buffer. + /// Pops one packet from the RX queue, if there is one pending. Returns the header, and a + /// reference to the buffer containing the body. /// - /// Returns `None` if there is no pending packet, or an error if the body is bigger than the - /// buffer supplied. - fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VirtioVsockHdr>> { + /// Returns `None` if there is no pending packet. + fn pop_packet_from_rx_queue(&mut self) -> Result<Option<(VirtioVsockHdr, &[u8], u16)>> { let Some(token) = self.rx.peek_used() else { return Ok(None); }; @@ -481,89 +459,53 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> { // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same // buffer to `pop_used` as we previously passed to `add` for the token. Once we add the // buffer back to the RX queue then we don't access it again until next time it is popped. - let header = unsafe { - let buffer = Self::as_mut_sub_rx_buffer(self.rx_buf_dma.raw_slice(), token.into())?; + let (header, body) = unsafe { + let buffer = self.rx_queue_buffers[usize::from(token)].as_mut(); let _len = self.rx.pop_used(token, &[], &mut [buffer])?; // Read the header and body from the buffer. Don't check the result yet, because we need // to add the buffer back to the queue either way. - let header_result = read_header_and_body(buffer, body); - - // Add the buffer back to the RX queue. - let new_token = self.rx.add(&[], &mut [buffer])?; - // If the RX buffer somehow gets assigned a different token, then our safety assumptions - // are broken and we can't safely continue to do anything with the device. - assert_eq!(new_token, token); + let header_result = read_header_and_body(buffer); + if header_result.is_err() { + // If there was an error, add the buffer back immediately. Ignore any errors, as we + // need to return the first error. + let _ = self.add_buffer_to_rx_queue(token); + } header_result }?; debug!("Received packet {:?}. Op {:?}", header, header.op()); - Ok(Some(header)) - } - - fn connection_info(&self) -> Result<ConnectionInfo> { - self.connection_info - .clone() - .ok_or(SocketError::NotConnected.into()) - } - - /// Gets a reference to a subslice of the RX buffer to be used for the given entry in the RX - /// virtqueue. - /// - /// # Safety - /// - /// `rx_buf` must be a valid dereferenceable pointer. - /// The returned reference has an arbitrary lifetime `'a`. This lifetime must not overlap with - /// any other references to the same subslice of the RX buffer or outlive the buffer. - unsafe fn as_mut_sub_rx_buffer<'a>( - mut rx_buf: NonNull<[u8]>, - i: usize, - ) -> Result<&'a mut [u8]> { - let buffer_size = rx_buf.len() / QUEUE_SIZE; - let start = buffer_size - .checked_mul(i) - .ok_or(SocketError::InvalidNumber)?; - let end = start - .checked_add(buffer_size) - .ok_or(SocketError::InvalidNumber)?; - // Safe because no alignment or initialisation is required for [u8], and our caller assures - // us that `rx_buf` is dereferenceable and that the lifetime of the slice we are creating - // won't overlap with any other references to the same slice or outlive it. - unsafe { - rx_buf - .as_mut() - .get_mut(start..end) - .ok_or(SocketError::BufferTooShort.into()) - } + Ok(Some((header, body, token))) } } -fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> { - let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?; +fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> { + // Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>()`. + let header = VirtioVsockHdr::read_from_prefix(buffer).unwrap(); let body_length = header.len() as usize; + + // This could fail if the device returns an unreasonably long body length. let data_end = size_of::<VirtioVsockHdr>() .checked_add(body_length) .ok_or(SocketError::InvalidNumber)?; + // This could fail if the device returns a body length longer than the buffer we gave it. let data = buffer .get(size_of::<VirtioVsockHdr>()..data_end) .ok_or(SocketError::BufferTooShort)?; - body.get_mut(0..body_length) - .ok_or(SocketError::OutputBufferTooShort(body_length))? - .copy_from_slice(data); - Ok(header) + Ok((header, data)) } #[cfg(test)] mod tests { use super::*; - use crate::volatile::ReadOnly; use crate::{ hal::fake::FakeHal, transport::{ fake::{FakeTransport, QueueStatus, State}, DeviceStatus, DeviceType, }, + volatile::ReadOnly, }; use alloc::{sync::Arc, vec}; use core::ptr::NonNull; @@ -580,7 +522,11 @@ mod tests { driver_features: 0, guest_page_size: 0, interrupt_pending: false, - queues: vec![QueueStatus::default(); 3], + queues: vec![ + QueueStatus::default(), + QueueStatus::default(), + QueueStatus::default(), + ], })); let transport = FakeTransport { device_type: DeviceType::Socket, @@ -591,6 +537,6 @@ mod tests { }; let socket = VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(); - assert_eq!(socket.guest_cid, 0x00_0000_0042); + assert_eq!(socket.guest_cid(), 0x00_0000_0042); } } @@ -19,6 +19,8 @@ pub struct Dma<H: Hal> { impl<H: Hal> Dma<H> { /// Allocates the given number of pages of physically contiguous memory to be used for DMA in /// the given direction. + /// + /// The pages will be zeroed. pub fn new(pages: usize, direction: BufferDirection) -> Result<Self> { let (paddr, vaddr) = H::dma_alloc(pages, direction); if paddr == 0 { @@ -67,7 +69,8 @@ impl<H: Hal> Drop for Dma<H> { /// Implementations of this trait must follow the "implementation safety" requirements documented /// for each method. Callers must follow the safety requirements documented for the unsafe methods. pub unsafe trait Hal { - /// Allocates the given number of contiguous physical pages of DMA memory for VirtIO use. + /// Allocates and zeroes the given number of contiguous physical pages of DMA memory for VirtIO + /// use. /// /// Returns both the physical address which the device can use to access the memory, and a /// pointer to the start of it which the driver can use to access it. @@ -77,7 +80,7 @@ pub unsafe trait Hal { /// Implementations of this method must ensure that the `NonNull<u8>` returned is a /// [_valid_](https://doc.rust-lang.org/std/ptr/index.html#safety) pointer, aligned to /// [`PAGE_SIZE`], and won't alias any other allocations or references in the program until it - /// is deallocated by `dma_dealloc`. + /// is deallocated by `dma_dealloc`. The pages must be zeroed. fn dma_alloc(pages: usize, direction: BufferDirection) -> (PhysAddr, NonNull<u8>); /// Deallocates the given contiguous physical DMA memory pages. @@ -24,12 +24,14 @@ //! //! ``` //! # use virtio_drivers::Hal; +//! # #[cfg(feature = "alloc")] //! use virtio_drivers::{ //! device::console::VirtIOConsole, //! transport::{mmio::MmioTransport, DeviceType, Transport}, //! }; //! +//! # #[cfg(feature = "alloc")] //! # fn example<HalImpl: Hal>(transport: MmioTransport) { //! if transport.device_type() == DeviceType::Console { //! let mut console = VirtIOConsole::<HalImpl, _>::new(transport).unwrap(); diff --git a/src/queue.rs b/src/queue.rs index d6baf17..758c139 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -60,7 +60,7 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { } if !SIZE.is_power_of_two() || SIZE > u16::MAX.into() - || transport.max_queue_size() < SIZE as u32 + || transport.max_queue_size(idx) < SIZE as u32 { return Err(Error::InvalidParam); } @@ -555,10 +555,13 @@ impl Descriptor { } } +/// Descriptor flags +#[derive(Copy, Clone, Debug, Default, Eq, FromBytes, PartialEq)] +#[repr(transparent)] +struct DescFlags(u16); + bitflags! { - /// Descriptor flags - #[derive(FromBytes)] - struct DescFlags: u16 { + impl DescFlags: u16 { const NEXT = 1; const WRITE = 2; const INDIRECT = 4; diff --git a/src/transport/fake.rs b/src/transport/fake.rs index a578db2..6ab61fc 100644 --- a/src/transport/fake.rs +++ b/src/transport/fake.rs @@ -4,8 +4,13 @@ use crate::{ PhysAddr, Result, }; use alloc::{sync::Arc, vec::Vec}; -use core::{any::TypeId, ptr::NonNull}; -use std::sync::Mutex; +use core::{ + any::TypeId, + ptr::NonNull, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use std::{sync::Mutex, thread}; /// A fake implementation of [`Transport`] for unit tests. #[derive(Debug)] @@ -30,12 +35,14 @@ impl<C> Transport for FakeTransport<C> { self.state.lock().unwrap().driver_features = driver_features; } - fn max_queue_size(&self) -> u32 { + fn max_queue_size(&mut self, _queue: u16) -> u32 { self.max_queue_size } fn notify(&mut self, queue: u16) { - self.state.lock().unwrap().queues[queue as usize].notified = true; + self.state.lock().unwrap().queues[queue as usize] + .notified + .store(true, Ordering::SeqCst); } fn get_status(&self) -> DeviceStatus { @@ -168,13 +175,23 @@ impl State { handler, ) } + + /// Waits until the given queue is notified. + pub fn wait_until_queue_notified(state: &Mutex<Self>, queue_index: u16) { + while !state.lock().unwrap().queues[usize::from(queue_index)] + .notified + .swap(false, Ordering::SeqCst) + { + thread::sleep(Duration::from_millis(10)); + } + } } -#[derive(Clone, Debug, Default, Eq, PartialEq)] +#[derive(Debug, Default)] pub struct QueueStatus { pub size: u32, pub descriptors: PhysAddr, pub driver_area: PhysAddr, pub device_area: PhysAddr, - pub notified: bool, + pub notified: AtomicBool, } diff --git a/src/transport/mmio.rs b/src/transport/mmio.rs index 026646b..d938a97 100644 --- a/src/transport/mmio.rs +++ b/src/transport/mmio.rs @@ -338,9 +338,12 @@ impl Transport for MmioTransport { } } - fn max_queue_size(&self) -> u32 { + fn max_queue_size(&mut self, queue: u16) -> u32 { // Safe because self.header points to a valid VirtIO MMIO region. - unsafe { volread!(self.header, queue_num_max) } + unsafe { + volwrite!(self.header, queue_sel, queue.into()); + volread!(self.header, queue_num_max) + } } fn notify(&mut self, queue: u16) { diff --git a/src/transport/mod.rs b/src/transport/mod.rs index f88293c..3157e81 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -20,8 +20,8 @@ pub trait Transport { /// Writes device features. fn write_driver_features(&mut self, driver_features: u64); - /// Gets the max size of queue. - fn max_queue_size(&self) -> u32; + /// Gets the max size of the given queue. + fn max_queue_size(&mut self, queue: u16) -> u32; /// Notifies the given queue on the device. fn notify(&mut self, queue: u16); @@ -65,6 +65,7 @@ pub trait Transport { /// /// Ref: virtio 3.1.1 Device Initialization fn begin_init(&mut self, negotiate_features: impl FnOnce(u64) -> u64) { + self.set_status(DeviceStatus::empty()); self.set_status(DeviceStatus::ACKNOWLEDGE | DeviceStatus::DRIVER); let features = self.read_device_features(); @@ -91,8 +92,8 @@ pub trait Transport { } bitflags! { - /// The device status field. - #[derive(Default)] + /// The device status field. Writing 0 into this field resets the device. + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] pub struct DeviceStatus: u32 { /// Indicates that the guest OS has found the device and recognized it /// as a valid virtio device. diff --git a/src/transport/pci.rs b/src/transport/pci.rs index b8bcb15..27401fe 100644 --- a/src/transport/pci.rs +++ b/src/transport/pci.rs @@ -231,10 +231,13 @@ impl Transport for PciTransport { } } - fn max_queue_size(&self) -> u32 { + fn max_queue_size(&mut self, queue: u16) -> u32 { // Safe because the common config pointer is valid and we checked in get_bar_region that it // was aligned. - unsafe { volread!(self.common_cfg, queue_size) }.into() + unsafe { + volwrite!(self.common_cfg, queue_select, queue); + volread!(self.common_cfg, queue_size).into() + } } fn notify(&mut self, queue: u16) { diff --git a/src/transport/pci/bus.rs b/src/transport/pci/bus.rs index dd6f520..0a3014b 100644 --- a/src/transport/pci/bus.rs +++ b/src/transport/pci/bus.rs @@ -24,6 +24,7 @@ pub const PCI_CAP_ID_VNDR: u8 = 0x09; bitflags! { /// The status register in PCI configuration space. + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] pub struct Status: u16 { // Bits 0-2 are reserved. /// The state of the device's INTx# signal. @@ -53,6 +54,7 @@ bitflags! { bitflags! { /// The command register in PCI configuration space. + #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] pub struct Command: u16 { /// The device can respond to I/O Space accesses. const IO_SPACE = 1 << 0; |