aboutsummaryrefslogtreecommitdiff
// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause

use std::{
    collections::{HashMap, HashSet, VecDeque},
    ops::Deref,
    os::unix::{
        net::UnixStream,
        prelude::{AsRawFd, RawFd},
    },
    sync::{Arc, RwLock},
};

use log::{info, warn};
use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
use vm_memory::bitmap::BitmapSlice;

use crate::{
    rxops::*,
    vhu_vsock::{
        CidMap, ConnMapKey, Error, Result, VSOCK_HOST_CID, VSOCK_OP_REQUEST, VSOCK_OP_RST,
        VSOCK_TYPE_STREAM,
    },
    vhu_vsock_thread::VhostUserVsockThread,
    vsock_conn::*,
};

pub(crate) type RawPktsQ = VecDeque<RawVsockPacket>;

pub(crate) struct RawVsockPacket {
    pub header: [u8; PKT_HEADER_SIZE],
    pub data: Vec<u8>,
}

impl RawVsockPacket {
    fn from_vsock_packet<B: BitmapSlice>(pkt: &VsockPacket<B>) -> Result<Self> {
        let mut raw_pkt = Self {
            header: [0; PKT_HEADER_SIZE],
            data: vec![0; pkt.len() as usize],
        };

        pkt.header_slice().copy_to(&mut raw_pkt.header);
        if !pkt.is_empty() {
            pkt.data_slice()
                .ok_or(Error::PktBufMissing)?
                .copy_to(raw_pkt.data.as_mut());
        }

        Ok(raw_pkt)
    }
}

pub(crate) struct VsockThreadBackend {
    /// Map of ConnMapKey objects indexed by raw file descriptors.
    pub listener_map: HashMap<RawFd, ConnMapKey>,
    /// Map of vsock connection objects indexed by ConnMapKey objects.
    pub conn_map: HashMap<ConnMapKey, VsockConnection<UnixStream>>,
    /// Queue of ConnMapKey objects indicating pending rx operations.
    pub backend_rxq: VecDeque<ConnMapKey>,
    /// Map of host-side unix streams indexed by raw file descriptors.
    pub stream_map: HashMap<i32, UnixStream>,
    /// Host side socket for listening to new connections from the host.
    host_socket_path: String,
    /// epoll for registering new host-side connections.
    epoll_fd: i32,
    /// CID of the guest.
    guest_cid: u64,
    /// Set of allocated local ports.
    pub local_port_set: HashSet<u32>,
    tx_buffer_size: u32,
    /// Maps the guest CID to the corresponding backend. Used for sibling VM communication.
    pub cid_map: Arc<RwLock<CidMap>>,
    /// Queue of raw vsock packets recieved from sibling VMs to be sent to the guest.
    pub raw_pkts_queue: Arc<RwLock<RawPktsQ>>,
    /// Set of groups assigned to the device which it is allowed to communicate with.
    groups_set: Arc<RwLock<HashSet<String>>>,
}

impl VsockThreadBackend {
    /// New instance of VsockThreadBackend.
    pub fn new(
        host_socket_path: String,
        epoll_fd: i32,
        guest_cid: u64,
        tx_buffer_size: u32,
        groups_set: Arc<RwLock<HashSet<String>>>,
        cid_map: Arc<RwLock<CidMap>>,
    ) -> Self {
        Self {
            listener_map: HashMap::new(),
            conn_map: HashMap::new(),
            backend_rxq: VecDeque::new(),
            // Need this map to prevent connected stream from closing
            // TODO: think of a better solution
            stream_map: HashMap::new(),
            host_socket_path,
            epoll_fd,
            guest_cid,
            local_port_set: HashSet::new(),
            tx_buffer_size,
            cid_map,
            raw_pkts_queue: Arc::new(RwLock::new(VecDeque::new())),
            groups_set,
        }
    }

