diff options
author | Alice Wang <aliceywang@google.com> | 2023-07-27 17:14:37 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2023-07-27 17:14:37 +0000 |
commit | 5ea26c7506ebcbbc12116a863ff111bf660e449e (patch) | |
tree | b1400747c25b795a087317ea0379eba1490fbaaf | |
parent | e9177f61d8ba8e36e12c16682ec4b4bc28cb712c (diff) | |
parent | 564dd7fd18ddd978de2593b8f81f5d186acf6e89 (diff) | |
download | virtio-drivers-5ea26c7506ebcbbc12116a863ff111bf660e449e.tar.gz |
Upgrade virtio-drivers to 0.6.0 am: 508e25a11b am: 358b0b4780 am: 564dd7fd18
Original change: https://android-review.googlesource.com/c/platform/external/rust/crates/virtio-drivers/+/2679235
Change-Id: Ic043dc86f9462c0a7513f90c22186e425e25d02d
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
-rw-r--r-- | .cargo_vcs_info.json | 2 | ||||
-rw-r--r-- | Android.bp | 6 | ||||
-rw-r--r-- | Cargo.toml | 6 | ||||
-rw-r--r-- | Cargo.toml.orig | 5 | ||||
-rw-r--r-- | METADATA | 10 | ||||
-rw-r--r-- | README.md | 4 | ||||
-rw-r--r-- | src/device/blk.rs | 390 | ||||
-rw-r--r-- | src/device/console.rs | 35 | ||||
-rw-r--r-- | src/device/gpu.rs | 22 | ||||
-rw-r--r-- | src/device/input.rs | 25 | ||||
-rw-r--r-- | src/device/net.rs | 46 | ||||
-rw-r--r-- | src/device/socket/mod.rs | 2 | ||||
-rw-r--r-- | src/device/socket/multiconnectionmanager.rs | 12 | ||||
-rw-r--r-- | src/device/socket/protocol.rs | 3 | ||||
-rw-r--r-- | src/device/socket/singleconnectionmanager.rs | 7 | ||||
-rw-r--r-- | src/device/socket/vsock.rs | 52 | ||||
-rw-r--r-- | src/hal.rs | 14 | ||||
-rw-r--r-- | src/hal/fake.rs | 47 | ||||
-rw-r--r-- | src/queue.rs | 510 | ||||
-rw-r--r-- | src/transport/mod.rs | 21 |
20 files changed, 934 insertions, 285 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json index ec9996c..b5546f9 100644 --- a/.cargo_vcs_info.json +++ b/.cargo_vcs_info.json @@ -1,6 +1,6 @@ { "git": { - "sha1": "de1c3b130e507702f13d142b9bee55670a4a2858" + "sha1": "94e013fd4cfb8ff5061e9fa9a4f43521b49679bb" }, "path_in_vcs": "" }
\ No newline at end of file @@ -22,9 +22,10 @@ license { rust_library_rlib { name: "libvirtio_drivers", + // has rustc warnings crate_name: "virtio_drivers", cargo_env_compat: true, - cargo_pkg_version: "0.5.0", + cargo_pkg_version: "0.6.0", srcs: ["src/lib.rs"], edition: "2018", features: ["alloc"], @@ -47,9 +48,10 @@ rust_library_rlib { rust_test { name: "virtio-drivers_test_src_lib", + // has rustc warnings crate_name: "virtio_drivers", cargo_env_compat: true, - cargo_pkg_version: "0.5.0", + cargo_pkg_version: "0.6.0", srcs: ["src/lib.rs"], test_suites: ["general-tests"], auto_gen_config: true, @@ -12,7 +12,7 @@ [package] edition = "2018" name = "virtio-drivers" -version = "0.5.0" +version = "0.6.0" authors = [ "Jiajie Chen <noc@jiegec.ac.cn>", "Runji Wang <wangrunji0408@163.com>", @@ -38,6 +38,10 @@ version = "0.4" [dependencies.zerocopy] version = "0.6.1" +[dev-dependencies.zerocopy] +version = "0.6.1" +features = ["alloc"] + [features] alloc = ["zerocopy/alloc"] default = ["alloc"] diff --git a/Cargo.toml.orig b/Cargo.toml.orig index 0a47ce9..8709240 100644 --- a/Cargo.toml.orig +++ b/Cargo.toml.orig @@ -1,6 +1,6 @@ [package] name = "virtio-drivers" -version = "0.5.0" +version = "0.6.0" license = "MIT" authors = [ "Jiajie Chen <noc@jiegec.ac.cn>", @@ -22,3 +22,6 @@ zerocopy = "0.6.1" [features] default = ["alloc"] alloc = ["zerocopy/alloc"] + +[dev-dependencies] +zerocopy = { version = "0.6.1", features = ["alloc"] } @@ -1,6 +1,6 @@ # This project was upgraded with external_updater. # Usage: tools/external_updater/updater.sh update rust/crates/virtio-drivers -# For more info, check https://cs.android.com/android/platform/superproject/+/master:tools/external_updater/README.md +# For more info, check https://cs.android.com/android/platform/superproject/+/main:tools/external_updater/README.md name: "virtio-drivers" description: "VirtIO guest drivers." @@ -11,13 +11,13 @@ third_party { } url { type: ARCHIVE - value: "https://static.crates.io/crates/virtio-drivers/virtio-drivers-0.5.0.crate" + value: "https://static.crates.io/crates/virtio-drivers/virtio-drivers-0.6.0.crate" } - version: "0.5.0" + version: "0.6.0" license_type: NOTICE last_upgrade_date { year: 2023 - month: 6 - day: 13 + month: 7 + day: 27 } } @@ -32,8 +32,8 @@ VirtIO guest drivers in Rust. For **no_std** environment. | Feature flag | Supported | | | ---------------------------- | --------- | --------------------------------------- | -| `VIRTIO_F_INDIRECT_DESC` | ❌ | Indirect descriptors | -| `VIRTIO_F_EVENT_IDX` | ❌ | `avail_event` and `used_event` fields | +| `VIRTIO_F_INDIRECT_DESC` | ✅ | Indirect descriptors | +| `VIRTIO_F_EVENT_IDX` | ✅ | `avail_event` and `used_event` fields | | `VIRTIO_F_VERSION_1` | TODO | VirtIO version 1 compliance | | `VIRTIO_F_ACCESS_PLATFORM` | ❌ | Limited device access to memory | | `VIRTIO_F_RING_PACKED` | ❌ | Packed virtqueue layout | diff --git a/src/device/blk.rs b/src/device/blk.rs index ea3aef0..a9ddfb2 100644 --- a/src/device/blk.rs +++ b/src/device/blk.rs @@ -11,6 +11,10 @@ use zerocopy::{AsBytes, FromBytes}; const QUEUE: u16 = 0; const QUEUE_SIZE: u16 = 16; +const SUPPORTED_FEATURES: BlkFeature = BlkFeature::RO + .union(BlkFeature::FLUSH) + .union(BlkFeature::RING_INDIRECT_DESC) + .union(BlkFeature::RING_EVENT_IDX); /// Driver for a VirtIO block device. /// @@ -33,8 +37,8 @@ const QUEUE_SIZE: u16 = 16; /// /// // Read sector 0 and then copy it to sector 1. /// let mut buf = [0; SECTOR_SIZE]; -/// disk.read_block(0, &mut buf)?; -/// disk.write_block(1, &buf)?; +/// disk.read_blocks(0, &mut buf)?; +/// disk.write_blocks(1, &buf)?; /// # Ok(()) /// # } /// ``` @@ -42,24 +46,15 @@ pub struct VirtIOBlk<H: Hal, T: Transport> { transport: T, queue: VirtQueue<H, { QUEUE_SIZE as usize }>, capacity: u64, - readonly: bool, + negotiated_features: BlkFeature, } impl<H: Hal, T: Transport> VirtIOBlk<H, T> { /// Create a new VirtIO-Blk driver. pub fn new(mut transport: T) -> Result<Self> { - let mut readonly = false; - - transport.begin_init(|features| { - let features = BlkFeature::from_bits_truncate(features); - info!("device features: {:?}", features); - readonly = features.contains(BlkFeature::RO); - // negotiate these flags only - let supported_features = BlkFeature::empty(); - (features & supported_features).bits() - }); + let negotiated_features = transport.begin_init(SUPPORTED_FEATURES); - // read configuration space + // Read configuration space. let config = transport.config_space::<BlkConfig>()?; info!("config: {:?}", config); // Safe because config is a valid pointer to the device configuration space. @@ -68,14 +63,19 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { }; info!("found a block device of size {}KB", capacity / 2); - let queue = VirtQueue::new(&mut transport, QUEUE)?; + let queue = VirtQueue::new( + &mut transport, + QUEUE, + negotiated_features.contains(BlkFeature::RING_INDIRECT_DESC), + negotiated_features.contains(BlkFeature::RING_EVENT_IDX), + )?; transport.finish_init(); Ok(VirtIOBlk { transport, queue, capacity, - readonly, + negotiated_features, }) } @@ -86,7 +86,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { /// Returns true if the block device is read-only, or false if it allows writes. pub fn readonly(&self) -> bool { - self.readonly + self.negotiated_features.contains(BlkFeature::RO) } /// Acknowledges a pending interrupt, if any. @@ -96,35 +96,98 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { self.transport.ack_interrupt() } - /// Reads a block into the given buffer. - /// - /// Blocks until the read completes or there is an error. - pub fn read_block(&mut self, block_id: usize, buf: &mut [u8]) -> Result { - assert_eq!(buf.len(), SECTOR_SIZE); - let req = BlkReq { - type_: ReqType::In, - reserved: 0, - sector: block_id as u64, - }; + /// Sends the given request to the device and waits for a response, with no extra data. + fn request(&mut self, request: BlkReq) -> Result { + let mut resp = BlkResp::default(); + self.queue.add_notify_wait_pop( + &[request.as_bytes()], + &mut [resp.as_bytes_mut()], + &mut self.transport, + )?; + resp.status.into() + } + + /// Sends the given request to the device and waits for a response, including the given data. + fn request_read(&mut self, request: BlkReq, data: &mut [u8]) -> Result { + let mut resp = BlkResp::default(); + self.queue.add_notify_wait_pop( + &[request.as_bytes()], + &mut [data, resp.as_bytes_mut()], + &mut self.transport, + )?; + resp.status.into() + } + + /// Sends the given request and data to the device and waits for a response. + fn request_write(&mut self, request: BlkReq, data: &[u8]) -> Result { let mut resp = BlkResp::default(); self.queue.add_notify_wait_pop( - &[req.as_bytes()], - &mut [buf, resp.as_bytes_mut()], + &[request.as_bytes(), data], + &mut [resp.as_bytes_mut()], &mut self.transport, )?; resp.status.into() } - /// Submits a request to read a block, but returns immediately without waiting for the read to - /// complete. + /// Requests the device to flush any pending writes to storage. + /// + /// This will be ignored if the device doesn't support the `VIRTIO_BLK_F_FLUSH` feature. + pub fn flush(&mut self) -> Result { + if self.negotiated_features.contains(BlkFeature::FLUSH) { + self.request(BlkReq { + type_: ReqType::Flush, + ..Default::default() + }) + } else { + Ok(()) + } + } + + /// Gets the device ID. + /// + /// The ID is written as ASCII into the given buffer, which must be 20 bytes long, and the used + /// length returned. + pub fn device_id(&mut self, id: &mut [u8; 20]) -> Result<usize> { + self.request_read( + BlkReq { + type_: ReqType::GetId, + ..Default::default() + }, + id, + )?; + + let length = id.iter().position(|&x| x == 0).unwrap_or(20); + Ok(length) + } + + /// Reads one or more blocks into the given buffer. + /// + /// The buffer length must be a non-zero multiple of [`SECTOR_SIZE`]. + /// + /// Blocks until the read completes or there is an error. + pub fn read_blocks(&mut self, block_id: usize, buf: &mut [u8]) -> Result { + assert_ne!(buf.len(), 0); + assert_eq!(buf.len() % SECTOR_SIZE, 0); + self.request_read( + BlkReq { + type_: ReqType::In, + reserved: 0, + sector: block_id as u64, + }, + buf, + ) + } + + /// Submits a request to read one or more blocks, but returns immediately without waiting for + /// the read to complete. /// /// # Arguments /// - /// * `block_id` - The identifier of the block to read. + /// * `block_id` - The identifier of the first block to read. /// * `req` - A buffer which the driver can use for the request to send to the device. The - /// contents don't matter as `read_block_nb` will initialise it, but like the other buffers it - /// needs to be valid (and not otherwise used) until the corresponding `complete_read_block` - /// call. + /// contents don't matter as `read_blocks_nb` will initialise it, but like the other buffers + /// it needs to be valid (and not otherwise used) until the corresponding + /// `complete_read_blocks` call. Its length must be a non-zero multiple of [`SECTOR_SIZE`]. /// * `buf` - The buffer in memory into which the block should be read. /// * `resp` - A mutable reference to a variable provided by the caller /// to contain the status of the request. The caller can safely @@ -137,7 +200,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { /// Descriptors to allocate, then it returns [`Error::QueueFull`]. /// /// The caller can then call `peek_used` with the returned token to check whether the device has - /// finished handling the request. Once it has, the caller must call `complete_read_block` with + /// finished handling the request. Once it has, the caller must call `complete_read_blocks` with /// the same buffers before reading the response. /// /// ``` @@ -150,13 +213,13 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { /// let mut request = BlkReq::default(); /// let mut buffer = [0; 512]; /// let mut response = BlkResp::default(); - /// let token = unsafe { blk.read_block_nb(42, &mut request, &mut buffer, &mut response) }?; + /// let token = unsafe { blk.read_blocks_nb(42, &mut request, &mut buffer, &mut response) }?; /// /// // Wait for an interrupt to tell us that the request completed... /// assert_eq!(blk.peek_used(), Some(token)); /// /// unsafe { - /// blk.complete_read_block(token, &request, &mut buffer, &mut response)?; + /// blk.complete_read_blocks(token, &request, &mut buffer, &mut response)?; /// } /// if response.status() == RespStatus::OK { /// println!("Successfully read block."); @@ -172,14 +235,15 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { /// `req`, `buf` and `resp` are still borrowed by the underlying VirtIO block device even after /// this method returns. Thus, it is the caller's responsibility to guarantee that they are not /// accessed before the request is completed in order to avoid data races. - pub unsafe fn read_block_nb( + pub unsafe fn read_blocks_nb( &mut self, block_id: usize, req: &mut BlkReq, buf: &mut [u8], resp: &mut BlkResp, ) -> Result<u16> { - assert_eq!(buf.len(), SECTOR_SIZE); + assert_ne!(buf.len(), 0); + assert_eq!(buf.len() % SECTOR_SIZE, 0); *req = BlkReq { type_: ReqType::In, reserved: 0, @@ -194,13 +258,13 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { Ok(token) } - /// Completes a read operation which was started by `read_block_nb`. + /// Completes a read operation which was started by `read_blocks_nb`. /// /// # Safety /// - /// The same buffers must be passed in again as were passed to `read_block_nb` when it returned + /// The same buffers must be passed in again as were passed to `read_blocks_nb` when it returned /// the token. - pub unsafe fn complete_read_block( + pub unsafe fn complete_read_blocks( &mut self, token: u16, req: &BlkReq, @@ -212,55 +276,56 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { resp.status.into() } - /// Writes the contents of the given buffer to a block. + /// Writes the contents of the given buffer to a block or blocks. + /// + /// The buffer length must be a non-zero multiple of [`SECTOR_SIZE`]. /// /// Blocks until the write is complete or there is an error. - pub fn write_block(&mut self, block_id: usize, buf: &[u8]) -> Result { - assert_eq!(buf.len(), SECTOR_SIZE); - let req = BlkReq { - type_: ReqType::Out, - reserved: 0, - sector: block_id as u64, - }; - let mut resp = BlkResp::default(); - self.queue.add_notify_wait_pop( - &[req.as_bytes(), buf], - &mut [resp.as_bytes_mut()], - &mut self.transport, - )?; - resp.status.into() + pub fn write_blocks(&mut self, block_id: usize, buf: &[u8]) -> Result { + assert_ne!(buf.len(), 0); + assert_eq!(buf.len() % SECTOR_SIZE, 0); + self.request_write( + BlkReq { + type_: ReqType::Out, + sector: block_id as u64, + ..Default::default() + }, + buf, + ) } - /// Submits a request to write a block, but returns immediately without waiting for the write to - /// complete. + /// Submits a request to write one or more blocks, but returns immediately without waiting for + /// the write to complete. /// /// # Arguments /// - /// * `block_id` - The identifier of the block to write. + /// * `block_id` - The identifier of the first block to write. /// * `req` - A buffer which the driver can use for the request to send to the device. The - /// contents don't matter as `read_block_nb` will initialise it, but like the other buffers it - /// needs to be valid (and not otherwise used) until the corresponding `complete_read_block` - /// call. - /// * `buf` - The buffer in memory containing the data to write to the block. + /// contents don't matter as `read_blocks_nb` will initialise it, but like the other buffers + /// it needs to be valid (and not otherwise used) until the corresponding + /// `complete_write_blocks` call. + /// * `buf` - The buffer in memory containing the data to write to the blocks. Its length must + /// be a non-zero multiple of [`SECTOR_SIZE`]. /// * `resp` - A mutable reference to a variable provided by the caller /// to contain the status of the request. The caller can safely /// read the variable only after the request is complete. /// /// # Usage /// - /// See [VirtIOBlk::read_block_nb]. + /// See [VirtIOBlk::read_blocks_nb]. /// /// # Safety /// - /// See [VirtIOBlk::read_block_nb]. - pub unsafe fn write_block_nb( + /// See [VirtIOBlk::read_blocks_nb]. + pub unsafe fn write_blocks_nb( &mut self, block_id: usize, req: &mut BlkReq, buf: &[u8], resp: &mut BlkResp, ) -> Result<u16> { - assert_eq!(buf.len(), SECTOR_SIZE); + assert_ne!(buf.len(), 0); + assert_eq!(buf.len() % SECTOR_SIZE, 0); *req = BlkReq { type_: ReqType::Out, reserved: 0, @@ -275,13 +340,13 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { Ok(token) } - /// Completes a write operation which was started by `write_block_nb`. + /// Completes a write operation which was started by `write_blocks_nb`. /// /// # Safety /// - /// The same buffers must be passed in again as were passed to `write_block_nb` when it returned - /// the token. - pub unsafe fn complete_write_block( + /// The same buffers must be passed in again as were passed to `write_blocks_nb` when it + /// returned the token. + pub unsafe fn complete_write_blocks( &mut self, token: u16, req: &BlkReq, @@ -372,8 +437,11 @@ enum ReqType { In = 0, Out = 1, Flush = 4, + GetId = 8, + GetLifetime = 10, Discard = 11, WriteZeroes = 13, + SecureErase = 14, } /// Status of a VirtIOBlk request. @@ -439,6 +507,8 @@ bitflags! { const TOPOLOGY = 1 << 10; /// Device can toggle its cache between writeback and writethrough modes. const CONFIG_WCE = 1 << 11; + /// Device supports multiqueue. + const MQ = 1 << 12; /// Device can support discard command, maximum discard sectors size in /// `max_discard_sectors` and maximum discard segment number in /// `max_discard_seg`. @@ -447,6 +517,10 @@ bitflags! { /// size in `max_write_zeroes_sectors` and maximum write zeroes segment /// number in `max_write_zeroes_seg`. const WRITE_ZEROES = 1 << 14; + /// Device supports providing storage lifetime information. + const LIFETIME = 1 << 15; + /// Device can support the secure erase command. + const SECURE_ERASE = 1 << 16; // device independent const NOTIFY_ON_EMPTY = 1 << 24; // legacy @@ -473,7 +547,7 @@ mod tests { hal::fake::FakeHal, transport::{ fake::{FakeTransport, QueueStatus, State}, - DeviceStatus, DeviceType, + DeviceType, }, }; use alloc::{sync::Arc, vec}; @@ -497,14 +571,11 @@ mod tests { opt_io_size: Volatile::new(0), }; let state = Arc::new(Mutex::new(State { - status: DeviceStatus::empty(), - driver_features: 0, - guest_page_size: 0, - interrupt_pending: false, queues: vec![QueueStatus::default()], + ..Default::default() })); let transport = FakeTransport { - device_type: DeviceType::Console, + device_type: DeviceType::Block, max_queue_size: QUEUE_SIZE.into(), device_features: BlkFeature::RO.bits(), config_space: NonNull::from(&mut config_space), @@ -533,16 +604,13 @@ mod tests { opt_io_size: Volatile::new(0), }; let state = Arc::new(Mutex::new(State { - status: DeviceStatus::empty(), - driver_features: 0, - guest_page_size: 0, - interrupt_pending: false, queues: vec![QueueStatus::default()], + ..Default::default() })); let transport = FakeTransport { - device_type: DeviceType::Console, + device_type: DeviceType::Block, max_queue_size: QUEUE_SIZE.into(), - device_features: 0, + device_features: BlkFeature::RING_INDIRECT_DESC.bits(), config_space: NonNull::from(&mut config_space), state: state.clone(), }; @@ -583,7 +651,7 @@ mod tests { // Read a block from the device. let mut buffer = [0; 512]; - blk.read_block(42, &mut buffer).unwrap(); + blk.read_blocks(42, &mut buffer).unwrap(); assert_eq!(&buffer[0..9], b"Test data"); handle.join().unwrap(); @@ -606,16 +674,13 @@ mod tests { opt_io_size: Volatile::new(0), }; let state = Arc::new(Mutex::new(State { - status: DeviceStatus::empty(), - driver_features: 0, - guest_page_size: 0, - interrupt_pending: false, queues: vec![QueueStatus::default()], + ..Default::default() })); let transport = FakeTransport { - device_type: DeviceType::Console, + device_type: DeviceType::Block, max_queue_size: QUEUE_SIZE.into(), - device_features: 0, + device_features: BlkFeature::RING_INDIRECT_DESC.bits(), config_space: NonNull::from(&mut config_space), state: state.clone(), }; @@ -659,7 +724,146 @@ mod tests { // Write a block to the device. let mut buffer = [0; 512]; buffer[0..9].copy_from_slice(b"Test data"); - blk.write_block(42, &mut buffer).unwrap(); + blk.write_blocks(42, &mut buffer).unwrap(); + + // Request to flush should be ignored as the device doesn't support it. + blk.flush().unwrap(); + + handle.join().unwrap(); + } + + #[test] + fn flush() { + let mut config_space = BlkConfig { + capacity_low: Volatile::new(66), + capacity_high: Volatile::new(0), + size_max: Volatile::new(0), + seg_max: Volatile::new(0), + cylinders: Volatile::new(0), + heads: Volatile::new(0), + sectors: Volatile::new(0), + blk_size: Volatile::new(0), + physical_block_exp: Volatile::new(0), + alignment_offset: Volatile::new(0), + min_io_size: Volatile::new(0), + opt_io_size: Volatile::new(0), + }; + let state = Arc::new(Mutex::new(State { + queues: vec![QueueStatus::default()], + ..Default::default() + })); + let transport = FakeTransport { + device_type: DeviceType::Block, + max_queue_size: QUEUE_SIZE.into(), + device_features: (BlkFeature::RING_INDIRECT_DESC | BlkFeature::FLUSH).bits(), + config_space: NonNull::from(&mut config_space), + state: state.clone(), + }; + let mut blk = VirtIOBlk::<FakeHal, FakeTransport<BlkConfig>>::new(transport).unwrap(); + + // Start a thread to simulate the device waiting for a flush request. + let handle = thread::spawn(move || { + println!("Device waiting for a request."); + State::wait_until_queue_notified(&state, QUEUE); + println!("Transmit queue was notified."); + + state + .lock() + .unwrap() + .read_write_queue::<{ QUEUE_SIZE as usize }>(QUEUE, |request| { + assert_eq!( + request, + BlkReq { + type_: ReqType::Flush, + reserved: 0, + sector: 0, + } + .as_bytes() + ); + + let mut response = Vec::new(); + response.extend_from_slice( + BlkResp { + status: RespStatus::OK, + } + .as_bytes(), + ); + + response + }); + }); + + // Request to flush. + blk.flush().unwrap(); + + handle.join().unwrap(); + } + + #[test] + fn device_id() { + let mut config_space = BlkConfig { + capacity_low: Volatile::new(66), + capacity_high: Volatile::new(0), + size_max: Volatile::new(0), + seg_max: Volatile::new(0), + cylinders: Volatile::new(0), + heads: Volatile::new(0), + sectors: Volatile::new(0), + blk_size: Volatile::new(0), + physical_block_exp: Volatile::new(0), + alignment_offset: Volatile::new(0), + min_io_size: Volatile::new(0), + opt_io_size: Volatile::new(0), + }; + let state = Arc::new(Mutex::new(State { + queues: vec![QueueStatus::default()], + ..Default::default() + })); + let transport = FakeTransport { + device_type: DeviceType::Block, + max_queue_size: QUEUE_SIZE.into(), + device_features: BlkFeature::RING_INDIRECT_DESC.bits(), + config_space: NonNull::from(&mut config_space), + state: state.clone(), + }; + let mut blk = VirtIOBlk::<FakeHal, FakeTransport<BlkConfig>>::new(transport).unwrap(); + + // Start a thread to simulate the device waiting for a flush request. + let handle = thread::spawn(move || { + println!("Device waiting for a request."); + State::wait_until_queue_notified(&state, QUEUE); + println!("Transmit queue was notified."); + + state + .lock() + .unwrap() + .read_write_queue::<{ QUEUE_SIZE as usize }>(QUEUE, |request| { + assert_eq!( + request, + BlkReq { + type_: ReqType::GetId, + reserved: 0, + sector: 0, + } + .as_bytes() + ); + + let mut response = Vec::new(); + response.extend_from_slice(b"device_id\0\0\0\0\0\0\0\0\0\0\0"); + response.extend_from_slice( + BlkResp { + status: RespStatus::OK, + } + .as_bytes(), + ); + + response + }); + }); + + let mut id = [0; 20]; + let length = blk.device_id(&mut id).unwrap(); + assert_eq!(&id[0..length], b"device_id"); handle.join().unwrap(); } diff --git a/src/device/console.rs b/src/device/console.rs index 7d3c7d4..6528276 100644 --- a/src/device/console.rs +++ b/src/device/console.rs @@ -8,11 +8,11 @@ use crate::{Result, PAGE_SIZE}; use alloc::boxed::Box; use bitflags::bitflags; use core::ptr::NonNull; -use log::info; const QUEUE_RECEIVEQ_PORT_0: u16 = 0; const QUEUE_TRANSMITQ_PORT_0: u16 = 1; const QUEUE_SIZE: usize = 2; +const SUPPORTED_FEATURES: Features = Features::RING_EVENT_IDX; /// Driver for a VirtIO console device. /// @@ -65,15 +65,20 @@ pub struct ConsoleInfo { impl<H: Hal, T: Transport> VirtIOConsole<H, T> { /// Creates a new VirtIO console driver. pub fn new(mut transport: T) -> Result<Self> { - transport.begin_init(|features| { - let features = Features::from_bits_truncate(features); - info!("Device features {:?}", features); - let supported_features = Features::empty(); - (features & supported_features).bits() - }); + let negotiated_features = transport.begin_init(SUPPORTED_FEATURES); let config_space = transport.config_space::<Config>()?; - let receiveq = VirtQueue::new(&mut transport, QUEUE_RECEIVEQ_PORT_0)?; - let transmitq = VirtQueue::new(&mut transport, QUEUE_TRANSMITQ_PORT_0)?; + let receiveq = VirtQueue::new( + &mut transport, + QUEUE_RECEIVEQ_PORT_0, + false, + negotiated_features.contains(Features::RING_EVENT_IDX), + )?; + let transmitq = VirtQueue::new( + &mut transport, + QUEUE_TRANSMITQ_PORT_0, + false, + negotiated_features.contains(Features::RING_EVENT_IDX), + )?; // Safe because no alignment or initialisation is required for [u8], the DMA buffer is // dereferenceable, and the lifetime of the reference matches the lifetime of the DMA buffer @@ -241,7 +246,7 @@ mod tests { hal::fake::FakeHal, transport::{ fake::{FakeTransport, QueueStatus, State}, - DeviceStatus, DeviceType, + DeviceType, }, }; use alloc::{sync::Arc, vec}; @@ -257,11 +262,8 @@ mod tests { emerg_wr: WriteOnly::default(), }; let state = Arc::new(Mutex::new(State { - status: DeviceStatus::empty(), - driver_features: 0, - guest_page_size: 0, - interrupt_pending: false, queues: vec![QueueStatus::default(), QueueStatus::default()], + ..Default::default() })); let transport = FakeTransport { device_type: DeviceType::Console, @@ -305,11 +307,8 @@ mod tests { emerg_wr: WriteOnly::default(), }; let state = Arc::new(Mutex::new(State { - status: DeviceStatus::empty(), - driver_features: 0, - guest_page_size: 0, - interrupt_pending: false, queues: vec![QueueStatus::default(), QueueStatus::default()], + ..Default::default() })); let transport = FakeTransport { device_type: DeviceType::Console, diff --git a/src/device/gpu.rs b/src/device/gpu.rs index 43e1b76..e19b780 100644 --- a/src/device/gpu.rs +++ b/src/device/gpu.rs @@ -11,6 +11,7 @@ use log::info; use zerocopy::{AsBytes, FromBytes}; const QUEUE_SIZE: u16 = 2; +const SUPPORTED_FEATURES: Features = Features::RING_EVENT_IDX; /// A virtio based graphics adapter. /// @@ -39,12 +40,7 @@ pub struct VirtIOGpu<H: Hal, T: Transport> { impl<H: Hal, T: Transport> VirtIOGpu<H, T> { /// Create a new VirtIO-Gpu driver. pub fn new(mut transport: T) -> Result<Self> { - transport.begin_init(|features| { - let features = Features::from_bits_truncate(features); - info!("Device features {:?}", features); - let supported_features = Features::empty(); - (features & supported_features).bits() - }); + let negotiated_features = transport.begin_init(SUPPORTED_FEATURES); // read configuration space let config_space = transport.config_space::<Config>()?; @@ -57,8 +53,18 @@ impl<H: Hal, T: Transport> VirtIOGpu<H, T> { ); } - let control_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT)?; - let cursor_queue = VirtQueue::new(&mut transport, QUEUE_CURSOR)?; + let control_queue = VirtQueue::new( + &mut transport, + QUEUE_TRANSMIT, + false, + negotiated_features.contains(Features::RING_EVENT_IDX), + )?; + let cursor_queue = VirtQueue::new( + &mut transport, + QUEUE_CURSOR, + false, + negotiated_features.contains(Features::RING_EVENT_IDX), + )?; let queue_buf_send = FromBytes::new_box_slice_zeroed(PAGE_SIZE); let queue_buf_recv = FromBytes::new_box_slice_zeroed(PAGE_SIZE); diff --git a/src/device/input.rs b/src/device/input.rs index dee2fec..c277b64 100644 --- a/src/device/input.rs +++ b/src/device/input.rs @@ -8,7 +8,6 @@ use crate::volatile::{volread, volwrite, ReadOnly, WriteOnly}; use crate::Result; use alloc::boxed::Box; use core::ptr::NonNull; -use log::info; use zerocopy::{AsBytes, FromBytes}; /// Virtual human interface devices such as keyboards, mice and tablets. @@ -28,18 +27,23 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> { /// Create a new VirtIO-Input driver. pub fn new(mut transport: T) -> Result<Self> { let mut event_buf = Box::new([InputEvent::default(); QUEUE_SIZE]); - transport.begin_init(|features| { - let features = Feature::from_bits_truncate(features); - info!("Device features: {:?}", features); - // negotiate these flags only - let supported_features = Feature::empty(); - (features & supported_features).bits() - }); + + let negotiated_features = transport.begin_init(SUPPORTED_FEATURES); let config = transport.config_space::<Config>()?; - let mut event_queue = VirtQueue::new(&mut transport, QUEUE_EVENT)?; - let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS)?; + let mut event_queue = VirtQueue::new( + &mut transport, + QUEUE_EVENT, + false, + negotiated_features.contains(Feature::RING_EVENT_IDX), + )?; + let status_queue = VirtQueue::new( + &mut transport, + QUEUE_STATUS, + false, + negotiated_features.contains(Feature::RING_EVENT_IDX), + )?; for (i, event) in event_buf.as_mut().iter_mut().enumerate() { // Safe because the buffer lasts as long as the queue. let token = unsafe { event_queue.add(&[], &mut [event.as_bytes_mut()])? }; @@ -193,6 +197,7 @@ pub struct InputEvent { const QUEUE_EVENT: u16 = 0; const QUEUE_STATUS: u16 = 1; +const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX; // a parameter that can change const QUEUE_SIZE: usize = 32; diff --git a/src/device/net.rs b/src/device/net.rs index b9419e7..522997e 100644 --- a/src/device/net.rs +++ b/src/device/net.rs @@ -8,7 +8,7 @@ use crate::{Error, Result}; use alloc::{vec, vec::Vec}; use bitflags::bitflags; use core::{convert::TryInto, mem::size_of}; -use log::{debug, info, warn}; +use log::{debug, warn}; use zerocopy::{AsBytes, FromBytes}; const MAX_BUFFER_LEN: usize = 65535; @@ -112,12 +112,7 @@ pub struct VirtIONet<H: Hal, T: Transport, const QUEUE_SIZE: usize> { impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE> { /// Create a new VirtIO-Net driver. pub fn new(mut transport: T, buf_len: usize) -> Result<Self> { - transport.begin_init(|features| { - let features = Features::from_bits_truncate(features); - info!("Device features {:?}", features); - let supported_features = Features::MAC | Features::STATUS; - (features & supported_features).bits() - }); + let negotiated_features = transport.begin_init(SUPPORTED_FEATURES); // read configuration space let config = transport.config_space::<Config>()?; let mac; @@ -139,8 +134,18 @@ impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE> return Err(Error::InvalidParam); } - let send_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT)?; - let mut recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE)?; + let send_queue = VirtQueue::new( + &mut transport, + QUEUE_TRANSMIT, + false, + negotiated_features.contains(Features::RING_EVENT_IDX), + )?; + let mut recv_queue = VirtQueue::new( + &mut transport, + QUEUE_RECEIVE, + false, + negotiated_features.contains(Features::RING_EVENT_IDX), + )?; const NONE_BUF: Option<RxBuffer> = None; let mut rx_buffers = [NONE_BUF; QUEUE_SIZE]; @@ -243,11 +248,21 @@ impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE> /// completed. pub fn send(&mut self, tx_buf: TxBuffer) -> Result { let header = VirtioNetHdr::default(); - self.send_queue.add_notify_wait_pop( - &[header.as_bytes(), tx_buf.packet()], - &mut [], - &mut self.transport, - )?; + if tx_buf.packet_len() == 0 { + // Special case sending an empty packet, to avoid adding an empty buffer to the + // virtqueue. + self.send_queue.add_notify_wait_pop( + &[header.as_bytes()], + &mut [], + &mut self.transport, + )?; + } else { + self.send_queue.add_notify_wait_pop( + &[header.as_bytes(), tx_buf.packet()], + &mut [], + &mut self.transport, + )?; + } Ok(()) } } @@ -393,3 +408,6 @@ impl GsoType { const QUEUE_RECEIVE: u16 = 0; const QUEUE_TRANSMIT: u16 = 1; +const SUPPORTED_FEATURES: Features = Features::MAC + .union(Features::STATUS) + .union(Features::RING_EVENT_IDX); diff --git a/src/device/socket/mod.rs b/src/device/socket/mod.rs index bf423bf..acc7def 100644 --- a/src/device/socket/mod.rs +++ b/src/device/socket/mod.rs @@ -19,7 +19,7 @@ mod vsock; pub use error::SocketError; #[cfg(feature = "alloc")] pub use multiconnectionmanager::VsockConnectionManager; -pub use protocol::VsockAddr; +pub use protocol::{VsockAddr, VMADDR_CID_HOST}; #[cfg(feature = "alloc")] pub use singleconnectionmanager::SingleConnectionManager; #[cfg(feature = "alloc")] diff --git a/src/device/socket/multiconnectionmanager.rs b/src/device/socket/multiconnectionmanager.rs index 6aee5cd..430f5e8 100644 --- a/src/device/socket/multiconnectionmanager.rs +++ b/src/device/socket/multiconnectionmanager.rs @@ -380,7 +380,7 @@ mod tests { hal::fake::FakeHal, transport::{ fake::{FakeTransport, QueueStatus, State}, - DeviceStatus, DeviceType, + DeviceType, }, volatile::ReadOnly, }; @@ -407,15 +407,12 @@ mod tests { guest_cid_high: ReadOnly::new(0), }; let state = Arc::new(Mutex::new(State { - status: DeviceStatus::empty(), - driver_features: 0, - guest_page_size: 0, - interrupt_pending: false, queues: vec![ QueueStatus::default(), QueueStatus::default(), QueueStatus::default(), ], + ..Default::default() })); let transport = FakeTransport { device_type: DeviceType::Socket, @@ -622,15 +619,12 @@ mod tests { guest_cid_high: ReadOnly::new(0), }; let state = Arc::new(Mutex::new(State { - status: DeviceStatus::empty(), - driver_features: 0, - guest_page_size: 0, - interrupt_pending: false, queues: vec![ QueueStatus::default(), QueueStatus::default(), QueueStatus::default(), ], + ..Default::default() })); let transport = FakeTransport { device_type: DeviceType::Socket, diff --git a/src/device/socket/protocol.rs b/src/device/socket/protocol.rs index 3587005..ab0650b 100644 --- a/src/device/socket/protocol.rs +++ b/src/device/socket/protocol.rs @@ -12,6 +12,9 @@ use zerocopy::{ AsBytes, FromBytes, }; +/// Well-known CID for the host. +pub const VMADDR_CID_HOST: u64 = 2; + /// Currently only stream sockets are supported. type is 1 for stream socket types. #[derive(Copy, Clone, Debug)] #[repr(u16)] diff --git a/src/device/socket/singleconnectionmanager.rs b/src/device/socket/singleconnectionmanager.rs index 8c9bff6..b5b21e4 100644 --- a/src/device/socket/singleconnectionmanager.rs +++ b/src/device/socket/singleconnectionmanager.rs @@ -191,7 +191,7 @@ mod tests { hal::fake::FakeHal, transport::{ fake::{FakeTransport, QueueStatus, State}, - DeviceStatus, DeviceType, + DeviceType, }, volatile::ReadOnly, }; @@ -218,15 +218,12 @@ mod tests { guest_cid_high: ReadOnly::new(0), }; let state = Arc::new(Mutex::new(State { - status: DeviceStatus::empty(), - driver_features: 0, - guest_page_size: 0, - interrupt_pending: false, queues: vec![ QueueStatus::default(), QueueStatus::default(), QueueStatus::default(), ], + ..Default::default() })); let transport = FakeTransport { device_type: DeviceType::Socket, diff --git a/src/device/socket/vsock.rs b/src/device/socket/vsock.rs index 523930e..2e9978a 100644 --- a/src/device/socket/vsock.rs +++ b/src/device/socket/vsock.rs @@ -19,6 +19,7 @@ pub(crate) const TX_QUEUE_IDX: u16 = 1; const EVENT_QUEUE_IDX: u16 = 2; pub(crate) const QUEUE_SIZE: usize = 8; +const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX; /// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than size_of::<VirtioVsockHdr>(). const RX_BUFFER_SIZE: usize = 512; @@ -241,13 +242,7 @@ impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> { impl<H: Hal, T: Transport> VirtIOSocket<H, T> { /// Create a new VirtIO Vsock driver. pub fn new(mut transport: T) -> Result<Self> { - transport.begin_init(|features| { - let features = Feature::from_bits_truncate(features); - debug!("Device features: {:?}", features); - // negotiate these flags only - let supported_features = Feature::empty(); - (features & supported_features).bits() - }); + let negotiated_features = transport.begin_init(SUPPORTED_FEATURES); let config = transport.config_space::<VirtioVsockConfig>()?; debug!("config: {:?}", config); @@ -257,9 +252,24 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> { }; debug!("guest cid: {guest_cid:?}"); - let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX)?; - let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?; - let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?; + let mut rx = VirtQueue::new( + &mut transport, + RX_QUEUE_IDX, + false, + negotiated_features.contains(Feature::RING_EVENT_IDX), + )?; + let tx = VirtQueue::new( + &mut transport, + TX_QUEUE_IDX, + false, + negotiated_features.contains(Feature::RING_EVENT_IDX), + )?; + let event = VirtQueue::new( + &mut transport, + EVENT_QUEUE_IDX, + false, + negotiated_features.contains(Feature::RING_EVENT_IDX), + )?; // Allocate and add buffers for the RX queue. let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE]; @@ -411,11 +421,16 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> { } fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result { - let _len = self.tx.add_notify_wait_pop( - &[header.as_bytes(), buffer], - &mut [], - &mut self.transport, - )?; + let _len = if buffer.is_empty() { + self.tx + .add_notify_wait_pop(&[header.as_bytes()], &mut [], &mut self.transport)? + } else { + self.tx.add_notify_wait_pop( + &[header.as_bytes(), buffer], + &mut [], + &mut self.transport, + )? + }; Ok(()) } @@ -503,7 +518,7 @@ mod tests { hal::fake::FakeHal, transport::{ fake::{FakeTransport, QueueStatus, State}, - DeviceStatus, DeviceType, + DeviceType, }, volatile::ReadOnly, }; @@ -518,15 +533,12 @@ mod tests { guest_cid_high: ReadOnly::new(0), }; let state = Arc::new(Mutex::new(State { - status: DeviceStatus::empty(), - driver_features: 0, - guest_page_size: 0, - interrupt_pending: false, queues: vec![ QueueStatus::default(), QueueStatus::default(), QueueStatus::default(), ], + ..Default::default() })); let transport = FakeTransport { device_type: DeviceType::Socket, @@ -13,7 +13,7 @@ pub struct Dma<H: Hal> { paddr: usize, vaddr: NonNull<u8>, pages: usize, - _phantom: PhantomData<H>, + _hal: PhantomData<H>, } impl<H: Hal> Dma<H> { @@ -30,7 +30,7 @@ impl<H: Hal> Dma<H> { paddr, vaddr, pages, - _phantom: PhantomData::default(), + _hal: PhantomData, }) } @@ -118,8 +118,8 @@ pub unsafe trait Hal { /// /// # Safety /// - /// The buffer must be a valid pointer to memory which will not be accessed by any other thread - /// for the duration of this method call. + /// The buffer must be a valid pointer to a non-empty memory range which will not be accessed by + /// any other thread for the duration of this method call. unsafe fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr; /// Unshares the given memory range from the device and (if necessary) copies it back to the @@ -127,9 +127,9 @@ pub unsafe trait Hal { /// /// # Safety /// - /// The buffer must be a valid pointer to memory which will not be accessed by any other thread - /// for the duration of this method call. The `paddr` must be the value previously returned by - /// the corresponding `share` call. + /// The buffer must be a valid pointer to a non-empty memory range which will not be accessed by + /// any other thread for the duration of this method call. The `paddr` must be the value + /// previously returned by the corresponding `share` call. unsafe fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection); } diff --git a/src/hal/fake.rs b/src/hal/fake.rs index 5d46835..fde129f 100644 --- a/src/hal/fake.rs +++ b/src/hal/fake.rs @@ -4,7 +4,11 @@ use crate::{BufferDirection, Hal, PhysAddr, PAGE_SIZE}; use alloc::alloc::{alloc_zeroed, dealloc, handle_alloc_error}; -use core::{alloc::Layout, ptr::NonNull}; +use core::{ + alloc::Layout, + ptr::{self, NonNull}, +}; +use zerocopy::FromBytes; #[derive(Debug)] pub struct FakeHal; @@ -38,18 +42,49 @@ unsafe impl Hal for FakeHal { NonNull::new(paddr as _).unwrap() } - unsafe fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr { - let vaddr = buffer.as_ptr() as *mut u8 as usize; + unsafe fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr { + assert_ne!(buffer.len(), 0); + // To ensure that the driver is handling and unsharing buffers properly, allocate a new + // buffer and copy to it if appropriate. + let mut shared_buffer = u8::new_box_slice_zeroed(buffer.len()); + if let BufferDirection::DriverToDevice | BufferDirection::Both = direction { + unsafe { + buffer + .as_ptr() + .cast::<u8>() + .copy_to(shared_buffer.as_mut_ptr(), buffer.len()); + } + } + let vaddr = Box::into_raw(shared_buffer) as *mut u8 as usize; // Nothing to do, as the host already has access to all memory. virt_to_phys(vaddr) } - unsafe fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) { - // Nothing to do, as the host already has access to all memory and we didn't copy the buffer - // anywhere else. + unsafe fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection) { + assert_ne!(buffer.len(), 0); + assert_ne!(paddr, 0); + let vaddr = phys_to_virt(paddr); + let shared_buffer = unsafe { + Box::from_raw(ptr::slice_from_raw_parts_mut( + vaddr as *mut u8, + buffer.len(), + )) + }; + if let BufferDirection::DeviceToDriver | BufferDirection::Both = direction { + unsafe { + buffer + .as_ptr() + .cast::<u8>() + .copy_from(shared_buffer.as_ptr(), buffer.len()); + } + } } } fn virt_to_phys(vaddr: usize) -> PhysAddr { vaddr } + +fn phys_to_virt(paddr: PhysAddr) -> usize { + paddr +} diff --git a/src/queue.rs b/src/queue.rs index 758c139..dcafcbc 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -3,6 +3,8 @@ use crate::hal::{BufferDirection, Dma, Hal, PhysAddr}; use crate::transport::Transport; use crate::{align_up, nonnull_slice_from_raw_parts, pages, Error, Result, PAGE_SIZE}; +#[cfg(feature = "alloc")] +use alloc::boxed::Box; use bitflags::bitflags; #[cfg(test)] use core::cmp::min; @@ -12,7 +14,7 @@ use core::mem::{size_of, take}; use core::ptr; use core::ptr::NonNull; use core::sync::atomic::{fence, Ordering}; -use zerocopy::FromBytes; +use zerocopy::{AsBytes, FromBytes}; /// The mechanism for bulk data transport on virtio devices. /// @@ -50,11 +52,28 @@ pub struct VirtQueue<H: Hal, const SIZE: usize> { /// Our trusted copy of `avail.idx`. avail_idx: u16, last_used_idx: u16, + /// Whether the `VIRTIO_F_EVENT_IDX` feature has been negotiated. + event_idx: bool, + #[cfg(feature = "alloc")] + indirect: bool, + #[cfg(feature = "alloc")] + indirect_lists: [Option<NonNull<[Descriptor]>>; SIZE], } impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { - /// Create a new VirtQueue. - pub fn new<T: Transport>(transport: &mut T, idx: u16) -> Result<Self> { + /// Creates a new VirtQueue. + /// + /// * `indirect`: Whether to use indirect descriptors. This should be set if the + /// `VIRTIO_F_INDIRECT_DESC` feature has been negotiated with the device. + /// * `event_idx`: Whether to use the `used_event` and `avail_event` fields for notification + /// suppression. This should be set if the `VIRTIO_F_EVENT_IDX` feature has been negotiated + /// with the device. + pub fn new<T: Transport>( + transport: &mut T, + idx: u16, + indirect: bool, + event_idx: bool, + ) -> Result<Self> { if transport.queue_used(idx) { return Err(Error::AlreadyUsed); } @@ -96,6 +115,8 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { } } + #[cfg(feature = "alloc")] + const NONE: Option<NonNull<[Descriptor]>> = None; Ok(VirtQueue { layout, desc, @@ -107,11 +128,18 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { desc_shadow, avail_idx: 0, last_used_idx: 0, + event_idx, + #[cfg(feature = "alloc")] + indirect, + #[cfg(feature = "alloc")] + indirect_lists: [NONE; SIZE], }) } /// Add buffers to the virtqueue, return a token. /// + /// The buffers must not be empty. + /// /// Ref: linux virtio_ring.c virtqueue_add /// /// # Safety @@ -126,15 +154,65 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { if inputs.is_empty() && outputs.is_empty() { return Err(Error::InvalidParam); } - if inputs.len() + outputs.len() + self.num_used as usize > SIZE { + let descriptors_needed = inputs.len() + outputs.len(); + // Only consider indirect descriptors if the alloc feature is enabled, as they require + // allocation. + #[cfg(feature = "alloc")] + if self.num_used as usize + 1 > SIZE + || descriptors_needed > SIZE + || (!self.indirect && self.num_used as usize + descriptors_needed > SIZE) + { + return Err(Error::QueueFull); + } + #[cfg(not(feature = "alloc"))] + if self.num_used as usize + descriptors_needed > SIZE { return Err(Error::QueueFull); } + #[cfg(feature = "alloc")] + let head = if self.indirect && descriptors_needed > 1 { + self.add_indirect(inputs, outputs) + } else { + self.add_direct(inputs, outputs) + }; + #[cfg(not(feature = "alloc"))] + let head = self.add_direct(inputs, outputs); + + let avail_slot = self.avail_idx & (SIZE as u16 - 1); + // Safe because self.avail is properly aligned, dereferenceable and initialised. + unsafe { + (*self.avail.as_ptr()).ring[avail_slot as usize] = head; + } + + // Write barrier so that device sees changes to descriptor table and available ring before + // change to available index. + fence(Ordering::SeqCst); + + // increase head of avail ring + self.avail_idx = self.avail_idx.wrapping_add(1); + // Safe because self.avail is properly aligned, dereferenceable and initialised. + unsafe { + (*self.avail.as_ptr()).idx = self.avail_idx; + } + + // Write barrier so that device can see change to available index after this method returns. + fence(Ordering::SeqCst); + + Ok(head) + } + + fn add_direct<'a, 'b>( + &mut self, + inputs: &'a [&'b [u8]], + outputs: &'a mut [&'b mut [u8]], + ) -> u16 { // allocate descriptors from free list let head = self.free_head; let mut last = self.free_head; for (buffer, direction) in InputOutputIter::new(inputs, outputs) { + assert_ne!(buffer.len(), 0); + // Write to desc_shadow then copy. let desc = &mut self.desc_shadow[usize::from(self.free_head)]; // Safe because our caller promises that the buffers live at least until `pop_used` @@ -156,33 +234,63 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { self.num_used += (inputs.len() + outputs.len()) as u16; - let avail_slot = self.avail_idx & (SIZE as u16 - 1); - // Safe because self.avail is properly aligned, dereferenceable and initialised. - unsafe { - (*self.avail.as_ptr()).ring[avail_slot as usize] = head; + head + } + + #[cfg(feature = "alloc")] + fn add_indirect<'a, 'b>( + &mut self, + inputs: &'a [&'b [u8]], + outputs: &'a mut [&'b mut [u8]], + ) -> u16 { + let head = self.free_head; + + // Allocate and fill in indirect descriptor list. + let mut indirect_list = Descriptor::new_box_slice_zeroed(inputs.len() + outputs.len()); + for (i, (buffer, direction)) in InputOutputIter::new(inputs, outputs).enumerate() { + let desc = &mut indirect_list[i]; + // Safe because our caller promises that the buffers live at least until `pop_used` + // returns them. + unsafe { + desc.set_buf::<H>(buffer, direction, DescFlags::NEXT); + } + desc.next = (i + 1) as u16; } + indirect_list + .last_mut() + .unwrap() + .flags + .remove(DescFlags::NEXT); - // Write barrier so that device sees changes to descriptor table and available ring before - // change to available index. - fence(Ordering::SeqCst); + // Need to store pointer to indirect_list too, because direct_desc.set_buf will only store + // the physical DMA address which might be different. + assert!(self.indirect_lists[usize::from(head)].is_none()); + self.indirect_lists[usize::from(head)] = Some(indirect_list.as_mut().into()); - // increase head of avail ring - self.avail_idx = self.avail_idx.wrapping_add(1); - // Safe because self.avail is properly aligned, dereferenceable and initialised. + // Write a descriptor pointing to indirect descriptor list. We use Box::leak to prevent the + // indirect list from being freed when this function returns; recycle_descriptors is instead + // responsible for freeing the memory after the buffer chain is popped. + let direct_desc = &mut self.desc_shadow[usize::from(head)]; + self.free_head = direct_desc.next; unsafe { - (*self.avail.as_ptr()).idx = self.avail_idx; + direct_desc.set_buf::<H>( + Box::leak(indirect_list).as_bytes().into(), + BufferDirection::DriverToDevice, + DescFlags::INDIRECT, + ); } + self.write_desc(head); + self.num_used += 1; - // Write barrier so that device can see change to available index after this method returns. - fence(Ordering::SeqCst); - - Ok(head) + head } /// Add the given buffers to the virtqueue, notifies the device, blocks until the device uses /// them, then pops them. /// /// This assumes that the device isn't processing any other buffers at the same time. + /// + /// The buffers must not be empty. pub fn add_notify_wait_pop<'a>( &mut self, inputs: &'a [&'a [u8]], @@ -216,9 +324,16 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { // Read barrier, so we read a fresh value from the device. fence(Ordering::SeqCst); - // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable - // instance of UsedRing. - unsafe { (*self.used.as_ptr()).flags & 0x0001 == 0 } + if self.event_idx { + // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable + // instance of UsedRing. + let avail_event = unsafe { (*self.used.as_ptr()).avail_event }; + self.avail_idx >= avail_event.wrapping_add(1) + } else { + // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable + // instance of UsedRing. + unsafe { (*self.used.as_ptr()).flags & 0x0001 == 0 } + } } /// Copies the descriptor at the given index from `desc_shadow` to `desc`, so it can be seen by @@ -257,7 +372,16 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { /// Returns the number of free descriptors. pub fn available_desc(&self) -> usize { - SIZE - self.num_used as usize + #[cfg(feature = "alloc")] + if self.indirect { + return if usize::from(self.num_used) == SIZE { + 0 + } else { + SIZE + }; + } + + SIZE - usize::from(self.num_used) } /// Unshares buffers in the list starting at descriptor index `head` and adds them to the free @@ -278,32 +402,75 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { ) { let original_free_head = self.free_head; self.free_head = head; - let mut next = Some(head); - for (buffer, direction) in InputOutputIter::new(inputs, outputs) { - let desc_index = next.expect("Descriptor chain was shorter than expected."); - let desc = &mut self.desc_shadow[usize::from(desc_index)]; - - let paddr = desc.addr; - desc.unset_buf(); - self.num_used -= 1; - next = desc.next(); - if next.is_none() { - desc.next = original_free_head; + let head_desc = &mut self.desc_shadow[usize::from(head)]; + if head_desc.flags.contains(DescFlags::INDIRECT) { + #[cfg(feature = "alloc")] + { + // Find the indirect descriptor list, unshare it and move its descriptor to the free + // list. + let indirect_list = self.indirect_lists[usize::from(head)].take().unwrap(); + // SAFETY: We allocated the indirect list in `add_indirect`, and the device has + // finished accessing it by this point. + let mut indirect_list = unsafe { Box::from_raw(indirect_list.as_ptr()) }; + let paddr = head_desc.addr; + head_desc.unset_buf(); + self.num_used -= 1; + head_desc.next = original_free_head; + + unsafe { + H::unshare( + paddr as usize, + indirect_list.as_bytes_mut().into(), + BufferDirection::DriverToDevice, + ); + } + + // Unshare the buffers in the indirect descriptor list, and free it. + assert_eq!(indirect_list.len(), inputs.len() + outputs.len()); + for (i, (buffer, direction)) in InputOutputIter::new(inputs, outputs).enumerate() { + assert_ne!(buffer.len(), 0); + + // SAFETY: The caller ensures that the buffer is valid and matches the + // descriptor from which we got `paddr`. + unsafe { + // Unshare the buffer (and perhaps copy its contents back to the original + // buffer). + H::unshare(indirect_list[i].addr as usize, buffer, direction); + } + } + drop(indirect_list); } + } else { + let mut next = Some(head); - self.write_desc(desc_index); + for (buffer, direction) in InputOutputIter::new(inputs, outputs) { + assert_ne!(buffer.len(), 0); - // Safe because the caller ensures that the buffer is valid and matches the descriptor - // from which we got `paddr`. - unsafe { - // Unshare the buffer (and perhaps copy its contents back to the original buffer). - H::unshare(paddr as usize, buffer, direction); + let desc_index = next.expect("Descriptor chain was shorter than expected."); + let desc = &mut self.desc_shadow[usize::from(desc_index)]; + + let paddr = desc.addr; + desc.unset_buf(); + self.num_used -= 1; + next = desc.next(); + if next.is_none() { + desc.next = original_free_head; + } + + self.write_desc(desc_index); + + // SAFETY: The caller ensures that the buffer is valid and matches the descriptor + // from which we got `paddr`. + unsafe { + // Unshare the buffer (and perhaps copy its contents back to the original buffer). + H::unshare(paddr as usize, buffer, direction); + } } - } - if next.is_some() { - panic!("Descriptor chain was longer than expected."); + if next.is_some() { + panic!("Descriptor chain was longer than expected."); + } } } @@ -501,7 +668,7 @@ fn queue_part_sizes(queue_size: u16) -> (usize, usize, usize) { } #[repr(C, align(16))] -#[derive(Clone, Debug, FromBytes)] +#[derive(AsBytes, Clone, Debug, FromBytes)] pub(crate) struct Descriptor { addr: u64, len: u32, @@ -556,7 +723,7 @@ impl Descriptor { } /// Descriptor flags -#[derive(Copy, Clone, Debug, Default, Eq, FromBytes, PartialEq)] +#[derive(AsBytes, Copy, Clone, Debug, Default, Eq, FromBytes, PartialEq)] #[repr(transparent)] struct DescFlags(u16); @@ -589,7 +756,8 @@ struct UsedRing<const SIZE: usize> { flags: u16, idx: u16, ring: [UsedElem; SIZE], - avail_event: u16, // unused + /// Only used if `VIRTIO_F_EVENT_IDX` is negotiated. + avail_event: u16, } #[repr(C)] @@ -650,6 +818,7 @@ pub(crate) fn fake_read_write_queue<const QUEUE_SIZE: usize>( handler: impl FnOnce(Vec<u8>) -> Vec<u8>, ) { use core::{ops::Deref, slice}; + use zerocopy::LayoutVerified; let available_ring = queue_driver_area as *const AvailRing<QUEUE_SIZE>; let used_ring = queue_device_area as *mut UsedRing<QUEUE_SIZE>; @@ -665,47 +834,99 @@ pub(crate) fn fake_read_write_queue<const QUEUE_SIZE: usize>( let head_descriptor_index = (*available_ring).ring[next_slot as usize]; let mut descriptor = &(*descriptors)[head_descriptor_index as usize]; - // Loop through all input descriptors in the chain, reading data from them. - let mut input = Vec::new(); - while !descriptor.flags.contains(DescFlags::WRITE) { - input.extend_from_slice(slice::from_raw_parts( - descriptor.addr as *const u8, - descriptor.len as usize, - )); + let input_length; + let output; + if descriptor.flags.contains(DescFlags::INDIRECT) { + // The descriptor shouldn't have any other flags if it is indirect. + assert_eq!(descriptor.flags, DescFlags::INDIRECT); + + // Loop through all input descriptors in the indirect descriptor list, reading data from + // them. + let indirect_descriptor_list: &[Descriptor] = LayoutVerified::new_slice( + slice::from_raw_parts(descriptor.addr as *const u8, descriptor.len as usize), + ) + .unwrap() + .into_slice(); + let mut input = Vec::new(); + let mut indirect_descriptor_index = 0; + while indirect_descriptor_index < indirect_descriptor_list.len() { + let indirect_descriptor = &indirect_descriptor_list[indirect_descriptor_index]; + if indirect_descriptor.flags.contains(DescFlags::WRITE) { + break; + } - if let Some(next) = descriptor.next() { - descriptor = &(*descriptors)[next as usize]; - } else { - break; + input.extend_from_slice(slice::from_raw_parts( + indirect_descriptor.addr as *const u8, + indirect_descriptor.len as usize, + )); + + indirect_descriptor_index += 1; } - } - let input_length = input.len(); + input_length = input.len(); - // Let the test handle the request. - let output = handler(input); + // Let the test handle the request. + output = handler(input); - // Write the response to the remaining descriptors. - let mut remaining_output = output.deref(); - if descriptor.flags.contains(DescFlags::WRITE) { - loop { - assert!(descriptor.flags.contains(DescFlags::WRITE)); + // Write the response to the remaining descriptors. + let mut remaining_output = output.deref(); + while indirect_descriptor_index < indirect_descriptor_list.len() { + let indirect_descriptor = &indirect_descriptor_list[indirect_descriptor_index]; + assert!(indirect_descriptor.flags.contains(DescFlags::WRITE)); - let length_to_write = min(remaining_output.len(), descriptor.len as usize); + let length_to_write = min(remaining_output.len(), indirect_descriptor.len as usize); ptr::copy( remaining_output.as_ptr(), - descriptor.addr as *mut u8, + indirect_descriptor.addr as *mut u8, length_to_write, ); remaining_output = &remaining_output[length_to_write..]; + indirect_descriptor_index += 1; + } + assert_eq!(remaining_output.len(), 0); + } else { + // Loop through all input descriptors in the chain, reading data from them. + let mut input = Vec::new(); + while !descriptor.flags.contains(DescFlags::WRITE) { + input.extend_from_slice(slice::from_raw_parts( + descriptor.addr as *const u8, + descriptor.len as usize, + )); + if let Some(next) = descriptor.next() { descriptor = &(*descriptors)[next as usize]; } else { break; } } + input_length = input.len(); + + // Let the test handle the request. + output = handler(input); + + // Write the response to the remaining descriptors. + let mut remaining_output = output.deref(); + if descriptor.flags.contains(DescFlags::WRITE) { + loop { + assert!(descriptor.flags.contains(DescFlags::WRITE)); + + let length_to_write = min(remaining_output.len(), descriptor.len as usize); + ptr::copy( + remaining_output.as_ptr(), + descriptor.addr as *mut u8, + length_to_write, + ); + remaining_output = &remaining_output[length_to_write..]; + + if let Some(next) = descriptor.next() { + descriptor = &(*descriptors)[next as usize]; + } else { + break; + } + } + } + assert_eq!(remaining_output.len(), 0); } - assert_eq!(remaining_output.len(), 0); // Mark the buffer as used. (*used_ring).ring[next_slot as usize].id = head_descriptor_index as u32; @@ -718,10 +939,16 @@ pub(crate) fn fake_read_write_queue<const QUEUE_SIZE: usize>( mod tests { use super::*; use crate::{ + device::common::Feature, hal::fake::FakeHal, - transport::mmio::{MmioTransport, VirtIOHeader, MODERN_VERSION}, + transport::{ + fake::{FakeTransport, QueueStatus, State}, + mmio::{MmioTransport, VirtIOHeader, MODERN_VERSION}, + DeviceStatus, DeviceType, + }, }; use core::ptr::NonNull; + use std::sync::{Arc, Mutex}; #[test] fn invalid_queue_size() { @@ -729,7 +956,7 @@ mod tests { let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); // Size not a power of 2. assert_eq!( - VirtQueue::<FakeHal, 3>::new(&mut transport, 0).unwrap_err(), + VirtQueue::<FakeHal, 3>::new(&mut transport, 0, false, false).unwrap_err(), Error::InvalidParam ); } @@ -739,7 +966,7 @@ mod tests { let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4); let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); assert_eq!( - VirtQueue::<FakeHal, 8>::new(&mut transport, 0).unwrap_err(), + VirtQueue::<FakeHal, 8>::new(&mut transport, 0, false, false).unwrap_err(), Error::InvalidParam ); } @@ -748,9 +975,9 @@ mod tests { fn queue_already_used() { let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4); let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); - VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap(); + VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap(); assert_eq!( - VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap_err(), + VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap_err(), Error::AlreadyUsed ); } @@ -759,7 +986,7 @@ mod tests { fn add_empty() { let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4); let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); - let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap(); + let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap(); assert_eq!( unsafe { queue.add(&[], &mut []) }.unwrap_err(), Error::InvalidParam @@ -770,7 +997,7 @@ mod tests { fn add_too_many() { let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4); let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); - let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap(); + let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap(); assert_eq!(queue.available_desc(), 4); assert_eq!( unsafe { queue.add(&[&[], &[], &[]], &mut [&mut [], &mut []]) }.unwrap_err(), @@ -782,7 +1009,7 @@ mod tests { fn add_buffers() { let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4); let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); - let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap(); + let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap(); assert_eq!(queue.available_desc(), 4); // Add a buffer chain consisting of two device-readable parts followed by two @@ -837,4 +1064,133 @@ mod tests { ); } } + + #[cfg(feature = "alloc")] + #[test] + fn add_buffers_indirect() { + use core::ptr::slice_from_raw_parts; + + let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4); + let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); + let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, true, false).unwrap(); + assert_eq!(queue.available_desc(), 4); + + // Add a buffer chain consisting of two device-readable parts followed by two + // device-writable parts. + let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap(); + + assert_eq!(queue.available_desc(), 4); + assert!(!queue.can_pop()); + + // Safe because the various parts of the queue are properly aligned, dereferenceable and + // initialised, and nothing else is accessing them at the same time. + unsafe { + let indirect_descriptor_index = (*queue.avail.as_ptr()).ring[0]; + assert_eq!(indirect_descriptor_index, token); + assert_eq!( + (*queue.desc.as_ptr())[indirect_descriptor_index as usize].len as usize, + 4 * size_of::<Descriptor>() + ); + assert_eq!( + (*queue.desc.as_ptr())[indirect_descriptor_index as usize].flags, + DescFlags::INDIRECT + ); + + let indirect_descriptors = slice_from_raw_parts( + (*queue.desc.as_ptr())[indirect_descriptor_index as usize].addr + as *const Descriptor, + 4, + ); + assert_eq!((*indirect_descriptors)[0].len, 2); + assert_eq!((*indirect_descriptors)[0].flags, DescFlags::NEXT); + assert_eq!((*indirect_descriptors)[0].next, 1); + assert_eq!((*indirect_descriptors)[1].len, 1); + assert_eq!((*indirect_descriptors)[1].flags, DescFlags::NEXT); + assert_eq!((*indirect_descriptors)[1].next, 2); + assert_eq!((*indirect_descriptors)[2].len, 2); + assert_eq!( + (*indirect_descriptors)[2].flags, + DescFlags::NEXT | DescFlags::WRITE + ); + assert_eq!((*indirect_descriptors)[2].next, 3); + assert_eq!((*indirect_descriptors)[3].len, 1); + assert_eq!((*indirect_descriptors)[3].flags, DescFlags::WRITE); + } + } + + /// Tests that the queue notifies the device about added buffers, if it hasn't suppressed + /// notifications. + #[test] + fn add_notify() { + let mut config_space = (); + let state = Arc::new(Mutex::new(State { + queues: vec![QueueStatus::default()], + ..Default::default() + })); + let mut transport = FakeTransport { + device_type: DeviceType::Block, + max_queue_size: 4, + device_features: 0, + config_space: NonNull::from(&mut config_space), + state: state.clone(), + }; + let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap(); + + // Add a buffer chain with a single device-readable part. + unsafe { queue.add(&[&[42]], &mut []) }.unwrap(); + + // Check that the transport would be notified. + assert_eq!(queue.should_notify(), true); + + // SAFETY: the various parts of the queue are properly aligned, dereferenceable and + // initialised, and nothing else is accessing them at the same time. + unsafe { + // Suppress notifications. + (*queue.used.as_ptr()).flags = 0x01; + } + + // Check that the transport would not be notified. + assert_eq!(queue.should_notify(), false); + } + + /// Tests that the queue notifies the device about added buffers, if it hasn't suppressed + /// notifications with the `avail_event` index. + #[test] + fn add_notify_event_idx() { + let mut config_space = (); + let state = Arc::new(Mutex::new(State { + queues: vec![QueueStatus::default()], + ..Default::default() + })); + let mut transport = FakeTransport { + device_type: DeviceType::Block, + max_queue_size: 4, + device_features: Feature::RING_EVENT_IDX.bits(), + config_space: NonNull::from(&mut config_space), + state: state.clone(), + }; + let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, true).unwrap(); + + // Add a buffer chain with a single device-readable part. + assert_eq!(unsafe { queue.add(&[&[42]], &mut []) }.unwrap(), 0); + + // Check that the transport would be notified. + assert_eq!(queue.should_notify(), true); + + // SAFETY: the various parts of the queue are properly aligned, dereferenceable and + // initialised, and nothing else is accessing them at the same time. + unsafe { + // Suppress notifications. + (*queue.used.as_ptr()).avail_event = 1; + } + + // Check that the transport would not be notified. + assert_eq!(queue.should_notify(), false); + + // Add another buffer chain. + assert_eq!(unsafe { queue.add(&[&[42]], &mut []) }.unwrap(), 1); + + // Check that the transport should be notified again now. + assert_eq!(queue.should_notify(), true); + } } diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 3157e81..f6e9eae 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -6,8 +6,9 @@ pub mod mmio; pub mod pci; use crate::{PhysAddr, Result, PAGE_SIZE}; -use bitflags::bitflags; -use core::ptr::NonNull; +use bitflags::{bitflags, Flags}; +use core::{fmt::Debug, ops::BitAnd, ptr::NonNull}; +use log::debug; /// A VirtIO transport layer. pub trait Transport { @@ -64,17 +65,27 @@ pub trait Transport { /// Begins initializing the device. /// /// Ref: virtio 3.1.1 Device Initialization - fn begin_init(&mut self, negotiate_features: impl FnOnce(u64) -> u64) { + /// + /// Returns the negotiated set of features. + fn begin_init<F: Flags<Bits = u64> + BitAnd<Output = F> + Debug>( + &mut self, + supported_features: F, + ) -> F { self.set_status(DeviceStatus::empty()); self.set_status(DeviceStatus::ACKNOWLEDGE | DeviceStatus::DRIVER); - let features = self.read_device_features(); - self.write_driver_features(negotiate_features(features)); + let device_features = F::from_bits_truncate(self.read_device_features()); + debug!("Device features: {:?}", device_features); + let negotiated_features = device_features & supported_features; + self.write_driver_features(negotiated_features.bits()); + self.set_status( DeviceStatus::ACKNOWLEDGE | DeviceStatus::DRIVER | DeviceStatus::FEATURES_OK, ); self.set_guest_page_size(PAGE_SIZE as u32); + + negotiated_features } /// Finishes initializing the device. |