diff options
author | Jeongik Cha <jeongik@google.com> | 2023-09-27 08:11:51 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2023-09-27 08:11:51 +0000 |
commit | fd8b88fc46ac8d4465b4ea1c963622be2c7f857c (patch) | |
tree | 5cd4944617e8f2802a8fe259dcbde9a8cc6326bc | |
parent | 088f9ee4aac7215be65f1941c100f7e14362e2f4 (diff) | |
parent | 777a4130a112a67b398cf73537c055a15030da3f (diff) | |
download | vhost-user-backend-fd8b88fc46ac8d4465b4ea1c963622be2c7f857c.tar.gz |
Import vhost-user-backend am: 777a4130a1
Original change: https://android-review.googlesource.com/c/platform/external/rust/crates/vhost-user-backend/+/2752353
Change-Id: I2c8695bbda6410995972a5b2a72377b789bab551
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
38 files changed, 11576 insertions, 0 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json new file mode 100644 index 0000000..03acefd --- /dev/null +++ b/.cargo_vcs_info.json @@ -0,0 +1,6 @@ +{ + "git": { + "sha1": "6ca88e160a8d34be54a158cc1bc702e9250e97dd" + }, + "path_in_vcs": "crates/vhost-user-backend" +}
\ No newline at end of file diff --git a/Android.bp b/Android.bp new file mode 100644 index 0000000..487cdd1 --- /dev/null +++ b/Android.bp @@ -0,0 +1,22 @@ +// This file is generated by cargo2android.py --config cargo2android.json. +// Do not modify this file as changes will be overridden on upgrade. + + + +rust_library_host { + name: "libvhost_user_backend", + crate_name: "vhost_user_backend", + cargo_env_compat: true, + cargo_pkg_version: "0.10.1", + srcs: ["src/lib.rs"], + edition: "2018", + rustlibs: [ + "liblibc", + "liblog_rust", + "libvhost_android", + "libvirtio_bindings", + "libvirtio_queue", + "libvm_memory_android", + "libvmm_sys_util", + ], +} diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..4ae24f3 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,104 @@ +# Changelog +## [Unreleased] + +### Added + +### Changed + +### Fixed + +### Deprecated + +## v0.10.1 + +### Fixed +- [[#180]](https://github.com/rust-vmm/vhost/pull/180) vhost-user-backend: fetch 'used' index from guest + +## v0.10.0 + +### Added +- [[#169]](https://github.com/rust-vmm/vhost/pull/160) vhost-user-backend: Add support for Xen memory mappings + +### Fixed +- [[#161]](https://github.com/rust-vmm/vhost/pull/161) get_vring_base should not reset the queue + +## v0.9.0 + +### Added +- [[#138]](https://github.com/rust-vmm/vhost/pull/138): vhost-user-backend: add repository metadata + +### Changed +- Updated dependency virtio-bindings 0.1.0 -> 0.2.0 +- Updated dependency virtio-queue 0.7.0 -> 0.8.0 +- Updated dependency vm-memory 0.10.0 -> 0.11.0 + +### Fixed +- [[#154]](https://github.com/rust-vmm/vhost/pull/154): Fix return value of GET_VRING_BASE message +- [[#142]](https://github.com/rust-vmm/vhost/pull/142): vhost_user: Slave requests aren't only FS specific + +## v0.8.0 + +### Added +- [[#120]](https://github.com/rust-vmm/vhost/pull/120): vhost_kern: vdpa: Add missing ioctls + +### Changed +- Updated dependency vhost 0.5 -> 0.6 +- Updated dependency virtio-queue 0.6 -> 0.7.0 +- Updated depepdency vm-memory 0.9 to 0.10.0 +- Updated depepdency vmm-sys-util 0.10 to 0.11.0 + +## v0.7.0 + +### Changed + +- Started using caret dependencies +- Updated dependency nix 0.24 -> 0.25 +- Updated depepdency log 0.4.6 -> 0.4.17 +- Updated dependency vhost 0.4 -> 0.5 +- Updated dependency virtio-queue 0.5.0 -> 0.6 +- Updated dependency vm-memory 0.7 -> 0.9 + +## v0.6.0 + +### Changed + +- Moved to rust-vmm/virtio-queue v0.5.0 + +### Fixed + +- Fixed vring initialization logic + +## v0.5.1 + +### Changed +- Moved to rust-vmm/vmm-sys-util 0.10.0 + +## v0.5.0 + +### Changed + +- Moved to rust-vmm/virtio-queue v0.4.0 + +## v0.4.0 + +### Changed + +- Moved to rust-vmm/virtio-queue v0.3.0 +- Relaxed rust-vmm/vm-memory dependency to require ">=0.7" + +## v0.3.0 + +### Changed + +- Moved to rust-vmm/vhost v0.4.0 + +## v0.2.0 + +### Added + +- Ability to run the daemon as a client +- VringEpollHandler implements AsRawFd + +## v0.1.0 + +First release diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..e8709f7 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,78 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2018" +name = "vhost-user-backend" +version = "0.10.1" +authors = ["The Cloud Hypervisor Authors"] +description = "A framework to build vhost-user backend service daemon" +readme = "README.md" +keywords = [ + "vhost-user", + "virtio", +] +license = "Apache-2.0" +repository = "https://github.com/rust-vmm/vhost" + +[dependencies.libc] +version = "0.2.39" + +[dependencies.log] +version = "0.4.17" + +[dependencies.vhost] +version = "0.8" +features = ["vhost-user-slave"] + +[dependencies.virtio-bindings] +version = "0.2.1" + +[dependencies.virtio-queue] +version = "0.9.0" + +[dependencies.vm-memory] +version = "0.12.0" +features = [ + "backend-mmap", + "backend-atomic", +] + +[dependencies.vmm-sys-util] +version = "0.11.0" + +[dev-dependencies.nix] +version = "0.26" + +[dev-dependencies.tempfile] +version = "3.2.0" + +[dev-dependencies.vhost] +version = "0.8" +features = [ + "test-utils", + "vhost-user-master", + "vhost-user-slave", +] + +[dev-dependencies.vm-memory] +version = "0.12.0" +features = [ + "backend-mmap", + "backend-atomic", + "backend-bitmap", +] + +[features] +xen = [ + "vm-memory/xen", + "vhost/xen", +] diff --git a/Cargo.toml.orig b/Cargo.toml.orig new file mode 100644 index 0000000..a42ea78 --- /dev/null +++ b/Cargo.toml.orig @@ -0,0 +1,27 @@ +[package] +name = "vhost-user-backend" +version = "0.10.1" +authors = ["The Cloud Hypervisor Authors"] +keywords = ["vhost-user", "virtio"] +description = "A framework to build vhost-user backend service daemon" +repository = "https://github.com/rust-vmm/vhost" +edition = "2018" +license = "Apache-2.0" + +[features] +xen = ["vm-memory/xen", "vhost/xen"] + +[dependencies] +libc = "0.2.39" +log = "0.4.17" +vhost = { path = "../vhost", version = "0.8", features = ["vhost-user-slave"] } +virtio-bindings = "0.2.1" +virtio-queue = "0.9.0" +vm-memory = { version = "0.12.0", features = ["backend-mmap", "backend-atomic"] } +vmm-sys-util = "0.11.0" + +[dev-dependencies] +nix = "0.26" +vhost = { path = "../vhost", version = "0.8", features = ["test-utils", "vhost-user-master", "vhost-user-slave"] } +vm-memory = { version = "0.12.0", features = ["backend-mmap", "backend-atomic", "backend-bitmap"] } +tempfile = "3.2.0" @@ -0,0 +1,235 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +--- + +// Copyright (C) 2019 Alibaba Cloud. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Alibaba Inc. nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +--- + diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/LICENSE-BSD-3-Clause b/LICENSE-BSD-3-Clause new file mode 100644 index 0000000..1ff0cd7 --- /dev/null +++ b/LICENSE-BSD-3-Clause @@ -0,0 +1,27 @@ +// Copyright (C) 2019 Alibaba Cloud. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Alibaba Inc. nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/METADATA b/METADATA new file mode 100644 index 0000000..76caad4 --- /dev/null +++ b/METADATA @@ -0,0 +1,19 @@ +name: "vhost-user-backend" +description: "A framework to build vhost-user backend service daemon" +third_party { + identifier { + type: "crates.io" + value: "https://crates.io/crates/vhost-user-backend" + } + identifier { + type: "Archive" + value: "https://static.crates.io/crates/vhost-user-backend/vhost-user-backend-0.10.1.crate" + } + version: "0.10.1" + license_type: NOTICE + last_upgrade_date { + year: 2023 + month: 8 + day: 23 + } +} diff --git a/MODULE_LICENSE_APACHE2 b/MODULE_LICENSE_APACHE2 new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/MODULE_LICENSE_APACHE2 @@ -0,0 +1 @@ +include platform/prebuilts/rust:master:/OWNERS diff --git a/README.md b/README.md new file mode 100644 index 0000000..a526ab1 --- /dev/null +++ b/README.md @@ -0,0 +1,117 @@ +# vhost-user-backend + +## Design + +The `vhost-user-backend` crate provides a framework to implement `vhost-user` backend services, +which includes following external public APIs: +- A daemon control object (`VhostUserDaemon`) to start and stop the service daemon. +- A vhost-user backend trait (`VhostUserBackendMut`) to handle vhost-user control messages and virtio + messages. +- A vring access trait (`VringT`) to access virtio queues, and three implementations of the trait: + `VringState`, `VringMutex` and `VringRwLock`. + +## Usage +The `vhost-user-backend` crate provides a framework to implement vhost-user backend services. The main interface provided by `vhost-user-backend` library is the `struct VhostUserDaemon`: +```rust +pub struct VhostUserDaemon<S, V, B = ()> +where + S: VhostUserBackend<V, B>, + V: VringT<GM<B>> + Clone + Send + Sync + 'static, + B: Bitmap + 'static, +{ + pub fn new(name: String, backend: S, atomic_mem: GuestMemoryAtomic<GuestMemoryMmap<B>>) -> Result<Self>; + pub fn start(&mut self, listener: Listener) -> Result<()>; + pub fn wait(&mut self) -> Result<()>; + pub fn get_epoll_handlers(&self) -> Vec<Arc<VringEpollHandler<S, V, B>>>; +} +``` + +### Create a `VhostUserDaemon` Instance +The `VhostUserDaemon::new()` creates an instance of `VhostUserDaemon` object. The client needs to +pass in an `VhostUserBackend` object, which will be used to configure the `VhostUserDaemon` +instance, handle control messages from the vhost-user master and handle virtio requests from +virtio queues. A group of working threads will be created to handle virtio requests from configured +virtio queues. + +### Start the `VhostUserDaemon` +The `VhostUserDaemon::start()` method waits for an incoming connection from the vhost-user masters +on the `listener`. Once a connection is ready, a main thread will be created to handle vhost-user +messages from the vhost-user master. + +### Stop the `VhostUserDaemon` +The `VhostUserDaemon::stop()` method waits for the main thread to exit. An exit event must be sent +to the main thread by writing to the `exit_event` EventFd before waiting for it to exit. + +### Threading Model +The main thread and virtio queue working threads will concurrently access the underlying virtio +queues, so all virtio queue in multi-threading model. But the main thread only accesses virtio +queues for configuration, so client could adopt locking policies to optimize for the virtio queue +working threads. + +## Example +Example code to handle virtio messages from a virtio queue: +```rust +impl VhostUserBackendMut for VhostUserService { + fn process_queue(&mut self, vring: &VringMutex) -> Result<bool> { + let mut used_any = false; + let mem = match &self.mem { + Some(m) => m.memory(), + None => return Err(Error::NoMemoryConfigured), + }; + + let mut vring_state = vring.get_mut(); + + while let Some(avail_desc) = vring_state + .get_queue_mut() + .iter() + .map_err(|_| Error::IterateQueue)? + .next() + { + // Process the request... + + if self.event_idx { + if vring_state.add_used(head_index, 0).is_err() { + warn!("Couldn't return used descriptors to the ring"); + } + + match vring_state.needs_notification() { + Err(_) => { + warn!("Couldn't check if queue needs to be notified"); + vring_state.signal_used_queue().unwrap(); + } + Ok(needs_notification) => { + if needs_notification { + vring_state.signal_used_queue().unwrap(); + } + } + } + } else { + if vring_state.add_used(head_index, 0).is_err() { + warn!("Couldn't return used descriptors to the ring"); + } + vring_state.signal_used_queue().unwrap(); + } + } + + Ok(used_any) + } +} +``` + +## Xen support + +Supporting Xen requires special handling while mapping the guest memory. The +`vm-memory` crate implements xen memory mapping support via a separate feature +`xen`, and this crate uses the same feature name to enable Xen support. + +Also, for xen mappings, the memory regions passed by the frontend contains few +extra fields as described in the vhost-user protocol documentation. + +It was decided by the `rust-vmm` maintainers to keep the interface simple and +build the crate for either standard Unix memory mapping or Xen, and not both. + +## License + +This project is licensed under + +- [Apache License](http://www.apache.org/licenses/LICENSE-2.0), Version 2.0 diff --git a/cargo2android.json b/cargo2android.json new file mode 100644 index 0000000..6f1f3ce --- /dev/null +++ b/cargo2android.json @@ -0,0 +1,7 @@ +{ + "run": true, + "dep-suffixes": { + "vhost": "_android", + "vm_memory": "_android" + } +}
\ No newline at end of file diff --git a/docs/vhost_architecture.drawio b/docs/vhost_architecture.drawio new file mode 100644 index 0000000..8c669d8 --- /dev/null +++ b/docs/vhost_architecture.drawio @@ -0,0 +1,171 @@ +<mxfile host="65bd71144e" modified="2021-02-22T05:37:26.833Z" agent="5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Code/1.53.0 Chrome/87.0.4280.141 Electron/11.2.1 Safari/537.36" etag="HWRXqybJYJqQhnlJWfmB" version="14.2.4" type="embed"> + <diagram id="xCgrIAQPDQM0eynUYBOE" name="Page-1"> + <mxGraphModel dx="3446" dy="1284" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0"> + <root> + <mxCell id="0"/> + <mxCell id="1" parent="0"/> + <mxCell id="46" value="<br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br>" style="rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;fillColor=none;strokeColor=#4D4D4D;strokeWidth=5;" vertex="1" parent="1"> + <mxGeometry x="1620" y="27" width="450" height="990" as="geometry"/> + </mxCell> + <mxCell id="47" value="" style="shape=hexagon;perimeter=hexagonPerimeter2;whiteSpace=wrap;html=1;fixedSize=1;rounded=0;labelBackgroundColor=none;sketch=0;fillColor=none;fontSize=25;dashed=1;strokeWidth=6;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="790" y="237" width="1260" height="750" as="geometry"/> + </mxCell> + <mxCell id="44" value="" style="rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;fillColor=none;strokeColor=#4D4D4D;strokeWidth=5;" vertex="1" parent="1"> + <mxGeometry x="-10" y="37" width="1250" height="670" as="geometry"/> + </mxCell> + <mxCell id="2" value="<pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;">MasterReqHandler</pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1"> + <mxGeometry x="830" y="477" width="220" height="50" as="geometry"/> + </mxCell> + <mxCell id="4" value="<pre style="font-size: 16.5pt; font-weight: 700; font-family: &quot;jetbrains mono&quot;, monospace;">VhostUserMasterReqHandler</pre>" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1"> + <mxGeometry x="840" y="597" width="360" height="60" as="geometry"/> + </mxCell> + <mxCell id="6" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;entryX=1;entryY=0.5;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="5" target="2"> + <mxGeometry relative="1" as="geometry"> + <Array as="points"> + <mxPoint x="1280" y="792"/> + <mxPoint x="1280" y="502"/> + </Array> + </mxGeometry> + </mxCell> + <mxCell id="5" value="<pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">Slave</pre></pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1"> + <mxGeometry x="1715" y="767" width="220" height="50" as="geometry"/> + </mxCell> + <mxCell id="7" value="<pre style="font-size: 16.5pt; font-weight: 700; font-family: &quot;jetbrains mono&quot;, monospace;">VhostUserMasterReqHandlerMut</pre>" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="1630" y="657" width="390" height="60" as="geometry"/> + </mxCell> + <mxCell id="8" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="2" target="4"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="950" y="657" as="sourcePoint"/> + <mxPoint x="680" y="717" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="10" value="<pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">SlaveListener</pre></pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="1360" y="472" width="190" height="50" as="geometry"/> + </mxCell> + <mxCell id="11" value="<pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">SlaveReqHandler</pre></pre></pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="1712" y="387" width="210" height="50" as="geometry"/> + </mxCell> + <mxCell id="14" value="<pre style="font-size: 16.5pt; font-weight: 700; font-family: &quot;jetbrains mono&quot;, monospace;"><pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;">VhostUserSlaveReqHandler</pre></pre>" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="1652" y="537" width="330" height="60" as="geometry"/> + </mxCell> + <mxCell id="15" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="11" target="14"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="1202" y="567" as="sourcePoint"/> + <mxPoint x="1202" y="667" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="16" value="<pre style="font-size: 16.5pt; font-weight: 700; font-family: &quot;jetbrains mono&quot;, monospace;">VhostBackend</pre>" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#00994D;strokeColor=#009900;" vertex="1" parent="1"> + <mxGeometry x="390" y="197" width="250" height="60" as="geometry"/> + </mxCell> + <mxCell id="17" value="<pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;">VhostKernBackend</pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1"> + <mxGeometry x="530" y="387" width="220" height="50" as="geometry"/> + </mxCell> + <mxCell id="18" value="<pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;">VhostVdpaBackend</pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1"> + <mxGeometry x="270" y="387" width="220" height="50" as="geometry"/> + </mxCell> + <mxCell id="19" value="<pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">Master</pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="820" y="387" width="220" height="50" as="geometry"/> + </mxCell> + <mxCell id="20" value="<pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">VhostSoftBackend</pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1"> + <mxGeometry x="10" y="387" width="220" height="50" as="geometry"/> + </mxCell> + <mxCell id="21" value="Handle virtque in VMM" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1"> + <mxGeometry x="10" y="557" width="220" height="120" as="geometry"/> + </mxCell> + <mxCell id="23" value="Handle virtque in hardware" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1"> + <mxGeometry x="270" y="807" width="220" height="120" as="geometry"/> + </mxCell> + <mxCell id="24" value="Handle virtque in kernel" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1"> + <mxGeometry x="530" y="807" width="220" height="120" as="geometry"/> + </mxCell> + <mxCell id="25" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#0000CC;" edge="1" parent="1" source="24" target="17"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="930" y="647" as="sourcePoint"/> + <mxPoint x="930" y="747" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="26" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0;entryY=0.5;entryDx=0;entryDy=0;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="19" target="11"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="840" y="917" as="sourcePoint"/> + <mxPoint x="840" y="1017" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="27" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;fontColor=#808080;strokeColor=#808080;" edge="1" parent="1" source="23" target="18"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="420" y="807" as="sourcePoint"/> + <mxPoint x="420" y="907" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="28" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;fontColor=#808080;strokeColor=#808080;" edge="1" parent="1" source="21" target="20"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="240" y="857" as="sourcePoint"/> + <mxPoint x="240" y="957" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="30" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="20" target="16"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="910" y="647" as="sourcePoint"/> + <mxPoint x="910" y="747" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="31" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;strokeColor=#00994D;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" source="18" target="16"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="1000" y="177" as="sourcePoint"/> + <mxPoint x="530" y="227" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="32" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="17" target="16"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="1010" y="127" as="sourcePoint"/> + <mxPoint x="1505" y="-73" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="35" value="<pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">Endpoint</pre></pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="1360" y="552" width="190" height="50" as="geometry"/> + </mxCell> + <mxCell id="36" value="<pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">Message</pre></pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="1360" y="632" width="190" height="50" as="geometry"/> + </mxCell> + <mxCell id="37" value="<pre style="font-size: 16.5pt ; font-weight: 700 ; font-family: &quot;jetbrains mono&quot; , monospace"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">VhostUserMaster</pre></pre>" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;strokeColor=#FF33FF;fontColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="980" y="257" width="230" height="60" as="geometry"/> + </mxCell> + <mxCell id="38" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="19" target="16"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="1030" y="527" as="sourcePoint"/> + <mxPoint x="515" y="257" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="39" value="Handle virtque in remote process" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="850" y="807" width="220" height="120" as="geometry"/> + </mxCell> + <mxCell id="41" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="5" target="7"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="1860" y="187" as="sourcePoint"/> + <mxPoint x="1860" y="267" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="43" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#FF00FF;" edge="1" parent="1" source="19" target="37"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="1430" y="187" as="sourcePoint"/> + <mxPoint x="2102" y="187" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="49" value="<pre style="font-size: 16.5pt ; font-weight: 700 ; font-family: &#34;jetbrains mono&#34; , monospace"><pre style="font-family: &#34;jetbrains mono&#34; , monospace ; font-size: 16.5pt">Trait</pre></pre>" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;strokeColor=#FF33FF;fontColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="60" y="1017" width="130" height="60" as="geometry"/> + </mxCell> + <mxCell id="51" value="Vhost-user protocol" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fontColor=#FF00FF;fillColor=none;strokeColor=none;" vertex="1" parent="1"> + <mxGeometry x="1220" y="817" width="330" height="150" as="geometry"/> + </mxCell> + <mxCell id="52" value="Vhost-user server" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fillColor=none;strokeColor=none;fontColor=#4D4D4D;" vertex="1" parent="1"> + <mxGeometry x="1680" y="57" width="330" height="150" as="geometry"/> + </mxCell> + <mxCell id="53" value="VMM" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fillColor=none;strokeColor=none;fontColor=#4D4D4D;" vertex="1" parent="1"> + <mxGeometry x="20" y="47" width="240" height="150" as="geometry"/> + </mxCell> + <mxCell id="54" value="<pre style="font-family: &#34;jetbrains mono&#34; , monospace ; font-size: 16.5pt">Struct</pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1"> + <mxGeometry x="240" y="1022" width="140" height="55" as="geometry"/> + </mxCell> + </root> + </mxGraphModel> + </diagram> +</mxfile>
\ No newline at end of file diff --git a/docs/vhost_architecture.png b/docs/vhost_architecture.png Binary files differnew file mode 100644 index 0000000..4d1e2bc --- /dev/null +++ b/docs/vhost_architecture.png diff --git a/src/backend.rs b/src/backend.rs new file mode 100644 index 0000000..43ab7b9 --- /dev/null +++ b/src/backend.rs @@ -0,0 +1,551 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// Copyright 2019-2021 Alibaba Cloud. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 + +//! Traits for vhost user backend servers to implement virtio data plain services. +//! +//! Define two traits for vhost user backend servers to implement virtio data plane services. +//! The only difference between the two traits is mutability. The [VhostUserBackend] trait is +//! designed with interior mutability, so the implementor may choose the suitable way to protect +//! itself from concurrent accesses. The [VhostUserBackendMut] is designed without interior +//! mutability, and an implementation of: +//! ```ignore +//! impl<T: VhostUserBackendMut> VhostUserBackend for RwLock<T> { } +//! ``` +//! is provided for convenience. +//! +//! [VhostUserBackend]: trait.VhostUserBackend.html +//! [VhostUserBackendMut]: trait.VhostUserBackendMut.html + +use std::io::Result; +use std::ops::Deref; +use std::sync::{Arc, Mutex, RwLock}; + +use vhost::vhost_user::message::VhostUserProtocolFeatures; +use vhost::vhost_user::Slave; +use vm_memory::bitmap::Bitmap; +use vmm_sys_util::epoll::EventSet; +use vmm_sys_util::eventfd::EventFd; + +use super::vring::VringT; +use super::GM; + +/// Trait with interior mutability for vhost user backend servers to implement concrete services. +/// +/// To support multi-threading and asynchronous IO, we enforce `Send + Sync` bound. +pub trait VhostUserBackend<V, B = ()>: Send + Sync +where + V: VringT<GM<B>>, + B: Bitmap + 'static, +{ + /// Get number of queues supported. + fn num_queues(&self) -> usize; + + /// Get maximum queue size supported. + fn max_queue_size(&self) -> usize; + + /// Get available virtio features. + fn features(&self) -> u64; + + /// Set acknowledged virtio features. + fn acked_features(&self, _features: u64) {} + + /// Get available vhost protocol features. + fn protocol_features(&self) -> VhostUserProtocolFeatures; + + /// Enable or disable the virtio EVENT_IDX feature + fn set_event_idx(&self, enabled: bool); + + /// Get virtio device configuration. + /// + /// A default implementation is provided as we cannot expect all backends to implement this + /// function. + fn get_config(&self, _offset: u32, _size: u32) -> Vec<u8> { + Vec::new() + } + + /// Set virtio device configuration. + /// + /// A default implementation is provided as we cannot expect all backends to implement this + /// function. + fn set_config(&self, _offset: u32, _buf: &[u8]) -> Result<()> { + Ok(()) + } + + /// Update guest memory regions. + fn update_memory(&self, mem: GM<B>) -> Result<()>; + + /// Set handler for communicating with the master by the slave communication channel. + /// + /// A default implementation is provided as we cannot expect all backends to implement this + /// function. + fn set_slave_req_fd(&self, _slave: Slave) {} + + /// Get the map to map queue index to worker thread index. + /// + /// A return value of [2, 2, 4] means: the first two queues will be handled by worker thread 0, + /// the following two queues will be handled by worker thread 1, and the last four queues will + /// be handled by worker thread 2. + fn queues_per_thread(&self) -> Vec<u64> { + vec![0xffff_ffff] + } + + /// Provide an optional exit EventFd for the specified worker thread. + /// + /// If an (`EventFd`, `token`) pair is returned, the returned `EventFd` will be monitored for IO + /// events by using epoll with the specified `token`. When the returned EventFd is written to, + /// the worker thread will exit. + fn exit_event(&self, _thread_index: usize) -> Option<EventFd> { + None + } + + /// Handle IO events for backend registered file descriptors. + /// + /// This function gets called if the backend registered some additional listeners onto specific + /// file descriptors. The library can handle virtqueues on its own, but does not know what to + /// do with events happening on custom listeners. + fn handle_event( + &self, + device_event: u16, + evset: EventSet, + vrings: &[V], + thread_id: usize, + ) -> Result<bool>; +} + +/// Trait without interior mutability for vhost user backend servers to implement concrete services. +pub trait VhostUserBackendMut<V, B = ()>: Send + Sync +where + V: VringT<GM<B>>, + B: Bitmap + 'static, +{ + /// Get number of queues supported. + fn num_queues(&self) -> usize; + + /// Get maximum queue size supported. + fn max_queue_size(&self) -> usize; + + /// Get available virtio features. + fn features(&self) -> u64; + + /// Set acknowledged virtio features. + fn acked_features(&mut self, _features: u64) {} + + /// Get available vhost protocol features. + fn protocol_features(&self) -> VhostUserProtocolFeatures; + + /// Enable or disable the virtio EVENT_IDX feature + fn set_event_idx(&mut self, enabled: bool); + + /// Get virtio device configuration. + /// + /// A default implementation is provided as we cannot expect all backends to implement this + /// function. + fn get_config(&self, _offset: u32, _size: u32) -> Vec<u8> { + Vec::new() + } + + /// Set virtio device configuration. + /// + /// A default implementation is provided as we cannot expect all backends to implement this + /// function. + fn set_config(&mut self, _offset: u32, _buf: &[u8]) -> Result<()> { + Ok(()) + } + + /// Update guest memory regions. + fn update_memory(&mut self, mem: GM<B>) -> Result<()>; + + /// Set handler for communicating with the master by the slave communication channel. + /// + /// A default implementation is provided as we cannot expect all backends to implement this + /// function. + fn set_slave_req_fd(&mut self, _slave: Slave) {} + + /// Get the map to map queue index to worker thread index. + /// + /// A return value of [2, 2, 4] means: the first two queues will be handled by worker thread 0, + /// the following two queues will be handled by worker thread 1, and the last four queues will + /// be handled by worker thread 2. + fn queues_per_thread(&self) -> Vec<u64> { + vec![0xffff_ffff] + } + + /// Provide an optional exit EventFd for the specified worker thread. + /// + /// If an (`EventFd`, `token`) pair is returned, the returned `EventFd` will be monitored for IO + /// events by using epoll with the specified `token`. When the returned EventFd is written to, + /// the worker thread will exit. + fn exit_event(&self, _thread_index: usize) -> Option<EventFd> { + None + } + + /// Handle IO events for backend registered file descriptors. + /// + /// This function gets called if the backend registered some additional listeners onto specific + /// file descriptors. The library can handle virtqueues on its own, but does not know what to + /// do with events happening on custom listeners. + fn handle_event( + &mut self, + device_event: u16, + evset: EventSet, + vrings: &[V], + thread_id: usize, + ) -> Result<bool>; +} + +impl<T: VhostUserBackend<V, B>, V, B> VhostUserBackend<V, B> for Arc<T> +where + V: VringT<GM<B>>, + B: Bitmap + 'static, +{ + fn num_queues(&self) -> usize { + self.deref().num_queues() + } + + fn max_queue_size(&self) -> usize { + self.deref().max_queue_size() + } + + fn features(&self) -> u64 { + self.deref().features() + } + + fn acked_features(&self, features: u64) { + self.deref().acked_features(features) + } + + fn protocol_features(&self) -> VhostUserProtocolFeatures { + self.deref().protocol_features() + } + + fn set_event_idx(&self, enabled: bool) { + self.deref().set_event_idx(enabled) + } + + fn get_config(&self, offset: u32, size: u32) -> Vec<u8> { + self.deref().get_config(offset, size) + } + + fn set_config(&self, offset: u32, buf: &[u8]) -> Result<()> { + self.deref().set_config(offset, buf) + } + + fn update_memory(&self, mem: GM<B>) -> Result<()> { + self.deref().update_memory(mem) + } + + fn set_slave_req_fd(&self, slave: Slave) { + self.deref().set_slave_req_fd(slave) + } + + fn queues_per_thread(&self) -> Vec<u64> { + self.deref().queues_per_thread() + } + + fn exit_event(&self, thread_index: usize) -> Option<EventFd> { + self.deref().exit_event(thread_index) + } + + fn handle_event( + &self, + device_event: u16, + evset: EventSet, + vrings: &[V], + thread_id: usize, + ) -> Result<bool> { + self.deref() + .handle_event(device_event, evset, vrings, thread_id) + } +} + +impl<T: VhostUserBackendMut<V, B>, V, B> VhostUserBackend<V, B> for Mutex<T> +where + V: VringT<GM<B>>, + B: Bitmap + 'static, +{ + fn num_queues(&self) -> usize { + self.lock().unwrap().num_queues() + } + + fn max_queue_size(&self) -> usize { + self.lock().unwrap().max_queue_size() + } + + fn features(&self) -> u64 { + self.lock().unwrap().features() + } + + fn acked_features(&self, features: u64) { + self.lock().unwrap().acked_features(features) + } + + fn protocol_features(&self) -> VhostUserProtocolFeatures { + self.lock().unwrap().protocol_features() + } + + fn set_event_idx(&self, enabled: bool) { + self.lock().unwrap().set_event_idx(enabled) + } + + fn get_config(&self, offset: u32, size: u32) -> Vec<u8> { + self.lock().unwrap().get_config(offset, size) + } + + fn set_config(&self, offset: u32, buf: &[u8]) -> Result<()> { + self.lock().unwrap().set_config(offset, buf) + } + + fn update_memory(&self, mem: GM<B>) -> Result<()> { + self.lock().unwrap().update_memory(mem) + } + + fn set_slave_req_fd(&self, slave: Slave) { + self.lock().unwrap().set_slave_req_fd(slave) + } + + fn queues_per_thread(&self) -> Vec<u64> { + self.lock().unwrap().queues_per_thread() + } + + fn exit_event(&self, thread_index: usize) -> Option<EventFd> { + self.lock().unwrap().exit_event(thread_index) + } + + fn handle_event( + &self, + device_event: u16, + evset: EventSet, + vrings: &[V], + thread_id: usize, + ) -> Result<bool> { + self.lock() + .unwrap() + .handle_event(device_event, evset, vrings, thread_id) + } +} + +impl<T: VhostUserBackendMut<V, B>, V, B> VhostUserBackend<V, B> for RwLock<T> +where + V: VringT<GM<B>>, + B: Bitmap + 'static, +{ + fn num_queues(&self) -> usize { + self.read().unwrap().num_queues() + } + + fn max_queue_size(&self) -> usize { + self.read().unwrap().max_queue_size() + } + + fn features(&self) -> u64 { + self.read().unwrap().features() + } + + fn acked_features(&self, features: u64) { + self.write().unwrap().acked_features(features) + } + + fn protocol_features(&self) -> VhostUserProtocolFeatures { + self.read().unwrap().protocol_features() + } + + fn set_event_idx(&self, enabled: bool) { + self.write().unwrap().set_event_idx(enabled) + } + + fn get_config(&self, offset: u32, size: u32) -> Vec<u8> { + self.read().unwrap().get_config(offset, size) + } + + fn set_config(&self, offset: u32, buf: &[u8]) -> Result<()> { + self.write().unwrap().set_config(offset, buf) + } + + fn update_memory(&self, mem: GM<B>) -> Result<()> { + self.write().unwrap().update_memory(mem) + } + + fn set_slave_req_fd(&self, slave: Slave) { + self.write().unwrap().set_slave_req_fd(slave) + } + + fn queues_per_thread(&self) -> Vec<u64> { + self.read().unwrap().queues_per_thread() + } + + fn exit_event(&self, thread_index: usize) -> Option<EventFd> { + self.read().unwrap().exit_event(thread_index) + } + + fn handle_event( + &self, + device_event: u16, + evset: EventSet, + vrings: &[V], + thread_id: usize, + ) -> Result<bool> { + self.write() + .unwrap() + .handle_event(device_event, evset, vrings, thread_id) + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + use crate::VringRwLock; + use std::sync::Mutex; + use vm_memory::{GuestAddress, GuestMemoryAtomic, GuestMemoryMmap}; + + pub struct MockVhostBackend { + events: u64, + event_idx: bool, + acked_features: u64, + } + + impl MockVhostBackend { + pub fn new() -> Self { + MockVhostBackend { + events: 0, + event_idx: false, + acked_features: 0, + } + } + } + + impl VhostUserBackendMut<VringRwLock, ()> for MockVhostBackend { + fn num_queues(&self) -> usize { + 2 + } + + fn max_queue_size(&self) -> usize { + 256 + } + + fn features(&self) -> u64 { + 0xffff_ffff_ffff_ffff + } + + fn acked_features(&mut self, features: u64) { + self.acked_features = features; + } + + fn protocol_features(&self) -> VhostUserProtocolFeatures { + VhostUserProtocolFeatures::all() + } + + fn set_event_idx(&mut self, enabled: bool) { + self.event_idx = enabled; + } + + fn get_config(&self, offset: u32, size: u32) -> Vec<u8> { + assert_eq!(offset, 0x200); + assert_eq!(size, 8); + + vec![0xa5u8; 8] + } + + fn set_config(&mut self, offset: u32, buf: &[u8]) -> Result<()> { + assert_eq!(offset, 0x200); + assert_eq!(buf.len(), 8); + assert_eq!(buf, &[0xa5u8; 8]); + + Ok(()) + } + + fn update_memory(&mut self, _atomic_mem: GuestMemoryAtomic<GuestMemoryMmap>) -> Result<()> { + Ok(()) + } + + fn set_slave_req_fd(&mut self, _slave: Slave) {} + + fn queues_per_thread(&self) -> Vec<u64> { + vec![1, 1] + } + + fn exit_event(&self, _thread_index: usize) -> Option<EventFd> { + let event_fd = EventFd::new(0).unwrap(); + + Some(event_fd) + } + + fn handle_event( + &mut self, + _device_event: u16, + _evset: EventSet, + _vrings: &[VringRwLock], + _thread_id: usize, + ) -> Result<bool> { + self.events += 1; + + Ok(false) + } + } + + #[test] + fn test_new_mock_backend_mutex() { + let backend = Arc::new(Mutex::new(MockVhostBackend::new())); + + assert_eq!(backend.num_queues(), 2); + assert_eq!(backend.max_queue_size(), 256); + assert_eq!(backend.features(), 0xffff_ffff_ffff_ffff); + assert_eq!( + backend.protocol_features(), + VhostUserProtocolFeatures::all() + ); + assert_eq!(backend.queues_per_thread(), [1, 1]); + + assert_eq!(backend.get_config(0x200, 8), vec![0xa5; 8]); + backend.set_config(0x200, &[0xa5; 8]).unwrap(); + + backend.acked_features(0xffff); + assert_eq!(backend.lock().unwrap().acked_features, 0xffff); + + backend.set_event_idx(true); + assert!(backend.lock().unwrap().event_idx); + + let _ = backend.exit_event(0).unwrap(); + + let mem = GuestMemoryAtomic::new( + GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(), + ); + backend.update_memory(mem).unwrap(); + } + + #[test] + fn test_new_mock_backend_rwlock() { + let backend = Arc::new(RwLock::new(MockVhostBackend::new())); + + assert_eq!(backend.num_queues(), 2); + assert_eq!(backend.max_queue_size(), 256); + assert_eq!(backend.features(), 0xffff_ffff_ffff_ffff); + assert_eq!( + backend.protocol_features(), + VhostUserProtocolFeatures::all() + ); + assert_eq!(backend.queues_per_thread(), [1, 1]); + + assert_eq!(backend.get_config(0x200, 8), vec![0xa5; 8]); + backend.set_config(0x200, &[0xa5; 8]).unwrap(); + + backend.acked_features(0xffff); + assert_eq!(backend.read().unwrap().acked_features, 0xffff); + + backend.set_event_idx(true); + assert!(backend.read().unwrap().event_idx); + + let _ = backend.exit_event(0).unwrap(); + + let mem = GuestMemoryAtomic::new( + GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(), + ); + backend.update_memory(mem.clone()).unwrap(); + + let vring = VringRwLock::new(mem, 0x1000).unwrap(); + backend + .handle_event(0x1, EventSet::IN, &[vring], 0) + .unwrap(); + } +} diff --git a/src/event_loop.rs b/src/event_loop.rs new file mode 100644 index 0000000..f10aad3 --- /dev/null +++ b/src/event_loop.rs @@ -0,0 +1,270 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// Copyright 2019-2021 Alibaba Cloud. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 + +use std::fmt::{Display, Formatter}; +use std::io::{self, Result}; +use std::marker::PhantomData; +use std::os::unix::io::{AsRawFd, RawFd}; + +use vm_memory::bitmap::Bitmap; +use vmm_sys_util::epoll::{ControlOperation, Epoll, EpollEvent, EventSet}; +use vmm_sys_util::eventfd::EventFd; + +use super::backend::VhostUserBackend; +use super::vring::VringT; +use super::GM; + +/// Errors related to vring epoll event handling. +#[derive(Debug)] +pub enum VringEpollError { + /// Failed to create epoll file descriptor. + EpollCreateFd(io::Error), + /// Failed while waiting for events. + EpollWait(io::Error), + /// Could not register exit event + RegisterExitEvent(io::Error), + /// Failed to read the event from kick EventFd. + HandleEventReadKick(io::Error), + /// Failed to handle the event from the backend. + HandleEventBackendHandling(io::Error), +} + +impl Display for VringEpollError { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match self { + VringEpollError::EpollCreateFd(e) => write!(f, "cannot create epoll fd: {}", e), + VringEpollError::EpollWait(e) => write!(f, "failed to wait for epoll event: {}", e), + VringEpollError::RegisterExitEvent(e) => write!(f, "cannot register exit event: {}", e), + VringEpollError::HandleEventReadKick(e) => { + write!(f, "cannot read vring kick event: {}", e) + } + VringEpollError::HandleEventBackendHandling(e) => { + write!(f, "failed to handle epoll event: {}", e) + } + } + } +} + +impl std::error::Error for VringEpollError {} + +/// Result of vring epoll operations. +pub type VringEpollResult<T> = std::result::Result<T, VringEpollError>; + +/// Epoll event handler to manage and process epoll events for registered file descriptor. +/// +/// The `VringEpollHandler` structure provides interfaces to: +/// - add file descriptors to be monitored by the epoll fd +/// - remove registered file descriptors from the epoll fd +/// - run the event loop to handle pending events on the epoll fd +pub struct VringEpollHandler<S, V, B> { + epoll: Epoll, + backend: S, + vrings: Vec<V>, + thread_id: usize, + exit_event_fd: Option<EventFd>, + phantom: PhantomData<B>, +} + +impl<S, V, B> VringEpollHandler<S, V, B> { + /// Send `exit event` to break the event loop. + pub fn send_exit_event(&self) { + if let Some(eventfd) = self.exit_event_fd.as_ref() { + let _ = eventfd.write(1); + } + } +} + +impl<S, V, B> VringEpollHandler<S, V, B> +where + S: VhostUserBackend<V, B>, + V: VringT<GM<B>>, + B: Bitmap + 'static, +{ + /// Create a `VringEpollHandler` instance. + pub(crate) fn new(backend: S, vrings: Vec<V>, thread_id: usize) -> VringEpollResult<Self> { + let epoll = Epoll::new().map_err(VringEpollError::EpollCreateFd)?; + let exit_event_fd = backend.exit_event(thread_id); + + if let Some(exit_event_fd) = &exit_event_fd { + let id = backend.num_queues(); + epoll + .ctl( + ControlOperation::Add, + exit_event_fd.as_raw_fd(), + EpollEvent::new(EventSet::IN, id as u64), + ) + .map_err(VringEpollError::RegisterExitEvent)?; + } + + Ok(VringEpollHandler { + epoll, + backend, + vrings, + thread_id, + exit_event_fd, + phantom: PhantomData, + }) + } + + /// Register an event into the epoll fd. + /// + /// When this event is later triggered, the backend implementation of `handle_event` will be + /// called. + pub fn register_listener(&self, fd: RawFd, ev_type: EventSet, data: u64) -> Result<()> { + // `data` range [0...num_queues] is reserved for queues and exit event. + if data <= self.backend.num_queues() as u64 { + Err(io::Error::from_raw_os_error(libc::EINVAL)) + } else { + self.register_event(fd, ev_type, data) + } + } + + /// Unregister an event from the epoll fd. + /// + /// If the event is triggered after this function has been called, the event will be silently + /// dropped. + pub fn unregister_listener(&self, fd: RawFd, ev_type: EventSet, data: u64) -> Result<()> { + // `data` range [0...num_queues] is reserved for queues and exit event. + if data <= self.backend.num_queues() as u64 { + Err(io::Error::from_raw_os_error(libc::EINVAL)) + } else { + self.unregister_event(fd, ev_type, data) + } + } + + pub(crate) fn register_event(&self, fd: RawFd, ev_type: EventSet, data: u64) -> Result<()> { + self.epoll + .ctl(ControlOperation::Add, fd, EpollEvent::new(ev_type, data)) + } + + pub(crate) fn unregister_event(&self, fd: RawFd, ev_type: EventSet, data: u64) -> Result<()> { + self.epoll + .ctl(ControlOperation::Delete, fd, EpollEvent::new(ev_type, data)) + } + + /// Run the event poll loop to handle all pending events on registered fds. + /// + /// The event loop will be terminated once an event is received from the `exit event fd` + /// associated with the backend. + pub(crate) fn run(&self) -> VringEpollResult<()> { + const EPOLL_EVENTS_LEN: usize = 100; + let mut events = vec![EpollEvent::new(EventSet::empty(), 0); EPOLL_EVENTS_LEN]; + + 'epoll: loop { + let num_events = match self.epoll.wait(-1, &mut events[..]) { + Ok(res) => res, + Err(e) => { + if e.kind() == io::ErrorKind::Interrupted { + // It's well defined from the epoll_wait() syscall + // documentation that the epoll loop can be interrupted + // before any of the requested events occurred or the + // timeout expired. In both those cases, epoll_wait() + // returns an error of type EINTR, but this should not + // be considered as a regular error. Instead it is more + // appropriate to retry, by calling into epoll_wait(). + continue; + } + return Err(VringEpollError::EpollWait(e)); + } + }; + + for event in events.iter().take(num_events) { + let evset = match EventSet::from_bits(event.events) { + Some(evset) => evset, + None => { + let evbits = event.events; + println!("epoll: ignoring unknown event set: 0x{:x}", evbits); + continue; + } + }; + + let ev_type = event.data() as u16; + + // handle_event() returns true if an event is received from the exit event fd. + if self.handle_event(ev_type, evset)? { + break 'epoll; + } + } + } + + Ok(()) + } + + fn handle_event(&self, device_event: u16, evset: EventSet) -> VringEpollResult<bool> { + if self.exit_event_fd.is_some() && device_event as usize == self.backend.num_queues() { + return Ok(true); + } + + if (device_event as usize) < self.vrings.len() { + let vring = &self.vrings[device_event as usize]; + let enabled = vring + .read_kick() + .map_err(VringEpollError::HandleEventReadKick)?; + + // If the vring is not enabled, it should not be processed. + if !enabled { + return Ok(false); + } + } + + self.backend + .handle_event(device_event, evset, &self.vrings, self.thread_id) + .map_err(VringEpollError::HandleEventBackendHandling) + } +} + +impl<S, V, B> AsRawFd for VringEpollHandler<S, V, B> { + fn as_raw_fd(&self) -> RawFd { + self.epoll.as_raw_fd() + } +} + +#[cfg(test)] +mod tests { + use super::super::backend::tests::MockVhostBackend; + use super::super::vring::VringRwLock; + use super::*; + use std::sync::{Arc, Mutex}; + use vm_memory::{GuestAddress, GuestMemoryAtomic, GuestMemoryMmap}; + use vmm_sys_util::eventfd::EventFd; + + #[test] + fn test_vring_epoll_handler() { + let mem = GuestMemoryAtomic::new( + GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(), + ); + let vring = VringRwLock::new(mem, 0x1000).unwrap(); + let backend = Arc::new(Mutex::new(MockVhostBackend::new())); + + let handler = VringEpollHandler::new(backend, vec![vring], 0x1).unwrap(); + + let eventfd = EventFd::new(0).unwrap(); + handler + .register_listener(eventfd.as_raw_fd(), EventSet::IN, 3) + .unwrap(); + // Register an already registered fd. + handler + .register_listener(eventfd.as_raw_fd(), EventSet::IN, 3) + .unwrap_err(); + // Register an invalid data. + handler + .register_listener(eventfd.as_raw_fd(), EventSet::IN, 1) + .unwrap_err(); + + handler + .unregister_listener(eventfd.as_raw_fd(), EventSet::IN, 3) + .unwrap(); + // unregister an already unregistered fd. + handler + .unregister_listener(eventfd.as_raw_fd(), EventSet::IN, 3) + .unwrap_err(); + // unregister an invalid data. + handler + .unregister_listener(eventfd.as_raw_fd(), EventSet::IN, 1) + .unwrap_err(); + // Check we retrieve the correct file descriptor + assert_eq!(handler.as_raw_fd(), handler.epoll.as_raw_fd()); + } +} diff --git a/src/handler.rs b/src/handler.rs new file mode 100644 index 0000000..262bf6c --- /dev/null +++ b/src/handler.rs @@ -0,0 +1,618 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// Copyright 2019-2021 Alibaba Cloud. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 + +use std::error; +use std::fs::File; +use std::io; +use std::os::unix::io::AsRawFd; +use std::sync::Arc; +use std::thread; + +use vhost::vhost_user::message::{ + VhostUserConfigFlags, VhostUserMemoryRegion, VhostUserProtocolFeatures, + VhostUserSingleMemoryRegion, VhostUserVirtioFeatures, VhostUserVringAddrFlags, + VhostUserVringState, +}; +use vhost::vhost_user::{ + Error as VhostUserError, Result as VhostUserResult, Slave, VhostUserSlaveReqHandlerMut, +}; +use virtio_bindings::bindings::virtio_ring::VIRTIO_RING_F_EVENT_IDX; +use virtio_queue::{Error as VirtQueError, QueueT}; +use vm_memory::bitmap::Bitmap; +use vm_memory::mmap::NewBitmap; +use vm_memory::{GuestAddress, GuestAddressSpace, GuestMemoryMmap, GuestRegionMmap}; +use vmm_sys_util::epoll::EventSet; + +use super::backend::VhostUserBackend; +use super::event_loop::VringEpollHandler; +use super::event_loop::{VringEpollError, VringEpollResult}; +use super::vring::VringT; +use super::GM; + +const MAX_MEM_SLOTS: u64 = 32; + +#[derive(Debug)] +/// Errors related to vhost-user handler. +pub enum VhostUserHandlerError { + /// Failed to create a `Vring`. + CreateVring(VirtQueError), + /// Failed to create vring worker. + CreateEpollHandler(VringEpollError), + /// Failed to spawn vring worker. + SpawnVringWorker(io::Error), + /// Could not find the mapping from memory regions. + MissingMemoryMapping, +} + +impl std::fmt::Display for VhostUserHandlerError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + VhostUserHandlerError::CreateVring(e) => { + write!(f, "failed to create vring: {}", e) + } + VhostUserHandlerError::CreateEpollHandler(e) => { + write!(f, "failed to create vring epoll handler: {}", e) + } + VhostUserHandlerError::SpawnVringWorker(e) => { + write!(f, "failed spawning the vring worker: {}", e) + } + VhostUserHandlerError::MissingMemoryMapping => write!(f, "Missing memory mapping"), + } + } +} + +impl error::Error for VhostUserHandlerError {} + +/// Result of vhost-user handler operations. +pub type VhostUserHandlerResult<T> = std::result::Result<T, VhostUserHandlerError>; + +struct AddrMapping { + vmm_addr: u64, + size: u64, + gpa_base: u64, +} + +pub struct VhostUserHandler<S, V, B: Bitmap + 'static> { + backend: S, + handlers: Vec<Arc<VringEpollHandler<S, V, B>>>, + owned: bool, + features_acked: bool, + acked_features: u64, + acked_protocol_features: u64, + num_queues: usize, + max_queue_size: usize, + queues_per_thread: Vec<u64>, + mappings: Vec<AddrMapping>, + atomic_mem: GM<B>, + vrings: Vec<V>, + worker_threads: Vec<thread::JoinHandle<VringEpollResult<()>>>, +} + +// Ensure VhostUserHandler: Clone + Send + Sync + 'static. +impl<S, V, B> VhostUserHandler<S, V, B> +where + S: VhostUserBackend<V, B> + Clone + 'static, + V: VringT<GM<B>> + Clone + Send + Sync + 'static, + B: Bitmap + Clone + Send + Sync + 'static, +{ + pub(crate) fn new(backend: S, atomic_mem: GM<B>) -> VhostUserHandlerResult<Self> { + let num_queues = backend.num_queues(); + let max_queue_size = backend.max_queue_size(); + let queues_per_thread = backend.queues_per_thread(); + + let mut vrings = Vec::new(); + for _ in 0..num_queues { + let vring = V::new(atomic_mem.clone(), max_queue_size as u16) + .map_err(VhostUserHandlerError::CreateVring)?; + vrings.push(vring); + } + + let mut handlers = Vec::new(); + let mut worker_threads = Vec::new(); + for (thread_id, queues_mask) in queues_per_thread.iter().enumerate() { + let mut thread_vrings = Vec::new(); + for (index, vring) in vrings.iter().enumerate() { + if (queues_mask >> index) & 1u64 == 1u64 { + thread_vrings.push(vring.clone()); + } + } + + let handler = Arc::new( + VringEpollHandler::new(backend.clone(), thread_vrings, thread_id) + .map_err(VhostUserHandlerError::CreateEpollHandler)?, + ); + let handler2 = handler.clone(); + let worker_thread = thread::Builder::new() + .name("vring_worker".to_string()) + .spawn(move || handler2.run()) + .map_err(VhostUserHandlerError::SpawnVringWorker)?; + + handlers.push(handler); + worker_threads.push(worker_thread); + } + + Ok(VhostUserHandler { + backend, + handlers, + owned: false, + features_acked: false, + acked_features: 0, + acked_protocol_features: 0, + num_queues, + max_queue_size, + queues_per_thread, + mappings: Vec::new(), + atomic_mem, + vrings, + worker_threads, + }) + } +} + +impl<S, V, B: Bitmap> VhostUserHandler<S, V, B> { + pub(crate) fn send_exit_event(&self) { + for handler in self.handlers.iter() { + handler.send_exit_event(); + } + } + + fn vmm_va_to_gpa(&self, vmm_va: u64) -> VhostUserHandlerResult<u64> { + for mapping in self.mappings.iter() { + if vmm_va >= mapping.vmm_addr && vmm_va < mapping.vmm_addr + mapping.size { + return Ok(vmm_va - mapping.vmm_addr + mapping.gpa_base); + } + } + + Err(VhostUserHandlerError::MissingMemoryMapping) + } +} + +impl<S, V, B> VhostUserHandler<S, V, B> +where + S: VhostUserBackend<V, B>, + V: VringT<GM<B>>, + B: Bitmap, +{ + pub(crate) fn get_epoll_handlers(&self) -> Vec<Arc<VringEpollHandler<S, V, B>>> { + self.handlers.clone() + } + + fn vring_needs_init(&self, vring: &V) -> bool { + let vring_state = vring.get_ref(); + + // If the vring wasn't initialized and we already have an EventFd for + // VRING_KICK, initialize it now. + !vring_state.get_queue().ready() && vring_state.get_kick().is_some() + } + + fn initialize_vring(&self, vring: &V, index: u8) -> VhostUserResult<()> { + assert!(vring.get_ref().get_kick().is_some()); + + if let Some(fd) = vring.get_ref().get_kick() { + for (thread_index, queues_mask) in self.queues_per_thread.iter().enumerate() { + let shifted_queues_mask = queues_mask >> index; + if shifted_queues_mask & 1u64 == 1u64 { + let evt_idx = queues_mask.count_ones() - shifted_queues_mask.count_ones(); + self.handlers[thread_index] + .register_event(fd.as_raw_fd(), EventSet::IN, u64::from(evt_idx)) + .map_err(VhostUserError::ReqHandlerError)?; + break; + } + } + } + + self.vrings[index as usize].set_queue_ready(true); + + Ok(()) + } + + /// Helper to check if VirtioFeature enabled + fn check_feature(&self, feat: VhostUserVirtioFeatures) -> VhostUserResult<()> { + if self.acked_features & feat.bits() != 0 { + Ok(()) + } else { + Err(VhostUserError::InactiveFeature(feat)) + } + } +} + +impl<S, V, B> VhostUserSlaveReqHandlerMut for VhostUserHandler<S, V, B> +where + S: VhostUserBackend<V, B>, + V: VringT<GM<B>>, + B: NewBitmap + Clone, +{ + fn set_owner(&mut self) -> VhostUserResult<()> { + if self.owned { + return Err(VhostUserError::InvalidOperation("already claimed")); + } + self.owned = true; + Ok(()) + } + + fn reset_owner(&mut self) -> VhostUserResult<()> { + self.owned = false; + self.features_acked = false; + self.acked_features = 0; + self.acked_protocol_features = 0; + Ok(()) + } + + fn get_features(&mut self) -> VhostUserResult<u64> { + Ok(self.backend.features()) + } + + fn set_features(&mut self, features: u64) -> VhostUserResult<()> { + if (features & !self.backend.features()) != 0 { + return Err(VhostUserError::InvalidParam); + } + + self.acked_features = features; + self.features_acked = true; + + // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, + // the ring is initialized in an enabled state. + // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, + // the ring is initialized in a disabled state. Client must not + // pass data to/from the backend until ring is enabled by + // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has + // been disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0. + let vring_enabled = + self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0; + for vring in self.vrings.iter_mut() { + vring.set_enabled(vring_enabled); + } + + self.backend.acked_features(self.acked_features); + + Ok(()) + } + + fn set_mem_table( + &mut self, + ctx: &[VhostUserMemoryRegion], + files: Vec<File>, + ) -> VhostUserResult<()> { + // We need to create tuple of ranges from the list of VhostUserMemoryRegion + // that we get from the caller. + let mut regions = Vec::new(); + let mut mappings: Vec<AddrMapping> = Vec::new(); + + for (region, file) in ctx.iter().zip(files) { + regions.push( + GuestRegionMmap::new( + region.mmap_region(file)?, + GuestAddress(region.guest_phys_addr), + ) + .map_err(|e| { + VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e)) + })?, + ); + mappings.push(AddrMapping { + vmm_addr: region.user_addr, + size: region.memory_size, + gpa_base: region.guest_phys_addr, + }); + } + + let mem = GuestMemoryMmap::from_regions(regions).map_err(|e| { + VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e)) + })?; + + // Updating the inner GuestMemory object here will cause all our vrings to + // see the new one the next time they call to `atomic_mem.memory()`. + self.atomic_mem.lock().unwrap().replace(mem); + + self.backend + .update_memory(self.atomic_mem.clone()) + .map_err(|e| { + VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e)) + })?; + self.mappings = mappings; + + Ok(()) + } + + fn set_vring_num(&mut self, index: u32, num: u32) -> VhostUserResult<()> { + if index as usize >= self.num_queues || num == 0 || num as usize > self.max_queue_size { + return Err(VhostUserError::InvalidParam); + } + self.vrings[index as usize].set_queue_size(num as u16); + Ok(()) + } + + fn set_vring_addr( + &mut self, + index: u32, + _flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + _log: u64, + ) -> VhostUserResult<()> { + if index as usize >= self.num_queues { + return Err(VhostUserError::InvalidParam); + } + + if !self.mappings.is_empty() { + let desc_table = self.vmm_va_to_gpa(descriptor).map_err(|e| { + VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e)) + })?; + let avail_ring = self.vmm_va_to_gpa(available).map_err(|e| { + VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e)) + })?; + let used_ring = self.vmm_va_to_gpa(used).map_err(|e| { + VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e)) + })?; + self.vrings[index as usize] + .set_queue_info(desc_table, avail_ring, used_ring) + .map_err(|_| VhostUserError::InvalidParam)?; + + // SET_VRING_BASE will only restore the 'avail' index, however, after the guest driver + // changes, for instance, after reboot, the 'used' index should be reset to 0. + // + // So let's fetch the used index from the vring as set by the guest here to keep + // compatibility with the QEMU's vhost-user library just in case, any implementation + // expects the 'used' index to be set when receiving a SET_VRING_ADDR message. + // + // Note: I'm not sure why QEMU's vhost-user library sets the 'user' index here, + // _probably_ to make sure that the VQ is already configured. A better solution would + // be to receive the 'used' index in SET_VRING_BASE, as is done when using packed VQs. + let idx = self.vrings[index as usize] + .queue_used_idx() + .map_err(|_| VhostUserError::SlaveInternalError)?; + self.vrings[index as usize].set_queue_next_used(idx); + + Ok(()) + } else { + Err(VhostUserError::InvalidParam) + } + } + + fn set_vring_base(&mut self, index: u32, base: u32) -> VhostUserResult<()> { + let event_idx: bool = (self.acked_features & (1 << VIRTIO_RING_F_EVENT_IDX)) != 0; + + self.vrings[index as usize].set_queue_next_avail(base as u16); + self.vrings[index as usize].set_queue_event_idx(event_idx); + self.backend.set_event_idx(event_idx); + + Ok(()) + } + + fn get_vring_base(&mut self, index: u32) -> VhostUserResult<VhostUserVringState> { + if index as usize >= self.num_queues { + return Err(VhostUserError::InvalidParam); + } + + // Quote from vhost-user specification: + // Client must start ring upon receiving a kick (that is, detecting + // that file descriptor is readable) on the descriptor specified by + // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving + // VHOST_USER_GET_VRING_BASE. + self.vrings[index as usize].set_queue_ready(false); + + if let Some(fd) = self.vrings[index as usize].get_ref().get_kick() { + for (thread_index, queues_mask) in self.queues_per_thread.iter().enumerate() { + let shifted_queues_mask = queues_mask >> index; + if shifted_queues_mask & 1u64 == 1u64 { + let evt_idx = queues_mask.count_ones() - shifted_queues_mask.count_ones(); + self.handlers[thread_index] + .unregister_event(fd.as_raw_fd(), EventSet::IN, u64::from(evt_idx)) + .map_err(VhostUserError::ReqHandlerError)?; + break; + } + } + } + + let next_avail = self.vrings[index as usize].queue_next_avail(); + + self.vrings[index as usize].set_kick(None); + self.vrings[index as usize].set_call(None); + + Ok(VhostUserVringState::new(index, u32::from(next_avail))) + } + + fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostUserResult<()> { + if index as usize >= self.num_queues { + return Err(VhostUserError::InvalidParam); + } + + // SAFETY: EventFd requires that it has sole ownership of its fd. So + // does File, so this is safe. + // Ideally, we'd have a generic way to refer to a uniquely-owned fd, + // such as that proposed by Rust RFC #3128. + self.vrings[index as usize].set_kick(file); + + if self.vring_needs_init(&self.vrings[index as usize]) { + self.initialize_vring(&self.vrings[index as usize], index)?; + } + + Ok(()) + } + + fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostUserResult<()> { + if index as usize >= self.num_queues { + return Err(VhostUserError::InvalidParam); + } + + self.vrings[index as usize].set_call(file); + + if self.vring_needs_init(&self.vrings[index as usize]) { + self.initialize_vring(&self.vrings[index as usize], index)?; + } + + Ok(()) + } + + fn set_vring_err(&mut self, index: u8, file: Option<File>) -> VhostUserResult<()> { + if index as usize >= self.num_queues { + return Err(VhostUserError::InvalidParam); + } + + self.vrings[index as usize].set_err(file); + + Ok(()) + } + + fn get_protocol_features(&mut self) -> VhostUserResult<VhostUserProtocolFeatures> { + Ok(self.backend.protocol_features()) + } + + fn set_protocol_features(&mut self, features: u64) -> VhostUserResult<()> { + // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must + // support this message even before VHOST_USER_SET_FEATURES was + // called. + self.acked_protocol_features = features; + Ok(()) + } + + fn get_queue_num(&mut self) -> VhostUserResult<u64> { + Ok(self.num_queues as u64) + } + + fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostUserResult<()> { + // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES + // has been negotiated. + self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?; + + if index as usize >= self.num_queues { + return Err(VhostUserError::InvalidParam); + } + + // Slave must not pass data to/from the backend until ring is + // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1, + // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE + // with parameter 0. + self.vrings[index as usize].set_enabled(enable); + + Ok(()) + } + + fn get_config( + &mut self, + offset: u32, + size: u32, + _flags: VhostUserConfigFlags, + ) -> VhostUserResult<Vec<u8>> { + Ok(self.backend.get_config(offset, size)) + } + + fn set_config( + &mut self, + offset: u32, + buf: &[u8], + _flags: VhostUserConfigFlags, + ) -> VhostUserResult<()> { + self.backend + .set_config(offset, buf) + .map_err(VhostUserError::ReqHandlerError) + } + + fn set_slave_req_fd(&mut self, slave: Slave) { + if self.acked_protocol_features & VhostUserProtocolFeatures::REPLY_ACK.bits() != 0 { + slave.set_reply_ack_flag(true); + } + + self.backend.set_slave_req_fd(slave); + } + + fn get_inflight_fd( + &mut self, + _inflight: &vhost::vhost_user::message::VhostUserInflight, + ) -> VhostUserResult<(vhost::vhost_user::message::VhostUserInflight, File)> { + // Assume the backend hasn't negotiated the inflight feature; it + // wouldn't be correct for the backend to do so, as we don't (yet) + // provide a way for it to handle such requests. + Err(VhostUserError::InvalidOperation("not supported")) + } + + fn set_inflight_fd( + &mut self, + _inflight: &vhost::vhost_user::message::VhostUserInflight, + _file: File, + ) -> VhostUserResult<()> { + Err(VhostUserError::InvalidOperation("not supported")) + } + + fn get_max_mem_slots(&mut self) -> VhostUserResult<u64> { + Ok(MAX_MEM_SLOTS) + } + + fn add_mem_region( + &mut self, + region: &VhostUserSingleMemoryRegion, + file: File, + ) -> VhostUserResult<()> { + let guest_region = Arc::new( + GuestRegionMmap::new( + region.mmap_region(file)?, + GuestAddress(region.guest_phys_addr), + ) + .map_err(|e| { + VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e)) + })?, + ); + + let mem = self + .atomic_mem + .memory() + .insert_region(guest_region) + .map_err(|e| { + VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e)) + })?; + + self.atomic_mem.lock().unwrap().replace(mem); + + self.backend + .update_memory(self.atomic_mem.clone()) + .map_err(|e| { + VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e)) + })?; + + self.mappings.push(AddrMapping { + vmm_addr: region.user_addr, + size: region.memory_size, + gpa_base: region.guest_phys_addr, + }); + + Ok(()) + } + + fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> VhostUserResult<()> { + let (mem, _) = self + .atomic_mem + .memory() + .remove_region(GuestAddress(region.guest_phys_addr), region.memory_size) + .map_err(|e| { + VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e)) + })?; + + self.atomic_mem.lock().unwrap().replace(mem); + + self.backend + .update_memory(self.atomic_mem.clone()) + .map_err(|e| { + VhostUserError::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e)) + })?; + + self.mappings + .retain(|mapping| mapping.gpa_base != region.guest_phys_addr); + + Ok(()) + } +} + +impl<S, V, B: Bitmap> Drop for VhostUserHandler<S, V, B> { + fn drop(&mut self) { + // Signal all working threads to exit. + self.send_exit_event(); + + for thread in self.worker_threads.drain(..) { + if let Err(e) = thread.join() { + error!("Error in vring worker: {:?}", e); + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c65a19e --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,270 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// Copyright 2019-2021 Alibaba Cloud Computing. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 + +//! A simple framework to run a vhost-user backend service. + +#[macro_use] +extern crate log; + +use std::fmt::{Display, Formatter}; +use std::sync::{Arc, Mutex}; +use std::thread; + +use vhost::vhost_user::{Error as VhostUserError, Listener, SlaveListener, SlaveReqHandler}; +use vm_memory::bitmap::Bitmap; +use vm_memory::mmap::NewBitmap; +use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap}; + +use self::handler::VhostUserHandler; + +mod backend; +pub use self::backend::{VhostUserBackend, VhostUserBackendMut}; + +mod event_loop; +pub use self::event_loop::VringEpollHandler; + +mod handler; +pub use self::handler::VhostUserHandlerError; + +mod vring; +pub use self::vring::{ + VringMutex, VringRwLock, VringState, VringStateGuard, VringStateMutGuard, VringT, +}; + +/// An alias for `GuestMemoryAtomic<GuestMemoryMmap<B>>` to simplify code. +type GM<B> = GuestMemoryAtomic<GuestMemoryMmap<B>>; + +#[derive(Debug)] +/// Errors related to vhost-user daemon. +pub enum Error { + /// Failed to create a new vhost-user handler. + NewVhostUserHandler(VhostUserHandlerError), + /// Failed creating vhost-user slave listener. + CreateSlaveListener(VhostUserError), + /// Failed creating vhost-user slave handler. + CreateSlaveReqHandler(VhostUserError), + /// Failed starting daemon thread. + StartDaemon(std::io::Error), + /// Failed waiting for daemon thread. + WaitDaemon(std::boxed::Box<dyn std::any::Any + std::marker::Send>), + /// Failed handling a vhost-user request. + HandleRequest(VhostUserError), +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match self { + Error::NewVhostUserHandler(e) => write!(f, "cannot create vhost user handler: {}", e), + Error::CreateSlaveListener(e) => write!(f, "cannot create slave listener: {}", e), + Error::CreateSlaveReqHandler(e) => write!(f, "cannot create slave req handler: {}", e), + Error::StartDaemon(e) => write!(f, "failed to start daemon: {}", e), + Error::WaitDaemon(_e) => write!(f, "failed to wait for daemon exit"), + Error::HandleRequest(e) => write!(f, "failed to handle request: {}", e), + } + } +} + +/// Result of vhost-user daemon operations. +pub type Result<T> = std::result::Result<T, Error>; + +/// Implement a simple framework to run a vhost-user service daemon. +/// +/// This structure is the public API the backend is allowed to interact with in order to run +/// a fully functional vhost-user daemon. +pub struct VhostUserDaemon<S, V, B: Bitmap + 'static = ()> { + name: String, + handler: Arc<Mutex<VhostUserHandler<S, V, B>>>, + main_thread: Option<thread::JoinHandle<Result<()>>>, +} + +impl<S, V, B> VhostUserDaemon<S, V, B> +where + S: VhostUserBackend<V, B> + Clone + 'static, + V: VringT<GM<B>> + Clone + Send + Sync + 'static, + B: NewBitmap + Clone + Send + Sync, +{ + /// Create the daemon instance, providing the backend implementation of `VhostUserBackend`. + /// + /// Under the hood, this will start a dedicated thread responsible for listening onto + /// registered event. Those events can be vring events or custom events from the backend, + /// but they get to be registered later during the sequence. + pub fn new( + name: String, + backend: S, + atomic_mem: GuestMemoryAtomic<GuestMemoryMmap<B>>, + ) -> Result<Self> { + let handler = Arc::new(Mutex::new( + VhostUserHandler::new(backend, atomic_mem).map_err(Error::NewVhostUserHandler)?, + )); + + Ok(VhostUserDaemon { + name, + handler, + main_thread: None, + }) + } + + /// Run a dedicated thread handling all requests coming through the socket. + /// This runs in an infinite loop that should be terminating once the other + /// end of the socket (the VMM) hangs up. + /// + /// This function is the common code for starting a new daemon, no matter if + /// it acts as a client or a server. + fn start_daemon( + &mut self, + mut handler: SlaveReqHandler<Mutex<VhostUserHandler<S, V, B>>>, + ) -> Result<()> { + let handle = thread::Builder::new() + .name(self.name.clone()) + .spawn(move || loop { + handler.handle_request().map_err(Error::HandleRequest)?; + }) + .map_err(Error::StartDaemon)?; + + self.main_thread = Some(handle); + + Ok(()) + } + + /// Connect to the vhost-user socket and run a dedicated thread handling + /// all requests coming through this socket. This runs in an infinite loop + /// that should be terminating once the other end of the socket (the VMM) + /// hangs up. + pub fn start_client(&mut self, socket_path: &str) -> Result<()> { + let slave_handler = SlaveReqHandler::connect(socket_path, self.handler.clone()) + .map_err(Error::CreateSlaveReqHandler)?; + self.start_daemon(slave_handler) + } + + /// Listen to the vhost-user socket and run a dedicated thread handling all requests coming + /// through this socket. + /// + /// This runs in an infinite loop that should be terminating once the other end of the socket + /// (the VMM) disconnects. + // TODO: the current implementation has limitations that only one incoming connection will be + // handled from the listener. Should it be enhanced to support reconnection? + pub fn start(&mut self, listener: Listener) -> Result<()> { + let mut slave_listener = SlaveListener::new(listener, self.handler.clone()) + .map_err(Error::CreateSlaveListener)?; + let slave_handler = self.accept(&mut slave_listener)?; + self.start_daemon(slave_handler) + } + + fn accept( + &self, + slave_listener: &mut SlaveListener<Mutex<VhostUserHandler<S, V, B>>>, + ) -> Result<SlaveReqHandler<Mutex<VhostUserHandler<S, V, B>>>> { + loop { + match slave_listener.accept() { + Err(e) => return Err(Error::CreateSlaveListener(e)), + Ok(Some(v)) => return Ok(v), + Ok(None) => continue, + } + } + } + + /// Wait for the thread handling the vhost-user socket connection to terminate. + pub fn wait(&mut self) -> Result<()> { + if let Some(handle) = self.main_thread.take() { + match handle.join().map_err(Error::WaitDaemon)? { + Ok(()) => Ok(()), + Err(Error::HandleRequest(VhostUserError::SocketBroken(_))) => Ok(()), + Err(e) => Err(e), + } + } else { + Ok(()) + } + } + + /// Retrieve the vring epoll handler. + /// + /// This is necessary to perform further actions like registering and unregistering some extra + /// event file descriptors. + pub fn get_epoll_handlers(&self) -> Vec<Arc<VringEpollHandler<S, V, B>>> { + // Do not expect poisoned lock. + self.handler.lock().unwrap().get_epoll_handlers() + } +} + +#[cfg(test)] +mod tests { + use super::backend::tests::MockVhostBackend; + use super::*; + use std::os::unix::net::{UnixListener, UnixStream}; + use std::sync::Barrier; + use vm_memory::{GuestAddress, GuestMemoryAtomic, GuestMemoryMmap}; + + #[test] + fn test_new_daemon() { + let mem = GuestMemoryAtomic::new( + GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(), + ); + let backend = Arc::new(Mutex::new(MockVhostBackend::new())); + let mut daemon = VhostUserDaemon::new("test".to_owned(), backend, mem).unwrap(); + + let handlers = daemon.get_epoll_handlers(); + assert_eq!(handlers.len(), 2); + + let barrier = Arc::new(Barrier::new(2)); + let tmpdir = tempfile::tempdir().unwrap(); + let mut path = tmpdir.path().to_path_buf(); + path.push("socket"); + + let barrier2 = barrier.clone(); + let path1 = path.clone(); + let thread = thread::spawn(move || { + barrier2.wait(); + let socket = UnixStream::connect(&path1).unwrap(); + barrier2.wait(); + drop(socket) + }); + + let listener = Listener::new(&path, false).unwrap(); + barrier.wait(); + daemon.start(listener).unwrap(); + barrier.wait(); + // Above process generates a `HandleRequest(PartialMessage)` error. + daemon.wait().unwrap_err(); + daemon.wait().unwrap(); + thread.join().unwrap(); + } + + #[test] + fn test_new_daemon_client() { + let mem = GuestMemoryAtomic::new( + GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(), + ); + let backend = Arc::new(Mutex::new(MockVhostBackend::new())); + let mut daemon = VhostUserDaemon::new("test".to_owned(), backend, mem).unwrap(); + + let handlers = daemon.get_epoll_handlers(); + assert_eq!(handlers.len(), 2); + + let barrier = Arc::new(Barrier::new(2)); + let tmpdir = tempfile::tempdir().unwrap(); + let mut path = tmpdir.path().to_path_buf(); + path.push("socket"); + + let barrier2 = barrier.clone(); + let path1 = path.clone(); + let thread = thread::spawn(move || { + let listener = UnixListener::bind(&path1).unwrap(); + barrier2.wait(); + let (stream, _) = listener.accept().unwrap(); + barrier2.wait(); + drop(stream) + }); + + barrier.wait(); + daemon + .start_client(path.as_path().to_str().unwrap()) + .unwrap(); + barrier.wait(); + // Above process generates a `HandleRequest(PartialMessage)` error. + daemon.wait().unwrap_err(); + daemon.wait().unwrap(); + thread.join().unwrap(); + } +} diff --git a/src/net.rs b/src/net.rs new file mode 100644 index 0000000..06a5e47 --- /dev/null +++ b/src/net.rs @@ -0,0 +1,19 @@ +// Copyright (C) 2021 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause + +//! Trait to control vhost-net backend drivers. + +use std::fs::File; + +use crate::backend::VhostBackend; +use crate::Result; + +/// Trait to control vhost-net backend drivers. +pub trait VhostNet: VhostBackend { + /// Set fd as VHOST_NET backend. + /// + /// # Arguments + /// * `queue_index` - Index of the virtqueue + /// * `fd` - The file descriptor which servers as the backend + fn set_backend(&self, queue_idx: usize, fd: Option<&File>) -> Result<()>; +} diff --git a/src/vdpa.rs b/src/vdpa.rs new file mode 100644 index 0000000..6af01cf --- /dev/null +++ b/src/vdpa.rs @@ -0,0 +1,126 @@ +// Copyright (C) 2021 Red Hat, Inc. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause + +//! Trait to control vhost-vdpa backend drivers. + +use vmm_sys_util::eventfd::EventFd; + +use crate::backend::VhostBackend; +use crate::Result; + +/// vhost vdpa IOVA range +pub struct VhostVdpaIovaRange { + /// First address that can be mapped by vhost-vDPA. + pub first: u64, + /// Last address that can be mapped by vhost-vDPA. + pub last: u64, +} + +/// Trait to control vhost-vdpa backend drivers. +/// +/// vDPA (virtio Data Path Acceleration) devices has datapath compliant with the +/// virtio specification and the control path is vendor specific. +/// vDPA devices can be both physically located on the hardware or emulated +/// by software. +/// +/// Compared to vhost acceleration, vDPA offers more control over the device +/// lifecycle. +/// For this reason, the vhost-vdpa interface extends the vhost API, offering +/// additional APIs for controlling the device (e.g. changing the state or +/// accessing the configuration space +pub trait VhostVdpa: VhostBackend { + /// Get the device id. + /// The device ids follow the same definition of the device id defined in virtio-spec. + fn get_device_id(&self) -> Result<u32>; + + /// Get the status. + /// The status bits follow the same definition of the device status defined in virtio-spec. + fn get_status(&self) -> Result<u8>; + + /// Set the status. + /// The status bits follow the same definition of the device status defined in virtio-spec. + /// + /// # Arguments + /// * `status` - Status bits to set + fn set_status(&self, status: u8) -> Result<()>; + + /// Get the device configuration. + /// + /// # Arguments + /// * `offset` - Offset in the device configuration space + /// * `buffer` - Buffer for configuration data + fn get_config(&self, offset: u32, buffer: &mut [u8]) -> Result<()>; + + /// Set the device configuration. + /// + /// # Arguments + /// * `offset` - Offset in the device configuration space + /// * `buffer` - Buffer for configuration data + fn set_config(&self, offset: u32, buffer: &[u8]) -> Result<()>; + + /// Set the status for a given vring. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to enable/disable. + /// * `enabled` - true to enable the vring, false to disable it. + fn set_vring_enable(&self, queue_index: usize, enabled: bool) -> Result<()>; + + /// Get the maximum number of descriptors in the vring supported by the device. + fn get_vring_num(&self) -> Result<u16>; + + /// Set the eventfd to trigger when device configuration change. + /// + /// # Arguments + /// * `fd` - EventFd to trigger. + fn set_config_call(&self, fd: &EventFd) -> Result<()>; + + /// Get the valid I/O virtual addresses range supported by the device. + fn get_iova_range(&self) -> Result<VhostVdpaIovaRange>; + + /// Get the config size + fn get_config_size(&self) -> Result<u32>; + + /// Get the count of all virtqueues + fn get_vqs_count(&self) -> Result<u32>; + + /// Get the number of virtqueue groups + fn get_group_num(&self) -> Result<u32>; + + /// Get the number of address spaces + fn get_as_num(&self) -> Result<u32>; + + /// Get the group for a virtqueue. + /// The virtqueue index is stored in the index field of + /// vhost_vring_state. The group for this specific virtqueue is + /// returned via num field of vhost_vring_state. + fn get_vring_group(&self, queue_index: u32) -> Result<u32>; + + /// Set the ASID for a virtqueue group. The group index is stored in + /// the index field of vhost_vring_state, the ASID associated with this + /// group is stored at num field of vhost_vring_state. + fn set_group_asid(&self, group_index: u32, asid: u32) -> Result<()>; + + /// Suspend a device so it does not process virtqueue requests anymore + /// + /// After the return of ioctl the device must preserve all the necessary state + /// (the virtqueue vring base plus the possible device specific states) that is + /// required for restoring in the future. The device must not change its + /// configuration after that point. + fn suspend(&self) -> Result<()>; + + /// Map DMA region. + /// + /// # Arguments + /// * `iova` - I/O virtual address. + /// * `size` - Size of the I/O mapping. + /// * `vaddr` - Virtual address in the current process. + /// * `readonly` - true if the region is read-only, false if reads and writes are allowed. + fn dma_map(&self, iova: u64, size: u64, vaddr: *const u8, readonly: bool) -> Result<()>; + + /// Unmap DMA region. + /// + /// # Arguments + /// * `iova` - I/O virtual address. + /// * `size` - Size of the I/O mapping. + fn dma_unmap(&self, iova: u64, size: u64) -> Result<()>; +} diff --git a/src/vhost_kern/mod.rs b/src/vhost_kern/mod.rs new file mode 100644 index 0000000..1fa5000 --- /dev/null +++ b/src/vhost_kern/mod.rs @@ -0,0 +1,467 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause +// +// Portions Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-Google file. + +//! Traits and structs to control Linux in-kernel vhost drivers. +//! +//! The initial vhost implementation is a part of the Linux kernel and uses ioctl interface to +//! communicate with userspace applications. This sub module provides ioctl based interfaces to +//! control the in-kernel net, scsi, vsock vhost drivers. + +use std::mem; +use std::os::unix::io::{AsRawFd, RawFd}; + +use libc::{c_void, ssize_t, write}; + +use vm_memory::{Address, GuestAddress, GuestAddressSpace, GuestMemory, GuestUsize}; +use vmm_sys_util::eventfd::EventFd; +use vmm_sys_util::ioctl::{ioctl, ioctl_with_mut_ref, ioctl_with_ptr, ioctl_with_ref}; + +use super::{ + Error, Result, VhostBackend, VhostIotlbBackend, VhostIotlbMsg, VhostIotlbMsgParser, + VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData, VHOST_MAX_MEMORY_REGIONS, +}; + +pub mod vhost_binding; +use self::vhost_binding::*; + +#[cfg(feature = "vhost-net")] +pub mod net; +#[cfg(feature = "vhost-vdpa")] +pub mod vdpa; +#[cfg(feature = "vhost-vsock")] +pub mod vsock; + +#[inline] +fn ioctl_result<T>(rc: i32, res: T) -> Result<T> { + if rc < 0 { + Err(Error::IoctlError(std::io::Error::last_os_error())) + } else { + Ok(res) + } +} + +#[inline] +fn io_result<T>(rc: isize, res: T) -> Result<T> { + if rc < 0 { + Err(Error::IOError(std::io::Error::last_os_error())) + } else { + Ok(res) + } +} + +/// Represent an in-kernel vhost device backend. +pub trait VhostKernBackend: AsRawFd { + /// Associated type to access guest memory. + type AS: GuestAddressSpace; + + /// Get the object to access the guest's memory. + fn mem(&self) -> &Self::AS; + + /// Check whether the ring configuration is valid. + fn is_valid(&self, config_data: &VringConfigData) -> bool { + let queue_size = config_data.queue_size; + if queue_size > config_data.queue_max_size + || queue_size == 0 + || (queue_size & (queue_size - 1)) != 0 + { + return false; + } + + let m = self.mem().memory(); + let desc_table_size = 16 * u64::from(queue_size) as GuestUsize; + let avail_ring_size = 6 + 2 * u64::from(queue_size) as GuestUsize; + let used_ring_size = 6 + 8 * u64::from(queue_size) as GuestUsize; + if GuestAddress(config_data.desc_table_addr) + .checked_add(desc_table_size) + .map_or(true, |v| !m.address_in_range(v)) + { + return false; + } + if GuestAddress(config_data.avail_ring_addr) + .checked_add(avail_ring_size) + .map_or(true, |v| !m.address_in_range(v)) + { + return false; + } + if GuestAddress(config_data.used_ring_addr) + .checked_add(used_ring_size) + .map_or(true, |v| !m.address_in_range(v)) + { + return false; + } + + config_data.is_log_addr_valid() + } +} + +impl<T: VhostKernBackend> VhostBackend for T { + /// Get a bitmask of supported virtio/vhost features. + fn get_features(&self) -> Result<u64> { + let mut avail_features: u64 = 0; + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_mut_ref(self, VHOST_GET_FEATURES(), &mut avail_features) }; + ioctl_result(ret, avail_features) + } + + /// Inform the vhost subsystem which features to enable. This should be a subset of + /// supported features from VHOST_GET_FEATURES. + /// + /// # Arguments + /// * `features` - Bitmask of features to set. + fn set_features(&self, features: u64) -> Result<()> { + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_FEATURES(), &features) }; + ioctl_result(ret, ()) + } + + /// Set the current process as the owner of this file descriptor. + /// This must be run before any other vhost ioctls. + fn set_owner(&self) -> Result<()> { + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl(self, VHOST_SET_OWNER()) }; + ioctl_result(ret, ()) + } + + fn reset_owner(&self) -> Result<()> { + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl(self, VHOST_RESET_OWNER()) }; + ioctl_result(ret, ()) + } + + /// Set the guest memory mappings for vhost to use. + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + if regions.is_empty() || regions.len() > VHOST_MAX_MEMORY_REGIONS { + return Err(Error::InvalidGuestMemory); + } + + let mut vhost_memory = VhostMemory::new(regions.len() as u16); + for (index, region) in regions.iter().enumerate() { + vhost_memory.set_region( + index as u32, + &vhost_memory_region { + guest_phys_addr: region.guest_phys_addr, + memory_size: region.memory_size, + userspace_addr: region.userspace_addr, + flags_padding: 0u64, + }, + )?; + } + + // SAFETY: This ioctl is called with a pointer that is valid for the lifetime + // of this function. The kernel will make its own copy of the memory + // tables. As always, check the return value. + let ret = unsafe { ioctl_with_ptr(self, VHOST_SET_MEM_TABLE(), vhost_memory.as_ptr()) }; + ioctl_result(ret, ()) + } + + /// Set base address for page modification logging. + /// + /// # Arguments + /// * `base` - Base address for page modification logging. + fn set_log_base(&self, base: u64, region: Option<VhostUserDirtyLogRegion>) -> Result<()> { + if region.is_some() { + return Err(Error::LogAddress); + } + + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_LOG_BASE(), &base) }; + ioctl_result(ret, ()) + } + + /// Specify an eventfd file descriptor to signal on log write. + fn set_log_fd(&self, fd: RawFd) -> Result<()> { + let val: i32 = fd; + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_LOG_FD(), &val) }; + ioctl_result(ret, ()) + } + + /// Set the number of descriptors in the vring. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to set descriptor count for. + /// * `num` - Number of descriptors in the queue. + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> { + let vring_state = vhost_vring_state { + index: queue_index as u32, + num: u32::from(num), + }; + + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_NUM(), &vring_state) }; + ioctl_result(ret, ()) + } + + /// Set the addresses for a given vring. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to set addresses for. + /// * `config_data` - Vring config data, addresses of desc_table, avail_ring + /// and used_ring are in the guest address space. + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + if !self.is_valid(config_data) { + return Err(Error::InvalidQueue); + } + + // The addresses are converted into the host address space. + let vring_addr = config_data.to_vhost_vring_addr(queue_index, self.mem())?; + + // SAFETY: This ioctl is called on a valid vhost fd and has its + // return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_ADDR(), &vring_addr) }; + ioctl_result(ret, ()) + } + + /// Set the first index to look for available descriptors. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `num` - Index where available descriptors start. + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> { + let vring_state = vhost_vring_state { + index: queue_index as u32, + num: u32::from(base), + }; + + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_BASE(), &vring_state) }; + ioctl_result(ret, ()) + } + + /// Get a bitmask of supported virtio/vhost features. + fn get_vring_base(&self, queue_index: usize) -> Result<u32> { + let vring_state = vhost_vring_state { + index: queue_index as u32, + num: 0, + }; + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_GET_VRING_BASE(), &vring_state) }; + ioctl_result(ret, vring_state.num) + } + + /// Set the eventfd to trigger when buffers have been used by the host. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd to trigger. + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let vring_file = vhost_vring_file { + index: queue_index as u32, + fd: fd.as_raw_fd(), + }; + + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_CALL(), &vring_file) }; + ioctl_result(ret, ()) + } + + /// Set the eventfd that will be signaled by the guest when buffers are + /// available for the host to process. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd that will be signaled from guest. + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let vring_file = vhost_vring_file { + index: queue_index as u32, + fd: fd.as_raw_fd(), + }; + + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_KICK(), &vring_file) }; + ioctl_result(ret, ()) + } + + /// Set the eventfd to signal an error from the vhost backend. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd that will be signaled from the backend. + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let vring_file = vhost_vring_file { + index: queue_index as u32, + fd: fd.as_raw_fd(), + }; + + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_ERR(), &vring_file) }; + ioctl_result(ret, ()) + } +} + +/// Interface to handle in-kernel backend features. +pub trait VhostKernFeatures: Sized + AsRawFd { + /// Get features acked with the vhost backend. + fn get_backend_features_acked(&self) -> u64; + + /// Set features acked with the vhost backend. + fn set_backend_features_acked(&mut self, features: u64); + + /// Get a bitmask of supported vhost backend features. + fn get_backend_features(&self) -> Result<u64> { + let mut avail_features: u64 = 0; + + let ret = + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + unsafe { ioctl_with_mut_ref(self, VHOST_GET_BACKEND_FEATURES(), &mut avail_features) }; + ioctl_result(ret, avail_features) + } + + /// Inform the vhost subsystem which backend features to enable. This should + /// be a subset of supported features from VHOST_GET_BACKEND_FEATURES. + /// + /// # Arguments + /// * `features` - Bitmask of features to set. + fn set_backend_features(&mut self, features: u64) -> Result<()> { + // SAFETY: This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_BACKEND_FEATURES(), &features) }; + + if ret >= 0 { + self.set_backend_features_acked(features); + } + + ioctl_result(ret, ()) + } +} + +/// Handle IOTLB messeges for in-kernel vhost device backend. +impl<I: VhostKernBackend + VhostKernFeatures> VhostIotlbBackend for I { + /// Send an IOTLB message to the in-kernel vhost backend. + /// + /// # Arguments + /// * `msg` - IOTLB message to send. + fn send_iotlb_msg(&self, msg: &VhostIotlbMsg) -> Result<()> { + let ret: ssize_t; + + if self.get_backend_features_acked() & (1 << VHOST_BACKEND_F_IOTLB_MSG_V2) != 0 { + let mut msg_v2 = vhost_msg_v2 { + type_: VHOST_IOTLB_MSG_V2, + ..Default::default() + }; + + msg_v2.__bindgen_anon_1.iotlb.iova = msg.iova; + msg_v2.__bindgen_anon_1.iotlb.size = msg.size; + msg_v2.__bindgen_anon_1.iotlb.uaddr = msg.userspace_addr; + msg_v2.__bindgen_anon_1.iotlb.perm = msg.perm as u8; + msg_v2.__bindgen_anon_1.iotlb.type_ = msg.msg_type as u8; + + // SAFETY: This is safe because we are using a valid vhost fd, and + // a valid pointer and size to the vhost_msg_v2 structure. + ret = unsafe { + write( + self.as_raw_fd(), + &msg_v2 as *const vhost_msg_v2 as *const c_void, + mem::size_of::<vhost_msg_v2>(), + ) + }; + } else { + let mut msg_v1 = vhost_msg { + type_: VHOST_IOTLB_MSG, + ..Default::default() + }; + + msg_v1.__bindgen_anon_1.iotlb.iova = msg.iova; + msg_v1.__bindgen_anon_1.iotlb.size = msg.size; + msg_v1.__bindgen_anon_1.iotlb.uaddr = msg.userspace_addr; + msg_v1.__bindgen_anon_1.iotlb.perm = msg.perm as u8; + msg_v1.__bindgen_anon_1.iotlb.type_ = msg.msg_type as u8; + + // SAFETY: This is safe because we are using a valid vhost fd, and + // a valid pointer and size to the vhost_msg structure. + ret = unsafe { + write( + self.as_raw_fd(), + &msg_v1 as *const vhost_msg as *const c_void, + mem::size_of::<vhost_msg>(), + ) + }; + } + + io_result(ret, ()) + } +} + +impl VhostIotlbMsgParser for vhost_msg { + fn parse(&self, msg: &mut VhostIotlbMsg) -> Result<()> { + if self.type_ != VHOST_IOTLB_MSG { + return Err(Error::InvalidIotlbMsg); + } + + // SAFETY: We trust the kernel to return a structure with the union + // fields properly initialized. We are sure it is a vhost_msg, because + // we checked that `self.type_` is VHOST_IOTLB_MSG. + unsafe { + if self.__bindgen_anon_1.iotlb.type_ == 0 { + return Err(Error::InvalidIotlbMsg); + } + + msg.iova = self.__bindgen_anon_1.iotlb.iova; + msg.size = self.__bindgen_anon_1.iotlb.size; + msg.userspace_addr = self.__bindgen_anon_1.iotlb.uaddr; + msg.perm = mem::transmute(self.__bindgen_anon_1.iotlb.perm); + msg.msg_type = mem::transmute(self.__bindgen_anon_1.iotlb.type_); + } + + Ok(()) + } +} + +impl VhostIotlbMsgParser for vhost_msg_v2 { + fn parse(&self, msg: &mut VhostIotlbMsg) -> Result<()> { + if self.type_ != VHOST_IOTLB_MSG_V2 { + return Err(Error::InvalidIotlbMsg); + } + + // SAFETY: We trust the kernel to return a structure with the union + // fields properly initialized. We are sure it is a vhost_msg_v2, because + // we checked that `self.type_` is VHOST_IOTLB_MSG_V2. + unsafe { + if self.__bindgen_anon_1.iotlb.type_ == 0 { + return Err(Error::InvalidIotlbMsg); + } + + msg.iova = self.__bindgen_anon_1.iotlb.iova; + msg.size = self.__bindgen_anon_1.iotlb.size; + msg.userspace_addr = self.__bindgen_anon_1.iotlb.uaddr; + msg.perm = mem::transmute(self.__bindgen_anon_1.iotlb.perm); + msg.msg_type = mem::transmute(self.__bindgen_anon_1.iotlb.type_); + } + + Ok(()) + } +} + +impl VringConfigData { + /// Convert the config (guest address space) into vhost_vring_addr + /// (host address space). + pub fn to_vhost_vring_addr<AS: GuestAddressSpace>( + &self, + queue_index: usize, + mem: &AS, + ) -> Result<vhost_vring_addr> { + let desc_addr = mem + .memory() + .get_host_address(GuestAddress(self.desc_table_addr)) + .map_err(|_| Error::DescriptorTableAddress)?; + let avail_addr = mem + .memory() + .get_host_address(GuestAddress(self.avail_ring_addr)) + .map_err(|_| Error::AvailAddress)?; + let used_addr = mem + .memory() + .get_host_address(GuestAddress(self.used_ring_addr)) + .map_err(|_| Error::UsedAddress)?; + Ok(vhost_vring_addr { + index: queue_index as u32, + flags: self.flags, + desc_user_addr: desc_addr as u64, + used_user_addr: used_addr as u64, + avail_user_addr: avail_addr as u64, + log_guest_addr: self.get_log_addr(), + }) + } +} diff --git a/src/vhost_kern/net.rs b/src/vhost_kern/net.rs new file mode 100644 index 0000000..89f390c --- /dev/null +++ b/src/vhost_kern/net.rs @@ -0,0 +1,177 @@ +// Copyright (C) 2021 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause + +//! Kernel-based vhost-net backend + +use std::fs::{File, OpenOptions}; +use std::os::unix::fs::OpenOptionsExt; +use std::os::unix::io::{AsRawFd, RawFd}; + +use vm_memory::GuestAddressSpace; +use vmm_sys_util::ioctl::ioctl_with_ref; + +use super::vhost_binding::*; +use super::{ioctl_result, Error, Result, VhostKernBackend}; + +use crate::net::*; + +const VHOST_NET_PATH: &str = "/dev/vhost-net"; + +/// Handle for running VHOST_NET ioctls +pub struct Net<AS: GuestAddressSpace> { + fd: File, + mem: AS, +} + +impl<AS: GuestAddressSpace> Net<AS> { + /// Open a handle to a new VHOST-NET instance. + pub fn new(mem: AS) -> Result<Self> { + Ok(Net { + fd: OpenOptions::new() + .read(true) + .write(true) + .custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK) + .open(VHOST_NET_PATH) + .map_err(Error::VhostOpen)?, + mem, + }) + } +} + +impl<AS: GuestAddressSpace> VhostNet for Net<AS> { + fn set_backend(&self, queue_index: usize, fd: Option<&File>) -> Result<()> { + let vring_file = vhost_vring_file { + index: queue_index as u32, + fd: fd.map_or(-1, |v| v.as_raw_fd()), + }; + + // SAFETY: Safe because the vhost-net fd is valid and we check the return value + let ret = unsafe { ioctl_with_ref(self, VHOST_NET_SET_BACKEND(), &vring_file) }; + ioctl_result(ret, ()) + } +} + +impl<AS: GuestAddressSpace> VhostKernBackend for Net<AS> { + type AS = AS; + + fn mem(&self) -> &Self::AS { + &self.mem + } +} + +impl<AS: GuestAddressSpace> AsRawFd for Net<AS> { + fn as_raw_fd(&self) -> RawFd { + self.fd.as_raw_fd() + } +} + +#[cfg(test)] +mod tests { + use vm_memory::{GuestAddress, GuestMemory, GuestMemoryMmap}; + use vmm_sys_util::eventfd::EventFd; + + use super::*; + use crate::{ + VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData, + }; + use serial_test::serial; + + #[test] + #[serial] + fn test_net_new_device() { + let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let net = Net::new(&m).unwrap(); + + assert!(net.as_raw_fd() >= 0); + assert!(net.mem().find_region(GuestAddress(0x100)).is_some()); + assert!(net.mem().find_region(GuestAddress(0x10_0000)).is_none()); + } + + #[test] + #[serial] + fn test_net_is_valid() { + let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let net = Net::new(&m).unwrap(); + + let mut config = VringConfigData { + queue_max_size: 32, + queue_size: 32, + flags: 0, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: None, + }; + assert!(net.is_valid(&config)); + + config.queue_size = 0; + assert!(!net.is_valid(&config)); + config.queue_size = 31; + assert!(!net.is_valid(&config)); + config.queue_size = 33; + assert!(!net.is_valid(&config)); + } + + #[test] + #[serial] + fn test_net_ioctls() { + let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let net = Net::new(&m).unwrap(); + let backend = OpenOptions::new() + .read(true) + .write(true) + .open("/dev/null") + .unwrap(); + + let features = net.get_features().unwrap(); + net.set_features(features).unwrap(); + + net.set_owner().unwrap(); + + net.set_mem_table(&[]).unwrap_err(); + + let region = VhostUserMemoryRegionInfo::new( + 0x0, + 0x10_0000, + m.get_host_address(GuestAddress(0x0)).unwrap() as u64, + 0, + -1, + ); + net.set_mem_table(&[region]).unwrap(); + + net.set_log_base( + 0x4000, + Some(VhostUserDirtyLogRegion { + mmap_size: 0x1000, + mmap_offset: 0x10, + mmap_handle: 1, + }), + ) + .unwrap_err(); + net.set_log_base(0x4000, None).unwrap(); + + let eventfd = EventFd::new(0).unwrap(); + net.set_log_fd(eventfd.as_raw_fd()).unwrap(); + + net.set_vring_num(0, 32).unwrap(); + + let config = VringConfigData { + queue_max_size: 32, + queue_size: 32, + flags: 0, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: None, + }; + net.set_vring_addr(0, &config).unwrap(); + net.set_vring_base(0, 1).unwrap(); + net.set_vring_call(0, &eventfd).unwrap(); + net.set_vring_kick(0, &eventfd).unwrap(); + net.set_vring_err(0, &eventfd).unwrap(); + assert_eq!(net.get_vring_base(0).unwrap(), 1); + + net.set_backend(0, Some(&backend)).unwrap_err(); + net.set_backend(0, None).unwrap(); + } +} diff --git a/src/vhost_kern/vdpa.rs b/src/vhost_kern/vdpa.rs new file mode 100644 index 0000000..65e0123 --- /dev/null +++ b/src/vhost_kern/vdpa.rs @@ -0,0 +1,560 @@ +// Copyright (C) 2021 Red Hat, Inc. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause + +//! Kernel-based vhost-vdpa backend. + +use std::fs::{File, OpenOptions}; +use std::io::Error as IOError; +use std::os::raw::{c_uchar, c_uint}; +use std::os::unix::fs::OpenOptionsExt; +use std::os::unix::io::{AsRawFd, RawFd}; + +use vm_memory::GuestAddressSpace; +use vmm_sys_util::eventfd::EventFd; +use vmm_sys_util::fam::*; +use vmm_sys_util::ioctl::{ioctl, ioctl_with_mut_ref, ioctl_with_ptr, ioctl_with_ref}; + +use super::vhost_binding::*; +use super::{ioctl_result, Error, Result, VhostKernBackend, VhostKernFeatures}; +use crate::vdpa::*; +use crate::{VhostAccess, VhostIotlbBackend, VhostIotlbMsg, VhostIotlbType, VringConfigData}; + +// Implement the FamStruct trait for vhost_vdpa_config +generate_fam_struct_impl!( + vhost_vdpa_config, + c_uchar, + buf, + c_uint, + len, + c_uint::MAX as usize +); + +type VhostVdpaConfig = FamStructWrapper<vhost_vdpa_config>; + +/// Handle for running VHOST_VDPA ioctls. +pub struct VhostKernVdpa<AS: GuestAddressSpace> { + fd: File, + mem: AS, + backend_features_acked: u64, +} + +impl<AS: GuestAddressSpace> VhostKernVdpa<AS> { + /// Open a handle to a new VHOST-VDPA instance. + pub fn new(path: &str, mem: AS) -> Result<Self> { + Ok(VhostKernVdpa { + fd: OpenOptions::new() + .read(true) + .write(true) + .custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK) + .open(path) + .map_err(Error::VhostOpen)?, + mem, + backend_features_acked: 0, + }) + } + + /// Create a `VhostKernVdpa` object with given content. + pub fn with(fd: File, mem: AS, backend_features_acked: u64) -> Self { + VhostKernVdpa { + fd, + mem, + backend_features_acked, + } + } + + /// Set the addresses for a given vring. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to set addresses for. + /// * `config_data` - Vring config data, addresses of desc_table, avail_ring + /// and used_ring are in the guest address space. + pub fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + if !self.is_valid(config_data) { + return Err(Error::InvalidQueue); + } + + // vDPA backends expect IOVA (that can be mapped 1:1 with + // GPA when no IOMMU is involved). + let vring_addr = vhost_vring_addr { + index: queue_index as u32, + flags: config_data.flags, + desc_user_addr: config_data.desc_table_addr, + used_user_addr: config_data.used_ring_addr, + avail_user_addr: config_data.avail_ring_addr, + log_guest_addr: config_data.get_log_addr(), + }; + + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_ADDR(), &vring_addr) }; + ioctl_result(ret, ()) + } +} + +impl<AS: GuestAddressSpace> VhostVdpa for VhostKernVdpa<AS> { + fn get_device_id(&self) -> Result<u32> { + let mut device_id: u32 = 0; + + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + let ret = unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_DEVICE_ID(), &mut device_id) }; + ioctl_result(ret, device_id) + } + + fn get_status(&self) -> Result<u8> { + let mut status: u8 = 0; + + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + let ret = unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_STATUS(), &mut status) }; + ioctl_result(ret, status) + } + + fn set_status(&self, status: u8) -> Result<()> { + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_VDPA_SET_STATUS(), &status) }; + ioctl_result(ret, ()) + } + + fn get_config(&self, offset: u32, buffer: &mut [u8]) -> Result<()> { + let mut config = VhostVdpaConfig::new(buffer.len()) + .map_err(|_| Error::IoctlError(IOError::from_raw_os_error(libc::ENOMEM)))?; + + config.as_mut_fam_struct().off = offset; + + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + let ret = unsafe { + ioctl_with_ptr( + self, + VHOST_VDPA_GET_CONFIG(), + config.as_mut_fam_struct_ptr(), + ) + }; + + buffer.copy_from_slice(config.as_slice()); + + ioctl_result(ret, ()) + } + + fn set_config(&self, offset: u32, buffer: &[u8]) -> Result<()> { + let mut config = VhostVdpaConfig::new(buffer.len()) + .map_err(|_| Error::IoctlError(IOError::from_raw_os_error(libc::ENOMEM)))?; + + config.as_mut_fam_struct().off = offset; + config.as_mut_slice().copy_from_slice(buffer); + + let ret = + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + unsafe { ioctl_with_ptr(self, VHOST_VDPA_SET_CONFIG(), config.as_fam_struct_ptr()) }; + ioctl_result(ret, ()) + } + + fn set_vring_enable(&self, queue_index: usize, enabled: bool) -> Result<()> { + let vring_state = vhost_vring_state { + index: queue_index as u32, + num: enabled as u32, + }; + + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_VDPA_SET_VRING_ENABLE(), &vring_state) }; + ioctl_result(ret, ()) + } + + fn get_vring_num(&self) -> Result<u16> { + let mut vring_num: u16 = 0; + + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + let ret = unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_VRING_NUM(), &mut vring_num) }; + ioctl_result(ret, vring_num) + } + + fn set_config_call(&self, fd: &EventFd) -> Result<()> { + let event_fd: ::std::os::raw::c_int = fd.as_raw_fd(); + + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_VDPA_SET_CONFIG_CALL(), &event_fd) }; + ioctl_result(ret, ()) + } + + fn get_iova_range(&self) -> Result<VhostVdpaIovaRange> { + let mut low_iova_range = vhost_vdpa_iova_range { first: 0, last: 0 }; + + let ret = + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_IOVA_RANGE(), &mut low_iova_range) }; + + let iova_range = VhostVdpaIovaRange { + first: low_iova_range.first, + last: low_iova_range.last, + }; + + ioctl_result(ret, iova_range) + } + + fn get_config_size(&self) -> Result<u32> { + let mut config_size: u32 = 0; + + let ret = + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_CONFIG_SIZE(), &mut config_size) }; + ioctl_result(ret, config_size) + } + + fn get_vqs_count(&self) -> Result<u32> { + let mut vqs_count: u32 = 0; + + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + let ret = unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_VQS_COUNT(), &mut vqs_count) }; + ioctl_result(ret, vqs_count) + } + + fn get_group_num(&self) -> Result<u32> { + let mut group_num: u32 = 0; + + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + let ret = unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_GROUP_NUM(), &mut group_num) }; + ioctl_result(ret, group_num) + } + + fn get_as_num(&self) -> Result<u32> { + let mut as_num: u32 = 0; + + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + let ret = unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_AS_NUM(), &mut as_num) }; + ioctl_result(ret, as_num) + } + + fn get_vring_group(&self, queue_index: u32) -> Result<u32> { + let mut vring_state = vhost_vring_state { + index: queue_index, + ..Default::default() + }; + + let ret = + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + unsafe { ioctl_with_mut_ref(self, VHOST_VDPA_GET_VRING_GROUP(), &mut vring_state) }; + ioctl_result(ret, vring_state.num) + } + + fn set_group_asid(&self, group_index: u32, asid: u32) -> Result<()> { + let vring_state = vhost_vring_state { + index: group_index, + num: asid, + }; + + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_VDPA_GET_VRING_GROUP(), &vring_state) }; + ioctl_result(ret, ()) + } + + fn suspend(&self) -> Result<()> { + // SAFETY: This ioctl is called on a valid vhost-vdpa fd and has its + // return value checked. + let ret = unsafe { ioctl(self, VHOST_VDPA_SUSPEND()) }; + ioctl_result(ret, ()) + } + + fn dma_map(&self, iova: u64, size: u64, vaddr: *const u8, readonly: bool) -> Result<()> { + let iotlb = VhostIotlbMsg { + iova, + size, + userspace_addr: vaddr as u64, + perm: match readonly { + true => VhostAccess::ReadOnly, + false => VhostAccess::ReadWrite, + }, + msg_type: VhostIotlbType::Update, + }; + + self.send_iotlb_msg(&iotlb) + } + + fn dma_unmap(&self, iova: u64, size: u64) -> Result<()> { + let iotlb = VhostIotlbMsg { + iova, + size, + msg_type: VhostIotlbType::Invalidate, + ..Default::default() + }; + + self.send_iotlb_msg(&iotlb) + } +} + +impl<AS: GuestAddressSpace> VhostKernBackend for VhostKernVdpa<AS> { + type AS = AS; + + fn mem(&self) -> &Self::AS { + &self.mem + } + + /// Check whether the ring configuration is valid. + fn is_valid(&self, config_data: &VringConfigData) -> bool { + let queue_size = config_data.queue_size; + if queue_size > config_data.queue_max_size + || queue_size == 0 + || (queue_size & (queue_size - 1)) != 0 + { + return false; + } + + // Since vDPA could be dealing with IOVAs corresponding to GVAs, it + // wouldn't make sense to go through the validation of the descriptor + // table address, available ring address and used ring address against + // the guest memory representation we have access to. + + config_data.is_log_addr_valid() + } +} + +impl<AS: GuestAddressSpace> AsRawFd for VhostKernVdpa<AS> { + fn as_raw_fd(&self) -> RawFd { + self.fd.as_raw_fd() + } +} + +impl<AS: GuestAddressSpace> VhostKernFeatures for VhostKernVdpa<AS> { + fn get_backend_features_acked(&self) -> u64 { + self.backend_features_acked + } + + fn set_backend_features_acked(&mut self, features: u64) { + self.backend_features_acked = features; + } +} + +#[cfg(test)] +mod tests { + const VHOST_VDPA_PATH: &str = "/dev/vhost-vdpa-0"; + + use std::alloc::{alloc, dealloc, Layout}; + use vm_memory::{GuestAddress, GuestMemory, GuestMemoryMmap}; + use vmm_sys_util::eventfd::EventFd; + + use super::*; + use crate::{ + VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData, + }; + use serial_test::serial; + use std::io::ErrorKind; + + /// macro to skip test if vhost-vdpa device path is not found. + /// + /// vDPA simulators are available since Linux 5.7, but the CI may have + /// an older kernel, so for now we skip the test if we don't find + /// the device. + macro_rules! unwrap_not_found { + ( $e:expr ) => { + match $e { + Ok(v) => v, + Err(error) => match error { + Error::VhostOpen(ref e) if e.kind() == ErrorKind::NotFound => { + println!("Err: {:?} SKIPPED", e); + return; + } + e => panic!("Err: {:?}", e), + }, + } + }; + } + + macro_rules! validate_ioctl { + ( $e:expr, $ref_value:expr ) => { + match $e { + Ok(v) => assert_eq!(v, $ref_value), + Err(error) => match error { + Error::IoctlError(e) if e.raw_os_error().unwrap() == libc::ENOTTY => { + println!("Err: {:?} SKIPPED", e); + } + e => panic!("Err: {:?}", e), + }, + } + }; + } + + #[test] + #[serial] + fn test_vdpa_kern_new_device() { + let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let vdpa = unwrap_not_found!(VhostKernVdpa::new(VHOST_VDPA_PATH, &m)); + + assert!(vdpa.as_raw_fd() >= 0); + assert!(vdpa.mem().find_region(GuestAddress(0x100)).is_some()); + assert!(vdpa.mem().find_region(GuestAddress(0x10_0000)).is_none()); + } + + #[test] + #[serial] + fn test_vdpa_kern_is_valid() { + let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let vdpa = unwrap_not_found!(VhostKernVdpa::new(VHOST_VDPA_PATH, &m)); + + let mut config = VringConfigData { + queue_max_size: 32, + queue_size: 32, + flags: 0, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: None, + }; + assert!(vdpa.is_valid(&config)); + + config.queue_size = 0; + assert!(!vdpa.is_valid(&config)); + config.queue_size = 31; + assert!(!vdpa.is_valid(&config)); + config.queue_size = 33; + assert!(!vdpa.is_valid(&config)); + } + + #[test] + #[serial] + fn test_vdpa_kern_ioctls() { + let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let vdpa = unwrap_not_found!(VhostKernVdpa::new(VHOST_VDPA_PATH, &m)); + + let features = vdpa.get_features().unwrap(); + // VIRTIO_F_VERSION_1 (bit 32) should be set + assert_ne!(features & (1 << 32), 0); + vdpa.set_features(features).unwrap(); + + vdpa.set_owner().unwrap(); + + vdpa.set_mem_table(&[]).unwrap_err(); + + let region = VhostUserMemoryRegionInfo::new( + 0x0, + 0x10_0000, + m.get_host_address(GuestAddress(0x0)).unwrap() as u64, + 0, + -1, + ); + vdpa.set_mem_table(&[region]).unwrap(); + + let device_id = vdpa.get_device_id().unwrap(); + assert!(device_id > 0); + + assert_eq!(vdpa.get_status().unwrap(), 0x0); + vdpa.set_status(0x1).unwrap(); + assert_eq!(vdpa.get_status().unwrap(), 0x1); + + let mut vec = vec![0u8; 8]; + vdpa.get_config(0, &mut vec).unwrap(); + vdpa.set_config(0, &vec).unwrap(); + + let eventfd = EventFd::new(0).unwrap(); + + // set_log_base() and set_log_fd() are not supported by vhost-vdpa + vdpa.set_log_base( + 0x4000, + Some(VhostUserDirtyLogRegion { + mmap_size: 0x1000, + mmap_offset: 0x10, + mmap_handle: 1, + }), + ) + .unwrap_err(); + vdpa.set_log_base(0x4000, None).unwrap_err(); + vdpa.set_log_fd(eventfd.as_raw_fd()).unwrap_err(); + + let max_queues = vdpa.get_vring_num().unwrap(); + vdpa.set_vring_num(0, max_queues + 1).unwrap_err(); + + vdpa.set_vring_num(0, 32).unwrap(); + + let config = VringConfigData { + queue_max_size: 32, + queue_size: 32, + flags: 0, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: None, + }; + vdpa.set_vring_addr(0, &config).unwrap(); + vdpa.set_vring_base(0, 1).unwrap(); + vdpa.set_vring_call(0, &eventfd).unwrap(); + vdpa.set_vring_kick(0, &eventfd).unwrap(); + vdpa.set_vring_err(0, &eventfd).unwrap(); + + vdpa.set_config_call(&eventfd).unwrap(); + + let iova_range = vdpa.get_iova_range().unwrap(); + // vDPA-block simulator returns [0, u64::MAX] range + assert_eq!(iova_range.first, 0); + assert_eq!(iova_range.last, u64::MAX); + + let (config_size, vqs_count, group_num, as_num, vring_group) = if device_id == 1 { + (24, 3, 2, 2, 0) + } else if device_id == 2 { + (60, 1, 1, 1, 0) + } else { + panic!("Unexpected device id {}", device_id) + }; + + validate_ioctl!(vdpa.get_config_size(), config_size); + validate_ioctl!(vdpa.get_vqs_count(), vqs_count); + validate_ioctl!(vdpa.get_group_num(), group_num); + validate_ioctl!(vdpa.get_as_num(), as_num); + validate_ioctl!(vdpa.get_vring_group(0), vring_group); + validate_ioctl!(vdpa.set_group_asid(0, 12345), ()); + + if vdpa.get_backend_features().unwrap() & (1 << VHOST_BACKEND_F_SUSPEND) + == (1 << VHOST_BACKEND_F_SUSPEND) + { + validate_ioctl!(vdpa.suspend(), ()); + } + + assert_eq!(vdpa.get_vring_base(0).unwrap(), 1); + + vdpa.set_vring_enable(0, true).unwrap(); + vdpa.set_vring_enable(0, false).unwrap(); + } + + #[test] + #[serial] + fn test_vdpa_kern_dma() { + let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let mut vdpa = unwrap_not_found!(VhostKernVdpa::new(VHOST_VDPA_PATH, &m)); + + let features = vdpa.get_features().unwrap(); + // VIRTIO_F_VERSION_1 (bit 32) should be set + assert_ne!(features & (1 << 32), 0); + vdpa.set_features(features).unwrap(); + + let backend_features = vdpa.get_backend_features().unwrap(); + assert_ne!(backend_features & (1 << VHOST_BACKEND_F_IOTLB_MSG_V2), 0); + vdpa.set_backend_features(backend_features).unwrap(); + + vdpa.set_owner().unwrap(); + + vdpa.dma_map(0xFFFF_0000, 0xFFFF, std::ptr::null::<u8>(), false) + .unwrap_err(); + + let layout = Layout::from_size_align(0xFFFF, 1).unwrap(); + + // SAFETY: Safe because layout has non-zero size. + let ptr = unsafe { alloc(layout) }; + + vdpa.dma_map(0xFFFF_0000, 0xFFFF, ptr, false).unwrap(); + vdpa.dma_unmap(0xFFFF_0000, 0xFFFF).unwrap(); + + // SAFETY: Safe because `ptr` is allocated with the same allocator + // using the same `layout`. + unsafe { dealloc(ptr, layout) }; + } +} diff --git a/src/vhost_kern/vhost_binding.rs b/src/vhost_kern/vhost_binding.rs new file mode 100644 index 0000000..5ebaa56 --- /dev/null +++ b/src/vhost_kern/vhost_binding.rs @@ -0,0 +1,545 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause +// +// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Portions Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-Google file. + +/* Auto-generated by bindgen then manually edited for simplicity */ + +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(missing_docs)] +#![allow(clippy::missing_safety_doc)] + +use crate::{Error, Result}; +use std::os::raw; + +pub const VHOST: raw::c_uint = 0xaf; +pub const VHOST_VRING_F_LOG: raw::c_uint = 0; +pub const VHOST_ACCESS_RO: raw::c_uchar = 1; +pub const VHOST_ACCESS_WO: raw::c_uchar = 2; +pub const VHOST_ACCESS_RW: raw::c_uchar = 3; +pub const VHOST_IOTLB_MISS: raw::c_uchar = 1; +pub const VHOST_IOTLB_UPDATE: raw::c_uchar = 2; +pub const VHOST_IOTLB_INVALIDATE: raw::c_uchar = 3; +pub const VHOST_IOTLB_ACCESS_FAIL: raw::c_uchar = 4; +pub const VHOST_IOTLB_BATCH_BEGIN: raw::c_uchar = 5; +pub const VHOST_IOTLB_BATCH_END: raw::c_uchar = 6; +pub const VHOST_IOTLB_MSG: raw::c_int = 1; +pub const VHOST_IOTLB_MSG_V2: raw::c_uint = 2; +pub const VHOST_PAGE_SIZE: raw::c_uint = 4096; +pub const VHOST_VIRTIO: raw::c_uint = 175; +pub const VHOST_VRING_LITTLE_ENDIAN: raw::c_uint = 0; +pub const VHOST_VRING_BIG_ENDIAN: raw::c_uint = 1; +pub const VHOST_F_LOG_ALL: raw::c_uint = 26; +pub const VHOST_NET_F_VIRTIO_NET_HDR: raw::c_uint = 27; +pub const VHOST_SCSI_ABI_VERSION: raw::c_uint = 1; +pub const VHOST_BACKEND_F_IOTLB_MSG_V2: raw::c_ulonglong = 0x1; +pub const VHOST_BACKEND_F_IOTLB_BATCH: raw::c_ulonglong = 0x2; +pub const VHOST_BACKEND_F_IOTLB_ASID: raw::c_ulonglong = 0x3; +pub const VHOST_BACKEND_F_SUSPEND: raw::c_ulonglong = 0x4; + +ioctl_ior_nr!(VHOST_GET_FEATURES, VHOST, 0x00, raw::c_ulonglong); +ioctl_iow_nr!(VHOST_SET_FEATURES, VHOST, 0x00, raw::c_ulonglong); +ioctl_io_nr!(VHOST_SET_OWNER, VHOST, 0x01); +ioctl_io_nr!(VHOST_RESET_OWNER, VHOST, 0x02); +ioctl_iow_nr!(VHOST_SET_MEM_TABLE, VHOST, 0x03, vhost_memory); +ioctl_iow_nr!(VHOST_SET_LOG_BASE, VHOST, 0x04, raw::c_ulonglong); +ioctl_iow_nr!(VHOST_SET_LOG_FD, VHOST, 0x07, raw::c_int); +ioctl_iow_nr!(VHOST_SET_VRING_NUM, VHOST, 0x10, vhost_vring_state); +ioctl_iow_nr!(VHOST_SET_VRING_ADDR, VHOST, 0x11, vhost_vring_addr); +ioctl_iow_nr!(VHOST_SET_VRING_BASE, VHOST, 0x12, vhost_vring_state); +ioctl_iowr_nr!(VHOST_GET_VRING_BASE, VHOST, 0x12, vhost_vring_state); +ioctl_iow_nr!(VHOST_SET_VRING_KICK, VHOST, 0x20, vhost_vring_file); +ioctl_iow_nr!(VHOST_SET_VRING_CALL, VHOST, 0x21, vhost_vring_file); +ioctl_iow_nr!(VHOST_SET_VRING_ERR, VHOST, 0x22, vhost_vring_file); +ioctl_iow_nr!(VHOST_SET_BACKEND_FEATURES, VHOST, 0x25, raw::c_ulonglong); +ioctl_ior_nr!(VHOST_GET_BACKEND_FEATURES, VHOST, 0x26, raw::c_ulonglong); +ioctl_iow_nr!(VHOST_NET_SET_BACKEND, VHOST, 0x30, vhost_vring_file); +ioctl_iow_nr!(VHOST_SCSI_SET_ENDPOINT, VHOST, 0x40, vhost_scsi_target); +ioctl_iow_nr!(VHOST_SCSI_CLEAR_ENDPOINT, VHOST, 0x41, vhost_scsi_target); +ioctl_iow_nr!(VHOST_SCSI_GET_ABI_VERSION, VHOST, 0x42, raw::c_int); +ioctl_iow_nr!(VHOST_SCSI_SET_EVENTS_MISSED, VHOST, 0x43, raw::c_uint); +ioctl_iow_nr!(VHOST_SCSI_GET_EVENTS_MISSED, VHOST, 0x44, raw::c_uint); +ioctl_iow_nr!(VHOST_VSOCK_SET_GUEST_CID, VHOST, 0x60, raw::c_ulonglong); +ioctl_iow_nr!(VHOST_VSOCK_SET_RUNNING, VHOST, 0x61, raw::c_int); +ioctl_ior_nr!(VHOST_VDPA_GET_DEVICE_ID, VHOST, 0x70, raw::c_uint); +ioctl_ior_nr!(VHOST_VDPA_GET_STATUS, VHOST, 0x71, raw::c_uchar); +ioctl_iow_nr!(VHOST_VDPA_SET_STATUS, VHOST, 0x72, raw::c_uchar); +ioctl_ior_nr!(VHOST_VDPA_GET_CONFIG, VHOST, 0x73, vhost_vdpa_config); +ioctl_iow_nr!(VHOST_VDPA_SET_CONFIG, VHOST, 0x74, vhost_vdpa_config); +ioctl_iow_nr!(VHOST_VDPA_SET_VRING_ENABLE, VHOST, 0x75, vhost_vring_state); +ioctl_ior_nr!(VHOST_VDPA_GET_VRING_NUM, VHOST, 0x76, raw::c_ushort); +ioctl_iow_nr!(VHOST_VDPA_SET_CONFIG_CALL, VHOST, 0x77, raw::c_int); +ioctl_ior_nr!( + VHOST_VDPA_GET_IOVA_RANGE, + VHOST, + 0x78, + vhost_vdpa_iova_range +); +ioctl_ior_nr!(VHOST_VDPA_GET_CONFIG_SIZE, VHOST, 0x79, raw::c_uint); +ioctl_ior_nr!(VHOST_VDPA_GET_VQS_COUNT, VHOST, 0x80, raw::c_uint); +ioctl_ior_nr!(VHOST_VDPA_GET_GROUP_NUM, VHOST, 0x81, raw::c_uint); +ioctl_ior_nr!(VHOST_VDPA_GET_AS_NUM, VHOST, 0x7a, raw::c_uint); +ioctl_iowr_nr!(VHOST_VDPA_GET_VRING_GROUP, VHOST, 0x7b, vhost_vring_state); +ioctl_iow_nr!(VHOST_VDPA_SET_GROUP_ASID, VHOST, 0x7c, vhost_vring_state); +ioctl_io_nr!(VHOST_VDPA_SUSPEND, VHOST, 0x7d); + +#[repr(C)] +#[derive(Default)] +pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>); + +impl<T> __IncompleteArrayField<T> { + #[inline] + pub fn new() -> Self { + __IncompleteArrayField(::std::marker::PhantomData) + } + + #[inline] + #[allow(clippy::trivially_copy_pass_by_ref)] + #[allow(clippy::useless_transmute)] + pub unsafe fn as_ptr(&self) -> *const T { + ::std::mem::transmute(self) + } + + #[inline] + #[allow(clippy::useless_transmute)] + pub unsafe fn as_mut_ptr(&mut self) -> *mut T { + ::std::mem::transmute(self) + } + + #[inline] + pub unsafe fn as_slice(&self, len: usize) -> &[T] { + ::std::slice::from_raw_parts(self.as_ptr(), len) + } + + #[inline] + pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] { + ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len) + } +} + +impl<T> ::std::fmt::Debug for __IncompleteArrayField<T> { + fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + fmt.write_str("__IncompleteArrayField") + } +} + +impl<T> ::std::clone::Clone for __IncompleteArrayField<T> { + #[inline] + fn clone(&self) -> Self { + Self::new() + } +} + +impl<T> ::std::marker::Copy for __IncompleteArrayField<T> {} + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct vhost_vring_state { + pub index: raw::c_uint, + pub num: raw::c_uint, +} + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct vhost_vring_file { + pub index: raw::c_uint, + pub fd: raw::c_int, +} + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct vhost_vring_addr { + pub index: raw::c_uint, + pub flags: raw::c_uint, + pub desc_user_addr: raw::c_ulonglong, + pub used_user_addr: raw::c_ulonglong, + pub avail_user_addr: raw::c_ulonglong, + pub log_guest_addr: raw::c_ulonglong, +} + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct vhost_iotlb_msg { + pub iova: raw::c_ulonglong, + pub size: raw::c_ulonglong, + pub uaddr: raw::c_ulonglong, + pub perm: raw::c_uchar, + pub type_: raw::c_uchar, +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct vhost_msg { + pub type_: raw::c_int, + pub __bindgen_anon_1: vhost_msg__bindgen_ty_1, +} + +impl Default for vhost_msg { + fn default() -> Self { + // SAFETY: Zeroing all bytes is fine because they represent a valid + // value for all members of the structure + unsafe { ::std::mem::zeroed() } + } +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub union vhost_msg__bindgen_ty_1 { + pub iotlb: vhost_iotlb_msg, + pub padding: [raw::c_uchar; 64usize], + _bindgen_union_align: [u64; 8usize], +} + +impl Default for vhost_msg__bindgen_ty_1 { + fn default() -> Self { + // SAFETY: Zeroing all bytes is fine because they represent a valid + // value for all members of the structure + unsafe { ::std::mem::zeroed() } + } +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct vhost_msg_v2 { + pub type_: raw::c_uint, + pub reserved: raw::c_uint, + pub __bindgen_anon_1: vhost_msg_v2__bindgen_ty_1, +} + +impl Default for vhost_msg_v2 { + fn default() -> Self { + // SAFETY: Zeroing all bytes is fine because they represent a valid + // value for all members of the structure + unsafe { ::std::mem::zeroed() } + } +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub union vhost_msg_v2__bindgen_ty_1 { + pub iotlb: vhost_iotlb_msg, + pub padding: [raw::c_uchar; 64usize], + _bindgen_union_align: [u64; 8usize], +} + +impl Default for vhost_msg_v2__bindgen_ty_1 { + fn default() -> Self { + // SAFETY: Zeroing all bytes is fine because they represent a valid + // value for all members of the structure + unsafe { ::std::mem::zeroed() } + } +} + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct vhost_memory_region { + pub guest_phys_addr: raw::c_ulonglong, + pub memory_size: raw::c_ulonglong, + pub userspace_addr: raw::c_ulonglong, + pub flags_padding: raw::c_ulonglong, +} + +#[repr(C)] +#[derive(Debug, Default, Clone)] +pub struct vhost_memory { + pub nregions: raw::c_uint, + pub padding: raw::c_uint, + pub regions: __IncompleteArrayField<vhost_memory_region>, + __force_alignment: [u64; 0], +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct vhost_scsi_target { + pub abi_version: raw::c_int, + pub vhost_wwpn: [raw::c_char; 224usize], + pub vhost_tpgt: raw::c_ushort, + pub reserved: raw::c_ushort, +} + +impl Default for vhost_scsi_target { + fn default() -> Self { + // SAFETY: Zeroing all bytes is fine because they represent a valid + // value for all members of the structure + unsafe { ::std::mem::zeroed() } + } +} + +#[repr(C)] +#[derive(Debug, Default)] +pub struct vhost_vdpa_config { + pub off: raw::c_uint, + pub len: raw::c_uint, + pub buf: __IncompleteArrayField<raw::c_uchar>, +} + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct vhost_vdpa_iova_range { + pub first: raw::c_ulonglong, + pub last: raw::c_ulonglong, +} + +/// Helper to support vhost::set_mem_table() +pub struct VhostMemory { + buf: Vec<vhost_memory>, +} + +impl VhostMemory { + // Limit number of regions to u16 to simplify error handling + pub fn new(entries: u16) -> Self { + let size = std::mem::size_of::<vhost_memory_region>() * entries as usize; + let count = (size + 2 * std::mem::size_of::<vhost_memory>() - 1) + / std::mem::size_of::<vhost_memory>(); + let mut buf: Vec<vhost_memory> = vec![Default::default(); count]; + buf[0].nregions = u32::from(entries); + VhostMemory { buf } + } + + pub fn as_ptr(&self) -> *const char { + &self.buf[0] as *const vhost_memory as *const char + } + + pub fn get_header(&self) -> &vhost_memory { + &self.buf[0] + } + + pub fn get_region(&self, index: u32) -> Option<&vhost_memory_region> { + if index >= self.buf[0].nregions { + return None; + } + // SAFETY: Safe because we have allocated enough space nregions + let regions = unsafe { self.buf[0].regions.as_slice(self.buf[0].nregions as usize) }; + Some(®ions[index as usize]) + } + + pub fn set_region(&mut self, index: u32, region: &vhost_memory_region) -> Result<()> { + if index >= self.buf[0].nregions { + return Err(Error::InvalidGuestMemory); + } + // SAFETY: Safe because we have allocated enough space nregions and checked the index. + let regions = unsafe { self.buf[0].regions.as_mut_slice(index as usize + 1) }; + regions[index as usize] = *region; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn bindgen_test_layout_vhost_vring_state() { + assert_eq!( + ::std::mem::size_of::<vhost_vring_state>(), + 8usize, + concat!("Size of: ", stringify!(vhost_vring_state)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_vring_state>(), + 4usize, + concat!("Alignment of ", stringify!(vhost_vring_state)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_vring_file() { + assert_eq!( + ::std::mem::size_of::<vhost_vring_file>(), + 8usize, + concat!("Size of: ", stringify!(vhost_vring_file)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_vring_file>(), + 4usize, + concat!("Alignment of ", stringify!(vhost_vring_file)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_vring_addr() { + assert_eq!( + ::std::mem::size_of::<vhost_vring_addr>(), + 40usize, + concat!("Size of: ", stringify!(vhost_vring_addr)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_vring_addr>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_vring_addr)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_msg__bindgen_ty_1() { + assert_eq!( + ::std::mem::size_of::<vhost_msg__bindgen_ty_1>(), + 64usize, + concat!("Size of: ", stringify!(vhost_msg__bindgen_ty_1)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_msg__bindgen_ty_1>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_msg__bindgen_ty_1)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_msg() { + assert_eq!( + ::std::mem::size_of::<vhost_msg>(), + 72usize, + concat!("Size of: ", stringify!(vhost_msg)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_msg>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_msg)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_msg_v2__bindgen_ty_1() { + assert_eq!( + ::std::mem::size_of::<vhost_msg_v2__bindgen_ty_1>(), + 64usize, + concat!("Size of: ", stringify!(vhost_msg_v2__bindgen_ty_1)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_msg_v2__bindgen_ty_1>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_msg_v2__bindgen_ty_1)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_msg_v2() { + assert_eq!( + ::std::mem::size_of::<vhost_msg_v2>(), + 72usize, + concat!("Size of: ", stringify!(vhost_msg_v2)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_msg_v2>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_msg_v2)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_memory_region() { + assert_eq!( + ::std::mem::size_of::<vhost_memory_region>(), + 32usize, + concat!("Size of: ", stringify!(vhost_memory_region)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_memory_region>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_memory_region)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_memory() { + assert_eq!( + ::std::mem::size_of::<vhost_memory>(), + 8usize, + concat!("Size of: ", stringify!(vhost_memory)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_memory>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_memory)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_iotlb_msg() { + assert_eq!( + ::std::mem::size_of::<vhost_iotlb_msg>(), + 32usize, + concat!("Size of: ", stringify!(vhost_iotlb_msg)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_iotlb_msg>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_iotlb_msg)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_scsi_target() { + assert_eq!( + ::std::mem::size_of::<vhost_scsi_target>(), + 232usize, + concat!("Size of: ", stringify!(vhost_scsi_target)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_scsi_target>(), + 4usize, + concat!("Alignment of ", stringify!(vhost_scsi_target)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_vdpa_config() { + assert_eq!( + ::std::mem::size_of::<vhost_vdpa_config>(), + 8usize, + concat!("Size of: ", stringify!(vhost_vdpa_config)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_vdpa_config>(), + 4usize, + concat!("Alignment of ", stringify!(vhost_vdpa_config)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_vdpa_iova_range() { + assert_eq!( + ::std::mem::size_of::<vhost_vdpa_iova_range>(), + 16usize, + concat!("Size of: ", stringify!(vhost_vdpa_iova_range)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_vdpa_iova_range>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_vdpa_iova_range)) + ); + } + + #[test] + fn test_vhostmemory() { + let mut obj = VhostMemory::new(2); + let region = vhost_memory_region { + guest_phys_addr: 0x1000u64, + memory_size: 0x2000u64, + userspace_addr: 0x300000u64, + flags_padding: 0u64, + }; + assert!(obj.get_region(2).is_none()); + + { + let header = obj.get_header(); + assert_eq!(header.nregions, 2u32); + } + { + assert!(obj.set_region(0, ®ion).is_ok()); + assert!(obj.set_region(1, ®ion).is_ok()); + assert!(obj.set_region(2, ®ion).is_err()); + } + + let region1 = obj.get_region(1).unwrap(); + assert_eq!(region1.guest_phys_addr, 0x1000u64); + assert_eq!(region1.memory_size, 0x2000u64); + assert_eq!(region1.userspace_addr, 0x300000u64); + } +} diff --git a/src/vhost_kern/vsock.rs b/src/vhost_kern/vsock.rs new file mode 100644 index 0000000..9bc788e --- /dev/null +++ b/src/vhost_kern/vsock.rs @@ -0,0 +1,196 @@ +// Copyright (C) 2019 Alibaba Cloud. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause +// +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-Google file. + +//! Kernel-based vhost-vsock backend. + +use std::fs::{File, OpenOptions}; +use std::os::unix::fs::OpenOptionsExt; +use std::os::unix::io::{AsRawFd, RawFd}; + +use vm_memory::GuestAddressSpace; +use vmm_sys_util::ioctl::ioctl_with_ref; + +use super::vhost_binding::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING}; +use super::{ioctl_result, Error, Result, VhostKernBackend}; +use crate::vsock::VhostVsock; + +const VHOST_PATH: &str = "/dev/vhost-vsock"; + +/// Handle for running VHOST_VSOCK ioctls. +pub struct Vsock<AS: GuestAddressSpace> { + fd: File, + mem: AS, +} + +impl<AS: GuestAddressSpace> Vsock<AS> { + /// Open a handle to a new VHOST-VSOCK instance. + pub fn new(mem: AS) -> Result<Self> { + Ok(Vsock { + fd: OpenOptions::new() + .read(true) + .write(true) + .custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK) + .open(VHOST_PATH) + .map_err(Error::VhostOpen)?, + mem, + }) + } + + fn set_running(&self, running: bool) -> Result<()> { + let on: ::std::os::raw::c_int = if running { 1 } else { 0 }; + + // SAFETY: This ioctl is called on a valid vhost-vsock fd and has its + // return value checked. + let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_RUNNING(), &on) }; + ioctl_result(ret, ()) + } +} + +impl<AS: GuestAddressSpace> VhostVsock for Vsock<AS> { + fn set_guest_cid(&self, cid: u64) -> Result<()> { + // SAFETY: This ioctl is called on a valid vhost-vsock fd and has its + // return value checked. + let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_GUEST_CID(), &cid) }; + ioctl_result(ret, ()) + } + + fn start(&self) -> Result<()> { + self.set_running(true) + } + + fn stop(&self) -> Result<()> { + self.set_running(false) + } +} + +impl<AS: GuestAddressSpace> VhostKernBackend for Vsock<AS> { + type AS = AS; + + fn mem(&self) -> &Self::AS { + &self.mem + } +} + +impl<AS: GuestAddressSpace> AsRawFd for Vsock<AS> { + fn as_raw_fd(&self) -> RawFd { + self.fd.as_raw_fd() + } +} + +#[cfg(test)] +mod tests { + use vm_memory::{GuestAddress, GuestMemory, GuestMemoryMmap}; + use vmm_sys_util::eventfd::EventFd; + + use super::*; + use crate::{ + VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData, + }; + + #[test] + fn test_vsock_new_device() { + let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let vsock = Vsock::new(&m).unwrap(); + + assert!(vsock.as_raw_fd() >= 0); + assert!(vsock.mem().find_region(GuestAddress(0x100)).is_some()); + assert!(vsock.mem().find_region(GuestAddress(0x10_0000)).is_none()); + } + + #[test] + fn test_vsock_is_valid() { + let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let vsock = Vsock::new(&m).unwrap(); + + let mut config = VringConfigData { + queue_max_size: 32, + queue_size: 32, + flags: 0, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: None, + }; + assert!(vsock.is_valid(&config)); + + config.queue_size = 0; + assert!(!vsock.is_valid(&config)); + config.queue_size = 31; + assert!(!vsock.is_valid(&config)); + config.queue_size = 33; + assert!(!vsock.is_valid(&config)); + } + + #[test] + fn test_vsock_ioctls() { + let m = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let vsock = Vsock::new(&m).unwrap(); + + let features = vsock.get_features().unwrap(); + vsock.set_features(features).unwrap(); + + vsock.set_owner().unwrap(); + + vsock.set_mem_table(&[]).unwrap_err(); + + /* + let region = VhostUserMemoryRegionInfo { + guest_phys_addr: 0x0, + memory_size: 0x10_0000, + userspace_addr: 0, + mmap_offset: 0, + mmap_handle: -1, + }; + vsock.set_mem_table(&[region]).unwrap_err(); + */ + + let region = VhostUserMemoryRegionInfo::new( + 0x0, + 0x10_0000, + m.get_host_address(GuestAddress(0x0)).unwrap() as u64, + 0, + -1, + ); + vsock.set_mem_table(&[region]).unwrap(); + + vsock + .set_log_base( + 0x4000, + Some(VhostUserDirtyLogRegion { + mmap_size: 0x1000, + mmap_offset: 0x10, + mmap_handle: 1, + }), + ) + .unwrap_err(); + vsock.set_log_base(0x4000, None).unwrap(); + + let eventfd = EventFd::new(0).unwrap(); + vsock.set_log_fd(eventfd.as_raw_fd()).unwrap(); + + vsock.set_vring_num(0, 32).unwrap(); + + let config = VringConfigData { + queue_max_size: 32, + queue_size: 32, + flags: 0, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: None, + }; + vsock.set_vring_addr(0, &config).unwrap(); + vsock.set_vring_base(0, 1).unwrap(); + vsock.set_vring_call(0, &eventfd).unwrap(); + vsock.set_vring_kick(0, &eventfd).unwrap(); + vsock.set_vring_err(0, &eventfd).unwrap(); + assert_eq!(vsock.get_vring_base(0).unwrap(), 1); + vsock.set_guest_cid(0xdead).unwrap(); + //vsock.start().unwrap(); + //vsock.stop().unwrap(); + } +} diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs new file mode 100644 index 0000000..4a62e12 --- /dev/null +++ b/src/vhost_user/connection.rs @@ -0,0 +1,903 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Structs for Unix Domain Socket listener and endpoint. + +#![allow(dead_code)] + +use std::fs::File; +use std::io::ErrorKind; +use std::marker::PhantomData; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::os::unix::net::{UnixListener, UnixStream}; +use std::path::{Path, PathBuf}; +use std::{mem, slice}; + +use libc::{c_void, iovec}; +use vm_memory::ByteValued; +use vmm_sys_util::sock_ctrl_msg::ScmSocket; + +use super::message::*; +use super::{Error, Result}; + +/// Unix domain socket listener for accepting incoming connections. +pub struct Listener { + fd: UnixListener, + path: Option<PathBuf>, +} + +impl Listener { + /// Create a unix domain socket listener. + /// + /// # Return: + /// * - the new Listener object on success. + /// * - SocketError: failed to create listener socket. + pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> { + if unlink { + let _ = std::fs::remove_file(&path); + } + let fd = UnixListener::bind(&path).map_err(Error::SocketError)?; + Ok(Listener { + fd, + path: Some(path.as_ref().to_owned()), + }) + } + + /// Accept an incoming connection. + /// + /// # Return: + /// * - Some(UnixStream): new UnixStream object if new incoming connection is available. + /// * - None: no incoming connection available. + /// * - SocketError: errors from accept(). + pub fn accept(&self) -> Result<Option<UnixStream>> { + loop { + match self.fd.accept() { + Ok((socket, _addr)) => return Ok(Some(socket)), + Err(e) => { + match e.kind() { + // No incoming connection available. + ErrorKind::WouldBlock => return Ok(None), + // New connection closed by peer. + ErrorKind::ConnectionAborted => return Ok(None), + // Interrupted by signals, retry + ErrorKind::Interrupted => continue, + _ => return Err(Error::SocketError(e)), + } + } + } + } + } + + /// Change blocking status on the listener. + /// + /// # Return: + /// * - () on success. + /// * - SocketError: failure from set_nonblocking(). + pub fn set_nonblocking(&self, block: bool) -> Result<()> { + self.fd.set_nonblocking(block).map_err(Error::SocketError) + } +} + +impl AsRawFd for Listener { + fn as_raw_fd(&self) -> RawFd { + self.fd.as_raw_fd() + } +} + +impl FromRawFd for Listener { + unsafe fn from_raw_fd(fd: RawFd) -> Self { + Listener { + fd: UnixListener::from_raw_fd(fd), + path: None, + } + } +} + +impl Drop for Listener { + fn drop(&mut self) { + if let Some(path) = &self.path { + let _ = std::fs::remove_file(path); + } + } +} + +/// Unix domain socket endpoint for vhost-user connection. +pub(super) struct Endpoint<R: Req> { + sock: UnixStream, + _r: PhantomData<R>, +} + +impl<R: Req> Endpoint<R> { + /// Create a new stream by connecting to server at `str`. + /// + /// # Return: + /// * - the new Endpoint object on success. + /// * - SocketConnect: failed to connect to peer. + pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> { + let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?; + Ok(Self::from_stream(sock)) + } + + /// Create an endpoint from a stream object. + pub fn from_stream(sock: UnixStream) -> Self { + Endpoint { + sock, + _r: PhantomData, + } + } + + /// Sends bytes from scatter-gather vectors over the socket with optional attached file + /// descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> { + let rfds = match fds { + Some(rfds) => rfds, + _ => &[], + }; + self.sock.send_with_fds(iovs, rfds).map_err(Into::into) + } + + /// Sends all bytes from scatter-gather vectors over the socket with optional attached file + /// descriptors. Will loop until all data has been transfered. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> { + let mut data_sent = 0; + let mut data_total = 0; + let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.len()).collect(); + for len in &iov_lens { + data_total += len; + } + + while (data_total - data_sent) > 0 { + let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent); + let iov = &iovs[nr_skip][offset..]; + + let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat(); + let sfds = if data_sent == 0 { fds } else { None }; + + let sent = self.send_iovec(data, sfds); + match sent { + Ok(0) => return Ok(data_sent), + Ok(n) => data_sent += n, + Err(e) => match e { + Error::SocketRetry(_) => {} + _ => return Err(e), + }, + } + } + Ok(data_sent) + } + + /// Sends bytes from a slice over the socket with optional attached file descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> { + self.send_iovec(&[data], fds) + } + + /// Sends a header-only message with optional attached file descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + pub fn send_header( + &mut self, + hdr: &VhostUserMsgHeader<R>, + fds: Option<&[RawFd]>, + ) -> Result<()> { + // SAFETY: Safe because there can't be other mutable referance to hdr. + let iovs = unsafe { + [slice::from_raw_parts( + hdr as *const VhostUserMsgHeader<R> as *const u8, + mem::size_of::<VhostUserMsgHeader<R>>(), + )] + }; + let bytes = self.send_iovec_all(&iovs[..], fds)?; + if bytes != mem::size_of::<VhostUserMsgHeader<R>>() { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Send a message with header and body. Optional file descriptors may be attached to + /// the message. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + pub fn send_message<T: ByteValued>( + &mut self, + hdr: &VhostUserMsgHeader<R>, + body: &T, + fds: Option<&[RawFd]>, + ) -> Result<()> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::OversizedMsg); + } + let bytes = self.send_iovec_all(&[hdr.as_slice(), body.as_slice()], fds)?; + if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Send a message with header, body and payload. Optional file descriptors + /// may also be attached to the message. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - OversizedMsg: message size is too big. + /// * - PartialMessage: received a partial message. + /// * - IncorrectFds: wrong number of attached fds. + pub fn send_message_with_payload<T: ByteValued>( + &mut self, + hdr: &VhostUserMsgHeader<R>, + body: &T, + payload: &[u8], + fds: Option<&[RawFd]>, + ) -> Result<()> { + let len = payload.len(); + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::OversizedMsg); + } + if len > MAX_MSG_SIZE - mem::size_of::<T>() { + return Err(Error::OversizedMsg); + } + if let Some(fd_arr) = fds { + if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES { + return Err(Error::IncorrectFds); + } + } + + let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len; + let len = self.send_iovec_all(&[hdr.as_slice(), body.as_slice(), payload], fds)?; + if len != total { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Reads bytes from the socket into the given scatter/gather vectors. + /// + /// # Return: + /// * - (number of bytes received, buf) on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)> { + let mut rbuf = vec![0u8; len]; + let mut iovs = [iovec { + iov_base: rbuf.as_mut_ptr() as *mut c_void, + iov_len: len, + }]; + // SAFETY: Safe because we own rbuf and it's safe to fill a byte array with arbitrary data. + let (bytes, _) = unsafe { self.sock.recv_with_fds(&mut iovs, &mut [])? }; + Ok((bytes, rbuf)) + } + + /// Reads bytes from the socket into the given scatter/gather vectors with optional attached + /// file. + /// + /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little + /// tricky to pass file descriptors through such a communication channel. Let's assume that a + /// sender sending a message with some file descriptors attached. To successfully receive those + /// attached file descriptors, the receiver must obey following rules: + /// 1) file descriptors are attached to a message. + /// 2) message(packet) boundaries must be respected on the receive side. + /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the + /// attached file descriptors will get lost. + /// Note that this function wraps received file descriptors as `File`. + /// + /// # Return: + /// * - (number of bytes received, [received files]) on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// + /// # Safety + /// + /// It is the callers responsibility to ensure it is safe for arbitrary data to be + /// written to the iovec pointers. + pub unsafe fn recv_into_iovec( + &mut self, + iovs: &mut [iovec], + ) -> Result<(usize, Option<Vec<File>>)> { + let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES]; + let (bytes, fds) = self.sock.recv_with_fds(iovs, &mut fd_array)?; + + let files = match fds { + 0 => None, + n => { + let files = fd_array + .iter() + .take(n) + .map(|fd| { + // Safe because we have the ownership of `fd`. + File::from_raw_fd(*fd) + }) + .collect(); + Some(files) + } + }; + + Ok((bytes, files)) + } + + /// Reads all bytes from the socket into the given scatter/gather vectors with optional + /// attached files. Will loop until all data has been transferred. + /// + /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little + /// tricky to pass file descriptors through such a communication channel. Let's assume that a + /// sender sending a message with some file descriptors attached. To successfully receive those + /// attached file descriptors, the receiver must obey following rules: + /// 1) file descriptors are attached to a message. + /// 2) message(packet) boundaries must be respected on the receive side. + /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the + /// attached file descriptors will get lost. + /// Note that this function wraps received file descriptors as `File`. + /// + /// # Return: + /// * - (number of bytes received, [received fds]) on success + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// + /// # Safety + /// + /// It is the callers responsibility to ensure it is safe for arbitrary data to be + /// written to the iovec pointers. + pub unsafe fn recv_into_iovec_all( + &mut self, + iovs: &mut [iovec], + ) -> Result<(usize, Option<Vec<File>>)> { + let mut data_read = 0; + let mut data_total = 0; + let mut rfds = None; + let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.iov_len).collect(); + for len in &iov_lens { + data_total += len; + } + + while (data_total - data_read) > 0 { + let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_read); + let iov = &mut iovs[nr_skip]; + + let mut data = [ + &[iovec { + iov_base: (iov.iov_base as usize + offset) as *mut c_void, + iov_len: iov.iov_len - offset, + }], + &iovs[(nr_skip + 1)..], + ] + .concat(); + + let res = self.recv_into_iovec(&mut data); + match res { + Ok((0, _)) => return Ok((data_read, rfds)), + Ok((n, fds)) => { + if data_read == 0 { + rfds = fds; + } + data_read += n; + } + Err(e) => match e { + Error::SocketRetry(_) => {} + _ => return Err(e), + }, + } + } + Ok((data_read, rfds)) + } + + /// Reads bytes from the socket into a new buffer with optional attached + /// files. Received file descriptors are set close-on-exec and converted to `File`. + /// + /// # Return: + /// * - (number of bytes received, buf, [received files]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_into_buf( + &mut self, + buf_size: usize, + ) -> Result<(usize, Vec<u8>, Option<Vec<File>>)> { + let mut buf = vec![0u8; buf_size]; + let (bytes, files) = { + let mut iovs = [iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf_size, + }]; + // SAFETY: Safe because we own buf and it's safe to fill a byte array with arbitrary data. + unsafe { self.recv_into_iovec(&mut iovs)? } + }; + Ok((bytes, buf, files)) + } + + /// Receive a header-only message with optional attached files. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, [received files]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<File>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut iovs = [iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }]; + // SAFETY: Safe because we own hdr and it's ByteValued. + let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? }; + + if bytes == 0 { + return Err(Error::Disconnected); + } else if bytes != mem::size_of::<VhostUserMsgHeader<R>>() { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, files)) + } + + /// Receive a message with optional attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, message body, [received files]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_body<T: ByteValued + Sized + VhostUserMsgValidator>( + &mut self, + ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<File>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut body: T = Default::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }, + iovec { + iov_base: (&mut body as *mut T) as *mut c_void, + iov_len: mem::size_of::<T>(), + }, + ]; + // SAFETY: Safe because we own hdr and body and they're ByteValued. + let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? }; + + let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); + if bytes != total { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() || !body.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, body, files)) + } + + /// Receive a message with header and optional content. Callers need to + /// pre-allocate a big enough buffer to receive the message body and + /// optional payload. If there are attached file descriptor associated + /// with the message, the first MAX_ATTACHED_FD_ENTRIES file descriptors + /// will be accepted and all other file descriptor will be discard + /// silently. + /// + /// # Return: + /// * - (message header, message size, [received files]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_body_into_buf( + &mut self, + buf: &mut [u8], + ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<File>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }, + iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }, + ]; + // SAFETY: Safe because we own hdr and have a mutable borrow of buf, and hdr is ByteValued + // and it's safe to fill a byte slice with arbitrary data. + let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? }; + + if bytes < mem::size_of::<VhostUserMsgHeader<R>>() { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), files)) + } + + /// Receive a message with optional payload and attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, message body, size of payload, [received files]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))] + pub fn recv_payload_into_buf<T: ByteValued + Sized + VhostUserMsgValidator>( + &mut self, + buf: &mut [u8], + ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<File>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut body: T = Default::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }, + iovec { + iov_base: (&mut body as *mut T) as *mut c_void, + iov_len: mem::size_of::<T>(), + }, + iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }, + ]; + // SAFETY: Safe because we own hdr and body and have a mutable borrow of buf, and + // hdr and body are ByteValued, and it's safe to fill a byte slice with + // arbitrary data. + let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? }; + + let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); + if bytes < total { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() || !body.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, body, bytes - total, files)) + } +} + +impl<T: Req> AsRawFd for Endpoint<T> { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +} + +// Given a slice of sizes and the `skip_size`, return the offset of `skip_size` in the slice. +// For example: +// let iov_lens = vec![4, 4, 5]; +// let size = 6; +// assert_eq!(get_sub_iovs_offset(&iov_len, size), (1, 2)); +fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) { + let mut size = skip_size; + let mut nr_skip = 0; + + for len in iov_lens { + if size >= *len { + size -= *len; + nr_skip += 1; + } else { + break; + } + } + (nr_skip, size) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::{Read, Seek, SeekFrom, Write}; + use vmm_sys_util::rand::rand_alphanumerics; + use vmm_sys_util::tempfile::TempFile; + + fn temp_path() -> PathBuf { + PathBuf::from(format!( + "/tmp/vhost_test_{}", + rand_alphanumerics(8).to_str().unwrap() + )) + } + + #[test] + fn create_listener() { + let path = temp_path(); + let listener = Listener::new(path, true).unwrap(); + + assert!(listener.as_raw_fd() > 0); + } + + #[test] + fn create_listener_from_raw_fd() { + let path = temp_path(); + let file = File::create(path).unwrap(); + + // SAFETY: Safe because `file` contains a valid fd to a file just created. + let listener = unsafe { Listener::from_raw_fd(file.as_raw_fd()) }; + + assert!(listener.as_raw_fd() > 0); + } + + #[test] + fn accept_connection() { + let path = temp_path(); + let listener = Listener::new(path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + + // accept on a fd without incoming connection + let conn = listener.accept().unwrap(); + assert!(conn.is_none()); + } + + #[test] + fn send_data() { + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + let buf1 = vec![0x1, 0x2, 0x3, 0x4]; + let mut len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap(); + assert_eq!(bytes, 4); + assert_eq!(&buf1[..], &buf2[..bytes]); + + len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf2[..]); + let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + } + + #[test] + fn send_fd() { + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + let mut fd = TempFile::new().unwrap().into_file(); + write!(fd, "test").unwrap(); + + // Normal case for sending/receiving file descriptors + let buf1 = vec![0x1, 0x2, 0x3, 0x4]; + let len = master + .send_slice(&buf1[..], Some(&[fd.as_raw_fd()])) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf2, files) = slave.recv_into_buf(4).unwrap(); + assert_eq!(bytes, 4); + assert_eq!(&buf1[..], &buf2[..]); + assert!(files.is_some()); + let files = files.unwrap(); + { + assert_eq!(files.len(), 1); + let mut file = &files[0]; + let mut content = String::new(); + file.seek(SeekFrom::Start(0)).unwrap(); + file.read_to_string(&mut content).unwrap(); + assert_eq!(content, "test"); + } + + // Following communication pattern should work: + // Sending side: data(header, body) with fds + // Receiving side: data(header) with fds, data(body) + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf2[..]); + assert!(files.is_some()); + let files = files.unwrap(); + { + assert_eq!(files.len(), 3); + let mut file = &files[1]; + let mut content = String::new(); + file.seek(SeekFrom::Start(0)).unwrap(); + file.read_to_string(&mut content).unwrap(); + assert_eq!(content, "test"); + } + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + assert!(files.is_none()); + + // Following communication pattern should not work: + // Sending side: data(header, body) with fds + // Receiving side: data(header), data(body) with fds + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf4) = slave.recv_data(2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf4[..]); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + assert!(files.is_none()); + + // Following communication pattern should work: + // Sending side: data, data with fds + // Receiving side: data, data with fds + let len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf2, files) = slave.recv_into_buf(0x4).unwrap(); + assert_eq!(bytes, 4); + assert_eq!(&buf1[..], &buf2[..]); + assert!(files.is_none()); + + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf2[..]); + assert!(files.is_some()); + let files = files.unwrap(); + { + assert_eq!(files.len(), 3); + let mut file = &files[1]; + let mut content = String::new(); + file.seek(SeekFrom::Start(0)).unwrap(); + file.read_to_string(&mut content).unwrap(); + assert_eq!(content, "test"); + } + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + assert!(files.is_none()); + + // Following communication pattern should not work: + // Sending side: data1, data2 with fds + // Receiving side: data + partial of data2, left of data2 with fds + let len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, _) = slave.recv_data(5).unwrap(); + assert_eq!(bytes, 5); + + let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap(); + assert_eq!(bytes, 3); + assert!(files.is_none()); + + // If the target fd array is too small, extra file descriptors will get lost. + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap(); + assert_eq!(bytes, 4); + assert!(files.is_some()); + } + + #[test] + fn send_recv() { + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + let mut hdr1 = + VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32); + hdr1.set_need_reply(true); + let features1 = 0x1u64; + master.send_message(&hdr1, &features1, None).unwrap(); + + let mut features2 = 0u64; + + // SAFETY: Safe because features2 is valid and it's an `u64`. + let slice = unsafe { + slice::from_raw_parts_mut( + (&mut features2 as *mut u64) as *mut u8, + mem::size_of::<u64>(), + ) + }; + let (hdr2, bytes, files) = slave.recv_body_into_buf(slice).unwrap(); + assert_eq!(hdr1, hdr2); + assert_eq!(bytes, 8); + assert_eq!(features1, features2); + assert!(files.is_none()); + + master.send_header(&hdr1, None).unwrap(); + let (hdr2, files) = slave.recv_header().unwrap(); + assert_eq!(hdr1, hdr2); + assert!(files.is_none()); + } + + #[test] + fn partial_message() { + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); + let mut master = UnixStream::connect(&path).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + write!(master, "a").unwrap(); + drop(master); + assert!(matches!(slave.recv_header(), Err(Error::PartialMessage))); + } + + #[test] + fn disconnected() { + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); + let _ = UnixStream::connect(&path).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + assert!(matches!(slave.recv_header(), Err(Error::Disconnected))); + } +} diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs new file mode 100644 index 0000000..ae728a0 --- /dev/null +++ b/src/vhost_user/dummy_slave.rs @@ -0,0 +1,294 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::fs::File; + +use super::message::*; +use super::*; + +pub const MAX_QUEUE_NUM: usize = 2; +pub const MAX_VRING_NUM: usize = 256; +pub const MAX_MEM_SLOTS: usize = 32; +pub const VIRTIO_FEATURES: u64 = 0x40000003; + +#[derive(Default)] +pub struct DummySlaveReqHandler { + pub owned: bool, + pub features_acked: bool, + pub acked_features: u64, + pub acked_protocol_features: u64, + pub queue_num: usize, + pub vring_num: [u32; MAX_QUEUE_NUM], + pub vring_base: [u32; MAX_QUEUE_NUM], + pub call_fd: [Option<File>; MAX_QUEUE_NUM], + pub kick_fd: [Option<File>; MAX_QUEUE_NUM], + pub err_fd: [Option<File>; MAX_QUEUE_NUM], + pub vring_started: [bool; MAX_QUEUE_NUM], + pub vring_enabled: [bool; MAX_QUEUE_NUM], + pub inflight_file: Option<File>, +} + +impl DummySlaveReqHandler { + pub fn new() -> Self { + DummySlaveReqHandler { + queue_num: MAX_QUEUE_NUM, + ..Default::default() + } + } + + /// Helper to check if VirtioFeature enabled + fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> { + if self.acked_features & feat.bits() != 0 { + Ok(()) + } else { + Err(Error::InactiveFeature(feat)) + } + } + + /// Helper to check is VhostUserProtocolFeatures enabled + fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> Result<()> { + if self.acked_protocol_features & feat.bits() != 0 { + Ok(()) + } else { + Err(Error::InactiveOperation(feat)) + } + } +} + +impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { + fn set_owner(&mut self) -> Result<()> { + if self.owned { + return Err(Error::InvalidOperation("already claimed")); + } + self.owned = true; + Ok(()) + } + + fn reset_owner(&mut self) -> Result<()> { + self.owned = false; + self.features_acked = false; + self.acked_features = 0; + self.acked_protocol_features = 0; + Ok(()) + } + + fn get_features(&mut self) -> Result<u64> { + Ok(VIRTIO_FEATURES) + } + + fn set_features(&mut self, features: u64) -> Result<()> { + if !self.owned { + return Err(Error::InvalidOperation("not owned")); + } else if self.features_acked { + return Err(Error::InvalidOperation("features already set")); + } else if (features & !VIRTIO_FEATURES) != 0 { + return Err(Error::InvalidParam); + } + + self.acked_features = features; + self.features_acked = true; + + // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, + // the ring is initialized in an enabled state. + // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, + // the ring is initialized in a disabled state. Client must not + // pass data to/from the backend until ring is enabled by + // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has + // been disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0. + let vring_enabled = + self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0; + for enabled in &mut self.vring_enabled { + *enabled = vring_enabled; + } + + Ok(()) + } + + fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _files: Vec<File>) -> Result<()> { + Ok(()) + } + + fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> { + if index as usize >= self.queue_num || num == 0 || num as usize > MAX_VRING_NUM { + return Err(Error::InvalidParam); + } + self.vring_num[index as usize] = num; + Ok(()) + } + + fn set_vring_addr( + &mut self, + index: u32, + _flags: VhostUserVringAddrFlags, + _descriptor: u64, + _used: u64, + _available: u64, + _log: u64, + ) -> Result<()> { + if index as usize >= self.queue_num { + return Err(Error::InvalidParam); + } + Ok(()) + } + + fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> { + if index as usize >= self.queue_num || base as usize >= MAX_VRING_NUM { + return Err(Error::InvalidParam); + } + self.vring_base[index as usize] = base; + Ok(()) + } + + fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> { + if index as usize >= self.queue_num { + return Err(Error::InvalidParam); + } + // Quotation from vhost-user spec: + // Client must start ring upon receiving a kick (that is, detecting + // that file descriptor is readable) on the descriptor specified by + // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving + // VHOST_USER_GET_VRING_BASE. + self.vring_started[index as usize] = false; + Ok(VhostUserVringState::new( + index, + self.vring_base[index as usize], + )) + } + + fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> { + if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + self.kick_fd[index as usize] = fd; + + // Quotation from vhost-user spec: + // Client must start ring upon receiving a kick (that is, detecting + // that file descriptor is readable) on the descriptor specified by + // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving + // VHOST_USER_GET_VRING_BASE. + // + // So we should add fd to event monitor(select, poll, epoll) here. + self.vring_started[index as usize] = true; + Ok(()) + } + + fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> { + if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + self.call_fd[index as usize] = fd; + Ok(()) + } + + fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> { + if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + self.err_fd[index as usize] = fd; + Ok(()) + } + + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { + Ok(VhostUserProtocolFeatures::all()) + } + + fn set_protocol_features(&mut self, features: u64) -> Result<()> { + // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must + // support this message even before VHOST_USER_SET_FEATURES was + // called. + // What happens if the master calls set_features() with + // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this + // interface? + self.acked_protocol_features = features; + Ok(()) + } + + fn get_queue_num(&mut self) -> Result<u64> { + Ok(MAX_QUEUE_NUM as u64) + } + + fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> { + // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES + // has been negotiated. + self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?; + + if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + + // Slave must not pass data to/from the backend until ring is + // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1, + // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE + // with parameter 0. + self.vring_enabled[index as usize] = enable; + Ok(()) + } + + fn get_config( + &mut self, + offset: u32, + size: u32, + _flags: VhostUserConfigFlags, + ) -> Result<Vec<u8>> { + self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; + + if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) + || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET + || size + offset > VHOST_USER_CONFIG_SIZE + { + return Err(Error::InvalidParam); + } + assert_eq!(offset, 0x100); + assert_eq!(size, 4); + Ok(vec![0xa5; size as usize]) + } + + fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> { + let size = buf.len() as u32; + self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; + + if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) + || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET + || size + offset > VHOST_USER_CONFIG_SIZE + { + return Err(Error::InvalidParam); + } + assert_eq!(offset, 0x100); + assert_eq!(buf.len(), 4); + assert_eq!(buf, &[0xa5; 4]); + Ok(()) + } + + fn get_inflight_fd( + &mut self, + inflight: &VhostUserInflight, + ) -> Result<(VhostUserInflight, File)> { + let file = tempfile::tempfile().unwrap(); + self.inflight_file = Some(file.try_clone().unwrap()); + Ok(( + VhostUserInflight { + mmap_size: 0x1000, + mmap_offset: 0, + num_queues: inflight.num_queues, + queue_size: inflight.queue_size, + }, + file, + )) + } + + fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> Result<()> { + Ok(()) + } + + fn get_max_mem_slots(&mut self) -> Result<u64> { + Ok(MAX_MEM_SLOTS as u64) + } + + fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: File) -> Result<()> { + Ok(()) + } + + fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> Result<()> { + Ok(()) + } +} diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs new file mode 100644 index 0000000..feeb984 --- /dev/null +++ b/src/vhost_user/master.rs @@ -0,0 +1,1110 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Traits and Struct for vhost-user master. + +use std::fs::File; +use std::mem; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::path::Path; +use std::sync::{Arc, Mutex, MutexGuard}; + +use vm_memory::ByteValued; +use vmm_sys_util::eventfd::EventFd; + +use super::connection::Endpoint; +use super::message::*; +use super::{take_single_file, Error as VhostUserError, Result as VhostUserResult}; +use crate::backend::{ + VhostBackend, VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData, +}; +use crate::{Error, Result}; + +/// Trait for vhost-user master to provide extra methods not covered by the VhostBackend yet. +pub trait VhostUserMaster: VhostBackend { + /// Get the protocol feature bitmask from the underlying vhost implementation. + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; + + /// Enable protocol features in the underlying vhost implementation. + fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()>; + + /// Query how many queues the backend supports. + fn get_queue_num(&mut self) -> Result<u64>; + + /// Signal slave to enable or disable corresponding vring. + /// + /// Slave must not pass data to/from the backend until ring is enabled by + /// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been + /// disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0. + fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()>; + + /// Fetch the contents of the virtio device configuration space. + fn get_config( + &mut self, + offset: u32, + size: u32, + flags: VhostUserConfigFlags, + buf: &[u8], + ) -> Result<(VhostUserConfig, VhostUserConfigPayload)>; + + /// Change the virtio device configuration space. It also can be used for live migration on the + /// destination host to set readonly configuration space fields. + fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()>; + + /// Setup slave communication channel. + fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()>; + + /// Retrieve shared buffer for inflight I/O tracking. + fn get_inflight_fd( + &mut self, + inflight: &VhostUserInflight, + ) -> Result<(VhostUserInflight, File)>; + + /// Set shared buffer for inflight I/O tracking. + fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()>; + + /// Query the maximum amount of memory slots supported by the backend. + fn get_max_mem_slots(&mut self) -> Result<u64>; + + /// Add a new guest memory mapping for vhost to use. + fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>; + + /// Remove a guest memory mapping from vhost. + fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>; +} + +fn error_code<T>(err: VhostUserError) -> Result<T> { + Err(Error::VhostUserProtocol(err)) +} + +/// Struct for the vhost-user master endpoint. +#[derive(Clone)] +pub struct Master { + node: Arc<Mutex<MasterInternal>>, +} + +impl Master { + /// Create a new instance. + fn new(ep: Endpoint<MasterReq>, max_queue_num: u64) -> Self { + Master { + node: Arc::new(Mutex::new(MasterInternal { + main_sock: ep, + virtio_features: 0, + acked_virtio_features: 0, + protocol_features: 0, + acked_protocol_features: 0, + protocol_features_ready: false, + max_queue_num, + error: None, + hdr_flags: VhostUserHeaderFlag::empty(), + })), + } + } + + fn node(&self) -> MutexGuard<MasterInternal> { + self.node.lock().unwrap() + } + + /// Create a new instance from a Unix stream socket. + pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self { + Self::new(Endpoint::<MasterReq>::from_stream(sock), max_queue_num) + } + + /// Create a new vhost-user master endpoint. + /// + /// Will retry as the backend may not be ready to accept the connection. + /// + /// # Arguments + /// * `path` - path of Unix domain socket listener to connect to + pub fn connect<P: AsRef<Path>>(path: P, max_queue_num: u64) -> Result<Self> { + let mut retry_count = 5; + let endpoint = loop { + match Endpoint::<MasterReq>::connect(&path) { + Ok(endpoint) => break Ok(endpoint), + Err(e) => match &e { + VhostUserError::SocketConnect(why) => { + if why.kind() == std::io::ErrorKind::ConnectionRefused && retry_count > 0 { + std::thread::sleep(std::time::Duration::from_millis(100)); + retry_count -= 1; + continue; + } else { + break Err(e); + } + } + _ => break Err(e), + }, + } + }?; + + Ok(Self::new(endpoint, max_queue_num)) + } + + /// Set the header flags that should be applied to all following messages. + pub fn set_hdr_flags(&self, flags: VhostUserHeaderFlag) { + let mut node = self.node(); + node.hdr_flags = flags; + } +} + +impl VhostBackend for Master { + /// Get from the underlying vhost implementation the feature bitmask. + fn get_features(&self) -> Result<u64> { + let mut node = self.node(); + let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + node.virtio_features = val.value; + Ok(node.virtio_features) + } + + /// Enable features in the underlying vhost implementation using a bitmask. + fn set_features(&self, features: u64) -> Result<()> { + let mut node = self.node(); + let val = VhostUserU64::new(features); + let hdr = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?; + node.acked_virtio_features = features & node.virtio_features; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + /// Set the current Master as an owner of the session. + fn set_owner(&self) -> Result<()> { + // We unwrap() the return value to assert that we are not expecting threads to ever fail + // while holding the lock. + let mut node = self.node(); + let hdr = node.send_request_header(MasterReq::SET_OWNER, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn reset_owner(&self) -> Result<()> { + let mut node = self.node(); + let hdr = node.send_request_header(MasterReq::RESET_OWNER, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + /// Set the memory map regions on the slave so it can translate the vring + /// addresses. In the ancillary data there is an array of file descriptors + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES { + return error_code(VhostUserError::InvalidParam); + } + + let mut ctx = VhostUserMemoryContext::new(); + for region in regions.iter() { + if region.memory_size == 0 || region.mmap_handle < 0 { + return error_code(VhostUserError::InvalidParam); + } + + ctx.append(®ion.to_region(), region.mmap_handle); + } + + let mut node = self.node(); + let body = VhostUserMemory::new(ctx.regions.len() as u32); + // SAFETY: Safe because ctx.regions is a valid Vec() at this point. + let (_, payload, _) = unsafe { ctx.regions.align_to::<u8>() }; + let hdr = node.send_request_with_payload( + MasterReq::SET_MEM_TABLE, + &body, + payload, + Some(ctx.fds.as_slice()), + )?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + // Clippy doesn't seem to know that if let with && is still experimental + #[allow(clippy::unnecessary_unwrap)] + fn set_log_base(&self, base: u64, region: Option<VhostUserDirtyLogRegion>) -> Result<()> { + let mut node = self.node(); + let val = VhostUserU64::new(base); + + if node.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0 + && region.is_some() + { + let region = region.unwrap(); + let log = VhostUserLog { + mmap_size: region.mmap_size, + mmap_offset: region.mmap_offset, + }; + let hdr = node.send_request_with_body( + MasterReq::SET_LOG_BASE, + &log, + Some(&[region.mmap_handle]), + )?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } else { + let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, None)?; + Ok(()) + } + } + + fn set_log_fd(&self, fd: RawFd) -> Result<()> { + let mut node = self.node(); + let fds = [fd]; + let hdr = node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + /// Set the size of the queue. + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let val = VhostUserVringState::new(queue_index as u32, num.into()); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_NUM, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + /// Sets the addresses of the different aspects of the vring. + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num + || config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0 + { + return error_code(VhostUserError::InvalidParam); + } + + let val = VhostUserVringAddr::from_config_data(queue_index as u32, config_data); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_ADDR, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + /// Sets the base offset in the available vring. + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let val = VhostUserVringState::new(queue_index as u32, base.into()); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_BASE, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn get_vring_base(&self, queue_index: usize) -> Result<u32> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let req = VhostUserVringState::new(queue_index as u32, 0); + let hdr = node.send_request_with_body(MasterReq::GET_VRING_BASE, &req, None)?; + let reply = node.recv_reply::<VhostUserVringState>(&hdr)?; + Ok(reply.num) + } + + /// Set the event file descriptor to signal when buffers are used. + /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag + /// is set when there is no file descriptor in the ancillary data. This signals that polling + /// will be used instead of waiting for the call. + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_CALL, queue_index, fd.as_raw_fd())?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + /// Set the event file descriptor for adding buffers to the vring. + /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag + /// is set when there is no file descriptor in the ancillary data. This signals that polling + /// should be used instead of waiting for a kick. + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_KICK, queue_index, fd.as_raw_fd())?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + /// Set the event file descriptor to signal when error occurs. + /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag + /// is set when there is no file descriptor in the ancillary data. + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_ERR, queue_index, fd.as_raw_fd())?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } +} + +impl VhostUserMaster for Master { + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { + let mut node = self.node(); + node.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?; + let hdr = node.send_request_header(MasterReq::GET_PROTOCOL_FEATURES, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + node.protocol_features = val.value; + // Should we support forward compatibility? + // If so just mask out unrecognized flags instead of return errors. + match VhostUserProtocolFeatures::from_bits(node.protocol_features) { + Some(val) => Ok(val), + None => error_code(VhostUserError::InvalidMessage), + } + } + + fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> { + let mut node = self.node(); + node.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?; + let val = VhostUserU64::new(features.bits()); + let hdr = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?; + // Don't wait for ACK here because the protocol feature negotiation process hasn't been + // completed yet. + node.acked_protocol_features = features.bits(); + node.protocol_features_ready = true; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn get_queue_num(&mut self) -> Result<u64> { + let mut node = self.node(); + node.check_proto_feature(VhostUserProtocolFeatures::MQ)?; + + let hdr = node.send_request_header(MasterReq::GET_QUEUE_NUM, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + if val.value > VHOST_USER_MAX_VRINGS { + return error_code(VhostUserError::InvalidMessage); + } + node.max_queue_num = val.value; + Ok(node.max_queue_num) + } + + fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()> { + let mut node = self.node(); + // set_vring_enable() is supported only when PROTOCOL_FEATURES has been enabled. + if node.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 { + return error_code(VhostUserError::InactiveFeature( + VhostUserVirtioFeatures::PROTOCOL_FEATURES, + )); + } else if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let flag = if enable { 1 } else { 0 }; + let val = VhostUserVringState::new(queue_index as u32, flag); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_ENABLE, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn get_config( + &mut self, + offset: u32, + size: u32, + flags: VhostUserConfigFlags, + buf: &[u8], + ) -> Result<(VhostUserConfig, VhostUserConfigPayload)> { + let body = VhostUserConfig::new(offset, size, flags); + if !body.is_valid() { + return error_code(VhostUserError::InvalidParam); + } + + let mut node = self.node(); + // depends on VhostUserProtocolFeatures::CONFIG + node.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; + + // vhost-user spec states that: + // "Master payload: virtio device config space" + // "Slave payload: virtio device config space" + let hdr = node.send_request_with_payload(MasterReq::GET_CONFIG, &body, buf, None)?; + let (body_reply, buf_reply, rfds) = + node.recv_reply_with_payload::<VhostUserConfig>(&hdr)?; + if rfds.is_some() { + return error_code(VhostUserError::InvalidMessage); + } else if body_reply.size == 0 { + return error_code(VhostUserError::SlaveInternalError); + } else if body_reply.size != body.size + || body_reply.size as usize != buf.len() + || body_reply.offset != body.offset + { + return error_code(VhostUserError::InvalidMessage); + } + + Ok((body_reply, buf_reply)) + } + + fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()> { + if buf.len() > MAX_MSG_SIZE { + return error_code(VhostUserError::InvalidParam); + } + let body = VhostUserConfig::new(offset, buf.len() as u32, flags); + if !body.is_valid() { + return error_code(VhostUserError::InvalidParam); + } + + let mut node = self.node(); + // depends on VhostUserProtocolFeatures::CONFIG + node.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; + + let hdr = node.send_request_with_payload(MasterReq::SET_CONFIG, &body, buf, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()> { + let mut node = self.node(); + node.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?; + let fds = [fd.as_raw_fd()]; + let hdr = node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn get_inflight_fd( + &mut self, + inflight: &VhostUserInflight, + ) -> Result<(VhostUserInflight, File)> { + let mut node = self.node(); + node.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; + + let hdr = node.send_request_with_body(MasterReq::GET_INFLIGHT_FD, inflight, None)?; + let (inflight, files) = node.recv_reply_with_files::<VhostUserInflight>(&hdr)?; + + match take_single_file(files) { + Some(file) => Ok((inflight, file)), + None => error_code(VhostUserError::IncorrectFds), + } + } + + fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()> { + let mut node = self.node(); + node.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; + + if inflight.mmap_size == 0 || inflight.num_queues == 0 || inflight.queue_size == 0 || fd < 0 + { + return error_code(VhostUserError::InvalidParam); + } + + let hdr = node.send_request_with_body(MasterReq::SET_INFLIGHT_FD, inflight, Some(&[fd]))?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn get_max_mem_slots(&mut self) -> Result<u64> { + let mut node = self.node(); + node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; + + let hdr = node.send_request_header(MasterReq::GET_MAX_MEM_SLOTS, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + + Ok(val.value) + } + + fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> { + let mut node = self.node(); + node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; + if region.memory_size == 0 || region.mmap_handle < 0 { + return error_code(VhostUserError::InvalidParam); + } + + let body = region.to_single_region(); + let fds = [region.mmap_handle]; + let hdr = node.send_request_with_body(MasterReq::ADD_MEM_REG, &body, Some(&fds))?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> { + let mut node = self.node(); + node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; + if region.memory_size == 0 { + return error_code(VhostUserError::InvalidParam); + } + + let body = region.to_single_region(); + let hdr = node.send_request_with_body(MasterReq::REM_MEM_REG, &body, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } +} + +impl AsRawFd for Master { + fn as_raw_fd(&self) -> RawFd { + let node = self.node(); + node.main_sock.as_raw_fd() + } +} + +/// Context object to pass guest memory configuration to VhostUserMaster::set_mem_table(). +struct VhostUserMemoryContext { + regions: VhostUserMemoryPayload, + fds: Vec<RawFd>, +} + +impl VhostUserMemoryContext { + /// Create a context object. + pub fn new() -> Self { + VhostUserMemoryContext { + regions: VhostUserMemoryPayload::new(), + fds: Vec::new(), + } + } + + /// Append a user memory region and corresponding RawFd into the context object. + pub fn append(&mut self, region: &VhostUserMemoryRegion, fd: RawFd) { + self.regions.push(*region); + self.fds.push(fd); + } +} + +struct MasterInternal { + // Used to send requests to the slave. + main_sock: Endpoint<MasterReq>, + // Cached virtio features from the slave. + virtio_features: u64, + // Cached acked virtio features from the driver. + acked_virtio_features: u64, + // Cached vhost-user protocol features from the slave. + protocol_features: u64, + // Cached vhost-user protocol features. + acked_protocol_features: u64, + // Cached vhost-user protocol features are ready to use. + protocol_features_ready: bool, + // Cached maxinum number of queues supported from the slave. + max_queue_num: u64, + // Internal flag to mark failure state. + error: Option<i32>, + // List of header flags. + hdr_flags: VhostUserHeaderFlag, +} + +impl MasterInternal { + fn send_request_header( + &mut self, + code: MasterReq, + fds: Option<&[RawFd]>, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + self.check_state()?; + let hdr = self.new_request_header(code, 0); + self.main_sock.send_header(&hdr, fds)?; + Ok(hdr) + } + + fn send_request_with_body<T: ByteValued>( + &mut self, + code: MasterReq, + msg: &T, + fds: Option<&[RawFd]>, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let hdr = self.new_request_header(code, mem::size_of::<T>() as u32); + self.main_sock.send_message(&hdr, msg, fds)?; + Ok(hdr) + } + + fn send_request_with_payload<T: ByteValued>( + &mut self, + code: MasterReq, + msg: &T, + payload: &[u8], + fds: Option<&[RawFd]>, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + let len = mem::size_of::<T>() + payload.len(); + if len > MAX_MSG_SIZE { + return Err(VhostUserError::InvalidParam); + } + if let Some(fd_arr) = fds { + if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES { + return Err(VhostUserError::InvalidParam); + } + } + self.check_state()?; + + let hdr = self.new_request_header(code, len as u32); + self.main_sock + .send_message_with_payload(&hdr, msg, payload, fds)?; + Ok(hdr) + } + + fn send_fd_for_vring( + &mut self, + code: MasterReq, + queue_index: usize, + fd: RawFd, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + if queue_index as u64 >= self.max_queue_num { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + // Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. + // This flag is set when there is no file descriptor in the ancillary data. This signals + // that polling will be used instead of waiting for the call. + let msg = VhostUserU64::new(queue_index as u64); + let hdr = self.new_request_header(code, mem::size_of::<VhostUserU64>() as u32); + self.main_sock.send_message(&hdr, &msg, Some(&[fd]))?; + Ok(hdr) + } + + fn recv_reply<T: ByteValued + Sized + VhostUserMsgValidator>( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + ) -> VhostUserResult<T> { + if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let (reply, body, rfds) = self.main_sock.recv_body::<T>()?; + if !reply.is_reply_for(hdr) || rfds.is_some() || !body.is_valid() { + return Err(VhostUserError::InvalidMessage); + } + Ok(body) + } + + fn recv_reply_with_files<T: ByteValued + Sized + VhostUserMsgValidator>( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + ) -> VhostUserResult<(T, Option<Vec<File>>)> { + if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let (reply, body, files) = self.main_sock.recv_body::<T>()?; + if !reply.is_reply_for(hdr) || files.is_none() || !body.is_valid() { + return Err(VhostUserError::InvalidMessage); + } + Ok((body, files)) + } + + fn recv_reply_with_payload<T: ByteValued + Sized + VhostUserMsgValidator>( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<File>>)> { + if mem::size_of::<T>() > MAX_MSG_SIZE + || hdr.get_size() as usize <= mem::size_of::<T>() + || hdr.get_size() as usize > MAX_MSG_SIZE + || hdr.is_reply() + { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let mut buf: Vec<u8> = vec![0; hdr.get_size() as usize - mem::size_of::<T>()]; + let (reply, body, bytes, files) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?; + if !reply.is_reply_for(hdr) + || reply.get_size() as usize != mem::size_of::<T>() + bytes + || files.is_some() + || !body.is_valid() + || bytes != buf.len() + { + return Err(VhostUserError::InvalidMessage); + } + + Ok((body, buf, files)) + } + + fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<MasterReq>) -> VhostUserResult<()> { + if self.acked_protocol_features & VhostUserProtocolFeatures::REPLY_ACK.bits() == 0 + || !hdr.is_need_reply() + { + return Ok(()); + } + self.check_state()?; + + let (reply, body, rfds) = self.main_sock.recv_body::<VhostUserU64>()?; + if !reply.is_reply_for(hdr) || rfds.is_some() || !body.is_valid() { + return Err(VhostUserError::InvalidMessage); + } + if body.value != 0 { + return Err(VhostUserError::SlaveInternalError); + } + Ok(()) + } + + fn check_feature(&self, feat: VhostUserVirtioFeatures) -> VhostUserResult<()> { + if self.virtio_features & feat.bits() != 0 { + Ok(()) + } else { + Err(VhostUserError::InactiveFeature(feat)) + } + } + + fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> VhostUserResult<()> { + if self.acked_protocol_features & feat.bits() != 0 { + Ok(()) + } else { + Err(VhostUserError::InactiveOperation(feat)) + } + } + + fn check_state(&self) -> VhostUserResult<()> { + match self.error { + Some(e) => Err(VhostUserError::SocketBroken( + std::io::Error::from_raw_os_error(e), + )), + None => Ok(()), + } + } + + #[inline] + fn new_request_header(&self, request: MasterReq, size: u32) -> VhostUserMsgHeader<MasterReq> { + VhostUserMsgHeader::new(request, self.hdr_flags.bits() | 0x1, size) + } +} + +#[cfg(test)] +mod tests { + use super::super::connection::Listener; + use super::*; + use vmm_sys_util::rand::rand_alphanumerics; + + use std::path::PathBuf; + + fn temp_path() -> PathBuf { + PathBuf::from(format!( + "/tmp/vhost_test_{}", + rand_alphanumerics(8).to_str().unwrap() + )) + } + + fn create_pair<P: AsRef<Path>>(path: P) -> (Master, Endpoint<MasterReq>) { + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let master = Master::connect(path, 2).unwrap(); + let slave = listener.accept().unwrap().unwrap(); + (master, Endpoint::from_stream(slave)) + } + + #[test] + fn create_master() { + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + + let master = Master::connect(&path, 1).unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(listener.accept().unwrap().unwrap()); + + assert!(master.as_raw_fd() > 0); + // Send two messages continuously + master.set_owner().unwrap(); + master.reset_owner().unwrap(); + + let (hdr, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_size(), 0); + assert_eq!(hdr.get_version(), 0x1); + assert!(rfds.is_none()); + + let (hdr, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr.get_code().unwrap(), MasterReq::RESET_OWNER); + assert_eq!(hdr.get_size(), 0); + assert_eq!(hdr.get_version(), 0x1); + assert!(rfds.is_none()); + } + + #[test] + fn test_create_failure() { + let path = temp_path(); + let _ = Listener::new(&path, true).unwrap(); + let _ = Listener::new(&path, false).is_err(); + assert!(Master::connect(&path, 1).is_err()); + + let listener = Listener::new(&path, true).unwrap(); + assert!(Listener::new(&path, false).is_err()); + listener.set_nonblocking(true).unwrap(); + + let _master = Master::connect(&path, 1).unwrap(); + let _slave = listener.accept().unwrap().unwrap(); + } + + #[test] + fn test_features() { + let path = temp_path(); + let (master, mut peer) = create_pair(path); + + master.set_owner().unwrap(); + let (hdr, rfds) = peer.recv_header().unwrap(); + assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_size(), 0); + assert_eq!(hdr.get_version(), 0x1); + assert!(rfds.is_none()); + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(0x15); + peer.send_message(&hdr, &msg, None).unwrap(); + let features = master.get_features().unwrap(); + assert_eq!(features, 0x15u64); + let (_hdr, rfds) = peer.recv_header().unwrap(); + assert!(rfds.is_none()); + + let hdr = VhostUserMsgHeader::new(MasterReq::SET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(0x15); + peer.send_message(&hdr, &msg, None).unwrap(); + master.set_features(0x15).unwrap(); + let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap(); + assert!(rfds.is_none()); + let val = msg.value; + assert_eq!(val, 0x15); + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8); + let msg = 0x15u32; + peer.send_message(&hdr, &msg, None).unwrap(); + assert!(master.get_features().is_err()); + } + + #[test] + fn test_protocol_features() { + let path = temp_path(); + let (mut master, mut peer) = create_pair(path); + + master.set_owner().unwrap(); + let (hdr, rfds) = peer.recv_header().unwrap(); + assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER); + assert!(rfds.is_none()); + + assert!(master.get_protocol_features().is_err()); + assert!(master + .set_protocol_features(VhostUserProtocolFeatures::all()) + .is_err()); + + let vfeatures = 0x15 | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(vfeatures); + peer.send_message(&hdr, &msg, None).unwrap(); + let features = master.get_features().unwrap(); + assert_eq!(features, vfeatures); + let (_hdr, rfds) = peer.recv_header().unwrap(); + assert!(rfds.is_none()); + + master.set_features(vfeatures).unwrap(); + let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap(); + assert!(rfds.is_none()); + let val = msg.value; + assert_eq!(val, vfeatures); + + let pfeatures = VhostUserProtocolFeatures::all(); + let hdr = VhostUserMsgHeader::new(MasterReq::GET_PROTOCOL_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(pfeatures.bits()); + peer.send_message(&hdr, &msg, None).unwrap(); + let features = master.get_protocol_features().unwrap(); + assert_eq!(features, pfeatures); + let (_hdr, rfds) = peer.recv_header().unwrap(); + assert!(rfds.is_none()); + + master.set_protocol_features(pfeatures).unwrap(); + let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap(); + assert!(rfds.is_none()); + let val = msg.value; + assert_eq!(val, pfeatures.bits()); + + let hdr = VhostUserMsgHeader::new(MasterReq::SET_PROTOCOL_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(pfeatures.bits()); + peer.send_message(&hdr, &msg, None).unwrap(); + assert!(master.get_protocol_features().is_err()); + } + + #[test] + fn test_master_set_config_negative() { + let path = temp_path(); + let (mut master, _peer) = create_pair(path); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + + { + let mut node = master.node(); + node.virtio_features = 0xffff_ffff; + node.acked_virtio_features = 0xffff_ffff; + node.protocol_features = 0xffff_ffff; + node.acked_protocol_features = 0xffff_ffff; + } + + master + .set_config(0, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap(); + master + .set_config( + VHOST_USER_CONFIG_SIZE, + VhostUserConfigFlags::WRITABLE, + &buf[0..4], + ) + .unwrap_err(); + master + .set_config(0x1000, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + master + .set_config( + 0x100, + // SAFETY: This is a negative test, so we are setting unexpected flags. + unsafe { VhostUserConfigFlags::from_bits_unchecked(0xffff_ffff) }, + &buf[0..4], + ) + .unwrap_err(); + master + .set_config(VHOST_USER_CONFIG_SIZE, VhostUserConfigFlags::WRITABLE, &buf) + .unwrap_err(); + master + .set_config(VHOST_USER_CONFIG_SIZE, VhostUserConfigFlags::WRITABLE, &[]) + .unwrap_err(); + } + + fn create_pair2() -> (Master, Endpoint<MasterReq>) { + let path = temp_path(); + let (master, peer) = create_pair(path); + + { + let mut node = master.node(); + node.virtio_features = 0xffff_ffff; + node.acked_virtio_features = 0xffff_ffff; + node.protocol_features = 0xffff_ffff; + node.acked_protocol_features = 0xffff_ffff; + } + + (master, peer) + } + + #[test] + fn test_master_get_config_negative0() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + hdr.set_code(MasterReq::GET_FEATURES); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + hdr.set_code(MasterReq::GET_CONFIG); + } + + #[test] + fn test_master_get_config_negative1() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + hdr.set_reply(false); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative2() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + } + + #[test] + fn test_master_get_config_negative3() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = 0; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative4() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = 0x101; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative5() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = (MAX_MSG_SIZE + 1) as u32; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative6() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.size = 6; + peer.send_message_with_payload(&hdr, &msg, &buf[0..6], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_maset_set_mem_table_failure() { + let (master, _peer) = create_pair2(); + + master.set_mem_table(&[]).unwrap_err(); + let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1]; + master.set_mem_table(&tables).unwrap_err(); + } +} diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs new file mode 100644 index 0000000..c9c528b --- /dev/null +++ b/src/vhost_user/master_req_handler.rs @@ -0,0 +1,466 @@ +// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::fs::File; +use std::mem; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::sync::{Arc, Mutex}; + +use super::connection::Endpoint; +use super::message::*; +use super::{Error, HandlerResult, Result}; + +/// Define services provided by masters for the slave communication channel. +/// +/// The vhost-user specification defines a slave communication channel, by which slaves could +/// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided +/// by masters, and it's used both on the master side and slave side. +/// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy +/// service requests to masters. The [Slave] is an example stub forwarder. +/// - on the master side, the [MasterReqHandler] will forward service requests to a handler +/// implementing [VhostUserMasterReqHandler]. +/// +/// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance +/// for multi-threading. +/// +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html +/// [MasterReqHandler]: struct.MasterReqHandler.html +/// [Slave]: struct.Slave.html +pub trait VhostUserMasterReqHandler { + /// Handle device configuration change notifications. + fn handle_config_change(&self) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs map file requests. + fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs unmap file requests. + fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs sync file requests. + fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs file IO requests. + fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); + // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: &dyn AsRawFd); +} + +/// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability. +/// +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html +pub trait VhostUserMasterReqHandlerMut { + /// Handle device configuration change notifications. + fn handle_config_change(&mut self) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs map file requests. + fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs unmap file requests. + fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs sync file requests. + fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs file IO requests. + fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); + // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); +} + +impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> { + fn handle_config_change(&self) -> HandlerResult<u64> { + self.lock().unwrap().handle_config_change() + } + + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_map(fs, fd) + } + + fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_unmap(fs) + } + + fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_sync(fs) + } + + fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_io(fs, fd) + } +} + +/// Server to handle service requests from slaves from the slave communication channel. +/// +/// The [MasterReqHandler] acts as a server on the master side, to handle service requests from +/// slaves on the slave communication channel. It's actually a proxy invoking the registered +/// handler implementing [VhostUserMasterReqHandler] to do the real work. +/// +/// [MasterReqHandler]: struct.MasterReqHandler.html +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html +pub struct MasterReqHandler<S: VhostUserMasterReqHandler> { + // underlying Unix domain socket for communication + sub_sock: Endpoint<SlaveReq>, + tx_sock: UnixStream, + // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. + reply_ack_negotiated: bool, + // the VirtIO backend device object + backend: Arc<S>, + // whether the endpoint has encountered any failure + error: Option<i32>, +} + +impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { + /// Create a server to handle service requests from slaves on the slave communication channel. + /// + /// This opens a pair of connected anonymous sockets to form the slave communication channel. + /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by + /// [VhostUserMaster::set_slave_request_fd()]. + /// + /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd + /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd + pub fn new(backend: Arc<S>) -> Result<Self> { + let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?; + + Ok(MasterReqHandler { + sub_sock: Endpoint::<SlaveReq>::from_stream(rx), + tx_sock: tx, + reply_ack_negotiated: false, + backend, + error: None, + }) + } + + /// Get the socket fd for the slave to communication with the master. + /// + /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()]. + /// + /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd + pub fn get_tx_raw_fd(&self) -> RawFd { + self.tx_sock.as_raw_fd() + } + + /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. + /// + /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, + /// the "REPLY_ACK" flag will be set in the message header for every slave to master request + /// message. + pub fn set_reply_ack_flag(&mut self, enable: bool) { + self.reply_ack_negotiated = enable; + } + + /// Mark endpoint as failed or in normal state. + pub fn set_failed(&mut self, error: i32) { + if error == 0 { + self.error = None; + } else { + self.error = Some(error); + } + } + + /// Main entrance to server slave request from the slave communication channel. + /// + /// The caller needs to: + /// - serialize calls to this function + /// - decide what to do when errer happens + /// - optional recover from failure + pub fn handle_request(&mut self) -> Result<u64> { + // Return error if the endpoint is already in failed state. + self.check_state()?; + + // The underlying communication channel is a Unix domain socket in + // stream mode, and recvmsg() is a little tricky here. To successfully + // receive attached file descriptors, we need to receive messages and + // corresponding attached file descriptors in this way: + // . recv messsage header and optional attached file + // . validate message header + // . recv optional message body and payload according size field in + // message header + // . validate message body and optional payload + let (hdr, files) = self.sub_sock.recv_header()?; + self.check_attached_files(&hdr, &files)?; + let (size, buf) = match hdr.get_size() { + 0 => (0, vec![0u8; 0]), + len => { + if len as usize > MAX_MSG_SIZE { + return Err(Error::InvalidMessage); + } + let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?; + if size2 != len as usize { + return Err(Error::InvalidMessage); + } + (size2, rbuf) + } + }; + + let res = match hdr.get_code() { + Ok(SlaveReq::CONFIG_CHANGE_MSG) => { + self.check_msg_size(&hdr, size, 0)?; + self.backend + .handle_config_change() + .map_err(Error::ReqHandlerError) + } + Ok(SlaveReq::FS_MAP) => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + // check_attached_files() has validated files + self.backend + .fs_slave_map(&msg, &files.unwrap()[0]) + .map_err(Error::ReqHandlerError) + } + Ok(SlaveReq::FS_UNMAP) => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + self.backend + .fs_slave_unmap(&msg) + .map_err(Error::ReqHandlerError) + } + Ok(SlaveReq::FS_SYNC) => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + self.backend + .fs_slave_sync(&msg) + .map_err(Error::ReqHandlerError) + } + Ok(SlaveReq::FS_IO) => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + // check_attached_files() has validated files + self.backend + .fs_slave_io(&msg, &files.unwrap()[0]) + .map_err(Error::ReqHandlerError) + } + _ => Err(Error::InvalidMessage), + }; + + self.send_ack_message(&hdr, &res)?; + + res + } + + fn check_state(&self) -> Result<()> { + match self.error { + Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), + None => Ok(()), + } + } + + fn check_msg_size( + &self, + hdr: &VhostUserMsgHeader<SlaveReq>, + size: usize, + expected: usize, + ) -> Result<()> { + if hdr.get_size() as usize != expected + || hdr.is_reply() + || hdr.get_version() != 0x1 + || size != expected + { + return Err(Error::InvalidMessage); + } + Ok(()) + } + + fn check_attached_files( + &self, + hdr: &VhostUserMsgHeader<SlaveReq>, + files: &Option<Vec<File>>, + ) -> Result<()> { + match hdr.get_code() { + Ok(SlaveReq::FS_MAP | SlaveReq::FS_IO) => { + // Expect a single file is passed. + match files { + Some(files) if files.len() == 1 => Ok(()), + _ => Err(Error::InvalidMessage), + } + } + _ if files.is_some() => Err(Error::InvalidMessage), + _ => Ok(()), + } + } + + fn extract_msg_body<T: Sized + VhostUserMsgValidator>( + &self, + hdr: &VhostUserMsgHeader<SlaveReq>, + size: usize, + buf: &[u8], + ) -> Result<T> { + self.check_msg_size(hdr, size, mem::size_of::<T>())?; + // SAFETY: Safe because we checked that `buf` size is equal to T size. + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + Ok(msg) + } + + fn new_reply_header<T: Sized>( + &self, + req: &VhostUserMsgHeader<SlaveReq>, + ) -> Result<VhostUserMsgHeader<SlaveReq>> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::InvalidParam); + } + self.check_state()?; + Ok(VhostUserMsgHeader::new( + req.get_code()?, + VhostUserHeaderFlag::REPLY.bits(), + mem::size_of::<T>() as u32, + )) + } + + fn send_ack_message( + &mut self, + req: &VhostUserMsgHeader<SlaveReq>, + res: &Result<u64>, + ) -> Result<()> { + if self.reply_ack_negotiated && req.is_need_reply() { + let hdr = self.new_reply_header::<VhostUserU64>(req)?; + let def_err = libc::EINVAL; + let val = match res { + Ok(n) => *n, + Err(e) => match e { + Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() { + Some(rawerr) => -rawerr as u64, + None => -def_err as u64, + }, + _ => -def_err as u64, + }, + }; + let msg = VhostUserU64::new(val); + self.sub_sock.send_message(&hdr, &msg, None)?; + } + Ok(()) + } +} + +impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> { + fn as_raw_fd(&self) -> RawFd { + self.sub_sock.as_raw_fd() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "vhost-user-slave")] + use crate::vhost_user::Slave; + #[cfg(feature = "vhost-user-slave")] + use std::os::unix::io::FromRawFd; + + struct MockMasterReqHandler {} + + impl VhostUserMasterReqHandlerMut for MockMasterReqHandler { + /// Handle virtio-fs map file requests from the slave. + fn fs_slave_map( + &mut self, + _fs: &VhostUserFSSlaveMsg, + _fd: &dyn AsRawFd, + ) -> HandlerResult<u64> { + Ok(0) + } + + /// Handle virtio-fs unmap file requests from the slave. + fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + } + + #[test] + fn test_new_master_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + + assert!(handler.get_tx_raw_fd() >= 0); + assert!(handler.as_raw_fd() >= 0); + handler.check_state().unwrap(); + + assert_eq!(handler.error, None); + handler.set_failed(libc::EAGAIN); + assert_eq!(handler.error, Some(libc::EAGAIN)); + handler.check_state().unwrap_err(); + } + + #[cfg(feature = "vhost-user-slave")] + #[test] + fn test_master_slave_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + + // SAFETY: Safe because `handler` contains valid fds, and we are + // checking if `dup` returns a valid fd. + let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; + if fd < 0 { + panic!("failed to duplicated tx fd!"); + } + // SAFETY: Safe because we checked if fd is valid. + let stream = unsafe { UnixStream::from_raw_fd(fd) }; + let slave = Slave::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + slave + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd) + .unwrap(); + // When REPLY_ACK has not been negotiated, the master has no way to detect failure from + // slave side. + slave + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap(); + } + + #[cfg(feature = "vhost-user-slave")] + #[test] + fn test_master_slave_req_handler_with_ack() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + handler.set_reply_ack_flag(true); + + // SAFETY: Safe because `handler` contains valid fds, and we are + // checking if `dup` returns a valid fd. + let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; + if fd < 0 { + panic!("failed to duplicated tx fd!"); + } + // SAFETY: Safe because we checked if fd is valid. + let stream = unsafe { UnixStream::from_raw_fd(fd) }; + let slave = Slave::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + slave.set_reply_ack_flag(true); + slave + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd) + .unwrap(); + slave + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + } +} diff --git a/src/vhost_user/message.rs b/src/vhost_user/message.rs new file mode 100644 index 0000000..bbd8eb9 --- /dev/null +++ b/src/vhost_user/message.rs @@ -0,0 +1,1403 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Define communication messages for the vhost-user protocol. +//! +//! For message definition, please refer to the [vhost-user spec](https://qemu.readthedocs.io/en/latest/interop/vhost-user.html). + +#![allow(dead_code)] +#![allow(non_camel_case_types)] +#![allow(clippy::upper_case_acronyms)] + +use std::fmt::Debug; +use std::fs::File; +use std::io; +use std::marker::PhantomData; +use std::ops::Deref; + +use vm_memory::{mmap::NewBitmap, ByteValued, Error as MmapError, FileOffset, MmapRegion}; + +#[cfg(feature = "xen")] +use vm_memory::{GuestAddress, MmapRange, MmapXenFlags}; + +use super::{Error, Result}; +use crate::VringConfigData; + +/// The vhost-user specification uses a field of u32 to store message length. +/// On the other hand, preallocated buffers are needed to receive messages from the Unix domain +/// socket. To preallocating a 4GB buffer for each vhost-user message is really just an overhead. +/// Among all defined vhost-user messages, only the VhostUserConfig and VhostUserMemory has variable +/// message size. For the VhostUserConfig, a maximum size of 4K is enough because the user +/// configuration space for virtio devices is (4K - 0x100) bytes at most. For the VhostUserMemory, +/// 4K should be enough too because it can support 255 memory regions at most. +pub const MAX_MSG_SIZE: usize = 0x1000; + +/// The VhostUserMemory message has variable message size and variable number of attached file +/// descriptors. Each user memory region entry in the message payload occupies 32 bytes, +/// so setting maximum number of attached file descriptors based on the maximum message size. +/// But rust only implements Default and AsMut traits for arrays with 0 - 32 entries, so further +/// reduce the maximum number... +// pub const MAX_ATTACHED_FD_ENTRIES: usize = (MAX_MSG_SIZE - 8) / 32; +pub const MAX_ATTACHED_FD_ENTRIES: usize = 32; + +/// Starting position (inclusion) of the device configuration space in virtio devices. +pub const VHOST_USER_CONFIG_OFFSET: u32 = 0x100; + +/// Ending position (exclusion) of the device configuration space in virtio devices. +pub const VHOST_USER_CONFIG_SIZE: u32 = 0x1000; + +/// Maximum number of vrings supported. +pub const VHOST_USER_MAX_VRINGS: u64 = 0x8000u64; + +pub(super) trait Req: + Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Send + Sync + Into<u32> +{ + fn is_valid(value: u32) -> bool; +} + +/// Type of requests sending from masters to slaves. +#[repr(u32)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum MasterReq { + /// Null operation. + NOOP = 0, + /// Get from the underlying vhost implementation the features bit mask. + GET_FEATURES = 1, + /// Enable features in the underlying vhost implementation using a bit mask. + SET_FEATURES = 2, + /// Set the current Master as an owner of the session. + SET_OWNER = 3, + /// No longer used. + RESET_OWNER = 4, + /// Set the memory map regions on the slave so it can translate the vring addresses. + SET_MEM_TABLE = 5, + /// Set logging shared memory space. + SET_LOG_BASE = 6, + /// Set the logging file descriptor, which is passed as ancillary data. + SET_LOG_FD = 7, + /// Set the size of the queue. + SET_VRING_NUM = 8, + /// Set the addresses of the different aspects of the vring. + SET_VRING_ADDR = 9, + /// Set the base offset in the available vring. + SET_VRING_BASE = 10, + /// Get the available vring base offset. + GET_VRING_BASE = 11, + /// Set the event file descriptor for adding buffers to the vring. + SET_VRING_KICK = 12, + /// Set the event file descriptor to signal when buffers are used. + SET_VRING_CALL = 13, + /// Set the event file descriptor to signal when error occurs. + SET_VRING_ERR = 14, + /// Get the protocol feature bit mask from the underlying vhost implementation. + GET_PROTOCOL_FEATURES = 15, + /// Enable protocol features in the underlying vhost implementation. + SET_PROTOCOL_FEATURES = 16, + /// Query how many queues the backend supports. + GET_QUEUE_NUM = 17, + /// Signal slave to enable or disable corresponding vring. + SET_VRING_ENABLE = 18, + /// Ask vhost user backend to broadcast a fake RARP to notify the migration is terminated + /// for guest that does not support GUEST_ANNOUNCE. + SEND_RARP = 19, + /// Set host MTU value exposed to the guest. + NET_SET_MTU = 20, + /// Set the socket file descriptor for slave initiated requests. + SET_SLAVE_REQ_FD = 21, + /// Send IOTLB messages with struct vhost_iotlb_msg as payload. + IOTLB_MSG = 22, + /// Set the endianness of a VQ for legacy devices. + SET_VRING_ENDIAN = 23, + /// Fetch the contents of the virtio device configuration space. + GET_CONFIG = 24, + /// Change the contents of the virtio device configuration space. + SET_CONFIG = 25, + /// Create a session for crypto operation. + CREATE_CRYPTO_SESSION = 26, + /// Close a session for crypto operation. + CLOSE_CRYPTO_SESSION = 27, + /// Advise slave that a migration with postcopy enabled is underway. + POSTCOPY_ADVISE = 28, + /// Advise slave that a transition to postcopy mode has happened. + POSTCOPY_LISTEN = 29, + /// Advise that postcopy migration has now completed. + POSTCOPY_END = 30, + /// Get a shared buffer from slave. + GET_INFLIGHT_FD = 31, + /// Send the shared inflight buffer back to slave. + SET_INFLIGHT_FD = 32, + /// Sets the GPU protocol socket file descriptor. + GPU_SET_SOCKET = 33, + /// Ask the vhost user backend to disable all rings and reset all internal + /// device state to the initial state. + RESET_DEVICE = 34, + /// Indicate that a buffer was added to the vring instead of signalling it + /// using the vring’s kick file descriptor. + VRING_KICK = 35, + /// Return a u64 payload containing the maximum number of memory slots. + GET_MAX_MEM_SLOTS = 36, + /// Update the memory tables by adding the region described. + ADD_MEM_REG = 37, + /// Update the memory tables by removing the region described. + REM_MEM_REG = 38, + /// Notify the backend with updated device status as defined in the VIRTIO + /// specification. + SET_STATUS = 39, + /// Query the backend for its device status as defined in the VIRTIO + /// specification. + GET_STATUS = 40, + /// Upper bound of valid commands. + MAX_CMD = 41, +} + +impl From<MasterReq> for u32 { + fn from(req: MasterReq) -> u32 { + req as u32 + } +} + +impl Req for MasterReq { + fn is_valid(value: u32) -> bool { + (value > MasterReq::NOOP as u32) && (value < MasterReq::MAX_CMD as u32) + } +} + +/// Type of requests sending from slaves to masters. +#[repr(u32)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum SlaveReq { + /// Null operation. + NOOP = 0, + /// Send IOTLB messages with struct vhost_iotlb_msg as payload. + IOTLB_MSG = 1, + /// Notify that the virtio device's configuration space has changed. + CONFIG_CHANGE_MSG = 2, + /// Set host notifier for a specified queue. + VRING_HOST_NOTIFIER_MSG = 3, + /// Indicate that a buffer was used from the vring. + VRING_CALL = 4, + /// Indicate that an error occurred on the specific vring. + VRING_ERR = 5, + /// Virtio-fs draft: map file content into the window. + FS_MAP = 6, + /// Virtio-fs draft: unmap file content from the window. + FS_UNMAP = 7, + /// Virtio-fs draft: sync file content. + FS_SYNC = 8, + /// Virtio-fs draft: perform a read/write from an fd directly to GPA. + FS_IO = 9, + /// Upper bound of valid commands. + MAX_CMD = 10, +} + +impl From<SlaveReq> for u32 { + fn from(req: SlaveReq) -> u32 { + req as u32 + } +} + +impl Req for SlaveReq { + fn is_valid(value: u32) -> bool { + (value > SlaveReq::NOOP as u32) && (value < SlaveReq::MAX_CMD as u32) + } +} + +/// Vhost message Validator. +pub trait VhostUserMsgValidator { + /// Validate message syntax only. + /// It doesn't validate message semantics such as protocol version number and dependency + /// on feature flags etc. + fn is_valid(&self) -> bool { + true + } +} + +// Bit mask for common message flags. +bitflags! { + /// Common message flags for vhost-user requests and replies. + pub struct VhostUserHeaderFlag: u32 { + /// Bits[0..2] is message version number. + const VERSION = 0x3; + /// Mark message as reply. + const REPLY = 0x4; + /// Sender anticipates a reply message from the peer. + const NEED_REPLY = 0x8; + /// All valid bits. + const ALL_FLAGS = 0xc; + /// All reserved bits. + const RESERVED_BITS = !0xf; + } +} + +/// Common message header for vhost-user requests and replies. +/// A vhost-user message consists of 3 header fields and an optional payload. All numbers are in the +/// machine native byte order. +#[repr(packed)] +#[derive(Copy)] +pub(super) struct VhostUserMsgHeader<R: Req> { + request: u32, + flags: u32, + size: u32, + _r: PhantomData<R>, +} + +impl<R: Req> Debug for VhostUserMsgHeader<R> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("VhostUserMsgHeader") + .field("request", &{ self.request }) + .field("flags", &{ self.flags }) + .field("size", &{ self.size }) + .finish() + } +} + +impl<R: Req> Clone for VhostUserMsgHeader<R> { + fn clone(&self) -> VhostUserMsgHeader<R> { + *self + } +} + +impl<R: Req> PartialEq for VhostUserMsgHeader<R> { + fn eq(&self, other: &Self) -> bool { + self.request == other.request && self.flags == other.flags && self.size == other.size + } +} + +impl<R: Req> VhostUserMsgHeader<R> { + /// Create a new instance of `VhostUserMsgHeader`. + pub fn new(request: R, flags: u32, size: u32) -> Self { + // Default to protocol version 1 + let fl = (flags & VhostUserHeaderFlag::ALL_FLAGS.bits()) | 0x1; + VhostUserMsgHeader { + request: request.into(), + flags: fl, + size, + _r: PhantomData, + } + } + + /// Get message type. + pub fn get_code(&self) -> Result<R> { + if R::is_valid(self.request) { + // SAFETY: It's safe because R is marked as repr(u32), and the value is valid. + Ok(unsafe { std::mem::transmute_copy::<u32, R>(&{ self.request }) }) + } else { + Err(Error::InvalidMessage) + } + } + + /// Set message type. + pub fn set_code(&mut self, request: R) { + self.request = request.into(); + } + + /// Get message version number. + pub fn get_version(&self) -> u32 { + self.flags & 0x3 + } + + /// Set message version number. + pub fn set_version(&mut self, ver: u32) { + self.flags &= !0x3; + self.flags |= ver & 0x3; + } + + /// Check whether it's a reply message. + pub fn is_reply(&self) -> bool { + (self.flags & VhostUserHeaderFlag::REPLY.bits()) != 0 + } + + /// Mark message as reply. + pub fn set_reply(&mut self, is_reply: bool) { + if is_reply { + self.flags |= VhostUserHeaderFlag::REPLY.bits(); + } else { + self.flags &= !VhostUserHeaderFlag::REPLY.bits(); + } + } + + /// Check whether reply for this message is requested. + pub fn is_need_reply(&self) -> bool { + (self.flags & VhostUserHeaderFlag::NEED_REPLY.bits()) != 0 + } + + /// Mark that reply for this message is needed. + pub fn set_need_reply(&mut self, need_reply: bool) { + if need_reply { + self.flags |= VhostUserHeaderFlag::NEED_REPLY.bits(); + } else { + self.flags &= !VhostUserHeaderFlag::NEED_REPLY.bits(); + } + } + + /// Check whether it's the reply message for the request `req`. + pub fn is_reply_for(&self, req: &VhostUserMsgHeader<R>) -> bool { + if let (Ok(code1), Ok(code2)) = (self.get_code(), req.get_code()) { + self.is_reply() && !req.is_reply() && code1 == code2 + } else { + false + } + } + + /// Get message size. + pub fn get_size(&self) -> u32 { + self.size + } + + /// Set message size. + pub fn set_size(&mut self, size: u32) { + self.size = size; + } +} + +impl<R: Req> Default for VhostUserMsgHeader<R> { + fn default() -> Self { + VhostUserMsgHeader { + request: 0, + flags: 0x1, + size: 0, + _r: PhantomData, + } + } +} + +// SAFETY: Safe because all fields of VhostUserMsgHeader are POD. +unsafe impl<R: Req> ByteValued for VhostUserMsgHeader<R> {} + +impl<T: Req> VhostUserMsgValidator for VhostUserMsgHeader<T> { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if self.get_code().is_err() { + return false; + } else if self.size as usize > MAX_MSG_SIZE { + return false; + } else if self.get_version() != 0x1 { + return false; + } else if (self.flags & VhostUserHeaderFlag::RESERVED_BITS.bits()) != 0 { + return false; + } + true + } +} + +// Bit mask for transport specific flags in VirtIO feature set defined by vhost-user. +bitflags! { + /// Transport specific flags in VirtIO feature set defined by vhost-user. + pub struct VhostUserVirtioFeatures: u64 { + /// Feature flag for the protocol feature. + const PROTOCOL_FEATURES = 0x4000_0000; + } +} + +// Bit mask for vhost-user protocol feature flags. +bitflags! { + /// Vhost-user protocol feature flags. + pub struct VhostUserProtocolFeatures: u64 { + /// Support multiple queues. + const MQ = 0x0000_0001; + /// Support logging through shared memory fd. + const LOG_SHMFD = 0x0000_0002; + /// Support broadcasting fake RARP packet. + const RARP = 0x0000_0004; + /// Support sending reply messages for requests with NEED_REPLY flag set. + const REPLY_ACK = 0x0000_0008; + /// Support setting MTU for virtio-net devices. + const MTU = 0x0000_0010; + /// Allow the slave to send requests to the master by an optional communication channel. + const SLAVE_REQ = 0x0000_0020; + /// Support setting slave endian by SET_VRING_ENDIAN. + const CROSS_ENDIAN = 0x0000_0040; + /// Support crypto operations. + const CRYPTO_SESSION = 0x0000_0080; + /// Support sending userfault_fd from slaves to masters. + const PAGEFAULT = 0x0000_0100; + /// Support Virtio device configuration. + const CONFIG = 0x0000_0200; + /// Allow the slave to send fds (at most 8 descriptors in each message) to the master. + const SLAVE_SEND_FD = 0x0000_0400; + /// Allow the slave to register a host notifier. + const HOST_NOTIFIER = 0x0000_0800; + /// Support inflight shmfd. + const INFLIGHT_SHMFD = 0x0000_1000; + /// Support resetting the device. + const RESET_DEVICE = 0x0000_2000; + /// Support inband notifications. + const INBAND_NOTIFICATIONS = 0x0000_4000; + /// Support configuring memory slots. + const CONFIGURE_MEM_SLOTS = 0x0000_8000; + /// Support reporting status. + const STATUS = 0x0001_0000; + /// Support Xen mmap. + const XEN_MMAP = 0x0002_0000; + } +} + +/// A generic message to encapsulate a 64-bit value. +#[repr(packed)] +#[derive(Copy, Clone, Default)] +pub struct VhostUserU64 { + /// The encapsulated 64-bit common value. + pub value: u64, +} + +impl VhostUserU64 { + /// Create a new instance. + pub fn new(value: u64) -> Self { + VhostUserU64 { value } + } +} + +// SAFETY: Safe because all fields of VhostUserU64 are POD. +unsafe impl ByteValued for VhostUserU64 {} + +impl VhostUserMsgValidator for VhostUserU64 {} + +/// Memory region descriptor for the SET_MEM_TABLE request. +#[repr(packed)] +#[derive(Copy, Clone, Default)] +pub struct VhostUserMemory { + /// Number of memory regions in the payload. + pub num_regions: u32, + /// Padding for alignment. + pub padding1: u32, +} + +impl VhostUserMemory { + /// Create a new instance. + pub fn new(cnt: u32) -> Self { + VhostUserMemory { + num_regions: cnt, + padding1: 0, + } + } +} + +// SAFETY: Safe because all fields of VhostUserMemory are POD. +unsafe impl ByteValued for VhostUserMemory {} + +impl VhostUserMsgValidator for VhostUserMemory { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if self.padding1 != 0 { + return false; + } else if self.num_regions == 0 || self.num_regions > MAX_ATTACHED_FD_ENTRIES as u32 { + return false; + } + true + } +} + +/// Memory region descriptors as payload for the SET_MEM_TABLE request. +#[repr(packed)] +#[derive(Default, Clone, Copy)] +pub struct VhostUserMemoryRegion { + /// Guest physical address of the memory region. + pub guest_phys_addr: u64, + /// Size of the memory region. + pub memory_size: u64, + /// Virtual address in the current process. + pub user_addr: u64, + /// Offset where region starts in the mapped memory. + pub mmap_offset: u64, + + #[cfg(feature = "xen")] + /// Xen specific flags. + pub xen_mmap_flags: u32, + + #[cfg(feature = "xen")] + /// Xen specific data. + pub xen_mmap_data: u32, +} + +impl VhostUserMemoryRegion { + fn is_valid_common(&self) -> bool { + self.memory_size != 0 + && self.guest_phys_addr.checked_add(self.memory_size).is_some() + && self.user_addr.checked_add(self.memory_size).is_some() + && self.mmap_offset.checked_add(self.memory_size).is_some() + } +} + +#[cfg(not(feature = "xen"))] +impl VhostUserMemoryRegion { + /// Create a new instance. + pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self { + VhostUserMemoryRegion { + guest_phys_addr, + memory_size, + user_addr, + mmap_offset, + } + } + + /// Creates mmap region from Self. + pub fn mmap_region<B: NewBitmap>(&self, file: File) -> Result<MmapRegion<B>> { + MmapRegion::<B>::from_file( + FileOffset::new(file, self.mmap_offset), + self.memory_size as usize, + ) + .map_err(MmapError::MmapRegion) + .map_err(|e| Error::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))) + } + + fn is_valid(&self) -> bool { + self.is_valid_common() + } +} + +#[cfg(feature = "xen")] +impl VhostUserMemoryRegion { + /// Create a new instance. + pub fn with_xen( + guest_phys_addr: u64, + memory_size: u64, + user_addr: u64, + mmap_offset: u64, + xen_mmap_flags: u32, + xen_mmap_data: u32, + ) -> Self { + VhostUserMemoryRegion { + guest_phys_addr, + memory_size, + user_addr, + mmap_offset, + xen_mmap_flags, + xen_mmap_data, + } + } + + /// Creates mmap region from Self. + pub fn mmap_region<B: NewBitmap>(&self, file: File) -> Result<MmapRegion<B>> { + let range = MmapRange::new( + self.memory_size as usize, + Some(FileOffset::new(file, self.mmap_offset)), + GuestAddress(self.guest_phys_addr), + self.xen_mmap_flags, + self.xen_mmap_data, + ); + + MmapRegion::<B>::from_range(range) + .map_err(MmapError::MmapRegion) + .map_err(|e| Error::ReqHandlerError(io::Error::new(io::ErrorKind::Other, e))) + } + + fn is_valid(&self) -> bool { + if !self.is_valid_common() { + false + } else { + // Only of one of FOREIGN or GRANT should be set. + match MmapXenFlags::from_bits(self.xen_mmap_flags) { + Some(flags) => flags.is_valid(), + None => false, + } + } + } +} + +impl VhostUserMsgValidator for VhostUserMemoryRegion { + fn is_valid(&self) -> bool { + self.is_valid() + } +} + +/// Payload of the VhostUserMemory message. +pub type VhostUserMemoryPayload = Vec<VhostUserMemoryRegion>; + +/// Single memory region descriptor as payload for ADD_MEM_REG and REM_MEM_REG +/// requests. +#[repr(C)] +#[derive(Default, Clone, Copy)] +pub struct VhostUserSingleMemoryRegion { + /// Padding for correct alignment + padding: u64, + /// General memory region + region: VhostUserMemoryRegion, +} + +impl Deref for VhostUserSingleMemoryRegion { + type Target = VhostUserMemoryRegion; + + fn deref(&self) -> &VhostUserMemoryRegion { + &self.region + } +} + +#[cfg(not(feature = "xen"))] +impl VhostUserSingleMemoryRegion { + /// Create a new instance. + pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self { + VhostUserSingleMemoryRegion { + padding: 0, + region: VhostUserMemoryRegion::new( + guest_phys_addr, + memory_size, + user_addr, + mmap_offset, + ), + } + } +} + +#[cfg(feature = "xen")] +impl VhostUserSingleMemoryRegion { + /// Create a new instance. + pub fn new( + guest_phys_addr: u64, + memory_size: u64, + user_addr: u64, + mmap_offset: u64, + xen_mmap_flags: u32, + xen_mmap_data: u32, + ) -> Self { + VhostUserSingleMemoryRegion { + padding: 0, + region: VhostUserMemoryRegion::with_xen( + guest_phys_addr, + memory_size, + user_addr, + mmap_offset, + xen_mmap_flags, + xen_mmap_data, + ), + } + } +} + +// SAFETY: Safe because all fields of VhostUserSingleMemoryRegion are POD. +unsafe impl ByteValued for VhostUserSingleMemoryRegion {} +impl VhostUserMsgValidator for VhostUserSingleMemoryRegion {} + +/// Vring state descriptor. +#[repr(packed)] +#[derive(Copy, Clone, Default)] +pub struct VhostUserVringState { + /// Vring index. + pub index: u32, + /// A common 32bit value to encapsulate vring state etc. + pub num: u32, +} + +impl VhostUserVringState { + /// Create a new instance. + pub fn new(index: u32, num: u32) -> Self { + VhostUserVringState { index, num } + } +} + +// SAFETY: Safe because all fields of VhostUserVringState are POD. +unsafe impl ByteValued for VhostUserVringState {} + +impl VhostUserMsgValidator for VhostUserVringState {} + +// Bit mask for vring address flags. +bitflags! { + /// Flags for vring address. + pub struct VhostUserVringAddrFlags: u32 { + /// Support log of vring operations. + /// Modifications to "used" vring should be logged. + const VHOST_VRING_F_LOG = 0x1; + } +} + +/// Vring address descriptor. +#[repr(packed)] +#[derive(Copy, Clone, Default)] +pub struct VhostUserVringAddr { + /// Vring index. + pub index: u32, + /// Vring flags defined by VhostUserVringAddrFlags. + pub flags: u32, + /// Ring address of the vring descriptor table. + pub descriptor: u64, + /// Ring address of the vring used ring. + pub used: u64, + /// Ring address of the vring available ring. + pub available: u64, + /// Guest address for logging. + pub log: u64, +} + +impl VhostUserVringAddr { + /// Create a new instance. + pub fn new( + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Self { + VhostUserVringAddr { + index, + flags: flags.bits(), + descriptor, + used, + available, + log, + } + } + + /// Create a new instance from `VringConfigData`. + #[cfg_attr(feature = "cargo-clippy", allow(clippy::useless_conversion))] + pub fn from_config_data(index: u32, config_data: &VringConfigData) -> Self { + let log_addr = config_data.log_addr.unwrap_or(0); + VhostUserVringAddr { + index, + flags: config_data.flags, + descriptor: config_data.desc_table_addr, + used: config_data.used_ring_addr, + available: config_data.avail_ring_addr, + log: log_addr, + } + } +} + +// SAFETY: Safe because all fields of VhostUserVringAddr are POD. +unsafe impl ByteValued for VhostUserVringAddr {} + +impl VhostUserMsgValidator for VhostUserVringAddr { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if (self.flags & !VhostUserVringAddrFlags::all().bits()) != 0 { + return false; + } else if self.descriptor & 0xf != 0 { + return false; + } else if self.available & 0x1 != 0 { + return false; + } else if self.used & 0x3 != 0 { + return false; + } + true + } +} + +// Bit mask for the vhost-user device configuration message. +bitflags! { + /// Flags for the device configuration message. + pub struct VhostUserConfigFlags: u32 { + /// Vhost master messages used for writeable fields. + const WRITABLE = 0x1; + /// Vhost master messages used for live migration. + const LIVE_MIGRATION = 0x2; + } +} + +/// Message to read/write device configuration space. +#[repr(packed)] +#[derive(Copy, Clone, Default)] +pub struct VhostUserConfig { + /// Offset of virtio device's configuration space. + pub offset: u32, + /// Configuration space access size in bytes. + pub size: u32, + /// Flags for the device configuration operation. + pub flags: u32, +} + +impl VhostUserConfig { + /// Create a new instance. + pub fn new(offset: u32, size: u32, flags: VhostUserConfigFlags) -> Self { + VhostUserConfig { + offset, + size, + flags: flags.bits(), + } + } +} + +// SAFETY: Safe because all fields of VhostUserConfig are POD. +unsafe impl ByteValued for VhostUserConfig {} + +impl VhostUserMsgValidator for VhostUserConfig { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + let end_addr = match self.size.checked_add(self.offset) { + Some(addr) => addr, + None => return false, + }; + if (self.flags & !VhostUserConfigFlags::all().bits()) != 0 { + return false; + } else if self.size == 0 || end_addr > VHOST_USER_CONFIG_SIZE { + return false; + } + true + } +} + +/// Payload for the VhostUserConfig message. +pub type VhostUserConfigPayload = Vec<u8>; + +/// Single memory region descriptor as payload for ADD_MEM_REG and REM_MEM_REG +/// requests. +#[repr(C)] +#[derive(Copy, Clone, Default)] +pub struct VhostUserInflight { + /// Size of the area to track inflight I/O. + pub mmap_size: u64, + /// Offset of this area from the start of the supplied file descriptor. + pub mmap_offset: u64, + /// Number of virtqueues. + pub num_queues: u16, + /// Size of virtqueues. + pub queue_size: u16, +} + +impl VhostUserInflight { + /// Create a new instance. + pub fn new(mmap_size: u64, mmap_offset: u64, num_queues: u16, queue_size: u16) -> Self { + VhostUserInflight { + mmap_size, + mmap_offset, + num_queues, + queue_size, + } + } +} + +// SAFETY: Safe because all fields of VhostUserInflight are POD. +unsafe impl ByteValued for VhostUserInflight {} + +impl VhostUserMsgValidator for VhostUserInflight { + fn is_valid(&self) -> bool { + if self.num_queues == 0 || self.queue_size == 0 { + return false; + } + true + } +} + +/// Single memory region descriptor as payload for SET_LOG_BASE request. +#[repr(C)] +#[derive(Copy, Clone, Default)] +pub struct VhostUserLog { + /// Size of the area to log dirty pages. + pub mmap_size: u64, + /// Offset of this area from the start of the supplied file descriptor. + pub mmap_offset: u64, +} + +impl VhostUserLog { + /// Create a new instance. + pub fn new(mmap_size: u64, mmap_offset: u64) -> Self { + VhostUserLog { + mmap_size, + mmap_offset, + } + } +} + +// SAFETY: Safe because all fields of VhostUserLog are POD. +unsafe impl ByteValued for VhostUserLog {} + +impl VhostUserMsgValidator for VhostUserLog { + fn is_valid(&self) -> bool { + if self.mmap_size == 0 || self.mmap_offset.checked_add(self.mmap_size).is_none() { + return false; + } + true + } +} + +/* + * TODO: support dirty log, live migration and IOTLB operations. +#[repr(packed)] +pub struct VhostUserVringArea { + pub index: u32, + pub flags: u32, + pub size: u64, + pub offset: u64, +} + +#[repr(packed)] +pub struct VhostUserLog { + pub size: u64, + pub offset: u64, +} + +#[repr(packed)] +pub struct VhostUserIotlb { + pub iova: u64, + pub size: u64, + pub user_addr: u64, + pub permission: u8, + pub optype: u8, +} +*/ + +// Bit mask for flags in virtio-fs slave messages +bitflags! { + #[derive(Default)] + /// Flags for virtio-fs slave messages. + pub struct VhostUserFSSlaveMsgFlags: u64 { + /// Empty permission. + const EMPTY = 0x0; + /// Read permission. + const MAP_R = 0x1; + /// Write permission. + const MAP_W = 0x2; + } +} + +/// Max entries in one virtio-fs slave request. +pub const VHOST_USER_FS_SLAVE_ENTRIES: usize = 8; + +/// Slave request message to update the MMIO window. +#[repr(packed)] +#[derive(Copy, Clone, Default)] +pub struct VhostUserFSSlaveMsg { + /// File offset. + pub fd_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES], + /// Offset into the DAX window. + pub cache_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES], + /// Size of region to map. + pub len: [u64; VHOST_USER_FS_SLAVE_ENTRIES], + /// Flags for the mmap operation + pub flags: [VhostUserFSSlaveMsgFlags; VHOST_USER_FS_SLAVE_ENTRIES], +} + +// SAFETY: Safe because all fields of VhostUserFSSlaveMsg are POD. +unsafe impl ByteValued for VhostUserFSSlaveMsg {} + +impl VhostUserMsgValidator for VhostUserFSSlaveMsg { + fn is_valid(&self) -> bool { + for i in 0..VHOST_USER_FS_SLAVE_ENTRIES { + if ({ self.flags[i] }.bits() & !VhostUserFSSlaveMsgFlags::all().bits()) != 0 + || self.fd_offset[i].checked_add(self.len[i]).is_none() + || self.cache_offset[i].checked_add(self.len[i]).is_none() + { + return false; + } + } + true + } +} + +/// Inflight I/O descriptor state for split virtqueues +#[repr(packed)] +#[derive(Clone, Copy, Default)] +pub struct DescStateSplit { + /// Indicate whether this descriptor (only head) is inflight or not. + pub inflight: u8, + /// Padding + padding: [u8; 5], + /// List of last batch of used descriptors, only when batching is used for submitting + pub next: u16, + /// Preserve order of fetching available descriptors, only for head descriptor + pub counter: u64, +} + +impl DescStateSplit { + /// New instance of DescStateSplit struct + pub fn new() -> Self { + Self::default() + } +} + +/// Inflight I/O queue region for split virtqueues +#[repr(packed)] +pub struct QueueRegionSplit { + /// Features flags of this region + pub features: u64, + /// Version of this region + pub version: u16, + /// Number of DescStateSplit entries + pub desc_num: u16, + /// List to track last batch of used descriptors + pub last_batch_head: u16, + /// Idx value of used ring + pub used_idx: u16, + /// Pointer to an array of DescStateSplit entries + pub desc: u64, +} + +impl QueueRegionSplit { + /// New instance of QueueRegionSplit struct + pub fn new(features: u64, queue_size: u16) -> Self { + QueueRegionSplit { + features, + version: 1, + desc_num: queue_size, + last_batch_head: 0, + used_idx: 0, + desc: 0, + } + } +} + +/// Inflight I/O descriptor state for packed virtqueues +#[repr(packed)] +#[derive(Clone, Copy, Default)] +pub struct DescStatePacked { + /// Indicate whether this descriptor (only head) is inflight or not. + pub inflight: u8, + /// Padding + padding: u8, + /// Link to next free entry + pub next: u16, + /// Link to last entry of descriptor list, only for head + pub last: u16, + /// Length of descriptor list, only for head + pub num: u16, + /// Preserve order of fetching avail descriptors, only for head + pub counter: u64, + /// Buffer ID + pub id: u16, + /// Descriptor flags + pub flags: u16, + /// Buffer length + pub len: u32, + /// Buffer address + pub addr: u64, +} + +impl DescStatePacked { + /// New instance of DescStatePacked struct + pub fn new() -> Self { + Self::default() + } +} + +/// Inflight I/O queue region for packed virtqueues +#[repr(packed)] +pub struct QueueRegionPacked { + /// Features flags of this region + pub features: u64, + /// version of this region + pub version: u16, + /// size of descriptor state array + pub desc_num: u16, + /// head of free DescStatePacked entry list + pub free_head: u16, + /// old head of free DescStatePacked entry list + pub old_free_head: u16, + /// used idx of descriptor ring + pub used_idx: u16, + /// old used idx of descriptor ring + pub old_used_idx: u16, + /// device ring wrap counter + pub used_wrap_counter: u8, + /// old device ring wrap counter + pub old_used_wrap_counter: u8, + /// Padding + padding: [u8; 7], + /// Pointer to array tracking state of each descriptor from descriptor ring + pub desc: u64, +} + +impl QueueRegionPacked { + /// New instance of QueueRegionPacked struct + pub fn new(features: u64, queue_size: u16) -> Self { + QueueRegionPacked { + features, + version: 1, + desc_num: queue_size, + free_head: 0, + old_free_head: 0, + used_idx: 0, + old_used_idx: 0, + used_wrap_counter: 0, + old_used_wrap_counter: 0, + padding: [0; 7], + desc: 0, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::mem; + + #[cfg(feature = "xen")] + impl VhostUserMemoryRegion { + fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self { + Self::with_xen( + guest_phys_addr, + memory_size, + user_addr, + mmap_offset, + MmapXenFlags::FOREIGN.bits(), + 0, + ) + } + } + + #[test] + fn check_master_request_code() { + assert!(!MasterReq::is_valid(MasterReq::NOOP as _)); + assert!(!MasterReq::is_valid(MasterReq::MAX_CMD as _)); + assert!(MasterReq::MAX_CMD > MasterReq::NOOP); + let code = MasterReq::GET_FEATURES; + assert!(MasterReq::is_valid(code as _)); + assert_eq!(code, code.clone()); + assert!(!MasterReq::is_valid(10000)); + } + + #[test] + fn check_slave_request_code() { + assert!(!SlaveReq::is_valid(SlaveReq::NOOP as _)); + assert!(!SlaveReq::is_valid(SlaveReq::MAX_CMD as _)); + assert!(SlaveReq::MAX_CMD > SlaveReq::NOOP); + let code = SlaveReq::CONFIG_CHANGE_MSG; + assert!(SlaveReq::is_valid(code as _)); + assert_eq!(code, code.clone()); + assert!(!SlaveReq::is_valid(10000)); + } + + #[test] + fn msg_header_ops() { + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, 0x100); + assert_eq!(hdr.get_code().unwrap(), MasterReq::GET_FEATURES); + hdr.set_code(MasterReq::SET_FEATURES); + assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_FEATURES); + + assert_eq!(hdr.get_version(), 0x1); + + assert!(!hdr.is_reply()); + hdr.set_reply(true); + assert!(hdr.is_reply()); + hdr.set_reply(false); + + assert!(!hdr.is_need_reply()); + hdr.set_need_reply(true); + assert!(hdr.is_need_reply()); + hdr.set_need_reply(false); + + assert_eq!(hdr.get_size(), 0x100); + hdr.set_size(0x200); + assert_eq!(hdr.get_size(), 0x200); + + assert!(!hdr.is_need_reply()); + assert!(!hdr.is_reply()); + assert_eq!(hdr.get_version(), 0x1); + + // Check message length + assert!(hdr.is_valid()); + hdr.set_size(0x2000); + assert!(!hdr.is_valid()); + hdr.set_size(0x100); + assert_eq!(hdr.get_size(), 0x100); + assert!(hdr.is_valid()); + hdr.set_size((MAX_MSG_SIZE - mem::size_of::<VhostUserMsgHeader<MasterReq>>()) as u32); + assert!(hdr.is_valid()); + hdr.set_size(0x0); + assert!(hdr.is_valid()); + + // Check version + hdr.set_version(0x0); + assert!(!hdr.is_valid()); + hdr.set_version(0x2); + assert!(!hdr.is_valid()); + hdr.set_version(0x1); + assert!(hdr.is_valid()); + + // Test Debug, Clone, PartiaEq trait + assert_eq!(hdr, hdr.clone()); + assert_eq!(hdr.clone().get_code().unwrap(), hdr.get_code().unwrap()); + assert_eq!(format!("{:?}", hdr.clone()), format!("{:?}", hdr)); + } + + #[test] + fn test_vhost_user_message_u64() { + let val = VhostUserU64::default(); + let val1 = VhostUserU64::new(0); + + let a = val.value; + let b = val1.value; + assert_eq!(a, b); + let a = VhostUserU64::new(1).value; + assert_eq!(a, 1); + } + + #[test] + fn check_user_memory() { + let mut msg = VhostUserMemory::new(1); + assert!(msg.is_valid()); + msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32; + assert!(msg.is_valid()); + + msg.num_regions += 1; + assert!(!msg.is_valid()); + msg.num_regions = 0xFFFFFFFF; + assert!(!msg.is_valid()); + msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32; + msg.padding1 = 1; + assert!(!msg.is_valid()); + } + + #[test] + fn check_user_memory_region() { + let mut msg = VhostUserMemoryRegion::new(0, 0x1000, 0, 0); + assert!(msg.is_valid()); + msg.guest_phys_addr = 0xFFFFFFFFFFFFEFFF; + assert!(msg.is_valid()); + msg.guest_phys_addr = 0xFFFFFFFFFFFFF000; + assert!(!msg.is_valid()); + msg.guest_phys_addr = 0xFFFFFFFFFFFF0000; + msg.memory_size = 0; + assert!(!msg.is_valid()); + let a = msg.guest_phys_addr; + let b = msg.guest_phys_addr; + assert_eq!(a, b); + + let msg = VhostUserMemoryRegion::default(); + let a = msg.guest_phys_addr; + assert_eq!(a, 0); + let a = msg.memory_size; + assert_eq!(a, 0); + let a = msg.user_addr; + assert_eq!(a, 0); + let a = msg.mmap_offset; + assert_eq!(a, 0); + } + + #[test] + fn test_vhost_user_state() { + let state = VhostUserVringState::new(5, 8); + + let a = state.index; + assert_eq!(a, 5); + let a = state.num; + assert_eq!(a, 8); + assert!(state.is_valid()); + + let state = VhostUserVringState::default(); + let a = state.index; + assert_eq!(a, 0); + let a = state.num; + assert_eq!(a, 0); + assert!(state.is_valid()); + } + + #[test] + fn test_vhost_user_addr() { + let mut addr = VhostUserVringAddr::new( + 2, + VhostUserVringAddrFlags::VHOST_VRING_F_LOG, + 0x1000, + 0x2000, + 0x3000, + 0x4000, + ); + + let a = addr.index; + assert_eq!(a, 2); + let a = addr.flags; + assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits()); + let a = addr.descriptor; + assert_eq!(a, 0x1000); + let a = addr.used; + assert_eq!(a, 0x2000); + let a = addr.available; + assert_eq!(a, 0x3000); + let a = addr.log; + assert_eq!(a, 0x4000); + assert!(addr.is_valid()); + + addr.descriptor = 0x1001; + assert!(!addr.is_valid()); + addr.descriptor = 0x1000; + + addr.available = 0x3001; + assert!(!addr.is_valid()); + addr.available = 0x3000; + + addr.used = 0x2001; + assert!(!addr.is_valid()); + addr.used = 0x2000; + assert!(addr.is_valid()); + } + + #[test] + fn test_vhost_user_state_from_config() { + let config = VringConfigData { + queue_max_size: 256, + queue_size: 128, + flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: Some(0x4000), + }; + let addr = VhostUserVringAddr::from_config_data(2, &config); + + let a = addr.index; + assert_eq!(a, 2); + let a = addr.flags; + assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits()); + let a = addr.descriptor; + assert_eq!(a, 0x1000); + let a = addr.used; + assert_eq!(a, 0x2000); + let a = addr.available; + assert_eq!(a, 0x3000); + let a = addr.log; + assert_eq!(a, 0x4000); + assert!(addr.is_valid()); + } + + #[test] + fn check_user_vring_addr() { + let mut msg = + VhostUserVringAddr::new(0, VhostUserVringAddrFlags::all(), 0x0, 0x0, 0x0, 0x0); + assert!(msg.is_valid()); + + msg.descriptor = 1; + assert!(!msg.is_valid()); + msg.descriptor = 0; + + msg.available = 1; + assert!(!msg.is_valid()); + msg.available = 0; + + msg.used = 1; + assert!(!msg.is_valid()); + msg.used = 0; + + msg.flags |= 0x80000000; + assert!(!msg.is_valid()); + msg.flags &= !0x80000000; + } + + #[test] + fn check_user_config_msg() { + let mut msg = + VhostUserConfig::new(0, VHOST_USER_CONFIG_SIZE, VhostUserConfigFlags::WRITABLE); + + assert!(msg.is_valid()); + msg.size = 0; + assert!(!msg.is_valid()); + msg.size = 1; + assert!(msg.is_valid()); + msg.offset = u32::MAX; + assert!(!msg.is_valid()); + msg.offset = VHOST_USER_CONFIG_SIZE; + assert!(!msg.is_valid()); + msg.offset = VHOST_USER_CONFIG_SIZE - 1; + assert!(msg.is_valid()); + msg.size = 2; + assert!(!msg.is_valid()); + msg.size = 1; + msg.flags |= VhostUserConfigFlags::LIVE_MIGRATION.bits(); + assert!(msg.is_valid()); + msg.flags |= 0x4; + assert!(!msg.is_valid()); + } + + #[test] + fn test_vhost_user_fs_slave() { + let mut fs_slave = VhostUserFSSlaveMsg::default(); + + assert!(fs_slave.is_valid()); + + fs_slave.fd_offset[0] = 0xffff_ffff_ffff_ffff; + fs_slave.len[0] = 0x1; + assert!(!fs_slave.is_valid()); + + assert_ne!( + VhostUserFSSlaveMsgFlags::MAP_R, + VhostUserFSSlaveMsgFlags::MAP_W + ); + assert_eq!(VhostUserFSSlaveMsgFlags::EMPTY.bits(), 0); + } +} diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs new file mode 100644 index 0000000..7df51f6 --- /dev/null +++ b/src/vhost_user/mod.rs @@ -0,0 +1,540 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! The protocol for vhost-user is based on the existing implementation of vhost for the Linux +//! Kernel. The protocol defines two sides of the communication, master and slave. Master is +//! the application that shares its virtqueues. Slave is the consumer of the virtqueues. +//! +//! The communication channel between the master and the slave includes two sub channels. One is +//! used to send requests from the master to the slave and optional replies from the slave to the +//! master. This sub channel is created on master startup by connecting to the slave service +//! endpoint. The other is used to send requests from the slave to the master and optional replies +//! from the master to the slave. This sub channel is created by the master issuing a +//! VHOST_USER_SET_SLAVE_REQ_FD request to the slave with an auxiliary file descriptor. +//! +//! Unix domain socket is used as the underlying communication channel because the master needs to +//! send file descriptors to the slave. +//! +//! Most messages that can be sent via the Unix domain socket implementing vhost-user have an +//! equivalent ioctl to the kernel implementation. + +use std::fs::File; +use std::io::Error as IOError; + +pub mod message; +pub use self::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures}; + +mod connection; +pub use self::connection::Listener; + +#[cfg(feature = "vhost-user-master")] +mod master; +#[cfg(feature = "vhost-user-master")] +pub use self::master::{Master, VhostUserMaster}; +#[cfg(feature = "vhost-user")] +mod master_req_handler; +#[cfg(feature = "vhost-user")] +pub use self::master_req_handler::{ + MasterReqHandler, VhostUserMasterReqHandler, VhostUserMasterReqHandlerMut, +}; + +#[cfg(feature = "vhost-user-slave")] +mod slave; +#[cfg(feature = "vhost-user-slave")] +pub use self::slave::SlaveListener; +#[cfg(feature = "vhost-user-slave")] +mod slave_req_handler; +#[cfg(feature = "vhost-user-slave")] +pub use self::slave_req_handler::{ + SlaveReqHandler, VhostUserSlaveReqHandler, VhostUserSlaveReqHandlerMut, +}; +#[cfg(feature = "vhost-user-slave")] +mod slave_req; +#[cfg(feature = "vhost-user-slave")] +pub use self::slave_req::Slave; + +/// Errors for vhost-user operations +#[derive(Debug)] +pub enum Error { + /// Invalid parameters. + InvalidParam, + /// Invalid operation due to some reason + InvalidOperation(&'static str), + /// Unsupported operation due to missing feature + InactiveFeature(VhostUserVirtioFeatures), + /// Unsupported operations due to that the protocol feature hasn't been negotiated. + InactiveOperation(VhostUserProtocolFeatures), + /// Invalid message format, flag or content. + InvalidMessage, + /// Only part of a message have been sent or received successfully + PartialMessage, + /// The peer disconnected from the socket. + Disconnected, + /// Message is too large + OversizedMsg, + /// Fd array in question is too big or too small + IncorrectFds, + /// Can't connect to peer. + SocketConnect(std::io::Error), + /// Generic socket errors. + SocketError(std::io::Error), + /// The socket is broken or has been closed. + SocketBroken(std::io::Error), + /// Should retry the socket operation again. + SocketRetry(std::io::Error), + /// Failure from the slave side. + SlaveInternalError, + /// Failure from the master side. + MasterInternalError, + /// Virtio/protocol features mismatch. + FeatureMismatch, + /// Error from request handler + ReqHandlerError(IOError), + /// memfd file creation error + MemFdCreateError, + /// File truncate error + FileTrucateError, + /// memfd file seal errors + MemFdSealError, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Error::InvalidParam => write!(f, "invalid parameters"), + Error::InvalidOperation(reason) => write!(f, "invalid operation: {}", reason), + Error::InactiveFeature(bits) => write!(f, "inactive feature: {}", bits.bits()), + Error::InactiveOperation(bits) => { + write!(f, "inactive protocol operation: {}", bits.bits()) + } + Error::InvalidMessage => write!(f, "invalid message"), + Error::PartialMessage => write!(f, "partial message"), + Error::Disconnected => write!(f, "peer disconnected"), + Error::OversizedMsg => write!(f, "oversized message"), + Error::IncorrectFds => write!(f, "wrong number of attached fds"), + Error::SocketError(e) => write!(f, "socket error: {}", e), + Error::SocketConnect(e) => write!(f, "can't connect to peer: {}", e), + Error::SocketBroken(e) => write!(f, "socket is broken: {}", e), + Error::SocketRetry(e) => write!(f, "temporary socket error: {}", e), + Error::SlaveInternalError => write!(f, "slave internal error"), + Error::MasterInternalError => write!(f, "Master internal error"), + Error::FeatureMismatch => write!(f, "virtio/protocol features mismatch"), + Error::ReqHandlerError(e) => write!(f, "handler failed to handle request: {}", e), + Error::MemFdCreateError => { + write!(f, "handler failed to allocate memfd during get_inflight_fd") + } + Error::FileTrucateError => { + write!(f, "handler failed to trucate memfd during get_inflight_fd") + } + Error::MemFdSealError => write!( + f, + "handler failed to apply seals to memfd during get_inflight_fd" + ), + } + } +} + +impl std::error::Error for Error {} + +impl Error { + /// Determine whether to rebuild the underline communication channel. + pub fn should_reconnect(&self) -> bool { + match *self { + // Should reconnect because it may be caused by temporary network errors. + Error::PartialMessage => true, + // Should reconnect because the underline socket is broken. + Error::SocketBroken(_) => true, + // Slave internal error, hope it recovers on reconnect. + Error::SlaveInternalError => true, + // Master internal error, hope it recovers on reconnect. + Error::MasterInternalError => true, + // Should just retry the IO operation instead of rebuilding the underline connection. + Error::SocketRetry(_) => false, + // Looks like the peer deliberately disconnected the socket. + Error::Disconnected => false, + Error::InvalidParam | Error::InvalidOperation(_) => false, + Error::InactiveFeature(_) | Error::InactiveOperation(_) => false, + Error::InvalidMessage | Error::IncorrectFds | Error::OversizedMsg => false, + Error::SocketError(_) | Error::SocketConnect(_) => false, + Error::FeatureMismatch => false, + Error::ReqHandlerError(_) => false, + Error::MemFdCreateError | Error::FileTrucateError | Error::MemFdSealError => false, + } + } +} + +impl std::convert::From<vmm_sys_util::errno::Error> for Error { + /// Convert raw socket errors into meaningful vhost-user errors. + /// + /// The vmm_sys_util::errno::Error is a simple wrapper over the raw errno, which doesn't means + /// much to the vhost-user connection manager. So convert it into meaningful errors to simplify + /// the connection manager logic. + /// + /// # Return: + /// * - Error::SocketRetry: temporary error caused by signals or short of resources. + /// * - Error::SocketBroken: the underline socket is broken. + /// * - Error::SocketError: other socket related errors. + #[allow(unreachable_patterns)] // EWOULDBLOCK equals to EGAIN on linux + fn from(err: vmm_sys_util::errno::Error) -> Self { + match err.errno() { + // The socket is marked nonblocking and the requested operation would block. + libc::EAGAIN => Error::SocketRetry(IOError::from_raw_os_error(libc::EAGAIN)), + // The socket is marked nonblocking and the requested operation would block. + libc::EWOULDBLOCK => Error::SocketRetry(IOError::from_raw_os_error(libc::EWOULDBLOCK)), + // A signal occurred before any data was transmitted + libc::EINTR => Error::SocketRetry(IOError::from_raw_os_error(libc::EINTR)), + // The output queue for a network interface was full. This generally indicates + // that the interface has stopped sending, but may be caused by transient congestion. + libc::ENOBUFS => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOBUFS)), + // No memory available. + libc::ENOMEM => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOMEM)), + // Connection reset by peer. + libc::ECONNRESET => Error::SocketBroken(IOError::from_raw_os_error(libc::ECONNRESET)), + // The local end has been shut down on a connection oriented socket. In this case the + // process will also receive a SIGPIPE unless MSG_NOSIGNAL is set. + libc::EPIPE => Error::SocketBroken(IOError::from_raw_os_error(libc::EPIPE)), + // Write permission is denied on the destination socket file, or search permission is + // denied for one of the directories the path prefix. + libc::EACCES => Error::SocketConnect(IOError::from_raw_os_error(libc::EACCES)), + // Catch all other errors + e => Error::SocketError(IOError::from_raw_os_error(e)), + } + } +} + +/// Result of vhost-user operations +pub type Result<T> = std::result::Result<T, Error>; + +/// Result of request handler. +pub type HandlerResult<T> = std::result::Result<T, IOError>; + +/// Utility function to take the first element from option of a vector of files. +/// Returns `None` if the vector contains no file or more than one file. +pub(crate) fn take_single_file(files: Option<Vec<File>>) -> Option<File> { + let mut files = files?; + if files.len() != 1 { + return None; + } + Some(files.swap_remove(0)) +} + +#[cfg(all(test, feature = "vhost-user-slave"))] +mod dummy_slave; + +#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))] +mod tests { + use std::fs::File; + use std::os::unix::io::AsRawFd; + use std::path::{Path, PathBuf}; + use std::sync::{Arc, Barrier, Mutex}; + use std::thread; + use vmm_sys_util::rand::rand_alphanumerics; + use vmm_sys_util::tempfile::TempFile; + + use super::dummy_slave::{DummySlaveReqHandler, VIRTIO_FEATURES}; + use super::message::*; + use super::*; + use crate::backend::VhostBackend; + use crate::{VhostUserDirtyLogRegion, VhostUserMemoryRegionInfo, VringConfigData}; + + fn temp_path() -> PathBuf { + PathBuf::from(format!( + "/tmp/vhost_test_{}", + rand_alphanumerics(8).to_str().unwrap() + )) + } + + fn create_slave<P, S>(path: P, backend: Arc<S>) -> (Master, SlaveReqHandler<S>) + where + P: AsRef<Path>, + S: VhostUserSlaveReqHandler, + { + let listener = Listener::new(&path, true).unwrap(); + let mut slave_listener = SlaveListener::new(listener, backend).unwrap(); + let master = Master::connect(&path, 1).unwrap(); + (master, slave_listener.accept().unwrap().unwrap()) + } + + #[test] + fn create_dummy_slave() { + let slave = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + + slave.set_owner().unwrap(); + assert!(slave.set_owner().is_err()); + } + + #[test] + fn test_set_owner() { + let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let path = temp_path(); + let (master, mut slave) = create_slave(path, slave_be.clone()); + + assert!(!slave_be.lock().unwrap().owned); + master.set_owner().unwrap(); + slave.handle_request().unwrap(); + assert!(slave_be.lock().unwrap().owned); + master.set_owner().unwrap(); + assert!(slave.handle_request().is_err()); + assert!(slave_be.lock().unwrap().owned); + } + + #[test] + fn test_set_features() { + let mbar = Arc::new(Barrier::new(2)); + let sbar = mbar.clone(); + let path = temp_path(); + let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let (mut master, mut slave) = create_slave(path, slave_be.clone()); + + thread::spawn(move || { + slave.handle_request().unwrap(); + assert!(slave_be.lock().unwrap().owned); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_features, + VIRTIO_FEATURES & !0x1 + ); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_protocol_features, + VhostUserProtocolFeatures::all().bits() + ); + + sbar.wait(); + }); + + master.set_owner().unwrap(); + + // set virtio features + let features = master.get_features().unwrap(); + assert_eq!(features, VIRTIO_FEATURES); + master.set_features(VIRTIO_FEATURES & !0x1).unwrap(); + + // set vhost protocol features + let features = master.get_protocol_features().unwrap(); + assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits()); + master.set_protocol_features(features).unwrap(); + + mbar.wait(); + } + + #[test] + fn test_master_slave_process() { + let mbar = Arc::new(Barrier::new(2)); + let sbar = mbar.clone(); + let path = temp_path(); + let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let (mut master, mut slave) = create_slave(path, slave_be.clone()); + + thread::spawn(move || { + // set_own() + slave.handle_request().unwrap(); + assert!(slave_be.lock().unwrap().owned); + + // get/set_features() + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_features, + VIRTIO_FEATURES & !0x1 + ); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + + let mut features = VhostUserProtocolFeatures::all(); + + // Disable Xen mmap feature. + if !cfg!(feature = "xen") { + features.remove(VhostUserProtocolFeatures::XEN_MMAP); + } + + assert_eq!( + slave_be.lock().unwrap().acked_protocol_features, + features.bits() + ); + + // get_inflight_fd() + slave.handle_request().unwrap(); + // set_inflight_fd() + slave.handle_request().unwrap(); + + // get_queue_num() + slave.handle_request().unwrap(); + + // set_mem_table() + slave.handle_request().unwrap(); + + // get/set_config() + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + + // set_slave_request_fd + slave.handle_request().unwrap(); + + // set_vring_enable + slave.handle_request().unwrap(); + + // set_log_base,set_log_fd() + slave.handle_request().unwrap_err(); + slave.handle_request().unwrap_err(); + + // set_vring_xxx + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + + // get_max_mem_slots() + slave.handle_request().unwrap(); + + // add_mem_region() + slave.handle_request().unwrap(); + + // remove_mem_region() + slave.handle_request().unwrap(); + + sbar.wait(); + }); + + master.set_owner().unwrap(); + + // set virtio features + let features = master.get_features().unwrap(); + assert_eq!(features, VIRTIO_FEATURES); + master.set_features(VIRTIO_FEATURES & !0x1).unwrap(); + + // set vhost protocol features + let mut features = master.get_protocol_features().unwrap(); + assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits()); + + // Disable Xen mmap feature. + if !cfg!(feature = "xen") { + features.remove(VhostUserProtocolFeatures::XEN_MMAP); + } + + master.set_protocol_features(features).unwrap(); + + // Retrieve inflight I/O tracking information + let (inflight_info, inflight_file) = master + .get_inflight_fd(&VhostUserInflight { + num_queues: 2, + queue_size: 256, + ..Default::default() + }) + .unwrap(); + // Set the buffer back to the backend + master + .set_inflight_fd(&inflight_info, inflight_file.as_raw_fd()) + .unwrap(); + + let num = master.get_queue_num().unwrap(); + assert_eq!(num, 2); + + let eventfd = vmm_sys_util::eventfd::EventFd::new(0).unwrap(); + let mem = [VhostUserMemoryRegionInfo::new( + 0, + 0x10_0000, + 0, + 0, + eventfd.as_raw_fd(), + )]; + master.set_mem_table(&mem).unwrap(); + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8; 4]) + .unwrap(); + let buf = [0x0u8; 4]; + let (reply_body, reply_payload) = master + .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf) + .unwrap(); + let offset = reply_body.offset; + assert_eq!(offset, 0x100); + assert_eq!(&reply_payload, &[0xa5; 4]); + + master.set_slave_request_fd(&eventfd).unwrap(); + master.set_vring_enable(0, true).unwrap(); + + master + .set_log_base( + 0, + Some(VhostUserDirtyLogRegion { + mmap_size: 0x1000, + mmap_offset: 0, + mmap_handle: eventfd.as_raw_fd(), + }), + ) + .unwrap(); + master.set_log_fd(eventfd.as_raw_fd()).unwrap(); + + master.set_vring_num(0, 256).unwrap(); + master.set_vring_base(0, 0).unwrap(); + let config = VringConfigData { + queue_max_size: 256, + queue_size: 128, + flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(), + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: Some(0x4000), + }; + master.set_vring_addr(0, &config).unwrap(); + master.set_vring_call(0, &eventfd).unwrap(); + master.set_vring_kick(0, &eventfd).unwrap(); + master.set_vring_err(0, &eventfd).unwrap(); + + let max_mem_slots = master.get_max_mem_slots().unwrap(); + assert_eq!(max_mem_slots, 32); + + let region_file: File = TempFile::new().unwrap().into_file(); + let region = + VhostUserMemoryRegionInfo::new(0x10_0000, 0x10_0000, 0, 0, region_file.as_raw_fd()); + master.add_mem_region(®ion).unwrap(); + + master.remove_mem_region(®ion).unwrap(); + + mbar.wait(); + } + + #[test] + fn test_error_display() { + assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters"); + assert_eq!( + format!("{}", Error::InvalidOperation("reason")), + "invalid operation: reason" + ); + } + + #[test] + fn test_should_reconnect() { + assert!(Error::PartialMessage.should_reconnect()); + assert!(Error::SlaveInternalError.should_reconnect()); + assert!(Error::MasterInternalError.should_reconnect()); + assert!(!Error::InvalidParam.should_reconnect()); + assert!(!Error::InvalidOperation("reason").should_reconnect()); + assert!( + !Error::InactiveFeature(VhostUserVirtioFeatures::PROTOCOL_FEATURES).should_reconnect() + ); + assert!(!Error::InactiveOperation(VhostUserProtocolFeatures::all()).should_reconnect()); + assert!(!Error::InvalidMessage.should_reconnect()); + assert!(!Error::IncorrectFds.should_reconnect()); + assert!(!Error::OversizedMsg.should_reconnect()); + assert!(!Error::FeatureMismatch.should_reconnect()); + } + + #[test] + fn test_error_from_sys_util_error() { + let e: Error = vmm_sys_util::errno::Error::new(libc::EAGAIN).into(); + if let Error::SocketRetry(e1) = e { + assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN); + } else { + panic!("invalid error code conversion!"); + } + } +} diff --git a/src/vhost_user/slave.rs b/src/vhost_user/slave.rs new file mode 100644 index 0000000..fb65c41 --- /dev/null +++ b/src/vhost_user/slave.rs @@ -0,0 +1,86 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Traits and Structs for vhost-user slave. + +use std::sync::Arc; + +use super::connection::{Endpoint, Listener}; +use super::message::*; +use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler}; + +/// Vhost-user slave side connection listener. +pub struct SlaveListener<S: VhostUserSlaveReqHandler> { + listener: Listener, + backend: Option<Arc<S>>, +} + +/// Sets up a listener for incoming master connections, and handles construction +/// of a Slave on success. +impl<S: VhostUserSlaveReqHandler> SlaveListener<S> { + /// Create a unix domain socket for incoming master connections. + pub fn new(listener: Listener, backend: Arc<S>) -> Result<Self> { + Ok(SlaveListener { + listener, + backend: Some(backend), + }) + } + + /// Accept an incoming connection from the master, returning Some(Slave) on + /// success, or None if the socket is nonblocking and no incoming connection + /// was detected + pub fn accept(&mut self) -> Result<Option<SlaveReqHandler<S>>> { + if let Some(fd) = self.listener.accept()? { + return Ok(Some(SlaveReqHandler::new( + Endpoint::<MasterReq>::from_stream(fd), + self.backend.take().unwrap(), + ))); + } + Ok(None) + } + + /// Change blocking status on the listener. + pub fn set_nonblocking(&self, block: bool) -> Result<()> { + self.listener.set_nonblocking(block) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Mutex; + + use super::*; + use crate::vhost_user::dummy_slave::DummySlaveReqHandler; + + #[test] + fn test_slave_listener_set_nonblocking() { + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let listener = + Listener::new("/tmp/vhost_user_lib_unit_test_slave_nonblocking", true).unwrap(); + let slave_listener = SlaveListener::new(listener, backend).unwrap(); + + slave_listener.set_nonblocking(true).unwrap(); + slave_listener.set_nonblocking(false).unwrap(); + slave_listener.set_nonblocking(false).unwrap(); + slave_listener.set_nonblocking(true).unwrap(); + slave_listener.set_nonblocking(true).unwrap(); + } + + #[cfg(feature = "vhost-user-master")] + #[test] + fn test_slave_listener_accept() { + use super::super::Master; + + let path = "/tmp/vhost_user_lib_unit_test_slave_accept"; + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let listener = Listener::new(path, true).unwrap(); + let mut slave_listener = SlaveListener::new(listener, backend).unwrap(); + + slave_listener.set_nonblocking(true).unwrap(); + assert!(slave_listener.accept().unwrap().is_none()); + assert!(slave_listener.accept().unwrap().is_none()); + + let _master = Master::connect(path, 1).unwrap(); + let _slave = slave_listener.accept().unwrap().unwrap(); + } +} diff --git a/src/vhost_user/slave_req.rs b/src/vhost_user/slave_req.rs new file mode 100644 index 0000000..ade1e91 --- /dev/null +++ b/src/vhost_user/slave_req.rs @@ -0,0 +1,219 @@ +// Copyright (C) 2020 Alibaba Cloud. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::io; +use std::mem; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::sync::{Arc, Mutex, MutexGuard}; + +use super::connection::Endpoint; +use super::message::*; +use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler}; + +use vm_memory::ByteValued; + +struct SlaveInternal { + sock: Endpoint<SlaveReq>, + + // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. + reply_ack_negotiated: bool, + + // whether the endpoint has encountered any failure + error: Option<i32>, +} + +impl SlaveInternal { + fn check_state(&self) -> Result<u64> { + match self.error { + Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), + None => Ok(0), + } + } + + fn send_message<T: ByteValued>( + &mut self, + request: SlaveReq, + body: &T, + fds: Option<&[RawFd]>, + ) -> Result<u64> { + self.check_state()?; + + let len = mem::size_of::<T>(); + let mut hdr = VhostUserMsgHeader::new(request, 0, len as u32); + if self.reply_ack_negotiated { + hdr.set_need_reply(true); + } + self.sock.send_message(&hdr, body, fds)?; + + self.wait_for_ack(&hdr) + } + + fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<SlaveReq>) -> Result<u64> { + self.check_state()?; + if !self.reply_ack_negotiated { + return Ok(0); + } + + let (reply, body, rfds) = self.sock.recv_body::<VhostUserU64>()?; + if !reply.is_reply_for(hdr) || rfds.is_some() || !body.is_valid() { + return Err(Error::InvalidMessage); + } + if body.value != 0 { + return Err(Error::MasterInternalError); + } + + Ok(body.value) + } +} + +/// Request proxy to send vhost-user slave requests to the master through the slave +/// communication channel. +/// +/// The [Slave] acts as a message proxy to forward vhost-user slave requests to the +/// master through the vhost-user slave communication channel. The forwarded messages will be +/// handled by the [MasterReqHandler] server. +/// +/// [Slave]: struct.Slave.html +/// [MasterReqHandler]: struct.MasterReqHandler.html +#[derive(Clone)] +pub struct Slave { + // underlying Unix domain socket for communication + node: Arc<Mutex<SlaveInternal>>, +} + +impl Slave { + fn new(ep: Endpoint<SlaveReq>) -> Self { + Slave { + node: Arc::new(Mutex::new(SlaveInternal { + sock: ep, + reply_ack_negotiated: false, + error: None, + })), + } + } + + fn node(&self) -> MutexGuard<SlaveInternal> { + self.node.lock().unwrap() + } + + fn send_message<T: ByteValued>( + &self, + request: SlaveReq, + body: &T, + fds: Option<&[RawFd]>, + ) -> io::Result<u64> { + self.node() + .send_message(request, body, fds) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e))) + } + + /// Create a new instance from a `UnixStream` object. + pub fn from_stream(sock: UnixStream) -> Self { + Self::new(Endpoint::<SlaveReq>::from_stream(sock)) + } + + /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. + /// + /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, + /// the "REPLY_ACK" flag will be set in the message header for every slave to master request + /// message. + pub fn set_reply_ack_flag(&self, enable: bool) { + self.node().reply_ack_negotiated = enable; + } + + /// Mark endpoint as failed with specified error code. + pub fn set_failed(&self, error: i32) { + self.node().error = Some(error); + } +} + +impl VhostUserMasterReqHandler for Slave { + /// Forward vhost-user-fs map file requests to the slave. + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> { + self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd.as_raw_fd()])) + } + + /// Forward vhost-user-fs unmap file requests to the master. + fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + self.send_message(SlaveReq::FS_UNMAP, fs, None) + } +} + +#[cfg(test)] +mod tests { + use std::os::unix::io::AsRawFd; + + use super::*; + + #[test] + fn test_slave_req_set_failed() { + let (p1, _p2) = UnixStream::pair().unwrap(); + let slave = Slave::from_stream(p1); + + assert!(slave.node().error.is_none()); + slave.set_failed(libc::EAGAIN); + assert_eq!(slave.node().error, Some(libc::EAGAIN)); + } + + #[test] + fn test_slave_req_send_failure() { + let (p1, p2) = UnixStream::pair().unwrap(); + let slave = Slave::from_stream(p1); + + slave.set_failed(libc::ECONNRESET); + slave + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &p2) + .unwrap_err(); + slave + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + slave.node().error = None; + } + + #[test] + fn test_slave_req_recv_negative() { + let (p1, p2) = UnixStream::pair().unwrap(); + let slave = Slave::from_stream(p1); + let mut master = Endpoint::<SlaveReq>::from_stream(p2); + + let len = mem::size_of::<VhostUserFSSlaveMsg>(); + let mut hdr = VhostUserMsgHeader::new( + SlaveReq::FS_MAP, + VhostUserHeaderFlag::REPLY.bits(), + len as u32, + ); + let body = VhostUserU64::new(0); + + master + .send_message(&hdr, &body, Some(&[master.as_raw_fd()])) + .unwrap(); + slave + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) + .unwrap(); + + slave.set_reply_ack_flag(true); + slave + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) + .unwrap_err(); + + hdr.set_code(SlaveReq::FS_UNMAP); + master.send_message(&hdr, &body, None).unwrap(); + slave + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) + .unwrap_err(); + hdr.set_code(SlaveReq::FS_MAP); + + let body = VhostUserU64::new(1); + master.send_message(&hdr, &body, None).unwrap(); + slave + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) + .unwrap_err(); + + let body = VhostUserU64::new(0); + master.send_message(&hdr, &body, None).unwrap(); + slave + .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master) + .unwrap(); + } +} diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs new file mode 100644 index 0000000..e998339 --- /dev/null +++ b/src/vhost_user/slave_req_handler.rs @@ -0,0 +1,833 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::fs::File; +use std::mem; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::slice; +use std::sync::{Arc, Mutex}; + +use vm_memory::ByteValued; + +use super::connection::Endpoint; +use super::message::*; +use super::slave_req::Slave; +use super::{take_single_file, Error, Result}; + +/// Services provided to the master by the slave with interior mutability. +/// +/// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave. +/// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler], +/// but without interior mutability. +/// The vhost-user specification defines a master communication channel, by which masters could +/// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by +/// slaves, and it's used both on the master side and slave side. +/// +/// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy +/// service requests to slaves. +/// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler +/// implementing [VhostUserSlaveReqHandler]. +/// +/// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance +/// for multi-threading. +/// +/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html +/// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html +/// [SlaveReqHandler]: struct.SlaveReqHandler.html +#[allow(missing_docs)] +pub trait VhostUserSlaveReqHandler { + fn set_owner(&self) -> Result<()>; + fn reset_owner(&self) -> Result<()>; + fn get_features(&self) -> Result<u64>; + fn set_features(&self, features: u64) -> Result<()>; + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; + fn set_vring_num(&self, index: u32, num: u32) -> Result<()>; + fn set_vring_addr( + &self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()>; + fn set_vring_base(&self, index: u32, base: u32) -> Result<()>; + fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>; + fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>; + + fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>; + fn set_protocol_features(&self, features: u64) -> Result<()>; + fn get_queue_num(&self) -> Result<u64>; + fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>; + fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>; + fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; + fn set_slave_req_fd(&self, _slave: Slave) {} + fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>; + fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>; + fn get_max_mem_slots(&self) -> Result<u64>; + fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; + fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>; +} + +/// Services provided to the master by the slave without interior mutability. +/// +/// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait. +#[allow(missing_docs)] +pub trait VhostUserSlaveReqHandlerMut { + fn set_owner(&mut self) -> Result<()>; + fn reset_owner(&mut self) -> Result<()>; + fn get_features(&mut self) -> Result<u64>; + fn set_features(&mut self, features: u64) -> Result<()>; + fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; + fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>; + fn set_vring_addr( + &mut self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()>; + fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>; + fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>; + fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>; + fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>; + + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; + fn set_protocol_features(&mut self, features: u64) -> Result<()>; + fn get_queue_num(&mut self) -> Result<u64>; + fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>; + fn get_config( + &mut self, + offset: u32, + size: u32, + flags: VhostUserConfigFlags, + ) -> Result<Vec<u8>>; + fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; + fn set_slave_req_fd(&mut self, _slave: Slave) {} + fn get_inflight_fd( + &mut self, + inflight: &VhostUserInflight, + ) -> Result<(VhostUserInflight, File)>; + fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>; + fn get_max_mem_slots(&mut self) -> Result<u64>; + fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; + fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>; +} + +impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { + fn set_owner(&self) -> Result<()> { + self.lock().unwrap().set_owner() + } + + fn reset_owner(&self) -> Result<()> { + self.lock().unwrap().reset_owner() + } + + fn get_features(&self) -> Result<u64> { + self.lock().unwrap().get_features() + } + + fn set_features(&self, features: u64) -> Result<()> { + self.lock().unwrap().set_features(features) + } + + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> { + self.lock().unwrap().set_mem_table(ctx, files) + } + + fn set_vring_num(&self, index: u32, num: u32) -> Result<()> { + self.lock().unwrap().set_vring_num(index, num) + } + + fn set_vring_addr( + &self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()> { + self.lock() + .unwrap() + .set_vring_addr(index, flags, descriptor, used, available, log) + } + + fn set_vring_base(&self, index: u32, base: u32) -> Result<()> { + self.lock().unwrap().set_vring_base(index, base) + } + + fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> { + self.lock().unwrap().get_vring_base(index) + } + + fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> { + self.lock().unwrap().set_vring_kick(index, fd) + } + + fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> { + self.lock().unwrap().set_vring_call(index, fd) + } + + fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> { + self.lock().unwrap().set_vring_err(index, fd) + } + + fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> { + self.lock().unwrap().get_protocol_features() + } + + fn set_protocol_features(&self, features: u64) -> Result<()> { + self.lock().unwrap().set_protocol_features(features) + } + + fn get_queue_num(&self) -> Result<u64> { + self.lock().unwrap().get_queue_num() + } + + fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> { + self.lock().unwrap().set_vring_enable(index, enable) + } + + fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> { + self.lock().unwrap().get_config(offset, size, flags) + } + + fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> { + self.lock().unwrap().set_config(offset, buf, flags) + } + + fn set_slave_req_fd(&self, slave: Slave) { + self.lock().unwrap().set_slave_req_fd(slave) + } + + fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> { + self.lock().unwrap().get_inflight_fd(inflight) + } + + fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> { + self.lock().unwrap().set_inflight_fd(inflight, file) + } + + fn get_max_mem_slots(&self) -> Result<u64> { + self.lock().unwrap().get_max_mem_slots() + } + + fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> { + self.lock().unwrap().add_mem_region(region, fd) + } + + fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> { + self.lock().unwrap().remove_mem_region(region) + } +} + +/// Server to handle service requests from masters from the master communication channel. +/// +/// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from +/// masters on the master communication channel. It's actually a proxy invoking the registered +/// handler implementing [VhostUserSlaveReqHandler] to do the real work. +/// +/// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain +/// Socket, so it gets simpler to recover from disconnect. +/// +/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html +/// [SlaveReqHandler]: struct.SlaveReqHandler.html +pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> { + // underlying Unix domain socket for communication + main_sock: Endpoint<MasterReq>, + // the vhost-user backend device object + backend: Arc<S>, + + virtio_features: u64, + acked_virtio_features: u64, + protocol_features: VhostUserProtocolFeatures, + acked_protocol_features: u64, + + // sending ack for messages without payload + reply_ack_enabled: bool, + // whether the endpoint has encountered any failure + error: Option<i32>, +} + +impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { + /// Create a vhost-user slave endpoint. + pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self { + SlaveReqHandler { + main_sock, + backend, + virtio_features: 0, + acked_virtio_features: 0, + protocol_features: VhostUserProtocolFeatures::empty(), + acked_protocol_features: 0, + reply_ack_enabled: false, + error: None, + } + } + + fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> { + if self.acked_virtio_features & feat.bits() != 0 { + Ok(()) + } else { + Err(Error::InactiveFeature(feat)) + } + } + + fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> Result<()> { + if self.acked_protocol_features & feat.bits() != 0 { + Ok(()) + } else { + Err(Error::InactiveOperation(feat)) + } + } + + /// Create a vhost-user slave endpoint from a connected socket. + pub fn from_stream(socket: UnixStream, backend: Arc<S>) -> Self { + Self::new(Endpoint::from_stream(socket), backend) + } + + /// Create a new vhost-user slave endpoint. + /// + /// # Arguments + /// * - `path` - path of Unix domain socket listener to connect to + /// * - `backend` - handler for requests from the master to the slave + pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> { + Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend)) + } + + /// Mark endpoint as failed with specified error code. + pub fn set_failed(&mut self, error: i32) { + self.error = Some(error); + } + + /// Main entrance to server slave request from the slave communication channel. + /// + /// Receive and handle one incoming request message from the master. The caller needs to: + /// - serialize calls to this function + /// - decide what to do when error happens + /// - optional recover from failure + pub fn handle_request(&mut self) -> Result<()> { + // Return error if the endpoint is already in failed state. + self.check_state()?; + + // The underlying communication channel is a Unix domain socket in + // stream mode, and recvmsg() is a little tricky here. To successfully + // receive attached file descriptors, we need to receive messages and + // corresponding attached file descriptors in this way: + // . recv messsage header and optional attached file + // . validate message header + // . recv optional message body and payload according size field in + // message header + // . validate message body and optional payload + let (hdr, files) = self.main_sock.recv_header()?; + self.check_attached_files(&hdr, &files)?; + + let (size, buf) = match hdr.get_size() { + 0 => (0, vec![0u8; 0]), + len => { + let (size2, rbuf) = self.main_sock.recv_data(len as usize)?; + if size2 != len as usize { + return Err(Error::InvalidMessage); + } + (size2, rbuf) + } + }; + + match hdr.get_code() { + Ok(MasterReq::SET_OWNER) => { + self.check_request_size(&hdr, size, 0)?; + let res = self.backend.set_owner(); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::RESET_OWNER) => { + self.check_request_size(&hdr, size, 0)?; + let res = self.backend.reset_owner(); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::GET_FEATURES) => { + self.check_request_size(&hdr, size, 0)?; + let features = self.backend.get_features()?; + let msg = VhostUserU64::new(features); + self.send_reply_message(&hdr, &msg)?; + self.virtio_features = features; + self.update_reply_ack_flag(); + } + Ok(MasterReq::SET_FEATURES) => { + let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; + let res = self.backend.set_features(msg.value); + self.acked_virtio_features = msg.value; + self.update_reply_ack_flag(); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::SET_MEM_TABLE) => { + let res = self.set_mem_table(&hdr, size, &buf, files); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::SET_VRING_NUM) => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + let res = self.backend.set_vring_num(msg.index, msg.num); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::SET_VRING_ADDR) => { + let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?; + let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) { + Some(val) => val, + None => return Err(Error::InvalidMessage), + }; + let res = self.backend.set_vring_addr( + msg.index, + flags, + msg.descriptor, + msg.used, + msg.available, + msg.log, + ); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::SET_VRING_BASE) => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + let res = self.backend.set_vring_base(msg.index, msg.num); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::GET_VRING_BASE) => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + let reply = self.backend.get_vring_base(msg.index)?; + self.send_reply_message(&hdr, &reply)?; + } + Ok(MasterReq::SET_VRING_CALL) => { + self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; + let (index, file) = self.handle_vring_fd_request(&buf, files)?; + let res = self.backend.set_vring_call(index, file); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::SET_VRING_KICK) => { + self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; + let (index, file) = self.handle_vring_fd_request(&buf, files)?; + let res = self.backend.set_vring_kick(index, file); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::SET_VRING_ERR) => { + self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; + let (index, file) = self.handle_vring_fd_request(&buf, files)?; + let res = self.backend.set_vring_err(index, file); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::GET_PROTOCOL_FEATURES) => { + self.check_request_size(&hdr, size, 0)?; + let features = self.backend.get_protocol_features()?; + + // Enable the `XEN_MMAP` protocol feature for backends if xen feature is enabled. + #[cfg(feature = "xen")] + let features = features | VhostUserProtocolFeatures::XEN_MMAP; + + let msg = VhostUserU64::new(features.bits()); + self.send_reply_message(&hdr, &msg)?; + self.protocol_features = features; + self.update_reply_ack_flag(); + } + Ok(MasterReq::SET_PROTOCOL_FEATURES) => { + let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; + let res = self.backend.set_protocol_features(msg.value); + self.acked_protocol_features = msg.value; + self.update_reply_ack_flag(); + self.send_ack_message(&hdr, res)?; + + #[cfg(feature = "xen")] + self.check_proto_feature(VhostUserProtocolFeatures::XEN_MMAP)?; + } + Ok(MasterReq::GET_QUEUE_NUM) => { + self.check_proto_feature(VhostUserProtocolFeatures::MQ)?; + self.check_request_size(&hdr, size, 0)?; + let num = self.backend.get_queue_num()?; + let msg = VhostUserU64::new(num); + self.send_reply_message(&hdr, &msg)?; + } + Ok(MasterReq::SET_VRING_ENABLE) => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?; + let enable = match msg.num { + 1 => true, + 0 => false, + _ => return Err(Error::InvalidParam), + }; + + let res = self.backend.set_vring_enable(msg.index, enable); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::GET_CONFIG) => { + self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; + self.get_config(&hdr, &buf)?; + } + Ok(MasterReq::SET_CONFIG) => { + self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; + let res = self.set_config(size, &buf); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::SET_SLAVE_REQ_FD) => { + self.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?; + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; + let res = self.set_slave_req_fd(files); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::GET_INFLIGHT_FD) => { + self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; + + let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; + let (inflight, file) = self.backend.get_inflight_fd(&msg)?; + let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?; + self.main_sock + .send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?; + } + Ok(MasterReq::SET_INFLIGHT_FD) => { + self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; + let file = take_single_file(files).ok_or(Error::IncorrectFds)?; + let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; + let res = self.backend.set_inflight_fd(&msg, file); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::GET_MAX_MEM_SLOTS) => { + self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; + self.check_request_size(&hdr, size, 0)?; + let num = self.backend.get_max_mem_slots()?; + let msg = VhostUserU64::new(num); + self.send_reply_message(&hdr, &msg)?; + } + Ok(MasterReq::ADD_MEM_REG) => { + self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; + let mut files = files.ok_or(Error::InvalidParam)?; + if files.len() != 1 { + return Err(Error::InvalidParam); + } + let msg = + self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; + let res = self.backend.add_mem_region(&msg, files.swap_remove(0)); + self.send_ack_message(&hdr, res)?; + } + Ok(MasterReq::REM_MEM_REG) => { + self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; + + let msg = + self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; + let res = self.backend.remove_mem_region(&msg); + self.send_ack_message(&hdr, res)?; + } + _ => { + return Err(Error::InvalidMessage); + } + } + Ok(()) + } + + fn set_mem_table( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + buf: &[u8], + files: Option<Vec<File>>, + ) -> Result<()> { + self.check_request_size(hdr, size, hdr.get_size() as usize)?; + + // check message size is consistent + let hdrsize = mem::size_of::<VhostUserMemory>(); + if size < hdrsize { + return Err(Error::InvalidMessage); + } + // SAFETY: Safe because we checked that `buf` size is at least that of + // VhostUserMemory. + let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() { + return Err(Error::InvalidMessage); + } + + // validate number of fds matching number of memory regions + let files = files.ok_or(Error::InvalidMessage)?; + if files.len() != msg.num_regions as usize { + return Err(Error::InvalidMessage); + } + + // Validate memory regions + // + // SAFETY: Safe because we checked that `buf` size is equal to that of + // VhostUserMemory, plus `msg.num_regions` elements of VhostUserMemoryRegion. + let regions = unsafe { + slice::from_raw_parts( + buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion, + msg.num_regions as usize, + ) + }; + for region in regions.iter() { + if !region.is_valid() { + return Err(Error::InvalidMessage); + } + } + + self.backend.set_mem_table(regions, files) + } + + fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> { + let payload_offset = mem::size_of::<VhostUserConfig>(); + if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset { + return Err(Error::InvalidMessage); + } + // SAFETY: Safe because we checked that `buf` size is at least that of VhostUserConfig. + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + if buf.len() - payload_offset != msg.size as usize { + return Err(Error::InvalidMessage); + } + let flags = match VhostUserConfigFlags::from_bits(msg.flags) { + Some(val) => val, + None => return Err(Error::InvalidMessage), + }; + let res = self.backend.get_config(msg.offset, msg.size, flags); + + // vhost-user slave's payload size MUST match master's request + // on success, uses zero length of payload to indicate an error + // to vhost-user master. + match res { + Ok(ref buf) if buf.len() == msg.size as usize => { + let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags); + self.send_reply_with_payload(hdr, &reply, buf.as_slice())?; + } + Ok(_) => { + let reply = VhostUserConfig::new(msg.offset, 0, flags); + self.send_reply_message(hdr, &reply)?; + } + Err(_) => { + let reply = VhostUserConfig::new(msg.offset, 0, flags); + self.send_reply_message(hdr, &reply)?; + } + } + Ok(()) + } + + fn set_config(&mut self, size: usize, buf: &[u8]) -> Result<()> { + if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() { + return Err(Error::InvalidMessage); + } + // SAFETY: Safe because we checked that `buf` size is at least that of VhostUserConfig. + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + if size - mem::size_of::<VhostUserConfig>() != msg.size as usize { + return Err(Error::InvalidMessage); + } + let flags = VhostUserConfigFlags::from_bits(msg.flags).ok_or(Error::InvalidMessage)?; + + self.backend + .set_config(msg.offset, &buf[mem::size_of::<VhostUserConfig>()..], flags) + } + + fn set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> { + let file = take_single_file(files).ok_or(Error::InvalidMessage)?; + // SAFETY: Safe because we have ownership of the files that were + // checked when received. We have to trust that they are Unix sockets + // since we have no way to check this. If not, it will fail later. + let sock = unsafe { UnixStream::from_raw_fd(file.into_raw_fd()) }; + let slave = Slave::from_stream(sock); + self.backend.set_slave_req_fd(slave); + Ok(()) + } + + fn handle_vring_fd_request( + &mut self, + buf: &[u8], + files: Option<Vec<File>>, + ) -> Result<(u8, Option<File>)> { + if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() { + return Err(Error::InvalidMessage); + } + // SAFETY: Safe because we checked that `buf` size is at least that of VhostUserU64. + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + + // Bits (0-7) of the payload contain the vring index. Bit 8 is the + // invalid FD flag. This bit is set when there is no file descriptor + // in the ancillary data. This signals that polling will be used + // instead of waiting for the call. + // If Bit 8 is unset, the data must contain a file descriptor. + let has_fd = (msg.value & 0x100u64) == 0; + + let file = take_single_file(files); + + if has_fd && file.is_none() || !has_fd && file.is_some() { + return Err(Error::InvalidMessage); + } + + Ok((msg.value as u8, file)) + } + + fn check_state(&self) -> Result<()> { + match self.error { + Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), + None => Ok(()), + } + } + + fn check_request_size( + &self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + expected: usize, + ) -> Result<()> { + if hdr.get_size() as usize != expected + || hdr.is_reply() + || hdr.get_version() != 0x1 + || size != expected + { + return Err(Error::InvalidMessage); + } + Ok(()) + } + + fn check_attached_files( + &self, + hdr: &VhostUserMsgHeader<MasterReq>, + files: &Option<Vec<File>>, + ) -> Result<()> { + match hdr.get_code() { + Ok( + MasterReq::SET_MEM_TABLE + | MasterReq::SET_VRING_CALL + | MasterReq::SET_VRING_KICK + | MasterReq::SET_VRING_ERR + | MasterReq::SET_LOG_BASE + | MasterReq::SET_LOG_FD + | MasterReq::SET_SLAVE_REQ_FD + | MasterReq::SET_INFLIGHT_FD + | MasterReq::ADD_MEM_REG, + ) => Ok(()), + _ if files.is_some() => Err(Error::InvalidMessage), + _ => Ok(()), + } + } + + fn extract_request_body<T: Sized + VhostUserMsgValidator>( + &self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + buf: &[u8], + ) -> Result<T> { + self.check_request_size(hdr, size, mem::size_of::<T>())?; + // SAFETY: Safe because we checked that `buf` size is equal to T size. + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + Ok(msg) + } + + fn update_reply_ack_flag(&mut self) { + let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let pflag = VhostUserProtocolFeatures::REPLY_ACK; + if (self.virtio_features & vflag) != 0 + && self.protocol_features.contains(pflag) + && (self.acked_protocol_features & pflag.bits()) != 0 + { + self.reply_ack_enabled = true; + } else { + self.reply_ack_enabled = false; + } + } + + fn new_reply_header<T: Sized>( + &self, + req: &VhostUserMsgHeader<MasterReq>, + payload_size: usize, + ) -> Result<VhostUserMsgHeader<MasterReq>> { + if mem::size_of::<T>() > MAX_MSG_SIZE + || payload_size > MAX_MSG_SIZE + || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE + { + return Err(Error::InvalidParam); + } + self.check_state()?; + Ok(VhostUserMsgHeader::new( + req.get_code()?, + VhostUserHeaderFlag::REPLY.bits(), + (mem::size_of::<T>() + payload_size) as u32, + )) + } + + fn send_ack_message( + &mut self, + req: &VhostUserMsgHeader<MasterReq>, + res: Result<()>, + ) -> Result<()> { + if self.reply_ack_enabled && req.is_need_reply() { + let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?; + let val = match res { + Ok(_) => 0, + Err(_) => 1, + }; + let msg = VhostUserU64::new(val); + self.main_sock.send_message(&hdr, &msg, None)?; + } + res + } + + fn send_reply_message<T: ByteValued>( + &mut self, + req: &VhostUserMsgHeader<MasterReq>, + msg: &T, + ) -> Result<()> { + let hdr = self.new_reply_header::<T>(req, 0)?; + self.main_sock.send_message(&hdr, msg, None)?; + Ok(()) + } + + fn send_reply_with_payload<T: ByteValued>( + &mut self, + req: &VhostUserMsgHeader<MasterReq>, + msg: &T, + payload: &[u8], + ) -> Result<()> { + let hdr = self.new_reply_header::<T>(req, payload.len())?; + self.main_sock + .send_message_with_payload(&hdr, msg, payload, None)?; + Ok(()) + } +} + +impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> { + fn as_raw_fd(&self) -> RawFd { + self.main_sock.as_raw_fd() + } +} + +#[cfg(test)] +mod tests { + use std::os::unix::io::AsRawFd; + + use super::*; + use crate::vhost_user::dummy_slave::DummySlaveReqHandler; + + #[test] + fn test_slave_req_handler_new() { + let (p1, _p2) = UnixStream::pair().unwrap(); + let endpoint = Endpoint::<MasterReq>::from_stream(p1); + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let mut handler = SlaveReqHandler::new(endpoint, backend); + + handler.check_state().unwrap(); + handler.set_failed(libc::EAGAIN); + handler.check_state().unwrap_err(); + assert!(handler.as_raw_fd() >= 0); + } +} diff --git a/src/vring.rs b/src/vring.rs new file mode 100644 index 0000000..13e08ac --- /dev/null +++ b/src/vring.rs @@ -0,0 +1,585 @@ +// Copyright 2019 Intel Corporation. All Rights Reserved. +// Copyright 2021 Alibaba Cloud Computing. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 + +//! Struct to maintain state information and manipulate vhost-user queues. + +use std::fs::File; +use std::io; +use std::ops::{Deref, DerefMut}; +use std::os::unix::io::{FromRawFd, IntoRawFd}; +use std::result::Result; +use std::sync::atomic::Ordering; +use std::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}; + +use virtio_queue::{Error as VirtQueError, Queue, QueueT}; +use vm_memory::{GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap}; +use vmm_sys_util::eventfd::EventFd; + +/// Trait for objects returned by `VringT::get_ref()`. +pub trait VringStateGuard<'a, M: GuestAddressSpace> { + /// Type for guard returned by `VringT::get_ref()`. + type G: Deref<Target = VringState<M>>; +} + +/// Trait for objects returned by `VringT::get_mut()`. +pub trait VringStateMutGuard<'a, M: GuestAddressSpace> { + /// Type for guard returned by `VringT::get_mut()`. + type G: DerefMut<Target = VringState<M>>; +} + +pub trait VringT<M: GuestAddressSpace>: + for<'a> VringStateGuard<'a, M> + for<'a> VringStateMutGuard<'a, M> +{ + /// Create a new instance of Vring. + fn new(mem: M, max_queue_size: u16) -> Result<Self, VirtQueError> + where + Self: Sized; + + /// Get an immutable reference to the kick event fd. + fn get_ref(&self) -> <Self as VringStateGuard<M>>::G; + + /// Get a mutable reference to the kick event fd. + fn get_mut(&self) -> <Self as VringStateMutGuard<M>>::G; + + /// Add an used descriptor into the used queue. + fn add_used(&self, desc_index: u16, len: u32) -> Result<(), VirtQueError>; + + /// Notify the vhost-user master that used descriptors have been put into the used queue. + fn signal_used_queue(&self) -> io::Result<()>; + + /// Enable event notification for queue. + fn enable_notification(&self) -> Result<bool, VirtQueError>; + + /// Disable event notification for queue. + fn disable_notification(&self) -> Result<(), VirtQueError>; + + /// Check whether a notification to the guest is needed. + fn needs_notification(&self) -> Result<bool, VirtQueError>; + + /// Set vring enabled state. + fn set_enabled(&self, enabled: bool); + + /// Set queue addresses for descriptor table, available ring and used ring. + fn set_queue_info( + &self, + desc_table: u64, + avail_ring: u64, + used_ring: u64, + ) -> Result<(), VirtQueError>; + + /// Get queue next avail head. + fn queue_next_avail(&self) -> u16; + + /// Set queue next avail head. + fn set_queue_next_avail(&self, base: u16); + + /// Set queue next used head. + fn set_queue_next_used(&self, idx: u16); + + /// Get queue next used head index from the guest memory. + fn queue_used_idx(&self) -> Result<u16, VirtQueError>; + + /// Set configured queue size. + fn set_queue_size(&self, num: u16); + + /// Enable/disable queue event index feature. + fn set_queue_event_idx(&self, enabled: bool); + + /// Set queue enabled state. + fn set_queue_ready(&self, ready: bool); + + /// Set `EventFd` for kick. + fn set_kick(&self, file: Option<File>); + + /// Read event from the kick `EventFd`. + fn read_kick(&self) -> io::Result<bool>; + + /// Set `EventFd` for call. + fn set_call(&self, file: Option<File>); + + /// Set `EventFd` for err. + fn set_err(&self, file: Option<File>); +} + +/// Struct to maintain raw state information for a vhost-user queue. +/// +/// This struct maintains all information of a virito queue, and could be used as an `VringT` +/// object for single-threaded context. +pub struct VringState<M: GuestAddressSpace = GuestMemoryAtomic<GuestMemoryMmap>> { + queue: Queue, + kick: Option<EventFd>, + call: Option<EventFd>, + err: Option<EventFd>, + enabled: bool, + mem: M, +} + +impl<M: GuestAddressSpace> VringState<M> { + /// Create a new instance of Vring. + fn new(mem: M, max_queue_size: u16) -> Result<Self, VirtQueError> { + Ok(VringState { + queue: Queue::new(max_queue_size)?, + kick: None, + call: None, + err: None, + enabled: false, + mem, + }) + } + + /// Get an immutable reference to the underlying raw `Queue` object. + pub fn get_queue(&self) -> &Queue { + &self.queue + } + + /// Get a mutable reference to the underlying raw `Queue` object. + pub fn get_queue_mut(&mut self) -> &mut Queue { + &mut self.queue + } + + /// Add an used descriptor into the used queue. + pub fn add_used(&mut self, desc_index: u16, len: u32) -> Result<(), VirtQueError> { + self.queue + .add_used(self.mem.memory().deref(), desc_index, len) + } + + /// Notify the vhost-user master that used descriptors have been put into the used queue. + pub fn signal_used_queue(&self) -> io::Result<()> { + if let Some(call) = self.call.as_ref() { + call.write(1) + } else { + Ok(()) + } + } + + /// Enable event notification for queue. + pub fn enable_notification(&mut self) -> Result<bool, VirtQueError> { + self.queue.enable_notification(self.mem.memory().deref()) + } + + /// Disable event notification for queue. + pub fn disable_notification(&mut self) -> Result<(), VirtQueError> { + self.queue.disable_notification(self.mem.memory().deref()) + } + + /// Check whether a notification to the guest is needed. + pub fn needs_notification(&mut self) -> Result<bool, VirtQueError> { + self.queue.needs_notification(self.mem.memory().deref()) + } + + /// Set vring enabled state. + pub fn set_enabled(&mut self, enabled: bool) { + self.enabled = enabled; + } + + /// Set queue addresses for descriptor table, available ring and used ring. + pub fn set_queue_info( + &mut self, + desc_table: u64, + avail_ring: u64, + used_ring: u64, + ) -> Result<(), VirtQueError> { + self.queue + .try_set_desc_table_address(GuestAddress(desc_table))?; + self.queue + .try_set_avail_ring_address(GuestAddress(avail_ring))?; + self.queue + .try_set_used_ring_address(GuestAddress(used_ring)) + } + + /// Get queue next avail head. + fn queue_next_avail(&self) -> u16 { + self.queue.next_avail() + } + + /// Set queue next avail head. + fn set_queue_next_avail(&mut self, base: u16) { + self.queue.set_next_avail(base); + } + + /// Set queue next used head. + fn set_queue_next_used(&mut self, idx: u16) { + self.queue.set_next_used(idx); + } + + /// Get queue next used head index from the guest memory. + fn queue_used_idx(&self) -> Result<u16, VirtQueError> { + self.queue + .used_idx(self.mem.memory().deref(), Ordering::Relaxed) + .map(|idx| idx.0) + } + + /// Set configured queue size. + fn set_queue_size(&mut self, num: u16) { + self.queue.set_size(num); + } + + /// Enable/disable queue event index feature. + fn set_queue_event_idx(&mut self, enabled: bool) { + self.queue.set_event_idx(enabled); + } + + /// Set queue enabled state. + fn set_queue_ready(&mut self, ready: bool) { + self.queue.set_ready(ready); + } + + /// Get the `EventFd` for kick. + pub fn get_kick(&self) -> &Option<EventFd> { + &self.kick + } + + /// Set `EventFd` for kick. + fn set_kick(&mut self, file: Option<File>) { + // SAFETY: + // EventFd requires that it has sole ownership of its fd. So does File, so this is safe. + // Ideally, we'd have a generic way to refer to a uniquely-owned fd, such as that proposed + // by Rust RFC #3128. + self.kick = file.map(|f| unsafe { EventFd::from_raw_fd(f.into_raw_fd()) }); + } + + /// Read event from the kick `EventFd`. + fn read_kick(&self) -> io::Result<bool> { + if let Some(kick) = &self.kick { + kick.read()?; + } + + Ok(self.enabled) + } + + /// Set `EventFd` for call. + fn set_call(&mut self, file: Option<File>) { + // SAFETY: see comment in set_kick() + self.call = file.map(|f| unsafe { EventFd::from_raw_fd(f.into_raw_fd()) }); + } + + /// Get the `EventFd` for call. + pub fn get_call(&self) -> &Option<EventFd> { + &self.call + } + + /// Set `EventFd` for err. + fn set_err(&mut self, file: Option<File>) { + // SAFETY: see comment in set_kick() + self.err = file.map(|f| unsafe { EventFd::from_raw_fd(f.into_raw_fd()) }); + } +} + +/// A `VringState` object protected by Mutex for multi-threading context. +#[derive(Clone)] +pub struct VringMutex<M: GuestAddressSpace = GuestMemoryAtomic<GuestMemoryMmap>> { + state: Arc<Mutex<VringState<M>>>, +} + +impl<M: GuestAddressSpace> VringMutex<M> { + /// Get a mutable guard to the underlying raw `VringState` object. + fn lock(&self) -> MutexGuard<VringState<M>> { + self.state.lock().unwrap() + } +} + +impl<'a, M: 'a + GuestAddressSpace> VringStateGuard<'a, M> for VringMutex<M> { + type G = MutexGuard<'a, VringState<M>>; +} + +impl<'a, M: 'a + GuestAddressSpace> VringStateMutGuard<'a, M> for VringMutex<M> { + type G = MutexGuard<'a, VringState<M>>; +} + +impl<M: 'static + GuestAddressSpace> VringT<M> for VringMutex<M> { + fn new(mem: M, max_queue_size: u16) -> Result<Self, VirtQueError> { + Ok(VringMutex { + state: Arc::new(Mutex::new(VringState::new(mem, max_queue_size)?)), + }) + } + + fn get_ref(&self) -> <Self as VringStateGuard<M>>::G { + self.state.lock().unwrap() + } + + fn get_mut(&self) -> <Self as VringStateMutGuard<M>>::G { + self.lock() + } + + fn add_used(&self, desc_index: u16, len: u32) -> Result<(), VirtQueError> { + self.lock().add_used(desc_index, len) + } + + fn signal_used_queue(&self) -> io::Result<()> { + self.get_ref().signal_used_queue() + } + + fn enable_notification(&self) -> Result<bool, VirtQueError> { + self.lock().enable_notification() + } + + fn disable_notification(&self) -> Result<(), VirtQueError> { + self.lock().disable_notification() + } + + fn needs_notification(&self) -> Result<bool, VirtQueError> { + self.lock().needs_notification() + } + + fn set_enabled(&self, enabled: bool) { + self.lock().set_enabled(enabled) + } + + fn set_queue_info( + &self, + desc_table: u64, + avail_ring: u64, + used_ring: u64, + ) -> Result<(), VirtQueError> { + self.lock() + .set_queue_info(desc_table, avail_ring, used_ring) + } + + fn queue_next_avail(&self) -> u16 { + self.get_ref().queue_next_avail() + } + + fn set_queue_next_avail(&self, base: u16) { + self.lock().set_queue_next_avail(base) + } + + fn set_queue_next_used(&self, idx: u16) { + self.lock().set_queue_next_used(idx) + } + + fn queue_used_idx(&self) -> Result<u16, VirtQueError> { + self.lock().queue_used_idx() + } + + fn set_queue_size(&self, num: u16) { + self.lock().set_queue_size(num); + } + + fn set_queue_event_idx(&self, enabled: bool) { + self.lock().set_queue_event_idx(enabled); + } + + fn set_queue_ready(&self, ready: bool) { + self.lock().set_queue_ready(ready); + } + + fn set_kick(&self, file: Option<File>) { + self.lock().set_kick(file); + } + + fn read_kick(&self) -> io::Result<bool> { + self.get_ref().read_kick() + } + + fn set_call(&self, file: Option<File>) { + self.lock().set_call(file) + } + + fn set_err(&self, file: Option<File>) { + self.lock().set_err(file) + } +} + +/// A `VringState` object protected by RwLock for multi-threading context. +#[derive(Clone)] +pub struct VringRwLock<M: GuestAddressSpace = GuestMemoryAtomic<GuestMemoryMmap>> { + state: Arc<RwLock<VringState<M>>>, +} + +impl<M: GuestAddressSpace> VringRwLock<M> { + /// Get a mutable guard to the underlying raw `VringState` object. + fn write_lock(&self) -> RwLockWriteGuard<VringState<M>> { + self.state.write().unwrap() + } +} + +impl<'a, M: 'a + GuestAddressSpace> VringStateGuard<'a, M> for VringRwLock<M> { + type G = RwLockReadGuard<'a, VringState<M>>; +} + +impl<'a, M: 'a + GuestAddressSpace> VringStateMutGuard<'a, M> for VringRwLock<M> { + type G = RwLockWriteGuard<'a, VringState<M>>; +} + +impl<M: 'static + GuestAddressSpace> VringT<M> for VringRwLock<M> { + fn new(mem: M, max_queue_size: u16) -> Result<Self, VirtQueError> { + Ok(VringRwLock { + state: Arc::new(RwLock::new(VringState::new(mem, max_queue_size)?)), + }) + } + + fn get_ref(&self) -> <Self as VringStateGuard<M>>::G { + self.state.read().unwrap() + } + + fn get_mut(&self) -> <Self as VringStateMutGuard<M>>::G { + self.write_lock() + } + + fn add_used(&self, desc_index: u16, len: u32) -> Result<(), VirtQueError> { + self.write_lock().add_used(desc_index, len) + } + + fn signal_used_queue(&self) -> io::Result<()> { + self.get_ref().signal_used_queue() + } + + fn enable_notification(&self) -> Result<bool, VirtQueError> { + self.write_lock().enable_notification() + } + + fn disable_notification(&self) -> Result<(), VirtQueError> { + self.write_lock().disable_notification() + } + + fn needs_notification(&self) -> Result<bool, VirtQueError> { + self.write_lock().needs_notification() + } + + fn set_enabled(&self, enabled: bool) { + self.write_lock().set_enabled(enabled) + } + + fn set_queue_info( + &self, + desc_table: u64, + avail_ring: u64, + used_ring: u64, + ) -> Result<(), VirtQueError> { + self.write_lock() + .set_queue_info(desc_table, avail_ring, used_ring) + } + + fn queue_next_avail(&self) -> u16 { + self.get_ref().queue_next_avail() + } + + fn set_queue_next_avail(&self, base: u16) { + self.write_lock().set_queue_next_avail(base) + } + + fn set_queue_next_used(&self, idx: u16) { + self.write_lock().set_queue_next_used(idx) + } + + fn queue_used_idx(&self) -> Result<u16, VirtQueError> { + self.get_ref().queue_used_idx() + } + + fn set_queue_size(&self, num: u16) { + self.write_lock().set_queue_size(num); + } + + fn set_queue_event_idx(&self, enabled: bool) { + self.write_lock().set_queue_event_idx(enabled); + } + + fn set_queue_ready(&self, ready: bool) { + self.write_lock().set_queue_ready(ready); + } + + fn set_kick(&self, file: Option<File>) { + self.write_lock().set_kick(file); + } + + fn read_kick(&self) -> io::Result<bool> { + self.get_ref().read_kick() + } + + fn set_call(&self, file: Option<File>) { + self.write_lock().set_call(file) + } + + fn set_err(&self, file: Option<File>) { + self.write_lock().set_err(file) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::os::unix::io::AsRawFd; + use vm_memory::bitmap::AtomicBitmap; + use vmm_sys_util::eventfd::EventFd; + + #[test] + fn test_new_vring() { + let mem = GuestMemoryAtomic::new( + GuestMemoryMmap::<AtomicBitmap>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]) + .unwrap(), + ); + let vring = VringMutex::new(mem, 0x1000).unwrap(); + + assert!(vring.get_ref().get_kick().is_none()); + assert!(!vring.get_mut().enabled); + assert!(!vring.lock().queue.ready()); + assert!(!vring.lock().queue.event_idx_enabled()); + + vring.set_enabled(true); + assert!(vring.get_ref().enabled); + + vring.set_queue_info(0x100100, 0x100200, 0x100300).unwrap(); + assert_eq!(vring.lock().get_queue().desc_table(), 0x100100); + assert_eq!(vring.lock().get_queue().avail_ring(), 0x100200); + assert_eq!(vring.lock().get_queue().used_ring(), 0x100300); + + assert_eq!(vring.queue_next_avail(), 0); + vring.set_queue_next_avail(0x20); + assert_eq!(vring.queue_next_avail(), 0x20); + + vring.set_queue_size(0x200); + assert_eq!(vring.lock().queue.size(), 0x200); + + vring.set_queue_event_idx(true); + assert!(vring.lock().queue.event_idx_enabled()); + + vring.set_queue_ready(true); + assert!(vring.lock().queue.ready()); + } + + #[test] + fn test_vring_set_fd() { + let mem = GuestMemoryAtomic::new( + GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(), + ); + let vring = VringMutex::new(mem, 0x1000).unwrap(); + + vring.set_enabled(true); + assert!(vring.get_ref().enabled); + + let eventfd = EventFd::new(0).unwrap(); + // SAFETY: Safe because we panic before if eventfd is not valid. + let file = unsafe { File::from_raw_fd(eventfd.as_raw_fd()) }; + assert!(vring.get_mut().kick.is_none()); + assert!(vring.read_kick().unwrap()); + vring.set_kick(Some(file)); + eventfd.write(1).unwrap(); + assert!(vring.read_kick().unwrap()); + assert!(vring.get_ref().kick.is_some()); + vring.set_kick(None); + assert!(vring.get_ref().kick.is_none()); + std::mem::forget(eventfd); + + let eventfd = EventFd::new(0).unwrap(); + // SAFETY: Safe because we panic before if eventfd is not valid. + let file = unsafe { File::from_raw_fd(eventfd.as_raw_fd()) }; + assert!(vring.get_ref().call.is_none()); + vring.set_call(Some(file)); + assert!(vring.get_ref().call.is_some()); + vring.set_call(None); + assert!(vring.get_ref().call.is_none()); + std::mem::forget(eventfd); + + let eventfd = EventFd::new(0).unwrap(); + // SAFETY: Safe because we panic before if eventfd is not valid. + let file = unsafe { File::from_raw_fd(eventfd.as_raw_fd()) }; + assert!(vring.get_ref().err.is_none()); + vring.set_err(Some(file)); + assert!(vring.get_ref().err.is_some()); + vring.set_err(None); + assert!(vring.get_ref().err.is_none()); + std::mem::forget(eventfd); + } +} diff --git a/src/vsock.rs b/src/vsock.rs new file mode 100644 index 0000000..1e1b0b9 --- /dev/null +++ b/src/vsock.rs @@ -0,0 +1,30 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause +// +// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Portions Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-Google file. + +//! Trait to control vhost-vsock backend drivers. + +use crate::backend::VhostBackend; +use crate::Result; + +/// Trait to control vhost-vsock backend drivers. +pub trait VhostVsock: VhostBackend { + /// Set the CID for the guest. + /// This number is used for routing all data destined for running in the guest. + /// Each guest on a hypervisor must have an unique CID. + /// + /// # Arguments + /// * `cid` - CID to assign to the guest + fn set_guest_cid(&self, cid: u64) -> Result<()>; + + /// Tell the VHOST driver to start performing data transfer. + fn start(&self) -> Result<()>; + + /// Tell the VHOST driver to stop performing data transfer. + fn stop(&self) -> Result<()>; +} diff --git a/tests/vhost-user-server.rs b/tests/vhost-user-server.rs new file mode 100644 index 0000000..f6fdea7 --- /dev/null +++ b/tests/vhost-user-server.rs @@ -0,0 +1,292 @@ +use std::ffi::CString; +use std::fs::File; +use std::io::Result; +use std::os::unix::io::{AsRawFd, FromRawFd}; +use std::os::unix::net::UnixStream; +use std::path::Path; +use std::sync::{Arc, Barrier, Mutex}; +use std::thread; + +use vhost::vhost_user::message::{ + VhostUserConfigFlags, VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures, +}; +use vhost::vhost_user::{Listener, Master, Slave, VhostUserMaster}; +use vhost::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData}; +use vhost_user_backend::{VhostUserBackendMut, VhostUserDaemon, VringRwLock}; +use vm_memory::{ + FileOffset, GuestAddress, GuestAddressSpace, GuestMemory, GuestMemoryAtomic, GuestMemoryMmap, +}; +use vmm_sys_util::epoll::EventSet; +use vmm_sys_util::eventfd::EventFd; + +struct MockVhostBackend { + events: u64, + event_idx: bool, + acked_features: u64, +} + +impl MockVhostBackend { + fn new() -> Self { + MockVhostBackend { + events: 0, + event_idx: false, + acked_features: 0, + } + } +} + +impl VhostUserBackendMut<VringRwLock, ()> for MockVhostBackend { + fn num_queues(&self) -> usize { + 2 + } + + fn max_queue_size(&self) -> usize { + 256 + } + + fn features(&self) -> u64 { + 0xffff_ffff_ffff_ffff + } + + fn acked_features(&mut self, features: u64) { + self.acked_features = features; + } + + fn protocol_features(&self) -> VhostUserProtocolFeatures { + VhostUserProtocolFeatures::all() + } + + fn set_event_idx(&mut self, enabled: bool) { + self.event_idx = enabled; + } + + fn get_config(&self, offset: u32, size: u32) -> Vec<u8> { + assert_eq!(offset, 0x200); + assert_eq!(size, 8); + + vec![0xa5u8; 8] + } + + fn set_config(&mut self, offset: u32, buf: &[u8]) -> Result<()> { + assert_eq!(offset, 0x200); + assert_eq!(buf, &[0xa5u8; 8]); + + Ok(()) + } + + fn update_memory(&mut self, atomic_mem: GuestMemoryAtomic<GuestMemoryMmap>) -> Result<()> { + let mem = atomic_mem.memory(); + let region = mem.find_region(GuestAddress(0x100000)).unwrap(); + assert_eq!(region.size(), 0x100000); + Ok(()) + } + + fn set_slave_req_fd(&mut self, _slave: Slave) {} + + fn queues_per_thread(&self) -> Vec<u64> { + vec![1, 1] + } + + fn exit_event(&self, _thread_index: usize) -> Option<EventFd> { + let event_fd = EventFd::new(0).unwrap(); + + Some(event_fd) + } + + fn handle_event( + &mut self, + _device_event: u16, + _evset: EventSet, + _vrings: &[VringRwLock], + _thread_id: usize, + ) -> Result<bool> { + self.events += 1; + + Ok(false) + } +} + +fn setup_master(path: &Path, barrier: Arc<Barrier>) -> Master { + barrier.wait(); + let mut master = Master::connect(path, 1).unwrap(); + master.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); + // Wait before issue service requests. + barrier.wait(); + + let features = master.get_features().unwrap(); + let proto = master.get_protocol_features().unwrap(); + master.set_features(features).unwrap(); + master.set_protocol_features(proto).unwrap(); + assert!(proto.contains(VhostUserProtocolFeatures::REPLY_ACK)); + + master +} + +fn vhost_user_client(path: &Path, barrier: Arc<Barrier>) { + barrier.wait(); + let mut master = Master::connect(path, 1).unwrap(); + master.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); + // Wait before issue service requests. + barrier.wait(); + + let features = master.get_features().unwrap(); + let proto = master.get_protocol_features().unwrap(); + master.set_features(features).unwrap(); + master.set_protocol_features(proto).unwrap(); + assert!(proto.contains(VhostUserProtocolFeatures::REPLY_ACK)); + + let queue_num = master.get_queue_num().unwrap(); + assert_eq!(queue_num, 2); + + master.set_owner().unwrap(); + //master.set_owner().unwrap_err(); + master.reset_owner().unwrap(); + master.reset_owner().unwrap(); + master.set_owner().unwrap(); + + master.set_features(features).unwrap(); + master.set_protocol_features(proto).unwrap(); + assert!(proto.contains(VhostUserProtocolFeatures::REPLY_ACK)); + + let memfd = nix::sys::memfd::memfd_create( + &CString::new("test").unwrap(), + nix::sys::memfd::MemFdCreateFlag::empty(), + ) + .unwrap(); + // SAFETY: Safe because we panic before if memfd is not valid. + let file = unsafe { File::from_raw_fd(memfd) }; + file.set_len(0x100000).unwrap(); + let file_offset = FileOffset::new(file, 0); + let mem = GuestMemoryMmap::<()>::from_ranges_with_files(&[( + GuestAddress(0x100000), + 0x100000, + Some(file_offset), + )]) + .unwrap(); + let addr = mem.get_host_address(GuestAddress(0x100000)).unwrap() as u64; + let reg = mem.find_region(GuestAddress(0x100000)).unwrap(); + let fd = reg.file_offset().unwrap(); + let regions = [VhostUserMemoryRegionInfo::new( + 0x100000, + 0x100000, + addr, + 0, + fd.file().as_raw_fd(), + )]; + master.set_mem_table(®ions).unwrap(); + + master.set_vring_num(0, 256).unwrap(); + + let config = VringConfigData { + queue_max_size: 256, + queue_size: 256, + flags: 0, + desc_table_addr: addr, + used_ring_addr: addr + 0x10000, + avail_ring_addr: addr + 0x20000, + log_addr: None, + }; + master.set_vring_addr(0, &config).unwrap(); + + let eventfd = EventFd::new(0).unwrap(); + master.set_vring_kick(0, &eventfd).unwrap(); + master.set_vring_call(0, &eventfd).unwrap(); + master.set_vring_err(0, &eventfd).unwrap(); + master.set_vring_enable(0, true).unwrap(); + + let buf = [0u8; 8]; + let (_cfg, data) = master + .get_config(0x200, 8, VhostUserConfigFlags::empty(), &buf) + .unwrap(); + assert_eq!(&data, &[0xa5u8; 8]); + master + .set_config(0x200, VhostUserConfigFlags::empty(), &data) + .unwrap(); + + let (tx, _rx) = UnixStream::pair().unwrap(); + master.set_slave_request_fd(&tx).unwrap(); + + let state = master.get_vring_base(0).unwrap(); + master.set_vring_base(0, state as u16).unwrap(); + + assert_eq!(master.get_max_mem_slots().unwrap(), 32); + let region = VhostUserMemoryRegionInfo::new(0x800000, 0x100000, addr, 0, fd.file().as_raw_fd()); + master.add_mem_region(®ion).unwrap(); + master.remove_mem_region(®ion).unwrap(); +} + +fn vhost_user_server(cb: fn(&Path, Arc<Barrier>)) { + let mem = GuestMemoryAtomic::new(GuestMemoryMmap::<()>::new()); + let backend = Arc::new(Mutex::new(MockVhostBackend::new())); + let mut daemon = VhostUserDaemon::new("test".to_owned(), backend, mem).unwrap(); + + let barrier = Arc::new(Barrier::new(2)); + let tmpdir = tempfile::tempdir().unwrap(); + let mut path = tmpdir.path().to_path_buf(); + path.push("socket"); + + let barrier2 = barrier.clone(); + let path1 = path.clone(); + let thread = thread::spawn(move || cb(&path1, barrier2)); + + let listener = Listener::new(&path, false).unwrap(); + barrier.wait(); + daemon.start(listener).unwrap(); + barrier.wait(); + + // handle service requests from clients. + thread.join().unwrap(); +} + +#[test] +fn test_vhost_user_server() { + vhost_user_server(vhost_user_client); +} + +fn vhost_user_enable(path: &Path, barrier: Arc<Barrier>) { + let master = setup_master(path, barrier); + master.set_owner().unwrap(); + master.set_owner().unwrap_err(); +} + +#[test] +fn test_vhost_user_enable() { + vhost_user_server(vhost_user_enable); +} + +fn vhost_user_set_inflight(path: &Path, barrier: Arc<Barrier>) { + let mut master = setup_master(path, barrier); + let eventfd = EventFd::new(0).unwrap(); + // No implementation for inflight_fd yet. + let inflight = VhostUserInflight { + mmap_size: 0x100000, + mmap_offset: 0, + num_queues: 1, + queue_size: 256, + }; + master + .set_inflight_fd(&inflight, eventfd.as_raw_fd()) + .unwrap_err(); +} + +#[test] +fn test_vhost_user_set_inflight() { + vhost_user_server(vhost_user_set_inflight); +} + +fn vhost_user_get_inflight(path: &Path, barrier: Arc<Barrier>) { + let mut master = setup_master(path, barrier); + // No implementation for inflight_fd yet. + let inflight = VhostUserInflight { + mmap_size: 0x100000, + mmap_offset: 0, + num_queues: 1, + queue_size: 256, + }; + assert!(master.get_inflight_fd(&inflight).is_err()); +} + +#[test] +fn test_vhost_user_get_inflight() { + vhost_user_server(vhost_user_get_inflight); +} |