diff options
author | Jeongik Cha <jeongik@google.com> | 2023-09-27 08:11:48 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2023-09-27 08:11:48 +0000 |
commit | 95da31ac976324283db0ec32c35afadc84327ae5 (patch) | |
tree | c689c04bc0ee16244ba8959d8a5bd02c15e5d909 | |
parent | 0c08e13e2cbcd3ff5f93b21ea1cd5db8b4ec4973 (diff) | |
parent | 69aec0289bec504f3c58997212653c59d4151bc4 (diff) | |
download | vmm-sys-util-95da31ac976324283db0ec32c35afadc84327ae5.tar.gz |
Import vmm-sys-util am: 69aec0289b
Original change: https://android-review.googlesource.com/c/platform/external/rust/crates/vmm-sys-util/+/2754406
Change-Id: Id396d8d0876870a43e8a961a6298caebdaec0c2d
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
42 files changed, 7624 insertions, 0 deletions
diff --git a/.buildkite/pipeline.windows.yml b/.buildkite/pipeline.windows.yml new file mode 100644 index 0000000..ecfdb56 --- /dev/null +++ b/.buildkite/pipeline.windows.yml @@ -0,0 +1,66 @@ +steps: + - label: "build-msvc-x86" + commands: + - cargo build --release + retry: + automatic: true + agents: + platform: x86_64 + os: windows + plugins: + - docker#v3.7.0: + image: "lpetrut/rust_win_buildtools" + always-pull: true + + - label: "style" + command: cargo fmt --all -- --check + retry: + automatic: true + agents: + platform: x86_64 + os: windows + plugins: + - docker#v3.7.0: + image: "lpetrut/rust_win_buildtools" + always-pull: true + + - label: "unittests-msvc-x86" + commands: + - cargo test --all-features + retry: + automatic: true + agents: + platform: x86_64 + os: windows + plugins: + - docker#v3.7.0: + image: "lpetrut/rust_win_buildtools" + always-pull: true + + - label: "clippy-x86" + commands: + - cargo clippy --all + retry: + automatic: true + agents: + platform: x86_64 + os: windows + plugins: + - docker#v3.7.0: + image: "lpetrut/rust_win_buildtools" + always-pull: true + + - label: "check-warnings-x86" + commands: + - cargo check --all-targets + retry: + automatic: true + agents: + platform: x86_64 + os: windows + plugins: + - docker#v3.7.0: + image: "lpetrut/rust_win_buildtools" + always-pull: true + environment: + - "RUSTFLAGS=-D warnings" diff --git a/.cargo/config b/.cargo/config new file mode 100644 index 0000000..0c55a7c --- /dev/null +++ b/.cargo/config @@ -0,0 +1,2 @@ +[target.aarch64-unknown-linux-musl] +rustflags = [ "-C", "target-feature=+crt-static", "-C", "link-arg=-lgcc"] diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json new file mode 100644 index 0000000..e83a746 --- /dev/null +++ b/.cargo_vcs_info.json @@ -0,0 +1,6 @@ +{ + "git": { + "sha1": "0e10ca98b55797a64319d746d94379f7cdf81d02" + }, + "path_in_vcs": "" +}
\ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..97b2020 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,7 @@ +version: 2 +updates: +- package-ecosystem: gitsubmodule + directory: "/" + schedule: + interval: monthly + open-pull-requests-limit: 10 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f21ea09 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/target +.idea +**/*.rs.bk +Cargo.lock + diff --git a/Android.bp b/Android.bp new file mode 100644 index 0000000..e72fe88 --- /dev/null +++ b/Android.bp @@ -0,0 +1,17 @@ +// This file is generated by cargo2android.py --config cargo2android.json. +// Do not modify this file as changes will be overridden on upgrade. + + + +rust_library_host { + name: "libvmm_sys_util", + crate_name: "vmm_sys_util", + cargo_env_compat: true, + cargo_pkg_version: "0.11.1", + srcs: ["src/lib.rs"], + edition: "2021", + rustlibs: [ + "libbitflags-1.3.2", + "liblibc", + ], +} diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..2a08682 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,195 @@ +# Changelog +## v0.11.1 + +### Changed +- [[#178](https://github.com/rust-vmm/vmm-sys-util/issues/178)]: Fixed a bug in + `rand_bytes` that was triggering a panic when the number of bytes was not a + multiple of 4. +- [[#181](https://github.com/rust-vmm/vmm-sys-util/pull/181)]: Changed + `TempFile::new_with_prefix()` on linux to use `mkstemp` to prevent name + collisions. + +## v0.11.0 + +### Added +- Added `rand_bytes` function that generates a pseudo random vector of + `len` bytes. +- Added implementation of `std::error::Error` for `fam::Error`. + - Added derive `Eq` and `PartialEq` for error types. + +### Changed +- [[#161](https://github.com/rust-vmm/vmm-sys-util/issues/161)]: Updated the + license to BSD-3-Clause. +- Use edition 2021. +- [[vm-memory#199](https://github.com/rust-vmm/vm-memory/issues/199)]: Use caret + dependencies. This is the idiomatic way of specifying dependencies. + With this we reduce the risk of breaking customer code when new releases of + the dependencies are published. +- Renamed `xor_psuedo_rng_u32` to `xor_pseudo_rng_u32` to fix a typo. +- Renamed `xor_psuedo_rng_u8_alphanumerics` to `xor_pseudo_rng_u8_alphanumerics` + to fix a typo. + +## v0.10.0 + +### Added +- Added Android support by using the appropriate macro configuration when + exporting functionality. +- Derive `Debug` for `FamStructWrapper` & `EventFd`. + +### Changed +- The `ioctl_expr` is now a const function instead of a macro. + +## v0.9.0 + +### Changed +* Fixed safety for sock_ctrl_msg::raw_recvmsg() and enhanced documentation +* Fixed sock_cmsg: ensured copy_nonoverlapping safety +* [[#135](https://github.com/rust-vmm/vmm-sys-util/pull/135)]: sock_ctrl_msg: + mark recv_with_fds as unsafe + + +## v0.8.0 + +* Added set_check_for_hangup() to PollContext. +* Added writable()/has_error()/raw_events() to PollEvent. +* Derived Copy/Clone for PollWatchingEvents. +* Fixed the implementation of `write_zeroes` to use `FALLOC_FL_ZERO_RANGE` + instead of `FALLOC_FL_PUNCH_HOLE`. +* Added `write_all_zeroes` to `WriteZeroes`, which calls `write_zeroes` in a + loop until the requested length is met. +* Added a new trait, `WriteZeroesAt`, which allows giving the offset in file + instead of using the current cursor. +* Removed `max_events` from `Epoll::wait` which removes possible undefined + behavior. +* [[#104](https://github.com/rust-vmm/vmm-sys-util/issues/104)]: Fixed FAM + struct `PartialEq` implementation. +* [[#85](https://github.com/rust-vmm/vmm-sys-util/issues/85)]: Fixed FAM struct + `Clone` implementation. +* [[#99](https://github.com/rust-vmm/vmm-sys-util/issues/99)]: Validate the + maximum capacity when initializing FAM Struct. + +# v0.7.0 + +* Switched to Rust edition 2018. +* Added the `metric` module that provides a `Metric` interface as well as a + default implementation for `AtomicU64`. + +# v0.6.1 + +* Implemented `From<io::Error>` for `errno::Error`. + +# v0.6.0 + +* Derived Copy for EpollEvent. +* Implemented Debug for EpollEvent. +* Changed `Epoll::ctl` signature such that `EpollEvent` is passed by + value and not by reference. +* Enabled this crate to be used on other Unixes (besides Linux) by using + target_os = linux where appropriate. + +# v0.5.0 + +* Added conditionally compiled `serde` compatibility to `FamStructWrapper`, + gated by the `with-serde` feature. +* Implemented `Into<std::io::Error` for `errno::Error`. +* Added a wrapper over `libc::epoll` used for basic epoll operations. + +# v0.4.0 + +* Added Windows support for TempFile and errno::Error. +* Added `into_file` for TempFile which enables the TempFile to be used as a + regular file. +* Implemented std::error::Error for errno::Error. +* Fixed the implementation of `register_signal_handler` by allowing only + valid signal numbers. + +# v0.3.1 + +* Advertise functionality for obtaining POSIX real time signal base which is + needed to provide absolute signals in the API changed in v0.3.0. + +# v0.3.0 + +* Removed `for_vcpu` argument from `signal::register_signal_handler` and + `signal::validate_signal_num`. Users can now pass absolute values for all + valid signal numbers. +* Removed `flag` argument of `signal::register_signal_handler` public methods, + which now defaults to `libc::SA_SIGINFO`. +* Changed `TempFile::new` and `TempDir::new` to create new temporary files/ + directories inside `$TMPDIR` if set, otherwise inside `/tmp`. +* Added methods which create temporary files/directories with prefix. + +# v0.2.1 + +* Fixed the FamStructWrapper Clone implementation to avoid UB. + +# v0.2.0 + +* fam: updated the macro that generates implementions of FamStruct to + also take a parameter that specifies the name of the flexible array + member. + +# v0.1.1 + +* Fixed the Cargo.toml license. +* Fixed some clippy warnings. + +# v0.1.0 + +This is the first vmm-sys-util crate release. + +It is a collection of modules implementing helpers and utilities used by +multiple rust-vmm components and rust-vmm based VMMs. +Most of the code in this first release is based on either the crosvm or the +Firecracker projects, or both. + +The first release comes with the following Rust modules: + +* aio: Safe wrapper over + [`Linux AIO`](http://man7.org/linux/man-pages/man7/aio.7.html). + +* errno: Structures, helpers, and type definitions for working with + [`errno`](http://man7.org/linux/man-pages/man3/errno.3.html). + +* eventfd: Structure and wrapper functions for working with + [`eventfd`](http://man7.org/linux/man-pages/man2/eventfd.2.html). + +* fallocate: Enum and function for dealing with an allocated disk space + by [`fallocate`](http://man7.org/linux/man-pages/man2/fallocate.2.html). + +* fam: Trait and wrapper for working with C defined FAM structures. + +* file_traits: Traits for handling file synchronization and length. + +* ioctls: Macros and functions for working with + [`ioctl`](http://man7.org/linux/man-pages/man2/ioctl.2.html). + +* poll: Traits and structures for working with + [`epoll`](http://man7.org/linux/man-pages/man7/epoll.7.html) + +* rand: Miscellaneous functions related to getting (pseudo) random + numbers and strings. + +* seek_hole: Traits and implementations over + [`lseek64`](https://linux.die.net/man/3/lseek64). + +* signal: Enums, traits and functions for working with + [`signal`](http://man7.org/linux/man-pages/man7/signal.7.html). + +* sockctrl_msg: Wrapper for sending and receiving messages with file + descriptors on sockets that accept control messages (e.g. Unix domain + sockets). + +* tempdir: Structure for handling temporary directories. + +* tempfile: Struct for handling temporary files as well as any cleanup + required. + +* terminal: Trait for working with + [`termios`](http://man7.org/linux/man-pages/man3/termios.3.html). + +* timerfd: Structure and functions for working with + [`timerfd`](http://man7.org/linux/man-pages/man2/timerfd_create.2.html). + +* write_zeroes: Traits for replacing a range with a hole and writing + zeroes in a file. diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000..b0440e9 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @liujing2 @sameo @andreeaflorescu @jiangliu diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..f1da15c --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,39 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2021" +name = "vmm-sys-util" +version = "0.11.1" +authors = ["Intel Virtualization Team <vmm-maintainers@intel.com>"] +description = "A system utility set" +readme = "README.md" +keywords = ["utils"] +license = "BSD-3-Clause" +repository = "https://github.com/rust-vmm/vmm-sys-util" +resolver = "2" +[dependencies.libc] +version = "0.2.39" + +[dependencies.serde] +version = "1.0.27" +optional = true + +[dependencies.serde_derive] +version = "1.0.27" +optional = true +[dev-dependencies.serde_json] +version = "1.0.9" + +[features] +with-serde = ["serde", "serde_derive"] +[target."cfg(any(target_os = \"linux\", target_os = \"android\"))".dependencies.bitflags] +version = "1.0" diff --git a/Cargo.toml.orig b/Cargo.toml.orig new file mode 100644 index 0000000..668a2ac --- /dev/null +++ b/Cargo.toml.orig @@ -0,0 +1,24 @@ +[package] +name = "vmm-sys-util" +version = "0.11.1" +authors = ["Intel Virtualization Team <vmm-maintainers@intel.com>"] +description = "A system utility set" +repository = "https://github.com/rust-vmm/vmm-sys-util" +readme = "README.md" +keywords = ["utils"] +edition = "2021" +license = "BSD-3-Clause" + +[features] +with-serde = ["serde", "serde_derive"] + +[dependencies] +libc = "0.2.39" +serde = { version = "1.0.27", optional = true } +serde_derive = { version = "1.0.27", optional = true } + +[target.'cfg(any(target_os = "linux", target_os = "android"))'.dependencies] +bitflags = "1.0" + +[dev-dependencies] +serde_json = "1.0.9" @@ -0,0 +1 @@ +LICENSE-BSD-3-Clause
\ No newline at end of file diff --git a/LICENSE-BSD-3-Clause b/LICENSE-BSD-3-Clause new file mode 100644 index 0000000..8bafca3 --- /dev/null +++ b/LICENSE-BSD-3-Clause @@ -0,0 +1,27 @@ +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/METADATA b/METADATA new file mode 100644 index 0000000..1c8b946 --- /dev/null +++ b/METADATA @@ -0,0 +1,19 @@ +name: "vmm-sys-util" +description: "A system utility set" +third_party { + identifier { + type: "crates.io" + value: "https://crates.io/crates/vmm-sys-util" + } + identifier { + type: "Archive" + value: "https://static.crates.io/crates/vmm-sys-util/vmm-sys-util-0.11.1.crate" + } + version: "0.11.1" + license_type: NOTICE + last_upgrade_date { + year: 2023 + month: 8 + day: 23 + } +} diff --git a/MODULE_LICENSE_BSD_LIKE b/MODULE_LICENSE_BSD_LIKE new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/MODULE_LICENSE_BSD_LIKE @@ -0,0 +1 @@ +include platform/prebuilts/rust:master:/OWNERS diff --git a/README.md b/README.md new file mode 100644 index 0000000..293a11a --- /dev/null +++ b/README.md @@ -0,0 +1,24 @@ +# vmm-sys-util + +[![crates.io](https://img.shields.io/crates/v/vmm-sys-util)](https://crates.io/crates/vmm-sys-util) +[![docs.rs](https://img.shields.io/docsrs/vmm-sys-util)](https://docs.rs/vmm-sys-util/) + +This crate is a collection of modules that provides helpers and utilities +used by multiple [rust-vmm](https://github.com/rust-vmm/community) components. + +The crate implements safe wrappers around common utilities for working +with files, event file descriptors, ioctls and others. + +## Support + +**Platforms**: +- x86_64 +- aarch64 + +**Operating Systems**: +- Linux +- Windows (partial support) + +## License + +This code is licensed under [BSD-3-Clause](LICENSE-BSD-3-Clause). diff --git a/cargo2android.json b/cargo2android.json new file mode 100644 index 0000000..1f2fa38 --- /dev/null +++ b/cargo2android.json @@ -0,0 +1,6 @@ +{ + "run": true, + "dep-suffixes": { + "bitflags": "-1.3.2" + } +}
\ No newline at end of file diff --git a/coverage_config_aarch64.json b/coverage_config_aarch64.json new file mode 100644 index 0000000..6b35952 --- /dev/null +++ b/coverage_config_aarch64.json @@ -0,0 +1,5 @@ +{ + "coverage_score": 84.3, + "exclude_path": "", + "crate_features": "" +} diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json new file mode 100644 index 0000000..4277fc0 --- /dev/null +++ b/coverage_config_x86_64.json @@ -0,0 +1,5 @@ +{ + "coverage_score": 87.1, + "exclude_path": "", + "crate_features": "" +} diff --git a/src/errno.rs b/src/errno.rs new file mode 100644 index 0000000..27b9bcf --- /dev/null +++ b/src/errno.rs @@ -0,0 +1,193 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// +// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Portions Copyright 2017 The Chromium OS Authors. All rights reserved. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Structures, helpers, and type definitions for working with +//! [`errno`](http://man7.org/linux/man-pages/man3/errno.3.html). + +use std::fmt::{Display, Formatter}; +use std::io; +use std::result; + +/// Wrapper over [`errno`](http://man7.org/linux/man-pages/man3/errno.3.html). +/// +/// The error number is an integer number set by system calls and some libc +/// functions in case of error. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Error(i32); + +/// A specialized [Result](https://doc.rust-lang.org/std/result/enum.Result.html) type +/// for operations that can return `errno`. +/// +/// This typedef is generally used to avoid writing out `errno::Error` directly and is +/// otherwise a direct mapping to `Result`. +pub type Result<T> = result::Result<T, Error>; + +impl Error { + /// Creates a new error from the given error number. + /// + /// # Arguments + /// + /// * `errno`: error number used for creating the `Error`. + /// + /// # Examples + /// + /// ``` + /// # extern crate libc; + /// extern crate vmm_sys_util; + /// # + /// # use libc; + /// use vmm_sys_util::errno::Error; + /// + /// let err = Error::new(libc::EIO); + /// ``` + pub fn new(errno: i32) -> Error { + Error(errno) + } + + /// Returns the last occurred `errno` wrapped in an `Error`. + /// + /// Calling `Error::last()` is the equivalent of using + /// [`errno`](http://man7.org/linux/man-pages/man3/errno.3.html) in C/C++. + /// The result of this function only has meaning after a libc call or syscall + /// where `errno` was set. + /// + /// # Examples + /// + /// ``` + /// # extern crate libc; + /// extern crate vmm_sys_util; + /// # + /// # use libc; + /// # use std::fs::File; + /// # use std::io::{self, Read}; + /// # use std::env::temp_dir; + /// use vmm_sys_util::errno::Error; + /// # + /// // Reading from a file without permissions returns an error. + /// let mut path = temp_dir(); + /// path.push("test"); + /// let mut file = File::create(path).unwrap(); + /// let mut buf: Vec<u8> = Vec::new(); + /// assert!(file.read_to_end(&mut buf).is_err()); + /// + /// // Retrieve the error number of the previous operation using `Error::last()`: + /// let read_err = Error::last(); + /// #[cfg(unix)] + /// assert_eq!(read_err, Error::new(libc::EBADF)); + /// #[cfg(not(unix))] + /// assert_eq!(read_err, Error::new(libc::EIO)); + /// ``` + pub fn last() -> Error { + // It's safe to unwrap because this `Error` was constructed via `last_os_error`. + Error(io::Error::last_os_error().raw_os_error().unwrap()) + } + + /// Returns the raw integer value (`errno`) corresponding to this Error. + /// + /// # Examples + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::errno::Error; + /// + /// let err = Error::new(13); + /// assert_eq!(err.errno(), 13); + /// ``` + pub fn errno(self) -> i32 { + self.0 + } +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + io::Error::from_raw_os_error(self.0).fmt(f) + } +} + +impl std::error::Error for Error {} + +impl From<io::Error> for Error { + fn from(e: io::Error) -> Self { + Error::new(e.raw_os_error().unwrap_or_default()) + } +} + +impl From<Error> for io::Error { + fn from(err: Error) -> io::Error { + io::Error::from_raw_os_error(err.0) + } +} + +/// Returns the last `errno` as a [`Result`] that is always an error. +/// +/// [`Result`]: type.Result.html +pub fn errno_result<T>() -> Result<T> { + Err(Error::last()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env::temp_dir; + use std::error::Error as _; + use std::fs::OpenOptions; + use std::io::{self, Read}; + + #[test] + pub fn test_errno() { + #[cfg(unix)] + let expected_errno = libc::EBADF; + #[cfg(not(unix))] + let expected_errno = libc::EIO; + + // try to read from a file without read permissions + let mut path = temp_dir(); + path.push("test"); + let mut file = OpenOptions::new() + .read(false) + .write(true) + .create(true) + .truncate(true) + .open(path) + .unwrap(); + let mut buf: Vec<u8> = Vec::new(); + assert!(file.read_to_end(&mut buf).is_err()); + + // Test that errno_result returns Err and the error is the expected one. + let last_err = errno_result::<i32>().unwrap_err(); + assert_eq!(last_err, Error::new(expected_errno)); + + // Test that the inner value of `Error` corresponds to expected_errno. + assert_eq!(last_err.errno(), expected_errno); + assert!(last_err.source().is_none()); + + // Test creating an `Error` from a `std::io::Error`. + assert_eq!(last_err, Error::from(io::Error::last_os_error())); + + // Test that calling `last()` returns the same error as `errno_result()`. + assert_eq!(last_err, Error::last()); + + let last_err: io::Error = last_err.into(); + // Test creating a `std::io::Error` from an `Error` + assert_eq!(io::Error::last_os_error().kind(), last_err.kind()); + } + + #[test] + pub fn test_display() { + // Test the display implementation. + #[cfg(target_os = "linux")] + assert_eq!( + format!("{}", Error::new(libc::EBADF)), + "Bad file descriptor (os error 9)" + ); + #[cfg(not(unix))] + assert_eq!( + format!("{}", Error::new(libc::EIO)), + "Access is denied. (os error 5)" + ); + } +} diff --git a/src/fam.rs b/src/fam.rs new file mode 100644 index 0000000..0d62b0f --- /dev/null +++ b/src/fam.rs @@ -0,0 +1,1048 @@ +// Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Portions Copyright 2017 The Chromium OS Authors. All rights reserved. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Trait and wrapper for working with C defined FAM structures. +//! +//! In C 99 an array of unknown size may appear within a struct definition as the last member +//! (as long as there is at least one other named member). +//! This is known as a flexible array member (FAM). +//! Pre C99, the same behavior could be achieved using zero length arrays. +//! +//! Flexible Array Members are the go-to choice for working with large amounts of data +//! prefixed by header values. +//! +//! For example the KVM API has many structures of this kind. + +#[cfg(feature = "with-serde")] +use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor}; +#[cfg(feature = "with-serde")] +use serde::{ser::SerializeTuple, Serialize, Serializer}; +use std::fmt; +#[cfg(feature = "with-serde")] +use std::marker::PhantomData; +use std::mem::{self, size_of}; + +/// Errors associated with the [`FamStructWrapper`](struct.FamStructWrapper.html) struct. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Error { + /// The max size has been exceeded + SizeLimitExceeded, +} + +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::SizeLimitExceeded => write!(f, "The max size has been exceeded"), + } + } +} + +/// Trait for accessing properties of C defined FAM structures. +/// +/// # Safety +/// +/// This is unsafe due to the number of constraints that aren't checked: +/// * the implementer should be a POD +/// * the implementor should contain a flexible array member of elements of type `Entry` +/// * `Entry` should be a POD +/// +/// Violating these may cause problems. +/// +/// # Example +/// +/// ``` +/// use vmm_sys_util::fam::*; +/// +/// #[repr(C)] +/// #[derive(Default)] +/// pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>, [T; 0]); +/// impl<T> __IncompleteArrayField<T> { +/// #[inline] +/// pub fn new() -> Self { +/// __IncompleteArrayField(::std::marker::PhantomData, []) +/// } +/// #[inline] +/// pub unsafe fn as_ptr(&self) -> *const T { +/// ::std::mem::transmute(self) +/// } +/// #[inline] +/// pub unsafe fn as_mut_ptr(&mut self) -> *mut T { +/// ::std::mem::transmute(self) +/// } +/// #[inline] +/// pub unsafe fn as_slice(&self, len: usize) -> &[T] { +/// ::std::slice::from_raw_parts(self.as_ptr(), len) +/// } +/// #[inline] +/// pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] { +/// ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len) +/// } +/// } +/// +/// #[repr(C)] +/// #[derive(Default)] +/// struct MockFamStruct { +/// pub len: u32, +/// pub padding: u32, +/// pub entries: __IncompleteArrayField<u32>, +/// } +/// +/// unsafe impl FamStruct for MockFamStruct { +/// type Entry = u32; +/// +/// fn len(&self) -> usize { +/// self.len as usize +/// } +/// +/// fn set_len(&mut self, len: usize) { +/// self.len = len as u32 +/// } +/// +/// fn max_len() -> usize { +/// 100 +/// } +/// +/// fn as_slice(&self) -> &[u32] { +/// let len = self.len(); +/// unsafe { self.entries.as_slice(len) } +/// } +/// +/// fn as_mut_slice(&mut self) -> &mut [u32] { +/// let len = self.len(); +/// unsafe { self.entries.as_mut_slice(len) } +/// } +/// } +/// +/// type MockFamStructWrapper = FamStructWrapper<MockFamStruct>; +/// ``` +#[allow(clippy::len_without_is_empty)] +pub unsafe trait FamStruct { + /// The type of the FAM entries + type Entry: PartialEq + Copy; + + /// Get the FAM length + /// + /// These type of structures contain a member that holds the FAM length. + /// This method will return the value of that member. + fn len(&self) -> usize; + + /// Set the FAM length + /// + /// These type of structures contain a member that holds the FAM length. + /// This method will set the value of that member. + fn set_len(&mut self, len: usize); + + /// Get max allowed FAM length + /// + /// This depends on each structure. + /// For example a structure representing the cpuid can contain at most 80 entries. + fn max_len() -> usize; + + /// Get the FAM entries as slice + fn as_slice(&self) -> &[Self::Entry]; + + /// Get the FAM entries as mut slice + fn as_mut_slice(&mut self) -> &mut [Self::Entry]; +} + +/// A wrapper for [`FamStruct`](trait.FamStruct.html). +/// +/// It helps in treating a [`FamStruct`](trait.FamStruct.html) similarly to an actual `Vec`. +#[derive(Debug)] +pub struct FamStructWrapper<T: Default + FamStruct> { + // This variable holds the FamStruct structure. We use a `Vec<T>` to make the allocation + // large enough while still being aligned for `T`. Only the first element of `Vec<T>` + // will actually be used as a `T`. The remaining memory in the `Vec<T>` is for `entries`, + // which must be contiguous. Since the entries are of type `FamStruct::Entry` we must + // be careful to convert the desired capacity of the `FamStructWrapper` + // from `FamStruct::Entry` to `T` when reserving or releasing memory. + mem_allocator: Vec<T>, +} + +impl<T: Default + FamStruct> FamStructWrapper<T> { + /// Convert FAM len to `mem_allocator` len. + /// + /// Get the capacity required by mem_allocator in order to hold + /// the provided number of [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry). + fn mem_allocator_len(fam_len: usize) -> usize { + let wrapper_size_in_bytes = size_of::<T>() + fam_len * size_of::<T::Entry>(); + (wrapper_size_in_bytes + size_of::<T>() - 1) / size_of::<T>() + } + + /// Convert `mem_allocator` len to FAM len. + /// + /// Get the number of elements of type + /// [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry) + /// that fit in a mem_allocator of provided len. + fn fam_len(mem_allocator_len: usize) -> usize { + if mem_allocator_len == 0 { + return 0; + } + + let array_size_in_bytes = (mem_allocator_len - 1) * size_of::<T>(); + array_size_in_bytes / size_of::<T::Entry>() + } + + /// Create a new FamStructWrapper with `num_elements` elements. + /// + /// The elements will be zero-initialized. The type of the elements will be + /// [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry). + /// + /// # Arguments + /// + /// * `num_elements` - The number of elements in the FamStructWrapper. + /// + /// # Errors + /// + /// When `num_elements` is greater than the max possible len, it returns + /// `Error::SizeLimitExceeded`. + pub fn new(num_elements: usize) -> Result<FamStructWrapper<T>, Error> { + if num_elements > T::max_len() { + return Err(Error::SizeLimitExceeded); + } + let required_mem_allocator_capacity = + FamStructWrapper::<T>::mem_allocator_len(num_elements); + + let mut mem_allocator = Vec::with_capacity(required_mem_allocator_capacity); + mem_allocator.push(T::default()); + for _ in 1..required_mem_allocator_capacity { + // SAFETY: Safe as long T follows the requirements of being POD. + mem_allocator.push(unsafe { mem::zeroed() }) + } + mem_allocator[0].set_len(num_elements); + + Ok(FamStructWrapper { mem_allocator }) + } + + /// Create a new FamStructWrapper from a slice of elements. + /// + /// # Arguments + /// + /// * `entries` - The slice of [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry) + /// entries. + /// + /// # Errors + /// + /// When the size of `entries` is greater than the max possible len, it returns + /// `Error::SizeLimitExceeded`. + pub fn from_entries(entries: &[T::Entry]) -> Result<FamStructWrapper<T>, Error> { + let mut adapter = FamStructWrapper::<T>::new(entries.len())?; + + { + let wrapper_entries = adapter.as_mut_fam_struct().as_mut_slice(); + wrapper_entries.copy_from_slice(entries); + } + + Ok(adapter) + } + + /// Create a new FamStructWrapper from the raw content represented as `Vec<T>`. + /// + /// Sometimes we already have the raw content of an FAM struct represented as `Vec<T>`, + /// and want to use the FamStructWrapper as accessors. + /// + /// # Arguments + /// + /// * `content` - The raw content represented as `Vec[T]`. + /// + /// # Safety + /// + /// This function is unsafe because the caller needs to ensure that the raw content is + /// correctly layed out. + pub unsafe fn from_raw(content: Vec<T>) -> Self { + FamStructWrapper { + mem_allocator: content, + } + } + + /// Consume the FamStructWrapper and return the raw content as `Vec<T>`. + pub fn into_raw(self) -> Vec<T> { + self.mem_allocator + } + + /// Get a reference to the actual [`FamStruct`](trait.FamStruct.html) instance. + pub fn as_fam_struct_ref(&self) -> &T { + &self.mem_allocator[0] + } + + /// Get a mut reference to the actual [`FamStruct`](trait.FamStruct.html) instance. + pub fn as_mut_fam_struct(&mut self) -> &mut T { + &mut self.mem_allocator[0] + } + + /// Get a pointer to the [`FamStruct`](trait.FamStruct.html) instance. + /// + /// The caller must ensure that the fam_struct outlives the pointer this + /// function returns, or else it will end up pointing to garbage. + /// + /// Modifying the container referenced by this pointer may cause its buffer + /// to be reallocated, which would also make any pointers to it invalid. + pub fn as_fam_struct_ptr(&self) -> *const T { + self.as_fam_struct_ref() + } + + /// Get a mutable pointer to the [`FamStruct`](trait.FamStruct.html) instance. + /// + /// The caller must ensure that the fam_struct outlives the pointer this + /// function returns, or else it will end up pointing to garbage. + /// + /// Modifying the container referenced by this pointer may cause its buffer + /// to be reallocated, which would also make any pointers to it invalid. + pub fn as_mut_fam_struct_ptr(&mut self) -> *mut T { + self.as_mut_fam_struct() + } + + /// Get the elements slice. + pub fn as_slice(&self) -> &[T::Entry] { + self.as_fam_struct_ref().as_slice() + } + + /// Get the mutable elements slice. + pub fn as_mut_slice(&mut self) -> &mut [T::Entry] { + self.as_mut_fam_struct().as_mut_slice() + } + + /// Get the number of elements of type `FamStruct::Entry` currently in the vec. + fn len(&self) -> usize { + self.as_fam_struct_ref().len() + } + + /// Get the capacity of the `FamStructWrapper` + /// + /// The capacity is measured in elements of type `FamStruct::Entry`. + fn capacity(&self) -> usize { + FamStructWrapper::<T>::fam_len(self.mem_allocator.capacity()) + } + + /// Reserve additional capacity. + /// + /// Reserve capacity for at least `additional` more + /// [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry) elements. + /// + /// If the capacity is already reserved, this method doesn't do anything. + /// If not this will trigger a reallocation of the underlying buffer. + fn reserve(&mut self, additional: usize) { + let desired_capacity = self.len() + additional; + if desired_capacity <= self.capacity() { + return; + } + + let current_mem_allocator_len = self.mem_allocator.len(); + let required_mem_allocator_len = FamStructWrapper::<T>::mem_allocator_len(desired_capacity); + let additional_mem_allocator_len = required_mem_allocator_len - current_mem_allocator_len; + + self.mem_allocator.reserve(additional_mem_allocator_len); + } + + /// Update the length of the FamStructWrapper. + /// + /// The length of `self` will be updated to the specified value. + /// The length of the `T` structure and of `self.mem_allocator` will be updated accordingly. + /// If the len is increased additional capacity will be reserved. + /// If the len is decreased the unnecessary memory will be deallocated. + /// + /// This method might trigger reallocations of the underlying buffer. + /// + /// # Errors + /// + /// When len is greater than the max possible len it returns Error::SizeLimitExceeded. + fn set_len(&mut self, len: usize) -> Result<(), Error> { + let additional_elements: isize = len as isize - self.len() as isize; + // If len == self.len there's nothing to do. + if additional_elements == 0 { + return Ok(()); + } + + // If the len needs to be increased: + if additional_elements > 0 { + // Check if the new len is valid. + if len > T::max_len() { + return Err(Error::SizeLimitExceeded); + } + // Reserve additional capacity. + self.reserve(additional_elements as usize); + } + + let current_mem_allocator_len = self.mem_allocator.len(); + let required_mem_allocator_len = FamStructWrapper::<T>::mem_allocator_len(len); + // Update the len of the `mem_allocator`. + // SAFETY: This is safe since enough capacity has been reserved. + unsafe { + self.mem_allocator.set_len(required_mem_allocator_len); + } + // Zero-initialize the additional elements if any. + for i in current_mem_allocator_len..required_mem_allocator_len { + // SAFETY: Safe as long as the trait is only implemented for POD. This is a requirement + // for the trait implementation. + self.mem_allocator[i] = unsafe { mem::zeroed() } + } + // Update the len of the underlying `FamStruct`. + self.as_mut_fam_struct().set_len(len); + + // If the len needs to be decreased, deallocate unnecessary memory + if additional_elements < 0 { + self.mem_allocator.shrink_to_fit(); + } + + Ok(()) + } + + /// Append an element. + /// + /// # Arguments + /// + /// * `entry` - The element that will be appended to the end of the collection. + /// + /// # Errors + /// + /// When len is already equal to max possible len it returns Error::SizeLimitExceeded. + pub fn push(&mut self, entry: T::Entry) -> Result<(), Error> { + let new_len = self.len() + 1; + self.set_len(new_len)?; + self.as_mut_slice()[new_len - 1] = entry; + + Ok(()) + } + + /// Retain only the elements specified by the predicate. + /// + /// # Arguments + /// + /// * `f` - The function used to evaluate whether an entry will be kept or not. + /// When `f` returns `true` the entry is kept. + pub fn retain<P>(&mut self, mut f: P) + where + P: FnMut(&T::Entry) -> bool, + { + let mut num_kept_entries = 0; + { + let entries = self.as_mut_slice(); + for entry_idx in 0..entries.len() { + let keep = f(&entries[entry_idx]); + if keep { + entries[num_kept_entries] = entries[entry_idx]; + num_kept_entries += 1; + } + } + } + + // This is safe since this method is not increasing the len + self.set_len(num_kept_entries).expect("invalid length"); + } +} + +impl<T: Default + FamStruct + PartialEq> PartialEq for FamStructWrapper<T> { + fn eq(&self, other: &FamStructWrapper<T>) -> bool { + self.as_fam_struct_ref() == other.as_fam_struct_ref() && self.as_slice() == other.as_slice() + } +} + +impl<T: Default + FamStruct> Clone for FamStructWrapper<T> { + fn clone(&self) -> Self { + // The number of entries (self.as_slice().len()) can't be > T::max_len() since `self` is a + // valid `FamStructWrapper`. + let required_mem_allocator_capacity = + FamStructWrapper::<T>::mem_allocator_len(self.as_slice().len()); + + let mut mem_allocator = Vec::with_capacity(required_mem_allocator_capacity); + + // SAFETY: This is safe as long as the requirements for the `FamStruct` trait to be safe + // are met (the implementing type and the entries elements are POD, therefore `Copy`, so + // memory safety can't be violated by the ownership of `fam_struct`). It is also safe + // because we're trying to read a T from a `&T` that is pointing to a properly initialized + // and aligned T. + unsafe { + let fam_struct: T = std::ptr::read(self.as_fam_struct_ref()); + mem_allocator.push(fam_struct); + } + for _ in 1..required_mem_allocator_capacity { + mem_allocator.push( + // SAFETY: This is safe as long as T respects the FamStruct trait and is a POD. + unsafe { mem::zeroed() }, + ) + } + + let mut adapter = FamStructWrapper { mem_allocator }; + { + let wrapper_entries = adapter.as_mut_fam_struct().as_mut_slice(); + wrapper_entries.copy_from_slice(self.as_slice()); + } + adapter + } +} + +impl<T: Default + FamStruct> From<Vec<T>> for FamStructWrapper<T> { + fn from(vec: Vec<T>) -> Self { + FamStructWrapper { mem_allocator: vec } + } +} + +#[cfg(feature = "with-serde")] +impl<T: Default + FamStruct + Serialize> Serialize for FamStructWrapper<T> +where + <T as FamStruct>::Entry: serde::Serialize, +{ + fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> + where + S: Serializer, + { + let mut s = serializer.serialize_tuple(2)?; + s.serialize_element(self.as_fam_struct_ref())?; + s.serialize_element(self.as_slice())?; + s.end() + } +} + +#[cfg(feature = "with-serde")] +impl<'de, T: Default + FamStruct + Deserialize<'de>> Deserialize<'de> for FamStructWrapper<T> +where + <T as FamStruct>::Entry: std::marker::Copy + serde::Deserialize<'de>, +{ + fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error> + where + D: Deserializer<'de>, + { + struct FamStructWrapperVisitor<X> { + dummy: PhantomData<X>, + } + + impl<'de, X: Default + FamStruct + Deserialize<'de>> Visitor<'de> for FamStructWrapperVisitor<X> + where + <X as FamStruct>::Entry: std::marker::Copy + serde::Deserialize<'de>, + { + type Value = FamStructWrapper<X>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("FamStructWrapper") + } + + fn visit_seq<V>(self, mut seq: V) -> Result<FamStructWrapper<X>, V::Error> + where + V: SeqAccess<'de>, + { + use serde::de::Error; + + let header = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let entries: Vec<X::Entry> = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + + let mut result: Self::Value = FamStructWrapper::from_entries(entries.as_slice()) + .map_err(|e| V::Error::custom(format!("{:?}", e)))?; + result.mem_allocator[0] = header; + Ok(result) + } + } + + deserializer.deserialize_tuple(2, FamStructWrapperVisitor { dummy: PhantomData }) + } +} + +/// Generate `FamStruct` implementation for structs with flexible array member. +#[macro_export] +macro_rules! generate_fam_struct_impl { + ($struct_type: ty, $entry_type: ty, $entries_name: ident, + $field_type: ty, $field_name: ident, $max: expr) => { + unsafe impl FamStruct for $struct_type { + type Entry = $entry_type; + + fn len(&self) -> usize { + self.$field_name as usize + } + + fn set_len(&mut self, len: usize) { + self.$field_name = len as $field_type; + } + + fn max_len() -> usize { + $max + } + + fn as_slice(&self) -> &[<Self as FamStruct>::Entry] { + let len = self.len(); + unsafe { self.$entries_name.as_slice(len) } + } + + fn as_mut_slice(&mut self) -> &mut [<Self as FamStruct>::Entry] { + let len = self.len(); + unsafe { self.$entries_name.as_mut_slice(len) } + } + } + }; +} + +#[cfg(test)] +mod tests { + #![allow(clippy::undocumented_unsafe_blocks)] + #[cfg(feature = "with-serde")] + use serde_derive::{Deserialize, Serialize}; + + use super::*; + + const MAX_LEN: usize = 100; + + #[repr(C)] + #[derive(Default, PartialEq, Eq)] + pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>, [T; 0]); + impl<T> __IncompleteArrayField<T> { + #[inline] + pub fn new() -> Self { + __IncompleteArrayField(::std::marker::PhantomData, []) + } + #[inline] + pub unsafe fn as_ptr(&self) -> *const T { + self as *const __IncompleteArrayField<T> as *const T + } + #[inline] + pub unsafe fn as_mut_ptr(&mut self) -> *mut T { + self as *mut __IncompleteArrayField<T> as *mut T + } + #[inline] + pub unsafe fn as_slice(&self, len: usize) -> &[T] { + ::std::slice::from_raw_parts(self.as_ptr(), len) + } + #[inline] + pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] { + ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len) + } + } + + #[cfg(feature = "with-serde")] + impl<T> Serialize for __IncompleteArrayField<T> { + fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> + where + S: Serializer, + { + [0u8; 0].serialize(serializer) + } + } + + #[cfg(feature = "with-serde")] + impl<'de, T> Deserialize<'de> for __IncompleteArrayField<T> { + fn deserialize<D>(_: D) -> std::result::Result<Self, D::Error> + where + D: Deserializer<'de>, + { + Ok(__IncompleteArrayField::new()) + } + } + + #[repr(C)] + #[derive(Default, PartialEq)] + struct MockFamStruct { + pub len: u32, + pub padding: u32, + pub entries: __IncompleteArrayField<u32>, + } + + generate_fam_struct_impl!(MockFamStruct, u32, entries, u32, len, 100); + + type MockFamStructWrapper = FamStructWrapper<MockFamStruct>; + + const ENTRIES_OFFSET: usize = 2; + + const FAM_LEN_TO_MEM_ALLOCATOR_LEN: &[(usize, usize)] = &[ + (0, 1), + (1, 2), + (2, 2), + (3, 3), + (4, 3), + (5, 4), + (10, 6), + (50, 26), + (100, 51), + ]; + + const MEM_ALLOCATOR_LEN_TO_FAM_LEN: &[(usize, usize)] = &[ + (0, 0), + (1, 0), + (2, 2), + (3, 4), + (4, 6), + (5, 8), + (10, 18), + (50, 98), + (100, 198), + ]; + + #[test] + fn test_mem_allocator_len() { + for pair in FAM_LEN_TO_MEM_ALLOCATOR_LEN { + let fam_len = pair.0; + let mem_allocator_len = pair.1; + assert_eq!( + mem_allocator_len, + MockFamStructWrapper::mem_allocator_len(fam_len) + ); + } + } + + #[test] + fn test_wrapper_len() { + for pair in MEM_ALLOCATOR_LEN_TO_FAM_LEN { + let mem_allocator_len = pair.0; + let fam_len = pair.1; + assert_eq!(fam_len, MockFamStructWrapper::fam_len(mem_allocator_len)); + } + } + + #[test] + fn test_new() { + let num_entries = 10; + + let adapter = MockFamStructWrapper::new(num_entries).unwrap(); + assert_eq!(num_entries, adapter.capacity()); + + let u32_slice = unsafe { + std::slice::from_raw_parts( + adapter.as_fam_struct_ptr() as *const u32, + num_entries + ENTRIES_OFFSET, + ) + }; + assert_eq!(num_entries, u32_slice[0] as usize); + for entry in u32_slice[1..].iter() { + assert_eq!(*entry, 0); + } + + // It's okay to create a `FamStructWrapper` with the maximum allowed number of entries. + let adapter = MockFamStructWrapper::new(MockFamStruct::max_len()).unwrap(); + assert_eq!(MockFamStruct::max_len(), adapter.capacity()); + + assert!(matches!( + MockFamStructWrapper::new(MockFamStruct::max_len() + 1), + Err(Error::SizeLimitExceeded) + )); + } + + #[test] + fn test_from_entries() { + let num_entries: usize = 10; + + let mut entries = Vec::new(); + for i in 0..num_entries { + entries.push(i as u32); + } + + let adapter = MockFamStructWrapper::from_entries(entries.as_slice()).unwrap(); + let u32_slice = unsafe { + std::slice::from_raw_parts( + adapter.as_fam_struct_ptr() as *const u32, + num_entries + ENTRIES_OFFSET, + ) + }; + assert_eq!(num_entries, u32_slice[0] as usize); + for (i, &value) in entries.iter().enumerate().take(num_entries) { + assert_eq!(adapter.as_slice()[i], value); + } + + let mut entries = Vec::new(); + for i in 0..MockFamStruct::max_len() + 1 { + entries.push(i as u32); + } + + // Can't create a `FamStructWrapper` with a number of entries > MockFamStruct::max_len(). + assert!(matches!( + MockFamStructWrapper::from_entries(entries.as_slice()), + Err(Error::SizeLimitExceeded) + )); + } + + #[test] + fn test_entries_slice() { + let num_entries = 10; + let mut adapter = MockFamStructWrapper::new(num_entries).unwrap(); + + let expected_slice = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + + { + let mut_entries_slice = adapter.as_mut_slice(); + mut_entries_slice.copy_from_slice(expected_slice); + } + + let u32_slice = unsafe { + std::slice::from_raw_parts( + adapter.as_fam_struct_ptr() as *const u32, + num_entries + ENTRIES_OFFSET, + ) + }; + assert_eq!(expected_slice, &u32_slice[ENTRIES_OFFSET..]); + assert_eq!(expected_slice, adapter.as_slice()); + } + + #[test] + fn test_reserve() { + let mut adapter = MockFamStructWrapper::new(0).unwrap(); + + // test that the right capacity is reserved + for pair in FAM_LEN_TO_MEM_ALLOCATOR_LEN { + let num_elements = pair.0; + let required_mem_allocator_len = pair.1; + + adapter.reserve(num_elements); + + assert!(adapter.mem_allocator.capacity() >= required_mem_allocator_len); + assert_eq!(0, adapter.len()); + assert!(adapter.capacity() >= num_elements); + } + + // test that when the capacity is already reserved, the method doesn't do anything + let current_capacity = adapter.capacity(); + adapter.reserve(current_capacity - 1); + assert_eq!(current_capacity, adapter.capacity()); + } + + #[test] + fn test_set_len() { + let mut desired_len = 0; + let mut adapter = MockFamStructWrapper::new(desired_len).unwrap(); + + // keep initial len + assert!(adapter.set_len(desired_len).is_ok()); + assert_eq!(adapter.len(), desired_len); + + // increase len + desired_len = 10; + assert!(adapter.set_len(desired_len).is_ok()); + // check that the len has been increased and zero-initialized elements have been added + assert_eq!(adapter.len(), desired_len); + for element in adapter.as_slice() { + assert_eq!(*element, 0_u32); + } + + // decrease len + desired_len = 5; + assert!(adapter.set_len(desired_len).is_ok()); + assert_eq!(adapter.len(), desired_len); + } + + #[test] + fn test_push() { + let mut adapter = MockFamStructWrapper::new(0).unwrap(); + + for i in 0..MAX_LEN { + assert!(adapter.push(i as u32).is_ok()); + assert_eq!(adapter.as_slice()[i], i as u32); + assert_eq!(adapter.len(), i + 1); + assert!( + adapter.mem_allocator.capacity() >= MockFamStructWrapper::mem_allocator_len(i + 1) + ); + } + + assert!(adapter.push(0).is_err()); + } + + #[test] + fn test_retain() { + let mut adapter = MockFamStructWrapper::new(0).unwrap(); + + let mut num_retained_entries = 0; + for i in 0..MAX_LEN { + assert!(adapter.push(i as u32).is_ok()); + if i % 2 == 0 { + num_retained_entries += 1; + } + } + + adapter.retain(|entry| entry % 2 == 0); + + for entry in adapter.as_slice().iter() { + assert_eq!(0, entry % 2); + } + assert_eq!(adapter.len(), num_retained_entries); + assert!( + adapter.mem_allocator.capacity() + >= MockFamStructWrapper::mem_allocator_len(num_retained_entries) + ); + } + + #[test] + fn test_partial_eq() { + let mut wrapper_1 = MockFamStructWrapper::new(0).unwrap(); + let mut wrapper_2 = MockFamStructWrapper::new(0).unwrap(); + let mut wrapper_3 = MockFamStructWrapper::new(0).unwrap(); + + for i in 0..MAX_LEN { + assert!(wrapper_1.push(i as u32).is_ok()); + assert!(wrapper_2.push(i as u32).is_ok()); + assert!(wrapper_3.push(0).is_ok()); + } + + assert!(wrapper_1 == wrapper_2); + assert!(wrapper_1 != wrapper_3); + } + + #[test] + fn test_clone() { + let mut adapter = MockFamStructWrapper::new(0).unwrap(); + + for i in 0..MAX_LEN { + assert!(adapter.push(i as u32).is_ok()); + } + + assert!(adapter == adapter.clone()); + } + + #[test] + fn test_raw_content() { + let data = vec![ + MockFamStruct { + len: 2, + padding: 5, + entries: __IncompleteArrayField::new(), + }, + MockFamStruct { + len: 0xA5, + padding: 0x1e, + entries: __IncompleteArrayField::new(), + }, + ]; + + let mut wrapper = unsafe { MockFamStructWrapper::from_raw(data) }; + { + let payload = wrapper.as_slice(); + assert_eq!(payload[0], 0xA5); + assert_eq!(payload[1], 0x1e); + } + assert_eq!(wrapper.as_mut_fam_struct().padding, 5); + let data = wrapper.into_raw(); + assert_eq!(data[0].len, 2); + assert_eq!(data[0].padding, 5); + } + + #[cfg(feature = "with-serde")] + #[test] + fn test_ser_deser() { + #[repr(C)] + #[derive(Default, PartialEq)] + #[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))] + struct Message { + pub len: u32, + pub padding: u32, + pub value: u32, + #[cfg_attr(feature = "with-serde", serde(skip))] + pub entries: __IncompleteArrayField<u32>, + } + + generate_fam_struct_impl!(Message, u32, entries, u32, len, 100); + + type MessageFamStructWrapper = FamStructWrapper<Message>; + + let data = vec![ + Message { + len: 2, + padding: 0, + value: 42, + entries: __IncompleteArrayField::new(), + }, + Message { + len: 0xA5, + padding: 0x1e, + value: 0, + entries: __IncompleteArrayField::new(), + }, + ]; + + let wrapper = unsafe { MessageFamStructWrapper::from_raw(data) }; + let data_ser = serde_json::to_string(&wrapper).unwrap(); + assert_eq!( + data_ser, + "[{\"len\":2,\"padding\":0,\"value\":42},[165,30]]" + ); + let data_deser = + serde_json::from_str::<MessageFamStructWrapper>(data_ser.as_str()).unwrap(); + assert!(wrapper.eq(&data_deser)); + + let bad_data_ser = r#"{"foo": "bar"}"#; + assert!(serde_json::from_str::<MessageFamStructWrapper>(bad_data_ser).is_err()); + + #[repr(C)] + #[derive(Default)] + #[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))] + struct Message2 { + pub len: u32, + pub padding: u32, + pub value: u32, + #[cfg_attr(feature = "with-serde", serde(skip))] + pub entries: __IncompleteArrayField<u32>, + } + + // Maximum number of entries = 1, so the deserialization should fail because of this reason. + generate_fam_struct_impl!(Message2, u32, entries, u32, len, 1); + + type Message2FamStructWrapper = FamStructWrapper<Message2>; + assert!(serde_json::from_str::<Message2FamStructWrapper>(data_ser.as_str()).is_err()); + } + + #[test] + fn test_clone_multiple_fields() { + #[derive(Default, PartialEq)] + #[repr(C)] + struct Foo { + index: u32, + length: u16, + flags: u32, + entries: __IncompleteArrayField<u32>, + } + + generate_fam_struct_impl!(Foo, u32, entries, u16, length, 100); + + type FooFamStructWrapper = FamStructWrapper<Foo>; + + let mut wrapper = FooFamStructWrapper::new(0).unwrap(); + wrapper.as_mut_fam_struct().index = 1; + wrapper.as_mut_fam_struct().flags = 2; + wrapper.as_mut_fam_struct().length = 3; + wrapper.push(3).unwrap(); + wrapper.push(14).unwrap(); + assert_eq!(wrapper.as_slice().len(), 3 + 2); + assert_eq!(wrapper.as_slice()[3], 3); + assert_eq!(wrapper.as_slice()[3 + 1], 14); + + let mut wrapper2 = wrapper.clone(); + assert_eq!( + wrapper.as_mut_fam_struct().index, + wrapper2.as_mut_fam_struct().index + ); + assert_eq!( + wrapper.as_mut_fam_struct().length, + wrapper2.as_mut_fam_struct().length + ); + assert_eq!( + wrapper.as_mut_fam_struct().flags, + wrapper2.as_mut_fam_struct().flags + ); + assert_eq!(wrapper.as_slice(), wrapper2.as_slice()); + assert_eq!( + wrapper2.as_slice().len(), + wrapper2.as_mut_fam_struct().length as usize + ); + assert!(wrapper == wrapper2); + + wrapper.as_mut_fam_struct().index = 3; + assert!(wrapper != wrapper2); + + wrapper.as_mut_fam_struct().length = 7; + assert!(wrapper != wrapper2); + + wrapper.push(1).unwrap(); + assert_eq!(wrapper.as_mut_fam_struct().length, 8); + assert!(wrapper != wrapper2); + + let mut wrapper2 = wrapper.clone(); + assert!(wrapper == wrapper2); + + // Dropping the original variable should not affect its clone. + drop(wrapper); + assert_eq!(wrapper2.as_mut_fam_struct().index, 3); + assert_eq!(wrapper2.as_mut_fam_struct().length, 8); + assert_eq!(wrapper2.as_mut_fam_struct().flags, 2); + assert_eq!(wrapper2.as_slice(), [0, 0, 0, 3, 14, 0, 0, 1]); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..1929816 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,24 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! Collection of modules that provides helpers and utilities used by multiple +//! [rust-vmm](https://github.com/rust-vmm/community) components. + +#![deny(missing_docs)] + +#[cfg(any(target_os = "linux", target_os = "android"))] +mod linux; +#[cfg(any(target_os = "linux", target_os = "android"))] +pub use crate::linux::*; + +#[cfg(unix)] +mod unix; +#[cfg(unix)] +pub use crate::unix::*; + +pub mod errno; +pub mod fam; +pub mod metric; +pub mod rand; +pub mod syscall; +pub mod tempfile; diff --git a/src/linux/aio.rs b/src/linux/aio.rs new file mode 100644 index 0000000..1e14ea0 --- /dev/null +++ b/src/linux/aio.rs @@ -0,0 +1,362 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! Safe wrapper over [`Linux AIO`](http://man7.org/linux/man-pages/man7/aio.7.html). + +#![allow(non_camel_case_types)] + +/* automatically generated by rust-bindgen from file linux/include/uapi/linux/aio_abi.h + * of commit 69973b8 and then manually edited */ + +use std::io::{Error, Result}; +use std::os::raw::{c_int, c_long, c_uint, c_ulong}; +use std::ptr::null_mut; + +type __s16 = ::std::os::raw::c_short; +type __u16 = ::std::os::raw::c_ushort; +type __u32 = ::std::os::raw::c_uint; +type __s64 = ::std::os::raw::c_longlong; +type __u64 = ::std::os::raw::c_ulonglong; + +/// Read from a file descriptor at a given offset. +pub const IOCB_CMD_PREAD: u32 = 0; +/// Write to a file descriptor at a given offset. +pub const IOCB_CMD_PWRITE: u32 = 1; +/// Synchronize a file's in-core metadata and data to storage device. +pub const IOCB_CMD_FSYNC: u32 = 2; +/// Synchronize a file's in-core data to storage device. +pub const IOCB_CMD_FDSYNC: u32 = 3; +/// Noop, this defined by never used by linux kernel. +pub const IOCB_CMD_NOOP: u32 = 6; +/// Read from a file descriptor at a given offset into multiple buffers. +pub const IOCB_CMD_PREADV: u32 = 7; +/// Write to a file descriptor at a given offset from multiple buffers. +pub const IOCB_CMD_PWRITEV: u32 = 8; + +/// Valid flags for the "aio_flags" member of the "struct iocb". +/// Set if the "aio_resfd" member of the "struct iocb" is valid. +pub const IOCB_FLAG_RESFD: u32 = 1; + +/// Maximum number of concurrent requests. +pub const MAX_REQUESTS: usize = 0x10000; + +/// Wrapper over the [`iocb`](https://elixir.bootlin.com/linux/v4.9/source/include/uapi/linux/aio_abi.h#L79) structure. +#[allow(missing_docs)] +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct IoControlBlock { + pub aio_data: __u64, + pub aio_key: __u32, + pub aio_reserved1: __u32, + pub aio_lio_opcode: __u16, + pub aio_reqprio: __s16, + pub aio_fildes: __u32, + pub aio_buf: __u64, + pub aio_nbytes: __u64, + pub aio_offset: __s64, + pub aio_reserved2: __u64, + pub aio_flags: __u32, + pub aio_resfd: __u32, +} + +/// Wrapper over the [`io_event`](https://elixir.bootlin.com/linux/v4.9/source/include/uapi/linux/aio_abi.h#L58) structure. +#[allow(missing_docs)] +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct IoEvent { + pub data: __u64, + pub obj: __u64, + pub res: __s64, + pub res2: __s64, +} + +/// Newtype for [`aio_context_t`](https://elixir.bootlin.com/linux/v4.9/source/include/uapi/linux/aio_abi.h#L33). +#[repr(transparent)] +#[derive(Debug)] +pub struct IoContext(::std::os::raw::c_ulong); + +impl IoContext { + /// Create a new aio context instance. + /// + /// Refer to Linux [`io_setup`](http://man7.org/linux/man-pages/man2/io_setup.2.html). + /// + /// # Arguments + /// * `nr_events`: maximum number of concurrently processing IO operations. + #[allow(clippy::new_ret_no_self)] + pub fn new(nr_events: c_uint) -> Result<Self> { + if nr_events as usize > MAX_REQUESTS { + return Err(Error::from_raw_os_error(libc::EINVAL)); + } + + let mut ctx = IoContext(0); + let rc = + // SAFETY: Safe because we use valid parameters and check the result. + unsafe { libc::syscall(libc::SYS_io_setup, nr_events, &mut ctx as *mut Self) as c_int }; + if rc < 0 { + Err(Error::last_os_error()) + } else { + Ok(ctx) + } + } + + /// Submit asynchronous I/O blocks for processing. + /// + /// Refer to Linux [`io_submit`](http://man7.org/linux/man-pages/man2/io_submit.2.html). + /// + /// # Arguments + /// * `iocbs`: array of AIO control blocks, which will be submitted to the context. + /// + /// # Examples + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::aio::*; + /// # use std::fs::File; + /// # use std::os::unix::io::AsRawFd; + /// + /// let file = File::open("/dev/zero").unwrap(); + /// let ctx = IoContext::new(128).unwrap(); + /// let mut buf: [u8; 4096] = unsafe { std::mem::uninitialized() }; + /// let iocbs = [&mut IoControlBlock { + /// aio_fildes: file.as_raw_fd() as u32, + /// aio_lio_opcode: IOCB_CMD_PREAD as u16, + /// aio_buf: buf.as_mut_ptr() as u64, + /// aio_nbytes: buf.len() as u64, + /// ..Default::default() + /// }]; + /// assert_eq!(ctx.submit(&iocbs[..]).unwrap(), 1); + /// ``` + pub fn submit(&self, iocbs: &[&mut IoControlBlock]) -> Result<usize> { + // SAFETY: It's safe because parameters are valid and we have checked the result. + let rc = unsafe { + libc::syscall( + libc::SYS_io_submit, + self.0, + iocbs.len() as c_ulong, + iocbs.as_ptr(), + ) as c_int + }; + if rc < 0 { + Err(Error::last_os_error()) + } else { + Ok(rc as usize) + } + } + + /// Cancel an outstanding asynchronous I/O operation. + /// + /// Refer to Linux [`io_cancel`](http://man7.org/linux/man-pages/man2/io_cancel.2.html). + /// Note: according to current Linux kernel implementation(v4.19), libc::SYS_io_cancel always + /// return failure, thus rendering it useless. + /// + /// # Arguments + /// * `iocb`: The iocb for the operation to be canceled. + /// * `result`: If the operation is successfully canceled, the event will be copied into the + /// memory pointed to by result without being placed into the completion queue. + pub fn cancel(&self, iocb: &IoControlBlock, result: &mut IoEvent) -> Result<()> { + // SAFETY: It's safe because parameters are valid and we have checked the result. + let rc = unsafe { + libc::syscall( + libc::SYS_io_cancel, + self.0, + iocb as *const IoControlBlock, + result as *mut IoEvent, + ) as c_int + }; + if rc < 0 { + Err(Error::last_os_error()) + } else { + Ok(()) + } + } + + /// Read asynchronous I/O events from the completion queue. + /// + /// Refer to Linux [`io_getevents`](http://man7.org/linux/man-pages/man2/io_getevents.2.html). + /// + /// # Arguments + /// * `min_nr`: read at least min_nr events. + /// * `events`: array to receive the io operation results. + /// * `timeout`: optional amount of time to wait for events. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::aio::*; + /// # use std::fs::File; + /// # use std::os::unix::io::AsRawFd; + /// + /// let file = File::open("/dev/zero").unwrap(); + /// let ctx = IoContext::new(128).unwrap(); + /// let mut buf: [u8; 4096] = unsafe { std::mem::uninitialized() }; + /// let iocbs = [ + /// &mut IoControlBlock { + /// aio_fildes: file.as_raw_fd() as u32, + /// aio_lio_opcode: IOCB_CMD_PREAD as u16, + /// aio_buf: buf.as_mut_ptr() as u64, + /// aio_nbytes: buf.len() as u64, + /// ..Default::default() + /// }, + /// &mut IoControlBlock { + /// aio_fildes: file.as_raw_fd() as u32, + /// aio_lio_opcode: IOCB_CMD_PREAD as u16, + /// aio_buf: buf.as_mut_ptr() as u64, + /// aio_nbytes: buf.len() as u64, + /// ..Default::default() + /// }, + /// ]; + /// + /// let mut rc = ctx.submit(&iocbs[..]).unwrap(); + /// let mut events = [unsafe { std::mem::uninitialized::<IoEvent>() }]; + /// rc = ctx.get_events(1, &mut events, None).unwrap(); + /// assert_eq!(rc, 1); + /// assert!(events[0].res > 0); + /// rc = ctx.get_events(1, &mut events, None).unwrap(); + /// assert_eq!(rc, 1); + /// assert!(events[0].res > 0); + /// ``` + pub fn get_events( + &self, + min_nr: c_long, + events: &mut [IoEvent], + timeout: Option<&mut libc::timespec>, + ) -> Result<usize> { + let to = match timeout { + Some(val) => val as *mut libc::timespec, + None => null_mut() as *mut libc::timespec, + }; + + // SAFETY: It's safe because parameters are valid and we have checked the result. + let rc = unsafe { + libc::syscall( + libc::SYS_io_getevents, + self.0, + min_nr, + events.len() as c_long, + events.as_mut_ptr(), + to, + ) as c_int + }; + if rc < 0 { + Err(Error::last_os_error()) + } else { + Ok(rc as usize) + } + } +} + +impl Drop for IoContext { + fn drop(&mut self) { + if self.0 != 0 { + // SAFETY: It's safe because the context is created by us. + let _ = unsafe { libc::syscall(libc::SYS_io_destroy, self.0) as c_int }; + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::fs::File; + use std::os::unix::io::AsRawFd; + + #[test] + fn test_new_context() { + let _ = IoContext::new(0).unwrap_err(); + } + + #[test] + fn test_cancel_request() { + let file = File::open("/dev/zero").unwrap(); + + let ctx = IoContext::new(128).unwrap(); + let mut buf: [u8; 16384] = [0u8; 16384]; + let iocbs = [&mut IoControlBlock { + aio_fildes: file.as_raw_fd() as u32, + aio_lio_opcode: IOCB_CMD_PREAD as u16, + aio_buf: buf.as_mut_ptr() as u64, + aio_nbytes: buf.len() as u64, + ..Default::default() + }]; + + let mut rc = ctx.submit(&iocbs).unwrap(); + assert_eq!(rc, 1); + + let mut result = Default::default(); + let err = ctx + .cancel(iocbs[0], &mut result) + .unwrap_err() + .raw_os_error() + .unwrap(); + assert_eq!(err, libc::EINVAL); + + let mut events = [IoEvent::default()]; + rc = ctx.get_events(1, &mut events, None).unwrap(); + assert_eq!(rc, 1); + assert!(events[0].res > 0); + } + + #[test] + fn test_read_zero() { + let file = File::open("/dev/zero").unwrap(); + + let ctx = IoContext::new(128).unwrap(); + let mut buf: [u8; 4096] = [0u8; 4096]; + let iocbs = [ + &mut IoControlBlock { + aio_fildes: file.as_raw_fd() as u32, + aio_lio_opcode: IOCB_CMD_PREAD as u16, + aio_buf: buf.as_mut_ptr() as u64, + aio_nbytes: buf.len() as u64, + ..Default::default() + }, + &mut IoControlBlock { + aio_fildes: file.as_raw_fd() as u32, + aio_lio_opcode: IOCB_CMD_PREAD as u16, + aio_buf: buf.as_mut_ptr() as u64, + aio_nbytes: buf.len() as u64, + ..Default::default() + }, + ]; + + let mut rc = ctx.submit(&iocbs[..]).unwrap(); + assert_eq!(rc, 2); + + let mut events = [IoEvent::default()]; + rc = ctx.get_events(1, &mut events, None).unwrap(); + assert_eq!(rc, 1); + assert!(events[0].res > 0); + + rc = ctx.get_events(1, &mut events, None).unwrap(); + assert_eq!(rc, 1); + assert!(events[0].res > 0); + } + + #[test] + fn bindgen_test_layout_io_event() { + assert_eq!( + ::std::mem::size_of::<IoEvent>(), + 32usize, + concat!("Size of: ", stringify!(IoEvent)) + ); + assert_eq!( + ::std::mem::align_of::<IoEvent>(), + 8usize, + concat!("Alignment of", stringify!(IoEvent)) + ); + } + + #[test] + fn bindgen_test_layout_iocb() { + assert_eq!( + ::std::mem::size_of::<IoControlBlock>(), + 64usize, + concat!("Size of:", stringify!(IoControlBlock)) + ); + assert_eq!( + ::std::mem::align_of::<IoControlBlock>(), + 8usize, + concat!("Alignment of", stringify!(IoControlBlock)) + ); + } +} diff --git a/src/linux/epoll.rs b/src/linux/epoll.rs new file mode 100644 index 0000000..b8e9b7b --- /dev/null +++ b/src/linux/epoll.rs @@ -0,0 +1,522 @@ +// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! Safe wrappers over the +//! [`epoll`](http://man7.org/linux/man-pages/man7/epoll.7.html) API. + +use std::io; +use std::ops::{Deref, Drop}; +use std::os::unix::io::{AsRawFd, RawFd}; + +#[cfg(any(target_os = "linux", target_os = "android"))] +use bitflags::bitflags; +use libc::{ + epoll_create1, epoll_ctl, epoll_event, epoll_wait, EPOLLERR, EPOLLET, EPOLLEXCLUSIVE, EPOLLHUP, + EPOLLIN, EPOLLONESHOT, EPOLLOUT, EPOLLPRI, EPOLLRDHUP, EPOLLWAKEUP, EPOLL_CLOEXEC, + EPOLL_CTL_ADD, EPOLL_CTL_DEL, EPOLL_CTL_MOD, +}; + +use crate::syscall::SyscallReturnCode; + +/// Wrapper over `EPOLL_CTL_*` operations that can be performed on a file descriptor. +#[repr(i32)] +pub enum ControlOperation { + /// Add a file descriptor to the interest list. + Add = EPOLL_CTL_ADD, + /// Change the settings associated with a file descriptor that is + /// already in the interest list. + Modify = EPOLL_CTL_MOD, + /// Remove a file descriptor from the interest list. + Delete = EPOLL_CTL_DEL, +} + +bitflags! { + /// The type of events we can monitor a file descriptor for. + pub struct EventSet: u32 { + /// The associated file descriptor is available for read operations. + const IN = EPOLLIN as u32; + /// The associated file descriptor is available for write operations. + const OUT = EPOLLOUT as u32; + /// Error condition happened on the associated file descriptor. + const ERROR = EPOLLERR as u32; + /// This can be used to detect peer shutdown when using Edge Triggered monitoring. + const READ_HANG_UP = EPOLLRDHUP as u32; + /// Sets the Edge Triggered behavior for the associated file descriptor. + /// The default behavior is Level Triggered. + const EDGE_TRIGGERED = EPOLLET as u32; + /// Hang up happened on the associated file descriptor. Note that `epoll_wait` + /// will always wait for this event and it is not necessary to set it in events. + const HANG_UP = EPOLLHUP as u32; + /// There is an exceptional condition on that file descriptor. It is mostly used to + /// set high priority for some data. + const PRIORITY = EPOLLPRI as u32; + /// The event is considered as being "processed" from the time when it is returned + /// by a call to `epoll_wait` until the next call to `epoll_wait` on the same + /// epoll file descriptor, the closure of that file descriptor, the removal of the + /// event file descriptor via EPOLL_CTL_DEL, or the clearing of EPOLLWAKEUP + /// for the event file descriptor via EPOLL_CTL_MOD. + const WAKE_UP = EPOLLWAKEUP as u32; + /// Sets the one-shot behavior for the associated file descriptor. + const ONE_SHOT = EPOLLONESHOT as u32; + /// Sets an exclusive wake up mode for the epoll file descriptor that is being + /// attached to the associated file descriptor. + /// When a wake up event occurs and multiple epoll file descriptors are attached to + /// the same target file using this mode, one or more of the epoll file descriptors + /// will receive an event with `epoll_wait`. The default here is for all those file + /// descriptors to receive an event. + const EXCLUSIVE = EPOLLEXCLUSIVE as u32; + } +} + +/// Wrapper over +/// ['libc::epoll_event'](https://doc.rust-lang.org/1.8.0/libc/struct.epoll_event.html). +// We are using `transparent` here to be super sure that this struct and its fields +// have the same alignment as those from the `epoll_event` struct from C. +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct EpollEvent(epoll_event); + +impl std::fmt::Debug for EpollEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{{ events: {}, data: {} }}", self.events(), self.data()) + } +} + +impl Deref for EpollEvent { + type Target = epoll_event; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Default for EpollEvent { + fn default() -> Self { + EpollEvent(epoll_event { + events: 0u32, + u64: 0u64, + }) + } +} + +impl EpollEvent { + /// Create a new epoll_event instance. + /// + /// # Arguments + /// + /// `events` - contains an event mask. + /// `data` - a user data variable. `data` field can be a fd on which + /// we want to monitor the events specified by `events`. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::epoll::{EpollEvent, EventSet}; + /// + /// let event = EpollEvent::new(EventSet::IN, 2); + /// ``` + pub fn new(events: EventSet, data: u64) -> Self { + EpollEvent(epoll_event { + events: events.bits(), + u64: data, + }) + } + + /// Returns the `events` from + /// ['libc::epoll_event'](https://doc.rust-lang.org/1.8.0/libc/struct.epoll_event.html). + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::epoll::{EpollEvent, EventSet}; + /// + /// let event = EpollEvent::new(EventSet::IN, 2); + /// assert_eq!(event.events(), 1); + /// ``` + pub fn events(&self) -> u32 { + self.events + } + + /// Returns the `EventSet` corresponding to `epoll_event.events`. + /// + /// # Panics + /// + /// Panics if `libc::epoll_event` contains invalid events. + /// + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::epoll::{EpollEvent, EventSet}; + /// + /// let event = EpollEvent::new(EventSet::IN, 2); + /// assert_eq!(event.event_set(), EventSet::IN); + /// ``` + pub fn event_set(&self) -> EventSet { + // This unwrap is safe because `epoll_events` can only be user created or + // initialized by the kernel. We trust the kernel to only send us valid + // events. The user can only initialize `epoll_events` using valid events. + EventSet::from_bits(self.events()).unwrap() + } + + /// Returns the `data` from the `libc::epoll_event`. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::epoll::{EpollEvent, EventSet}; + /// + /// let event = EpollEvent::new(EventSet::IN, 2); + /// assert_eq!(event.data(), 2); + /// ``` + pub fn data(&self) -> u64 { + self.u64 + } + + /// Converts the `libc::epoll_event` data to a RawFd. + /// + /// This conversion is lossy when the data does not correspond to a RawFd + /// (data does not fit in a i32). + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::epoll::{EpollEvent, EventSet}; + /// + /// let event = EpollEvent::new(EventSet::IN, 2); + /// assert_eq!(event.fd(), 2); + /// ``` + pub fn fd(&self) -> RawFd { + self.u64 as i32 + } +} + +/// Wrapper over epoll functionality. +#[derive(Debug)] +pub struct Epoll { + epoll_fd: RawFd, +} + +impl Epoll { + /// Create a new epoll file descriptor. + pub fn new() -> io::Result<Self> { + let epoll_fd = SyscallReturnCode( + // SAFETY: Safe because the return code is transformed by `into_result` in a `Result`. + unsafe { epoll_create1(EPOLL_CLOEXEC) }, + ) + .into_result()?; + Ok(Epoll { epoll_fd }) + } + + /// Wrapper for `libc::epoll_ctl`. + /// + /// This can be used for adding, modifying or removing a file descriptor in the + /// interest list of the epoll instance. + /// + /// # Arguments + /// + /// * `operation` - refers to the action to be performed on the file descriptor. + /// * `fd` - the file descriptor on which we want to perform `operation`. + /// * `event` - refers to the `epoll_event` instance that is linked to `fd`. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// + /// use std::os::unix::io::AsRawFd; + /// use vmm_sys_util::epoll::{ControlOperation, Epoll, EpollEvent, EventSet}; + /// use vmm_sys_util::eventfd::EventFd; + /// + /// let epoll = Epoll::new().unwrap(); + /// let event_fd = EventFd::new(libc::EFD_NONBLOCK).unwrap(); + /// epoll + /// .ctl( + /// ControlOperation::Add, + /// event_fd.as_raw_fd() as i32, + /// EpollEvent::new(EventSet::OUT, event_fd.as_raw_fd() as u64), + /// ) + /// .unwrap(); + /// epoll + /// .ctl( + /// ControlOperation::Modify, + /// event_fd.as_raw_fd() as i32, + /// EpollEvent::new(EventSet::IN, 4), + /// ) + /// .unwrap(); + /// ``` + pub fn ctl(&self, operation: ControlOperation, fd: RawFd, event: EpollEvent) -> io::Result<()> { + SyscallReturnCode( + // SAFETY: Safe because we give a valid epoll file descriptor, a valid file descriptor + // to watch, as well as a valid epoll_event structure. We also check the return value. + unsafe { + epoll_ctl( + self.epoll_fd, + operation as i32, + fd, + &event as *const EpollEvent as *mut epoll_event, + ) + }, + ) + .into_empty_result() + } + + /// Wrapper for `libc::epoll_wait`. + /// Returns the number of file descriptors in the interest list that became ready + /// for I/O or `errno` if an error occurred. + /// + /// # Arguments + /// + /// * `timeout` - specifies for how long the `epoll_wait` system call will block + /// (measured in milliseconds). + /// * `events` - points to a memory area that will be used for storing the events + /// returned by `epoll_wait()` call. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// + /// use std::os::unix::io::AsRawFd; + /// use vmm_sys_util::epoll::{ControlOperation, Epoll, EpollEvent, EventSet}; + /// use vmm_sys_util::eventfd::EventFd; + /// + /// let epoll = Epoll::new().unwrap(); + /// let event_fd = EventFd::new(libc::EFD_NONBLOCK).unwrap(); + /// + /// let mut ready_events = vec![EpollEvent::default(); 10]; + /// epoll + /// .ctl( + /// ControlOperation::Add, + /// event_fd.as_raw_fd() as i32, + /// EpollEvent::new(EventSet::OUT, 4), + /// ) + /// .unwrap(); + /// let ev_count = epoll.wait(-1, &mut ready_events[..]).unwrap(); + /// assert_eq!(ev_count, 1); + /// ``` + pub fn wait(&self, timeout: i32, events: &mut [EpollEvent]) -> io::Result<usize> { + let events_count = SyscallReturnCode( + // SAFETY: Safe because we give a valid epoll file descriptor and an array of + // epoll_event structures that will be modified by the kernel to indicate information + // about the subset of file descriptors in the interest list. + // We also check the return value. + unsafe { + epoll_wait( + self.epoll_fd, + events.as_mut_ptr() as *mut epoll_event, + events.len() as i32, + timeout, + ) + }, + ) + .into_result()? as usize; + + Ok(events_count) + } +} + +impl AsRawFd for Epoll { + fn as_raw_fd(&self) -> RawFd { + self.epoll_fd + } +} + +impl Drop for Epoll { + fn drop(&mut self) { + // SAFETY: Safe because this fd is opened with `epoll_create` and we trust + // the kernel to give us a valid fd. + unsafe { + libc::close(self.epoll_fd); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::eventfd::EventFd; + + #[test] + fn test_event_ops() { + let mut event = EpollEvent::default(); + assert_eq!(event.events(), 0); + assert_eq!(event.data(), 0); + + event = EpollEvent::new(EventSet::IN, 2); + assert_eq!(event.events(), 1); + assert_eq!(event.event_set(), EventSet::IN); + + assert_eq!(event.data(), 2); + assert_eq!(event.fd(), 2); + } + + #[test] + fn test_events_debug() { + let events = EpollEvent::new(EventSet::IN, 42); + assert_eq!(format!("{:?}", events), "{ events: 1, data: 42 }") + } + + #[test] + fn test_epoll() { + const DEFAULT__TIMEOUT: i32 = 250; + const EVENT_BUFFER_SIZE: usize = 128; + + let epoll = Epoll::new().unwrap(); + assert_eq!(epoll.epoll_fd, epoll.as_raw_fd()); + + // Let's test different scenarios for `epoll_ctl()` and `epoll_wait()` functionality. + + let event_fd_1 = EventFd::new(libc::EFD_NONBLOCK).unwrap(); + // For EPOLLOUT to be available it is enough only to be possible to write a value of + // at least 1 to the eventfd counter without blocking. + // If we write a value greater than 0 to this counter, the fd will be available for + // EPOLLIN events too. + event_fd_1.write(1).unwrap(); + + let mut event_1 = + EpollEvent::new(EventSet::IN | EventSet::OUT, event_fd_1.as_raw_fd() as u64); + + // For EPOLL_CTL_ADD behavior we will try to add some fds with different event masks into + // the interest list of epoll instance. + assert!(epoll + .ctl( + ControlOperation::Add, + event_fd_1.as_raw_fd() as i32, + event_1 + ) + .is_ok()); + + // We can't add twice the same fd to epoll interest list. + assert!(epoll + .ctl( + ControlOperation::Add, + event_fd_1.as_raw_fd() as i32, + event_1 + ) + .is_err()); + + let event_fd_2 = EventFd::new(libc::EFD_NONBLOCK).unwrap(); + event_fd_2.write(1).unwrap(); + assert!(epoll + .ctl( + ControlOperation::Add, + event_fd_2.as_raw_fd() as i32, + // For this fd, we want an Event instance that has `data` field set to other + // value than the value of the fd and `events` without EPOLLIN type set. + EpollEvent::new(EventSet::OUT, 10) + ) + .is_ok()); + + // For the following eventfd we won't write anything to its counter, so we expect EPOLLIN + // event to not be available for this fd, even if we say that we want to monitor this type + // of event via EPOLL_CTL_ADD operation. + let event_fd_3 = EventFd::new(libc::EFD_NONBLOCK).unwrap(); + let event_3 = EpollEvent::new(EventSet::OUT | EventSet::IN, event_fd_3.as_raw_fd() as u64); + assert!(epoll + .ctl( + ControlOperation::Add, + event_fd_3.as_raw_fd() as i32, + event_3 + ) + .is_ok()); + + // Let's check `epoll_wait()` behavior for our epoll instance. + let mut ready_events = vec![EpollEvent::default(); EVENT_BUFFER_SIZE]; + let mut ev_count = epoll.wait(DEFAULT__TIMEOUT, &mut ready_events[..]).unwrap(); + + // We expect to have 3 fds in the ready list of epoll instance. + assert_eq!(ev_count, 3); + + // Let's check also the Event values that are now returned in the ready list. + assert_eq!(ready_events[0].data(), event_fd_1.as_raw_fd() as u64); + // For this fd, `data` field was populated with random data instead of the + // corresponding fd value. + assert_eq!(ready_events[1].data(), 10); + assert_eq!(ready_events[2].data(), event_fd_3.as_raw_fd() as u64); + + // EPOLLIN and EPOLLOUT should be available for this fd. + assert_eq!( + ready_events[0].events(), + (EventSet::IN | EventSet::OUT).bits() + ); + // Only EPOLLOUT is expected because we didn't want to monitor EPOLLIN on this fd. + assert_eq!(ready_events[1].events(), EventSet::OUT.bits()); + // Only EPOLLOUT too because eventfd counter value is 0 (we didn't write a value + // greater than 0 to it). + assert_eq!(ready_events[2].events(), EventSet::OUT.bits()); + + // Now we're gonna modify the Event instance for a fd to test EPOLL_CTL_MOD + // behavior. + // We create here a new Event with some events, other than those previously set, + // that we want to monitor this time on event_fd_1. + event_1 = EpollEvent::new(EventSet::OUT, 20); + assert!(epoll + .ctl( + ControlOperation::Modify, + event_fd_1.as_raw_fd() as i32, + event_1 + ) + .is_ok()); + + let event_fd_4 = EventFd::new(libc::EFD_NONBLOCK).unwrap(); + // Can't modify a fd that wasn't added to epoll interest list. + assert!(epoll + .ctl( + ControlOperation::Modify, + event_fd_4.as_raw_fd() as i32, + EpollEvent::default() + ) + .is_err()); + + let _ = epoll.wait(DEFAULT__TIMEOUT, &mut ready_events[..]).unwrap(); + + // Let's check that Event fields were indeed changed for the `event_fd_1` fd. + assert_eq!(ready_events[0].data(), 20); + // EPOLLOUT is now available for this fd as we've intended with EPOLL_CTL_MOD operation. + assert_eq!(ready_events[0].events(), EventSet::OUT.bits()); + + // Now let's set for a fd to not have any events monitored. + assert!(epoll + .ctl( + ControlOperation::Modify, + event_fd_1.as_raw_fd() as i32, + EpollEvent::default() + ) + .is_ok()); + + // In this particular case we expect to remain only with 2 fds in the ready list. + ev_count = epoll.wait(DEFAULT__TIMEOUT, &mut ready_events[..]).unwrap(); + assert_eq!(ev_count, 2); + + // Let's also delete a fd from the interest list. + assert!(epoll + .ctl( + ControlOperation::Delete, + event_fd_2.as_raw_fd() as i32, + EpollEvent::default() + ) + .is_ok()); + + // We expect to have only one fd remained in the ready list (event_fd_3). + ev_count = epoll.wait(DEFAULT__TIMEOUT, &mut ready_events[..]).unwrap(); + + assert_eq!(ev_count, 1); + assert_eq!(ready_events[0].data(), event_fd_3.as_raw_fd() as u64); + assert_eq!(ready_events[0].events(), EventSet::OUT.bits()); + + // If we try to remove a fd from epoll interest list that wasn't added before it will fail. + assert!(epoll + .ctl( + ControlOperation::Delete, + event_fd_4.as_raw_fd() as i32, + EpollEvent::default() + ) + .is_err()); + } +} diff --git a/src/linux/eventfd.rs b/src/linux/eventfd.rs new file mode 100644 index 0000000..55944f6 --- /dev/null +++ b/src/linux/eventfd.rs @@ -0,0 +1,216 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Structure and wrapper functions for working with +//! [`eventfd`](http://man7.org/linux/man-pages/man2/eventfd.2.html). + +use std::fs::File; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::{io, mem, result}; + +use libc::{c_void, dup, eventfd, read, write}; + +// Reexport commonly used flags from libc. +pub use libc::{EFD_CLOEXEC, EFD_NONBLOCK, EFD_SEMAPHORE}; + +/// A safe wrapper around Linux +/// [`eventfd`](http://man7.org/linux/man-pages/man2/eventfd.2.html). +#[derive(Debug)] +pub struct EventFd { + eventfd: File, +} + +impl EventFd { + /// Create a new EventFd with an initial value. + /// + /// # Arguments + /// + /// * `flag`: The initial value used for creating the `EventFd`. + /// Refer to Linux [`eventfd`](http://man7.org/linux/man-pages/man2/eventfd.2.html). + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::eventfd::{EventFd, EFD_NONBLOCK}; + /// + /// EventFd::new(EFD_NONBLOCK).unwrap(); + /// ``` + pub fn new(flag: i32) -> result::Result<EventFd, io::Error> { + // SAFETY: This is safe because eventfd merely allocated an eventfd for + // our process and we handle the error case. + let ret = unsafe { eventfd(0, flag) }; + if ret < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(EventFd { + // SAFETY: This is safe because we checked ret for success and know + // the kernel gave us an fd that we own. + eventfd: unsafe { File::from_raw_fd(ret) }, + }) + } + } + + /// Add a value to the eventfd's counter. + /// + /// When the addition causes the counter overflow, this would either block + /// until a [`read`](http://man7.org/linux/man-pages/man2/read.2.html) is + /// performed on the file descriptor, or fail with the + /// error EAGAIN if the file descriptor has been made nonblocking. + /// + /// # Arguments + /// + /// * `v`: the value to be added to the eventfd's counter. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::eventfd::{EventFd, EFD_NONBLOCK}; + /// + /// let evt = EventFd::new(EFD_NONBLOCK).unwrap(); + /// evt.write(55).unwrap(); + /// ``` + pub fn write(&self, v: u64) -> result::Result<(), io::Error> { + // SAFETY: This is safe because we made this fd and the pointer we pass + // can not overflow because we give the syscall's size parameter properly. + let ret = unsafe { + write( + self.as_raw_fd(), + &v as *const u64 as *const c_void, + mem::size_of::<u64>(), + ) + }; + if ret <= 0 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + } + } + + /// Read a value from the eventfd. + /// + /// If the counter is zero, this would either block + /// until the counter becomes nonzero, or fail with the + /// error EAGAIN if the file descriptor has been made nonblocking. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::eventfd::{EventFd, EFD_NONBLOCK}; + /// + /// let evt = EventFd::new(EFD_NONBLOCK).unwrap(); + /// evt.write(55).unwrap(); + /// assert_eq!(evt.read().unwrap(), 55); + /// ``` + pub fn read(&self) -> result::Result<u64, io::Error> { + let mut buf: u64 = 0; + // SAFETY: This is safe because we made this fd and the pointer we + // pass can not overflow because we give the syscall's size parameter properly. + let ret = unsafe { + read( + self.as_raw_fd(), + &mut buf as *mut u64 as *mut c_void, + mem::size_of::<u64>(), + ) + }; + if ret < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(buf) + } + } + + /// Clone this EventFd. + /// + /// This internally creates a new file descriptor and it will share the same + /// underlying count within the kernel. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::eventfd::{EventFd, EFD_NONBLOCK}; + /// + /// let evt = EventFd::new(EFD_NONBLOCK).unwrap(); + /// let evt_clone = evt.try_clone().unwrap(); + /// evt.write(923).unwrap(); + /// assert_eq!(evt_clone.read().unwrap(), 923); + /// ``` + pub fn try_clone(&self) -> result::Result<EventFd, io::Error> { + // SAFETY: This is safe because we made this fd and properly check that it returns + // without error. + let ret = unsafe { dup(self.as_raw_fd()) }; + if ret < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(EventFd { + // SAFETY: This is safe because we checked ret for success and know the kernel + // gave us an fd that we own. + eventfd: unsafe { File::from_raw_fd(ret) }, + }) + } + } +} + +impl AsRawFd for EventFd { + fn as_raw_fd(&self) -> RawFd { + self.eventfd.as_raw_fd() + } +} + +impl FromRawFd for EventFd { + unsafe fn from_raw_fd(fd: RawFd) -> Self { + EventFd { + eventfd: File::from_raw_fd(fd), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new() { + EventFd::new(EFD_NONBLOCK).unwrap(); + EventFd::new(0).unwrap(); + } + + #[test] + fn test_read_write() { + let evt = EventFd::new(EFD_NONBLOCK).unwrap(); + evt.write(55).unwrap(); + assert_eq!(evt.read().unwrap(), 55); + } + + #[test] + fn test_write_overflow() { + let evt = EventFd::new(EFD_NONBLOCK).unwrap(); + evt.write(std::u64::MAX - 1).unwrap(); + let r = evt.write(1); + match r { + Err(ref inner) if inner.kind() == io::ErrorKind::WouldBlock => (), + _ => panic!("Unexpected"), + } + } + #[test] + fn test_read_nothing() { + let evt = EventFd::new(EFD_NONBLOCK).unwrap(); + let r = evt.read(); + match r { + Err(ref inner) if inner.kind() == io::ErrorKind::WouldBlock => (), + _ => panic!("Unexpected"), + } + } + #[test] + fn test_clone() { + let evt = EventFd::new(EFD_NONBLOCK).unwrap(); + let evt_clone = evt.try_clone().unwrap(); + evt.write(923).unwrap(); + assert_eq!(evt_clone.read().unwrap(), 923); + } +} diff --git a/src/linux/fallocate.rs b/src/linux/fallocate.rs new file mode 100644 index 0000000..e3a7fed --- /dev/null +++ b/src/linux/fallocate.rs @@ -0,0 +1,93 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Enum and function for dealing with an allocated disk space +//! by [`fallocate`](http://man7.org/linux/man-pages/man2/fallocate.2.html). + +use std::os::unix::io::AsRawFd; + +use crate::errno::{errno_result, Error, Result}; + +/// Operation to be performed on a given range when calling [`fallocate`] +/// +/// [`fallocate`]: fn.fallocate.html +pub enum FallocateMode { + /// Deallocating file space. + PunchHole, + /// Zeroing file space. + ZeroRange, +} + +/// A safe wrapper for [`fallocate`](http://man7.org/linux/man-pages/man2/fallocate.2.html). +/// +/// Manipulate the file space with specified operation parameters. +/// +/// # Arguments +/// +/// * `file`: the file to be manipulate. +/// * `mode`: specify the operation to be performed on the given range. +/// * `keep_size`: file size won't be changed even if `offset` + `len` is greater +/// than the file size. +/// * `offset`: the position that manipulates the file from. +/// * `size`: the bytes of the operation range. +/// +/// # Examples +/// +/// ``` +/// extern crate vmm_sys_util; +/// # use std::fs::OpenOptions; +/// # use std::path::PathBuf; +/// use vmm_sys_util::fallocate::{fallocate, FallocateMode}; +/// use vmm_sys_util::tempdir::TempDir; +/// +/// let tempdir = TempDir::new_with_prefix("/tmp/fallocate_test").unwrap(); +/// let mut path = PathBuf::from(tempdir.as_path()); +/// path.push("file"); +/// let mut f = OpenOptions::new() +/// .read(true) +/// .write(true) +/// .create(true) +/// .open(&path) +/// .unwrap(); +/// fallocate(&f, FallocateMode::PunchHole, true, 0, 1).unwrap(); +/// ``` +pub fn fallocate( + file: &dyn AsRawFd, + mode: FallocateMode, + keep_size: bool, + offset: u64, + len: u64, +) -> Result<()> { + let offset = if offset > libc::off64_t::max_value() as u64 { + return Err(Error::new(libc::EINVAL)); + } else { + offset as libc::off64_t + }; + + let len = if len > libc::off64_t::max_value() as u64 { + return Err(Error::new(libc::EINVAL)); + } else { + len as libc::off64_t + }; + + let mut mode = match mode { + FallocateMode::PunchHole => libc::FALLOC_FL_PUNCH_HOLE, + FallocateMode::ZeroRange => libc::FALLOC_FL_ZERO_RANGE, + }; + + if keep_size { + mode |= libc::FALLOC_FL_KEEP_SIZE; + } + + // SAFETY: Safe since we pass in a valid fd and fallocate mode, validate offset and len, + // and check the return value. + let ret = unsafe { libc::fallocate64(file.as_raw_fd(), mode, offset, len) }; + if ret < 0 { + errno_result() + } else { + Ok(()) + } +} diff --git a/src/linux/ioctl.rs b/src/linux/ioctl.rs new file mode 100644 index 0000000..a1e29d5 --- /dev/null +++ b/src/linux/ioctl.rs @@ -0,0 +1,410 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// +// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Portions Copyright 2017 The Chromium OS Authors. All rights reserved. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Macros and functions for working with +//! [`ioctl`](http://man7.org/linux/man-pages/man2/ioctl.2.html). + +use std::os::raw::{c_int, c_uint, c_ulong, c_void}; +use std::os::unix::io::AsRawFd; + +// The only reason +// [_IOC](https://elixir.bootlin.com/linux/v5.10.129/source/arch/alpha/include/uapi/asm/ioctl.h#L40) +// is a macro in C is because C doesn't have const functions, it is always better when possible to +// use a const function over a macro in Rust. +/// Function to calculate icotl number. Mimic of +/// [_IOC](https://elixir.bootlin.com/linux/v5.10.129/source/arch/alpha/include/uapi/asm/ioctl.h#L40) +/// ``` +/// # use std::os::raw::c_uint; +/// # use vmm_sys_util::ioctl::{ioctl_expr, _IOC_NONE}; +/// const KVMIO: c_uint = 0xAE; +/// ioctl_expr(_IOC_NONE, KVMIO, 0x01, 0); +/// ``` +pub const fn ioctl_expr( + dir: c_uint, + ty: c_uint, + nr: c_uint, + size: c_uint, +) -> ::std::os::raw::c_ulong { + (dir << crate::ioctl::_IOC_DIRSHIFT + | ty << crate::ioctl::_IOC_TYPESHIFT + | nr << crate::ioctl::_IOC_NRSHIFT + | size << crate::ioctl::_IOC_SIZESHIFT) as ::std::os::raw::c_ulong +} + +/// Declare a function that returns an ioctl number. +/// +/// ``` +/// # #[macro_use] extern crate vmm_sys_util; +/// # use std::os::raw::c_uint; +/// use vmm_sys_util::ioctl::_IOC_NONE; +/// +/// const KVMIO: c_uint = 0xAE; +/// ioctl_ioc_nr!(KVM_CREATE_VM, _IOC_NONE, KVMIO, 0x01, 0); +/// ``` +#[macro_export] +macro_rules! ioctl_ioc_nr { + ($name:ident, $dir:expr, $ty:expr, $nr:expr, $size:expr) => { + #[allow(non_snake_case)] + #[allow(clippy::cast_lossless)] + pub fn $name() -> ::std::os::raw::c_ulong { + $crate::ioctl::ioctl_expr($dir, $ty, $nr, $size) + } + }; + ($name:ident, $dir:expr, $ty:expr, $nr:expr, $size:expr, $($v:ident),+) => { + #[allow(non_snake_case)] + #[allow(clippy::cast_lossless)] + pub fn $name($($v: ::std::os::raw::c_uint),+) -> ::std::os::raw::c_ulong { + $crate::ioctl::ioctl_expr($dir, $ty, $nr, $size) + } + }; +} + +/// Declare an ioctl that transfers no data. +/// +/// ``` +/// # #[macro_use] extern crate vmm_sys_util; +/// # use std::os::raw::c_uint; +/// const KVMIO: c_uint = 0xAE; +/// ioctl_io_nr!(KVM_CREATE_VM, KVMIO, 0x01); +/// ``` +#[macro_export] +macro_rules! ioctl_io_nr { + ($name:ident, $ty:expr, $nr:expr) => { + ioctl_ioc_nr!($name, $crate::ioctl::_IOC_NONE, $ty, $nr, 0); + }; + ($name:ident, $ty:expr, $nr:expr, $($v:ident),+) => { + ioctl_ioc_nr!($name, $crate::ioctl::_IOC_NONE, $ty, $nr, 0, $($v),+); + }; +} + +/// Declare an ioctl that reads data. +/// +/// ``` +/// # #[macro_use] extern crate vmm_sys_util; +/// const TUNTAP: ::std::os::raw::c_uint = 0x54; +/// ioctl_ior_nr!(TUNGETFEATURES, TUNTAP, 0xcf, ::std::os::raw::c_uint); +/// ``` +#[macro_export] +macro_rules! ioctl_ior_nr { + ($name:ident, $ty:expr, $nr:expr, $size:ty) => { + ioctl_ioc_nr!( + $name, + $crate::ioctl::_IOC_READ, + $ty, + $nr, + ::std::mem::size_of::<$size>() as u32 + ); + }; + ($name:ident, $ty:expr, $nr:expr, $size:ty, $($v:ident),+) => { + ioctl_ioc_nr!( + $name, + $crate::ioctl::_IOC_READ, + $ty, + $nr, + ::std::mem::size_of::<$size>() as u32, + $($v),+ + ); + }; +} + +/// Declare an ioctl that writes data. +/// +/// ``` +/// # #[macro_use] extern crate vmm_sys_util; +/// const TUNTAP: ::std::os::raw::c_uint = 0x54; +/// ioctl_iow_nr!(TUNSETQUEUE, TUNTAP, 0xd9, ::std::os::raw::c_int); +/// ``` +#[macro_export] +macro_rules! ioctl_iow_nr { + ($name:ident, $ty:expr, $nr:expr, $size:ty) => { + ioctl_ioc_nr!( + $name, + $crate::ioctl::_IOC_WRITE, + $ty, + $nr, + ::std::mem::size_of::<$size>() as u32 + ); + }; + ($name:ident, $ty:expr, $nr:expr, $size:ty, $($v:ident),+) => { + ioctl_ioc_nr!( + $name, + $crate::ioctl::_IOC_WRITE, + $ty, + $nr, + ::std::mem::size_of::<$size>() as u32, + $($v),+ + ); + }; +} + +/// Declare an ioctl that reads and writes data. +/// +/// ``` +/// # #[macro_use] extern crate vmm_sys_util; +/// const VHOST: ::std::os::raw::c_uint = 0xAF; +/// ioctl_iowr_nr!(VHOST_GET_VRING_BASE, VHOST, 0x12, ::std::os::raw::c_int); +/// ``` +#[macro_export] +macro_rules! ioctl_iowr_nr { + ($name:ident, $ty:expr, $nr:expr, $size:ty) => { + ioctl_ioc_nr!( + $name, + $crate::ioctl::_IOC_READ | $crate::ioctl::_IOC_WRITE, + $ty, + $nr, + ::std::mem::size_of::<$size>() as u32 + ); + }; + ($name:ident, $ty:expr, $nr:expr, $size:ty, $($v:ident),+) => { + ioctl_ioc_nr!( + $name, + $crate::ioctl::_IOC_READ | $crate::ioctl::_IOC_WRITE, + $ty, + $nr, + ::std::mem::size_of::<$size>() as u32, + $($v),+ + ); + }; +} + +// Define IOC_* constants in a module so that we can allow missing docs on it. +// There is not much value in documenting these as it is code generated from +// kernel definitions. +#[allow(missing_docs)] +mod ioc { + use std::os::raw::c_uint; + + pub const _IOC_NRBITS: c_uint = 8; + pub const _IOC_TYPEBITS: c_uint = 8; + pub const _IOC_SIZEBITS: c_uint = 14; + pub const _IOC_DIRBITS: c_uint = 2; + pub const _IOC_NRMASK: c_uint = 255; + pub const _IOC_TYPEMASK: c_uint = 255; + pub const _IOC_SIZEMASK: c_uint = 16383; + pub const _IOC_DIRMASK: c_uint = 3; + pub const _IOC_NRSHIFT: c_uint = 0; + pub const _IOC_TYPESHIFT: c_uint = 8; + pub const _IOC_SIZESHIFT: c_uint = 16; + pub const _IOC_DIRSHIFT: c_uint = 30; + pub const _IOC_NONE: c_uint = 0; + pub const _IOC_WRITE: c_uint = 1; + pub const _IOC_READ: c_uint = 2; + pub const IOC_IN: c_uint = 1_073_741_824; + pub const IOC_OUT: c_uint = 2_147_483_648; + pub const IOC_INOUT: c_uint = 3_221_225_472; + pub const IOCSIZE_MASK: c_uint = 1_073_676_288; + pub const IOCSIZE_SHIFT: c_uint = 16; +} +pub use self::ioc::*; + +// The type of the `req` parameter is different for the `musl` library. This will enable +// successful build for other non-musl libraries. +#[cfg(target_env = "musl")] +type IoctlRequest = c_int; +#[cfg(all(not(target_env = "musl"), not(target_os = "android")))] +type IoctlRequest = c_ulong; +#[cfg(all(not(target_env = "musl"), target_os = "android"))] +type IoctlRequest = c_int; +/// Run an [`ioctl`](http://man7.org/linux/man-pages/man2/ioctl.2.html) +/// with no arguments. +/// +/// # Arguments +/// +/// * `fd`: an open file descriptor corresponding to the device on which +/// to call the ioctl. +/// * `req`: a device-dependent request code. +/// +/// # Safety +/// +/// The caller should ensure to pass a valid file descriptor and have the +/// return value checked. +/// +/// # Examples +/// +/// ``` +/// # extern crate libc; +/// # #[macro_use] extern crate vmm_sys_util; +/// # +/// # use libc::{open, O_CLOEXEC, O_RDWR}; +/// # use std::fs::File; +/// # use std::os::raw::{c_char, c_uint}; +/// # use std::os::unix::io::FromRawFd; +/// use vmm_sys_util::ioctl::ioctl; +/// +/// const KVMIO: c_uint = 0xAE; +/// const KVM_API_VERSION: u32 = 12; +/// ioctl_io_nr!(KVM_GET_API_VERSION, KVMIO, 0x00); +/// +/// let open_flags = O_RDWR | O_CLOEXEC; +/// let kvm_fd = unsafe { open("/dev/kvm\0".as_ptr() as *const c_char, open_flags) }; +/// +/// let ret = unsafe { ioctl(&File::from_raw_fd(kvm_fd), KVM_GET_API_VERSION()) }; +/// +/// assert_eq!(ret as u32, KVM_API_VERSION); +/// ``` +pub unsafe fn ioctl<F: AsRawFd>(fd: &F, req: c_ulong) -> c_int { + libc::ioctl(fd.as_raw_fd(), req as IoctlRequest, 0) +} + +/// Run an [`ioctl`](http://man7.org/linux/man-pages/man2/ioctl.2.html) +/// with a single value argument. +/// +/// # Arguments +/// +/// * `fd`: an open file descriptor corresponding to the device on which +/// to call the ioctl. +/// * `req`: a device-dependent request code. +/// * `arg`: a single value passed to ioctl. +/// +/// # Safety +/// +/// The caller should ensure to pass a valid file descriptor and have the +/// return value checked. +/// +/// # Examples +/// +/// ``` +/// # extern crate libc; +/// # #[macro_use] extern crate vmm_sys_util; +/// # use libc::{open, O_CLOEXEC, O_RDWR}; +/// # use std::fs::File; +/// # use std::os::raw::{c_char, c_uint, c_ulong}; +/// # use std::os::unix::io::FromRawFd; +/// use vmm_sys_util::ioctl::ioctl_with_val; +/// +/// const KVMIO: c_uint = 0xAE; +/// const KVM_CAP_USER_MEMORY: u32 = 3; +/// ioctl_io_nr!(KVM_CHECK_EXTENSION, KVMIO, 0x03); +/// +/// let open_flags = O_RDWR | O_CLOEXEC; +/// let kvm_fd = unsafe { open("/dev/kvm\0".as_ptr() as *const c_char, open_flags) }; +/// +/// let ret = unsafe { +/// ioctl_with_val( +/// &File::from_raw_fd(kvm_fd), +/// KVM_CHECK_EXTENSION(), +/// KVM_CAP_USER_MEMORY as c_ulong, +/// ) +/// }; +/// assert!(ret > 0); +/// ``` +pub unsafe fn ioctl_with_val<F: AsRawFd>(fd: &F, req: c_ulong, arg: c_ulong) -> c_int { + libc::ioctl(fd.as_raw_fd(), req as IoctlRequest, arg) +} + +/// Run an [`ioctl`](http://man7.org/linux/man-pages/man2/ioctl.2.html) +/// with an immutable reference. +/// +/// # Arguments +/// +/// * `fd`: an open file descriptor corresponding to the device on which +/// to call the ioctl. +/// * `req`: a device-dependent request code. +/// * `arg`: an immutable reference passed to ioctl. +/// +/// # Safety +/// +/// The caller should ensure to pass a valid file descriptor and have the +/// return value checked. +pub unsafe fn ioctl_with_ref<F: AsRawFd, T>(fd: &F, req: c_ulong, arg: &T) -> c_int { + libc::ioctl( + fd.as_raw_fd(), + req as IoctlRequest, + arg as *const T as *const c_void, + ) +} + +/// Run an [`ioctl`](http://man7.org/linux/man-pages/man2/ioctl.2.html) +/// with a mutable reference. +/// +/// # Arguments +/// +/// * `fd`: an open file descriptor corresponding to the device on which +/// to call the ioctl. +/// * `req`: a device-dependent request code. +/// * `arg`: a mutable reference passed to ioctl. +/// +/// # Safety +/// +/// The caller should ensure to pass a valid file descriptor and have the +/// return value checked. +pub unsafe fn ioctl_with_mut_ref<F: AsRawFd, T>(fd: &F, req: c_ulong, arg: &mut T) -> c_int { + libc::ioctl( + fd.as_raw_fd(), + req as IoctlRequest, + arg as *mut T as *mut c_void, + ) +} + +/// Run an [`ioctl`](http://man7.org/linux/man-pages/man2/ioctl.2.html) +/// with a raw pointer. +/// +/// # Arguments +/// +/// * `fd`: an open file descriptor corresponding to the device on which +/// to call the ioctl. +/// * `req`: a device-dependent request code. +/// * `arg`: a raw pointer passed to ioctl. +/// +/// # Safety +/// +/// The caller should ensure to pass a valid file descriptor and have the +/// return value checked. +pub unsafe fn ioctl_with_ptr<F: AsRawFd, T>(fd: &F, req: c_ulong, arg: *const T) -> c_int { + libc::ioctl(fd.as_raw_fd(), req as IoctlRequest, arg as *const c_void) +} + +/// Run an [`ioctl`](http://man7.org/linux/man-pages/man2/ioctl.2.html) +/// with a mutable raw pointer. +/// +/// # Arguments +/// +/// * `fd`: an open file descriptor corresponding to the device on which +/// to call the ioctl. +/// * `req`: a device-dependent request code. +/// * `arg`: a mutable raw pointer passed to ioctl. +/// +/// # Safety +/// +/// The caller should ensure to pass a valid file descriptor and have the +/// return value checked. +pub unsafe fn ioctl_with_mut_ptr<F: AsRawFd, T>(fd: &F, req: c_ulong, arg: *mut T) -> c_int { + libc::ioctl(fd.as_raw_fd(), req as IoctlRequest, arg as *mut c_void) +} + +#[cfg(test)] +mod tests { + const TUNTAP: ::std::os::raw::c_uint = 0x54; + const VHOST: ::std::os::raw::c_uint = 0xAF; + const EVDEV: ::std::os::raw::c_uint = 0x45; + + const KVMIO: ::std::os::raw::c_uint = 0xAE; + + ioctl_io_nr!(KVM_CREATE_VM, KVMIO, 0x01); + ioctl_ior_nr!(TUNGETFEATURES, TUNTAP, 0xcf, ::std::os::raw::c_uint); + ioctl_iow_nr!(TUNSETQUEUE, TUNTAP, 0xd9, ::std::os::raw::c_int); + ioctl_io_nr!(VHOST_SET_OWNER, VHOST, 0x01); + ioctl_iowr_nr!(VHOST_GET_VRING_BASE, VHOST, 0x12, ::std::os::raw::c_int); + ioctl_iowr_nr!(KVM_GET_MSR_INDEX_LIST, KVMIO, 0x2, ::std::os::raw::c_int); + + ioctl_ior_nr!(EVIOCGBIT, EVDEV, 0x20 + evt, [u8; 128], evt); + ioctl_io_nr!(FAKE_IOCTL_2_ARG, EVDEV, 0x01 + x + y, x, y); + + #[test] + fn test_ioctl_macros() { + assert_eq!(0x0000_AE01, KVM_CREATE_VM()); + assert_eq!(0x0000_AF01, VHOST_SET_OWNER()); + assert_eq!(0x8004_54CF, TUNGETFEATURES()); + assert_eq!(0x4004_54D9, TUNSETQUEUE()); + assert_eq!(0xC004_AE02, KVM_GET_MSR_INDEX_LIST()); + assert_eq!(0xC004_AF12, VHOST_GET_VRING_BASE()); + + assert_eq!(0x8080_4522, EVIOCGBIT(2)); + assert_eq!(0x0000_4509, FAKE_IOCTL_2_ARG(3, 5)); + } +} diff --git a/src/linux/mod.rs b/src/linux/mod.rs new file mode 100644 index 0000000..daf86f5 --- /dev/null +++ b/src/linux/mod.rs @@ -0,0 +1,15 @@ +// Copyright 2022 rust-vmm Authors or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: BSD-3-Clause + +#[macro_use] +pub mod ioctl; +pub mod aio; +pub mod epoll; +pub mod eventfd; +pub mod fallocate; +pub mod poll; +pub mod seek_hole; +pub mod signal; +pub mod sock_ctrl_msg; +pub mod timerfd; +pub mod write_zeroes; diff --git a/src/linux/poll.rs b/src/linux/poll.rs new file mode 100644 index 0000000..12809f0 --- /dev/null +++ b/src/linux/poll.rs @@ -0,0 +1,1010 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Traits and structures for working with +//! [`epoll`](http://man7.org/linux/man-pages/man7/epoll.7.html) + +use std::cell::{Cell, Ref, RefCell}; +use std::cmp::min; +use std::fs::File; +use std::i32; +use std::i64; +use std::io::{stderr, Cursor, Write}; +use std::marker::PhantomData; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +use std::ptr::null_mut; +use std::slice; +use std::thread; +use std::time::Duration; + +use libc::{ + c_int, epoll_create1, epoll_ctl, epoll_event, epoll_wait, EINTR, EPOLLERR, EPOLLHUP, EPOLLIN, + EPOLLOUT, EPOLL_CLOEXEC, EPOLL_CTL_ADD, EPOLL_CTL_DEL, EPOLL_CTL_MOD, +}; + +use crate::errno::{errno_result, Error, Result}; + +macro_rules! handle_eintr_errno { + ($x:expr) => {{ + let mut res; + loop { + res = $x; + if res != -1 || Error::last() != Error::new(EINTR) { + break; + } + } + res + }}; +} + +const POLL_CONTEXT_MAX_EVENTS: usize = 16; + +/// A wrapper of raw `libc::epoll_event`. +/// +/// This should only be used with [`EpollContext`](struct.EpollContext.html). +pub struct EpollEvents(RefCell<[epoll_event; POLL_CONTEXT_MAX_EVENTS]>); + +impl EpollEvents { + /// Creates a new EpollEvents. + pub fn new() -> EpollEvents { + EpollEvents(RefCell::new( + [epoll_event { events: 0, u64: 0 }; POLL_CONTEXT_MAX_EVENTS], + )) + } +} + +impl Default for EpollEvents { + fn default() -> Self { + Self::new() + } +} + +/// Trait for a token that can be associated with an `fd` in a [`PollContext`](struct.PollContext.html). +/// +/// Simple enums that have no or primitive variant data can use the `#[derive(PollToken)]` +/// custom derive to implement this trait. +pub trait PollToken { + /// Converts this token into a u64 that can be turned back into a token via `from_raw_token`. + fn as_raw_token(&self) -> u64; + + /// Converts a raw token as returned from `as_raw_token` back into a token. + /// + /// It is invalid to give a raw token that was not returned via `as_raw_token` from the same + /// `Self`. The implementation can expect that this will never happen as a result of its usage + /// in `PollContext`. + fn from_raw_token(data: u64) -> Self; +} + +impl PollToken for usize { + fn as_raw_token(&self) -> u64 { + *self as u64 + } + + fn from_raw_token(data: u64) -> Self { + data as Self + } +} + +impl PollToken for u64 { + fn as_raw_token(&self) -> u64 { + *self as u64 + } + + fn from_raw_token(data: u64) -> Self { + data as Self + } +} + +impl PollToken for u32 { + fn as_raw_token(&self) -> u64 { + u64::from(*self) + } + + fn from_raw_token(data: u64) -> Self { + data as Self + } +} + +impl PollToken for u16 { + fn as_raw_token(&self) -> u64 { + u64::from(*self) + } + + fn from_raw_token(data: u64) -> Self { + data as Self + } +} + +impl PollToken for u8 { + fn as_raw_token(&self) -> u64 { + u64::from(*self) + } + + fn from_raw_token(data: u64) -> Self { + data as Self + } +} + +impl PollToken for () { + fn as_raw_token(&self) -> u64 { + 0 + } + + fn from_raw_token(_data: u64) -> Self {} +} + +/// An event returned by [`PollContext::wait`](struct.PollContext.html#method.wait). +pub struct PollEvent<'a, T> { + event: &'a epoll_event, + token: PhantomData<T>, // Needed to satisfy usage of T +} + +impl<'a, T: PollToken> PollEvent<'a, T> { + /// Gets the token associated in + /// [`PollContext::add`](struct.PollContext.html#method.add) with this event. + pub fn token(&self) -> T { + T::from_raw_token(self.event.u64) + } + + /// Get the raw events returned by the kernel. + pub fn raw_events(&self) -> u32 { + self.event.events + } + + /// Checks if the event is readable. + /// + /// True if the `fd` associated with this token in + /// [`PollContext::add`](struct.PollContext.html#method.add) is readable. + pub fn readable(&self) -> bool { + self.event.events & (EPOLLIN as u32) != 0 + } + + /// Checks if the event is writable. + /// + /// True if the `fd` associated with this token in + /// [`PollContext::add`](struct.PollContext.html#method.add) is writable. + pub fn writable(&self) -> bool { + self.event.events & (EPOLLOUT as u32) != 0 + } + + /// Checks if the event has been hangup on. + /// + /// True if the `fd` associated with this token in + /// [`PollContext::add`](struct.PollContext.html#method.add) has been hungup on. + pub fn hungup(&self) -> bool { + self.event.events & (EPOLLHUP as u32) != 0 + } + + /// Checks if the event has associated error conditions. + /// + /// True if the `fd` associated with this token in + /// [`PollContext::add`](struct.PollContext.html#method.add) has associated error conditions. + pub fn has_error(&self) -> bool { + self.event.events & (EPOLLERR as u32) != 0 + } +} + +/// An iterator over a subset of events returned by +/// [`PollContext::wait`](struct.PollContext.html#method.wait). +pub struct PollEventIter<'a, I, T> +where + I: Iterator<Item = &'a epoll_event>, +{ + mask: u32, + iter: I, + tokens: PhantomData<[T]>, // Needed to satisfy usage of T +} + +impl<'a, I, T> Iterator for PollEventIter<'a, I, T> +where + I: Iterator<Item = &'a epoll_event>, + T: PollToken, +{ + type Item = PollEvent<'a, T>; + fn next(&mut self) -> Option<Self::Item> { + let mask = self.mask; + self.iter + .find(|event| (event.events & mask) != 0) + .map(|event| PollEvent { + event, + token: PhantomData, + }) + } +} + +/// The list of events returned by [`PollContext::wait`](struct.PollContext.html#method.wait). +pub struct PollEvents<'a, T> { + count: usize, + events: Ref<'a, [epoll_event; POLL_CONTEXT_MAX_EVENTS]>, + tokens: PhantomData<[T]>, // Needed to satisfy usage of T +} + +impl<'a, T: PollToken> PollEvents<'a, T> { + /// Creates owned structure from borrowed [`PollEvents`](struct.PollEvents.html). + /// + /// Copies the events to an owned structure so the reference to this (and by extension + /// [`PollContext`](struct.PollContext.html)) can be dropped. + pub fn to_owned(&self) -> PollEventsOwned<T> { + PollEventsOwned { + count: self.count, + events: RefCell::new(*self.events), + tokens: PhantomData, + } + } + + /// Iterates over each event. + pub fn iter(&self) -> PollEventIter<'_, slice::Iter<'_, epoll_event>, T> { + PollEventIter { + mask: 0xffff_ffff, + iter: self.events[..self.count].iter(), + tokens: PhantomData, + } + } + + /// Iterates over each readable event. + pub fn iter_readable(&self) -> PollEventIter<'_, slice::Iter<'_, epoll_event>, T> { + PollEventIter { + mask: EPOLLIN as u32, + iter: self.events[..self.count].iter(), + tokens: PhantomData, + } + } + + /// Iterates over each hungup event. + pub fn iter_hungup(&self) -> PollEventIter<'_, slice::Iter<'_, epoll_event>, T> { + PollEventIter { + mask: EPOLLHUP as u32, + iter: self.events[..self.count].iter(), + tokens: PhantomData, + } + } +} + +/// A deep copy of the event records from [`PollEvents`](struct.PollEvents.html). +pub struct PollEventsOwned<T> { + count: usize, + events: RefCell<[epoll_event; POLL_CONTEXT_MAX_EVENTS]>, + tokens: PhantomData<T>, // Needed to satisfy usage of T +} + +impl<T: PollToken> PollEventsOwned<T> { + /// Creates borrowed structure from owned structure + /// [`PollEventsOwned`](struct.PollEventsOwned.html). + /// + /// Takes a reference to the events so it can be iterated via methods in + /// [`PollEvents`](struct.PollEvents.html). + pub fn as_ref(&self) -> PollEvents<'_, T> { + PollEvents { + count: self.count, + events: self.events.borrow(), + tokens: PhantomData, + } + } +} + +/// Watching events taken by [`PollContext`](struct.PollContext.html). +#[derive(Copy, Clone)] +pub struct WatchingEvents(u32); + +impl WatchingEvents { + /// Returns empty `WatchingEvents`. + #[inline(always)] + pub fn empty() -> WatchingEvents { + WatchingEvents(0) + } + + /// Creates a new `WatchingEvents` with a specified value. + /// + /// Builds `WatchingEvents` from raw `epoll_event`. + /// + /// # Arguments + /// + /// * `raw`: the events to be created for watching. + #[inline(always)] + pub fn new(raw: u32) -> WatchingEvents { + WatchingEvents(raw) + } + + /// Sets read events. + /// + /// Sets the events to be readable. + #[inline(always)] + pub fn set_read(self) -> WatchingEvents { + WatchingEvents(self.0 | EPOLLIN as u32) + } + + /// Sets write events. + /// + /// Sets the events to be writable. + #[inline(always)] + pub fn set_write(self) -> WatchingEvents { + WatchingEvents(self.0 | EPOLLOUT as u32) + } + + /// Gets the underlying epoll events. + pub fn get_raw(&self) -> u32 { + self.0 + } +} + +/// A wrapper of linux [`epoll`](http://man7.org/linux/man-pages/man7/epoll.7.html). +/// +/// It provides similar interface to [`PollContext`](struct.PollContext.html). +/// It is thread safe while PollContext is not. It requires user to pass in a reference of +/// EpollEvents while PollContext does not. Always use PollContext if you don't need to access the +/// same epoll from different threads. +/// +/// # Examples +/// +/// ``` +/// extern crate vmm_sys_util; +/// use vmm_sys_util::eventfd::EventFd; +/// use vmm_sys_util::poll::{EpollContext, EpollEvents}; +/// +/// let evt = EventFd::new(0).unwrap(); +/// let ctx: EpollContext<u32> = EpollContext::new().unwrap(); +/// let events = EpollEvents::new(); +/// +/// evt.write(1).unwrap(); +/// ctx.add(&evt, 1).unwrap(); +/// +/// for event in ctx.wait(&events).unwrap().iter_readable() { +/// assert_eq!(event.token(), 1); +/// } +/// ``` +pub struct EpollContext<T> { + epoll_ctx: File, + // Needed to satisfy usage of T + tokens: PhantomData<[T]>, +} + +impl<T: PollToken> EpollContext<T> { + /// Creates a new `EpollContext`. + /// + /// Uses [`epoll_create1`](http://man7.org/linux/man-pages/man2/epoll_create.2.html) + /// to create a new epoll fd. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::poll::EpollContext; + /// + /// let ctx: EpollContext<usize> = EpollContext::new().unwrap(); + /// ``` + pub fn new() -> Result<EpollContext<T>> { + // SAFETY: Safe because we check the return value. + let epoll_fd = unsafe { epoll_create1(EPOLL_CLOEXEC) }; + if epoll_fd < 0 { + return errno_result(); + } + Ok(EpollContext { + // SAFETY: Safe because we verified that the FD is valid and we trust `epoll_create1`. + epoll_ctx: unsafe { File::from_raw_fd(epoll_fd) }, + tokens: PhantomData, + }) + } + + /// Adds the given `fd` to this context and associates the given + /// `token` with the `fd`'s readable events. + /// + /// A `fd` can only be added once and does not need to be kept open. + /// If the `fd` is dropped and there were no duplicated file descriptors + /// (i.e. adding the same descriptor with a different FD number) added + /// to this context, events will not be reported by `wait` anymore. + /// + /// # Arguments + /// + /// * `fd`: the target file descriptor to be added. + /// * `token`: a `PollToken` implementation, used to be as u64 of `libc::epoll_event` structure. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::eventfd::EventFd; + /// use vmm_sys_util::poll::EpollContext; + /// + /// let evt = EventFd::new(0).unwrap(); + /// let ctx: EpollContext<u32> = EpollContext::new().unwrap(); + /// ctx.add(&evt, 1).unwrap(); + /// ``` + pub fn add(&self, fd: &dyn AsRawFd, token: T) -> Result<()> { + self.add_fd_with_events(fd, WatchingEvents::empty().set_read(), token) + } + + /// Adds the given `fd` to this context, watching for the specified `events` + /// and associates the given 'token' with those events. + /// + /// A `fd` can only be added once and does not need to be kept open. If the `fd` + /// is dropped and there were no duplicated file descriptors (i.e. adding the same + /// descriptor with a different FD number) added to this context, events will + /// not be reported by `wait` anymore. + /// + /// # Arguments + /// + /// * `fd`: the target file descriptor to be added. + /// * `events`: specifies the events to be watched. + /// * `token`: a `PollToken` implementation, used to be as u64 of `libc::epoll_event` structure. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::eventfd::EventFd; + /// use vmm_sys_util::poll::{EpollContext, WatchingEvents}; + /// + /// let evt = EventFd::new(0).unwrap(); + /// let ctx: EpollContext<u32> = EpollContext::new().unwrap(); + /// ctx.add_fd_with_events(&evt, WatchingEvents::empty().set_read(), 1) + /// .unwrap(); + /// ``` + pub fn add_fd_with_events( + &self, + fd: &dyn AsRawFd, + events: WatchingEvents, + token: T, + ) -> Result<()> { + let mut evt = epoll_event { + events: events.get_raw(), + u64: token.as_raw_token(), + }; + // SAFETY: Safe because we give a valid epoll FD and FD to watch, as well as a + // valid epoll_event structure. Then we check the return value. + let ret = unsafe { + epoll_ctl( + self.epoll_ctx.as_raw_fd(), + EPOLL_CTL_ADD, + fd.as_raw_fd(), + &mut evt, + ) + }; + if ret < 0 { + return errno_result(); + }; + Ok(()) + } + + /// Changes the setting associated with the given `fd` in this context. + /// + /// If `fd` was previously added to this context, the watched events will be replaced with + /// `events` and the token associated with it will be replaced with the given `token`. + /// + /// # Arguments + /// + /// * `fd`: the target file descriptor to be performed. + /// * `events`: specifies the events to be watched. + /// * `token`: a `PollToken` implementation, used to be as u64 of `libc::epoll_event` structure. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::eventfd::EventFd; + /// use vmm_sys_util::poll::{EpollContext, WatchingEvents}; + /// + /// let evt = EventFd::new(0).unwrap(); + /// let ctx: EpollContext<u32> = EpollContext::new().unwrap(); + /// ctx.add_fd_with_events(&evt, WatchingEvents::empty().set_read(), 1) + /// .unwrap(); + /// ctx.modify(&evt, WatchingEvents::empty().set_write(), 2) + /// .unwrap(); + /// ``` + pub fn modify(&self, fd: &dyn AsRawFd, events: WatchingEvents, token: T) -> Result<()> { + let mut evt = epoll_event { + events: events.0, + u64: token.as_raw_token(), + }; + // SAFETY: Safe because we give a valid epoll FD and FD to modify, as well as a valid + // epoll_event structure. Then we check the return value. + let ret = unsafe { + epoll_ctl( + self.epoll_ctx.as_raw_fd(), + EPOLL_CTL_MOD, + fd.as_raw_fd(), + &mut evt, + ) + }; + if ret < 0 { + return errno_result(); + }; + Ok(()) + } + + /// Deletes the given `fd` from this context. + /// + /// If an `fd`'s token shows up in the list of hangup events, it should be removed using this + /// method or by closing/dropping (if and only if the fd was never dup()'d/fork()'d) the `fd`. + /// Failure to do so will cause the `wait` method to always return immediately, causing ~100% + /// CPU load. + /// + /// # Arguments + /// + /// * `fd`: the target file descriptor to be removed. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::eventfd::EventFd; + /// use vmm_sys_util::poll::EpollContext; + /// + /// let evt = EventFd::new(0).unwrap(); + /// let ctx: EpollContext<u32> = EpollContext::new().unwrap(); + /// ctx.add(&evt, 1).unwrap(); + /// ctx.delete(&evt).unwrap(); + /// ``` + pub fn delete(&self, fd: &dyn AsRawFd) -> Result<()> { + // SAFETY: Safe because we give a valid epoll FD and FD to stop watching. Then we check + // the return value. + let ret = unsafe { + epoll_ctl( + self.epoll_ctx.as_raw_fd(), + EPOLL_CTL_DEL, + fd.as_raw_fd(), + null_mut(), + ) + }; + if ret < 0 { + return errno_result(); + }; + Ok(()) + } + + /// Waits for any events to occur in FDs that were previously added to this context. + /// + /// The events are level-triggered, meaning that if any events are unhandled (i.e. not reading + /// for readable events and not closing for hungup events), subsequent calls to `wait` will + /// return immediately. The consequence of not handling an event perpetually while calling + /// `wait` is that the callers loop will degenerated to busy loop polling, pinning a CPU to + /// ~100% usage. + /// + /// # Arguments + /// + /// * `events`: the events to wait for. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// use vmm_sys_util::eventfd::EventFd; + /// use vmm_sys_util::poll::{EpollContext, EpollEvents}; + /// + /// let evt = EventFd::new(0).unwrap(); + /// let ctx: EpollContext<u32> = EpollContext::new().unwrap(); + /// let events = EpollEvents::new(); + /// + /// evt.write(1).unwrap(); + /// ctx.add(&evt, 1).unwrap(); + /// + /// for event in ctx.wait(&events).unwrap().iter_readable() { + /// assert_eq!(event.token(), 1); + /// } + /// ``` + pub fn wait<'a>(&self, events: &'a EpollEvents) -> Result<PollEvents<'a, T>> { + self.wait_timeout(events, Duration::new(i64::MAX as u64, 0)) + } + + /// Like [`wait`](struct.EpollContext.html#method.wait) except will only block for a + /// maximum of the given `timeout`. + /// + /// This may return earlier than `timeout` with zero events if the duration indicated exceeds + /// system limits. + /// + /// # Arguments + /// + /// * `events`: the events to wait for. + /// * `timeout`: specifies the timeout that will block. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// # use std::time::Duration; + /// use vmm_sys_util::eventfd::EventFd; + /// use vmm_sys_util::poll::{EpollContext, EpollEvents}; + /// + /// let evt = EventFd::new(0).unwrap(); + /// let ctx: EpollContext<u32> = EpollContext::new().unwrap(); + /// let events = EpollEvents::new(); + /// + /// evt.write(1).unwrap(); + /// ctx.add(&evt, 1).unwrap(); + /// for event in ctx + /// .wait_timeout(&events, Duration::new(100, 0)) + /// .unwrap() + /// .iter_readable() + /// { + /// assert_eq!(event.token(), 1); + /// } + /// ``` + pub fn wait_timeout<'a>( + &self, + events: &'a EpollEvents, + timeout: Duration, + ) -> Result<PollEvents<'a, T>> { + let timeout_millis = if timeout.as_secs() as i64 == i64::max_value() { + // We make the convenient assumption that 2^63 seconds is an effectively unbounded time + // frame. This is meant to mesh with `wait` calling us with no timeout. + -1 + } else { + // In cases where we the number of milliseconds would overflow an i32, we substitute the + // maximum timeout which is ~24.8 days. + let millis = timeout + .as_secs() + .checked_mul(1_000) + .and_then(|ms| ms.checked_add(u64::from(timeout.subsec_nanos()) / 1_000_000)) + .unwrap_or(i32::max_value() as u64); + min(i32::max_value() as u64, millis) as i32 + }; + let ret = { + let mut epoll_events = events.0.borrow_mut(); + let max_events = epoll_events.len() as c_int; + // SAFETY: Safe because we give an epoll context and a properly sized epoll_events + // array pointer, which we trust the kernel to fill in properly. + unsafe { + handle_eintr_errno!(epoll_wait( + self.epoll_ctx.as_raw_fd(), + &mut epoll_events[0], + max_events, + timeout_millis + )) + } + }; + if ret < 0 { + return errno_result(); + } + let epoll_events = events.0.borrow(); + let events = PollEvents { + count: ret as usize, + events: epoll_events, + tokens: PhantomData, + }; + Ok(events) + } +} + +impl<T: PollToken> AsRawFd for EpollContext<T> { + fn as_raw_fd(&self) -> RawFd { + self.epoll_ctx.as_raw_fd() + } +} + +impl<T: PollToken> IntoRawFd for EpollContext<T> { + fn into_raw_fd(self) -> RawFd { + self.epoll_ctx.into_raw_fd() + } +} + +/// Used to poll multiple objects that have file descriptors. +/// +/// # Example +/// +/// ``` +/// # use vmm_sys_util::errno::Result; +/// # use vmm_sys_util::eventfd::EventFd; +/// # use vmm_sys_util::poll::{PollContext, PollEvents}; +/// let evt1 = EventFd::new(0).unwrap(); +/// let evt2 = EventFd::new(0).unwrap(); +/// evt2.write(1).unwrap(); +/// +/// let ctx: PollContext<u32> = PollContext::new().unwrap(); +/// ctx.add(&evt1, 1).unwrap(); +/// ctx.add(&evt2, 2).unwrap(); +/// +/// let pollevents: PollEvents<u32> = ctx.wait().unwrap(); +/// let tokens: Vec<u32> = pollevents.iter_readable().map(|e| e.token()).collect(); +/// assert_eq!(&tokens[..], &[2]); +/// ``` +pub struct PollContext<T> { + epoll_ctx: EpollContext<T>, + + // We use a RefCell here so that the `wait` method only requires an immutable self reference + // while returning the events (encapsulated by PollEvents). Without the RefCell, `wait` would + // hold a mutable reference that lives as long as its returned reference (i.e. the PollEvents), + // even though that reference is immutable. This is terribly inconvenient for the caller because + // the borrow checking would prevent them from using `delete` and `add` while the events are in + // scope. + events: EpollEvents, + + // Hangup busy loop detection variables. See `check_for_hungup_busy_loop`. + check_for_hangup: bool, + hangups: Cell<usize>, + max_hangups: Cell<usize>, +} + +impl<T: PollToken> PollContext<T> { + /// Creates a new `PollContext`. + pub fn new() -> Result<PollContext<T>> { + Ok(PollContext { + epoll_ctx: EpollContext::new()?, + events: EpollEvents::new(), + check_for_hangup: true, + hangups: Cell::new(0), + max_hangups: Cell::new(0), + }) + } + + /// Enable/disable of checking for unhandled hangup events. + pub fn set_check_for_hangup(&mut self, enable: bool) { + self.check_for_hangup = enable; + } + + /// Adds the given `fd` to this context and associates the given `token` with the `fd`'s + /// readable events. + /// + /// A `fd` can only be added once and does not need to be kept open. If the `fd` is dropped and + /// there were no duplicated file descriptors (i.e. adding the same descriptor with a different + /// FD number) added to this context, events will not be reported by `wait` anymore. + /// + /// # Arguments + /// + /// * `fd`: the target file descriptor to be added. + /// * `token`: a `PollToken` implementation, used to be as u64 of `libc::epoll_event` structure. + pub fn add(&self, fd: &dyn AsRawFd, token: T) -> Result<()> { + self.add_fd_with_events(fd, WatchingEvents::empty().set_read(), token) + } + + /// Adds the given `fd` to this context, watching for the specified events and associates the + /// given 'token' with those events. + /// + /// A `fd` can only be added once and does not need to be kept open. If the `fd` is dropped and + /// there were no duplicated file descriptors (i.e. adding the same descriptor with a different + /// FD number) added to this context, events will not be reported by `wait` anymore. + /// + /// # Arguments + /// + /// * `fd`: the target file descriptor to be added. + /// * `events`: specifies the events to be watched. + /// * `token`: a `PollToken` implementation, used to be as u64 of `libc::epoll_event` structure. + pub fn add_fd_with_events( + &self, + fd: &dyn AsRawFd, + events: WatchingEvents, + token: T, + ) -> Result<()> { + self.epoll_ctx.add_fd_with_events(fd, events, token)?; + self.hangups.set(0); + self.max_hangups.set(self.max_hangups.get() + 1); + Ok(()) + } + + /// Changes the setting associated with the given `fd` in this context. + /// + /// If `fd` was previously added to this context, the watched events will be replaced with + /// `events` and the token associated with it will be replaced with the given `token`. + /// + /// # Arguments + /// + /// * `fd`: the target file descriptor to be modified. + /// * `events`: specifies the events to be watched. + /// * `token`: a `PollToken` implementation, used to be as u64 of `libc::epoll_event` structure. + pub fn modify(&self, fd: &dyn AsRawFd, events: WatchingEvents, token: T) -> Result<()> { + self.epoll_ctx.modify(fd, events, token) + } + + /// Deletes the given `fd` from this context. + /// + /// If an `fd`'s token shows up in the list of hangup events, it should be removed using this + /// method or by closing/dropping (if and only if the fd was never dup()'d/fork()'d) the `fd`. + /// Failure to do so will cause the `wait` method to always return immediately, causing ~100% + /// CPU load. + /// + /// # Arguments + /// + /// * `fd`: the target file descriptor to be removed. + pub fn delete(&self, fd: &dyn AsRawFd) -> Result<()> { + self.epoll_ctx.delete(fd)?; + self.hangups.set(0); + self.max_hangups.set(self.max_hangups.get() - 1); + Ok(()) + } + + // This method determines if the the user of wait is misusing the `PollContext` by leaving FDs + // in this `PollContext` that have been shutdown or hungup on. Such an FD will cause `wait` to + // return instantly with a hungup event. If that FD is perpetually left in this context, a busy + // loop burning ~100% of one CPU will silently occur with no human visible malfunction. + // + // How do we know if the client of this context is ignoring hangups? A naive implementation + // would trigger if consecutive wait calls yield hangup events, but there are legitimate cases + // for this, such as two distinct sockets becoming hungup across two consecutive wait calls. A + // smarter implementation would only trigger if `delete` wasn't called between waits that + // yielded hangups. Sadly `delete` isn't the only way to remove an FD from this context. The + // other way is for the client to close the hungup FD, which automatically removes it from this + // context. Assuming that the client always uses close, this implementation would too eagerly + // trigger. + // + // The implementation used here keeps an upper bound of FDs in this context using a counter + // hooked into add/delete (which is imprecise because close can also remove FDs without us + // knowing). The number of consecutive (no add or delete in between) hangups yielded by wait + // calls is counted and compared to the upper bound. If the upper bound is exceeded by the + // consecutive hangups, the implementation triggers the check and logs. + // + // This implementation has false negatives because the upper bound can be completely too high, + // in the worst case caused by only using close instead of delete. However, this method has the + // advantage of always triggering eventually genuine busy loop cases, requires no dynamic + // allocations, is fast and constant time to compute, and has no false positives. + fn check_for_hungup_busy_loop(&self, new_hangups: usize) { + let old_hangups = self.hangups.get(); + let max_hangups = self.max_hangups.get(); + if old_hangups <= max_hangups && old_hangups + new_hangups > max_hangups { + let mut buf = [0u8; 512]; + let (res, len) = { + let mut buf_cursor = Cursor::new(&mut buf[..]); + ( + writeln!( + &mut buf_cursor, + "[{}:{}] busy poll wait loop with hungup FDs detected on thread {}\n", + file!(), + line!(), + thread::current().name().unwrap_or("") + ), + buf_cursor.position() as usize, + ) + }; + + if res.is_ok() { + let _ = stderr().write_all(&buf[..len]); + } + // This panic is helpful for tests of this functionality. + #[cfg(test)] + panic!("hungup busy loop detected"); + } + self.hangups.set(old_hangups + new_hangups); + } + + /// Waits for any events to occur in FDs that were previously added to this context. + /// + /// The events are level-triggered, meaning that if any events are unhandled (i.e. not reading + /// for readable events and not closing for hungup events), subsequent calls to `wait` will + /// return immediately. The consequence of not handling an event perpetually while calling + /// `wait` is that the callers loop will degenerated to busy loop polling, pinning a CPU to + /// ~100% usage. + /// + /// # Panics + /// Panics if the returned `PollEvents` structure is not dropped before subsequent `wait` calls. + pub fn wait(&self) -> Result<PollEvents<'_, T>> { + self.wait_timeout(Duration::new(i64::MAX as u64, 0)) + } + + /// Like [`wait`](struct.EpollContext.html#method.wait) except will only block for a + /// maximum of the given `timeout`. + /// + /// This may return earlier than `timeout` with zero events if the duration indicated exceeds + /// system limits. + /// + /// # Arguments + /// + /// * `timeout`: specify the time that will block. + pub fn wait_timeout(&self, timeout: Duration) -> Result<PollEvents<'_, T>> { + let events = self.epoll_ctx.wait_timeout(&self.events, timeout)?; + let hangups = events.iter_hungup().count(); + if self.check_for_hangup { + self.check_for_hungup_busy_loop(hangups); + } + Ok(events) + } +} + +impl<T: PollToken> AsRawFd for PollContext<T> { + fn as_raw_fd(&self) -> RawFd { + self.epoll_ctx.as_raw_fd() + } +} + +impl<T: PollToken> IntoRawFd for PollContext<T> { + fn into_raw_fd(self) -> RawFd { + self.epoll_ctx.into_raw_fd() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::eventfd::EventFd; + use std::os::unix::net::UnixStream; + use std::time::Instant; + + #[test] + fn test_poll_context() { + let evt1 = EventFd::new(0).unwrap(); + let evt2 = EventFd::new(0).unwrap(); + evt1.write(1).unwrap(); + evt2.write(1).unwrap(); + let ctx: PollContext<u32> = PollContext::new().unwrap(); + ctx.add(&evt1, 1).unwrap(); + ctx.add(&evt2, 2).unwrap(); + + let mut evt_count = 0; + while evt_count < 2 { + for event in ctx.wait().unwrap().iter_readable() { + evt_count += 1; + match event.token() { + 1 => { + evt1.read().unwrap(); + ctx.delete(&evt1).unwrap(); + } + 2 => { + evt2.read().unwrap(); + ctx.delete(&evt2).unwrap(); + } + _ => panic!("unexpected token"), + }; + } + } + assert_eq!(evt_count, 2); + } + + #[test] + fn test_poll_context_overflow() { + const EVT_COUNT: usize = POLL_CONTEXT_MAX_EVENTS * 2 + 1; + let ctx: PollContext<usize> = PollContext::new().unwrap(); + let mut evts = Vec::with_capacity(EVT_COUNT); + for i in 0..EVT_COUNT { + let evt = EventFd::new(0).unwrap(); + evt.write(1).unwrap(); + ctx.add(&evt, i).unwrap(); + evts.push(evt); + } + let mut evt_count = 0; + while evt_count < EVT_COUNT { + for event in ctx.wait().unwrap().iter_readable() { + evts[event.token()].read().unwrap(); + evt_count += 1; + } + } + } + + #[test] + #[should_panic] + fn test_poll_context_hungup() { + let (s1, s2) = UnixStream::pair().unwrap(); + let ctx: PollContext<u32> = PollContext::new().unwrap(); + ctx.add(&s1, 1).unwrap(); + + // Causes s1 to receive hangup events, which we purposefully ignore to trip the detection + // logic in `PollContext`. + drop(s2); + + // Should easily panic within this many iterations. + for _ in 0..1000 { + ctx.wait().unwrap(); + } + } + + #[test] + fn test_poll_context_timeout() { + let mut ctx: PollContext<u32> = PollContext::new().unwrap(); + let dur = Duration::from_millis(10); + let start_inst = Instant::now(); + + ctx.set_check_for_hangup(false); + ctx.wait_timeout(dur).unwrap(); + assert!(start_inst.elapsed() >= dur); + } + + #[test] + fn test_poll_event() { + let event = epoll_event { + events: (EPOLLIN | EPOLLERR | EPOLLOUT | EPOLLHUP) as u32, + u64: 0x10, + }; + let ev = PollEvent::<u32> { + event: &event, + token: PhantomData, + }; + + assert_eq!(ev.token(), 0x10); + assert!(ev.readable()); + assert!(ev.writable()); + assert!(ev.hungup()); + assert!(ev.has_error()); + assert_eq!( + ev.raw_events(), + (EPOLLIN | EPOLLERR | EPOLLOUT | EPOLLHUP) as u32 + ); + } +} diff --git a/src/linux/seek_hole.rs b/src/linux/seek_hole.rs new file mode 100644 index 0000000..1392993 --- /dev/null +++ b/src/linux/seek_hole.rs @@ -0,0 +1,228 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Traits and implementations over [lseek64](https://linux.die.net/man/3/lseek64). + +use std::fs::File; +use std::io::{Error, Result}; +use std::os::unix::io::AsRawFd; + +#[cfg(target_env = "musl")] +use libc::{c_int, lseek64, ENXIO}; + +#[cfg(target_env = "gnu")] +use libc::{lseek64, ENXIO, SEEK_DATA, SEEK_HOLE}; + +#[cfg(all(not(target_env = "musl"), target_os = "android"))] +use libc::{lseek64, ENXIO, SEEK_DATA, SEEK_HOLE}; + +/// A trait for seeking to the next hole or non-hole position in a file. +pub trait SeekHole { + /// Seek to the first hole in a file. + /// + /// Seek at a position greater than or equal to `offset`. If no holes exist + /// after `offset`, the seek position will be set to the end of the file. + /// If `offset` is at or after the end of the file, the seek position is + /// unchanged, and None is returned. + /// + /// Returns the current seek position after the seek or an error. + fn seek_hole(&mut self, offset: u64) -> Result<Option<u64>>; + + /// Seek to the first data in a file. + /// + /// Seek at a position greater than or equal to `offset`. + /// If no data exists after `offset`, the seek position is unchanged, + /// and None is returned. + /// + /// Returns the current offset after the seek or an error. + fn seek_data(&mut self, offset: u64) -> Result<Option<u64>>; +} + +#[cfg(target_env = "musl")] +const SEEK_DATA: c_int = 3; +#[cfg(target_env = "musl")] +const SEEK_HOLE: c_int = 4; + +// Safe wrapper for `libc::lseek64()` +fn lseek(file: &mut File, offset: i64, whence: i32) -> Result<Option<u64>> { + // SAFETY: This is safe because we pass a known-good file descriptor. + let res = unsafe { lseek64(file.as_raw_fd(), offset, whence) }; + + if res < 0 { + // Convert ENXIO into None; pass any other error as-is. + let err = Error::last_os_error(); + if let Some(errno) = Error::raw_os_error(&err) { + if errno == ENXIO { + return Ok(None); + } + } + Err(err) + } else { + Ok(Some(res as u64)) + } +} + +impl SeekHole for File { + fn seek_hole(&mut self, offset: u64) -> Result<Option<u64>> { + lseek(self, offset as i64, SEEK_HOLE) + } + + fn seek_data(&mut self, offset: u64) -> Result<Option<u64>> { + lseek(self, offset as i64, SEEK_DATA) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tempdir::TempDir; + use std::fs::File; + use std::io::{Seek, SeekFrom, Write}; + use std::path::PathBuf; + + fn seek_cur(file: &mut File) -> u64 { + file.seek(SeekFrom::Current(0)).unwrap() + } + + #[test] + fn seek_data() { + let tempdir = TempDir::new_with_prefix("/tmp/seek_data_test").unwrap(); + let mut path = PathBuf::from(tempdir.as_path()); + path.push("test_file"); + let mut file = File::create(&path).unwrap(); + + // Empty file + assert_eq!(file.seek_data(0).unwrap(), None); + assert_eq!(seek_cur(&mut file), 0); + + // File with non-zero length consisting entirely of a hole + file.set_len(0x10000).unwrap(); + assert_eq!(file.seek_data(0).unwrap(), None); + assert_eq!(seek_cur(&mut file), 0); + + // seek_data at or after the end of the file should return None + assert_eq!(file.seek_data(0x10000).unwrap(), None); + assert_eq!(seek_cur(&mut file), 0); + assert_eq!(file.seek_data(0x10001).unwrap(), None); + assert_eq!(seek_cur(&mut file), 0); + + // Write some data to [0x10000, 0x20000) + let b = [0x55u8; 0x10000]; + file.seek(SeekFrom::Start(0x10000)).unwrap(); + file.write_all(&b).unwrap(); + assert_eq!(file.seek_data(0).unwrap(), Some(0x10000)); + assert_eq!(seek_cur(&mut file), 0x10000); + + // seek_data within data should return the same offset + assert_eq!(file.seek_data(0x10000).unwrap(), Some(0x10000)); + assert_eq!(seek_cur(&mut file), 0x10000); + assert_eq!(file.seek_data(0x10001).unwrap(), Some(0x10001)); + assert_eq!(seek_cur(&mut file), 0x10001); + assert_eq!(file.seek_data(0x1FFFF).unwrap(), Some(0x1FFFF)); + assert_eq!(seek_cur(&mut file), 0x1FFFF); + + // Extend the file to add another hole after the data + file.set_len(0x30000).unwrap(); + assert_eq!(file.seek_data(0).unwrap(), Some(0x10000)); + assert_eq!(seek_cur(&mut file), 0x10000); + assert_eq!(file.seek_data(0x1FFFF).unwrap(), Some(0x1FFFF)); + assert_eq!(seek_cur(&mut file), 0x1FFFF); + assert_eq!(file.seek_data(0x20000).unwrap(), None); + assert_eq!(seek_cur(&mut file), 0x1FFFF); + } + + #[test] + #[allow(clippy::cognitive_complexity)] + fn seek_hole() { + let tempdir = TempDir::new_with_prefix("/tmp/seek_hole_test").unwrap(); + let mut path = PathBuf::from(tempdir.as_path()); + path.push("test_file"); + let mut file = File::create(&path).unwrap(); + + // Empty file + assert_eq!(file.seek_hole(0).unwrap(), None); + assert_eq!(seek_cur(&mut file), 0); + + // File with non-zero length consisting entirely of a hole + file.set_len(0x10000).unwrap(); + assert_eq!(file.seek_hole(0).unwrap(), Some(0)); + assert_eq!(seek_cur(&mut file), 0); + assert_eq!(file.seek_hole(0xFFFF).unwrap(), Some(0xFFFF)); + assert_eq!(seek_cur(&mut file), 0xFFFF); + + // seek_hole at or after the end of the file should return None + file.seek(SeekFrom::Start(0)).unwrap(); + assert_eq!(file.seek_hole(0x10000).unwrap(), None); + assert_eq!(seek_cur(&mut file), 0); + assert_eq!(file.seek_hole(0x10001).unwrap(), None); + assert_eq!(seek_cur(&mut file), 0); + + // Write some data to [0x10000, 0x20000) + let b = [0x55u8; 0x10000]; + file.seek(SeekFrom::Start(0x10000)).unwrap(); + file.write_all(&b).unwrap(); + + // seek_hole within a hole should return the same offset + assert_eq!(file.seek_hole(0).unwrap(), Some(0)); + assert_eq!(seek_cur(&mut file), 0); + assert_eq!(file.seek_hole(0xFFFF).unwrap(), Some(0xFFFF)); + assert_eq!(seek_cur(&mut file), 0xFFFF); + + // seek_hole within data should return the next hole (EOF) + file.seek(SeekFrom::Start(0)).unwrap(); + assert_eq!(file.seek_hole(0x10000).unwrap(), Some(0x20000)); + assert_eq!(seek_cur(&mut file), 0x20000); + file.seek(SeekFrom::Start(0)).unwrap(); + assert_eq!(file.seek_hole(0x10001).unwrap(), Some(0x20000)); + assert_eq!(seek_cur(&mut file), 0x20000); + file.seek(SeekFrom::Start(0)).unwrap(); + assert_eq!(file.seek_hole(0x1FFFF).unwrap(), Some(0x20000)); + assert_eq!(seek_cur(&mut file), 0x20000); + + // seek_hole at EOF after data should return None + file.seek(SeekFrom::Start(0)).unwrap(); + assert_eq!(file.seek_hole(0x20000).unwrap(), None); + assert_eq!(seek_cur(&mut file), 0); + + // Extend the file to add another hole after the data + file.set_len(0x30000).unwrap(); + assert_eq!(file.seek_hole(0).unwrap(), Some(0)); + assert_eq!(seek_cur(&mut file), 0); + assert_eq!(file.seek_hole(0xFFFF).unwrap(), Some(0xFFFF)); + assert_eq!(seek_cur(&mut file), 0xFFFF); + file.seek(SeekFrom::Start(0)).unwrap(); + assert_eq!(file.seek_hole(0x10000).unwrap(), Some(0x20000)); + assert_eq!(seek_cur(&mut file), 0x20000); + file.seek(SeekFrom::Start(0)).unwrap(); + assert_eq!(file.seek_hole(0x1FFFF).unwrap(), Some(0x20000)); + assert_eq!(seek_cur(&mut file), 0x20000); + file.seek(SeekFrom::Start(0)).unwrap(); + assert_eq!(file.seek_hole(0x20000).unwrap(), Some(0x20000)); + assert_eq!(seek_cur(&mut file), 0x20000); + file.seek(SeekFrom::Start(0)).unwrap(); + assert_eq!(file.seek_hole(0x20001).unwrap(), Some(0x20001)); + assert_eq!(seek_cur(&mut file), 0x20001); + + // seek_hole at EOF after a hole should return None + file.seek(SeekFrom::Start(0)).unwrap(); + assert_eq!(file.seek_hole(0x30000).unwrap(), None); + assert_eq!(seek_cur(&mut file), 0); + + // Write some data to [0x20000, 0x30000) + file.seek(SeekFrom::Start(0x20000)).unwrap(); + file.write_all(&b).unwrap(); + + // seek_hole within [0x20000, 0x30000) should now find the hole at EOF + assert_eq!(file.seek_hole(0x20000).unwrap(), Some(0x30000)); + assert_eq!(seek_cur(&mut file), 0x30000); + file.seek(SeekFrom::Start(0)).unwrap(); + assert_eq!(file.seek_hole(0x20001).unwrap(), Some(0x30000)); + assert_eq!(seek_cur(&mut file), 0x30000); + file.seek(SeekFrom::Start(0)).unwrap(); + assert_eq!(file.seek_hole(0x30000).unwrap(), None); + assert_eq!(seek_cur(&mut file), 0); + } +} diff --git a/src/linux/signal.rs b/src/linux/signal.rs new file mode 100644 index 0000000..45d9009 --- /dev/null +++ b/src/linux/signal.rs @@ -0,0 +1,583 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// +// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Enums, traits and functions for working with +//! [`signal`](http://man7.org/linux/man-pages/man7/signal.7.html). + +use libc::{ + c_int, c_void, pthread_kill, pthread_sigmask, pthread_t, sigaction, sigaddset, sigemptyset, + sigfillset, siginfo_t, sigismember, sigpending, sigset_t, sigtimedwait, timespec, EAGAIN, + EINTR, EINVAL, SIG_BLOCK, SIG_UNBLOCK, +}; + +use crate::errno; +use std::fmt::{self, Display}; +use std::io; +use std::mem; +use std::os::unix::thread::JoinHandleExt; +use std::ptr::{null, null_mut}; +use std::result; +use std::thread::JoinHandle; + +/// The error cases enumeration for signal handling. +#[derive(Debug, PartialEq, Eq)] +pub enum Error { + /// Couldn't create a sigset. + CreateSigset(errno::Error), + /// The wrapped signal has already been blocked. + SignalAlreadyBlocked(c_int), + /// Failed to check if the requested signal is in the blocked set already. + CompareBlockedSignals(errno::Error), + /// The signal could not be blocked. + BlockSignal(errno::Error), + /// The signal mask could not be retrieved. + RetrieveSignalMask(c_int), + /// The signal could not be unblocked. + UnblockSignal(errno::Error), + /// Failed to wait for given signal. + ClearWaitPending(errno::Error), + /// Failed to get pending signals. + ClearGetPending(errno::Error), + /// Failed to check if given signal is in the set of pending signals. + ClearCheckPending(errno::Error), +} + +impl Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use self::Error::*; + + match self { + CreateSigset(e) => write!(f, "couldn't create a sigset: {}", e), + SignalAlreadyBlocked(num) => write!(f, "signal {} already blocked", num), + CompareBlockedSignals(e) => write!( + f, + "failed to check whether requested signal is in the blocked set: {}", + e, + ), + BlockSignal(e) => write!(f, "signal could not be blocked: {}", e), + RetrieveSignalMask(errno) => write!( + f, + "failed to retrieve signal mask: {}", + io::Error::from_raw_os_error(*errno), + ), + UnblockSignal(e) => write!(f, "signal could not be unblocked: {}", e), + ClearWaitPending(e) => write!(f, "failed to wait for given signal: {}", e), + ClearGetPending(e) => write!(f, "failed to get pending signals: {}", e), + ClearCheckPending(e) => write!( + f, + "failed to check whether given signal is in the pending set: {}", + e, + ), + } + } +} + +/// A simplified [Result](https://doc.rust-lang.org/std/result/enum.Result.html) type +/// for operations that can return [`Error`](Enum.error.html). +pub type SignalResult<T> = result::Result<T, Error>; + +/// Public alias for a signal handler. +/// [`sigaction`](http://man7.org/linux/man-pages/man2/sigaction.2.html). +pub type SignalHandler = + extern "C" fn(num: c_int, info: *mut siginfo_t, _unused: *mut c_void) -> (); + +extern "C" { + fn __libc_current_sigrtmin() -> c_int; + fn __libc_current_sigrtmax() -> c_int; +} + +/// Return the minimum (inclusive) real-time signal number. +#[allow(non_snake_case)] +pub fn SIGRTMIN() -> c_int { + // SAFETY: We trust this libc function. + unsafe { __libc_current_sigrtmin() } +} + +/// Return the maximum (inclusive) real-time signal number. +#[allow(non_snake_case)] +pub fn SIGRTMAX() -> c_int { + // SAFETY: We trust this libc function. + unsafe { __libc_current_sigrtmax() } +} + +/// Verify that a signal number is valid. +/// +/// Supported signals range from `SIGHUP` to `SIGSYS` and from `SIGRTMIN` to `SIGRTMAX`. +/// We recommend using realtime signals `[SIGRTMIN(), SIGRTMAX()]` for VCPU threads. +/// +/// # Arguments +/// +/// * `num`: the signal number to be verified. +/// +/// # Examples +/// +/// ``` +/// extern crate vmm_sys_util; +/// use vmm_sys_util::signal::validate_signal_num; +/// +/// let num = validate_signal_num(1).unwrap(); +/// ``` +pub fn validate_signal_num(num: c_int) -> errno::Result<()> { + if (libc::SIGHUP..=libc::SIGSYS).contains(&num) || (SIGRTMIN() <= num && num <= SIGRTMAX()) { + Ok(()) + } else { + Err(errno::Error::new(EINVAL)) + } +} + +/// Register the signal handler of `signum`. +/// +/// # Safety +/// +/// This is considered unsafe because the given handler will be called +/// asynchronously, interrupting whatever the thread was doing and therefore +/// must only do async-signal-safe operations. +/// +/// # Arguments +/// +/// * `num`: the signal number to be registered. +/// * `handler`: the signal handler function to register. +/// +/// # Examples +/// +/// ``` +/// # extern crate libc; +/// extern crate vmm_sys_util; +/// # use libc::{c_int, c_void, siginfo_t, SA_SIGINFO}; +/// use vmm_sys_util::signal::{register_signal_handler, SignalHandler}; +/// +/// extern "C" fn handle_signal(_: c_int, _: *mut siginfo_t, _: *mut c_void) {} +/// register_signal_handler(0, handle_signal); +/// ``` + +pub fn register_signal_handler(num: c_int, handler: SignalHandler) -> errno::Result<()> { + validate_signal_num(num)?; + + // signum specifies the signal and can be any valid signal except + // SIGKILL and SIGSTOP. + // [`sigaction`](http://man7.org/linux/man-pages/man2/sigaction.2.html). + if libc::SIGKILL == num || libc::SIGSTOP == num { + return Err(errno::Error::new(EINVAL)); + } + + // SAFETY: Safe, because this is a POD struct. + let mut act: sigaction = unsafe { mem::zeroed() }; + act.sa_sigaction = handler as *const () as usize; + act.sa_flags = libc::SA_SIGINFO; + + // Block all signals while the `handler` is running. + // Blocking other signals is needed to make sure the execution of + // the handler continues uninterrupted if another signal comes. + // SAFETY: The parameters are valid and we trust the sifillset function. + if unsafe { sigfillset(&mut act.sa_mask as *mut sigset_t) } < 0 { + return errno::errno_result(); + } + + // SAFETY: Safe because the parameters are valid and we check the return value. + match unsafe { sigaction(num, &act, null_mut()) } { + 0 => Ok(()), + _ => errno::errno_result(), + } +} + +/// Create a `sigset` with given signals. +/// +/// An array of signal numbers are added into the signal set by +/// [`sigaddset`](http://man7.org/linux/man-pages/man3/sigaddset.3p.html). +/// This is a helper function used when we want to manipulate signals. +/// +/// # Arguments +/// +/// * `signals`: signal numbers to be added to the new `sigset`. +/// +/// # Examples +/// +/// ``` +/// # extern crate libc; +/// extern crate vmm_sys_util; +/// # use libc::sigismember; +/// use vmm_sys_util::signal::create_sigset; +/// +/// let sigset = create_sigset(&[1]).unwrap(); +/// +/// unsafe { +/// assert_eq!(sigismember(&sigset, 1), 1); +/// } +/// ``` +pub fn create_sigset(signals: &[c_int]) -> errno::Result<sigset_t> { + // SAFETY: sigset will actually be initialized by sigemptyset below. + let mut sigset: sigset_t = unsafe { mem::zeroed() }; + + // SAFETY: return value is checked. + let ret = unsafe { sigemptyset(&mut sigset) }; + if ret < 0 { + return errno::errno_result(); + } + + for signal in signals { + // SAFETY: return value is checked. + let ret = unsafe { sigaddset(&mut sigset, *signal) }; + if ret < 0 { + return errno::errno_result(); + } + } + + Ok(sigset) +} + +/// Retrieve the signal mask that is blocked of the current thread. +/// +/// Use [`pthread_sigmask`](http://man7.org/linux/man-pages/man3/pthread_sigmask.3.html) +/// to fetch the signal mask which is blocked for the caller, return the signal mask as +/// a vector of c_int. +/// +/// # Examples +/// +/// ``` +/// extern crate vmm_sys_util; +/// use vmm_sys_util::signal::{block_signal, get_blocked_signals}; +/// +/// block_signal(1).unwrap(); +/// assert!(get_blocked_signals().unwrap().contains(&(1))); +/// ``` +pub fn get_blocked_signals() -> SignalResult<Vec<c_int>> { + let mut mask = Vec::new(); + + // SAFETY: return values are checked. + unsafe { + let mut old_sigset: sigset_t = mem::zeroed(); + let ret = pthread_sigmask(SIG_BLOCK, null(), &mut old_sigset as *mut sigset_t); + if ret < 0 { + return Err(Error::RetrieveSignalMask(ret)); + } + + for num in 0..=SIGRTMAX() { + if sigismember(&old_sigset, num) > 0 { + mask.push(num); + } + } + } + + Ok(mask) +} + +/// Mask a given signal. +/// +/// Set the given signal `num` as blocked. +/// If signal is already blocked, the call will fail with +/// [`SignalAlreadyBlocked`](enum.Error.html#variant.SignalAlreadyBlocked). +/// +/// # Arguments +/// +/// * `num`: the signal to be masked. +/// +/// # Examples +/// +/// ``` +/// extern crate vmm_sys_util; +/// use vmm_sys_util::signal::block_signal; +/// +/// block_signal(1).unwrap(); +/// ``` +// Allowing comparison chain because rewriting it with match makes the code less readable. +// Also, the risk of having non-exhaustive checks is low. +#[allow(clippy::comparison_chain)] +pub fn block_signal(num: c_int) -> SignalResult<()> { + let sigset = create_sigset(&[num]).map_err(Error::CreateSigset)?; + + // SAFETY: return values are checked. + unsafe { + let mut old_sigset: sigset_t = mem::zeroed(); + let ret = pthread_sigmask(SIG_BLOCK, &sigset, &mut old_sigset as *mut sigset_t); + if ret < 0 { + return Err(Error::BlockSignal(errno::Error::last())); + } + // Check if the given signal is already blocked. + let ret = sigismember(&old_sigset, num); + if ret < 0 { + return Err(Error::CompareBlockedSignals(errno::Error::last())); + } else if ret > 0 { + return Err(Error::SignalAlreadyBlocked(num)); + } + } + Ok(()) +} + +/// Unmask a given signal. +/// +/// # Arguments +/// +/// * `num`: the signal to be unmasked. +/// +/// # Examples +/// +/// ``` +/// extern crate vmm_sys_util; +/// use vmm_sys_util::signal::{block_signal, get_blocked_signals, unblock_signal}; +/// +/// block_signal(1).unwrap(); +/// assert!(get_blocked_signals().unwrap().contains(&(1))); +/// unblock_signal(1).unwrap(); +/// ``` +pub fn unblock_signal(num: c_int) -> SignalResult<()> { + let sigset = create_sigset(&[num]).map_err(Error::CreateSigset)?; + + // SAFETY: return value is checked. + let ret = unsafe { pthread_sigmask(SIG_UNBLOCK, &sigset, null_mut()) }; + if ret < 0 { + return Err(Error::UnblockSignal(errno::Error::last())); + } + Ok(()) +} + +/// Clear a pending signal. +/// +/// # Arguments +/// +/// * `num`: the signal to be cleared. +/// +/// # Examples +/// +/// ``` +/// # extern crate libc; +/// extern crate vmm_sys_util; +/// # use libc::{pthread_kill, sigismember, sigpending, sigset_t}; +/// # use std::mem; +/// # use std::thread; +/// # use std::time::Duration; +/// use vmm_sys_util::signal::{block_signal, clear_signal, Killable}; +/// +/// block_signal(1).unwrap(); +/// let killable = thread::spawn(move || { +/// thread::sleep(Duration::from_millis(100)); +/// unsafe { +/// let mut chkset: sigset_t = mem::zeroed(); +/// sigpending(&mut chkset); +/// assert_eq!(sigismember(&chkset, 1), 1); +/// } +/// }); +/// unsafe { +/// pthread_kill(killable.pthread_handle(), 1); +/// } +/// clear_signal(1).unwrap(); +/// ``` +pub fn clear_signal(num: c_int) -> SignalResult<()> { + let sigset = create_sigset(&[num]).map_err(Error::CreateSigset)?; + + while { + // SAFETY: This is safe as we are rigorously checking return values + // of libc calls. + unsafe { + let mut siginfo: siginfo_t = mem::zeroed(); + let ts = timespec { + tv_sec: 0, + tv_nsec: 0, + }; + // Attempt to consume one instance of pending signal. If signal + // is not pending, the call will fail with EAGAIN or EINTR. + let ret = sigtimedwait(&sigset, &mut siginfo, &ts); + if ret < 0 { + let e = errno::Error::last(); + match e.errno() { + EAGAIN | EINTR => {} + _ => { + return Err(Error::ClearWaitPending(errno::Error::last())); + } + } + } + + // This sigset will be actually filled with `sigpending` call. + let mut chkset: sigset_t = mem::zeroed(); + // See if more instances of the signal are pending. + let ret = sigpending(&mut chkset); + if ret < 0 { + return Err(Error::ClearGetPending(errno::Error::last())); + } + + let ret = sigismember(&chkset, num); + if ret < 0 { + return Err(Error::ClearCheckPending(errno::Error::last())); + } + + // This is do-while loop condition. + ret != 0 + } + } {} + + Ok(()) +} + +/// Trait for threads that can be signalled via `pthread_kill`. +/// +/// Note that this is only useful for signals between `SIGRTMIN()` and +/// `SIGRTMAX()` because these are guaranteed to not be used by the C +/// runtime. +/// +/// # Safety +/// +/// This is marked unsafe because the implementation of this trait must +/// guarantee that the returned `pthread_t` is valid and has a lifetime at +/// least that of the trait object. +pub unsafe trait Killable { + /// Cast this killable thread as `pthread_t`. + fn pthread_handle(&self) -> pthread_t; + + /// Send a signal to this killable thread. + /// + /// # Arguments + /// + /// * `num`: specify the signal + fn kill(&self, num: c_int) -> errno::Result<()> { + validate_signal_num(num)?; + + // SAFETY: Safe because we ensure we are using a valid pthread handle, + // a valid signal number, and check the return result. + let ret = unsafe { pthread_kill(self.pthread_handle(), num) }; + if ret < 0 { + return errno::errno_result(); + } + Ok(()) + } +} + +// SAFETY: Safe because we fulfill our contract of returning a genuine pthread handle. +unsafe impl<T> Killable for JoinHandle<T> { + fn pthread_handle(&self) -> pthread_t { + // JoinHandleExt::as_pthread_t gives c_ulong, convert it to the + // type that the libc crate expects + assert_eq!(mem::size_of::<pthread_t>(), mem::size_of::<usize>()); + self.as_pthread_t() as usize as pthread_t + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::undocumented_unsafe_blocks)] + use super::*; + use std::thread; + use std::time::Duration; + + // Reserve for each vcpu signal. + static mut SIGNAL_HANDLER_CALLED: bool = false; + + extern "C" fn handle_signal(_: c_int, _: *mut siginfo_t, _: *mut c_void) { + unsafe { + // In the tests, there only uses vcpu signal. + SIGNAL_HANDLER_CALLED = true; + } + } + + fn is_pending(signal: c_int) -> bool { + unsafe { + let mut chkset: sigset_t = mem::zeroed(); + sigpending(&mut chkset); + sigismember(&chkset, signal) == 1 + } + } + + #[test] + fn test_register_signal_handler() { + // testing bad value + assert!(register_signal_handler(libc::SIGKILL, handle_signal).is_err()); + assert!(register_signal_handler(libc::SIGSTOP, handle_signal).is_err()); + assert!(register_signal_handler(SIGRTMAX() + 1, handle_signal).is_err()); + format!("{:?}", register_signal_handler(SIGRTMAX(), handle_signal)); + assert!(register_signal_handler(SIGRTMIN(), handle_signal).is_ok()); + assert!(register_signal_handler(libc::SIGSYS, handle_signal).is_ok()); + } + + #[test] + #[allow(clippy::empty_loop)] + fn test_killing_thread() { + let killable = thread::spawn(|| thread::current().id()); + let killable_id = killable.join().unwrap(); + assert_ne!(killable_id, thread::current().id()); + + // We install a signal handler for the specified signal; otherwise the whole process will + // be brought down when the signal is received, as part of the default behaviour. Signal + // handlers are global, so we install this before starting the thread. + register_signal_handler(SIGRTMIN(), handle_signal) + .expect("failed to register vcpu signal handler"); + + let killable = thread::spawn(|| loop {}); + + let res = killable.kill(SIGRTMAX() + 1); + assert!(res.is_err()); + format!("{:?}", res); + + unsafe { + assert!(!SIGNAL_HANDLER_CALLED); + } + + assert!(killable.kill(SIGRTMIN()).is_ok()); + + // We're waiting to detect that the signal handler has been called. + const MAX_WAIT_ITERS: u32 = 20; + let mut iter_count = 0; + loop { + thread::sleep(Duration::from_millis(100)); + + if unsafe { SIGNAL_HANDLER_CALLED } { + break; + } + + iter_count += 1; + // timeout if we wait too long + assert!(iter_count <= MAX_WAIT_ITERS); + } + + // Our signal handler doesn't do anything which influences the killable thread, so the + // previous signal is effectively ignored. If we were to join killable here, we would block + // forever as the loop keeps running. Since we don't join, the thread will become detached + // as the handle is dropped, and will be killed when the process/main thread exits. + } + + #[test] + fn test_block_unblock_signal() { + let signal = SIGRTMIN(); + + // Check if it is blocked. + unsafe { + let mut sigset: sigset_t = mem::zeroed(); + pthread_sigmask(SIG_BLOCK, null(), &mut sigset as *mut sigset_t); + assert_eq!(sigismember(&sigset, signal), 0); + } + + block_signal(signal).unwrap(); + assert!(get_blocked_signals().unwrap().contains(&(signal))); + + unblock_signal(signal).unwrap(); + assert!(!get_blocked_signals().unwrap().contains(&(signal))); + } + + #[test] + fn test_clear_pending() { + let signal = SIGRTMIN() + 1; + + block_signal(signal).unwrap(); + + // Block the signal, which means it won't be delivered until it is + // unblocked. Pending between the time when the signal which is set as blocked + // is generated and when is delivered. + let killable = thread::spawn(move || { + loop { + // Wait for the signal being killed. + thread::sleep(Duration::from_millis(100)); + if is_pending(signal) { + clear_signal(signal).unwrap(); + assert!(!is_pending(signal)); + break; + } + } + }); + + // Send a signal to the thread. + assert!(killable.kill(SIGRTMIN() + 1).is_ok()); + killable.join().unwrap(); + } +} diff --git a/src/linux/sock_ctrl_msg.rs b/src/linux/sock_ctrl_msg.rs new file mode 100644 index 0000000..0c19a11 --- /dev/null +++ b/src/linux/sock_ctrl_msg.rs @@ -0,0 +1,663 @@ +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-3-Clause file. +// SPDX-License-Identifier: BSD-3-Clause + +/* Copied from the crosvm Project, commit 186eb8b */ + +//! Wrapper for sending and receiving messages with file descriptors on sockets that accept +//! control messages (e.g. Unix domain sockets). + +use std::fs::File; +use std::mem::size_of; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::os::unix::net::{UnixDatagram, UnixStream}; +use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned}; + +use crate::errno::{Error, Result}; +use libc::{ + c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET, +}; +use std::os::raw::c_int; + +// Each of the following macros performs the same function as their C counterparts. They are each +// macros because they are used to size statically allocated arrays. + +macro_rules! CMSG_ALIGN { + ($len:expr) => { + (($len) as usize + size_of::<c_long>() - 1) & !(size_of::<c_long>() - 1) + }; +} + +macro_rules! CMSG_SPACE { + ($len:expr) => { + size_of::<cmsghdr>() + CMSG_ALIGN!($len) + }; +} + +// This function (macro in the C version) is not used in any compile time constant slots, so is just +// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this +// module supports. +#[allow(non_snake_case)] +#[inline(always)] +fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd { + // Essentially returns a pointer to just past the header. + cmsg_buffer.wrapping_offset(1) as *mut RawFd +} + +#[cfg(not(target_env = "musl"))] +macro_rules! CMSG_LEN { + ($len:expr) => { + size_of::<cmsghdr>() + ($len) + }; +} + +#[cfg(target_env = "musl")] +macro_rules! CMSG_LEN { + ($len:expr) => {{ + let sz = size_of::<cmsghdr>() + ($len); + assert!(sz <= (std::u32::MAX as usize)); + sz as u32 + }}; +} + +#[cfg(not(target_env = "musl"))] +fn new_msghdr(iovecs: &mut [iovec]) -> msghdr { + msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: iovecs.as_mut_ptr(), + msg_iovlen: iovecs.len(), + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + } +} + +#[cfg(target_env = "musl")] +fn new_msghdr(iovecs: &mut [iovec]) -> msghdr { + assert!(iovecs.len() <= (std::i32::MAX as usize)); + let mut msg: msghdr = unsafe { std::mem::zeroed() }; + msg.msg_name = null_mut(); + msg.msg_iov = iovecs.as_mut_ptr(); + msg.msg_iovlen = iovecs.len() as i32; + msg.msg_control = null_mut(); + msg +} + +#[cfg(not(target_env = "musl"))] +fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) { + msg.msg_controllen = cmsg_capacity; +} + +#[cfg(target_env = "musl")] +fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) { + assert!(cmsg_capacity <= (std::u32::MAX as usize)); + msg.msg_controllen = cmsg_capacity as u32; +} + +// This function is like CMSG_NEXT, but safer because it reads only from references, although it +// does some pointer arithmetic on cmsg_ptr. +#[cfg_attr(feature = "cargo-clippy", allow(clippy::cast_ptr_alignment))] +fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr { + let next_cmsg = (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len)) as *mut cmsghdr; + if next_cmsg + .wrapping_offset(1) + .wrapping_sub(msghdr.msg_control as usize) as usize + > msghdr.msg_controllen as usize + { + null_mut() + } else { + next_cmsg + } +} + +const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::<RawFd>() * 32); + +enum CmsgBuffer { + Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]), + Heap(Box<[cmsghdr]>), +} + +impl CmsgBuffer { + fn with_capacity(capacity: usize) -> CmsgBuffer { + let cap_in_cmsghdr_units = + (capacity.checked_add(size_of::<cmsghdr>()).unwrap() - 1) / size_of::<cmsghdr>(); + if capacity <= CMSG_BUFFER_INLINE_CAPACITY { + CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]) + } else { + CmsgBuffer::Heap( + vec![ + cmsghdr { + cmsg_len: 0, + cmsg_level: 0, + cmsg_type: 0, + #[cfg(all(target_env = "musl", target_pointer_width = "64"))] + __pad1: 0, + }; + cap_in_cmsghdr_units + ] + .into_boxed_slice(), + ) + } + } + + fn as_mut_ptr(&mut self) -> *mut cmsghdr { + match self { + CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr, + CmsgBuffer::Heap(a) => a.as_mut_ptr(), + } + } +} + +fn raw_sendmsg<D: IntoIovec>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> { + let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * out_fds.len()); + let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); + + let mut iovecs = Vec::with_capacity(out_data.len()); + for data in out_data { + iovecs.push(iovec { + iov_base: data.as_ptr() as *mut c_void, + iov_len: data.size(), + }); + } + + let mut msg = new_msghdr(&mut iovecs); + + if !out_fds.is_empty() { + let cmsg = cmsghdr { + cmsg_len: CMSG_LEN!(size_of::<RawFd>() * out_fds.len()), + cmsg_level: SOL_SOCKET, + cmsg_type: SCM_RIGHTS, + #[cfg(all(target_env = "musl", target_pointer_width = "64"))] + __pad1: 0, + }; + // SAFETY: Check comments below for each call. + unsafe { + // Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr. + write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg); + // Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len() + // file descriptors. + copy_nonoverlapping( + out_fds.as_ptr(), + CMSG_DATA(cmsg_buffer.as_mut_ptr()), + out_fds.len(), + ); + } + + msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; + set_msg_controllen(&mut msg, cmsg_capacity); + } + + // SAFETY: Safe because the msghdr was properly constructed from valid (or null) pointers of + // the indicated length and we check the return value. + let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) }; + + if write_count == -1 { + Err(Error::last()) + } else { + Ok(write_count as usize) + } +} + +unsafe fn raw_recvmsg( + fd: RawFd, + iovecs: &mut [iovec], + in_fds: &mut [RawFd], +) -> Result<(usize, usize)> { + let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len()); + let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); + let mut msg = new_msghdr(iovecs); + + if !in_fds.is_empty() { + // MSG control len is size_of(cmsghdr) + size_of(RawFd) * in_fds.len(). + msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; + set_msg_controllen(&mut msg, cmsg_capacity); + } + + // Safe because the msghdr was properly constructed from valid (or null) pointers of the + // indicated length and we check the return value. + // TODO: Should we handle MSG_TRUNC in a specific way? + let total_read = recvmsg(fd, &mut msg, 0); + if total_read == -1 { + return Err(Error::last()); + } + + if total_read == 0 && (msg.msg_controllen as usize) < size_of::<cmsghdr>() { + return Ok((0, 0)); + } + + // Reference to a memory area with a CmsgBuffer, which contains a `cmsghdr` struct followed + // by a sequence of `in_fds.len()` count RawFds. + let mut cmsg_ptr = msg.msg_control as *mut cmsghdr; + let mut copied_fds_count = 0; + // If the control data was truncated, then this might be a sign of incorrect communication + // protocol. If MSG_CTRUNC was set we must close the fds from the control data. + let mut teardown_control_data = msg.msg_flags & libc::MSG_CTRUNC != 0; + + while !cmsg_ptr.is_null() { + // Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such + // that it only happens when there is at least sizeof(cmsghdr) space after the pointer to + // read. + let cmsg = (cmsg_ptr as *mut cmsghdr).read_unaligned(); + if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS { + let fds_count = ((cmsg.cmsg_len - CMSG_LEN!(0)) as usize) / size_of::<RawFd>(); + // The sender can transmit more data than we can buffer. If a message is too long to + // fit in the supplied buffer, excess bytes may be discarded depending on the type of + // socket the message is received from. + let fds_to_be_copied_count = std::cmp::min(in_fds.len() - copied_fds_count, fds_count); + teardown_control_data |= fds_count > fds_to_be_copied_count; + if teardown_control_data { + // Allocating space for cmesg buffer might provide extra space for fds, due to + // alignment. If these fds can not be stored in `in_fds` buffer, then all the control + // data must be dropped to insufficient buffer space for returning them to outer + // scope. This might be a sign of incorrect protocol communication. + for fd_offset in 0..fds_count { + let raw_fds_ptr = CMSG_DATA(cmsg_ptr); + // The cmsg_ptr is valid here because is checked at the beginning of the + // loop and it is assured to have `fds_count` fds available. + let raw_fd = *(raw_fds_ptr.wrapping_add(fd_offset)) as c_int; + libc::close(raw_fd); + } + } else { + // Safe because `cmsg_ptr` is checked against null and we copy from `cmesg_buffer` to + // `in_fds` according to their current capacity. + copy_nonoverlapping( + CMSG_DATA(cmsg_ptr), + in_fds[copied_fds_count..(copied_fds_count + fds_to_be_copied_count)] + .as_mut_ptr(), + fds_to_be_copied_count, + ); + + copied_fds_count += fds_to_be_copied_count; + } + } + + // Remove the previously copied fds. + if teardown_control_data { + for fd in in_fds.iter().take(copied_fds_count) { + // This is safe because we close only the previously copied fds. We do not care + // about `close` return code. + libc::close(*fd); + } + + return Err(Error::new(libc::ENOBUFS)); + } + + cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr); + } + + Ok((total_read as usize, copied_fds_count)) +} + +/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and +/// `recvmsg`. +/// +/// # Examples +/// +/// ``` +/// # extern crate libc; +/// extern crate vmm_sys_util; +/// use vmm_sys_util::sock_ctrl_msg::ScmSocket; +/// # use vmm_sys_util::eventfd::{EventFd, EFD_NONBLOCK}; +/// # use std::fs::File; +/// # use std::io::Write; +/// # use std::os::unix::io::{AsRawFd, FromRawFd}; +/// # use std::os::unix::net::UnixDatagram; +/// # use std::slice::from_raw_parts; +/// +/// # use libc::{c_void, iovec}; +/// +/// let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); +/// let evt = EventFd::new(0).expect("failed to create eventfd"); +/// +/// let write_count = s1 +/// .send_with_fds(&[[237].as_ref()], &[evt.as_raw_fd()]) +/// .expect("failed to send fd"); +/// +/// let mut files = [0; 2]; +/// let mut buf = [0u8]; +/// let mut iovecs = [iovec { +/// iov_base: buf.as_mut_ptr() as *mut c_void, +/// iov_len: buf.len(), +/// }]; +/// let (read_count, file_count) = unsafe { +/// s2.recv_with_fds(&mut iovecs[..], &mut files) +/// .expect("failed to recv fd") +/// }; +/// +/// let mut file = unsafe { File::from_raw_fd(files[0]) }; +/// file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) +/// .expect("failed to write to sent fd"); +/// assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); +/// ``` +pub trait ScmSocket { + /// Gets the file descriptor of this socket. + fn socket_fd(&self) -> RawFd; + + /// Sends the given data and file descriptor over the socket. + /// + /// On success, returns the number of bytes sent. + /// + /// # Arguments + /// + /// * `buf` - A buffer of data to send on the `socket`. + /// * `fd` - A file descriptors to be sent. + fn send_with_fd<D: IntoIovec>(&self, buf: D, fd: RawFd) -> Result<usize> { + self.send_with_fds(&[buf], &[fd]) + } + + /// Sends the given data and file descriptors over the socket. + /// + /// On success, returns the number of bytes sent. + /// + /// # Arguments + /// + /// * `bufs` - A list of data buffer to send on the `socket`. + /// * `fds` - A list of file descriptors to be sent. + fn send_with_fds<D: IntoIovec>(&self, bufs: &[D], fds: &[RawFd]) -> Result<usize> { + raw_sendmsg(self.socket_fd(), bufs, fds) + } + + /// Receives data and potentially a file descriptor from the socket. + /// + /// On success, returns the number of bytes and an optional file descriptor. + /// + /// # Arguments + /// + /// * `buf` - A buffer to receive data from the socket. + fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> { + let mut fd = [0]; + let mut iovecs = [iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }]; + + // SAFETY: Safe because we have mutably borrowed buf and it's safe to write arbitrary data + // to a slice. + let (read_count, fd_count) = unsafe { self.recv_with_fds(&mut iovecs[..], &mut fd)? }; + let file = if fd_count == 0 { + None + } else { + // SAFETY: Safe because the first fd from recv_with_fds is owned by us and valid + // because this branch was taken. + Some(unsafe { File::from_raw_fd(fd[0]) }) + }; + Ok((read_count, file)) + } + + /// Receives data and file descriptors from the socket. + /// + /// On success, returns the number of bytes and file descriptors received as a tuple + /// `(bytes count, files count)`. + /// + /// # Arguments + /// + /// * `iovecs` - A list of iovec to receive data from the socket. + /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the + /// number of valid file descriptors is indicated by the second element of the + /// returned tuple. The caller owns these file descriptors, but they will not be + /// closed on drop like a `File`-like type would be. It is recommended that each valid + /// file descriptor gets wrapped in a drop type that closes it after this returns. + /// + /// # Safety + /// + /// It is the callers responsibility to ensure it is safe for arbitrary data to be + /// written to the iovec pointers. + unsafe fn recv_with_fds( + &self, + iovecs: &mut [iovec], + fds: &mut [RawFd], + ) -> Result<(usize, usize)> { + raw_recvmsg(self.socket_fd(), iovecs, fds) + } +} + +impl ScmSocket for UnixDatagram { + fn socket_fd(&self) -> RawFd { + self.as_raw_fd() + } +} + +impl ScmSocket for UnixStream { + fn socket_fd(&self) -> RawFd { + self.as_raw_fd() + } +} + +/// Trait for types that can be converted into an `iovec` that can be referenced by a syscall for +/// the lifetime of this object. +/// +/// # Safety +/// +/// This is marked unsafe because the implementation must ensure that the returned pointer and size +/// is valid and that the lifetime of the returned pointer is at least that of the trait object. +pub unsafe trait IntoIovec { + /// Gets the base pointer of this `iovec`. + fn as_ptr(&self) -> *const c_void; + + /// Gets the size in bytes of this `iovec`. + fn size(&self) -> usize; +} + +// SAFETY: Safe because this slice can not have another mutable reference and it's pointer and +// size are guaranteed to be valid. +unsafe impl<'a> IntoIovec for &'a [u8] { + // Clippy false positive: https://github.com/rust-lang/rust-clippy/issues/3480 + #[cfg_attr(feature = "cargo-clippy", allow(clippy::useless_asref))] + fn as_ptr(&self) -> *const c_void { + self.as_ref().as_ptr() as *const c_void + } + + fn size(&self) -> usize { + self.len() + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::undocumented_unsafe_blocks)] + use super::*; + use crate::eventfd::EventFd; + + use std::io::Write; + use std::mem::size_of; + use std::os::raw::c_long; + use std::os::unix::net::UnixDatagram; + use std::slice::from_raw_parts; + + use libc::cmsghdr; + + #[test] + fn buffer_len() { + assert_eq!(CMSG_SPACE!(0), size_of::<cmsghdr>()); + assert_eq!( + CMSG_SPACE!(size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() + ); + if size_of::<RawFd>() == 4 { + assert_eq!( + CMSG_SPACE!(2 * size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() + ); + assert_eq!( + CMSG_SPACE!(3 * size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() * 2 + ); + assert_eq!( + CMSG_SPACE!(4 * size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() * 2 + ); + } else if size_of::<RawFd>() == 8 { + assert_eq!( + CMSG_SPACE!(2 * size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() * 2 + ); + assert_eq!( + CMSG_SPACE!(3 * size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() * 3 + ); + assert_eq!( + CMSG_SPACE!(4 * size_of::<RawFd>()), + size_of::<cmsghdr>() + size_of::<c_long>() * 4 + ); + } + } + + #[test] + fn send_recv_no_fd() { + let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); + + let write_count = s1 + .send_with_fds(&[[1u8, 1, 2].as_ref(), [21u8, 34, 55].as_ref()], &[]) + .expect("failed to send data"); + + assert_eq!(write_count, 6); + + let mut buf = [0u8; 6]; + let mut files = [0; 1]; + let mut iovecs = [iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }]; + let (read_count, file_count) = unsafe { + s2.recv_with_fds(&mut iovecs[..], &mut files) + .expect("failed to recv data") + }; + + assert_eq!(read_count, 6); + assert_eq!(file_count, 0); + assert_eq!(buf, [1, 1, 2, 21, 34, 55]); + } + + #[test] + fn send_recv_only_fd() { + let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); + + let evt = EventFd::new(0).expect("failed to create eventfd"); + let write_count = s1 + .send_with_fd([].as_ref(), evt.as_raw_fd()) + .expect("failed to send fd"); + + assert_eq!(write_count, 0); + + let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd"); + + let mut file = file_opt.unwrap(); + + assert_eq!(read_count, 0); + assert!(file.as_raw_fd() >= 0); + assert_ne!(file.as_raw_fd(), s1.as_raw_fd()); + assert_ne!(file.as_raw_fd(), s2.as_raw_fd()); + assert_ne!(file.as_raw_fd(), evt.as_raw_fd()); + + file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) + .expect("failed to write to sent fd"); + + assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); + } + + #[test] + fn send_recv_with_fd() { + let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); + + let evt = EventFd::new(0).expect("failed to create eventfd"); + let write_count = s1 + .send_with_fds(&[[237].as_ref()], &[evt.as_raw_fd()]) + .expect("failed to send fd"); + + assert_eq!(write_count, 1); + + let mut files = [0; 2]; + let mut buf = [0u8]; + let mut iovecs = [iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }]; + let (read_count, file_count) = unsafe { + s2.recv_with_fds(&mut iovecs[..], &mut files) + .expect("failed to recv fd") + }; + + assert_eq!(read_count, 1); + assert_eq!(buf[0], 237); + assert_eq!(file_count, 1); + assert!(files[0] >= 0); + assert_ne!(files[0], s1.as_raw_fd()); + assert_ne!(files[0], s2.as_raw_fd()); + assert_ne!(files[0], evt.as_raw_fd()); + + let mut file = unsafe { File::from_raw_fd(files[0]) }; + + file.write_all(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) + .expect("failed to write to sent fd"); + + assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); + } + + #[test] + // Exercise the code paths that activate the issue of receiving the all the ancillary data, + // but missing to provide enough buffer space to store it. + fn send_more_recv_less1() { + let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); + + let evt1 = EventFd::new(0).expect("failed to create eventfd"); + let evt2 = EventFd::new(0).expect("failed to create eventfd"); + let evt3 = EventFd::new(0).expect("failed to create eventfd"); + let evt4 = EventFd::new(0).expect("failed to create eventfd"); + let write_count = s1 + .send_with_fds( + &[[237].as_ref()], + &[ + evt1.as_raw_fd(), + evt2.as_raw_fd(), + evt3.as_raw_fd(), + evt4.as_raw_fd(), + ], + ) + .expect("failed to send fd"); + + assert_eq!(write_count, 1); + + let mut files = [0; 2]; + let mut buf = [0u8]; + let mut iovecs = [iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }]; + assert!(unsafe { s2.recv_with_fds(&mut iovecs[..], &mut files).is_err() }); + } + + // Exercise the code paths that activate the issue of receiving part of the sent ancillary + // data due to insufficient buffer space, activating `msg_flags` `MSG_CTRUNC` flag. + #[test] + fn send_more_recv_less2() { + let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); + + let evt1 = EventFd::new(0).expect("failed to create eventfd"); + let evt2 = EventFd::new(0).expect("failed to create eventfd"); + let evt3 = EventFd::new(0).expect("failed to create eventfd"); + let evt4 = EventFd::new(0).expect("failed to create eventfd"); + let write_count = s1 + .send_with_fds( + &[[237].as_ref()], + &[ + evt1.as_raw_fd(), + evt2.as_raw_fd(), + evt3.as_raw_fd(), + evt4.as_raw_fd(), + ], + ) + .expect("failed to send fd"); + + assert_eq!(write_count, 1); + + let mut files = [0; 1]; + let mut buf = [0u8]; + let mut iovecs = [iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }]; + assert!(unsafe { s2.recv_with_fds(&mut iovecs[..], &mut files).is_err() }); + } +} diff --git a/src/linux/timerfd.rs b/src/linux/timerfd.rs new file mode 100644 index 0000000..80f0789 --- /dev/null +++ b/src/linux/timerfd.rs @@ -0,0 +1,281 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Structure and functions for working with +//! [`timerfd`](http://man7.org/linux/man-pages/man2/timerfd_create.2.html). + +use std::fs::File; +use std::mem; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +use std::ptr; +use std::time::Duration; + +use libc::{self, timerfd_create, timerfd_gettime, timerfd_settime, CLOCK_MONOTONIC, TFD_CLOEXEC}; + +use crate::errno::{errno_result, Result}; + +/// A safe wrapper around a Linux +/// [`timerfd`](http://man7.org/linux/man-pages/man2/timerfd_create.2.html). +pub struct TimerFd(File); + +impl TimerFd { + /// Create a new [`TimerFd`](struct.TimerFd.html). + /// + /// This creates a nonsettable monotonically increasing clock that does not + /// change after system startup. The timer is initally disarmed and must be + /// armed by calling [`reset`](fn.reset.html). + pub fn new() -> Result<TimerFd> { + // SAFETY: Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { timerfd_create(CLOCK_MONOTONIC, TFD_CLOEXEC) }; + if ret < 0 { + return errno_result(); + } + + // SAFETY: Safe because we uniquely own the file descriptor. + Ok(TimerFd(unsafe { File::from_raw_fd(ret) })) + } + + /// Arm the [`TimerFd`](struct.TimerFd.html). + /// + /// Set the timer to expire after `dur`. + /// + /// # Arguments + /// + /// * `dur`: Specify the initial expiration of the timer. + /// * `interval`: Specify the period for repeated expirations, depending on the + /// value passed. If `interval` is not `None`, it represents the period after + /// the initial expiration. Otherwise the timer will expire just once. Cancels + /// any existing duration and repeating interval. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// # use std::time::Duration; + /// use vmm_sys_util::timerfd::TimerFd; + /// + /// let mut timer = TimerFd::new().unwrap(); + /// let dur = Duration::from_millis(100); + /// let interval = Duration::from_millis(100); + /// + /// timer.reset(dur, Some(interval)).unwrap(); + /// ``` + pub fn reset(&mut self, dur: Duration, interval: Option<Duration>) -> Result<()> { + // SAFETY: Safe because we are zero-initializing a struct with only primitive member fields. + let mut spec: libc::itimerspec = unsafe { mem::zeroed() }; + // https://github.com/rust-lang/libc/issues/1848 + #[cfg_attr(target_env = "musl", allow(deprecated))] + { + spec.it_value.tv_sec = dur.as_secs() as libc::time_t; + } + // nsec always fits in i32 because subsec_nanos is defined to be less than one billion. + let nsec = dur.subsec_nanos() as i32; + spec.it_value.tv_nsec = libc::c_long::from(nsec); + + if let Some(int) = interval { + // https://github.com/rust-lang/libc/issues/1848 + #[cfg_attr(target_env = "musl", allow(deprecated))] + { + spec.it_interval.tv_sec = int.as_secs() as libc::time_t; + } + // nsec always fits in i32 because subsec_nanos is defined to be less than one billion. + let nsec = int.subsec_nanos() as i32; + spec.it_interval.tv_nsec = libc::c_long::from(nsec); + } + + // SAFETY: Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { timerfd_settime(self.as_raw_fd(), 0, &spec, ptr::null_mut()) }; + if ret < 0 { + return errno_result(); + } + + Ok(()) + } + + /// Wait until the timer expires. + /// + /// The return value represents the number of times the timer has expired since + /// the last time `wait` was called. If the timer has not yet expired once, + /// this call will block until it does. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// # use std::time::Duration; + /// # use std::thread::sleep; + /// use vmm_sys_util::timerfd::TimerFd; + /// + /// let mut timer = TimerFd::new().unwrap(); + /// let dur = Duration::from_millis(100); + /// let interval = Duration::from_millis(100); + /// timer.reset(dur, Some(interval)).unwrap(); + /// + /// sleep(dur * 3); + /// let count = timer.wait().unwrap(); + /// assert!(count >= 3); + /// ``` + pub fn wait(&mut self) -> Result<u64> { + let mut count = 0u64; + + // SAFETY: Safe because this will only modify |buf| and we check the return value. + let ret = unsafe { + libc::read( + self.as_raw_fd(), + &mut count as *mut _ as *mut libc::c_void, + mem::size_of_val(&count), + ) + }; + if ret < 0 { + return errno_result(); + } + + // The bytes in the buffer are guaranteed to be in native byte-order so we don't need to + // use from_le or from_be. + Ok(count) + } + + /// Tell if the timer is armed. + /// + /// Returns `Ok(true)` if the timer is currently armed, otherwise the errno set by + /// [`timerfd_gettime`](http://man7.org/linux/man-pages/man2/timerfd_create.2.html). + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// # use std::time::Duration; + /// use vmm_sys_util::timerfd::TimerFd; + /// + /// let mut timer = TimerFd::new().unwrap(); + /// let dur = Duration::from_millis(100); + /// + /// timer.reset(dur, None).unwrap(); + /// assert!(timer.is_armed().unwrap()); + /// ``` + pub fn is_armed(&self) -> Result<bool> { + // SAFETY: Safe because we are zero-initializing a struct with only primitive member fields. + let mut spec: libc::itimerspec = unsafe { mem::zeroed() }; + + // SAFETY: Safe because timerfd_gettime is trusted to only modify `spec`. + let ret = unsafe { timerfd_gettime(self.as_raw_fd(), &mut spec) }; + if ret < 0 { + return errno_result(); + } + + Ok(spec.it_value.tv_sec != 0 || spec.it_value.tv_nsec != 0) + } + + /// Disarm the timer. + /// + /// Set zero to disarm the timer, referring to + /// [`timerfd_settime`](http://man7.org/linux/man-pages/man2/timerfd_create.2.html). + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// # use std::time::Duration; + /// use vmm_sys_util::timerfd::TimerFd; + /// + /// let mut timer = TimerFd::new().unwrap(); + /// let dur = Duration::from_millis(100); + /// + /// timer.reset(dur, None).unwrap(); + /// timer.clear().unwrap(); + /// ``` + pub fn clear(&mut self) -> Result<()> { + // SAFETY: Safe because we are zero-initializing a struct with only primitive member fields. + let spec: libc::itimerspec = unsafe { mem::zeroed() }; + + // SAFETY: Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { timerfd_settime(self.as_raw_fd(), 0, &spec, ptr::null_mut()) }; + if ret < 0 { + return errno_result(); + } + + Ok(()) + } +} + +impl AsRawFd for TimerFd { + fn as_raw_fd(&self) -> RawFd { + self.0.as_raw_fd() + } +} + +impl FromRawFd for TimerFd { + /// This function is unsafe as the primitives currently returned + /// have the contract that they are the sole owner of the file + /// descriptor they are wrapping. Usage of this function could + /// accidentally allow violating this contract which can cause memory + /// unsafety in code that relies on it being true. + unsafe fn from_raw_fd(fd: RawFd) -> Self { + TimerFd(File::from_raw_fd(fd)) + } +} + +impl IntoRawFd for TimerFd { + fn into_raw_fd(self) -> RawFd { + self.0.into_raw_fd() + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::undocumented_unsafe_blocks)] + use super::*; + use std::thread::sleep; + use std::time::{Duration, Instant}; + + #[test] + fn test_from_raw_fd() { + let ret = unsafe { timerfd_create(CLOCK_MONOTONIC, TFD_CLOEXEC) }; + let tfd = unsafe { TimerFd::from_raw_fd(ret) }; + assert!(!tfd.is_armed().unwrap()); + } + + #[test] + fn test_into_raw_fd() { + let tfd = TimerFd::new().expect("failed to create timerfd"); + let fd = tfd.into_raw_fd(); + assert!(fd > 0); + } + #[test] + fn test_one_shot() { + let mut tfd = TimerFd::new().expect("failed to create timerfd"); + assert!(!tfd.is_armed().unwrap()); + + let dur = Duration::from_millis(200); + let now = Instant::now(); + tfd.reset(dur, None).expect("failed to arm timer"); + + assert!(tfd.is_armed().unwrap()); + + let count = tfd.wait().expect("unable to wait for timer"); + + assert_eq!(count, 1); + assert!(now.elapsed() >= dur); + tfd.clear().expect("unable to clear the timer"); + assert!(!tfd.is_armed().unwrap()); + } + + #[test] + fn test_repeating() { + let mut tfd = TimerFd::new().expect("failed to create timerfd"); + + let dur = Duration::from_millis(200); + let interval = Duration::from_millis(100); + tfd.reset(dur, Some(interval)).expect("failed to arm timer"); + + sleep(dur * 3); + + let count = tfd.wait().expect("unable to wait for timer"); + assert!(count >= 5, "count = {}", count); + tfd.clear().expect("unable to clear the timer"); + assert!(!tfd.is_armed().unwrap()); + } +} diff --git a/src/linux/write_zeroes.rs b/src/linux/write_zeroes.rs new file mode 100644 index 0000000..e6084da --- /dev/null +++ b/src/linux/write_zeroes.rs @@ -0,0 +1,330 @@ +// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Copyright 2019 Intel Corporation. All Rights Reserved. +// +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Traits for replacing a range with a hole and writing zeroes in a file. + +use std::cmp::min; +use std::fs::File; +use std::io::{Error, ErrorKind, Result, Seek, SeekFrom}; +use std::os::unix::fs::FileExt; + +use crate::fallocate::{fallocate, FallocateMode}; + +/// A trait for deallocating space in a file. +pub trait PunchHole { + /// Replace a range of bytes with a hole. + /// + /// # Arguments + /// + /// * `offset`: offset of the file where to replace with a hole. + /// * `length`: the number of bytes of the hole to replace with. + fn punch_hole(&mut self, offset: u64, length: u64) -> Result<()>; +} + +impl PunchHole for File { + fn punch_hole(&mut self, offset: u64, length: u64) -> Result<()> { + fallocate(self, FallocateMode::PunchHole, true, offset, length as u64) + .map_err(|e| Error::from_raw_os_error(e.errno())) + } +} + +/// A trait for writing zeroes to a stream. +pub trait WriteZeroes { + /// Write up to `length` bytes of zeroes to the stream, returning how many bytes were written. + /// + /// # Arguments + /// + /// * `length`: the number of bytes of zeroes to write to the stream. + fn write_zeroes(&mut self, length: usize) -> Result<usize>; + + /// Write zeroes to the stream until `length` bytes have been written. + /// + /// This method will continuously write zeroes until the requested `length` is satisfied or an + /// unrecoverable error is encountered. + /// + /// # Arguments + /// + /// * `length`: the exact number of bytes of zeroes to write to the stream. + fn write_all_zeroes(&mut self, mut length: usize) -> Result<()> { + while length > 0 { + match self.write_zeroes(length) { + Ok(0) => return Err(Error::from(ErrorKind::WriteZero)), + Ok(bytes_written) => { + length = length + .checked_sub(bytes_written) + .ok_or_else(|| Error::from(ErrorKind::Other))? + } + // If the operation was interrupted, we should retry it. + Err(e) => { + if e.kind() != ErrorKind::Interrupted { + return Err(e); + } + } + } + } + Ok(()) + } +} + +/// A trait for writing zeroes to an arbitrary position in a file. +pub trait WriteZeroesAt { + /// Write up to `length` bytes of zeroes starting at `offset`, returning how many bytes were + /// written. + /// + /// # Arguments + /// + /// * `offset`: offset of the file where to write zeroes. + /// * `length`: the number of bytes of zeroes to write to the stream. + fn write_zeroes_at(&mut self, offset: u64, length: usize) -> Result<usize>; + + /// Write zeroes starting at `offset` until `length` bytes have been written. + /// + /// This method will continuously write zeroes until the requested `length` is satisfied or an + /// unrecoverable error is encountered. + /// + /// # Arguments + /// + /// * `offset`: offset of the file where to write zeroes. + /// * `length`: the exact number of bytes of zeroes to write to the stream. + fn write_all_zeroes_at(&mut self, mut offset: u64, mut length: usize) -> Result<()> { + while length > 0 { + match self.write_zeroes_at(offset, length) { + Ok(0) => return Err(Error::from(ErrorKind::WriteZero)), + Ok(bytes_written) => { + length = length + .checked_sub(bytes_written) + .ok_or_else(|| Error::from(ErrorKind::Other))?; + offset = offset + .checked_add(bytes_written as u64) + .ok_or_else(|| Error::from(ErrorKind::Other))?; + } + Err(e) => { + // If the operation was interrupted, we should retry it. + if e.kind() != ErrorKind::Interrupted { + return Err(e); + } + } + } + } + Ok(()) + } +} + +impl WriteZeroesAt for File { + fn write_zeroes_at(&mut self, offset: u64, length: usize) -> Result<usize> { + // Try to use fallocate() first, since it is more efficient than writing zeroes with + // write(). + if fallocate(self, FallocateMode::ZeroRange, true, offset, length as u64).is_ok() { + return Ok(length); + } + + // Fall back to write(). + // fallocate() failed; fall back to writing a buffer of zeroes until we have written up + // to `length`. + let buf_size = min(length, 0x10000); + let buf = vec![0u8; buf_size]; + let mut num_written: usize = 0; + while num_written < length { + let remaining = length - num_written; + let write_size = min(remaining, buf_size); + num_written += self.write_at(&buf[0..write_size], offset + num_written as u64)?; + } + Ok(length) + } +} + +impl<T: WriteZeroesAt + Seek> WriteZeroes for T { + fn write_zeroes(&mut self, length: usize) -> Result<usize> { + let offset = self.seek(SeekFrom::Current(0))?; + let num_written = self.write_zeroes_at(offset, length)?; + // Advance the seek cursor as if we had done a real write(). + self.seek(SeekFrom::Current(num_written as i64))?; + Ok(length) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::io::{Read, Seek, SeekFrom, Write}; + + use crate::tempfile::TempFile; + + #[test] + fn test_small_write_zeroes() { + const NON_ZERO_VALUE: u8 = 0x55; + const BUF_SIZE: usize = 5678; + + let mut f = TempFile::new().unwrap().into_file(); + f.set_len(16384).unwrap(); + + // Write buffer of non-zero bytes to offset 1234. + let orig_data = [NON_ZERO_VALUE; BUF_SIZE]; + f.seek(SeekFrom::Start(1234)).unwrap(); + f.write_all(&orig_data).unwrap(); + + // Read back the data plus some overlap on each side. + let mut readback = [0u8; 16384]; + f.seek(SeekFrom::Start(0)).unwrap(); + f.read_exact(&mut readback).unwrap(); + // Bytes before the write should still be 0. + for read in &readback[0..1234] { + assert_eq!(*read, 0); + } + // Bytes that were just written should have `NON_ZERO_VALUE` value. + for read in &readback[1234..(1234 + BUF_SIZE)] { + assert_eq!(*read, NON_ZERO_VALUE); + } + // Bytes after the written area should still be 0. + for read in &readback[(1234 + BUF_SIZE)..] { + assert_eq!(*read, 0); + } + + // Overwrite some of the data with zeroes. + f.seek(SeekFrom::Start(2345)).unwrap(); + f.write_all_zeroes(4321).unwrap(); + // Verify seek position after `write_all_zeroes()`. + assert_eq!(f.seek(SeekFrom::Current(0)).unwrap(), 2345 + 4321); + + // Read back the data and verify that it is now zero. + f.seek(SeekFrom::Start(0)).unwrap(); + f.read_exact(&mut readback).unwrap(); + // Bytes before the write should still be 0. + for read in &readback[0..1234] { + assert_eq!(*read, 0); + } + // Original data should still exist before the zeroed region. + for read in &readback[1234..2345] { + assert_eq!(*read, NON_ZERO_VALUE); + } + // Verify that `write_all_zeroes()` zeroed the intended region. + for read in &readback[2345..(2345 + 4321)] { + assert_eq!(*read, 0); + } + // Original data should still exist after the zeroed region. + for read in &readback[(2345 + 4321)..(1234 + BUF_SIZE)] { + assert_eq!(*read, NON_ZERO_VALUE); + } + // The rest of the file should still be 0. + for read in &readback[(1234 + BUF_SIZE)..] { + assert_eq!(*read, 0); + } + } + + #[test] + fn test_large_write_zeroes() { + const NON_ZERO_VALUE: u8 = 0x55; + const SIZE: usize = 0x2_0000; + + let mut f = TempFile::new().unwrap().into_file(); + f.set_len(16384).unwrap(); + + // Write buffer of non-zero bytes. The size of the buffer will be the new + // size of the file. + let orig_data = [NON_ZERO_VALUE; SIZE]; + f.seek(SeekFrom::Start(0)).unwrap(); + f.write_all(&orig_data).unwrap(); + assert_eq!(f.metadata().unwrap().len(), SIZE as u64); + + // Overwrite some of the data with zeroes. + f.seek(SeekFrom::Start(0)).unwrap(); + f.write_all_zeroes(0x1_0001).unwrap(); + // Verify seek position after `write_all_zeroes()`. + assert_eq!(f.seek(SeekFrom::Current(0)).unwrap(), 0x1_0001); + + // Read back the data and verify that it is now zero. + let mut readback = [0u8; SIZE]; + f.seek(SeekFrom::Start(0)).unwrap(); + f.read_exact(&mut readback).unwrap(); + // Verify that `write_all_zeroes()` zeroed the intended region. + for read in &readback[0..0x1_0001] { + assert_eq!(*read, 0); + } + // Original data should still exist after the zeroed region. + for read in &readback[0x1_0001..SIZE] { + assert_eq!(*read, NON_ZERO_VALUE); + } + + // Now let's zero a certain region by using `write_all_zeroes_at()`. + f.write_all_zeroes_at(0x1_8001, 0x200).unwrap(); + f.seek(SeekFrom::Start(0)).unwrap(); + f.read_exact(&mut readback).unwrap(); + + // Original data should still exist before the zeroed region. + for read in &readback[0x1_0001..0x1_8001] { + assert_eq!(*read, NON_ZERO_VALUE); + } + // Verify that `write_all_zeroes_at()` zeroed the intended region. + for read in &readback[0x1_8001..(0x1_8001 + 0x200)] { + assert_eq!(*read, 0); + } + // Original data should still exist after the zeroed region. + for read in &readback[(0x1_8001 + 0x200)..SIZE] { + assert_eq!(*read, NON_ZERO_VALUE); + } + } + + #[test] + fn test_punch_hole() { + const NON_ZERO_VALUE: u8 = 0x55; + const SIZE: usize = 0x2_0000; + + let mut f = TempFile::new().unwrap().into_file(); + f.set_len(16384).unwrap(); + + // Write buffer of non-zero bytes. The size of the buffer will be the new + // size of the file. + let orig_data = [NON_ZERO_VALUE; SIZE]; + f.seek(SeekFrom::Start(0)).unwrap(); + f.write_all(&orig_data).unwrap(); + assert_eq!(f.metadata().unwrap().len(), SIZE as u64); + + // Punch a hole at offset 0x10001. + // Subsequent reads from this range will return zeros. + f.punch_hole(0x1_0001, 0x200).unwrap(); + + // Read back the data. + let mut readback = [0u8; SIZE]; + f.seek(SeekFrom::Start(0)).unwrap(); + f.read_exact(&mut readback).unwrap(); + // Original data should still exist before the hole. + for read in &readback[0..0x1_0001] { + assert_eq!(*read, NON_ZERO_VALUE); + } + // Verify that `punch_hole()` zeroed the intended region. + for read in &readback[0x1_0001..(0x1_0001 + 0x200)] { + assert_eq!(*read, 0); + } + // Original data should still exist after the hole. + for read in &readback[(0x1_0001 + 0x200)..] { + assert_eq!(*read, NON_ZERO_VALUE); + } + + // Punch a hole at the end of the file. + // Subsequent reads from this range should return zeros. + f.punch_hole(SIZE as u64 - 0x400, 0x400).unwrap(); + // Even though we punched a hole at the end of the file, the file size should remain the + // same since FALLOC_FL_PUNCH_HOLE must be used with FALLOC_FL_KEEP_SIZE. + assert_eq!(f.metadata().unwrap().len(), SIZE as u64); + + let mut readback = [0u8; 0x400]; + f.seek(SeekFrom::Start(SIZE as u64 - 0x400)).unwrap(); + f.read_exact(&mut readback).unwrap(); + // Verify that `punch_hole()` zeroed the intended region. + for read in &readback[0..0x400] { + assert_eq!(*read, 0); + } + + // Punching a hole of len 0 should return an error. + assert!(f.punch_hole(0x200, 0x0).is_err()); + // Zeroing a region of len 0 should not return an error since we have a fallback path + // in `write_zeroes_at()` for `fallocate()` failure. + assert!(f.write_zeroes_at(0x200, 0x0).is_ok()); + } +} diff --git a/src/metric.rs b/src/metric.rs new file mode 100644 index 0000000..c728161 --- /dev/null +++ b/src/metric.rs @@ -0,0 +1,178 @@ +// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: BSD-3-Clause +//! The purpose of this module is to provide abstractions for working with +//! metrics in the context of rust-vmm components where there is a strong need +//! to have metrics as an optional feature. +//! +//! As multiple stakeholders are using these components, there are also +//! questions regarding the serialization format, as metrics are expected to be +//! flexible enough to allow different formatting, serialization and writers. +//! When using the rust-vmm metrics, the expectation is that VMMs built on top +//! of these components can choose what metrics they’re interested in and also +//! can add their own custom metrics without the need to maintain forks. + +use std::sync::atomic::{AtomicU64, Ordering}; + +/// Abstraction over the common metric operations. +/// +/// An object implementing `Metric` is expected to have an inner counter that +/// can be incremented and reset. The `Metric` trait can be used for +/// implementing a metric system backend (or an aggregator). +pub trait Metric { + /// Adds `value` to the current counter. + fn add(&self, value: u64); + /// Increments by 1 unit the current counter. + fn inc(&self) { + self.add(1); + } + /// Returns current value of the counter. + fn count(&self) -> u64; + /// Resets the metric counter. + fn reset(&self); + /// Set the metric counter `value`. + fn set(&self, value: u64); +} + +impl Metric for AtomicU64 { + /// Adds `value` to the current counter. + /// + /// According to + /// [`fetch_add` documentation](https://doc.rust-lang.org/std/sync/atomic/struct.AtomicU64.html#method.fetch_add), + /// in case of an integer overflow, the counter starts over from 0. + fn add(&self, value: u64) { + self.fetch_add(value, Ordering::Relaxed); + } + + /// Returns current value of the counter. + fn count(&self) -> u64 { + self.load(Ordering::Relaxed) + } + + /// Resets the metric counter to 0. + fn reset(&self) { + self.store(0, Ordering::Relaxed) + } + + /// Set the metric counter `value`. + fn set(&self, value: u64) { + self.store(value, Ordering::Relaxed); + } +} + +#[cfg(test)] +mod tests { + use crate::metric::Metric; + + use std::sync::atomic::AtomicU64; + use std::sync::Arc; + + struct Dog<T: DogEvents> { + metrics: T, + } + + // Trait that declares events that can happen during the lifetime of the + // `Dog` which should also have associated events (such as metrics). + trait DogEvents { + // Event to be called when the dog `bark`s. + fn inc_bark(&self); + // Event to be called when the dog `eat`s. + fn inc_eat(&self); + // Event to be called when the dog `eat`s a lot. + fn set_eat(&self, no_times: u64); + } + + impl<T: DogEvents> Dog<T> { + fn bark(&self) { + println!("bark! bark!"); + self.metrics.inc_bark(); + } + + fn eat(&self) { + println!("nom! nom!"); + self.metrics.inc_eat(); + } + + fn eat_more_times(&self, no_times: u64) { + self.metrics.set_eat(no_times); + } + } + + impl<T: DogEvents> Dog<T> { + fn new_with_metrics(metrics: T) -> Self { + Self { metrics } + } + } + + #[test] + fn test_main() { + // The `Metric` trait is implemented for `AtomicUsize` so we can easily use it as the + // counter for the dog events. + #[derive(Default, Debug)] + struct DogEventMetrics { + bark: AtomicU64, + eat: AtomicU64, + } + + impl DogEvents for Arc<DogEventMetrics> { + fn inc_bark(&self) { + self.bark.inc(); + } + + fn inc_eat(&self) { + self.eat.inc(); + } + + fn set_eat(&self, no_times: u64) { + self.eat.set(no_times); + } + } + + impl DogEventMetrics { + fn reset(&self) { + self.bark.reset(); + self.eat.reset(); + } + } + + // This is the central object of mini-app built in this example. + // All the metrics that might be needed by the app are referenced through the + // `SystemMetrics` object. The `SystemMetric` also decides how to format the metrics. + // In this simple example, the metrics are formatted with the dummy Debug formatter. + #[derive(Default)] + struct SystemMetrics { + pub(crate) dog_metrics: Arc<DogEventMetrics>, + } + + impl SystemMetrics { + fn serialize(&self) -> String { + let mut serialized_metrics = format!("{:#?}", &self.dog_metrics); + // We can choose to reset the metrics right after we format them for serialization. + self.dog_metrics.reset(); + + serialized_metrics.retain(|c| !c.is_whitespace()); + serialized_metrics + } + } + + let system_metrics = SystemMetrics::default(); + let dog = Dog::new_with_metrics(system_metrics.dog_metrics.clone()); + dog.bark(); + dog.bark(); + dog.eat(); + + let expected_metrics = String::from("DogEventMetrics{bark:2,eat:1,}"); + let actual_metrics = system_metrics.serialize(); + assert_eq!(expected_metrics, actual_metrics); + + assert_eq!(system_metrics.dog_metrics.eat.count(), 0); + assert_eq!(system_metrics.dog_metrics.bark.count(), 0); + + // Set `std::u64::MAX` value to `eat` metric. + dog.eat_more_times(std::u64::MAX); + assert_eq!(system_metrics.dog_metrics.eat.count(), std::u64::MAX); + // Check that `add()` wraps around on overflow. + dog.eat(); + dog.eat(); + assert_eq!(system_metrics.dog_metrics.eat.count(), 1); + } +} diff --git a/src/rand.rs b/src/rand.rs new file mode 100644 index 0000000..097341f --- /dev/null +++ b/src/rand.rs @@ -0,0 +1,163 @@ +// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! Miscellaneous functions related to getting (pseudo) random numbers and +//! strings. +//! +//! NOTE! This should not be used when you do need __real__ random numbers such +//! as for encryption but will probably be suitable when you want locally +//! unique ID's that will not be shared over the network. + +use std::ffi::OsString; +use std::str; + +/// Gets an ever increasing u64 (at least for this process). +/// +/// The number retrieved will be based upon the time of the last reboot (x86_64) +/// and something undefined for other architectures. +pub fn timestamp_cycles() -> u64 { + #[cfg(target_arch = "x86_64")] + // SAFETY: Safe because there's nothing that can go wrong with this call. + unsafe { + std::arch::x86_64::_rdtsc() as u64 + } + + #[cfg(not(target_arch = "x86_64"))] + { + const MONOTONIC_CLOCK_MULTPIPLIER: u64 = 1_000_000_000; + + let mut ts = libc::timespec { + tv_sec: 0, + tv_nsec: 0, + }; + // SAFETY: We initialized the parameters correctly and we trust the function. + unsafe { + libc::clock_gettime(libc::CLOCK_MONOTONIC, &mut ts); + } + (ts.tv_sec as u64) * MONOTONIC_CLOCK_MULTPIPLIER + (ts.tv_nsec as u64) + } +} + +/// Generate pseudo random u32 numbers based on the current timestamp. +pub fn xor_pseudo_rng_u32() -> u32 { + let mut t: u32 = timestamp_cycles() as u32; + // Taken from https://en.wikipedia.org/wiki/Xorshift + t ^= t << 13; + t ^= t >> 17; + t ^ (t << 5) +} + +// This will get an array of numbers that can safely be converted to strings +// because they will be in the range [a-zA-Z0-9]. The return vector could be any +// size between 0 and 4. +fn xor_pseudo_rng_u8_alphanumerics(rand_fn: &dyn Fn() -> u32) -> Vec<u8> { + rand_fn() + .to_ne_bytes() + .to_vec() + .drain(..) + .filter(|val| { + (48..=57).contains(val) || (65..=90).contains(val) || (97..=122).contains(val) + }) + .collect() +} + +fn xor_pseudo_rng_u8_bytes(rand_fn: &dyn Fn() -> u32) -> Vec<u8> { + rand_fn().to_ne_bytes().to_vec() +} + +fn rand_alphanumerics_impl(rand_fn: &dyn Fn() -> u32, len: usize) -> OsString { + let mut buf = OsString::new(); + let mut done = 0; + loop { + for n in xor_pseudo_rng_u8_alphanumerics(rand_fn) { + done += 1; + buf.push(str::from_utf8(&[n]).unwrap_or("_")); + if done >= len { + return buf; + } + } + } +} + +fn rand_bytes_impl(rand_fn: &dyn Fn() -> u32, len: usize) -> Vec<u8> { + let mut buf: Vec<Vec<u8>> = Vec::new(); + let mut num = if len % 4 == 0 { len / 4 } else { len / 4 + 1 }; + while num > 0 { + buf.push(xor_pseudo_rng_u8_bytes(rand_fn)); + num -= 1; + } + buf.into_iter().flatten().take(len).collect() +} + +/// Gets a pseudo random OsString of length `len` with characters in the +/// range [a-zA-Z0-9]. +pub fn rand_alphanumerics(len: usize) -> OsString { + rand_alphanumerics_impl(&xor_pseudo_rng_u32, len) +} + +/// Get a pseudo random vector of `len` bytes. +pub fn rand_bytes(len: usize) -> Vec<u8> { + rand_bytes_impl(&xor_pseudo_rng_u32, len) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_timestamp_cycles() { + for _ in 0..1000 { + assert!(timestamp_cycles() < timestamp_cycles()); + } + } + + #[test] + fn test_xor_pseudo_rng_u32() { + for _ in 0..1000 { + assert_ne!(xor_pseudo_rng_u32(), xor_pseudo_rng_u32()); + } + } + + #[test] + fn test_xor_pseudo_rng_u8_alphas() { + let i = 3612982; // 55 (shifted 16 places), 33 (shifted 8 places), 54... + // The 33 will be discarded as it is not a valid letter + // (upper or lower) or number. + let s = xor_pseudo_rng_u8_alphanumerics(&|| i); + assert_eq!(vec![54, 55], s); + } + + #[test] + fn test_rand_alphanumerics_impl() { + let s = rand_alphanumerics_impl(&|| 14134, 5); + assert_eq!("67676", s); + } + + #[test] + fn test_rand_alphanumerics() { + let s = rand_alphanumerics(5); + assert_eq!(5, s.len()); + } + + #[test] + fn test_xor_pseudo_rng_u8_bytes() { + let i = 3612982; // 55 (shifted 16 places), 33 (shifted 8 places), 54... + // The 33 will be discarded as it is not a valid letter + // (upper or lower) or number. + let s = xor_pseudo_rng_u8_bytes(&|| i); + assert_eq!(vec![54, 33, 55, 0], s); + } + + #[test] + fn test_rand_bytes_impl() { + let s = rand_bytes_impl(&|| 1234567, 4); + assert_eq!(vec![135, 214, 18, 0], s); + } + + #[test] + fn test_rand_bytes() { + for i in 0..8 { + assert_eq!(i, rand_bytes(i).len()); + } + } +} diff --git a/src/syscall.rs b/src/syscall.rs new file mode 100644 index 0000000..6fb4d64 --- /dev/null +++ b/src/syscall.rs @@ -0,0 +1,50 @@ +// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! Wrapper for interpreting syscall exit codes. + +use std::os::raw::c_int; + +/// Wrapper to interpret syscall exit codes and provide a rustacean `io::Result`. +pub struct SyscallReturnCode(pub c_int); + +impl SyscallReturnCode { + /// Returns the last OS error if value is -1 or Ok(value) otherwise. + pub fn into_result(self) -> std::io::Result<c_int> { + if self.0 == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(self.0) + } + } + /// Returns the last OS error if value is -1 or Ok(()) otherwise. + pub fn into_empty_result(self) -> std::io::Result<()> { + self.into_result().map(|_| ()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_syscall_ops() { + let mut syscall_code = SyscallReturnCode(1); + match syscall_code.into_result() { + Ok(_value) => (), + _ => unreachable!(), + } + + syscall_code = SyscallReturnCode(-1); + assert!(syscall_code.into_result().is_err()); + + syscall_code = SyscallReturnCode(1); + match syscall_code.into_empty_result() { + Ok(()) => (), + _ => unreachable!(), + } + + syscall_code = SyscallReturnCode(-1); + assert!(syscall_code.into_empty_result().is_err()); + } +} diff --git a/src/tempfile.rs b/src/tempfile.rs new file mode 100644 index 0000000..7d70f66 --- /dev/null +++ b/src/tempfile.rs @@ -0,0 +1,304 @@ +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-3-Clause file. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Struct for handling temporary files as well as any cleanup required. +//! +//! The temporary files will be created with a name available as well as having +//! an exposed `fs::File` for reading/writing. +//! +//! The file will be removed when the object goes out of scope. +//! +//! # Examples +//! +//! ``` +//! use std::env::temp_dir; +//! use std::io::Write; +//! use std::path::{Path, PathBuf}; +//! use vmm_sys_util::tempfile::TempFile; +//! +//! let mut prefix = temp_dir(); +//! prefix.push("tempfile"); +//! let t = TempFile::new_with_prefix(prefix).unwrap(); +//! let mut f = t.as_file(); +//! f.write_all(b"hello world").unwrap(); +//! f.sync_all().unwrap(); + +use std::env::temp_dir; +use std::ffi::OsStr; +use std::fs; +use std::fs::File; +use std::path::{Path, PathBuf}; + +use libc; + +use crate::errno::{errno_result, Error, Result}; + +/// Wrapper for working with temporary files. +/// +/// The file will be maintained for the lifetime of the `TempFile` object. +pub struct TempFile { + path: PathBuf, + file: Option<File>, +} + +impl TempFile { + /// Creates the TempFile using a prefix. + /// + /// # Arguments + /// + /// `prefix`: The path and filename where to create the temporary file. Six + /// random alphanumeric characters will be added to the end of this to form + /// the filename. + #[cfg(unix)] + pub fn new_with_prefix<P: AsRef<OsStr>>(prefix: P) -> Result<TempFile> { + use std::ffi::CString; + use std::os::unix::{ffi::OsStrExt, io::FromRawFd}; + + let mut os_fname = prefix.as_ref().to_os_string(); + os_fname.push("XXXXXX"); + + let raw_fname = match CString::new(os_fname.as_bytes()) { + Ok(c_string) => c_string.into_raw(), + Err(_) => return Err(Error::new(libc::EINVAL)), + }; + + // SAFETY: Safe because `raw_fname` originates from CString::into_raw, meaning + // it is a pointer to a nul-terminated sequence of characters. + let fd = unsafe { libc::mkstemp(raw_fname) }; + if fd == -1 { + return errno_result(); + } + + // SAFETY: raw_fname originates from a call to CString::into_raw. The length + // of the string has not changed, as mkstemp returns a valid file name, and + // '\0' cannot be part of a valid filename. + let c_tempname = unsafe { CString::from_raw(raw_fname) }; + let os_tempname = OsStr::from_bytes(c_tempname.as_bytes()); + + // SAFETY: Safe because we checked `fd != -1` above and we uniquely own the file + // descriptor. This `fd` will be freed etc when `File` and thus + // `TempFile` goes out of scope. + let file = unsafe { File::from_raw_fd(fd) }; + + Ok(TempFile { + path: PathBuf::from(os_tempname), + file: Some(file), + }) + } + + /// Creates the TempFile using a prefix. + /// + /// # Arguments + /// + /// `prefix`: The path and filename where to create the temporary file. Six + /// random alphanumeric characters will be added to the end of this to form + /// the filename. + #[cfg(windows)] + pub fn new_with_prefix<P: AsRef<OsStr>>(prefix: P) -> Result<TempFile> { + use crate::rand::rand_alphanumerics; + use std::fs::OpenOptions; + + let file_path_str = format!( + "{}{}", + prefix.as_ref().to_str().unwrap_or_default(), + rand_alphanumerics(6).to_str().unwrap_or_default() + ); + let file_path_buf = PathBuf::from(&file_path_str); + + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(true) + .open(file_path_buf.as_path())?; + + Ok(TempFile { + path: file_path_buf, + file: Some(file), + }) + } + + /// Creates the TempFile inside a specific location. + /// + /// # Arguments + /// + /// `path`: The path where to create a temporary file with a filename formed from + /// six random alphanumeric characters. + pub fn new_in(path: &Path) -> Result<Self> { + let mut path_buf = path.canonicalize().unwrap(); + // This `push` adds a trailing slash ("/whatever/path" -> "/whatever/path/"). + // This is safe for paths with an already existing trailing slash. + path_buf.push(""); + let temp_file = TempFile::new_with_prefix(path_buf.as_path())?; + Ok(temp_file) + } + + /// Creates the TempFile. + /// + /// Creates a temporary file inside `$TMPDIR` if set, otherwise inside `/tmp`. + /// The filename will consist of six random alphanumeric characters. + pub fn new() -> Result<Self> { + let in_tmp_dir = temp_dir(); + let temp_file = TempFile::new_in(in_tmp_dir.as_path())?; + Ok(temp_file) + } + + /// Removes the temporary file. + /// + /// Calling this is optional as dropping a `TempFile` object will also + /// remove the file. Calling remove explicitly allows for better error + /// handling. + pub fn remove(&mut self) -> Result<()> { + fs::remove_file(&self.path).map_err(Error::from) + } + + /// Returns the path to the file if the `TempFile` object that is wrapping the file + /// is still in scope. + /// + /// If we remove the file by explicitly calling [`remove`](#method.remove), + /// `as_path()` can still be used to return the path to that file (even though that + /// path does not point at an existing entity anymore). + /// Calling `as_path()` after `remove()` is useful, for example, when you need a + /// random path string, but don't want an actual resource at that path. + pub fn as_path(&self) -> &Path { + &self.path + } + + /// Returns a reference to the File. + pub fn as_file(&self) -> &File { + // It's safe to unwrap because `file` can be `None` only after calling `into_file` + // which consumes this object. + self.file.as_ref().unwrap() + } + + /// Consumes the TempFile, returning the wrapped file. + /// + /// This also removes the file from the system. The file descriptor remains opened and + /// it can be used until the returned file is dropped. + pub fn into_file(mut self) -> File { + self.file.take().unwrap() + } +} + +impl Drop for TempFile { + fn drop(&mut self) { + let _ = self.remove(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::{Read, Write}; + + #[test] + fn test_create_file_with_prefix() { + fn between(lower: u8, upper: u8, to_check: u8) -> bool { + (to_check >= lower) && (to_check <= upper) + } + + let mut prefix = temp_dir(); + prefix.push("asdf"); + let t = TempFile::new_with_prefix(&prefix).unwrap(); + let path = t.as_path().to_owned(); + + // Check filename exists + assert!(path.is_file()); + + // Check filename is in the correct location + assert!(path.starts_with(temp_dir())); + + // Check filename has random added + assert_eq!(path.file_name().unwrap().to_string_lossy().len(), 10); + + // Check filename has only ascii letters / numbers + for n in path.file_name().unwrap().to_string_lossy().bytes() { + assert!(between(b'0', b'9', n) || between(b'a', b'z', n) || between(b'A', b'Z', n)); + } + + // Check we can write to the file + let mut f = t.as_file(); + f.write_all(b"hello world").unwrap(); + f.sync_all().unwrap(); + assert_eq!(f.metadata().unwrap().len(), 11); + } + + #[test] + fn test_create_file_new() { + let t = TempFile::new().unwrap(); + let path = t.as_path().to_owned(); + + // Check filename is in the correct location + assert!(path.starts_with(temp_dir().canonicalize().unwrap())); + } + + #[test] + fn test_create_file_new_in() { + let t = TempFile::new_in(temp_dir().as_path()).unwrap(); + let path = t.as_path().to_owned(); + + // Check filename exists + assert!(path.is_file()); + + // Check filename is in the correct location + assert!(path.starts_with(temp_dir().canonicalize().unwrap())); + + let t = TempFile::new_in(temp_dir().as_path()).unwrap(); + let path = t.as_path().to_owned(); + + // Check filename is in the correct location + assert!(path.starts_with(temp_dir().canonicalize().unwrap())); + } + + #[test] + fn test_remove_file() { + let mut prefix = temp_dir(); + prefix.push("asdf"); + + let mut t = TempFile::new_with_prefix(prefix).unwrap(); + let path = t.as_path().to_owned(); + + // Check removal. + assert!(t.remove().is_ok()); + assert!(!path.exists()); + + // Calling `as_path()` after the file was removed is allowed. + let path_2 = t.as_path().to_owned(); + assert_eq!(path, path_2); + + // Check trying to remove a second time returns an error. + assert!(t.remove().is_err()); + } + + #[test] + fn test_drop_file() { + let mut prefix = temp_dir(); + prefix.push("asdf"); + + let t = TempFile::new_with_prefix(prefix).unwrap(); + let path = t.as_path().to_owned(); + + assert!(path.starts_with(temp_dir())); + drop(t); + assert!(!path.exists()); + } + + #[test] + fn test_into_file() { + let mut prefix = temp_dir(); + prefix.push("asdf"); + + let text = b"hello world"; + let temp_file = TempFile::new_with_prefix(prefix).unwrap(); + let path = temp_file.as_path().to_owned(); + fs::write(path, text).unwrap(); + + let mut file = temp_file.into_file(); + let mut buf: Vec<u8> = Vec::new(); + file.read_to_end(&mut buf).unwrap(); + assert_eq!(buf, text); + } +} diff --git a/src/unix/file_traits.rs b/src/unix/file_traits.rs new file mode 100644 index 0000000..ce9fa29 --- /dev/null +++ b/src/unix/file_traits.rs @@ -0,0 +1,101 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Traits for handling file synchronization and length. + +use std::fs::File; +use std::io::Result; + +/// A trait for flushing the contents of a file to disk. +/// +/// This is equivalent to +/// [`std::fd::File::sync_all`](https://doc.rust-lang.org/std/fs/struct.File.html#method.sync_all) +/// method, but wrapped in a trait so that it can be implemented for other types. +pub trait FileSync { + /// Flush buffers related to this file to disk. + fn fsync(&mut self) -> Result<()>; +} + +impl FileSync for File { + fn fsync(&mut self) -> Result<()> { + self.sync_all() + } +} + +/// A trait for setting the size of a file. +/// +/// This is equivalent to +/// [`std::fd::File::set_len`](https://doc.rust-lang.org/std/fs/struct.File.html#method.set_len) +/// method, but wrapped in a trait so that it can be implemented for other types. +pub trait FileSetLen { + /// Set the size of this file. + /// + /// This is the moral equivalent of + /// [`ftruncate`](http://man7.org/linux/man-pages/man3/ftruncate.3p.html). + /// + /// # Arguments + /// + /// * `len`: the size to set for file. + fn set_len(&self, len: u64) -> Result<()>; +} + +impl FileSetLen for File { + fn set_len(&self, len: u64) -> Result<()> { + File::set_len(self, len) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs::OpenOptions; + use std::io::{Seek, SeekFrom, Write}; + use std::path::PathBuf; + + use crate::tempdir::TempDir; + + #[test] + fn test_fsync() { + let tempdir = TempDir::new_with_prefix("/tmp/fsync_test").unwrap(); + let mut path = PathBuf::from(tempdir.as_path()); + path.push("file"); + let mut f = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .open(&path) + .unwrap(); + f.write_all(b"Hello, world!").unwrap(); + f.fsync().unwrap(); + assert_eq!(f.metadata().unwrap().len(), 13); + } + + #[test] + fn test_set_len() { + let tempdir = TempDir::new_with_prefix("/tmp/set_len_test").unwrap(); + let mut path = PathBuf::from(tempdir.as_path()); + path.push("file"); + let mut f = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .open(&path) + .unwrap(); + f.set_len(10).unwrap(); + assert_eq!(f.seek(SeekFrom::End(0)).unwrap(), 10); + } + + #[test] + fn test_set_len_fails_when_file_not_opened_for_writing() { + let tempdir = TempDir::new_with_prefix("/tmp/set_len_test").unwrap(); + let mut path = PathBuf::from(tempdir.as_path()); + path.push("file"); + File::create(path.clone()).unwrap(); + let f = OpenOptions::new().read(true).open(&path).unwrap(); + let result = f.set_len(10); + assert!(result.is_err()); + } +} diff --git a/src/unix/mod.rs b/src/unix/mod.rs new file mode 100644 index 0000000..5c26a9c --- /dev/null +++ b/src/unix/mod.rs @@ -0,0 +1,5 @@ +// Copyright 2022 rust-vmm Authors or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: BSD-3-Clause +pub mod file_traits; +pub mod tempdir; +pub mod terminal; diff --git a/src/unix/tempdir.rs b/src/unix/tempdir.rs new file mode 100644 index 0000000..101d35a --- /dev/null +++ b/src/unix/tempdir.rs @@ -0,0 +1,205 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Structure for handling temporary directories. +use std::env::temp_dir; +use std::ffi::{CString, OsStr, OsString}; +use std::fs; +use std::os::unix::ffi::OsStringExt; +use std::path::{Path, PathBuf}; + +use crate::errno::{errno_result, Error, Result}; + +/// Wrapper over a temporary directory. +/// +/// The directory will be maintained for the lifetime of the `TempDir` object. +pub struct TempDir { + path: PathBuf, +} + +impl TempDir { + /// Creates a new temporary directory with `prefix`. + /// + /// The directory will be removed when the object goes out of scope. + /// + /// # Examples + /// + /// ``` + /// # use vmm_sys_util::tempdir::TempDir; + /// let t = TempDir::new_with_prefix("/tmp/testdir").unwrap(); + /// ``` + pub fn new_with_prefix<P: AsRef<OsStr>>(prefix: P) -> Result<TempDir> { + let mut dir_string = prefix.as_ref().to_os_string(); + dir_string.push("XXXXXX"); + // unwrap this result as the internal bytes can't have a null with a valid path. + let dir_name = CString::new(dir_string.into_vec()).unwrap(); + let mut dir_bytes = dir_name.into_bytes_with_nul(); + // SAFETY: Creating the directory isn't unsafe. The fact that it modifies the guts of the + // path is also OK because it only overwrites the last 6 Xs added above. + let ret = unsafe { libc::mkdtemp(dir_bytes.as_mut_ptr() as *mut libc::c_char) }; + if ret.is_null() { + return errno_result(); + } + dir_bytes.pop(); // Remove the null becasue from_vec can't handle it. + Ok(TempDir { + path: PathBuf::from(OsString::from_vec(dir_bytes)), + }) + } + + /// Creates a new temporary directory with inside `path`. + /// + /// The directory will be removed when the object goes out of scope. + /// + /// # Examples + /// + /// ``` + /// # use std::path::Path; + /// # use vmm_sys_util::tempdir::TempDir; + /// let t = TempDir::new_in(Path::new("/tmp/")).unwrap(); + /// ``` + pub fn new_in(path: &Path) -> Result<TempDir> { + let mut path_buf = path.canonicalize().unwrap(); + // This `push` adds a trailing slash ("/whatever/path" -> "/whatever/path/"). + // This is safe for paths with already trailing slash. + path_buf.push(""); + let temp_dir = TempDir::new_with_prefix(path_buf)?; + Ok(temp_dir) + } + + /// Creates a new temporary directory with inside `$TMPDIR` if set, otherwise in `/tmp`. + /// + /// The directory will be removed when the object goes out of scope. + /// + /// # Examples + /// + /// ``` + /// # use vmm_sys_util::tempdir::TempDir; + /// let t = TempDir::new().unwrap(); + /// ``` + pub fn new() -> Result<TempDir> { + let mut in_tmp_dir = temp_dir(); + // This `push` adds a trailing slash ("/tmp" -> "/tmp/"). + // This is safe for paths with already trailing slash. + in_tmp_dir.push(""); + let temp_dir = TempDir::new_in(in_tmp_dir.as_path())?; + Ok(temp_dir) + } + + /// Removes the temporary directory. + /// + /// Calling this is optional as when a `TempDir` object goes out of scope, + /// the directory will be removed. + /// Calling remove explicitly allows for better error handling. + /// + /// # Errors + /// + /// This function can only be called once per object. An error is returned + /// otherwise. + /// + /// # Examples + /// + /// ``` + /// # use std::path::Path; + /// # use std::path::PathBuf; + /// # use vmm_sys_util::tempdir::TempDir; + /// let temp_dir = TempDir::new_with_prefix("/tmp/testdir").unwrap(); + /// temp_dir.remove().unwrap(); + pub fn remove(&self) -> Result<()> { + fs::remove_dir_all(&self.path).map_err(Error::from) + } + + /// Returns the path to the tempdir. + /// + /// # Examples + /// + /// ``` + /// # use std::path::Path; + /// # use std::path::PathBuf; + /// # use vmm_sys_util::tempdir::TempDir; + /// let temp_dir = TempDir::new_with_prefix("/tmp/testdir").unwrap(); + /// assert!(temp_dir.as_path().exists()); + pub fn as_path(&self) -> &Path { + self.path.as_ref() + } +} + +impl Drop for TempDir { + fn drop(&mut self) { + let _ = self.remove(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_dir() { + let t = TempDir::new().unwrap(); + let path = t.as_path(); + assert!(path.exists()); + assert!(path.is_dir()); + assert!(path.starts_with(temp_dir())); + } + + #[test] + fn test_create_dir_with_prefix() { + let t = TempDir::new_with_prefix("/tmp/testdir").unwrap(); + let path = t.as_path(); + assert!(path.exists()); + assert!(path.is_dir()); + assert!(path.to_str().unwrap().contains("/tmp/testdir")); + } + + #[test] + fn test_remove_dir() { + use crate::tempfile::TempFile; + let t = TempDir::new().unwrap(); + let path = t.as_path().to_owned(); + assert!(t.remove().is_ok()); + // Calling remove twice returns error. + assert!(t.remove().is_err()); + assert!(!path.exists()); + + let t = TempDir::new().unwrap(); + let mut file = TempFile::new_in(t.as_path()).unwrap(); + let t2 = TempDir::new_in(t.as_path()).unwrap(); + let mut file2 = TempFile::new_in(t2.as_path()).unwrap(); + let path2 = t2.as_path().to_owned(); + assert!(t.remove().is_ok()); + // Calling t2.remove returns error because parent dir has removed + assert!(t2.remove().is_err()); + assert!(!path2.exists()); + assert!(file.remove().is_err()); + assert!(file2.remove().is_err()); + } + + #[test] + fn test_create_dir_in() { + let t = TempDir::new_in(Path::new("/tmp")).unwrap(); + let path = t.as_path(); + assert!(path.exists()); + assert!(path.is_dir()); + assert!(path.starts_with("/tmp/")); + + let t = TempDir::new_in(Path::new("/tmp")).unwrap(); + let path = t.as_path(); + assert!(path.exists()); + assert!(path.is_dir()); + assert!(path.starts_with("/tmp")); + } + + #[test] + fn test_drop() { + use std::mem::drop; + let t = TempDir::new_with_prefix("/tmp/asdf").unwrap(); + let path = t.as_path().to_owned(); + // Force tempdir object to go out of scope. + drop(t); + + assert!(!(path.exists())); + } +} diff --git a/src/unix/terminal.rs b/src/unix/terminal.rs new file mode 100644 index 0000000..093502a --- /dev/null +++ b/src/unix/terminal.rs @@ -0,0 +1,190 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// +// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// +// SPDX-License-Identifier: BSD-3-Clause + +//! Trait for working with [`termios`](http://man7.org/linux/man-pages/man3/termios.3.html). + +use std::io::StdinLock; +use std::mem::zeroed; +use std::os::unix::io::RawFd; + +use libc::{ + c_int, fcntl, isatty, read, tcgetattr, tcsetattr, termios, ECHO, F_GETFL, F_SETFL, ICANON, + ISIG, O_NONBLOCK, STDIN_FILENO, TCSANOW, +}; + +use crate::errno::{errno_result, Result}; + +fn modify_mode<F: FnOnce(&mut termios)>(fd: RawFd, f: F) -> Result<()> { + // SAFETY: Safe because we check the return value of isatty. + if unsafe { isatty(fd) } != 1 { + return Ok(()); + } + + // SAFETY: The following pair are safe because termios gets totally overwritten by tcgetattr + // and we check the return result. + let mut termios: termios = unsafe { zeroed() }; + // SAFETY: The parameter is valid and we check the result. + let ret = unsafe { tcgetattr(fd, &mut termios as *mut _) }; + if ret < 0 { + return errno_result(); + } + let mut new_termios = termios; + f(&mut new_termios); + // SAFETY: Safe because the syscall will only read the extent of termios and we check the + // return result. + let ret = unsafe { tcsetattr(fd, TCSANOW, &new_termios as *const _) }; + if ret < 0 { + return errno_result(); + } + + Ok(()) +} + +fn get_flags(fd: RawFd) -> Result<c_int> { + // SAFETY: Safe because no third parameter is expected and we check the return result. + let ret = unsafe { fcntl(fd, F_GETFL) }; + if ret < 0 { + return errno_result(); + } + Ok(ret) +} + +fn set_flags(fd: RawFd, flags: c_int) -> Result<()> { + // SAFETY: Safe because we supply the third parameter and we check the return result. + let ret = unsafe { fcntl(fd, F_SETFL, flags) }; + if ret < 0 { + return errno_result(); + } + Ok(()) +} + +/// Trait for file descriptors that are TTYs, according to +/// [`isatty`](http://man7.org/linux/man-pages/man3/isatty.3.html). +/// +/// # Safety +/// +/// This is marked unsafe because the implementation must ensure that the returned +/// RawFd is a valid fd and that the lifetime of the returned fd is at least that +/// of the trait object. +pub unsafe trait Terminal { + /// Get the file descriptor of the TTY. + fn tty_fd(&self) -> RawFd; + + /// Set this terminal to canonical mode (`ICANON | ECHO | ISIG`). + /// + /// Enable canonical mode with `ISIG` that generates signal when receiving + /// any of the characters INTR, QUIT, SUSP, or DSUSP, and with `ECHO` that echo + /// the input characters. Refer to + /// [`termios`](http://man7.org/linux/man-pages/man3/termios.3.html). + fn set_canon_mode(&self) -> Result<()> { + modify_mode(self.tty_fd(), |t| t.c_lflag |= ICANON | ECHO | ISIG) + } + + /// Set this terminal to raw mode. + /// + /// Unset the canonical mode with (`!(ICANON | ECHO | ISIG)`) which means + /// input is available character by character, echoing is disabled and special + /// signal of receiving characters INTR, QUIT, SUSP, or DSUSP is disabled. + fn set_raw_mode(&self) -> Result<()> { + modify_mode(self.tty_fd(), |t| t.c_lflag &= !(ICANON | ECHO | ISIG)) + } + + /// Set this terminal to non-blocking mode. + /// + /// If `non_block` is `true`, then `read_raw` will not block. + /// If `non_block` is `false`, then `read_raw` may block if + /// there is nothing to read. + fn set_non_block(&self, non_block: bool) -> Result<()> { + let old_flags = get_flags(self.tty_fd())?; + let new_flags = if non_block { + old_flags | O_NONBLOCK + } else { + old_flags & !O_NONBLOCK + }; + if new_flags != old_flags { + set_flags(self.tty_fd(), new_flags)? + } + Ok(()) + } + + /// Read from a [`Terminal`](trait.Terminal.html). + /// + /// Read up to `out.len()` bytes from this terminal without any buffering. + /// This may block, depending on if non-blocking was enabled with `set_non_block` + /// or if there are any bytes to read. + /// If there is at least one byte that is readable, this will not block. + /// + /// # Examples + /// + /// ``` + /// extern crate vmm_sys_util; + /// # use std::io; + /// # use std::os::unix::io::RawFd; + /// use vmm_sys_util::terminal::Terminal; + /// + /// let stdin_handle = io::stdin(); + /// let stdin = stdin_handle.lock(); + /// assert!(stdin.set_non_block(true).is_ok()); + /// + /// let mut out = [0u8; 0]; + /// assert_eq!(stdin.read_raw(&mut out[..]).unwrap(), 0); + /// ``` + fn read_raw(&self, out: &mut [u8]) -> Result<usize> { + // SAFETY: Safe because read will only modify the pointer up to the length we give it and + // we check the return result. + let ret = unsafe { read(self.tty_fd(), out.as_mut_ptr() as *mut _, out.len()) }; + if ret < 0 { + return errno_result(); + } + + Ok(ret as usize) + } +} + +// SAFETY: Safe because we return a genuine terminal fd that never changes and shares our lifetime. +unsafe impl<'a> Terminal for StdinLock<'a> { + fn tty_fd(&self) -> RawFd { + STDIN_FILENO + } +} + +#[cfg(test)] +mod tests { + #![allow(clippy::undocumented_unsafe_blocks)] + use super::*; + use std::fs::File; + use std::io; + use std::os::unix::io::AsRawFd; + use std::path::Path; + + unsafe impl Terminal for File { + fn tty_fd(&self) -> RawFd { + self.as_raw_fd() + } + } + + #[test] + fn test_a_tty() { + let stdin_handle = io::stdin(); + let stdin = stdin_handle.lock(); + + assert!(stdin.set_canon_mode().is_ok()); + assert!(stdin.set_raw_mode().is_ok()); + assert!(stdin.set_raw_mode().is_ok()); + assert!(stdin.set_canon_mode().is_ok()); + assert!(stdin.set_non_block(true).is_ok()); + let mut out = [0u8; 0]; + assert!(stdin.read_raw(&mut out[..]).is_ok()); + } + + #[test] + fn test_a_non_tty() { + let file = File::open(Path::new("/dev/zero")).unwrap(); + assert!(file.set_canon_mode().is_ok()); + } +} |