diff options
Diffstat (limited to 'src/vhost_user/slave_req_handler.rs')
-rw-r--r-- | src/vhost_user/slave_req_handler.rs | 262 |
1 files changed, 131 insertions, 131 deletions
diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs index 9d7ea10..402030c 100644 --- a/src/vhost_user/slave_req_handler.rs +++ b/src/vhost_user/slave_req_handler.rs @@ -1,16 +1,16 @@ // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use std::fs::File; use std::mem; -use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::UnixStream; use std::slice; use std::sync::{Arc, Mutex}; use super::connection::Endpoint; use super::message::*; -use super::slave_fs_cache::SlaveFsCacheReq; -use super::{Error, Result}; +use super::{take_single_file, Error, Result}; /// Services provided to the master by the slave with interior mutability. /// @@ -38,7 +38,7 @@ pub trait VhostUserSlaveReqHandler { fn reset_owner(&self) -> Result<()>; fn get_features(&self) -> Result<u64>; fn set_features(&self, features: u64) -> Result<()>; - fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>; + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; fn set_vring_num(&self, index: u32, num: u32) -> Result<()>; fn set_vring_addr( &self, @@ -51,9 +51,9 @@ pub trait VhostUserSlaveReqHandler { ) -> Result<()>; fn set_vring_base(&self, index: u32, base: u32) -> Result<()>; fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>; - fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()>; - fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()>; - fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>; fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>; fn set_protocol_features(&self, features: u64) -> Result<()>; @@ -61,9 +61,11 @@ pub trait VhostUserSlaveReqHandler { fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>; fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>; fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; - fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {} + fn set_slave_req_fd(&self, _vu_req: File) {} + fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>; + fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>; fn get_max_mem_slots(&self) -> Result<u64>; - fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>; + fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>; } @@ -76,7 +78,7 @@ pub trait VhostUserSlaveReqHandlerMut { fn reset_owner(&mut self) -> Result<()>; fn get_features(&mut self) -> Result<u64>; fn set_features(&mut self, features: u64) -> Result<()>; - fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>; + fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>; fn set_vring_addr( &mut self, @@ -89,9 +91,9 @@ pub trait VhostUserSlaveReqHandlerMut { ) -> Result<()>; fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>; fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>; - fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; - fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; - fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>; fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; fn set_protocol_features(&mut self, features: u64) -> Result<()>; @@ -104,9 +106,14 @@ pub trait VhostUserSlaveReqHandlerMut { flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>; fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; - fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {} + fn set_slave_req_fd(&mut self, _vu_req: File) {} + fn get_inflight_fd( + &mut self, + inflight: &VhostUserInflight, + ) -> Result<(VhostUserInflight, File)>; + fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>; fn get_max_mem_slots(&mut self) -> Result<u64>; - fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>; + fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>; } @@ -127,8 +134,8 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { self.lock().unwrap().set_features(features) } - fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> { - self.lock().unwrap().set_mem_table(ctx, fds) + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> { + self.lock().unwrap().set_mem_table(ctx, files) } fn set_vring_num(&self, index: u32, num: u32) -> Result<()> { @@ -157,15 +164,15 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { self.lock().unwrap().get_vring_base(index) } - fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> { self.lock().unwrap().set_vring_kick(index, fd) } - fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> { self.lock().unwrap().set_vring_call(index, fd) } - fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> { self.lock().unwrap().set_vring_err(index, fd) } @@ -193,15 +200,23 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { self.lock().unwrap().set_config(offset, buf, flags) } - fn set_slave_req_fd(&self, vu_req: SlaveFsCacheReq) { + fn set_slave_req_fd(&self, vu_req: File) { self.lock().unwrap().set_slave_req_fd(vu_req) } + fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> { + self.lock().unwrap().get_inflight_fd(inflight) + } + + fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> { + self.lock().unwrap().set_inflight_fd(inflight, file) + } + fn get_max_mem_slots(&self) -> Result<u64> { self.lock().unwrap().get_max_mem_slots() } - fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()> { + fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> { self.lock().unwrap().add_mem_region(region, fd) } @@ -253,6 +268,11 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } } + /// Create a vhost-user slave endpoint from a connected socket. + pub fn from_stream(socket: UnixStream, backend: Arc<S>) -> Self { + Self::new(Endpoint::from_stream(socket), backend) + } + /// Create a new vhost-user slave endpoint. /// /// # Arguments @@ -286,8 +306,9 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { // . recv optional message body and payload according size field in // message header // . validate message body and optional payload - let (hdr, rfds) = self.main_sock.recv_header()?; - let rfds = self.check_attached_rfds(&hdr, rfds)?; + let (hdr, files) = self.main_sock.recv_header()?; + self.check_attached_files(&hdr, &files)?; + let (size, buf) = match hdr.get_size() { 0 => (0, vec![0u8; 0]), len => { @@ -302,11 +323,13 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { match hdr.get_code() { MasterReq::SET_OWNER => { self.check_request_size(&hdr, size, 0)?; - self.backend.set_owner()?; + let res = self.backend.set_owner(); + self.send_ack_message(&hdr, res)?; } MasterReq::RESET_OWNER => { self.check_request_size(&hdr, size, 0)?; - self.backend.reset_owner()?; + let res = self.backend.reset_owner(); + self.send_ack_message(&hdr, res)?; } MasterReq::GET_FEATURES => { self.check_request_size(&hdr, size, 0)?; @@ -318,12 +341,13 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } MasterReq::SET_FEATURES => { let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; - self.backend.set_features(msg.value)?; + let res = self.backend.set_features(msg.value); self.acked_virtio_features = msg.value; self.update_reply_ack_flag(); + self.send_ack_message(&hdr, res)?; } MasterReq::SET_MEM_TABLE => { - let res = self.set_mem_table(&hdr, size, &buf, rfds); + let res = self.set_mem_table(&hdr, size, &buf, files); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_NUM => { @@ -359,20 +383,20 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } MasterReq::SET_VRING_CALL => { self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; - let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.set_vring_call(index, rfds); + let (index, file) = self.handle_vring_fd_request(&buf, files)?; + let res = self.backend.set_vring_call(index, file); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_KICK => { self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; - let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.set_vring_kick(index, rfds); + let (index, file) = self.handle_vring_fd_request(&buf, files)?; + let res = self.backend.set_vring_kick(index, file); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_ERR => { self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; - let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.set_vring_err(index, rfds); + let (index, file) = self.handle_vring_fd_request(&buf, files)?; + let res = self.backend.set_vring_err(index, file); self.send_ack_message(&hdr, res)?; } MasterReq::GET_PROTOCOL_FEATURES => { @@ -385,9 +409,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } MasterReq::SET_PROTOCOL_FEATURES => { let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; - self.backend.set_protocol_features(msg.value)?; + let res = self.backend.set_protocol_features(msg.value); self.acked_protocol_features = msg.value; self.update_reply_ack_flag(); + self.send_ack_message(&hdr, res)?; } MasterReq::GET_QUEUE_NUM => { if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 { @@ -426,14 +451,40 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { return Err(Error::InvalidOperation); } self.check_request_size(&hdr, size, hdr.get_size() as usize)?; - self.set_config(&hdr, size, &buf)?; + let res = self.set_config(size, &buf); + self.send_ack_message(&hdr, res)?; } MasterReq::SET_SLAVE_REQ_FD => { if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { return Err(Error::InvalidOperation); } self.check_request_size(&hdr, size, hdr.get_size() as usize)?; - self.set_slave_req_fd(&hdr, rfds)?; + let res = self.set_slave_req_fd(files); + self.send_ack_message(&hdr, res)?; + } + MasterReq::GET_INFLIGHT_FD => { + if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() + == 0 + { + return Err(Error::InvalidOperation); + } + + let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; + let (inflight, file) = self.backend.get_inflight_fd(&msg)?; + let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?; + self.main_sock + .send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?; + } + MasterReq::SET_INFLIGHT_FD => { + if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() + == 0 + { + return Err(Error::InvalidOperation); + } + let file = take_single_file(files).ok_or(Error::IncorrectFds)?; + let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; + let res = self.backend.set_inflight_fd(&msg, file); + self.send_ack_message(&hdr, res)?; } MasterReq::GET_MAX_MEM_SLOTS => { if self.acked_protocol_features @@ -454,18 +505,13 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { { return Err(Error::InvalidOperation); } - let fd = if let Some(fds) = &rfds { - if fds.len() != 1 { - return Err(Error::InvalidParam); - } - fds[0] - } else { + let mut files = files.ok_or(Error::InvalidParam)?; + if files.len() != 1 { return Err(Error::InvalidParam); - }; - + } let msg = self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; - let res = self.backend.add_mem_region(&msg, fd); + let res = self.backend.add_mem_region(&msg, files.swap_remove(0)); self.send_ack_message(&hdr, res)?; } MasterReq::REM_MEM_REG => { @@ -493,37 +539,28 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], - rfds: Option<Vec<RawFd>>, + files: Option<Vec<File>>, ) -> Result<()> { self.check_request_size(&hdr, size, hdr.get_size() as usize)?; // check message size is consistent let hdrsize = mem::size_of::<VhostUserMemory>(); if size < hdrsize { - Endpoint::<MasterReq>::close_rfds(rfds); return Err(Error::InvalidMessage); } let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) }; if !msg.is_valid() { - Endpoint::<MasterReq>::close_rfds(rfds); return Err(Error::InvalidMessage); } if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() { - Endpoint::<MasterReq>::close_rfds(rfds); return Err(Error::InvalidMessage); } // validate number of fds matching number of memory regions - let fds = match rfds { - None => return Err(Error::InvalidMessage), - Some(fds) => { - if fds.len() != msg.num_regions as usize { - Endpoint::<MasterReq>::close_rfds(Some(fds)); - return Err(Error::InvalidMessage); - } - fds - } - }; + let files = files.ok_or(Error::InvalidMessage)?; + if files.len() != msg.num_regions as usize { + return Err(Error::InvalidMessage); + } // Validate memory regions let regions = unsafe { @@ -534,12 +571,11 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { }; for region in regions.iter() { if !region.is_valid() { - Endpoint::<MasterReq>::close_rfds(Some(fds)); return Err(Error::InvalidMessage); } } - self.backend.set_mem_table(®ions, &fds) + self.backend.set_mem_table(®ions, files) } fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> { @@ -580,12 +616,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { Ok(()) } - fn set_config( - &mut self, - hdr: &VhostUserMsgHeader<MasterReq>, - size: usize, - buf: &[u8], - ) -> Result<()> { + fn set_config(&mut self, size: usize, buf: &[u8]) -> Result<()> { if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() { return Err(Error::InvalidMessage); } @@ -602,35 +633,20 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { None => return Err(Error::InvalidMessage), } - let res = self.backend.set_config(msg.offset, buf, flags); - self.send_ack_message(&hdr, res)?; - Ok(()) + self.backend.set_config(msg.offset, buf, flags) } - fn set_slave_req_fd( - &mut self, - hdr: &VhostUserMsgHeader<MasterReq>, - rfds: Option<Vec<RawFd>>, - ) -> Result<()> { - if let Some(fds) = rfds { - if fds.len() == 1 { - let sock = unsafe { UnixStream::from_raw_fd(fds[0]) }; - let vu_req = SlaveFsCacheReq::from_stream(sock); - self.backend.set_slave_req_fd(vu_req); - self.send_ack_message(&hdr, Ok(())) - } else { - Err(Error::InvalidMessage) - } - } else { - Err(Error::InvalidMessage) - } + fn set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> { + let file = take_single_file(files).ok_or(Error::InvalidMessage)?; + self.backend.set_slave_req_fd(file); + Ok(()) } fn handle_vring_fd_request( &mut self, buf: &[u8], - rfds: Option<Vec<RawFd>>, - ) -> Result<(u8, Option<RawFd>)> { + files: Option<Vec<File>>, + ) -> Result<(u8, Option<File>)> { if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() { return Err(Error::InvalidMessage); } @@ -640,28 +656,19 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } // Bits (0-7) of the payload contain the vring index. Bit 8 is the - // invalid FD flag. This flag is set when there is no file descriptor + // invalid FD flag. This bit is set when there is no file descriptor // in the ancillary data. This signals that polling will be used // instead of waiting for the call. - let nofd = (msg.value & 0x100u64) == 0x100u64; - - let mut rfd = None; - match rfds { - Some(fds) => { - if !nofd && fds.len() == 1 { - rfd = Some(fds[0]); - } else if (nofd && !fds.is_empty()) || (!nofd && fds.len() != 1) { - Endpoint::<MasterReq>::close_rfds(Some(fds)); - return Err(Error::InvalidMessage); - } - } - None => { - if !nofd { - return Err(Error::InvalidMessage); - } - } + // If Bit 8 is unset, the data must contain a file descriptor. + let has_fd = (msg.value & 0x100u64) == 0; + + let file = take_single_file(files); + + if has_fd && file.is_none() || !has_fd && file.is_some() { + return Err(Error::InvalidMessage); } - Ok((msg.value as u8, rfd)) + + Ok((msg.value as u8, file)) } fn check_state(&self) -> Result<()> { @@ -687,29 +694,23 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { Ok(()) } - fn check_attached_rfds( + fn check_attached_files( &self, hdr: &VhostUserMsgHeader<MasterReq>, - rfds: Option<Vec<RawFd>>, - ) -> Result<Option<Vec<RawFd>>> { + files: &Option<Vec<File>>, + ) -> Result<()> { match hdr.get_code() { - MasterReq::SET_MEM_TABLE => Ok(rfds), - MasterReq::SET_VRING_CALL => Ok(rfds), - MasterReq::SET_VRING_KICK => Ok(rfds), - MasterReq::SET_VRING_ERR => Ok(rfds), - MasterReq::SET_LOG_BASE => Ok(rfds), - MasterReq::SET_LOG_FD => Ok(rfds), - MasterReq::SET_SLAVE_REQ_FD => Ok(rfds), - MasterReq::SET_INFLIGHT_FD => Ok(rfds), - MasterReq::ADD_MEM_REG => Ok(rfds), - _ => { - if rfds.is_some() { - Endpoint::<MasterReq>::close_rfds(rfds); - Err(Error::InvalidMessage) - } else { - Ok(rfds) - } - } + MasterReq::SET_MEM_TABLE + | MasterReq::SET_VRING_CALL + | MasterReq::SET_VRING_KICK + | MasterReq::SET_VRING_ERR + | MasterReq::SET_LOG_BASE + | MasterReq::SET_LOG_FD + | MasterReq::SET_SLAVE_REQ_FD + | MasterReq::SET_INFLIGHT_FD + | MasterReq::ADD_MEM_REG => Ok(()), + _ if files.is_some() => Err(Error::InvalidMessage), + _ => Ok(()), } } @@ -731,7 +732,6 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); let pflag = VhostUserProtocolFeatures::REPLY_ACK; if (self.virtio_features & vflag) != 0 - && (self.acked_virtio_features & vflag) != 0 && self.protocol_features.contains(pflag) && (self.acked_protocol_features & pflag.bits()) != 0 { @@ -774,7 +774,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { let msg = VhostUserU64::new(val); self.main_sock.send_message(&hdr, &msg, None)?; } - Ok(()) + res } fn send_reply_message<T>( |