aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Walbran <qwandor@google.com>2023-06-14 16:37:52 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2023-06-14 16:37:52 +0000
commit5cd661c16e064e6f55849e776e263950443c2ae8 (patch)
treeb374b383add3775d221d8a98373ff7b8c216b48a
parent149b5d2ec015fac8d0ece693d7b627b25d2f6b0b (diff)
parentde718b8f609090e37fff649b116c9ede2ba322ba (diff)
downloadvirtio-drivers-5cd661c16e064e6f55849e776e263950443c2ae8.tar.gz
Update to 0.5.0. am: 91b9730f88 am: acd64406dc am: 23d2bf9625 am: d75e97c86b am: 9bfad29090 am: de718b8f60
Original change: https://android-review.googlesource.com/c/platform/external/rust/crates/virtio-drivers/+/2624430 Change-Id: If437f323c679e5ef4744d3512c2d27d150034723 Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
-rw-r--r--.cargo_vcs_info.json2
-rw-r--r--.github/workflows/main.yml7
-rw-r--r--Android.bp12
-rw-r--r--Cargo.toml6
-rw-r--r--Cargo.toml.orig6
-rw-r--r--METADATA8
-rw-r--r--README.md7
-rw-r--r--cargo2android.json2
-rw-r--r--patches/Android.bp.patch21
-rw-r--r--src/device/blk.rs17
-rw-r--r--src/device/common.rs1
-rw-r--r--src/device/console.rs41
-rw-r--r--src/device/gpu.rs32
-rw-r--r--src/device/mod.rs2
-rw-r--r--src/device/net.rs11
-rw-r--r--src/device/socket/mod.rs20
-rw-r--r--src/device/socket/multiconnectionmanager.rs763
-rw-r--r--src/device/socket/protocol.rs31
-rw-r--r--src/device/socket/singleconnectionmanager.rs447
-rw-r--r--src/device/socket/vsock.rs544
-rw-r--r--src/hal.rs7
-rw-r--r--src/lib.rs2
-rw-r--r--src/queue.rs11
-rw-r--r--src/transport/fake.rs29
-rw-r--r--src/transport/mmio.rs7
-rw-r--r--src/transport/mod.rs9
-rw-r--r--src/transport/pci.rs7
-rw-r--r--src/transport/pci/bus.rs2
28 files changed, 1651 insertions, 403 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json
index d2f328c..ec9996c 100644
--- a/.cargo_vcs_info.json
+++ b/.cargo_vcs_info.json
@@ -1,6 +1,6 @@
{
"git": {
- "sha1": "8e52adace55c5e082ba2effffcb70bf480d76ec0"
+ "sha1": "de1c3b130e507702f13d142b9bee55670a4a2858"
},
"path_in_vcs": ""
} \ No newline at end of file
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 025b7e9..b991d6b 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -66,6 +66,7 @@ jobs:
example:
- aarch64
- riscv
+ - x86_64
include:
- example: aarch64
toolchain: stable
@@ -75,6 +76,10 @@ jobs:
toolchain: nightly-2022-11-03
target: riscv64imac-unknown-none-elf
packages: qemu-system-misc
+ - example: x86_64
+ toolchain: nightly
+ target: x86_64-unknown-none
+ packages: qemu-system-x86
steps:
- uses: actions/checkout@v2
- name: Install QEMU
@@ -93,4 +98,4 @@ jobs:
run: make kernel
- name: Run
working-directory: examples/${{ matrix.example }}
- run: QEMU_ARGS="-display none" make qemu
+ run: QEMU_ARGS="-display none" make qemu accel="off"
diff --git a/Android.bp b/Android.bp
index d111544..6132b98 100644
--- a/Android.bp
+++ b/Android.bp
@@ -24,11 +24,12 @@ rust_library_rlib {
name: "libvirtio_drivers",
crate_name: "virtio_drivers",
cargo_env_compat: true,
- cargo_pkg_version: "0.4.0",
+ cargo_pkg_version: "0.5.0",
srcs: ["src/lib.rs"],
edition: "2018",
+ features: ["alloc"],
rustlibs: [
- "libbitflags-1.3.2",
+ "libbitflags",
"liblog_rust_nostd",
"libzerocopy_nostd",
],
@@ -48,14 +49,15 @@ rust_test {
name: "virtio-drivers_test_src_lib",
crate_name: "virtio_drivers",
cargo_env_compat: true,
- cargo_pkg_version: "0.4.0",
+ cargo_pkg_version: "0.5.0",
srcs: ["src/lib.rs"],
test_suites: ["general-tests"],
auto_gen_config: true,
edition: "2018",
+ features: ["alloc"],
rustlibs: [
- "libbitflags-1.3.2",
+ "libbitflags",
"liblog_rust",
- "libzerocopy",
+ "libzerocopy_nostd",
],
}
diff --git a/Cargo.toml b/Cargo.toml
index 7f9968a..2fac186 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,7 +12,7 @@
[package]
edition = "2018"
name = "virtio-drivers"
-version = "0.4.0"
+version = "0.5.0"
authors = [
"Jiajie Chen <noc@jiegec.ac.cn>",
"Runji Wang <wangrunji0408@163.com>",
@@ -30,7 +30,7 @@ license = "MIT"
repository = "https://github.com/rcore-os/virtio-drivers"
[dependencies.bitflags]
-version = "1.3"
+version = "2.3.0"
[dependencies.log]
version = "0.4"
@@ -39,5 +39,5 @@ version = "0.4"
version = "0.6.1"
[features]
-alloc = []
+alloc = ["zerocopy/alloc"]
default = ["alloc"]
diff --git a/Cargo.toml.orig b/Cargo.toml.orig
index 431ccd2..0a47ce9 100644
--- a/Cargo.toml.orig
+++ b/Cargo.toml.orig
@@ -1,6 +1,6 @@
[package]
name = "virtio-drivers"
-version = "0.4.0"
+version = "0.5.0"
license = "MIT"
authors = [
"Jiajie Chen <noc@jiegec.ac.cn>",
@@ -16,9 +16,9 @@ categories = ["hardware-support", "no-std"]
[dependencies]
log = "0.4"
-bitflags = "1.3"
+bitflags = "2.3.0"
zerocopy = "0.6.1"
[features]
default = ["alloc"]
-alloc = []
+alloc = ["zerocopy/alloc"]
diff --git a/METADATA b/METADATA
index b865aa4..a90ace3 100644
--- a/METADATA
+++ b/METADATA
@@ -11,13 +11,13 @@ third_party {
}
url {
type: ARCHIVE
- value: "https://static.crates.io/crates/virtio-drivers/virtio-drivers-0.4.0.crate"
+ value: "https://static.crates.io/crates/virtio-drivers/virtio-drivers-0.5.0.crate"
}
- version: "0.4.0"
+ version: "0.5.0"
license_type: NOTICE
last_upgrade_date {
year: 2023
- month: 4
- day: 19
+ month: 6
+ day: 13
}
}
diff --git a/README.md b/README.md
index fdb61d8..ad63d11 100644
--- a/README.md
+++ b/README.md
@@ -44,7 +44,12 @@ VirtIO guest drivers in Rust. For **no_std** environment.
## Examples & Tests
-- x86_64 (TODO)
+### [x86_64](./examples/x86_64)
+
+```bash
+cd examples/x86_64
+make qemu
+```
### [aarch64](./examples/aarch64)
diff --git a/cargo2android.json b/cargo2android.json
index 8bd0ae6..ede59d0 100644
--- a/cargo2android.json
+++ b/cargo2android.json
@@ -1,7 +1,7 @@
{
"dependencies": true,
"device": true,
- "features": "",
+ "features": "alloc",
"force-rlib": true,
"no-host": "true",
"no-std": true,
diff --git a/patches/Android.bp.patch b/patches/Android.bp.patch
index 8896c52..2a063c6 100644
--- a/patches/Android.bp.patch
+++ b/patches/Android.bp.patch
@@ -1,26 +1,23 @@
diff --git a/Android.bp b/Android.bp
-index f635115..d111544 100644
+index e2b5dd5..6132b98 100644
--- a/Android.bp
+++ b/Android.bp
-@@ -28,9 +28,9 @@ rust_library_rlib {
- srcs: ["src/lib.rs"],
- edition: "2018",
+@@ -30,8 +30,8 @@ rust_library_rlib {
+ features: ["alloc"],
rustlibs: [
-- "libbitflags",
+ "libbitflags",
- "liblog_rust",
- "libzerocopy",
-+ "libbitflags-1.3.2",
+ "liblog_rust_nostd",
+ "libzerocopy_nostd",
],
apex_available: [
"//apex_available:platform",
-@@ -54,7 +54,7 @@ rust_test {
- auto_gen_config: true,
- edition: "2018",
+@@ -58,6 +58,6 @@ rust_test {
rustlibs: [
-- "libbitflags",
-+ "libbitflags-1.3.2",
+ "libbitflags",
"liblog_rust",
- "libzerocopy",
+- "libzerocopy",
++ "libzerocopy_nostd",
],
+ }
diff --git a/src/device/blk.rs b/src/device/blk.rs
index d095047..ea3aef0 100644
--- a/src/device/blk.rs
+++ b/src/device/blk.rs
@@ -417,6 +417,7 @@ impl Default for BlkResp {
pub const SECTOR_SIZE: usize = 512;
bitflags! {
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
struct BlkFeature: u64 {
/// Device supports request barriers. (legacy)
const BARRIER = 1 << 0;
@@ -477,7 +478,7 @@ mod tests {
};
use alloc::{sync::Arc, vec};
use core::{mem::size_of, ptr::NonNull};
- use std::{sync::Mutex, thread, time::Duration};
+ use std::{sync::Mutex, thread};
#[test]
fn config() {
@@ -500,7 +501,7 @@ mod tests {
driver_features: 0,
guest_page_size: 0,
interrupt_pending: false,
- queues: vec![QueueStatus::default(); 1],
+ queues: vec![QueueStatus::default()],
}));
let transport = FakeTransport {
device_type: DeviceType::Console,
@@ -536,7 +537,7 @@ mod tests {
driver_features: 0,
guest_page_size: 0,
interrupt_pending: false,
- queues: vec![QueueStatus::default(); 1],
+ queues: vec![QueueStatus::default()],
}));
let transport = FakeTransport {
device_type: DeviceType::Console,
@@ -550,9 +551,7 @@ mod tests {
// Start a thread to simulate the device waiting for a read request.
let handle = thread::spawn(move || {
println!("Device waiting for a request.");
- while !state.lock().unwrap().queues[usize::from(QUEUE)].notified {
- thread::sleep(Duration::from_millis(10));
- }
+ State::wait_until_queue_notified(&state, QUEUE);
println!("Transmit queue was notified.");
state
@@ -611,7 +610,7 @@ mod tests {
driver_features: 0,
guest_page_size: 0,
interrupt_pending: false,
- queues: vec![QueueStatus::default(); 1],
+ queues: vec![QueueStatus::default()],
}));
let transport = FakeTransport {
device_type: DeviceType::Console,
@@ -625,9 +624,7 @@ mod tests {
// Start a thread to simulate the device waiting for a write request.
let handle = thread::spawn(move || {
println!("Device waiting for a request.");
- while !state.lock().unwrap().queues[usize::from(QUEUE)].notified {
- thread::sleep(Duration::from_millis(10));
- }
+ State::wait_until_queue_notified(&state, QUEUE);
println!("Transmit queue was notified.");
state
diff --git a/src/device/common.rs b/src/device/common.rs
index 2c8be3e..1081319 100644
--- a/src/device/common.rs
+++ b/src/device/common.rs
@@ -3,6 +3,7 @@
use bitflags::bitflags;
bitflags! {
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
pub(crate) struct Feature: u64 {
// device independent
const NOTIFY_ON_EMPTY = 1 << 24; // legacy
diff --git a/src/device/console.rs b/src/device/console.rs
index e0b0356..7d3c7d4 100644
--- a/src/device/console.rs
+++ b/src/device/console.rs
@@ -1,10 +1,11 @@
//! Driver for VirtIO console devices.
-use crate::hal::{BufferDirection, Dma, Hal};
+use crate::hal::Hal;
use crate::queue::VirtQueue;
use crate::transport::Transport;
use crate::volatile::{volread, ReadOnly, WriteOnly};
-use crate::Result;
+use crate::{Result, PAGE_SIZE};
+use alloc::boxed::Box;
use bitflags::bitflags;
use core::ptr::NonNull;
use log::info;
@@ -38,13 +39,12 @@ const QUEUE_SIZE: usize = 2;
/// # Ok(())
/// # }
/// ```
-pub struct VirtIOConsole<'a, H: Hal, T: Transport> {
+pub struct VirtIOConsole<H: Hal, T: Transport> {
transport: T,
config_space: NonNull<Config>,
receiveq: VirtQueue<H, QUEUE_SIZE>,
transmitq: VirtQueue<H, QUEUE_SIZE>,
- queue_buf_dma: Dma<H>,
- queue_buf_rx: &'a mut [u8],
+ queue_buf_rx: Box<[u8; PAGE_SIZE]>,
cursor: usize,
pending_len: usize,
/// The token of the outstanding receive request, if there is one.
@@ -62,7 +62,7 @@ pub struct ConsoleInfo {
pub max_ports: u32,
}
-impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
+impl<H: Hal, T: Transport> VirtIOConsole<H, T> {
/// Creates a new VirtIO console driver.
pub fn new(mut transport: T) -> Result<Self> {
transport.begin_init(|features| {
@@ -74,12 +74,11 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
let config_space = transport.config_space::<Config>()?;
let receiveq = VirtQueue::new(&mut transport, QUEUE_RECEIVEQ_PORT_0)?;
let transmitq = VirtQueue::new(&mut transport, QUEUE_TRANSMITQ_PORT_0)?;
- let queue_buf_dma = Dma::new(1, BufferDirection::DeviceToDriver)?;
// Safe because no alignment or initialisation is required for [u8], the DMA buffer is
// dereferenceable, and the lifetime of the reference matches the lifetime of the DMA buffer
// (which we don't otherwise access).
- let queue_buf_rx = unsafe { queue_buf_dma.raw_slice().as_mut() };
+ let queue_buf_rx = Box::new([0; PAGE_SIZE]);
transport.finish_init();
let mut console = VirtIOConsole {
@@ -87,7 +86,6 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
config_space,
receiveq,
transmitq,
- queue_buf_dma,
queue_buf_rx,
cursor: 0,
pending_len: 0,
@@ -118,7 +116,10 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
if self.receive_token.is_none() && self.cursor == self.pending_len {
// Safe because the buffer lasts at least as long as the queue, and there are no other
// outstanding requests using the buffer.
- self.receive_token = Some(unsafe { self.receiveq.add(&[], &mut [self.queue_buf_rx]) }?);
+ self.receive_token = Some(unsafe {
+ self.receiveq
+ .add(&[], &mut [self.queue_buf_rx.as_mut_slice()])
+ }?);
if self.receiveq.should_notify() {
self.transport.notify(QUEUE_RECEIVEQ_PORT_0);
}
@@ -148,8 +149,11 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
// Safe because we are passing the same buffer as we passed to `VirtQueue::add` in
// `poll_retrieve` and it is still valid.
let len = unsafe {
- self.receiveq
- .pop_used(receive_token, &[], &mut [self.queue_buf_rx])?
+ self.receiveq.pop_used(
+ receive_token,
+ &[],
+ &mut [self.queue_buf_rx.as_mut_slice()],
+ )?
};
flag = true;
assert_ne!(len, 0);
@@ -188,7 +192,7 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
}
}
-impl<H: Hal, T: Transport> Drop for VirtIOConsole<'_, H, T> {
+impl<H: Hal, T: Transport> Drop for VirtIOConsole<H, T> {
fn drop(&mut self) {
// Clear any pointers pointing to DMA regions, so the device doesn't try to access them
// after they have been freed.
@@ -206,6 +210,7 @@ struct Config {
}
bitflags! {
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
struct Features: u64 {
const SIZE = 1 << 0;
const MULTIPORT = 1 << 1;
@@ -241,7 +246,7 @@ mod tests {
};
use alloc::{sync::Arc, vec};
use core::ptr::NonNull;
- use std::{sync::Mutex, thread, time::Duration};
+ use std::{sync::Mutex, thread};
#[test]
fn receive() {
@@ -256,7 +261,7 @@ mod tests {
driver_features: 0,
guest_page_size: 0,
interrupt_pending: false,
- queues: vec![QueueStatus::default(); 2],
+ queues: vec![QueueStatus::default(), QueueStatus::default()],
}));
let transport = FakeTransport {
device_type: DeviceType::Console,
@@ -304,7 +309,7 @@ mod tests {
driver_features: 0,
guest_page_size: 0,
interrupt_pending: false,
- queues: vec![QueueStatus::default(); 2],
+ queues: vec![QueueStatus::default(), QueueStatus::default()],
}));
let transport = FakeTransport {
device_type: DeviceType::Console,
@@ -318,9 +323,7 @@ mod tests {
// Start a thread to simulate the device waiting for characters.
let handle = thread::spawn(move || {
println!("Device waiting for a character.");
- while !state.lock().unwrap().queues[usize::from(QUEUE_TRANSMITQ_PORT_0)].notified {
- thread::sleep(Duration::from_millis(10));
- }
+ State::wait_until_queue_notified(&state, QUEUE_TRANSMITQ_PORT_0);
println!("Transmit queue was notified.");
let data = state
diff --git a/src/device/gpu.rs b/src/device/gpu.rs
index b1b53bd..43e1b76 100644
--- a/src/device/gpu.rs
+++ b/src/device/gpu.rs
@@ -4,7 +4,8 @@ use crate::hal::{BufferDirection, Dma, Hal};
use crate::queue::VirtQueue;
use crate::transport::Transport;
use crate::volatile::{volread, ReadOnly, Volatile, WriteOnly};
-use crate::{pages, Error, Result};
+use crate::{pages, Error, Result, PAGE_SIZE};
+use alloc::boxed::Box;
use bitflags::bitflags;
use log::info;
use zerocopy::{AsBytes, FromBytes};
@@ -18,7 +19,7 @@ const QUEUE_SIZE: u16 = 2;
/// a gpu with 3D support on the host machine.
/// In 2D mode the virtio-gpu device provides support for ARGB Hardware cursors
/// and multiple scanouts (aka heads).
-pub struct VirtIOGpu<'a, H: Hal, T: Transport> {
+pub struct VirtIOGpu<H: Hal, T: Transport> {
transport: T,
rect: Option<Rect>,
/// DMA area of frame buffer.
@@ -29,17 +30,13 @@ pub struct VirtIOGpu<'a, H: Hal, T: Transport> {
control_queue: VirtQueue<H, { QUEUE_SIZE as usize }>,
/// Queue for sending cursor commands.
cursor_queue: VirtQueue<H, { QUEUE_SIZE as usize }>,
- /// DMA region for sending data to the device.
- dma_send: Dma<H>,
- /// DMA region for receiving data from the device.
- dma_recv: Dma<H>,
/// Send buffer for queue.
- queue_buf_send: &'a mut [u8],
+ queue_buf_send: Box<[u8]>,
/// Recv buffer for queue.
- queue_buf_recv: &'a mut [u8],
+ queue_buf_recv: Box<[u8]>,
}
-impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
+impl<H: Hal, T: Transport> VirtIOGpu<H, T> {
/// Create a new VirtIO-Gpu driver.
pub fn new(mut transport: T) -> Result<Self> {
transport.begin_init(|features| {
@@ -63,10 +60,8 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
let control_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT)?;
let cursor_queue = VirtQueue::new(&mut transport, QUEUE_CURSOR)?;
- let dma_send = Dma::new(1, BufferDirection::DriverToDevice)?;
- let dma_recv = Dma::new(1, BufferDirection::DeviceToDriver)?;
- let queue_buf_send = unsafe { dma_send.raw_slice().as_mut() };
- let queue_buf_recv = unsafe { dma_recv.raw_slice().as_mut() };
+ let queue_buf_send = FromBytes::new_box_slice_zeroed(PAGE_SIZE);
+ let queue_buf_recv = FromBytes::new_box_slice_zeroed(PAGE_SIZE);
transport.finish_init();
@@ -77,8 +72,6 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
rect: None,
control_queue,
cursor_queue,
- dma_send,
- dma_recv,
queue_buf_send,
queue_buf_recv,
})
@@ -177,8 +170,8 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
fn request<Req: AsBytes, Rsp: FromBytes>(&mut self, req: Req) -> Result<Rsp> {
req.write_to_prefix(&mut *self.queue_buf_send).unwrap();
self.control_queue.add_notify_wait_pop(
- &[self.queue_buf_send],
- &mut [self.queue_buf_recv],
+ &[&self.queue_buf_send],
+ &mut [&mut self.queue_buf_recv],
&mut self.transport,
)?;
Ok(Rsp::read_from_prefix(&*self.queue_buf_recv).unwrap())
@@ -188,7 +181,7 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
fn cursor_request<Req: AsBytes>(&mut self, req: Req) -> Result {
req.write_to_prefix(&mut *self.queue_buf_send).unwrap();
self.cursor_queue.add_notify_wait_pop(
- &[self.queue_buf_send],
+ &[&self.queue_buf_send],
&mut [],
&mut self.transport,
)?;
@@ -286,7 +279,7 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
}
}
-impl<H: Hal, T: Transport> Drop for VirtIOGpu<'_, H, T> {
+impl<H: Hal, T: Transport> Drop for VirtIOGpu<H, T> {
fn drop(&mut self) {
// Clear any pointers pointing to DMA regions, so the device doesn't try to access them
// after they have been freed.
@@ -313,6 +306,7 @@ struct Config {
const EVENT_DISPLAY: u32 = 1 << 0;
bitflags! {
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
struct Features: u64 {
/// virgl 3D mode is supported.
const VIRGL = 1 << 0;
diff --git a/src/device/mod.rs b/src/device/mod.rs
index ca68901..00fa6fe 100644
--- a/src/device/mod.rs
+++ b/src/device/mod.rs
@@ -1,7 +1,9 @@
//! Drivers for specific VirtIO devices.
pub mod blk;
+#[cfg(feature = "alloc")]
pub mod console;
+#[cfg(feature = "alloc")]
pub mod gpu;
#[cfg(feature = "alloc")]
pub mod input;
diff --git a/src/device/net.rs b/src/device/net.rs
index 4441f63..b9419e7 100644
--- a/src/device/net.rs
+++ b/src/device/net.rs
@@ -262,6 +262,7 @@ impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> Drop for VirtIONet<H, T, QUE
}
bitflags! {
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
struct Features: u64 {
/// Device handles packets with partial checksum.
/// This "checksum offload" is a common feature on modern network cards.
@@ -323,6 +324,7 @@ bitflags! {
}
bitflags! {
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
struct Status: u16 {
const LINK_UP = 1;
const ANNOUNCE = 2;
@@ -330,6 +332,7 @@ bitflags! {
}
bitflags! {
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
struct InterruptStatus : u32 {
const USED_RING_UPDATE = 1 << 0;
const CONFIGURATION_CHANGE = 1 << 1;
@@ -364,10 +367,12 @@ pub struct VirtioNetHdr {
// payload starts from here
}
+#[derive(AsBytes, Copy, Clone, Debug, Default, Eq, FromBytes, PartialEq)]
+#[repr(transparent)]
+struct Flags(u8);
+
bitflags! {
- #[repr(transparent)]
- #[derive(AsBytes, Default, FromBytes)]
- struct Flags: u8 {
+ impl Flags: u8 {
const NEEDS_CSUM = 1;
const DATA_VALID = 2;
const RSC_INFO = 4;
diff --git a/src/device/socket/mod.rs b/src/device/socket/mod.rs
index 65280aa..bf423bf 100644
--- a/src/device/socket/mod.rs
+++ b/src/device/socket/mod.rs
@@ -1,8 +1,26 @@
-//! This module implements the virtio vsock device.
+//! Driver for VirtIO socket devices.
+//!
+//! To use the driver, you should first create a [`VirtIOSocket`] instance with your VirtIO
+//! transport, and then create a [`VsockConnectionManager`] wrapping it to keep track of
+//! connections. If you only want to have a single outgoing vsock connection at once, you can use
+//! [`SingleConnectionManager`] for a slightly simpler interface.
+//!
+//! See [`VsockConnectionManager`] for a usage example.
mod error;
+#[cfg(feature = "alloc")]
+mod multiconnectionmanager;
mod protocol;
+#[cfg(feature = "alloc")]
+mod singleconnectionmanager;
+#[cfg(feature = "alloc")]
mod vsock;
pub use error::SocketError;
+#[cfg(feature = "alloc")]
+pub use multiconnectionmanager::VsockConnectionManager;
+pub use protocol::VsockAddr;
+#[cfg(feature = "alloc")]
+pub use singleconnectionmanager::SingleConnectionManager;
+#[cfg(feature = "alloc")]
pub use vsock::{DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType};
diff --git a/src/device/socket/multiconnectionmanager.rs b/src/device/socket/multiconnectionmanager.rs
new file mode 100644
index 0000000..6aee5cd
--- /dev/null
+++ b/src/device/socket/multiconnectionmanager.rs
@@ -0,0 +1,763 @@
+use super::{
+ protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket,
+ VsockEvent, VsockEventType,
+};
+use crate::{transport::Transport, Hal, Result};
+use alloc::{boxed::Box, vec::Vec};
+use core::cmp::min;
+use core::convert::TryInto;
+use core::hint::spin_loop;
+use log::debug;
+use zerocopy::FromBytes;
+
+const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
+
+/// A higher level interface for VirtIO socket (vsock) devices.
+///
+/// This keeps track of multiple vsock connections.
+///
+/// # Example
+///
+/// ```
+/// # use virtio_drivers::{Error, Hal};
+/// # use virtio_drivers::transport::Transport;
+/// use virtio_drivers::device::socket::{VirtIOSocket, VsockAddr, VsockConnectionManager};
+///
+/// # fn example<HalImpl: Hal, T: Transport>(transport: T) -> Result<(), Error> {
+/// let mut socket = VsockConnectionManager::new(VirtIOSocket::<HalImpl, _>::new(transport)?);
+///
+/// // Start a thread to call `socket.poll()` and handle events.
+///
+/// let remote_address = VsockAddr { cid: 2, port: 42 };
+/// let local_port = 1234;
+/// socket.connect(remote_address, local_port)?;
+///
+/// // Wait until `socket.poll()` returns an event indicating that the socket is connected.
+///
+/// socket.send(remote_address, local_port, "Hello world".as_bytes())?;
+///
+/// socket.shutdown(remote_address, local_port)?;
+/// # Ok(())
+/// # }
+/// ```
+pub struct VsockConnectionManager<H: Hal, T: Transport> {
+ driver: VirtIOSocket<H, T>,
+ connections: Vec<Connection>,
+ listening_ports: Vec<u32>,
+}
+
+#[derive(Debug)]
+struct Connection {
+ info: ConnectionInfo,
+ buffer: RingBuffer,
+ /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
+ /// still data in the buffer.
+ peer_requested_shutdown: bool,
+}
+
+impl Connection {
+ fn new(peer: VsockAddr, local_port: u32) -> Self {
+ let mut info = ConnectionInfo::new(peer, local_port);
+ info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
+ Self {
+ info,
+ buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
+ peer_requested_shutdown: false,
+ }
+ }
+}
+
+impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
+ /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
+ pub fn new(driver: VirtIOSocket<H, T>) -> Self {
+ Self {
+ driver,
+ connections: Vec::new(),
+ listening_ports: Vec::new(),
+ }
+ }
+
+ /// Returns the CID which has been assigned to this guest.
+ pub fn guest_cid(&self) -> u64 {
+ self.driver.guest_cid()
+ }
+
+ /// Allows incoming connections on the given port number.
+ pub fn listen(&mut self, port: u32) {
+ if !self.listening_ports.contains(&port) {
+ self.listening_ports.push(port);
+ }
+ }
+
+ /// Stops allowing incoming connections on the given port number.
+ pub fn unlisten(&mut self, port: u32) {
+ self.listening_ports.retain(|p| *p != port);
+ }
+
+ /// Sends a request to connect to the given destination.
+ ///
+ /// This returns as soon as the request is sent; you should wait until `poll` returns a
+ /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
+ /// before sending data.
+ pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
+ if self.connections.iter().any(|connection| {
+ connection.info.dst == destination && connection.info.src_port == src_port
+ }) {
+ return Err(SocketError::ConnectionExists.into());
+ }
+
+ let new_connection = Connection::new(destination, src_port);
+
+ self.driver.connect(&new_connection.info)?;
+ debug!("Connection requested: {:?}", new_connection.info);
+ self.connections.push(new_connection);
+ Ok(())
+ }
+
+ /// Sends the buffer to the destination.
+ pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
+ let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
+
+ self.driver.send(buffer, &mut connection.info)
+ }
+
+ /// Polls the vsock device to receive data or other updates.
+ pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
+ let guest_cid = self.driver.guest_cid();
+ let connections = &mut self.connections;
+
+ let result = self.driver.poll(|event, body| {
+ let connection = get_connection_for_event(connections, &event, guest_cid);
+
+ // Skip events which don't match any connection we know about, unless they are a
+ // connection request.
+ let connection = if let Some((_, connection)) = connection {
+ connection
+ } else if let VsockEventType::ConnectionRequest = event.event_type {
+ // If the requested connection already exists or the CID isn't ours, ignore it.
+ if connection.is_some() || event.destination.cid != guest_cid {
+ return Ok(None);
+ }
+ // Add the new connection to our list, at least for now. It will be removed again
+ // below if we weren't listening on the port.
+ connections.push(Connection::new(event.source, event.destination.port));
+ connections.last_mut().unwrap()
+ } else {
+ return Ok(None);
+ };
+
+ // Update stored connection info.
+ connection.info.update_for_event(&event);
+
+ if let VsockEventType::Received { length } = event.event_type {
+ // Copy to buffer
+ if !connection.buffer.add(body) {
+ return Err(SocketError::OutputBufferTooShort(length).into());
+ }
+ }
+
+ Ok(Some(event))
+ })?;
+
+ let Some(event) = result else {
+ return Ok(None);
+ };
+
+ // The connection must exist because we found it above in the callback.
+ let (connection_index, connection) =
+ get_connection_for_event(connections, &event, guest_cid).unwrap();
+
+ match event.event_type {
+ VsockEventType::ConnectionRequest => {
+ if self.listening_ports.contains(&event.destination.port) {
+ self.driver.accept(&connection.info)?;
+ } else {
+ // Reject the connection request and remove it from our list.
+ self.driver.force_close(&connection.info)?;
+ self.connections.swap_remove(connection_index);
+
+ // No need to pass the request on to the client, as we've already rejected it.
+ return Ok(None);
+ }
+ }
+ VsockEventType::Connected => {}
+ VsockEventType::Disconnected { reason } => {
+ // Wait until client reads all data before removing connection.
+ if connection.buffer.is_empty() {
+ if reason == DisconnectReason::Shutdown {
+ self.driver.force_close(&connection.info)?;
+ }
+ self.connections.swap_remove(connection_index);
+ } else {
+ connection.peer_requested_shutdown = true;
+ }
+ }
+ VsockEventType::Received { .. } => {
+ // Already copied the buffer in the callback above.
+ }
+ VsockEventType::CreditRequest => {
+ // If the peer requested credit, send an update.
+ self.driver.credit_update(&connection.info)?;
+ // No need to pass the request on to the client, we've already handled it.
+ return Ok(None);
+ }
+ VsockEventType::CreditUpdate => {}
+ }
+
+ Ok(Some(event))
+ }
+
+ /// Reads data received from the given connection.
+ pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
+ let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
+
+ // Copy from ring buffer
+ let bytes_read = connection.buffer.drain(buffer);
+
+ connection.info.done_forwarding(bytes_read);
+
+ // If buffer is now empty and the peer requested shutdown, finish shutting down the
+ // connection.
+ if connection.peer_requested_shutdown && connection.buffer.is_empty() {
+ self.driver.force_close(&connection.info)?;
+ self.connections.swap_remove(connection_index);
+ }
+
+ Ok(bytes_read)
+ }
+
+ /// Blocks until we get some event from the vsock device.
+ pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
+ loop {
+ if let Some(event) = self.poll()? {
+ return Ok(event);
+ } else {
+ spin_loop();
+ }
+ }
+ }
+
+ /// Requests to shut down the connection cleanly.
+ ///
+ /// This returns as soon as the request is sent; you should wait until `poll` returns a
+ /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
+ /// shutdown.
+ pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
+ let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
+
+ self.driver.shutdown(&connection.info)
+ }
+
+ /// Forcibly closes the connection without waiting for the peer.
+ pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
+ let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
+
+ self.driver.force_close(&connection.info)?;
+
+ self.connections.swap_remove(index);
+ Ok(())
+ }
+}
+
+/// Returns the connection from the given list matching the given peer address and local port, and
+/// its index.
+///
+/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list.
+fn get_connection(
+ connections: &mut [Connection],
+ peer: VsockAddr,
+ local_port: u32,
+) -> core::result::Result<(usize, &mut Connection), SocketError> {
+ connections
+ .iter_mut()
+ .enumerate()
+ .find(|(_, connection)| {
+ connection.info.dst == peer && connection.info.src_port == local_port
+ })
+ .ok_or(SocketError::NotConnected)
+}
+
+/// Returns the connection from the given list matching the event, if any, and its index.
+fn get_connection_for_event<'a>(
+ connections: &'a mut [Connection],
+ event: &VsockEvent,
+ local_cid: u64,
+) -> Option<(usize, &'a mut Connection)> {
+ connections
+ .iter_mut()
+ .enumerate()
+ .find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
+}
+
+#[derive(Debug)]
+struct RingBuffer {
+ buffer: Box<[u8]>,
+ /// The number of bytes currently in the buffer.
+ used: usize,
+ /// The index of the first used byte in the buffer.
+ start: usize,
+}
+
+impl RingBuffer {
+ pub fn new(capacity: usize) -> Self {
+ Self {
+ buffer: FromBytes::new_box_slice_zeroed(capacity),
+ used: 0,
+ start: 0,
+ }
+ }
+
+ /// Returns the number of bytes currently used in the buffer.
+ pub fn used(&self) -> usize {
+ self.used
+ }
+
+ /// Returns true iff there are currently no bytes in the buffer.
+ pub fn is_empty(&self) -> bool {
+ self.used == 0
+ }
+
+ /// Returns the number of bytes currently free in the buffer.
+ pub fn available(&self) -> usize {
+ self.buffer.len() - self.used
+ }
+
+ /// Adds the given bytes to the buffer if there is enough capacity for them all.
+ ///
+ /// Returns true if they were added, or false if they were not.
+ pub fn add(&mut self, bytes: &[u8]) -> bool {
+ if bytes.len() > self.available() {
+ return false;
+ }
+
+ // The index of the first available position in the buffer.
+ let first_available = (self.start + self.used) % self.buffer.len();
+ // The number of bytes to copy from `bytes` to `buffer` between `first_available` and
+ // `buffer.len()`.
+ let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
+ self.buffer[first_available..first_available + copy_length_before_wraparound]
+ .copy_from_slice(&bytes[0..copy_length_before_wraparound]);
+ if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
+ self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
+ }
+ self.used += bytes.len();
+
+ true
+ }
+
+ /// Reads and removes as many bytes as possible from the buffer, up to the length of the given
+ /// buffer.
+ pub fn drain(&mut self, out: &mut [u8]) -> usize {
+ let bytes_read = min(self.used, out.len());
+
+ // The number of bytes to copy out between `start` and the end of the buffer.
+ let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
+ // The number of bytes to copy out from the beginning of the buffer after wrapping around.
+ let read_after_wraparound = bytes_read
+ .checked_sub(read_before_wraparound)
+ .unwrap_or_default();
+
+ out[0..read_before_wraparound]
+ .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
+ out[read_before_wraparound..bytes_read]
+ .copy_from_slice(&self.buffer[0..read_after_wraparound]);
+
+ self.used -= bytes_read;
+ self.start = (self.start + bytes_read) % self.buffer.len();
+
+ bytes_read
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ device::socket::{
+ protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp},
+ vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
+ },
+ hal::fake::FakeHal,
+ transport::{
+ fake::{FakeTransport, QueueStatus, State},
+ DeviceStatus, DeviceType,
+ },
+ volatile::ReadOnly,
+ };
+ use alloc::{sync::Arc, vec};
+ use core::{mem::size_of, ptr::NonNull};
+ use std::{sync::Mutex, thread};
+ use zerocopy::{AsBytes, FromBytes};
+
+ #[test]
+ fn send_recv() {
+ let host_cid = 2;
+ let guest_cid = 66;
+ let host_port = 1234;
+ let guest_port = 4321;
+ let host_address = VsockAddr {
+ cid: host_cid,
+ port: host_port,
+ };
+ let hello_from_guest = "Hello from guest";
+ let hello_from_host = "Hello from host";
+
+ let mut config_space = VirtioVsockConfig {
+ guest_cid_low: ReadOnly::new(66),
+ guest_cid_high: ReadOnly::new(0),
+ };
+ let state = Arc::new(Mutex::new(State {
+ status: DeviceStatus::empty(),
+ driver_features: 0,
+ guest_page_size: 0,
+ interrupt_pending: false,
+ queues: vec![
+ QueueStatus::default(),
+ QueueStatus::default(),
+ QueueStatus::default(),
+ ],
+ }));
+ let transport = FakeTransport {
+ device_type: DeviceType::Socket,
+ max_queue_size: 32,
+ device_features: 0,
+ config_space: NonNull::from(&mut config_space),
+ state: state.clone(),
+ };
+ let mut socket = VsockConnectionManager::new(
+ VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
+ );
+
+ // Start a thread to simulate the device.
+ let handle = thread::spawn(move || {
+ // Wait for connection request.
+ State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+ assert_eq!(
+ VirtioVsockHdr::read_from(
+ state
+ .lock()
+ .unwrap()
+ .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+ .as_slice()
+ )
+ .unwrap(),
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Request.into(),
+ src_cid: guest_cid.into(),
+ dst_cid: host_cid.into(),
+ src_port: guest_port.into(),
+ dst_port: host_port.into(),
+ len: 0.into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 1024.into(),
+ fwd_cnt: 0.into(),
+ }
+ );
+
+ // Accept connection and give the peer enough credit to send the message.
+ state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
+ RX_QUEUE_IDX,
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Response.into(),
+ src_cid: host_cid.into(),
+ dst_cid: guest_cid.into(),
+ src_port: host_port.into(),
+ dst_port: guest_port.into(),
+ len: 0.into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 50.into(),
+ fwd_cnt: 0.into(),
+ }
+ .as_bytes(),
+ );
+
+ // Expect the guest to send some data.
+ State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+ let request = state
+ .lock()
+ .unwrap()
+ .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
+ assert_eq!(
+ request.len(),
+ size_of::<VirtioVsockHdr>() + hello_from_guest.len()
+ );
+ assert_eq!(
+ VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Rw.into(),
+ src_cid: guest_cid.into(),
+ dst_cid: host_cid.into(),
+ src_port: guest_port.into(),
+ dst_port: host_port.into(),
+ len: (hello_from_guest.len() as u32).into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 1024.into(),
+ fwd_cnt: 0.into(),
+ }
+ );
+ assert_eq!(
+ &request[size_of::<VirtioVsockHdr>()..],
+ hello_from_guest.as_bytes()
+ );
+
+ println!("Host sending");
+
+ // Send a response.
+ let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Rw.into(),
+ src_cid: host_cid.into(),
+ dst_cid: guest_cid.into(),
+ src_port: host_port.into(),
+ dst_port: guest_port.into(),
+ len: (hello_from_host.len() as u32).into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 50.into(),
+ fwd_cnt: (hello_from_guest.len() as u32).into(),
+ }
+ .write_to_prefix(response.as_mut_slice());
+ response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
+ state
+ .lock()
+ .unwrap()
+ .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
+
+ // Expect a shutdown.
+ State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+ assert_eq!(
+ VirtioVsockHdr::read_from(
+ state
+ .lock()
+ .unwrap()
+ .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+ .as_slice()
+ )
+ .unwrap(),
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Shutdown.into(),
+ src_cid: guest_cid.into(),
+ dst_cid: host_cid.into(),
+ src_port: guest_port.into(),
+ dst_port: host_port.into(),
+ len: 0.into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 1024.into(),
+ fwd_cnt: (hello_from_host.len() as u32).into(),
+ }
+ );
+ });
+
+ socket.connect(host_address, guest_port).unwrap();
+ assert_eq!(
+ socket.wait_for_event().unwrap(),
+ VsockEvent {
+ source: host_address,
+ destination: VsockAddr {
+ cid: guest_cid,
+ port: guest_port,
+ },
+ event_type: VsockEventType::Connected,
+ buffer_status: VsockBufferStatus {
+ buffer_allocation: 50,
+ forward_count: 0,
+ },
+ }
+ );
+ println!("Guest sending");
+ socket
+ .send(host_address, guest_port, "Hello from guest".as_bytes())
+ .unwrap();
+ println!("Guest waiting to receive.");
+ assert_eq!(
+ socket.wait_for_event().unwrap(),
+ VsockEvent {
+ source: host_address,
+ destination: VsockAddr {
+ cid: guest_cid,
+ port: guest_port,
+ },
+ event_type: VsockEventType::Received {
+ length: hello_from_host.len()
+ },
+ buffer_status: VsockBufferStatus {
+ buffer_allocation: 50,
+ forward_count: hello_from_guest.len() as u32,
+ },
+ }
+ );
+ println!("Guest getting received data.");
+ let mut buffer = [0u8; 64];
+ assert_eq!(
+ socket.recv(host_address, guest_port, &mut buffer).unwrap(),
+ hello_from_host.len()
+ );
+ assert_eq!(
+ &buffer[0..hello_from_host.len()],
+ hello_from_host.as_bytes()
+ );
+ socket.shutdown(host_address, guest_port).unwrap();
+
+ handle.join().unwrap();
+ }
+
+ #[test]
+ fn incoming_connection() {
+ let host_cid = 2;
+ let guest_cid = 66;
+ let host_port = 1234;
+ let guest_port = 4321;
+ let wrong_guest_port = 4444;
+ let host_address = VsockAddr {
+ cid: host_cid,
+ port: host_port,
+ };
+
+ let mut config_space = VirtioVsockConfig {
+ guest_cid_low: ReadOnly::new(66),
+ guest_cid_high: ReadOnly::new(0),
+ };
+ let state = Arc::new(Mutex::new(State {
+ status: DeviceStatus::empty(),
+ driver_features: 0,
+ guest_page_size: 0,
+ interrupt_pending: false,
+ queues: vec![
+ QueueStatus::default(),
+ QueueStatus::default(),
+ QueueStatus::default(),
+ ],
+ }));
+ let transport = FakeTransport {
+ device_type: DeviceType::Socket,
+ max_queue_size: 32,
+ device_features: 0,
+ config_space: NonNull::from(&mut config_space),
+ state: state.clone(),
+ };
+ let mut socket = VsockConnectionManager::new(
+ VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
+ );
+
+ socket.listen(guest_port);
+
+ // Start a thread to simulate the device.
+ let handle = thread::spawn(move || {
+ // Send a connection request for a port the guest isn't listening on.
+ println!("Host sending connection request to wrong port");
+ state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
+ RX_QUEUE_IDX,
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Request.into(),
+ src_cid: host_cid.into(),
+ dst_cid: guest_cid.into(),
+ src_port: host_port.into(),
+ dst_port: wrong_guest_port.into(),
+ len: 0.into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 50.into(),
+ fwd_cnt: 0.into(),
+ }
+ .as_bytes(),
+ );
+
+ // Expect a rejection.
+ println!("Host waiting for rejection");
+ State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+ assert_eq!(
+ VirtioVsockHdr::read_from(
+ state
+ .lock()
+ .unwrap()
+ .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+ .as_slice()
+ )
+ .unwrap(),
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Rst.into(),
+ src_cid: guest_cid.into(),
+ dst_cid: host_cid.into(),
+ src_port: wrong_guest_port.into(),
+ dst_port: host_port.into(),
+ len: 0.into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 1024.into(),
+ fwd_cnt: 0.into(),
+ }
+ );
+
+ // Send a connection request for a port the guest is listening on.
+ println!("Host sending connection request to right port");
+ state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
+ RX_QUEUE_IDX,
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Request.into(),
+ src_cid: host_cid.into(),
+ dst_cid: guest_cid.into(),
+ src_port: host_port.into(),
+ dst_port: guest_port.into(),
+ len: 0.into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 50.into(),
+ fwd_cnt: 0.into(),
+ }
+ .as_bytes(),
+ );
+
+ // Expect a response.
+ println!("Host waiting for response");
+ State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+ assert_eq!(
+ VirtioVsockHdr::read_from(
+ state
+ .lock()
+ .unwrap()
+ .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+ .as_slice()
+ )
+ .unwrap(),
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Response.into(),
+ src_cid: guest_cid.into(),
+ dst_cid: host_cid.into(),
+ src_port: guest_port.into(),
+ dst_port: host_port.into(),
+ len: 0.into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 1024.into(),
+ fwd_cnt: 0.into(),
+ }
+ );
+
+ println!("Host finished");
+ });
+
+ // Expect an incoming connection.
+ println!("Guest expecting incoming connection.");
+ assert_eq!(
+ socket.wait_for_event().unwrap(),
+ VsockEvent {
+ source: host_address,
+ destination: VsockAddr {
+ cid: guest_cid,
+ port: guest_port,
+ },
+ event_type: VsockEventType::ConnectionRequest,
+ buffer_status: VsockBufferStatus {
+ buffer_allocation: 50,
+ forward_count: 0,
+ },
+ }
+ );
+
+ handle.join().unwrap();
+ }
+}
diff --git a/src/device/socket/protocol.rs b/src/device/socket/protocol.rs
index abc1702..3587005 100644
--- a/src/device/socket/protocol.rs
+++ b/src/device/socket/protocol.rs
@@ -2,6 +2,7 @@
use super::error::{self, SocketError};
use crate::volatile::ReadOnly;
+use bitflags::bitflags;
use core::{
convert::{TryFrom, TryInto},
fmt,
@@ -17,6 +18,8 @@ use zerocopy::{
pub enum SocketType {
/// Stream sockets provide in-order, guaranteed, connection-oriented delivery without message boundaries.
Stream = 1,
+ /// seqpacket socket type introduced in virtio-v1.2.
+ SeqPacket = 2,
}
impl From<SocketType> for U16<LittleEndian> {
@@ -40,7 +43,7 @@ pub struct VirtioVsockConfig {
/// The message header for data packets sent on the tx/rx queues
#[repr(packed)]
-#[derive(AsBytes, Clone, Copy, Debug, FromBytes)]
+#[derive(AsBytes, Clone, Copy, Debug, Eq, FromBytes, PartialEq)]
pub struct VirtioVsockHdr {
pub src_cid: U64<LittleEndian>,
pub dst_cid: U64<LittleEndian>,
@@ -182,3 +185,29 @@ impl fmt::Debug for VirtioVsockOp {
}
}
}
+
+bitflags! {
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
+ pub(crate) struct Feature: u64 {
+ /// stream socket type is supported.
+ const STREAM = 1 << 0;
+ /// seqpacket socket type is supported.
+ const SEQ_PACKET = 1 << 1;
+
+ // device independent
+ const NOTIFY_ON_EMPTY = 1 << 24; // legacy
+ const ANY_LAYOUT = 1 << 27; // legacy
+ const RING_INDIRECT_DESC = 1 << 28;
+ const RING_EVENT_IDX = 1 << 29;
+ const UNUSED = 1 << 30; // legacy
+ const VERSION_1 = 1 << 32; // detect legacy
+
+ // since virtio v1.1
+ const ACCESS_PLATFORM = 1 << 33;
+ const RING_PACKED = 1 << 34;
+ const IN_ORDER = 1 << 35;
+ const ORDER_PLATFORM = 1 << 36;
+ const SR_IOV = 1 << 37;
+ const NOTIFICATION_DATA = 1 << 38;
+ }
+}
diff --git a/src/device/socket/singleconnectionmanager.rs b/src/device/socket/singleconnectionmanager.rs
new file mode 100644
index 0000000..8c9bff6
--- /dev/null
+++ b/src/device/socket/singleconnectionmanager.rs
@@ -0,0 +1,447 @@
+use super::{
+ protocol::VsockAddr, vsock::ConnectionInfo, SocketError, VirtIOSocket, VsockEvent,
+ VsockEventType,
+};
+use crate::{transport::Transport, Hal, Result};
+use core::hint::spin_loop;
+use log::debug;
+
+/// A higher level interface for VirtIO socket (vsock) devices.
+///
+/// This can only keep track of a single vsock connection. If you want to support multiple
+/// simultaneous connections, try [`VsockConnectionManager`](super::VsockConnectionManager).
+pub struct SingleConnectionManager<H: Hal, T: Transport> {
+ driver: VirtIOSocket<H, T>,
+ connection_info: Option<ConnectionInfo>,
+}
+
+impl<H: Hal, T: Transport> SingleConnectionManager<H, T> {
+ /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
+ pub fn new(driver: VirtIOSocket<H, T>) -> Self {
+ Self {
+ driver,
+ connection_info: None,
+ }
+ }
+
+ /// Returns the CID which has been assigned to this guest.
+ pub fn guest_cid(&self) -> u64 {
+ self.driver.guest_cid()
+ }
+
+ /// Sends a request to connect to the given destination.
+ ///
+ /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+ /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
+ /// before sending data.
+ pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
+ if self.connection_info.is_some() {
+ return Err(SocketError::ConnectionExists.into());
+ }
+
+ let new_connection_info = ConnectionInfo::new(destination, src_port);
+
+ self.driver.connect(&new_connection_info)?;
+ debug!("Connection requested: {:?}", new_connection_info);
+ self.connection_info = Some(new_connection_info);
+ Ok(())
+ }
+
+ /// Sends the buffer to the destination.
+ pub fn send(&mut self, buffer: &[u8]) -> Result {
+ let connection_info = self
+ .connection_info
+ .as_mut()
+ .ok_or(SocketError::NotConnected)?;
+ connection_info.buf_alloc = 0;
+ self.driver.send(buffer, connection_info)
+ }
+
+ /// Polls the vsock device to receive data or other updates.
+ ///
+ /// A buffer must be provided to put the data in if there is some to
+ /// receive.
+ pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
+ let Some(connection_info) = &mut self.connection_info else {
+ return Err(SocketError::NotConnected.into());
+ };
+
+ // Tell the peer that we have space to receive some data.
+ connection_info.buf_alloc = buffer.len() as u32;
+ self.driver.credit_update(connection_info)?;
+
+ self.poll_rx_queue(buffer)
+ }
+
+ /// Blocks until we get some event from the vsock device.
+ ///
+ /// A buffer must be provided to put the data in if there is some to
+ /// receive.
+ pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
+ loop {
+ if let Some(event) = self.poll_recv(buffer)? {
+ return Ok(event);
+ } else {
+ spin_loop();
+ }
+ }
+ }
+
+ fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
+ let guest_cid = self.driver.guest_cid();
+ let self_connection_info = &mut self.connection_info;
+
+ self.driver.poll(|event, borrowed_body| {
+ let Some(connection_info) = self_connection_info else {
+ return Ok(None);
+ };
+
+ // Skip packets which don't match our current connection.
+ if !event.matches_connection(connection_info, guest_cid) {
+ debug!(
+ "Skipping {:?} as connection is {:?}",
+ event, connection_info
+ );
+ return Ok(None);
+ }
+
+ // Update stored connection info.
+ connection_info.update_for_event(&event);
+
+ match event.event_type {
+ VsockEventType::ConnectionRequest => {
+ // TODO: Send Rst or handle incoming connections.
+ }
+ VsockEventType::Connected => {}
+ VsockEventType::Disconnected { .. } => {
+ *self_connection_info = None;
+ }
+ VsockEventType::Received { length } => {
+ body.get_mut(0..length)
+ .ok_or(SocketError::OutputBufferTooShort(length))?
+ .copy_from_slice(borrowed_body);
+ connection_info.done_forwarding(length);
+ }
+ VsockEventType::CreditRequest => {
+ // No point sending a credit update until `poll_recv` is called with a buffer,
+ // as otherwise buf_alloc would just be 0 anyway.
+ }
+ VsockEventType::CreditUpdate => {}
+ }
+
+ Ok(Some(event))
+ })
+ }
+
+ /// Requests to shut down the connection cleanly.
+ ///
+ /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+ /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
+ /// shutdown.
+ pub fn shutdown(&mut self) -> Result {
+ let connection_info = self
+ .connection_info
+ .as_mut()
+ .ok_or(SocketError::NotConnected)?;
+ connection_info.buf_alloc = 0;
+
+ self.driver.shutdown(connection_info)
+ }
+
+ /// Forcibly closes the connection without waiting for the peer.
+ pub fn force_close(&mut self) -> Result {
+ let connection_info = self
+ .connection_info
+ .as_mut()
+ .ok_or(SocketError::NotConnected)?;
+ connection_info.buf_alloc = 0;
+
+ self.driver.force_close(connection_info)?;
+ self.connection_info = None;
+ Ok(())
+ }
+
+ /// Blocks until the peer either accepts our connection request (with a
+ /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a
+ /// `VIRTIO_VSOCK_OP_RST`).
+ pub fn wait_for_connect(&mut self) -> Result {
+ loop {
+ match self.wait_for_recv(&mut [])?.event_type {
+ VsockEventType::Connected => return Ok(()),
+ VsockEventType::Disconnected { .. } => {
+ return Err(SocketError::ConnectionFailed.into())
+ }
+ VsockEventType::Received { .. } => return Err(SocketError::InvalidOperation.into()),
+ VsockEventType::ConnectionRequest
+ | VsockEventType::CreditRequest
+ | VsockEventType::CreditUpdate => {}
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ device::socket::{
+ protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp},
+ vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
+ },
+ hal::fake::FakeHal,
+ transport::{
+ fake::{FakeTransport, QueueStatus, State},
+ DeviceStatus, DeviceType,
+ },
+ volatile::ReadOnly,
+ };
+ use alloc::{sync::Arc, vec};
+ use core::{mem::size_of, ptr::NonNull};
+ use std::{sync::Mutex, thread};
+ use zerocopy::{AsBytes, FromBytes};
+
+ #[test]
+ fn send_recv() {
+ let host_cid = 2;
+ let guest_cid = 66;
+ let host_port = 1234;
+ let guest_port = 4321;
+ let host_address = VsockAddr {
+ cid: host_cid,
+ port: host_port,
+ };
+ let hello_from_guest = "Hello from guest";
+ let hello_from_host = "Hello from host";
+
+ let mut config_space = VirtioVsockConfig {
+ guest_cid_low: ReadOnly::new(66),
+ guest_cid_high: ReadOnly::new(0),
+ };
+ let state = Arc::new(Mutex::new(State {
+ status: DeviceStatus::empty(),
+ driver_features: 0,
+ guest_page_size: 0,
+ interrupt_pending: false,
+ queues: vec![
+ QueueStatus::default(),
+ QueueStatus::default(),
+ QueueStatus::default(),
+ ],
+ }));
+ let transport = FakeTransport {
+ device_type: DeviceType::Socket,
+ max_queue_size: 32,
+ device_features: 0,
+ config_space: NonNull::from(&mut config_space),
+ state: state.clone(),
+ };
+ let mut socket = SingleConnectionManager::new(
+ VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
+ );
+
+ // Start a thread to simulate the device.
+ let handle = thread::spawn(move || {
+ // Wait for connection request.
+ State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+ assert_eq!(
+ VirtioVsockHdr::read_from(
+ state
+ .lock()
+ .unwrap()
+ .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+ .as_slice()
+ )
+ .unwrap(),
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Request.into(),
+ src_cid: guest_cid.into(),
+ dst_cid: host_cid.into(),
+ src_port: guest_port.into(),
+ dst_port: host_port.into(),
+ len: 0.into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 0.into(),
+ fwd_cnt: 0.into(),
+ }
+ );
+
+ // Accept connection and give the peer enough credit to send the message.
+ state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
+ RX_QUEUE_IDX,
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Response.into(),
+ src_cid: host_cid.into(),
+ dst_cid: guest_cid.into(),
+ src_port: host_port.into(),
+ dst_port: guest_port.into(),
+ len: 0.into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 50.into(),
+ fwd_cnt: 0.into(),
+ }
+ .as_bytes(),
+ );
+
+ // Expect a credit update.
+ State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+ assert_eq!(
+ VirtioVsockHdr::read_from(
+ state
+ .lock()
+ .unwrap()
+ .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+ .as_slice()
+ )
+ .unwrap(),
+ VirtioVsockHdr {
+ op: VirtioVsockOp::CreditUpdate.into(),
+ src_cid: guest_cid.into(),
+ dst_cid: host_cid.into(),
+ src_port: guest_port.into(),
+ dst_port: host_port.into(),
+ len: 0.into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 0.into(),
+ fwd_cnt: 0.into(),
+ }
+ );
+
+ // Expect the guest to send some data.
+ State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+ let request = state
+ .lock()
+ .unwrap()
+ .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
+ assert_eq!(
+ request.len(),
+ size_of::<VirtioVsockHdr>() + hello_from_guest.len()
+ );
+ assert_eq!(
+ VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Rw.into(),
+ src_cid: guest_cid.into(),
+ dst_cid: host_cid.into(),
+ src_port: guest_port.into(),
+ dst_port: host_port.into(),
+ len: (hello_from_guest.len() as u32).into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 0.into(),
+ fwd_cnt: 0.into(),
+ }
+ );
+ assert_eq!(
+ &request[size_of::<VirtioVsockHdr>()..],
+ hello_from_guest.as_bytes()
+ );
+
+ // Send a response.
+ let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Rw.into(),
+ src_cid: host_cid.into(),
+ dst_cid: guest_cid.into(),
+ src_port: host_port.into(),
+ dst_port: guest_port.into(),
+ len: (hello_from_host.len() as u32).into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 50.into(),
+ fwd_cnt: (hello_from_guest.len() as u32).into(),
+ }
+ .write_to_prefix(response.as_mut_slice());
+ response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
+ state
+ .lock()
+ .unwrap()
+ .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
+
+ // Expect a credit update.
+ State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+ assert_eq!(
+ VirtioVsockHdr::read_from(
+ state
+ .lock()
+ .unwrap()
+ .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+ .as_slice()
+ )
+ .unwrap(),
+ VirtioVsockHdr {
+ op: VirtioVsockOp::CreditUpdate.into(),
+ src_cid: guest_cid.into(),
+ dst_cid: host_cid.into(),
+ src_port: guest_port.into(),
+ dst_port: host_port.into(),
+ len: 0.into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 64.into(),
+ fwd_cnt: 0.into(),
+ }
+ );
+
+ // Expect a shutdown.
+ State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+ assert_eq!(
+ VirtioVsockHdr::read_from(
+ state
+ .lock()
+ .unwrap()
+ .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+ .as_slice()
+ )
+ .unwrap(),
+ VirtioVsockHdr {
+ op: VirtioVsockOp::Shutdown.into(),
+ src_cid: guest_cid.into(),
+ dst_cid: host_cid.into(),
+ src_port: guest_port.into(),
+ dst_port: host_port.into(),
+ len: 0.into(),
+ socket_type: SocketType::Stream.into(),
+ flags: 0.into(),
+ buf_alloc: 0.into(),
+ fwd_cnt: (hello_from_host.len() as u32).into(),
+ }
+ );
+ });
+
+ socket.connect(host_address, guest_port).unwrap();
+ socket.wait_for_connect().unwrap();
+ socket.send(hello_from_guest.as_bytes()).unwrap();
+ let mut buffer = [0u8; 64];
+ let event = socket.wait_for_recv(&mut buffer).unwrap();
+ assert_eq!(
+ event,
+ VsockEvent {
+ source: VsockAddr {
+ cid: host_cid,
+ port: host_port,
+ },
+ destination: VsockAddr {
+ cid: guest_cid,
+ port: guest_port,
+ },
+ event_type: VsockEventType::Received {
+ length: hello_from_host.len()
+ },
+ buffer_status: VsockBufferStatus {
+ buffer_allocation: 50,
+ forward_count: hello_from_guest.len() as u32,
+ },
+ }
+ );
+ assert_eq!(
+ &buffer[0..hello_from_host.len()],
+ hello_from_host.as_bytes()
+ );
+ socket.shutdown().unwrap();
+
+ handle.join().unwrap();
+ }
+}
diff --git a/src/device/socket/vsock.rs b/src/device/socket/vsock.rs
index 686d7a6..523930e 100644
--- a/src/device/socket/vsock.rs
+++ b/src/device/socket/vsock.rs
@@ -2,29 +2,31 @@
#![deny(unsafe_op_in_unsafe_fn)]
use super::error::SocketError;
-use super::protocol::{VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr};
-use crate::device::common::Feature;
-use crate::hal::{BufferDirection, Dma, Hal};
+use super::protocol::{Feature, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr};
+use crate::hal::Hal;
use crate::queue::VirtQueue;
use crate::transport::Transport;
use crate::volatile::volread;
-use crate::Result;
-use core::hint::spin_loop;
+use crate::{Error, Result};
+use alloc::boxed::Box;
use core::mem::size_of;
-use core::ptr::NonNull;
-use log::{debug, info};
+use core::ptr::{null_mut, NonNull};
+use log::debug;
use zerocopy::{AsBytes, FromBytes};
-const RX_QUEUE_IDX: u16 = 0;
-const TX_QUEUE_IDX: u16 = 1;
+pub(crate) const RX_QUEUE_IDX: u16 = 0;
+pub(crate) const TX_QUEUE_IDX: u16 = 1;
const EVENT_QUEUE_IDX: u16 = 2;
-const QUEUE_SIZE: usize = 8;
+pub(crate) const QUEUE_SIZE: usize = 8;
+
+/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than size_of::<VirtioVsockHdr>().
+const RX_BUFFER_SIZE: usize = 512;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
-struct ConnectionInfo {
- dst: VsockAddr,
- src_port: u32,
+pub struct ConnectionInfo {
+ pub dst: VsockAddr,
+ pub src_port: u32,
/// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
/// bytes it has allocated for packet bodies.
peer_buf_alloc: u32,
@@ -33,6 +35,9 @@ struct ConnectionInfo {
peer_fwd_cnt: u32,
/// The number of bytes of packet bodies which we have sent to the peer.
tx_cnt: u32,
+ /// The number of bytes of buffer space we have allocated to receive packet bodies from the
+ /// peer.
+ pub buf_alloc: u32,
/// The number of bytes of packet bodies which we have received from the peer and handled.
fwd_cnt: u32,
/// Whether we have recently requested credit from the peer.
@@ -43,6 +48,35 @@ struct ConnectionInfo {
}
impl ConnectionInfo {
+ pub fn new(destination: VsockAddr, src_port: u32) -> Self {
+ Self {
+ dst: destination,
+ src_port,
+ ..Default::default()
+ }
+ }
+
+ /// Updates this connection info with the peer buffer allocation and forwarded count from the
+ /// given event.
+ pub fn update_for_event(&mut self, event: &VsockEvent) {
+ self.peer_buf_alloc = event.buffer_status.buffer_allocation;
+ self.peer_fwd_cnt = event.buffer_status.forward_count;
+
+ if let VsockEventType::CreditUpdate = event.event_type {
+ self.has_pending_credit_request = false;
+ }
+ }
+
+ /// Increases the forwarded count recorded for this connection by the given number of bytes.
+ ///
+ /// This should be called once received data has been passed to the client, so there is buffer
+ /// space available for more.
+ pub fn done_forwarding(&mut self, length: usize) {
+ self.fwd_cnt += length as u32;
+ }
+
+ /// Returns the number of bytes of RX buffer space the peer has available to receive packet body
+ /// data from us.
fn peer_free(&self) -> u32 {
self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
}
@@ -53,6 +87,7 @@ impl ConnectionInfo {
dst_cid: self.dst.cid.into(),
src_port: self.src_port.into(),
dst_port: self.dst.port.into(),
+ buf_alloc: self.buf_alloc.into(),
fwd_cnt: self.fwd_cnt.into(),
..Default::default()
}
@@ -66,10 +101,77 @@ pub struct VsockEvent {
pub source: VsockAddr,
/// The destination of the event, i.e. the CID and port on our side.
pub destination: VsockAddr,
+ /// The peer's buffer status for the connection.
+ pub buffer_status: VsockBufferStatus,
/// The type of event.
pub event_type: VsockEventType,
}
+impl VsockEvent {
+ /// Returns whether the event matches the given connection.
+ pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
+ self.source == connection_info.dst
+ && self.destination.cid == guest_cid
+ && self.destination.port == connection_info.src_port
+ }
+
+ fn from_header(header: &VirtioVsockHdr) -> Result<Self> {
+ let op = header.op()?;
+ let buffer_status = VsockBufferStatus {
+ buffer_allocation: header.buf_alloc.into(),
+ forward_count: header.fwd_cnt.into(),
+ };
+ let source = header.source();
+ let destination = header.destination();
+
+ let event_type = match op {
+ VirtioVsockOp::Request => {
+ header.check_data_is_empty()?;
+ VsockEventType::ConnectionRequest
+ }
+ VirtioVsockOp::Response => {
+ header.check_data_is_empty()?;
+ VsockEventType::Connected
+ }
+ VirtioVsockOp::CreditUpdate => {
+ header.check_data_is_empty()?;
+ VsockEventType::CreditUpdate
+ }
+ VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
+ header.check_data_is_empty()?;
+ debug!("Disconnected from the peer");
+ let reason = if op == VirtioVsockOp::Rst {
+ DisconnectReason::Reset
+ } else {
+ DisconnectReason::Shutdown
+ };
+ VsockEventType::Disconnected { reason }
+ }
+ VirtioVsockOp::Rw => VsockEventType::Received {
+ length: header.len() as usize,
+ },
+ VirtioVsockOp::CreditRequest => {
+ header.check_data_is_empty()?;
+ VsockEventType::CreditRequest
+ }
+ VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation.into()),
+ };
+
+ Ok(VsockEvent {
+ source,
+ destination,
+ buffer_status,
+ event_type,
+ })
+ }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct VsockBufferStatus {
+ pub buffer_allocation: u32,
+ pub forward_count: u32,
+}
+
/// The reason why a vsock connection was closed.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum DisconnectReason {
@@ -83,6 +185,8 @@ pub enum DisconnectReason {
/// Details of the type of an event received from a VirtIO socket.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum VsockEventType {
+ /// The peer requests to establish a connection with us.
+ ConnectionRequest,
/// The connection was successfully established.
Connected,
/// The connection was closed.
@@ -95,9 +199,16 @@ pub enum VsockEventType {
/// The length of the data in bytes.
length: usize,
},
+ /// The peer requests us to send a credit update.
+ CreditRequest,
+ /// The peer just sent us a credit update with nothing else.
+ CreditUpdate,
}
-/// Driver for a VirtIO socket device.
+/// Low-level driver for a VirtIO socket device.
+///
+/// You probably want to use [`VsockConnectionManager`](super::VsockConnectionManager) rather than
+/// using this directly.
pub struct VirtIOSocket<H: Hal, T: Transport> {
transport: T,
/// Virtqueue to receive packets.
@@ -108,10 +219,7 @@ pub struct VirtIOSocket<H: Hal, T: Transport> {
/// The guest_cid field contains the guest’s context ID, which uniquely identifies
/// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
guest_cid: u64,
- rx_buf_dma: Dma<H>,
-
- /// Currently the device is only allowed to be connected to one destination at a time.
- connection_info: Option<ConnectionInfo>,
+ rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE],
}
impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
@@ -121,6 +229,12 @@ impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
self.transport.queue_unset(RX_QUEUE_IDX);
self.transport.queue_unset(TX_QUEUE_IDX);
self.transport.queue_unset(EVENT_QUEUE_IDX);
+
+ for buffer in self.rx_queue_buffers {
+ // Safe because we obtained the RX buffer pointer from Box::into_raw, and it won't be
+ // used anywhere else after the driver is destroyed.
+ unsafe { drop(Box::from_raw(buffer.as_ptr())) };
+ }
}
}
@@ -129,35 +243,40 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
pub fn new(mut transport: T) -> Result<Self> {
transport.begin_init(|features| {
let features = Feature::from_bits_truncate(features);
- info!("Device features: {:?}", features);
+ debug!("Device features: {:?}", features);
// negotiate these flags only
let supported_features = Feature::empty();
(features & supported_features).bits()
});
let config = transport.config_space::<VirtioVsockConfig>()?;
- info!("config: {:?}", config);
+ debug!("config: {:?}", config);
// Safe because config is a valid pointer to the device configuration space.
let guest_cid = unsafe {
volread!(config, guest_cid_low) as u64 | (volread!(config, guest_cid_high) as u64) << 32
};
- info!("guest cid: {guest_cid:?}");
+ debug!("guest cid: {guest_cid:?}");
let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX)?;
let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?;
let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?;
- // Allocates 4 KiB memory as the rx buffer.
- let rx_buf_dma = Dma::new(
- 1, // pages
- BufferDirection::DeviceToDriver,
- )?;
- let rx_buf = rx_buf_dma.raw_slice();
- // Safe because `rx_buf` lives as long as the `rx` queue.
- unsafe {
- Self::fill_rx_queue(&mut rx, rx_buf, &mut transport)?;
+ // Allocate and add buffers for the RX queue.
+ let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE];
+ for (i, rx_queue_buffer) in rx_queue_buffers.iter_mut().enumerate() {
+ let mut buffer: Box<[u8; RX_BUFFER_SIZE]> = FromBytes::new_box_zeroed();
+ // Safe because the buffer lives as long as the queue, as specified in the function
+ // safety requirement, and we don't access it until it is popped.
+ let token = unsafe { rx.add(&[], &mut [buffer.as_mut_slice()]) }?;
+ assert_eq!(i, token.into());
+ *rx_queue_buffer = Box::into_raw(buffer);
}
+ let rx_queue_buffers = rx_queue_buffers.map(|ptr| NonNull::new(ptr).unwrap());
+
transport.finish_init();
+ if rx.should_notify() {
+ transport.notify(RX_QUEUE_IDX);
+ }
Ok(Self {
transport,
@@ -165,85 +284,41 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
tx,
event,
guest_cid,
- rx_buf_dma,
- connection_info: None,
+ rx_queue_buffers,
})
}
- /// Fills the `rx` queue with the buffer `rx_buf`.
- ///
- /// # Safety
- ///
- /// `rx_buf` must live at least as long as the `rx` queue, and the parts of the buffer which are
- /// in the queue must not be used anywhere else at the same time.
- unsafe fn fill_rx_queue(
- rx: &mut VirtQueue<H, { QUEUE_SIZE }>,
- rx_buf: NonNull<[u8]>,
- transport: &mut T,
- ) -> Result {
- if rx_buf.len() < size_of::<VirtioVsockHdr>() * QUEUE_SIZE {
- return Err(SocketError::BufferTooShort.into());
- }
- for i in 0..QUEUE_SIZE {
- // Safe because the buffer lives as long as the queue, as specified in the function
- // safety requirement, and we don't access it until it is popped.
- unsafe {
- let buffer = Self::as_mut_sub_rx_buffer(rx_buf, i)?;
- let token = rx.add(&[], &mut [buffer])?;
- assert_eq!(i, token.into());
- }
- }
-
- if rx.should_notify() {
- transport.notify(RX_QUEUE_IDX);
- }
- Ok(())
+ /// Returns the CID which has been assigned to this guest.
+ pub fn guest_cid(&self) -> u64 {
+ self.guest_cid
}
/// Sends a request to connect to the given destination.
///
- /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+ /// This returns as soon as the request is sent; you should wait until `poll` returns a
/// `VsockEventType::Connected` event indicating that the peer has accepted the connection
/// before sending data.
- pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result {
- if self.connection_info.is_some() {
- return Err(SocketError::ConnectionExists.into());
- }
- let new_connection_info = ConnectionInfo {
- dst: VsockAddr {
- cid: dst_cid,
- port: dst_port,
- },
- src_port,
- ..Default::default()
- };
+ pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Request.into(),
- ..new_connection_info.new_header(self.guest_cid)
+ ..connection_info.new_header(self.guest_cid)
};
- // Sends a header only packet to the tx queue to connect the device to the listening
- // socket at the given destination.
- self.send_packet_to_tx_queue(&header, &[])?;
-
- self.connection_info = Some(new_connection_info);
- debug!("Connection requested: {:?}", self.connection_info);
- Ok(())
+ // Sends a header only packet to the TX queue to connect the device to the listening socket
+ // at the given destination.
+ self.send_packet_to_tx_queue(&header, &[])
}
- /// Blocks until the peer either accepts our connection request (with a
- /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a
- /// `VIRTIO_VSOCK_OP_RST`).
- pub fn wait_for_connect(&mut self) -> Result {
- match self.wait_for_recv(&mut [])?.event_type {
- VsockEventType::Connected => Ok(()),
- VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()),
- VsockEventType::Received { .. } => Err(SocketError::InvalidOperation.into()),
- }
+ /// Accepts the given connection from a peer.
+ pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result {
+ let header = VirtioVsockHdr {
+ op: VirtioVsockOp::Response.into(),
+ ..connection_info.new_header(self.guest_cid)
+ };
+ self.send_packet_to_tx_queue(&header, &[])
}
- /// Requests the peer to send us a credit update for the current connection.
- fn request_credit(&mut self) -> Result {
- let connection_info = self.connection_info()?;
+ /// Requests the peer to send us a credit update for the given connection.
+ fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
let header = VirtioVsockHdr {
op: VirtioVsockOp::CreditRequest.into(),
..connection_info.new_header(self.guest_cid)
@@ -252,21 +327,16 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
}
/// Sends the buffer to the destination.
- pub fn send(&mut self, buffer: &[u8]) -> Result {
- let mut connection_info = self.connection_info()?;
-
- let result = self.check_peer_buffer_is_sufficient(&mut connection_info, buffer.len());
- self.connection_info = Some(connection_info.clone());
- result?;
+ pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
+ self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
let len = buffer.len() as u32;
let header = VirtioVsockHdr {
op: VirtioVsockOp::Rw.into(),
len: len.into(),
- buf_alloc: 0.into(),
..connection_info.new_header(self.guest_cid)
};
- self.connection_info.as_mut().unwrap().tx_cnt += len;
+ connection_info.tx_cnt += len;
self.send_packet_to_tx_queue(&header, buffer)
}
@@ -281,59 +351,48 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
// Request an update of the cached peer credit, if we haven't already done so, and tell
// the caller to try again later.
if !connection_info.has_pending_credit_request {
- self.request_credit()?;
+ self.request_credit(connection_info)?;
connection_info.has_pending_credit_request = true;
}
Err(SocketError::InsufficientBufferSpaceInPeer.into())
}
}
- /// Polls the vsock device to receive data or other updates.
- ///
- /// A buffer must be provided to put the data in if there is some to
- /// receive.
- pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
- let connection_info = self.connection_info()?;
-
- // Tell the peer that we have space to receive some data.
+ /// Tells the peer how much buffer space we have to receive data.
+ pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result {
let header = VirtioVsockHdr {
op: VirtioVsockOp::CreditUpdate.into(),
- buf_alloc: (buffer.len() as u32).into(),
..connection_info.new_header(self.guest_cid)
};
- self.send_packet_to_tx_queue(&header, &[])?;
-
- // Handle entries from the RX virtqueue until we find one that generates an event.
- let event = self.poll_rx_queue(buffer)?;
+ self.send_packet_to_tx_queue(&header, &[])
+ }
- if self.rx.should_notify() {
- self.transport.notify(RX_QUEUE_IDX);
- }
+ /// Polls the RX virtqueue for the next event, and calls the given handler function to handle
+ /// it.
+ pub fn poll(
+ &mut self,
+ handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
+ ) -> Result<Option<VsockEvent>> {
+ let Some((header, body, token)) = self.pop_packet_from_rx_queue()? else {
+ return Ok(None);
+ };
- Ok(event)
- }
+ let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body));
- /// Blocks until we get some event from the vsock device.
- ///
- /// A buffer must be provided to put the data in if there is some to
- /// receive.
- pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
- loop {
- if let Some(event) = self.poll_recv(buffer)? {
- return Ok(event);
- } else {
- spin_loop();
- }
+ unsafe {
+ // TODO: What about if both handler and this give errors?
+ self.add_buffer_to_rx_queue(token)?;
}
+
+ result
}
- /// Request to shut down the connection cleanly.
+ /// Requests to shut down the connection cleanly.
///
- /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+ /// This returns as soon as the request is sent; you should wait until `poll` returns a
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
/// shutdown.
- pub fn shutdown(&mut self) -> Result {
- let connection_info = self.connection_info()?;
+ pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Shutdown.into(),
..connection_info.new_header(self.guest_cid)
@@ -342,14 +401,12 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
}
/// Forcibly closes the connection without waiting for the peer.
- pub fn force_close(&mut self) -> Result {
- let connection_info = self.connection_info()?;
+ pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Rst.into(),
..connection_info.new_header(self.guest_cid)
};
self.send_packet_to_tx_queue(&header, &[])?;
- self.connection_info = None;
Ok(())
}
@@ -362,118 +419,39 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
Ok(())
}
- /// Polls the RX virtqueue until either it is empty, there is an error, or we find a packet
- /// which generates a `VsockEvent`.
+ /// Adds the buffer at the given index in `rx_queue_buffers` back to the RX queue.
///
- /// Returns `Ok(None)` if the virtqueue is empty, possibly after processing some packets which
- /// don't result in any events to return.
- fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
- loop {
- let mut connection_info = self.connection_info.clone().unwrap_or_default();
- let Some(header) = self.pop_packet_from_rx_queue(body)? else{
- return Ok(None);
- };
-
- let op = header.op()?;
-
- // Skip packets which don't match our current connection.
- if header.source() != connection_info.dst
- || header.dst_cid.get() != self.guest_cid
- || header.dst_port.get() != connection_info.src_port
- {
- debug!(
- "Skipping {:?} as connection is {:?}",
- header, connection_info
- );
- continue;
- }
-
- connection_info.peer_buf_alloc = header.buf_alloc.into();
- connection_info.peer_fwd_cnt = header.fwd_cnt.into();
- if self.connection_info.is_some() {
- self.connection_info = Some(connection_info.clone());
- debug!("Connection info updated: {:?}", self.connection_info);
- }
+ /// # Safety
+ ///
+ /// The buffer must not currently be in the RX queue, and no other references to it must exist
+ /// between when this method is called and when it is popped from the queue.
+ unsafe fn add_buffer_to_rx_queue(&mut self, index: u16) -> Result {
+ // Safe because the buffer lives as long as the queue, and the caller guarantees that it's
+ // not currently in the queue or referred to anywhere else until it is popped.
+ unsafe {
+ let buffer = self
+ .rx_queue_buffers
+ .get_mut(usize::from(index))
+ .ok_or(Error::WrongToken)?
+ .as_mut();
+ let new_token = self.rx.add(&[], &mut [buffer])?;
+ // If the RX buffer somehow gets assigned a different token, then our safety assumptions
+ // are broken and we can't safely continue to do anything with the device.
+ assert_eq!(new_token, index);
+ }
- match op {
- VirtioVsockOp::Request => {
- header.check_data_is_empty()?;
- // TODO: Send a Rst, or support listening.
- }
- VirtioVsockOp::Response => {
- header.check_data_is_empty()?;
- return Ok(Some(VsockEvent {
- source: connection_info.dst,
- destination: VsockAddr {
- cid: self.guest_cid,
- port: connection_info.src_port,
- },
- event_type: VsockEventType::Connected,
- }));
- }
- VirtioVsockOp::CreditUpdate => {
- header.check_data_is_empty()?;
- connection_info.has_pending_credit_request = false;
- if self.connection_info.is_some() {
- self.connection_info = Some(connection_info.clone());
- }
-
- // Virtio v1.1 5.10.6.3
- // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously
- // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates
- // any time a change in buffer space occurs.
- continue;
- }
- VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
- header.check_data_is_empty()?;
-
- self.connection_info = None;
- info!("Disconnected from the peer");
-
- let reason = if op == VirtioVsockOp::Rst {
- DisconnectReason::Reset
- } else {
- DisconnectReason::Shutdown
- };
- return Ok(Some(VsockEvent {
- source: connection_info.dst,
- destination: VsockAddr {
- cid: self.guest_cid,
- port: connection_info.src_port,
- },
- event_type: VsockEventType::Disconnected { reason },
- }));
- }
- VirtioVsockOp::Rw => {
- self.connection_info.as_mut().unwrap().fwd_cnt += header.len();
- return Ok(Some(VsockEvent {
- source: connection_info.dst,
- destination: VsockAddr {
- cid: self.guest_cid,
- port: connection_info.src_port,
- },
- event_type: VsockEventType::Received {
- length: header.len() as usize,
- },
- }));
- }
- VirtioVsockOp::CreditRequest => {
- header.check_data_is_empty()?;
- // TODO: Send a credit update.
- }
- VirtioVsockOp::Invalid => {
- return Err(SocketError::InvalidOperation.into());
- }
- }
+ if self.rx.should_notify() {
+ self.transport.notify(RX_QUEUE_IDX);
}
+
+ Ok(())
}
- /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies
- /// the body into the given buffer.
+ /// Pops one packet from the RX queue, if there is one pending. Returns the header, and a
+ /// reference to the buffer containing the body.
///
- /// Returns `None` if there is no pending packet, or an error if the body is bigger than the
- /// buffer supplied.
- fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VirtioVsockHdr>> {
+ /// Returns `None` if there is no pending packet.
+ fn pop_packet_from_rx_queue(&mut self) -> Result<Option<(VirtioVsockHdr, &[u8], u16)>> {
let Some(token) = self.rx.peek_used() else {
return Ok(None);
};
@@ -481,89 +459,53 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
// Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
// buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
// buffer back to the RX queue then we don't access it again until next time it is popped.
- let header = unsafe {
- let buffer = Self::as_mut_sub_rx_buffer(self.rx_buf_dma.raw_slice(), token.into())?;
+ let (header, body) = unsafe {
+ let buffer = self.rx_queue_buffers[usize::from(token)].as_mut();
let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
// Read the header and body from the buffer. Don't check the result yet, because we need
// to add the buffer back to the queue either way.
- let header_result = read_header_and_body(buffer, body);
-
- // Add the buffer back to the RX queue.
- let new_token = self.rx.add(&[], &mut [buffer])?;
- // If the RX buffer somehow gets assigned a different token, then our safety assumptions
- // are broken and we can't safely continue to do anything with the device.
- assert_eq!(new_token, token);
+ let header_result = read_header_and_body(buffer);
+ if header_result.is_err() {
+ // If there was an error, add the buffer back immediately. Ignore any errors, as we
+ // need to return the first error.
+ let _ = self.add_buffer_to_rx_queue(token);
+ }
header_result
}?;
debug!("Received packet {:?}. Op {:?}", header, header.op());
- Ok(Some(header))
- }
-
- fn connection_info(&self) -> Result<ConnectionInfo> {
- self.connection_info
- .clone()
- .ok_or(SocketError::NotConnected.into())
- }
-
- /// Gets a reference to a subslice of the RX buffer to be used for the given entry in the RX
- /// virtqueue.
- ///
- /// # Safety
- ///
- /// `rx_buf` must be a valid dereferenceable pointer.
- /// The returned reference has an arbitrary lifetime `'a`. This lifetime must not overlap with
- /// any other references to the same subslice of the RX buffer or outlive the buffer.
- unsafe fn as_mut_sub_rx_buffer<'a>(
- mut rx_buf: NonNull<[u8]>,
- i: usize,
- ) -> Result<&'a mut [u8]> {
- let buffer_size = rx_buf.len() / QUEUE_SIZE;
- let start = buffer_size
- .checked_mul(i)
- .ok_or(SocketError::InvalidNumber)?;
- let end = start
- .checked_add(buffer_size)
- .ok_or(SocketError::InvalidNumber)?;
- // Safe because no alignment or initialisation is required for [u8], and our caller assures
- // us that `rx_buf` is dereferenceable and that the lifetime of the slice we are creating
- // won't overlap with any other references to the same slice or outlive it.
- unsafe {
- rx_buf
- .as_mut()
- .get_mut(start..end)
- .ok_or(SocketError::BufferTooShort.into())
- }
+ Ok(Some((header, body, token)))
}
}
-fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> {
- let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
+fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
+ // Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>()`.
+ let header = VirtioVsockHdr::read_from_prefix(buffer).unwrap();
let body_length = header.len() as usize;
+
+ // This could fail if the device returns an unreasonably long body length.
let data_end = size_of::<VirtioVsockHdr>()
.checked_add(body_length)
.ok_or(SocketError::InvalidNumber)?;
+ // This could fail if the device returns a body length longer than the buffer we gave it.
let data = buffer
.get(size_of::<VirtioVsockHdr>()..data_end)
.ok_or(SocketError::BufferTooShort)?;
- body.get_mut(0..body_length)
- .ok_or(SocketError::OutputBufferTooShort(body_length))?
- .copy_from_slice(data);
- Ok(header)
+ Ok((header, data))
}
#[cfg(test)]
mod tests {
use super::*;
- use crate::volatile::ReadOnly;
use crate::{
hal::fake::FakeHal,
transport::{
fake::{FakeTransport, QueueStatus, State},
DeviceStatus, DeviceType,
},
+ volatile::ReadOnly,
};
use alloc::{sync::Arc, vec};
use core::ptr::NonNull;
@@ -580,7 +522,11 @@ mod tests {
driver_features: 0,
guest_page_size: 0,
interrupt_pending: false,
- queues: vec![QueueStatus::default(); 3],
+ queues: vec![
+ QueueStatus::default(),
+ QueueStatus::default(),
+ QueueStatus::default(),
+ ],
}));
let transport = FakeTransport {
device_type: DeviceType::Socket,
@@ -591,6 +537,6 @@ mod tests {
};
let socket =
VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
- assert_eq!(socket.guest_cid, 0x00_0000_0042);
+ assert_eq!(socket.guest_cid(), 0x00_0000_0042);
}
}
diff --git a/src/hal.rs b/src/hal.rs
index 6295f5f..a9fad1c 100644
--- a/src/hal.rs
+++ b/src/hal.rs
@@ -19,6 +19,8 @@ pub struct Dma<H: Hal> {
impl<H: Hal> Dma<H> {
/// Allocates the given number of pages of physically contiguous memory to be used for DMA in
/// the given direction.
+ ///
+ /// The pages will be zeroed.
pub fn new(pages: usize, direction: BufferDirection) -> Result<Self> {
let (paddr, vaddr) = H::dma_alloc(pages, direction);
if paddr == 0 {
@@ -67,7 +69,8 @@ impl<H: Hal> Drop for Dma<H> {
/// Implementations of this trait must follow the "implementation safety" requirements documented
/// for each method. Callers must follow the safety requirements documented for the unsafe methods.
pub unsafe trait Hal {
- /// Allocates the given number of contiguous physical pages of DMA memory for VirtIO use.
+ /// Allocates and zeroes the given number of contiguous physical pages of DMA memory for VirtIO
+ /// use.
///
/// Returns both the physical address which the device can use to access the memory, and a
/// pointer to the start of it which the driver can use to access it.
@@ -77,7 +80,7 @@ pub unsafe trait Hal {
/// Implementations of this method must ensure that the `NonNull<u8>` returned is a
/// [_valid_](https://doc.rust-lang.org/std/ptr/index.html#safety) pointer, aligned to
/// [`PAGE_SIZE`], and won't alias any other allocations or references in the program until it
- /// is deallocated by `dma_dealloc`.
+ /// is deallocated by `dma_dealloc`. The pages must be zeroed.
fn dma_alloc(pages: usize, direction: BufferDirection) -> (PhysAddr, NonNull<u8>);
/// Deallocates the given contiguous physical DMA memory pages.
diff --git a/src/lib.rs b/src/lib.rs
index 754dd51..f2f2f12 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -24,12 +24,14 @@
//!
//! ```
//! # use virtio_drivers::Hal;
+//! # #[cfg(feature = "alloc")]
//! use virtio_drivers::{
//! device::console::VirtIOConsole,
//! transport::{mmio::MmioTransport, DeviceType, Transport},
//! };
//!
+//! # #[cfg(feature = "alloc")]
//! # fn example<HalImpl: Hal>(transport: MmioTransport) {
//! if transport.device_type() == DeviceType::Console {
//! let mut console = VirtIOConsole::<HalImpl, _>::new(transport).unwrap();
diff --git a/src/queue.rs b/src/queue.rs
index d6baf17..758c139 100644
--- a/src/queue.rs
+++ b/src/queue.rs
@@ -60,7 +60,7 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
}
if !SIZE.is_power_of_two()
|| SIZE > u16::MAX.into()
- || transport.max_queue_size() < SIZE as u32
+ || transport.max_queue_size(idx) < SIZE as u32
{
return Err(Error::InvalidParam);
}
@@ -555,10 +555,13 @@ impl Descriptor {
}
}
+/// Descriptor flags
+#[derive(Copy, Clone, Debug, Default, Eq, FromBytes, PartialEq)]
+#[repr(transparent)]
+struct DescFlags(u16);
+
bitflags! {
- /// Descriptor flags
- #[derive(FromBytes)]
- struct DescFlags: u16 {
+ impl DescFlags: u16 {
const NEXT = 1;
const WRITE = 2;
const INDIRECT = 4;
diff --git a/src/transport/fake.rs b/src/transport/fake.rs
index a578db2..6ab61fc 100644
--- a/src/transport/fake.rs
+++ b/src/transport/fake.rs
@@ -4,8 +4,13 @@ use crate::{
PhysAddr, Result,
};
use alloc::{sync::Arc, vec::Vec};
-use core::{any::TypeId, ptr::NonNull};
-use std::sync::Mutex;
+use core::{
+ any::TypeId,
+ ptr::NonNull,
+ sync::atomic::{AtomicBool, Ordering},
+ time::Duration,
+};
+use std::{sync::Mutex, thread};
/// A fake implementation of [`Transport`] for unit tests.
#[derive(Debug)]
@@ -30,12 +35,14 @@ impl<C> Transport for FakeTransport<C> {
self.state.lock().unwrap().driver_features = driver_features;
}
- fn max_queue_size(&self) -> u32 {
+ fn max_queue_size(&mut self, _queue: u16) -> u32 {
self.max_queue_size
}
fn notify(&mut self, queue: u16) {
- self.state.lock().unwrap().queues[queue as usize].notified = true;
+ self.state.lock().unwrap().queues[queue as usize]
+ .notified
+ .store(true, Ordering::SeqCst);
}
fn get_status(&self) -> DeviceStatus {
@@ -168,13 +175,23 @@ impl State {
handler,
)
}
+
+ /// Waits until the given queue is notified.
+ pub fn wait_until_queue_notified(state: &Mutex<Self>, queue_index: u16) {
+ while !state.lock().unwrap().queues[usize::from(queue_index)]
+ .notified
+ .swap(false, Ordering::SeqCst)
+ {
+ thread::sleep(Duration::from_millis(10));
+ }
+ }
}
-#[derive(Clone, Debug, Default, Eq, PartialEq)]
+#[derive(Debug, Default)]
pub struct QueueStatus {
pub size: u32,
pub descriptors: PhysAddr,
pub driver_area: PhysAddr,
pub device_area: PhysAddr,
- pub notified: bool,
+ pub notified: AtomicBool,
}
diff --git a/src/transport/mmio.rs b/src/transport/mmio.rs
index 026646b..d938a97 100644
--- a/src/transport/mmio.rs
+++ b/src/transport/mmio.rs
@@ -338,9 +338,12 @@ impl Transport for MmioTransport {
}
}
- fn max_queue_size(&self) -> u32 {
+ fn max_queue_size(&mut self, queue: u16) -> u32 {
// Safe because self.header points to a valid VirtIO MMIO region.
- unsafe { volread!(self.header, queue_num_max) }
+ unsafe {
+ volwrite!(self.header, queue_sel, queue.into());
+ volread!(self.header, queue_num_max)
+ }
}
fn notify(&mut self, queue: u16) {
diff --git a/src/transport/mod.rs b/src/transport/mod.rs
index f88293c..3157e81 100644
--- a/src/transport/mod.rs
+++ b/src/transport/mod.rs
@@ -20,8 +20,8 @@ pub trait Transport {
/// Writes device features.
fn write_driver_features(&mut self, driver_features: u64);
- /// Gets the max size of queue.
- fn max_queue_size(&self) -> u32;
+ /// Gets the max size of the given queue.
+ fn max_queue_size(&mut self, queue: u16) -> u32;
/// Notifies the given queue on the device.
fn notify(&mut self, queue: u16);
@@ -65,6 +65,7 @@ pub trait Transport {
///
/// Ref: virtio 3.1.1 Device Initialization
fn begin_init(&mut self, negotiate_features: impl FnOnce(u64) -> u64) {
+ self.set_status(DeviceStatus::empty());
self.set_status(DeviceStatus::ACKNOWLEDGE | DeviceStatus::DRIVER);
let features = self.read_device_features();
@@ -91,8 +92,8 @@ pub trait Transport {
}
bitflags! {
- /// The device status field.
- #[derive(Default)]
+ /// The device status field. Writing 0 into this field resets the device.
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
pub struct DeviceStatus: u32 {
/// Indicates that the guest OS has found the device and recognized it
/// as a valid virtio device.
diff --git a/src/transport/pci.rs b/src/transport/pci.rs
index b8bcb15..27401fe 100644
--- a/src/transport/pci.rs
+++ b/src/transport/pci.rs
@@ -231,10 +231,13 @@ impl Transport for PciTransport {
}
}
- fn max_queue_size(&self) -> u32 {
+ fn max_queue_size(&mut self, queue: u16) -> u32 {
// Safe because the common config pointer is valid and we checked in get_bar_region that it
// was aligned.
- unsafe { volread!(self.common_cfg, queue_size) }.into()
+ unsafe {
+ volwrite!(self.common_cfg, queue_select, queue);
+ volread!(self.common_cfg, queue_size).into()
+ }
}
fn notify(&mut self, queue: u16) {
diff --git a/src/transport/pci/bus.rs b/src/transport/pci/bus.rs
index dd6f520..0a3014b 100644
--- a/src/transport/pci/bus.rs
+++ b/src/transport/pci/bus.rs
@@ -24,6 +24,7 @@ pub const PCI_CAP_ID_VNDR: u8 = 0x09;
bitflags! {
/// The status register in PCI configuration space.
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
pub struct Status: u16 {
// Bits 0-2 are reserved.
/// The state of the device's INTx# signal.
@@ -53,6 +54,7 @@ bitflags! {
bitflags! {
/// The command register in PCI configuration space.
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
pub struct Command: u16 {
/// The device can respond to I/O Space accesses.
const IO_SPACE = 1 << 0;