    /// Checks if there are pending rx requests in the backend rxq.
    pub fn pending_rx(&self) -> bool {
        !self.backend_rxq.is_empty()
    }

    /// Checks if there are pending raw vsock packets to be sent to the guest.
    pub fn pending_raw_pkts(&self) -> bool {
        !self.raw_pkts_queue.read().unwrap().is_empty()
    }

    /// Deliver a vsock packet to the guest vsock driver.
    ///
    /// Returns:
    /// - `Ok(())` if the packet was successfully filled in
    /// - `Err(Error::EmptyBackendRxQ) if there was no available data
    pub fn recv_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()> {
        // Pop an event from the backend_rxq
        let key = self.backend_rxq.pop_front().ok_or(Error::EmptyBackendRxQ)?;
        let conn = match self.conn_map.get_mut(&key) {
            Some(conn) => conn,
            None => {
                // assume that the connection does not exist
                return Ok(());
            }
        };

        if conn.rx_queue.peek() == Some(RxOps::Reset) {
            // Handle RST events here
            let conn = self.conn_map.remove(&key).unwrap();
            self.listener_map.remove(&conn.stream.as_raw_fd());
            self.stream_map.remove(&conn.stream.as_raw_fd());
            self.local_port_set.remove(&conn.local_port);
            VhostUserVsockThread::epoll_unregister(conn.epoll_fd, conn.stream.as_raw_fd())
                .unwrap_or_else(|err| {
                    warn!(
                        "Could not remove epoll listener for fd {:?}: {:?}",
                        conn.stream.as_raw_fd(),
                        err
                    )
                });

            // Initialize the packet header to contain a VSOCK_OP_RST operation
            pkt.set_op(VSOCK_OP_RST)
                .set_src_cid(VSOCK_HOST_CID)
                .set_dst_cid(conn.guest_cid)
                .set_src_port(conn.local_port)
                .set_dst_port(conn.peer_port)
                .set_len(0)
                .set_type(VSOCK_TYPE_STREAM)
                .set_flags(0)
                .set_buf_alloc(0)
                .set_fwd_cnt(0);

            return Ok(());
        }

        // Handle other packet types per connection
        conn.recv_pkt(pkt)?;

        Ok(())
    }

