aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeongik Cha <jeongik@google.com>2023-09-27 08:11:51 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2023-09-27 08:11:51 +0000
commitfd8b88fc46ac8d4465b4ea1c963622be2c7f857c (patch)
tree5cd4944617e8f2802a8fe259dcbde9a8cc6326bc
parent088f9ee4aac7215be65f1941c100f7e14362e2f4 (diff)
parent777a4130a112a67b398cf73537c055a15030da3f (diff)
downloadvhost-user-backend-fd8b88fc46ac8d4465b4ea1c963622be2c7f857c.tar.gz
Import vhost-user-backend am: 777a4130a1
Original change: https://android-review.googlesource.com/c/platform/external/rust/crates/vhost-user-backend/+/2752353 Change-Id: I2c8695bbda6410995972a5b2a72377b789bab551 Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
-rw-r--r--.cargo_vcs_info.json6
-rw-r--r--Android.bp22
-rw-r--r--CHANGELOG.md104
-rw-r--r--Cargo.toml78
-rw-r--r--Cargo.toml.orig27
-rw-r--r--LICENSE235
-rw-r--r--LICENSE-APACHE202
-rw-r--r--LICENSE-BSD-3-Clause27
-rw-r--r--METADATA19
-rw-r--r--MODULE_LICENSE_APACHE20
-rw-r--r--OWNERS1
-rw-r--r--README.md117
-rw-r--r--cargo2android.json7
-rw-r--r--docs/vhost_architecture.drawio171
-rw-r--r--docs/vhost_architecture.pngbin0 -> 146074 bytes
-rw-r--r--src/backend.rs551
-rw-r--r--src/event_loop.rs270
-rw-r--r--src/handler.rs618
-rw-r--r--src/lib.rs270
-rw-r--r--src/net.rs19
-rw-r--r--src/vdpa.rs126
-rw-r--r--src/vhost_kern/mod.rs467
-rw-r--r--src/vhost_kern/net.rs177
-rw-r--r--src/vhost_kern/vdpa.rs560
-rw-r--r--src/vhost_kern/vhost_binding.rs545
-rw-r--r--src/vhost_kern/vsock.rs196
-rw-r--r--src/vhost_user/connection.rs903
-rw-r--r--src/vhost_user/dummy_slave.rs294
-rw-r--r--src/vhost_user/master.rs1110
-rw-r--r--src/vhost_user/master_req_handler.rs466
-rw-r--r--src/vhost_user/message.rs1403
-rw-r--r--src/vhost_user/mod.rs540
-rw-r--r--src/vhost_user/slave.rs86
-rw-r--r--src/vhost_user/slave_req.rs219
-rw-r--r--src/vhost_user/slave_req_handler.rs833
-rw-r--r--src/vring.rs585
-rw-r--r--src/vsock.rs30
-rw-r--r--tests/vhost-user-server.rs292
38 files changed, 11576 insertions, 0 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json
new file mode 100644
index 0000000..03acefd
--- /dev/null
+++ b/.cargo_vcs_info.json
@@ -0,0 +1,6 @@
+{
+ "git": {
+ "sha1": "6ca88e160a8d34be54a158cc1bc702e9250e97dd"
+ },
+ "path_in_vcs": "crates/vhost-user-backend"
+} \ No newline at end of file
diff --git a/Android.bp b/Android.bp
new file mode 100644
index 0000000..487cdd1
--- /dev/null
+++ b/Android.bp
@@ -0,0 +1,22 @@
+// This file is generated by cargo2android.py --config cargo2android.json.
+// Do not modify this file as changes will be overridden on upgrade.
+
+
+
+rust_library_host {
+ name: "libvhost_user_backend",
+ crate_name: "vhost_user_backend",
+ cargo_env_compat: true,
+ cargo_pkg_version: "0.10.1",
+ srcs: ["src/lib.rs"],
+ edition: "2018",
+ rustlibs: [
+ "liblibc",
+ "liblog_rust",
+ "libvhost_android",
+ "libvirtio_bindings",
+ "libvirtio_queue",
+ "libvm_memory_android",
+ "libvmm_sys_util",
+ ],
+}
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..4ae24f3
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,104 @@
+# Changelog
+## [Unreleased]
+
+### Added
+
+### Changed
+
+### Fixed
+
+### Deprecated
+
+## v0.10.1
+
+### Fixed
+- [[#180]](https://github.com/rust-vmm/vhost/pull/180) vhost-user-backend: fetch 'used' index from guest
+
+## v0.10.0
+
+### Added
+- [[#169]](https://github.com/rust-vmm/vhost/pull/160) vhost-user-backend: Add support for Xen memory mappings
+
+### Fixed
+- [[#161]](https://github.com/rust-vmm/vhost/pull/161) get_vring_base should not reset the queue
+
+## v0.9.0
+
+### Added
+- [[#138]](https://github.com/rust-vmm/vhost/pull/138): vhost-user-backend: add repository metadata
+
+### Changed
+- Updated dependency virtio-bindings 0.1.0 -> 0.2.0
+- Updated dependency virtio-queue 0.7.0 -> 0.8.0
+- Updated dependency vm-memory 0.10.0 -> 0.11.0
+
+### Fixed
+- [[#154]](https://github.com/rust-vmm/vhost/pull/154): Fix return value of GET_VRING_BASE message
+- [[#142]](https://github.com/rust-vmm/vhost/pull/142): vhost_user: Slave requests aren't only FS specific
+
+## v0.8.0
+
+### Added
+- [[#120]](https://github.com/rust-vmm/vhost/pull/120): vhost_kern: vdpa: Add missing ioctls
+
+### Changed
+- Updated dependency vhost 0.5 -> 0.6
+- Updated dependency virtio-queue 0.6 -> 0.7.0
+- Updated depepdency vm-memory 0.9 to 0.10.0
+- Updated depepdency vmm-sys-util 0.10 to 0.11.0
+
+## v0.7.0
+
+### Changed
+
+- Started using caret dependencies
+- Updated dependency nix 0.24 -> 0.25
+- Updated depepdency log 0.4.6 -> 0.4.17
+- Updated dependency vhost 0.4 -> 0.5
+- Updated dependency virtio-queue 0.5.0 -> 0.6
+- Updated dependency vm-memory 0.7 -> 0.9
+
+## v0.6.0
+
+### Changed
+
+- Moved to rust-vmm/virtio-queue v0.5.0
+
+### Fixed
+
+- Fixed vring initialization logic
+
+## v0.5.1
+
+### Changed
+- Moved to rust-vmm/vmm-sys-util 0.10.0
+
+## v0.5.0
+
+### Changed
+
+- Moved to rust-vmm/virtio-queue v0.4.0
+
+## v0.4.0
+
+### Changed
+
+- Moved to rust-vmm/virtio-queue v0.3.0
+- Relaxed rust-vmm/vm-memory dependency to require ">=0.7"
+
+## v0.3.0
+
+### Changed
+
+- Moved to rust-vmm/vhost v0.4.0
+
+## v0.2.0
+
+### Added
+
+- Ability to run the daemon as a client
+- VringEpollHandler implements AsRawFd
+
+## v0.1.0
+
+First release
diff --git a/Cargo.toml b/Cargo.toml
new file mode 100644
index 0000000..e8709f7
--- /dev/null
+++ b/Cargo.toml
@@ -0,0 +1,78 @@
+# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
+#
+# When uploading crates to the registry Cargo will automatically
+# "normalize" Cargo.toml files for maximal compatibility
+# with all versions of Cargo and also rewrite `path` dependencies
+# to registry (e.g., crates.io) dependencies.
+#
+# If you are reading this file be aware that the original Cargo.toml
+# will likely look very different (and much more reasonable).
+# See Cargo.toml.orig for the original contents.
+
+[package]
+edition = "2018"
+name = "vhost-user-backend"
+version = "0.10.1"
+authors = ["The Cloud Hypervisor Authors"]
+description = "A framework to build vhost-user backend service daemon"
+readme = "README.md"
+keywords = [
+ "vhost-user",
+ "virtio",
+]
+license = "Apache-2.0"
+repository = "https://github.com/rust-vmm/vhost"
+
+[dependencies.libc]
+version = "0.2.39"
+
+[dependencies.log]
+version = "0.4.17"
+
+[dependencies.vhost]
+version = "0.8"
+features = ["vhost-user-slave"]
+
+[dependencies.virtio-bindings]
+version = "0.2.1"
+
+[dependencies.virtio-queue]
+version = "0.9.0"
+
+[dependencies.vm-memory]
+version = "0.12.0"
+features = [
+ "backend-mmap",
+ "backend-atomic",
+]
+
+[dependencies.vmm-sys-util]
+version = "0.11.0"
+
+[dev-dependencies.nix]
+version = "0.26"
+
+[dev-dependencies.tempfile]
+version = "3.2.0"
+
+[dev-dependencies.vhost]
+version = "0.8"
+features = [
+ "test-utils",
+ "vhost-user-master",
+ "vhost-user-slave",
+]
+
+[dev-dependencies.vm-memory]
+version = "0.12.0"
+features = [
+ "backend-mmap",
+ "backend-atomic",
+ "backend-bitmap",
+]
+
+[features]
+xen = [
+ "vm-memory/xen",
+ "vhost/xen",
+]
diff --git a/Cargo.toml.orig b/Cargo.toml.orig
new file mode 100644
index 0000000..a42ea78
--- /dev/null
+++ b/Cargo.toml.orig
@@ -0,0 +1,27 @@
+[package]
+name = "vhost-user-backend"
+version = "0.10.1"
+authors = ["The Cloud Hypervisor Authors"]
+keywords = ["vhost-user", "virtio"]
+description = "A framework to build vhost-user backend service daemon"
+repository = "https://github.com/rust-vmm/vhost"
+edition = "2018"
+license = "Apache-2.0"
+
+[features]
+xen = ["vm-memory/xen", "vhost/xen"]
+
+[dependencies]
+libc = "0.2.39"
+log = "0.4.17"
+vhost = { path = "../vhost", version = "0.8", features = ["vhost-user-slave"] }
+virtio-bindings = "0.2.1"
+virtio-queue = "0.9.0"
+vm-memory = { version = "0.12.0", features = ["backend-mmap", "backend-atomic"] }
+vmm-sys-util = "0.11.0"
+
+[dev-dependencies]
+nix = "0.26"
+vhost = { path = "../vhost", version = "0.8", features = ["test-utils", "vhost-user-master", "vhost-user-slave"] }
+vm-memory = { version = "0.12.0", features = ["backend-mmap", "backend-atomic", "backend-bitmap"] }
+tempfile = "3.2.0"
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..c3968e2
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,235 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+---
+
+// Copyright (C) 2019 Alibaba Cloud. All rights reserved.
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Alibaba Inc. nor the names of its contributors
+// may be used to endorse or promote products derived from this software
+// without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+---
+
diff --git a/LICENSE-APACHE b/LICENSE-APACHE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/LICENSE-APACHE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/LICENSE-BSD-3-Clause b/LICENSE-BSD-3-Clause
new file mode 100644
index 0000000..1ff0cd7
--- /dev/null
+++ b/LICENSE-BSD-3-Clause
@@ -0,0 +1,27 @@
+// Copyright (C) 2019 Alibaba Cloud. All rights reserved.
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Alibaba Inc. nor the names of its contributors
+// may be used to endorse or promote products derived from this software
+// without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/METADATA b/METADATA
new file mode 100644
index 0000000..76caad4
--- /dev/null
+++ b/METADATA
@@ -0,0 +1,19 @@
+name: "vhost-user-backend"
+description: "A framework to build vhost-user backend service daemon"
+third_party {
+ identifier {
+ type: "crates.io"
+ value: "https://crates.io/crates/vhost-user-backend"
+ }
+ identifier {
+ type: "Archive"
+ value: "https://static.crates.io/crates/vhost-user-backend/vhost-user-backend-0.10.1.crate"
+ }
+ version: "0.10.1"
+ license_type: NOTICE
+ last_upgrade_date {
+ year: 2023
+ month: 8
+ day: 23
+ }
+}
diff --git a/MODULE_LICENSE_APACHE2 b/MODULE_LICENSE_APACHE2
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/MODULE_LICENSE_APACHE2
diff --git a/OWNERS b/OWNERS
new file mode 100644
index 0000000..45dc4dd
--- /dev/null
+++ b/OWNERS
@@ -0,0 +1 @@
+include platform/prebuilts/rust:master:/OWNERS
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..a526ab1
--- /dev/null
+++ b/README.md
@@ -0,0 +1,117 @@
+# vhost-user-backend
+
+## Design
+
+The `vhost-user-backend` crate provides a framework to implement `vhost-user` backend services,
+which includes following external public APIs:
+- A daemon control object (`VhostUserDaemon`) to start and stop the service daemon.
+- A vhost-user backend trait (`VhostUserBackendMut`) to handle vhost-user control messages and virtio
+ messages.
+- A vring access trait (`VringT`) to access virtio queues, and three implementations of the trait:
+ `VringState`, `VringMutex` and `VringRwLock`.
+
+## Usage
+The `vhost-user-backend` crate provides a framework to implement vhost-user backend services. The main interface provided by `vhost-user-backend` library is the `struct VhostUserDaemon`:
+```rust
+pub struct VhostUserDaemon<S, V, B = ()>
+where
+ S: VhostUserBackend<V, B>,
+ V: VringT<GM<B>> + Clone + Send + Sync + 'static,
+ B: Bitmap + 'static,
+{
+ pub fn new(name: String, backend: S, atomic_mem: GuestMemoryAtomic<GuestMemoryMmap<B>>) -> Result<Self>;
+ pub fn start(&mut self, listener: Listener) -> Result<()>;
+ pub fn wait(&mut self) -> Result<()>;
+ pub fn get_epoll_handlers(&self) -> Vec<Arc<VringEpollHandler<S, V, B>>>;
+}
+```
+
+### Create a `VhostUserDaemon` Instance
+The `VhostUserDaemon::new()` creates an instance of `VhostUserDaemon` object. The client needs to
+pass in an `VhostUserBackend` object, which will be used to configure the `VhostUserDaemon`
+instance, handle control messages from the vhost-user master and handle virtio requests from
+virtio queues. A group of working threads will be created to handle virtio requests from configured
+virtio queues.
+
+### Start the `VhostUserDaemon`
+The `VhostUserDaemon::start()` method waits for an incoming connection from the vhost-user masters
+on the `listener`. Once a connection is ready, a main thread will be created to handle vhost-user
+messages from the vhost-user master.
+
+### Stop the `VhostUserDaemon`
+The `VhostUserDaemon::stop()` method waits for the main thread to exit. An exit event must be sent
+to the main thread by writing to the `exit_event` EventFd before waiting for it to exit.
+
+### Threading Model
+The main thread and virtio queue working threads will concurrently access the underlying virtio
+queues, so all virtio queue in multi-threading model. But the main thread only accesses virtio
+queues for configuration, so client could adopt locking policies to optimize for the virtio queue
+working threads.
+
+## Example
+Example code to handle virtio messages from a virtio queue:
+```rust
+impl VhostUserBackendMut for VhostUserService {
+ fn process_queue(&mut self, vring: &VringMutex) -> Result<bool> {
+ let mut used_any = false;
+ let mem = match &self.mem {
+ Some(m) => m.memory(),
+ None => return Err(Error::NoMemoryConfigured),
+ };
+
+ let mut vring_state = vring.get_mut();
+
+ while let Some(avail_desc) = vring_state
+ .get_queue_mut()
+ .iter()
+ .map_err(|_| Error::IterateQueue)?
+ .next()
+ {
+ // Process the request...
+
+ if self.event_idx {
+ if vring_state.add_used(head_index, 0).is_err() {
+ warn!("Couldn't return used descriptors to the ring");
+ }
+
+ match vring_state.needs_notification() {
+ Err(_) => {
+ warn!("Couldn't check if queue needs to be notified");
+ vring_state.signal_used_queue().unwrap();
+ }
+ Ok(needs_notification) => {
+ if needs_notification {
+ vring_state.signal_used_queue().unwrap();
+ }
+ }
+ }
+ } else {
+ if vring_state.add_used(head_index, 0).is_err() {
+ warn!("Couldn't return used descriptors to the ring");
+ }
+ vring_state.signal_used_queue().unwrap();
+ }
+ }
+
+ Ok(used_any)
+ }
+}
+```
+
+## Xen support
+
+Supporting Xen requires special handling while mapping the guest memory. The
+`vm-memory` crate implements xen memory mapping support via a separate feature
+`xen`, and this crate uses the same feature name to enable Xen support.
+
+Also, for xen mappings, the memory regions passed by the frontend contains few
+extra fields as described in the vhost-user protocol documentation.
+
+It was decided by the `rust-vmm` maintainers to keep the interface simple and
+build the crate for either standard Unix memory mapping or Xen, and not both.
+
+## License
+
+This project is licensed under
+
+- [Apache License](http://www.apache.org/licenses/LICENSE-2.0), Version 2.0
diff --git a/cargo2android.json b/cargo2android.json
new file mode 100644
index 0000000..6f1f3ce
--- /dev/null
+++ b/cargo2android.json
@@ -0,0 +1,7 @@
+{
+ "run": true,
+ "dep-suffixes": {
+ "vhost": "_android",
+ "vm_memory": "_android"
+ }
+} \ No newline at end of file
diff --git a/docs/vhost_architecture.drawio b/docs/vhost_architecture.drawio
new file mode 100644
index 0000000..8c669d8
--- /dev/null
+++ b/docs/vhost_architecture.drawio
@@ -0,0 +1,171 @@
+<mxfile host="65bd71144e" modified="2021-02-22T05:37:26.833Z" agent="5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Code/1.53.0 Chrome/87.0.4280.141 Electron/11.2.1 Safari/537.36" etag="HWRXqybJYJqQhnlJWfmB" version="14.2.4" type="embed">
+ <diagram id="xCgrIAQPDQM0eynUYBOE" name="Page-1">
+ <mxGraphModel dx="3446" dy="1284" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
+ <root>
+ <mxCell id="0"/>
+ <mxCell id="1" parent="0"/>
+ <mxCell id="46" value="&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;" style="rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;fillColor=none;strokeColor=#4D4D4D;strokeWidth=5;" vertex="1" parent="1">
+ <mxGeometry x="1620" y="27" width="450" height="990" as="geometry"/>
+ </mxCell>
+ <mxCell id="47" value="" style="shape=hexagon;perimeter=hexagonPerimeter2;whiteSpace=wrap;html=1;fixedSize=1;rounded=0;labelBackgroundColor=none;sketch=0;fillColor=none;fontSize=25;dashed=1;strokeWidth=6;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="790" y="237" width="1260" height="750" as="geometry"/>
+ </mxCell>
+ <mxCell id="44" value="" style="rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;fillColor=none;strokeColor=#4D4D4D;strokeWidth=5;" vertex="1" parent="1">
+ <mxGeometry x="-10" y="37" width="1250" height="670" as="geometry"/>
+ </mxCell>
+ <mxCell id="2" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;MasterReqHandler&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1">
+ <mxGeometry x="830" y="477" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="4" value="&lt;pre style=&quot;font-size: 16.5pt; font-weight: 700; font-family: &amp;quot;jetbrains mono&amp;quot;, monospace;&quot;&gt;VhostUserMasterReqHandler&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1">
+ <mxGeometry x="840" y="597" width="360" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="6" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;entryX=1;entryY=0.5;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="5" target="2">
+ <mxGeometry relative="1" as="geometry">
+ <Array as="points">
+ <mxPoint x="1280" y="792"/>
+ <mxPoint x="1280" y="502"/>
+ </Array>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="5" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;Slave&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1">
+ <mxGeometry x="1715" y="767" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="7" value="&lt;pre style=&quot;font-size: 16.5pt; font-weight: 700; font-family: &amp;quot;jetbrains mono&amp;quot;, monospace;&quot;&gt;VhostUserMasterReqHandlerMut&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1630" y="657" width="390" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="8" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="2" target="4">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="950" y="657" as="sourcePoint"/>
+ <mxPoint x="680" y="717" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="10" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;SlaveListener&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1360" y="472" width="190" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="11" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;SlaveReqHandler&lt;/pre&gt;&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1712" y="387" width="210" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="14" value="&lt;pre style=&quot;font-size: 16.5pt; font-weight: 700; font-family: &amp;quot;jetbrains mono&amp;quot;, monospace;&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;VhostUserSlaveReqHandler&lt;/pre&gt;&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1652" y="537" width="330" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="15" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="11" target="14">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1202" y="567" as="sourcePoint"/>
+ <mxPoint x="1202" y="667" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="16" value="&lt;pre style=&quot;font-size: 16.5pt; font-weight: 700; font-family: &amp;quot;jetbrains mono&amp;quot;, monospace;&quot;&gt;VhostBackend&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#00994D;strokeColor=#009900;" vertex="1" parent="1">
+ <mxGeometry x="390" y="197" width="250" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="17" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;VhostKernBackend&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1">
+ <mxGeometry x="530" y="387" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="18" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;VhostVdpaBackend&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1">
+ <mxGeometry x="270" y="387" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="19" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;Master&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="820" y="387" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="20" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;VhostSoftBackend&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1">
+ <mxGeometry x="10" y="387" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="21" value="Handle virtque in VMM" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1">
+ <mxGeometry x="10" y="557" width="220" height="120" as="geometry"/>
+ </mxCell>
+ <mxCell id="23" value="Handle virtque in hardware" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1">
+ <mxGeometry x="270" y="807" width="220" height="120" as="geometry"/>
+ </mxCell>
+ <mxCell id="24" value="Handle virtque in kernel" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1">
+ <mxGeometry x="530" y="807" width="220" height="120" as="geometry"/>
+ </mxCell>
+ <mxCell id="25" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#0000CC;" edge="1" parent="1" source="24" target="17">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="930" y="647" as="sourcePoint"/>
+ <mxPoint x="930" y="747" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="26" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0;entryY=0.5;entryDx=0;entryDy=0;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="19" target="11">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="840" y="917" as="sourcePoint"/>
+ <mxPoint x="840" y="1017" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="27" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;fontColor=#808080;strokeColor=#808080;" edge="1" parent="1" source="23" target="18">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="420" y="807" as="sourcePoint"/>
+ <mxPoint x="420" y="907" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="28" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;fontColor=#808080;strokeColor=#808080;" edge="1" parent="1" source="21" target="20">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="240" y="857" as="sourcePoint"/>
+ <mxPoint x="240" y="957" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="30" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="20" target="16">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="910" y="647" as="sourcePoint"/>
+ <mxPoint x="910" y="747" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="31" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;strokeColor=#00994D;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" source="18" target="16">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1000" y="177" as="sourcePoint"/>
+ <mxPoint x="530" y="227" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="32" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="17" target="16">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1010" y="127" as="sourcePoint"/>
+ <mxPoint x="1505" y="-73" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="35" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;Endpoint&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1360" y="552" width="190" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="36" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;Message&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1360" y="632" width="190" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="37" value="&lt;pre style=&quot;font-size: 16.5pt ; font-weight: 700 ; font-family: &amp;quot;jetbrains mono&amp;quot; , monospace&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;VhostUserMaster&lt;/pre&gt;&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;strokeColor=#FF33FF;fontColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="980" y="257" width="230" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="38" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="19" target="16">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1030" y="527" as="sourcePoint"/>
+ <mxPoint x="515" y="257" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="39" value="Handle virtque in remote process" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="850" y="807" width="220" height="120" as="geometry"/>
+ </mxCell>
+ <mxCell id="41" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="5" target="7">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1860" y="187" as="sourcePoint"/>
+ <mxPoint x="1860" y="267" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="43" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#FF00FF;" edge="1" parent="1" source="19" target="37">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1430" y="187" as="sourcePoint"/>
+ <mxPoint x="2102" y="187" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="49" value="&lt;pre style=&quot;font-size: 16.5pt ; font-weight: 700 ; font-family: &amp;#34;jetbrains mono&amp;#34; , monospace&quot;&gt;&lt;pre style=&quot;font-family: &amp;#34;jetbrains mono&amp;#34; , monospace ; font-size: 16.5pt&quot;&gt;Trait&lt;/pre&gt;&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;strokeColor=#FF33FF;fontColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="60" y="1017" width="130" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="51" value="Vhost-user protocol" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fontColor=#FF00FF;fillColor=none;strokeColor=none;" vertex="1" parent="1">
+ <mxGeometry x="1220" y="817" width="330" height="150" as="geometry"/>
+ </mxCell>
+ <mxCell id="52" value="Vhost-user server" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fillColor=none;strokeColor=none;fontColor=#4D4D4D;" vertex="1" parent="1">
+ <mxGeometry x="1680" y="57" width="330" height="150" as="geometry"/>
+ </mxCell>
+ <mxCell id="53" value="VMM" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fillColor=none;strokeColor=none;fontColor=#4D4D4D;" vertex="1" parent="1">
+ <mxGeometry x="20" y="47" width="240" height="150" as="geometry"/>
+ </mxCell>
+ <mxCell id="54" value="&lt;pre style=&quot;font-family: &amp;#34;jetbrains mono&amp;#34; , monospace ; font-size: 16.5pt&quot;&gt;Struct&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1">
+ <mxGeometry x="240" y="1022" width="140" height="55" as="geometry"/>
+ </mxCell>
+ </root>
+ </mxGraphModel>
+ </diagram>
+</mxfile> \ No newline at end of file
diff --git a/docs/vhost_architecture.png b/docs/vhost_architecture.png
new file mode 100644
index 0000000..4d1e2bc
--- /dev/null
+++ b/docs/vhost_architecture.png
Binary files differ
diff --git a/src/backend.rs b/src/backend.rs
new file mode 100644
index 0000000..43ab7b9
--- /dev/null
+++ b/src/backend.rs
@@ -0,0 +1,551 @@
+// Copyright 2019 Intel Corporation. All Rights Reserved.
+// Copyright 2019-2021 Alibaba Cloud. All rights reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+//! Traits for vhost user backend servers to implement virtio data plain services.
+//!
+//! Define two traits for vhost user backend servers to implement virtio data plane services.
+//! The only difference between the two traits is mutability. The [VhostUserBackend] trait is
+//! designed with interior mutability, so the implementor may choose the suitable way to protect
+//! itself from concurrent accesses. The [VhostUserBackendMut] is designed without interior
+//! mutability, and an implementation of:
+//! ```ignore
+//! impl<T: VhostUserBackendMut> VhostUserBackend for RwLock<T> { }
+//! ```
+//! is provided for convenience.
+//!
+//! [VhostUserBackend]: trait.VhostUserBackend.html
+//! [VhostUserBackendMut]: trait.VhostUserBackendMut.html
+
+use std::io::Result;
+use std::ops::Deref;
+use std::sync::{Arc, Mutex, RwLock};
+
+use vhost::vhost_user::message::VhostUserProtocolFeatures;
+use vhost::vhost_user::Slave;
+use vm_memory::bitmap::Bitmap;
+use vmm_sys_util::epoll::EventSet;
+use vmm_sys_util::eventfd::EventFd;
+
+use super::vring::VringT;
+use super::GM;
+
+/// Trait with interior mutability for vhost user backend servers to implement concrete services.
+///
+/// To support multi-threading and asynchronous IO, we enforce `Send + Sync` bound.
+pub trait VhostUserBackend<V, B = ()>: Send + Sync
+where
+ V: VringT<GM<B>>,
+ B: Bitmap + 'static,
+{
+ /// Get number of queues supported.
+ fn num_queues(&self) -> usize;
+
+ /// Get maximum queue size supported.
+ fn max_queue_size(&self) -> usize;
+
+ /// Get available virtio features.
+ fn features(&self) -> u64;
+
+ /// Set acknowledged virtio features.
+ fn acked_features(&self, _features: u64) {}
+
+ /// Get available vhost protocol features.
+ fn protocol_features(&self) -> VhostUserProtocolFeatures;
+
+ /// Enable or disable the virtio EVENT_IDX feature
+ fn set_event_idx(&self, enabled: bool);
+
+ /// Get virtio device configuration.
+ ///
+ /// A default implementation is provided as we cannot expect all backends to implement this
+ /// function.
+ fn get_config(&self, _offset: u32, _size: u32) -> Vec<u8> {
+ Vec::new()
+ }
+
+ /// Set virtio device configuration.
+ ///
+ /// A default implementation is provided as we cannot expect all backends to implement this
+ /// function.
+ fn set_config(&self, _offset: u32, _buf: &[u8]) -> Result<()> {
+ Ok(())
+ }
+
+ /// Update guest memory regions.
+ fn update_memory(&self, mem: GM<B>) -> Result<()>;
+
+ /// Set handler for communicating with the master by the slave communication channel.
+ ///
+ /// A default implementation is provided as we cannot expect all backends to implement this
+ /// function.
+ fn set_slave_req_fd(&self, _slave: Slave) {}
+
+ /// Get the map to map queue index to worker thread index.
+ ///
+ /// A return value of [2, 2, 4] means: the first two queues will be handled by worker thread 0,
+ /// the following two queues will be handled by worker thread 1, and the last four queues will
+ /// be handled by worker thread 2.
+ fn queues_per_thread(&self) -> Vec<u64> {
+ vec![0xffff_ffff]
+ }
+
+ /// Provide an optional exit EventFd for the specified worker thread.
+ ///
+ /// If an (`EventFd`, `token`) pair is returned, the returned `EventFd` will be monitored for IO
+ /// events by using epoll with the specified `token`. When the returned EventFd is written to,
+ /// the worker thread will exit.
+ fn exit_event(&self, _thread_index: usize) -> Option<EventFd> {
+ None
+ }
+
+ /// Handle IO events for backend registered file descriptors.
+ ///
+ /// This function gets called if the backend registered some additional listeners onto specific
+ /// file descriptors. The library can handle virtqueues on its own, but does not know what to
+ /// do with events happening on custom listeners.
+ fn handle_event(
+ &self,
+ device_event: u16,
+ evset: EventSet,
+ vrings: &[V],
+ thread_id: usize,
+ ) -> Result<bool>;
+}
+
+/// Trait without interior mutability for vhost user backend servers to implement concrete services.
+pub trait VhostUserBackendMut<V, B = ()>: Send + Sync
+where
+ V: VringT<GM<B>>,
+ B: Bitmap + 'static,
+{
+ /// Get number of queues supported.
+ fn num_queues(&self) -> usize;
+
+ /// Get maximum queue size supported.
+ fn max_queue_size(&self) -> usize;
+
+ /// Get available virtio features.
+ fn features(&self) -> u64;
+
+ /// Set acknowledged virtio features.
+ fn acked_features(&mut self, _features: u64) {}
+
+ /// Get available vhost protocol features.
+ fn protocol_features(&self) -> VhostUserProtocolFeatures;
+
+ /// Enable or disable the virtio EVENT_IDX feature
+ fn set_event_idx(&mut self, enabled: bool);
+
+ /// Get virtio device configuration.
+ ///
+ /// A default implementation is provided as we cannot expect all backends to implement this
+ /// function.
+ fn get_config(&self, _offset: u32, _size: u32) -> Vec<u8> {
+ Vec::new()
+ }
+
+ /// Set virtio device configuration.
+ ///
+ /// A default implementation is provided as we cannot expect all backends to implement this
+ /// function.
+ fn set_config(&mut self, _offset: u32, _buf: &[u8]) -> Result<()> {
+ Ok(())
+ }
+
+ /// Update guest memory regions.
+ fn update_memory(&mut self, mem: GM<B>) -> Result<()>;
+
+ /// Set handler for communicating with the master by the slave communication channel.
+ ///
+ /// A default implementation is provided as we cannot expect all backends to implement this
+ /// function.
+ fn set_slave_req_fd(&mut self, _slave: Slave) {}
+
+ /// Get the map to map queue index to worker thread index.
+ ///
+ /// A return value of [2, 2, 4] means: the first two queues will be handled by worker thread 0,
+ /// the following two queues will be handled by worker thread 1, and the last four queues will
+ /// be handled by worker thread 2.
+ fn queues_per_thread(&self) -> Vec<u64> {
+ vec![0xffff_ffff]
+ }
+
+ /// Provide an optional exit EventFd for the specified worker thread.
+ ///
+ /// If an (`EventFd`, `token`) pair is returned, the returned `EventFd` will be monitored for IO
+ /// events by using epoll with the specified `token`. When the returned EventFd is written to,
+ /// the worker thread will exit.
+ fn exit_event(&self, _thread_index: usize) -> Option<EventFd> {
+ None
+ }
+
+ /// Handle IO events for backend registered file descriptors.
+ ///
+ /// This function gets called if the backend registered some additional listeners onto specific
+ /// file descriptors. The library can handle virtqueues on its own, but does not know what to
+ /// do with events happening on custom listeners.
+ fn handle_event(
+ &mut self,
+ device_event: u16,
+ evset: EventSet,
+ vrings: &[V],
+ thread_id: usize,
+ ) -> Result<bool>;
+}
+
+impl<T: VhostUserBackend<V, B>, V, B> VhostUserBackend<V, B> for Arc<T>
+where
+ V: VringT<GM<B>>,
+ B: Bitmap + 'static,
+{
+ fn num_queues(&self) -> usize {
+ self.deref().num_queues()
+ }
+
+ fn max_queue_size(&self) -> usize {
+ self.deref().max_queue_size()
+ }
+
+ fn features(&self) -> u64 {
+ self.deref().features()
+ }
+
+ fn acked_features(&self, features: u64) {
+ self.deref().acked_features(features)
+ }
+
+ fn protocol_features(&self) -> VhostUserProtocolFeatures {
+ self.deref().protocol_features()
+ }
+
+ fn set_event_idx(&self, enabled: bool) {
+ self.deref().set_event_idx(enabled)
+ }
+
+ fn get_config(&self, offset: u32, size: u32) -> Vec<u8> {
+ self.deref().get_config(offset, size)
+ }
+
+ fn set_config(&self, offset: u32, buf: &[u8]) -> Result<()> {
+ self.deref().set_config(offset, buf)
+ }
+
+ fn update_memory(&self, mem: GM<B>) -> Result<()> {
+ self.deref().update_memory(mem)
+ }
+
+ fn set_slave_req_fd(&self, slave: Slave) {
+ self.deref().set_slave_req_fd(slave)
+ }
+
+ fn queues_per_thread(&self) -> Vec<u64> {
+ self.deref().queues_per_thread()
+ }
+
+ fn exit_event(&self, thread_index: usize) -> Option<EventFd> {
+ self.deref().exit_event(thread_index)
+ }
+
+ fn handle_event(
+ &self,
+ device_event: u16,
+ evset: EventSet,
+ vrings: &[V],
+ thread_id: usize,
+ ) -> Result<bool> {
+ self.deref()
+ .handle_event(device_event, evset, vrings, thread_id)
+ }
+}
+
+impl<T: VhostUserBackendMut<V, B>, V, B> VhostUserBackend<V, B> for Mutex<T>
+where
+ V: VringT<GM<B>>,
+ B: Bitmap + 'static,
+{
+ fn num_queues(&self) -> usize {
+ self.lock().unwrap().num_queues()
+ }
+
+ fn max_queue_size(&self) -> usize {
+ self.lock().unwrap().max_queue_size()
+ }
+
+ fn features(&self) -> u64 {
+ self.lock().unwrap().features()
+ }
+
+ fn acked_features(&self, features: u64) {
+ self.lock().unwrap().acked_features(features)
+ }
+
+ fn protocol_features(&self) -> VhostUserProtocolFeatures {
+ self.lock().unwrap().protocol_features()
+ }
+
+ fn set_event_idx(&self, enabled: bool) {
+ self.lock().unwrap().set_event_idx(enabled)
+ }
+
+ fn get_config(&self, offset: u32, size: u32) -> Vec<u8> {
+ self.lock().unwrap().get_config(offset, size)
+ }
+
+ fn set_config(&self, offset: u32, buf: &[u8]) -> Result<()> {
+ self.lock().unwrap().set_config(offset, buf)
+ }
+
+ fn update_memory(&self, mem: GM<B>) -> Result<()> {
+ self.lock().unwrap().update_memory(mem)
+ }
+
+ fn set_slave_req_fd(&self, slave: Slave) {
+ self.lock().unwrap().set_slave_req_fd(slave)
+ }
+
+ fn queues_per_thread(&self) -> Vec<u64> {
+ self.lock().unwrap().queues_per_thread()
+ }
+
+ fn exit_event(&self, thread_index: usize) -> Option<EventFd> {
+ self.lock().unwrap().exit_event(thread_index)
+ }
+
+ fn handle_event(
+ &self,
+ device_event: u16,
+ evset: EventSet,
+ vrings: &[V],
+ thread_id: usize,
+ ) -> Result<bool> {
+ self.lock()
+ .unwrap()
+ .handle_event(device_event, evset, vrings, thread_id)
+ }
+}
+
+impl<T: VhostUserBackendMut<V, B>, V, B> VhostUserBackend<V, B> for RwLock<T>
+where
+ V: VringT<GM<B>>,
+ B: Bitmap + 'static,
+{
+ fn num_queues(&self) -> usize {
+ self.read().unwrap().num_queues()
+ }
+
+ fn max_queue_size(&self) -> usize {
+ self.read().unwrap().max_queue_size()
+ }
+
+ fn features(&self) -> u64 {
+ self.read().unwrap().features()
+ }
+
+ fn acked_features(&self, features: u64) {
+ self.write().unwrap().acked_features(features)
+ }
+
+ fn protocol_features(&self) -> VhostUserProtocolFeatures {
+ self.read().unwrap().protocol_features()
+ }
+
+ fn set_event_idx(&self, enabled: bool) {
+ self.write().unwrap().set_event_idx(enabled)
+ }
+
+ fn get_config(&self, offset: u32, size: u32) -> Vec<u8> {
+ self.read().unwrap().get_config(offset, size)
+ }
+
+ fn set_config(&self, offset: u32, buf: &[u8]) -> Result<()> {
+ self.write().unwrap().set_config(offset, buf)
+ }
+
+ fn update_memory(&self, mem: GM<B>) -> Result<()> {
+ self.write().unwrap().update_memory(mem)
+ }
+
+ fn set_slave_req_fd(&self, slave: Slave) {
+ self.write().unwrap().set_slave_req_fd(slave)
+ }
+
+ fn queues_per_thread(&self) -> Vec<u64> {
+ self.read().unwrap().queues_per_thread()
+ }
+
+ fn exit_event(&self, thread_index: usize) -> Option<EventFd> {
+ self.read().unwrap().exit_event(thread_index)
+ }
+
+ fn handle_event(
+ &self,
+ device_event: u16,
+ evset: EventSet,
+ vrings: &[V],
+ thread_id: usize,
+ ) -> Result<bool> {
+ self.write()
+ .unwrap()
+ .handle_event(device_event, evset, vrings, thread_id)
+ }
+}
+
+#[cfg(test)]
+pub mod tests {
+ use super::*;
+ use crate::VringRwLock;
+ use std::sync::Mutex;
+ use vm_memory::{GuestAddress, GuestMemoryAtomic, GuestMemoryMmap};
+
+ pub struct MockVhostBackend {
+ events: u64,
+ event_idx: bool,
+ acked_features: u64,
+ }
+
+ impl MockVhostBackend {
+ pub fn new() -> Self {
+ MockVhostBackend {
+ events: 0,
+ event_idx: false,
+ acked_features: 0,
+ }
+ }
+ }
+
+ impl VhostUserBackendMut<VringRwLock, ()> for MockVhostBackend {
+ fn num_queues(&self) -> usize {
+ 2
+ }
+
+ fn max_queue_size(&self) -> usize {
+ 256
+ }
+
+ fn features(&self) -> u64 {
+ 0xffff_ffff_ffff_ffff
+ }
+
+ fn acked_features(&mut self, features: u64) {
+ self.acked_features = features;
+ }
+
+ fn protocol_features(&self) -> VhostUserProtocolFeatures {
+ VhostUserProtocolFeatures::all()
+ }
+
+ fn set_event_idx(&mut self, enabled: bool) {
+ self.event_idx = enabled;
+ }
+
+ fn get_config(&self, offset: u32, size: u32) -> Vec<u8> {
+ assert_eq!(offset, 0x200);
+ assert_eq!(size, 8);
+
+ vec![0xa5u8; 8]
+ }
+
+ fn set_config(&mut self, offset: u32, buf: &[u8]) -> Result<()> {
+ assert_eq!(offset, 0x200);
+ assert_eq!(buf.len(), 8);
+ assert_eq!(buf, &[0xa5u8; 8]);
+
+ Ok(())
+ }
+
+ fn update_memory(&mut self, _atomic_mem: GuestMemoryAtomic<GuestMemoryMmap>) -> Result<()> {
+ Ok(())
+ }
+
+ fn set_slave_req_fd(&mut self, _slave: Slave) {}
+
+ fn queues_per_thread(&self) -> Vec<u64> {
+ vec![1, 1]
+ }
+
+ fn exit_event(&self, _thread_index: usize) -> Option<EventFd> {
+ let event_fd = EventFd::new(0).unwrap();
+
+ Some(event_fd)
+ }
+
+ fn handle_event(
+ &mut self,
+ _device_event: u16,
+ _evset: EventSet,
+ _vrings: &[VringRwLock],
+ _thread_id: usize,
+ ) -> Result<bool> {
+ self.events += 1;
+
+ Ok(false)
+ }
+ }
+
+ #[test]
+ fn test_new_mock_backend_mutex() {
+ let backend = Arc::new(Mutex::new(MockVhostBackend::new()));
+
+ assert_eq!(backend.num_queues(), 2);
+ assert_eq!(backend.max_queue_size(), 256);
+ assert_eq!(backend.features(), 0xffff_ffff_ffff_ffff);
+ assert_eq!(
+ backend.protocol_features(),
+ VhostUserProtocolFeatures::all()
+ );
+ assert_eq!(backend.queues_per_thread(), [1, 1]);
+
+ assert_eq!(backend.get_config(0x200, 8), vec![0xa5; 8]);
+ backend.set_config(0x200, &[0xa5; 8]).unwrap();
+
+ backend.acked_features(0xffff);
+ assert_eq!(backend.lock().unwrap().acked_features, 0xffff);
+
+ backend.set_event_idx(true);
+ assert!(backend.lock().unwrap().event_idx);
+
+ let _ = backend.exit_event(0).unwrap();
+
+ let mem = GuestMemoryAtomic::new(
+ GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(),
+ );
+ backend.update_memory(mem).unwrap();
+ }
+
+ #[test]
+ fn test_new_mock_backend_rwlock() {
+ let backend = Arc::new(RwLock::new(MockVhostBackend::new()));
+
+ assert_eq!(backend.num_queues(), 2);
+ assert_eq!(backend.max_queue_size(), 256);
+ assert_eq!(backend.features(), 0xffff_ffff_ffff_ffff);
+ assert_eq!(
+ backend.protocol_features(),
+ VhostUserProtocolFeatures::all()
+ );
+ assert_eq!(backend.queues_per_thread(), [1, 1]);
+
+ assert_eq!(backend.get_config(0x200, 8), vec![0xa5; 8]);
+ backend.set_config(0x200, &[0xa5; 8]).unwrap();
+
+ backend.acked_features(0xffff);
+ assert_eq!(backend.read().unwrap().acked_features, 0xffff);
+
+ backend.set_event_idx(true);
+ assert!(backend.read().unwrap().event_idx);
+
+ let _ = backend.exit_event(0).unwrap();
+
+ let mem = GuestMemoryAtomic::new(
+ GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(),
+ );
+ backend.update_memory(mem.clone()).unwrap();
+
+ let vring = VringRwLock::new(mem, 0x1000).unwrap();
+ backend
+ .handle_event(0x1, EventSet::IN, &[vring], 0)
+ .unwrap();
+ }
+}
diff --git a/src/event_loop.rs b/src/event_loop.rs
new file mode 100644
index 0000000..f10aad3
--- /dev/null
+++ b/src/event_loop.rs
@@ -0,0 +1,270 @@
+// Copyright 2019 Intel Corporation. All Rights Reserved.
+// Copyright 2019-2021 Alibaba Cloud. All rights reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+use std::fmt::{Display, Formatter};
+use std::io::{self, Result};
+use std::marker::PhantomData;
+use std::os::unix::io::{AsRawFd, RawFd};
+
+use vm_memory::bitmap::Bitmap;
+use vmm_sys_util::epoll::{ControlOperation, Epoll, EpollEvent, EventSet};
+use vmm_sys_util::eventfd::EventFd;
+
+use super::backend::VhostUserBackend;
+use super::vring::VringT;
+use super::GM;
+
+/// Errors related to vring epoll event handling.
+#[derive(Debug)]
+pub enum VringEpollError {
+ /// Failed to create epoll file descriptor.
+ EpollCreateFd(io::Error),
+ /// Failed while waiting for events.
+ EpollWait(io::Error),
+ /// Could not register exit event
+ RegisterExitEvent(io::Error),
+ /// Failed to read the event from kick EventFd.
+ HandleEventReadKick(io::Error),
+ /// Failed to handle the event from the backend.
+ HandleEventBackendHandling(io::Error),
+}
+
+impl Display for VringEpollError {
+ fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
+ match self {
+ VringEpollError::EpollCreateFd(e) => write!(f, "cannot create epoll fd: {}", e),
+ VringEpollError::EpollWait(e) => write!(f, "failed to wait for epoll event: {}", e),
+ VringEpollError::RegisterExitEvent(e) => write!(f, "cannot register exit event: {}", e),
+ VringEpollError::HandleEventReadKick(e) => {
+ write!(f, "cannot read vring kick event: {}", e)
+ }
+ VringEpollError::HandleEventBackendHandling(e) => {
+ write!(f, "failed to handle epoll event: {}", e)
+ }
+ }
+ }
+}
+
+impl std::error::Error for VringEpollError {}
+
+/// Result of vring epoll operations.
+pub type VringEpollResult<T> = std::result::Result<T, VringEpollError>;
+
+/// Epoll event handler to manage and process epoll events for registered file descriptor.
+///
+/// The `VringEpollHandler` structure provides interfaces to:
+/// - add file descriptors to be monitored by the epoll fd
+/// - remove registered file descriptors from the epoll fd
+/// - run the event loop to handle pending events on the epoll fd
+pub struct VringEpollHandler<S, V, B> {
+ epoll: Epoll,
+ backend: S,
+ vrings: Vec<V>,
+ thread_id: usize,
+ exit_event_fd: Option<EventFd>,
+ phantom: PhantomData<B>,
+}
+
+impl<S, V, B> VringEpollHandler<S, V, B> {
+ /// Send `exit event` to break the event loop.
+ pub fn send_exit_event(&self) {
+ if let Some(eventfd) = self.exit_event_fd.as_ref() {
+ let _ = eventfd.write(1);
+ }
+ }
+}
+
+impl<S, V, B> VringEpollHandler<S, V, B>
+where
+ S: VhostUserBackend<V, B>,
+ V: VringT<GM<B>>,
+ B: Bitmap + 'static,
+{
+ /// Create a `VringEpollHandler` instance.
+ pub(crate) fn new(backend: S, vrings: Vec<V>, thread_id: usize) -> VringEpollResult<Self> {
+ let epoll = Epoll::new().map_err(VringEpollError::EpollCreateFd)?;
+ let exit_event_fd = backend.exit_event(thread_id);
+
+ if let Some(exit_event_fd) = &exit_event_fd {
+ let id = backend.num_queues();
+ epoll
+ .ctl(
+ ControlOperation::Add,
+ exit_event_fd.as_raw_fd(),
+ EpollEvent::new(EventSet::IN, id as u64),
+ )
+ .map_err(VringEpollError::RegisterExitEvent)?;
+ }
+
+ Ok(VringEpollHandler {
+ epoll,
+ backend,
+ vrings,
+ thread_id,
+ exit_event_fd,
+ phantom: PhantomData,
+ })
+ }
+
+ /// Register an event into the epoll fd.
+ ///
+ /// When this event is later triggered, the backend implementation of `handle_event` will be
+ /// called.
+ pub fn register_listener(&self, fd: RawFd, ev_type: EventSet, data: u64) -> Result<()> {
+ // `data` range [0...num_queues] is reserved for queues and exit event.
+ if data <= self.backend.num_queues() as u64 {
+ Err(io::Error::from_raw_os_error(libc::EINVAL))
+ } else {
+ self.register_event(fd, ev_type, data)
+ }
+ }
+
+ /// Unregister an event from the epoll fd.
+ ///
+ /// If the event is triggered after this function has been called, the event will be silently
+ /// dropped.
+ pub fn unregister_listener(&self, fd: RawFd, ev_type: EventSet, data: u64) -> Result<()> {
+ // `data` range [0...num_queues] is reserved for queues and exit event.
+ if data <= self.backend.num_queues() as u64 {
+ Err(io::Error::from_raw_os_error(libc::EINVAL))
+ } else {
+ self.unregister_event(fd, ev_type, data)
+ }
+ }
+
+ pub(crate) fn register_event(&self, fd: RawFd, ev_type: EventSet, data: u64) -> Result<()> {
+ self.epoll
+ .ctl(ControlOperation::Add, fd, EpollEvent::new(ev_type, data))
+ }
+
+ pub(crate) fn unregister_event(&self, fd: RawFd, ev_type: EventSet, data: u64) -> Result<()> {
+ self.epoll
+ .ctl(ControlOperation::Delete, fd, EpollEvent::new(ev_type, data))
+ }
+
+ /// Run the event poll loop to handle all pending events on registered fds.
+ ///
+ /// The event loop will be terminated once an event is received from the `exit event fd`
+ /// associated with the backend.
+ pub(crate) fn run(&self) -> VringEpollResult<()> {
+ const EPOLL_EVENTS_LEN: usize = 100;
+ let mut events = vec![EpollEvent::new(EventSet::empty(), 0); EPOLL_EVENTS_LEN];
+
+ 'epoll: loop {
+ let num_events = match self.epoll.wait(-1, &mut events[..]) {
+ Ok(res) => res,
+ Err(e) => {
+ if e.kind() == io::ErrorKind::Interrupted {
+ // It's well defined from the epoll_wait() syscall
+ // documentation that the epoll loop can be interrupted
+ // before any of the requested events occurred or the
+ // timeout expired. In both those cases, epoll_wait()
+ // returns an error of type EINTR, but this should not
+ // be considered as a regular error. Instead it is more
+ // appropriate to retry, by calling into epoll_wait().
+ continue;
+ }
+ return Err(VringEpollError::EpollWait(e));
+ }
+ };
+
+ for event in events.iter().take(num_events) {
+ let evset = match EventSet::from_bits(event.events) {
+ Some(evset) => evset,
+ None => {
+ let evbits = event.events;
+ println!("epoll: ignoring unknown event set: 0x{:x}", evbits);
+ continue;
+ }
+ };
+
+ let ev_type = event.data() as u16;
+
+ // handle_event() returns true if an event is received from the exit event fd.
+ if self.handle_event(ev_type, evset)? {
+ break 'epoll;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ fn handle_event(&self, device_event: u16, evset: EventSet) -> VringEpollResult<bool> {
+ if self.exit_event_fd.is_some() && device_event as usize == self.backend.num_queues() {
+ return Ok(true);
+ }
+
+ if (device_event as usize) < self.vrings.len() {
+ let vring = &self.vrings[device_event as usize];
+ let enabled = vring
+ .read_kick()
+ .map_err(VringEpollError::HandleEventReadKick)?;
+
+ // If the vring is not enabled, it should not be processed.
+ if !enabled {
+ return Ok(false);
+ }
+ }
+
+ self.backend
+ .handle_event(device_event, evset, &self.vrings, self.thread_id)
+ .map_err(VringEpollError::HandleEventBackendHandling)
+ }
+}
+
+impl<S, V, B> AsRawFd for VringEpollHandler<S, V, B> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.epoll.as_raw_fd()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::super::backend::tests::MockVhostBackend;
+ use super::super::vring::VringRwLock;
+ use super::*;
+ use std::sync::{Arc, Mutex};
+ use vm_memory::{GuestAddress, GuestMemoryAtomic, GuestMemoryMmap};
+ use vmm_sys_util::eventfd::EventFd;
+
+ #[test]
+ fn test_vring_epoll_handler() {
+ let mem = GuestMemoryAtomic::new(
+ GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(),
+ );
+ let vring = VringRwLock::new(mem, 0x1000).unwrap();
+ let backend = Arc::new(Mutex::new(MockVhostBackend::new()));
+
+ let handler = VringEpollHandler::new(backend, vec![vring], 0x1).unwrap();
+
+ let eventfd = EventFd::new(0).unwrap();
+ handler
+ .register_listener(eventfd.as_raw_fd(), EventSet::IN, 3)
+ .unwrap();
+ // Register an already registered fd.
+ handler
+ .register_listener(eventfd.as_raw_fd(), EventSet::IN, 3)
+ .unwrap_err();
+ // Register an invalid data.
+ handler
+ .register_listener(eventfd.as_raw_fd(), EventSet::IN, 1)
+ .unwrap_err();
+
+ handler
+ .unregister_listener(eventfd.as_raw_fd(), EventSet::IN, 3)
+ .unwrap();
+ // unregister an already unregistered fd.
+ handler
+ .unregister_listener(eventfd.as_raw_fd(), EventSet::IN, 3)
+ .unwrap_err();
+ // unregister an invalid data.
+ handler
+ .unregister_listener(eventfd.as_raw_fd(), EventSet::IN, 1)
+ .unwrap_err();
+ // Check we retrieve the correct file descriptor
+ assert_eq!(handler.as_raw_fd(), handler.epoll.as_raw_fd());
+ }
+}
diff --git a/src/handler.rs b/src/handler.rs
new file mode 100644
index 0000000..262bf6c
--- /dev/null
+++ b/src/handler.rs
@@ -0,0 +1,618 @@
+// Copyright 2019 Intel Corporation. All Rights Reserved.
+// Copyright 2019-2021 Alibaba Cloud. All rights reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+use std::error;
+use std::fs::File;
+use std::io;
+use std::os::unix::io::AsRawFd;
+use std::sync::Arc;
+use std::thread;
+
+use vhost::vhost_user::message::{
+ VhostUserConfigFlags, VhostUserMemoryRegion, VhostUserProtocolFeatures,
+ VhostUserSingleMemoryRegion, VhostUserVirtioFeatures, VhostUserVringAddrFlags,
+ VhostUserVringState,
+};
+use vhost::vhost_user::{
+ Error as VhostUserError, Result as VhostUserResult, Slave, VhostUserSlaveReqHandlerMut,
+};
+use virtio_bindings::bindings::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
+use virtio_queue::{Error as VirtQueError, QueueT};
+use vm_memory::bitmap::Bitmap;
+use vm_memory::mmap::NewBitmap;
+use vm_memory::{GuestAddress, GuestAddressSpace, GuestMemoryMmap, GuestRegionMmap};
+use vmm_sys_util::epoll::EventSet;
+
+use super::backend::VhostUserBackend;
+use super::event_loop::VringEpollHandler;
+use super::event_loop::{VringEpollError, VringEpollResult};
+use super::vring::VringT;
+use super::GM;
+
+const MAX_MEM_SLOTS: u64 = 32;
+
+#[derive(Debug)]
+/// Errors related to vhost-user handler.
+pub enum VhostUserHandlerError {
+ /// Failed to create a `Vring`.
+ CreateVring(VirtQueError),
+ /// Failed to create vring worker.
+ CreateEpollHandler(VringEpollError),
+ /// Failed to spawn vring worker.
+ SpawnVringWorker(io::Error),
+ /// Could not find the mapping from memory regions.
+ MissingMemoryMapping,
+}
+
+impl std::fmt::Display for VhostUserHandlerError {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ match self {
+ VhostUserHandlerError::CreateVring(e) => {
+ write!(f, "failed to create vring: {}", e)
+ }
+ VhostUserHandlerError::CreateEpollHandler(e) => {
+ write!(f, "failed to create vring epoll handler: {}", e)
+ }
+ VhostUserHandlerError::SpawnVringWorker(e) => {
+ write!(f, "failed spawning the vring worker: {}", e)
+ }
+ VhostUserHandlerError::MissingMemoryMapping => write!(f, "Missing memory mapping"),
+ }
+ }
+}
+
+impl error::Error for VhostUserHandlerError {}
+
+/// Result of vhost-user handler operations.
+pub type VhostUserHandlerResult<T> = std::result::Result<T, VhostUserHandlerError>;
+
+struct AddrMapping {
+ vmm_addr: u64,
+ size: u64,
+ gpa_base: u64,
+}
+
+pub struct VhostUserHandler<S, V, B: Bitmap + 'static> {
+ backend: S,
+ handlers: Vec<Arc<VringEpollHandler<S, V, B>>>,
+ owned: bool,
+ features_acked: bool,
+ acked_features: u64,
+ acked_protocol_features: u64,
+ num_queues: usize,
+ max_queue_size: usize,
+ queues_per_thread: Vec<u64>,
+ mappings: Vec<AddrMapping>,
+ atomic_mem: GM<B>,
+ vrings: Vec<V>,
+ worker_threads: Vec<thread::JoinHandle<VringEpollResult<()>>>,
+}
+
+// Ensure VhostUserHandler: Clone + Send + Sync + 'static.
+impl<S, V, B> VhostUserHandler<S, V, B>
+where
+ S: VhostUserBackend<V, B> + Clone + 'static,
+ V: VringT<GM<B>> + Clone + Send + Sync + 'static,
+ B: Bitmap + Clone + Send + Sync + 'static,
+{
+ pub(crate) fn new(backend: S, atomic_mem: GM<B>) -> VhostUserHandlerResult<Self> {
+ let num_queues = backend.num_queues();
+ let max_queue_size = backend.max_queue_size();
+ let queues_per_thread = backend.queues_per_thread();
+
+ let mut vrings = Vec::new();
+ for _ in 0..num_queues {
+ let vring = V::new(atomic_mem.clone(), max_queue_size as u16)
+ .map_err(VhostUserHandlerError::CreateVring)?;
+ vrings.push(vring);
+ }
+
+ let mut handlers = Vec::new();
+ let mut worker_threads = Vec::new();
+ for (thread_id, queues_mask) in queues_per_thread.iter().enumerate() {
+ let mut thread_vrings = Vec::new();
+ for (index, vring) in vrings.iter().enumerate() {
+ if (queues_mask >> index) & 1u64 == 1u64 {
+ thread_vrings.push(vring.clone());
+ }
+ }
+
+ let handler = Arc::new(
+ VringEpollHandler::new(backend.clone(), thread_vrings, thread_id)
+ .map_err(VhostUserHandlerError::CreateEpollHandler)?,
+ );
+ let handler2 = handler.clone();
+ let worker_thread = thread::Builder::new()
+ .name("vring_worker".to_string())
+ .spawn(move || handler2.run())
+ .map_err(VhostUserHandlerError::SpawnVringWorker)?;
+
+ handlers.push(handler);
+ worker_threads.push(worker_thread);
+ }
+
+ Ok(VhostUserHandler {
+ backend,
+ handlers,
+ owned: false,
+ features_acked: false,
+ acked_features: 0,
+ acked_protocol_features: 0,
+ num_queues,
+ max_queue_size,
+ queues_per_thread,
+ mappings: Vec::new(),
+ atomic_mem,
+ vrings,
+ worker_threads,
+ })
+ }
+}
+
+impl<S, V, B: Bitmap> VhostUserHandler<S, V, B> {
+ pub(crate) fn send_exit_event(&self) {
+ for handler in self.handlers.iter() {
+ handler.send_exit_event();
+ }
+ }
+
+ fn vmm_va_to_gpa(&self, vmm_va: u64) -> VhostUserHandlerResult<u64> {
+ for mapping in self.mappings.iter() {
+ if vmm_va >= mapping.vmm_addr && vmm_va < mapping.vmm_addr + mapping.size {
+ return Ok(vmm_va - mapping.vmm_addr + mapping.gpa_base);
+ }
+ }
+
+ Err(VhostUserHandlerError::MissingMemoryMapping)
+ }
+}
+
+impl<S, V, B> VhostUserHandler<S, V, B>
+where
+ S: VhostUserBackend<V, B>,
+ V: VringT<GM<B>>,
+ B: Bitmap,
+{
+ pub(crate) fn get_epoll_handlers(&self) -> Vec<Arc<VringEpollHandler<S, V, B>>> {
+ self.handlers.clone()
+ }
+
+ fn vring_needs_init(&self, vring: &V) -> bool {
+ let vring_state = vring.get_ref();
+
+ // If the vring wasn't initialized and we already have an EventFd for
+ // VRING_KICK, initialize it now.
+ !vring_state.get_queue().ready() && vring_state.get_kick().is_some()
+ }
+
+ fn initialize_vring(&self, vring: &V, index: u8) -> VhostUserResult<()> {
+ assert!(vring.get_ref().get_kick().is_some());
+
+ if let Some(fd) = vring.get_ref().get_kick() {
+ for (thread_index, queues_mask) in self.queues_per_thread.iter().enumerate() {
+ let shifted_queues_mask = queues_mask >> index;
+ if shifted_queues_mask & 1u64 == 1u64 {
+ let evt_idx = queues_mask.count_ones() - shifted_queues_mask.count_ones();
+ self.handlers[thread_index]
+ .register_event(fd.as_raw_fd(), EventSet::IN, u64::from(evt_idx))
+ .map_err(VhostUserError::ReqHandlerError)?;
+ break;
+ }
+ }
+ }
+
+ self.vrings[index as usize].set_queue_ready(true);
+
+ Ok(())
+ }
+
+ /// Helper to check if VirtioFeature enabled
+ fn check_feature(&self, feat: VhostUserVirtioFeatures) -> VhostUserResult<()> {
+ if self.acked_features & feat.bits() != 0 {
+ Ok(())
+ } else {
+ Err(VhostUserError::InactiveFeature(feat))
+ }
+ }
+}
+
+impl<S, V, B> VhostUserSlaveReqHandlerMut for VhostUserHandler<S, V, B>
+where
+ S: VhostUserBackend<V, B>,
+ V: VringT<GM<B>>,
+ B: NewBitmap + Clone,
+{
+ fn set_owner(&mut self) -> VhostUserResult<()> {
+ if self.owned {
+ return Err(VhostUserError::InvalidOperation("already claimed"));
+ }
+ self.owned = true;
+ Ok(())
+ }
+
+ fn reset_owner(&mut self) -> VhostUserResult<()> {
+ self.owned = false;
+ self.features_acked = false;
+ self.acked_features = 0;
+ self.acked_protocol_features = 0;
+ Ok(())
+ }
+
+ fn get_features(&mut self) -> VhostUserResult<u64> {
+ Ok(self.backend.features())
+ }
+
+ fn set_features(&mut self, features: u64) -> VhostUserResult<()> {
+ if (features & !self.backend.features()) != 0 {
+ return Err(VhostUserError::InvalidParam);
+ }
+
+ self.acked_features = features;
+ self.features_acked = true;
+
+ // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated,
+ // the ring is initialized in an enabled state.
+ // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated,
+ // the ring is initialized in a disabled state. Client must not
+ // pass data to/from the backend until ring is enabled by
+ // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has
+ // been disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0.
+ let vring_enabled =
+ self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0;
+ for vring in self.vrings.iter_mut() {
+ vring.set_enabled(vring_enabled);
+ }
+
+ self.backend.acked_features(self.acked_features);
+
+ Ok(())
+ }
+
+ fn set_mem_table(
+ &mut self,
+ ctx: &[VhostUserMemoryRegion],
+ files: Vec<File>,
+ ) -> VhostUserResult<()> {
+ // We need to create tuple of ranges from the list of VhostUserMemoryRegion
+ // that we get from the caller.
+ let mut regions = Vec::new();
+ let mut mappings: Vec<AddrMapping> = Vec::new();
+
+ for (region, file) in ctx.iter().zip(files) {
+ regions.push(
+ GuestRegionMmap::new(
+ region.mmap_region(file)?,
+ GuestAddress(region.guest_phys_addr),
+ )
+ .map_err(|e| {
+ VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
+ })?,
+ );
+ mappings.push(AddrMapping {
+ vmm_addr: region.user_addr,
+ size: region.memory_size,
+ gpa_base: region.guest_phys_addr,
+ });
+ }
+
+ let mem = GuestMemoryMmap::from_regions(regions).map_err(|e| {
+ VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
+ })?;
+
+ // Updating the inner GuestMemory object here will cause all our vrings to
+ // see the new one the next time they call to `atomic_mem.memory()`.
+ self.atomic_mem.lock().unwrap().replace(mem);
+
+ self.backend
+ .update_memory(self.atomic_mem.clone())
+ .map_err(|e| {
+ VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
+ })?;
+ self.mappings = mappings;
+
+ Ok(())
+ }
+
+ fn set_vring_num(&mut self, index: u32, num: u32) -> VhostUserResult<()> {
+ if index as usize >= self.num_queues || num == 0 || num as usize > self.max_queue_size {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.vrings[index as usize].set_queue_size(num as u16);
+ Ok(())
+ }
+
+ fn set_vring_addr(
+ &mut self,
+ index: u32,
+ _flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ _log: u64,
+ ) -> VhostUserResult<()> {
+ if index as usize >= self.num_queues {
+ return Err(VhostUserError::InvalidParam);
+ }
+
+ if !self.mappings.is_empty() {
+ let desc_table = self.vmm_va_to_gpa(descriptor).map_err(|e| {
+ VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
+ })?;
+ let avail_ring = self.vmm_va_to_gpa(available).map_err(|e| {
+ VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
+ })?;
+ let used_ring = self.vmm_va_to_gpa(used).map_err(|e| {
+ VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
+ })?;
+ self.vrings[index as usize]
+ .set_queue_info(desc_table, avail_ring, used_ring)
+ .map_err(|_| VhostUserError::InvalidParam)?;
+
+ // SET_VRING_BASE will only restore the 'avail' index, however, after the guest driver
+ // changes, for instance, after reboot, the 'used' index should be reset to 0.
+ //
+ // So let's fetch the used index from the vring as set by the guest here to keep
+ // compatibility with the QEMU's vhost-user library just in case, any implementation
+ // expects the 'used' index to be set when receiving a SET_VRING_ADDR message.
+ //
+ // Note: I'm not sure why QEMU's vhost-user library sets the 'user' index here,
+ // _probably_ to make sure that the VQ is already configured. A better solution would
+ // be to receive the 'used' index in SET_VRING_BASE, as is done when using packed VQs.
+ let idx = self.vrings[index as usize]
+ .queue_used_idx()
+ .map_err(|_| VhostUserError::SlaveInternalError)?;
+ self.vrings[index as usize].set_queue_next_used(idx);
+
+ Ok(())
+ } else {
+ Err(VhostUserError::InvalidParam)
+ }
+ }
+
+ fn set_vring_base(&mut self, index: u32, base: u32) -> VhostUserResult<()> {
+ let event_idx: bool = (self.acked_features & (1 << VIRTIO_RING_F_EVENT_IDX)) != 0;
+
+ self.vrings[index as usize].set_queue_next_avail(base as u16);
+ self.vrings[index as usize].set_queue_event_idx(event_idx);
+ self.backend.set_event_idx(event_idx);
+
+ Ok(())
+ }
+
+ fn get_vring_base(&mut self, index: u32) -> VhostUserResult<VhostUserVringState> {
+ if index as usize >= self.num_queues {
+ return Err(VhostUserError::InvalidParam);
+ }
+
+ // Quote from vhost-user specification:
+ // Client must start ring upon receiving a kick (that is, detecting
+ // that file descriptor is readable) on the descriptor specified by
+ // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
+ // VHOST_USER_GET_VRING_BASE.
+ self.vrings[index as usize].set_queue_ready(false);
+
+ if let Some(fd) = self.vrings[index as usize].get_ref().get_kick() {
+ for (thread_index, queues_mask) in self.queues_per_thread.iter().enumerate() {
+ let shifted_queues_mask = queues_mask >> index;
+ if shifted_queues_mask & 1u64 == 1u64 {
+ let evt_idx = queues_mask.count_ones() - shifted_queues_mask.count_ones();
+ self.handlers[thread_index]
+ .unregister_event(fd.as_raw_fd(), EventSet::IN, u64::from(evt_idx))
+ .map_err(VhostUserError::ReqHandlerError)?;
+ break;
+ }
+ }
+ }
+
+ let next_avail = self.vrings[index as usize].queue_next_avail();
+
+ self.vrings[index as usize].set_kick(None);
+ self.vrings[index as usize].set_call(None);
+
+ Ok(VhostUserVringState::new(index, u32::from(next_avail)))
+ }
+
+ fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostUserResult<()> {
+ if index as usize >= self.num_queues {
+ return Err(VhostUserError::InvalidParam);
+ }
+
+ // SAFETY: EventFd requires that it has sole ownership of its fd. So
+ // does File, so this is safe.
+ // Ideally, we'd have a generic way to refer to a uniquely-owned fd,
+ // such as that proposed by Rust RFC #3128.
+ self.vrings[index as usize].set_kick(file);
+
+ if self.vring_needs_init(&self.vrings[index as usize]) {
+ self.initialize_vring(&self.vrings[index as usize], index)?;
+ }
+
+ Ok(())
+ }
+
+ fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostUserResult<()> {
+ if index as usize >= self.num_queues {
+ return Err(VhostUserError::InvalidParam);
+ }
+
+ self.vrings[index as usize].set_call(file);
+
+ if self.vring_needs_init(&self.vrings[index as usize]) {
+ self.initialize_vring(&self.vrings[index as usize], index)?;
+ }
+
+ Ok(())
+ }
+
+ fn set_vring_err(&mut self, index: u8, file: Option<File>) -> VhostUserResult<()> {
+ if index as usize >= self.num_queues {
+ return Err(VhostUserError::InvalidParam);
+ }
+
+ self.vrings[index as usize].set_err(file);
+
+ Ok(())
+ }
+
+ fn get_protocol_features(&mut self) -> VhostUserResult<VhostUserProtocolFeatures> {
+ Ok(self.backend.protocol_features())
+ }
+
+ fn set_protocol_features(&mut self, features: u64) -> VhostUserResult<()> {
+ // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must
+ // support this message even before VHOST_USER_SET_FEATURES was
+ // called.
+ self.acked_protocol_features = features;
+ Ok(())
+ }
+
+ fn get_queue_num(&mut self) -> VhostUserResult<u64> {
+ Ok(self.num_queues as u64)
+ }
+
+ fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostUserResult<()> {
+ // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
+ // has been negotiated.
+ self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
+
+ if index as usize >= self.num_queues {
+ return Err(VhostUserError::InvalidParam);
+ }
+
+ // Slave must not pass data to/from the backend until ring is
+ // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1,
+ // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE
+ // with parameter 0.
+ self.vrings[index as usize].set_enabled(enable);
+
+ Ok(())
+ }
+
+ fn get_config(
+ &mut self,
+ offset: u32,
+ size: u32,
+ _flags: VhostUserConfigFlags,
+ ) -> VhostUserResult<Vec<u8>> {
+ Ok(self.backend.get_config(offset, size))
+ }
+
+ fn set_config(
+ &mut self,
+ offset: u32,
+ buf: &[u8],
+ _flags: VhostUserConfigFlags,
+ ) -> VhostUserResult<()> {
+ self.backend
+ .set_config(offset, buf)
+ .map_err(VhostUserError::ReqHandlerError)
+ }
+
+ fn set_slave_req_fd(&mut self, slave: Slave) {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::REPLY_ACK.bits() != 0 {
+ slave.set_reply_ack_flag(true);
+ }
+
+ self.backend.set_slave_req_fd(slave);
+ }
+
+ fn get_inflight_fd(
+ &mut self,
+ _inflight: &vhost::vhost_user::message::VhostUserInflight,
+ ) -> VhostUserResult<(vhost::vhost_user::message::VhostUserInflight, File)> {
+ // Assume the backend hasn't negotiated the inflight feature; it
+ // wouldn't be correct for the backend to do so, as we don't (yet)
+ // provide a way for it to handle such requests.
+ Err(VhostUserError::InvalidOperation("not supported"))
+ }
+
+ fn set_inflight_fd(
+ &mut self,
+ _inflight: &vhost::vhost_user::message::VhostUserInflight,
+ _file: File,
+ ) -> VhostUserResult<()> {
+ Err(VhostUserError::InvalidOperation("not supported"))
+ }
+
+ fn get_max_mem_slots(&mut self) -> VhostUserResult<u64> {
+ Ok(MAX_MEM_SLOTS)
+ }
+
+ fn add_mem_region(
+ &mut self,
+ region: &VhostUserSingleMemoryRegion,
+ file: File,
+ ) -> VhostUserResult<()> {
+ let guest_region = Arc::new(
+ GuestRegionMmap::new(
+ region.mmap_region(file)?,
+ GuestAddress(region.guest_phys_addr),
+ )
+ .map_err(|e| {
+ VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
+ })?,
+ );
+
+ let mem = self
+ .atomic_mem
+ .memory()
+ .insert_region(guest_region)
+ .map_err(|e| {
+ VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
+ })?;
+
+ self.atomic_mem.lock().unwrap().replace(mem);
+
+ self.backend
+ .update_memory(self.atomic_mem.clone())
+ .map_err(|e| {
+ VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
+ })?;
+
+ self.mappings.push(AddrMapping {
+ vmm_addr: region.user_addr,
+ size: region.memory_size,
+ gpa_base: region.guest_phys_addr,
+ });
+
+ Ok(())
+ }
+
+ fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> VhostUserResult<()> {
+ let (mem, _) = self
+ .atomic_mem
+ .memory()
+ .remove_region(GuestAddress(region.guest_phys_addr), region.memory_size)
+ .map_err(|e| {
+ VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
+ })?;
+
+ self.atomic_mem.lock().unwrap().replace(mem);
+
+ self.backend
+ .update_memory(self.atomic_mem.clone())
+ .map_err(|e| {
+ VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))
+ })?;
+
+ self.mappings
+ .retain(|mapping| mapping.gpa_base != region.guest_phys_addr);
+
+ Ok(())
+ }
+}
+
+impl<S, V, B: Bitmap> Drop for VhostUserHandler<S, V, B> {
+ fn drop(&mut self) {
+ // Signal all working threads to exit.
+ self.send_exit_event();
+
+ for thread in self.worker_threads.drain(..) {
+ if let Err(e) = thread.join() {
+ error!("Error in vring worker: {:?}", e);
+ }
+ }
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..c65a19e
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,270 @@
+// Copyright 2019 Intel Corporation. All Rights Reserved.
+// Copyright 2019-2021 Alibaba Cloud Computing. All rights reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+//! A simple framework to run a vhost-user backend service.
+
+#[macro_use]
+extern crate log;
+
+use std::fmt::{Display, Formatter};
+use std::sync::{Arc, Mutex};
+use std::thread;
+
+use vhost::vhost_user::{Error as VhostUserError, Listener, SlaveListener, SlaveReqHandler};
+use vm_memory::bitmap::Bitmap;
+use vm_memory::mmap::NewBitmap;
+use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap};
+
+use self::handler::VhostUserHandler;
+
+mod backend;
+pub use self::backend::{VhostUserBackend, VhostUserBackendMut};
+
+mod event_loop;
+pub use self::event_loop::VringEpollHandler;
+
+mod handler;
+pub use self::handler::VhostUserHandlerError;
+
+mod vring;
+pub use self::vring::{
+ VringMutex, VringRwLock, VringState, VringStateGuard, VringStateMutGuard, VringT,
+};
+
+/// An alias for `GuestMemoryAtomic<GuestMemoryMmap<B>>` to simplify code.
+type GM<B> = GuestMemoryAtomic<GuestMemoryMmap<B>>;
+
+#[derive(Debug)]
+/// Errors related to vhost-user daemon.
+pub enum Error {
+ /// Failed to create a new vhost-user handler.
+ NewVhostUserHandler(VhostUserHandlerError),
+ /// Failed creating vhost-user slave listener.
+ CreateSlaveListener(VhostUserError),
+ /// Failed creating vhost-user slave handler.
+ CreateSlaveReqHandler(VhostUserError),
+ /// Failed starting daemon thread.
+ StartDaemon(std::io::Error),
+ /// Failed waiting for daemon thread.
+ WaitDaemon(std::boxed::Box<dyn std::any::Any + std::marker::Send>),
+ /// Failed handling a vhost-user request.
+ HandleRequest(VhostUserError),
+}
+
+impl Display for Error {
+ fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
+ match self {
+ Error::NewVhostUserHandler(e) => write!(f, "cannot create vhost user handler: {}", e),
+ Error::CreateSlaveListener(e) => write!(f, "cannot create slave listener: {}", e),
+ Error::CreateSlaveReqHandler(e) => write!(f, "cannot create slave req handler: {}", e),
+ Error::StartDaemon(e) => write!(f, "failed to start daemon: {}", e),
+ Error::WaitDaemon(_e) => write!(f, "failed to wait for daemon exit"),
+ Error::HandleRequest(e) => write!(f, "failed to handle request: {}", e),
+ }
+ }
+}
+
+/// Result of vhost-user daemon operations.
+pub type Result<T> = std::result::Result<T, Error>;
+
+/// Implement a simple framework to run a vhost-user service daemon.
+///
+/// This structure is the public API the backend is allowed to interact with in order to run
+/// a fully functional vhost-user daemon.
+pub struct VhostUserDaemon<S, V, B: Bitmap + 'static = ()> {
+ name: String,
+ handler: Arc<Mutex<VhostUserHandler<S, V, B>>>,
+ main_thread: Option<thread::JoinHandle<Result<()>>>,
+}
+
+impl<S, V, B> VhostUserDaemon<S, V, B>
+where
+ S: VhostUserBackend<V, B> + Clone + 'static,
+ V: VringT<GM<B>> + Clone + Send + Sync + 'static,
+ B: NewBitmap + Clone + Send + Sync,
+{
+ /// Create the daemon instance, providing the backend implementation of `VhostUserBackend`.
+ ///
+ /// Under the hood, this will start a dedicated thread responsible for listening onto
+ /// registered event. Those events can be vring events or custom events from the backend,
+ /// but they get to be registered later during the sequence.
+ pub fn new(
+ name: String,
+ backend: S,
+ atomic_mem: GuestMemoryAtomic<GuestMemoryMmap<B>>,
+ ) -> Result<Self> {
+ let handler = Arc::new(Mutex::new(
+ VhostUserHandler::new(backend, atomic_mem).map_err(Error::NewVhostUserHandler)?,
+ ));
+
+ Ok(VhostUserDaemon {
+ name,
+ handler,
+ main_thread: None,
+ })
+ }
+
+ /// Run a dedicated thread handling all requests coming through the socket.
+ /// This runs in an infinite loop that should be terminating once the other
+ /// end of the socket (the VMM) hangs up.
+ ///
+ /// This function is the common code for starting a new daemon, no matter if
+ /// it acts as a client or a server.
+ fn start_daemon(
+ &mut self,
+ mut handler: SlaveReqHandler<Mutex<VhostUserHandler<S, V, B>>>,
+ ) -> Result<()> {
+ let handle = thread::Builder::new()
+ .name(self.name.clone())
+ .spawn(move || loop {
+ handler.handle_request().map_err(Error::HandleRequest)?;
+ })
+ .map_err(Error::StartDaemon)?;
+
+ self.main_thread = Some(handle);
+
+ Ok(())
+ }
+
+ /// Connect to the vhost-user socket and run a dedicated thread handling
+ /// all requests coming through this socket. This runs in an infinite loop
+ /// that should be terminating once the other end of the socket (the VMM)
+ /// hangs up.
+ pub fn start_client(&mut self, socket_path: &str) -> Result<()> {
+ let slave_handler = SlaveReqHandler::connect(socket_path, self.handler.clone())
+ .map_err(Error::CreateSlaveReqHandler)?;
+ self.start_daemon(slave_handler)
+ }
+
+ /// Listen to the vhost-user socket and run a dedicated thread handling all requests coming
+ /// through this socket.
+ ///
+ /// This runs in an infinite loop that should be terminating once the other end of the socket
+ /// (the VMM) disconnects.
+ // TODO: the current implementation has limitations that only one incoming connection will be
+ // handled from the listener. Should it be enhanced to support reconnection?
+ pub fn start(&mut self, listener: Listener) -> Result<()> {
+ let mut slave_listener = SlaveListener::new(listener, self.handler.clone())
+ .map_err(Error::CreateSlaveListener)?;
+ let slave_handler = self.accept(&mut slave_listener)?;
+ self.start_daemon(slave_handler)
+ }
+
+ fn accept(
+ &self,
+ slave_listener: &mut SlaveListener<Mutex<VhostUserHandler<S, V, B>>>,
+ ) -> Result<SlaveReqHandler<Mutex<VhostUserHandler<S, V, B>>>> {
+ loop {
+ match slave_listener.accept() {
+ Err(e) => return Err(Error::CreateSlaveListener(e)),
+ Ok(Some(v)) => return Ok(v),
+ Ok(None) => continue,
+ }
+ }
+ }
+
+ /// Wait for the thread handling the vhost-user socket connection to terminate.
+ pub fn wait(&mut self) -> Result<()> {
+ if let Some(handle) = self.main_thread.take() {
+ match handle.join().map_err(Error::WaitDaemon)? {
+ Ok(()) => Ok(()),
+ Err(Error::HandleRequest(VhostUserError::SocketBroken(_))) => Ok(()),
+ Err(e) => Err(e),
+ }
+ } else {
+ Ok(())
+ }
+ }
+
+ /// Retrieve the vring epoll handler.
+ ///
+ /// This is necessary to perform further actions like registering and unregistering some extra
+ /// event file descriptors.
+ pub fn get_epoll_handlers(&self) -> Vec<Arc<VringEpollHandler<S, V, B>>> {
+ // Do not expect poisoned lock.
+ self.handler.lock().unwrap().get_epoll_handlers()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::backend::tests::MockVhostBackend;
+ use super::*;
+ use std::os::unix::net::{UnixListener, UnixStream};
+ use std::sync::Barrier;
+ use vm_memory::{GuestAddress, GuestMemoryAtomic, GuestMemoryMmap};
+
+ #[test]
+ fn test_new_daemon() {
+ let mem = GuestMemoryAtomic::new(
+ GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(),
+ );
+ let backend = Arc::new(Mutex::new(MockVhostBackend::new()));
+ let mut daemon = VhostUserDaemon::new("test".to_owned(), backend, mem).unwrap();
+
+ let handlers = daemon.get_epoll_handlers();
+ assert_eq!(handlers.len(), 2);
+
+ let barrier = Arc::new(Barrier::new(2));
+ let tmpdir = tempfile::tempdir().unwrap();
+ let mut path = tmpdir.path().to_path_buf();
+ path.push("socket");
+
+ let barrier2 = barrier.clone();
+ let path1 = path.clone();
+ let thread = thread::spawn(move || {
+ barrier2.wait();
+ let socket = UnixStream::connect(&path1).unwrap();
+ barrier2.wait();
+ drop(socket)
+ });
+
+ let listener = Listener::new(&path, false).unwrap();
+ barrier.wait();
+ daemon.start(listener).unwrap();
+ barrier.wait();
+ // Above process generates a `HandleRequest(PartialMessage)` error.
+ daemon.wait().unwrap_err();
+ daemon.wait().unwrap();
+ thread.join().unwrap();
+ }
+
+ #[test]
+ fn test_new_daemon_client() {
+ let mem = GuestMemoryAtomic::new(
+ GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(),
+ );
+ let backend = Arc::new(Mutex::new(MockVhostBackend::new()));
+ let mut daemon = VhostUserDaemon::new("test".to_owned(), backend, mem).unwrap();
+
+ let handlers = daemon.get_epoll_handlers();
+ assert_eq!(handlers.len(), 2);
+
+ let barrier = Arc::new(Barrier::new(2));
+ let tmpdir = tempfile::tempdir().unwrap();
+ let mut path = tmpdir.path().to_path_buf();
+ path.push("socket");
+
+ let barrier2 = barrier.clone();
+ let path1 = path.clone();
+ let thread = thread::spawn(move || {
+ let listener = UnixListener::bind(&path1).unwrap();
+ barrier2.wait();
+ let (stream, _) = listener.accept().unwrap();
+ barrier2.wait();
+ drop(stream)
+ });
+
+ barrier.wait();
+ daemon
+ .start_client(path.as_path().to_str().unwrap())
+ .unwrap();
+ barrier.wait();
+ // Above process generates a `HandleRequest(PartialMessage)` error.
+ daemon.wait().unwrap_err();
+ daemon.wait().unwrap();
+ thread.join().unwrap();
+ }
+}
diff --git a/src/net.rs b/src/net.rs
new file mode 100644
index 0000000..06a5e47
--- /dev/null
+++ b/src/net.rs
@@ -0,0 +1,19 @@
+// Copyright (C) 2021 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+
+//! Trait to control vhost-net backend drivers.
+
+use std::fs::File;
+
+use crate::backend::VhostBackend;
+use crate::Result;
+
+/// Trait to control vhost-net backend drivers.
+pub trait VhostNet: VhostBackend {
+ /// Set fd as VHOST_NET backend.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the virtqueue
+ /// * `fd` - The file descriptor which servers as the backend
+ fn set_backend(&self, queue_idx: usize, fd: Option<&File>) -> Result<()>;
+}
diff --git a/src/vdpa.rs b/src/vdpa.rs
new file mode 100644
index 0000000..6af01cf
--- /dev/null
+++ b/src/vdpa.rs
@@ -0,0 +1,126 @@
+// Copyright (C) 2021 Red Hat, Inc. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+
+//! Trait to control vhost-vdpa backend drivers.
+
+use vmm_sys_util::eventfd::EventFd;
+
+use crate::backend::VhostBackend;
+use crate::Result;
+
+/// vhost vdpa IOVA range
+pub struct VhostVdpaIovaRange {
+ /// First address that can be mapped by vhost-vDPA.
+ pub first: u64,
+ /// Last address that can be mapped by vhost-vDPA.
+ pub last: u64,
+}
+
+/// Trait to control vhost-vdpa backend drivers.
+///
+/// vDPA (virtio Data Path Acceleration) devices has datapath compliant with the
+/// virtio specification and the control path is vendor specific.
+/// vDPA devices can be both physically located on the hardware or emulated
+/// by software.
+///
+/// Compared to vhost acceleration, vDPA offers more control over the device
+/// lifecycle.
+/// For this reason, the vhost-vdpa interface extends the vhost API, offering
+/// additional APIs for controlling the device (e.g. changing the state or
+/// accessing the configuration space
+pub trait VhostVdpa: VhostBackend {
+ /// Get the device id.
+ /// The device ids follow the same definition of the device id defined in virtio-spec.
+ fn get_device_id(&self) -> Result<u32>;
+
+ /// Get the status.
+ /// The status bits follow the same definition of the device status defined in virtio-spec.
+ fn get_status(&self) -> Result<u8>;
+
+ /// Set the status.
+ /// The status bits follow the same definition of the device status defined in virtio-spec.
+ ///
+ /// # Arguments
+ /// * `status` - Status bits to set
+ fn set_status(&self, status: u8) -> Result<()>;
+
+ /// Get the device configuration.
+ ///
+ /// # Arguments
+ /// * `offset` - Offset in the device configuration space
+ /// * `buffer` - Buffer for configuration data
+ fn get_config(&self, offset: u32, buffer: &mut [u8]) -> Result<()>;
+
+ /// Set the device configuration.
+ ///
+ /// # Arguments
+ /// * `offset` - Offset in the device configuration space
+ /// * `buffer` - Buffer for configuration data
+ fn set_config(&self, offset: u32, buffer: &[u8]) -> Result<()>;
+
+ /// Set the status for a given vring.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to enable/disable.
+ /// * `enabled` - true to enable the vring, false to disable it.
+ fn set_vring_enable(&self, queue_index: usize, enabled: bool) -> Result<()>;
+
+ /// Get the maximum number of descriptors in the vring supported by the device.
+ fn get_vring_num(&self) -> Result<u16>;
+
+ /// Set the eventfd to trigger when device configuration change.
+ ///
+ /// # Arguments
+ /// * `fd` - EventFd to trigger.
+ fn set_config_call(&self, fd: &EventFd) -> Result<()>;
+
+ /// Get the valid I/O virtual addresses range supported by the device.
+ fn get_iova_range(&self) -> Result<VhostVdpaIovaRange>;
+
+ /// Get the config size
+ fn get_config_size(&self) -> Result<u32>;
+
+ /// Get the count of all virtqueues
+ fn get_vqs_count(&self) -> Result<u32>;
+
+ /// Get the number of virtqueue groups
+ fn get_group_num(&self) -> Result<u32>;
+
+ /// Get the number of address spaces
+ fn get_as_num(&self) -> Result<u32>;
+
+ /// Get the group for a virtqueue.
+ /// The virtqueue index is stored in the index field of
+ /// vhost_vring_state. The group for this specific virtqueue is
+ /// returned via num field of vhost_vring_state.
+ fn get_vring_group(&self, queue_index: u32) -> Result<u32>;
+
+ /// Set the ASID for a virtqueue group. The group index is stored in
+ /// the index field of vhost_vring_state, the ASID associated with this
+ /// group is stored at num field of vhost_vring_state.
+ fn set_group_asid(&self, group_index: u32, asid: u32) -> Result<()>;
+
+ /// Suspend a device so it does not process virtqueue requests anymore
+ ///
+ /// After the return of ioctl the device must preserve all the necessary state
+ /// (the virtqueue vring base plus the possible device specific states) that is
+ /// required for restoring in the future. The device must not change its
+ /// configuration after that point.
+ fn suspend(&self) -> Result<()>;
+
+ /// Map DMA region.
+ ///
+ /// # Arguments
+ /// * `iova` - I/O virtual address.
+ /// * `size` - Size of the I/O mapping.
+ /// * `vaddr` - Virtual address in the current process.
+ /// * `readonly` - true if the region is read-only, false if reads and writes are allowed.
+ fn dma_map(&self, iova: u64, size: u64, vaddr: *const u8, readonly: bool) -> Result<()>;
+
+ /// Unmap DMA region.
+ ///
+ /// # Arguments
+ /// * `iova` - I/O virtual address.
+ /// * `size` - Size of the I/O mapping.
+ fn dma_unmap(&self, iova: u64, size: u64) -> Result<()>;
+}
diff --git a/src/vhost_kern/mod.rs b/src/vhost_kern/mod.rs
new file mode 100644
index 0000000..1fa5000
--- /dev/null
+++ b/src/vhost_kern/mod.rs
@@ -0,0 +1,467 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+//
+// Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE-BSD-Google file.
+
+//! Traits and structs to control Linux in-kernel vhost drivers.
+//!
+//! The initial vhost implementation is a part of the Linux kernel and uses ioctl interface to
+//! communicate with userspace applications. This sub module provides ioctl based interfaces to
+//! control the in-kernel net, scsi, vsock vhost drivers.
+
+use std::mem;
+use std::os::unix::io::{AsRawFd, RawFd};
+
+use libc::{c_void, ssize_t, write};
+
+use vm_memory::{Address, GuestAddress, GuestAddressSpace, GuestMemory, GuestUsize};
+use vmm_sys_util::eventfd::EventFd;
+use vmm_sys_util::ioctl::{ioctl, ioctl_with_mut_ref, ioctl_with_ptr, ioctl_with_ref};
+
+use super::{
+ Error, Result, VhostBackend, VhostIotlbBackend, VhostIotlbMsg, VhostIotlbMsgParser,
+ VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData, VHOST_MAX_MEMORY_REGIONS,
+};
+
+pub mod vhost_binding;
+use self::vhost_binding::*;
+
+#[cfg(feature = "vhost-net")]
+pub mod net;
+#[cfg(feature = "vhost-vdpa")]
+pub mod vdpa;
+#[cfg(feature = "vhost-vsock")]
+pub mod vsock;
+
+#[inline]
+fn ioctl_result<T>(rc: i32, res: T) -> Result<T> {
+ if rc < 0 {
+ Err(Error::IoctlError(std::io::Error::last_os_error()))
+ } else {
+ Ok(res)
+ }
+}
+
+#[inline]
+fn io_result<T>(rc: isize, res: T) -> Result<T> {
+ if rc < 0 {
+ Err(Error::IOError(std::io::Error::last_os_error()))
+ } else {
+ Ok(res)
+ }
+}
+
+/// Represent an in-kernel vhost device backend.
+pub trait VhostKernBackend: AsRawFd {
+ /// Associated type to access guest memory.
+ type AS: GuestAddressSpace;
+
+ /// Get the object to access the guest's memory.
+ fn mem(&self) -> &Self::AS;
+
+ /// Check whether the ring configuration is valid.
+ fn is_valid(&self, config_data: &VringConfigData) -> bool {
+ let queue_size = config_data.queue_size;
+ if queue_size > config_data.queue_max_size
+ || queue_size == 0
+ || (queue_size & (queue_size - 1)) != 0
+ {
+ return false;
+ }
+
+ let m = self.mem().memory();
+ let desc_table_size = 16 * u64::from(queue_size) as GuestUsize;
+ let avail_ring_size = 6 + 2 * u64::from(queue_size) as GuestUsize;
+ let used_ring_size = 6 + 8 * u64::from(queue_size) as GuestUsize;
+ if GuestAddress(config_data.desc_table_addr)
+ .checked_add(desc_table_size)
+ .map_or(true, |v| !m.address_in_range(v))
+ {
+ return false;
+ }
+ if GuestAddress(config_data.avail_ring_addr)
+ .checked_add(avail_ring_size)
+ .map_or(true, |v| !m.address_in_range(v))
+ {
+ return false;
+ }
+ if GuestAddress(config_data.used_ring_addr)
+ .checked_add(used_ring_size)
+ .map_or(true, |v| !m.address_in_range(v))
+ {
+ return false;
+ }
+
+ config_data.is_log_addr_valid()
+ }
+}
+
+impl<T: VhostKernBackend> VhostBackend for T {
+ /// Get a bitmask of supported virtio/vhost features.
+ fn get_features(&self) -> Result<u64> {
+ let mut avail_features: u64 = 0;
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_mut_ref(self, VHOST_GET_FEATURES(), &mut avail_features) };
+ ioctl_result(ret, avail_features)
+ }
+
+ /// Inform the vhost subsystem which features to enable. This should be a subset of
+ /// supported features from VHOST_GET_FEATURES.
+ ///
+ /// # Arguments
+ /// * `features` - Bitmask of features to set.
+ fn set_features(&self, features: u64) -> Result<()> {
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_FEATURES(), &features) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the current process as the owner of this file descriptor.
+ /// This must be run before any other vhost ioctls.
+ fn set_owner(&self) -> Result<()> {
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl(self, VHOST_SET_OWNER()) };
+ ioctl_result(ret, ())
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl(self, VHOST_RESET_OWNER()) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the guest memory mappings for vhost to use.
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ if regions.is_empty() || regions.len() > VHOST_MAX_MEMORY_REGIONS {
+ return Err(Error::InvalidGuestMemory);
+ }
+
+ let mut vhost_memory = VhostMemory::new(regions.len() as u16);
+ for (index, region) in regions.iter().enumerate() {
+ vhost_memory.set_region(
+ index as u32,
+ &vhost_memory_region {
+ guest_phys_addr: region.guest_phys_addr,
+ memory_size: region.memory_size,
+ userspace_addr: region.userspace_addr,
+ flags_padding: 0u64,
+ },
+ )?;
+ }
+
+ // SAFETY: This ioctl is called with a pointer that is valid for the lifetime
+ // of this function. The kernel will make its own copy of the memory
+ // tables. As always, check the return value.
+ let ret = unsafe { ioctl_with_ptr(self, VHOST_SET_MEM_TABLE(), vhost_memory.as_ptr()) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set base address for page modification logging.
+ ///
+ /// # Arguments
+ /// * `base` - Base address for page modification logging.
+ fn set_log_base(&self, base: u64, region: Option<VhostUserDirtyLogRegion>) -> Result<()> {
+ if region.is_some() {
+ return Err(Error::LogAddress);
+ }
+
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_LOG_BASE(), &base) };
+ ioctl_result(ret, ())
+ }
+
+ /// Specify an eventfd file descriptor to signal on log write.
+ fn set_log_fd(&self, fd: RawFd) -> Result<()> {
+ let val: i32 = fd;
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_LOG_FD(), &val) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the number of descriptors in the vring.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to set descriptor count for.
+ /// * `num` - Number of descriptors in the queue.
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
+ let vring_state = vhost_vring_state {
+ index: queue_index as u32,
+ num: u32::from(num),
+ };
+
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_NUM(), &vring_state) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the addresses for a given vring.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to set addresses for.
+ /// * `config_data` - Vring config data, addresses of desc_table, avail_ring
+ /// and used_ring are in the guest address space.
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
+ if !self.is_valid(config_data) {
+ return Err(Error::InvalidQueue);
+ }
+
+ // The addresses are converted into the host address space.
+ let vring_addr = config_data.to_vhost_vring_addr(queue_index, self.mem())?;
+
+ // SAFETY: This ioctl is called on a valid vhost fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_ADDR(), &vring_addr) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the first index to look for available descriptors.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `num` - Index where available descriptors start.
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
+ let vring_state = vhost_vring_state {
+ index: queue_index as u32,
+ num: u32::from(base),
+ };
+
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_BASE(), &vring_state) };
+ ioctl_result(ret, ())
+ }
+
+ /// Get a bitmask of supported virtio/vhost features.
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
+ let vring_state = vhost_vring_state {
+ index: queue_index as u32,
+ num: 0,
+ };
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_GET_VRING_BASE(), &vring_state) };
+ ioctl_result(ret, vring_state.num)
+ }
+
+ /// Set the eventfd to trigger when buffers have been used by the host.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd to trigger.
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let vring_file = vhost_vring_file {
+ index: queue_index as u32,
+ fd: fd.as_raw_fd(),
+ };
+
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_CALL(), &vring_file) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the eventfd that will be signaled by the guest when buffers are
+ /// available for the host to process.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd that will be signaled from guest.
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let vring_file = vhost_vring_file {
+ index: queue_index as u32,
+ fd: fd.as_raw_fd(),
+ };
+
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_KICK(), &vring_file) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the eventfd to signal an error from the vhost backend.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd that will be signaled from the backend.
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let vring_file = vhost_vring_file {
+ index: queue_index as u32,
+ fd: fd.as_raw_fd(),
+ };
+
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_ERR(), &vring_file) };
+ ioctl_result(ret, ())
+ }
+}
+
+/// Interface to handle in-kernel backend features.
+pub trait VhostKernFeatures: Sized + AsRawFd {
+ /// Get features acked with the vhost backend.
+ fn get_backend_features_acked(&self) -> u64;
+
+ /// Set features acked with the vhost backend.
+ fn set_backend_features_acked(&mut self, features: u64);
+
+ /// Get a bitmask of supported vhost backend features.
+ fn get_backend_features(&self) -> Result<u64> {
+ let mut avail_features: u64 = 0;
+
+ let ret =
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ unsafe { ioctl_with_mut_ref(self, VHOST_GET_BACKEND_FEATURES(), &mut avail_features) };
+ ioctl_result(ret, avail_features)
+ }
+
+ /// Inform the vhost subsystem which backend features to enable. This should
+ /// be a subset of supported features from VHOST_GET_BACKEND_FEATURES.
+ ///
+ /// # Arguments
+ /// * `features` - Bitmask of features to set.
+ fn set_backend_features(&mut self, features: u64) -> Result<()> {
+ // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_BACKEND_FEATURES(), &features) };
+
+ if ret >= 0 {
+ self.set_backend_features_acked(features);
+ }
+
+ ioctl_result(ret, ())
+ }
+}
+
+/// Handle IOTLB messeges for in-kernel vhost device backend.
+impl<I: VhostKernBackend + VhostKernFeatures> VhostIotlbBackend for I {
+ /// Send an IOTLB message to the in-kernel vhost backend.
+ ///
+ /// # Arguments
+ /// * `msg` - IOTLB message to send.
+ fn send_iotlb_msg(&self, msg: &VhostIotlbMsg) -> Result<()> {
+ let ret: ssize_t;
+
+ if self.get_backend_features_acked() & (1 << VHOST_BACKEND_F_IOTLB_MSG_V2) != 0 {
+ let mut msg_v2 = vhost_msg_v2 {
+ type_: VHOST_IOTLB_MSG_V2,
+ ..Default::default()
+ };
+
+ msg_v2.__bindgen_anon_1.iotlb.iova = msg.iova;
+ msg_v2.__bindgen_anon_1.iotlb.size = msg.size;
+ msg_v2.__bindgen_anon_1.iotlb.uaddr = msg.userspace_addr;
+ msg_v2.__bindgen_anon_1.iotlb.perm = msg.perm as u8;
+ msg_v2.__bindgen_anon_1.iotlb.type_ = msg.msg_type as u8;
+
+ // SAFETY: This is safe because we are using a valid vhost fd, and
+ // a valid pointer and size to the vhost_msg_v2 structure.
+ ret = unsafe {
+ write(
+ self.as_raw_fd(),
+ &msg_v2 as *const vhost_msg_v2 as *const c_void,
+ mem::size_of::<vhost_msg_v2>(),
+ )
+ };
+ } else {
+ let mut msg_v1 = vhost_msg {
+ type_: VHOST_IOTLB_MSG,
+ ..Default::default()
+ };
+
+ msg_v1.__bindgen_anon_1.iotlb.iova = msg.iova;
+ msg_v1.__bindgen_anon_1.iotlb.size = msg.size;
+ msg_v1.__bindgen_anon_1.iotlb.uaddr = msg.userspace_addr;
+ msg_v1.__bindgen_anon_1.iotlb.perm = msg.perm as u8;
+ msg_v1.__bindgen_anon_1.iotlb.type_ = msg.msg_type as u8;
+
+ // SAFETY: This is safe because we are using a valid vhost fd, and
+ // a valid pointer and size to the vhost_msg structure.
+ ret = unsafe {
+ write(
+ self.as_raw_fd(),
+ &msg_v1 as *const vhost_msg as *const c_void,
+ mem::size_of::<vhost_msg>(),
+ )
+ };
+ }
+
+ io_result(ret, ())
+ }
+}
+
+impl VhostIotlbMsgParser for vhost_msg {
+ fn parse(&self, msg: &mut VhostIotlbMsg) -> Result<()> {
+ if self.type_ != VHOST_IOTLB_MSG {
+ return Err(Error::InvalidIotlbMsg);
+ }
+
+ // SAFETY: We trust the kernel to return a structure with the union
+ // fields properly initialized. We are sure it is a vhost_msg, because
+ // we checked that `self.type_` is VHOST_IOTLB_MSG.
+ unsafe {
+ if self.__bindgen_anon_1.iotlb.type_ == 0 {
+ return Err(Error::InvalidIotlbMsg);
+ }
+
+ msg.iova = self.__bindgen_anon_1.iotlb.iova;
+ msg.size = self.__bindgen_anon_1.iotlb.size;
+ msg.userspace_addr = self.__bindgen_anon_1.iotlb.uaddr;
+ msg.perm = mem::transmute(self.__bindgen_anon_1.iotlb.perm);
+ msg.msg_type = mem::transmute(self.__bindgen_anon_1.iotlb.type_);
+ }
+
+ Ok(())
+ }
+}
+
+impl VhostIotlbMsgParser for vhost_msg_v2 {
+ fn parse(&self, msg: &mut VhostIotlbMsg) -> Result<()> {
+ if self.type_ != VHOST_IOTLB_MSG_V2 {
+ return Err(Error::InvalidIotlbMsg);
+ }
+
+ // SAFETY: We trust the kernel to return a structure with the union
+ // fields properly initialized. We are sure it is a vhost_msg_v2, because
+ // we checked that `self.type_` is VHOST_IOTLB_MSG_V2.
+ unsafe {
+ if self.__bindgen_anon_1.iotlb.type_ == 0 {
+ return Err(Error::InvalidIotlbMsg);
+ }
+
+ msg.iova = self.__bindgen_anon_1.iotlb.iova;
+ msg.size = self.__bindgen_anon_1.iotlb.size;
+ msg.userspace_addr = self.__bindgen_anon_1.iotlb.uaddr;
+ msg.perm = mem::transmute(self.__bindgen_anon_1.iotlb.perm);
+ msg.msg_type = mem::transmute(self.__bindgen_anon_1.iotlb.type_);
+ }
+
+ Ok(())
+ }
+}
+
+impl VringConfigData {
+ /// Convert the config (guest address space) into vhost_vring_addr
+ /// (host address space).
+ pub fn to_vhost_vring_addr<AS: GuestAddressSpace>(
+ &self,
+ queue_index: usize,
+ mem: &AS,
+ ) -> Result<vhost_vring_addr> {
+ let desc_addr = mem
+ .memory()
+ .get_host_address(GuestAddress(self.desc_table_addr))
+ .map_err(|_| Error::DescriptorTableAddress)?;
+ let avail_addr = mem
+ .memory()
+ .get_host_address(GuestAddress(self.avail_ring_addr))
+ .map_err(|_| Error::AvailAddress)?;
+ let used_addr = mem
+ .memory()
+ .get_host_address(GuestAddress(self.used_ring_addr))
+ .map_err(|_| Error::UsedAddress)?;
+ Ok(vhost_vring_addr {
+ index: queue_index as u32,
+ flags: self.flags,
+ desc_user_addr: desc_addr as u64,
+ used_user_addr: used_addr as u64,
+ avail_user_addr: avail_addr as u64,
+ log_guest_addr: self.get_log_addr(),
+ })
+ }
+}
diff --git a/src/vhost_kern/net.rs b/src/vhost_kern/net.rs
new file mode 100644
index 0000000..89f390c
--- /dev/null
+++ b/src/vhost_kern/net.rs
@@ -0,0 +1,177 @@
+// Copyright (C) 2021 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+
+//! Kernel-based vhost-net backend
+
+use std::fs::{File, OpenOptions};
+use std::os::unix::fs::OpenOptionsExt;
+use std::os::unix::io::{AsRawFd, RawFd};
+
+use vm_memory::GuestAddressSpace;
+use vmm_sys_util::ioctl::ioctl_with_ref;
+
+use super::vhost_binding::*;
+use super::{ioctl_result, Error, Result, VhostKernBackend};
+
+use crate::net::*;
+
+const VHOST_NET_PATH: &str = "/dev/vhost-net";
+
+/// Handle for running VHOST_NET ioctls
+pub struct Net<AS: GuestAddressSpace> {
+ fd: File,
+ mem: AS,
+}
+
+impl<AS: GuestAddressSpace> Net<AS> {
+ /// Open a handle to a new VHOST-NET instance.
+ pub fn new(mem: AS) -> Result<Self> {
+ Ok(Net {
+ fd: OpenOptions::new()
+ .read(true)
+ .write(true)
+ .custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK)
+ .open(VHOST_NET_PATH)
+ .map_err(Error::VhostOpen)?,
+ mem,
+ })
+ }
+}
+
+impl<AS: GuestAddressSpace> VhostNet for Net<AS> {
+ fn set_backend(&self, queue_index: usize, fd: Option<&File>) -> Result<()> {
+ let vring_file = vhost_vring_file {
+ index: queue_index as u32,
+ fd: fd.map_or(-1, |v| v.as_raw_fd()),
+ };
+
+ // SAFETY: Safe because the vhost-net fd is valid and we check the return value
+ let ret = unsafe { ioctl_with_ref(self, VHOST_NET_SET_BACKEND(), &vring_file) };
+ ioctl_result(ret, ())
+ }
+}
+
+impl<AS: GuestAddressSpace> VhostKernBackend for Net<AS> {
+ type AS = AS;
+
+ fn mem(&self) -> &Self::AS {
+ &self.mem
+ }
+}
+
+impl<AS: GuestAddressSpace> AsRawFd for Net<AS> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd.as_raw_fd()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use vm_memory::{GuestAddress, GuestMemory, GuestMemoryMmap};
+ use vmm_sys_util::eventfd::EventFd;
+
+ use super::*;
+ use crate::{
+ VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData,
+ };
+ use serial_test::serial;
+
+ #[test]
+ #[serial]
+ fn test_net_new_device() {
+ let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let net = Net::new(&m).unwrap();
+
+ assert!(net.as_raw_fd() >= 0);
+ assert!(net.mem().find_region(GuestAddress(0x100)).is_some());
+ assert!(net.mem().find_region(GuestAddress(0x10_0000)).is_none());
+ }
+
+ #[test]
+ #[serial]
+ fn test_net_is_valid() {
+ let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let net = Net::new(&m).unwrap();
+
+ let mut config = VringConfigData {
+ queue_max_size: 32,
+ queue_size: 32,
+ flags: 0,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: None,
+ };
+ assert!(net.is_valid(&config));
+
+ config.queue_size = 0;
+ assert!(!net.is_valid(&config));
+ config.queue_size = 31;
+ assert!(!net.is_valid(&config));
+ config.queue_size = 33;
+ assert!(!net.is_valid(&config));
+ }
+
+ #[test]
+ #[serial]
+ fn test_net_ioctls() {
+ let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let net = Net::new(&m).unwrap();
+ let backend = OpenOptions::new()
+ .read(true)
+ .write(true)
+ .open("/dev/null")
+ .unwrap();
+
+ let features = net.get_features().unwrap();
+ net.set_features(features).unwrap();
+
+ net.set_owner().unwrap();
+
+ net.set_mem_table(&[]).unwrap_err();
+
+ let region = VhostUserMemoryRegionInfo::new(
+ 0x0,
+ 0x10_0000,
+ m.get_host_address(GuestAddress(0x0)).unwrap() as u64,
+ 0,
+ -1,
+ );
+ net.set_mem_table(&[region]).unwrap();
+
+ net.set_log_base(
+ 0x4000,
+ Some(VhostUserDirtyLogRegion {
+ mmap_size: 0x1000,
+ mmap_offset: 0x10,
+ mmap_handle: 1,
+ }),
+ )
+ .unwrap_err();
+ net.set_log_base(0x4000, None).unwrap();
+
+ let eventfd = EventFd::new(0).unwrap();
+ net.set_log_fd(eventfd.as_raw_fd()).unwrap();
+
+ net.set_vring_num(0, 32).unwrap();
+
+ let config = VringConfigData {
+ queue_max_size: 32,
+ queue_size: 32,
+ flags: 0,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: None,
+ };
+ net.set_vring_addr(0, &config).unwrap();
+ net.set_vring_base(0, 1).unwrap();
+ net.set_vring_call(0, &eventfd).unwrap();
+ net.set_vring_kick(0, &eventfd).unwrap();
+ net.set_vring_err(0, &eventfd).unwrap();
+ assert_eq!(net.get_vring_base(0).unwrap(), 1);
+
+ net.set_backend(0, Some(&backend)).unwrap_err();
+ net.set_backend(0, None).unwrap();
+ }
+}
diff --git a/src/vhost_kern/vdpa.rs b/src/vhost_kern/vdpa.rs
new file mode 100644
index 0000000..65e0123
--- /dev/null
+++ b/src/vhost_kern/vdpa.rs
@@ -0,0 +1,560 @@
+// Copyright (C) 2021 Red Hat, Inc. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+
+//! Kernel-based vhost-vdpa backend.
+
+use std::fs::{File, OpenOptions};
+use std::io::Error as IOError;
+use std::os::raw::{c_uchar, c_uint};
+use std::os::unix::fs::OpenOptionsExt;
+use std::os::unix::io::{AsRawFd, RawFd};
+
+use vm_memory::GuestAddressSpace;
+use vmm_sys_util::eventfd::EventFd;
+use vmm_sys_util::fam::*;
+use vmm_sys_util::ioctl::{ioctl, ioctl_with_mut_ref, ioctl_with_ptr, ioctl_with_ref};
+
+use super::vhost_binding::*;
+use super::{ioctl_result, Error, Result, VhostKernBackend, VhostKernFeatures};
+use crate::vdpa::*;
+use crate::{VhostAccess, VhostIotlbBackend, VhostIotlbMsg, VhostIotlbType, VringConfigData};
+
+// Implement the FamStruct trait for vhost_vdpa_config
+generate_fam_struct_impl!(
+ vhost_vdpa_config,
+ c_uchar,
+ buf,
+ c_uint,
+ len,
+ c_uint::MAX as usize
+);
+
+type VhostVdpaConfig = FamStructWrapper<vhost_vdpa_config>;
+
+/// Handle for running VHOST_VDPA ioctls.
+pub struct VhostKernVdpa<AS: GuestAddressSpace> {
+ fd: File,
+ mem: AS,
+ backend_features_acked: u64,
+}
+
+impl<AS: GuestAddressSpace> VhostKernVdpa<AS> {
+ /// Open a handle to a new VHOST-VDPA instance.
+ pub fn new(path: &str, mem: AS) -> Result<Self> {
+ Ok(VhostKernVdpa {
+ fd: OpenOptions::new()
+ .read(true)
+ .write(true)
+ .custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK)
+ .open(path)
+ .map_err(Error::VhostOpen)?,
+ mem,
+ backend_features_acked: 0,
+ })
+ }
+
+ /// Create a `VhostKernVdpa` object with given content.
+ pub fn with(fd: File, mem: AS, backend_features_acked: u64) -> Self {
+ VhostKernVdpa {
+ fd,
+ mem,
+ backend_features_acked,
+ }
+ }
+
+ /// Set the addresses for a given vring.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to set addresses for.
+ /// * `config_data` - Vring config data, addresses of desc_table, avail_ring
+ /// and used_ring are in the guest address space.
+ pub fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
+ if !self.is_valid(config_data) {
+ return Err(Error::InvalidQueue);
+ }
+
+ // vDPA backends expect IOVA (that can be mapped 1:1 with
+ // GPA when no IOMMU is involved).
+ let vring_addr = vhost_vring_addr {
+ index: queue_index as u32,
+ flags: config_data.flags,
+ desc_user_addr: config_data.desc_table_addr,
+ used_user_addr: config_data.used_ring_addr,
+ avail_user_addr: config_data.avail_ring_addr,
+ log_guest_addr: config_data.get_log_addr(),
+ };
+
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_ADDR(), &vring_addr) };
+ ioctl_result(ret, ())
+ }
+}
+
+impl<AS: GuestAddressSpace> VhostVdpa for VhostKernVdpa<AS> {
+ fn get_device_id(&self) -> Result<u32> {
+ let mut device_id: u32 = 0;
+
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_DEVICE_ID(), &mut device_id) };
+ ioctl_result(ret, device_id)
+ }
+
+ fn get_status(&self) -> Result<u8> {
+ let mut status: u8 = 0;
+
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_STATUS(), &mut status) };
+ ioctl_result(ret, status)
+ }
+
+ fn set_status(&self, status: u8) -> Result<()> {
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_VDPA_SET_STATUS(), &status) };
+ ioctl_result(ret, ())
+ }
+
+ fn get_config(&self, offset: u32, buffer: &mut [u8]) -> Result<()> {
+ let mut config = VhostVdpaConfig::new(buffer.len())
+ .map_err(|_| Error::IoctlError(IOError::from_raw_os_error(libc::ENOMEM)))?;
+
+ config.as_mut_fam_struct().off = offset;
+
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ let ret = unsafe {
+ ioctl_with_ptr(
+ self,
+ VHOST_VDPA_GET_CONFIG(),
+ config.as_mut_fam_struct_ptr(),
+ )
+ };
+
+ buffer.copy_from_slice(config.as_slice());
+
+ ioctl_result(ret, ())
+ }
+
+ fn set_config(&self, offset: u32, buffer: &[u8]) -> Result<()> {
+ let mut config = VhostVdpaConfig::new(buffer.len())
+ .map_err(|_| Error::IoctlError(IOError::from_raw_os_error(libc::ENOMEM)))?;
+
+ config.as_mut_fam_struct().off = offset;
+ config.as_mut_slice().copy_from_slice(buffer);
+
+ let ret =
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ unsafe { ioctl_with_ptr(self, VHOST_VDPA_SET_CONFIG(), config.as_fam_struct_ptr()) };
+ ioctl_result(ret, ())
+ }
+
+ fn set_vring_enable(&self, queue_index: usize, enabled: bool) -> Result<()> {
+ let vring_state = vhost_vring_state {
+ index: queue_index as u32,
+ num: enabled as u32,
+ };
+
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_VDPA_SET_VRING_ENABLE(), &vring_state) };
+ ioctl_result(ret, ())
+ }
+
+ fn get_vring_num(&self) -> Result<u16> {
+ let mut vring_num: u16 = 0;
+
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_VRING_NUM(), &mut vring_num) };
+ ioctl_result(ret, vring_num)
+ }
+
+ fn set_config_call(&self, fd: &EventFd) -> Result<()> {
+ let event_fd: ::std::os::raw::c_int = fd.as_raw_fd();
+
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_VDPA_SET_CONFIG_CALL(), &event_fd) };
+ ioctl_result(ret, ())
+ }
+
+ fn get_iova_range(&self) -> Result<VhostVdpaIovaRange> {
+ let mut low_iova_range = vhost_vdpa_iova_range { first: 0, last: 0 };
+
+ let ret =
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_IOVA_RANGE(), &mut low_iova_range) };
+
+ let iova_range = VhostVdpaIovaRange {
+ first: low_iova_range.first,
+ last: low_iova_range.last,
+ };
+
+ ioctl_result(ret, iova_range)
+ }
+
+ fn get_config_size(&self) -> Result<u32> {
+ let mut config_size: u32 = 0;
+
+ let ret =
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_CONFIG_SIZE(), &mut config_size) };
+ ioctl_result(ret, config_size)
+ }
+
+ fn get_vqs_count(&self) -> Result<u32> {
+ let mut vqs_count: u32 = 0;
+
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_VQS_COUNT(), &mut vqs_count) };
+ ioctl_result(ret, vqs_count)
+ }
+
+ fn get_group_num(&self) -> Result<u32> {
+ let mut group_num: u32 = 0;
+
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_GROUP_NUM(), &mut group_num) };
+ ioctl_result(ret, group_num)
+ }
+
+ fn get_as_num(&self) -> Result<u32> {
+ let mut as_num: u32 = 0;
+
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_AS_NUM(), &mut as_num) };
+ ioctl_result(ret, as_num)
+ }
+
+ fn get_vring_group(&self, queue_index: u32) -> Result<u32> {
+ let mut vring_state = vhost_vring_state {
+ index: queue_index,
+ ..Default::default()
+ };
+
+ let ret =
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_VRING_GROUP(), &mut vring_state) };
+ ioctl_result(ret, vring_state.num)
+ }
+
+ fn set_group_asid(&self, group_index: u32, asid: u32) -> Result<()> {
+ let vring_state = vhost_vring_state {
+ index: group_index,
+ num: asid,
+ };
+
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_VDPA_GET_VRING_GROUP(), &vring_state) };
+ ioctl_result(ret, ())
+ }
+
+ fn suspend(&self) -> Result<()> {
+ // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl(self, VHOST_VDPA_SUSPEND()) };
+ ioctl_result(ret, ())
+ }
+
+ fn dma_map(&self, iova: u64, size: u64, vaddr: *const u8, readonly: bool) -> Result<()> {
+ let iotlb = VhostIotlbMsg {
+ iova,
+ size,
+ userspace_addr: vaddr as u64,
+ perm: match readonly {
+ true => VhostAccess::ReadOnly,
+ false => VhostAccess::ReadWrite,
+ },
+ msg_type: VhostIotlbType::Update,
+ };
+
+ self.send_iotlb_msg(&iotlb)
+ }
+
+ fn dma_unmap(&self, iova: u64, size: u64) -> Result<()> {
+ let iotlb = VhostIotlbMsg {
+ iova,
+ size,
+ msg_type: VhostIotlbType::Invalidate,
+ ..Default::default()
+ };
+
+ self.send_iotlb_msg(&iotlb)
+ }
+}
+
+impl<AS: GuestAddressSpace> VhostKernBackend for VhostKernVdpa<AS> {
+ type AS = AS;
+
+ fn mem(&self) -> &Self::AS {
+ &self.mem
+ }
+
+ /// Check whether the ring configuration is valid.
+ fn is_valid(&self, config_data: &VringConfigData) -> bool {
+ let queue_size = config_data.queue_size;
+ if queue_size > config_data.queue_max_size
+ || queue_size == 0
+ || (queue_size & (queue_size - 1)) != 0
+ {
+ return false;
+ }
+
+ // Since vDPA could be dealing with IOVAs corresponding to GVAs, it
+ // wouldn't make sense to go through the validation of the descriptor
+ // table address, available ring address and used ring address against
+ // the guest memory representation we have access to.
+
+ config_data.is_log_addr_valid()
+ }
+}
+
+impl<AS: GuestAddressSpace> AsRawFd for VhostKernVdpa<AS> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd.as_raw_fd()
+ }
+}
+
+impl<AS: GuestAddressSpace> VhostKernFeatures for VhostKernVdpa<AS> {
+ fn get_backend_features_acked(&self) -> u64 {
+ self.backend_features_acked
+ }
+
+ fn set_backend_features_acked(&mut self, features: u64) {
+ self.backend_features_acked = features;
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ const VHOST_VDPA_PATH: &str = "/dev/vhost-vdpa-0";
+
+ use std::alloc::{alloc, dealloc, Layout};
+ use vm_memory::{GuestAddress, GuestMemory, GuestMemoryMmap};
+ use vmm_sys_util::eventfd::EventFd;
+
+ use super::*;
+ use crate::{
+ VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData,
+ };
+ use serial_test::serial;
+ use std::io::ErrorKind;
+
+ /// macro to skip test if vhost-vdpa device path is not found.
+ ///
+ /// vDPA simulators are available since Linux 5.7, but the CI may have
+ /// an older kernel, so for now we skip the test if we don't find
+ /// the device.
+ macro_rules! unwrap_not_found {
+ ( $e:expr ) => {
+ match $e {
+ Ok(v) => v,
+ Err(error) => match error {
+ Error::VhostOpen(ref e) if e.kind() == ErrorKind::NotFound => {
+ println!("Err: {:?} SKIPPED", e);
+ return;
+ }
+ e => panic!("Err: {:?}", e),
+ },
+ }
+ };
+ }
+
+ macro_rules! validate_ioctl {
+ ( $e:expr, $ref_value:expr ) => {
+ match $e {
+ Ok(v) => assert_eq!(v, $ref_value),
+ Err(error) => match error {
+ Error::IoctlError(e) if e.raw_os_error().unwrap() == libc::ENOTTY => {
+ println!("Err: {:?} SKIPPED", e);
+ }
+ e => panic!("Err: {:?}", e),
+ },
+ }
+ };
+ }
+
+ #[test]
+ #[serial]
+ fn test_vdpa_kern_new_device() {
+ let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let vdpa = unwrap_not_found!(VhostKernVdpa::new(VHOST_VDPA_PATH, &m));
+
+ assert!(vdpa.as_raw_fd() >= 0);
+ assert!(vdpa.mem().find_region(GuestAddress(0x100)).is_some());
+ assert!(vdpa.mem().find_region(GuestAddress(0x10_0000)).is_none());
+ }
+
+ #[test]
+ #[serial]
+ fn test_vdpa_kern_is_valid() {
+ let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let vdpa = unwrap_not_found!(VhostKernVdpa::new(VHOST_VDPA_PATH, &m));
+
+ let mut config = VringConfigData {
+ queue_max_size: 32,
+ queue_size: 32,
+ flags: 0,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: None,
+ };
+ assert!(vdpa.is_valid(&config));
+
+ config.queue_size = 0;
+ assert!(!vdpa.is_valid(&config));
+ config.queue_size = 31;
+ assert!(!vdpa.is_valid(&config));
+ config.queue_size = 33;
+ assert!(!vdpa.is_valid(&config));
+ }
+
+ #[test]
+ #[serial]
+ fn test_vdpa_kern_ioctls() {
+ let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let vdpa = unwrap_not_found!(VhostKernVdpa::new(VHOST_VDPA_PATH, &m));
+
+ let features = vdpa.get_features().unwrap();
+ // VIRTIO_F_VERSION_1 (bit 32) should be set
+ assert_ne!(features & (1 << 32), 0);
+ vdpa.set_features(features).unwrap();
+
+ vdpa.set_owner().unwrap();
+
+ vdpa.set_mem_table(&[]).unwrap_err();
+
+ let region = VhostUserMemoryRegionInfo::new(
+ 0x0,
+ 0x10_0000,
+ m.get_host_address(GuestAddress(0x0)).unwrap() as u64,
+ 0,
+ -1,
+ );
+ vdpa.set_mem_table(&[region]).unwrap();
+
+ let device_id = vdpa.get_device_id().unwrap();
+ assert!(device_id > 0);
+
+ assert_eq!(vdpa.get_status().unwrap(), 0x0);
+ vdpa.set_status(0x1).unwrap();
+ assert_eq!(vdpa.get_status().unwrap(), 0x1);
+
+ let mut vec = vec![0u8; 8];
+ vdpa.get_config(0, &mut vec).unwrap();
+ vdpa.set_config(0, &vec).unwrap();
+
+ let eventfd = EventFd::new(0).unwrap();
+
+ // set_log_base() and set_log_fd() are not supported by vhost-vdpa
+ vdpa.set_log_base(
+ 0x4000,
+ Some(VhostUserDirtyLogRegion {
+ mmap_size: 0x1000,
+ mmap_offset: 0x10,
+ mmap_handle: 1,
+ }),
+ )
+ .unwrap_err();
+ vdpa.set_log_base(0x4000, None).unwrap_err();
+ vdpa.set_log_fd(eventfd.as_raw_fd()).unwrap_err();
+
+ let max_queues = vdpa.get_vring_num().unwrap();
+ vdpa.set_vring_num(0, max_queues + 1).unwrap_err();
+
+ vdpa.set_vring_num(0, 32).unwrap();
+
+ let config = VringConfigData {
+ queue_max_size: 32,
+ queue_size: 32,
+ flags: 0,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: None,
+ };
+ vdpa.set_vring_addr(0, &config).unwrap();
+ vdpa.set_vring_base(0, 1).unwrap();
+ vdpa.set_vring_call(0, &eventfd).unwrap();
+ vdpa.set_vring_kick(0, &eventfd).unwrap();
+ vdpa.set_vring_err(0, &eventfd).unwrap();
+
+ vdpa.set_config_call(&eventfd).unwrap();
+
+ let iova_range = vdpa.get_iova_range().unwrap();
+ // vDPA-block simulator returns [0, u64::MAX] range
+ assert_eq!(iova_range.first, 0);
+ assert_eq!(iova_range.last, u64::MAX);
+
+ let (config_size, vqs_count, group_num, as_num, vring_group) = if device_id == 1 {
+ (24, 3, 2, 2, 0)
+ } else if device_id == 2 {
+ (60, 1, 1, 1, 0)
+ } else {
+ panic!("Unexpected device id {}", device_id)
+ };
+
+ validate_ioctl!(vdpa.get_config_size(), config_size);
+ validate_ioctl!(vdpa.get_vqs_count(), vqs_count);
+ validate_ioctl!(vdpa.get_group_num(), group_num);
+ validate_ioctl!(vdpa.get_as_num(), as_num);
+ validate_ioctl!(vdpa.get_vring_group(0), vring_group);
+ validate_ioctl!(vdpa.set_group_asid(0, 12345), ());
+
+ if vdpa.get_backend_features().unwrap() & (1 << VHOST_BACKEND_F_SUSPEND)
+ == (1 << VHOST_BACKEND_F_SUSPEND)
+ {
+ validate_ioctl!(vdpa.suspend(), ());
+ }
+
+ assert_eq!(vdpa.get_vring_base(0).unwrap(), 1);
+
+ vdpa.set_vring_enable(0, true).unwrap();
+ vdpa.set_vring_enable(0, false).unwrap();
+ }
+
+ #[test]
+ #[serial]
+ fn test_vdpa_kern_dma() {
+ let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let mut vdpa = unwrap_not_found!(VhostKernVdpa::new(VHOST_VDPA_PATH, &m));
+
+ let features = vdpa.get_features().unwrap();
+ // VIRTIO_F_VERSION_1 (bit 32) should be set
+ assert_ne!(features & (1 << 32), 0);
+ vdpa.set_features(features).unwrap();
+
+ let backend_features = vdpa.get_backend_features().unwrap();
+ assert_ne!(backend_features & (1 << VHOST_BACKEND_F_IOTLB_MSG_V2), 0);
+ vdpa.set_backend_features(backend_features).unwrap();
+
+ vdpa.set_owner().unwrap();
+
+ vdpa.dma_map(0xFFFF_0000, 0xFFFF, std::ptr::null::<u8>(), false)
+ .unwrap_err();
+
+ let layout = Layout::from_size_align(0xFFFF, 1).unwrap();
+
+ // SAFETY: Safe because layout has non-zero size.
+ let ptr = unsafe { alloc(layout) };
+
+ vdpa.dma_map(0xFFFF_0000, 0xFFFF, ptr, false).unwrap();
+ vdpa.dma_unmap(0xFFFF_0000, 0xFFFF).unwrap();
+
+ // SAFETY: Safe because `ptr` is allocated with the same allocator
+ // using the same `layout`.
+ unsafe { dealloc(ptr, layout) };
+ }
+}
diff --git a/src/vhost_kern/vhost_binding.rs b/src/vhost_kern/vhost_binding.rs
new file mode 100644
index 0000000..5ebaa56
--- /dev/null
+++ b/src/vhost_kern/vhost_binding.rs
@@ -0,0 +1,545 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+//
+// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE-BSD-Google file.
+
+/* Auto-generated by bindgen then manually edited for simplicity */
+
+#![allow(non_upper_case_globals)]
+#![allow(non_camel_case_types)]
+#![allow(non_snake_case)]
+#![allow(missing_docs)]
+#![allow(clippy::missing_safety_doc)]
+
+use crate::{Error, Result};
+use std::os::raw;
+
+pub const VHOST: raw::c_uint = 0xaf;
+pub const VHOST_VRING_F_LOG: raw::c_uint = 0;
+pub const VHOST_ACCESS_RO: raw::c_uchar = 1;
+pub const VHOST_ACCESS_WO: raw::c_uchar = 2;
+pub const VHOST_ACCESS_RW: raw::c_uchar = 3;
+pub const VHOST_IOTLB_MISS: raw::c_uchar = 1;
+pub const VHOST_IOTLB_UPDATE: raw::c_uchar = 2;
+pub const VHOST_IOTLB_INVALIDATE: raw::c_uchar = 3;
+pub const VHOST_IOTLB_ACCESS_FAIL: raw::c_uchar = 4;
+pub const VHOST_IOTLB_BATCH_BEGIN: raw::c_uchar = 5;
+pub const VHOST_IOTLB_BATCH_END: raw::c_uchar = 6;
+pub const VHOST_IOTLB_MSG: raw::c_int = 1;
+pub const VHOST_IOTLB_MSG_V2: raw::c_uint = 2;
+pub const VHOST_PAGE_SIZE: raw::c_uint = 4096;
+pub const VHOST_VIRTIO: raw::c_uint = 175;
+pub const VHOST_VRING_LITTLE_ENDIAN: raw::c_uint = 0;
+pub const VHOST_VRING_BIG_ENDIAN: raw::c_uint = 1;
+pub const VHOST_F_LOG_ALL: raw::c_uint = 26;
+pub const VHOST_NET_F_VIRTIO_NET_HDR: raw::c_uint = 27;
+pub const VHOST_SCSI_ABI_VERSION: raw::c_uint = 1;
+pub const VHOST_BACKEND_F_IOTLB_MSG_V2: raw::c_ulonglong = 0x1;
+pub const VHOST_BACKEND_F_IOTLB_BATCH: raw::c_ulonglong = 0x2;
+pub const VHOST_BACKEND_F_IOTLB_ASID: raw::c_ulonglong = 0x3;
+pub const VHOST_BACKEND_F_SUSPEND: raw::c_ulonglong = 0x4;
+
+ioctl_ior_nr!(VHOST_GET_FEATURES, VHOST, 0x00, raw::c_ulonglong);
+ioctl_iow_nr!(VHOST_SET_FEATURES, VHOST, 0x00, raw::c_ulonglong);
+ioctl_io_nr!(VHOST_SET_OWNER, VHOST, 0x01);
+ioctl_io_nr!(VHOST_RESET_OWNER, VHOST, 0x02);
+ioctl_iow_nr!(VHOST_SET_MEM_TABLE, VHOST, 0x03, vhost_memory);
+ioctl_iow_nr!(VHOST_SET_LOG_BASE, VHOST, 0x04, raw::c_ulonglong);
+ioctl_iow_nr!(VHOST_SET_LOG_FD, VHOST, 0x07, raw::c_int);
+ioctl_iow_nr!(VHOST_SET_VRING_NUM, VHOST, 0x10, vhost_vring_state);
+ioctl_iow_nr!(VHOST_SET_VRING_ADDR, VHOST, 0x11, vhost_vring_addr);
+ioctl_iow_nr!(VHOST_SET_VRING_BASE, VHOST, 0x12, vhost_vring_state);
+ioctl_iowr_nr!(VHOST_GET_VRING_BASE, VHOST, 0x12, vhost_vring_state);
+ioctl_iow_nr!(VHOST_SET_VRING_KICK, VHOST, 0x20, vhost_vring_file);
+ioctl_iow_nr!(VHOST_SET_VRING_CALL, VHOST, 0x21, vhost_vring_file);
+ioctl_iow_nr!(VHOST_SET_VRING_ERR, VHOST, 0x22, vhost_vring_file);
+ioctl_iow_nr!(VHOST_SET_BACKEND_FEATURES, VHOST, 0x25, raw::c_ulonglong);
+ioctl_ior_nr!(VHOST_GET_BACKEND_FEATURES, VHOST, 0x26, raw::c_ulonglong);
+ioctl_iow_nr!(VHOST_NET_SET_BACKEND, VHOST, 0x30, vhost_vring_file);
+ioctl_iow_nr!(VHOST_SCSI_SET_ENDPOINT, VHOST, 0x40, vhost_scsi_target);
+ioctl_iow_nr!(VHOST_SCSI_CLEAR_ENDPOINT, VHOST, 0x41, vhost_scsi_target);
+ioctl_iow_nr!(VHOST_SCSI_GET_ABI_VERSION, VHOST, 0x42, raw::c_int);
+ioctl_iow_nr!(VHOST_SCSI_SET_EVENTS_MISSED, VHOST, 0x43, raw::c_uint);
+ioctl_iow_nr!(VHOST_SCSI_GET_EVENTS_MISSED, VHOST, 0x44, raw::c_uint);
+ioctl_iow_nr!(VHOST_VSOCK_SET_GUEST_CID, VHOST, 0x60, raw::c_ulonglong);
+ioctl_iow_nr!(VHOST_VSOCK_SET_RUNNING, VHOST, 0x61, raw::c_int);
+ioctl_ior_nr!(VHOST_VDPA_GET_DEVICE_ID, VHOST, 0x70, raw::c_uint);
+ioctl_ior_nr!(VHOST_VDPA_GET_STATUS, VHOST, 0x71, raw::c_uchar);
+ioctl_iow_nr!(VHOST_VDPA_SET_STATUS, VHOST, 0x72, raw::c_uchar);
+ioctl_ior_nr!(VHOST_VDPA_GET_CONFIG, VHOST, 0x73, vhost_vdpa_config);
+ioctl_iow_nr!(VHOST_VDPA_SET_CONFIG, VHOST, 0x74, vhost_vdpa_config);
+ioctl_iow_nr!(VHOST_VDPA_SET_VRING_ENABLE, VHOST, 0x75, vhost_vring_state);
+ioctl_ior_nr!(VHOST_VDPA_GET_VRING_NUM, VHOST, 0x76, raw::c_ushort);
+ioctl_iow_nr!(VHOST_VDPA_SET_CONFIG_CALL, VHOST, 0x77, raw::c_int);
+ioctl_ior_nr!(
+ VHOST_VDPA_GET_IOVA_RANGE,
+ VHOST,
+ 0x78,
+ vhost_vdpa_iova_range
+);
+ioctl_ior_nr!(VHOST_VDPA_GET_CONFIG_SIZE, VHOST, 0x79, raw::c_uint);
+ioctl_ior_nr!(VHOST_VDPA_GET_VQS_COUNT, VHOST, 0x80, raw::c_uint);
+ioctl_ior_nr!(VHOST_VDPA_GET_GROUP_NUM, VHOST, 0x81, raw::c_uint);
+ioctl_ior_nr!(VHOST_VDPA_GET_AS_NUM, VHOST, 0x7a, raw::c_uint);
+ioctl_iowr_nr!(VHOST_VDPA_GET_VRING_GROUP, VHOST, 0x7b, vhost_vring_state);
+ioctl_iow_nr!(VHOST_VDPA_SET_GROUP_ASID, VHOST, 0x7c, vhost_vring_state);
+ioctl_io_nr!(VHOST_VDPA_SUSPEND, VHOST, 0x7d);
+
+#[repr(C)]
+#[derive(Default)]
+pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>);
+
+impl<T> __IncompleteArrayField<T> {
+ #[inline]
+ pub fn new() -> Self {
+ __IncompleteArrayField(::std::marker::PhantomData)
+ }
+
+ #[inline]
+ #[allow(clippy::trivially_copy_pass_by_ref)]
+ #[allow(clippy::useless_transmute)]
+ pub unsafe fn as_ptr(&self) -> *const T {
+ ::std::mem::transmute(self)
+ }
+
+ #[inline]
+ #[allow(clippy::useless_transmute)]
+ pub unsafe fn as_mut_ptr(&mut self) -> *mut T {
+ ::std::mem::transmute(self)
+ }
+
+ #[inline]
+ pub unsafe fn as_slice(&self, len: usize) -> &[T] {
+ ::std::slice::from_raw_parts(self.as_ptr(), len)
+ }
+
+ #[inline]
+ pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] {
+ ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len)
+ }
+}
+
+impl<T> ::std::fmt::Debug for __IncompleteArrayField<T> {
+ fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ fmt.write_str("__IncompleteArrayField")
+ }
+}
+
+impl<T> ::std::clone::Clone for __IncompleteArrayField<T> {
+ #[inline]
+ fn clone(&self) -> Self {
+ Self::new()
+ }
+}
+
+impl<T> ::std::marker::Copy for __IncompleteArrayField<T> {}
+
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct vhost_vring_state {
+ pub index: raw::c_uint,
+ pub num: raw::c_uint,
+}
+
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct vhost_vring_file {
+ pub index: raw::c_uint,
+ pub fd: raw::c_int,
+}
+
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct vhost_vring_addr {
+ pub index: raw::c_uint,
+ pub flags: raw::c_uint,
+ pub desc_user_addr: raw::c_ulonglong,
+ pub used_user_addr: raw::c_ulonglong,
+ pub avail_user_addr: raw::c_ulonglong,
+ pub log_guest_addr: raw::c_ulonglong,
+}
+
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct vhost_iotlb_msg {
+ pub iova: raw::c_ulonglong,
+ pub size: raw::c_ulonglong,
+ pub uaddr: raw::c_ulonglong,
+ pub perm: raw::c_uchar,
+ pub type_: raw::c_uchar,
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub struct vhost_msg {
+ pub type_: raw::c_int,
+ pub __bindgen_anon_1: vhost_msg__bindgen_ty_1,
+}
+
+impl Default for vhost_msg {
+ fn default() -> Self {
+ // SAFETY: Zeroing all bytes is fine because they represent a valid
+ // value for all members of the structure
+ unsafe { ::std::mem::zeroed() }
+ }
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub union vhost_msg__bindgen_ty_1 {
+ pub iotlb: vhost_iotlb_msg,
+ pub padding: [raw::c_uchar; 64usize],
+ _bindgen_union_align: [u64; 8usize],
+}
+
+impl Default for vhost_msg__bindgen_ty_1 {
+ fn default() -> Self {
+ // SAFETY: Zeroing all bytes is fine because they represent a valid
+ // value for all members of the structure
+ unsafe { ::std::mem::zeroed() }
+ }
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub struct vhost_msg_v2 {
+ pub type_: raw::c_uint,
+ pub reserved: raw::c_uint,
+ pub __bindgen_anon_1: vhost_msg_v2__bindgen_ty_1,
+}
+
+impl Default for vhost_msg_v2 {
+ fn default() -> Self {
+ // SAFETY: Zeroing all bytes is fine because they represent a valid
+ // value for all members of the structure
+ unsafe { ::std::mem::zeroed() }
+ }
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub union vhost_msg_v2__bindgen_ty_1 {
+ pub iotlb: vhost_iotlb_msg,
+ pub padding: [raw::c_uchar; 64usize],
+ _bindgen_union_align: [u64; 8usize],
+}
+
+impl Default for vhost_msg_v2__bindgen_ty_1 {
+ fn default() -> Self {
+ // SAFETY: Zeroing all bytes is fine because they represent a valid
+ // value for all members of the structure
+ unsafe { ::std::mem::zeroed() }
+ }
+}
+
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct vhost_memory_region {
+ pub guest_phys_addr: raw::c_ulonglong,
+ pub memory_size: raw::c_ulonglong,
+ pub userspace_addr: raw::c_ulonglong,
+ pub flags_padding: raw::c_ulonglong,
+}
+
+#[repr(C)]
+#[derive(Debug, Default, Clone)]
+pub struct vhost_memory {
+ pub nregions: raw::c_uint,
+ pub padding: raw::c_uint,
+ pub regions: __IncompleteArrayField<vhost_memory_region>,
+ __force_alignment: [u64; 0],
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub struct vhost_scsi_target {
+ pub abi_version: raw::c_int,
+ pub vhost_wwpn: [raw::c_char; 224usize],
+ pub vhost_tpgt: raw::c_ushort,
+ pub reserved: raw::c_ushort,
+}
+
+impl Default for vhost_scsi_target {
+ fn default() -> Self {
+ // SAFETY: Zeroing all bytes is fine because they represent a valid
+ // value for all members of the structure
+ unsafe { ::std::mem::zeroed() }
+ }
+}
+
+#[repr(C)]
+#[derive(Debug, Default)]
+pub struct vhost_vdpa_config {
+ pub off: raw::c_uint,
+ pub len: raw::c_uint,
+ pub buf: __IncompleteArrayField<raw::c_uchar>,
+}
+
+#[repr(C)]
+#[derive(Debug, Copy, Clone)]
+pub struct vhost_vdpa_iova_range {
+ pub first: raw::c_ulonglong,
+ pub last: raw::c_ulonglong,
+}
+
+/// Helper to support vhost::set_mem_table()
+pub struct VhostMemory {
+ buf: Vec<vhost_memory>,
+}
+
+impl VhostMemory {
+ // Limit number of regions to u16 to simplify error handling
+ pub fn new(entries: u16) -> Self {
+ let size = std::mem::size_of::<vhost_memory_region>() * entries as usize;
+ let count = (size + 2 * std::mem::size_of::<vhost_memory>() - 1)
+ / std::mem::size_of::<vhost_memory>();
+ let mut buf: Vec<vhost_memory> = vec![Default::default(); count];
+ buf[0].nregions = u32::from(entries);
+ VhostMemory { buf }
+ }
+
+ pub fn as_ptr(&self) -> *const char {
+ &self.buf[0] as *const vhost_memory as *const char
+ }
+
+ pub fn get_header(&self) -> &vhost_memory {
+ &self.buf[0]
+ }
+
+ pub fn get_region(&self, index: u32) -> Option<&vhost_memory_region> {
+ if index >= self.buf[0].nregions {
+ return None;
+ }
+ // SAFETY: Safe because we have allocated enough space nregions
+ let regions = unsafe { self.buf[0].regions.as_slice(self.buf[0].nregions as usize) };
+ Some(&regions[index as usize])
+ }
+
+ pub fn set_region(&mut self, index: u32, region: &vhost_memory_region) -> Result<()> {
+ if index >= self.buf[0].nregions {
+ return Err(Error::InvalidGuestMemory);
+ }
+ // SAFETY: Safe because we have allocated enough space nregions and checked the index.
+ let regions = unsafe { self.buf[0].regions.as_mut_slice(index as usize + 1) };
+ regions[index as usize] = *region;
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn bindgen_test_layout_vhost_vring_state() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_vring_state>(),
+ 8usize,
+ concat!("Size of: ", stringify!(vhost_vring_state))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_vring_state>(),
+ 4usize,
+ concat!("Alignment of ", stringify!(vhost_vring_state))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_vring_file() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_vring_file>(),
+ 8usize,
+ concat!("Size of: ", stringify!(vhost_vring_file))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_vring_file>(),
+ 4usize,
+ concat!("Alignment of ", stringify!(vhost_vring_file))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_vring_addr() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_vring_addr>(),
+ 40usize,
+ concat!("Size of: ", stringify!(vhost_vring_addr))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_vring_addr>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_vring_addr))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_msg__bindgen_ty_1() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_msg__bindgen_ty_1>(),
+ 64usize,
+ concat!("Size of: ", stringify!(vhost_msg__bindgen_ty_1))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_msg__bindgen_ty_1>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_msg__bindgen_ty_1))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_msg() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_msg>(),
+ 72usize,
+ concat!("Size of: ", stringify!(vhost_msg))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_msg>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_msg))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_msg_v2__bindgen_ty_1() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_msg_v2__bindgen_ty_1>(),
+ 64usize,
+ concat!("Size of: ", stringify!(vhost_msg_v2__bindgen_ty_1))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_msg_v2__bindgen_ty_1>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_msg_v2__bindgen_ty_1))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_msg_v2() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_msg_v2>(),
+ 72usize,
+ concat!("Size of: ", stringify!(vhost_msg_v2))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_msg_v2>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_msg_v2))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_memory_region() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_memory_region>(),
+ 32usize,
+ concat!("Size of: ", stringify!(vhost_memory_region))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_memory_region>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_memory_region))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_memory() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_memory>(),
+ 8usize,
+ concat!("Size of: ", stringify!(vhost_memory))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_memory>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_memory))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_iotlb_msg() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_iotlb_msg>(),
+ 32usize,
+ concat!("Size of: ", stringify!(vhost_iotlb_msg))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_iotlb_msg>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_iotlb_msg))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_scsi_target() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_scsi_target>(),
+ 232usize,
+ concat!("Size of: ", stringify!(vhost_scsi_target))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_scsi_target>(),
+ 4usize,
+ concat!("Alignment of ", stringify!(vhost_scsi_target))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_vdpa_config() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_vdpa_config>(),
+ 8usize,
+ concat!("Size of: ", stringify!(vhost_vdpa_config))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_vdpa_config>(),
+ 4usize,
+ concat!("Alignment of ", stringify!(vhost_vdpa_config))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_vdpa_iova_range() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_vdpa_iova_range>(),
+ 16usize,
+ concat!("Size of: ", stringify!(vhost_vdpa_iova_range))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_vdpa_iova_range>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_vdpa_iova_range))
+ );
+ }
+
+ #[test]
+ fn test_vhostmemory() {
+ let mut obj = VhostMemory::new(2);
+ let region = vhost_memory_region {
+ guest_phys_addr: 0x1000u64,
+ memory_size: 0x2000u64,
+ userspace_addr: 0x300000u64,
+ flags_padding: 0u64,
+ };
+ assert!(obj.get_region(2).is_none());
+
+ {
+ let header = obj.get_header();
+ assert_eq!(header.nregions, 2u32);
+ }
+ {
+ assert!(obj.set_region(0, &region).is_ok());
+ assert!(obj.set_region(1, &region).is_ok());
+ assert!(obj.set_region(2, &region).is_err());
+ }
+
+ let region1 = obj.get_region(1).unwrap();
+ assert_eq!(region1.guest_phys_addr, 0x1000u64);
+ assert_eq!(region1.memory_size, 0x2000u64);
+ assert_eq!(region1.userspace_addr, 0x300000u64);
+ }
+}
diff --git a/src/vhost_kern/vsock.rs b/src/vhost_kern/vsock.rs
new file mode 100644
index 0000000..9bc788e
--- /dev/null
+++ b/src/vhost_kern/vsock.rs
@@ -0,0 +1,196 @@
+// Copyright (C) 2019 Alibaba Cloud. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+//
+// Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE-BSD-Google file.
+
+//! Kernel-based vhost-vsock backend.
+
+use std::fs::{File, OpenOptions};
+use std::os::unix::fs::OpenOptionsExt;
+use std::os::unix::io::{AsRawFd, RawFd};
+
+use vm_memory::GuestAddressSpace;
+use vmm_sys_util::ioctl::ioctl_with_ref;
+
+use super::vhost_binding::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING};
+use super::{ioctl_result, Error, Result, VhostKernBackend};
+use crate::vsock::VhostVsock;
+
+const VHOST_PATH: &str = "/dev/vhost-vsock";
+
+/// Handle for running VHOST_VSOCK ioctls.
+pub struct Vsock<AS: GuestAddressSpace> {
+ fd: File,
+ mem: AS,
+}
+
+impl<AS: GuestAddressSpace> Vsock<AS> {
+ /// Open a handle to a new VHOST-VSOCK instance.
+ pub fn new(mem: AS) -> Result<Self> {
+ Ok(Vsock {
+ fd: OpenOptions::new()
+ .read(true)
+ .write(true)
+ .custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK)
+ .open(VHOST_PATH)
+ .map_err(Error::VhostOpen)?,
+ mem,
+ })
+ }
+
+ fn set_running(&self, running: bool) -> Result<()> {
+ let on: ::std::os::raw::c_int = if running { 1 } else { 0 };
+
+ // SAFETY: This ioctl is called on a valid vhost-vsock fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_RUNNING(), &on) };
+ ioctl_result(ret, ())
+ }
+}
+
+impl<AS: GuestAddressSpace> VhostVsock for Vsock<AS> {
+ fn set_guest_cid(&self, cid: u64) -> Result<()> {
+ // SAFETY: This ioctl is called on a valid vhost-vsock fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_GUEST_CID(), &cid) };
+ ioctl_result(ret, ())
+ }
+
+ fn start(&self) -> Result<()> {
+ self.set_running(true)
+ }
+
+ fn stop(&self) -> Result<()> {
+ self.set_running(false)
+ }
+}
+
+impl<AS: GuestAddressSpace> VhostKernBackend for Vsock<AS> {
+ type AS = AS;
+
+ fn mem(&self) -> &Self::AS {
+ &self.mem
+ }
+}
+
+impl<AS: GuestAddressSpace> AsRawFd for Vsock<AS> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd.as_raw_fd()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use vm_memory::{GuestAddress, GuestMemory, GuestMemoryMmap};
+ use vmm_sys_util::eventfd::EventFd;
+
+ use super::*;
+ use crate::{
+ VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData,
+ };
+
+ #[test]
+ fn test_vsock_new_device() {
+ let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let vsock = Vsock::new(&m).unwrap();
+
+ assert!(vsock.as_raw_fd() >= 0);
+ assert!(vsock.mem().find_region(GuestAddress(0x100)).is_some());
+ assert!(vsock.mem().find_region(GuestAddress(0x10_0000)).is_none());
+ }
+
+ #[test]
+ fn test_vsock_is_valid() {
+ let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let vsock = Vsock::new(&m).unwrap();
+
+ let mut config = VringConfigData {
+ queue_max_size: 32,
+ queue_size: 32,
+ flags: 0,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: None,
+ };
+ assert!(vsock.is_valid(&config));
+
+ config.queue_size = 0;
+ assert!(!vsock.is_valid(&config));
+ config.queue_size = 31;
+ assert!(!vsock.is_valid(&config));
+ config.queue_size = 33;
+ assert!(!vsock.is_valid(&config));
+ }
+
+ #[test]
+ fn test_vsock_ioctls() {
+ let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let vsock = Vsock::new(&m).unwrap();
+
+ let features = vsock.get_features().unwrap();
+ vsock.set_features(features).unwrap();
+
+ vsock.set_owner().unwrap();
+
+ vsock.set_mem_table(&[]).unwrap_err();
+
+ /*
+ let region = VhostUserMemoryRegionInfo {
+ guest_phys_addr: 0x0,
+ memory_size: 0x10_0000,
+ userspace_addr: 0,
+ mmap_offset: 0,
+ mmap_handle: -1,
+ };
+ vsock.set_mem_table(&[region]).unwrap_err();
+ */
+
+ let region = VhostUserMemoryRegionInfo::new(
+ 0x0,
+ 0x10_0000,
+ m.get_host_address(GuestAddress(0x0)).unwrap() as u64,
+ 0,
+ -1,
+ );
+ vsock.set_mem_table(&[region]).unwrap();
+
+ vsock
+ .set_log_base(
+ 0x4000,
+ Some(VhostUserDirtyLogRegion {
+ mmap_size: 0x1000,
+ mmap_offset: 0x10,
+ mmap_handle: 1,
+ }),
+ )
+ .unwrap_err();
+ vsock.set_log_base(0x4000, None).unwrap();
+
+ let eventfd = EventFd::new(0).unwrap();
+ vsock.set_log_fd(eventfd.as_raw_fd()).unwrap();
+
+ vsock.set_vring_num(0, 32).unwrap();
+
+ let config = VringConfigData {
+ queue_max_size: 32,
+ queue_size: 32,
+ flags: 0,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: None,
+ };
+ vsock.set_vring_addr(0, &config).unwrap();
+ vsock.set_vring_base(0, 1).unwrap();
+ vsock.set_vring_call(0, &eventfd).unwrap();
+ vsock.set_vring_kick(0, &eventfd).unwrap();
+ vsock.set_vring_err(0, &eventfd).unwrap();
+ assert_eq!(vsock.get_vring_base(0).unwrap(), 1);
+ vsock.set_guest_cid(0xdead).unwrap();
+ //vsock.start().unwrap();
+ //vsock.stop().unwrap();
+ }
+}
diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs
new file mode 100644
index 0000000..4a62e12
--- /dev/null
+++ b/src/vhost_user/connection.rs
@@ -0,0 +1,903 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! Structs for Unix Domain Socket listener and endpoint.
+
+#![allow(dead_code)]
+
+use std::fs::File;
+use std::io::ErrorKind;
+use std::marker::PhantomData;
+use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
+use std::os::unix::net::{UnixListener, UnixStream};
+use std::path::{Path, PathBuf};
+use std::{mem, slice};
+
+use libc::{c_void, iovec};
+use vm_memory::ByteValued;
+use vmm_sys_util::sock_ctrl_msg::ScmSocket;
+
+use super::message::*;
+use super::{Error, Result};
+
+/// Unix domain socket listener for accepting incoming connections.
+pub struct Listener {
+ fd: UnixListener,
+ path: Option<PathBuf>,
+}
+
+impl Listener {
+ /// Create a unix domain socket listener.
+ ///
+ /// # Return:
+ /// * - the new Listener object on success.
+ /// * - SocketError: failed to create listener socket.
+ pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
+ if unlink {
+ let _ = std::fs::remove_file(&path);
+ }
+ let fd = UnixListener::bind(&path).map_err(Error::SocketError)?;
+ Ok(Listener {
+ fd,
+ path: Some(path.as_ref().to_owned()),
+ })
+ }
+
+ /// Accept an incoming connection.
+ ///
+ /// # Return:
+ /// * - Some(UnixStream): new UnixStream object if new incoming connection is available.
+ /// * - None: no incoming connection available.
+ /// * - SocketError: errors from accept().
+ pub fn accept(&self) -> Result<Option<UnixStream>> {
+ loop {
+ match self.fd.accept() {
+ Ok((socket, _addr)) => return Ok(Some(socket)),
+ Err(e) => {
+ match e.kind() {
+ // No incoming connection available.
+ ErrorKind::WouldBlock => return Ok(None),
+ // New connection closed by peer.
+ ErrorKind::ConnectionAborted => return Ok(None),
+ // Interrupted by signals, retry
+ ErrorKind::Interrupted => continue,
+ _ => return Err(Error::SocketError(e)),
+ }
+ }
+ }
+ }
+ }
+
+ /// Change blocking status on the listener.
+ ///
+ /// # Return:
+ /// * - () on success.
+ /// * - SocketError: failure from set_nonblocking().
+ pub fn set_nonblocking(&self, block: bool) -> Result<()> {
+ self.fd.set_nonblocking(block).map_err(Error::SocketError)
+ }
+}
+
+impl AsRawFd for Listener {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd.as_raw_fd()
+ }
+}
+
+impl FromRawFd for Listener {
+ unsafe fn from_raw_fd(fd: RawFd) -> Self {
+ Listener {
+ fd: UnixListener::from_raw_fd(fd),
+ path: None,
+ }
+ }
+}
+
+impl Drop for Listener {
+ fn drop(&mut self) {
+ if let Some(path) = &self.path {
+ let _ = std::fs::remove_file(path);
+ }
+ }
+}
+
+/// Unix domain socket endpoint for vhost-user connection.
+pub(super) struct Endpoint<R: Req> {
+ sock: UnixStream,
+ _r: PhantomData<R>,
+}
+
+impl<R: Req> Endpoint<R> {
+ /// Create a new stream by connecting to server at `str`.
+ ///
+ /// # Return:
+ /// * - the new Endpoint object on success.
+ /// * - SocketConnect: failed to connect to peer.
+ pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
+ let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?;
+ Ok(Self::from_stream(sock))
+ }
+
+ /// Create an endpoint from a stream object.
+ pub fn from_stream(sock: UnixStream) -> Self {
+ Endpoint {
+ sock,
+ _r: PhantomData,
+ }
+ }
+
+ /// Sends bytes from scatter-gather vectors over the socket with optional attached file
+ /// descriptors.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
+ let rfds = match fds {
+ Some(rfds) => rfds,
+ _ => &[],
+ };
+ self.sock.send_with_fds(iovs, rfds).map_err(Into::into)
+ }
+
+ /// Sends all bytes from scatter-gather vectors over the socket with optional attached file
+ /// descriptors. Will loop until all data has been transfered.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
+ let mut data_sent = 0;
+ let mut data_total = 0;
+ let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.len()).collect();
+ for len in &iov_lens {
+ data_total += len;
+ }
+
+ while (data_total - data_sent) > 0 {
+ let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent);
+ let iov = &iovs[nr_skip][offset..];
+
+ let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat();
+ let sfds = if data_sent == 0 { fds } else { None };
+
+ let sent = self.send_iovec(data, sfds);
+ match sent {
+ Ok(0) => return Ok(data_sent),
+ Ok(n) => data_sent += n,
+ Err(e) => match e {
+ Error::SocketRetry(_) => {}
+ _ => return Err(e),
+ },
+ }
+ }
+ Ok(data_sent)
+ }
+
+ /// Sends bytes from a slice over the socket with optional attached file descriptors.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> {
+ self.send_iovec(&[data], fds)
+ }
+
+ /// Sends a header-only message with optional attached file descriptors.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ pub fn send_header(
+ &mut self,
+ hdr: &VhostUserMsgHeader<R>,
+ fds: Option<&[RawFd]>,
+ ) -> Result<()> {
+ // SAFETY: Safe because there can't be other mutable referance to hdr.
+ let iovs = unsafe {
+ [slice::from_raw_parts(
+ hdr as *const VhostUserMsgHeader<R> as *const u8,
+ mem::size_of::<VhostUserMsgHeader<R>>(),
+ )]
+ };
+ let bytes = self.send_iovec_all(&iovs[..], fds)?;
+ if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
+ return Err(Error::PartialMessage);
+ }
+ Ok(())
+ }
+
+ /// Send a message with header and body. Optional file descriptors may be attached to
+ /// the message.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ pub fn send_message<T: ByteValued>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<R>,
+ body: &T,
+ fds: Option<&[RawFd]>,
+ ) -> Result<()> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::OversizedMsg);
+ }
+ let bytes = self.send_iovec_all(&[hdr.as_slice(), body.as_slice()], fds)?;
+ if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() {
+ return Err(Error::PartialMessage);
+ }
+ Ok(())
+ }
+
+ /// Send a message with header, body and payload. Optional file descriptors
+ /// may also be attached to the message.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - OversizedMsg: message size is too big.
+ /// * - PartialMessage: received a partial message.
+ /// * - IncorrectFds: wrong number of attached fds.
+ pub fn send_message_with_payload<T: ByteValued>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<R>,
+ body: &T,
+ payload: &[u8],
+ fds: Option<&[RawFd]>,
+ ) -> Result<()> {
+ let len = payload.len();
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::OversizedMsg);
+ }
+ if len > MAX_MSG_SIZE - mem::size_of::<T>() {
+ return Err(Error::OversizedMsg);
+ }
+ if let Some(fd_arr) = fds {
+ if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
+ return Err(Error::IncorrectFds);
+ }
+ }
+
+ let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len;
+ let len = self.send_iovec_all(&[hdr.as_slice(), body.as_slice(), payload], fds)?;
+ if len != total {
+ return Err(Error::PartialMessage);
+ }
+ Ok(())
+ }
+
+ /// Reads bytes from the socket into the given scatter/gather vectors.
+ ///
+ /// # Return:
+ /// * - (number of bytes received, buf) on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)> {
+ let mut rbuf = vec![0u8; len];
+ let mut iovs = [iovec {
+ iov_base: rbuf.as_mut_ptr() as *mut c_void,
+ iov_len: len,
+ }];
+ // SAFETY: Safe because we own rbuf and it's safe to fill a byte array with arbitrary data.
+ let (bytes, _) = unsafe { self.sock.recv_with_fds(&mut iovs, &mut [])? };
+ Ok((bytes, rbuf))
+ }
+
+ /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
+ /// file.
+ ///
+ /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
+ /// tricky to pass file descriptors through such a communication channel. Let's assume that a
+ /// sender sending a message with some file descriptors attached. To successfully receive those
+ /// attached file descriptors, the receiver must obey following rules:
+ /// 1) file descriptors are attached to a message.
+ /// 2) message(packet) boundaries must be respected on the receive side.
+ /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
+ /// attached file descriptors will get lost.
+ /// Note that this function wraps received file descriptors as `File`.
+ ///
+ /// # Return:
+ /// * - (number of bytes received, [received files]) on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ ///
+ /// # Safety
+ ///
+ /// It is the callers responsibility to ensure it is safe for arbitrary data to be
+ /// written to the iovec pointers.
+ pub unsafe fn recv_into_iovec(
+ &mut self,
+ iovs: &mut [iovec],
+ ) -> Result<(usize, Option<Vec<File>>)> {
+ let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES];
+ let (bytes, fds) = self.sock.recv_with_fds(iovs, &mut fd_array)?;
+
+ let files = match fds {
+ 0 => None,
+ n => {
+ let files = fd_array
+ .iter()
+ .take(n)
+ .map(|fd| {
+ // Safe because we have the ownership of `fd`.
+ File::from_raw_fd(*fd)
+ })
+ .collect();
+ Some(files)
+ }
+ };
+
+ Ok((bytes, files))
+ }
+
+ /// Reads all bytes from the socket into the given scatter/gather vectors with optional
+ /// attached files. Will loop until all data has been transferred.
+ ///
+ /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
+ /// tricky to pass file descriptors through such a communication channel. Let's assume that a
+ /// sender sending a message with some file descriptors attached. To successfully receive those
+ /// attached file descriptors, the receiver must obey following rules:
+ /// 1) file descriptors are attached to a message.
+ /// 2) message(packet) boundaries must be respected on the receive side.
+ /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
+ /// attached file descriptors will get lost.
+ /// Note that this function wraps received file descriptors as `File`.
+ ///
+ /// # Return:
+ /// * - (number of bytes received, [received fds]) on success
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ ///
+ /// # Safety
+ ///
+ /// It is the callers responsibility to ensure it is safe for arbitrary data to be
+ /// written to the iovec pointers.
+ pub unsafe fn recv_into_iovec_all(
+ &mut self,
+ iovs: &mut [iovec],
+ ) -> Result<(usize, Option<Vec<File>>)> {
+ let mut data_read = 0;
+ let mut data_total = 0;
+ let mut rfds = None;
+ let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.iov_len).collect();
+ for len in &iov_lens {
+ data_total += len;
+ }
+
+ while (data_total - data_read) > 0 {
+ let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_read);
+ let iov = &mut iovs[nr_skip];
+
+ let mut data = [
+ &[iovec {
+ iov_base: (iov.iov_base as usize + offset) as *mut c_void,
+ iov_len: iov.iov_len - offset,
+ }],
+ &iovs[(nr_skip + 1)..],
+ ]
+ .concat();
+
+ let res = self.recv_into_iovec(&mut data);
+ match res {
+ Ok((0, _)) => return Ok((data_read, rfds)),
+ Ok((n, fds)) => {
+ if data_read == 0 {
+ rfds = fds;
+ }
+ data_read += n;
+ }
+ Err(e) => match e {
+ Error::SocketRetry(_) => {}
+ _ => return Err(e),
+ },
+ }
+ }
+ Ok((data_read, rfds))
+ }
+
+ /// Reads bytes from the socket into a new buffer with optional attached
+ /// files. Received file descriptors are set close-on-exec and converted to `File`.
+ ///
+ /// # Return:
+ /// * - (number of bytes received, buf, [received files]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn recv_into_buf(
+ &mut self,
+ buf_size: usize,
+ ) -> Result<(usize, Vec<u8>, Option<Vec<File>>)> {
+ let mut buf = vec![0u8; buf_size];
+ let (bytes, files) = {
+ let mut iovs = [iovec {
+ iov_base: buf.as_mut_ptr() as *mut c_void,
+ iov_len: buf_size,
+ }];
+ // SAFETY: Safe because we own buf and it's safe to fill a byte array with arbitrary data.
+ unsafe { self.recv_into_iovec(&mut iovs)? }
+ };
+ Ok((bytes, buf, files))
+ }
+
+ /// Receive a header-only message with optional attached files.
+ /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
+ /// accepted and all other file descriptor will be discard silently.
+ ///
+ /// # Return:
+ /// * - (message header, [received files]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ /// * - InvalidMessage: received a invalid message.
+ pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<File>>)> {
+ let mut hdr = VhostUserMsgHeader::default();
+ let mut iovs = [iovec {
+ iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
+ iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
+ }];
+ // SAFETY: Safe because we own hdr and it's ByteValued.
+ let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? };
+
+ if bytes == 0 {
+ return Err(Error::Disconnected);
+ } else if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
+ return Err(Error::PartialMessage);
+ } else if !hdr.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ Ok((hdr, files))
+ }
+
+ /// Receive a message with optional attached file descriptors.
+ /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
+ /// accepted and all other file descriptor will be discard silently.
+ ///
+ /// # Return:
+ /// * - (message header, message body, [received files]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ /// * - InvalidMessage: received a invalid message.
+ pub fn recv_body<T: ByteValued + Sized + VhostUserMsgValidator>(
+ &mut self,
+ ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<File>>)> {
+ let mut hdr = VhostUserMsgHeader::default();
+ let mut body: T = Default::default();
+ let mut iovs = [
+ iovec {
+ iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
+ iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
+ },
+ iovec {
+ iov_base: (&mut body as *mut T) as *mut c_void,
+ iov_len: mem::size_of::<T>(),
+ },
+ ];
+ // SAFETY: Safe because we own hdr and body and they're ByteValued.
+ let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? };
+
+ let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
+ if bytes != total {
+ return Err(Error::PartialMessage);
+ } else if !hdr.is_valid() || !body.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ Ok((hdr, body, files))
+ }
+
+ /// Receive a message with header and optional content. Callers need to
+ /// pre-allocate a big enough buffer to receive the message body and
+ /// optional payload. If there are attached file descriptor associated
+ /// with the message, the first MAX_ATTACHED_FD_ENTRIES file descriptors
+ /// will be accepted and all other file descriptor will be discard
+ /// silently.
+ ///
+ /// # Return:
+ /// * - (message header, message size, [received files]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ /// * - InvalidMessage: received a invalid message.
+ pub fn recv_body_into_buf(
+ &mut self,
+ buf: &mut [u8],
+ ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<File>>)> {
+ let mut hdr = VhostUserMsgHeader::default();
+ let mut iovs = [
+ iovec {
+ iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
+ iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
+ },
+ iovec {
+ iov_base: buf.as_mut_ptr() as *mut c_void,
+ iov_len: buf.len(),
+ },
+ ];
+ // SAFETY: Safe because we own hdr and have a mutable borrow of buf, and hdr is ByteValued
+ // and it's safe to fill a byte slice with arbitrary data.
+ let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? };
+
+ if bytes < mem::size_of::<VhostUserMsgHeader<R>>() {
+ return Err(Error::PartialMessage);
+ } else if !hdr.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), files))
+ }
+
+ /// Receive a message with optional payload and attached file descriptors.
+ /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
+ /// accepted and all other file descriptor will be discard silently.
+ ///
+ /// # Return:
+ /// * - (message header, message body, size of payload, [received files]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ /// * - InvalidMessage: received a invalid message.
+ #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))]
+ pub fn recv_payload_into_buf<T: ByteValued + Sized + VhostUserMsgValidator>(
+ &mut self,
+ buf: &mut [u8],
+ ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<File>>)> {
+ let mut hdr = VhostUserMsgHeader::default();
+ let mut body: T = Default::default();
+ let mut iovs = [
+ iovec {
+ iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
+ iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
+ },
+ iovec {
+ iov_base: (&mut body as *mut T) as *mut c_void,
+ iov_len: mem::size_of::<T>(),
+ },
+ iovec {
+ iov_base: buf.as_mut_ptr() as *mut c_void,
+ iov_len: buf.len(),
+ },
+ ];
+ // SAFETY: Safe because we own hdr and body and have a mutable borrow of buf, and
+ // hdr and body are ByteValued, and it's safe to fill a byte slice with
+ // arbitrary data.
+ let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? };
+
+ let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
+ if bytes < total {
+ return Err(Error::PartialMessage);
+ } else if !hdr.is_valid() || !body.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ Ok((hdr, body, bytes - total, files))
+ }
+}
+
+impl<T: Req> AsRawFd for Endpoint<T> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.sock.as_raw_fd()
+ }
+}
+
+// Given a slice of sizes and the `skip_size`, return the offset of `skip_size` in the slice.
+// For example:
+// let iov_lens = vec![4, 4, 5];
+// let size = 6;
+// assert_eq!(get_sub_iovs_offset(&iov_len, size), (1, 2));
+fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
+ let mut size = skip_size;
+ let mut nr_skip = 0;
+
+ for len in iov_lens {
+ if size >= *len {
+ size -= *len;
+ nr_skip += 1;
+ } else {
+ break;
+ }
+ }
+ (nr_skip, size)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::io::{Read, Seek, SeekFrom, Write};
+ use vmm_sys_util::rand::rand_alphanumerics;
+ use vmm_sys_util::tempfile::TempFile;
+
+ fn temp_path() -> PathBuf {
+ PathBuf::from(format!(
+ "/tmp/vhost_test_{}",
+ rand_alphanumerics(8).to_str().unwrap()
+ ))
+ }
+
+ #[test]
+ fn create_listener() {
+ let path = temp_path();
+ let listener = Listener::new(path, true).unwrap();
+
+ assert!(listener.as_raw_fd() > 0);
+ }
+
+ #[test]
+ fn create_listener_from_raw_fd() {
+ let path = temp_path();
+ let file = File::create(path).unwrap();
+
+ // SAFETY: Safe because `file` contains a valid fd to a file just created.
+ let listener = unsafe { Listener::from_raw_fd(file.as_raw_fd()) };
+
+ assert!(listener.as_raw_fd() > 0);
+ }
+
+ #[test]
+ fn accept_connection() {
+ let path = temp_path();
+ let listener = Listener::new(path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+
+ // accept on a fd without incoming connection
+ let conn = listener.accept().unwrap();
+ assert!(conn.is_none());
+ }
+
+ #[test]
+ fn send_data() {
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
+ let sock = listener.accept().unwrap().unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(sock);
+
+ let buf1 = vec![0x1, 0x2, 0x3, 0x4];
+ let mut len = master.send_slice(&buf1[..], None).unwrap();
+ assert_eq!(len, 4);
+ let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap();
+ assert_eq!(bytes, 4);
+ assert_eq!(&buf1[..], &buf2[..bytes]);
+
+ len = master.send_slice(&buf1[..], None).unwrap();
+ assert_eq!(len, 4);
+ let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[..2], &buf2[..]);
+ let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[2..], &buf2[..]);
+ }
+
+ #[test]
+ fn send_fd() {
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
+ let sock = listener.accept().unwrap().unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(sock);
+
+ let mut fd = TempFile::new().unwrap().into_file();
+ write!(fd, "test").unwrap();
+
+ // Normal case for sending/receiving file descriptors
+ let buf1 = vec![0x1, 0x2, 0x3, 0x4];
+ let len = master
+ .send_slice(&buf1[..], Some(&[fd.as_raw_fd()]))
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, buf2, files) = slave.recv_into_buf(4).unwrap();
+ assert_eq!(bytes, 4);
+ assert_eq!(&buf1[..], &buf2[..]);
+ assert!(files.is_some());
+ let files = files.unwrap();
+ {
+ assert_eq!(files.len(), 1);
+ let mut file = &files[0];
+ let mut content = String::new();
+ file.seek(SeekFrom::Start(0)).unwrap();
+ file.read_to_string(&mut content).unwrap();
+ assert_eq!(content, "test");
+ }
+
+ // Following communication pattern should work:
+ // Sending side: data(header, body) with fds
+ // Receiving side: data(header) with fds, data(body)
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[..2], &buf2[..]);
+ assert!(files.is_some());
+ let files = files.unwrap();
+ {
+ assert_eq!(files.len(), 3);
+ let mut file = &files[1];
+ let mut content = String::new();
+ file.seek(SeekFrom::Start(0)).unwrap();
+ file.read_to_string(&mut content).unwrap();
+ assert_eq!(content, "test");
+ }
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[2..], &buf2[..]);
+ assert!(files.is_none());
+
+ // Following communication pattern should not work:
+ // Sending side: data(header, body) with fds
+ // Receiving side: data(header), data(body) with fds
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, buf4) = slave.recv_data(2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[..2], &buf4[..]);
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[2..], &buf2[..]);
+ assert!(files.is_none());
+
+ // Following communication pattern should work:
+ // Sending side: data, data with fds
+ // Receiving side: data, data with fds
+ let len = master.send_slice(&buf1[..], None).unwrap();
+ assert_eq!(len, 4);
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, buf2, files) = slave.recv_into_buf(0x4).unwrap();
+ assert_eq!(bytes, 4);
+ assert_eq!(&buf1[..], &buf2[..]);
+ assert!(files.is_none());
+
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[..2], &buf2[..]);
+ assert!(files.is_some());
+ let files = files.unwrap();
+ {
+ assert_eq!(files.len(), 3);
+ let mut file = &files[1];
+ let mut content = String::new();
+ file.seek(SeekFrom::Start(0)).unwrap();
+ file.read_to_string(&mut content).unwrap();
+ assert_eq!(content, "test");
+ }
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[2..], &buf2[..]);
+ assert!(files.is_none());
+
+ // Following communication pattern should not work:
+ // Sending side: data1, data2 with fds
+ // Receiving side: data + partial of data2, left of data2 with fds
+ let len = master.send_slice(&buf1[..], None).unwrap();
+ assert_eq!(len, 4);
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, _) = slave.recv_data(5).unwrap();
+ assert_eq!(bytes, 5);
+
+ let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
+ assert_eq!(bytes, 3);
+ assert!(files.is_none());
+
+ // If the target fd array is too small, extra file descriptors will get lost.
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
+ assert_eq!(bytes, 4);
+ assert!(files.is_some());
+ }
+
+ #[test]
+ fn send_recv() {
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
+ let sock = listener.accept().unwrap().unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(sock);
+
+ let mut hdr1 =
+ VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32);
+ hdr1.set_need_reply(true);
+ let features1 = 0x1u64;
+ master.send_message(&hdr1, &features1, None).unwrap();
+
+ let mut features2 = 0u64;
+
+ // SAFETY: Safe because features2 is valid and it's an `u64`.
+ let slice = unsafe {
+ slice::from_raw_parts_mut(
+ (&mut features2 as *mut u64) as *mut u8,
+ mem::size_of::<u64>(),
+ )
+ };
+ let (hdr2, bytes, files) = slave.recv_body_into_buf(slice).unwrap();
+ assert_eq!(hdr1, hdr2);
+ assert_eq!(bytes, 8);
+ assert_eq!(features1, features2);
+ assert!(files.is_none());
+
+ master.send_header(&hdr1, None).unwrap();
+ let (hdr2, files) = slave.recv_header().unwrap();
+ assert_eq!(hdr1, hdr2);
+ assert!(files.is_none());
+ }
+
+ #[test]
+ fn partial_message() {
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
+ let mut master = UnixStream::connect(&path).unwrap();
+ let sock = listener.accept().unwrap().unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(sock);
+
+ write!(master, "a").unwrap();
+ drop(master);
+ assert!(matches!(slave.recv_header(), Err(Error::PartialMessage)));
+ }
+
+ #[test]
+ fn disconnected() {
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
+ let _ = UnixStream::connect(&path).unwrap();
+ let sock = listener.accept().unwrap().unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(sock);
+
+ assert!(matches!(slave.recv_header(), Err(Error::Disconnected)));
+ }
+}
diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs
new file mode 100644
index 0000000..ae728a0
--- /dev/null
+++ b/src/vhost_user/dummy_slave.rs
@@ -0,0 +1,294 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+use std::fs::File;
+
+use super::message::*;
+use super::*;
+
+pub const MAX_QUEUE_NUM: usize = 2;
+pub const MAX_VRING_NUM: usize = 256;
+pub const MAX_MEM_SLOTS: usize = 32;
+pub const VIRTIO_FEATURES: u64 = 0x40000003;
+
+#[derive(Default)]
+pub struct DummySlaveReqHandler {
+ pub owned: bool,
+ pub features_acked: bool,
+ pub acked_features: u64,
+ pub acked_protocol_features: u64,
+ pub queue_num: usize,
+ pub vring_num: [u32; MAX_QUEUE_NUM],
+ pub vring_base: [u32; MAX_QUEUE_NUM],
+ pub call_fd: [Option<File>; MAX_QUEUE_NUM],
+ pub kick_fd: [Option<File>; MAX_QUEUE_NUM],
+ pub err_fd: [Option<File>; MAX_QUEUE_NUM],
+ pub vring_started: [bool; MAX_QUEUE_NUM],
+ pub vring_enabled: [bool; MAX_QUEUE_NUM],
+ pub inflight_file: Option<File>,
+}
+
+impl DummySlaveReqHandler {
+ pub fn new() -> Self {
+ DummySlaveReqHandler {
+ queue_num: MAX_QUEUE_NUM,
+ ..Default::default()
+ }
+ }
+
+ /// Helper to check if VirtioFeature enabled
+ fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> {
+ if self.acked_features & feat.bits() != 0 {
+ Ok(())
+ } else {
+ Err(Error::InactiveFeature(feat))
+ }
+ }
+
+ /// Helper to check is VhostUserProtocolFeatures enabled
+ fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> Result<()> {
+ if self.acked_protocol_features & feat.bits() != 0 {
+ Ok(())
+ } else {
+ Err(Error::InactiveOperation(feat))
+ }
+ }
+}
+
+impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
+ fn set_owner(&mut self) -> Result<()> {
+ if self.owned {
+ return Err(Error::InvalidOperation("already claimed"));
+ }
+ self.owned = true;
+ Ok(())
+ }
+
+ fn reset_owner(&mut self) -> Result<()> {
+ self.owned = false;
+ self.features_acked = false;
+ self.acked_features = 0;
+ self.acked_protocol_features = 0;
+ Ok(())
+ }
+
+ fn get_features(&mut self) -> Result<u64> {
+ Ok(VIRTIO_FEATURES)
+ }
+
+ fn set_features(&mut self, features: u64) -> Result<()> {
+ if !self.owned {
+ return Err(Error::InvalidOperation("not owned"));
+ } else if self.features_acked {
+ return Err(Error::InvalidOperation("features already set"));
+ } else if (features & !VIRTIO_FEATURES) != 0 {
+ return Err(Error::InvalidParam);
+ }
+
+ self.acked_features = features;
+ self.features_acked = true;
+
+ // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated,
+ // the ring is initialized in an enabled state.
+ // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated,
+ // the ring is initialized in a disabled state. Client must not
+ // pass data to/from the backend until ring is enabled by
+ // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has
+ // been disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0.
+ let vring_enabled =
+ self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0;
+ for enabled in &mut self.vring_enabled {
+ *enabled = vring_enabled;
+ }
+
+ Ok(())
+ }
+
+ fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _files: Vec<File>) -> Result<()> {
+ Ok(())
+ }
+
+ fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
+ if index as usize >= self.queue_num || num == 0 || num as usize > MAX_VRING_NUM {
+ return Err(Error::InvalidParam);
+ }
+ self.vring_num[index as usize] = num;
+ Ok(())
+ }
+
+ fn set_vring_addr(
+ &mut self,
+ index: u32,
+ _flags: VhostUserVringAddrFlags,
+ _descriptor: u64,
+ _used: u64,
+ _available: u64,
+ _log: u64,
+ ) -> Result<()> {
+ if index as usize >= self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ Ok(())
+ }
+
+ fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> {
+ if index as usize >= self.queue_num || base as usize >= MAX_VRING_NUM {
+ return Err(Error::InvalidParam);
+ }
+ self.vring_base[index as usize] = base;
+ Ok(())
+ }
+
+ fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> {
+ if index as usize >= self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ // Quotation from vhost-user spec:
+ // Client must start ring upon receiving a kick (that is, detecting
+ // that file descriptor is readable) on the descriptor specified by
+ // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
+ // VHOST_USER_GET_VRING_BASE.
+ self.vring_started[index as usize] = false;
+ Ok(VhostUserVringState::new(
+ index,
+ self.vring_base[index as usize],
+ ))
+ }
+
+ fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> {
+ if index as usize >= self.queue_num || index as usize > self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ self.kick_fd[index as usize] = fd;
+
+ // Quotation from vhost-user spec:
+ // Client must start ring upon receiving a kick (that is, detecting
+ // that file descriptor is readable) on the descriptor specified by
+ // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
+ // VHOST_USER_GET_VRING_BASE.
+ //
+ // So we should add fd to event monitor(select, poll, epoll) here.
+ self.vring_started[index as usize] = true;
+ Ok(())
+ }
+
+ fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> {
+ if index as usize >= self.queue_num || index as usize > self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ self.call_fd[index as usize] = fd;
+ Ok(())
+ }
+
+ fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> {
+ if index as usize >= self.queue_num || index as usize > self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ self.err_fd[index as usize] = fd;
+ Ok(())
+ }
+
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
+ Ok(VhostUserProtocolFeatures::all())
+ }
+
+ fn set_protocol_features(&mut self, features: u64) -> Result<()> {
+ // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must
+ // support this message even before VHOST_USER_SET_FEATURES was
+ // called.
+ // What happens if the master calls set_features() with
+ // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this
+ // interface?
+ self.acked_protocol_features = features;
+ Ok(())
+ }
+
+ fn get_queue_num(&mut self) -> Result<u64> {
+ Ok(MAX_QUEUE_NUM as u64)
+ }
+
+ fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
+ // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
+ // has been negotiated.
+ self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
+
+ if index as usize >= self.queue_num || index as usize > self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+
+ // Slave must not pass data to/from the backend until ring is
+ // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1,
+ // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE
+ // with parameter 0.
+ self.vring_enabled[index as usize] = enable;
+ Ok(())
+ }
+
+ fn get_config(
+ &mut self,
+ offset: u32,
+ size: u32,
+ _flags: VhostUserConfigFlags,
+ ) -> Result<Vec<u8>> {
+ self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
+
+ if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
+ || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
+ || size + offset > VHOST_USER_CONFIG_SIZE
+ {
+ return Err(Error::InvalidParam);
+ }
+ assert_eq!(offset, 0x100);
+ assert_eq!(size, 4);
+ Ok(vec![0xa5; size as usize])
+ }
+
+ fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> {
+ let size = buf.len() as u32;
+ self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
+
+ if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
+ || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
+ || size + offset > VHOST_USER_CONFIG_SIZE
+ {
+ return Err(Error::InvalidParam);
+ }
+ assert_eq!(offset, 0x100);
+ assert_eq!(buf.len(), 4);
+ assert_eq!(buf, &[0xa5; 4]);
+ Ok(())
+ }
+
+ fn get_inflight_fd(
+ &mut self,
+ inflight: &VhostUserInflight,
+ ) -> Result<(VhostUserInflight, File)> {
+ let file = tempfile::tempfile().unwrap();
+ self.inflight_file = Some(file.try_clone().unwrap());
+ Ok((
+ VhostUserInflight {
+ mmap_size: 0x1000,
+ mmap_offset: 0,
+ num_queues: inflight.num_queues,
+ queue_size: inflight.queue_size,
+ },
+ file,
+ ))
+ }
+
+ fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> Result<()> {
+ Ok(())
+ }
+
+ fn get_max_mem_slots(&mut self) -> Result<u64> {
+ Ok(MAX_MEM_SLOTS as u64)
+ }
+
+ fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: File) -> Result<()> {
+ Ok(())
+ }
+
+ fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> Result<()> {
+ Ok(())
+ }
+}
diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs
new file mode 100644
index 0000000..feeb984
--- /dev/null
+++ b/src/vhost_user/master.rs
@@ -0,0 +1,1110 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! Traits and Struct for vhost-user master.
+
+use std::fs::File;
+use std::mem;
+use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::net::UnixStream;
+use std::path::Path;
+use std::sync::{Arc, Mutex, MutexGuard};
+
+use vm_memory::ByteValued;
+use vmm_sys_util::eventfd::EventFd;
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::{take_single_file, Error as VhostUserError, Result as VhostUserResult};
+use crate::backend::{
+ VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData,
+};
+use crate::{Error, Result};
+
+/// Trait for vhost-user master to provide extra methods not covered by the VhostBackend yet.
+pub trait VhostUserMaster: VhostBackend {
+ /// Get the protocol feature bitmask from the underlying vhost implementation.
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
+
+ /// Enable protocol features in the underlying vhost implementation.
+ fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()>;
+
+ /// Query how many queues the backend supports.
+ fn get_queue_num(&mut self) -> Result<u64>;
+
+ /// Signal slave to enable or disable corresponding vring.
+ ///
+ /// Slave must not pass data to/from the backend until ring is enabled by
+ /// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been
+ /// disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0.
+ fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()>;
+
+ /// Fetch the contents of the virtio device configuration space.
+ fn get_config(
+ &mut self,
+ offset: u32,
+ size: u32,
+ flags: VhostUserConfigFlags,
+ buf: &[u8],
+ ) -> Result<(VhostUserConfig, VhostUserConfigPayload)>;
+
+ /// Change the virtio device configuration space. It also can be used for live migration on the
+ /// destination host to set readonly configuration space fields.
+ fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()>;
+
+ /// Setup slave communication channel.
+ fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()>;
+
+ /// Retrieve shared buffer for inflight I/O tracking.
+ fn get_inflight_fd(
+ &mut self,
+ inflight: &VhostUserInflight,
+ ) -> Result<(VhostUserInflight, File)>;
+
+ /// Set shared buffer for inflight I/O tracking.
+ fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()>;
+
+ /// Query the maximum amount of memory slots supported by the backend.
+ fn get_max_mem_slots(&mut self) -> Result<u64>;
+
+ /// Add a new guest memory mapping for vhost to use.
+ fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>;
+
+ /// Remove a guest memory mapping from vhost.
+ fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>;
+}
+
+fn error_code<T>(err: VhostUserError) -> Result<T> {
+ Err(Error::VhostUserProtocol(err))
+}
+
+/// Struct for the vhost-user master endpoint.
+#[derive(Clone)]
+pub struct Master {
+ node: Arc<Mutex<MasterInternal>>,
+}
+
+impl Master {
+ /// Create a new instance.
+ fn new(ep: Endpoint<MasterReq>, max_queue_num: u64) -> Self {
+ Master {
+ node: Arc::new(Mutex::new(MasterInternal {
+ main_sock: ep,
+ virtio_features: 0,
+ acked_virtio_features: 0,
+ protocol_features: 0,
+ acked_protocol_features: 0,
+ protocol_features_ready: false,
+ max_queue_num,
+ error: None,
+ hdr_flags: VhostUserHeaderFlag::empty(),
+ })),
+ }
+ }
+
+ fn node(&self) -> MutexGuard<MasterInternal> {
+ self.node.lock().unwrap()
+ }
+
+ /// Create a new instance from a Unix stream socket.
+ pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self {
+ Self::new(Endpoint::<MasterReq>::from_stream(sock), max_queue_num)
+ }
+
+ /// Create a new vhost-user master endpoint.
+ ///
+ /// Will retry as the backend may not be ready to accept the connection.
+ ///
+ /// # Arguments
+ /// * `path` - path of Unix domain socket listener to connect to
+ pub fn connect<P: AsRef<Path>>(path: P, max_queue_num: u64) -> Result<Self> {
+ let mut retry_count = 5;
+ let endpoint = loop {
+ match Endpoint::<MasterReq>::connect(&path) {
+ Ok(endpoint) => break Ok(endpoint),
+ Err(e) => match &e {
+ VhostUserError::SocketConnect(why) => {
+ if why.kind() == std::io::ErrorKind::ConnectionRefused && retry_count > 0 {
+ std::thread::sleep(std::time::Duration::from_millis(100));
+ retry_count -= 1;
+ continue;
+ } else {
+ break Err(e);
+ }
+ }
+ _ => break Err(e),
+ },
+ }
+ }?;
+
+ Ok(Self::new(endpoint, max_queue_num))
+ }
+
+ /// Set the header flags that should be applied to all following messages.
+ pub fn set_hdr_flags(&self, flags: VhostUserHeaderFlag) {
+ let mut node = self.node();
+ node.hdr_flags = flags;
+ }
+}
+
+impl VhostBackend for Master {
+ /// Get from the underlying vhost implementation the feature bitmask.
+ fn get_features(&self) -> Result<u64> {
+ let mut node = self.node();
+ let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?;
+ let val = node.recv_reply::<VhostUserU64>(&hdr)?;
+ node.virtio_features = val.value;
+ Ok(node.virtio_features)
+ }
+
+ /// Enable features in the underlying vhost implementation using a bitmask.
+ fn set_features(&self, features: u64) -> Result<()> {
+ let mut node = self.node();
+ let val = VhostUserU64::new(features);
+ let hdr = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?;
+ node.acked_virtio_features = features & node.virtio_features;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ /// Set the current Master as an owner of the session.
+ fn set_owner(&self) -> Result<()> {
+ // We unwrap() the return value to assert that we are not expecting threads to ever fail
+ // while holding the lock.
+ let mut node = self.node();
+ let hdr = node.send_request_header(MasterReq::SET_OWNER, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ let mut node = self.node();
+ let hdr = node.send_request_header(MasterReq::RESET_OWNER, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ /// Set the memory map regions on the slave so it can translate the vring
+ /// addresses. In the ancillary data there is an array of file descriptors
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let mut ctx = VhostUserMemoryContext::new();
+ for region in regions.iter() {
+ if region.memory_size == 0 || region.mmap_handle < 0 {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ ctx.append(&region.to_region(), region.mmap_handle);
+ }
+
+ let mut node = self.node();
+ let body = VhostUserMemory::new(ctx.regions.len() as u32);
+ // SAFETY: Safe because ctx.regions is a valid Vec() at this point.
+ let (_, payload, _) = unsafe { ctx.regions.align_to::<u8>() };
+ let hdr = node.send_request_with_payload(
+ MasterReq::SET_MEM_TABLE,
+ &body,
+ payload,
+ Some(ctx.fds.as_slice()),
+ )?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ // Clippy doesn't seem to know that if let with && is still experimental
+ #[allow(clippy::unnecessary_unwrap)]
+ fn set_log_base(&self, base: u64, region: Option<VhostUserDirtyLogRegion>) -> Result<()> {
+ let mut node = self.node();
+ let val = VhostUserU64::new(base);
+
+ if node.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0
+ && region.is_some()
+ {
+ let region = region.unwrap();
+ let log = VhostUserLog {
+ mmap_size: region.mmap_size,
+ mmap_offset: region.mmap_offset,
+ };
+ let hdr = node.send_request_with_body(
+ MasterReq::SET_LOG_BASE,
+ &log,
+ Some(&[region.mmap_handle]),
+ )?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ } else {
+ let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, None)?;
+ Ok(())
+ }
+ }
+
+ fn set_log_fd(&self, fd: RawFd) -> Result<()> {
+ let mut node = self.node();
+ let fds = [fd];
+ let hdr = node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ /// Set the size of the queue.
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let val = VhostUserVringState::new(queue_index as u32, num.into());
+ let hdr = node.send_request_with_body(MasterReq::SET_VRING_NUM, &val, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ /// Sets the addresses of the different aspects of the vring.
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num
+ || config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0
+ {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let val = VhostUserVringAddr::from_config_data(queue_index as u32, config_data);
+ let hdr = node.send_request_with_body(MasterReq::SET_VRING_ADDR, &val, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ /// Sets the base offset in the available vring.
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let val = VhostUserVringState::new(queue_index as u32, base.into());
+ let hdr = node.send_request_with_body(MasterReq::SET_VRING_BASE, &val, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let req = VhostUserVringState::new(queue_index as u32, 0);
+ let hdr = node.send_request_with_body(MasterReq::GET_VRING_BASE, &req, None)?;
+ let reply = node.recv_reply::<VhostUserVringState>(&hdr)?;
+ Ok(reply.num)
+ }
+
+ /// Set the event file descriptor to signal when buffers are used.
+ /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
+ /// is set when there is no file descriptor in the ancillary data. This signals that polling
+ /// will be used instead of waiting for the call.
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_CALL, queue_index, fd.as_raw_fd())?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ /// Set the event file descriptor for adding buffers to the vring.
+ /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
+ /// is set when there is no file descriptor in the ancillary data. This signals that polling
+ /// should be used instead of waiting for a kick.
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_KICK, queue_index, fd.as_raw_fd())?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ /// Set the event file descriptor to signal when error occurs.
+ /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
+ /// is set when there is no file descriptor in the ancillary data.
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_ERR, queue_index, fd.as_raw_fd())?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+}
+
+impl VhostUserMaster for Master {
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
+ let mut node = self.node();
+ node.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
+ let hdr = node.send_request_header(MasterReq::GET_PROTOCOL_FEATURES, None)?;
+ let val = node.recv_reply::<VhostUserU64>(&hdr)?;
+ node.protocol_features = val.value;
+ // Should we support forward compatibility?
+ // If so just mask out unrecognized flags instead of return errors.
+ match VhostUserProtocolFeatures::from_bits(node.protocol_features) {
+ Some(val) => Ok(val),
+ None => error_code(VhostUserError::InvalidMessage),
+ }
+ }
+
+ fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> {
+ let mut node = self.node();
+ node.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
+ let val = VhostUserU64::new(features.bits());
+ let hdr = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?;
+ // Don't wait for ACK here because the protocol feature negotiation process hasn't been
+ // completed yet.
+ node.acked_protocol_features = features.bits();
+ node.protocol_features_ready = true;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn get_queue_num(&mut self) -> Result<u64> {
+ let mut node = self.node();
+ node.check_proto_feature(VhostUserProtocolFeatures::MQ)?;
+
+ let hdr = node.send_request_header(MasterReq::GET_QUEUE_NUM, None)?;
+ let val = node.recv_reply::<VhostUserU64>(&hdr)?;
+ if val.value > VHOST_USER_MAX_VRINGS {
+ return error_code(VhostUserError::InvalidMessage);
+ }
+ node.max_queue_num = val.value;
+ Ok(node.max_queue_num)
+ }
+
+ fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()> {
+ let mut node = self.node();
+ // set_vring_enable() is supported only when PROTOCOL_FEATURES has been enabled.
+ if node.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
+ return error_code(VhostUserError::InactiveFeature(
+ VhostUserVirtioFeatures::PROTOCOL_FEATURES,
+ ));
+ } else if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let flag = if enable { 1 } else { 0 };
+ let val = VhostUserVringState::new(queue_index as u32, flag);
+ let hdr = node.send_request_with_body(MasterReq::SET_VRING_ENABLE, &val, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn get_config(
+ &mut self,
+ offset: u32,
+ size: u32,
+ flags: VhostUserConfigFlags,
+ buf: &[u8],
+ ) -> Result<(VhostUserConfig, VhostUserConfigPayload)> {
+ let body = VhostUserConfig::new(offset, size, flags);
+ if !body.is_valid() {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let mut node = self.node();
+ // depends on VhostUserProtocolFeatures::CONFIG
+ node.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
+
+ // vhost-user spec states that:
+ // "Master payload: virtio device config space"
+ // "Slave payload: virtio device config space"
+ let hdr = node.send_request_with_payload(MasterReq::GET_CONFIG, &body, buf, None)?;
+ let (body_reply, buf_reply, rfds) =
+ node.recv_reply_with_payload::<VhostUserConfig>(&hdr)?;
+ if rfds.is_some() {
+ return error_code(VhostUserError::InvalidMessage);
+ } else if body_reply.size == 0 {
+ return error_code(VhostUserError::SlaveInternalError);
+ } else if body_reply.size != body.size
+ || body_reply.size as usize != buf.len()
+ || body_reply.offset != body.offset
+ {
+ return error_code(VhostUserError::InvalidMessage);
+ }
+
+ Ok((body_reply, buf_reply))
+ }
+
+ fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()> {
+ if buf.len() > MAX_MSG_SIZE {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ let body = VhostUserConfig::new(offset, buf.len() as u32, flags);
+ if !body.is_valid() {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let mut node = self.node();
+ // depends on VhostUserProtocolFeatures::CONFIG
+ node.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
+
+ let hdr = node.send_request_with_payload(MasterReq::SET_CONFIG, &body, buf, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()> {
+ let mut node = self.node();
+ node.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?;
+ let fds = [fd.as_raw_fd()];
+ let hdr = node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn get_inflight_fd(
+ &mut self,
+ inflight: &VhostUserInflight,
+ ) -> Result<(VhostUserInflight, File)> {
+ let mut node = self.node();
+ node.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
+
+ let hdr = node.send_request_with_body(MasterReq::GET_INFLIGHT_FD, inflight, None)?;
+ let (inflight, files) = node.recv_reply_with_files::<VhostUserInflight>(&hdr)?;
+
+ match take_single_file(files) {
+ Some(file) => Ok((inflight, file)),
+ None => error_code(VhostUserError::IncorrectFds),
+ }
+ }
+
+ fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()> {
+ let mut node = self.node();
+ node.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
+
+ if inflight.mmap_size == 0 || inflight.num_queues == 0 || inflight.queue_size == 0 || fd < 0
+ {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let hdr = node.send_request_with_body(MasterReq::SET_INFLIGHT_FD, inflight, Some(&[fd]))?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn get_max_mem_slots(&mut self) -> Result<u64> {
+ let mut node = self.node();
+ node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
+
+ let hdr = node.send_request_header(MasterReq::GET_MAX_MEM_SLOTS, None)?;
+ let val = node.recv_reply::<VhostUserU64>(&hdr)?;
+
+ Ok(val.value)
+ }
+
+ fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> {
+ let mut node = self.node();
+ node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
+ if region.memory_size == 0 || region.mmap_handle < 0 {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let body = region.to_single_region();
+ let fds = [region.mmap_handle];
+ let hdr = node.send_request_with_body(MasterReq::ADD_MEM_REG, &body, Some(&fds))?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> {
+ let mut node = self.node();
+ node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
+ if region.memory_size == 0 {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let body = region.to_single_region();
+ let hdr = node.send_request_with_body(MasterReq::REM_MEM_REG, &body, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+}
+
+impl AsRawFd for Master {
+ fn as_raw_fd(&self) -> RawFd {
+ let node = self.node();
+ node.main_sock.as_raw_fd()
+ }
+}
+
+/// Context object to pass guest memory configuration to VhostUserMaster::set_mem_table().
+struct VhostUserMemoryContext {
+ regions: VhostUserMemoryPayload,
+ fds: Vec<RawFd>,
+}
+
+impl VhostUserMemoryContext {
+ /// Create a context object.
+ pub fn new() -> Self {
+ VhostUserMemoryContext {
+ regions: VhostUserMemoryPayload::new(),
+ fds: Vec::new(),
+ }
+ }
+
+ /// Append a user memory region and corresponding RawFd into the context object.
+ pub fn append(&mut self, region: &VhostUserMemoryRegion, fd: RawFd) {
+ self.regions.push(*region);
+ self.fds.push(fd);
+ }
+}
+
+struct MasterInternal {
+ // Used to send requests to the slave.
+ main_sock: Endpoint<MasterReq>,
+ // Cached virtio features from the slave.
+ virtio_features: u64,
+ // Cached acked virtio features from the driver.
+ acked_virtio_features: u64,
+ // Cached vhost-user protocol features from the slave.
+ protocol_features: u64,
+ // Cached vhost-user protocol features.
+ acked_protocol_features: u64,
+ // Cached vhost-user protocol features are ready to use.
+ protocol_features_ready: bool,
+ // Cached maxinum number of queues supported from the slave.
+ max_queue_num: u64,
+ // Internal flag to mark failure state.
+ error: Option<i32>,
+ // List of header flags.
+ hdr_flags: VhostUserHeaderFlag,
+}
+
+impl MasterInternal {
+ fn send_request_header(
+ &mut self,
+ code: MasterReq,
+ fds: Option<&[RawFd]>,
+ ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
+ self.check_state()?;
+ let hdr = self.new_request_header(code, 0);
+ self.main_sock.send_header(&hdr, fds)?;
+ Ok(hdr)
+ }
+
+ fn send_request_with_body<T: ByteValued>(
+ &mut self,
+ code: MasterReq,
+ msg: &T,
+ fds: Option<&[RawFd]>,
+ ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ let hdr = self.new_request_header(code, mem::size_of::<T>() as u32);
+ self.main_sock.send_message(&hdr, msg, fds)?;
+ Ok(hdr)
+ }
+
+ fn send_request_with_payload<T: ByteValued>(
+ &mut self,
+ code: MasterReq,
+ msg: &T,
+ payload: &[u8],
+ fds: Option<&[RawFd]>,
+ ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
+ let len = mem::size_of::<T>() + payload.len();
+ if len > MAX_MSG_SIZE {
+ return Err(VhostUserError::InvalidParam);
+ }
+ if let Some(fd_arr) = fds {
+ if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
+ return Err(VhostUserError::InvalidParam);
+ }
+ }
+ self.check_state()?;
+
+ let hdr = self.new_request_header(code, len as u32);
+ self.main_sock
+ .send_message_with_payload(&hdr, msg, payload, fds)?;
+ Ok(hdr)
+ }
+
+ fn send_fd_for_vring(
+ &mut self,
+ code: MasterReq,
+ queue_index: usize,
+ fd: RawFd,
+ ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
+ if queue_index as u64 >= self.max_queue_num {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ // Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag.
+ // This flag is set when there is no file descriptor in the ancillary data. This signals
+ // that polling will be used instead of waiting for the call.
+ let msg = VhostUserU64::new(queue_index as u64);
+ let hdr = self.new_request_header(code, mem::size_of::<VhostUserU64>() as u32);
+ self.main_sock.send_message(&hdr, &msg, Some(&[fd]))?;
+ Ok(hdr)
+ }
+
+ fn recv_reply<T: ByteValued + Sized + VhostUserMsgValidator>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ ) -> VhostUserResult<T> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ let (reply, body, rfds) = self.main_sock.recv_body::<T>()?;
+ if !reply.is_reply_for(hdr) || rfds.is_some() || !body.is_valid() {
+ return Err(VhostUserError::InvalidMessage);
+ }
+ Ok(body)
+ }
+
+ fn recv_reply_with_files<T: ByteValued + Sized + VhostUserMsgValidator>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ ) -> VhostUserResult<(T, Option<Vec<File>>)> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ let (reply, body, files) = self.main_sock.recv_body::<T>()?;
+ if !reply.is_reply_for(hdr) || files.is_none() || !body.is_valid() {
+ return Err(VhostUserError::InvalidMessage);
+ }
+ Ok((body, files))
+ }
+
+ fn recv_reply_with_payload<T: ByteValued + Sized + VhostUserMsgValidator>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<File>>)> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE
+ || hdr.get_size() as usize <= mem::size_of::<T>()
+ || hdr.get_size() as usize > MAX_MSG_SIZE
+ || hdr.is_reply()
+ {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ let mut buf: Vec<u8> = vec![0; hdr.get_size() as usize - mem::size_of::<T>()];
+ let (reply, body, bytes, files) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?;
+ if !reply.is_reply_for(hdr)
+ || reply.get_size() as usize != mem::size_of::<T>() + bytes
+ || files.is_some()
+ || !body.is_valid()
+ || bytes != buf.len()
+ {
+ return Err(VhostUserError::InvalidMessage);
+ }
+
+ Ok((body, buf, files))
+ }
+
+ fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<MasterReq>) -> VhostUserResult<()> {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::REPLY_ACK.bits() == 0
+ || !hdr.is_need_reply()
+ {
+ return Ok(());
+ }
+ self.check_state()?;
+
+ let (reply, body, rfds) = self.main_sock.recv_body::<VhostUserU64>()?;
+ if !reply.is_reply_for(hdr) || rfds.is_some() || !body.is_valid() {
+ return Err(VhostUserError::InvalidMessage);
+ }
+ if body.value != 0 {
+ return Err(VhostUserError::SlaveInternalError);
+ }
+ Ok(())
+ }
+
+ fn check_feature(&self, feat: VhostUserVirtioFeatures) -> VhostUserResult<()> {
+ if self.virtio_features & feat.bits() != 0 {
+ Ok(())
+ } else {
+ Err(VhostUserError::InactiveFeature(feat))
+ }
+ }
+
+ fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> VhostUserResult<()> {
+ if self.acked_protocol_features & feat.bits() != 0 {
+ Ok(())
+ } else {
+ Err(VhostUserError::InactiveOperation(feat))
+ }
+ }
+
+ fn check_state(&self) -> VhostUserResult<()> {
+ match self.error {
+ Some(e) => Err(VhostUserError::SocketBroken(
+ std::io::Error::from_raw_os_error(e),
+ )),
+ None => Ok(()),
+ }
+ }
+
+ #[inline]
+ fn new_request_header(&self, request: MasterReq, size: u32) -> VhostUserMsgHeader<MasterReq> {
+ VhostUserMsgHeader::new(request, self.hdr_flags.bits() | 0x1, size)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::super::connection::Listener;
+ use super::*;
+ use vmm_sys_util::rand::rand_alphanumerics;
+
+ use std::path::PathBuf;
+
+ fn temp_path() -> PathBuf {
+ PathBuf::from(format!(
+ "/tmp/vhost_test_{}",
+ rand_alphanumerics(8).to_str().unwrap()
+ ))
+ }
+
+ fn create_pair<P: AsRef<Path>>(path: P) -> (Master, Endpoint<MasterReq>) {
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+ let master = Master::connect(path, 2).unwrap();
+ let slave = listener.accept().unwrap().unwrap();
+ (master, Endpoint::from_stream(slave))
+ }
+
+ #[test]
+ fn create_master() {
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+
+ let master = Master::connect(&path, 1).unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(listener.accept().unwrap().unwrap());
+
+ assert!(master.as_raw_fd() > 0);
+ // Send two messages continuously
+ master.set_owner().unwrap();
+ master.reset_owner().unwrap();
+
+ let (hdr, rfds) = slave.recv_header().unwrap();
+ assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER);
+ assert_eq!(hdr.get_size(), 0);
+ assert_eq!(hdr.get_version(), 0x1);
+ assert!(rfds.is_none());
+
+ let (hdr, rfds) = slave.recv_header().unwrap();
+ assert_eq!(hdr.get_code().unwrap(), MasterReq::RESET_OWNER);
+ assert_eq!(hdr.get_size(), 0);
+ assert_eq!(hdr.get_version(), 0x1);
+ assert!(rfds.is_none());
+ }
+
+ #[test]
+ fn test_create_failure() {
+ let path = temp_path();
+ let _ = Listener::new(&path, true).unwrap();
+ let _ = Listener::new(&path, false).is_err();
+ assert!(Master::connect(&path, 1).is_err());
+
+ let listener = Listener::new(&path, true).unwrap();
+ assert!(Listener::new(&path, false).is_err());
+ listener.set_nonblocking(true).unwrap();
+
+ let _master = Master::connect(&path, 1).unwrap();
+ let _slave = listener.accept().unwrap().unwrap();
+ }
+
+ #[test]
+ fn test_features() {
+ let path = temp_path();
+ let (master, mut peer) = create_pair(path);
+
+ master.set_owner().unwrap();
+ let (hdr, rfds) = peer.recv_header().unwrap();
+ assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER);
+ assert_eq!(hdr.get_size(), 0);
+ assert_eq!(hdr.get_version(), 0x1);
+ assert!(rfds.is_none());
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(0x15);
+ peer.send_message(&hdr, &msg, None).unwrap();
+ let features = master.get_features().unwrap();
+ assert_eq!(features, 0x15u64);
+ let (_hdr, rfds) = peer.recv_header().unwrap();
+ assert!(rfds.is_none());
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::SET_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(0x15);
+ peer.send_message(&hdr, &msg, None).unwrap();
+ master.set_features(0x15).unwrap();
+ let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
+ assert!(rfds.is_none());
+ let val = msg.value;
+ assert_eq!(val, 0x15);
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
+ let msg = 0x15u32;
+ peer.send_message(&hdr, &msg, None).unwrap();
+ assert!(master.get_features().is_err());
+ }
+
+ #[test]
+ fn test_protocol_features() {
+ let path = temp_path();
+ let (mut master, mut peer) = create_pair(path);
+
+ master.set_owner().unwrap();
+ let (hdr, rfds) = peer.recv_header().unwrap();
+ assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER);
+ assert!(rfds.is_none());
+
+ assert!(master.get_protocol_features().is_err());
+ assert!(master
+ .set_protocol_features(VhostUserProtocolFeatures::all())
+ .is_err());
+
+ let vfeatures = 0x15 | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(vfeatures);
+ peer.send_message(&hdr, &msg, None).unwrap();
+ let features = master.get_features().unwrap();
+ assert_eq!(features, vfeatures);
+ let (_hdr, rfds) = peer.recv_header().unwrap();
+ assert!(rfds.is_none());
+
+ master.set_features(vfeatures).unwrap();
+ let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
+ assert!(rfds.is_none());
+ let val = msg.value;
+ assert_eq!(val, vfeatures);
+
+ let pfeatures = VhostUserProtocolFeatures::all();
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_PROTOCOL_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(pfeatures.bits());
+ peer.send_message(&hdr, &msg, None).unwrap();
+ let features = master.get_protocol_features().unwrap();
+ assert_eq!(features, pfeatures);
+ let (_hdr, rfds) = peer.recv_header().unwrap();
+ assert!(rfds.is_none());
+
+ master.set_protocol_features(pfeatures).unwrap();
+ let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
+ assert!(rfds.is_none());
+ let val = msg.value;
+ assert_eq!(val, pfeatures.bits());
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::SET_PROTOCOL_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(pfeatures.bits());
+ peer.send_message(&hdr, &msg, None).unwrap();
+ assert!(master.get_protocol_features().is_err());
+ }
+
+ #[test]
+ fn test_master_set_config_negative() {
+ let path = temp_path();
+ let (mut master, _peer) = create_pair(path);
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+
+ {
+ let mut node = master.node();
+ node.virtio_features = 0xffff_ffff;
+ node.acked_virtio_features = 0xffff_ffff;
+ node.protocol_features = 0xffff_ffff;
+ node.acked_protocol_features = 0xffff_ffff;
+ }
+
+ master
+ .set_config(0, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap();
+ master
+ .set_config(
+ VHOST_USER_CONFIG_SIZE,
+ VhostUserConfigFlags::WRITABLE,
+ &buf[0..4],
+ )
+ .unwrap_err();
+ master
+ .set_config(0x1000, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+ master
+ .set_config(
+ 0x100,
+ // SAFETY: This is a negative test, so we are setting unexpected flags.
+ unsafe { VhostUserConfigFlags::from_bits_unchecked(0xffff_ffff) },
+ &buf[0..4],
+ )
+ .unwrap_err();
+ master
+ .set_config(VHOST_USER_CONFIG_SIZE, VhostUserConfigFlags::WRITABLE, &buf)
+ .unwrap_err();
+ master
+ .set_config(VHOST_USER_CONFIG_SIZE, VhostUserConfigFlags::WRITABLE, &[])
+ .unwrap_err();
+ }
+
+ fn create_pair2() -> (Master, Endpoint<MasterReq>) {
+ let path = temp_path();
+ let (master, peer) = create_pair(path);
+
+ {
+ let mut node = master.node();
+ node.virtio_features = 0xffff_ffff;
+ node.acked_virtio_features = 0xffff_ffff;
+ node.protocol_features = 0xffff_ffff;
+ node.acked_protocol_features = 0xffff_ffff;
+ }
+
+ (master, peer)
+ }
+
+ #[test]
+ fn test_master_get_config_negative0() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ hdr.set_code(MasterReq::GET_FEATURES);
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ hdr.set_code(MasterReq::GET_CONFIG);
+ }
+
+ #[test]
+ fn test_master_get_config_negative1() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ hdr.set_reply(false);
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative2() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+ }
+
+ #[test]
+ fn test_master_get_config_negative3() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = 0;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative4() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = 0x101;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative5() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = (MAX_MSG_SIZE + 1) as u32;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative6() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.size = 6;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..6], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_maset_set_mem_table_failure() {
+ let (master, _peer) = create_pair2();
+
+ master.set_mem_table(&[]).unwrap_err();
+ let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1];
+ master.set_mem_table(&tables).unwrap_err();
+ }
+}
diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs
new file mode 100644
index 0000000..c9c528b
--- /dev/null
+++ b/src/vhost_user/master_req_handler.rs
@@ -0,0 +1,466 @@
+// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+use std::fs::File;
+use std::mem;
+use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::net::UnixStream;
+use std::sync::{Arc, Mutex};
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::{Error, HandlerResult, Result};
+
+/// Define services provided by masters for the slave communication channel.
+///
+/// The vhost-user specification defines a slave communication channel, by which slaves could
+/// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided
+/// by masters, and it's used both on the master side and slave side.
+/// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy
+/// service requests to masters. The [Slave] is an example stub forwarder.
+/// - on the master side, the [MasterReqHandler] will forward service requests to a handler
+/// implementing [VhostUserMasterReqHandler].
+///
+/// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance
+/// for multi-threading.
+///
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+/// [Slave]: struct.Slave.html
+pub trait VhostUserMasterReqHandler {
+ /// Handle device configuration change notifications.
+ fn handle_config_change(&self) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs map file requests.
+ fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs unmap file requests.
+ fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs sync file requests.
+ fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs file IO requests.
+ fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
+ // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: &dyn AsRawFd);
+}
+
+/// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability.
+///
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
+pub trait VhostUserMasterReqHandlerMut {
+ /// Handle device configuration change notifications.
+ fn handle_config_change(&mut self) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs map file requests.
+ fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs unmap file requests.
+ fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs sync file requests.
+ fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs file IO requests.
+ fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
+ // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
+}
+
+impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> {
+ fn handle_config_change(&self) -> HandlerResult<u64> {
+ self.lock().unwrap().handle_config_change()
+ }
+
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_map(fs, fd)
+ }
+
+ fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_unmap(fs)
+ }
+
+ fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_sync(fs)
+ }
+
+ fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_io(fs, fd)
+ }
+}
+
+/// Server to handle service requests from slaves from the slave communication channel.
+///
+/// The [MasterReqHandler] acts as a server on the master side, to handle service requests from
+/// slaves on the slave communication channel. It's actually a proxy invoking the registered
+/// handler implementing [VhostUserMasterReqHandler] to do the real work.
+///
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
+pub struct MasterReqHandler<S: VhostUserMasterReqHandler> {
+ // underlying Unix domain socket for communication
+ sub_sock: Endpoint<SlaveReq>,
+ tx_sock: UnixStream,
+ // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
+ reply_ack_negotiated: bool,
+ // the VirtIO backend device object
+ backend: Arc<S>,
+ // whether the endpoint has encountered any failure
+ error: Option<i32>,
+}
+
+impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
+ /// Create a server to handle service requests from slaves on the slave communication channel.
+ ///
+ /// This opens a pair of connected anonymous sockets to form the slave communication channel.
+ /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by
+ /// [VhostUserMaster::set_slave_request_fd()].
+ ///
+ /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd
+ /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
+ pub fn new(backend: Arc<S>) -> Result<Self> {
+ let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?;
+
+ Ok(MasterReqHandler {
+ sub_sock: Endpoint::<SlaveReq>::from_stream(rx),
+ tx_sock: tx,
+ reply_ack_negotiated: false,
+ backend,
+ error: None,
+ })
+ }
+
+ /// Get the socket fd for the slave to communication with the master.
+ ///
+ /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()].
+ ///
+ /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
+ pub fn get_tx_raw_fd(&self) -> RawFd {
+ self.tx_sock.as_raw_fd()
+ }
+
+ /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
+ ///
+ /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
+ /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
+ /// message.
+ pub fn set_reply_ack_flag(&mut self, enable: bool) {
+ self.reply_ack_negotiated = enable;
+ }
+
+ /// Mark endpoint as failed or in normal state.
+ pub fn set_failed(&mut self, error: i32) {
+ if error == 0 {
+ self.error = None;
+ } else {
+ self.error = Some(error);
+ }
+ }
+
+ /// Main entrance to server slave request from the slave communication channel.
+ ///
+ /// The caller needs to:
+ /// - serialize calls to this function
+ /// - decide what to do when errer happens
+ /// - optional recover from failure
+ pub fn handle_request(&mut self) -> Result<u64> {
+ // Return error if the endpoint is already in failed state.
+ self.check_state()?;
+
+ // The underlying communication channel is a Unix domain socket in
+ // stream mode, and recvmsg() is a little tricky here. To successfully
+ // receive attached file descriptors, we need to receive messages and
+ // corresponding attached file descriptors in this way:
+ // . recv messsage header and optional attached file
+ // . validate message header
+ // . recv optional message body and payload according size field in
+ // message header
+ // . validate message body and optional payload
+ let (hdr, files) = self.sub_sock.recv_header()?;
+ self.check_attached_files(&hdr, &files)?;
+ let (size, buf) = match hdr.get_size() {
+ 0 => (0, vec![0u8; 0]),
+ len => {
+ if len as usize > MAX_MSG_SIZE {
+ return Err(Error::InvalidMessage);
+ }
+ let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?;
+ if size2 != len as usize {
+ return Err(Error::InvalidMessage);
+ }
+ (size2, rbuf)
+ }
+ };
+
+ let res = match hdr.get_code() {
+ Ok(SlaveReq::CONFIG_CHANGE_MSG) => {
+ self.check_msg_size(&hdr, size, 0)?;
+ self.backend
+ .handle_config_change()
+ .map_err(Error::ReqHandlerError)
+ }
+ Ok(SlaveReq::FS_MAP) => {
+ let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ // check_attached_files() has validated files
+ self.backend
+ .fs_slave_map(&msg, &files.unwrap()[0])
+ .map_err(Error::ReqHandlerError)
+ }
+ Ok(SlaveReq::FS_UNMAP) => {
+ let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ self.backend
+ .fs_slave_unmap(&msg)
+ .map_err(Error::ReqHandlerError)
+ }
+ Ok(SlaveReq::FS_SYNC) => {
+ let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ self.backend
+ .fs_slave_sync(&msg)
+ .map_err(Error::ReqHandlerError)
+ }
+ Ok(SlaveReq::FS_IO) => {
+ let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ // check_attached_files() has validated files
+ self.backend
+ .fs_slave_io(&msg, &files.unwrap()[0])
+ .map_err(Error::ReqHandlerError)
+ }
+ _ => Err(Error::InvalidMessage),
+ };
+
+ self.send_ack_message(&hdr, &res)?;
+
+ res
+ }
+
+ fn check_state(&self) -> Result<()> {
+ match self.error {
+ Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
+ None => Ok(()),
+ }
+ }
+
+ fn check_msg_size(
+ &self,
+ hdr: &VhostUserMsgHeader<SlaveReq>,
+ size: usize,
+ expected: usize,
+ ) -> Result<()> {
+ if hdr.get_size() as usize != expected
+ || hdr.is_reply()
+ || hdr.get_version() != 0x1
+ || size != expected
+ {
+ return Err(Error::InvalidMessage);
+ }
+ Ok(())
+ }
+
+ fn check_attached_files(
+ &self,
+ hdr: &VhostUserMsgHeader<SlaveReq>,
+ files: &Option<Vec<File>>,
+ ) -> Result<()> {
+ match hdr.get_code() {
+ Ok(SlaveReq::FS_MAP | SlaveReq::FS_IO) => {
+ // Expect a single file is passed.
+ match files {
+ Some(files) if files.len() == 1 => Ok(()),
+ _ => Err(Error::InvalidMessage),
+ }
+ }
+ _ if files.is_some() => Err(Error::InvalidMessage),
+ _ => Ok(()),
+ }
+ }
+
+ fn extract_msg_body<T: Sized + VhostUserMsgValidator>(
+ &self,
+ hdr: &VhostUserMsgHeader<SlaveReq>,
+ size: usize,
+ buf: &[u8],
+ ) -> Result<T> {
+ self.check_msg_size(hdr, size, mem::size_of::<T>())?;
+ // SAFETY: Safe because we checked that `buf` size is equal to T size.
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ Ok(msg)
+ }
+
+ fn new_reply_header<T: Sized>(
+ &self,
+ req: &VhostUserMsgHeader<SlaveReq>,
+ ) -> Result<VhostUserMsgHeader<SlaveReq>> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::InvalidParam);
+ }
+ self.check_state()?;
+ Ok(VhostUserMsgHeader::new(
+ req.get_code()?,
+ VhostUserHeaderFlag::REPLY.bits(),
+ mem::size_of::<T>() as u32,
+ ))
+ }
+
+ fn send_ack_message(
+ &mut self,
+ req: &VhostUserMsgHeader<SlaveReq>,
+ res: &Result<u64>,
+ ) -> Result<()> {
+ if self.reply_ack_negotiated && req.is_need_reply() {
+ let hdr = self.new_reply_header::<VhostUserU64>(req)?;
+ let def_err = libc::EINVAL;
+ let val = match res {
+ Ok(n) => *n,
+ Err(e) => match e {
+ Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() {
+ Some(rawerr) => -rawerr as u64,
+ None => -def_err as u64,
+ },
+ _ => -def_err as u64,
+ },
+ };
+ let msg = VhostUserU64::new(val);
+ self.sub_sock.send_message(&hdr, &msg, None)?;
+ }
+ Ok(())
+ }
+}
+
+impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.sub_sock.as_raw_fd()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[cfg(feature = "vhost-user-slave")]
+ use crate::vhost_user::Slave;
+ #[cfg(feature = "vhost-user-slave")]
+ use std::os::unix::io::FromRawFd;
+
+ struct MockMasterReqHandler {}
+
+ impl VhostUserMasterReqHandlerMut for MockMasterReqHandler {
+ /// Handle virtio-fs map file requests from the slave.
+ fn fs_slave_map(
+ &mut self,
+ _fs: &VhostUserFSSlaveMsg,
+ _fd: &dyn AsRawFd,
+ ) -> HandlerResult<u64> {
+ Ok(0)
+ }
+
+ /// Handle virtio-fs unmap file requests from the slave.
+ fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+ }
+
+ #[test]
+ fn test_new_master_req_handler() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+
+ assert!(handler.get_tx_raw_fd() >= 0);
+ assert!(handler.as_raw_fd() >= 0);
+ handler.check_state().unwrap();
+
+ assert_eq!(handler.error, None);
+ handler.set_failed(libc::EAGAIN);
+ assert_eq!(handler.error, Some(libc::EAGAIN));
+ handler.check_state().unwrap_err();
+ }
+
+ #[cfg(feature = "vhost-user-slave")]
+ #[test]
+ fn test_master_slave_req_handler() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+
+ // SAFETY: Safe because `handler` contains valid fds, and we are
+ // checking if `dup` returns a valid fd.
+ let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
+ if fd < 0 {
+ panic!("failed to duplicated tx fd!");
+ }
+ // SAFETY: Safe because we checked if fd is valid.
+ let stream = unsafe { UnixStream::from_raw_fd(fd) };
+ let slave = Slave::from_stream(stream);
+
+ std::thread::spawn(move || {
+ let res = handler.handle_request().unwrap();
+ assert_eq!(res, 0);
+ handler.handle_request().unwrap_err();
+ });
+
+ slave
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd)
+ .unwrap();
+ // When REPLY_ACK has not been negotiated, the master has no way to detect failure from
+ // slave side.
+ slave
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap();
+ }
+
+ #[cfg(feature = "vhost-user-slave")]
+ #[test]
+ fn test_master_slave_req_handler_with_ack() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+ handler.set_reply_ack_flag(true);
+
+ // SAFETY: Safe because `handler` contains valid fds, and we are
+ // checking if `dup` returns a valid fd.
+ let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
+ if fd < 0 {
+ panic!("failed to duplicated tx fd!");
+ }
+ // SAFETY: Safe because we checked if fd is valid.
+ let stream = unsafe { UnixStream::from_raw_fd(fd) };
+ let slave = Slave::from_stream(stream);
+
+ std::thread::spawn(move || {
+ let res = handler.handle_request().unwrap();
+ assert_eq!(res, 0);
+ handler.handle_request().unwrap_err();
+ });
+
+ slave.set_reply_ack_flag(true);
+ slave
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd)
+ .unwrap();
+ slave
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ }
+}
diff --git a/src/vhost_user/message.rs b/src/vhost_user/message.rs
new file mode 100644
index 0000000..bbd8eb9
--- /dev/null
+++ b/src/vhost_user/message.rs
@@ -0,0 +1,1403 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! Define communication messages for the vhost-user protocol.
+//!
+//! For message definition, please refer to the [vhost-user spec](https://qemu.readthedocs.io/en/latest/interop/vhost-user.html).
+
+#![allow(dead_code)]
+#![allow(non_camel_case_types)]
+#![allow(clippy::upper_case_acronyms)]
+
+use std::fmt::Debug;
+use std::fs::File;
+use std::io;
+use std::marker::PhantomData;
+use std::ops::Deref;
+
+use vm_memory::{mmap::NewBitmap, ByteValued, Error as MmapError, FileOffset, MmapRegion};
+
+#[cfg(feature = "xen")]
+use vm_memory::{GuestAddress, MmapRange, MmapXenFlags};
+
+use super::{Error, Result};
+use crate::VringConfigData;
+
+/// The vhost-user specification uses a field of u32 to store message length.
+/// On the other hand, preallocated buffers are needed to receive messages from the Unix domain
+/// socket. To preallocating a 4GB buffer for each vhost-user message is really just an overhead.
+/// Among all defined vhost-user messages, only the VhostUserConfig and VhostUserMemory has variable
+/// message size. For the VhostUserConfig, a maximum size of 4K is enough because the user
+/// configuration space for virtio devices is (4K - 0x100) bytes at most. For the VhostUserMemory,
+/// 4K should be enough too because it can support 255 memory regions at most.
+pub const MAX_MSG_SIZE: usize = 0x1000;
+
+/// The VhostUserMemory message has variable message size and variable number of attached file
+/// descriptors. Each user memory region entry in the message payload occupies 32 bytes,
+/// so setting maximum number of attached file descriptors based on the maximum message size.
+/// But rust only implements Default and AsMut traits for arrays with 0 - 32 entries, so further
+/// reduce the maximum number...
+// pub const MAX_ATTACHED_FD_ENTRIES: usize = (MAX_MSG_SIZE - 8) / 32;
+pub const MAX_ATTACHED_FD_ENTRIES: usize = 32;
+
+/// Starting position (inclusion) of the device configuration space in virtio devices.
+pub const VHOST_USER_CONFIG_OFFSET: u32 = 0x100;
+
+/// Ending position (exclusion) of the device configuration space in virtio devices.
+pub const VHOST_USER_CONFIG_SIZE: u32 = 0x1000;
+
+/// Maximum number of vrings supported.
+pub const VHOST_USER_MAX_VRINGS: u64 = 0x8000u64;
+
+pub(super) trait Req:
+ Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Send + Sync + Into<u32>
+{
+ fn is_valid(value: u32) -> bool;
+}
+
+/// Type of requests sending from masters to slaves.
+#[repr(u32)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub enum MasterReq {
+ /// Null operation.
+ NOOP = 0,
+ /// Get from the underlying vhost implementation the features bit mask.
+ GET_FEATURES = 1,
+ /// Enable features in the underlying vhost implementation using a bit mask.
+ SET_FEATURES = 2,
+ /// Set the current Master as an owner of the session.
+ SET_OWNER = 3,
+ /// No longer used.
+ RESET_OWNER = 4,
+ /// Set the memory map regions on the slave so it can translate the vring addresses.
+ SET_MEM_TABLE = 5,
+ /// Set logging shared memory space.
+ SET_LOG_BASE = 6,
+ /// Set the logging file descriptor, which is passed as ancillary data.
+ SET_LOG_FD = 7,
+ /// Set the size of the queue.
+ SET_VRING_NUM = 8,
+ /// Set the addresses of the different aspects of the vring.
+ SET_VRING_ADDR = 9,
+ /// Set the base offset in the available vring.
+ SET_VRING_BASE = 10,
+ /// Get the available vring base offset.
+ GET_VRING_BASE = 11,
+ /// Set the event file descriptor for adding buffers to the vring.
+ SET_VRING_KICK = 12,
+ /// Set the event file descriptor to signal when buffers are used.
+ SET_VRING_CALL = 13,
+ /// Set the event file descriptor to signal when error occurs.
+ SET_VRING_ERR = 14,
+ /// Get the protocol feature bit mask from the underlying vhost implementation.
+ GET_PROTOCOL_FEATURES = 15,
+ /// Enable protocol features in the underlying vhost implementation.
+ SET_PROTOCOL_FEATURES = 16,
+ /// Query how many queues the backend supports.
+ GET_QUEUE_NUM = 17,
+ /// Signal slave to enable or disable corresponding vring.
+ SET_VRING_ENABLE = 18,
+ /// Ask vhost user backend to broadcast a fake RARP to notify the migration is terminated
+ /// for guest that does not support GUEST_ANNOUNCE.
+ SEND_RARP = 19,
+ /// Set host MTU value exposed to the guest.
+ NET_SET_MTU = 20,
+ /// Set the socket file descriptor for slave initiated requests.
+ SET_SLAVE_REQ_FD = 21,
+ /// Send IOTLB messages with struct vhost_iotlb_msg as payload.
+ IOTLB_MSG = 22,
+ /// Set the endianness of a VQ for legacy devices.
+ SET_VRING_ENDIAN = 23,
+ /// Fetch the contents of the virtio device configuration space.
+ GET_CONFIG = 24,
+ /// Change the contents of the virtio device configuration space.
+ SET_CONFIG = 25,
+ /// Create a session for crypto operation.
+ CREATE_CRYPTO_SESSION = 26,
+ /// Close a session for crypto operation.
+ CLOSE_CRYPTO_SESSION = 27,
+ /// Advise slave that a migration with postcopy enabled is underway.
+ POSTCOPY_ADVISE = 28,
+ /// Advise slave that a transition to postcopy mode has happened.
+ POSTCOPY_LISTEN = 29,
+ /// Advise that postcopy migration has now completed.
+ POSTCOPY_END = 30,
+ /// Get a shared buffer from slave.
+ GET_INFLIGHT_FD = 31,
+ /// Send the shared inflight buffer back to slave.
+ SET_INFLIGHT_FD = 32,
+ /// Sets the GPU protocol socket file descriptor.
+ GPU_SET_SOCKET = 33,
+ /// Ask the vhost user backend to disable all rings and reset all internal
+ /// device state to the initial state.
+ RESET_DEVICE = 34,
+ /// Indicate that a buffer was added to the vring instead of signalling it
+ /// using the vring’s kick file descriptor.
+ VRING_KICK = 35,
+ /// Return a u64 payload containing the maximum number of memory slots.
+ GET_MAX_MEM_SLOTS = 36,
+ /// Update the memory tables by adding the region described.
+ ADD_MEM_REG = 37,
+ /// Update the memory tables by removing the region described.
+ REM_MEM_REG = 38,
+ /// Notify the backend with updated device status as defined in the VIRTIO
+ /// specification.
+ SET_STATUS = 39,
+ /// Query the backend for its device status as defined in the VIRTIO
+ /// specification.
+ GET_STATUS = 40,
+ /// Upper bound of valid commands.
+ MAX_CMD = 41,
+}
+
+impl From<MasterReq> for u32 {
+ fn from(req: MasterReq) -> u32 {
+ req as u32
+ }
+}
+
+impl Req for MasterReq {
+ fn is_valid(value: u32) -> bool {
+ (value > MasterReq::NOOP as u32) && (value < MasterReq::MAX_CMD as u32)
+ }
+}
+
+/// Type of requests sending from slaves to masters.
+#[repr(u32)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub enum SlaveReq {
+ /// Null operation.
+ NOOP = 0,
+ /// Send IOTLB messages with struct vhost_iotlb_msg as payload.
+ IOTLB_MSG = 1,
+ /// Notify that the virtio device's configuration space has changed.
+ CONFIG_CHANGE_MSG = 2,
+ /// Set host notifier for a specified queue.
+ VRING_HOST_NOTIFIER_MSG = 3,
+ /// Indicate that a buffer was used from the vring.
+ VRING_CALL = 4,
+ /// Indicate that an error occurred on the specific vring.
+ VRING_ERR = 5,
+ /// Virtio-fs draft: map file content into the window.
+ FS_MAP = 6,
+ /// Virtio-fs draft: unmap file content from the window.
+ FS_UNMAP = 7,
+ /// Virtio-fs draft: sync file content.
+ FS_SYNC = 8,
+ /// Virtio-fs draft: perform a read/write from an fd directly to GPA.
+ FS_IO = 9,
+ /// Upper bound of valid commands.
+ MAX_CMD = 10,
+}
+
+impl From<SlaveReq> for u32 {
+ fn from(req: SlaveReq) -> u32 {
+ req as u32
+ }
+}
+
+impl Req for SlaveReq {
+ fn is_valid(value: u32) -> bool {
+ (value > SlaveReq::NOOP as u32) && (value < SlaveReq::MAX_CMD as u32)
+ }
+}
+
+/// Vhost message Validator.
+pub trait VhostUserMsgValidator {
+ /// Validate message syntax only.
+ /// It doesn't validate message semantics such as protocol version number and dependency
+ /// on feature flags etc.
+ fn is_valid(&self) -> bool {
+ true
+ }
+}
+
+// Bit mask for common message flags.
+bitflags! {
+ /// Common message flags for vhost-user requests and replies.
+ pub struct VhostUserHeaderFlag: u32 {
+ /// Bits[0..2] is message version number.
+ const VERSION = 0x3;
+ /// Mark message as reply.
+ const REPLY = 0x4;
+ /// Sender anticipates a reply message from the peer.
+ const NEED_REPLY = 0x8;
+ /// All valid bits.
+ const ALL_FLAGS = 0xc;
+ /// All reserved bits.
+ const RESERVED_BITS = !0xf;
+ }
+}
+
+/// Common message header for vhost-user requests and replies.
+/// A vhost-user message consists of 3 header fields and an optional payload. All numbers are in the
+/// machine native byte order.
+#[repr(packed)]
+#[derive(Copy)]
+pub(super) struct VhostUserMsgHeader<R: Req> {
+ request: u32,
+ flags: u32,
+ size: u32,
+ _r: PhantomData<R>,
+}
+
+impl<R: Req> Debug for VhostUserMsgHeader<R> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("VhostUserMsgHeader")
+ .field("request", &{ self.request })
+ .field("flags", &{ self.flags })
+ .field("size", &{ self.size })
+ .finish()
+ }
+}
+
+impl<R: Req> Clone for VhostUserMsgHeader<R> {
+ fn clone(&self) -> VhostUserMsgHeader<R> {
+ *self
+ }
+}
+
+impl<R: Req> PartialEq for VhostUserMsgHeader<R> {
+ fn eq(&self, other: &Self) -> bool {
+ self.request == other.request && self.flags == other.flags && self.size == other.size
+ }
+}
+
+impl<R: Req> VhostUserMsgHeader<R> {
+ /// Create a new instance of `VhostUserMsgHeader`.
+ pub fn new(request: R, flags: u32, size: u32) -> Self {
+ // Default to protocol version 1
+ let fl = (flags & VhostUserHeaderFlag::ALL_FLAGS.bits()) | 0x1;
+ VhostUserMsgHeader {
+ request: request.into(),
+ flags: fl,
+ size,
+ _r: PhantomData,
+ }
+ }
+
+ /// Get message type.
+ pub fn get_code(&self) -> Result<R> {
+ if R::is_valid(self.request) {
+ // SAFETY: It's safe because R is marked as repr(u32), and the value is valid.
+ Ok(unsafe { std::mem::transmute_copy::<u32, R>(&{ self.request }) })
+ } else {
+ Err(Error::InvalidMessage)
+ }
+ }
+
+ /// Set message type.
+ pub fn set_code(&mut self, request: R) {
+ self.request = request.into();
+ }
+
+ /// Get message version number.
+ pub fn get_version(&self) -> u32 {
+ self.flags & 0x3
+ }
+
+ /// Set message version number.
+ pub fn set_version(&mut self, ver: u32) {
+ self.flags &= !0x3;
+ self.flags |= ver & 0x3;
+ }
+
+ /// Check whether it's a reply message.
+ pub fn is_reply(&self) -> bool {
+ (self.flags & VhostUserHeaderFlag::REPLY.bits()) != 0
+ }
+
+ /// Mark message as reply.
+ pub fn set_reply(&mut self, is_reply: bool) {
+ if is_reply {
+ self.flags |= VhostUserHeaderFlag::REPLY.bits();
+ } else {
+ self.flags &= !VhostUserHeaderFlag::REPLY.bits();
+ }
+ }
+
+ /// Check whether reply for this message is requested.
+ pub fn is_need_reply(&self) -> bool {
+ (self.flags & VhostUserHeaderFlag::NEED_REPLY.bits()) != 0
+ }
+
+ /// Mark that reply for this message is needed.
+ pub fn set_need_reply(&mut self, need_reply: bool) {
+ if need_reply {
+ self.flags |= VhostUserHeaderFlag::NEED_REPLY.bits();
+ } else {
+ self.flags &= !VhostUserHeaderFlag::NEED_REPLY.bits();
+ }
+ }
+
+ /// Check whether it's the reply message for the request `req`.
+ pub fn is_reply_for(&self, req: &VhostUserMsgHeader<R>) -> bool {
+ if let (Ok(code1), Ok(code2)) = (self.get_code(), req.get_code()) {
+ self.is_reply() && !req.is_reply() && code1 == code2
+ } else {
+ false
+ }
+ }
+
+ /// Get message size.
+ pub fn get_size(&self) -> u32 {
+ self.size
+ }
+
+ /// Set message size.
+ pub fn set_size(&mut self, size: u32) {
+ self.size = size;
+ }
+}
+
+impl<R: Req> Default for VhostUserMsgHeader<R> {
+ fn default() -> Self {
+ VhostUserMsgHeader {
+ request: 0,
+ flags: 0x1,
+ size: 0,
+ _r: PhantomData,
+ }
+ }
+}
+
+// SAFETY: Safe because all fields of VhostUserMsgHeader are POD.
+unsafe impl<R: Req> ByteValued for VhostUserMsgHeader<R> {}
+
+impl<T: Req> VhostUserMsgValidator for VhostUserMsgHeader<T> {
+ #[allow(clippy::if_same_then_else)]
+ fn is_valid(&self) -> bool {
+ if self.get_code().is_err() {
+ return false;
+ } else if self.size as usize > MAX_MSG_SIZE {
+ return false;
+ } else if self.get_version() != 0x1 {
+ return false;
+ } else if (self.flags & VhostUserHeaderFlag::RESERVED_BITS.bits()) != 0 {
+ return false;
+ }
+ true
+ }
+}
+
+// Bit mask for transport specific flags in VirtIO feature set defined by vhost-user.
+bitflags! {
+ /// Transport specific flags in VirtIO feature set defined by vhost-user.
+ pub struct VhostUserVirtioFeatures: u64 {
+ /// Feature flag for the protocol feature.
+ const PROTOCOL_FEATURES = 0x4000_0000;
+ }
+}
+
+// Bit mask for vhost-user protocol feature flags.
+bitflags! {
+ /// Vhost-user protocol feature flags.
+ pub struct VhostUserProtocolFeatures: u64 {
+ /// Support multiple queues.
+ const MQ = 0x0000_0001;
+ /// Support logging through shared memory fd.
+ const LOG_SHMFD = 0x0000_0002;
+ /// Support broadcasting fake RARP packet.
+ const RARP = 0x0000_0004;
+ /// Support sending reply messages for requests with NEED_REPLY flag set.
+ const REPLY_ACK = 0x0000_0008;
+ /// Support setting MTU for virtio-net devices.
+ const MTU = 0x0000_0010;
+ /// Allow the slave to send requests to the master by an optional communication channel.
+ const SLAVE_REQ = 0x0000_0020;
+ /// Support setting slave endian by SET_VRING_ENDIAN.
+ const CROSS_ENDIAN = 0x0000_0040;
+ /// Support crypto operations.
+ const CRYPTO_SESSION = 0x0000_0080;
+ /// Support sending userfault_fd from slaves to masters.
+ const PAGEFAULT = 0x0000_0100;
+ /// Support Virtio device configuration.
+ const CONFIG = 0x0000_0200;
+ /// Allow the slave to send fds (at most 8 descriptors in each message) to the master.
+ const SLAVE_SEND_FD = 0x0000_0400;
+ /// Allow the slave to register a host notifier.
+ const HOST_NOTIFIER = 0x0000_0800;
+ /// Support inflight shmfd.
+ const INFLIGHT_SHMFD = 0x0000_1000;
+ /// Support resetting the device.
+ const RESET_DEVICE = 0x0000_2000;
+ /// Support inband notifications.
+ const INBAND_NOTIFICATIONS = 0x0000_4000;
+ /// Support configuring memory slots.
+ const CONFIGURE_MEM_SLOTS = 0x0000_8000;
+ /// Support reporting status.
+ const STATUS = 0x0001_0000;
+ /// Support Xen mmap.
+ const XEN_MMAP = 0x0002_0000;
+ }
+}
+
+/// A generic message to encapsulate a 64-bit value.
+#[repr(packed)]
+#[derive(Copy, Clone, Default)]
+pub struct VhostUserU64 {
+ /// The encapsulated 64-bit common value.
+ pub value: u64,
+}
+
+impl VhostUserU64 {
+ /// Create a new instance.
+ pub fn new(value: u64) -> Self {
+ VhostUserU64 { value }
+ }
+}
+
+// SAFETY: Safe because all fields of VhostUserU64 are POD.
+unsafe impl ByteValued for VhostUserU64 {}
+
+impl VhostUserMsgValidator for VhostUserU64 {}
+
+/// Memory region descriptor for the SET_MEM_TABLE request.
+#[repr(packed)]
+#[derive(Copy, Clone, Default)]
+pub struct VhostUserMemory {
+ /// Number of memory regions in the payload.
+ pub num_regions: u32,
+ /// Padding for alignment.
+ pub padding1: u32,
+}
+
+impl VhostUserMemory {
+ /// Create a new instance.
+ pub fn new(cnt: u32) -> Self {
+ VhostUserMemory {
+ num_regions: cnt,
+ padding1: 0,
+ }
+ }
+}
+
+// SAFETY: Safe because all fields of VhostUserMemory are POD.
+unsafe impl ByteValued for VhostUserMemory {}
+
+impl VhostUserMsgValidator for VhostUserMemory {
+ #[allow(clippy::if_same_then_else)]
+ fn is_valid(&self) -> bool {
+ if self.padding1 != 0 {
+ return false;
+ } else if self.num_regions == 0 || self.num_regions > MAX_ATTACHED_FD_ENTRIES as u32 {
+ return false;
+ }
+ true
+ }
+}
+
+/// Memory region descriptors as payload for the SET_MEM_TABLE request.
+#[repr(packed)]
+#[derive(Default, Clone, Copy)]
+pub struct VhostUserMemoryRegion {
+ /// Guest physical address of the memory region.
+ pub guest_phys_addr: u64,
+ /// Size of the memory region.
+ pub memory_size: u64,
+ /// Virtual address in the current process.
+ pub user_addr: u64,
+ /// Offset where region starts in the mapped memory.
+ pub mmap_offset: u64,
+
+ #[cfg(feature = "xen")]
+ /// Xen specific flags.
+ pub xen_mmap_flags: u32,
+
+ #[cfg(feature = "xen")]
+ /// Xen specific data.
+ pub xen_mmap_data: u32,
+}
+
+impl VhostUserMemoryRegion {
+ fn is_valid_common(&self) -> bool {
+ self.memory_size != 0
+ && self.guest_phys_addr.checked_add(self.memory_size).is_some()
+ && self.user_addr.checked_add(self.memory_size).is_some()
+ && self.mmap_offset.checked_add(self.memory_size).is_some()
+ }
+}
+
+#[cfg(not(feature = "xen"))]
+impl VhostUserMemoryRegion {
+ /// Create a new instance.
+ pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self {
+ VhostUserMemoryRegion {
+ guest_phys_addr,
+ memory_size,
+ user_addr,
+ mmap_offset,
+ }
+ }
+
+ /// Creates mmap region from Self.
+ pub fn mmap_region<B: NewBitmap>(&self, file: File) -> Result<MmapRegion<B>> {
+ MmapRegion::<B>::from_file(
+ FileOffset::new(file, self.mmap_offset),
+ self.memory_size as usize,
+ )
+ .map_err(MmapError::MmapRegion)
+ .map_err(|e| Error::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e)))
+ }
+
+ fn is_valid(&self) -> bool {
+ self.is_valid_common()
+ }
+}
+
+#[cfg(feature = "xen")]
+impl VhostUserMemoryRegion {
+ /// Create a new instance.
+ pub fn with_xen(
+ guest_phys_addr: u64,
+ memory_size: u64,
+ user_addr: u64,
+ mmap_offset: u64,
+ xen_mmap_flags: u32,
+ xen_mmap_data: u32,
+ ) -> Self {
+ VhostUserMemoryRegion {
+ guest_phys_addr,
+ memory_size,
+ user_addr,
+ mmap_offset,
+ xen_mmap_flags,
+ xen_mmap_data,
+ }
+ }
+
+ /// Creates mmap region from Self.
+ pub fn mmap_region<B: NewBitmap>(&self, file: File) -> Result<MmapRegion<B>> {
+ let range = MmapRange::new(
+ self.memory_size as usize,
+ Some(FileOffset::new(file, self.mmap_offset)),
+ GuestAddress(self.guest_phys_addr),
+ self.xen_mmap_flags,
+ self.xen_mmap_data,
+ );
+
+ MmapRegion::<B>::from_range(range)
+ .map_err(MmapError::MmapRegion)
+ .map_err(|e| Error::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e)))
+ }
+
+ fn is_valid(&self) -> bool {
+ if !self.is_valid_common() {
+ false
+ } else {
+ // Only of one of FOREIGN or GRANT should be set.
+ match MmapXenFlags::from_bits(self.xen_mmap_flags) {
+ Some(flags) => flags.is_valid(),
+ None => false,
+ }
+ }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserMemoryRegion {
+ fn is_valid(&self) -> bool {
+ self.is_valid()
+ }
+}
+
+/// Payload of the VhostUserMemory message.
+pub type VhostUserMemoryPayload = Vec<VhostUserMemoryRegion>;
+
+/// Single memory region descriptor as payload for ADD_MEM_REG and REM_MEM_REG
+/// requests.
+#[repr(C)]
+#[derive(Default, Clone, Copy)]
+pub struct VhostUserSingleMemoryRegion {
+ /// Padding for correct alignment
+ padding: u64,
+ /// General memory region
+ region: VhostUserMemoryRegion,
+}
+
+impl Deref for VhostUserSingleMemoryRegion {
+ type Target = VhostUserMemoryRegion;
+
+ fn deref(&self) -> &VhostUserMemoryRegion {
+ &self.region
+ }
+}
+
+#[cfg(not(feature = "xen"))]
+impl VhostUserSingleMemoryRegion {
+ /// Create a new instance.
+ pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self {
+ VhostUserSingleMemoryRegion {
+ padding: 0,
+ region: VhostUserMemoryRegion::new(
+ guest_phys_addr,
+ memory_size,
+ user_addr,
+ mmap_offset,
+ ),
+ }
+ }
+}
+
+#[cfg(feature = "xen")]
+impl VhostUserSingleMemoryRegion {
+ /// Create a new instance.
+ pub fn new(
+ guest_phys_addr: u64,
+ memory_size: u64,
+ user_addr: u64,
+ mmap_offset: u64,
+ xen_mmap_flags: u32,
+ xen_mmap_data: u32,
+ ) -> Self {
+ VhostUserSingleMemoryRegion {
+ padding: 0,
+ region: VhostUserMemoryRegion::with_xen(
+ guest_phys_addr,
+ memory_size,
+ user_addr,
+ mmap_offset,
+ xen_mmap_flags,
+ xen_mmap_data,
+ ),
+ }
+ }
+}
+
+// SAFETY: Safe because all fields of VhostUserSingleMemoryRegion are POD.
+unsafe impl ByteValued for VhostUserSingleMemoryRegion {}
+impl VhostUserMsgValidator for VhostUserSingleMemoryRegion {}
+
+/// Vring state descriptor.
+#[repr(packed)]
+#[derive(Copy, Clone, Default)]
+pub struct VhostUserVringState {
+ /// Vring index.
+ pub index: u32,
+ /// A common 32bit value to encapsulate vring state etc.
+ pub num: u32,
+}
+
+impl VhostUserVringState {
+ /// Create a new instance.
+ pub fn new(index: u32, num: u32) -> Self {
+ VhostUserVringState { index, num }
+ }
+}
+
+// SAFETY: Safe because all fields of VhostUserVringState are POD.
+unsafe impl ByteValued for VhostUserVringState {}
+
+impl VhostUserMsgValidator for VhostUserVringState {}
+
+// Bit mask for vring address flags.
+bitflags! {
+ /// Flags for vring address.
+ pub struct VhostUserVringAddrFlags: u32 {
+ /// Support log of vring operations.
+ /// Modifications to "used" vring should be logged.
+ const VHOST_VRING_F_LOG = 0x1;
+ }
+}
+
+/// Vring address descriptor.
+#[repr(packed)]
+#[derive(Copy, Clone, Default)]
+pub struct VhostUserVringAddr {
+ /// Vring index.
+ pub index: u32,
+ /// Vring flags defined by VhostUserVringAddrFlags.
+ pub flags: u32,
+ /// Ring address of the vring descriptor table.
+ pub descriptor: u64,
+ /// Ring address of the vring used ring.
+ pub used: u64,
+ /// Ring address of the vring available ring.
+ pub available: u64,
+ /// Guest address for logging.
+ pub log: u64,
+}
+
+impl VhostUserVringAddr {
+ /// Create a new instance.
+ pub fn new(
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Self {
+ VhostUserVringAddr {
+ index,
+ flags: flags.bits(),
+ descriptor,
+ used,
+ available,
+ log,
+ }
+ }
+
+ /// Create a new instance from `VringConfigData`.
+ #[cfg_attr(feature = "cargo-clippy", allow(clippy::useless_conversion))]
+ pub fn from_config_data(index: u32, config_data: &VringConfigData) -> Self {
+ let log_addr = config_data.log_addr.unwrap_or(0);
+ VhostUserVringAddr {
+ index,
+ flags: config_data.flags,
+ descriptor: config_data.desc_table_addr,
+ used: config_data.used_ring_addr,
+ available: config_data.avail_ring_addr,
+ log: log_addr,
+ }
+ }
+}
+
+// SAFETY: Safe because all fields of VhostUserVringAddr are POD.
+unsafe impl ByteValued for VhostUserVringAddr {}
+
+impl VhostUserMsgValidator for VhostUserVringAddr {
+ #[allow(clippy::if_same_then_else)]
+ fn is_valid(&self) -> bool {
+ if (self.flags & !VhostUserVringAddrFlags::all().bits()) != 0 {
+ return false;
+ } else if self.descriptor & 0xf != 0 {
+ return false;
+ } else if self.available & 0x1 != 0 {
+ return false;
+ } else if self.used & 0x3 != 0 {
+ return false;
+ }
+ true
+ }
+}
+
+// Bit mask for the vhost-user device configuration message.
+bitflags! {
+ /// Flags for the device configuration message.
+ pub struct VhostUserConfigFlags: u32 {
+ /// Vhost master messages used for writeable fields.
+ const WRITABLE = 0x1;
+ /// Vhost master messages used for live migration.
+ const LIVE_MIGRATION = 0x2;
+ }
+}
+
+/// Message to read/write device configuration space.
+#[repr(packed)]
+#[derive(Copy, Clone, Default)]
+pub struct VhostUserConfig {
+ /// Offset of virtio device's configuration space.
+ pub offset: u32,
+ /// Configuration space access size in bytes.
+ pub size: u32,
+ /// Flags for the device configuration operation.
+ pub flags: u32,
+}
+
+impl VhostUserConfig {
+ /// Create a new instance.
+ pub fn new(offset: u32, size: u32, flags: VhostUserConfigFlags) -> Self {
+ VhostUserConfig {
+ offset,
+ size,
+ flags: flags.bits(),
+ }
+ }
+}
+
+// SAFETY: Safe because all fields of VhostUserConfig are POD.
+unsafe impl ByteValued for VhostUserConfig {}
+
+impl VhostUserMsgValidator for VhostUserConfig {
+ #[allow(clippy::if_same_then_else)]
+ fn is_valid(&self) -> bool {
+ let end_addr = match self.size.checked_add(self.offset) {
+ Some(addr) => addr,
+ None => return false,
+ };
+ if (self.flags & !VhostUserConfigFlags::all().bits()) != 0 {
+ return false;
+ } else if self.size == 0 || end_addr > VHOST_USER_CONFIG_SIZE {
+ return false;
+ }
+ true
+ }
+}
+
+/// Payload for the VhostUserConfig message.
+pub type VhostUserConfigPayload = Vec<u8>;
+
+/// Single memory region descriptor as payload for ADD_MEM_REG and REM_MEM_REG
+/// requests.
+#[repr(C)]
+#[derive(Copy, Clone, Default)]
+pub struct VhostUserInflight {
+ /// Size of the area to track inflight I/O.
+ pub mmap_size: u64,
+ /// Offset of this area from the start of the supplied file descriptor.
+ pub mmap_offset: u64,
+ /// Number of virtqueues.
+ pub num_queues: u16,
+ /// Size of virtqueues.
+ pub queue_size: u16,
+}
+
+impl VhostUserInflight {
+ /// Create a new instance.
+ pub fn new(mmap_size: u64, mmap_offset: u64, num_queues: u16, queue_size: u16) -> Self {
+ VhostUserInflight {
+ mmap_size,
+ mmap_offset,
+ num_queues,
+ queue_size,
+ }
+ }
+}
+
+// SAFETY: Safe because all fields of VhostUserInflight are POD.
+unsafe impl ByteValued for VhostUserInflight {}
+
+impl VhostUserMsgValidator for VhostUserInflight {
+ fn is_valid(&self) -> bool {
+ if self.num_queues == 0 || self.queue_size == 0 {
+ return false;
+ }
+ true
+ }
+}
+
+/// Single memory region descriptor as payload for SET_LOG_BASE request.
+#[repr(C)]
+#[derive(Copy, Clone, Default)]
+pub struct VhostUserLog {
+ /// Size of the area to log dirty pages.
+ pub mmap_size: u64,
+ /// Offset of this area from the start of the supplied file descriptor.
+ pub mmap_offset: u64,
+}
+
+impl VhostUserLog {
+ /// Create a new instance.
+ pub fn new(mmap_size: u64, mmap_offset: u64) -> Self {
+ VhostUserLog {
+ mmap_size,
+ mmap_offset,
+ }
+ }
+}
+
+// SAFETY: Safe because all fields of VhostUserLog are POD.
+unsafe impl ByteValued for VhostUserLog {}
+
+impl VhostUserMsgValidator for VhostUserLog {
+ fn is_valid(&self) -> bool {
+ if self.mmap_size == 0 || self.mmap_offset.checked_add(self.mmap_size).is_none() {
+ return false;
+ }
+ true
+ }
+}
+
+/*
+ * TODO: support dirty log, live migration and IOTLB operations.
+#[repr(packed)]
+pub struct VhostUserVringArea {
+ pub index: u32,
+ pub flags: u32,
+ pub size: u64,
+ pub offset: u64,
+}
+
+#[repr(packed)]
+pub struct VhostUserLog {
+ pub size: u64,
+ pub offset: u64,
+}
+
+#[repr(packed)]
+pub struct VhostUserIotlb {
+ pub iova: u64,
+ pub size: u64,
+ pub user_addr: u64,
+ pub permission: u8,
+ pub optype: u8,
+}
+*/
+
+// Bit mask for flags in virtio-fs slave messages
+bitflags! {
+ #[derive(Default)]
+ /// Flags for virtio-fs slave messages.
+ pub struct VhostUserFSSlaveMsgFlags: u64 {
+ /// Empty permission.
+ const EMPTY = 0x0;
+ /// Read permission.
+ const MAP_R = 0x1;
+ /// Write permission.
+ const MAP_W = 0x2;
+ }
+}
+
+/// Max entries in one virtio-fs slave request.
+pub const VHOST_USER_FS_SLAVE_ENTRIES: usize = 8;
+
+/// Slave request message to update the MMIO window.
+#[repr(packed)]
+#[derive(Copy, Clone, Default)]
+pub struct VhostUserFSSlaveMsg {
+ /// File offset.
+ pub fd_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
+ /// Offset into the DAX window.
+ pub cache_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
+ /// Size of region to map.
+ pub len: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
+ /// Flags for the mmap operation
+ pub flags: [VhostUserFSSlaveMsgFlags; VHOST_USER_FS_SLAVE_ENTRIES],
+}
+
+// SAFETY: Safe because all fields of VhostUserFSSlaveMsg are POD.
+unsafe impl ByteValued for VhostUserFSSlaveMsg {}
+
+impl VhostUserMsgValidator for VhostUserFSSlaveMsg {
+ fn is_valid(&self) -> bool {
+ for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
+ if ({ self.flags[i] }.bits() & !VhostUserFSSlaveMsgFlags::all().bits()) != 0
+ || self.fd_offset[i].checked_add(self.len[i]).is_none()
+ || self.cache_offset[i].checked_add(self.len[i]).is_none()
+ {
+ return false;
+ }
+ }
+ true
+ }
+}
+
+/// Inflight I/O descriptor state for split virtqueues
+#[repr(packed)]
+#[derive(Clone, Copy, Default)]
+pub struct DescStateSplit {
+ /// Indicate whether this descriptor (only head) is inflight or not.
+ pub inflight: u8,
+ /// Padding
+ padding: [u8; 5],
+ /// List of last batch of used descriptors, only when batching is used for submitting
+ pub next: u16,
+ /// Preserve order of fetching available descriptors, only for head descriptor
+ pub counter: u64,
+}
+
+impl DescStateSplit {
+ /// New instance of DescStateSplit struct
+ pub fn new() -> Self {
+ Self::default()
+ }
+}
+
+/// Inflight I/O queue region for split virtqueues
+#[repr(packed)]
+pub struct QueueRegionSplit {
+ /// Features flags of this region
+ pub features: u64,
+ /// Version of this region
+ pub version: u16,
+ /// Number of DescStateSplit entries
+ pub desc_num: u16,
+ /// List to track last batch of used descriptors
+ pub last_batch_head: u16,
+ /// Idx value of used ring
+ pub used_idx: u16,
+ /// Pointer to an array of DescStateSplit entries
+ pub desc: u64,
+}
+
+impl QueueRegionSplit {
+ /// New instance of QueueRegionSplit struct
+ pub fn new(features: u64, queue_size: u16) -> Self {
+ QueueRegionSplit {
+ features,
+ version: 1,
+ desc_num: queue_size,
+ last_batch_head: 0,
+ used_idx: 0,
+ desc: 0,
+ }
+ }
+}
+
+/// Inflight I/O descriptor state for packed virtqueues
+#[repr(packed)]
+#[derive(Clone, Copy, Default)]
+pub struct DescStatePacked {
+ /// Indicate whether this descriptor (only head) is inflight or not.
+ pub inflight: u8,
+ /// Padding
+ padding: u8,
+ /// Link to next free entry
+ pub next: u16,
+ /// Link to last entry of descriptor list, only for head
+ pub last: u16,
+ /// Length of descriptor list, only for head
+ pub num: u16,
+ /// Preserve order of fetching avail descriptors, only for head
+ pub counter: u64,
+ /// Buffer ID
+ pub id: u16,
+ /// Descriptor flags
+ pub flags: u16,
+ /// Buffer length
+ pub len: u32,
+ /// Buffer address
+ pub addr: u64,
+}
+
+impl DescStatePacked {
+ /// New instance of DescStatePacked struct
+ pub fn new() -> Self {
+ Self::default()
+ }
+}
+
+/// Inflight I/O queue region for packed virtqueues
+#[repr(packed)]
+pub struct QueueRegionPacked {
+ /// Features flags of this region
+ pub features: u64,
+ /// version of this region
+ pub version: u16,
+ /// size of descriptor state array
+ pub desc_num: u16,
+ /// head of free DescStatePacked entry list
+ pub free_head: u16,
+ /// old head of free DescStatePacked entry list
+ pub old_free_head: u16,
+ /// used idx of descriptor ring
+ pub used_idx: u16,
+ /// old used idx of descriptor ring
+ pub old_used_idx: u16,
+ /// device ring wrap counter
+ pub used_wrap_counter: u8,
+ /// old device ring wrap counter
+ pub old_used_wrap_counter: u8,
+ /// Padding
+ padding: [u8; 7],
+ /// Pointer to array tracking state of each descriptor from descriptor ring
+ pub desc: u64,
+}
+
+impl QueueRegionPacked {
+ /// New instance of QueueRegionPacked struct
+ pub fn new(features: u64, queue_size: u16) -> Self {
+ QueueRegionPacked {
+ features,
+ version: 1,
+ desc_num: queue_size,
+ free_head: 0,
+ old_free_head: 0,
+ used_idx: 0,
+ old_used_idx: 0,
+ used_wrap_counter: 0,
+ old_used_wrap_counter: 0,
+ padding: [0; 7],
+ desc: 0,
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::mem;
+
+ #[cfg(feature = "xen")]
+ impl VhostUserMemoryRegion {
+ fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self {
+ Self::with_xen(
+ guest_phys_addr,
+ memory_size,
+ user_addr,
+ mmap_offset,
+ MmapXenFlags::FOREIGN.bits(),
+ 0,
+ )
+ }
+ }
+
+ #[test]
+ fn check_master_request_code() {
+ assert!(!MasterReq::is_valid(MasterReq::NOOP as _));
+ assert!(!MasterReq::is_valid(MasterReq::MAX_CMD as _));
+ assert!(MasterReq::MAX_CMD > MasterReq::NOOP);
+ let code = MasterReq::GET_FEATURES;
+ assert!(MasterReq::is_valid(code as _));
+ assert_eq!(code, code.clone());
+ assert!(!MasterReq::is_valid(10000));
+ }
+
+ #[test]
+ fn check_slave_request_code() {
+ assert!(!SlaveReq::is_valid(SlaveReq::NOOP as _));
+ assert!(!SlaveReq::is_valid(SlaveReq::MAX_CMD as _));
+ assert!(SlaveReq::MAX_CMD > SlaveReq::NOOP);
+ let code = SlaveReq::CONFIG_CHANGE_MSG;
+ assert!(SlaveReq::is_valid(code as _));
+ assert_eq!(code, code.clone());
+ assert!(!SlaveReq::is_valid(10000));
+ }
+
+ #[test]
+ fn msg_header_ops() {
+ let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, 0x100);
+ assert_eq!(hdr.get_code().unwrap(), MasterReq::GET_FEATURES);
+ hdr.set_code(MasterReq::SET_FEATURES);
+ assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_FEATURES);
+
+ assert_eq!(hdr.get_version(), 0x1);
+
+ assert!(!hdr.is_reply());
+ hdr.set_reply(true);
+ assert!(hdr.is_reply());
+ hdr.set_reply(false);
+
+ assert!(!hdr.is_need_reply());
+ hdr.set_need_reply(true);
+ assert!(hdr.is_need_reply());
+ hdr.set_need_reply(false);
+
+ assert_eq!(hdr.get_size(), 0x100);
+ hdr.set_size(0x200);
+ assert_eq!(hdr.get_size(), 0x200);
+
+ assert!(!hdr.is_need_reply());
+ assert!(!hdr.is_reply());
+ assert_eq!(hdr.get_version(), 0x1);
+
+ // Check message length
+ assert!(hdr.is_valid());
+ hdr.set_size(0x2000);
+ assert!(!hdr.is_valid());
+ hdr.set_size(0x100);
+ assert_eq!(hdr.get_size(), 0x100);
+ assert!(hdr.is_valid());
+ hdr.set_size((MAX_MSG_SIZE - mem::size_of::<VhostUserMsgHeader<MasterReq>>()) as u32);
+ assert!(hdr.is_valid());
+ hdr.set_size(0x0);
+ assert!(hdr.is_valid());
+
+ // Check version
+ hdr.set_version(0x0);
+ assert!(!hdr.is_valid());
+ hdr.set_version(0x2);
+ assert!(!hdr.is_valid());
+ hdr.set_version(0x1);
+ assert!(hdr.is_valid());
+
+ // Test Debug, Clone, PartiaEq trait
+ assert_eq!(hdr, hdr.clone());
+ assert_eq!(hdr.clone().get_code().unwrap(), hdr.get_code().unwrap());
+ assert_eq!(format!("{:?}", hdr.clone()), format!("{:?}", hdr));
+ }
+
+ #[test]
+ fn test_vhost_user_message_u64() {
+ let val = VhostUserU64::default();
+ let val1 = VhostUserU64::new(0);
+
+ let a = val.value;
+ let b = val1.value;
+ assert_eq!(a, b);
+ let a = VhostUserU64::new(1).value;
+ assert_eq!(a, 1);
+ }
+
+ #[test]
+ fn check_user_memory() {
+ let mut msg = VhostUserMemory::new(1);
+ assert!(msg.is_valid());
+ msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32;
+ assert!(msg.is_valid());
+
+ msg.num_regions += 1;
+ assert!(!msg.is_valid());
+ msg.num_regions = 0xFFFFFFFF;
+ assert!(!msg.is_valid());
+ msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32;
+ msg.padding1 = 1;
+ assert!(!msg.is_valid());
+ }
+
+ #[test]
+ fn check_user_memory_region() {
+ let mut msg = VhostUserMemoryRegion::new(0, 0x1000, 0, 0);
+ assert!(msg.is_valid());
+ msg.guest_phys_addr = 0xFFFFFFFFFFFFEFFF;
+ assert!(msg.is_valid());
+ msg.guest_phys_addr = 0xFFFFFFFFFFFFF000;
+ assert!(!msg.is_valid());
+ msg.guest_phys_addr = 0xFFFFFFFFFFFF0000;
+ msg.memory_size = 0;
+ assert!(!msg.is_valid());
+ let a = msg.guest_phys_addr;
+ let b = msg.guest_phys_addr;
+ assert_eq!(a, b);
+
+ let msg = VhostUserMemoryRegion::default();
+ let a = msg.guest_phys_addr;
+ assert_eq!(a, 0);
+ let a = msg.memory_size;
+ assert_eq!(a, 0);
+ let a = msg.user_addr;
+ assert_eq!(a, 0);
+ let a = msg.mmap_offset;
+ assert_eq!(a, 0);
+ }
+
+ #[test]
+ fn test_vhost_user_state() {
+ let state = VhostUserVringState::new(5, 8);
+
+ let a = state.index;
+ assert_eq!(a, 5);
+ let a = state.num;
+ assert_eq!(a, 8);
+ assert!(state.is_valid());
+
+ let state = VhostUserVringState::default();
+ let a = state.index;
+ assert_eq!(a, 0);
+ let a = state.num;
+ assert_eq!(a, 0);
+ assert!(state.is_valid());
+ }
+
+ #[test]
+ fn test_vhost_user_addr() {
+ let mut addr = VhostUserVringAddr::new(
+ 2,
+ VhostUserVringAddrFlags::VHOST_VRING_F_LOG,
+ 0x1000,
+ 0x2000,
+ 0x3000,
+ 0x4000,
+ );
+
+ let a = addr.index;
+ assert_eq!(a, 2);
+ let a = addr.flags;
+ assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits());
+ let a = addr.descriptor;
+ assert_eq!(a, 0x1000);
+ let a = addr.used;
+ assert_eq!(a, 0x2000);
+ let a = addr.available;
+ assert_eq!(a, 0x3000);
+ let a = addr.log;
+ assert_eq!(a, 0x4000);
+ assert!(addr.is_valid());
+
+ addr.descriptor = 0x1001;
+ assert!(!addr.is_valid());
+ addr.descriptor = 0x1000;
+
+ addr.available = 0x3001;
+ assert!(!addr.is_valid());
+ addr.available = 0x3000;
+
+ addr.used = 0x2001;
+ assert!(!addr.is_valid());
+ addr.used = 0x2000;
+ assert!(addr.is_valid());
+ }
+
+ #[test]
+ fn test_vhost_user_state_from_config() {
+ let config = VringConfigData {
+ queue_max_size: 256,
+ queue_size: 128,
+ flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: Some(0x4000),
+ };
+ let addr = VhostUserVringAddr::from_config_data(2, &config);
+
+ let a = addr.index;
+ assert_eq!(a, 2);
+ let a = addr.flags;
+ assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits());
+ let a = addr.descriptor;
+ assert_eq!(a, 0x1000);
+ let a = addr.used;
+ assert_eq!(a, 0x2000);
+ let a = addr.available;
+ assert_eq!(a, 0x3000);
+ let a = addr.log;
+ assert_eq!(a, 0x4000);
+ assert!(addr.is_valid());
+ }
+
+ #[test]
+ fn check_user_vring_addr() {
+ let mut msg =
+ VhostUserVringAddr::new(0, VhostUserVringAddrFlags::all(), 0x0, 0x0, 0x0, 0x0);
+ assert!(msg.is_valid());
+
+ msg.descriptor = 1;
+ assert!(!msg.is_valid());
+ msg.descriptor = 0;
+
+ msg.available = 1;
+ assert!(!msg.is_valid());
+ msg.available = 0;
+
+ msg.used = 1;
+ assert!(!msg.is_valid());
+ msg.used = 0;
+
+ msg.flags |= 0x80000000;
+ assert!(!msg.is_valid());
+ msg.flags &= !0x80000000;
+ }
+
+ #[test]
+ fn check_user_config_msg() {
+ let mut msg =
+ VhostUserConfig::new(0, VHOST_USER_CONFIG_SIZE, VhostUserConfigFlags::WRITABLE);
+
+ assert!(msg.is_valid());
+ msg.size = 0;
+ assert!(!msg.is_valid());
+ msg.size = 1;
+ assert!(msg.is_valid());
+ msg.offset = u32::MAX;
+ assert!(!msg.is_valid());
+ msg.offset = VHOST_USER_CONFIG_SIZE;
+ assert!(!msg.is_valid());
+ msg.offset = VHOST_USER_CONFIG_SIZE - 1;
+ assert!(msg.is_valid());
+ msg.size = 2;
+ assert!(!msg.is_valid());
+ msg.size = 1;
+ msg.flags |= VhostUserConfigFlags::LIVE_MIGRATION.bits();
+ assert!(msg.is_valid());
+ msg.flags |= 0x4;
+ assert!(!msg.is_valid());
+ }
+
+ #[test]
+ fn test_vhost_user_fs_slave() {
+ let mut fs_slave = VhostUserFSSlaveMsg::default();
+
+ assert!(fs_slave.is_valid());
+
+ fs_slave.fd_offset[0] = 0xffff_ffff_ffff_ffff;
+ fs_slave.len[0] = 0x1;
+ assert!(!fs_slave.is_valid());
+
+ assert_ne!(
+ VhostUserFSSlaveMsgFlags::MAP_R,
+ VhostUserFSSlaveMsgFlags::MAP_W
+ );
+ assert_eq!(VhostUserFSSlaveMsgFlags::EMPTY.bits(), 0);
+ }
+}
diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs
new file mode 100644
index 0000000..7df51f6
--- /dev/null
+++ b/src/vhost_user/mod.rs
@@ -0,0 +1,540 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! The protocol for vhost-user is based on the existing implementation of vhost for the Linux
+//! Kernel. The protocol defines two sides of the communication, master and slave. Master is
+//! the application that shares its virtqueues. Slave is the consumer of the virtqueues.
+//!
+//! The communication channel between the master and the slave includes two sub channels. One is
+//! used to send requests from the master to the slave and optional replies from the slave to the
+//! master. This sub channel is created on master startup by connecting to the slave service
+//! endpoint. The other is used to send requests from the slave to the master and optional replies
+//! from the master to the slave. This sub channel is created by the master issuing a
+//! VHOST_USER_SET_SLAVE_REQ_FD request to the slave with an auxiliary file descriptor.
+//!
+//! Unix domain socket is used as the underlying communication channel because the master needs to
+//! send file descriptors to the slave.
+//!
+//! Most messages that can be sent via the Unix domain socket implementing vhost-user have an
+//! equivalent ioctl to the kernel implementation.
+
+use std::fs::File;
+use std::io::Error as IOError;
+
+pub mod message;
+pub use self::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures};
+
+mod connection;
+pub use self::connection::Listener;
+
+#[cfg(feature = "vhost-user-master")]
+mod master;
+#[cfg(feature = "vhost-user-master")]
+pub use self::master::{Master, VhostUserMaster};
+#[cfg(feature = "vhost-user")]
+mod master_req_handler;
+#[cfg(feature = "vhost-user")]
+pub use self::master_req_handler::{
+ MasterReqHandler, VhostUserMasterReqHandler, VhostUserMasterReqHandlerMut,
+};
+
+#[cfg(feature = "vhost-user-slave")]
+mod slave;
+#[cfg(feature = "vhost-user-slave")]
+pub use self::slave::SlaveListener;
+#[cfg(feature = "vhost-user-slave")]
+mod slave_req_handler;
+#[cfg(feature = "vhost-user-slave")]
+pub use self::slave_req_handler::{
+ SlaveReqHandler, VhostUserSlaveReqHandler, VhostUserSlaveReqHandlerMut,
+};
+#[cfg(feature = "vhost-user-slave")]
+mod slave_req;
+#[cfg(feature = "vhost-user-slave")]
+pub use self::slave_req::Slave;
+
+/// Errors for vhost-user operations
+#[derive(Debug)]
+pub enum Error {
+ /// Invalid parameters.
+ InvalidParam,
+ /// Invalid operation due to some reason
+ InvalidOperation(&'static str),
+ /// Unsupported operation due to missing feature
+ InactiveFeature(VhostUserVirtioFeatures),
+ /// Unsupported operations due to that the protocol feature hasn't been negotiated.
+ InactiveOperation(VhostUserProtocolFeatures),
+ /// Invalid message format, flag or content.
+ InvalidMessage,
+ /// Only part of a message have been sent or received successfully
+ PartialMessage,
+ /// The peer disconnected from the socket.
+ Disconnected,
+ /// Message is too large
+ OversizedMsg,
+ /// Fd array in question is too big or too small
+ IncorrectFds,
+ /// Can't connect to peer.
+ SocketConnect(std::io::Error),
+ /// Generic socket errors.
+ SocketError(std::io::Error),
+ /// The socket is broken or has been closed.
+ SocketBroken(std::io::Error),
+ /// Should retry the socket operation again.
+ SocketRetry(std::io::Error),
+ /// Failure from the slave side.
+ SlaveInternalError,
+ /// Failure from the master side.
+ MasterInternalError,
+ /// Virtio/protocol features mismatch.
+ FeatureMismatch,
+ /// Error from request handler
+ ReqHandlerError(IOError),
+ /// memfd file creation error
+ MemFdCreateError,
+ /// File truncate error
+ FileTrucateError,
+ /// memfd file seal errors
+ MemFdSealError,
+}
+
+impl std::fmt::Display for Error {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ match self {
+ Error::InvalidParam => write!(f, "invalid parameters"),
+ Error::InvalidOperation(reason) => write!(f, "invalid operation: {}", reason),
+ Error::InactiveFeature(bits) => write!(f, "inactive feature: {}", bits.bits()),
+ Error::InactiveOperation(bits) => {
+ write!(f, "inactive protocol operation: {}", bits.bits())
+ }
+ Error::InvalidMessage => write!(f, "invalid message"),
+ Error::PartialMessage => write!(f, "partial message"),
+ Error::Disconnected => write!(f, "peer disconnected"),
+ Error::OversizedMsg => write!(f, "oversized message"),
+ Error::IncorrectFds => write!(f, "wrong number of attached fds"),
+ Error::SocketError(e) => write!(f, "socket error: {}", e),
+ Error::SocketConnect(e) => write!(f, "can't connect to peer: {}", e),
+ Error::SocketBroken(e) => write!(f, "socket is broken: {}", e),
+ Error::SocketRetry(e) => write!(f, "temporary socket error: {}", e),
+ Error::SlaveInternalError => write!(f, "slave internal error"),
+ Error::MasterInternalError => write!(f, "Master internal error"),
+ Error::FeatureMismatch => write!(f, "virtio/protocol features mismatch"),
+ Error::ReqHandlerError(e) => write!(f, "handler failed to handle request: {}", e),
+ Error::MemFdCreateError => {
+ write!(f, "handler failed to allocate memfd during get_inflight_fd")
+ }
+ Error::FileTrucateError => {
+ write!(f, "handler failed to trucate memfd during get_inflight_fd")
+ }
+ Error::MemFdSealError => write!(
+ f,
+ "handler failed to apply seals to memfd during get_inflight_fd"
+ ),
+ }
+ }
+}
+
+impl std::error::Error for Error {}
+
+impl Error {
+ /// Determine whether to rebuild the underline communication channel.
+ pub fn should_reconnect(&self) -> bool {
+ match *self {
+ // Should reconnect because it may be caused by temporary network errors.
+ Error::PartialMessage => true,
+ // Should reconnect because the underline socket is broken.
+ Error::SocketBroken(_) => true,
+ // Slave internal error, hope it recovers on reconnect.
+ Error::SlaveInternalError => true,
+ // Master internal error, hope it recovers on reconnect.
+ Error::MasterInternalError => true,
+ // Should just retry the IO operation instead of rebuilding the underline connection.
+ Error::SocketRetry(_) => false,
+ // Looks like the peer deliberately disconnected the socket.
+ Error::Disconnected => false,
+ Error::InvalidParam | Error::InvalidOperation(_) => false,
+ Error::InactiveFeature(_) | Error::InactiveOperation(_) => false,
+ Error::InvalidMessage | Error::IncorrectFds | Error::OversizedMsg => false,
+ Error::SocketError(_) | Error::SocketConnect(_) => false,
+ Error::FeatureMismatch => false,
+ Error::ReqHandlerError(_) => false,
+ Error::MemFdCreateError | Error::FileTrucateError | Error::MemFdSealError => false,
+ }
+ }
+}
+
+impl std::convert::From<vmm_sys_util::errno::Error> for Error {
+ /// Convert raw socket errors into meaningful vhost-user errors.
+ ///
+ /// The vmm_sys_util::errno::Error is a simple wrapper over the raw errno, which doesn't means
+ /// much to the vhost-user connection manager. So convert it into meaningful errors to simplify
+ /// the connection manager logic.
+ ///
+ /// # Return:
+ /// * - Error::SocketRetry: temporary error caused by signals or short of resources.
+ /// * - Error::SocketBroken: the underline socket is broken.
+ /// * - Error::SocketError: other socket related errors.
+ #[allow(unreachable_patterns)] // EWOULDBLOCK equals to EGAIN on linux
+ fn from(err: vmm_sys_util::errno::Error) -> Self {
+ match err.errno() {
+ // The socket is marked nonblocking and the requested operation would block.
+ libc::EAGAIN => Error::SocketRetry(IOError::from_raw_os_error(libc::EAGAIN)),
+ // The socket is marked nonblocking and the requested operation would block.
+ libc::EWOULDBLOCK => Error::SocketRetry(IOError::from_raw_os_error(libc::EWOULDBLOCK)),
+ // A signal occurred before any data was transmitted
+ libc::EINTR => Error::SocketRetry(IOError::from_raw_os_error(libc::EINTR)),
+ // The output queue for a network interface was full. This generally indicates
+ // that the interface has stopped sending, but may be caused by transient congestion.
+ libc::ENOBUFS => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOBUFS)),
+ // No memory available.
+ libc::ENOMEM => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOMEM)),
+ // Connection reset by peer.
+ libc::ECONNRESET => Error::SocketBroken(IOError::from_raw_os_error(libc::ECONNRESET)),
+ // The local end has been shut down on a connection oriented socket. In this case the
+ // process will also receive a SIGPIPE unless MSG_NOSIGNAL is set.
+ libc::EPIPE => Error::SocketBroken(IOError::from_raw_os_error(libc::EPIPE)),
+ // Write permission is denied on the destination socket file, or search permission is
+ // denied for one of the directories the path prefix.
+ libc::EACCES => Error::SocketConnect(IOError::from_raw_os_error(libc::EACCES)),
+ // Catch all other errors
+ e => Error::SocketError(IOError::from_raw_os_error(e)),
+ }
+ }
+}
+
+/// Result of vhost-user operations
+pub type Result<T> = std::result::Result<T, Error>;
+
+/// Result of request handler.
+pub type HandlerResult<T> = std::result::Result<T, IOError>;
+
+/// Utility function to take the first element from option of a vector of files.
+/// Returns `None` if the vector contains no file or more than one file.
+pub(crate) fn take_single_file(files: Option<Vec<File>>) -> Option<File> {
+ let mut files = files?;
+ if files.len() != 1 {
+ return None;
+ }
+ Some(files.swap_remove(0))
+}
+
+#[cfg(all(test, feature = "vhost-user-slave"))]
+mod dummy_slave;
+
+#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))]
+mod tests {
+ use std::fs::File;
+ use std::os::unix::io::AsRawFd;
+ use std::path::{Path, PathBuf};
+ use std::sync::{Arc, Barrier, Mutex};
+ use std::thread;
+ use vmm_sys_util::rand::rand_alphanumerics;
+ use vmm_sys_util::tempfile::TempFile;
+
+ use super::dummy_slave::{DummySlaveReqHandler, VIRTIO_FEATURES};
+ use super::message::*;
+ use super::*;
+ use crate::backend::VhostBackend;
+ use crate::{VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData};
+
+ fn temp_path() -> PathBuf {
+ PathBuf::from(format!(
+ "/tmp/vhost_test_{}",
+ rand_alphanumerics(8).to_str().unwrap()
+ ))
+ }
+
+ fn create_slave<P, S>(path: P, backend: Arc<S>) -> (Master, SlaveReqHandler<S>)
+ where
+ P: AsRef<Path>,
+ S: VhostUserSlaveReqHandler,
+ {
+ let listener = Listener::new(&path, true).unwrap();
+ let mut slave_listener = SlaveListener::new(listener, backend).unwrap();
+ let master = Master::connect(&path, 1).unwrap();
+ (master, slave_listener.accept().unwrap().unwrap())
+ }
+
+ #[test]
+ fn create_dummy_slave() {
+ let slave = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+
+ slave.set_owner().unwrap();
+ assert!(slave.set_owner().is_err());
+ }
+
+ #[test]
+ fn test_set_owner() {
+ let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let path = temp_path();
+ let (master, mut slave) = create_slave(path, slave_be.clone());
+
+ assert!(!slave_be.lock().unwrap().owned);
+ master.set_owner().unwrap();
+ slave.handle_request().unwrap();
+ assert!(slave_be.lock().unwrap().owned);
+ master.set_owner().unwrap();
+ assert!(slave.handle_request().is_err());
+ assert!(slave_be.lock().unwrap().owned);
+ }
+
+ #[test]
+ fn test_set_features() {
+ let mbar = Arc::new(Barrier::new(2));
+ let sbar = mbar.clone();
+ let path = temp_path();
+ let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let (mut master, mut slave) = create_slave(path, slave_be.clone());
+
+ thread::spawn(move || {
+ slave.handle_request().unwrap();
+ assert!(slave_be.lock().unwrap().owned);
+
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_features,
+ VIRTIO_FEATURES & !0x1
+ );
+
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_protocol_features,
+ VhostUserProtocolFeatures::all().bits()
+ );
+
+ sbar.wait();
+ });
+
+ master.set_owner().unwrap();
+
+ // set virtio features
+ let features = master.get_features().unwrap();
+ assert_eq!(features, VIRTIO_FEATURES);
+ master.set_features(VIRTIO_FEATURES & !0x1).unwrap();
+
+ // set vhost protocol features
+ let features = master.get_protocol_features().unwrap();
+ assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
+ master.set_protocol_features(features).unwrap();
+
+ mbar.wait();
+ }
+
+ #[test]
+ fn test_master_slave_process() {
+ let mbar = Arc::new(Barrier::new(2));
+ let sbar = mbar.clone();
+ let path = temp_path();
+ let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let (mut master, mut slave) = create_slave(path, slave_be.clone());
+
+ thread::spawn(move || {
+ // set_own()
+ slave.handle_request().unwrap();
+ assert!(slave_be.lock().unwrap().owned);
+
+ // get/set_features()
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_features,
+ VIRTIO_FEATURES & !0x1
+ );
+
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+
+ let mut features = VhostUserProtocolFeatures::all();
+
+ // Disable Xen mmap feature.
+ if !cfg!(feature = "xen") {
+ features.remove(VhostUserProtocolFeatures::XEN_MMAP);
+ }
+
+ assert_eq!(
+ slave_be.lock().unwrap().acked_protocol_features,
+ features.bits()
+ );
+
+ // get_inflight_fd()
+ slave.handle_request().unwrap();
+ // set_inflight_fd()
+ slave.handle_request().unwrap();
+
+ // get_queue_num()
+ slave.handle_request().unwrap();
+
+ // set_mem_table()
+ slave.handle_request().unwrap();
+
+ // get/set_config()
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+
+ // set_slave_request_fd
+ slave.handle_request().unwrap();
+
+ // set_vring_enable
+ slave.handle_request().unwrap();
+
+ // set_log_base,set_log_fd()
+ slave.handle_request().unwrap_err();
+ slave.handle_request().unwrap_err();
+
+ // set_vring_xxx
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+
+ // get_max_mem_slots()
+ slave.handle_request().unwrap();
+
+ // add_mem_region()
+ slave.handle_request().unwrap();
+
+ // remove_mem_region()
+ slave.handle_request().unwrap();
+
+ sbar.wait();
+ });
+
+ master.set_owner().unwrap();
+
+ // set virtio features
+ let features = master.get_features().unwrap();
+ assert_eq!(features, VIRTIO_FEATURES);
+ master.set_features(VIRTIO_FEATURES & !0x1).unwrap();
+
+ // set vhost protocol features
+ let mut features = master.get_protocol_features().unwrap();
+ assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
+
+ // Disable Xen mmap feature.
+ if !cfg!(feature = "xen") {
+ features.remove(VhostUserProtocolFeatures::XEN_MMAP);
+ }
+
+ master.set_protocol_features(features).unwrap();
+
+ // Retrieve inflight I/O tracking information
+ let (inflight_info, inflight_file) = master
+ .get_inflight_fd(&VhostUserInflight {
+ num_queues: 2,
+ queue_size: 256,
+ ..Default::default()
+ })
+ .unwrap();
+ // Set the buffer back to the backend
+ master
+ .set_inflight_fd(&inflight_info, inflight_file.as_raw_fd())
+ .unwrap();
+
+ let num = master.get_queue_num().unwrap();
+ assert_eq!(num, 2);
+
+ let eventfd = vmm_sys_util::eventfd::EventFd::new(0).unwrap();
+ let mem = [VhostUserMemoryRegionInfo::new(
+ 0,
+ 0x10_0000,
+ 0,
+ 0,
+ eventfd.as_raw_fd(),
+ )];
+ master.set_mem_table(&mem).unwrap();
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8; 4])
+ .unwrap();
+ let buf = [0x0u8; 4];
+ let (reply_body, reply_payload) = master
+ .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
+ .unwrap();
+ let offset = reply_body.offset;
+ assert_eq!(offset, 0x100);
+ assert_eq!(&reply_payload, &[0xa5; 4]);
+
+ master.set_slave_request_fd(&eventfd).unwrap();
+ master.set_vring_enable(0, true).unwrap();
+
+ master
+ .set_log_base(
+ 0,
+ Some(VhostUserDirtyLogRegion {
+ mmap_size: 0x1000,
+ mmap_offset: 0,
+ mmap_handle: eventfd.as_raw_fd(),
+ }),
+ )
+ .unwrap();
+ master.set_log_fd(eventfd.as_raw_fd()).unwrap();
+
+ master.set_vring_num(0, 256).unwrap();
+ master.set_vring_base(0, 0).unwrap();
+ let config = VringConfigData {
+ queue_max_size: 256,
+ queue_size: 128,
+ flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: Some(0x4000),
+ };
+ master.set_vring_addr(0, &config).unwrap();
+ master.set_vring_call(0, &eventfd).unwrap();
+ master.set_vring_kick(0, &eventfd).unwrap();
+ master.set_vring_err(0, &eventfd).unwrap();
+
+ let max_mem_slots = master.get_max_mem_slots().unwrap();
+ assert_eq!(max_mem_slots, 32);
+
+ let region_file: File = TempFile::new().unwrap().into_file();
+ let region =
+ VhostUserMemoryRegionInfo::new(0x10_0000, 0x10_0000, 0, 0, region_file.as_raw_fd());
+ master.add_mem_region(&region).unwrap();
+
+ master.remove_mem_region(&region).unwrap();
+
+ mbar.wait();
+ }
+
+ #[test]
+ fn test_error_display() {
+ assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters");
+ assert_eq!(
+ format!("{}", Error::InvalidOperation("reason")),
+ "invalid operation: reason"
+ );
+ }
+
+ #[test]
+ fn test_should_reconnect() {
+ assert!(Error::PartialMessage.should_reconnect());
+ assert!(Error::SlaveInternalError.should_reconnect());
+ assert!(Error::MasterInternalError.should_reconnect());
+ assert!(!Error::InvalidParam.should_reconnect());
+ assert!(!Error::InvalidOperation("reason").should_reconnect());
+ assert!(
+ !Error::InactiveFeature(VhostUserVirtioFeatures::PROTOCOL_FEATURES).should_reconnect()
+ );
+ assert!(!Error::InactiveOperation(VhostUserProtocolFeatures::all()).should_reconnect());
+ assert!(!Error::InvalidMessage.should_reconnect());
+ assert!(!Error::IncorrectFds.should_reconnect());
+ assert!(!Error::OversizedMsg.should_reconnect());
+ assert!(!Error::FeatureMismatch.should_reconnect());
+ }
+
+ #[test]
+ fn test_error_from_sys_util_error() {
+ let e: Error = vmm_sys_util::errno::Error::new(libc::EAGAIN).into();
+ if let Error::SocketRetry(e1) = e {
+ assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN);
+ } else {
+ panic!("invalid error code conversion!");
+ }
+ }
+}
diff --git a/src/vhost_user/slave.rs b/src/vhost_user/slave.rs
new file mode 100644
index 0000000..fb65c41
--- /dev/null
+++ b/src/vhost_user/slave.rs
@@ -0,0 +1,86 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! Traits and Structs for vhost-user slave.
+
+use std::sync::Arc;
+
+use super::connection::{Endpoint, Listener};
+use super::message::*;
+use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler};
+
+/// Vhost-user slave side connection listener.
+pub struct SlaveListener<S: VhostUserSlaveReqHandler> {
+ listener: Listener,
+ backend: Option<Arc<S>>,
+}
+
+/// Sets up a listener for incoming master connections, and handles construction
+/// of a Slave on success.
+impl<S: VhostUserSlaveReqHandler> SlaveListener<S> {
+ /// Create a unix domain socket for incoming master connections.
+ pub fn new(listener: Listener, backend: Arc<S>) -> Result<Self> {
+ Ok(SlaveListener {
+ listener,
+ backend: Some(backend),
+ })
+ }
+
+ /// Accept an incoming connection from the master, returning Some(Slave) on
+ /// success, or None if the socket is nonblocking and no incoming connection
+ /// was detected
+ pub fn accept(&mut self) -> Result<Option<SlaveReqHandler<S>>> {
+ if let Some(fd) = self.listener.accept()? {
+ return Ok(Some(SlaveReqHandler::new(
+ Endpoint::<MasterReq>::from_stream(fd),
+ self.backend.take().unwrap(),
+ )));
+ }
+ Ok(None)
+ }
+
+ /// Change blocking status on the listener.
+ pub fn set_nonblocking(&self, block: bool) -> Result<()> {
+ self.listener.set_nonblocking(block)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::sync::Mutex;
+
+ use super::*;
+ use crate::vhost_user::dummy_slave::DummySlaveReqHandler;
+
+ #[test]
+ fn test_slave_listener_set_nonblocking() {
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let listener =
+ Listener::new("/tmp/vhost_user_lib_unit_test_slave_nonblocking", true).unwrap();
+ let slave_listener = SlaveListener::new(listener, backend).unwrap();
+
+ slave_listener.set_nonblocking(true).unwrap();
+ slave_listener.set_nonblocking(false).unwrap();
+ slave_listener.set_nonblocking(false).unwrap();
+ slave_listener.set_nonblocking(true).unwrap();
+ slave_listener.set_nonblocking(true).unwrap();
+ }
+
+ #[cfg(feature = "vhost-user-master")]
+ #[test]
+ fn test_slave_listener_accept() {
+ use super::super::Master;
+
+ let path = "/tmp/vhost_user_lib_unit_test_slave_accept";
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let listener = Listener::new(path, true).unwrap();
+ let mut slave_listener = SlaveListener::new(listener, backend).unwrap();
+
+ slave_listener.set_nonblocking(true).unwrap();
+ assert!(slave_listener.accept().unwrap().is_none());
+ assert!(slave_listener.accept().unwrap().is_none());
+
+ let _master = Master::connect(path, 1).unwrap();
+ let _slave = slave_listener.accept().unwrap().unwrap();
+ }
+}
diff --git a/src/vhost_user/slave_req.rs b/src/vhost_user/slave_req.rs
new file mode 100644
index 0000000..ade1e91
--- /dev/null
+++ b/src/vhost_user/slave_req.rs
@@ -0,0 +1,219 @@
+// Copyright (C) 2020 Alibaba Cloud. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+use std::io;
+use std::mem;
+use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::net::UnixStream;
+use std::sync::{Arc, Mutex, MutexGuard};
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler};
+
+use vm_memory::ByteValued;
+
+struct SlaveInternal {
+ sock: Endpoint<SlaveReq>,
+
+ // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
+ reply_ack_negotiated: bool,
+
+ // whether the endpoint has encountered any failure
+ error: Option<i32>,
+}
+
+impl SlaveInternal {
+ fn check_state(&self) -> Result<u64> {
+ match self.error {
+ Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
+ None => Ok(0),
+ }
+ }
+
+ fn send_message<T: ByteValued>(
+ &mut self,
+ request: SlaveReq,
+ body: &T,
+ fds: Option<&[RawFd]>,
+ ) -> Result<u64> {
+ self.check_state()?;
+
+ let len = mem::size_of::<T>();
+ let mut hdr = VhostUserMsgHeader::new(request, 0, len as u32);
+ if self.reply_ack_negotiated {
+ hdr.set_need_reply(true);
+ }
+ self.sock.send_message(&hdr, body, fds)?;
+
+ self.wait_for_ack(&hdr)
+ }
+
+ fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<SlaveReq>) -> Result<u64> {
+ self.check_state()?;
+ if !self.reply_ack_negotiated {
+ return Ok(0);
+ }
+
+ let (reply, body, rfds) = self.sock.recv_body::<VhostUserU64>()?;
+ if !reply.is_reply_for(hdr) || rfds.is_some() || !body.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ if body.value != 0 {
+ return Err(Error::MasterInternalError);
+ }
+
+ Ok(body.value)
+ }
+}
+
+/// Request proxy to send vhost-user slave requests to the master through the slave
+/// communication channel.
+///
+/// The [Slave] acts as a message proxy to forward vhost-user slave requests to the
+/// master through the vhost-user slave communication channel. The forwarded messages will be
+/// handled by the [MasterReqHandler] server.
+///
+/// [Slave]: struct.Slave.html
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+#[derive(Clone)]
+pub struct Slave {
+ // underlying Unix domain socket for communication
+ node: Arc<Mutex<SlaveInternal>>,
+}
+
+impl Slave {
+ fn new(ep: Endpoint<SlaveReq>) -> Self {
+ Slave {
+ node: Arc::new(Mutex::new(SlaveInternal {
+ sock: ep,
+ reply_ack_negotiated: false,
+ error: None,
+ })),
+ }
+ }
+
+ fn node(&self) -> MutexGuard<SlaveInternal> {
+ self.node.lock().unwrap()
+ }
+
+ fn send_message<T: ByteValued>(
+ &self,
+ request: SlaveReq,
+ body: &T,
+ fds: Option<&[RawFd]>,
+ ) -> io::Result<u64> {
+ self.node()
+ .send_message(request, body, fds)
+ .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))
+ }
+
+ /// Create a new instance from a `UnixStream` object.
+ pub fn from_stream(sock: UnixStream) -> Self {
+ Self::new(Endpoint::<SlaveReq>::from_stream(sock))
+ }
+
+ /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
+ ///
+ /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
+ /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
+ /// message.
+ pub fn set_reply_ack_flag(&self, enable: bool) {
+ self.node().reply_ack_negotiated = enable;
+ }
+
+ /// Mark endpoint as failed with specified error code.
+ pub fn set_failed(&self, error: i32) {
+ self.node().error = Some(error);
+ }
+}
+
+impl VhostUserMasterReqHandler for Slave {
+ /// Forward vhost-user-fs map file requests to the slave.
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
+ self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd.as_raw_fd()]))
+ }
+
+ /// Forward vhost-user-fs unmap file requests to the master.
+ fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ self.send_message(SlaveReq::FS_UNMAP, fs, None)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::os::unix::io::AsRawFd;
+
+ use super::*;
+
+ #[test]
+ fn test_slave_req_set_failed() {
+ let (p1, _p2) = UnixStream::pair().unwrap();
+ let slave = Slave::from_stream(p1);
+
+ assert!(slave.node().error.is_none());
+ slave.set_failed(libc::EAGAIN);
+ assert_eq!(slave.node().error, Some(libc::EAGAIN));
+ }
+
+ #[test]
+ fn test_slave_req_send_failure() {
+ let (p1, p2) = UnixStream::pair().unwrap();
+ let slave = Slave::from_stream(p1);
+
+ slave.set_failed(libc::ECONNRESET);
+ slave
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &p2)
+ .unwrap_err();
+ slave
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ slave.node().error = None;
+ }
+
+ #[test]
+ fn test_slave_req_recv_negative() {
+ let (p1, p2) = UnixStream::pair().unwrap();
+ let slave = Slave::from_stream(p1);
+ let mut master = Endpoint::<SlaveReq>::from_stream(p2);
+
+ let len = mem::size_of::<VhostUserFSSlaveMsg>();
+ let mut hdr = VhostUserMsgHeader::new(
+ SlaveReq::FS_MAP,
+ VhostUserHeaderFlag::REPLY.bits(),
+ len as u32,
+ );
+ let body = VhostUserU64::new(0);
+
+ master
+ .send_message(&hdr, &body, Some(&[master.as_raw_fd()]))
+ .unwrap();
+ slave
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
+ .unwrap();
+
+ slave.set_reply_ack_flag(true);
+ slave
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
+ .unwrap_err();
+
+ hdr.set_code(SlaveReq::FS_UNMAP);
+ master.send_message(&hdr, &body, None).unwrap();
+ slave
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
+ .unwrap_err();
+ hdr.set_code(SlaveReq::FS_MAP);
+
+ let body = VhostUserU64::new(1);
+ master.send_message(&hdr, &body, None).unwrap();
+ slave
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
+ .unwrap_err();
+
+ let body = VhostUserU64::new(0);
+ master.send_message(&hdr, &body, None).unwrap();
+ slave
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
+ .unwrap();
+ }
+}
diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs
new file mode 100644
index 0000000..e998339
--- /dev/null
+++ b/src/vhost_user/slave_req_handler.rs
@@ -0,0 +1,833 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+use std::fs::File;
+use std::mem;
+use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
+use std::os::unix::net::UnixStream;
+use std::slice;
+use std::sync::{Arc, Mutex};
+
+use vm_memory::ByteValued;
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::slave_req::Slave;
+use super::{take_single_file, Error, Result};
+
+/// Services provided to the master by the slave with interior mutability.
+///
+/// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave.
+/// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler],
+/// but without interior mutability.
+/// The vhost-user specification defines a master communication channel, by which masters could
+/// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by
+/// slaves, and it's used both on the master side and slave side.
+///
+/// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy
+/// service requests to slaves.
+/// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler
+/// implementing [VhostUserSlaveReqHandler].
+///
+/// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance
+/// for multi-threading.
+///
+/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
+/// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html
+/// [SlaveReqHandler]: struct.SlaveReqHandler.html
+#[allow(missing_docs)]
+pub trait VhostUserSlaveReqHandler {
+ fn set_owner(&self) -> Result<()>;
+ fn reset_owner(&self) -> Result<()>;
+ fn get_features(&self) -> Result<u64>;
+ fn set_features(&self, features: u64) -> Result<()>;
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
+ fn set_vring_num(&self, index: u32, num: u32) -> Result<()>;
+ fn set_vring_addr(
+ &self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()>;
+ fn set_vring_base(&self, index: u32, base: u32) -> Result<()>;
+ fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>;
+ fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>;
+ fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>;
+ fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>;
+
+ fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>;
+ fn set_protocol_features(&self, features: u64) -> Result<()>;
+ fn get_queue_num(&self) -> Result<u64>;
+ fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>;
+ fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>;
+ fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
+ fn set_slave_req_fd(&self, _slave: Slave) {}
+ fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>;
+ fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>;
+ fn get_max_mem_slots(&self) -> Result<u64>;
+ fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
+ fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
+}
+
+/// Services provided to the master by the slave without interior mutability.
+///
+/// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait.
+#[allow(missing_docs)]
+pub trait VhostUserSlaveReqHandlerMut {
+ fn set_owner(&mut self) -> Result<()>;
+ fn reset_owner(&mut self) -> Result<()>;
+ fn get_features(&mut self) -> Result<u64>;
+ fn set_features(&mut self, features: u64) -> Result<()>;
+ fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
+ fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
+ fn set_vring_addr(
+ &mut self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()>;
+ fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
+ fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
+ fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
+ fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
+ fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
+
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
+ fn set_protocol_features(&mut self, features: u64) -> Result<()>;
+ fn get_queue_num(&mut self) -> Result<u64>;
+ fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
+ fn get_config(
+ &mut self,
+ offset: u32,
+ size: u32,
+ flags: VhostUserConfigFlags,
+ ) -> Result<Vec<u8>>;
+ fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
+ fn set_slave_req_fd(&mut self, _slave: Slave) {}
+ fn get_inflight_fd(
+ &mut self,
+ inflight: &VhostUserInflight,
+ ) -> Result<(VhostUserInflight, File)>;
+ fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
+ fn get_max_mem_slots(&mut self) -> Result<u64>;
+ fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
+ fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
+}
+
+impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
+ fn set_owner(&self) -> Result<()> {
+ self.lock().unwrap().set_owner()
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ self.lock().unwrap().reset_owner()
+ }
+
+ fn get_features(&self) -> Result<u64> {
+ self.lock().unwrap().get_features()
+ }
+
+ fn set_features(&self, features: u64) -> Result<()> {
+ self.lock().unwrap().set_features(features)
+ }
+
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
+ self.lock().unwrap().set_mem_table(ctx, files)
+ }
+
+ fn set_vring_num(&self, index: u32, num: u32) -> Result<()> {
+ self.lock().unwrap().set_vring_num(index, num)
+ }
+
+ fn set_vring_addr(
+ &self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()> {
+ self.lock()
+ .unwrap()
+ .set_vring_addr(index, flags, descriptor, used, available, log)
+ }
+
+ fn set_vring_base(&self, index: u32, base: u32) -> Result<()> {
+ self.lock().unwrap().set_vring_base(index, base)
+ }
+
+ fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> {
+ self.lock().unwrap().get_vring_base(index)
+ }
+
+ fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> {
+ self.lock().unwrap().set_vring_kick(index, fd)
+ }
+
+ fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> {
+ self.lock().unwrap().set_vring_call(index, fd)
+ }
+
+ fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> {
+ self.lock().unwrap().set_vring_err(index, fd)
+ }
+
+ fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> {
+ self.lock().unwrap().get_protocol_features()
+ }
+
+ fn set_protocol_features(&self, features: u64) -> Result<()> {
+ self.lock().unwrap().set_protocol_features(features)
+ }
+
+ fn get_queue_num(&self) -> Result<u64> {
+ self.lock().unwrap().get_queue_num()
+ }
+
+ fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> {
+ self.lock().unwrap().set_vring_enable(index, enable)
+ }
+
+ fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> {
+ self.lock().unwrap().get_config(offset, size, flags)
+ }
+
+ fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
+ self.lock().unwrap().set_config(offset, buf, flags)
+ }
+
+ fn set_slave_req_fd(&self, slave: Slave) {
+ self.lock().unwrap().set_slave_req_fd(slave)
+ }
+
+ fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> {
+ self.lock().unwrap().get_inflight_fd(inflight)
+ }
+
+ fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> {
+ self.lock().unwrap().set_inflight_fd(inflight, file)
+ }
+
+ fn get_max_mem_slots(&self) -> Result<u64> {
+ self.lock().unwrap().get_max_mem_slots()
+ }
+
+ fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
+ self.lock().unwrap().add_mem_region(region, fd)
+ }
+
+ fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
+ self.lock().unwrap().remove_mem_region(region)
+ }
+}
+
+/// Server to handle service requests from masters from the master communication channel.
+///
+/// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from
+/// masters on the master communication channel. It's actually a proxy invoking the registered
+/// handler implementing [VhostUserSlaveReqHandler] to do the real work.
+///
+/// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain
+/// Socket, so it gets simpler to recover from disconnect.
+///
+/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
+/// [SlaveReqHandler]: struct.SlaveReqHandler.html
+pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> {
+ // underlying Unix domain socket for communication
+ main_sock: Endpoint<MasterReq>,
+ // the vhost-user backend device object
+ backend: Arc<S>,
+
+ virtio_features: u64,
+ acked_virtio_features: u64,
+ protocol_features: VhostUserProtocolFeatures,
+ acked_protocol_features: u64,
+
+ // sending ack for messages without payload
+ reply_ack_enabled: bool,
+ // whether the endpoint has encountered any failure
+ error: Option<i32>,
+}
+
+impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
+ /// Create a vhost-user slave endpoint.
+ pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self {
+ SlaveReqHandler {
+ main_sock,
+ backend,
+ virtio_features: 0,
+ acked_virtio_features: 0,
+ protocol_features: VhostUserProtocolFeatures::empty(),
+ acked_protocol_features: 0,
+ reply_ack_enabled: false,
+ error: None,
+ }
+ }
+
+ fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> {
+ if self.acked_virtio_features & feat.bits() != 0 {
+ Ok(())
+ } else {
+ Err(Error::InactiveFeature(feat))
+ }
+ }
+
+ fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> Result<()> {
+ if self.acked_protocol_features & feat.bits() != 0 {
+ Ok(())
+ } else {
+ Err(Error::InactiveOperation(feat))
+ }
+ }
+
+ /// Create a vhost-user slave endpoint from a connected socket.
+ pub fn from_stream(socket: UnixStream, backend: Arc<S>) -> Self {
+ Self::new(Endpoint::from_stream(socket), backend)
+ }
+
+ /// Create a new vhost-user slave endpoint.
+ ///
+ /// # Arguments
+ /// * - `path` - path of Unix domain socket listener to connect to
+ /// * - `backend` - handler for requests from the master to the slave
+ pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> {
+ Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend))
+ }
+
+ /// Mark endpoint as failed with specified error code.
+ pub fn set_failed(&mut self, error: i32) {
+ self.error = Some(error);
+ }
+
+ /// Main entrance to server slave request from the slave communication channel.
+ ///
+ /// Receive and handle one incoming request message from the master. The caller needs to:
+ /// - serialize calls to this function
+ /// - decide what to do when error happens
+ /// - optional recover from failure
+ pub fn handle_request(&mut self) -> Result<()> {
+ // Return error if the endpoint is already in failed state.
+ self.check_state()?;
+
+ // The underlying communication channel is a Unix domain socket in
+ // stream mode, and recvmsg() is a little tricky here. To successfully
+ // receive attached file descriptors, we need to receive messages and
+ // corresponding attached file descriptors in this way:
+ // . recv messsage header and optional attached file
+ // . validate message header
+ // . recv optional message body and payload according size field in
+ // message header
+ // . validate message body and optional payload
+ let (hdr, files) = self.main_sock.recv_header()?;
+ self.check_attached_files(&hdr, &files)?;
+
+ let (size, buf) = match hdr.get_size() {
+ 0 => (0, vec![0u8; 0]),
+ len => {
+ let (size2, rbuf) = self.main_sock.recv_data(len as usize)?;
+ if size2 != len as usize {
+ return Err(Error::InvalidMessage);
+ }
+ (size2, rbuf)
+ }
+ };
+
+ match hdr.get_code() {
+ Ok(MasterReq::SET_OWNER) => {
+ self.check_request_size(&hdr, size, 0)?;
+ let res = self.backend.set_owner();
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::RESET_OWNER) => {
+ self.check_request_size(&hdr, size, 0)?;
+ let res = self.backend.reset_owner();
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::GET_FEATURES) => {
+ self.check_request_size(&hdr, size, 0)?;
+ let features = self.backend.get_features()?;
+ let msg = VhostUserU64::new(features);
+ self.send_reply_message(&hdr, &msg)?;
+ self.virtio_features = features;
+ self.update_reply_ack_flag();
+ }
+ Ok(MasterReq::SET_FEATURES) => {
+ let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
+ let res = self.backend.set_features(msg.value);
+ self.acked_virtio_features = msg.value;
+ self.update_reply_ack_flag();
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::SET_MEM_TABLE) => {
+ let res = self.set_mem_table(&hdr, size, &buf, files);
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::SET_VRING_NUM) => {
+ let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
+ let res = self.backend.set_vring_num(msg.index, msg.num);
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::SET_VRING_ADDR) => {
+ let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
+ let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
+ Some(val) => val,
+ None => return Err(Error::InvalidMessage),
+ };
+ let res = self.backend.set_vring_addr(
+ msg.index,
+ flags,
+ msg.descriptor,
+ msg.used,
+ msg.available,
+ msg.log,
+ );
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::SET_VRING_BASE) => {
+ let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
+ let res = self.backend.set_vring_base(msg.index, msg.num);
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::GET_VRING_BASE) => {
+ let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
+ let reply = self.backend.get_vring_base(msg.index)?;
+ self.send_reply_message(&hdr, &reply)?;
+ }
+ Ok(MasterReq::SET_VRING_CALL) => {
+ self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
+ let (index, file) = self.handle_vring_fd_request(&buf, files)?;
+ let res = self.backend.set_vring_call(index, file);
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::SET_VRING_KICK) => {
+ self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
+ let (index, file) = self.handle_vring_fd_request(&buf, files)?;
+ let res = self.backend.set_vring_kick(index, file);
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::SET_VRING_ERR) => {
+ self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
+ let (index, file) = self.handle_vring_fd_request(&buf, files)?;
+ let res = self.backend.set_vring_err(index, file);
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::GET_PROTOCOL_FEATURES) => {
+ self.check_request_size(&hdr, size, 0)?;
+ let features = self.backend.get_protocol_features()?;
+
+ // Enable the `XEN_MMAP` protocol feature for backends if xen feature is enabled.
+ #[cfg(feature = "xen")]
+ let features = features | VhostUserProtocolFeatures::XEN_MMAP;
+
+ let msg = VhostUserU64::new(features.bits());
+ self.send_reply_message(&hdr, &msg)?;
+ self.protocol_features = features;
+ self.update_reply_ack_flag();
+ }
+ Ok(MasterReq::SET_PROTOCOL_FEATURES) => {
+ let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
+ let res = self.backend.set_protocol_features(msg.value);
+ self.acked_protocol_features = msg.value;
+ self.update_reply_ack_flag();
+ self.send_ack_message(&hdr, res)?;
+
+ #[cfg(feature = "xen")]
+ self.check_proto_feature(VhostUserProtocolFeatures::XEN_MMAP)?;
+ }
+ Ok(MasterReq::GET_QUEUE_NUM) => {
+ self.check_proto_feature(VhostUserProtocolFeatures::MQ)?;
+ self.check_request_size(&hdr, size, 0)?;
+ let num = self.backend.get_queue_num()?;
+ let msg = VhostUserU64::new(num);
+ self.send_reply_message(&hdr, &msg)?;
+ }
+ Ok(MasterReq::SET_VRING_ENABLE) => {
+ let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
+ self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
+ let enable = match msg.num {
+ 1 => true,
+ 0 => false,
+ _ => return Err(Error::InvalidParam),
+ };
+
+ let res = self.backend.set_vring_enable(msg.index, enable);
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::GET_CONFIG) => {
+ self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
+ self.get_config(&hdr, &buf)?;
+ }
+ Ok(MasterReq::SET_CONFIG) => {
+ self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
+ let res = self.set_config(size, &buf);
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::SET_SLAVE_REQ_FD) => {
+ self.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?;
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
+ let res = self.set_slave_req_fd(files);
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::GET_INFLIGHT_FD) => {
+ self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
+
+ let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
+ let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
+ let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?;
+ self.main_sock
+ .send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?;
+ }
+ Ok(MasterReq::SET_INFLIGHT_FD) => {
+ self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
+ let file = take_single_file(files).ok_or(Error::IncorrectFds)?;
+ let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
+ let res = self.backend.set_inflight_fd(&msg, file);
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::GET_MAX_MEM_SLOTS) => {
+ self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
+ self.check_request_size(&hdr, size, 0)?;
+ let num = self.backend.get_max_mem_slots()?;
+ let msg = VhostUserU64::new(num);
+ self.send_reply_message(&hdr, &msg)?;
+ }
+ Ok(MasterReq::ADD_MEM_REG) => {
+ self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
+ let mut files = files.ok_or(Error::InvalidParam)?;
+ if files.len() != 1 {
+ return Err(Error::InvalidParam);
+ }
+ let msg =
+ self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
+ let res = self.backend.add_mem_region(&msg, files.swap_remove(0));
+ self.send_ack_message(&hdr, res)?;
+ }
+ Ok(MasterReq::REM_MEM_REG) => {
+ self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
+
+ let msg =
+ self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
+ let res = self.backend.remove_mem_region(&msg);
+ self.send_ack_message(&hdr, res)?;
+ }
+ _ => {
+ return Err(Error::InvalidMessage);
+ }
+ }
+ Ok(())
+ }
+
+ fn set_mem_table(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ size: usize,
+ buf: &[u8],
+ files: Option<Vec<File>>,
+ ) -> Result<()> {
+ self.check_request_size(hdr, size, hdr.get_size() as usize)?;
+
+ // check message size is consistent
+ let hdrsize = mem::size_of::<VhostUserMemory>();
+ if size < hdrsize {
+ return Err(Error::InvalidMessage);
+ }
+ // SAFETY: Safe because we checked that `buf` size is at least that of
+ // VhostUserMemory.
+ let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() {
+ return Err(Error::InvalidMessage);
+ }
+
+ // validate number of fds matching number of memory regions
+ let files = files.ok_or(Error::InvalidMessage)?;
+ if files.len() != msg.num_regions as usize {
+ return Err(Error::InvalidMessage);
+ }
+
+ // Validate memory regions
+ //
+ // SAFETY: Safe because we checked that `buf` size is equal to that of
+ // VhostUserMemory, plus `msg.num_regions` elements of VhostUserMemoryRegion.
+ let regions = unsafe {
+ slice::from_raw_parts(
+ buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion,
+ msg.num_regions as usize,
+ )
+ };
+ for region in regions.iter() {
+ if !region.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ }
+
+ self.backend.set_mem_table(regions, files)
+ }
+
+ fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> {
+ let payload_offset = mem::size_of::<VhostUserConfig>();
+ if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset {
+ return Err(Error::InvalidMessage);
+ }
+ // SAFETY: Safe because we checked that `buf` size is at least that of VhostUserConfig.
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ if buf.len() - payload_offset != msg.size as usize {
+ return Err(Error::InvalidMessage);
+ }
+ let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
+ Some(val) => val,
+ None => return Err(Error::InvalidMessage),
+ };
+ let res = self.backend.get_config(msg.offset, msg.size, flags);
+
+ // vhost-user slave's payload size MUST match master's request
+ // on success, uses zero length of payload to indicate an error
+ // to vhost-user master.
+ match res {
+ Ok(ref buf) if buf.len() == msg.size as usize => {
+ let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
+ self.send_reply_with_payload(hdr, &reply, buf.as_slice())?;
+ }
+ Ok(_) => {
+ let reply = VhostUserConfig::new(msg.offset, 0, flags);
+ self.send_reply_message(hdr, &reply)?;
+ }
+ Err(_) => {
+ let reply = VhostUserConfig::new(msg.offset, 0, flags);
+ self.send_reply_message(hdr, &reply)?;
+ }
+ }
+ Ok(())
+ }
+
+ fn set_config(&mut self, size: usize, buf: &[u8]) -> Result<()> {
+ if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() {
+ return Err(Error::InvalidMessage);
+ }
+ // SAFETY: Safe because we checked that `buf` size is at least that of VhostUserConfig.
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ if size - mem::size_of::<VhostUserConfig>() != msg.size as usize {
+ return Err(Error::InvalidMessage);
+ }
+ let flags = VhostUserConfigFlags::from_bits(msg.flags).ok_or(Error::InvalidMessage)?;
+
+ self.backend
+ .set_config(msg.offset, &buf[mem::size_of::<VhostUserConfig>()..], flags)
+ }
+
+ fn set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> {
+ let file = take_single_file(files).ok_or(Error::InvalidMessage)?;
+ // SAFETY: Safe because we have ownership of the files that were
+ // checked when received. We have to trust that they are Unix sockets
+ // since we have no way to check this. If not, it will fail later.
+ let sock = unsafe { UnixStream::from_raw_fd(file.into_raw_fd()) };
+ let slave = Slave::from_stream(sock);
+ self.backend.set_slave_req_fd(slave);
+ Ok(())
+ }
+
+ fn handle_vring_fd_request(
+ &mut self,
+ buf: &[u8],
+ files: Option<Vec<File>>,
+ ) -> Result<(u8, Option<File>)> {
+ if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() {
+ return Err(Error::InvalidMessage);
+ }
+ // SAFETY: Safe because we checked that `buf` size is at least that of VhostUserU64.
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ // Bits (0-7) of the payload contain the vring index. Bit 8 is the
+ // invalid FD flag. This bit is set when there is no file descriptor
+ // in the ancillary data. This signals that polling will be used
+ // instead of waiting for the call.
+ // If Bit 8 is unset, the data must contain a file descriptor.
+ let has_fd = (msg.value & 0x100u64) == 0;
+
+ let file = take_single_file(files);
+
+ if has_fd && file.is_none() || !has_fd && file.is_some() {
+ return Err(Error::InvalidMessage);
+ }
+
+ Ok((msg.value as u8, file))
+ }
+
+ fn check_state(&self) -> Result<()> {
+ match self.error {
+ Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
+ None => Ok(()),
+ }
+ }
+
+ fn check_request_size(
+ &self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ size: usize,
+ expected: usize,
+ ) -> Result<()> {
+ if hdr.get_size() as usize != expected
+ || hdr.is_reply()
+ || hdr.get_version() != 0x1
+ || size != expected
+ {
+ return Err(Error::InvalidMessage);
+ }
+ Ok(())
+ }
+
+ fn check_attached_files(
+ &self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ files: &Option<Vec<File>>,
+ ) -> Result<()> {
+ match hdr.get_code() {
+ Ok(
+ MasterReq::SET_MEM_TABLE
+ | MasterReq::SET_VRING_CALL
+ | MasterReq::SET_VRING_KICK
+ | MasterReq::SET_VRING_ERR
+ | MasterReq::SET_LOG_BASE
+ | MasterReq::SET_LOG_FD
+ | MasterReq::SET_SLAVE_REQ_FD
+ | MasterReq::SET_INFLIGHT_FD
+ | MasterReq::ADD_MEM_REG,
+ ) => Ok(()),
+ _ if files.is_some() => Err(Error::InvalidMessage),
+ _ => Ok(()),
+ }
+ }
+
+ fn extract_request_body<T: Sized + VhostUserMsgValidator>(
+ &self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ size: usize,
+ buf: &[u8],
+ ) -> Result<T> {
+ self.check_request_size(hdr, size, mem::size_of::<T>())?;
+ // SAFETY: Safe because we checked that `buf` size is equal to T size.
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ Ok(msg)
+ }
+
+ fn update_reply_ack_flag(&mut self) {
+ let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ let pflag = VhostUserProtocolFeatures::REPLY_ACK;
+ if (self.virtio_features & vflag) != 0
+ && self.protocol_features.contains(pflag)
+ && (self.acked_protocol_features & pflag.bits()) != 0
+ {
+ self.reply_ack_enabled = true;
+ } else {
+ self.reply_ack_enabled = false;
+ }
+ }
+
+ fn new_reply_header<T: Sized>(
+ &self,
+ req: &VhostUserMsgHeader<MasterReq>,
+ payload_size: usize,
+ ) -> Result<VhostUserMsgHeader<MasterReq>> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE
+ || payload_size > MAX_MSG_SIZE
+ || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE
+ {
+ return Err(Error::InvalidParam);
+ }
+ self.check_state()?;
+ Ok(VhostUserMsgHeader::new(
+ req.get_code()?,
+ VhostUserHeaderFlag::REPLY.bits(),
+ (mem::size_of::<T>() + payload_size) as u32,
+ ))
+ }
+
+ fn send_ack_message(
+ &mut self,
+ req: &VhostUserMsgHeader<MasterReq>,
+ res: Result<()>,
+ ) -> Result<()> {
+ if self.reply_ack_enabled && req.is_need_reply() {
+ let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?;
+ let val = match res {
+ Ok(_) => 0,
+ Err(_) => 1,
+ };
+ let msg = VhostUserU64::new(val);
+ self.main_sock.send_message(&hdr, &msg, None)?;
+ }
+ res
+ }
+
+ fn send_reply_message<T: ByteValued>(
+ &mut self,
+ req: &VhostUserMsgHeader<MasterReq>,
+ msg: &T,
+ ) -> Result<()> {
+ let hdr = self.new_reply_header::<T>(req, 0)?;
+ self.main_sock.send_message(&hdr, msg, None)?;
+ Ok(())
+ }
+
+ fn send_reply_with_payload<T: ByteValued>(
+ &mut self,
+ req: &VhostUserMsgHeader<MasterReq>,
+ msg: &T,
+ payload: &[u8],
+ ) -> Result<()> {
+ let hdr = self.new_reply_header::<T>(req, payload.len())?;
+ self.main_sock
+ .send_message_with_payload(&hdr, msg, payload, None)?;
+ Ok(())
+ }
+}
+
+impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.main_sock.as_raw_fd()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::os::unix::io::AsRawFd;
+
+ use super::*;
+ use crate::vhost_user::dummy_slave::DummySlaveReqHandler;
+
+ #[test]
+ fn test_slave_req_handler_new() {
+ let (p1, _p2) = UnixStream::pair().unwrap();
+ let endpoint = Endpoint::<MasterReq>::from_stream(p1);
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let mut handler = SlaveReqHandler::new(endpoint, backend);
+
+ handler.check_state().unwrap();
+ handler.set_failed(libc::EAGAIN);
+ handler.check_state().unwrap_err();
+ assert!(handler.as_raw_fd() >= 0);
+ }
+}
diff --git a/src/vring.rs b/src/vring.rs
new file mode 100644
index 0000000..13e08ac
--- /dev/null
+++ b/src/vring.rs
@@ -0,0 +1,585 @@
+// Copyright 2019 Intel Corporation. All Rights Reserved.
+// Copyright 2021 Alibaba Cloud Computing. All rights reserved.
+//
+// SPDX-License-Identifier: Apache-2.0
+
+//! Struct to maintain state information and manipulate vhost-user queues.
+
+use std::fs::File;
+use std::io;
+use std::ops::{Deref, DerefMut};
+use std::os::unix::io::{FromRawFd, IntoRawFd};
+use std::result::Result;
+use std::sync::atomic::Ordering;
+use std::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
+
+use virtio_queue::{Error as VirtQueError, Queue, QueueT};
+use vm_memory::{GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap};
+use vmm_sys_util::eventfd::EventFd;
+
+/// Trait for objects returned by `VringT::get_ref()`.
+pub trait VringStateGuard<'a, M: GuestAddressSpace> {
+ /// Type for guard returned by `VringT::get_ref()`.
+ type G: Deref<Target = VringState<M>>;
+}
+
+/// Trait for objects returned by `VringT::get_mut()`.
+pub trait VringStateMutGuard<'a, M: GuestAddressSpace> {
+ /// Type for guard returned by `VringT::get_mut()`.
+ type G: DerefMut<Target = VringState<M>>;
+}
+
+pub trait VringT<M: GuestAddressSpace>:
+ for<'a> VringStateGuard<'a, M> + for<'a> VringStateMutGuard<'a, M>
+{
+ /// Create a new instance of Vring.
+ fn new(mem: M, max_queue_size: u16) -> Result<Self, VirtQueError>
+ where
+ Self: Sized;
+
+ /// Get an immutable reference to the kick event fd.
+ fn get_ref(&self) -> <Self as VringStateGuard<M>>::G;
+
+ /// Get a mutable reference to the kick event fd.
+ fn get_mut(&self) -> <Self as VringStateMutGuard<M>>::G;
+
+ /// Add an used descriptor into the used queue.
+ fn add_used(&self, desc_index: u16, len: u32) -> Result<(), VirtQueError>;
+
+ /// Notify the vhost-user master that used descriptors have been put into the used queue.
+ fn signal_used_queue(&self) -> io::Result<()>;
+
+ /// Enable event notification for queue.
+ fn enable_notification(&self) -> Result<bool, VirtQueError>;
+
+ /// Disable event notification for queue.
+ fn disable_notification(&self) -> Result<(), VirtQueError>;
+
+ /// Check whether a notification to the guest is needed.
+ fn needs_notification(&self) -> Result<bool, VirtQueError>;
+
+ /// Set vring enabled state.
+ fn set_enabled(&self, enabled: bool);
+
+ /// Set queue addresses for descriptor table, available ring and used ring.
+ fn set_queue_info(
+ &self,
+ desc_table: u64,
+ avail_ring: u64,
+ used_ring: u64,
+ ) -> Result<(), VirtQueError>;
+
+ /// Get queue next avail head.
+ fn queue_next_avail(&self) -> u16;
+
+ /// Set queue next avail head.
+ fn set_queue_next_avail(&self, base: u16);
+
+ /// Set queue next used head.
+ fn set_queue_next_used(&self, idx: u16);
+
+ /// Get queue next used head index from the guest memory.
+ fn queue_used_idx(&self) -> Result<u16, VirtQueError>;
+
+ /// Set configured queue size.
+ fn set_queue_size(&self, num: u16);
+
+ /// Enable/disable queue event index feature.
+ fn set_queue_event_idx(&self, enabled: bool);
+
+ /// Set queue enabled state.
+ fn set_queue_ready(&self, ready: bool);
+
+ /// Set `EventFd` for kick.
+ fn set_kick(&self, file: Option<File>);
+
+ /// Read event from the kick `EventFd`.
+ fn read_kick(&self) -> io::Result<bool>;
+
+ /// Set `EventFd` for call.
+ fn set_call(&self, file: Option<File>);
+
+ /// Set `EventFd` for err.
+ fn set_err(&self, file: Option<File>);
+}
+
+/// Struct to maintain raw state information for a vhost-user queue.
+///
+/// This struct maintains all information of a virito queue, and could be used as an `VringT`
+/// object for single-threaded context.
+pub struct VringState<M: GuestAddressSpace = GuestMemoryAtomic<GuestMemoryMmap>> {
+ queue: Queue,
+ kick: Option<EventFd>,
+ call: Option<EventFd>,
+ err: Option<EventFd>,
+ enabled: bool,
+ mem: M,
+}
+
+impl<M: GuestAddressSpace> VringState<M> {
+ /// Create a new instance of Vring.
+ fn new(mem: M, max_queue_size: u16) -> Result<Self, VirtQueError> {
+ Ok(VringState {
+ queue: Queue::new(max_queue_size)?,
+ kick: None,
+ call: None,
+ err: None,
+ enabled: false,
+ mem,
+ })
+ }
+
+ /// Get an immutable reference to the underlying raw `Queue` object.
+ pub fn get_queue(&self) -> &Queue {
+ &self.queue
+ }
+
+ /// Get a mutable reference to the underlying raw `Queue` object.
+ pub fn get_queue_mut(&mut self) -> &mut Queue {
+ &mut self.queue
+ }
+
+ /// Add an used descriptor into the used queue.
+ pub fn add_used(&mut self, desc_index: u16, len: u32) -> Result<(), VirtQueError> {
+ self.queue
+ .add_used(self.mem.memory().deref(), desc_index, len)
+ }
+
+ /// Notify the vhost-user master that used descriptors have been put into the used queue.
+ pub fn signal_used_queue(&self) -> io::Result<()> {
+ if let Some(call) = self.call.as_ref() {
+ call.write(1)
+ } else {
+ Ok(())
+ }
+ }
+
+ /// Enable event notification for queue.
+ pub fn enable_notification(&mut self) -> Result<bool, VirtQueError> {
+ self.queue.enable_notification(self.mem.memory().deref())
+ }
+
+ /// Disable event notification for queue.
+ pub fn disable_notification(&mut self) -> Result<(), VirtQueError> {
+ self.queue.disable_notification(self.mem.memory().deref())
+ }
+
+ /// Check whether a notification to the guest is needed.
+ pub fn needs_notification(&mut self) -> Result<bool, VirtQueError> {
+ self.queue.needs_notification(self.mem.memory().deref())
+ }
+
+ /// Set vring enabled state.
+ pub fn set_enabled(&mut self, enabled: bool) {
+ self.enabled = enabled;
+ }
+
+ /// Set queue addresses for descriptor table, available ring and used ring.
+ pub fn set_queue_info(
+ &mut self,
+ desc_table: u64,
+ avail_ring: u64,
+ used_ring: u64,
+ ) -> Result<(), VirtQueError> {
+ self.queue
+ .try_set_desc_table_address(GuestAddress(desc_table))?;
+ self.queue
+ .try_set_avail_ring_address(GuestAddress(avail_ring))?;
+ self.queue
+ .try_set_used_ring_address(GuestAddress(used_ring))
+ }
+
+ /// Get queue next avail head.
+ fn queue_next_avail(&self) -> u16 {
+ self.queue.next_avail()
+ }
+
+ /// Set queue next avail head.
+ fn set_queue_next_avail(&mut self, base: u16) {
+ self.queue.set_next_avail(base);
+ }
+
+ /// Set queue next used head.
+ fn set_queue_next_used(&mut self, idx: u16) {
+ self.queue.set_next_used(idx);
+ }
+
+ /// Get queue next used head index from the guest memory.
+ fn queue_used_idx(&self) -> Result<u16, VirtQueError> {
+ self.queue
+ .used_idx(self.mem.memory().deref(), Ordering::Relaxed)
+ .map(|idx| idx.0)
+ }
+
+ /// Set configured queue size.
+ fn set_queue_size(&mut self, num: u16) {
+ self.queue.set_size(num);
+ }
+
+ /// Enable/disable queue event index feature.
+ fn set_queue_event_idx(&mut self, enabled: bool) {
+ self.queue.set_event_idx(enabled);
+ }
+
+ /// Set queue enabled state.
+ fn set_queue_ready(&mut self, ready: bool) {
+ self.queue.set_ready(ready);
+ }
+
+ /// Get the `EventFd` for kick.
+ pub fn get_kick(&self) -> &Option<EventFd> {
+ &self.kick
+ }
+
+ /// Set `EventFd` for kick.
+ fn set_kick(&mut self, file: Option<File>) {
+ // SAFETY:
+ // EventFd requires that it has sole ownership of its fd. So does File, so this is safe.
+ // Ideally, we'd have a generic way to refer to a uniquely-owned fd, such as that proposed
+ // by Rust RFC #3128.
+ self.kick = file.map(|f| unsafe { EventFd::from_raw_fd(f.into_raw_fd()) });
+ }
+
+ /// Read event from the kick `EventFd`.
+ fn read_kick(&self) -> io::Result<bool> {
+ if let Some(kick) = &self.kick {
+ kick.read()?;
+ }
+
+ Ok(self.enabled)
+ }
+
+ /// Set `EventFd` for call.
+ fn set_call(&mut self, file: Option<File>) {
+ // SAFETY: see comment in set_kick()
+ self.call = file.map(|f| unsafe { EventFd::from_raw_fd(f.into_raw_fd()) });
+ }
+
+ /// Get the `EventFd` for call.
+ pub fn get_call(&self) -> &Option<EventFd> {
+ &self.call
+ }
+
+ /// Set `EventFd` for err.
+ fn set_err(&mut self, file: Option<File>) {
+ // SAFETY: see comment in set_kick()
+ self.err = file.map(|f| unsafe { EventFd::from_raw_fd(f.into_raw_fd()) });
+ }
+}
+
+/// A `VringState` object protected by Mutex for multi-threading context.
+#[derive(Clone)]
+pub struct VringMutex<M: GuestAddressSpace = GuestMemoryAtomic<GuestMemoryMmap>> {
+ state: Arc<Mutex<VringState<M>>>,
+}
+
+impl<M: GuestAddressSpace> VringMutex<M> {
+ /// Get a mutable guard to the underlying raw `VringState` object.
+ fn lock(&self) -> MutexGuard<VringState<M>> {
+ self.state.lock().unwrap()
+ }
+}
+
+impl<'a, M: 'a + GuestAddressSpace> VringStateGuard<'a, M> for VringMutex<M> {
+ type G = MutexGuard<'a, VringState<M>>;
+}
+
+impl<'a, M: 'a + GuestAddressSpace> VringStateMutGuard<'a, M> for VringMutex<M> {
+ type G = MutexGuard<'a, VringState<M>>;
+}
+
+impl<M: 'static + GuestAddressSpace> VringT<M> for VringMutex<M> {
+ fn new(mem: M, max_queue_size: u16) -> Result<Self, VirtQueError> {
+ Ok(VringMutex {
+ state: Arc::new(Mutex::new(VringState::new(mem, max_queue_size)?)),
+ })
+ }
+
+ fn get_ref(&self) -> <Self as VringStateGuard<M>>::G {
+ self.state.lock().unwrap()
+ }
+
+ fn get_mut(&self) -> <Self as VringStateMutGuard<M>>::G {
+ self.lock()
+ }
+
+ fn add_used(&self, desc_index: u16, len: u32) -> Result<(), VirtQueError> {
+ self.lock().add_used(desc_index, len)
+ }
+
+ fn signal_used_queue(&self) -> io::Result<()> {
+ self.get_ref().signal_used_queue()
+ }
+
+ fn enable_notification(&self) -> Result<bool, VirtQueError> {
+ self.lock().enable_notification()
+ }
+
+ fn disable_notification(&self) -> Result<(), VirtQueError> {
+ self.lock().disable_notification()
+ }
+
+ fn needs_notification(&self) -> Result<bool, VirtQueError> {
+ self.lock().needs_notification()
+ }
+
+ fn set_enabled(&self, enabled: bool) {
+ self.lock().set_enabled(enabled)
+ }
+
+ fn set_queue_info(
+ &self,
+ desc_table: u64,
+ avail_ring: u64,
+ used_ring: u64,
+ ) -> Result<(), VirtQueError> {
+ self.lock()
+ .set_queue_info(desc_table, avail_ring, used_ring)
+ }
+
+ fn queue_next_avail(&self) -> u16 {
+ self.get_ref().queue_next_avail()
+ }
+
+ fn set_queue_next_avail(&self, base: u16) {
+ self.lock().set_queue_next_avail(base)
+ }
+
+ fn set_queue_next_used(&self, idx: u16) {
+ self.lock().set_queue_next_used(idx)
+ }
+
+ fn queue_used_idx(&self) -> Result<u16, VirtQueError> {
+ self.lock().queue_used_idx()
+ }
+
+ fn set_queue_size(&self, num: u16) {
+ self.lock().set_queue_size(num);
+ }
+
+ fn set_queue_event_idx(&self, enabled: bool) {
+ self.lock().set_queue_event_idx(enabled);
+ }
+
+ fn set_queue_ready(&self, ready: bool) {
+ self.lock().set_queue_ready(ready);
+ }
+
+ fn set_kick(&self, file: Option<File>) {
+ self.lock().set_kick(file);
+ }
+
+ fn read_kick(&self) -> io::Result<bool> {
+ self.get_ref().read_kick()
+ }
+
+ fn set_call(&self, file: Option<File>) {
+ self.lock().set_call(file)
+ }
+
+ fn set_err(&self, file: Option<File>) {
+ self.lock().set_err(file)
+ }
+}
+
+/// A `VringState` object protected by RwLock for multi-threading context.
+#[derive(Clone)]
+pub struct VringRwLock<M: GuestAddressSpace = GuestMemoryAtomic<GuestMemoryMmap>> {
+ state: Arc<RwLock<VringState<M>>>,
+}
+
+impl<M: GuestAddressSpace> VringRwLock<M> {
+ /// Get a mutable guard to the underlying raw `VringState` object.
+ fn write_lock(&self) -> RwLockWriteGuard<VringState<M>> {
+ self.state.write().unwrap()
+ }
+}
+
+impl<'a, M: 'a + GuestAddressSpace> VringStateGuard<'a, M> for VringRwLock<M> {
+ type G = RwLockReadGuard<'a, VringState<M>>;
+}
+
+impl<'a, M: 'a + GuestAddressSpace> VringStateMutGuard<'a, M> for VringRwLock<M> {
+ type G = RwLockWriteGuard<'a, VringState<M>>;
+}
+
+impl<M: 'static + GuestAddressSpace> VringT<M> for VringRwLock<M> {
+ fn new(mem: M, max_queue_size: u16) -> Result<Self, VirtQueError> {
+ Ok(VringRwLock {
+ state: Arc::new(RwLock::new(VringState::new(mem, max_queue_size)?)),
+ })
+ }
+
+ fn get_ref(&self) -> <Self as VringStateGuard<M>>::G {
+ self.state.read().unwrap()
+ }
+
+ fn get_mut(&self) -> <Self as VringStateMutGuard<M>>::G {
+ self.write_lock()
+ }
+
+ fn add_used(&self, desc_index: u16, len: u32) -> Result<(), VirtQueError> {
+ self.write_lock().add_used(desc_index, len)
+ }
+
+ fn signal_used_queue(&self) -> io::Result<()> {
+ self.get_ref().signal_used_queue()
+ }
+
+ fn enable_notification(&self) -> Result<bool, VirtQueError> {
+ self.write_lock().enable_notification()
+ }
+
+ fn disable_notification(&self) -> Result<(), VirtQueError> {
+ self.write_lock().disable_notification()
+ }
+
+ fn needs_notification(&self) -> Result<bool, VirtQueError> {
+ self.write_lock().needs_notification()
+ }
+
+ fn set_enabled(&self, enabled: bool) {
+ self.write_lock().set_enabled(enabled)
+ }
+
+ fn set_queue_info(
+ &self,
+ desc_table: u64,
+ avail_ring: u64,
+ used_ring: u64,
+ ) -> Result<(), VirtQueError> {
+ self.write_lock()
+ .set_queue_info(desc_table, avail_ring, used_ring)
+ }
+
+ fn queue_next_avail(&self) -> u16 {
+ self.get_ref().queue_next_avail()
+ }
+
+ fn set_queue_next_avail(&self, base: u16) {
+ self.write_lock().set_queue_next_avail(base)
+ }
+
+ fn set_queue_next_used(&self, idx: u16) {
+ self.write_lock().set_queue_next_used(idx)
+ }
+
+ fn queue_used_idx(&self) -> Result<u16, VirtQueError> {
+ self.get_ref().queue_used_idx()
+ }
+
+ fn set_queue_size(&self, num: u16) {
+ self.write_lock().set_queue_size(num);
+ }
+
+ fn set_queue_event_idx(&self, enabled: bool) {
+ self.write_lock().set_queue_event_idx(enabled);
+ }
+
+ fn set_queue_ready(&self, ready: bool) {
+ self.write_lock().set_queue_ready(ready);
+ }
+
+ fn set_kick(&self, file: Option<File>) {
+ self.write_lock().set_kick(file);
+ }
+
+ fn read_kick(&self) -> io::Result<bool> {
+ self.get_ref().read_kick()
+ }
+
+ fn set_call(&self, file: Option<File>) {
+ self.write_lock().set_call(file)
+ }
+
+ fn set_err(&self, file: Option<File>) {
+ self.write_lock().set_err(file)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::os::unix::io::AsRawFd;
+ use vm_memory::bitmap::AtomicBitmap;
+ use vmm_sys_util::eventfd::EventFd;
+
+ #[test]
+ fn test_new_vring() {
+ let mem = GuestMemoryAtomic::new(
+ GuestMemoryMmap::<AtomicBitmap>::from_ranges(&[(GuestAddress(0x100000), 0x10000)])
+ .unwrap(),
+ );
+ let vring = VringMutex::new(mem, 0x1000).unwrap();
+
+ assert!(vring.get_ref().get_kick().is_none());
+ assert!(!vring.get_mut().enabled);
+ assert!(!vring.lock().queue.ready());
+ assert!(!vring.lock().queue.event_idx_enabled());
+
+ vring.set_enabled(true);
+ assert!(vring.get_ref().enabled);
+
+ vring.set_queue_info(0x100100, 0x100200, 0x100300).unwrap();
+ assert_eq!(vring.lock().get_queue().desc_table(), 0x100100);
+ assert_eq!(vring.lock().get_queue().avail_ring(), 0x100200);
+ assert_eq!(vring.lock().get_queue().used_ring(), 0x100300);
+
+ assert_eq!(vring.queue_next_avail(), 0);
+ vring.set_queue_next_avail(0x20);
+ assert_eq!(vring.queue_next_avail(), 0x20);
+
+ vring.set_queue_size(0x200);
+ assert_eq!(vring.lock().queue.size(), 0x200);
+
+ vring.set_queue_event_idx(true);
+ assert!(vring.lock().queue.event_idx_enabled());
+
+ vring.set_queue_ready(true);
+ assert!(vring.lock().queue.ready());
+ }
+
+ #[test]
+ fn test_vring_set_fd() {
+ let mem = GuestMemoryAtomic::new(
+ GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(),
+ );
+ let vring = VringMutex::new(mem, 0x1000).unwrap();
+
+ vring.set_enabled(true);
+ assert!(vring.get_ref().enabled);
+
+ let eventfd = EventFd::new(0).unwrap();
+ // SAFETY: Safe because we panic before if eventfd is not valid.
+ let file = unsafe { File::from_raw_fd(eventfd.as_raw_fd()) };
+ assert!(vring.get_mut().kick.is_none());
+ assert!(vring.read_kick().unwrap());
+ vring.set_kick(Some(file));
+ eventfd.write(1).unwrap();
+ assert!(vring.read_kick().unwrap());
+ assert!(vring.get_ref().kick.is_some());
+ vring.set_kick(None);
+ assert!(vring.get_ref().kick.is_none());
+ std::mem::forget(eventfd);
+
+ let eventfd = EventFd::new(0).unwrap();
+ // SAFETY: Safe because we panic before if eventfd is not valid.
+ let file = unsafe { File::from_raw_fd(eventfd.as_raw_fd()) };
+ assert!(vring.get_ref().call.is_none());
+ vring.set_call(Some(file));
+ assert!(vring.get_ref().call.is_some());
+ vring.set_call(None);
+ assert!(vring.get_ref().call.is_none());
+ std::mem::forget(eventfd);
+
+ let eventfd = EventFd::new(0).unwrap();
+ // SAFETY: Safe because we panic before if eventfd is not valid.
+ let file = unsafe { File::from_raw_fd(eventfd.as_raw_fd()) };
+ assert!(vring.get_ref().err.is_none());
+ vring.set_err(Some(file));
+ assert!(vring.get_ref().err.is_some());
+ vring.set_err(None);
+ assert!(vring.get_ref().err.is_none());
+ std::mem::forget(eventfd);
+ }
+}
diff --git a/src/vsock.rs b/src/vsock.rs
new file mode 100644
index 0000000..1e1b0b9
--- /dev/null
+++ b/src/vsock.rs
@@ -0,0 +1,30 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+//
+// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE-BSD-Google file.
+
+//! Trait to control vhost-vsock backend drivers.
+
+use crate::backend::VhostBackend;
+use crate::Result;
+
+/// Trait to control vhost-vsock backend drivers.
+pub trait VhostVsock: VhostBackend {
+ /// Set the CID for the guest.
+ /// This number is used for routing all data destined for running in the guest.
+ /// Each guest on a hypervisor must have an unique CID.
+ ///
+ /// # Arguments
+ /// * `cid` - CID to assign to the guest
+ fn set_guest_cid(&self, cid: u64) -> Result<()>;
+
+ /// Tell the VHOST driver to start performing data transfer.
+ fn start(&self) -> Result<()>;
+
+ /// Tell the VHOST driver to stop performing data transfer.
+ fn stop(&self) -> Result<()>;
+}
diff --git a/tests/vhost-user-server.rs b/tests/vhost-user-server.rs
new file mode 100644
index 0000000..f6fdea7
--- /dev/null
+++ b/tests/vhost-user-server.rs
@@ -0,0 +1,292 @@
+use std::ffi::CString;
+use std::fs::File;
+use std::io::Result;
+use std::os::unix::io::{AsRawFd, FromRawFd};
+use std::os::unix::net::UnixStream;
+use std::path::Path;
+use std::sync::{Arc, Barrier, Mutex};
+use std::thread;
+
+use vhost::vhost_user::message::{
+ VhostUserConfigFlags, VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures,
+};
+use vhost::vhost_user::{Listener, Master, Slave, VhostUserMaster};
+use vhost::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
+use vhost_user_backend::{VhostUserBackendMut, VhostUserDaemon, VringRwLock};
+use vm_memory::{
+ FileOffset, GuestAddress, GuestAddressSpace, GuestMemory, GuestMemoryAtomic, GuestMemoryMmap,
+};
+use vmm_sys_util::epoll::EventSet;
+use vmm_sys_util::eventfd::EventFd;
+
+struct MockVhostBackend {
+ events: u64,
+ event_idx: bool,
+ acked_features: u64,
+}
+
+impl MockVhostBackend {
+ fn new() -> Self {
+ MockVhostBackend {
+ events: 0,
+ event_idx: false,
+ acked_features: 0,
+ }
+ }
+}
+
+impl VhostUserBackendMut<VringRwLock, ()> for MockVhostBackend {
+ fn num_queues(&self) -> usize {
+ 2
+ }
+
+ fn max_queue_size(&self) -> usize {
+ 256
+ }
+
+ fn features(&self) -> u64 {
+ 0xffff_ffff_ffff_ffff
+ }
+
+ fn acked_features(&mut self, features: u64) {
+ self.acked_features = features;
+ }
+
+ fn protocol_features(&self) -> VhostUserProtocolFeatures {
+ VhostUserProtocolFeatures::all()
+ }
+
+ fn set_event_idx(&mut self, enabled: bool) {
+ self.event_idx = enabled;
+ }
+
+ fn get_config(&self, offset: u32, size: u32) -> Vec<u8> {
+ assert_eq!(offset, 0x200);
+ assert_eq!(size, 8);
+
+ vec![0xa5u8; 8]
+ }
+
+ fn set_config(&mut self, offset: u32, buf: &[u8]) -> Result<()> {
+ assert_eq!(offset, 0x200);
+ assert_eq!(buf, &[0xa5u8; 8]);
+
+ Ok(())
+ }
+
+ fn update_memory(&mut self, atomic_mem: GuestMemoryAtomic<GuestMemoryMmap>) -> Result<()> {
+ let mem = atomic_mem.memory();
+ let region = mem.find_region(GuestAddress(0x100000)).unwrap();
+ assert_eq!(region.size(), 0x100000);
+ Ok(())
+ }
+
+ fn set_slave_req_fd(&mut self, _slave: Slave) {}
+
+ fn queues_per_thread(&self) -> Vec<u64> {
+ vec![1, 1]
+ }
+
+ fn exit_event(&self, _thread_index: usize) -> Option<EventFd> {
+ let event_fd = EventFd::new(0).unwrap();
+
+ Some(event_fd)
+ }
+
+ fn handle_event(
+ &mut self,
+ _device_event: u16,
+ _evset: EventSet,
+ _vrings: &[VringRwLock],
+ _thread_id: usize,
+ ) -> Result<bool> {
+ self.events += 1;
+
+ Ok(false)
+ }
+}
+
+fn setup_master(path: &Path, barrier: Arc<Barrier>) -> Master {
+ barrier.wait();
+ let mut master = Master::connect(path, 1).unwrap();
+ master.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
+ // Wait before issue service requests.
+ barrier.wait();
+
+ let features = master.get_features().unwrap();
+ let proto = master.get_protocol_features().unwrap();
+ master.set_features(features).unwrap();
+ master.set_protocol_features(proto).unwrap();
+ assert!(proto.contains(VhostUserProtocolFeatures::REPLY_ACK));
+
+ master
+}
+
+fn vhost_user_client(path: &Path, barrier: Arc<Barrier>) {
+ barrier.wait();
+ let mut master = Master::connect(path, 1).unwrap();
+ master.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY);
+ // Wait before issue service requests.
+ barrier.wait();
+
+ let features = master.get_features().unwrap();
+ let proto = master.get_protocol_features().unwrap();
+ master.set_features(features).unwrap();
+ master.set_protocol_features(proto).unwrap();
+ assert!(proto.contains(VhostUserProtocolFeatures::REPLY_ACK));
+
+ let queue_num = master.get_queue_num().unwrap();
+ assert_eq!(queue_num, 2);
+
+ master.set_owner().unwrap();
+ //master.set_owner().unwrap_err();
+ master.reset_owner().unwrap();
+ master.reset_owner().unwrap();
+ master.set_owner().unwrap();
+
+ master.set_features(features).unwrap();
+ master.set_protocol_features(proto).unwrap();
+ assert!(proto.contains(VhostUserProtocolFeatures::REPLY_ACK));
+
+ let memfd = nix::sys::memfd::memfd_create(
+ &CString::new("test").unwrap(),
+ nix::sys::memfd::MemFdCreateFlag::empty(),
+ )
+ .unwrap();
+ // SAFETY: Safe because we panic before if memfd is not valid.
+ let file = unsafe { File::from_raw_fd(memfd) };
+ file.set_len(0x100000).unwrap();
+ let file_offset = FileOffset::new(file, 0);
+ let mem = GuestMemoryMmap::<()>::from_ranges_with_files(&[(
+ GuestAddress(0x100000),
+ 0x100000,
+ Some(file_offset),
+ )])
+ .unwrap();
+ let addr = mem.get_host_address(GuestAddress(0x100000)).unwrap() as u64;
+ let reg = mem.find_region(GuestAddress(0x100000)).unwrap();
+ let fd = reg.file_offset().unwrap();
+ let regions = [VhostUserMemoryRegionInfo::new(
+ 0x100000,
+ 0x100000,
+ addr,
+ 0,
+ fd.file().as_raw_fd(),
+ )];
+ master.set_mem_table(&regions).unwrap();
+
+ master.set_vring_num(0, 256).unwrap();
+
+ let config = VringConfigData {
+ queue_max_size: 256,
+ queue_size: 256,
+ flags: 0,
+ desc_table_addr: addr,
+ used_ring_addr: addr + 0x10000,
+ avail_ring_addr: addr + 0x20000,
+ log_addr: None,
+ };
+ master.set_vring_addr(0, &config).unwrap();
+
+ let eventfd = EventFd::new(0).unwrap();
+ master.set_vring_kick(0, &eventfd).unwrap();
+ master.set_vring_call(0, &eventfd).unwrap();
+ master.set_vring_err(0, &eventfd).unwrap();
+ master.set_vring_enable(0, true).unwrap();
+
+ let buf = [0u8; 8];
+ let (_cfg, data) = master
+ .get_config(0x200, 8, VhostUserConfigFlags::empty(), &buf)
+ .unwrap();
+ assert_eq!(&data, &[0xa5u8; 8]);
+ master
+ .set_config(0x200, VhostUserConfigFlags::empty(), &data)
+ .unwrap();
+
+ let (tx, _rx) = UnixStream::pair().unwrap();
+ master.set_slave_request_fd(&tx).unwrap();
+
+ let state = master.get_vring_base(0).unwrap();
+ master.set_vring_base(0, state as u16).unwrap();
+
+ assert_eq!(master.get_max_mem_slots().unwrap(), 32);
+ let region = VhostUserMemoryRegionInfo::new(0x800000, 0x100000, addr, 0, fd.file().as_raw_fd());
+ master.add_mem_region(&region).unwrap();
+ master.remove_mem_region(&region).unwrap();
+}
+
+fn vhost_user_server(cb: fn(&Path, Arc<Barrier>)) {
+ let mem = GuestMemoryAtomic::new(GuestMemoryMmap::<()>::new());
+ let backend = Arc::new(Mutex::new(MockVhostBackend::new()));
+ let mut daemon = VhostUserDaemon::new("test".to_owned(), backend, mem).unwrap();
+
+ let barrier = Arc::new(Barrier::new(2));
+ let tmpdir = tempfile::tempdir().unwrap();
+ let mut path = tmpdir.path().to_path_buf();
+ path.push("socket");
+
+ let barrier2 = barrier.clone();
+ let path1 = path.clone();
+ let thread = thread::spawn(move || cb(&path1, barrier2));
+
+ let listener = Listener::new(&path, false).unwrap();
+ barrier.wait();
+ daemon.start(listener).unwrap();
+ barrier.wait();
+
+ // handle service requests from clients.
+ thread.join().unwrap();
+}
+
+#[test]
+fn test_vhost_user_server() {
+ vhost_user_server(vhost_user_client);
+}
+
+fn vhost_user_enable(path: &Path, barrier: Arc<Barrier>) {
+ let master = setup_master(path, barrier);
+ master.set_owner().unwrap();
+ master.set_owner().unwrap_err();
+}
+
+#[test]
+fn test_vhost_user_enable() {
+ vhost_user_server(vhost_user_enable);
+}
+
+fn vhost_user_set_inflight(path: &Path, barrier: Arc<Barrier>) {
+ let mut master = setup_master(path, barrier);
+ let eventfd = EventFd::new(0).unwrap();
+ // No implementation for inflight_fd yet.
+ let inflight = VhostUserInflight {
+ mmap_size: 0x100000,
+ mmap_offset: 0,
+ num_queues: 1,
+ queue_size: 256,
+ };
+ master
+ .set_inflight_fd(&inflight, eventfd.as_raw_fd())
+ .unwrap_err();
+}
+
+#[test]
+fn test_vhost_user_set_inflight() {
+ vhost_user_server(vhost_user_set_inflight);
+}
+
+fn vhost_user_get_inflight(path: &Path, barrier: Arc<Barrier>) {
+ let mut master = setup_master(path, barrier);
+ // No implementation for inflight_fd yet.
+ let inflight = VhostUserInflight {
+ mmap_size: 0x100000,
+ mmap_offset: 0,
+ num_queues: 1,
+ queue_size: 256,
+ };
+ assert!(master.get_inflight_fd(&inflight).is_err());
+}
+
+#[test]
+fn test_vhost_user_get_inflight() {
+ vhost_user_server(vhost_user_get_inflight);
+}