aboutsummaryrefslogtreecommitdiff
path: root/pw_stream/rust/pw_stream/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'pw_stream/rust/pw_stream/lib.rs')
-rw-r--r--pw_stream/rust/pw_stream/lib.rs360
1 files changed, 360 insertions, 0 deletions
diff --git a/pw_stream/rust/pw_stream/lib.rs b/pw_stream/rust/pw_stream/lib.rs
new file mode 100644
index 000000000..ba60c230f
--- /dev/null
+++ b/pw_stream/rust/pw_stream/lib.rs
@@ -0,0 +1,360 @@
+// Copyright 2023 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+
+//! `pw_stream` provides `no_std` versions of Rust's [`std::io::Read`],
+//! [`std::io::Write`], and [`std::io::Seek`] traits as well as a simplified
+//! version of [`std::io::Cursor`]. One notable difference is that
+//! [`pw_status::Error`] is used to avoid needing to do error conversion or
+//! encapsulation.
+#![deny(missing_docs)]
+// Allows docs to reference `std`
+#![cfg_attr(feature = "no_std", no_std)]
+
+use pw_status::{Error, Result};
+
+#[doc(hidden)]
+mod cursor;
+mod integer;
+
+pub use cursor::Cursor;
+pub use integer::{ReadInteger, ReadVarint, WriteInteger, WriteVarint};
+
+/// A trait for objects that provide streaming read capability.
+pub trait Read {
+ /// Read from a stream into a buffer.
+ ///
+ /// Semantics match [`std::io::Read::read()`].
+ fn read(&mut self, buf: &mut [u8]) -> Result<usize>;
+
+ /// Read exactly enough bytes to fill the buffer.
+ ///
+ /// Semantics match [`std::io::Read::read_exact()`].
+ fn read_exact(&mut self, mut buf: &mut [u8]) -> Result<()> {
+ while !buf.is_empty() {
+ let len = self.read(buf)?;
+
+ // End of stream
+ if len == 0 {
+ break;
+ }
+
+ buf = &mut buf[len..];
+ }
+
+ if !buf.is_empty() {
+ Err(Error::OutOfRange)
+ } else {
+ Ok(())
+ }
+ }
+}
+
+/// A trait for objects that provide streaming write capability.
+pub trait Write {
+ /// Write a buffer to a stream.
+ ///
+ /// Semantics match [`std::io::Write::write()`].
+ fn write(&mut self, buf: &[u8]) -> Result<usize>;
+
+ /// Commit any outstanding buffered writes to underlying storage.
+ ///
+ /// Semantics match [`std::io::Write::flush()`].
+ fn flush(&mut self) -> Result<()>;
+
+ /// Writes entire buffer to stream.
+ ///
+ /// Semantics match [`std::io::Write::write_all()`].
+ fn write_all(&mut self, mut buf: &[u8]) -> Result<()> {
+ while !buf.is_empty() {
+ let len = self.write(buf)?;
+
+ // End of stream
+ if len == 0 {
+ break;
+ }
+
+ buf = &buf[len..];
+ }
+
+ if !buf.is_empty() {
+ Err(Error::OutOfRange)
+ } else {
+ Ok(())
+ }
+ }
+}
+
+/// A description of a seek operation in a stream.
+///
+/// While `pw_stream` targets embedded platforms which are often natively
+/// 32 bit, we believe that seek operation are relatively rare and the added
+/// overhead of using 64 bit values for seeks is balanced by the ability
+/// to support objects and operations over 4 GiB.
+pub enum SeekFrom {
+ /// Seek from the start of the stream.
+ Start(u64),
+
+ /// Seek from the end of the stream.
+ End(i64),
+
+ /// Seek from the current position of the stream.
+ Current(i64),
+}
+
+/// A trait for objects that provide the ability to seek withing a stream.
+pub trait Seek {
+ /// Adjust the current position of the stream.
+ ///
+ /// Semantics match [`std::io::Seek::seek()`].
+ fn seek(&mut self, pos: SeekFrom) -> Result<u64>;
+
+ /// Set the current position of the stream to its beginning.
+ ///
+ /// Semantics match [`std::io::Seek::rewind()`].
+ fn rewind(&mut self) -> Result<()> {
+ self.seek(SeekFrom::Start(0)).map(|_| ())
+ }
+
+ /// Returns the length of the stream.
+ ///
+ /// Semantics match [`std::io::Seek::stream_len()`].
+ fn stream_len(&mut self) -> Result<u64> {
+ // Save original position.
+ let orig_pos = self.seek(SeekFrom::Current(0))?;
+
+ // Seed to the end to discover stream length.
+ let end_pos = self.seek(SeekFrom::End(0))?;
+
+ // Go back to original position.
+ self.seek(SeekFrom::Start(orig_pos))?;
+
+ Ok(end_pos)
+ }
+
+ /// Returns the current position of the stream.
+ ///
+ /// Semantics match [`std::io::Seek::stream_position()`].
+ fn stream_position(&mut self) -> Result<u64> {
+ self.seek(SeekFrom::Current(0))
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use super::{Seek, SeekFrom};
+
+ pub(crate) fn test_rewind_resets_position_to_zero<const LEN: u64, T: Seek>(mut seeker: T) {
+ seeker.seek(SeekFrom::Current(LEN as i64 / 2)).unwrap();
+ assert_eq!(seeker.stream_position().unwrap(), LEN / 2);
+ seeker.rewind().unwrap();
+ assert_eq!(seeker.stream_position().unwrap(), 0);
+ }
+
+ pub(crate) fn test_stream_pos_reports_correct_position<const LEN: u64, T: Seek>(mut seeker: T) {
+ assert_eq!(seeker.stream_position().unwrap(), 0);
+ seeker.seek(SeekFrom::Current(1)).unwrap();
+ assert_eq!(seeker.stream_position().unwrap(), 1);
+ seeker.seek(SeekFrom::Current(LEN as i64 / 2 - 1)).unwrap();
+ assert_eq!(seeker.stream_position().unwrap(), LEN / 2);
+ seeker.seek(SeekFrom::Current(0)).unwrap();
+ assert_eq!(seeker.stream_position().unwrap(), LEN / 2);
+ seeker.seek(SeekFrom::End(0)).unwrap();
+ assert_eq!(seeker.stream_position().unwrap(), LEN);
+ }
+
+ pub(crate) fn test_stream_len_reports_correct_length<const LEN: u64, T: Seek>(mut seeker: T) {
+ assert_eq!(seeker.stream_len().unwrap(), LEN);
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use core::cmp::min;
+
+ use pw_status::Error;
+
+ use super::test_utils::*;
+ use super::*;
+
+ struct TestSeeker {
+ len: u64,
+ pos: u64,
+ }
+
+ impl Seek for TestSeeker {
+ fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
+ let new_pos = match pos {
+ SeekFrom::Start(pos) => pos,
+ SeekFrom::Current(pos) => {
+ self.pos.checked_add_signed(pos).ok_or(Error::OutOfRange)?
+ }
+ SeekFrom::End(pos) => self.len.checked_add_signed(-pos).ok_or(Error::OutOfRange)?,
+ };
+
+ if new_pos > self.len {
+ Err(Error::OutOfRange)
+ } else {
+ self.pos = new_pos;
+ Ok(new_pos)
+ }
+ }
+ }
+
+ // A stream wrapper that limits reads and writes to a maximum chunk size.
+ struct ChunkedStreamAdapter<S: Read + Write + Seek> {
+ inner: S,
+ chunk_size: usize,
+ num_reads: u32,
+ num_writes: u32,
+ }
+
+ impl<S: Read + Write + Seek> ChunkedStreamAdapter<S> {
+ fn new(inner: S, chunk_size: usize) -> Self {
+ Self {
+ inner,
+ chunk_size,
+ num_reads: 0,
+ num_writes: 0,
+ }
+ }
+ }
+
+ impl<S: Read + Write + Seek> Read for ChunkedStreamAdapter<S> {
+ fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
+ let read_len = min(self.chunk_size, buf.len());
+ self.num_reads += 1;
+ self.inner.read(&mut buf[..read_len])
+ }
+ }
+
+ impl<S: Read + Write + Seek> Write for ChunkedStreamAdapter<S> {
+ fn write(&mut self, buf: &[u8]) -> Result<usize> {
+ let write_len = min(self.chunk_size, buf.len());
+ self.num_writes += 1;
+ self.inner.write(&buf[..write_len])
+ }
+
+ fn flush(&mut self) -> Result<()> {
+ self.inner.flush()
+ }
+ }
+
+ struct ErrorStream {
+ error: Error,
+ }
+
+ impl Read for ErrorStream {
+ fn read(&mut self, _buf: &mut [u8]) -> Result<usize> {
+ Err(self.error)
+ }
+ }
+
+ impl Write for ErrorStream {
+ fn write(&mut self, _buf: &[u8]) -> Result<usize> {
+ Err(self.error)
+ }
+
+ fn flush(&mut self) -> Result<()> {
+ Err(self.error)
+ }
+ }
+
+ #[test]
+ fn default_rewind_impl_resets_position_to_zero() {
+ test_rewind_resets_position_to_zero::<64, _>(TestSeeker { len: 64, pos: 0 });
+ }
+
+ #[test]
+ fn default_stream_pos_impl_reports_correct_position() {
+ test_stream_pos_reports_correct_position::<64, _>(TestSeeker { len: 64, pos: 0 });
+ }
+
+ #[test]
+ fn default_stream_len_impl_reports_correct_length() {
+ test_stream_len_reports_correct_length::<64, _>(TestSeeker { len: 64, pos: 32 });
+ }
+
+ #[test]
+ fn read_exact_reads_full_buffer_on_short_reads() {
+ let cursor = Cursor::new((0x0..=0xff).collect::<Vec<u8>>());
+ // Limit reads to 10 bytes per read.
+ let mut wrapper = ChunkedStreamAdapter::new(cursor, 10);
+ let mut read_buffer = vec![0u8; 256];
+
+ wrapper.read_exact(&mut read_buffer).unwrap();
+
+ // Ensure that the correct bytes were read.
+ assert_eq!(wrapper.inner.into_inner(), read_buffer);
+
+ // Verify that the read was broken up into the correct number of reads.
+ assert_eq!(wrapper.num_reads, 26);
+ }
+
+ #[test]
+ fn read_exact_returns_error_on_too_little_data() {
+ let cursor = Cursor::new((0x0..=0x7f).collect::<Vec<u8>>());
+ // Limit reads to 10 bytes per read.
+ let mut wrapper = ChunkedStreamAdapter::new(cursor, 10);
+ let mut read_buffer = vec![0u8; 256];
+
+ assert_eq!(wrapper.read_exact(&mut read_buffer), Err(Error::OutOfRange));
+ }
+
+ #[test]
+ fn read_exact_propagates_read_errors() {
+ let mut error_stream = ErrorStream {
+ error: Error::Internal,
+ };
+ let mut read_buffer = vec![0u8; 256];
+ assert_eq!(
+ error_stream.read_exact(&mut read_buffer),
+ Err(Error::Internal)
+ );
+ }
+
+ #[test]
+ fn write_all_writes_full_buffer_on_short_writes() {
+ let cursor = Cursor::new(vec![0u8; 256]);
+ // Limit writes to 10 bytes per write.
+ let mut wrapper = ChunkedStreamAdapter::new(cursor, 10);
+ let write_buffer = (0x0..=0xff).collect::<Vec<u8>>();
+
+ wrapper.write_all(&write_buffer).unwrap();
+
+ // Ensure that the correct bytes were written.
+ assert_eq!(wrapper.inner.into_inner(), write_buffer);
+
+ // Verify that the write was broken up into the correct number of writes.
+ assert_eq!(wrapper.num_writes, 26);
+ }
+
+ #[test]
+ fn write_all_returns_error_on_too_little_data() {
+ let cursor = Cursor::new(vec![0u8; 128]);
+ // Limit writes to 10 bytes per write.
+ let mut wrapper = ChunkedStreamAdapter::new(cursor, 10);
+ let write_buffer = (0x0..=0xff).collect::<Vec<u8>>();
+
+ assert_eq!(wrapper.write_all(&write_buffer), Err(Error::OutOfRange));
+ }
+
+ #[test]
+ fn write_all_propagates_write_errors() {
+ let mut error_stream = ErrorStream {
+ error: Error::Internal,
+ };
+ let write_buffer = (0x0..=0xff).collect::<Vec<u8>>();
+ assert_eq!(error_stream.write_all(&write_buffer), Err(Error::Internal));
+ }
+}