aboutsummaryrefslogtreecommitdiff
path: root/cros_async/src/sys/windows/tokio_source.rs
diff options
context:
space:
mode:
Diffstat (limited to 'cros_async/src/sys/windows/tokio_source.rs')
-rw-r--r--cros_async/src/sys/windows/tokio_source.rs291
1 files changed, 291 insertions, 0 deletions
diff --git a/cros_async/src/sys/windows/tokio_source.rs b/cros_async/src/sys/windows/tokio_source.rs
new file mode 100644
index 000000000..676edb7c2
--- /dev/null
+++ b/cros_async/src/sys/windows/tokio_source.rs
@@ -0,0 +1,291 @@
+// Copyright 2022 The ChromiumOS Authors
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::fs::File;
+use std::io;
+use std::io::Read;
+use std::io::Seek;
+use std::io::SeekFrom;
+use std::io::Write;
+use std::mem::ManuallyDrop;
+use std::sync::Arc;
+
+use base::AsRawDescriptor;
+use base::FileReadWriteAtVolatile;
+use base::FileReadWriteVolatile;
+use base::FromRawDescriptor;
+use base::PunchHole;
+use base::VolatileSlice;
+use base::WriteZeroesAt;
+use smallvec::SmallVec;
+use sync::Mutex;
+
+use crate::mem::MemRegion;
+use crate::AsyncError;
+use crate::AsyncResult;
+use crate::BackingMemory;
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error("An error occurred trying to seek: {0}.")]
+ IoSeekError(io::Error),
+ #[error("An error occurred trying to read: {0}.")]
+ IoReadError(io::Error),
+ #[error("An error occurred trying to write: {0}.")]
+ IoWriteError(io::Error),
+ #[error("An error occurred trying to flush: {0}.")]
+ IoFlushError(io::Error),
+ #[error("An error occurred trying to punch hole: {0}.")]
+ IoPunchHoleError(io::Error),
+ #[error("An error occurred trying to write zeroes: {0}.")]
+ IoWriteZeroesError(io::Error),
+ #[error("Failed to join task: '{0}'")]
+ Join(tokio::task::JoinError),
+ #[error("An error occurred trying to duplicate source handles: {0}.")]
+ HandleDuplicationFailed(io::Error),
+ #[error("An error occurred trying to wait on source handles: {0}.")]
+ HandleWaitFailed(base::Error),
+ #[error("An error occurred trying to get a VolatileSlice into BackingMemory: {0}.")]
+ BackingMemoryVolatileSliceFetchFailed(crate::mem::Error),
+ #[error("TokioSource is gone, so no handles are available to fulfill the IO request.")]
+ NoTokioSource,
+ #[error("Operation on TokioSource is cancelled.")]
+ OperationCancelled,
+ #[error("Operation on TokioSource was aborted (unexpected).")]
+ OperationAborted,
+}
+
+impl From<Error> for AsyncError {
+ fn from(e: Error) -> AsyncError {
+ AsyncError::SysVariants(e.into())
+ }
+}
+
+impl From<Error> for io::Error {
+ fn from(e: Error) -> Self {
+ use Error::*;
+ match e {
+ IoSeekError(e) => e,
+ IoReadError(e) => e,
+ IoWriteError(e) => e,
+ IoFlushError(e) => e,
+ IoPunchHoleError(e) => e,
+ IoWriteZeroesError(e) => e,
+ Join(e) => io::Error::new(io::ErrorKind::Other, e),
+ HandleDuplicationFailed(e) => e,
+ HandleWaitFailed(e) => e.into(),
+ BackingMemoryVolatileSliceFetchFailed(e) => io::Error::new(io::ErrorKind::Other, e),
+ NoTokioSource => io::Error::new(io::ErrorKind::Other, NoTokioSource),
+ OperationCancelled => io::Error::new(io::ErrorKind::Interrupted, OperationCancelled),
+ OperationAborted => io::Error::new(io::ErrorKind::Interrupted, OperationAborted),
+ }
+ }
+}
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+pub struct TokioSource<T: AsRawDescriptor> {
+ source: Option<T>,
+ source_file: Arc<Mutex<Option<ManuallyDrop<File>>>>,
+ runtime: tokio::runtime::Handle,
+}
+
+impl<T: AsRawDescriptor> TokioSource<T> {
+ pub(crate) fn new(source: T, runtime: tokio::runtime::Handle) -> Result<TokioSource<T>> {
+ let descriptor = source.as_raw_descriptor();
+ // SAFETY: The Drop implementation makes sure `source` outlives `source_file`.
+ let source_file = unsafe { ManuallyDrop::new(File::from_raw_descriptor(descriptor)) };
+ Ok(Self {
+ source: Some(source),
+ source_file: Arc::new(Mutex::new(Some(source_file))),
+ runtime,
+ })
+ }
+ #[inline]
+ fn get_slices(
+ mem: &Arc<dyn BackingMemory + Send + Sync>,
+ mem_offsets: Vec<MemRegion>,
+ ) -> Result<SmallVec<[VolatileSlice<'_>; 16]>> {
+ mem_offsets
+ .into_iter()
+ .map(|region| {
+ mem.get_volatile_slice(region)
+ .map_err(Error::BackingMemoryVolatileSliceFetchFailed)
+ })
+ .collect::<Result<SmallVec<[VolatileSlice; 16]>>>()
+ }
+ pub fn as_source(&self) -> &T {
+ self.source.as_ref().unwrap()
+ }
+ pub fn as_source_mut(&mut self) -> &mut T {
+ self.source.as_mut().unwrap()
+ }
+ pub async fn fdatasync(&self) -> AsyncResult<()> {
+ // TODO(b/282003931): Fall back to regular fsync.
+ self.fsync().await
+ }
+ pub async fn fsync(&self) -> AsyncResult<()> {
+ let source_file = self.source_file.clone();
+ Ok(self
+ .runtime
+ .spawn_blocking(move || {
+ source_file
+ .lock()
+ .as_mut()
+ .ok_or(Error::OperationCancelled)?
+ .flush()
+ .map_err(Error::IoFlushError)
+ })
+ .await
+ .map_err(Error::Join)??)
+ }
+ pub fn into_source(mut self) -> T {
+ self.source_file.lock().take();
+ self.source.take().unwrap()
+ }
+ pub async fn punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
+ let source_file = self.source_file.clone();
+ Ok(self
+ .runtime
+ .spawn_blocking(move || {
+ source_file
+ .lock()
+ .as_mut()
+ .ok_or(Error::OperationCancelled)?
+ .punch_hole(file_offset, len)
+ .map_err(Error::IoPunchHoleError)
+ })
+ .await
+ .map_err(Error::Join)??)
+ }
+ pub async fn read_to_mem(
+ &self,
+ file_offset: Option<u64>,
+ mem: Arc<dyn BackingMemory + Send + Sync>,
+ mem_offsets: impl IntoIterator<Item = MemRegion>,
+ ) -> AsyncResult<usize> {
+ let mem_offsets = mem_offsets.into_iter().collect();
+ let source_file = self.source_file.clone();
+ Ok(self
+ .runtime
+ .spawn_blocking(move || {
+ let mut file_lock = source_file.lock();
+ let file = file_lock.as_mut().ok_or(Error::OperationCancelled)?;
+ let memory_slices = Self::get_slices(&mem, mem_offsets)?;
+ match file_offset {
+ Some(file_offset) => file
+ .read_vectored_at_volatile(memory_slices.as_slice(), file_offset)
+ .map_err(Error::IoReadError),
+ None => file
+ .read_vectored_volatile(memory_slices.as_slice())
+ .map_err(Error::IoReadError),
+ }
+ })
+ .await
+ .map_err(Error::Join)??)
+ }
+ pub async fn read_to_vec(
+ &self,
+ file_offset: Option<u64>,
+ mut vec: Vec<u8>,
+ ) -> AsyncResult<(usize, Vec<u8>)> {
+ let source_file = self.source_file.clone();
+ Ok(self
+ .runtime
+ .spawn_blocking(move || {
+ let mut file_lock = source_file.lock();
+ let file = file_lock.as_mut().ok_or(Error::OperationCancelled)?;
+ if let Some(file_offset) = file_offset {
+ file.seek(SeekFrom::Start(file_offset))
+ .map_err(Error::IoSeekError)?;
+ }
+ Ok::<(usize, Vec<u8>), Error>((
+ file.read(vec.as_mut_slice()).map_err(Error::IoReadError)?,
+ vec,
+ ))
+ })
+ .await
+ .map_err(Error::Join)??)
+ }
+ pub async fn wait_readable(&self) -> AsyncResult<()> {
+ unimplemented!();
+ }
+ pub async fn wait_for_handle(&self) -> AsyncResult<()> {
+ let waiter = super::wait_for_handle::WaitForHandle::new(self.source.as_ref().unwrap());
+ Ok(waiter.await?)
+ }
+ pub async fn write_from_mem(
+ &self,
+ file_offset: Option<u64>,
+ mem: Arc<dyn BackingMemory + Send + Sync>,
+ mem_offsets: impl IntoIterator<Item = MemRegion>,
+ ) -> AsyncResult<usize> {
+ let mem_offsets = mem_offsets.into_iter().collect();
+ let source_file = self.source_file.clone();
+ Ok(self
+ .runtime
+ .spawn_blocking(move || {
+ let mut file_lock = source_file.lock();
+ let file = file_lock.as_mut().ok_or(Error::OperationCancelled)?;
+ let memory_slices = Self::get_slices(&mem, mem_offsets)?;
+ match file_offset {
+ Some(file_offset) => file
+ .write_vectored_at_volatile(memory_slices.as_slice(), file_offset)
+ .map_err(Error::IoWriteError),
+ None => file
+ .write_vectored_volatile(memory_slices.as_slice())
+ .map_err(Error::IoWriteError),
+ }
+ })
+ .await
+ .map_err(Error::Join)??)
+ }
+ pub async fn write_from_vec(
+ &self,
+ file_offset: Option<u64>,
+ vec: Vec<u8>,
+ ) -> AsyncResult<(usize, Vec<u8>)> {
+ let source_file = self.source_file.clone();
+ Ok(self
+ .runtime
+ .spawn_blocking(move || {
+ let mut file_lock = source_file.lock();
+ let file = file_lock.as_mut().ok_or(Error::OperationCancelled)?;
+ if let Some(file_offset) = file_offset {
+ file.seek(SeekFrom::Start(file_offset))
+ .map_err(Error::IoSeekError)?;
+ }
+ Ok::<(usize, Vec<u8>), Error>((
+ file.write(vec.as_slice()).map_err(Error::IoWriteError)?,
+ vec,
+ ))
+ })
+ .await
+ .map_err(Error::Join)??)
+ }
+ pub async fn write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
+ let source_file = self.source_file.clone();
+ Ok(self
+ .runtime
+ .spawn_blocking(move || {
+ // ZeroRange calls `punch_hole` which doesn't extend the File size if it needs to.
+ // Will fix if it becomes a problem.
+ source_file
+ .lock()
+ .as_mut()
+ .ok_or(Error::OperationCancelled)?
+ .write_zeroes_at(file_offset, len as usize)
+ .map_err(Error::IoWriteZeroesError)
+ .map(|_| ())
+ })
+ .await
+ .map_err(Error::Join)??)
+ }
+}
+impl<T: AsRawDescriptor> Drop for TokioSource<T> {
+ fn drop(&mut self) {
+ let mut source_file = self.source_file.lock();
+ source_file.take();
+ }
+}