    /// Deliver a guest generated packet to its destination in the backend.
    ///
    /// Absorbs unexpected packets, handles rest to respective connection
    /// object.
    ///
    /// Returns:
    /// - always `Ok(())` if packet has been consumed correctly
    pub fn send_pkt<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>) -> Result<()> {
        if pkt.src_cid() != self.guest_cid {
            warn!(
                "vsock: dropping packet with inconsistent src_cid: {:?} from guest configured with CID: {:?}",
                pkt.src_cid(), self.guest_cid
            );
            return Ok(());
        }

        let dst_cid = pkt.dst_cid();
        if dst_cid != VSOCK_HOST_CID {
            let cid_map = self.cid_map.read().unwrap();
            if cid_map.contains_key(&dst_cid) {
                let (sibling_raw_pkts_queue, sibling_groups_set, sibling_event_fd) =
                    cid_map.get(&dst_cid).unwrap();

                if self
                    .groups_set
                    .read()
                    .unwrap()
                    .is_disjoint(sibling_groups_set.read().unwrap().deref())
                {
                    info!(
                        "vsock: dropping packet for cid: {:?} due to group mismatch",
                        dst_cid
                    );
                    return Ok(());
                }

                sibling_raw_pkts_queue
                    .write()
                    .unwrap()
                    .push_back(RawVsockPacket::from_vsock_packet(pkt)?);
                let _ = sibling_event_fd.write(1);
            } else {
                warn!("vsock: dropping packet for unknown cid: {:?}", dst_cid);
            }

            return Ok(());
        }

        // TODO: Rst if packet has unsupported type
        if pkt.type_() != VSOCK_TYPE_STREAM {
            info!("vsock: dropping packet of unknown type");
            return Ok(());
        }

        let key = ConnMapKey::new(pkt.dst_port(), pkt.src_port());

        // TODO: Handle cases where connection does not exist and packet op
        // is not VSOCK_OP_REQUEST
        if !self.conn_map.contains_key(&key) {
            // The packet contains a new connection request
            if pkt.op() == VSOCK_OP_REQUEST {
                self.handle_new_guest_conn(pkt);
            } else {
                // TODO: send back RST
            }
            return Ok(());
        }

        if pkt.op() == VSOCK_OP_RST {
            // Handle an RST packet from the guest here
            let conn = self.conn_map.get(&key).unwrap();
            if conn.rx_queue.contains(RxOps::Reset.bitmask()) {
                return Ok(());
            }
            let conn = self.conn_map.remove(&key).unwrap();
            self.listener_map.remove(&conn.stream.as_raw_fd());
            self.stream_map.remove(&conn.stream.as_raw_fd());
            self.local_port_set.remove(&conn.local_port);
            VhostUserVsockThread::epoll_unregister(conn.epoll_fd, conn.stream.as_raw_fd())
                .unwrap_or_else(|err| {
                    warn!(
                        "Could not remove epoll listener for fd {:?}: {:?}",
                        conn.stream.as_raw_fd(),
                        err
                    )
                });
            return Ok(());
        }

        // Forward this packet to its listening connection
        let conn = self.conn_map.get_mut(&key).unwrap();
        conn.send_pkt(pkt)?;

        if conn.rx_queue.pending_rx() {
            // Required if the connection object adds new rx operations
            self.backend_rxq.push_back(key);
        }

        Ok(())
    }

    /// Deliver a raw vsock packet sent from a sibling VM to the guest vsock driver.
    ///
    /// Returns:
    /// - `Ok(())` if packet was successfully filled in
    /// - `Err(Error::EmptyRawPktsQueue)` if there was no available data
    pub fn recv_raw_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()> {
        let raw_vsock_pkt = self
            .raw_pkts_queue
            .write()
            .unwrap()
            .pop_front()
            .ok_or(Error::EmptyRawPktsQueue)?;

        pkt.set_header_from_raw(&raw_vsock_pkt.header).unwrap();
        if !raw_vsock_pkt.data.is_empty() {
            let buf = pkt.data_slice().ok_or(Error::PktBufMissing)?;
            buf.copy_from(&raw_vsock_pkt.data);
        }

        Ok(())
    }

    /// Handle a new guest initiated connection, i.e from the peer, the guest driver.
    ///
    /// Attempts to connect to a host side unix socket listening on a path
    /// corresponding to the destination port as follows:
    /// - "{self.host_sock_path}_{local_port}""
    fn handle_new_guest_conn<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>) {
        let port_path = format!("{}_{}", self.host_socket_path, pkt.dst_port());

        UnixStream::connect(port_path)
            .and_then(|stream| stream.set_nonblocking(true).map(|_| stream))
            .map_err(Error::UnixConnect)
            .and_then(|stream| self.add_new_guest_conn(stream, pkt))
            .unwrap_or_else(|_| self.enq_rst());
    }

    /// Wrapper to add new connection to relevant HashMaps.
    fn add_new_guest_conn<B: BitmapSlice>(
        &mut self,
        stream: UnixStream,
        pkt: &VsockPacket<B>,
    ) -> Result<()> {
        let conn = VsockConnection::new_peer_init(
            stream.try_clone().map_err(Error::UnixConnect)?,
            pkt.dst_cid(),
            pkt.dst_port(),
            pkt.src_cid(),
            pkt.src_port(),
            self.epoll_fd,
            pkt.buf_alloc(),
            self.tx_buffer_size,
        );
        let stream_fd = conn.stream.as_raw_fd();
        self.listener_map
            .insert(stream_fd, ConnMapKey::new(pkt.dst_port(), pkt.src_port()));

        self.conn_map
            .insert(ConnMapKey::new(pkt.dst_port(), pkt.src_port()), conn);
        self.backend_rxq
            .push_back(ConnMapKey::new(pkt.dst_port(), pkt.src_port()));

        self.stream_map.insert(stream_fd, stream);
        self.local_port_set.insert(pkt.dst_port());

        VhostUserVsockThread::epoll_register(
            self.epoll_fd,
            stream_fd,
            epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT,
        )?;
        Ok(())
    }

    /// Enqueue RST packets to be sent to guest.
    fn enq_rst(&mut self) {
        // TODO
        dbg!("New guest conn error: Enqueue RST");
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::vhu_vsock::{VhostUserVsockBackend, VsockConfig, VSOCK_OP_RW};
    use std::os::unix::net::UnixListener;
    use tempfile::tempdir;
    use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};

    const DATA_LEN: usize = 16;
    const CONN_TX_BUF_SIZE: u32 = 64 * 1024;
    const GROUP_NAME: &str = "default";

    #[test]
    fn test_vsock_thread_backend() {
        const CID: u64 = 3;
        const VSOCK_PEER_PORT: u32 = 1234;

        let test_dir = tempdir().expect("Could not create a temp test directory.");

        let vsock_socket_path = test_dir.path().join("test_vsock_thread_backend.vsock");
        let vsock_peer_path = test_dir.path().join("test_vsock_thread_backend.vsock_1234");

        let _listener = UnixListener::bind(&vsock_peer_path).unwrap();

        let epoll_fd = epoll::create(false).unwrap();

        let groups_set: HashSet<String> = vec![GROUP_NAME.to_string()].into_iter().collect();

        let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new()));

        let mut vtp = VsockThreadBackend::new(
            vsock_socket_path.display().to_string(),
            epoll_fd,
            CID,
            CONN_TX_BUF_SIZE,
            Arc::new(RwLock::new(groups_set)),
            cid_map,
        );

        assert!(!vtp.pending_rx());

        let mut pkt_raw = [0u8; PKT_HEADER_SIZE + DATA_LEN];
        let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE);

        // SAFETY: Safe as hdr_raw and data_raw are guaranteed to be valid.
        let mut packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() };

        assert_eq!(
            vtp.recv_pkt(&mut packet).unwrap_err().to_string(),
            Error::EmptyBackendRxQ.to_string()
        );

        assert!(vtp.send_pkt(&packet).is_ok());

        packet.set_type(VSOCK_TYPE_STREAM);
        assert!(vtp.send_pkt(&packet).is_ok());

        packet.set_src_cid(CID);
        packet.set_dst_cid(VSOCK_HOST_CID);
        packet.set_dst_port(VSOCK_PEER_PORT);
        assert!(vtp.send_pkt(&packet).is_ok());

        packet.set_op(VSOCK_OP_REQUEST);
        assert!(vtp.send_pkt(&packet).is_ok());

        packet.set_op(VSOCK_OP_RW);
        assert!(vtp.send_pkt(&packet).is_ok());

        packet.set_op(VSOCK_OP_RST);
        assert!(vtp.send_pkt(&packet).is_ok());

        assert!(vtp.recv_pkt(&mut packet).is_ok());

        // cleanup
        let _ = std::fs::remove_file(&vsock_peer_path);
        let _ = std::fs::remove_file(&vsock_socket_path);

        test_dir.close().unwrap();
    }

    #[test]
    fn test_vsock_thread_backend_sibling_vms() {
        const CID: u64 = 3;
        const SIBLING_CID: u64 = 4;
        const SIBLING_LISTENING_PORT: u32 = 1234;

        let test_dir = tempdir().expect("Could not create a temp test directory.");

        let vsock_socket_path = test_dir
            .path()
            .join("test_vsock_thread_backend.vsock")
            .display()
            .to_string();
        let sibling_vhost_socket_path = test_dir
            .path()
            .join("test_vsock_thread_backend_sibling.socket")
            .display()
            .to_string();
        let sibling_vsock_socket_path = test_dir
            .path()
            .join("test_vsock_thread_backend_sibling.vsock")
            .display()
            .to_string();

        let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new()));

        let sibling_config = VsockConfig::new(
            SIBLING_CID,
            sibling_vhost_socket_path,
            sibling_vsock_socket_path,
            CONN_TX_BUF_SIZE,
            vec!["group1", "group2", "group3"]
                .into_iter()
                .map(String::from)
                .collect(),
        );

        let sibling_backend =
            Arc::new(VhostUserVsockBackend::new(sibling_config, cid_map.clone()).unwrap());

        let epoll_fd = epoll::create(false).unwrap();

        let groups_set: HashSet<String> = vec!["groupA", "groupB", "group3"]
            .into_iter()
            .map(String::from)
            .collect();

        let mut vtp = VsockThreadBackend::new(
            vsock_socket_path,
            epoll_fd,
            CID,
            CONN_TX_BUF_SIZE,
            Arc::new(RwLock::new(groups_set)),
            cid_map,
        );

        assert!(!vtp.pending_raw_pkts());

        let mut pkt_raw = [0u8; PKT_HEADER_SIZE + DATA_LEN];
        let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE);

        // SAFETY: Safe as hdr_raw and data_raw are guaranteed to be valid.
        let mut packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() };

        assert_eq!(
            vtp.recv_raw_pkt(&mut packet).unwrap_err().to_string(),
            Error::EmptyRawPktsQueue.to_string()
        );

        packet.set_type(VSOCK_TYPE_STREAM);
        packet.set_src_cid(CID);
        packet.set_dst_cid(SIBLING_CID);
        packet.set_dst_port(SIBLING_LISTENING_PORT);
        packet.set_op(VSOCK_OP_RW);
        packet.set_len(DATA_LEN as u32);
        packet
            .data_slice()
            .unwrap()
            .copy_from(&[0xCAu8, 0xFEu8, 0xBAu8, 0xBEu8]);

        assert!(vtp.send_pkt(&packet).is_ok());
        assert!(sibling_backend.threads[0]
            .lock()
            .unwrap()
            .thread_backend
            .pending_raw_pkts());

        let mut recvd_pkt_raw = [0u8; PKT_HEADER_SIZE + DATA_LEN];
        let (recvd_hdr_raw, recvd_data_raw) = recvd_pkt_raw.split_at_mut(PKT_HEADER_SIZE);

        let mut recvd_packet =
            // SAFETY: Safe as recvd_hdr_raw and recvd_data_raw are guaranteed to be valid.
            unsafe { VsockPacket::new(recvd_hdr_raw, Some(recvd_data_raw)).unwrap() };

        assert!(sibling_backend.threads[0]
            .lock()
            .unwrap()
            .thread_backend
            .recv_raw_pkt(&mut recvd_packet)
            .is_ok());

        assert_eq!(recvd_packet.type_(), VSOCK_TYPE_STREAM);
        assert_eq!(recvd_packet.src_cid(), CID);
        assert_eq!(recvd_packet.dst_cid(), SIBLING_CID);
        assert_eq!(recvd_packet.dst_port(), SIBLING_LISTENING_PORT);
        assert_eq!(recvd_packet.op(), VSOCK_OP_RW);
        assert_eq!(recvd_packet.len(), DATA_LEN as u32);

        assert_eq!(recvd_data_raw[0], 0xCAu8);
        assert_eq!(recvd_data_raw[1], 0xFEu8);
        assert_eq!(recvd_data_raw[2], 0xBAu8);
        assert_eq!(recvd_data_raw[3], 0xBEu8);

        test_dir.close().unwrap();
    }
}