diff options
author | Roman Yepishev <ryepishev@google.com> | 2024-02-29 20:48:38 +0000 |
---|---|---|
committer | Roman Yepishev <ryepishev@google.com> | 2024-02-29 20:48:38 +0000 |
commit | 22ce981acab277a6809f53856ee2733d282bea31 (patch) | |
tree | 802782cc8230ccf2ea9ca80a0b4cf2cb2a430159 | |
parent | 051c16650263ae4e8305547e5a92d4fc441445e9 (diff) | |
parent | fabac8ea399f5da14f27b41eb543e75542904b78 (diff) | |
download | smoltcp-22ce981acab277a6809f53856ee2733d282bea31.tar.gz |
Merge remote-tracking branch 'origin/upstream'main
117 files changed, 54097 insertions, 0 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json new file mode 100644 index 0000000..4b22af8 --- /dev/null +++ b/.cargo_vcs_info.json @@ -0,0 +1,6 @@ +{ + "git": { + "sha1": "ce420118efff83b47767389500ef1562f5074b55" + }, + "path_in_vcs": "" +}
\ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..41ca801 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/target +Cargo.lock +*.pcap diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..4cf517e --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,296 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +No unreleased changes yet. + +## [0.11.0] - 2023-12-23 + +### Additions + +- wire/ipsec: add basic IPsec parsing/emitting ([#821](https://github.com/smoltcp-rs/smoltcp/pull/821)). +- phy: add support for `TUNSETIFF` on MIPS, PPC and SPARC ([#839](https://github.com/smoltcp-rs/smoltcp/pull/839)). +- socket/tcp: accept FIN on zero window ([#845](https://github.com/smoltcp-rs/smoltcp/pull/845)). +- wire/ipv6: add `is_unique_local()` to IPv6 addresses ([#862](https://github.com/smoltcp-rs/smoltcp/pull/862)). +- wire/ipv6: add `is_global_unicast()` to IPv6 addresses ([#864](https://github.com/smoltcp-rs/smoltcp/pull/864)). +- iface/neigh: add `fill_with_expiration` ([#871](https://github.com/smoltcp-rs/smoltcp/pull/871)). + +### Fixes + +- icmpv6: truncate packet to MTU ([#807](https://github.com/smoltcp-rs/smoltcp/pull/807), [#808](https://github.com/smoltcp-rs/smoltcp/pull/810)). +- wire/rpl: DAO-ACK DODAG ID was wrongly read ([#824](https://github.com/smoltcp-rs/smoltcp/pull/824)). +- socket/tcp: don't panic when calling `listen` again on the same local endpoint ([#841](https://github.com/smoltcp-rs/smoltcp/pull/841)). +- wire/dhcpv4: don't panic when parsing addresses with incorrect amount of bytes ([#843](https://github.com/smoltcp-rs/smoltcp/pull/843)). +- iface/ndisc: prevent ndisc when the medium is IP ([#865](https://github.com/smoltcp-rs/smoltcp/pull/865)). +- wire/ieee802154: better parsing of security fields. Correctly parse frame type (3 bits instead of 2 bits) ([#868](https://github.com/smoltcp-rs/smoltcp/pull/864)). +- wire/ieee802154: better handle address fields for new frame version ([#870](https://github.com/smoltcp-rs/smoltcp/pull/870)). +- iface/tcp: don't send TCP RST with unspecified addresses ([#867](https://github.com/smoltcp-rs/smoltcp/pull/867)). +- iface: don't handle empty packets (this would panic when reading the IP version) ([#866](https://github.com/smoltcp-rs/smoltcp/pull/866)). +- socket/dhcp: Add an upper bound to the renew/rebind timeout in `RetryConfig` ([#835](https://github.com/smoltcp-rs/smoltcp/pull/835)). + +### Changes + +- iface: rewrite `IpPacket` such that IPv6 packets can contain owned extension headers ([#802](https://github.com/smoltcp-rs/smoltcp/pull/802)). +- iface: remove generic `T: [u8]` in functions. This reduced the server example by 10KB ([#810](https://github.com/smoltcp-rs/smoltcp/pull/810)). +- SocketSet: add comment about using static lifetime for SocketSets with owned storage ([#813](https://github.com/smoltcp-rs/smoltcp/pull/813)). +- phy/RawSocket: open raw socket with `O_NONBLOCK` ([#817](https://github.com/smoltcp-rs/smoltcp/pull/817)). +- tests/rstest: use rstest for fixture based testing ([#823](https://github.com/smoltcp-rs/smoltcp/pull/823)). +- docs/readme: update readme about IEEE802.15.4 and 6LoWPAN ([#826](https://github.com/smoltcp-rs/smoltcp/pull/826)). +- wire/ipv6-hbh: IPv6 HBH has owned options instead of references ([#827](https://github.com/smoltcp-rs/smoltcp/pull/827)). +- wire/sixlowpan: 6LoWPAN is split into multiple modules ([#828](https://github.com/smoltcp-rs/smoltcp/pull/828)). +- sockets: match the behaviour of `peek_slice` and `recv_slice` ([#834](https://github.com/smoltcp-rs/smoltcp/pull/834)). +- dependencies: update to headpless v0.8 ([#853](https://github.com/smoltcp-rs/smoltcp/pull/853)). +- config: make `config` constants public ([#855](https://github.com/smoltcp-rs/smoltcp/pull/855)). +- phy/ieee802154: clarify `mtu+=2` for IEEE802.15.4 ([#857](https://github.com/smoltcp-rs/smoltcp/pull/857)). +- sockets: `recv_slice` returns `RcvError::Truncated` when the length of the slice is smaller than the data received by the socket ([#859](https://github.com/smoltcp-rs/smoltcp/pull/859)). +- iface/ipv6: `get_source_address` uses [RFC 6724](https://www.rfc-editor.org/rfc/rfc6724) for address selection ([#864](https://github.com/smoltcp-rs/smoltcp/pull/864)). +- pcap: use IEEE 802.15.4 without FCS for PCAP link types ([#874](https://github.com/smoltcp-rs/smoltcp/pull/874)). +- iface: rename `IpPacket`/`Ipv4Packet`/`Ipv6Packet` to `Pacet`/`PacketV4`/`PacketV4`. This is to remove the ambiguity with `IpPacket` in `src/wire/` ([#873](https://github.com/smoltcp-rs/smoltcp/pull/873)). +- wire/ndisc: rewrite parse function (3.1KiB -> 1.9KiB) ([#878](https://github.com/smoltcp-rs/smoltcp/pull/878)) +- iface: Check IPv6 address after processing HBH ([#861](https://github.com/smoltcp-rs/smoltcp/pull/861)) + +## [0.10.0] - 2023-06-26 + +- Add optional packet metadata. Allows tracking packets by ID across the whole stack, between the `Device` impl and sockets. One application is timestamping packets with the PHY's collaboration, allowing implementing PTP (#628) +- Work-in-progress implementation of RPL (Routing Protocol for Low-Power and Lossy Networks), commonly used for IEEE 802.15.4 / 6LoWPAN networks. Wire is mostly complete, full functionality will be in 0.11 hopefully! (#627, #766, #767, #772, #773, #777, #790, #798, #804) +- dhcp: Add support for rebinding (#744) + +- iface: + - add support for sending to subnet-local broadcast addrs (like 192.168.1.255). (#801) + - Creating an interface requires passing in the time. (#799) + - fix wrong payload length of first IPv4 fragment (#791, #792) + - Don't discard from unspecified IPv4 src addresses (#787) + +- tcp: + - do not count window updates as duplicate acks. (#748) + - consider segments partially overlapping the window as acceptable (#749) + - Perform a reset() after an abort() (#788) + +- 6lowpan: + - Hop-by-Hop Header compression (#765) + - Routing Header compression (#770) + +- wire: + - reexport DNS opcode, rcode, flag. (#763, #806) + - refactor IPv6 Extension Headers to make them more consistent and easier to parse. (#781) + - check length field of NDISC redirected head (#784) + +- Modify `hardware_addr` and `neighbor_cache` to be not `Option`, add `HardwareAddress::Ip` (#745) +- Add file descriptor support for tuntap devices, needed for the Android VPN API. (#776) +- implement Display and Error for error types (#750, #756, #757) +- Better defmt for Instant, Duration and Ipv6Address (#754, #758) +- Add Hash trait for enum_with_unknown macro (#755) + +## [0.9.1] - 2023-02-08 + +- iface: make MulticastError public. (#747) +- Fix parsing of ieee802154 link layer address for NDISC options (#746) + +## [0.9.0] - 2023-02-06 + +- Minimum Supported Rust Version (MSRV) **bumped** from 1.56 to 1.65 +- Added DNS client support. + - Add DnsSocket (#465) + - Add support for one-shot mDNS resolution (#669) +- Added support for packet fragmentation and reassembly, both for IPv4 and 6LoWPAN. (#591, #580, #624, #634, #645, #653, #684) +- Major error handling overhaul. + - Previously, _smoltcp_ had a single `Error` enum that all methods returned. Now methods that can fail have their own error enums, with only the actual errors they can return. (#617, #667, #730) + - Consuming `phy::Device` tokens is now infallible. + - In the case of "buffer full", `phy::Device` implementations must return `None` from the `transmit`/`receive` methods. (Previously, they could either do that, or return tokens and then return `Error::Exhausted` when consuming them. The latter wasted computation since it'd make _smoltcp_ pointlessly spend effort preparing the packet, and is now disallowed). + - For all other phy errors, `phy::Device` implementations should drop the packet and handle the error themselves. (Either log it and forget it, or buffer/count it and offer methods to let the user retrieve the error queue/counts.) Returning the error to have it bubble up to `Interface::poll()` is no longer supported. +- phy: the `trait Device` now uses Generic Associated Types (GAT) for the TX and RX tokens. The main impact of this is `Device` impls can now borrow data (because previously, the`for<'a> T: Device<'a>` bounds required to workaround the lack of GATs essentially implied `T: 'static`.) (#572) +- iface: The `Interface` API has been significantly simplified and cleaned up. + - The builder has been removed (#736) + - SocketSet and Device are now borrowed in methods that need them, instead of owning them. (#619) + - `Interface` now owns the list of addresses (#719), routes, neighbor cache (#722), 6LoWPAN address contexts, and fragmentation buffers (#736) instead of borrowing them with `managed`. + - A new compile-time configuration mechanism has been added, to configure the size of the (now owned) buffers (#742) +- iface: Change neighbor discovery timeout from 3s to 1s, to match Linux's behavior. (#620) +- iface: Remove implicit sized bound on device generics (#679) +- iface/6lowpan: Add address context information for resolving 6LoWPAN addresses (#687) +- iface/6lowpan: fix incorrect SAM value in IPHC when address is not compressed (#630) +- iface/6lowpan: packet parsing fuzz fixes (#636) +- socket: Add send_with to udp, raw, and icmp sockets. These methods enable reserving a packet buffer with a greater size than you need, and then shrinking the size once you know it. (#625) +- socket: Make `trait AnySocket` object-safe (#718) +- socket/dhcpv4: add waker support (#623) +- socket/dhcpv4: indicate new config if there's a packet buffer provided (#685) +- socket/dhcpv4: Use renewal time from DHCP server ACK, if given (#683) +- socket/dhcpv4: allow for extra configuration + - setting arbitrary options in the request. (#650) + - retrieving arbitrary options from the response. (#650) + - setting custom parameter request list. (#650) + - setting custom timing for retries. (#650) + - Allow specifying different server/client DHCP ports (#738) +- socket/raw: Add `peek` and `peek_slice` methods (#734) +- socket/raw: When sending packets, send the source IP address unmodified (it was previously replaced with the interface's address if it was unspecified). (#616) +- socket/tcp: Do not reset socket-level settings, such as keepalive, on reset (#603) +- socket/tcp: ensure we always accept the segment at offset=0 even if the assembler is full. (#735, #452) +- socket/tcp: Refactored assembler, now more robust and faster (#726, #735) +- socket/udp: accept packets with checksum field set to `0`, since that means the checksum is not computed (#632) +- wire: make many functions const (#693) +- wire/dhcpv4: remove Option enum (#656) +- wire/dhcpv4: use heapless Vec for DNS server list (#678) +- wire/icmpv4: add support for TimeExceeded packets (#609) +- wire/ip: Remove `IpRepr::Unspecified`, `IpVersion::Unspecified`, `IpAddress::Unspecified` (#579, #616) +- wire/ip: support parsing unspecified IPv6 IpEndpoints from string (like `[::]:12345`) (#732) +- wire/ipv6: Make Public Ipv6RoutingType (#691) +- wire/ndisc: do not error on unrecognized options. (#737) +- Switch to Rust 2021 edition. (#729) +- Remove obsolete Cargo feature `rust-1_28` (#725) + +## [0.8.2] - 2022-11-27 + +- tcp: Fix return value of nagle_enable ([#642](https://github.com/smoltcp-rs/smoltcp/pull/642)) +- tcp: Only clear retransmit timer when all packets are acked ([#662](https://github.com/smoltcp-rs/smoltcp/pull/662)) +- tcp: Send incomplete fin packets even if nagle enabled ([#665](https://github.com/smoltcp-rs/smoltcp/pull/665)) +- phy: Fix mtu calculation for raw_socket ([#611](https://github.com/smoltcp-rs/smoltcp/pull/611)) +- wire: Fix ipv6 contains_addr function ([#605](https://github.com/smoltcp-rs/smoltcp/pull/605)) + +## [0.8.1] - 2022-05-12 + +- Remove unused `rand_core` dep. ([#589](https://github.com/smoltcp-rs/smoltcp/pull/589)) +- Use socklen_t instead of u32 for RawSocket bind() parameter. Fixes build on 32bit Android. ([#593](https://github.com/smoltcp-rs/smoltcp/pull/593)) +- Propagate phy::RawSocket send errors to caller ([#588](https://github.com/smoltcp-rs/smoltcp/pull/588)) +- Fix Interface set_hardware_addr, get_hardware_addr for ieee802154/6lowpan. ([#584](https://github.com/smoltcp-rs/smoltcp/pull/584)) + +## [0.8.0] - 2021-12-11 + +- Minimum Supported Rust Version (MSRV) **bumped** from 1.40 to 1.56 +- Add support for IEEE 802.15.4 + 6LoWPAN medium ([#469](https://github.com/smoltcp-rs/smoltcp/pull/469)) +- Add support for IP medium ([#401](https://github.com/smoltcp-rs/smoltcp/pull/401)) +- Add `defmt` logging support ([#455](https://github.com/smoltcp-rs/smoltcp/pull/455)) +- Add RNG infrastructure ([#547](https://github.com/smoltcp-rs/smoltcp/pull/547), [#573](https://github.com/smoltcp-rs/smoltcp/pull/573)) +- Add `Context` struct that must be passed to some socket methods ([#500](https://github.com/smoltcp-rs/smoltcp/pull/500)) +- Remove `SocketSet`, sockets are owned by `Interface` now. ([#557](https://github.com/smoltcp-rs/smoltcp/pull/557), [#571](https://github.com/smoltcp-rs/smoltcp/pull/571)) +- TCP: Add Nagle's Algorithm. ([#500](https://github.com/smoltcp-rs/smoltcp/pull/500)) +- TCP crash and correctness fixes: + - Add Nagle's Algorithm. ([#500](https://github.com/smoltcp-rs/smoltcp/pull/500)) + - Window scaling fixes. ([#500](https://github.com/smoltcp-rs/smoltcp/pull/500)) + - Fix delayed ack causing ack not to be sent after 3 packets. ([#530](https://github.com/smoltcp-rs/smoltcp/pull/530)) + - Fix RTT estimation for RTTs longer than 1 second ([#538](https://github.com/smoltcp-rs/smoltcp/pull/538)) + - Fix infinite loop when remote side sets a MSS of 0 ([#538](https://github.com/smoltcp-rs/smoltcp/pull/538)) + - Fix infinite loop when retransmit when remote window is 0 ([#538](https://github.com/smoltcp-rs/smoltcp/pull/538)) + - Fix crash when receiving a FIN in SYN_SENT state ([#538](https://github.com/smoltcp-rs/smoltcp/pull/538)) + - Fix overflow crash when receiving a wrong ACK seq in SYN_RECEIVED state ([#538](https://github.com/smoltcp-rs/smoltcp/pull/538)) + - Fix overflow crash when initial sequence number is u32::MAX ([#538](https://github.com/smoltcp-rs/smoltcp/pull/538)) + - Fix infinite loop on challenge ACKs ([#542](https://github.com/smoltcp-rs/smoltcp/pull/542)) + - Reply with RST to invalid packets in SynReceived state. ([#542](https://github.com/smoltcp-rs/smoltcp/pull/542)) + - Do not abort socket when receiving some invalid packets. ([#542](https://github.com/smoltcp-rs/smoltcp/pull/542)) + - Make initial sequence number random. ([#547](https://github.com/smoltcp-rs/smoltcp/pull/547)) + - Reply with RST to ACKs with invalid ackno in SYN_SENT. ([#522](https://github.com/smoltcp-rs/smoltcp/pull/522)) +- ARP fixes to deal better with broken networks: + - Fill cache only from ARP packets, not any packets. ([#544](https://github.com/smoltcp-rs/smoltcp/pull/544)) + - Fill cache only from ARP packets directed at us. ([#544](https://github.com/smoltcp-rs/smoltcp/pull/544)) + - Reject ARP packets with a source address not in the local network. ([#536](https://github.com/smoltcp-rs/smoltcp/pull/536), [#544](https://github.com/smoltcp-rs/smoltcp/pull/544)) + - Ignore unknown ARP packets. ([#544](https://github.com/smoltcp-rs/smoltcp/pull/544)) + - Flush neighbor cache on IP change ([#564](https://github.com/smoltcp-rs/smoltcp/pull/564)) +- UDP: Add `close()` method to unbind socket. ([#475](https://github.com/smoltcp-rs/smoltcp/pull/475), [#482](https://github.com/smoltcp-rs/smoltcp/pull/482)) +- DHCP client improvements: + - Refactored implementation to improve reliability and RFC compliance ([#459](https://github.com/smoltcp-rs/smoltcp/pull/459)) + - Convert to socket ([#459](https://github.com/smoltcp-rs/smoltcp/pull/459)) + - Added `max_lease_duration` option ([#459](https://github.com/smoltcp-rs/smoltcp/pull/459)) + - Do not set the BROADCAST flag ([#548](https://github.com/smoltcp-rs/smoltcp/pull/548)) + - Add option to ignore NAKs ([#548](https://github.com/smoltcp-rs/smoltcp/pull/548)) +- DHCP wire: + - Fix DhcpRepr::buffer_len not accounting for lease time, router and subnet options ([#478](https://github.com/smoltcp-rs/smoltcp/pull/478)) + - Emit DNS servers in DhcpRepr ([#510](https://github.com/smoltcp-rs/smoltcp/pull/510)) + - Fix incorrect bit for BROADCAST flag ([#548](https://github.com/smoltcp-rs/smoltcp/pull/548)) +- Improve resilience against packet ingress processing errors ([#281](https://github.com/smoltcp-rs/smoltcp/pull/281), [#483](https://github.com/smoltcp-rs/smoltcp/pull/483)) +- Implement `std::error::Error` for `smoltcp::Error` ([#485](https://github.com/smoltcp-rs/smoltcp/pull/485)) +- Update `managed` from 0.7 to 0.8 ([442](https://github.com/smoltcp-rs/smoltcp/pull/442)) +- Fix incorrect timestamp in PCAP captures ([#513](https://github.com/smoltcp-rs/smoltcp/pull/513)) +- Use microseconds instead of milliseconds in Instant and Duration ([#514](https://github.com/smoltcp-rs/smoltcp/pull/514)) +- Expose inner `Device` in `PcapWriter` ([#524](https://github.com/smoltcp-rs/smoltcp/pull/524)) +- Fix assert with any_ip + broadcast dst_addr. ([#533](https://github.com/smoltcp-rs/smoltcp/pull/533), [#534](https://github.com/smoltcp-rs/smoltcp/pull/534)) +- Simplify PcapSink trait ([#535](https://github.com/smoltcp-rs/smoltcp/pull/535)) +- Fix wrong operation order in FuzzInjector ([#525](https://github.com/smoltcp-rs/smoltcp/pull/525), [#535](https://github.com/smoltcp-rs/smoltcp/pull/535)) + +## [0.7.5] - 2021-06-28 + +- dhcpv4: emit DNS servers in repr ([#505](https://github.com/smoltcp-rs/smoltcp/pull/505)) + +## [0.7.4] - 2021-06-11 + +- tcp: fix "subtract sequence numbers with underflow" on remote window shrink. ([#490](https://github.com/smoltcp-rs/smoltcp/pull/490)) +- tcp: fix subtract with overflow when receiving a SYNACK with unincremented ACK number. ([#491](https://github.com/smoltcp-rs/smoltcp/pull/491)) +- tcp: use nonzero initial sequence number to workaround misbehaving servers. ([#492](https://github.com/smoltcp-rs/smoltcp/pull/492)) + +## [0.7.3] - 2021-05-29 + +- Fix "unused attribute" error in recent nightlies. + +## [0.7.2] - 2021-05-29 + +- iface: check for ipv4 subnet broadcast addrs everywhere ([#462](https://github.com/smoltcp-rs/smoltcp/pull/462)) +- dhcp: always send parameter_request_list. ([#456](https://github.com/smoltcp-rs/smoltcp/pull/456)) +- dhcp: Clear expiration time on reset. ([#456](https://github.com/smoltcp-rs/smoltcp/pull/456)) +- phy: fix FaultInjector returning a too big buffer when simulating a drop on tx ([#463](https://github.com/smoltcp-rs/smoltcp/pull/463)) +- tcp rtte: fix "attempt to multiply with overflow". ([#476](https://github.com/smoltcp-rs/smoltcp/pull/476)) +- tcp: LastAck should only change to Closed on ack of fin. ([#477](https://github.com/smoltcp-rs/smoltcp/pull/477)) +- wire/dhcpv4: account for lease time, router and subnet options in DhcpRepr::buffer_len ([#478](https://github.com/smoltcp-rs/smoltcp/pull/478)) + +## [0.7.1] - 2021-03-27 + +- ndisc: Fix NeighborSolicit incorrectly asking for src addr instead of dst addr ([419](https://github.com/smoltcp-rs/smoltcp/pull/419)) +- dhcpv4: respect lease time from the server instead of renewing every 60 seconds. ([437](https://github.com/smoltcp-rs/smoltcp/pull/437)) +- Fix build errors due to invalid combinations of features ([416](https://github.com/smoltcp-rs/smoltcp/pull/416), [447](https://github.com/smoltcp-rs/smoltcp/pull/447)) +- wire/ipv4: make some functions const ([420](https://github.com/smoltcp-rs/smoltcp/pull/420)) +- phy: fix BPF on OpenBSD ([421](https://github.com/smoltcp-rs/smoltcp/pull/421), [427](https://github.com/smoltcp-rs/smoltcp/pull/427)) +- phy: enable RawSocket, TapInterface on Android ([435](https://github.com/smoltcp-rs/smoltcp/pull/435)) +- phy: fix phy_wait for waits longer than 1 second ([449](https://github.com/smoltcp-rs/smoltcp/pull/449)) + +## [0.7.0] - 2021-01-20 + +- Minimum Supported Rust Version (MSRV) **bumped** from 1.36 to 1.40 + +### New features +- tcp: Allow distinguishing between graceful (FIN) and ungraceful (RST) close. On graceful close, `recv()` now returns `Error::Finished`. On ungraceful close, `Error::Illegal` is returned, as before. ([351](https://github.com/smoltcp-rs/smoltcp/pull/351)) +- sockets: Add support for attaching async/await Wakers to sockets. Wakers are woken on socket state changes. ([394](https://github.com/smoltcp-rs/smoltcp/pull/394)) +- tcp: Set retransmission timeout based on an RTT estimation, instead of the previously fixed 100ms. This improves performance on high-latency links, such as mobile networks. ([406](https://github.com/smoltcp-rs/smoltcp/pull/406)) +- tcp: add Delayed ACK support. On by default, with a 10ms delay. ([404](https://github.com/smoltcp-rs/smoltcp/pull/404)) +- ip: Process broadcast packets directed to the subnet's broadcast address, such as 192.168.1.255. Previously broadcast packets were +only processed when directed to the 255.255.255.255 address. ([377](https://github.com/smoltcp-rs/smoltcp/pull/377)) + +### Fixes +- udp,raw,icmp: Fix packet buffer panic caused by large payload ([332](https://github.com/smoltcp-rs/smoltcp/pull/332)) +- dhcpv4: use offered ip in requested ip option ([310](https://github.com/smoltcp-rs/smoltcp/pull/310)) +- dhcpv4: Re-export dhcp::clientv4::Config +- dhcpv4: Enable `proto-dhcpv4` feature by default. ([327](https://github.com/smoltcp-rs/smoltcp/pull/327)) +- ethernet,arp: Allow for ARP retry during egress ([368](https://github.com/smoltcp-rs/smoltcp/pull/368)) +- ethernet,arp: Only limit the neighbor cache rate after sending a request packet ([369](https://github.com/smoltcp-rs/smoltcp/pull/369)) +- tcp: use provided ip for TcpSocket::connect instead of 0.0.0.0 ([329](https://github.com/smoltcp-rs/smoltcp/pull/329)) +- tcp: Accept data packets in FIN_WAIT_2 state. ([350](https://github.com/smoltcp-rs/smoltcp/pull/350)) +- tcp: Always send updated ack number in `ack_reply()`. ([353](https://github.com/smoltcp-rs/smoltcp/pull/353)) +- tcp: allow sending ACKs in FinWait2 state. ([388](https://github.com/smoltcp-rs/smoltcp/pull/388)) +- tcp: fix racey simultaneous close not sending FIN. ([398](https://github.com/smoltcp-rs/smoltcp/pull/398)) +- tcp: Do not send window updates in states that shouldn't do so ([360](https://github.com/smoltcp-rs/smoltcp/pull/360)) +- tcp: Return RST to unexpected ACK in SYN-SENT state. ([367](https://github.com/smoltcp-rs/smoltcp/pull/367)) +- tcp: Take MTU into account during TcpSocket dispatch. ([384](https://github.com/smoltcp-rs/smoltcp/pull/384)) +- tcp: don't send data outside the remote window ([387](https://github.com/smoltcp-rs/smoltcp/pull/387)) +- phy: Take Ethernet header into account for MTU of RawSocket and TapInterface. ([393](https://github.com/smoltcp-rs/smoltcp/pull/393)) +- phy: add null terminator to c-string passed to libc API ([372](https://github.com/smoltcp-rs/smoltcp/pull/372)) + +### Quality of Life™ improvements +- Update to Rust 2018 edition ([396](https://github.com/smoltcp-rs/smoltcp/pull/396)) +- Migrate CI to Github Actions ([390](https://github.com/smoltcp-rs/smoltcp/pull/390)) +- Fix clippy lints, enforce clippy in CI ([395](https://github.com/smoltcp-rs/smoltcp/pull/395), [402](https://github.com/smoltcp-rs/smoltcp/pull/402), [403](https://github.com/smoltcp-rs/smoltcp/pull/403), [405](https://github.com/smoltcp-rs/smoltcp/pull/405), [407](https://github.com/smoltcp-rs/smoltcp/pull/407)) +- Use #[non_exhaustive] for enums and structs ([409](https://github.com/smoltcp-rs/smoltcp/pull/409), [411](https://github.com/smoltcp-rs/smoltcp/pull/411)) +- Simplify lifetime parameters of sockets, SocketSet, EthernetInterface ([410](https://github.com/smoltcp-rs/smoltcp/pull/410), [413](https://github.com/smoltcp-rs/smoltcp/pull/413)) + +[Unreleased]: https://github.com/smoltcp-rs/smoltcp/compare/v0.11.0...HEAD +[0.11.0]: https://github.com/smoltcp-rs/smoltcp/compare/v0.10.0...v0.11.0 +[0.10.0]: https://github.com/smoltcp-rs/smoltcp/compare/v0.9.1...v0.10.0 +[0.9.1]: https://github.com/smoltcp-rs/smoltcp/compare/v0.9.0...v0.9.1 +[0.9.0]: https://github.com/smoltcp-rs/smoltcp/compare/v0.8.2...v0.9.0 +[0.8.2]: https://github.com/smoltcp-rs/smoltcp/compare/v0.8.1...v0.8.2 +[0.8.1]: https://github.com/smoltcp-rs/smoltcp/compare/v0.8.0...v0.8.1 +[0.8.0]: https://github.com/smoltcp-rs/smoltcp/compare/v0.7.0...v0.8.0 +[0.7.5]: https://github.com/smoltcp-rs/smoltcp/compare/v0.7.4...v0.7.5 +[0.7.4]: https://github.com/smoltcp-rs/smoltcp/compare/v0.7.3...v0.7.4 +[0.7.3]: https://github.com/smoltcp-rs/smoltcp/compare/v0.7.2...v0.7.3 +[0.7.2]: https://github.com/smoltcp-rs/smoltcp/compare/v0.7.1...v0.7.2 +[0.7.1]: https://github.com/smoltcp-rs/smoltcp/compare/v0.7.0...v0.7.1 +[0.7.0]: https://github.com/smoltcp-rs/smoltcp/compare/v0.6.0...v0.7.0 diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..d67ab14 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,466 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2021" +rust-version = "1.65" +name = "smoltcp" +version = "0.11.0" +authors = ["whitequark <whitequark@whitequark.org>"] +autoexamples = false +description = "A TCP/IP stack designed for bare-metal, real-time systems without a heap." +homepage = "https://github.com/smoltcp-rs/smoltcp" +documentation = "https://docs.rs/smoltcp/" +readme = "README.md" +keywords = [ + "ip", + "tcp", + "udp", + "ethernet", + "network", +] +categories = [ + "embedded", + "network-programming", +] +license = "0BSD" +repository = "https://github.com/smoltcp-rs/smoltcp.git" + +[profile.release] +debug = 2 + +[[example]] +name = "packet2pcap" +path = "utils/packet2pcap.rs" +required-features = ["std"] + +[[example]] +name = "tcpdump" +required-features = [ + "std", + "phy-raw_socket", + "proto-ipv4", +] + +[[example]] +name = "httpclient" +required-features = [ + "std", + "medium-ethernet", + "medium-ip", + "phy-tuntap_interface", + "proto-ipv4", + "proto-ipv6", + "socket-tcp", +] + +[[example]] +name = "ping" +required-features = [ + "std", + "medium-ethernet", + "medium-ip", + "phy-tuntap_interface", + "proto-ipv4", + "proto-ipv6", + "socket-icmp", +] + +[[example]] +name = "server" +required-features = [ + "std", + "medium-ethernet", + "medium-ip", + "phy-tuntap_interface", + "proto-ipv4", + "socket-tcp", + "socket-udp", +] + +[[example]] +name = "client" +required-features = [ + "std", + "medium-ethernet", + "medium-ip", + "phy-tuntap_interface", + "proto-ipv4", + "socket-tcp", + "socket-udp", +] + +[[example]] +name = "loopback" +required-features = [ + "log", + "medium-ethernet", + "proto-ipv4", + "socket-tcp", +] + +[[example]] +name = "multicast" +required-features = [ + "std", + "medium-ethernet", + "medium-ip", + "phy-tuntap_interface", + "proto-ipv4", + "proto-igmp", + "socket-udp", +] + +[[example]] +name = "benchmark" +required-features = [ + "std", + "medium-ethernet", + "medium-ip", + "phy-tuntap_interface", + "proto-ipv4", + "socket-raw", + "socket-udp", +] + +[[example]] +name = "dhcp_client" +required-features = [ + "std", + "medium-ethernet", + "medium-ip", + "phy-tuntap_interface", + "proto-ipv4", + "proto-dhcpv4", + "socket-raw", +] + +[[example]] +name = "sixlowpan" +required-features = [ + "std", + "medium-ieee802154", + "phy-raw_socket", + "proto-sixlowpan", + "proto-sixlowpan-fragmentation", + "socket-udp", +] + +[[example]] +name = "sixlowpan_benchmark" +required-features = [ + "std", + "medium-ieee802154", + "phy-raw_socket", + "proto-sixlowpan", + "proto-sixlowpan-fragmentation", + "socket-udp", +] + +[[example]] +name = "dns" +required-features = [ + "std", + "medium-ethernet", + "medium-ip", + "phy-tuntap_interface", + "proto-ipv4", + "socket-dns", +] + +[dependencies.bitflags] +version = "1.0" +default-features = false + +[dependencies.byteorder] +version = "1.0" +default-features = false + +[dependencies.cfg-if] +version = "1.0.0" + +[dependencies.defmt] +version = "0.3" +optional = true + +[dependencies.heapless] +version = "0.8" + +[dependencies.libc] +version = "0.2.18" +optional = true + +[dependencies.log] +version = "0.4.4" +optional = true +default-features = false + +[dependencies.managed] +version = "0.8" +features = ["map"] +default-features = false + +[dev-dependencies.env_logger] +version = "0.10" + +[dev-dependencies.getopts] +version = "0.2" + +[dev-dependencies.rand] +version = "0.8" + +[dev-dependencies.rstest] +version = "0.17" + +[dev-dependencies.url] +version = "2.0" + +[features] +_proto-fragmentation = [] +alloc = [ + "managed/alloc", + "defmt?/alloc", +] +assembler-max-segment-count-1 = [] +assembler-max-segment-count-16 = [] +assembler-max-segment-count-2 = [] +assembler-max-segment-count-3 = [] +assembler-max-segment-count-32 = [] +assembler-max-segment-count-4 = [] +assembler-max-segment-count-8 = [] +async = [] +default = [ + "std", + "log", + "medium-ethernet", + "medium-ip", + "medium-ieee802154", + "phy-raw_socket", + "phy-tuntap_interface", + "proto-ipv4", + "proto-igmp", + "proto-dhcpv4", + "proto-ipv6", + "proto-dns", + "proto-ipv4-fragmentation", + "proto-sixlowpan-fragmentation", + "socket-raw", + "socket-icmp", + "socket-udp", + "socket-tcp", + "socket-dhcpv4", + "socket-dns", + "socket-mdns", + "packetmeta-id", + "async", +] +defmt = [ + "dep:defmt", + "heapless/defmt-03", +] +dns-max-name-size-128 = [] +dns-max-name-size-255 = [] +dns-max-name-size-64 = [] +dns-max-result-count-1 = [] +dns-max-result-count-16 = [] +dns-max-result-count-2 = [] +dns-max-result-count-3 = [] +dns-max-result-count-32 = [] +dns-max-result-count-4 = [] +dns-max-result-count-8 = [] +dns-max-server-count-1 = [] +dns-max-server-count-16 = [] +dns-max-server-count-2 = [] +dns-max-server-count-3 = [] +dns-max-server-count-32 = [] +dns-max-server-count-4 = [] +dns-max-server-count-8 = [] +fragmentation-buffer-size-1024 = [] +fragmentation-buffer-size-1500 = [] +fragmentation-buffer-size-16384 = [] +fragmentation-buffer-size-2048 = [] +fragmentation-buffer-size-256 = [] +fragmentation-buffer-size-32768 = [] +fragmentation-buffer-size-4096 = [] +fragmentation-buffer-size-512 = [] +fragmentation-buffer-size-65536 = [] +fragmentation-buffer-size-8192 = [] +iface-max-addr-count-1 = [] +iface-max-addr-count-2 = [] +iface-max-addr-count-3 = [] +iface-max-addr-count-4 = [] +iface-max-addr-count-5 = [] +iface-max-addr-count-6 = [] +iface-max-addr-count-7 = [] +iface-max-addr-count-8 = [] +iface-max-multicast-group-count-1 = [] +iface-max-multicast-group-count-1024 = [] +iface-max-multicast-group-count-128 = [] +iface-max-multicast-group-count-16 = [] +iface-max-multicast-group-count-2 = [] +iface-max-multicast-group-count-256 = [] +iface-max-multicast-group-count-3 = [] +iface-max-multicast-group-count-32 = [] +iface-max-multicast-group-count-4 = [] +iface-max-multicast-group-count-5 = [] +iface-max-multicast-group-count-512 = [] +iface-max-multicast-group-count-6 = [] +iface-max-multicast-group-count-64 = [] +iface-max-multicast-group-count-7 = [] +iface-max-multicast-group-count-8 = [] +iface-max-route-count-1 = [] +iface-max-route-count-1024 = [] +iface-max-route-count-128 = [] +iface-max-route-count-16 = [] +iface-max-route-count-2 = [] +iface-max-route-count-256 = [] +iface-max-route-count-3 = [] +iface-max-route-count-32 = [] +iface-max-route-count-4 = [] +iface-max-route-count-5 = [] +iface-max-route-count-512 = [] +iface-max-route-count-6 = [] +iface-max-route-count-64 = [] +iface-max-route-count-7 = [] +iface-max-route-count-8 = [] +iface-max-sixlowpan-address-context-count-1 = [] +iface-max-sixlowpan-address-context-count-1024 = [] +iface-max-sixlowpan-address-context-count-128 = [] +iface-max-sixlowpan-address-context-count-16 = [] +iface-max-sixlowpan-address-context-count-2 = [] +iface-max-sixlowpan-address-context-count-256 = [] +iface-max-sixlowpan-address-context-count-3 = [] +iface-max-sixlowpan-address-context-count-32 = [] +iface-max-sixlowpan-address-context-count-4 = [] +iface-max-sixlowpan-address-context-count-5 = [] +iface-max-sixlowpan-address-context-count-512 = [] +iface-max-sixlowpan-address-context-count-6 = [] +iface-max-sixlowpan-address-context-count-64 = [] +iface-max-sixlowpan-address-context-count-7 = [] +iface-max-sixlowpan-address-context-count-8 = [] +iface-neighbor-cache-count-1 = [] +iface-neighbor-cache-count-1024 = [] +iface-neighbor-cache-count-128 = [] +iface-neighbor-cache-count-16 = [] +iface-neighbor-cache-count-2 = [] +iface-neighbor-cache-count-256 = [] +iface-neighbor-cache-count-3 = [] +iface-neighbor-cache-count-32 = [] +iface-neighbor-cache-count-4 = [] +iface-neighbor-cache-count-5 = [] +iface-neighbor-cache-count-512 = [] +iface-neighbor-cache-count-6 = [] +iface-neighbor-cache-count-64 = [] +iface-neighbor-cache-count-7 = [] +iface-neighbor-cache-count-8 = [] +ipv6-hbh-max-options-1 = [] +ipv6-hbh-max-options-16 = [] +ipv6-hbh-max-options-2 = [] +ipv6-hbh-max-options-3 = [] +ipv6-hbh-max-options-32 = [] +ipv6-hbh-max-options-4 = [] +ipv6-hbh-max-options-8 = [] +medium-ethernet = ["socket"] +medium-ieee802154 = [ + "socket", + "proto-sixlowpan", +] +medium-ip = ["socket"] +packetmeta-id = [] +phy-raw_socket = [ + "std", + "libc", +] +phy-tuntap_interface = [ + "std", + "libc", + "medium-ethernet", +] +proto-dhcpv4 = ["proto-ipv4"] +proto-dns = [] +proto-igmp = ["proto-ipv4"] +proto-ipsec = [ + "proto-ipsec-ah", + "proto-ipsec-esp", +] +proto-ipsec-ah = [] +proto-ipsec-esp = [] +proto-ipv4 = [] +proto-ipv4-fragmentation = [ + "proto-ipv4", + "_proto-fragmentation", +] +proto-ipv6 = [] +proto-ipv6-fragmentation = [ + "proto-ipv6", + "_proto-fragmentation", +] +proto-ipv6-hbh = ["proto-ipv6"] +proto-ipv6-routing = ["proto-ipv6"] +proto-rpl = [ + "proto-ipv6-hbh", + "proto-ipv6-routing", +] +proto-sixlowpan = ["proto-ipv6"] +proto-sixlowpan-fragmentation = [ + "proto-sixlowpan", + "_proto-fragmentation", +] +reassembly-buffer-count-1 = [] +reassembly-buffer-count-16 = [] +reassembly-buffer-count-2 = [] +reassembly-buffer-count-3 = [] +reassembly-buffer-count-32 = [] +reassembly-buffer-count-4 = [] +reassembly-buffer-count-8 = [] +reassembly-buffer-size-1024 = [] +reassembly-buffer-size-1500 = [] +reassembly-buffer-size-16384 = [] +reassembly-buffer-size-2048 = [] +reassembly-buffer-size-256 = [] +reassembly-buffer-size-32768 = [] +reassembly-buffer-size-4096 = [] +reassembly-buffer-size-512 = [] +reassembly-buffer-size-65536 = [] +reassembly-buffer-size-8192 = [] +rpl-parents-buffer-count-16 = [] +rpl-parents-buffer-count-2 = [] +rpl-parents-buffer-count-32 = [] +rpl-parents-buffer-count-4 = [] +rpl-parents-buffer-count-8 = [] +rpl-relations-buffer-count-1 = [] +rpl-relations-buffer-count-128 = [] +rpl-relations-buffer-count-16 = [] +rpl-relations-buffer-count-2 = [] +rpl-relations-buffer-count-32 = [] +rpl-relations-buffer-count-4 = [] +rpl-relations-buffer-count-64 = [] +rpl-relations-buffer-count-8 = [] +socket = [] +socket-dhcpv4 = [ + "socket", + "medium-ethernet", + "proto-dhcpv4", +] +socket-dns = [ + "socket", + "proto-dns", +] +socket-icmp = ["socket"] +socket-mdns = ["socket-dns"] +socket-raw = ["socket"] +socket-tcp = ["socket"] +socket-udp = ["socket"] +std = [ + "managed/std", + "alloc", +] +verbose = [] diff --git a/Cargo.toml.orig b/Cargo.toml.orig new file mode 100644 index 0000000..b3cd9c9 --- /dev/null +++ b/Cargo.toml.orig @@ -0,0 +1,305 @@ +[package] +name = "smoltcp" +version = "0.11.0" +edition = "2021" +rust-version = "1.65" +authors = ["whitequark <whitequark@whitequark.org>"] +description = "A TCP/IP stack designed for bare-metal, real-time systems without a heap." +documentation = "https://docs.rs/smoltcp/" +homepage = "https://github.com/smoltcp-rs/smoltcp" +repository = "https://github.com/smoltcp-rs/smoltcp.git" +readme = "README.md" +keywords = ["ip", "tcp", "udp", "ethernet", "network"] +categories = ["embedded", "network-programming"] +license = "0BSD" +# Each example should have an explicit `[[example]]` section here to +# ensure that the correct features are enabled. +autoexamples = false + +[dependencies] +managed = { version = "0.8", default-features = false, features = ["map"] } +byteorder = { version = "1.0", default-features = false } +log = { version = "0.4.4", default-features = false, optional = true } +libc = { version = "0.2.18", optional = true } +bitflags = { version = "1.0", default-features = false } +defmt = { version = "0.3", optional = true } +cfg-if = "1.0.0" +heapless = "0.8" + +[dev-dependencies] +env_logger = "0.10" +getopts = "0.2" +rand = "0.8" +url = "2.0" +rstest = "0.17" + +[features] +std = ["managed/std", "alloc"] +alloc = ["managed/alloc", "defmt?/alloc"] +verbose = [] +defmt = ["dep:defmt", "heapless/defmt-03"] +"medium-ethernet" = ["socket"] +"medium-ip" = ["socket"] +"medium-ieee802154" = ["socket", "proto-sixlowpan"] + +"phy-raw_socket" = ["std", "libc"] +"phy-tuntap_interface" = ["std", "libc", "medium-ethernet"] + +"proto-ipv4" = [] +"proto-ipv4-fragmentation" = ["proto-ipv4", "_proto-fragmentation"] +"proto-igmp" = ["proto-ipv4"] +"proto-dhcpv4" = ["proto-ipv4"] +"proto-ipv6" = [] +"proto-ipv6-hbh" = ["proto-ipv6"] +"proto-ipv6-fragmentation" = ["proto-ipv6", "_proto-fragmentation"] +"proto-ipv6-routing" = ["proto-ipv6"] +"proto-rpl" = ["proto-ipv6-hbh", "proto-ipv6-routing"] +"proto-sixlowpan" = ["proto-ipv6"] +"proto-sixlowpan-fragmentation" = ["proto-sixlowpan", "_proto-fragmentation"] +"proto-dns" = [] +"proto-ipsec" = ["proto-ipsec-ah", "proto-ipsec-esp"] +"proto-ipsec-ah" = [] +"proto-ipsec-esp" = [] + +"socket" = [] +"socket-raw" = ["socket"] +"socket-udp" = ["socket"] +"socket-tcp" = ["socket"] +"socket-icmp" = ["socket"] +"socket-dhcpv4" = ["socket", "medium-ethernet", "proto-dhcpv4"] +"socket-dns" = ["socket", "proto-dns"] +"socket-mdns" = ["socket-dns"] + +"packetmeta-id" = [] + +"async" = [] + +default = [ + "std", "log", # needed for `cargo test --no-default-features --features default` :/ + "medium-ethernet", "medium-ip", "medium-ieee802154", + "phy-raw_socket", "phy-tuntap_interface", + "proto-ipv4", "proto-igmp", "proto-dhcpv4", "proto-ipv6", "proto-dns", + "proto-ipv4-fragmentation", "proto-sixlowpan-fragmentation", + "socket-raw", "socket-icmp", "socket-udp", "socket-tcp", "socket-dhcpv4", "socket-dns", "socket-mdns", + "packetmeta-id", "async" +] + +# Private features +# Features starting with "_" are considered private. They should not be enabled by +# other crates, and they are not considered semver-stable. + +"_proto-fragmentation" = [] + +# BEGIN AUTOGENERATED CONFIG FEATURES +# Generated by gen_config.py. DO NOT EDIT. +iface-max-addr-count-1 = [] +iface-max-addr-count-2 = [] # Default +iface-max-addr-count-3 = [] +iface-max-addr-count-4 = [] +iface-max-addr-count-5 = [] +iface-max-addr-count-6 = [] +iface-max-addr-count-7 = [] +iface-max-addr-count-8 = [] + +iface-max-multicast-group-count-1 = [] +iface-max-multicast-group-count-2 = [] +iface-max-multicast-group-count-3 = [] +iface-max-multicast-group-count-4 = [] # Default +iface-max-multicast-group-count-5 = [] +iface-max-multicast-group-count-6 = [] +iface-max-multicast-group-count-7 = [] +iface-max-multicast-group-count-8 = [] +iface-max-multicast-group-count-16 = [] +iface-max-multicast-group-count-32 = [] +iface-max-multicast-group-count-64 = [] +iface-max-multicast-group-count-128 = [] +iface-max-multicast-group-count-256 = [] +iface-max-multicast-group-count-512 = [] +iface-max-multicast-group-count-1024 = [] + +iface-max-sixlowpan-address-context-count-1 = [] +iface-max-sixlowpan-address-context-count-2 = [] +iface-max-sixlowpan-address-context-count-3 = [] +iface-max-sixlowpan-address-context-count-4 = [] # Default +iface-max-sixlowpan-address-context-count-5 = [] +iface-max-sixlowpan-address-context-count-6 = [] +iface-max-sixlowpan-address-context-count-7 = [] +iface-max-sixlowpan-address-context-count-8 = [] +iface-max-sixlowpan-address-context-count-16 = [] +iface-max-sixlowpan-address-context-count-32 = [] +iface-max-sixlowpan-address-context-count-64 = [] +iface-max-sixlowpan-address-context-count-128 = [] +iface-max-sixlowpan-address-context-count-256 = [] +iface-max-sixlowpan-address-context-count-512 = [] +iface-max-sixlowpan-address-context-count-1024 = [] + +iface-neighbor-cache-count-1 = [] +iface-neighbor-cache-count-2 = [] +iface-neighbor-cache-count-3 = [] +iface-neighbor-cache-count-4 = [] # Default +iface-neighbor-cache-count-5 = [] +iface-neighbor-cache-count-6 = [] +iface-neighbor-cache-count-7 = [] +iface-neighbor-cache-count-8 = [] +iface-neighbor-cache-count-16 = [] +iface-neighbor-cache-count-32 = [] +iface-neighbor-cache-count-64 = [] +iface-neighbor-cache-count-128 = [] +iface-neighbor-cache-count-256 = [] +iface-neighbor-cache-count-512 = [] +iface-neighbor-cache-count-1024 = [] + +iface-max-route-count-1 = [] +iface-max-route-count-2 = [] # Default +iface-max-route-count-3 = [] +iface-max-route-count-4 = [] +iface-max-route-count-5 = [] +iface-max-route-count-6 = [] +iface-max-route-count-7 = [] +iface-max-route-count-8 = [] +iface-max-route-count-16 = [] +iface-max-route-count-32 = [] +iface-max-route-count-64 = [] +iface-max-route-count-128 = [] +iface-max-route-count-256 = [] +iface-max-route-count-512 = [] +iface-max-route-count-1024 = [] + +fragmentation-buffer-size-256 = [] +fragmentation-buffer-size-512 = [] +fragmentation-buffer-size-1024 = [] +fragmentation-buffer-size-1500 = [] # Default +fragmentation-buffer-size-2048 = [] +fragmentation-buffer-size-4096 = [] +fragmentation-buffer-size-8192 = [] +fragmentation-buffer-size-16384 = [] +fragmentation-buffer-size-32768 = [] +fragmentation-buffer-size-65536 = [] + +assembler-max-segment-count-1 = [] +assembler-max-segment-count-2 = [] +assembler-max-segment-count-3 = [] +assembler-max-segment-count-4 = [] # Default +assembler-max-segment-count-8 = [] +assembler-max-segment-count-16 = [] +assembler-max-segment-count-32 = [] + +reassembly-buffer-size-256 = [] +reassembly-buffer-size-512 = [] +reassembly-buffer-size-1024 = [] +reassembly-buffer-size-1500 = [] # Default +reassembly-buffer-size-2048 = [] +reassembly-buffer-size-4096 = [] +reassembly-buffer-size-8192 = [] +reassembly-buffer-size-16384 = [] +reassembly-buffer-size-32768 = [] +reassembly-buffer-size-65536 = [] + +reassembly-buffer-count-1 = [] # Default +reassembly-buffer-count-2 = [] +reassembly-buffer-count-3 = [] +reassembly-buffer-count-4 = [] +reassembly-buffer-count-8 = [] +reassembly-buffer-count-16 = [] +reassembly-buffer-count-32 = [] + +ipv6-hbh-max-options-1 = [] # Default +ipv6-hbh-max-options-2 = [] +ipv6-hbh-max-options-3 = [] +ipv6-hbh-max-options-4 = [] +ipv6-hbh-max-options-8 = [] +ipv6-hbh-max-options-16 = [] +ipv6-hbh-max-options-32 = [] + +dns-max-result-count-1 = [] # Default +dns-max-result-count-2 = [] +dns-max-result-count-3 = [] +dns-max-result-count-4 = [] +dns-max-result-count-8 = [] +dns-max-result-count-16 = [] +dns-max-result-count-32 = [] + +dns-max-server-count-1 = [] # Default +dns-max-server-count-2 = [] +dns-max-server-count-3 = [] +dns-max-server-count-4 = [] +dns-max-server-count-8 = [] +dns-max-server-count-16 = [] +dns-max-server-count-32 = [] + +dns-max-name-size-64 = [] +dns-max-name-size-128 = [] +dns-max-name-size-255 = [] # Default + +rpl-relations-buffer-count-1 = [] +rpl-relations-buffer-count-2 = [] +rpl-relations-buffer-count-4 = [] +rpl-relations-buffer-count-8 = [] +rpl-relations-buffer-count-16 = [] # Default +rpl-relations-buffer-count-32 = [] +rpl-relations-buffer-count-64 = [] +rpl-relations-buffer-count-128 = [] + +rpl-parents-buffer-count-2 = [] +rpl-parents-buffer-count-4 = [] +rpl-parents-buffer-count-8 = [] # Default +rpl-parents-buffer-count-16 = [] +rpl-parents-buffer-count-32 = [] + +# END AUTOGENERATED CONFIG FEATURES + +[[example]] +name = "packet2pcap" +path = "utils/packet2pcap.rs" +required-features = ["std"] + +[[example]] +name = "tcpdump" +required-features = ["std", "phy-raw_socket", "proto-ipv4"] + +[[example]] +name = "httpclient" +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "proto-ipv6", "socket-tcp"] + +[[example]] +name = "ping" +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "proto-ipv6", "socket-icmp"] + +[[example]] +name = "server" +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "socket-tcp", "socket-udp"] + +[[example]] +name = "client" +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "socket-tcp", "socket-udp"] + +[[example]] +name = "loopback" +required-features = ["log", "medium-ethernet", "proto-ipv4", "socket-tcp"] + +[[example]] +name = "multicast" +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "proto-igmp", "socket-udp"] + +[[example]] +name = "benchmark" +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "socket-raw", "socket-udp"] + +[[example]] +name = "dhcp_client" +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "proto-dhcpv4", "socket-raw"] + +[[example]] +name = "sixlowpan" +required-features = ["std", "medium-ieee802154", "phy-raw_socket", "proto-sixlowpan", "proto-sixlowpan-fragmentation", "socket-udp"] + +[[example]] +name = "sixlowpan_benchmark" +required-features = ["std", "medium-ieee802154", "phy-raw_socket", "proto-sixlowpan", "proto-sixlowpan-fragmentation", "socket-udp"] + +[[example]] +name = "dns" +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "socket-dns"] + +[profile.release] +debug = 2 @@ -0,0 +1 @@ +LICENSE-0BSD.txt
\ No newline at end of file diff --git a/LICENSE-0BSD.txt b/LICENSE-0BSD.txt new file mode 100644 index 0000000..427fafa --- /dev/null +++ b/LICENSE-0BSD.txt @@ -0,0 +1,13 @@ +Copyright (C) 2016 whitequark@whitequark.org + +Permission to use, copy, modify, and/or distribute this software for +any purpose with or without fee is hereby granted. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN +AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + diff --git a/METADATA b/METADATA new file mode 100644 index 0000000..ce996b4 --- /dev/null +++ b/METADATA @@ -0,0 +1,20 @@ +name: "smoltcp" +description: "A TCP/IP stack designed for bare-metal, real-time systems without a heap." +third_party { + identifier { + type: "crates.io" + value: "smoltcp" + } + identifier { + type: "Archive" + value: "https://static.crates.io/crates/smoltcp/smoltcp-0.11.0.crate" + primary_source: true + } + version: "0.11.0" + license_type: PERMISSIVE + last_upgrade_date { + year: 2024 + month: 2 + day: 17 + } +} diff --git a/MODULE_LICENSE_ZERO_BSD b/MODULE_LICENSE_ZERO_BSD new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/MODULE_LICENSE_ZERO_BSD @@ -0,0 +1,2 @@ +# Bug component: 688011 +include platform/prebuilts/rust:main:/OWNERS diff --git a/README.md b/README.md new file mode 100644 index 0000000..74fa161 --- /dev/null +++ b/README.md @@ -0,0 +1,560 @@ +# smoltcp + +[![docs.rs](https://docs.rs/smoltcp/badge.svg)](https://docs.rs/smoltcp) +[![crates.io](https://img.shields.io/crates/v/smoltcp.svg)](https://crates.io/crates/smoltcp) +[![crates.io](https://img.shields.io/crates/d/smoltcp.svg)](https://crates.io/crates/smoltcp) +[![crates.io](https://img.shields.io/matrix/smoltcp:matrix.org)](https://matrix.to/#/#smoltcp:matrix.org) +[![codecov](https://codecov.io/github/smoltcp-rs/smoltcp/branch/master/graph/badge.svg?token=3KbAR9xH1t)](https://codecov.io/github/smoltcp-rs/smoltcp) + +_smoltcp_ is a standalone, event-driven TCP/IP stack that is designed for bare-metal, +real-time systems. Its design goals are simplicity and robustness. Its design anti-goals +include complicated compile-time computations, such as macro or type tricks, even +at cost of performance degradation. + +_smoltcp_ does not need heap allocation *at all*, is [extensively documented][docs], +and compiles on stable Rust 1.65 and later. + +_smoltcp_ achieves [~Gbps of throughput](#examplesbenchmarkrs) when tested against +the Linux TCP stack in loopback mode. + +[docs]: https://docs.rs/smoltcp/ + +## Features + +_smoltcp_ is missing many widely deployed features, usually because no one implemented them yet. +To set expectations right, both implemented and omitted features are listed. + +### Media layer + +There are 3 supported mediums. + +* Ethernet + * Regular Ethernet II frames are supported. + * Unicast, broadcast and multicast packets are supported. + * ARP packets (including gratuitous requests and replies) are supported. + * ARP requests are sent at a rate not exceeding one per second. + * Cached ARP entries expire after one minute. + * 802.3 frames and 802.1Q are **not** supported. + * Jumbo frames are **not** supported. +* IP + * Unicast, broadcast and multicast packets are supported. +* IEEE 802.15.4 + * Only support for data frames. + +### IP layer + +#### IPv4 + + * IPv4 header checksum is generated and validated. + * IPv4 time-to-live value is configurable per socket, set to 64 by default. + * IPv4 default gateway is supported. + * Routing outgoing IPv4 packets is supported, through a default gateway or a CIDR route table. + * IPv4 fragmentation and reassembly is supported. + * IPv4 options are **not** supported and are silently ignored. + +#### IPv6 + + * IPv6 hop-limit value is configurable per socket, set to 64 by default. + * Routing outgoing IPv6 packets is supported, through a default gateway or a CIDR route table. + * IPv6 hop-by-hop header is supported. + * ICMPv6 parameter problem message is generated in response to an unrecognized IPv6 next header. + * ICMPv6 parameter problem message is **not** generated in response to an unknown IPv6 + hop-by-hop option. + +#### 6LoWPAN + + * Implementation of [RFC6282](https://tools.ietf.org/rfc/rfc6282.txt). + * Fragmentation is supported, as defined in [RFC4944](https://tools.ietf.org/rfc/rfc4944.txt). + * UDP header compression/decompression is supported. + * Extension header compression/decompression is supported. + * Uncompressed IPv6 Extension Headers are **not** supported. + +### IP multicast + +#### IGMP + +The IGMPv1 and IGMPv2 protocols are supported, and IPv4 multicast is available. + + * Membership reports are sent in response to membership queries at + equal intervals equal to the maximum response time divided by the + number of groups to be reported. + +### ICMP layer + +#### ICMPv4 + +The ICMPv4 protocol is supported, and ICMP sockets are available. + + * ICMPv4 header checksum is supported. + * ICMPv4 echo replies are generated in response to echo requests. + * ICMP sockets can listen to ICMPv4 Port Unreachable messages, or any ICMPv4 messages with + a given IPv4 identifier field. + * ICMPv4 protocol unreachable messages are **not** passed to higher layers when received. + * ICMPv4 parameter problem messages are **not** generated. + +#### ICMPv6 + +The ICMPv6 protocol is supported, and ICMP sockets are available. + + * ICMPv6 header checksum is supported. + * ICMPv6 echo replies are generated in response to echo requests. + * ICMPv6 protocol unreachable messages are **not** passed to higher layers when received. + +#### NDISC + + * Neighbor Advertisement messages are generated in response to Neighbor Solicitations. + * Router Advertisement messages are **not** generated or read. + * Router Solicitation messages are **not** generated or read. + * Redirected Header messages are **not** generated or read. + +### UDP layer + +The UDP protocol is supported over IPv4 and IPv6, and UDP sockets are available. + + * Header checksum is always generated and validated. + * In response to a packet arriving at a port without a listening socket, + an ICMP destination unreachable message is generated. + +### TCP layer + +The TCP protocol is supported over IPv4 and IPv6, and server and client TCP sockets are available. + + * Header checksum is generated and validated. + * Maximum segment size is negotiated. + * Window scaling is negotiated. + * Multiple packets are transmitted without waiting for an acknowledgement. + * Reassembly of out-of-order segments is supported, with no more than 4 or 32 gaps in sequence space. + * Keep-alive packets may be sent at a configurable interval. + * Retransmission timeout starts at at an estimate of RTT, and doubles every time. + * Time-wait timeout has a fixed interval of 10 s. + * User timeout has a configurable interval. + * Delayed acknowledgements are supported, with configurable delay. + * Nagle's algorithm is implemented. + * Selective acknowledgements are **not** implemented. + * Silly window syndrome avoidance is **not** implemented. + * Congestion control is **not** implemented. + * Timestamping is **not** supported. + * Urgent pointer is **ignored**. + * Probing Zero Windows is **not** implemented. + * Packetization Layer Path MTU Discovery [PLPMTU](https://tools.ietf.org/rfc/rfc4821.txt) is **not** implemented. + +## Installation + +To use the _smoltcp_ library in your project, add the following to `Cargo.toml`: + +```toml +[dependencies] +smoltcp = "0.10.0" +``` + +The default configuration assumes a hosted environment, for ease of evaluation. +You probably want to disable default features and configure them one by one: + +```toml +[dependencies] +smoltcp = { version = "0.10.0", default-features = false, features = ["log"] } +``` + +## Feature flags + +### Feature `std` + +The `std` feature enables use of objects and slices owned by the networking stack through a +dependency on `std::boxed::Box` and `std::vec::Vec`. + +This feature is enabled by default. + +### Feature `alloc` + +The `alloc` feature enables use of objects owned by the networking stack through a dependency +on collections from the `alloc` crate. This only works on nightly rustc. + +This feature is disabled by default. + +### Feature `log` + +The `log` feature enables logging of events within the networking stack through +the [log crate][log]. Normal events (e.g. buffer level or TCP state changes) are emitted with +the TRACE log level. Exceptional events (e.g. malformed packets) are emitted with +the DEBUG log level. + +[log]: https://crates.io/crates/log + +This feature is enabled by default. + +### Feature `defmt` + +The `defmt` feature enables logging of events with the [defmt crate][defmt]. + +[defmt]: https://crates.io/crates/defmt + +This feature is disabled by default, and cannot be used at the same time as `log`. + +### Feature `verbose` + +The `verbose` feature enables logging of events where the logging itself may incur very high +overhead. For example, emitting a log line every time an application reads or writes as little +as 1 octet from a socket is likely to overwhelm the application logic unless a `BufReader` +or `BufWriter` is used, which are of course not available on heap-less systems. + +This feature is disabled by default. + +### Features `phy-raw_socket` and `phy-tuntap_interface` + +Enable `smoltcp::phy::RawSocket` and `smoltcp::phy::TunTapInterface`, respectively. + +These features are enabled by default. + +### Features `socket-raw`, `socket-udp`, `socket-tcp`, `socket-icmp`, `socket-dhcpv4`, `socket-dns` + +Enable the corresponding socket type. + +These features are enabled by default. + +### Features `proto-ipv4`, `proto-ipv6` and `proto-sixlowpan` + +Enable [IPv4], [IPv6] and [6LoWPAN] respectively. + +[IPv4]: https://tools.ietf.org/rfc/rfc791.txt +[IPv6]: https://tools.ietf.org/rfc/rfc8200.txt +[6LoWPAN]: https://tools.ietf.org/rfc/rfc6282.txt + +## Configuration + +_smoltcp_ has some configuration settings that are set at compile time, affecting sizes +and counts of buffers. + +They can be set in two ways: + +- Via Cargo features: enable a feature like `<name>-<value>`. `name` must be in lowercase and +use dashes instead of underscores. For example. `iface-max-addr-count-3`. Only a selection of values +is available, check `Cargo.toml` for the list. +- Via environment variables at build time: set the variable named `SMOLTCP_<value>`. For example +`SMOLTCP_IFACE_MAX_ADDR_COUNT=3 cargo build`. You can also set them in the `[env]` section of `.cargo/config.toml`. +Any value can be set, unlike with Cargo features. + +Environment variables take precedence over Cargo features. If two Cargo features are enabled for the same setting +with different values, compilation fails. + +### `IFACE_MAX_ADDR_COUNT` + +Max amount of IP addresses that can be assigned to one interface (counting both IPv4 and IPv6 addresses). Default: 2. + +### `IFACE_MAX_MULTICAST_GROUP_COUNT` + +Max amount of multicast groups that can be joined by one interface. Default: 4. + +### `IFACE_MAX_SIXLOWPAN_ADDRESS_CONTEXT_COUNT` + +Max amount of 6LoWPAN address contexts that can be assigned to one interface. Default: 4. + +### `IFACE_NEIGHBOR_CACHE_COUNT` + +Amount of "IP address -> hardware address" entries the neighbor cache (also known as the "ARP cache" or the "ARP table") holds. Default: 4. + +### `IFACE_MAX_ROUTE_COUNT` + +Max amount of routes that can be added to one interface. Includes the default route. Includes both IPv4 and IPv6. Default: 2. + +### `FRAGMENTATION_BUFFER_SIZE` + +Size of the buffer used for fragmenting outgoing packets larger than the MTU. Packets larger than this setting will be dropped instead of fragmented. Default: 1500. + +### `ASSEMBLER_MAX_SEGMENT_COUNT` + +Maximum number of non-contiguous segments the assembler can hold. Used for both packet reassembly and TCP stream reassembly. Default: 4. + +### `REASSEMBLY_BUFFER_SIZE` + +Size of the buffer used for reassembling (de-fragmenting) incoming packets. If the reassembled packet is larger than this setting, it will be dropped instead of reassembled. Default: 1500. + +### `REASSEMBLY_BUFFER_COUNT` + +Number of reassembly buffers, i.e how many different incoming packets can be reassembled at the same time. Default: 1. + +### `DNS_MAX_RESULT_COUNT` + +Maximum amount of address results for a given DNS query that will be kept. For example, if this is set to 2 and the queried name has 4 `A` records, only the first 2 will be returned. Default: 1. + +### `DNS_MAX_SERVER_COUNT` + +Maximum amount of DNS servers that can be configured in one DNS socket. Default: 1. + +### `DNS_MAX_NAME_SIZE` + +Maximum length of DNS names that can be queried. Default: 255. + +### IPV6_HBH_MAX_OPTIONS + +The maximum amount of parsed options the IPv6 Hop-by-Hop header can hold. Default: 1. + +## Hosted usage examples + +_smoltcp_, being a freestanding networking stack, needs to be able to transmit and receive +raw frames. For testing purposes, we will use a regular OS, and run _smoltcp_ in +a userspace process. Only Linux is supported (right now). + +On \*nix OSes, transmitting and receiving raw frames normally requires superuser privileges, but +on Linux it is possible to create a _persistent tap interface_ that can be manipulated by +a specific user: + +```sh +sudo ip tuntap add name tap0 mode tap user $USER +sudo ip link set tap0 up +sudo ip addr add 192.168.69.100/24 dev tap0 +sudo ip -6 addr add fe80::100/64 dev tap0 +sudo ip -6 addr add fdaa::100/64 dev tap0 +sudo ip -6 route add fe80::/64 dev tap0 +sudo ip -6 route add fdaa::/64 dev tap0 +``` + +It's possible to let _smoltcp_ access Internet by enabling routing for the tap interface: + +```sh +sudo iptables -t nat -A POSTROUTING -s 192.168.69.0/24 -j MASQUERADE +sudo sysctl net.ipv4.ip_forward=1 +sudo ip6tables -t nat -A POSTROUTING -s fdaa::/64 -j MASQUERADE +sudo sysctl -w net.ipv6.conf.all.forwarding=1 + +# Some distros have a default policy of DROP. This allows the traffic. +sudo iptables -A FORWARD -i tap0 -s 192.168.69.0/24 -j ACCEPT +sudo iptables -A FORWARD -o tap0 -d 192.168.69.0/24 -j ACCEPT +``` + +### Bridged connection + +Instead of the routed connection above, you may also set up a bridged (switched) +connection. This will make smoltcp speak directly to your LAN, with real ARP, etc. +It is needed to run the DHCP example. + +NOTE: In this case, the examples' IP configuration must match your LAN's! + +NOTE: this ONLY works with actual wired Ethernet connections. It +will NOT work on a WiFi connection. + +```sh +# Replace with your wired Ethernet interface name +ETH=enp0s20f0u1u1 + +sudo modprobe bridge +sudo modprobe br_netfilter + +sudo sysctl -w net.bridge.bridge-nf-call-arptables=0 +sudo sysctl -w net.bridge.bridge-nf-call-ip6tables=0 +sudo sysctl -w net.bridge.bridge-nf-call-iptables=0 + +sudo ip tuntap add name tap0 mode tap user $USER +sudo brctl addbr br0 +sudo brctl addif br0 tap0 +sudo brctl addif br0 $ETH +sudo ip link set tap0 up +sudo ip link set $ETH up +sudo ip link set br0 up + +# This connects your host system to the internet, so you can use it +# at the same time you run the examples. +sudo dhcpcd br0 +``` + +To tear down: + +``` +sudo killall dhcpcd +sudo ip link set br0 down +sudo brctl delbr br0 +``` + +### Fault injection + +In order to demonstrate the response of _smoltcp_ to adverse network conditions, all examples +implement fault injection, available through command-line options: + + * The `--drop-chance` option randomly drops packets, with given probability in percents. + * The `--corrupt-chance` option randomly mutates one octet in a packet, with given + probability in percents. + * The `--size-limit` option drops packets larger than specified size. + * The `--tx-rate-limit` and `--rx-rate-limit` options set the amount of tokens for + a token bucket rate limiter, in packets per bucket. + * The `--shaping-interval` option sets the refill interval of a token bucket rate limiter, + in milliseconds. + +A good starting value for `--drop-chance` and `--corrupt-chance` is 15%. A good starting +value for `--?x-rate-limit` is 4 and `--shaping-interval` is 50 ms. + +Note that packets dropped by the fault injector still get traced; +the `rx: randomly dropping a packet` message indicates that the packet *above* it got dropped, +and the `tx: randomly dropping a packet` message indicates that the packet *below* it was. + +### Packet dumps + +All examples provide a `--pcap` option that writes a [libpcap] file containing a view of every +packet as it is seen by _smoltcp_. + +[libpcap]: https://wiki.wireshark.org/Development/LibpcapFileFormat + +### examples/tcpdump.rs + +_examples/tcpdump.rs_ is a tiny clone of the _tcpdump_ utility. + +Unlike the rest of the examples, it uses raw sockets, and so it can be used on regular interfaces, +e.g. `eth0` or `wlan0`, as well as the `tap0` interface we've created above. + +Read its [source code](/examples/tcpdump.rs), then run it as: + +```sh +cargo build --example tcpdump +sudo ./target/debug/examples/tcpdump eth0 +``` + +### examples/httpclient.rs + +_examples/httpclient.rs_ emulates a network host that can initiate HTTP requests. + +The host is assigned the hardware address `02-00-00-00-00-02`, IPv4 address `192.168.69.1`, and IPv6 address `fdaa::1`. + +Read its [source code](/examples/httpclient.rs), then run it as: + +```sh +cargo run --example httpclient -- --tap tap0 ADDRESS URL +``` + +For example: + +```sh +cargo run --example httpclient -- --tap tap0 93.184.216.34 http://example.org/ +``` + +or: + +```sh +cargo run --example httpclient -- --tap tap0 2606:2800:220:1:248:1893:25c8:1946 http://example.org/ +``` + +It connects to the given address (not a hostname) and URL, and prints any returned response data. +The TCP socket buffers are limited to 1024 bytes to make packet traces more interesting. + +### examples/ping.rs + +_examples/ping.rs_ implements a minimal version of the `ping` utility using raw sockets. + +The host is assigned the hardware address `02-00-00-00-00-02` and IPv4 address `192.168.69.1`. + +Read its [source code](/examples/ping.rs), then run it as: + +```sh +cargo run --example ping -- --tap tap0 ADDRESS +``` + +It sends a series of 4 ICMP ECHO\_REQUEST packets to the given address at one second intervals and +prints out a status line on each valid ECHO\_RESPONSE received. + +The first ECHO\_REQUEST packet is expected to be lost since arp\_cache is empty after startup; +the ECHO\_REQUEST packet is dropped and an ARP request is sent instead. + +Currently, netmasks are not implemented, and so the only address this example can reach +is the other endpoint of the tap interface, `192.168.69.100`. It cannot reach itself because +packets entering a tap interface do not loop back. + +### examples/server.rs + +_examples/server.rs_ emulates a network host that can respond to basic requests. + +The host is assigned the hardware address `02-00-00-00-00-01` and IPv4 address `192.168.69.1`. + +Read its [source code](/examples/server.rs), then run it as: + +```sh +cargo run --example server -- --tap tap0 +``` + +It responds to: + + * pings (`ping 192.168.69.1`); + * UDP packets on port 6969 (`socat stdio udp4-connect:192.168.69.1:6969 <<<"abcdefg"`), + where it will respond with reversed chunks of the input indefinitely; + * TCP connections on port 6969 (`socat stdio tcp4-connect:192.168.69.1:6969`), + where it will respond "hello" to any incoming connection and immediately close it; + * TCP connections on port 6970 (`socat stdio tcp4-connect:192.168.69.1:6970 <<<"abcdefg"`), + where it will respond with reversed chunks of the input indefinitely. + * TCP connections on port 6971 (`socat stdio tcp4-connect:192.168.69.1:6971 </dev/urandom`), + which will sink data. Also, keep-alive packets (every 1 s) and a user timeout (at 2 s) + are enabled on this port; try to trigger them using fault injection. + * TCP connections on port 6972 (`socat stdio tcp4-connect:192.168.69.1:6972 >/dev/null`), + which will source data. + +Except for the socket on port 6971. the buffers are only 64 bytes long, for convenience +of testing resource exhaustion conditions. + +### examples/client.rs + +_examples/client.rs_ emulates a network host that can initiate basic requests. + +The host is assigned the hardware address `02-00-00-00-00-02` and IPv4 address `192.168.69.2`. + +Read its [source code](/examples/client.rs), then run it as: + +```sh +cargo run --example client -- --tap tap0 ADDRESS PORT +``` + +It connects to the given address (not a hostname) and port (e.g. `socat stdio tcp4-listen:1234`), +and will respond with reversed chunks of the input indefinitely. + +### examples/benchmark.rs + +_examples/benchmark.rs_ implements a simple throughput benchmark. + +Read its [source code](/examples/benchmark.rs), then run it as: + +```sh +cargo run --release --example benchmark -- --tap tap0 [reader|writer] +``` + +It establishes a connection to itself from a different thread and reads or writes a large amount +of data in one direction. + +A typical result (achieved on a Intel Core i7-7500U CPU and a Linux 4.9.65 x86_64 kernel running +on a Dell XPS 13 9360 laptop) is as follows: + +``` +$ cargo run -q --release --example benchmark -- --tap tap0 reader +throughput: 2.556 Gbps +$ cargo run -q --release --example benchmark -- --tap tap0 writer +throughput: 5.301 Gbps +``` + +## Bare-metal usage examples + +Examples that use no services from the host OS are necessarily less illustrative than examples +that do. Because of this, only one such example is provided. + +### examples/loopback.rs + +_examples/loopback.rs_ sets up _smoltcp_ to talk with itself via a loopback interface. +Although it does not require `std`, this example still requires the `alloc` feature to run, as well as `log`, `proto-ipv4` and `socket-tcp`. + +Read its [source code](/examples/loopback.rs), then run it without `std`: + +```sh +cargo run --example loopback --no-default-features --features="log proto-ipv4 socket-tcp alloc" +``` + +... or with `std` (in this case the features don't have to be explicitly listed): + +```sh +cargo run --example loopback -- --pcap loopback.pcap +``` + +It opens a server and a client TCP socket, and transfers a chunk of data. You can examine +the packet exchange by opening `loopback.pcap` in [Wireshark]. + +If the `std` feature is enabled, it will print logs and packet dumps, and fault injection +is possible; otherwise, nothing at all will be displayed and no options are accepted. + +[wireshark]: https://wireshark.org + +## License + +_smoltcp_ is distributed under the terms of 0-clause BSD license. + +See [LICENSE-0BSD](LICENSE-0BSD.txt) for details. diff --git a/benches/bench.rs b/benches/bench.rs new file mode 100644 index 0000000..2738840 --- /dev/null +++ b/benches/bench.rs @@ -0,0 +1,117 @@ +#![feature(test)] + +mod wire { + use smoltcp::phy::ChecksumCapabilities; + use smoltcp::wire::{IpAddress, IpProtocol}; + #[cfg(feature = "proto-ipv4")] + use smoltcp::wire::{Ipv4Address, Ipv4Packet, Ipv4Repr}; + #[cfg(feature = "proto-ipv6")] + use smoltcp::wire::{Ipv6Address, Ipv6Packet, Ipv6Repr}; + use smoltcp::wire::{TcpControl, TcpPacket, TcpRepr, TcpSeqNumber}; + use smoltcp::wire::{UdpPacket, UdpRepr}; + + extern crate test; + + #[cfg(feature = "proto-ipv6")] + const SRC_ADDR: IpAddress = IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + ])); + #[cfg(feature = "proto-ipv6")] + const DST_ADDR: IpAddress = IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, + ])); + + #[cfg(all(not(feature = "proto-ipv6"), feature = "proto-ipv4"))] + const SRC_ADDR: IpAddress = IpAddress::Ipv4(Ipv4Address([192, 168, 1, 1])); + #[cfg(all(not(feature = "proto-ipv6"), feature = "proto-ipv4"))] + const DST_ADDR: IpAddress = IpAddress::Ipv4(Ipv4Address([192, 168, 1, 2])); + + #[bench] + #[cfg(any(feature = "proto-ipv6", feature = "proto-ipv4"))] + fn bench_emit_tcp(b: &mut test::Bencher) { + static PAYLOAD_BYTES: [u8; 400] = [0x2a; 400]; + let repr = TcpRepr { + src_port: 48896, + dst_port: 80, + control: TcpControl::Syn, + seq_number: TcpSeqNumber(0x01234567), + ack_number: None, + window_len: 0x0123, + window_scale: None, + max_seg_size: None, + sack_permitted: false, + sack_ranges: [None, None, None], + payload: &PAYLOAD_BYTES, + }; + let mut bytes = vec![0xa5; repr.buffer_len()]; + + b.iter(|| { + let mut packet = TcpPacket::new_unchecked(&mut bytes); + repr.emit( + &mut packet, + &SRC_ADDR, + &DST_ADDR, + &ChecksumCapabilities::default(), + ); + }); + } + + #[bench] + #[cfg(any(feature = "proto-ipv6", feature = "proto-ipv4"))] + fn bench_emit_udp(b: &mut test::Bencher) { + static PAYLOAD_BYTES: [u8; 400] = [0x2a; 400]; + let repr = UdpRepr { + src_port: 48896, + dst_port: 80, + }; + let mut bytes = vec![0xa5; repr.header_len() + PAYLOAD_BYTES.len()]; + + b.iter(|| { + let mut packet = UdpPacket::new_unchecked(&mut bytes); + repr.emit( + &mut packet, + &SRC_ADDR, + &DST_ADDR, + PAYLOAD_BYTES.len(), + |buf| buf.copy_from_slice(&PAYLOAD_BYTES), + &ChecksumCapabilities::default(), + ); + }); + } + + #[bench] + #[cfg(feature = "proto-ipv4")] + fn bench_emit_ipv4(b: &mut test::Bencher) { + let repr = Ipv4Repr { + src_addr: Ipv4Address([192, 168, 1, 1]), + dst_addr: Ipv4Address([192, 168, 1, 2]), + next_header: IpProtocol::Tcp, + payload_len: 100, + hop_limit: 64, + }; + let mut bytes = vec![0xa5; repr.buffer_len()]; + + b.iter(|| { + let mut packet = Ipv4Packet::new_unchecked(&mut bytes); + repr.emit(&mut packet, &ChecksumCapabilities::default()); + }); + } + + #[bench] + #[cfg(feature = "proto-ipv6")] + fn bench_emit_ipv6(b: &mut test::Bencher) { + let repr = Ipv6Repr { + src_addr: Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), + dst_addr: Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]), + next_header: IpProtocol::Tcp, + payload_len: 100, + hop_limit: 64, + }; + let mut bytes = vec![0xa5; repr.buffer_len()]; + + b.iter(|| { + let mut packet = Ipv6Packet::new_unchecked(&mut bytes); + repr.emit(&mut packet); + }); + } +} diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..54662ed --- /dev/null +++ b/build.rs @@ -0,0 +1,104 @@ +use std::collections::HashMap; +use std::fmt::Write; +use std::path::PathBuf; +use std::{env, fs}; + +static CONFIGS: &[(&str, usize)] = &[ + // BEGIN AUTOGENERATED CONFIG FEATURES + // Generated by gen_config.py. DO NOT EDIT. + ("IFACE_MAX_ADDR_COUNT", 2), + ("IFACE_MAX_MULTICAST_GROUP_COUNT", 4), + ("IFACE_MAX_SIXLOWPAN_ADDRESS_CONTEXT_COUNT", 4), + ("IFACE_NEIGHBOR_CACHE_COUNT", 4), + ("IFACE_MAX_ROUTE_COUNT", 2), + ("FRAGMENTATION_BUFFER_SIZE", 1500), + ("ASSEMBLER_MAX_SEGMENT_COUNT", 4), + ("REASSEMBLY_BUFFER_SIZE", 1500), + ("REASSEMBLY_BUFFER_COUNT", 1), + ("IPV6_HBH_MAX_OPTIONS", 1), + ("DNS_MAX_RESULT_COUNT", 1), + ("DNS_MAX_SERVER_COUNT", 1), + ("DNS_MAX_NAME_SIZE", 255), + ("RPL_RELATIONS_BUFFER_COUNT", 16), + ("RPL_PARENTS_BUFFER_COUNT", 8), + // END AUTOGENERATED CONFIG FEATURES +]; + +struct ConfigState { + value: usize, + seen_feature: bool, + seen_env: bool, +} + +fn main() { + // only rebuild if build.rs changed. Otherwise Cargo will rebuild if any + // other file changed. + println!("cargo:rerun-if-changed=build.rs"); + + // Rebuild if config envvar changed. + for (name, _) in CONFIGS { + println!("cargo:rerun-if-env-changed=SMOLTCP_{name}"); + } + + let mut configs = HashMap::new(); + for (name, default) in CONFIGS { + configs.insert( + *name, + ConfigState { + value: *default, + seen_env: false, + seen_feature: false, + }, + ); + } + + for (var, value) in env::vars() { + if let Some(name) = var.strip_prefix("SMOLTCP_") { + let Some(cfg) = configs.get_mut(name) else { + panic!("Unknown env var {name}") + }; + + let Ok(value) = value.parse::<usize>() else { + panic!("Invalid value for env var {name}: {value}") + }; + + cfg.value = value; + cfg.seen_env = true; + } + + if let Some(feature) = var.strip_prefix("CARGO_FEATURE_") { + if let Some(i) = feature.rfind('_') { + let name = &feature[..i]; + let value = &feature[i + 1..]; + if let Some(cfg) = configs.get_mut(name) { + let Ok(value) = value.parse::<usize>() else { + panic!("Invalid value for feature {name}: {value}") + }; + + // envvars take priority. + if !cfg.seen_env { + if cfg.seen_feature { + panic!( + "multiple values set for feature {}: {} and {}", + name, cfg.value, value + ); + } + + cfg.value = value; + cfg.seen_feature = true; + } + } + } + } + } + + let mut data = String::new(); + + for (name, cfg) in &configs { + writeln!(&mut data, "pub const {}: usize = {};", name, cfg.value).unwrap(); + } + + let out_dir = PathBuf::from(env::var_os("OUT_DIR").unwrap()); + let out_file = out_dir.join("config.rs").to_string_lossy().to_string(); + fs::write(out_file, data).unwrap(); +} @@ -0,0 +1,120 @@ +#!/usr/bin/env bash + +set -eox pipefail + +export DEFMT_LOG=trace + +MSRV="1.65.0" + +RUSTC_VERSIONS=( + $MSRV + "stable" + "nightly" +) + +FEATURES_TEST=( + "default" + "std,proto-ipv4" + "std,medium-ethernet,phy-raw_socket,proto-ipv6,socket-udp,socket-dns" + "std,medium-ethernet,phy-tuntap_interface,proto-ipv6,socket-udp" + "std,medium-ethernet,proto-ipv4,proto-ipv4-fragmentation,socket-raw,socket-dns" + "std,medium-ethernet,proto-ipv4,proto-igmp,socket-raw,socket-dns" + "std,medium-ethernet,proto-ipv4,socket-udp,socket-tcp,socket-dns" + "std,medium-ethernet,proto-ipv4,proto-dhcpv4,socket-udp" + "std,medium-ethernet,medium-ip,medium-ieee802154,proto-ipv6,socket-udp,socket-dns" + "std,medium-ethernet,proto-ipv6,socket-tcp" + "std,medium-ethernet,medium-ip,proto-ipv4,socket-icmp,socket-tcp" + "std,medium-ip,proto-ipv6,socket-icmp,socket-tcp" + "std,medium-ieee802154,proto-sixlowpan,socket-udp" + "std,medium-ieee802154,proto-sixlowpan,proto-sixlowpan-fragmentation,socket-udp" + "std,medium-ieee802154,proto-rpl,proto-sixlowpan,proto-sixlowpan-fragmentation,socket-udp" + "std,medium-ip,proto-ipv4,proto-ipv6,socket-tcp,socket-udp" + "std,medium-ethernet,medium-ip,medium-ieee802154,proto-ipv4,proto-ipv6,socket-raw,socket-udp,socket-tcp,socket-icmp,socket-dns,async" + "std,medium-ieee802154,medium-ip,proto-ipv4,socket-raw" + "std,medium-ethernet,proto-ipv4,proto-ipsec,socket-raw" +) + +FEATURES_TEST_NIGHTLY=( + "alloc,medium-ethernet,proto-ipv4,proto-ipv6,socket-raw,socket-udp,socket-tcp,socket-icmp" +) + +FEATURES_CHECK=( + "medium-ip,medium-ethernet,medium-ieee802154,proto-ipv6,proto-ipv6,proto-igmp,proto-dhcpv4,proto-ipsec,socket-raw,socket-udp,socket-tcp,socket-icmp,socket-dns,async" + "defmt,medium-ip,medium-ethernet,proto-ipv6,proto-ipv6,proto-igmp,proto-dhcpv4,socket-raw,socket-udp,socket-tcp,socket-icmp,socket-dns,async" + "defmt,alloc,medium-ip,medium-ethernet,proto-ipv6,proto-ipv6,proto-igmp,proto-dhcpv4,socket-raw,socket-udp,socket-tcp,socket-icmp,socket-dns,async" +) + +test() { + local version=$1 + rustup toolchain install $version + + for features in ${FEATURES_TEST[@]}; do + cargo +$version test --no-default-features --features "$features" + done + + if [[ $version == "nightly" ]]; then + for features in ${FEATURES_TEST_NIGHTLY[@]}; do + cargo +$version test --no-default-features --features "$features" + done + fi +} + +check() { + local version=$1 + rustup toolchain install $version + + export DEFMT_LOG="trace" + + for features in ${FEATURES_CHECK[@]}; do + cargo +$version check --no-default-features --features "$features" + done +} + +clippy() { + rustup toolchain install $MSRV + rustup component add clippy --toolchain=$MSRV + cargo +$MSRV clippy --tests --examples -- -D warnings +} + +coverage() { + for features in ${FEATURES_TEST[@]}; do + cargo llvm-cov --no-report --no-default-features --features "$features" + done + cargo llvm-cov report --lcov --output-path lcov.info +} + +if [[ $1 == "test" || $1 == "all" ]]; then + if [[ -n $2 ]]; then + if [[ $2 == "msrv" ]]; then + test $MSRV + else + test $2 + fi + else + for version in ${RUSTC_VERSIONS[@]}; do + test $version + done + fi +fi + +if [[ $1 == "check" || $1 == "all" ]]; then + if [[ -n $2 ]]; then + if [[ $2 == "msrv" ]]; then + check $MSRV + else + check $2 + fi + else + for version in ${RUSTC_VERSIONS[@]}; do + check $version + done + fi +fi + +if [[ $1 == "clippy" || $1 == "all" ]]; then + clippy +fi + +if [[ $1 == "coverage" || $1 == "all" ]]; then + coverage +fi diff --git a/examples/benchmark.rs b/examples/benchmark.rs new file mode 100644 index 0000000..ad2c6e1 --- /dev/null +++ b/examples/benchmark.rs @@ -0,0 +1,164 @@ +#![allow(clippy::collapsible_if)] + +mod utils; + +use std::cmp; +use std::io::{Read, Write}; +use std::net::TcpStream; +use std::os::unix::io::AsRawFd; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::thread; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium}; +use smoltcp::socket::tcp; +use smoltcp::time::{Duration, Instant}; +use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr}; + +const AMOUNT: usize = 1_000_000_000; + +enum Client { + Reader, + Writer, +} + +fn client(kind: Client) { + let port = match kind { + Client::Reader => 1234, + Client::Writer => 1235, + }; + let mut stream = TcpStream::connect(("192.168.69.1", port)).unwrap(); + let mut buffer = vec![0; 1_000_000]; + + let start = Instant::now(); + + let mut processed = 0; + while processed < AMOUNT { + let length = cmp::min(buffer.len(), AMOUNT - processed); + let result = match kind { + Client::Reader => stream.read(&mut buffer[..length]), + Client::Writer => stream.write(&buffer[..length]), + }; + match result { + Ok(0) => break, + Ok(result) => { + // print!("(P:{})", result); + processed += result + } + Err(err) => panic!("cannot process: {err}"), + } + } + + let end = Instant::now(); + + let elapsed = (end - start).total_millis() as f64 / 1000.0; + + println!("throughput: {:.3} Gbps", AMOUNT as f64 / elapsed / 0.125e9); + + CLIENT_DONE.store(true, Ordering::SeqCst); +} + +static CLIENT_DONE: AtomicBool = AtomicBool::new(false); + +fn main() { + #[cfg(feature = "log")] + utils::setup_logging("info"); + + let (mut opts, mut free) = utils::create_options(); + utils::add_tuntap_options(&mut opts, &mut free); + utils::add_middleware_options(&mut opts, &mut free); + free.push("MODE"); + + let mut matches = utils::parse_options(&opts, free); + let device = utils::parse_tuntap_options(&mut matches); + let fd = device.as_raw_fd(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + let mode = match matches.free[0].as_ref() { + "reader" => Client::Reader, + "writer" => Client::Writer, + _ => panic!("invalid mode"), + }; + + let tcp1_rx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp1_tx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp1_socket = tcp::Socket::new(tcp1_rx_buffer, tcp1_tx_buffer); + + let tcp2_rx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp2_tx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp2_socket = tcp::Socket::new(tcp2_rx_buffer, tcp2_tx_buffer); + + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + }); + + let mut sockets = SocketSet::new(vec![]); + let tcp1_handle = sockets.add(tcp1_socket); + let tcp2_handle = sockets.add(tcp2_socket); + let default_timeout = Some(Duration::from_millis(1000)); + + thread::spawn(move || client(mode)); + let mut processed = 0; + while !CLIENT_DONE.load(Ordering::SeqCst) { + let timestamp = Instant::now(); + iface.poll(timestamp, &mut device, &mut sockets); + + // tcp:1234: emit data + let socket = sockets.get_mut::<tcp::Socket>(tcp1_handle); + if !socket.is_open() { + socket.listen(1234).unwrap(); + } + + if socket.can_send() { + if processed < AMOUNT { + let length = socket + .send(|buffer| { + let length = cmp::min(buffer.len(), AMOUNT - processed); + (length, length) + }) + .unwrap(); + processed += length; + } + } + + // tcp:1235: sink data + let socket = sockets.get_mut::<tcp::Socket>(tcp2_handle); + if !socket.is_open() { + socket.listen(1235).unwrap(); + } + + if socket.can_recv() { + if processed < AMOUNT { + let length = socket + .recv(|buffer| { + let length = cmp::min(buffer.len(), AMOUNT - processed); + (length, length) + }) + .unwrap(); + processed += length; + } + } + + match iface.poll_at(timestamp, &sockets) { + Some(poll_at) if timestamp < poll_at => { + phy_wait(fd, Some(poll_at - timestamp)).expect("wait error"); + } + Some(_) => (), + None => { + phy_wait(fd, default_timeout).expect("wait error"); + } + } + } +} diff --git a/examples/client.rs b/examples/client.rs new file mode 100644 index 0000000..c18c08f --- /dev/null +++ b/examples/client.rs @@ -0,0 +1,118 @@ +mod utils; + +use log::debug; +use std::os::unix::io::AsRawFd; +use std::str::{self, FromStr}; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium}; +use smoltcp::socket::tcp; +use smoltcp::time::Instant; +use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr, Ipv4Address, Ipv6Address}; + +fn main() { + utils::setup_logging(""); + + let (mut opts, mut free) = utils::create_options(); + utils::add_tuntap_options(&mut opts, &mut free); + utils::add_middleware_options(&mut opts, &mut free); + free.push("ADDRESS"); + free.push("PORT"); + + let mut matches = utils::parse_options(&opts, free); + let device = utils::parse_tuntap_options(&mut matches); + + let fd = device.as_raw_fd(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + let address = IpAddress::from_str(&matches.free[0]).expect("invalid address format"); + let port = u16::from_str(&matches.free[1]).expect("invalid port format"); + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + iface + .routes_mut() + .add_default_ipv4_route(Ipv4Address::new(192, 168, 69, 100)) + .unwrap(); + iface + .routes_mut() + .add_default_ipv6_route(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100)) + .unwrap(); + + // Create sockets + let tcp_rx_buffer = tcp::SocketBuffer::new(vec![0; 1500]); + let tcp_tx_buffer = tcp::SocketBuffer::new(vec![0; 1500]); + let tcp_socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer); + let mut sockets = SocketSet::new(vec![]); + let tcp_handle = sockets.add(tcp_socket); + + let socket = sockets.get_mut::<tcp::Socket>(tcp_handle); + socket + .connect(iface.context(), (address, port), 49500) + .unwrap(); + + let mut tcp_active = false; + loop { + let timestamp = Instant::now(); + iface.poll(timestamp, &mut device, &mut sockets); + + let socket = sockets.get_mut::<tcp::Socket>(tcp_handle); + if socket.is_active() && !tcp_active { + debug!("connected"); + } else if !socket.is_active() && tcp_active { + debug!("disconnected"); + break; + } + tcp_active = socket.is_active(); + + if socket.may_recv() { + let data = socket + .recv(|data| { + let mut data = data.to_owned(); + if !data.is_empty() { + debug!( + "recv data: {:?}", + str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)") + ); + data = data.split(|&b| b == b'\n').collect::<Vec<_>>().concat(); + data.reverse(); + data.extend(b"\n"); + } + (data.len(), data) + }) + .unwrap(); + if socket.can_send() && !data.is_empty() { + debug!( + "send data: {:?}", + str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)") + ); + socket.send_slice(&data[..]).unwrap(); + } + } else if socket.may_send() { + debug!("close"); + socket.close(); + } + + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); + } +} diff --git a/examples/dhcp_client.rs b/examples/dhcp_client.rs new file mode 100644 index 0000000..9ef46c2 --- /dev/null +++ b/examples/dhcp_client.rs @@ -0,0 +1,94 @@ +#![allow(clippy::option_map_unit_fn)] +mod utils; + +use log::*; +use std::os::unix::io::AsRawFd; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::socket::dhcpv4; +use smoltcp::time::Instant; +use smoltcp::wire::{EthernetAddress, IpCidr, Ipv4Address, Ipv4Cidr}; +use smoltcp::{ + phy::{wait as phy_wait, Device, Medium}, + time::Duration, +}; + +fn main() { + #[cfg(feature = "log")] + utils::setup_logging(""); + + let (mut opts, mut free) = utils::create_options(); + utils::add_tuntap_options(&mut opts, &mut free); + utils::add_middleware_options(&mut opts, &mut free); + + let mut matches = utils::parse_options(&opts, free); + let device = utils::parse_tuntap_options(&mut matches); + let fd = device.as_raw_fd(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); + let mut iface = Interface::new(config, &mut device, Instant::now()); + + // Create sockets + let mut dhcp_socket = dhcpv4::Socket::new(); + + // Set a ridiculously short max lease time to show DHCP renews work properly. + // This will cause the DHCP client to start renewing after 5 seconds, and give up the + // lease after 10 seconds if renew hasn't succeeded. + // IMPORTANT: This should be removed in production. + dhcp_socket.set_max_lease_duration(Some(Duration::from_secs(10))); + + let mut sockets = SocketSet::new(vec![]); + let dhcp_handle = sockets.add(dhcp_socket); + + loop { + let timestamp = Instant::now(); + iface.poll(timestamp, &mut device, &mut sockets); + + let event = sockets.get_mut::<dhcpv4::Socket>(dhcp_handle).poll(); + match event { + None => {} + Some(dhcpv4::Event::Configured(config)) => { + debug!("DHCP config acquired!"); + + debug!("IP address: {}", config.address); + set_ipv4_addr(&mut iface, config.address); + + if let Some(router) = config.router { + debug!("Default gateway: {}", router); + iface.routes_mut().add_default_ipv4_route(router).unwrap(); + } else { + debug!("Default gateway: None"); + iface.routes_mut().remove_default_ipv4_route(); + } + + for (i, s) in config.dns_servers.iter().enumerate() { + debug!("DNS server {}: {}", i, s); + } + } + Some(dhcpv4::Event::Deconfigured) => { + debug!("DHCP lost config!"); + set_ipv4_addr(&mut iface, Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0)); + iface.routes_mut().remove_default_ipv4_route(); + } + } + + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); + } +} + +fn set_ipv4_addr(iface: &mut Interface, cidr: Ipv4Cidr) { + iface.update_ip_addrs(|addrs| { + let dest = addrs.iter_mut().next().unwrap(); + *dest = IpCidr::Ipv4(cidr); + }); +} diff --git a/examples/dns.rs b/examples/dns.rs new file mode 100644 index 0000000..977f405 --- /dev/null +++ b/examples/dns.rs @@ -0,0 +1,92 @@ +mod utils; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::Device; +use smoltcp::phy::{wait as phy_wait, Medium}; +use smoltcp::socket::dns::{self, GetQueryResultError}; +use smoltcp::time::Instant; +use smoltcp::wire::{DnsQueryType, EthernetAddress, IpAddress, IpCidr, Ipv4Address, Ipv6Address}; +use std::os::unix::io::AsRawFd; + +fn main() { + utils::setup_logging("warn"); + + let (mut opts, mut free) = utils::create_options(); + utils::add_tuntap_options(&mut opts, &mut free); + utils::add_middleware_options(&mut opts, &mut free); + free.push("ADDRESS"); + + let mut matches = utils::parse_options(&opts, free); + let device = utils::parse_tuntap_options(&mut matches); + let fd = device.as_raw_fd(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + let name = &matches.free[0]; + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + iface + .routes_mut() + .add_default_ipv4_route(Ipv4Address::new(192, 168, 69, 100)) + .unwrap(); + iface + .routes_mut() + .add_default_ipv6_route(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100)) + .unwrap(); + + // Create sockets + let servers = &[ + Ipv4Address::new(8, 8, 4, 4).into(), + Ipv4Address::new(8, 8, 8, 8).into(), + ]; + let dns_socket = dns::Socket::new(servers, vec![]); + + let mut sockets = SocketSet::new(vec![]); + let dns_handle = sockets.add(dns_socket); + + let socket = sockets.get_mut::<dns::Socket>(dns_handle); + let query = socket + .start_query(iface.context(), name, DnsQueryType::A) + .unwrap(); + + loop { + let timestamp = Instant::now(); + log::debug!("timestamp {:?}", timestamp); + + iface.poll(timestamp, &mut device, &mut sockets); + + match sockets + .get_mut::<dns::Socket>(dns_handle) + .get_query_result(query) + { + Ok(addrs) => { + println!("Query done: {addrs:?}"); + break; + } + Err(GetQueryResultError::Pending) => {} // not done yet + Err(e) => panic!("query failed: {e:?}"), + } + + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); + } +} diff --git a/examples/httpclient.rs b/examples/httpclient.rs new file mode 100644 index 0000000..8f3a53a --- /dev/null +++ b/examples/httpclient.rs @@ -0,0 +1,123 @@ +mod utils; + +use log::debug; +use std::os::unix::io::AsRawFd; +use std::str::{self, FromStr}; +use url::Url; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium}; +use smoltcp::socket::tcp; +use smoltcp::time::Instant; +use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr, Ipv4Address, Ipv6Address}; + +fn main() { + utils::setup_logging(""); + + let (mut opts, mut free) = utils::create_options(); + utils::add_tuntap_options(&mut opts, &mut free); + utils::add_middleware_options(&mut opts, &mut free); + free.push("ADDRESS"); + free.push("URL"); + + let mut matches = utils::parse_options(&opts, free); + let device = utils::parse_tuntap_options(&mut matches); + let fd = device.as_raw_fd(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + let address = IpAddress::from_str(&matches.free[0]).expect("invalid address format"); + let url = Url::parse(&matches.free[1]).expect("invalid url format"); + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + iface + .routes_mut() + .add_default_ipv4_route(Ipv4Address::new(192, 168, 69, 100)) + .unwrap(); + iface + .routes_mut() + .add_default_ipv6_route(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100)) + .unwrap(); + + // Create sockets + let tcp_rx_buffer = tcp::SocketBuffer::new(vec![0; 1024]); + let tcp_tx_buffer = tcp::SocketBuffer::new(vec![0; 1024]); + let tcp_socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer); + + let mut sockets = SocketSet::new(vec![]); + let tcp_handle = sockets.add(tcp_socket); + + enum State { + Connect, + Request, + Response, + } + let mut state = State::Connect; + + loop { + let timestamp = Instant::now(); + iface.poll(timestamp, &mut device, &mut sockets); + + let socket = sockets.get_mut::<tcp::Socket>(tcp_handle); + let cx = iface.context(); + + state = match state { + State::Connect if !socket.is_active() => { + debug!("connecting"); + let local_port = 49152 + rand::random::<u16>() % 16384; + socket + .connect(cx, (address, url.port().unwrap_or(80)), local_port) + .unwrap(); + State::Request + } + State::Request if socket.may_send() => { + debug!("sending request"); + let http_get = "GET ".to_owned() + url.path() + " HTTP/1.1\r\n"; + socket.send_slice(http_get.as_ref()).expect("cannot send"); + let http_host = "Host: ".to_owned() + url.host_str().unwrap() + "\r\n"; + socket.send_slice(http_host.as_ref()).expect("cannot send"); + socket + .send_slice(b"Connection: close\r\n") + .expect("cannot send"); + socket.send_slice(b"\r\n").expect("cannot send"); + State::Response + } + State::Response if socket.can_recv() => { + socket + .recv(|data| { + println!("{}", str::from_utf8(data).unwrap_or("(invalid utf8)")); + (data.len(), ()) + }) + .unwrap(); + State::Response + } + State::Response if !socket.may_recv() => { + debug!("received complete response"); + break; + } + _ => state, + }; + + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); + } +} diff --git a/examples/loopback.rs b/examples/loopback.rs new file mode 100644 index 0000000..7ca95b1 --- /dev/null +++ b/examples/loopback.rs @@ -0,0 +1,184 @@ +#![cfg_attr(not(feature = "std"), no_std)] +#![allow(unused_mut)] +#![allow(clippy::collapsible_if)] + +#[cfg(feature = "std")] +#[allow(dead_code)] +mod utils; + +use core::str; +use log::{debug, error, info}; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{Device, Loopback, Medium}; +use smoltcp::socket::tcp; +use smoltcp::time::{Duration, Instant}; +use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr}; + +#[cfg(not(feature = "std"))] +mod mock { + use core::cell::Cell; + use smoltcp::time::{Duration, Instant}; + + #[derive(Debug)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct Clock(Cell<Instant>); + + impl Clock { + pub fn new() -> Clock { + Clock(Cell::new(Instant::from_millis(0))) + } + + pub fn advance(&self, duration: Duration) { + self.0.set(self.0.get() + duration) + } + + pub fn elapsed(&self) -> Instant { + self.0.get() + } + } +} + +#[cfg(feature = "std")] +mod mock { + use smoltcp::time::{Duration, Instant}; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + // should be AtomicU64 but that's unstable + #[derive(Debug, Clone)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct Clock(Arc<AtomicUsize>); + + impl Clock { + pub fn new() -> Clock { + Clock(Arc::new(AtomicUsize::new(0))) + } + + pub fn advance(&self, duration: Duration) { + self.0 + .fetch_add(duration.total_millis() as usize, Ordering::SeqCst); + } + + pub fn elapsed(&self) -> Instant { + Instant::from_millis(self.0.load(Ordering::SeqCst) as i64) + } + } +} + +fn main() { + let clock = mock::Clock::new(); + let device = Loopback::new(Medium::Ethernet); + + #[cfg(feature = "std")] + let mut device = { + let clock = clock.clone(); + utils::setup_logging_with_clock("", move || clock.elapsed()); + + let (mut opts, mut free) = utils::create_options(); + utils::add_middleware_options(&mut opts, &mut free); + + let mut matches = utils::parse_options(&opts, free); + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ true) + }; + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(127, 0, 0, 1), 8)) + .unwrap(); + }); + + // Create sockets + let server_socket = { + // It is not strictly necessary to use a `static mut` and unsafe code here, but + // on embedded systems that smoltcp targets it is far better to allocate the data + // statically to verify that it fits into RAM rather than get undefined behavior + // when stack overflows. + static mut TCP_SERVER_RX_DATA: [u8; 1024] = [0; 1024]; + static mut TCP_SERVER_TX_DATA: [u8; 1024] = [0; 1024]; + let tcp_rx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); + let tcp_tx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer) + }; + + let client_socket = { + static mut TCP_CLIENT_RX_DATA: [u8; 1024] = [0; 1024]; + static mut TCP_CLIENT_TX_DATA: [u8; 1024] = [0; 1024]; + let tcp_rx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_CLIENT_RX_DATA[..] }); + let tcp_tx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_CLIENT_TX_DATA[..] }); + tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer) + }; + + let mut sockets: [_; 2] = Default::default(); + let mut sockets = SocketSet::new(&mut sockets[..]); + let server_handle = sockets.add(server_socket); + let client_handle = sockets.add(client_socket); + + let mut did_listen = false; + let mut did_connect = false; + let mut done = false; + while !done && clock.elapsed() < Instant::from_millis(10_000) { + iface.poll(clock.elapsed(), &mut device, &mut sockets); + + let mut socket = sockets.get_mut::<tcp::Socket>(server_handle); + if !socket.is_active() && !socket.is_listening() { + if !did_listen { + debug!("listening"); + socket.listen(1234).unwrap(); + did_listen = true; + } + } + + if socket.can_recv() { + debug!( + "got {:?}", + socket.recv(|buffer| { (buffer.len(), str::from_utf8(buffer).unwrap()) }) + ); + socket.close(); + done = true; + } + + let mut socket = sockets.get_mut::<tcp::Socket>(client_handle); + let cx = iface.context(); + if !socket.is_open() { + if !did_connect { + debug!("connecting"); + socket + .connect(cx, (IpAddress::v4(127, 0, 0, 1), 1234), 65000) + .unwrap(); + did_connect = true; + } + } + + if socket.can_send() { + debug!("sending"); + socket.send_slice(b"0123456789abcdef").unwrap(); + socket.close(); + } + + match iface.poll_delay(clock.elapsed(), &sockets) { + Some(Duration::ZERO) => debug!("resuming"), + Some(delay) => { + debug!("sleeping for {} ms", delay); + clock.advance(delay) + } + None => clock.advance(Duration::from_millis(1)), + } + } + + if done { + info!("done") + } else { + error!("this is taking too long, bailing out") + } +} diff --git a/examples/multicast.rs b/examples/multicast.rs new file mode 100644 index 0000000..ea89a2e --- /dev/null +++ b/examples/multicast.rs @@ -0,0 +1,129 @@ +mod utils; + +use std::os::unix::io::AsRawFd; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium}; +use smoltcp::socket::{raw, udp}; +use smoltcp::time::Instant; +use smoltcp::wire::{ + EthernetAddress, IgmpPacket, IgmpRepr, IpAddress, IpCidr, IpProtocol, IpVersion, Ipv4Address, + Ipv4Packet, Ipv6Address, +}; + +const MDNS_PORT: u16 = 5353; +const MDNS_GROUP: [u8; 4] = [224, 0, 0, 251]; + +fn main() { + utils::setup_logging("warn"); + + let (mut opts, mut free) = utils::create_options(); + utils::add_tuntap_options(&mut opts, &mut free); + utils::add_middleware_options(&mut opts, &mut free); + + let mut matches = utils::parse_options(&opts, free); + let device = utils::parse_tuntap_options(&mut matches); + let fd = device.as_raw_fd(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + iface + .routes_mut() + .add_default_ipv4_route(Ipv4Address::new(192, 168, 69, 100)) + .unwrap(); + iface + .routes_mut() + .add_default_ipv6_route(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100)) + .unwrap(); + + // Create sockets + let mut sockets = SocketSet::new(vec![]); + + // Must fit at least one IGMP packet + let raw_rx_buffer = raw::PacketBuffer::new(vec![raw::PacketMetadata::EMPTY; 2], vec![0; 512]); + // Will not send IGMP + let raw_tx_buffer = raw::PacketBuffer::new(vec![], vec![]); + let raw_socket = raw::Socket::new( + IpVersion::Ipv4, + IpProtocol::Igmp, + raw_rx_buffer, + raw_tx_buffer, + ); + let raw_handle = sockets.add(raw_socket); + + // Must fit mDNS payload of at least one packet + let udp_rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY; 4], vec![0; 1024]); + // Will not send mDNS + let udp_tx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 0]); + let udp_socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); + let udp_handle = sockets.add(udp_socket); + + // Join a multicast group to receive mDNS traffic + iface + .join_multicast_group( + &mut device, + Ipv4Address::from_bytes(&MDNS_GROUP), + Instant::now(), + ) + .unwrap(); + + loop { + let timestamp = Instant::now(); + iface.poll(timestamp, &mut device, &mut sockets); + + let socket = sockets.get_mut::<raw::Socket>(raw_handle); + + if socket.can_recv() { + // For display purposes only - normally we wouldn't process incoming IGMP packets + // in the application layer + match socket.recv() { + Err(e) => println!("Recv IGMP error: {e:?}"), + Ok(buf) => { + Ipv4Packet::new_checked(buf) + .and_then(|ipv4_packet| IgmpPacket::new_checked(ipv4_packet.payload())) + .and_then(|igmp_packet| IgmpRepr::parse(&igmp_packet)) + .map(|igmp_repr| println!("IGMP packet: {igmp_repr:?}")) + .unwrap_or_else(|e| println!("parse IGMP error: {e:?}")); + } + } + } + + let socket = sockets.get_mut::<udp::Socket>(udp_handle); + if !socket.is_open() { + socket.bind(MDNS_PORT).unwrap() + } + + if socket.can_recv() { + socket + .recv() + .map(|(data, sender)| { + println!("mDNS traffic: {} UDP bytes from {}", data.len(), sender) + }) + .unwrap_or_else(|e| println!("Recv UDP error: {e:?}")); + } + + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); + } +} diff --git a/examples/ping.rs b/examples/ping.rs new file mode 100644 index 0000000..7e33a21 --- /dev/null +++ b/examples/ping.rs @@ -0,0 +1,281 @@ +mod utils; + +use byteorder::{ByteOrder, NetworkEndian}; +use smoltcp::iface::{Interface, SocketSet}; +use std::cmp; +use std::collections::HashMap; +use std::os::unix::io::AsRawFd; +use std::str::FromStr; + +use smoltcp::iface::Config; +use smoltcp::phy::wait as phy_wait; +use smoltcp::phy::Device; +use smoltcp::socket::icmp; +use smoltcp::wire::{ + EthernetAddress, Icmpv4Packet, Icmpv4Repr, Icmpv6Packet, Icmpv6Repr, IpAddress, IpCidr, + Ipv4Address, Ipv6Address, +}; +use smoltcp::{ + phy::Medium, + time::{Duration, Instant}, +}; + +macro_rules! send_icmp_ping { + ( $repr_type:ident, $packet_type:ident, $ident:expr, $seq_no:expr, + $echo_payload:expr, $socket:expr, $remote_addr:expr ) => {{ + let icmp_repr = $repr_type::EchoRequest { + ident: $ident, + seq_no: $seq_no, + data: &$echo_payload, + }; + + let icmp_payload = $socket.send(icmp_repr.buffer_len(), $remote_addr).unwrap(); + + let icmp_packet = $packet_type::new_unchecked(icmp_payload); + (icmp_repr, icmp_packet) + }}; +} + +macro_rules! get_icmp_pong { + ( $repr_type:ident, $repr:expr, $payload:expr, $waiting_queue:expr, $remote_addr:expr, + $timestamp:expr, $received:expr ) => {{ + if let $repr_type::EchoReply { seq_no, data, .. } = $repr { + if let Some(_) = $waiting_queue.get(&seq_no) { + let packet_timestamp_ms = NetworkEndian::read_i64(data); + println!( + "{} bytes from {}: icmp_seq={}, time={}ms", + data.len(), + $remote_addr, + seq_no, + $timestamp.total_millis() - packet_timestamp_ms + ); + $waiting_queue.remove(&seq_no); + $received += 1; + } + } + }}; +} + +fn main() { + utils::setup_logging("warn"); + + let (mut opts, mut free) = utils::create_options(); + utils::add_tuntap_options(&mut opts, &mut free); + utils::add_middleware_options(&mut opts, &mut free); + opts.optopt( + "c", + "count", + "Amount of echo request packets to send (default: 4)", + "COUNT", + ); + opts.optopt( + "i", + "interval", + "Interval between successive packets sent (seconds) (default: 1)", + "INTERVAL", + ); + opts.optopt( + "", + "timeout", + "Maximum wait duration for an echo response packet (seconds) (default: 5)", + "TIMEOUT", + ); + free.push("ADDRESS"); + + let mut matches = utils::parse_options(&opts, free); + let device = utils::parse_tuntap_options(&mut matches); + let fd = device.as_raw_fd(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + let device_caps = device.capabilities(); + let remote_addr = IpAddress::from_str(&matches.free[0]).expect("invalid address format"); + let count = matches + .opt_str("count") + .map(|s| usize::from_str(&s).unwrap()) + .unwrap_or(4); + let interval = matches + .opt_str("interval") + .map(|s| Duration::from_secs(u64::from_str(&s).unwrap())) + .unwrap_or_else(|| Duration::from_secs(1)); + let timeout = Duration::from_secs( + matches + .opt_str("timeout") + .map(|s| u64::from_str(&s).unwrap()) + .unwrap_or(5), + ); + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + iface + .routes_mut() + .add_default_ipv4_route(Ipv4Address::new(192, 168, 69, 100)) + .unwrap(); + iface + .routes_mut() + .add_default_ipv6_route(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100)) + .unwrap(); + + // Create sockets + let icmp_rx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 256]); + let icmp_tx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 256]); + let icmp_socket = icmp::Socket::new(icmp_rx_buffer, icmp_tx_buffer); + let mut sockets = SocketSet::new(vec![]); + let icmp_handle = sockets.add(icmp_socket); + + let mut send_at = Instant::from_millis(0); + let mut seq_no = 0; + let mut received = 0; + let mut echo_payload = [0xffu8; 40]; + let mut waiting_queue = HashMap::new(); + let ident = 0x22b; + + loop { + let timestamp = Instant::now(); + iface.poll(timestamp, &mut device, &mut sockets); + + let timestamp = Instant::now(); + let socket = sockets.get_mut::<icmp::Socket>(icmp_handle); + if !socket.is_open() { + socket.bind(icmp::Endpoint::Ident(ident)).unwrap(); + send_at = timestamp; + } + + if socket.can_send() && seq_no < count as u16 && send_at <= timestamp { + NetworkEndian::write_i64(&mut echo_payload, timestamp.total_millis()); + + match remote_addr { + IpAddress::Ipv4(_) => { + let (icmp_repr, mut icmp_packet) = send_icmp_ping!( + Icmpv4Repr, + Icmpv4Packet, + ident, + seq_no, + echo_payload, + socket, + remote_addr + ); + icmp_repr.emit(&mut icmp_packet, &device_caps.checksum); + } + IpAddress::Ipv6(address) => { + let (icmp_repr, mut icmp_packet) = send_icmp_ping!( + Icmpv6Repr, + Icmpv6Packet, + ident, + seq_no, + echo_payload, + socket, + remote_addr + ); + icmp_repr.emit( + &iface + .get_source_address_ipv6(&address) + .unwrap() + .into_address(), + &remote_addr, + &mut icmp_packet, + &device_caps.checksum, + ); + } + } + + waiting_queue.insert(seq_no, timestamp); + seq_no += 1; + send_at += interval; + } + + if socket.can_recv() { + let (payload, _) = socket.recv().unwrap(); + + match remote_addr { + IpAddress::Ipv4(_) => { + let icmp_packet = Icmpv4Packet::new_checked(&payload).unwrap(); + let icmp_repr = Icmpv4Repr::parse(&icmp_packet, &device_caps.checksum).unwrap(); + get_icmp_pong!( + Icmpv4Repr, + icmp_repr, + payload, + waiting_queue, + remote_addr, + timestamp, + received + ); + } + IpAddress::Ipv6(address) => { + let icmp_packet = Icmpv6Packet::new_checked(&payload).unwrap(); + let icmp_repr = Icmpv6Repr::parse( + &remote_addr, + &iface + .get_source_address_ipv6(&address) + .unwrap() + .into_address(), + &icmp_packet, + &device_caps.checksum, + ) + .unwrap(); + get_icmp_pong!( + Icmpv6Repr, + icmp_repr, + payload, + waiting_queue, + remote_addr, + timestamp, + received + ); + } + } + } + + waiting_queue.retain(|seq, from| { + if timestamp - *from < timeout { + true + } else { + println!("From {remote_addr} icmp_seq={seq} timeout"); + false + } + }); + + if seq_no == count as u16 && waiting_queue.is_empty() { + break; + } + + let timestamp = Instant::now(); + match iface.poll_at(timestamp, &sockets) { + Some(poll_at) if timestamp < poll_at => { + let resume_at = cmp::min(poll_at, send_at); + phy_wait(fd, Some(resume_at - timestamp)).expect("wait error"); + } + Some(_) => (), + None => { + phy_wait(fd, Some(send_at - timestamp)).expect("wait error"); + } + } + } + + println!("--- {remote_addr} ping statistics ---"); + println!( + "{} packets transmitted, {} received, {:.0}% packet loss", + seq_no, + received, + 100.0 * (seq_no - received) as f64 / seq_no as f64 + ); +} diff --git a/examples/server.rs b/examples/server.rs new file mode 100644 index 0000000..33d95c5 --- /dev/null +++ b/examples/server.rs @@ -0,0 +1,209 @@ +mod utils; + +use log::debug; +use std::fmt::Write; +use std::os::unix::io::AsRawFd; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium}; +use smoltcp::socket::{tcp, udp}; +use smoltcp::time::{Duration, Instant}; +use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr, Ipv4Address, Ipv6Address}; + +fn main() { + utils::setup_logging(""); + + let (mut opts, mut free) = utils::create_options(); + utils::add_tuntap_options(&mut opts, &mut free); + utils::add_middleware_options(&mut opts, &mut free); + + let mut matches = utils::parse_options(&opts, free); + let device = utils::parse_tuntap_options(&mut matches); + let fd = device.as_raw_fd(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + + config.random_seed = rand::random(); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + iface + .routes_mut() + .add_default_ipv4_route(Ipv4Address::new(192, 168, 69, 100)) + .unwrap(); + iface + .routes_mut() + .add_default_ipv6_route(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100)) + .unwrap(); + + // Create sockets + let udp_rx_buffer = udp::PacketBuffer::new( + vec![udp::PacketMetadata::EMPTY, udp::PacketMetadata::EMPTY], + vec![0; 65535], + ); + let udp_tx_buffer = udp::PacketBuffer::new( + vec![udp::PacketMetadata::EMPTY, udp::PacketMetadata::EMPTY], + vec![0; 65535], + ); + let udp_socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); + + let tcp1_rx_buffer = tcp::SocketBuffer::new(vec![0; 64]); + let tcp1_tx_buffer = tcp::SocketBuffer::new(vec![0; 128]); + let tcp1_socket = tcp::Socket::new(tcp1_rx_buffer, tcp1_tx_buffer); + + let tcp2_rx_buffer = tcp::SocketBuffer::new(vec![0; 64]); + let tcp2_tx_buffer = tcp::SocketBuffer::new(vec![0; 128]); + let tcp2_socket = tcp::Socket::new(tcp2_rx_buffer, tcp2_tx_buffer); + + let tcp3_rx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp3_tx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp3_socket = tcp::Socket::new(tcp3_rx_buffer, tcp3_tx_buffer); + + let tcp4_rx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp4_tx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp4_socket = tcp::Socket::new(tcp4_rx_buffer, tcp4_tx_buffer); + + let mut sockets = SocketSet::new(vec![]); + let udp_handle = sockets.add(udp_socket); + let tcp1_handle = sockets.add(tcp1_socket); + let tcp2_handle = sockets.add(tcp2_socket); + let tcp3_handle = sockets.add(tcp3_socket); + let tcp4_handle = sockets.add(tcp4_socket); + + let mut tcp_6970_active = false; + loop { + let timestamp = Instant::now(); + iface.poll(timestamp, &mut device, &mut sockets); + + // udp:6969: respond "hello" + let socket = sockets.get_mut::<udp::Socket>(udp_handle); + if !socket.is_open() { + socket.bind(6969).unwrap() + } + + let client = match socket.recv() { + Ok((data, endpoint)) => { + debug!("udp:6969 recv data: {:?} from {}", data, endpoint); + let mut data = data.to_vec(); + data.reverse(); + Some((endpoint, data)) + } + Err(_) => None, + }; + if let Some((endpoint, data)) = client { + debug!("udp:6969 send data: {:?} to {}", data, endpoint,); + socket.send_slice(&data, endpoint).unwrap(); + } + + // tcp:6969: respond "hello" + let socket = sockets.get_mut::<tcp::Socket>(tcp1_handle); + if !socket.is_open() { + socket.listen(6969).unwrap(); + } + + if socket.can_send() { + debug!("tcp:6969 send greeting"); + writeln!(socket, "hello").unwrap(); + debug!("tcp:6969 close"); + socket.close(); + } + + // tcp:6970: echo with reverse + let socket = sockets.get_mut::<tcp::Socket>(tcp2_handle); + if !socket.is_open() { + socket.listen(6970).unwrap() + } + + if socket.is_active() && !tcp_6970_active { + debug!("tcp:6970 connected"); + } else if !socket.is_active() && tcp_6970_active { + debug!("tcp:6970 disconnected"); + } + tcp_6970_active = socket.is_active(); + + if socket.may_recv() { + let data = socket + .recv(|buffer| { + let recvd_len = buffer.len(); + let mut data = buffer.to_owned(); + if !data.is_empty() { + debug!("tcp:6970 recv data: {:?}", data); + data = data.split(|&b| b == b'\n').collect::<Vec<_>>().concat(); + data.reverse(); + data.extend(b"\n"); + } + (recvd_len, data) + }) + .unwrap(); + if socket.can_send() && !data.is_empty() { + debug!("tcp:6970 send data: {:?}", data); + socket.send_slice(&data[..]).unwrap(); + } + } else if socket.may_send() { + debug!("tcp:6970 close"); + socket.close(); + } + + // tcp:6971: sinkhole + let socket = sockets.get_mut::<tcp::Socket>(tcp3_handle); + if !socket.is_open() { + socket.listen(6971).unwrap(); + socket.set_keep_alive(Some(Duration::from_millis(1000))); + socket.set_timeout(Some(Duration::from_millis(2000))); + } + + if socket.may_recv() { + socket + .recv(|buffer| { + if !buffer.is_empty() { + debug!("tcp:6971 recv {:?} octets", buffer.len()); + } + (buffer.len(), ()) + }) + .unwrap(); + } else if socket.may_send() { + socket.close(); + } + + // tcp:6972: fountain + let socket = sockets.get_mut::<tcp::Socket>(tcp4_handle); + if !socket.is_open() { + socket.listen(6972).unwrap() + } + + if socket.may_send() { + socket + .send(|data| { + if !data.is_empty() { + debug!("tcp:6972 send {:?} octets", data.len()); + for (i, b) in data.iter_mut().enumerate() { + *b = (i % 256) as u8; + } + } + (data.len(), ()) + }) + .unwrap(); + } + + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); + } +} diff --git a/examples/sixlowpan.rs b/examples/sixlowpan.rs new file mode 100644 index 0000000..0d9ec21 --- /dev/null +++ b/examples/sixlowpan.rs @@ -0,0 +1,177 @@ +//! 6lowpan example +//! +//! This example is designed to run using the Linux ieee802154/6lowpan support, +//! using mac802154_hwsim. +//! +//! mac802154_hwsim allows you to create multiple "virtual" radios and specify +//! which is in range with which. This is very useful for testing without +//! needing real hardware. By default it creates two interfaces `wpan0` and +//! `wpan1` that are in range with each other. You can customize this with +//! the `wpan-hwsim` tool. +//! +//! We'll configure Linux to speak 6lowpan on `wpan0`, and leave `wpan1` +//! unconfigured so smoltcp can use it with a raw socket. +//! +//! # Setup +//! +//! modprobe mac802154_hwsim +//! +//! ip link set wpan0 down +//! ip link set wpan1 down +//! iwpan dev wpan0 set pan_id 0xbeef +//! iwpan dev wpan1 set pan_id 0xbeef +//! ip link add link wpan0 name lowpan0 type lowpan +//! ip link set wpan0 up +//! ip link set wpan1 up +//! ip link set lowpan0 up +//! +//! # Running +//! +//! Run it with `sudo ./target/debug/examples/sixlowpan`. +//! +//! You can set wireshark to sniff on interface `wpan0` to see the packets. +//! +//! Ping it with `ping fe80::180b:4242:4242:4242%lowpan0`. +//! +//! Speak UDP with `nc -uv fe80::180b:4242:4242:4242%lowpan0 6969`. +//! +//! # Teardown +//! +//! rmmod mac802154_hwsim +//! + +mod utils; + +use log::debug; +use std::os::unix::io::AsRawFd; +use std::str; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium, RawSocket}; +use smoltcp::socket::tcp; +use smoltcp::socket::udp; +use smoltcp::time::Instant; +use smoltcp::wire::{EthernetAddress, Ieee802154Address, Ieee802154Pan, IpAddress, IpCidr}; + +fn main() { + utils::setup_logging(""); + + let (mut opts, mut free) = utils::create_options(); + utils::add_middleware_options(&mut opts, &mut free); + + let mut matches = utils::parse_options(&opts, free); + + let device = RawSocket::new("wpan1", Medium::Ieee802154).unwrap(); + let fd = device.as_raw_fd(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => Config::new( + Ieee802154Address::Extended([0x1a, 0x0b, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42]).into(), + ), + }; + config.random_seed = rand::random(); + config.pan_id = Some(Ieee802154Pan(0xbeef)); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new( + IpAddress::v6(0xfe80, 0, 0, 0, 0x180b, 0x4242, 0x4242, 0x4242), + 64, + )) + .unwrap(); + }); + + // Create sockets + let udp_rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 1280]); + let udp_tx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 1280]); + let udp_socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); + + let tcp_rx_buffer = tcp::SocketBuffer::new(vec![0; 4096]); + let tcp_tx_buffer = tcp::SocketBuffer::new(vec![0; 4096]); + let tcp_socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer); + + let mut sockets = SocketSet::new(vec![]); + let udp_handle = sockets.add(udp_socket); + let tcp_handle = sockets.add(tcp_socket); + + let socket = sockets.get_mut::<tcp::Socket>(tcp_handle); + socket.listen(50000).unwrap(); + + let mut tcp_active = false; + + loop { + let timestamp = Instant::now(); + iface.poll(timestamp, &mut device, &mut sockets); + + // udp:6969: respond "hello" + let socket = sockets.get_mut::<udp::Socket>(udp_handle); + if !socket.is_open() { + socket.bind(6969).unwrap() + } + + let mut buffer = vec![0; 1500]; + let client = match socket.recv() { + Ok((data, endpoint)) => { + debug!( + "udp:6969 recv data: {:?} from {}", + str::from_utf8(data).unwrap(), + endpoint + ); + buffer[..data.len()].copy_from_slice(data); + Some((data.len(), endpoint)) + } + Err(_) => None, + }; + if let Some((len, endpoint)) = client { + debug!( + "udp:6969 send data: {:?}", + str::from_utf8(&buffer[..len]).unwrap() + ); + socket.send_slice(&buffer[..len], endpoint).unwrap(); + } + + let socket = sockets.get_mut::<tcp::Socket>(tcp_handle); + if socket.is_active() && !tcp_active { + debug!("connected"); + } else if !socket.is_active() && tcp_active { + debug!("disconnected"); + } + tcp_active = socket.is_active(); + + if socket.may_recv() { + let data = socket + .recv(|data| { + let data = data.to_owned(); + if !data.is_empty() { + debug!( + "recv data: {:?}", + str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)") + ); + } + (data.len(), data) + }) + .unwrap(); + + if socket.can_send() && !data.is_empty() { + debug!( + "send data: {:?}", + str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)") + ); + socket.send_slice(&data[..]).unwrap(); + } + } else if socket.may_send() { + debug!("close"); + socket.close(); + } + + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); + } +} diff --git a/examples/sixlowpan_benchmark.rs b/examples/sixlowpan_benchmark.rs new file mode 100644 index 0000000..4e61491 --- /dev/null +++ b/examples/sixlowpan_benchmark.rs @@ -0,0 +1,235 @@ +//! 6lowpan benchmark example +//! +//! This example runs a simple TCP throughput benchmark using the 6lowpan implementation in smoltcp +//! It is designed to run using the Linux ieee802154/6lowpan support, +//! using mac802154_hwsim. +//! +//! mac802154_hwsim allows you to create multiple "virtual" radios and specify +//! which is in range with which. This is very useful for testing without +//! needing real hardware. By default it creates two interfaces `wpan0` and +//! `wpan1` that are in range with each other. You can customize this with +//! the `wpan-hwsim` tool. +//! +//! We'll configure Linux to speak 6lowpan on `wpan0`, and leave `wpan1` +//! unconfigured so smoltcp can use it with a raw socket. +//! +//! +//! +//! +//! +//! # Setup +//! +//! modprobe mac802154_hwsim +//! +//! ip link set wpan0 down +//! ip link set wpan1 down +//! iwpan dev wpan0 set pan_id 0xbeef +//! iwpan dev wpan1 set pan_id 0xbeef +//! ip link add link wpan0 name lowpan0 type lowpan +//! ip link set wpan0 up +//! ip link set wpan1 up +//! ip link set lowpan0 up +//! +//! +//! # Running +//! +//! Compile with `cargo build --release --example sixlowpan_benchmark` +//! Run it with `sudo ./target/release/examples/sixlowpan_benchmark [reader|writer]`. +//! +//! # Teardown +//! +//! rmmod mac802154_hwsim +//! + +mod utils; + +use std::os::unix::io::AsRawFd; +use std::str; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium, RawSocket}; +use smoltcp::socket::tcp; +use smoltcp::wire::{EthernetAddress, Ieee802154Address, Ieee802154Pan, IpAddress, IpCidr}; + +//For benchmark +use smoltcp::time::{Duration, Instant}; +use std::cmp; +use std::io::{Read, Write}; +use std::net::SocketAddrV6; +use std::net::TcpStream; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::thread; + +use std::fs; + +fn if_nametoindex(ifname: &str) -> u32 { + let contents = fs::read_to_string(format!("/sys/devices/virtual/net/{ifname}/ifindex")) + .expect("couldn't read interface from \"/sys/devices/virtual/net\"") + .replace('\n', ""); + contents.parse::<u32>().unwrap() +} + +const AMOUNT: usize = 100_000_000; + +enum Client { + Reader, + Writer, +} + +fn client(kind: Client) { + let port: u16 = match kind { + Client::Reader => 1234, + Client::Writer => 1235, + }; + + let scope_id = if_nametoindex("lowpan0"); + + let socket_addr = SocketAddrV6::new( + "fe80:0:0:0:180b:4242:4242:4242".parse().unwrap(), + port, + 0, + scope_id, + ); + + let mut stream = TcpStream::connect(socket_addr).expect("failed to connect TLKAGMKA"); + let mut buffer = vec![0; 1_000_000]; + + let start = Instant::now(); + + let mut processed = 0; + while processed < AMOUNT { + let length = cmp::min(buffer.len(), AMOUNT - processed); + let result = match kind { + Client::Reader => stream.read(&mut buffer[..length]), + Client::Writer => stream.write(&buffer[..length]), + }; + match result { + Ok(0) => break, + Ok(result) => { + // print!("(P:{})", result); + processed += result + } + Err(err) => panic!("cannot process: {err}"), + } + } + + let end = Instant::now(); + + let elapsed = (end - start).total_millis() as f64 / 1000.0; + + println!("throughput: {:.3} Gbps", AMOUNT as f64 / elapsed / 0.125e9); + + CLIENT_DONE.store(true, Ordering::SeqCst); +} + +static CLIENT_DONE: AtomicBool = AtomicBool::new(false); + +fn main() { + #[cfg(feature = "log")] + utils::setup_logging("info"); + + let (mut opts, mut free) = utils::create_options(); + utils::add_middleware_options(&mut opts, &mut free); + free.push("MODE"); + + let mut matches = utils::parse_options(&opts, free); + + let device = RawSocket::new("wpan1", Medium::Ieee802154).unwrap(); + + let fd = device.as_raw_fd(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + + let mode = match matches.free[0].as_ref() { + "reader" => Client::Reader, + "writer" => Client::Writer, + _ => panic!("invalid mode"), + }; + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => Config::new( + Ieee802154Address::Extended([0x1a, 0x0b, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42]).into(), + ), + }; + config.random_seed = rand::random(); + config.pan_id = Some(Ieee802154Pan(0xbeef)); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new( + IpAddress::v6(0xfe80, 0, 0, 0, 0x180b, 0x4242, 0x4242, 0x4242), + 64, + )) + .unwrap(); + }); + + let tcp1_rx_buffer = tcp::SocketBuffer::new(vec![0; 4096]); + let tcp1_tx_buffer = tcp::SocketBuffer::new(vec![0; 4096]); + let tcp1_socket = tcp::Socket::new(tcp1_rx_buffer, tcp1_tx_buffer); + + let tcp2_rx_buffer = tcp::SocketBuffer::new(vec![0; 4096]); + let tcp2_tx_buffer = tcp::SocketBuffer::new(vec![0; 4096]); + let tcp2_socket = tcp::Socket::new(tcp2_rx_buffer, tcp2_tx_buffer); + + let mut sockets = SocketSet::new(vec![]); + let tcp1_handle = sockets.add(tcp1_socket); + let tcp2_handle = sockets.add(tcp2_socket); + + let default_timeout = Some(Duration::from_millis(1000)); + + thread::spawn(move || client(mode)); + let mut processed = 0; + + while !CLIENT_DONE.load(Ordering::SeqCst) { + let timestamp = Instant::now(); + iface.poll(timestamp, &mut device, &mut sockets); + + // tcp:1234: emit data + let socket = sockets.get_mut::<tcp::Socket>(tcp1_handle); + if !socket.is_open() { + socket.listen(1234).unwrap(); + } + + if socket.can_send() && processed < AMOUNT { + let length = socket + .send(|buffer| { + let length = cmp::min(buffer.len(), AMOUNT - processed); + (length, length) + }) + .unwrap(); + processed += length; + } + + // tcp:1235: sink data + let socket = sockets.get_mut::<tcp::Socket>(tcp2_handle); + if !socket.is_open() { + socket.listen(1235).unwrap(); + } + + if socket.can_recv() && processed < AMOUNT { + let length = socket + .recv(|buffer| { + let length = cmp::min(buffer.len(), AMOUNT - processed); + (length, length) + }) + .unwrap(); + processed += length; + } + + match iface.poll_at(timestamp, &sockets) { + Some(poll_at) if timestamp < poll_at => { + phy_wait(fd, Some(poll_at - timestamp)).expect("wait error"); + } + Some(_) => (), + None => { + phy_wait(fd, default_timeout).expect("wait error"); + } + } + } +} diff --git a/examples/tcpdump.rs b/examples/tcpdump.rs new file mode 100644 index 0000000..2baf376 --- /dev/null +++ b/examples/tcpdump.rs @@ -0,0 +1,21 @@ +use smoltcp::phy::wait as phy_wait; +use smoltcp::phy::{Device, RawSocket, RxToken}; +use smoltcp::time::Instant; +use smoltcp::wire::{EthernetFrame, PrettyPrinter}; +use std::env; +use std::os::unix::io::AsRawFd; + +fn main() { + let ifname = env::args().nth(1).unwrap(); + let mut socket = RawSocket::new(ifname.as_ref(), smoltcp::phy::Medium::Ethernet).unwrap(); + loop { + phy_wait(socket.as_raw_fd(), None).unwrap(); + let (rx_token, _) = socket.receive(Instant::now()).unwrap(); + rx_token.consume(|buffer| { + println!( + "{}", + PrettyPrinter::<EthernetFrame<&[u8]>>::new("", &buffer) + ); + }) + } +} diff --git a/examples/utils.rs b/examples/utils.rs new file mode 100644 index 0000000..dbe9076 --- /dev/null +++ b/examples/utils.rs @@ -0,0 +1,218 @@ +#![allow(dead_code)] + +#[cfg(feature = "log")] +use env_logger::Builder; +use getopts::{Matches, Options}; +#[cfg(feature = "log")] +use log::{trace, Level, LevelFilter}; +use std::env; +use std::fs::File; +use std::io::{self, Write}; +use std::process; +use std::str::{self, FromStr}; +use std::time::{SystemTime, UNIX_EPOCH}; + +#[cfg(feature = "phy-tuntap_interface")] +use smoltcp::phy::TunTapInterface; +use smoltcp::phy::{Device, FaultInjector, Medium, Tracer}; +use smoltcp::phy::{PcapMode, PcapWriter}; +use smoltcp::time::{Duration, Instant}; + +#[cfg(feature = "log")] +pub fn setup_logging_with_clock<F>(filter: &str, since_startup: F) +where + F: Fn() -> Instant + Send + Sync + 'static, +{ + Builder::new() + .format(move |buf, record| { + let elapsed = since_startup(); + let timestamp = format!("[{elapsed}]"); + if record.target().starts_with("smoltcp::") { + writeln!( + buf, + "\x1b[0m{} ({}): {}\x1b[0m", + timestamp, + record.target().replace("smoltcp::", ""), + record.args() + ) + } else if record.level() == Level::Trace { + let message = format!("{}", record.args()); + writeln!( + buf, + "\x1b[37m{} {}\x1b[0m", + timestamp, + message.replace('\n', "\n ") + ) + } else { + writeln!( + buf, + "\x1b[32m{} ({}): {}\x1b[0m", + timestamp, + record.target(), + record.args() + ) + } + }) + .filter(None, LevelFilter::Trace) + .parse_filters(filter) + .parse_env("RUST_LOG") + .init(); +} + +#[cfg(feature = "log")] +pub fn setup_logging(filter: &str) { + setup_logging_with_clock(filter, Instant::now) +} + +pub fn create_options() -> (Options, Vec<&'static str>) { + let mut opts = Options::new(); + opts.optflag("h", "help", "print this help menu"); + (opts, Vec::new()) +} + +pub fn parse_options(options: &Options, free: Vec<&str>) -> Matches { + match options.parse(env::args().skip(1)) { + Err(err) => { + println!("{err}"); + process::exit(1) + } + Ok(matches) => { + if matches.opt_present("h") || matches.free.len() != free.len() { + let brief = format!( + "Usage: {} [OPTION]... {}", + env::args().next().unwrap(), + free.join(" ") + ); + print!("{}", options.usage(&brief)); + process::exit((matches.free.len() != free.len()) as _); + } + matches + } + } +} + +pub fn add_tuntap_options(opts: &mut Options, _free: &mut [&str]) { + opts.optopt("", "tun", "TUN interface to use", "tun0"); + opts.optopt("", "tap", "TAP interface to use", "tap0"); +} + +#[cfg(feature = "phy-tuntap_interface")] +pub fn parse_tuntap_options(matches: &mut Matches) -> TunTapInterface { + let tun = matches.opt_str("tun"); + let tap = matches.opt_str("tap"); + match (tun, tap) { + (Some(tun), None) => TunTapInterface::new(&tun, Medium::Ip).unwrap(), + (None, Some(tap)) => TunTapInterface::new(&tap, Medium::Ethernet).unwrap(), + _ => panic!("You must specify exactly one of --tun or --tap"), + } +} + +pub fn add_middleware_options(opts: &mut Options, _free: &mut [&str]) { + opts.optopt("", "pcap", "Write a packet capture file", "FILE"); + opts.optopt( + "", + "drop-chance", + "Chance of dropping a packet (%)", + "CHANCE", + ); + opts.optopt( + "", + "corrupt-chance", + "Chance of corrupting a packet (%)", + "CHANCE", + ); + opts.optopt( + "", + "size-limit", + "Drop packets larger than given size (octets)", + "SIZE", + ); + opts.optopt( + "", + "tx-rate-limit", + "Drop packets after transmit rate exceeds given limit \ + (packets per interval)", + "RATE", + ); + opts.optopt( + "", + "rx-rate-limit", + "Drop packets after transmit rate exceeds given limit \ + (packets per interval)", + "RATE", + ); + opts.optopt( + "", + "shaping-interval", + "Sets the interval for rate limiting (ms)", + "RATE", + ); +} + +pub fn parse_middleware_options<D>( + matches: &mut Matches, + device: D, + loopback: bool, +) -> FaultInjector<Tracer<PcapWriter<D, Box<dyn io::Write>>>> +where + D: Device, +{ + let drop_chance = matches + .opt_str("drop-chance") + .map(|s| u8::from_str(&s).unwrap()) + .unwrap_or(0); + let corrupt_chance = matches + .opt_str("corrupt-chance") + .map(|s| u8::from_str(&s).unwrap()) + .unwrap_or(0); + let size_limit = matches + .opt_str("size-limit") + .map(|s| usize::from_str(&s).unwrap()) + .unwrap_or(0); + let tx_rate_limit = matches + .opt_str("tx-rate-limit") + .map(|s| u64::from_str(&s).unwrap()) + .unwrap_or(0); + let rx_rate_limit = matches + .opt_str("rx-rate-limit") + .map(|s| u64::from_str(&s).unwrap()) + .unwrap_or(0); + let shaping_interval = matches + .opt_str("shaping-interval") + .map(|s| u64::from_str(&s).unwrap()) + .unwrap_or(0); + + let pcap_writer: Box<dyn io::Write> = match matches.opt_str("pcap") { + Some(pcap_filename) => Box::new(File::create(pcap_filename).expect("cannot open file")), + None => Box::new(io::sink()), + }; + + let seed = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .subsec_nanos(); + + let device = PcapWriter::new( + device, + pcap_writer, + if loopback { + PcapMode::TxOnly + } else { + PcapMode::Both + }, + ); + + let device = Tracer::new(device, |_timestamp, _printer| { + #[cfg(feature = "log")] + trace!("{}", _printer); + }); + + let mut device = FaultInjector::new(device, seed); + device.set_drop_chance(drop_chance); + device.set_corrupt_chance(corrupt_chance); + device.set_max_packet_size(size_limit); + device.set_max_tx_rate(tx_rate_limit); + device.set_max_rx_rate(rx_rate_limit); + device.set_bucket_interval(Duration::from_millis(shaping_interval)); + device +} diff --git a/gen_config.py b/gen_config.py new file mode 100644 index 0000000..2569192 --- /dev/null +++ b/gen_config.py @@ -0,0 +1,86 @@ +import os + +abspath = os.path.abspath(__file__) +dname = os.path.dirname(abspath) +os.chdir(dname) + +features = [] + + +def feature(name, default, min, max, pow2=None): + vals = set() + val = min + while val <= max: + vals.add(val) + if pow2 == True or (isinstance(pow2, int) and val >= pow2): + val *= 2 + else: + val += 1 + vals.add(default) + + features.append( + { + "name": name, + "default": default, + "vals": sorted(list(vals)), + } + ) + + +feature("iface_max_addr_count", default=2, min=1, max=8) +feature("iface_max_multicast_group_count", default=4, min=1, max=1024, pow2=8) +feature("iface_max_sixlowpan_address_context_count", default=4, min=1, max=1024, pow2=8) +feature("iface_neighbor_cache_count", default=4, min=1, max=1024, pow2=8) +feature("iface_max_route_count", default=2, min=1, max=1024, pow2=8) +feature("fragmentation_buffer_size", default=1500, min=256, max=65536, pow2=True) +feature("assembler_max_segment_count", default=4, min=1, max=32, pow2=4) +feature("reassembly_buffer_size", default=1500, min=256, max=65536, pow2=True) +feature("reassembly_buffer_count", default=1, min=1, max=32, pow2=4) +feature("ipv6_hbh_max_options", default=1, min=1, max=32, pow2=4) +feature("dns_max_result_count", default=1, min=1, max=32, pow2=4) +feature("dns_max_server_count", default=1, min=1, max=32, pow2=4) +feature("dns_max_name_size", default=255, min=64, max=255, pow2=True) +feature("rpl_relations_buffer_count", default=16, min=1, max=128, pow2=True) +feature("rpl_parents_buffer_count", default=8, min=2, max=32, pow2=True) + +# ========= Update Cargo.toml + +things = "" +for f in features: + name = f["name"].replace("_", "-") + for val in f["vals"]: + things += f"{name}-{val} = []" + if val == f["default"]: + things += " # Default" + things += "\n" + things += "\n" + +SEPARATOR_START = "# BEGIN AUTOGENERATED CONFIG FEATURES\n" +SEPARATOR_END = "# END AUTOGENERATED CONFIG FEATURES\n" +HELP = "# Generated by gen_config.py. DO NOT EDIT.\n" +with open("Cargo.toml", "r") as f: + data = f.read() +before, data = data.split(SEPARATOR_START, maxsplit=1) +_, after = data.split(SEPARATOR_END, maxsplit=1) +data = before + SEPARATOR_START + HELP + things + SEPARATOR_END + after +with open("Cargo.toml", "w") as f: + f.write(data) + + +# ========= Update build.rs + +things = "" +for f in features: + name = f["name"].upper() + things += f' ("{name}", {f["default"]}),\n' + +SEPARATOR_START = "// BEGIN AUTOGENERATED CONFIG FEATURES\n" +SEPARATOR_END = "// END AUTOGENERATED CONFIG FEATURES\n" +HELP = " // Generated by gen_config.py. DO NOT EDIT.\n" +with open("build.rs", "r") as f: + data = f.read() +before, data = data.split(SEPARATOR_START, maxsplit=1) +_, after = data.split(SEPARATOR_END, maxsplit=1) +data = before + SEPARATOR_START + HELP + things + " " + SEPARATOR_END + after +with open("build.rs", "w") as f: + f.write(data) diff --git a/src/iface/fragmentation.rs b/src/iface/fragmentation.rs new file mode 100644 index 0000000..ed00f17 --- /dev/null +++ b/src/iface/fragmentation.rs @@ -0,0 +1,506 @@ +#![allow(unused)] + +use core::fmt; + +use managed::{ManagedMap, ManagedSlice}; + +use crate::config::{FRAGMENTATION_BUFFER_SIZE, REASSEMBLY_BUFFER_COUNT, REASSEMBLY_BUFFER_SIZE}; +use crate::storage::Assembler; +use crate::time::{Duration, Instant}; +use crate::wire::*; + +use core::result::Result; + +#[cfg(feature = "alloc")] +type Buffer = alloc::vec::Vec<u8>; +#[cfg(not(feature = "alloc"))] +type Buffer = [u8; REASSEMBLY_BUFFER_SIZE]; + +/// Problem when assembling: something was out of bounds. +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct AssemblerError; + +impl fmt::Display for AssemblerError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "AssemblerError") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for AssemblerError {} + +/// Packet assembler is full +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct AssemblerFullError; + +impl fmt::Display for AssemblerFullError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "AssemblerFullError") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for AssemblerFullError {} + +/// Holds different fragments of one packet, used for assembling fragmented packets. +/// +/// The buffer used for the `PacketAssembler` should either be dynamically sized (ex: Vec<u8>) +/// or should be statically allocated based upon the MTU of the type of packet being +/// assembled (ex: 1280 for a IPv6 frame). +#[derive(Debug)] +pub struct PacketAssembler<K> { + key: Option<K>, + buffer: Buffer, + + assembler: Assembler, + total_size: Option<usize>, + expires_at: Instant, +} + +impl<K> PacketAssembler<K> { + /// Create a new empty buffer for fragments. + pub const fn new() -> Self { + Self { + key: None, + + #[cfg(feature = "alloc")] + buffer: Buffer::new(), + #[cfg(not(feature = "alloc"))] + buffer: [0u8; REASSEMBLY_BUFFER_SIZE], + + assembler: Assembler::new(), + total_size: None, + expires_at: Instant::ZERO, + } + } + + pub(crate) fn reset(&mut self) { + self.key = None; + self.assembler.clear(); + self.total_size = None; + self.expires_at = Instant::ZERO; + } + + /// Set the total size of the packet assembler. + pub(crate) fn set_total_size(&mut self, size: usize) -> Result<(), AssemblerError> { + if let Some(old_size) = self.total_size { + if old_size != size { + return Err(AssemblerError); + } + } + + #[cfg(not(feature = "alloc"))] + if self.buffer.len() < size { + return Err(AssemblerError); + } + + #[cfg(feature = "alloc")] + if self.buffer.len() < size { + self.buffer.resize(size, 0); + } + + self.total_size = Some(size); + Ok(()) + } + + /// Return the instant when the assembler expires. + pub(crate) fn expires_at(&self) -> Instant { + self.expires_at + } + + pub(crate) fn add_with( + &mut self, + offset: usize, + f: impl Fn(&mut [u8]) -> Result<usize, AssemblerError>, + ) -> Result<(), AssemblerError> { + if self.buffer.len() < offset { + return Err(AssemblerError); + } + + let len = f(&mut self.buffer[offset..])?; + assert!(offset + len <= self.buffer.len()); + + net_debug!( + "frag assembler: receiving {} octets at offset {}", + len, + offset + ); + + self.assembler.add(offset, len); + Ok(()) + } + + /// Add a fragment into the packet that is being reassembled. + /// + /// # Errors + /// + /// - Returns [`Error::PacketAssemblerBufferTooSmall`] when trying to add data into the buffer at a non-existing + /// place. + pub(crate) fn add(&mut self, data: &[u8], offset: usize) -> Result<(), AssemblerError> { + #[cfg(not(feature = "alloc"))] + if self.buffer.len() < offset + data.len() { + return Err(AssemblerError); + } + + #[cfg(feature = "alloc")] + if self.buffer.len() < offset + data.len() { + self.buffer.resize(offset + data.len(), 0); + } + + let len = data.len(); + self.buffer[offset..][..len].copy_from_slice(data); + + net_debug!( + "frag assembler: receiving {} octets at offset {}", + len, + offset + ); + + self.assembler.add(offset, data.len()); + Ok(()) + } + + /// Get an immutable slice of the underlying packet data, if reassembly complete. + /// This will mark the assembler as empty, so that it can be reused. + pub(crate) fn assemble(&mut self) -> Option<&'_ [u8]> { + if !self.is_complete() { + return None; + } + + // NOTE: we can unwrap because `is_complete` already checks this. + let total_size = self.total_size.unwrap(); + self.reset(); + Some(&self.buffer[..total_size]) + } + + /// Returns `true` when all fragments have been received, otherwise `false`. + pub(crate) fn is_complete(&self) -> bool { + self.total_size == Some(self.assembler.peek_front()) + } + + /// Returns `true` when the packet assembler is free to use. + fn is_free(&self) -> bool { + self.key.is_none() + } +} + +/// Set holding multiple [`PacketAssembler`]. +#[derive(Debug)] +pub struct PacketAssemblerSet<K: Eq + Copy> { + assemblers: [PacketAssembler<K>; REASSEMBLY_BUFFER_COUNT], +} + +impl<K: Eq + Copy> PacketAssemblerSet<K> { + const NEW_PA: PacketAssembler<K> = PacketAssembler::new(); + + /// Create a new set of packet assemblers. + pub fn new() -> Self { + Self { + assemblers: [Self::NEW_PA; REASSEMBLY_BUFFER_COUNT], + } + } + + /// Get a [`PacketAssembler`] for a specific key. + /// + /// If it doesn't exist, it is created, with the `expires_at` timestamp. + /// + /// If the assembler set is full, in which case an error is returned. + pub(crate) fn get( + &mut self, + key: &K, + expires_at: Instant, + ) -> Result<&mut PacketAssembler<K>, AssemblerFullError> { + let mut empty_slot = None; + for slot in &mut self.assemblers { + if slot.key.as_ref() == Some(key) { + return Ok(slot); + } + if slot.is_free() { + empty_slot = Some(slot) + } + } + + let slot = empty_slot.ok_or(AssemblerFullError)?; + slot.key = Some(*key); + slot.expires_at = expires_at; + Ok(slot) + } + + /// Remove all [`PacketAssembler`]s that are expired. + pub fn remove_expired(&mut self, timestamp: Instant) { + for frag in &mut self.assemblers { + if !frag.is_free() && frag.expires_at < timestamp { + frag.reset(); + } + } + } +} + +// Max len of non-fragmented packets after decompression (including ipv6 header and payload) +// TODO: lower. Should be (6lowpan mtu) - (min 6lowpan header size) + (max ipv6 header size) +pub(crate) const MAX_DECOMPRESSED_LEN: usize = 1500; + +#[cfg(feature = "_proto-fragmentation")] +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) enum FragKey { + #[cfg(feature = "proto-ipv4-fragmentation")] + Ipv4(Ipv4FragKey), + #[cfg(feature = "proto-sixlowpan-fragmentation")] + Sixlowpan(SixlowpanFragKey), +} + +pub(crate) struct FragmentsBuffer { + #[cfg(feature = "proto-sixlowpan")] + pub decompress_buf: [u8; MAX_DECOMPRESSED_LEN], + + #[cfg(feature = "_proto-fragmentation")] + pub assembler: PacketAssemblerSet<FragKey>, + + #[cfg(feature = "_proto-fragmentation")] + pub reassembly_timeout: Duration, +} + +#[cfg(not(feature = "_proto-fragmentation"))] +pub(crate) struct Fragmenter {} + +#[cfg(not(feature = "_proto-fragmentation"))] +impl Fragmenter { + pub(crate) fn new() -> Self { + Self {} + } +} + +#[cfg(feature = "_proto-fragmentation")] +pub(crate) struct Fragmenter { + /// The buffer that holds the unfragmented 6LoWPAN packet. + pub buffer: [u8; FRAGMENTATION_BUFFER_SIZE], + /// The size of the packet without the IEEE802.15.4 header and the fragmentation headers. + pub packet_len: usize, + /// The amount of bytes that already have been transmitted. + pub sent_bytes: usize, + + #[cfg(feature = "proto-ipv4-fragmentation")] + pub ipv4: Ipv4Fragmenter, + #[cfg(feature = "proto-sixlowpan-fragmentation")] + pub sixlowpan: SixlowpanFragmenter, +} + +#[cfg(feature = "proto-ipv4-fragmentation")] +pub(crate) struct Ipv4Fragmenter { + /// The IPv4 representation. + pub repr: Ipv4Repr, + /// The destination hardware address. + #[cfg(feature = "medium-ethernet")] + pub dst_hardware_addr: EthernetAddress, + /// The offset of the next fragment. + pub frag_offset: u16, + /// The identifier of the stream. + pub ident: u16, +} + +#[cfg(feature = "proto-sixlowpan-fragmentation")] +pub(crate) struct SixlowpanFragmenter { + /// The datagram size that is used for the fragmentation headers. + pub datagram_size: u16, + /// The datagram tag that is used for the fragmentation headers. + pub datagram_tag: u16, + pub datagram_offset: usize, + + /// The size of the FRAG_N packets. + pub fragn_size: usize, + + /// The link layer IEEE802.15.4 source address. + pub ll_dst_addr: Ieee802154Address, + /// The link layer IEEE802.15.4 source address. + pub ll_src_addr: Ieee802154Address, +} + +#[cfg(feature = "_proto-fragmentation")] +impl Fragmenter { + pub(crate) fn new() -> Self { + Self { + buffer: [0u8; FRAGMENTATION_BUFFER_SIZE], + packet_len: 0, + sent_bytes: 0, + + #[cfg(feature = "proto-ipv4-fragmentation")] + ipv4: Ipv4Fragmenter { + repr: Ipv4Repr { + src_addr: Ipv4Address::default(), + dst_addr: Ipv4Address::default(), + next_header: IpProtocol::Unknown(0), + payload_len: 0, + hop_limit: 0, + }, + #[cfg(feature = "medium-ethernet")] + dst_hardware_addr: EthernetAddress::default(), + frag_offset: 0, + ident: 0, + }, + + #[cfg(feature = "proto-sixlowpan-fragmentation")] + sixlowpan: SixlowpanFragmenter { + datagram_size: 0, + datagram_tag: 0, + datagram_offset: 0, + fragn_size: 0, + ll_dst_addr: Ieee802154Address::Absent, + ll_src_addr: Ieee802154Address::Absent, + }, + } + } + + /// Return `true` when everything is transmitted. + #[inline] + pub(crate) fn finished(&self) -> bool { + self.packet_len == self.sent_bytes + } + + /// Returns `true` when there is nothing to transmit. + #[inline] + pub(crate) fn is_empty(&self) -> bool { + self.packet_len == 0 + } + + // Reset the buffer. + pub(crate) fn reset(&mut self) { + self.packet_len = 0; + self.sent_bytes = 0; + + #[cfg(feature = "proto-ipv4-fragmentation")] + { + self.ipv4.repr = Ipv4Repr { + src_addr: Ipv4Address::default(), + dst_addr: Ipv4Address::default(), + next_header: IpProtocol::Unknown(0), + payload_len: 0, + hop_limit: 0, + }; + #[cfg(feature = "medium-ethernet")] + { + self.ipv4.dst_hardware_addr = EthernetAddress::default(); + } + } + + #[cfg(feature = "proto-sixlowpan-fragmentation")] + { + self.sixlowpan.datagram_size = 0; + self.sixlowpan.datagram_tag = 0; + self.sixlowpan.fragn_size = 0; + self.sixlowpan.ll_dst_addr = Ieee802154Address::Absent; + self.sixlowpan.ll_src_addr = Ieee802154Address::Absent; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] + struct Key { + id: usize, + } + + #[test] + fn packet_assembler_overlap() { + let mut p_assembler = PacketAssembler::<Key>::new(); + + p_assembler.set_total_size(5).unwrap(); + + let data = b"Rust"; + p_assembler.add(&data[..], 0); + p_assembler.add(&data[..], 1); + + assert_eq!(p_assembler.assemble(), Some(&b"RRust"[..])) + } + + #[test] + fn packet_assembler_assemble() { + let mut p_assembler = PacketAssembler::<Key>::new(); + + let data = b"Hello World!"; + + p_assembler.set_total_size(data.len()).unwrap(); + + p_assembler.add(b"Hello ", 0).unwrap(); + assert_eq!(p_assembler.assemble(), None); + + p_assembler.add(b"World!", b"Hello ".len()).unwrap(); + + assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..])); + } + + #[test] + fn packet_assembler_out_of_order_assemble() { + let mut p_assembler = PacketAssembler::<Key>::new(); + + let data = b"Hello World!"; + + p_assembler.set_total_size(data.len()).unwrap(); + + p_assembler.add(b"World!", b"Hello ".len()).unwrap(); + assert_eq!(p_assembler.assemble(), None); + + p_assembler.add(b"Hello ", 0).unwrap(); + + assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..])); + } + + #[test] + fn packet_assembler_set() { + let key = Key { id: 1 }; + + let mut set = PacketAssemblerSet::new(); + + assert!(set.get(&key, Instant::ZERO).is_ok()); + } + + #[test] + fn packet_assembler_set_full() { + let mut set = PacketAssemblerSet::new(); + for i in 0..REASSEMBLY_BUFFER_COUNT { + set.get(&Key { id: i }, Instant::ZERO).unwrap(); + } + assert!(set.get(&Key { id: 4 }, Instant::ZERO).is_err()); + } + + #[test] + fn packet_assembler_set_assembling_many() { + let mut set = PacketAssemblerSet::new(); + + let key = Key { id: 0 }; + let assr = set.get(&key, Instant::ZERO).unwrap(); + assert_eq!(assr.assemble(), None); + assr.set_total_size(0).unwrap(); + assr.assemble().unwrap(); + + // Test that `.assemble()` effectively deletes it. + let assr = set.get(&key, Instant::ZERO).unwrap(); + assert_eq!(assr.assemble(), None); + assr.set_total_size(0).unwrap(); + assr.assemble().unwrap(); + + let key = Key { id: 1 }; + let assr = set.get(&key, Instant::ZERO).unwrap(); + assr.set_total_size(0).unwrap(); + assr.assemble().unwrap(); + + let key = Key { id: 2 }; + let assr = set.get(&key, Instant::ZERO).unwrap(); + assr.set_total_size(0).unwrap(); + assr.assemble().unwrap(); + + let key = Key { id: 2 }; + let assr = set.get(&key, Instant::ZERO).unwrap(); + assr.set_total_size(2).unwrap(); + assr.add(&[0x00], 0).unwrap(); + assert_eq!(assr.assemble(), None); + let assr = set.get(&key, Instant::ZERO).unwrap(); + assr.add(&[0x01], 1).unwrap(); + assert_eq!(assr.assemble(), Some(&[0x00, 0x01][..])); + } +} diff --git a/src/iface/interface/ethernet.rs b/src/iface/interface/ethernet.rs new file mode 100644 index 0000000..e2555a1 --- /dev/null +++ b/src/iface/interface/ethernet.rs @@ -0,0 +1,76 @@ +use super::check; +use super::DispatchError; +use super::EthernetPacket; +use super::FragmentsBuffer; +use super::InterfaceInner; +use super::SocketSet; +use core::result::Result; + +use crate::phy::TxToken; +use crate::wire::*; + +impl InterfaceInner { + #[cfg(feature = "medium-ethernet")] + pub(super) fn process_ethernet<'frame>( + &mut self, + sockets: &mut SocketSet, + meta: crate::phy::PacketMeta, + frame: &'frame [u8], + fragments: &'frame mut FragmentsBuffer, + ) -> Option<EthernetPacket<'frame>> { + let eth_frame = check!(EthernetFrame::new_checked(frame)); + + // Ignore any packets not directed to our hardware address or any of the multicast groups. + if !eth_frame.dst_addr().is_broadcast() + && !eth_frame.dst_addr().is_multicast() + && HardwareAddress::Ethernet(eth_frame.dst_addr()) != self.hardware_addr + { + return None; + } + + match eth_frame.ethertype() { + #[cfg(feature = "proto-ipv4")] + EthernetProtocol::Arp => self.process_arp(self.now, ð_frame), + #[cfg(feature = "proto-ipv4")] + EthernetProtocol::Ipv4 => { + let ipv4_packet = check!(Ipv4Packet::new_checked(eth_frame.payload())); + + self.process_ipv4(sockets, meta, &ipv4_packet, fragments) + .map(EthernetPacket::Ip) + } + #[cfg(feature = "proto-ipv6")] + EthernetProtocol::Ipv6 => { + let ipv6_packet = check!(Ipv6Packet::new_checked(eth_frame.payload())); + self.process_ipv6(sockets, meta, &ipv6_packet) + .map(EthernetPacket::Ip) + } + // Drop all other traffic. + _ => None, + } + } + + #[cfg(feature = "medium-ethernet")] + pub(super) fn dispatch_ethernet<Tx, F>( + &mut self, + tx_token: Tx, + buffer_len: usize, + f: F, + ) -> Result<(), DispatchError> + where + Tx: TxToken, + F: FnOnce(EthernetFrame<&mut [u8]>), + { + let tx_len = EthernetFrame::<&[u8]>::buffer_len(buffer_len); + tx_token.consume(tx_len, |tx_buffer| { + debug_assert!(tx_buffer.as_ref().len() == tx_len); + let mut frame = EthernetFrame::new_unchecked(tx_buffer); + + let src_addr = self.hardware_addr.ethernet_or_panic(); + frame.set_src_addr(src_addr); + + f(frame); + + Ok(()) + }) + } +} diff --git a/src/iface/interface/ieee802154.rs b/src/iface/interface/ieee802154.rs new file mode 100644 index 0000000..0feca5e --- /dev/null +++ b/src/iface/interface/ieee802154.rs @@ -0,0 +1,94 @@ +use super::*; + +use crate::phy::TxToken; +use crate::wire::*; + +impl InterfaceInner { + pub(super) fn process_ieee802154<'output, 'payload: 'output>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + sixlowpan_payload: &'payload [u8], + _fragments: &'output mut FragmentsBuffer, + ) -> Option<Packet<'output>> { + let ieee802154_frame = check!(Ieee802154Frame::new_checked(sixlowpan_payload)); + let ieee802154_repr = check!(Ieee802154Repr::parse(&ieee802154_frame)); + + if ieee802154_repr.frame_type != Ieee802154FrameType::Data { + return None; + } + + // Drop frames when the user has set a PAN id and the PAN id from frame is not equal to this + // When the user didn't set a PAN id (so it is None), then we accept all PAN id's. + // We always accept the broadcast PAN id. + if self.pan_id.is_some() + && ieee802154_repr.dst_pan_id != self.pan_id + && ieee802154_repr.dst_pan_id != Some(Ieee802154Pan::BROADCAST) + { + net_debug!( + "IEEE802.15.4: dropping {:?} because not our PAN id (or not broadcast)", + ieee802154_repr + ); + return None; + } + + match ieee802154_frame.payload() { + Some(payload) => { + self.process_sixlowpan(sockets, meta, &ieee802154_repr, payload, _fragments) + } + None => None, + } + } + + pub(super) fn dispatch_ieee802154<Tx: TxToken>( + &mut self, + ll_dst_a: Ieee802154Address, + tx_token: Tx, + meta: PacketMeta, + packet: Packet, + frag: &mut Fragmenter, + ) { + let ll_src_a = self.hardware_addr.ieee802154_or_panic(); + + // Create the IEEE802.15.4 header. + let ieee_repr = Ieee802154Repr { + frame_type: Ieee802154FrameType::Data, + security_enabled: false, + frame_pending: false, + ack_request: false, + sequence_number: Some(self.get_sequence_number()), + pan_id_compression: true, + frame_version: Ieee802154FrameVersion::Ieee802154_2003, + dst_pan_id: self.pan_id, + dst_addr: Some(ll_dst_a), + src_pan_id: self.pan_id, + src_addr: Some(ll_src_a), + }; + + self.dispatch_sixlowpan(tx_token, meta, packet, ieee_repr, frag); + } + + #[cfg(feature = "proto-sixlowpan-fragmentation")] + pub(super) fn dispatch_ieee802154_frag<Tx: TxToken>( + &mut self, + tx_token: Tx, + frag: &mut Fragmenter, + ) { + // Create the IEEE802.15.4 header. + let ieee_repr = Ieee802154Repr { + frame_type: Ieee802154FrameType::Data, + security_enabled: false, + frame_pending: false, + ack_request: false, + sequence_number: Some(self.get_sequence_number()), + pan_id_compression: true, + frame_version: Ieee802154FrameVersion::Ieee802154_2003, + dst_pan_id: self.pan_id, + dst_addr: Some(frag.sixlowpan.ll_dst_addr), + src_pan_id: self.pan_id, + src_addr: Some(frag.sixlowpan.ll_src_addr), + }; + + self.dispatch_sixlowpan_frag(tx_token, ieee_repr, frag); + } +} diff --git a/src/iface/interface/igmp.rs b/src/iface/interface/igmp.rs new file mode 100644 index 0000000..14856ca --- /dev/null +++ b/src/iface/interface/igmp.rs @@ -0,0 +1,275 @@ +use super::*; + +use crate::phy::{Device, PacketMeta}; +use crate::time::{Duration, Instant}; +use crate::wire::*; + +use core::result::Result; + +/// Error type for `join_multicast_group`, `leave_multicast_group`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum MulticastError { + /// The hardware device transmit buffer is full. Try again later. + Exhausted, + /// The table of joined multicast groups is already full. + GroupTableFull, + /// IPv6 multicast is not yet supported. + Ipv6NotSupported, +} + +impl core::fmt::Display for MulticastError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + MulticastError::Exhausted => write!(f, "Exhausted"), + MulticastError::GroupTableFull => write!(f, "GroupTableFull"), + MulticastError::Ipv6NotSupported => write!(f, "Ipv6NotSupported"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for MulticastError {} + +impl Interface { + /// Add an address to a list of subscribed multicast IP addresses. + /// + /// Returns `Ok(announce_sent)` if the address was added successfully, where `announce_sent` + /// indicates whether an initial immediate announcement has been sent. + pub fn join_multicast_group<D, T: Into<IpAddress>>( + &mut self, + device: &mut D, + addr: T, + timestamp: Instant, + ) -> Result<bool, MulticastError> + where + D: Device + ?Sized, + { + self.inner.now = timestamp; + + match addr.into() { + IpAddress::Ipv4(addr) => { + let is_not_new = self + .inner + .ipv4_multicast_groups + .insert(addr, ()) + .map_err(|_| MulticastError::GroupTableFull)? + .is_some(); + if is_not_new { + Ok(false) + } else if let Some(pkt) = self.inner.igmp_report_packet(IgmpVersion::Version2, addr) + { + // Send initial membership report + let tx_token = device + .transmit(timestamp) + .ok_or(MulticastError::Exhausted)?; + + // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. + self.inner + .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) + .unwrap(); + + Ok(true) + } else { + Ok(false) + } + } + // Multicast is not yet implemented for other address families + #[allow(unreachable_patterns)] + _ => Err(MulticastError::Ipv6NotSupported), + } + } + + /// Remove an address from the subscribed multicast IP addresses. + /// + /// Returns `Ok(leave_sent)` if the address was removed successfully, where `leave_sent` + /// indicates whether an immediate leave packet has been sent. + pub fn leave_multicast_group<D, T: Into<IpAddress>>( + &mut self, + device: &mut D, + addr: T, + timestamp: Instant, + ) -> Result<bool, MulticastError> + where + D: Device + ?Sized, + { + self.inner.now = timestamp; + + match addr.into() { + IpAddress::Ipv4(addr) => { + let was_not_present = self.inner.ipv4_multicast_groups.remove(&addr).is_none(); + if was_not_present { + Ok(false) + } else if let Some(pkt) = self.inner.igmp_leave_packet(addr) { + // Send group leave packet + let tx_token = device + .transmit(timestamp) + .ok_or(MulticastError::Exhausted)?; + + // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. + self.inner + .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) + .unwrap(); + + Ok(true) + } else { + Ok(false) + } + } + // Multicast is not yet implemented for other address families + #[allow(unreachable_patterns)] + _ => Err(MulticastError::Ipv6NotSupported), + } + } + + /// Check whether the interface listens to given destination multicast IP address. + pub fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool { + self.inner.has_multicast_group(addr) + } + + /// Depending on `igmp_report_state` and the therein contained + /// timeouts, send IGMP membership reports. + pub(crate) fn igmp_egress<D>(&mut self, device: &mut D) -> bool + where + D: Device + ?Sized, + { + match self.inner.igmp_report_state { + IgmpReportState::ToSpecificQuery { + version, + timeout, + group, + } if self.inner.now >= timeout => { + if let Some(pkt) = self.inner.igmp_report_packet(version, group) { + // Send initial membership report + if let Some(tx_token) = device.transmit(self.inner.now) { + // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. + self.inner + .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) + .unwrap(); + } else { + return false; + } + } + + self.inner.igmp_report_state = IgmpReportState::Inactive; + true + } + IgmpReportState::ToGeneralQuery { + version, + timeout, + interval, + next_index, + } if self.inner.now >= timeout => { + let addr = self + .inner + .ipv4_multicast_groups + .iter() + .nth(next_index) + .map(|(addr, ())| *addr); + + match addr { + Some(addr) => { + if let Some(pkt) = self.inner.igmp_report_packet(version, addr) { + // Send initial membership report + if let Some(tx_token) = device.transmit(self.inner.now) { + // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. + self.inner + .dispatch_ip( + tx_token, + PacketMeta::default(), + pkt, + &mut self.fragmenter, + ) + .unwrap(); + } else { + return false; + } + } + + let next_timeout = (timeout + interval).max(self.inner.now); + self.inner.igmp_report_state = IgmpReportState::ToGeneralQuery { + version, + timeout: next_timeout, + interval, + next_index: next_index + 1, + }; + true + } + + None => { + self.inner.igmp_report_state = IgmpReportState::Inactive; + false + } + } + } + _ => false, + } + } +} + +impl InterfaceInner { + /// Host duties of the **IGMPv2** protocol. + /// + /// Sets up `igmp_report_state` for responding to IGMP general/specific membership queries. + /// Membership must not be reported immediately in order to avoid flooding the network + /// after a query is broadcasted by a router; this is not currently done. + pub(super) fn process_igmp<'frame>( + &mut self, + ipv4_repr: Ipv4Repr, + ip_payload: &'frame [u8], + ) -> Option<Packet<'frame>> { + let igmp_packet = check!(IgmpPacket::new_checked(ip_payload)); + let igmp_repr = check!(IgmpRepr::parse(&igmp_packet)); + + // FIXME: report membership after a delay + match igmp_repr { + IgmpRepr::MembershipQuery { + group_addr, + version, + max_resp_time, + } => { + // General query + if group_addr.is_unspecified() + && ipv4_repr.dst_addr == Ipv4Address::MULTICAST_ALL_SYSTEMS + { + // Are we member in any groups? + if self.ipv4_multicast_groups.iter().next().is_some() { + let interval = match version { + IgmpVersion::Version1 => Duration::from_millis(100), + IgmpVersion::Version2 => { + // No dependence on a random generator + // (see [#24](https://github.com/m-labs/smoltcp/issues/24)) + // but at least spread reports evenly across max_resp_time. + let intervals = self.ipv4_multicast_groups.len() as u32 + 1; + max_resp_time / intervals + } + }; + self.igmp_report_state = IgmpReportState::ToGeneralQuery { + version, + timeout: self.now + interval, + interval, + next_index: 0, + }; + } + } else { + // Group-specific query + if self.has_multicast_group(group_addr) && ipv4_repr.dst_addr == group_addr { + // Don't respond immediately + let timeout = max_resp_time / 4; + self.igmp_report_state = IgmpReportState::ToSpecificQuery { + version, + timeout: self.now + timeout, + group: group_addr, + }; + } + } + } + // Ignore membership reports + IgmpRepr::MembershipReport { .. } => (), + // Ignore hosts leaving groups + IgmpRepr::LeaveGroup { .. } => (), + } + + None + } +} diff --git a/src/iface/interface/ipv4.rs b/src/iface/interface/ipv4.rs new file mode 100644 index 0000000..7569572 --- /dev/null +++ b/src/iface/interface/ipv4.rs @@ -0,0 +1,445 @@ +use super::*; + +#[cfg(feature = "socket-dhcpv4")] +use crate::socket::dhcpv4; +#[cfg(feature = "socket-icmp")] +use crate::socket::icmp; +use crate::socket::AnySocket; + +use crate::phy::{Medium, TxToken}; +use crate::time::Instant; +use crate::wire::*; + +impl InterfaceInner { + pub(super) fn process_ipv4<'a>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + ipv4_packet: &Ipv4Packet<&'a [u8]>, + frag: &'a mut FragmentsBuffer, + ) -> Option<Packet<'a>> { + let ipv4_repr = check!(Ipv4Repr::parse(ipv4_packet, &self.caps.checksum)); + if !self.is_unicast_v4(ipv4_repr.src_addr) && !ipv4_repr.src_addr.is_unspecified() { + // Discard packets with non-unicast source addresses but allow unspecified + net_debug!("non-unicast or unspecified source address"); + return None; + } + + #[cfg(feature = "proto-ipv4-fragmentation")] + let ip_payload = { + if ipv4_packet.more_frags() || ipv4_packet.frag_offset() != 0 { + let key = FragKey::Ipv4(ipv4_packet.get_key()); + + let f = match frag.assembler.get(&key, self.now + frag.reassembly_timeout) { + Ok(f) => f, + Err(_) => { + net_debug!("No available packet assembler for fragmented packet"); + return None; + } + }; + + if !ipv4_packet.more_frags() { + // This is the last fragment, so we know the total size + check!(f.set_total_size( + ipv4_packet.total_len() as usize - ipv4_packet.header_len() as usize + + ipv4_packet.frag_offset() as usize, + )); + } + + if let Err(e) = f.add(ipv4_packet.payload(), ipv4_packet.frag_offset() as usize) { + net_debug!("fragmentation error: {:?}", e); + return None; + } + + // NOTE: according to the standard, the total length needs to be + // recomputed, as well as the checksum. However, we don't really use + // the IPv4 header after the packet is reassembled. + match f.assemble() { + Some(payload) => payload, + None => return None, + } + } else { + ipv4_packet.payload() + } + }; + + #[cfg(not(feature = "proto-ipv4-fragmentation"))] + let ip_payload = ipv4_packet.payload(); + + let ip_repr = IpRepr::Ipv4(ipv4_repr); + + #[cfg(feature = "socket-raw")] + let handled_by_raw_socket = self.raw_socket_filter(sockets, &ip_repr, ip_payload); + #[cfg(not(feature = "socket-raw"))] + let handled_by_raw_socket = false; + + #[cfg(feature = "socket-dhcpv4")] + { + if ipv4_repr.next_header == IpProtocol::Udp + && matches!(self.caps.medium, Medium::Ethernet) + { + let udp_packet = check!(UdpPacket::new_checked(ip_payload)); + if let Some(dhcp_socket) = sockets + .items_mut() + .find_map(|i| dhcpv4::Socket::downcast_mut(&mut i.socket)) + { + // First check for source and dest ports, then do `UdpRepr::parse` if they match. + // This way we avoid validating the UDP checksum twice for all non-DHCP UDP packets (one here, one in `process_udp`) + if udp_packet.src_port() == dhcp_socket.server_port + && udp_packet.dst_port() == dhcp_socket.client_port + { + let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr()); + let udp_repr = check!(UdpRepr::parse( + &udp_packet, + &src_addr, + &dst_addr, + &self.caps.checksum + )); + let udp_payload = udp_packet.payload(); + + dhcp_socket.process(self, &ipv4_repr, &udp_repr, udp_payload); + return None; + } + } + } + } + + if !self.has_ip_addr(ipv4_repr.dst_addr) + && !self.has_multicast_group(ipv4_repr.dst_addr) + && !self.is_broadcast_v4(ipv4_repr.dst_addr) + { + // Ignore IP packets not directed at us, or broadcast, or any of the multicast groups. + // If AnyIP is enabled, also check if the packet is routed locally. + if !self.any_ip + || !ipv4_repr.dst_addr.is_unicast() + || self + .routes + .lookup(&IpAddress::Ipv4(ipv4_repr.dst_addr), self.now) + .map_or(true, |router_addr| !self.has_ip_addr(router_addr)) + { + return None; + } + } + + match ipv4_repr.next_header { + IpProtocol::Icmp => self.process_icmpv4(sockets, ip_repr, ip_payload), + + #[cfg(feature = "proto-igmp")] + IpProtocol::Igmp => self.process_igmp(ipv4_repr, ip_payload), + + #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] + IpProtocol::Udp => { + let udp_packet = check!(UdpPacket::new_checked(ip_payload)); + let udp_repr = check!(UdpRepr::parse( + &udp_packet, + &ipv4_repr.src_addr.into(), + &ipv4_repr.dst_addr.into(), + &self.checksum_caps(), + )); + + self.process_udp( + sockets, + meta, + ip_repr, + udp_repr, + handled_by_raw_socket, + udp_packet.payload(), + ip_payload, + ) + } + + #[cfg(feature = "socket-tcp")] + IpProtocol::Tcp => self.process_tcp(sockets, ip_repr, ip_payload), + + _ if handled_by_raw_socket => None, + + _ => { + // Send back as much of the original payload as we can. + let payload_len = + icmp_reply_payload_len(ip_payload.len(), IPV4_MIN_MTU, ipv4_repr.buffer_len()); + let icmp_reply_repr = Icmpv4Repr::DstUnreachable { + reason: Icmpv4DstUnreachable::ProtoUnreachable, + header: ipv4_repr, + data: &ip_payload[0..payload_len], + }; + self.icmpv4_reply(ipv4_repr, icmp_reply_repr) + } + } + } + + #[cfg(feature = "medium-ethernet")] + pub(super) fn process_arp<'frame>( + &mut self, + timestamp: Instant, + eth_frame: &EthernetFrame<&'frame [u8]>, + ) -> Option<EthernetPacket<'frame>> { + let arp_packet = check!(ArpPacket::new_checked(eth_frame.payload())); + let arp_repr = check!(ArpRepr::parse(&arp_packet)); + + match arp_repr { + ArpRepr::EthernetIpv4 { + operation, + source_hardware_addr, + source_protocol_addr, + target_protocol_addr, + .. + } => { + // Only process ARP packets for us. + if !self.has_ip_addr(target_protocol_addr) { + return None; + } + + // Only process REQUEST and RESPONSE. + if let ArpOperation::Unknown(_) = operation { + net_debug!("arp: unknown operation code"); + return None; + } + + // Discard packets with non-unicast source addresses. + if !source_protocol_addr.is_unicast() || !source_hardware_addr.is_unicast() { + net_debug!("arp: non-unicast source address"); + return None; + } + + if !self.in_same_network(&IpAddress::Ipv4(source_protocol_addr)) { + net_debug!("arp: source IP address not in same network as us"); + return None; + } + + // Fill the ARP cache from any ARP packet aimed at us (both request or response). + // We fill from requests too because if someone is requesting our address they + // are probably going to talk to us, so we avoid having to request their address + // when we later reply to them. + self.neighbor_cache.fill( + source_protocol_addr.into(), + source_hardware_addr.into(), + timestamp, + ); + + if operation == ArpOperation::Request { + let src_hardware_addr = self.hardware_addr.ethernet_or_panic(); + + Some(EthernetPacket::Arp(ArpRepr::EthernetIpv4 { + operation: ArpOperation::Reply, + source_hardware_addr: src_hardware_addr, + source_protocol_addr: target_protocol_addr, + target_hardware_addr: source_hardware_addr, + target_protocol_addr: source_protocol_addr, + })) + } else { + None + } + } + } + } + + pub(super) fn process_icmpv4<'frame>( + &mut self, + _sockets: &mut SocketSet, + ip_repr: IpRepr, + ip_payload: &'frame [u8], + ) -> Option<Packet<'frame>> { + let icmp_packet = check!(Icmpv4Packet::new_checked(ip_payload)); + let icmp_repr = check!(Icmpv4Repr::parse(&icmp_packet, &self.caps.checksum)); + + #[cfg(feature = "socket-icmp")] + let mut handled_by_icmp_socket = false; + + #[cfg(all(feature = "socket-icmp", feature = "proto-ipv4"))] + for icmp_socket in _sockets + .items_mut() + .filter_map(|i| icmp::Socket::downcast_mut(&mut i.socket)) + { + if icmp_socket.accepts(self, &ip_repr, &icmp_repr.into()) { + icmp_socket.process(self, &ip_repr, &icmp_repr.into()); + handled_by_icmp_socket = true; + } + } + + match icmp_repr { + // Respond to echo requests. + #[cfg(feature = "proto-ipv4")] + Icmpv4Repr::EchoRequest { + ident, + seq_no, + data, + } => { + let icmp_reply_repr = Icmpv4Repr::EchoReply { + ident, + seq_no, + data, + }; + match ip_repr { + IpRepr::Ipv4(ipv4_repr) => self.icmpv4_reply(ipv4_repr, icmp_reply_repr), + #[allow(unreachable_patterns)] + _ => unreachable!(), + } + } + + // Ignore any echo replies. + Icmpv4Repr::EchoReply { .. } => None, + + // Don't report an error if a packet with unknown type + // has been handled by an ICMP socket + #[cfg(feature = "socket-icmp")] + _ if handled_by_icmp_socket => None, + + // FIXME: do something correct here? + _ => None, + } + } + + pub(super) fn icmpv4_reply<'frame, 'icmp: 'frame>( + &self, + ipv4_repr: Ipv4Repr, + icmp_repr: Icmpv4Repr<'icmp>, + ) -> Option<Packet<'frame>> { + if !self.is_unicast_v4(ipv4_repr.src_addr) { + // Do not send ICMP replies to non-unicast sources + None + } else if self.is_unicast_v4(ipv4_repr.dst_addr) { + // Reply as normal when src_addr and dst_addr are both unicast + let ipv4_reply_repr = Ipv4Repr { + src_addr: ipv4_repr.dst_addr, + dst_addr: ipv4_repr.src_addr, + next_header: IpProtocol::Icmp, + payload_len: icmp_repr.buffer_len(), + hop_limit: 64, + }; + Some(Packet::new_ipv4( + ipv4_reply_repr, + IpPayload::Icmpv4(icmp_repr), + )) + } else if self.is_broadcast_v4(ipv4_repr.dst_addr) { + // Only reply to broadcasts for echo replies and not other ICMP messages + match icmp_repr { + Icmpv4Repr::EchoReply { .. } => match self.ipv4_addr() { + Some(src_addr) => { + let ipv4_reply_repr = Ipv4Repr { + src_addr, + dst_addr: ipv4_repr.src_addr, + next_header: IpProtocol::Icmp, + payload_len: icmp_repr.buffer_len(), + hop_limit: 64, + }; + Some(Packet::new_ipv4( + ipv4_reply_repr, + IpPayload::Icmpv4(icmp_repr), + )) + } + None => None, + }, + _ => None, + } + } else { + None + } + } + + #[cfg(feature = "proto-ipv4-fragmentation")] + pub(super) fn dispatch_ipv4_frag<Tx: TxToken>(&mut self, tx_token: Tx, frag: &mut Fragmenter) { + let caps = self.caps.clone(); + + let mtu_max = self.ip_mtu(); + let ip_len = (frag.packet_len - frag.sent_bytes + frag.ipv4.repr.buffer_len()).min(mtu_max); + let payload_len = ip_len - frag.ipv4.repr.buffer_len(); + + let more_frags = (frag.packet_len - frag.sent_bytes) != payload_len; + frag.ipv4.repr.payload_len = payload_len; + frag.sent_bytes += payload_len; + + let mut tx_len = ip_len; + #[cfg(feature = "medium-ethernet")] + if matches!(caps.medium, Medium::Ethernet) { + tx_len += EthernetFrame::<&[u8]>::header_len(); + } + + // Emit function for the Ethernet header. + #[cfg(feature = "medium-ethernet")] + let emit_ethernet = |repr: &IpRepr, tx_buffer: &mut [u8]| { + let mut frame = EthernetFrame::new_unchecked(tx_buffer); + + let src_addr = self.hardware_addr.ethernet_or_panic(); + frame.set_src_addr(src_addr); + frame.set_dst_addr(frag.ipv4.dst_hardware_addr); + + match repr.version() { + #[cfg(feature = "proto-ipv4")] + IpVersion::Ipv4 => frame.set_ethertype(EthernetProtocol::Ipv4), + #[cfg(feature = "proto-ipv6")] + IpVersion::Ipv6 => frame.set_ethertype(EthernetProtocol::Ipv6), + } + }; + + tx_token.consume(tx_len, |mut tx_buffer| { + #[cfg(feature = "medium-ethernet")] + if matches!(self.caps.medium, Medium::Ethernet) { + emit_ethernet(&IpRepr::Ipv4(frag.ipv4.repr), tx_buffer); + tx_buffer = &mut tx_buffer[EthernetFrame::<&[u8]>::header_len()..]; + } + + let mut packet = + Ipv4Packet::new_unchecked(&mut tx_buffer[..frag.ipv4.repr.buffer_len()]); + frag.ipv4.repr.emit(&mut packet, &caps.checksum); + packet.set_ident(frag.ipv4.ident); + packet.set_more_frags(more_frags); + packet.set_dont_frag(false); + packet.set_frag_offset(frag.ipv4.frag_offset); + + if caps.checksum.ipv4.tx() { + packet.fill_checksum(); + } + + tx_buffer[frag.ipv4.repr.buffer_len()..][..payload_len].copy_from_slice( + &frag.buffer[frag.ipv4.frag_offset as usize + frag.ipv4.repr.buffer_len()..] + [..payload_len], + ); + + // Update the frag offset for the next fragment. + frag.ipv4.frag_offset += payload_len as u16; + }) + } + + #[cfg(feature = "proto-igmp")] + pub(super) fn igmp_report_packet<'any>( + &self, + version: IgmpVersion, + group_addr: Ipv4Address, + ) -> Option<Packet<'any>> { + let iface_addr = self.ipv4_addr()?; + let igmp_repr = IgmpRepr::MembershipReport { + group_addr, + version, + }; + let pkt = Packet::new_ipv4( + Ipv4Repr { + src_addr: iface_addr, + // Send to the group being reported + dst_addr: group_addr, + next_header: IpProtocol::Igmp, + payload_len: igmp_repr.buffer_len(), + hop_limit: 1, + // [#183](https://github.com/m-labs/smoltcp/issues/183). + }, + IpPayload::Igmp(igmp_repr), + ); + Some(pkt) + } + + #[cfg(feature = "proto-igmp")] + pub(super) fn igmp_leave_packet<'any>(&self, group_addr: Ipv4Address) -> Option<Packet<'any>> { + self.ipv4_addr().map(|iface_addr| { + let igmp_repr = IgmpRepr::LeaveGroup { group_addr }; + Packet::new_ipv4( + Ipv4Repr { + src_addr: iface_addr, + dst_addr: Ipv4Address::MULTICAST_ALL_ROUTERS, + next_header: IpProtocol::Igmp, + payload_len: igmp_repr.buffer_len(), + hop_limit: 1, + }, + IpPayload::Igmp(igmp_repr), + ) + }) + } +} diff --git a/src/iface/interface/ipv6.rs b/src/iface/interface/ipv6.rs new file mode 100644 index 0000000..22e65b9 --- /dev/null +++ b/src/iface/interface/ipv6.rs @@ -0,0 +1,355 @@ +use super::*; + +#[cfg(feature = "socket-icmp")] +use crate::socket::icmp; +use crate::socket::AnySocket; + +use crate::phy::PacketMeta; +use crate::wire::*; + +/// Enum used for the process_hopbyhop function. In some cases, when discarding a packet, an ICMMP +/// parameter problem message needs to be transmitted to the source of the address. In other cases, +/// the processing of the IP packet can continue. +#[allow(clippy::large_enum_variant)] +enum HopByHopResponse<'frame> { + /// Continue processing the IPv6 packet. + Continue((IpProtocol, &'frame [u8])), + /// Discard the packet and maybe send back an ICMPv6 packet. + Discard(Option<Packet<'frame>>), +} + +// We implement `Default` such that we can use the check! macro. +impl Default for HopByHopResponse<'_> { + fn default() -> Self { + Self::Discard(None) + } +} + +impl InterfaceInner { + pub(super) fn process_ipv6<'frame>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + ipv6_packet: &Ipv6Packet<&'frame [u8]>, + ) -> Option<Packet<'frame>> { + let ipv6_repr = check!(Ipv6Repr::parse(ipv6_packet)); + + if !ipv6_repr.src_addr.is_unicast() { + // Discard packets with non-unicast source addresses. + net_debug!("non-unicast source address"); + return None; + } + + let (next_header, ip_payload) = if ipv6_repr.next_header == IpProtocol::HopByHop { + match self.process_hopbyhop(ipv6_repr, ipv6_packet.payload()) { + HopByHopResponse::Discard(e) => return e, + HopByHopResponse::Continue(next) => next, + } + } else { + (ipv6_repr.next_header, ipv6_packet.payload()) + }; + + if !self.has_ip_addr(ipv6_repr.dst_addr) + && !self.has_multicast_group(ipv6_repr.dst_addr) + && !ipv6_repr.dst_addr.is_loopback() + { + net_trace!("packet IP address not for this interface"); + return None; + } + + #[cfg(feature = "socket-raw")] + let handled_by_raw_socket = self.raw_socket_filter(sockets, &ipv6_repr.into(), ip_payload); + #[cfg(not(feature = "socket-raw"))] + let handled_by_raw_socket = false; + + self.process_nxt_hdr( + sockets, + meta, + ipv6_repr, + next_header, + handled_by_raw_socket, + ip_payload, + ) + } + + fn process_hopbyhop<'frame>( + &mut self, + ipv6_repr: Ipv6Repr, + ip_payload: &'frame [u8], + ) -> HopByHopResponse<'frame> { + let param_problem = || { + let payload_len = + icmp_reply_payload_len(ip_payload.len(), IPV6_MIN_MTU, ipv6_repr.buffer_len()); + self.icmpv6_reply( + ipv6_repr, + Icmpv6Repr::ParamProblem { + reason: Icmpv6ParamProblem::UnrecognizedOption, + pointer: ipv6_repr.buffer_len() as u32, + header: ipv6_repr, + data: &ip_payload[0..payload_len], + }, + ) + }; + + let ext_hdr = check!(Ipv6ExtHeader::new_checked(ip_payload)); + let ext_repr = check!(Ipv6ExtHeaderRepr::parse(&ext_hdr)); + let hbh_hdr = check!(Ipv6HopByHopHeader::new_checked(ext_repr.data)); + let hbh_repr = check!(Ipv6HopByHopRepr::parse(&hbh_hdr)); + + for opt_repr in &hbh_repr.options { + match opt_repr { + Ipv6OptionRepr::Pad1 | Ipv6OptionRepr::PadN(_) => (), + #[cfg(feature = "proto-rpl")] + Ipv6OptionRepr::Rpl(_) => {} + + Ipv6OptionRepr::Unknown { type_, .. } => { + match Ipv6OptionFailureType::from(*type_) { + Ipv6OptionFailureType::Skip => (), + Ipv6OptionFailureType::Discard => { + return HopByHopResponse::Discard(None); + } + Ipv6OptionFailureType::DiscardSendAll => { + return HopByHopResponse::Discard(param_problem()); + } + Ipv6OptionFailureType::DiscardSendUnicast + if !ipv6_repr.dst_addr.is_multicast() => + { + return HopByHopResponse::Discard(param_problem()); + } + _ => unreachable!(), + } + } + } + } + + HopByHopResponse::Continue(( + ext_repr.next_header, + &ip_payload[ext_repr.header_len() + ext_repr.data.len()..], + )) + } + + /// Given the next header value forward the payload onto the correct process + /// function. + fn process_nxt_hdr<'frame>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + ipv6_repr: Ipv6Repr, + nxt_hdr: IpProtocol, + handled_by_raw_socket: bool, + ip_payload: &'frame [u8], + ) -> Option<Packet<'frame>> { + match nxt_hdr { + IpProtocol::Icmpv6 => self.process_icmpv6(sockets, ipv6_repr.into(), ip_payload), + + #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] + IpProtocol::Udp => { + let udp_packet = check!(UdpPacket::new_checked(ip_payload)); + let udp_repr = check!(UdpRepr::parse( + &udp_packet, + &ipv6_repr.src_addr.into(), + &ipv6_repr.dst_addr.into(), + &self.checksum_caps(), + )); + + self.process_udp( + sockets, + meta, + ipv6_repr.into(), + udp_repr, + handled_by_raw_socket, + udp_packet.payload(), + ip_payload, + ) + } + + #[cfg(feature = "socket-tcp")] + IpProtocol::Tcp => self.process_tcp(sockets, ipv6_repr.into(), ip_payload), + + #[cfg(feature = "socket-raw")] + _ if handled_by_raw_socket => None, + + _ => { + // Send back as much of the original payload as we can. + let payload_len = + icmp_reply_payload_len(ip_payload.len(), IPV6_MIN_MTU, ipv6_repr.buffer_len()); + let icmp_reply_repr = Icmpv6Repr::ParamProblem { + reason: Icmpv6ParamProblem::UnrecognizedNxtHdr, + // The offending packet is after the IPv6 header. + pointer: ipv6_repr.buffer_len() as u32, + header: ipv6_repr, + data: &ip_payload[0..payload_len], + }; + self.icmpv6_reply(ipv6_repr, icmp_reply_repr) + } + } + } + + pub(super) fn process_icmpv6<'frame>( + &mut self, + _sockets: &mut SocketSet, + ip_repr: IpRepr, + ip_payload: &'frame [u8], + ) -> Option<Packet<'frame>> { + let icmp_packet = check!(Icmpv6Packet::new_checked(ip_payload)); + let icmp_repr = check!(Icmpv6Repr::parse( + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + &icmp_packet, + &self.caps.checksum, + )); + + #[cfg(feature = "socket-icmp")] + let mut handled_by_icmp_socket = false; + + #[cfg(feature = "socket-icmp")] + for icmp_socket in _sockets + .items_mut() + .filter_map(|i| icmp::Socket::downcast_mut(&mut i.socket)) + { + if icmp_socket.accepts(self, &ip_repr, &icmp_repr.into()) { + icmp_socket.process(self, &ip_repr, &icmp_repr.into()); + handled_by_icmp_socket = true; + } + } + + match icmp_repr { + // Respond to echo requests. + Icmpv6Repr::EchoRequest { + ident, + seq_no, + data, + } => match ip_repr { + IpRepr::Ipv6(ipv6_repr) => { + let icmp_reply_repr = Icmpv6Repr::EchoReply { + ident, + seq_no, + data, + }; + self.icmpv6_reply(ipv6_repr, icmp_reply_repr) + } + #[allow(unreachable_patterns)] + _ => unreachable!(), + }, + + // Ignore any echo replies. + Icmpv6Repr::EchoReply { .. } => None, + + // Forward any NDISC packets to the ndisc packet handler + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + Icmpv6Repr::Ndisc(repr) if ip_repr.hop_limit() == 0xff => match ip_repr { + IpRepr::Ipv6(ipv6_repr) => match self.caps.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => self.process_ndisc(ipv6_repr, repr), + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => self.process_ndisc(ipv6_repr, repr), + #[cfg(feature = "medium-ip")] + Medium::Ip => None, + }, + #[allow(unreachable_patterns)] + _ => unreachable!(), + }, + + // Don't report an error if a packet with unknown type + // has been handled by an ICMP socket + #[cfg(feature = "socket-icmp")] + _ if handled_by_icmp_socket => None, + + // FIXME: do something correct here? + _ => None, + } + } + + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + pub(super) fn process_ndisc<'frame>( + &mut self, + ip_repr: Ipv6Repr, + repr: NdiscRepr<'frame>, + ) -> Option<Packet<'frame>> { + match repr { + NdiscRepr::NeighborAdvert { + lladdr, + target_addr, + flags, + } => { + let ip_addr = ip_repr.src_addr.into(); + if let Some(lladdr) = lladdr { + let lladdr = check!(lladdr.parse(self.caps.medium)); + if !lladdr.is_unicast() || !target_addr.is_unicast() { + return None; + } + if flags.contains(NdiscNeighborFlags::OVERRIDE) + || !self.neighbor_cache.lookup(&ip_addr, self.now).found() + { + self.neighbor_cache.fill(ip_addr, lladdr, self.now) + } + } + None + } + NdiscRepr::NeighborSolicit { + target_addr, + lladdr, + .. + } => { + if let Some(lladdr) = lladdr { + let lladdr = check!(lladdr.parse(self.caps.medium)); + if !lladdr.is_unicast() || !target_addr.is_unicast() { + return None; + } + self.neighbor_cache + .fill(ip_repr.src_addr.into(), lladdr, self.now); + } + + if self.has_solicited_node(ip_repr.dst_addr) && self.has_ip_addr(target_addr) { + let advert = Icmpv6Repr::Ndisc(NdiscRepr::NeighborAdvert { + flags: NdiscNeighborFlags::SOLICITED, + target_addr, + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + lladdr: Some(self.hardware_addr.into()), + }); + let ip_repr = Ipv6Repr { + src_addr: target_addr, + dst_addr: ip_repr.src_addr, + next_header: IpProtocol::Icmpv6, + hop_limit: 0xff, + payload_len: advert.buffer_len(), + }; + Some(Packet::new_ipv6(ip_repr, IpPayload::Icmpv6(advert))) + } else { + None + } + } + _ => None, + } + } + + pub(super) fn icmpv6_reply<'frame, 'icmp: 'frame>( + &self, + ipv6_repr: Ipv6Repr, + icmp_repr: Icmpv6Repr<'icmp>, + ) -> Option<Packet<'frame>> { + let src_addr = ipv6_repr.dst_addr; + let dst_addr = ipv6_repr.src_addr; + + let src_addr = if src_addr.is_unicast() { + src_addr + } else if let Some(addr) = self.get_source_address_ipv6(&dst_addr) { + addr + } else { + net_debug!("no suitable source address found"); + return None; + }; + + let ipv6_reply_repr = Ipv6Repr { + src_addr, + dst_addr, + next_header: IpProtocol::Icmpv6, + payload_len: icmp_repr.buffer_len(), + hop_limit: 64, + }; + Some(Packet::new_ipv6( + ipv6_reply_repr, + IpPayload::Icmpv6(icmp_repr), + )) + } +} diff --git a/src/iface/interface/mod.rs b/src/iface/interface/mod.rs new file mode 100644 index 0000000..7d6bdc8 --- /dev/null +++ b/src/iface/interface/mod.rs @@ -0,0 +1,1644 @@ +// Heads up! Before working on this file you should read the parts +// of RFC 1122 that discuss Ethernet, ARP and IP for any IPv4 work +// and RFCs 8200 and 4861 for any IPv6 and NDISC work. + +#[cfg(test)] +mod tests; + +#[cfg(feature = "medium-ethernet")] +mod ethernet; +#[cfg(feature = "medium-ieee802154")] +mod ieee802154; + +#[cfg(feature = "proto-ipv4")] +mod ipv4; +#[cfg(feature = "proto-ipv6")] +mod ipv6; +#[cfg(feature = "proto-sixlowpan")] +mod sixlowpan; + +#[cfg(feature = "proto-igmp")] +mod igmp; + +#[cfg(feature = "proto-igmp")] +pub use igmp::MulticastError; + +use super::packet::*; + +use core::result::Result; +use heapless::{LinearMap, Vec}; + +#[cfg(feature = "_proto-fragmentation")] +use super::fragmentation::FragKey; +#[cfg(any(feature = "proto-ipv4", feature = "proto-sixlowpan"))] +use super::fragmentation::PacketAssemblerSet; +use super::fragmentation::{Fragmenter, FragmentsBuffer}; + +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +use super::neighbor::{Answer as NeighborAnswer, Cache as NeighborCache}; +use super::socket_set::SocketSet; +use crate::config::{ + IFACE_MAX_ADDR_COUNT, IFACE_MAX_MULTICAST_GROUP_COUNT, + IFACE_MAX_SIXLOWPAN_ADDRESS_CONTEXT_COUNT, +}; +use crate::iface::Routes; +use crate::phy::PacketMeta; +use crate::phy::{ChecksumCapabilities, Device, DeviceCapabilities, Medium, RxToken, TxToken}; +use crate::rand::Rand; +#[cfg(feature = "socket-dns")] +use crate::socket::dns; +use crate::socket::*; +use crate::time::{Duration, Instant}; + +use crate::wire::*; + +macro_rules! check { + ($e:expr) => { + match $e { + Ok(x) => x, + Err(_) => { + // concat!/stringify! doesn't work with defmt macros + #[cfg(not(feature = "defmt"))] + net_trace!(concat!("iface: malformed ", stringify!($e))); + #[cfg(feature = "defmt")] + net_trace!("iface: malformed"); + return Default::default(); + } + } + }; +} +use check; + +/// A network interface. +/// +/// The network interface logically owns a number of other data structures; to avoid +/// a dependency on heap allocation, it instead owns a `BorrowMut<[T]>`, which can be +/// a `&mut [T]`, or `Vec<T>` if a heap is available. +pub struct Interface { + pub(crate) inner: InterfaceInner, + fragments: FragmentsBuffer, + fragmenter: Fragmenter, +} + +/// The device independent part of an Ethernet network interface. +/// +/// Separating the device from the data required for processing and dispatching makes +/// it possible to borrow them independently. For example, the tx and rx tokens borrow +/// the `device` mutably until they're used, which makes it impossible to call other +/// methods on the `Interface` in this time (since its `device` field is borrowed +/// exclusively). However, it is still possible to call methods on its `inner` field. +pub struct InterfaceInner { + caps: DeviceCapabilities, + now: Instant, + rand: Rand, + + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + neighbor_cache: NeighborCache, + hardware_addr: HardwareAddress, + #[cfg(feature = "medium-ieee802154")] + sequence_no: u8, + #[cfg(feature = "medium-ieee802154")] + pan_id: Option<Ieee802154Pan>, + #[cfg(feature = "proto-ipv4-fragmentation")] + ipv4_id: u16, + #[cfg(feature = "proto-sixlowpan")] + sixlowpan_address_context: + Vec<SixlowpanAddressContext, IFACE_MAX_SIXLOWPAN_ADDRESS_CONTEXT_COUNT>, + #[cfg(feature = "proto-sixlowpan-fragmentation")] + tag: u16, + ip_addrs: Vec<IpCidr, IFACE_MAX_ADDR_COUNT>, + #[cfg(feature = "proto-ipv4")] + any_ip: bool, + routes: Routes, + #[cfg(feature = "proto-igmp")] + ipv4_multicast_groups: LinearMap<Ipv4Address, (), IFACE_MAX_MULTICAST_GROUP_COUNT>, + /// When to report for (all or) the next multicast group membership via IGMP + #[cfg(feature = "proto-igmp")] + igmp_report_state: IgmpReportState, +} + +/// Configuration structure used for creating a network interface. +#[non_exhaustive] +pub struct Config { + /// Random seed. + /// + /// It is strongly recommended that the random seed is different on each boot, + /// to avoid problems with TCP port/sequence collisions. + /// + /// The seed doesn't have to be cryptographically secure. + pub random_seed: u64, + + /// Set the Hardware address the interface will use. + /// + /// # Panics + /// Creating the interface panics if the address is not unicast. + pub hardware_addr: HardwareAddress, + + /// Set the IEEE802.15.4 PAN ID the interface will use. + /// + /// **NOTE**: we use the same PAN ID for destination and source. + #[cfg(feature = "medium-ieee802154")] + pub pan_id: Option<Ieee802154Pan>, +} + +impl Config { + pub fn new(hardware_addr: HardwareAddress) -> Self { + Config { + random_seed: 0, + hardware_addr, + #[cfg(feature = "medium-ieee802154")] + pan_id: None, + } + } +} + +impl Interface { + /// Create a network interface using the previously provided configuration. + /// + /// # Panics + /// This function panics if the [`Config::hardware_address`] does not match + /// the medium of the device. + pub fn new<D>(config: Config, device: &mut D, now: Instant) -> Self + where + D: Device + ?Sized, + { + let caps = device.capabilities(); + assert_eq!( + config.hardware_addr.medium(), + caps.medium, + "The hardware address does not match the medium of the interface." + ); + + let mut rand = Rand::new(config.random_seed); + + #[cfg(feature = "medium-ieee802154")] + let mut sequence_no; + #[cfg(feature = "medium-ieee802154")] + loop { + sequence_no = (rand.rand_u32() & 0xff) as u8; + if sequence_no != 0 { + break; + } + } + + #[cfg(feature = "proto-sixlowpan")] + let mut tag; + + #[cfg(feature = "proto-sixlowpan")] + loop { + tag = rand.rand_u16(); + if tag != 0 { + break; + } + } + + #[cfg(feature = "proto-ipv4")] + let mut ipv4_id; + + #[cfg(feature = "proto-ipv4")] + loop { + ipv4_id = rand.rand_u16(); + if ipv4_id != 0 { + break; + } + } + + Interface { + fragments: FragmentsBuffer { + #[cfg(feature = "proto-sixlowpan")] + decompress_buf: [0u8; sixlowpan::MAX_DECOMPRESSED_LEN], + + #[cfg(feature = "_proto-fragmentation")] + assembler: PacketAssemblerSet::new(), + #[cfg(feature = "_proto-fragmentation")] + reassembly_timeout: Duration::from_secs(60), + }, + fragmenter: Fragmenter::new(), + inner: InterfaceInner { + now, + caps, + hardware_addr: config.hardware_addr, + ip_addrs: Vec::new(), + #[cfg(feature = "proto-ipv4")] + any_ip: false, + routes: Routes::new(), + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + neighbor_cache: NeighborCache::new(), + #[cfg(feature = "proto-igmp")] + ipv4_multicast_groups: LinearMap::new(), + #[cfg(feature = "proto-igmp")] + igmp_report_state: IgmpReportState::Inactive, + #[cfg(feature = "medium-ieee802154")] + sequence_no, + #[cfg(feature = "medium-ieee802154")] + pan_id: config.pan_id, + #[cfg(feature = "proto-sixlowpan-fragmentation")] + tag, + #[cfg(feature = "proto-ipv4-fragmentation")] + ipv4_id, + #[cfg(feature = "proto-sixlowpan")] + sixlowpan_address_context: Vec::new(), + rand, + }, + } + } + + /// Get the socket context. + /// + /// The context is needed for some socket methods. + pub fn context(&mut self) -> &mut InterfaceInner { + &mut self.inner + } + + /// Get the HardwareAddress address of the interface. + /// + /// # Panics + /// This function panics if the medium is not Ethernet or Ieee802154. + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + pub fn hardware_addr(&self) -> HardwareAddress { + #[cfg(all(feature = "medium-ethernet", not(feature = "medium-ieee802154")))] + assert!(self.inner.caps.medium == Medium::Ethernet); + #[cfg(all(feature = "medium-ieee802154", not(feature = "medium-ethernet")))] + assert!(self.inner.caps.medium == Medium::Ieee802154); + + #[cfg(all(feature = "medium-ieee802154", feature = "medium-ethernet"))] + assert!( + self.inner.caps.medium == Medium::Ethernet + || self.inner.caps.medium == Medium::Ieee802154 + ); + + self.inner.hardware_addr + } + + /// Set the HardwareAddress address of the interface. + /// + /// # Panics + /// This function panics if the address is not unicast, and if the medium is not Ethernet or + /// Ieee802154. + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + pub fn set_hardware_addr(&mut self, addr: HardwareAddress) { + #[cfg(all(feature = "medium-ethernet", not(feature = "medium-ieee802154")))] + assert!(self.inner.caps.medium == Medium::Ethernet); + #[cfg(all(feature = "medium-ieee802154", not(feature = "medium-ethernet")))] + assert!(self.inner.caps.medium == Medium::Ieee802154); + + #[cfg(all(feature = "medium-ieee802154", feature = "medium-ethernet"))] + assert!( + self.inner.caps.medium == Medium::Ethernet + || self.inner.caps.medium == Medium::Ieee802154 + ); + + InterfaceInner::check_hardware_addr(&addr); + self.inner.hardware_addr = addr; + } + + /// Get the IP addresses of the interface. + pub fn ip_addrs(&self) -> &[IpCidr] { + self.inner.ip_addrs.as_ref() + } + + /// Get the first IPv4 address if present. + #[cfg(feature = "proto-ipv4")] + pub fn ipv4_addr(&self) -> Option<Ipv4Address> { + self.inner.ipv4_addr() + } + + /// Get the first IPv6 address if present. + #[cfg(feature = "proto-ipv6")] + pub fn ipv6_addr(&self) -> Option<Ipv6Address> { + self.inner.ipv6_addr() + } + + /// Get an address from the interface that could be used as source address. For IPv4, this is + /// the first IPv4 address from the list of addresses. For IPv6, the address is based on the + /// destination address and uses RFC6724 for selecting the source address. + pub fn get_source_address(&self, dst_addr: &IpAddress) -> Option<IpAddress> { + self.inner.get_source_address(dst_addr) + } + + /// Get an address from the interface that could be used as source address. This is the first + /// IPv4 address from the list of addresses in the interface. + #[cfg(feature = "proto-ipv4")] + pub fn get_source_address_ipv4(&self, dst_addr: &Ipv4Address) -> Option<Ipv4Address> { + self.inner.get_source_address_ipv4(dst_addr) + } + + /// Get an address from the interface that could be used as source address. The selection is + /// based on RFC6724. + #[cfg(feature = "proto-ipv6")] + pub fn get_source_address_ipv6(&self, dst_addr: &Ipv6Address) -> Option<Ipv6Address> { + self.inner.get_source_address_ipv6(dst_addr) + } + + /// Update the IP addresses of the interface. + /// + /// # Panics + /// This function panics if any of the addresses are not unicast. + pub fn update_ip_addrs<F: FnOnce(&mut Vec<IpCidr, IFACE_MAX_ADDR_COUNT>)>(&mut self, f: F) { + f(&mut self.inner.ip_addrs); + InterfaceInner::flush_cache(&mut self.inner); + InterfaceInner::check_ip_addrs(&self.inner.ip_addrs) + } + + /// Check whether the interface has the given IP address assigned. + pub fn has_ip_addr<T: Into<IpAddress>>(&self, addr: T) -> bool { + self.inner.has_ip_addr(addr) + } + + pub fn routes(&self) -> &Routes { + &self.inner.routes + } + + pub fn routes_mut(&mut self) -> &mut Routes { + &mut self.inner.routes + } + + /// Enable or disable the AnyIP capability. + /// + /// AnyIP allowins packets to be received + /// locally on IPv4 addresses other than the interface's configured [ip_addrs]. + /// When AnyIP is enabled and a route prefix in [`routes`](Self::routes) specifies one of + /// the interface's [`ip_addrs`](Self::ip_addrs) as its gateway, the interface will accept + /// packets addressed to that prefix. + /// + /// # IPv6 + /// + /// This option is not available or required for IPv6 as packets sent to + /// the interface are not filtered by IPv6 address. + #[cfg(feature = "proto-ipv4")] + pub fn set_any_ip(&mut self, any_ip: bool) { + self.inner.any_ip = any_ip; + } + + /// Get whether AnyIP is enabled. + /// + /// See [`set_any_ip`](Self::set_any_ip) for details on AnyIP + #[cfg(feature = "proto-ipv4")] + pub fn any_ip(&self) -> bool { + self.inner.any_ip + } + + /// Get the 6LoWPAN address contexts. + #[cfg(feature = "proto-sixlowpan")] + pub fn sixlowpan_address_context( + &self, + ) -> &Vec<SixlowpanAddressContext, IFACE_MAX_SIXLOWPAN_ADDRESS_CONTEXT_COUNT> { + &self.inner.sixlowpan_address_context + } + + /// Get a mutable reference to the 6LoWPAN address contexts. + #[cfg(feature = "proto-sixlowpan")] + pub fn sixlowpan_address_context_mut( + &mut self, + ) -> &mut Vec<SixlowpanAddressContext, IFACE_MAX_SIXLOWPAN_ADDRESS_CONTEXT_COUNT> { + &mut self.inner.sixlowpan_address_context + } + + /// Get the packet reassembly timeout. + #[cfg(feature = "_proto-fragmentation")] + pub fn reassembly_timeout(&self) -> Duration { + self.fragments.reassembly_timeout + } + + /// Set the packet reassembly timeout. + #[cfg(feature = "_proto-fragmentation")] + pub fn set_reassembly_timeout(&mut self, timeout: Duration) { + if timeout > Duration::from_secs(60) { + net_debug!("RFC 4944 specifies that the reassembly timeout MUST be set to a maximum of 60 seconds"); + } + self.fragments.reassembly_timeout = timeout; + } + + /// Transmit packets queued in the given sockets, and receive packets queued + /// in the device. + /// + /// This function returns a boolean value indicating whether any packets were + /// processed or emitted, and thus, whether the readiness of any socket might + /// have changed. + pub fn poll<D>( + &mut self, + timestamp: Instant, + device: &mut D, + sockets: &mut SocketSet<'_>, + ) -> bool + where + D: Device + ?Sized, + { + self.inner.now = timestamp; + + #[cfg(feature = "_proto-fragmentation")] + self.fragments.assembler.remove_expired(timestamp); + + match self.inner.caps.medium { + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => + { + #[cfg(feature = "proto-sixlowpan-fragmentation")] + if self.sixlowpan_egress(device) { + return true; + } + } + #[cfg(any(feature = "medium-ethernet", feature = "medium-ip"))] + _ => + { + #[cfg(feature = "proto-ipv4-fragmentation")] + if self.ipv4_egress(device) { + return true; + } + } + } + + let mut readiness_may_have_changed = false; + + loop { + let mut did_something = false; + did_something |= self.socket_ingress(device, sockets); + did_something |= self.socket_egress(device, sockets); + + #[cfg(feature = "proto-igmp")] + { + did_something |= self.igmp_egress(device); + } + + if did_something { + readiness_may_have_changed = true; + } else { + break; + } + } + + readiness_may_have_changed + } + + /// Return a _soft deadline_ for calling [poll] the next time. + /// The [Instant] returned is the time at which you should call [poll] next. + /// It is harmless (but wastes energy) to call it before the [Instant], and + /// potentially harmful (impacting quality of service) to call it after the + /// [Instant] + /// + /// [poll]: #method.poll + /// [Instant]: struct.Instant.html + pub fn poll_at(&mut self, timestamp: Instant, sockets: &SocketSet<'_>) -> Option<Instant> { + self.inner.now = timestamp; + + #[cfg(feature = "_proto-fragmentation")] + if !self.fragmenter.is_empty() { + return Some(Instant::from_millis(0)); + } + + let inner = &mut self.inner; + + sockets + .items() + .filter_map(move |item| { + let socket_poll_at = item.socket.poll_at(inner); + match item + .meta + .poll_at(socket_poll_at, |ip_addr| inner.has_neighbor(&ip_addr)) + { + PollAt::Ingress => None, + PollAt::Time(instant) => Some(instant), + PollAt::Now => Some(Instant::from_millis(0)), + } + }) + .min() + } + + /// Return an _advisory wait time_ for calling [poll] the next time. + /// The [Duration] returned is the time left to wait before calling [poll] next. + /// It is harmless (but wastes energy) to call it before the [Duration] has passed, + /// and potentially harmful (impacting quality of service) to call it after the + /// [Duration] has passed. + /// + /// [poll]: #method.poll + /// [Duration]: struct.Duration.html + pub fn poll_delay(&mut self, timestamp: Instant, sockets: &SocketSet<'_>) -> Option<Duration> { + match self.poll_at(timestamp, sockets) { + Some(poll_at) if timestamp < poll_at => Some(poll_at - timestamp), + Some(_) => Some(Duration::from_millis(0)), + _ => None, + } + } + + fn socket_ingress<D>(&mut self, device: &mut D, sockets: &mut SocketSet<'_>) -> bool + where + D: Device + ?Sized, + { + let mut processed_any = false; + + while let Some((rx_token, tx_token)) = device.receive(self.inner.now) { + let rx_meta = rx_token.meta(); + rx_token.consume(|frame| { + if frame.is_empty() { + return; + } + + match self.inner.caps.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => { + if let Some(packet) = self.inner.process_ethernet( + sockets, + rx_meta, + frame, + &mut self.fragments, + ) { + if let Err(err) = + self.inner.dispatch(tx_token, packet, &mut self.fragmenter) + { + net_debug!("Failed to send response: {:?}", err); + } + } + } + #[cfg(feature = "medium-ip")] + Medium::Ip => { + if let Some(packet) = + self.inner + .process_ip(sockets, rx_meta, frame, &mut self.fragments) + { + if let Err(err) = self.inner.dispatch_ip( + tx_token, + PacketMeta::default(), + packet, + &mut self.fragmenter, + ) { + net_debug!("Failed to send response: {:?}", err); + } + } + } + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => { + if let Some(packet) = self.inner.process_ieee802154( + sockets, + rx_meta, + frame, + &mut self.fragments, + ) { + if let Err(err) = self.inner.dispatch_ip( + tx_token, + PacketMeta::default(), + packet, + &mut self.fragmenter, + ) { + net_debug!("Failed to send response: {:?}", err); + } + } + } + } + processed_any = true; + }); + } + + processed_any + } + + fn socket_egress<D>(&mut self, device: &mut D, sockets: &mut SocketSet<'_>) -> bool + where + D: Device + ?Sized, + { + let _caps = device.capabilities(); + + enum EgressError { + Exhausted, + Dispatch(DispatchError), + } + + let mut emitted_any = false; + for item in sockets.items_mut() { + if !item + .meta + .egress_permitted(self.inner.now, |ip_addr| self.inner.has_neighbor(&ip_addr)) + { + continue; + } + + let mut neighbor_addr = None; + let mut respond = |inner: &mut InterfaceInner, meta: PacketMeta, response: Packet| { + neighbor_addr = Some(response.ip_repr().dst_addr()); + let t = device.transmit(inner.now).ok_or_else(|| { + net_debug!("failed to transmit IP: device exhausted"); + EgressError::Exhausted + })?; + + inner + .dispatch_ip(t, meta, response, &mut self.fragmenter) + .map_err(EgressError::Dispatch)?; + + emitted_any = true; + + Ok(()) + }; + + let result = match &mut item.socket { + #[cfg(feature = "socket-raw")] + Socket::Raw(socket) => socket.dispatch(&mut self.inner, |inner, (ip, raw)| { + respond( + inner, + PacketMeta::default(), + Packet::new(ip, IpPayload::Raw(raw)), + ) + }), + #[cfg(feature = "socket-icmp")] + Socket::Icmp(socket) => { + socket.dispatch(&mut self.inner, |inner, response| match response { + #[cfg(feature = "proto-ipv4")] + (IpRepr::Ipv4(ipv4_repr), IcmpRepr::Ipv4(icmpv4_repr)) => respond( + inner, + PacketMeta::default(), + Packet::new_ipv4(ipv4_repr, IpPayload::Icmpv4(icmpv4_repr)), + ), + #[cfg(feature = "proto-ipv6")] + (IpRepr::Ipv6(ipv6_repr), IcmpRepr::Ipv6(icmpv6_repr)) => respond( + inner, + PacketMeta::default(), + Packet::new_ipv6(ipv6_repr, IpPayload::Icmpv6(icmpv6_repr)), + ), + #[allow(unreachable_patterns)] + _ => unreachable!(), + }) + } + #[cfg(feature = "socket-udp")] + Socket::Udp(socket) => { + socket.dispatch(&mut self.inner, |inner, meta, (ip, udp, payload)| { + respond(inner, meta, Packet::new(ip, IpPayload::Udp(udp, payload))) + }) + } + #[cfg(feature = "socket-tcp")] + Socket::Tcp(socket) => socket.dispatch(&mut self.inner, |inner, (ip, tcp)| { + respond( + inner, + PacketMeta::default(), + Packet::new(ip, IpPayload::Tcp(tcp)), + ) + }), + #[cfg(feature = "socket-dhcpv4")] + Socket::Dhcpv4(socket) => { + socket.dispatch(&mut self.inner, |inner, (ip, udp, dhcp)| { + respond( + inner, + PacketMeta::default(), + Packet::new_ipv4(ip, IpPayload::Dhcpv4(udp, dhcp)), + ) + }) + } + #[cfg(feature = "socket-dns")] + Socket::Dns(socket) => socket.dispatch(&mut self.inner, |inner, (ip, udp, dns)| { + respond( + inner, + PacketMeta::default(), + Packet::new(ip, IpPayload::Udp(udp, dns)), + ) + }), + }; + + match result { + Err(EgressError::Exhausted) => break, // Device buffer full. + Err(EgressError::Dispatch(_)) => { + // `NeighborCache` already takes care of rate limiting the neighbor discovery + // requests from the socket. However, without an additional rate limiting + // mechanism, we would spin on every socket that has yet to discover its + // neighbor. + item.meta.neighbor_missing( + self.inner.now, + neighbor_addr.expect("non-IP response packet"), + ); + } + Ok(()) => {} + } + } + emitted_any + } + + /// Process fragments that still need to be sent for IPv4 packets. + /// + /// This function returns a boolean value indicating whether any packets were + /// processed or emitted, and thus, whether the readiness of any socket might + /// have changed. + #[cfg(feature = "proto-ipv4-fragmentation")] + fn ipv4_egress<D>(&mut self, device: &mut D) -> bool + where + D: Device + ?Sized, + { + // Reset the buffer when we transmitted everything. + if self.fragmenter.finished() { + self.fragmenter.reset(); + } + + if self.fragmenter.is_empty() { + return false; + } + + let pkt = &self.fragmenter; + if pkt.packet_len > pkt.sent_bytes { + if let Some(tx_token) = device.transmit(self.inner.now) { + self.inner + .dispatch_ipv4_frag(tx_token, &mut self.fragmenter); + return true; + } + } + false + } + + /// Process fragments that still need to be sent for 6LoWPAN packets. + /// + /// This function returns a boolean value indicating whether any packets were + /// processed or emitted, and thus, whether the readiness of any socket might + /// have changed. + #[cfg(feature = "proto-sixlowpan-fragmentation")] + fn sixlowpan_egress<D>(&mut self, device: &mut D) -> bool + where + D: Device + ?Sized, + { + // Reset the buffer when we transmitted everything. + if self.fragmenter.finished() { + self.fragmenter.reset(); + } + + if self.fragmenter.is_empty() { + return false; + } + + let pkt = &self.fragmenter; + if pkt.packet_len > pkt.sent_bytes { + if let Some(tx_token) = device.transmit(self.inner.now) { + self.inner + .dispatch_ieee802154_frag(tx_token, &mut self.fragmenter); + return true; + } + } + false + } +} + +impl InterfaceInner { + #[allow(unused)] // unused depending on which sockets are enabled + pub(crate) fn now(&self) -> Instant { + self.now + } + + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + #[allow(unused)] // unused depending on which sockets are enabled + pub(crate) fn hardware_addr(&self) -> HardwareAddress { + self.hardware_addr + } + + #[allow(unused)] // unused depending on which sockets are enabled + pub(crate) fn checksum_caps(&self) -> ChecksumCapabilities { + self.caps.checksum.clone() + } + + #[allow(unused)] // unused depending on which sockets are enabled + pub(crate) fn ip_mtu(&self) -> usize { + self.caps.ip_mtu() + } + + #[allow(unused)] // unused depending on which sockets are enabled, and in tests + pub(crate) fn rand(&mut self) -> &mut Rand { + &mut self.rand + } + + #[allow(unused)] // unused depending on which sockets are enabled + pub(crate) fn get_source_address(&self, dst_addr: &IpAddress) -> Option<IpAddress> { + match dst_addr { + #[cfg(feature = "proto-ipv4")] + IpAddress::Ipv4(addr) => self.get_source_address_ipv4(addr).map(|a| a.into()), + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(addr) => self.get_source_address_ipv6(addr).map(|a| a.into()), + } + } + + #[cfg(feature = "proto-ipv4")] + #[allow(unused)] + pub(crate) fn get_source_address_ipv4(&self, _dst_addr: &Ipv4Address) -> Option<Ipv4Address> { + for cidr in self.ip_addrs.iter() { + #[allow(irrefutable_let_patterns)] // if only ipv4 is enabled + if let IpCidr::Ipv4(cidr) = cidr { + return Some(cidr.address()); + } + } + None + } + + #[cfg(feature = "proto-ipv6")] + #[allow(unused)] + pub(crate) fn get_source_address_ipv6(&self, dst_addr: &Ipv6Address) -> Option<Ipv6Address> { + // RFC 6724 describes how to select the correct source address depending on the destination + // address. + + // See RFC 6724 Section 4: Candidate source address + fn is_candidate_source_address(dst_addr: &Ipv6Address, src_addr: &Ipv6Address) -> bool { + // For all multicast and link-local destination addresses, the candidate address MUST + // only be an address from the same link. + if dst_addr.is_link_local() && !src_addr.is_link_local() { + return false; + } + + if dst_addr.is_multicast() + && matches!(dst_addr.scope(), Ipv6AddressScope::LinkLocal) + && src_addr.is_multicast() + && !matches!(src_addr.scope(), Ipv6AddressScope::LinkLocal) + { + return false; + } + + // Loopback addresses and multicast address can not be in the candidate source address + // list. Except when the destination multicast address has a link-local scope, then the + // source address can also be link-local multicast. + if src_addr.is_loopback() || src_addr.is_multicast() { + return false; + } + + true + } + + // See RFC 6724 Section 2.2: Common Prefix Length + fn common_prefix_length(dst_addr: &Ipv6Cidr, src_addr: &Ipv6Address) -> usize { + let addr = dst_addr.address(); + let mut bits = 0; + for (l, r) in addr.as_bytes().iter().zip(src_addr.as_bytes().iter()) { + if l == r { + bits += 8; + } else { + bits += (l ^ r).leading_zeros(); + break; + } + } + + bits = bits.min(dst_addr.prefix_len() as u32); + + bits as usize + } + + // Get the first address that is a candidate address. + let mut candidate = self + .ip_addrs + .iter() + .filter_map(|a| match a { + #[cfg(feature = "proto-ipv4")] + IpCidr::Ipv4(_) => None, + #[cfg(feature = "proto-ipv6")] + IpCidr::Ipv6(a) => Some(a), + }) + .find(|a| is_candidate_source_address(dst_addr, &a.address())) + .unwrap(); + + for addr in self.ip_addrs.iter().filter_map(|a| match a { + #[cfg(feature = "proto-ipv4")] + IpCidr::Ipv4(_) => None, + #[cfg(feature = "proto-ipv6")] + IpCidr::Ipv6(a) => Some(a), + }) { + if !is_candidate_source_address(dst_addr, &addr.address()) { + continue; + } + + // Rule 1: prefer the address that is the same as the output destination address. + if candidate.address() != *dst_addr && addr.address() == *dst_addr { + candidate = addr; + } + + // Rule 2: prefer appropriate scope. + if (candidate.address().scope() as u8) < (addr.address().scope() as u8) { + if (candidate.address().scope() as u8) < (dst_addr.scope() as u8) { + candidate = addr; + } + } else if (addr.address().scope() as u8) > (dst_addr.scope() as u8) { + candidate = addr; + } + + // Rule 3: avoid deprecated addresses (TODO) + // Rule 4: prefer home addresses (TODO) + // Rule 5: prefer outgoing interfaces (TODO) + // Rule 5.5: prefer addresses in a prefix advertises by the next-hop (TODO). + // Rule 6: prefer matching label (TODO) + // Rule 7: prefer temporary addresses (TODO) + // Rule 8: use longest matching prefix + if common_prefix_length(candidate, dst_addr) < common_prefix_length(addr, dst_addr) { + candidate = addr; + } + } + + Some(candidate.address()) + } + + #[cfg(test)] + #[allow(unused)] // unused depending on which sockets are enabled + pub(crate) fn set_now(&mut self, now: Instant) { + self.now = now + } + + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + fn check_hardware_addr(addr: &HardwareAddress) { + if !addr.is_unicast() { + panic!("Hardware address {addr} is not unicast") + } + } + + fn check_ip_addrs(addrs: &[IpCidr]) { + for cidr in addrs { + if !cidr.address().is_unicast() && !cidr.address().is_unspecified() { + panic!("IP address {} is not unicast", cidr.address()) + } + } + } + + #[cfg(feature = "medium-ieee802154")] + fn get_sequence_number(&mut self) -> u8 { + let no = self.sequence_no; + self.sequence_no = self.sequence_no.wrapping_add(1); + no + } + + #[cfg(feature = "proto-ipv4-fragmentation")] + fn get_ipv4_ident(&mut self) -> u16 { + let ipv4_id = self.ipv4_id; + self.ipv4_id = self.ipv4_id.wrapping_add(1); + ipv4_id + } + + #[cfg(feature = "proto-sixlowpan-fragmentation")] + fn get_sixlowpan_fragment_tag(&mut self) -> u16 { + let tag = self.tag; + self.tag = self.tag.wrapping_add(1); + tag + } + + /// Determine if the given `Ipv6Address` is the solicited node + /// multicast address for a IPv6 addresses assigned to the interface. + /// See [RFC 4291 § 2.7.1] for more details. + /// + /// [RFC 4291 § 2.7.1]: https://tools.ietf.org/html/rfc4291#section-2.7.1 + #[cfg(feature = "proto-ipv6")] + pub fn has_solicited_node(&self, addr: Ipv6Address) -> bool { + self.ip_addrs.iter().any(|cidr| { + match *cidr { + IpCidr::Ipv6(cidr) if cidr.address() != Ipv6Address::LOOPBACK => { + // Take the lower order 24 bits of the IPv6 address and + // append those bits to FF02:0:0:0:0:1:FF00::/104. + addr.as_bytes()[14..] == cidr.address().as_bytes()[14..] + } + _ => false, + } + }) + } + + /// Check whether the interface has the given IP address assigned. + fn has_ip_addr<T: Into<IpAddress>>(&self, addr: T) -> bool { + let addr = addr.into(); + self.ip_addrs.iter().any(|probe| probe.address() == addr) + } + + /// Get the first IPv4 address of the interface. + #[cfg(feature = "proto-ipv4")] + pub fn ipv4_addr(&self) -> Option<Ipv4Address> { + self.ip_addrs.iter().find_map(|addr| match *addr { + IpCidr::Ipv4(cidr) => Some(cidr.address()), + #[allow(unreachable_patterns)] + _ => None, + }) + } + + /// Get the first IPv6 address if present. + #[cfg(feature = "proto-ipv6")] + pub fn ipv6_addr(&self) -> Option<Ipv6Address> { + self.ip_addrs.iter().find_map(|addr| match *addr { + IpCidr::Ipv6(cidr) => Some(cidr.address()), + #[allow(unreachable_patterns)] + _ => None, + }) + } + + /// Check whether the interface listens to given destination multicast IP address. + /// + /// If built without feature `proto-igmp` this function will + /// always return `false` when using IPv4. + fn has_multicast_group<T: Into<IpAddress>>(&self, addr: T) -> bool { + match addr.into() { + #[cfg(feature = "proto-igmp")] + IpAddress::Ipv4(key) => { + key == Ipv4Address::MULTICAST_ALL_SYSTEMS + || self.ipv4_multicast_groups.get(&key).is_some() + } + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(Ipv6Address::LINK_LOCAL_ALL_NODES) => true, + #[cfg(feature = "proto-rpl")] + IpAddress::Ipv6(Ipv6Address::LINK_LOCAL_ALL_RPL_NODES) => true, + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(addr) => self.has_solicited_node(addr), + #[allow(unreachable_patterns)] + _ => false, + } + } + + #[cfg(feature = "medium-ip")] + fn process_ip<'frame>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + ip_payload: &'frame [u8], + frag: &'frame mut FragmentsBuffer, + ) -> Option<Packet<'frame>> { + match IpVersion::of_packet(ip_payload) { + #[cfg(feature = "proto-ipv4")] + Ok(IpVersion::Ipv4) => { + let ipv4_packet = check!(Ipv4Packet::new_checked(ip_payload)); + + self.process_ipv4(sockets, meta, &ipv4_packet, frag) + } + #[cfg(feature = "proto-ipv6")] + Ok(IpVersion::Ipv6) => { + let ipv6_packet = check!(Ipv6Packet::new_checked(ip_payload)); + self.process_ipv6(sockets, meta, &ipv6_packet) + } + // Drop all other traffic. + _ => None, + } + } + + #[cfg(feature = "socket-raw")] + fn raw_socket_filter( + &mut self, + sockets: &mut SocketSet, + ip_repr: &IpRepr, + ip_payload: &[u8], + ) -> bool { + let mut handled_by_raw_socket = false; + + // Pass every IP packet to all raw sockets we have registered. + for raw_socket in sockets + .items_mut() + .filter_map(|i| raw::Socket::downcast_mut(&mut i.socket)) + { + if raw_socket.accepts(ip_repr) { + raw_socket.process(self, ip_repr, ip_payload); + handled_by_raw_socket = true; + } + } + handled_by_raw_socket + } + + /// Checks if an address is broadcast, taking into account ipv4 subnet-local + /// broadcast addresses. + pub(crate) fn is_broadcast(&self, address: &IpAddress) -> bool { + match address { + #[cfg(feature = "proto-ipv4")] + IpAddress::Ipv4(address) => self.is_broadcast_v4(*address), + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(_) => false, + } + } + + /// Checks if an address is broadcast, taking into account ipv4 subnet-local + /// broadcast addresses. + #[cfg(feature = "proto-ipv4")] + pub(crate) fn is_broadcast_v4(&self, address: Ipv4Address) -> bool { + if address.is_broadcast() { + return true; + } + + self.ip_addrs + .iter() + .filter_map(|own_cidr| match own_cidr { + IpCidr::Ipv4(own_ip) => Some(own_ip.broadcast()?), + #[cfg(feature = "proto-ipv6")] + IpCidr::Ipv6(_) => None, + }) + .any(|broadcast_address| address == broadcast_address) + } + + /// Checks if an ipv4 address is unicast, taking into account subnet broadcast addresses + #[cfg(feature = "proto-ipv4")] + fn is_unicast_v4(&self, address: Ipv4Address) -> bool { + address.is_unicast() && !self.is_broadcast_v4(address) + } + + #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] + #[allow(clippy::too_many_arguments)] + fn process_udp<'frame>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + ip_repr: IpRepr, + udp_repr: UdpRepr, + handled_by_raw_socket: bool, + udp_payload: &'frame [u8], + ip_payload: &'frame [u8], + ) -> Option<Packet<'frame>> { + #[cfg(feature = "socket-udp")] + for udp_socket in sockets + .items_mut() + .filter_map(|i| udp::Socket::downcast_mut(&mut i.socket)) + { + if udp_socket.accepts(self, &ip_repr, &udp_repr) { + udp_socket.process(self, meta, &ip_repr, &udp_repr, udp_payload); + return None; + } + } + + #[cfg(feature = "socket-dns")] + for dns_socket in sockets + .items_mut() + .filter_map(|i| dns::Socket::downcast_mut(&mut i.socket)) + { + if dns_socket.accepts(&ip_repr, &udp_repr) { + dns_socket.process(self, &ip_repr, &udp_repr, udp_payload); + return None; + } + } + + // The packet wasn't handled by a socket, send an ICMP port unreachable packet. + match ip_repr { + #[cfg(feature = "proto-ipv4")] + IpRepr::Ipv4(_) if handled_by_raw_socket => None, + #[cfg(feature = "proto-ipv6")] + IpRepr::Ipv6(_) if handled_by_raw_socket => None, + #[cfg(feature = "proto-ipv4")] + IpRepr::Ipv4(ipv4_repr) => { + let payload_len = + icmp_reply_payload_len(ip_payload.len(), IPV4_MIN_MTU, ipv4_repr.buffer_len()); + let icmpv4_reply_repr = Icmpv4Repr::DstUnreachable { + reason: Icmpv4DstUnreachable::PortUnreachable, + header: ipv4_repr, + data: &ip_payload[0..payload_len], + }; + self.icmpv4_reply(ipv4_repr, icmpv4_reply_repr) + } + #[cfg(feature = "proto-ipv6")] + IpRepr::Ipv6(ipv6_repr) => { + let payload_len = + icmp_reply_payload_len(ip_payload.len(), IPV6_MIN_MTU, ipv6_repr.buffer_len()); + let icmpv6_reply_repr = Icmpv6Repr::DstUnreachable { + reason: Icmpv6DstUnreachable::PortUnreachable, + header: ipv6_repr, + data: &ip_payload[0..payload_len], + }; + self.icmpv6_reply(ipv6_repr, icmpv6_reply_repr) + } + } + } + + #[cfg(feature = "socket-tcp")] + pub(crate) fn process_tcp<'frame>( + &mut self, + sockets: &mut SocketSet, + ip_repr: IpRepr, + ip_payload: &'frame [u8], + ) -> Option<Packet<'frame>> { + let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr()); + let tcp_packet = check!(TcpPacket::new_checked(ip_payload)); + let tcp_repr = check!(TcpRepr::parse( + &tcp_packet, + &src_addr, + &dst_addr, + &self.caps.checksum + )); + + for tcp_socket in sockets + .items_mut() + .filter_map(|i| tcp::Socket::downcast_mut(&mut i.socket)) + { + if tcp_socket.accepts(self, &ip_repr, &tcp_repr) { + return tcp_socket + .process(self, &ip_repr, &tcp_repr) + .map(|(ip, tcp)| Packet::new(ip, IpPayload::Tcp(tcp))); + } + } + + if tcp_repr.control == TcpControl::Rst + || ip_repr.dst_addr().is_unspecified() + || ip_repr.src_addr().is_unspecified() + { + // Never reply to a TCP RST packet with another TCP RST packet. We also never want to + // send a TCP RST packet with unspecified addresses. + None + } else { + // The packet wasn't handled by a socket, send a TCP RST packet. + let (ip, tcp) = tcp::Socket::rst_reply(&ip_repr, &tcp_repr); + Some(Packet::new(ip, IpPayload::Tcp(tcp))) + } + } + + #[cfg(feature = "medium-ethernet")] + fn dispatch<Tx>( + &mut self, + tx_token: Tx, + packet: EthernetPacket, + frag: &mut Fragmenter, + ) -> Result<(), DispatchError> + where + Tx: TxToken, + { + match packet { + #[cfg(feature = "proto-ipv4")] + EthernetPacket::Arp(arp_repr) => { + let dst_hardware_addr = match arp_repr { + ArpRepr::EthernetIpv4 { + target_hardware_addr, + .. + } => target_hardware_addr, + }; + + self.dispatch_ethernet(tx_token, arp_repr.buffer_len(), |mut frame| { + frame.set_dst_addr(dst_hardware_addr); + frame.set_ethertype(EthernetProtocol::Arp); + + let mut packet = ArpPacket::new_unchecked(frame.payload_mut()); + arp_repr.emit(&mut packet); + }) + } + EthernetPacket::Ip(packet) => { + self.dispatch_ip(tx_token, PacketMeta::default(), packet, frag) + } + } + } + + fn in_same_network(&self, addr: &IpAddress) -> bool { + self.ip_addrs.iter().any(|cidr| cidr.contains_addr(addr)) + } + + fn route(&self, addr: &IpAddress, timestamp: Instant) -> Option<IpAddress> { + // Send directly. + // note: no need to use `self.is_broadcast()` to check for subnet-local broadcast addrs + // here because `in_same_network` will already return true. + if self.in_same_network(addr) || addr.is_broadcast() { + return Some(*addr); + } + + // Route via a router. + self.routes.lookup(addr, timestamp) + } + + fn has_neighbor(&self, addr: &IpAddress) -> bool { + match self.route(addr, self.now) { + Some(_routed_addr) => match self.caps.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => self.neighbor_cache.lookup(&_routed_addr, self.now).found(), + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => self.neighbor_cache.lookup(&_routed_addr, self.now).found(), + #[cfg(feature = "medium-ip")] + Medium::Ip => true, + }, + None => false, + } + } + + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + fn lookup_hardware_addr<Tx>( + &mut self, + tx_token: Tx, + src_addr: &IpAddress, + dst_addr: &IpAddress, + fragmenter: &mut Fragmenter, + ) -> Result<(HardwareAddress, Tx), DispatchError> + where + Tx: TxToken, + { + if self.is_broadcast(dst_addr) { + let hardware_addr = match self.caps.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => HardwareAddress::Ethernet(EthernetAddress::BROADCAST), + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => HardwareAddress::Ieee802154(Ieee802154Address::BROADCAST), + #[cfg(feature = "medium-ip")] + Medium::Ip => unreachable!(), + }; + + return Ok((hardware_addr, tx_token)); + } + + if dst_addr.is_multicast() { + let b = dst_addr.as_bytes(); + let hardware_addr = match *dst_addr { + #[cfg(feature = "proto-ipv4")] + IpAddress::Ipv4(_addr) => match self.caps.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => HardwareAddress::Ethernet(EthernetAddress::from_bytes(&[ + 0x01, + 0x00, + 0x5e, + b[1] & 0x7F, + b[2], + b[3], + ])), + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => unreachable!(), + #[cfg(feature = "medium-ip")] + Medium::Ip => unreachable!(), + }, + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(_addr) => match self.caps.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => HardwareAddress::Ethernet(EthernetAddress::from_bytes(&[ + 0x33, 0x33, b[12], b[13], b[14], b[15], + ])), + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => { + // Not sure if this is correct + HardwareAddress::Ieee802154(Ieee802154Address::BROADCAST) + } + #[cfg(feature = "medium-ip")] + Medium::Ip => unreachable!(), + }, + }; + + return Ok((hardware_addr, tx_token)); + } + + let dst_addr = self + .route(dst_addr, self.now) + .ok_or(DispatchError::NoRoute)?; + + match self.neighbor_cache.lookup(&dst_addr, self.now) { + NeighborAnswer::Found(hardware_addr) => return Ok((hardware_addr, tx_token)), + NeighborAnswer::RateLimited => return Err(DispatchError::NeighborPending), + _ => (), // XXX + } + + match (src_addr, dst_addr) { + #[cfg(all(feature = "medium-ethernet", feature = "proto-ipv4"))] + (&IpAddress::Ipv4(src_addr), IpAddress::Ipv4(dst_addr)) + if matches!(self.caps.medium, Medium::Ethernet) => + { + net_debug!( + "address {} not in neighbor cache, sending ARP request", + dst_addr + ); + let src_hardware_addr = self.hardware_addr.ethernet_or_panic(); + + let arp_repr = ArpRepr::EthernetIpv4 { + operation: ArpOperation::Request, + source_hardware_addr: src_hardware_addr, + source_protocol_addr: src_addr, + target_hardware_addr: EthernetAddress::BROADCAST, + target_protocol_addr: dst_addr, + }; + + if let Err(e) = + self.dispatch_ethernet(tx_token, arp_repr.buffer_len(), |mut frame| { + frame.set_dst_addr(EthernetAddress::BROADCAST); + frame.set_ethertype(EthernetProtocol::Arp); + + arp_repr.emit(&mut ArpPacket::new_unchecked(frame.payload_mut())) + }) + { + net_debug!("Failed to dispatch ARP request: {:?}", e); + return Err(DispatchError::NeighborPending); + } + } + + #[cfg(feature = "proto-ipv6")] + (&IpAddress::Ipv6(src_addr), IpAddress::Ipv6(dst_addr)) => { + net_debug!( + "address {} not in neighbor cache, sending Neighbor Solicitation", + dst_addr + ); + + let solicit = Icmpv6Repr::Ndisc(NdiscRepr::NeighborSolicit { + target_addr: dst_addr, + lladdr: Some(self.hardware_addr.into()), + }); + + let packet = Packet::new_ipv6( + Ipv6Repr { + src_addr, + dst_addr: dst_addr.solicited_node(), + next_header: IpProtocol::Icmpv6, + payload_len: solicit.buffer_len(), + hop_limit: 0xff, + }, + IpPayload::Icmpv6(solicit), + ); + + if let Err(e) = + self.dispatch_ip(tx_token, PacketMeta::default(), packet, fragmenter) + { + net_debug!("Failed to dispatch NDISC solicit: {:?}", e); + return Err(DispatchError::NeighborPending); + } + } + + #[allow(unreachable_patterns)] + _ => (), + } + + // The request got dispatched, limit the rate on the cache. + self.neighbor_cache.limit_rate(self.now); + Err(DispatchError::NeighborPending) + } + + fn flush_cache(&mut self) { + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + self.neighbor_cache.flush() + } + + fn dispatch_ip<Tx: TxToken>( + &mut self, + // NOTE(unused_mut): tx_token isn't always mutated, depending on + // the feature set that is used. + #[allow(unused_mut)] mut tx_token: Tx, + meta: PacketMeta, + packet: Packet, + frag: &mut Fragmenter, + ) -> Result<(), DispatchError> { + let mut ip_repr = packet.ip_repr(); + assert!(!ip_repr.dst_addr().is_unspecified()); + + // Dispatch IEEE802.15.4: + + #[cfg(feature = "medium-ieee802154")] + if matches!(self.caps.medium, Medium::Ieee802154) { + let (addr, tx_token) = self.lookup_hardware_addr( + tx_token, + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + frag, + )?; + let addr = addr.ieee802154_or_panic(); + + self.dispatch_ieee802154(addr, tx_token, meta, packet, frag); + return Ok(()); + } + + // Dispatch IP/Ethernet: + + let caps = self.caps.clone(); + + #[cfg(feature = "proto-ipv4-fragmentation")] + let ipv4_id = self.get_ipv4_ident(); + + // First we calculate the total length that we will have to emit. + let mut total_len = ip_repr.buffer_len(); + + // Add the size of the Ethernet header if the medium is Ethernet. + #[cfg(feature = "medium-ethernet")] + if matches!(self.caps.medium, Medium::Ethernet) { + total_len = EthernetFrame::<&[u8]>::buffer_len(total_len); + } + + // If the medium is Ethernet, then we need to retrieve the destination hardware address. + #[cfg(feature = "medium-ethernet")] + let (dst_hardware_addr, mut tx_token) = match self.caps.medium { + Medium::Ethernet => { + match self.lookup_hardware_addr( + tx_token, + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + frag, + )? { + (HardwareAddress::Ethernet(addr), tx_token) => (addr, tx_token), + (_, _) => unreachable!(), + } + } + _ => (EthernetAddress([0; 6]), tx_token), + }; + + // Emit function for the Ethernet header. + #[cfg(feature = "medium-ethernet")] + let emit_ethernet = |repr: &IpRepr, tx_buffer: &mut [u8]| { + let mut frame = EthernetFrame::new_unchecked(tx_buffer); + + let src_addr = self.hardware_addr.ethernet_or_panic(); + frame.set_src_addr(src_addr); + frame.set_dst_addr(dst_hardware_addr); + + match repr.version() { + #[cfg(feature = "proto-ipv4")] + IpVersion::Ipv4 => frame.set_ethertype(EthernetProtocol::Ipv4), + #[cfg(feature = "proto-ipv6")] + IpVersion::Ipv6 => frame.set_ethertype(EthernetProtocol::Ipv6), + } + + Ok(()) + }; + + // Emit function for the IP header and payload. + let emit_ip = |repr: &IpRepr, mut tx_buffer: &mut [u8]| { + repr.emit(&mut tx_buffer, &self.caps.checksum); + + let payload = &mut tx_buffer[repr.header_len()..]; + packet.emit_payload(repr, payload, &caps) + }; + + let total_ip_len = ip_repr.buffer_len(); + + match &mut ip_repr { + #[cfg(feature = "proto-ipv4")] + IpRepr::Ipv4(repr) => { + // If we have an IPv4 packet, then we need to check if we need to fragment it. + if total_ip_len > self.caps.max_transmission_unit { + #[cfg(feature = "proto-ipv4-fragmentation")] + { + net_debug!("start fragmentation"); + + // Calculate how much we will send now (including the Ethernet header). + let tx_len = self.caps.max_transmission_unit; + + let ip_header_len = repr.buffer_len(); + let first_frag_ip_len = self.caps.ip_mtu(); + + if frag.buffer.len() < total_ip_len { + net_debug!( + "Fragmentation buffer is too small, at least {} needed. Dropping", + total_ip_len + ); + return Ok(()); + } + + #[cfg(feature = "medium-ethernet")] + { + frag.ipv4.dst_hardware_addr = dst_hardware_addr; + } + + // Save the total packet len (without the Ethernet header, but with the first + // IP header). + frag.packet_len = total_ip_len; + + // Save the IP header for other fragments. + frag.ipv4.repr = *repr; + + // Save how much bytes we will send now. + frag.sent_bytes = first_frag_ip_len; + + // Modify the IP header + repr.payload_len = first_frag_ip_len - repr.buffer_len(); + + // Emit the IP header to the buffer. + emit_ip(&ip_repr, &mut frag.buffer); + + let mut ipv4_packet = Ipv4Packet::new_unchecked(&mut frag.buffer[..]); + frag.ipv4.ident = ipv4_id; + ipv4_packet.set_ident(ipv4_id); + ipv4_packet.set_more_frags(true); + ipv4_packet.set_dont_frag(false); + ipv4_packet.set_frag_offset(0); + + if caps.checksum.ipv4.tx() { + ipv4_packet.fill_checksum(); + } + + // Transmit the first packet. + tx_token.consume(tx_len, |mut tx_buffer| { + #[cfg(feature = "medium-ethernet")] + if matches!(self.caps.medium, Medium::Ethernet) { + emit_ethernet(&ip_repr, tx_buffer)?; + tx_buffer = &mut tx_buffer[EthernetFrame::<&[u8]>::header_len()..]; + } + + // Change the offset for the next packet. + frag.ipv4.frag_offset = (first_frag_ip_len - ip_header_len) as u16; + + // Copy the IP header and the payload. + tx_buffer[..first_frag_ip_len] + .copy_from_slice(&frag.buffer[..first_frag_ip_len]); + + Ok(()) + }) + } + + #[cfg(not(feature = "proto-ipv4-fragmentation"))] + { + net_debug!("Enable the `proto-ipv4-fragmentation` feature for fragmentation support."); + Ok(()) + } + } else { + tx_token.set_meta(meta); + + // No fragmentation is required. + tx_token.consume(total_len, |mut tx_buffer| { + #[cfg(feature = "medium-ethernet")] + if matches!(self.caps.medium, Medium::Ethernet) { + emit_ethernet(&ip_repr, tx_buffer)?; + tx_buffer = &mut tx_buffer[EthernetFrame::<&[u8]>::header_len()..]; + } + + emit_ip(&ip_repr, tx_buffer); + Ok(()) + }) + } + } + // We don't support IPv6 fragmentation yet. + #[cfg(feature = "proto-ipv6")] + IpRepr::Ipv6(_) => tx_token.consume(total_len, |mut tx_buffer| { + #[cfg(feature = "medium-ethernet")] + if matches!(self.caps.medium, Medium::Ethernet) { + emit_ethernet(&ip_repr, tx_buffer)?; + tx_buffer = &mut tx_buffer[EthernetFrame::<&[u8]>::header_len()..]; + } + + emit_ip(&ip_repr, tx_buffer); + Ok(()) + }), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +enum DispatchError { + /// No route to dispatch this packet. Retrying won't help unless + /// configuration is changed. + NoRoute, + /// We do have a route to dispatch this packet, but we haven't discovered + /// the neighbor for it yet. Discovery has been initiated, dispatch + /// should be retried later. + NeighborPending, +} diff --git a/src/iface/interface/sixlowpan.rs b/src/iface/interface/sixlowpan.rs new file mode 100644 index 0000000..de2c3e9 --- /dev/null +++ b/src/iface/interface/sixlowpan.rs @@ -0,0 +1,922 @@ +use super::*; + +use crate::phy::ChecksumCapabilities; +use crate::wire::*; + +// Max len of non-fragmented packets after decompression (including ipv6 header and payload) +// TODO: lower. Should be (6lowpan mtu) - (min 6lowpan header size) + (max ipv6 header size) +pub(crate) const MAX_DECOMPRESSED_LEN: usize = 1500; + +impl InterfaceInner { + pub(super) fn process_sixlowpan<'output, 'payload: 'output>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + ieee802154_repr: &Ieee802154Repr, + payload: &'payload [u8], + f: &'output mut FragmentsBuffer, + ) -> Option<Packet<'output>> { + let payload = match check!(SixlowpanPacket::dispatch(payload)) { + #[cfg(not(feature = "proto-sixlowpan-fragmentation"))] + SixlowpanPacket::FragmentHeader => { + net_debug!( + "Fragmentation is not supported, \ + use the `proto-sixlowpan-fragmentation` feature to add support." + ); + return None; + } + #[cfg(feature = "proto-sixlowpan-fragmentation")] + SixlowpanPacket::FragmentHeader => { + match self.process_sixlowpan_fragment(ieee802154_repr, payload, f) { + Some(payload) => payload, + None => return None, + } + } + SixlowpanPacket::IphcHeader => { + match Self::sixlowpan_to_ipv6( + &self.sixlowpan_address_context, + ieee802154_repr, + payload, + None, + &mut f.decompress_buf, + ) { + Ok(len) => &f.decompress_buf[..len], + Err(e) => { + net_debug!("sixlowpan decompress failed: {:?}", e); + return None; + } + } + } + }; + + self.process_ipv6(sockets, meta, &check!(Ipv6Packet::new_checked(payload))) + } + + #[cfg(feature = "proto-sixlowpan-fragmentation")] + fn process_sixlowpan_fragment<'output, 'payload: 'output>( + &mut self, + ieee802154_repr: &Ieee802154Repr, + payload: &'payload [u8], + f: &'output mut FragmentsBuffer, + ) -> Option<&'output [u8]> { + use crate::iface::fragmentation::{AssemblerError, AssemblerFullError}; + + // We have a fragment header, which means we cannot process the 6LoWPAN packet, + // unless we have a complete one after processing this fragment. + let frag = check!(SixlowpanFragPacket::new_checked(payload)); + + // The key specifies to which 6LoWPAN fragment it belongs too. + // It is based on the link layer addresses, the tag and the size. + let key = FragKey::Sixlowpan(frag.get_key(ieee802154_repr)); + + // The offset of this fragment in increments of 8 octets. + let offset = frag.datagram_offset() as usize * 8; + + // We reserve a spot in the packet assembler set and add the required + // information to the packet assembler. + // This information is the total size of the packet when it is fully assmbled. + // We also pass the header size, since this is needed when other fragments + // (other than the first one) are added. + let frag_slot = match f.assembler.get(&key, self.now + f.reassembly_timeout) { + Ok(frag) => frag, + Err(AssemblerFullError) => { + net_debug!("No available packet assembler for fragmented packet"); + return None; + } + }; + + if frag.is_first_fragment() { + // The first fragment contains the total size of the IPv6 packet. + // However, we received a packet that is compressed following the 6LoWPAN + // standard. This means we need to convert the IPv6 packet size to a 6LoWPAN + // packet size. The packet size can be different because of first the + // compression of the IP header and when UDP is used (because the UDP header + // can also be compressed). Other headers are not compressed by 6LoWPAN. + + // First segment tells us the total size. + let total_size = frag.datagram_size() as usize; + if frag_slot.set_total_size(total_size).is_err() { + net_debug!("No available packet assembler for fragmented packet"); + return None; + } + + // Decompress headers+payload into the assembler. + if let Err(e) = frag_slot.add_with(0, |buffer| { + Self::sixlowpan_to_ipv6( + &self.sixlowpan_address_context, + ieee802154_repr, + frag.payload(), + Some(total_size), + buffer, + ) + .map_err(|_| AssemblerError) + }) { + net_debug!("fragmentation error: {:?}", e); + return None; + } + } else { + // Add the fragment to the packet assembler. + if let Err(e) = frag_slot.add(frag.payload(), offset) { + net_debug!("fragmentation error: {:?}", e); + return None; + } + } + + match frag_slot.assemble() { + Some(payload) => { + net_trace!("6LoWPAN: fragmented packet now complete"); + Some(payload) + } + None => None, + } + } + + fn sixlowpan_to_ipv6( + address_context: &[SixlowpanAddressContext], + ieee802154_repr: &Ieee802154Repr, + iphc_payload: &[u8], + total_size: Option<usize>, + buffer: &mut [u8], + ) -> core::result::Result<usize, crate::wire::Error> { + let iphc = SixlowpanIphcPacket::new_checked(iphc_payload)?; + let iphc_repr = SixlowpanIphcRepr::parse( + &iphc, + ieee802154_repr.src_addr, + ieee802154_repr.dst_addr, + address_context, + )?; + + let first_next_header = match iphc_repr.next_header { + SixlowpanNextHeader::Compressed => { + match SixlowpanNhcPacket::dispatch(iphc.payload())? { + SixlowpanNhcPacket::ExtHeader => { + SixlowpanExtHeaderPacket::new_checked(iphc.payload())? + .extension_header_id() + .into() + } + SixlowpanNhcPacket::UdpHeader => IpProtocol::Udp, + } + } + SixlowpanNextHeader::Uncompressed(proto) => proto, + }; + + let mut decompressed_size = 40 + iphc.payload().len(); + let mut next_header = Some(iphc_repr.next_header); + let mut data = iphc.payload(); + + while let Some(nh) = next_header { + match nh { + SixlowpanNextHeader::Compressed => match SixlowpanNhcPacket::dispatch(data)? { + SixlowpanNhcPacket::ExtHeader => { + let ext_hdr = SixlowpanExtHeaderPacket::new_checked(data)?; + let ext_repr = SixlowpanExtHeaderRepr::parse(&ext_hdr)?; + decompressed_size += 2; + decompressed_size -= ext_repr.buffer_len(); + next_header = Some(ext_repr.next_header); + + if ext_repr.buffer_len() + ext_repr.length as usize > data.len() { + return Err(Error); + } + + data = &data[ext_repr.buffer_len() + ext_repr.length as usize..]; + } + SixlowpanNhcPacket::UdpHeader => { + let udp_packet = SixlowpanUdpNhcPacket::new_checked(data)?; + let udp_repr = SixlowpanUdpNhcRepr::parse( + &udp_packet, + &iphc_repr.src_addr, + &iphc_repr.dst_addr, + &crate::phy::ChecksumCapabilities::ignored(), + )?; + + decompressed_size += 8; + decompressed_size -= udp_repr.header_len(); + break; + } + }, + SixlowpanNextHeader::Uncompressed(proto) => match proto { + IpProtocol::Tcp => break, + IpProtocol::Udp => break, + IpProtocol::Icmpv6 => break, + proto => { + net_debug!("unable to decompress Uncompressed({})", proto); + return Err(Error); + } + }, + } + } + + if buffer.len() < decompressed_size { + net_debug!("sixlowpan decompress: buffer too short"); + return Err(crate::wire::Error); + } + let buffer = &mut buffer[..decompressed_size]; + + let total_size = if let Some(size) = total_size { + size + } else { + decompressed_size + }; + + let mut rest_size = total_size; + + let ipv6_repr = Ipv6Repr { + src_addr: iphc_repr.src_addr, + dst_addr: iphc_repr.dst_addr, + next_header: first_next_header, + payload_len: total_size - 40, + hop_limit: iphc_repr.hop_limit, + }; + rest_size -= 40; + + // Emit the decompressed IPHC header (decompressed to an IPv6 header). + let mut ipv6_packet = Ipv6Packet::new_unchecked(&mut buffer[..ipv6_repr.buffer_len()]); + ipv6_repr.emit(&mut ipv6_packet); + let mut buffer = &mut buffer[ipv6_repr.buffer_len()..]; + + let mut next_header = Some(iphc_repr.next_header); + let mut data = iphc.payload(); + + while let Some(nh) = next_header { + match nh { + SixlowpanNextHeader::Compressed => match SixlowpanNhcPacket::dispatch(data)? { + SixlowpanNhcPacket::ExtHeader => { + let ext_hdr = SixlowpanExtHeaderPacket::new_checked(data)?; + let ext_repr = SixlowpanExtHeaderRepr::parse(&ext_hdr)?; + + let nh = match ext_repr.next_header { + SixlowpanNextHeader::Compressed => { + let d = &data[ext_repr.length as usize + ext_repr.buffer_len()..]; + match SixlowpanNhcPacket::dispatch(d)? { + SixlowpanNhcPacket::ExtHeader => { + SixlowpanExtHeaderPacket::new_checked(d)? + .extension_header_id() + .into() + } + SixlowpanNhcPacket::UdpHeader => IpProtocol::Udp, + } + } + SixlowpanNextHeader::Uncompressed(proto) => proto, + }; + next_header = Some(ext_repr.next_header); + + let ipv6_ext_hdr = Ipv6ExtHeaderRepr { + next_header: nh, + length: ext_repr.length / 8, + data: &ext_hdr.payload()[..ext_repr.length as usize], + }; + + ipv6_ext_hdr.emit(&mut Ipv6ExtHeader::new_unchecked( + &mut buffer[..ipv6_ext_hdr.header_len()], + )); + buffer[ipv6_ext_hdr.header_len()..][..ipv6_ext_hdr.data.len()] + .copy_from_slice(ipv6_ext_hdr.data); + + buffer = &mut buffer[ipv6_ext_hdr.header_len() + ipv6_ext_hdr.data.len()..]; + + rest_size -= ipv6_ext_hdr.header_len() + ipv6_ext_hdr.data.len(); + data = &data[ext_repr.buffer_len() + ext_repr.length as usize..]; + } + SixlowpanNhcPacket::UdpHeader => { + let udp_packet = SixlowpanUdpNhcPacket::new_checked(data)?; + let payload = udp_packet.payload(); + let udp_repr = SixlowpanUdpNhcRepr::parse( + &udp_packet, + &iphc_repr.src_addr, + &iphc_repr.dst_addr, + &ChecksumCapabilities::ignored(), + )?; + + if payload.len() + 8 > buffer.len() { + return Err(Error); + } + + let mut udp = UdpPacket::new_unchecked(&mut buffer[..payload.len() + 8]); + udp_repr + .0 + .emit_header(&mut udp, rest_size - udp_repr.0.header_len()); + buffer[8..][..payload.len()].copy_from_slice(payload); + + break; + } + }, + SixlowpanNextHeader::Uncompressed(proto) => match proto { + IpProtocol::HopByHop => unreachable!(), + IpProtocol::Tcp => { + buffer.copy_from_slice(data); + break; + } + IpProtocol::Udp => { + buffer.copy_from_slice(data); + break; + } + IpProtocol::Icmpv6 => { + buffer.copy_from_slice(data); + break; + } + _ => unreachable!(), + }, + } + } + + Ok(decompressed_size) + } + + pub(super) fn dispatch_sixlowpan<Tx: TxToken>( + &mut self, + mut tx_token: Tx, + meta: PacketMeta, + packet: Packet, + ieee_repr: Ieee802154Repr, + frag: &mut Fragmenter, + ) { + let packet = match packet { + #[cfg(feature = "proto-ipv4")] + Packet::Ipv4(_) => unreachable!(), + Packet::Ipv6(packet) => packet, + }; + + // First we calculate the size we are going to need. If the size is bigger than the MTU, + // then we use fragmentation. + let (total_size, compressed_size, uncompressed_size) = + Self::compressed_packet_size(&packet, &ieee_repr); + + let ieee_len = ieee_repr.buffer_len(); + + // TODO(thvdveld): use the MTU of the device. + if total_size + ieee_len > 125 { + #[cfg(feature = "proto-sixlowpan-fragmentation")] + { + // The packet does not fit in one Ieee802154 frame, so we need fragmentation. + // We do this by emitting everything in the `frag.buffer` from the interface. + // After emitting everything into that buffer, we send the first fragment heere. + // When `poll` is called again, we check if frag was fully sent, otherwise we + // call `dispatch_ieee802154_frag`, which will transmit the other fragments. + + // `dispatch_ieee802154_frag` requires some information about the total packet size, + // the link local source and destination address... + + let pkt = frag; + if pkt.buffer.len() < total_size { + net_debug!( + "dispatch_ieee802154: dropping, \ + fragmentation buffer is too small, at least {} needed", + total_size + ); + return; + } + + let payload_length = packet.header.payload_len; + + Self::ipv6_to_sixlowpan( + &self.checksum_caps(), + packet, + &ieee_repr, + &mut pkt.buffer[..], + ); + + pkt.sixlowpan.ll_dst_addr = ieee_repr.dst_addr.unwrap(); + pkt.sixlowpan.ll_src_addr = ieee_repr.src_addr.unwrap(); + pkt.packet_len = total_size; + + // The datagram size that we need to set in the first fragment header is equal to the + // IPv6 payload length + 40. + pkt.sixlowpan.datagram_size = (payload_length + 40) as u16; + + let tag = self.get_sixlowpan_fragment_tag(); + // We save the tag for the other fragments that will be created when calling `poll` + // multiple times. + pkt.sixlowpan.datagram_tag = tag; + + let frag1 = SixlowpanFragRepr::FirstFragment { + size: pkt.sixlowpan.datagram_size, + tag, + }; + let fragn = SixlowpanFragRepr::Fragment { + size: pkt.sixlowpan.datagram_size, + tag, + offset: 0, + }; + + // We calculate how much data we can send in the first fragment and the other + // fragments. The eventual IPv6 sizes of these fragments need to be a multiple of eight + // (except for the last fragment) since the offset field in the fragment is an offset + // in multiples of 8 octets. This is explained in [RFC 4944 § 5.3]. + // + // [RFC 4944 § 5.3]: https://datatracker.ietf.org/doc/html/rfc4944#section-5.3 + + let header_diff = uncompressed_size - compressed_size; + let frag1_size = + (125 - ieee_len - frag1.buffer_len() + header_diff) / 8 * 8 - header_diff; + + pkt.sixlowpan.fragn_size = (125 - ieee_len - fragn.buffer_len()) / 8 * 8; + pkt.sent_bytes = frag1_size; + pkt.sixlowpan.datagram_offset = frag1_size + header_diff; + + tx_token.set_meta(meta); + tx_token.consume(ieee_len + frag1.buffer_len() + frag1_size, |mut tx_buf| { + // Add the IEEE header. + let mut ieee_packet = Ieee802154Frame::new_unchecked(&mut tx_buf[..ieee_len]); + ieee_repr.emit(&mut ieee_packet); + tx_buf = &mut tx_buf[ieee_len..]; + + // Add the first fragment header + let mut frag1_packet = SixlowpanFragPacket::new_unchecked(&mut tx_buf); + frag1.emit(&mut frag1_packet); + tx_buf = &mut tx_buf[frag1.buffer_len()..]; + + // Add the buffer part. + tx_buf[..frag1_size].copy_from_slice(&pkt.buffer[..frag1_size]); + }); + } + + #[cfg(not(feature = "proto-sixlowpan-fragmentation"))] + { + net_debug!( + "Enable the `proto-sixlowpan-fragmentation` feature for fragmentation support." + ); + return; + } + } else { + tx_token.set_meta(meta); + + // We don't need fragmentation, so we emit everything to the TX token. + tx_token.consume(total_size + ieee_len, |mut tx_buf| { + let mut ieee_packet = Ieee802154Frame::new_unchecked(&mut tx_buf[..ieee_len]); + ieee_repr.emit(&mut ieee_packet); + tx_buf = &mut tx_buf[ieee_len..]; + + Self::ipv6_to_sixlowpan(&self.checksum_caps(), packet, &ieee_repr, tx_buf); + }); + } + } + + fn ipv6_to_sixlowpan( + checksum_caps: &ChecksumCapabilities, + mut packet: PacketV6, + ieee_repr: &Ieee802154Repr, + mut buffer: &mut [u8], + ) { + let last_header = packet.payload.as_sixlowpan_next_header(); + let next_header = last_header; + + #[cfg(feature = "proto-ipv6-hbh")] + let next_header = if packet.hop_by_hop.is_some() { + SixlowpanNextHeader::Compressed + } else { + next_header + }; + + #[cfg(feature = "proto-ipv6-routing")] + let next_header = if packet.routing.is_some() { + SixlowpanNextHeader::Compressed + } else { + next_header + }; + + let iphc_repr = SixlowpanIphcRepr { + src_addr: packet.header.src_addr, + ll_src_addr: ieee_repr.src_addr, + dst_addr: packet.header.dst_addr, + ll_dst_addr: ieee_repr.dst_addr, + next_header, + hop_limit: packet.header.hop_limit, + ecn: None, + dscp: None, + flow_label: None, + }; + + iphc_repr.emit(&mut SixlowpanIphcPacket::new_unchecked( + &mut buffer[..iphc_repr.buffer_len()], + )); + buffer = &mut buffer[iphc_repr.buffer_len()..]; + + // Emit the Hop-by-Hop header + #[cfg(feature = "proto-ipv6-hbh")] + if let Some(hbh) = packet.hop_by_hop { + #[allow(unused)] + let next_header = last_header; + + #[cfg(feature = "proto-ipv6-routing")] + let next_header = if packet.routing.is_some() { + SixlowpanNextHeader::Compressed + } else { + last_header + }; + + let ext_hdr = SixlowpanExtHeaderRepr { + ext_header_id: SixlowpanExtHeaderId::HopByHopHeader, + next_header, + length: hbh.options.iter().map(|o| o.buffer_len()).sum::<usize>() as u8, + }; + ext_hdr.emit(&mut SixlowpanExtHeaderPacket::new_unchecked( + &mut buffer[..ext_hdr.buffer_len()], + )); + buffer = &mut buffer[ext_hdr.buffer_len()..]; + + for opt in &hbh.options { + opt.emit(&mut Ipv6Option::new_unchecked( + &mut buffer[..opt.buffer_len()], + )); + + buffer = &mut buffer[opt.buffer_len()..]; + } + } + + // Emit the Routing header + #[cfg(feature = "proto-ipv6-routing")] + if let Some(routing) = &packet.routing { + let ext_hdr = SixlowpanExtHeaderRepr { + ext_header_id: SixlowpanExtHeaderId::RoutingHeader, + next_header, + length: routing.buffer_len() as u8, + }; + ext_hdr.emit(&mut SixlowpanExtHeaderPacket::new_unchecked( + &mut buffer[..ext_hdr.buffer_len()], + )); + buffer = &mut buffer[ext_hdr.buffer_len()..]; + + routing.emit(&mut Ipv6RoutingHeader::new_unchecked( + &mut buffer[..routing.buffer_len()], + )); + buffer = &mut buffer[routing.buffer_len()..]; + } + + match &mut packet.payload { + IpPayload::Icmpv6(icmp_repr) => { + icmp_repr.emit( + &packet.header.src_addr.into(), + &packet.header.dst_addr.into(), + &mut Icmpv6Packet::new_unchecked(&mut buffer[..icmp_repr.buffer_len()]), + checksum_caps, + ); + } + #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] + IpPayload::Udp(udp_repr, payload) => { + let udp_repr = SixlowpanUdpNhcRepr(*udp_repr); + udp_repr.emit( + &mut SixlowpanUdpNhcPacket::new_unchecked( + &mut buffer[..udp_repr.header_len() + payload.len()], + ), + &iphc_repr.src_addr, + &iphc_repr.dst_addr, + payload.len(), + |buf| buf.copy_from_slice(payload), + checksum_caps, + ); + } + #[cfg(feature = "socket-tcp")] + IpPayload::Tcp(tcp_repr) => { + tcp_repr.emit( + &mut TcpPacket::new_unchecked(&mut buffer[..tcp_repr.buffer_len()]), + &packet.header.src_addr.into(), + &packet.header.dst_addr.into(), + checksum_caps, + ); + } + #[cfg(feature = "socket-raw")] + IpPayload::Raw(_raw) => todo!(), + + #[allow(unreachable_patterns)] + _ => unreachable!(), + } + } + + /// Calculates three sizes: + /// - total size: the size of a compressed IPv6 packet + /// - compressed header size: the size of the compressed headers + /// - uncompressed header size: the size of the headers that are not compressed + /// They are returned as a tuple in the same order. + fn compressed_packet_size( + packet: &PacketV6, + ieee_repr: &Ieee802154Repr, + ) -> (usize, usize, usize) { + let last_header = packet.payload.as_sixlowpan_next_header(); + let next_header = last_header; + + #[cfg(feature = "proto-ipv6-hbh")] + let next_header = if packet.hop_by_hop.is_some() { + SixlowpanNextHeader::Compressed + } else { + next_header + }; + + #[cfg(feature = "proto-ipv6-routing")] + let next_header = if packet.routing.is_some() { + SixlowpanNextHeader::Compressed + } else { + next_header + }; + + let iphc = SixlowpanIphcRepr { + src_addr: packet.header.src_addr, + ll_src_addr: ieee_repr.src_addr, + dst_addr: packet.header.dst_addr, + ll_dst_addr: ieee_repr.dst_addr, + next_header, + hop_limit: packet.header.hop_limit, + ecn: None, + dscp: None, + flow_label: None, + }; + + let mut total_size = iphc.buffer_len(); + let mut compressed_hdr_size = iphc.buffer_len(); + let mut uncompressed_hdr_size = packet.header.buffer_len(); + + // Add the hop-by-hop to the sizes. + #[cfg(feature = "proto-ipv6-hbh")] + if let Some(hbh) = &packet.hop_by_hop { + #[allow(unused)] + let next_header = last_header; + + #[cfg(feature = "proto-ipv6-routing")] + let next_header = if packet.routing.is_some() { + SixlowpanNextHeader::Compressed + } else { + last_header + }; + + let options_size = hbh.options.iter().map(|o| o.buffer_len()).sum::<usize>(); + + let ext_hdr = SixlowpanExtHeaderRepr { + ext_header_id: SixlowpanExtHeaderId::HopByHopHeader, + next_header, + length: hbh.buffer_len() as u8 + options_size as u8, + }; + + total_size += ext_hdr.buffer_len() + options_size; + compressed_hdr_size += ext_hdr.buffer_len() + options_size; + uncompressed_hdr_size += hbh.buffer_len() + options_size; + } + + // Add the routing header to the sizes. + #[cfg(feature = "proto-ipv6-routing")] + if let Some(routing) = &packet.routing { + let ext_hdr = SixlowpanExtHeaderRepr { + ext_header_id: SixlowpanExtHeaderId::RoutingHeader, + next_header, + length: routing.buffer_len() as u8, + }; + total_size += ext_hdr.buffer_len() + routing.buffer_len(); + compressed_hdr_size += ext_hdr.buffer_len() + routing.buffer_len(); + uncompressed_hdr_size += routing.buffer_len(); + } + + match packet.payload { + #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] + IpPayload::Udp(udp_hdr, payload) => { + uncompressed_hdr_size += udp_hdr.header_len(); + + let udp_hdr = SixlowpanUdpNhcRepr(udp_hdr); + compressed_hdr_size += udp_hdr.header_len(); + + total_size += udp_hdr.header_len() + payload.len(); + } + _ => { + total_size += packet.header.payload_len; + } + } + + (total_size, compressed_hdr_size, uncompressed_hdr_size) + } + + #[cfg(feature = "proto-sixlowpan-fragmentation")] + pub(super) fn dispatch_sixlowpan_frag<Tx: TxToken>( + &mut self, + tx_token: Tx, + ieee_repr: Ieee802154Repr, + frag: &mut Fragmenter, + ) { + // Create the FRAG_N header. + let fragn = SixlowpanFragRepr::Fragment { + size: frag.sixlowpan.datagram_size, + tag: frag.sixlowpan.datagram_tag, + offset: (frag.sixlowpan.datagram_offset / 8) as u8, + }; + + let ieee_len = ieee_repr.buffer_len(); + let frag_size = (frag.packet_len - frag.sent_bytes).min(frag.sixlowpan.fragn_size); + + tx_token.consume( + ieee_repr.buffer_len() + fragn.buffer_len() + frag_size, + |mut tx_buf| { + let mut ieee_packet = Ieee802154Frame::new_unchecked(&mut tx_buf[..ieee_len]); + ieee_repr.emit(&mut ieee_packet); + tx_buf = &mut tx_buf[ieee_len..]; + + let mut frag_packet = + SixlowpanFragPacket::new_unchecked(&mut tx_buf[..fragn.buffer_len()]); + fragn.emit(&mut frag_packet); + tx_buf = &mut tx_buf[fragn.buffer_len()..]; + + // Add the buffer part + tx_buf[..frag_size].copy_from_slice(&frag.buffer[frag.sent_bytes..][..frag_size]); + + frag.sent_bytes += frag_size; + frag.sixlowpan.datagram_offset += frag_size; + }, + ); + } +} + +#[cfg(test)] +#[cfg(all(feature = "proto-rpl", feature = "proto-ipv6-hbh"))] +mod tests { + use super::*; + + static SIXLOWPAN_COMPRESSED_RPL_DAO: [u8; 99] = [ + 0x61, 0xdc, 0x45, 0xcd, 0xab, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x03, 0x00, + 0x03, 0x00, 0x03, 0x00, 0x03, 0x00, 0x7e, 0xf7, 0x00, 0xe0, 0x3a, 0x06, 0x63, 0x04, 0x00, + 0x1e, 0x08, 0x00, 0x9b, 0x02, 0x3e, 0x63, 0x1e, 0x40, 0x00, 0xf1, 0xfd, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x05, 0x12, 0x00, + 0x80, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x03, 0x00, 0x03, 0x00, 0x03, + 0x00, 0x03, 0x06, 0x14, 0x00, 0x00, 0x00, 0x1e, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x02, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, + ]; + + static SIXLOWPAN_UNCOMPRESSED_RPL_DAO: [u8; 114] = [ + 0x60, 0x00, 0x00, 0x00, 0x00, 0x4a, 0x00, 0x40, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x02, 0x03, 0x00, 0x03, 0x00, 0x03, 0x00, 0x03, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x02, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x3a, 0x00, 0x63, 0x04, 0x00, + 0x1e, 0x08, 0x00, 0x9b, 0x02, 0x3e, 0x63, 0x1e, 0x40, 0x00, 0xf1, 0xfd, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x05, 0x12, 0x00, + 0x80, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x03, 0x00, 0x03, 0x00, 0x03, + 0x00, 0x03, 0x06, 0x14, 0x00, 0x00, 0x00, 0x1e, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x02, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, + ]; + + #[test] + fn test_sixlowpan_decompress_hop_by_hop_with_icmpv6() { + let address_context = [SixlowpanAddressContext([ + 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + ])]; + + let ieee_frame = Ieee802154Frame::new_checked(&SIXLOWPAN_COMPRESSED_RPL_DAO).unwrap(); + let ieee_repr = Ieee802154Repr::parse(&ieee_frame).unwrap(); + + let mut buffer = [0u8; 256]; + let len = InterfaceInner::sixlowpan_to_ipv6( + &address_context, + &ieee_repr, + ieee_frame.payload().unwrap(), + None, + &mut buffer[..], + ) + .unwrap(); + + assert_eq!(&buffer[..len], &SIXLOWPAN_UNCOMPRESSED_RPL_DAO); + } + + #[test] + fn test_sixlowpan_compress_hop_by_hop_with_icmpv6() { + let ieee_repr = Ieee802154Repr { + frame_type: Ieee802154FrameType::Data, + security_enabled: false, + frame_pending: false, + ack_request: true, + sequence_number: Some(69), + pan_id_compression: true, + frame_version: Ieee802154FrameVersion::Ieee802154_2006, + dst_pan_id: Some(Ieee802154Pan(43981)), + dst_addr: Some(Ieee802154Address::Extended([0, 1, 0, 1, 0, 1, 0, 1])), + src_pan_id: None, + src_addr: Some(Ieee802154Address::Extended([0, 3, 0, 3, 0, 3, 0, 3])), + }; + + let mut ip_packet = PacketV6 { + header: Ipv6Repr { + src_addr: Ipv6Address::from_bytes(&[ + 253, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 3, 0, 3, 0, 3, + ]), + dst_addr: Ipv6Address::from_bytes(&[ + 253, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 1, 0, 1, 0, 1, + ]), + next_header: IpProtocol::Icmpv6, + payload_len: 66, + hop_limit: 64, + }, + #[cfg(feature = "proto-ipv6-hbh")] + hop_by_hop: None, + #[cfg(feature = "proto-ipv6-fragmentation")] + fragment: None, + #[cfg(feature = "proto-ipv6-routing")] + routing: None, + payload: IpPayload::Icmpv6(Icmpv6Repr::Rpl(RplRepr::DestinationAdvertisementObject { + rpl_instance_id: RplInstanceId::Global(30), + expect_ack: false, + sequence: 241, + dodag_id: Some(Ipv6Address::from_bytes(&[ + 253, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 1, 0, 1, 0, 1, + ])), + options: &[], + })), + }; + + let (total_size, _, _) = InterfaceInner::compressed_packet_size(&mut ip_packet, &ieee_repr); + let mut buffer = vec![0u8; total_size]; + + InterfaceInner::ipv6_to_sixlowpan( + &ChecksumCapabilities::default(), + ip_packet, + &ieee_repr, + &mut buffer[..total_size], + ); + + let result = [ + 0x7e, 0x0, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x3, 0x0, 0x3, 0x0, 0x3, 0x0, + 0x3, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x1, 0x0, 0x1, 0x0, 0x1, 0x0, 0x1, + 0xe0, 0x3a, 0x6, 0x63, 0x4, 0x0, 0x1e, 0x3, 0x0, 0x9b, 0x2, 0x3e, 0x63, 0x1e, 0x40, + 0x0, 0xf1, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x1, 0x0, 0x1, 0x0, 0x1, 0x0, + 0x1, 0x5, 0x12, 0x0, 0x80, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x3, 0x0, 0x3, + 0x0, 0x3, 0x0, 0x3, 0x6, 0x14, 0x0, 0x0, 0x0, 0x1e, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x2, 0x1, 0x0, 0x1, 0x0, 0x1, 0x0, 0x1, + ]; + + assert_eq!(&result, &result); + } + + #[test] + fn test_sixlowpan_compress_hop_by_hop_with_udp() { + let ieee_repr = Ieee802154Repr { + frame_type: Ieee802154FrameType::Data, + security_enabled: false, + frame_pending: false, + ack_request: true, + sequence_number: Some(69), + pan_id_compression: true, + frame_version: Ieee802154FrameVersion::Ieee802154_2006, + dst_pan_id: Some(Ieee802154Pan(43981)), + dst_addr: Some(Ieee802154Address::Extended([0, 1, 0, 1, 0, 1, 0, 1])), + src_pan_id: None, + src_addr: Some(Ieee802154Address::Extended([0, 3, 0, 3, 0, 3, 0, 3])), + }; + + let addr = Ipv6Address::from_bytes(&[253, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 3, 0, 3, 0, 3]); + let parent_address = + Ipv6Address::from_bytes(&[253, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 1, 0, 1, 0, 1]); + + let mut hbh_options = heapless::Vec::new(); + hbh_options + .push(Ipv6OptionRepr::Rpl(RplHopByHopRepr { + down: false, + rank_error: false, + forwarding_error: false, + instance_id: RplInstanceId::from(0x1e), + sender_rank: 0x300, + })) + .unwrap(); + + let mut ip_packet = PacketV6 { + header: Ipv6Repr { + src_addr: addr, + dst_addr: parent_address, + next_header: IpProtocol::Icmpv6, + payload_len: 66, + hop_limit: 64, + }, + #[cfg(feature = "proto-ipv6-hbh")] + hop_by_hop: Some(Ipv6HopByHopRepr { + options: hbh_options, + }), + #[cfg(feature = "proto-ipv6-fragmentation")] + fragment: None, + #[cfg(feature = "proto-ipv6-routing")] + routing: None, + payload: IpPayload::Icmpv6(Icmpv6Repr::Rpl(RplRepr::DestinationAdvertisementObject { + rpl_instance_id: RplInstanceId::Global(30), + expect_ack: false, + sequence: 241, + dodag_id: Some(Ipv6Address::from_bytes(&[ + 253, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 1, 0, 1, 0, 1, + ])), + options: &[ + 5, 18, 0, 128, 253, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 3, 0, 3, 0, 3, 6, 20, 0, 0, + 0, 30, 253, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 1, 0, 1, 0, 1, + ], + })), + }; + + let (total_size, _, _) = InterfaceInner::compressed_packet_size(&mut ip_packet, &ieee_repr); + let mut buffer = vec![0u8; total_size]; + + InterfaceInner::ipv6_to_sixlowpan( + &ChecksumCapabilities::default(), + ip_packet, + &ieee_repr, + &mut buffer[..total_size], + ); + + let result = [ + 0x7e, 0x0, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x3, 0x0, 0x3, 0x0, 0x3, 0x0, + 0x3, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x1, 0x0, 0x1, 0x0, 0x1, 0x0, 0x1, + 0xe0, 0x3a, 0x6, 0x63, 0x4, 0x0, 0x1e, 0x3, 0x0, 0x9b, 0x2, 0x3e, 0x63, 0x1e, 0x40, + 0x0, 0xf1, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x1, 0x0, 0x1, 0x0, 0x1, 0x0, + 0x1, 0x5, 0x12, 0x0, 0x80, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x3, 0x0, 0x3, + 0x0, 0x3, 0x0, 0x3, 0x6, 0x14, 0x0, 0x0, 0x0, 0x1e, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x2, 0x1, 0x0, 0x1, 0x0, 0x1, 0x0, 0x1, + ]; + + assert_eq!(&buffer[..total_size], &result); + } +} diff --git a/src/iface/interface/tests/ipv4.rs b/src/iface/interface/tests/ipv4.rs new file mode 100644 index 0000000..d685f37 --- /dev/null +++ b/src/iface/interface/tests/ipv4.rs @@ -0,0 +1,968 @@ +use super::*; + +#[rstest] +#[case(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_no_icmp_no_unicast(#[case] medium: Medium) { + let (mut iface, mut sockets, _) = setup(medium); + + // Unknown Ipv4 Protocol + // + // Because the destination is the broadcast address + // this should not trigger and Destination Unreachable + // response. See RFC 1122 § 3.2.2. + let repr = IpRepr::Ipv4(Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Unknown(0x0c), + payload_len: 0, + hop_limit: 0x40, + }); + + let mut bytes = vec![0u8; 54]; + repr.emit(&mut bytes, &ChecksumCapabilities::default()); + let frame = Ipv4Packet::new_unchecked(&bytes[..]); + + // Ensure that the unknown protocol frame does not trigger an + // ICMP error response when the destination address is a + // broadcast address + + assert_eq!( + iface.inner.process_ipv4( + &mut sockets, + PacketMeta::default(), + &frame, + &mut iface.fragments + ), + None + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_icmp_error_no_payload(#[case] medium: Medium) { + static NO_BYTES: [u8; 0] = []; + let (mut iface, mut sockets, _device) = setup(medium); + + // Unknown Ipv4 Protocol with no payload + let repr = IpRepr::Ipv4(Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + next_header: IpProtocol::Unknown(0x0c), + payload_len: 0, + hop_limit: 0x40, + }); + + let mut bytes = vec![0u8; 34]; + repr.emit(&mut bytes, &ChecksumCapabilities::default()); + let frame = Ipv4Packet::new_unchecked(&bytes[..]); + + // The expected Destination Unreachable response due to the + // unknown protocol + let icmp_repr = Icmpv4Repr::DstUnreachable { + reason: Icmpv4DstUnreachable::ProtoUnreachable, + header: Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + next_header: IpProtocol::Unknown(12), + payload_len: 0, + hop_limit: 64, + }, + data: &NO_BYTES, + }; + + let expected_repr = Packet::new_ipv4( + Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + next_header: IpProtocol::Icmp, + payload_len: icmp_repr.buffer_len(), + hop_limit: 64, + }, + IpPayload::Icmpv4(icmp_repr), + ); + + // Ensure that the unknown protocol triggers an error response. + // And we correctly handle no payload. + + assert_eq!( + iface.inner.process_ipv4( + &mut sockets, + PacketMeta::default(), + &frame, + &mut iface.fragments + ), + Some(expected_repr) + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_local_subnet_broadcasts(#[case] medium: Medium) { + let (mut iface, _, _device) = setup(medium); + iface.update_ip_addrs(|addrs| { + addrs.iter_mut().next().map(|addr| { + *addr = IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Address([192, 168, 1, 23]), 24)); + }); + }); + + assert!(iface + .inner + .is_broadcast_v4(Ipv4Address([255, 255, 255, 255]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([255, 255, 255, 254]))); + assert!(iface.inner.is_broadcast_v4(Ipv4Address([192, 168, 1, 255]))); + assert!(!iface.inner.is_broadcast_v4(Ipv4Address([192, 168, 1, 254]))); + + iface.update_ip_addrs(|addrs| { + addrs.iter_mut().next().map(|addr| { + *addr = IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Address([192, 168, 23, 24]), 16)); + }); + }); + assert!(iface + .inner + .is_broadcast_v4(Ipv4Address([255, 255, 255, 255]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([255, 255, 255, 254]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([192, 168, 23, 255]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([192, 168, 23, 254]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([192, 168, 255, 254]))); + assert!(iface + .inner + .is_broadcast_v4(Ipv4Address([192, 168, 255, 255]))); + + iface.update_ip_addrs(|addrs| { + addrs.iter_mut().next().map(|addr| { + *addr = IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Address([192, 168, 23, 24]), 8)); + }); + }); + assert!(iface + .inner + .is_broadcast_v4(Ipv4Address([255, 255, 255, 255]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([255, 255, 255, 254]))); + assert!(!iface.inner.is_broadcast_v4(Ipv4Address([192, 23, 1, 255]))); + assert!(!iface.inner.is_broadcast_v4(Ipv4Address([192, 23, 1, 254]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([192, 255, 255, 254]))); + assert!(iface + .inner + .is_broadcast_v4(Ipv4Address([192, 255, 255, 255]))); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "medium-ip", feature = "socket-udp"))] +#[case(Medium::Ethernet)] +#[cfg(all(feature = "medium-ethernet", feature = "socket-udp"))] +fn test_icmp_error_port_unreachable(#[case] medium: Medium) { + static UDP_PAYLOAD: [u8; 12] = [ + 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x57, 0x6f, 0x6c, 0x64, 0x21, + ]; + let (mut iface, mut sockets, _device) = setup(medium); + + let mut udp_bytes_unicast = vec![0u8; 20]; + let mut udp_bytes_broadcast = vec![0u8; 20]; + let mut packet_unicast = UdpPacket::new_unchecked(&mut udp_bytes_unicast); + let mut packet_broadcast = UdpPacket::new_unchecked(&mut udp_bytes_broadcast); + + let udp_repr = UdpRepr { + src_port: 67, + dst_port: 68, + }; + + let ip_repr = IpRepr::Ipv4(Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(), + hop_limit: 64, + }); + + // Emit the representations to a packet + udp_repr.emit( + &mut packet_unicast, + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(&UDP_PAYLOAD), + &ChecksumCapabilities::default(), + ); + + let data = packet_unicast.into_inner(); + + // The expected Destination Unreachable ICMPv4 error response due + // to no sockets listening on the destination port. + let icmp_repr = Icmpv4Repr::DstUnreachable { + reason: Icmpv4DstUnreachable::PortUnreachable, + header: Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(), + hop_limit: 64, + }, + data, + }; + let expected_repr = Packet::new_ipv4( + Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + next_header: IpProtocol::Icmp, + payload_len: icmp_repr.buffer_len(), + hop_limit: 64, + }, + IpPayload::Icmpv4(icmp_repr), + ); + + // Ensure that the unknown protocol triggers an error response. + // And we correctly handle no payload. + assert_eq!( + iface.inner.process_udp( + &mut sockets, + PacketMeta::default(), + ip_repr, + udp_repr, + false, + &UDP_PAYLOAD, + data + ), + Some(expected_repr) + ); + + let ip_repr = IpRepr::Ipv4(Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(), + hop_limit: 64, + }); + + // Emit the representations to a packet + udp_repr.emit( + &mut packet_broadcast, + &ip_repr.src_addr(), + &IpAddress::Ipv4(Ipv4Address::BROADCAST), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(&UDP_PAYLOAD), + &ChecksumCapabilities::default(), + ); + + // Ensure that the port unreachable error does not trigger an + // ICMP error response when the destination address is a + // broadcast address and no socket is bound to the port. + assert_eq!( + iface.inner.process_udp( + &mut sockets, + PacketMeta::default(), + ip_repr, + udp_repr, + false, + &UDP_PAYLOAD, + packet_broadcast.into_inner(), + ), + None + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_handle_ipv4_broadcast(#[case] medium: Medium) { + use crate::wire::{Icmpv4Packet, Icmpv4Repr}; + + let (mut iface, mut sockets, _device) = setup(medium); + + let our_ipv4_addr = iface.ipv4_addr().unwrap(); + let src_ipv4_addr = Ipv4Address([127, 0, 0, 2]); + + // ICMPv4 echo request + let icmpv4_data: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; + let icmpv4_repr = Icmpv4Repr::EchoRequest { + ident: 0x1234, + seq_no: 0xabcd, + data: &icmpv4_data, + }; + + // Send to IPv4 broadcast address + let ipv4_repr = Ipv4Repr { + src_addr: src_ipv4_addr, + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Icmp, + hop_limit: 64, + payload_len: icmpv4_repr.buffer_len(), + }; + + // Emit to ip frame + let mut bytes = vec![0u8; ipv4_repr.buffer_len() + icmpv4_repr.buffer_len()]; + let frame = { + ipv4_repr.emit( + &mut Ipv4Packet::new_unchecked(&mut bytes[..]), + &ChecksumCapabilities::default(), + ); + icmpv4_repr.emit( + &mut Icmpv4Packet::new_unchecked(&mut bytes[ipv4_repr.buffer_len()..]), + &ChecksumCapabilities::default(), + ); + Ipv4Packet::new_unchecked(&bytes[..]) + }; + + // Expected ICMPv4 echo reply + let expected_icmpv4_repr = Icmpv4Repr::EchoReply { + ident: 0x1234, + seq_no: 0xabcd, + data: &icmpv4_data, + }; + let expected_ipv4_repr = Ipv4Repr { + src_addr: our_ipv4_addr, + dst_addr: src_ipv4_addr, + next_header: IpProtocol::Icmp, + hop_limit: 64, + payload_len: expected_icmpv4_repr.buffer_len(), + }; + let expected_packet = + Packet::new_ipv4(expected_ipv4_repr, IpPayload::Icmpv4(expected_icmpv4_repr)); + + assert_eq!( + iface.inner.process_ipv4( + &mut sockets, + PacketMeta::default(), + &frame, + &mut iface.fragments + ), + Some(expected_packet) + ); +} + +#[rstest] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_handle_valid_arp_request(#[case] medium: Medium) { + let (mut iface, mut sockets, _device) = setup(medium); + + let mut eth_bytes = vec![0u8; 42]; + + let local_ip_addr = Ipv4Address([0x7f, 0x00, 0x00, 0x01]); + let remote_ip_addr = Ipv4Address([0x7f, 0x00, 0x00, 0x02]); + let local_hw_addr = EthernetAddress([0x02, 0x02, 0x02, 0x02, 0x02, 0x02]); + let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00]); + + let repr = ArpRepr::EthernetIpv4 { + operation: ArpOperation::Request, + source_hardware_addr: remote_hw_addr, + source_protocol_addr: remote_ip_addr, + target_hardware_addr: EthernetAddress::default(), + target_protocol_addr: local_ip_addr, + }; + + let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); + frame.set_dst_addr(EthernetAddress::BROADCAST); + frame.set_src_addr(remote_hw_addr); + frame.set_ethertype(EthernetProtocol::Arp); + let mut packet = ArpPacket::new_unchecked(frame.payload_mut()); + repr.emit(&mut packet); + + // Ensure an ARP Request for us triggers an ARP Reply + assert_eq!( + iface.inner.process_ethernet( + &mut sockets, + PacketMeta::default(), + frame.into_inner(), + &mut iface.fragments + ), + Some(EthernetPacket::Arp(ArpRepr::EthernetIpv4 { + operation: ArpOperation::Reply, + source_hardware_addr: local_hw_addr, + source_protocol_addr: local_ip_addr, + target_hardware_addr: remote_hw_addr, + target_protocol_addr: remote_ip_addr + })) + ); + + // Ensure the address of the requester was entered in the cache + assert_eq!( + iface.inner.lookup_hardware_addr( + MockTxToken, + &IpAddress::Ipv4(local_ip_addr), + &IpAddress::Ipv4(remote_ip_addr), + &mut iface.fragmenter, + ), + Ok((HardwareAddress::Ethernet(remote_hw_addr), MockTxToken)) + ); +} + +#[rstest] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_handle_other_arp_request(#[case] medium: Medium) { + let (mut iface, mut sockets, _device) = setup(medium); + + let mut eth_bytes = vec![0u8; 42]; + + let remote_ip_addr = Ipv4Address([0x7f, 0x00, 0x00, 0x02]); + let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00]); + + let repr = ArpRepr::EthernetIpv4 { + operation: ArpOperation::Request, + source_hardware_addr: remote_hw_addr, + source_protocol_addr: remote_ip_addr, + target_hardware_addr: EthernetAddress::default(), + target_protocol_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x03]), + }; + + let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); + frame.set_dst_addr(EthernetAddress::BROADCAST); + frame.set_src_addr(remote_hw_addr); + frame.set_ethertype(EthernetProtocol::Arp); + let mut packet = ArpPacket::new_unchecked(frame.payload_mut()); + repr.emit(&mut packet); + + // Ensure an ARP Request for someone else does not trigger an ARP Reply + assert_eq!( + iface.inner.process_ethernet( + &mut sockets, + PacketMeta::default(), + frame.into_inner(), + &mut iface.fragments + ), + None + ); + + // Ensure the address of the requester was NOT entered in the cache + assert_eq!( + iface.inner.lookup_hardware_addr( + MockTxToken, + &IpAddress::Ipv4(Ipv4Address([0x7f, 0x00, 0x00, 0x01])), + &IpAddress::Ipv4(remote_ip_addr), + &mut iface.fragmenter, + ), + Err(DispatchError::NeighborPending) + ); +} + +#[rstest] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_arp_flush_after_update_ip(#[case] medium: Medium) { + let (mut iface, mut sockets, _device) = setup(medium); + + let mut eth_bytes = vec![0u8; 42]; + + let local_ip_addr = Ipv4Address([0x7f, 0x00, 0x00, 0x01]); + let remote_ip_addr = Ipv4Address([0x7f, 0x00, 0x00, 0x02]); + let local_hw_addr = EthernetAddress([0x02, 0x02, 0x02, 0x02, 0x02, 0x02]); + let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00]); + + let repr = ArpRepr::EthernetIpv4 { + operation: ArpOperation::Request, + source_hardware_addr: remote_hw_addr, + source_protocol_addr: remote_ip_addr, + target_hardware_addr: EthernetAddress::default(), + target_protocol_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + }; + + let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); + frame.set_dst_addr(EthernetAddress::BROADCAST); + frame.set_src_addr(remote_hw_addr); + frame.set_ethertype(EthernetProtocol::Arp); + { + let mut packet = ArpPacket::new_unchecked(frame.payload_mut()); + repr.emit(&mut packet); + } + + // Ensure an ARP Request for us triggers an ARP Reply + assert_eq!( + iface.inner.process_ethernet( + &mut sockets, + PacketMeta::default(), + frame.into_inner(), + &mut iface.fragments + ), + Some(EthernetPacket::Arp(ArpRepr::EthernetIpv4 { + operation: ArpOperation::Reply, + source_hardware_addr: local_hw_addr, + source_protocol_addr: local_ip_addr, + target_hardware_addr: remote_hw_addr, + target_protocol_addr: remote_ip_addr + })) + ); + + // Ensure the address of the requester was entered in the cache + assert_eq!( + iface.inner.lookup_hardware_addr( + MockTxToken, + &IpAddress::Ipv4(local_ip_addr), + &IpAddress::Ipv4(remote_ip_addr), + &mut iface.fragmenter, + ), + Ok((HardwareAddress::Ethernet(remote_hw_addr), MockTxToken)) + ); + + // Update IP addrs to trigger ARP cache flush + let local_ip_addr_new = Ipv4Address([0x7f, 0x00, 0x00, 0x01]); + iface.update_ip_addrs(|addrs| { + addrs.iter_mut().next().map(|addr| { + *addr = IpCidr::Ipv4(Ipv4Cidr::new(local_ip_addr_new, 24)); + }); + }); + + // ARP cache flush after address change + assert!(!iface.inner.has_neighbor(&IpAddress::Ipv4(remote_ip_addr))); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "socket-icmp", feature = "medium-ip"))] +#[case(Medium::Ethernet)] +#[cfg(all(feature = "socket-icmp", feature = "medium-ethernet"))] +fn test_icmpv4_socket(#[case] medium: Medium) { + use crate::wire::Icmpv4Packet; + + let (mut iface, mut sockets, _device) = setup(medium); + + let rx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 24]); + let tx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 24]); + + let icmpv4_socket = icmp::Socket::new(rx_buffer, tx_buffer); + + let socket_handle = sockets.add(icmpv4_socket); + + let ident = 0x1234; + let seq_no = 0x5432; + let echo_data = &[0xff; 16]; + + let socket = sockets.get_mut::<icmp::Socket>(socket_handle); + // Bind to the ID 0x1234 + assert_eq!(socket.bind(icmp::Endpoint::Ident(ident)), Ok(())); + + // Ensure the ident we bound to and the ident of the packet are the same. + let mut bytes = [0xff; 24]; + let mut packet = Icmpv4Packet::new_unchecked(&mut bytes[..]); + let echo_repr = Icmpv4Repr::EchoRequest { + ident, + seq_no, + data: echo_data, + }; + echo_repr.emit(&mut packet, &ChecksumCapabilities::default()); + let icmp_data = &*packet.into_inner(); + + let ipv4_repr = Ipv4Repr { + src_addr: Ipv4Address::new(0x7f, 0x00, 0x00, 0x02), + dst_addr: Ipv4Address::new(0x7f, 0x00, 0x00, 0x01), + next_header: IpProtocol::Icmp, + payload_len: 24, + hop_limit: 64, + }; + let ip_repr = IpRepr::Ipv4(ipv4_repr); + + // Open a socket and ensure the packet is handled due to the listening + // socket. + assert!(!sockets.get_mut::<icmp::Socket>(socket_handle).can_recv()); + + // Confirm we still get EchoReply from `smoltcp` even with the ICMP socket listening + let echo_reply = Icmpv4Repr::EchoReply { + ident, + seq_no, + data: echo_data, + }; + let ipv4_reply = Ipv4Repr { + src_addr: ipv4_repr.dst_addr, + dst_addr: ipv4_repr.src_addr, + ..ipv4_repr + }; + assert_eq!( + iface.inner.process_icmpv4(&mut sockets, ip_repr, icmp_data), + Some(Packet::new_ipv4(ipv4_reply, IpPayload::Icmpv4(echo_reply))) + ); + + let socket = sockets.get_mut::<icmp::Socket>(socket_handle); + assert!(socket.can_recv()); + assert_eq!( + socket.recv(), + Ok(( + icmp_data, + IpAddress::Ipv4(Ipv4Address::new(0x7f, 0x00, 0x00, 0x02)) + )) + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "proto-igmp", feature = "medium-ip"))] +#[case(Medium::Ethernet)] +#[cfg(all(feature = "proto-igmp", feature = "medium-ethernet"))] +fn test_handle_igmp(#[case] medium: Medium) { + fn recv_igmp( + device: &mut crate::tests::TestingDevice, + timestamp: Instant, + ) -> Vec<(Ipv4Repr, IgmpRepr)> { + let caps = device.capabilities(); + let checksum_caps = &caps.checksum; + recv_all(device, timestamp) + .iter() + .filter_map(|frame| { + let ipv4_packet = match caps.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => { + let eth_frame = EthernetFrame::new_checked(frame).ok()?; + Ipv4Packet::new_checked(eth_frame.payload()).ok()? + } + #[cfg(feature = "medium-ip")] + Medium::Ip => Ipv4Packet::new_checked(&frame[..]).ok()?, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => todo!(), + }; + let ipv4_repr = Ipv4Repr::parse(&ipv4_packet, checksum_caps).ok()?; + let ip_payload = ipv4_packet.payload(); + let igmp_packet = IgmpPacket::new_checked(ip_payload).ok()?; + let igmp_repr = IgmpRepr::parse(&igmp_packet).ok()?; + Some((ipv4_repr, igmp_repr)) + }) + .collect::<Vec<_>>() + } + + let groups = [ + Ipv4Address::new(224, 0, 0, 22), + Ipv4Address::new(224, 0, 0, 56), + ]; + + let (mut iface, mut sockets, mut device) = setup(medium); + + // Join multicast groups + let timestamp = Instant::ZERO; + for group in &groups { + iface + .join_multicast_group(&mut device, *group, timestamp) + .unwrap(); + } + + let reports = recv_igmp(&mut device, timestamp); + assert_eq!(reports.len(), 2); + for (i, group_addr) in groups.iter().enumerate() { + assert_eq!(reports[i].0.next_header, IpProtocol::Igmp); + assert_eq!(reports[i].0.dst_addr, *group_addr); + assert_eq!( + reports[i].1, + IgmpRepr::MembershipReport { + group_addr: *group_addr, + version: IgmpVersion::Version2, + } + ); + } + + // General query + let timestamp = Instant::ZERO; + const GENERAL_QUERY_BYTES: &[u8] = &[ + 0x46, 0xc0, 0x00, 0x24, 0xed, 0xb4, 0x00, 0x00, 0x01, 0x02, 0x47, 0x43, 0xac, 0x16, 0x63, + 0x04, 0xe0, 0x00, 0x00, 0x01, 0x94, 0x04, 0x00, 0x00, 0x11, 0x64, 0xec, 0x8f, 0x00, 0x00, + 0x00, 0x00, 0x02, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + ]; + { + // Transmit GENERAL_QUERY_BYTES into loopback + let tx_token = device.transmit(timestamp).unwrap(); + tx_token.consume(GENERAL_QUERY_BYTES.len(), |buffer| { + buffer.copy_from_slice(GENERAL_QUERY_BYTES); + }); + } + // Trigger processing until all packets received through the + // loopback have been processed, including responses to + // GENERAL_QUERY_BYTES. Therefore `recv_all()` would return 0 + // pkts that could be checked. + iface.socket_ingress(&mut device, &mut sockets); + + // Leave multicast groups + let timestamp = Instant::ZERO; + for group in &groups { + iface + .leave_multicast_group(&mut device, *group, timestamp) + .unwrap(); + } + + let leaves = recv_igmp(&mut device, timestamp); + assert_eq!(leaves.len(), 2); + for (i, group_addr) in groups.iter().cloned().enumerate() { + assert_eq!(leaves[i].0.next_header, IpProtocol::Igmp); + assert_eq!(leaves[i].0.dst_addr, Ipv4Address::MULTICAST_ALL_ROUTERS); + assert_eq!(leaves[i].1, IgmpRepr::LeaveGroup { group_addr }); + } +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "socket-raw", feature = "medium-ip"))] +#[case(Medium::Ethernet)] +#[cfg(all(feature = "socket-raw", feature = "medium-ethernet"))] +fn test_raw_socket_no_reply(#[case] medium: Medium) { + use crate::wire::{IpVersion, UdpPacket, UdpRepr}; + + let (mut iface, mut sockets, _) = setup(medium); + + let packets = 1; + let rx_buffer = + raw::PacketBuffer::new(vec![raw::PacketMetadata::EMPTY; packets], vec![0; 48 * 1]); + let tx_buffer = raw::PacketBuffer::new( + vec![raw::PacketMetadata::EMPTY; packets], + vec![0; 48 * packets], + ); + let raw_socket = raw::Socket::new(IpVersion::Ipv4, IpProtocol::Udp, rx_buffer, tx_buffer); + sockets.add(raw_socket); + + let src_addr = Ipv4Address([127, 0, 0, 2]); + let dst_addr = Ipv4Address([127, 0, 0, 1]); + + const PAYLOAD_LEN: usize = 10; + + let udp_repr = UdpRepr { + src_port: 67, + dst_port: 68, + }; + let mut bytes = vec![0xff; udp_repr.header_len() + PAYLOAD_LEN]; + let mut packet = UdpPacket::new_unchecked(&mut bytes[..]); + udp_repr.emit( + &mut packet, + &src_addr.into(), + &dst_addr.into(), + PAYLOAD_LEN, + |buf| fill_slice(buf, 0x2a), + &ChecksumCapabilities::default(), + ); + let ipv4_repr = Ipv4Repr { + src_addr, + dst_addr, + next_header: IpProtocol::Udp, + hop_limit: 64, + payload_len: udp_repr.header_len() + PAYLOAD_LEN, + }; + + // Emit to frame + let mut bytes = vec![0u8; ipv4_repr.buffer_len() + udp_repr.header_len() + PAYLOAD_LEN]; + let frame = { + ipv4_repr.emit( + &mut Ipv4Packet::new_unchecked(&mut bytes), + &ChecksumCapabilities::default(), + ); + udp_repr.emit( + &mut UdpPacket::new_unchecked(&mut bytes[ipv4_repr.buffer_len()..]), + &src_addr.into(), + &dst_addr.into(), + PAYLOAD_LEN, + |buf| fill_slice(buf, 0x2a), + &ChecksumCapabilities::default(), + ); + Ipv4Packet::new_unchecked(&bytes[..]) + }; + + assert_eq!( + iface.inner.process_ipv4( + &mut sockets, + PacketMeta::default(), + &frame, + &mut iface.fragments + ), + None + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "socket-raw", feature = "socket-udp", feature = "medium-ip"))] +#[case(Medium::Ethernet)] +#[cfg(all( + feature = "socket-raw", + feature = "socket-udp", + feature = "medium-ethernet" +))] +fn test_raw_socket_with_udp_socket(#[case] medium: Medium) { + use crate::wire::{IpEndpoint, IpVersion, UdpPacket, UdpRepr}; + + static UDP_PAYLOAD: [u8; 5] = [0x48, 0x65, 0x6c, 0x6c, 0x6f]; + + let (mut iface, mut sockets, _) = setup(medium); + + let udp_rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 15]); + let udp_tx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 15]); + let udp_socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); + let udp_socket_handle = sockets.add(udp_socket); + + // Bind the socket to port 68 + let socket = sockets.get_mut::<udp::Socket>(udp_socket_handle); + assert_eq!(socket.bind(68), Ok(())); + assert!(!socket.can_recv()); + assert!(socket.can_send()); + + let packets = 1; + let raw_rx_buffer = + raw::PacketBuffer::new(vec![raw::PacketMetadata::EMPTY; packets], vec![0; 48 * 1]); + let raw_tx_buffer = raw::PacketBuffer::new( + vec![raw::PacketMetadata::EMPTY; packets], + vec![0; 48 * packets], + ); + let raw_socket = raw::Socket::new( + IpVersion::Ipv4, + IpProtocol::Udp, + raw_rx_buffer, + raw_tx_buffer, + ); + sockets.add(raw_socket); + + let src_addr = Ipv4Address([127, 0, 0, 2]); + let dst_addr = Ipv4Address([127, 0, 0, 1]); + + let udp_repr = UdpRepr { + src_port: 67, + dst_port: 68, + }; + let mut bytes = vec![0xff; udp_repr.header_len() + UDP_PAYLOAD.len()]; + let mut packet = UdpPacket::new_unchecked(&mut bytes[..]); + udp_repr.emit( + &mut packet, + &src_addr.into(), + &dst_addr.into(), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(&UDP_PAYLOAD), + &ChecksumCapabilities::default(), + ); + let ipv4_repr = Ipv4Repr { + src_addr, + dst_addr, + next_header: IpProtocol::Udp, + hop_limit: 64, + payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(), + }; + + // Emit to frame + let mut bytes = vec![0u8; ipv4_repr.buffer_len() + udp_repr.header_len() + UDP_PAYLOAD.len()]; + let frame = { + ipv4_repr.emit( + &mut Ipv4Packet::new_unchecked(&mut bytes), + &ChecksumCapabilities::default(), + ); + udp_repr.emit( + &mut UdpPacket::new_unchecked(&mut bytes[ipv4_repr.buffer_len()..]), + &src_addr.into(), + &dst_addr.into(), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(&UDP_PAYLOAD), + &ChecksumCapabilities::default(), + ); + Ipv4Packet::new_unchecked(&bytes[..]) + }; + + assert_eq!( + iface.inner.process_ipv4( + &mut sockets, + PacketMeta::default(), + &frame, + &mut iface.fragments + ), + None + ); + + // Make sure the UDP socket can still receive in presence of a Raw socket that handles UDP + let socket = sockets.get_mut::<udp::Socket>(udp_socket_handle); + assert!(socket.can_recv()); + assert_eq!( + socket.recv(), + Ok(( + &UDP_PAYLOAD[..], + IpEndpoint::new(src_addr.into(), 67).into() + )) + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "socket-udp", feature = "medium-ip"))] +#[case(Medium::Ethernet)] +#[cfg(all(feature = "socket-udp", feature = "medium-ethernet"))] +fn test_icmp_reply_size(#[case] medium: Medium) { + use crate::wire::IPV4_MIN_MTU as MIN_MTU; + const MAX_PAYLOAD_LEN: usize = 528; + + let (mut iface, mut sockets, _device) = setup(medium); + + let src_addr = Ipv4Address([192, 168, 1, 1]); + let dst_addr = Ipv4Address([192, 168, 1, 2]); + + // UDP packet that if not tructated will cause a icmp port unreachable reply + // to exceed the minimum mtu bytes in length. + let udp_repr = UdpRepr { + src_port: 67, + dst_port: 68, + }; + let mut bytes = vec![0xff; udp_repr.header_len() + MAX_PAYLOAD_LEN]; + let mut packet = UdpPacket::new_unchecked(&mut bytes[..]); + udp_repr.emit( + &mut packet, + &src_addr.into(), + &dst_addr.into(), + MAX_PAYLOAD_LEN, + |buf| fill_slice(buf, 0x2a), + &ChecksumCapabilities::default(), + ); + + let ip_repr = Ipv4Repr { + src_addr, + dst_addr, + next_header: IpProtocol::Udp, + hop_limit: 64, + payload_len: udp_repr.header_len() + MAX_PAYLOAD_LEN, + }; + let payload = packet.into_inner(); + + let expected_icmp_repr = Icmpv4Repr::DstUnreachable { + reason: Icmpv4DstUnreachable::PortUnreachable, + header: ip_repr, + data: &payload[..MAX_PAYLOAD_LEN], + }; + + let expected_ip_repr = Ipv4Repr { + src_addr: dst_addr, + dst_addr: src_addr, + next_header: IpProtocol::Icmp, + hop_limit: 64, + payload_len: expected_icmp_repr.buffer_len(), + }; + + assert_eq!( + expected_ip_repr.buffer_len() + expected_icmp_repr.buffer_len(), + MIN_MTU + ); + + assert_eq!( + iface.inner.process_udp( + &mut sockets, + PacketMeta::default(), + ip_repr.into(), + udp_repr, + false, + &vec![0x2a; MAX_PAYLOAD_LEN], + payload, + ), + Some(Packet::new_ipv4( + expected_ip_repr, + IpPayload::Icmpv4(expected_icmp_repr) + )) + ); +} diff --git a/src/iface/interface/tests/ipv6.rs b/src/iface/interface/tests/ipv6.rs new file mode 100644 index 0000000..9c5f099 --- /dev/null +++ b/src/iface/interface/tests/ipv6.rs @@ -0,0 +1,988 @@ +use super::*; + +fn parse_ipv6(data: &[u8]) -> crate::wire::Result<Packet<'_>> { + let ipv6_header = Ipv6Packet::new_checked(data)?; + let ipv6 = Ipv6Repr::parse(&ipv6_header)?; + + match ipv6.next_header { + IpProtocol::HopByHop => todo!(), + IpProtocol::Icmp => todo!(), + IpProtocol::Igmp => todo!(), + IpProtocol::Tcp => todo!(), + IpProtocol::Udp => todo!(), + IpProtocol::Ipv6Route => todo!(), + IpProtocol::Ipv6Frag => todo!(), + IpProtocol::IpSecEsp => todo!(), + IpProtocol::IpSecAh => todo!(), + IpProtocol::Icmpv6 => { + let icmp = Icmpv6Repr::parse( + &ipv6.src_addr.into(), + &ipv6.dst_addr.into(), + &Icmpv6Packet::new_checked(ipv6_header.payload())?, + &Default::default(), + )?; + Ok(Packet::new_ipv6(ipv6, IpPayload::Icmpv6(icmp))) + } + IpProtocol::Ipv6NoNxt => todo!(), + IpProtocol::Ipv6Opts => todo!(), + IpProtocol::Unknown(_) => todo!(), + } +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn multicast_source_address(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x40, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, + ]; + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data[..]).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn hop_by_hop_skip_with_icmp(#[case] medium: Medium) { + // The following contains: + // - IPv6 header + // - Hop-by-hop, with options: + // - PADN (skipped) + // - Unknown option (skipped) + // - ICMP echo request + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x0, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x3a, 0x0, 0x1, 0x0, 0xf, 0x0, 0x1, 0x0, 0x80, 0x0, 0x2c, 0x88, + 0x0, 0x2a, 0x1, 0xa4, 0x4c, 0x6f, 0x72, 0x65, 0x6d, 0x20, 0x49, 0x70, 0x73, 0x75, 0x6d, + ]; + + let response = Some(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 19, + }, + IpPayload::Icmpv6(Icmpv6Repr::EchoReply { + ident: 42, + seq_no: 420, + data: b"Lorem Ipsum", + }), + )); + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data[..]).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn hop_by_hop_discard_with_icmp(#[case] medium: Medium) { + // The following contains: + // - IPv6 header + // - Hop-by-hop, with options: + // - PADN (skipped) + // - Unknown option (discard) + // - ICMP echo request + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x0, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x3a, 0x0, 0x1, 0x0, 0x40, 0x0, 0x1, 0x0, 0x80, 0x0, 0x2c, 0x88, + 0x0, 0x2a, 0x1, 0xa4, 0x4c, 0x6f, 0x72, 0x65, 0x6d, 0x20, 0x49, 0x70, 0x73, 0x75, 0x6d, + ]; + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data[..]).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +fn hop_by_hop_discard_param_problem(#[case] medium: Medium) { + // The following contains: + // - IPv6 header + // - Hop-by-hop, with options: + // - PADN (skipped) + // - Unknown option (discard + ParamProblem) + // - ICMP echo request + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x0, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x3a, 0x0, 0xC0, 0x0, 0x40, 0x0, 0x1, 0x0, 0x80, 0x0, 0x2c, 0x88, + 0x0, 0x2a, 0x1, 0xa4, 0x4c, 0x6f, 0x72, 0x65, 0x6d, 0x20, 0x49, 0x70, 0x73, 0x75, 0x6d, + ]; + + let response = Some(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::new(0xfdbe, 0, 0, 0, 0, 0, 0, 1), + dst_addr: Ipv6Address::new(0xfdbe, 0, 0, 0, 0, 0, 0, 2), + next_header: IpProtocol::Icmpv6, + payload_len: 75, + hop_limit: 64, + }, + IpPayload::Icmpv6(Icmpv6Repr::ParamProblem { + reason: Icmpv6ParamProblem::UnrecognizedOption, + pointer: 40, + header: Ipv6Repr { + src_addr: Ipv6Address::new(0xfdbe, 0, 0, 0, 0, 0, 0, 2), + dst_addr: Ipv6Address::new(0xfdbe, 0, 0, 0, 0, 0, 0, 1), + next_header: IpProtocol::HopByHop, + payload_len: 27, + hop_limit: 64, + }, + data: &[ + 0x3a, 0x0, 0xC0, 0x0, 0x40, 0x0, 0x1, 0x0, 0x80, 0x0, 0x2c, 0x88, 0x0, 0x2a, 0x1, + 0xa4, 0x4c, 0x6f, 0x72, 0x65, 0x6d, 0x20, 0x49, 0x70, 0x73, 0x75, 0x6d, + ], + }), + )); + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data[..]).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +fn hop_by_hop_discard_with_multicast(#[case] medium: Medium) { + // The following contains: + // - IPv6 header + // - Hop-by-hop, with options: + // - PADN (skipped) + // - Unknown option (discard (0b11) + ParamProblem) + // - ICMP echo request + // + // In this case, even if the destination address is a multicast address, an ICMPv6 ParamProblem + // should be transmitted. + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x0, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xff, 0x02, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x3a, 0x0, 0x80, 0x0, 0x40, 0x0, 0x1, 0x0, 0x80, 0x0, 0x2c, 0x88, + 0x0, 0x2a, 0x1, 0xa4, 0x4c, 0x6f, 0x72, 0x65, 0x6d, 0x20, 0x49, 0x70, 0x73, 0x75, 0x6d, + ]; + + let response = Some(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::new(0xfdbe, 0, 0, 0, 0, 0, 0, 1), + dst_addr: Ipv6Address::new(0xfdbe, 0, 0, 0, 0, 0, 0, 2), + next_header: IpProtocol::Icmpv6, + payload_len: 75, + hop_limit: 64, + }, + IpPayload::Icmpv6(Icmpv6Repr::ParamProblem { + reason: Icmpv6ParamProblem::UnrecognizedOption, + pointer: 40, + header: Ipv6Repr { + src_addr: Ipv6Address::new(0xfdbe, 0, 0, 0, 0, 0, 0, 2), + dst_addr: Ipv6Address::new(0xff02, 0, 0, 0, 0, 0, 0, 1), + next_header: IpProtocol::HopByHop, + payload_len: 27, + hop_limit: 64, + }, + data: &[ + 0x3a, 0x0, 0x80, 0x0, 0x40, 0x0, 0x1, 0x0, 0x80, 0x0, 0x2c, 0x88, 0x0, 0x2a, 0x1, + 0xa4, 0x4c, 0x6f, 0x72, 0x65, 0x6d, 0x20, 0x49, 0x70, 0x73, 0x75, 0x6d, + ], + }), + )); + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data[..]).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn imcp_empty_echo_request(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x8, 0x3a, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x84, 0x3c, 0x0, 0x0, 0x0, 0x0, + ]; + + assert_eq!( + parse_ipv6(&data), + Ok(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 8, + }, + IpPayload::Icmpv6(Icmpv6Repr::EchoRequest { + ident: 0, + seq_no: 0, + data: b"", + }) + )) + ); + + let response = Some(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 8, + }, + IpPayload::Icmpv6(Icmpv6Repr::EchoReply { + ident: 0, + seq_no: 0, + data: b"", + }), + )); + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data[..]).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn icmp_echo_request(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x13, 0x3a, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x2c, 0x88, 0x0, 0x2a, 0x1, 0xa4, 0x4c, 0x6f, 0x72, + 0x65, 0x6d, 0x20, 0x49, 0x70, 0x73, 0x75, 0x6d, + ]; + + assert_eq!( + parse_ipv6(&data), + Ok(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 19, + }, + IpPayload::Icmpv6(Icmpv6Repr::EchoRequest { + ident: 42, + seq_no: 420, + data: b"Lorem Ipsum", + }) + )) + ); + + let response = Some(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 19, + }, + IpPayload::Icmpv6(Icmpv6Repr::EchoReply { + ident: 42, + seq_no: 420, + data: b"Lorem Ipsum", + }), + )); + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data[..]).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn icmp_echo_reply_as_input(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x13, 0x3a, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x81, 0x0, 0x2d, 0x56, 0x0, 0x0, 0x0, 0x0, 0x4c, 0x6f, 0x72, 0x65, + 0x6d, 0x20, 0x49, 0x70, 0x73, 0x75, 0x6d, + ]; + + assert_eq!( + parse_ipv6(&data), + Ok(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 19, + }, + IpPayload::Icmpv6(Icmpv6Repr::EchoReply { + ident: 0, + seq_no: 0, + data: b"Lorem Ipsum", + }) + )) + ); + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data[..]).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn unknown_proto_with_multicast_dst_address(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xff, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, + ]; + + let response = Some(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 48, + }, + IpPayload::Icmpv6(Icmpv6Repr::ParamProblem { + reason: Icmpv6ParamProblem::UnrecognizedNxtHdr, + pointer: 40, + header: Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xff02, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 64, + next_header: IpProtocol::Unknown(0x0c), + payload_len: 0, + }, + data: &[], + }), + )); + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data[..]).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn unknown_proto(#[case] medium: Medium) { + // Since the destination address is multicast, we should answer with an ICMPv6 message. + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, + ]; + + let response = Some(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 48, + }, + IpPayload::Icmpv6(Icmpv6Repr::ParamProblem { + reason: Icmpv6ParamProblem::UnrecognizedNxtHdr, + pointer: 40, + header: Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 64, + next_header: IpProtocol::Unknown(0x0c), + payload_len: 0, + }, + data: &[], + }), + )); + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data[..]).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn ndsic_neighbor_advertisement_ethernet(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x20, 0x3a, 0xff, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x88, 0x0, 0x3b, 0x9f, 0x40, 0x0, 0x0, 0x0, 0xfe, 0x80, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x2, 0x1, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x1, + ]; + + assert_eq!( + parse_ipv6(&data), + Ok(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 255, + next_header: IpProtocol::Icmpv6, + payload_len: 32, + }, + IpPayload::Icmpv6(Icmpv6Repr::Ndisc(NdiscRepr::NeighborAdvert { + flags: NdiscNeighborFlags::SOLICITED, + target_addr: Ipv6Address::from_parts(&[0xfe80, 0, 0, 0, 0, 0, 0, 0x0002]), + lladdr: Some(RawHardwareAddress::from_bytes(&[0, 0, 0, 0, 0, 1])), + })) + )) + ); + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data[..]).unwrap() + ), + response + ); + + assert_eq!( + iface.inner.neighbor_cache.lookup( + &IpAddress::Ipv6(Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002])), + iface.inner.now, + ), + NeighborAnswer::Found(HardwareAddress::Ethernet(EthernetAddress::from_bytes(&[ + 0, 0, 0, 0, 0, 1 + ]))), + ); +} + +#[rstest] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn ndsic_neighbor_advertisement_ethernet_multicast_addr(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x20, 0x3a, 0xff, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x88, 0x0, 0x3b, 0xa0, 0x40, 0x0, 0x0, 0x0, 0xfe, 0x80, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x2, 0x1, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, + ]; + + assert_eq!( + parse_ipv6(&data), + Ok(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 255, + next_header: IpProtocol::Icmpv6, + payload_len: 32, + }, + IpPayload::Icmpv6(Icmpv6Repr::Ndisc(NdiscRepr::NeighborAdvert { + flags: NdiscNeighborFlags::SOLICITED, + target_addr: Ipv6Address::from_parts(&[0xfe80, 0, 0, 0, 0, 0, 0, 0x0002]), + lladdr: Some(RawHardwareAddress::from_bytes(&[ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + ])), + })) + )) + ); + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data[..]).unwrap() + ), + response + ); + + assert_eq!( + iface.inner.neighbor_cache.lookup( + &IpAddress::Ipv6(Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002])), + iface.inner.now, + ), + NeighborAnswer::NotFound, + ); +} + +#[rstest] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn ndsic_neighbor_advertisement_ieee802154(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x28, 0x3a, 0xff, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x88, 0x0, 0x3b, 0x96, 0x40, 0x0, 0x0, 0x0, 0xfe, 0x80, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x2, 0x2, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + ]; + + assert_eq!( + parse_ipv6(&data), + Ok(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 255, + next_header: IpProtocol::Icmpv6, + payload_len: 40, + }, + IpPayload::Icmpv6(Icmpv6Repr::Ndisc(NdiscRepr::NeighborAdvert { + flags: NdiscNeighborFlags::SOLICITED, + target_addr: Ipv6Address::from_parts(&[0xfe80, 0, 0, 0, 0, 0, 0, 0x0002]), + lladdr: Some(RawHardwareAddress::from_bytes(&[0, 0, 0, 0, 0, 0, 0, 1])), + })) + )) + ); + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data[..]).unwrap() + ), + response + ); + + assert_eq!( + iface.inner.neighbor_cache.lookup( + &IpAddress::Ipv6(Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002])), + iface.inner.now, + ), + NeighborAnswer::Found(HardwareAddress::Ieee802154(Ieee802154Address::from_bytes( + &[0, 0, 0, 0, 0, 0, 0, 1] + ))), + ); +} + +#[rstest] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_handle_valid_ndisc_request(#[case] medium: Medium) { + let (mut iface, mut sockets, _device) = setup(medium); + + let mut eth_bytes = vec![0u8; 86]; + + let local_ip_addr = Ipv6Address::new(0xfdbe, 0, 0, 0, 0, 0, 0, 1); + let remote_ip_addr = Ipv6Address::new(0xfdbe, 0, 0, 0, 0, 0, 0, 2); + let local_hw_addr = EthernetAddress([0x02, 0x02, 0x02, 0x02, 0x02, 0x02]); + let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00]); + + let solicit = Icmpv6Repr::Ndisc(NdiscRepr::NeighborSolicit { + target_addr: local_ip_addr, + lladdr: Some(remote_hw_addr.into()), + }); + let ip_repr = IpRepr::Ipv6(Ipv6Repr { + src_addr: remote_ip_addr, + dst_addr: local_ip_addr.solicited_node(), + next_header: IpProtocol::Icmpv6, + hop_limit: 0xff, + payload_len: solicit.buffer_len(), + }); + + let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); + frame.set_dst_addr(EthernetAddress([0x33, 0x33, 0x00, 0x00, 0x00, 0x00])); + frame.set_src_addr(remote_hw_addr); + frame.set_ethertype(EthernetProtocol::Ipv6); + ip_repr.emit(frame.payload_mut(), &ChecksumCapabilities::default()); + solicit.emit( + &remote_ip_addr.into(), + &local_ip_addr.solicited_node().into(), + &mut Icmpv6Packet::new_unchecked(&mut frame.payload_mut()[ip_repr.header_len()..]), + &ChecksumCapabilities::default(), + ); + + let icmpv6_expected = Icmpv6Repr::Ndisc(NdiscRepr::NeighborAdvert { + flags: NdiscNeighborFlags::SOLICITED, + target_addr: local_ip_addr, + lladdr: Some(local_hw_addr.into()), + }); + + let ipv6_expected = Ipv6Repr { + src_addr: local_ip_addr, + dst_addr: remote_ip_addr, + next_header: IpProtocol::Icmpv6, + hop_limit: 0xff, + payload_len: icmpv6_expected.buffer_len(), + }; + + // Ensure an Neighbor Solicitation triggers a Neighbor Advertisement + assert_eq!( + iface.inner.process_ethernet( + &mut sockets, + PacketMeta::default(), + frame.into_inner(), + &mut iface.fragments + ), + Some(EthernetPacket::Ip(Packet::new_ipv6( + ipv6_expected, + IpPayload::Icmpv6(icmpv6_expected) + ))) + ); + + // Ensure the address of the requester was entered in the cache + assert_eq!( + iface.inner.lookup_hardware_addr( + MockTxToken, + &IpAddress::Ipv6(local_ip_addr), + &IpAddress::Ipv6(remote_ip_addr), + &mut iface.fragmenter, + ), + Ok((HardwareAddress::Ethernet(remote_hw_addr), MockTxToken)) + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn test_solicited_node_addrs(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let mut new_addrs = heapless::Vec::<IpCidr, IFACE_MAX_ADDR_COUNT>::new(); + new_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 1, 2, 0, 2), 64)) + .unwrap(); + new_addrs + .push(IpCidr::new( + IpAddress::v6(0xfe80, 0, 0, 0, 3, 4, 0, 0xffff), + 64, + )) + .unwrap(); + iface.update_ip_addrs(|addrs| { + new_addrs.extend(addrs.to_vec()); + *addrs = new_addrs; + }); + assert!(iface + .inner + .has_solicited_node(Ipv6Address::new(0xff02, 0, 0, 0, 0, 1, 0xff00, 0x0002))); + assert!(iface + .inner + .has_solicited_node(Ipv6Address::new(0xff02, 0, 0, 0, 0, 1, 0xff00, 0xffff))); + assert!(!iface + .inner + .has_solicited_node(Ipv6Address::new(0xff02, 0, 0, 0, 0, 1, 0xff00, 0x0003))); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "socket-udp", feature = "medium-ip"))] +#[case(Medium::Ethernet)] +#[cfg(all(feature = "socket-udp", feature = "medium-ethernet"))] +#[case(Medium::Ieee802154)] +#[cfg(all(feature = "socket-udp", feature = "medium-ieee802154"))] +fn test_icmp_reply_size(#[case] medium: Medium) { + use crate::wire::Icmpv6DstUnreachable; + use crate::wire::IPV6_MIN_MTU as MIN_MTU; + const MAX_PAYLOAD_LEN: usize = 1192; + + let (mut iface, mut sockets, _device) = setup(medium); + + let src_addr = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1); + let dst_addr = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 2); + + // UDP packet that if not tructated will cause a icmp port unreachable reply + // to exceed the minimum mtu bytes in length. + let udp_repr = UdpRepr { + src_port: 67, + dst_port: 68, + }; + let mut bytes = vec![0xff; udp_repr.header_len() + MAX_PAYLOAD_LEN]; + let mut packet = UdpPacket::new_unchecked(&mut bytes[..]); + udp_repr.emit( + &mut packet, + &src_addr.into(), + &dst_addr.into(), + MAX_PAYLOAD_LEN, + |buf| fill_slice(buf, 0x2a), + &ChecksumCapabilities::default(), + ); + + let ip_repr = Ipv6Repr { + src_addr, + dst_addr, + next_header: IpProtocol::Udp, + hop_limit: 64, + payload_len: udp_repr.header_len() + MAX_PAYLOAD_LEN, + }; + let payload = packet.into_inner(); + + let expected_icmp_repr = Icmpv6Repr::DstUnreachable { + reason: Icmpv6DstUnreachable::PortUnreachable, + header: ip_repr, + data: &payload[..MAX_PAYLOAD_LEN], + }; + + let expected_ip_repr = Ipv6Repr { + src_addr: dst_addr, + dst_addr: src_addr, + next_header: IpProtocol::Icmpv6, + hop_limit: 64, + payload_len: expected_icmp_repr.buffer_len(), + }; + + assert_eq!( + expected_ip_repr.buffer_len() + expected_icmp_repr.buffer_len(), + MIN_MTU + ); + + assert_eq!( + iface.inner.process_udp( + &mut sockets, + PacketMeta::default(), + ip_repr.into(), + udp_repr, + false, + &vec![0x2a; MAX_PAYLOAD_LEN], + payload, + ), + Some(Packet::new_ipv6( + expected_ip_repr, + IpPayload::Icmpv6(expected_icmp_repr) + )) + ); +} + +#[cfg(feature = "medium-ip")] +#[test] +fn get_source_address() { + let (mut iface, _, _) = setup(Medium::Ip); + + const OWN_LINK_LOCAL_ADDR: Ipv6Address = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1); + const OWN_UNIQUE_LOCAL_ADDR1: Ipv6Address = Ipv6Address::new(0xfd00, 0, 0, 201, 1, 1, 1, 2); + const OWN_UNIQUE_LOCAL_ADDR2: Ipv6Address = Ipv6Address::new(0xfd01, 0, 0, 201, 1, 1, 1, 2); + const OWN_GLOBAL_UNICAST_ADDR1: Ipv6Address = + Ipv6Address::new(0x2001, 0x0db8, 0x0003, 0, 0, 0, 0, 1); + + // List of addresses of the interface: + // fe80::1/64 + // fd00::201:1:1:1:2/64 + // fd01::201:1:1:1:2/64 + // 2001:db8:3::1/64 + // ::1/128 + // ::/128 + iface.update_ip_addrs(|addrs| { + addrs.clear(); + + addrs + .push(IpCidr::Ipv6(Ipv6Cidr::new(OWN_LINK_LOCAL_ADDR, 64))) + .unwrap(); + addrs + .push(IpCidr::Ipv6(Ipv6Cidr::new(OWN_UNIQUE_LOCAL_ADDR1, 64))) + .unwrap(); + addrs + .push(IpCidr::Ipv6(Ipv6Cidr::new(OWN_UNIQUE_LOCAL_ADDR2, 64))) + .unwrap(); + addrs + .push(IpCidr::Ipv6(Ipv6Cidr::new(OWN_GLOBAL_UNICAST_ADDR1, 64))) + .unwrap(); + + // These should never be used: + addrs + .push(IpCidr::Ipv6(Ipv6Cidr::new(Ipv6Address::LOOPBACK, 128))) + .unwrap(); + addrs + .push(IpCidr::Ipv6(Ipv6Cidr::new(Ipv6Address::UNSPECIFIED, 128))) + .unwrap(); + }); + + // List of addresses we test: + // fe80::42 -> fe80::1 + // fd00::201:1:1:1:1 -> fd00::201:1:1:1:2 + // fd01::201:1:1:1:1 -> fd01::201:1:1:1:2 + // fd02::201:1:1:1:1 -> fd00::201:1:1:1:2 (because first added in the list) + // ff02::1 -> fe80::1 (same scope) + // 2001:db8:3::2 -> 2001:db8:3::1 + // 2001:db9:3::2 -> 2001:db8:3::1 + const LINK_LOCAL_ADDR: Ipv6Address = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 42); + const UNIQUE_LOCAL_ADDR1: Ipv6Address = Ipv6Address::new(0xfd00, 0, 0, 201, 1, 1, 1, 1); + const UNIQUE_LOCAL_ADDR2: Ipv6Address = Ipv6Address::new(0xfd01, 0, 0, 201, 1, 1, 1, 1); + const UNIQUE_LOCAL_ADDR3: Ipv6Address = Ipv6Address::new(0xfd02, 0, 0, 201, 1, 1, 1, 1); + const GLOBAL_UNICAST_ADDR1: Ipv6Address = + Ipv6Address::new(0x2001, 0x0db8, 0x0003, 0, 0, 0, 0, 2); + const GLOBAL_UNICAST_ADDR2: Ipv6Address = + Ipv6Address::new(0x2001, 0x0db9, 0x0003, 0, 0, 0, 0, 2); + + assert_eq!( + iface.inner.get_source_address_ipv6(&LINK_LOCAL_ADDR), + Some(OWN_LINK_LOCAL_ADDR) + ); + assert_eq!( + iface.inner.get_source_address_ipv6(&UNIQUE_LOCAL_ADDR1), + Some(OWN_UNIQUE_LOCAL_ADDR1) + ); + assert_eq!( + iface.inner.get_source_address_ipv6(&UNIQUE_LOCAL_ADDR2), + Some(OWN_UNIQUE_LOCAL_ADDR2) + ); + assert_eq!( + iface.inner.get_source_address_ipv6(&UNIQUE_LOCAL_ADDR3), + Some(OWN_UNIQUE_LOCAL_ADDR1) + ); + assert_eq!( + iface + .inner + .get_source_address_ipv6(&Ipv6Address::LINK_LOCAL_ALL_NODES), + Some(OWN_LINK_LOCAL_ADDR) + ); + assert_eq!( + iface.inner.get_source_address_ipv6(&GLOBAL_UNICAST_ADDR1), + Some(OWN_GLOBAL_UNICAST_ADDR1) + ); + assert_eq!( + iface.inner.get_source_address_ipv6(&GLOBAL_UNICAST_ADDR2), + Some(OWN_GLOBAL_UNICAST_ADDR1) + ); + + assert_eq!( + iface.get_source_address_ipv6(&LINK_LOCAL_ADDR), + Some(OWN_LINK_LOCAL_ADDR) + ); + assert_eq!( + iface.get_source_address_ipv6(&UNIQUE_LOCAL_ADDR1), + Some(OWN_UNIQUE_LOCAL_ADDR1) + ); + assert_eq!( + iface.get_source_address_ipv6(&UNIQUE_LOCAL_ADDR2), + Some(OWN_UNIQUE_LOCAL_ADDR2) + ); + assert_eq!( + iface.get_source_address_ipv6(&UNIQUE_LOCAL_ADDR3), + Some(OWN_UNIQUE_LOCAL_ADDR1) + ); + assert_eq!( + iface.get_source_address_ipv6(&Ipv6Address::LINK_LOCAL_ALL_NODES), + Some(OWN_LINK_LOCAL_ADDR) + ); + assert_eq!( + iface.get_source_address_ipv6(&GLOBAL_UNICAST_ADDR1), + Some(OWN_GLOBAL_UNICAST_ADDR1) + ); + assert_eq!( + iface.get_source_address_ipv6(&GLOBAL_UNICAST_ADDR2), + Some(OWN_GLOBAL_UNICAST_ADDR1) + ); +} diff --git a/src/iface/interface/tests/mod.rs b/src/iface/interface/tests/mod.rs new file mode 100644 index 0000000..b4b4416 --- /dev/null +++ b/src/iface/interface/tests/mod.rs @@ -0,0 +1,235 @@ +#[cfg(feature = "proto-ipv4")] +mod ipv4; +#[cfg(feature = "proto-ipv6")] +mod ipv6; +#[cfg(feature = "proto-sixlowpan")] +mod sixlowpan; + +#[cfg(feature = "proto-igmp")] +use std::vec::Vec; + +use crate::tests::setup; + +use rstest::*; + +use super::*; + +use crate::iface::Interface; +use crate::phy::ChecksumCapabilities; +#[cfg(feature = "alloc")] +use crate::phy::Loopback; +use crate::time::Instant; + +#[allow(unused)] +fn fill_slice(s: &mut [u8], val: u8) { + for x in s.iter_mut() { + *x = val + } +} + +#[cfg(feature = "proto-igmp")] +fn recv_all(device: &mut crate::tests::TestingDevice, timestamp: Instant) -> Vec<Vec<u8>> { + let mut pkts = Vec::new(); + while let Some((rx, _tx)) = device.receive(timestamp) { + rx.consume(|pkt| { + pkts.push(pkt.to_vec()); + }); + } + pkts +} + +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct MockTxToken; + +impl TxToken for MockTxToken { + fn consume<R, F>(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut junk = [0; 1536]; + f(&mut junk[..len]) + } +} + +#[test] +#[should_panic(expected = "The hardware address does not match the medium of the interface.")] +#[cfg(all(feature = "medium-ip", feature = "medium-ethernet", feature = "alloc"))] +fn test_new_panic() { + let mut device = Loopback::new(Medium::Ethernet); + let config = Config::new(HardwareAddress::Ip); + Interface::new(config, &mut device, Instant::ZERO); +} + +#[rstest] +#[cfg(feature = "default")] +fn test_handle_udp_broadcast( + #[values(Medium::Ip, Medium::Ethernet, Medium::Ieee802154)] medium: Medium, +) { + use crate::wire::IpEndpoint; + + static UDP_PAYLOAD: [u8; 5] = [0x48, 0x65, 0x6c, 0x6c, 0x6f]; + + let (mut iface, mut sockets, _device) = setup(medium); + + let rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 15]); + let tx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 15]); + + let udp_socket = udp::Socket::new(rx_buffer, tx_buffer); + + let mut udp_bytes = vec![0u8; 13]; + let mut packet = UdpPacket::new_unchecked(&mut udp_bytes); + + let socket_handle = sockets.add(udp_socket); + + #[cfg(feature = "proto-ipv6")] + let src_ip = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1); + #[cfg(all(not(feature = "proto-ipv6"), feature = "proto-ipv4"))] + let src_ip = Ipv4Address::new(0x7f, 0x00, 0x00, 0x02); + + let udp_repr = UdpRepr { + src_port: 67, + dst_port: 68, + }; + + #[cfg(feature = "proto-ipv6")] + let ip_repr = IpRepr::Ipv6(Ipv6Repr { + src_addr: src_ip, + dst_addr: Ipv6Address::LINK_LOCAL_ALL_NODES, + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(), + hop_limit: 0x40, + }); + #[cfg(all(not(feature = "proto-ipv6"), feature = "proto-ipv4"))] + let ip_repr = IpRepr::Ipv4(Ipv4Repr { + src_addr: src_ip, + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(), + hop_limit: 0x40, + }); + + // Bind the socket to port 68 + let socket = sockets.get_mut::<udp::Socket>(socket_handle); + assert_eq!(socket.bind(68), Ok(())); + assert!(!socket.can_recv()); + assert!(socket.can_send()); + + udp_repr.emit( + &mut packet, + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(&UDP_PAYLOAD), + &ChecksumCapabilities::default(), + ); + + // Packet should be handled by bound UDP socket + assert_eq!( + iface.inner.process_udp( + &mut sockets, + PacketMeta::default(), + ip_repr, + udp_repr, + false, + &UDP_PAYLOAD, + packet.into_inner(), + ), + None + ); + + // Make sure the payload to the UDP packet processed by process_udp is + // appended to the bound sockets rx_buffer + let socket = sockets.get_mut::<udp::Socket>(socket_handle); + assert!(socket.can_recv()); + assert_eq!( + socket.recv(), + Ok((&UDP_PAYLOAD[..], IpEndpoint::new(src_ip.into(), 67).into())) + ); +} + +#[test] +#[cfg(all(feature = "medium-ip", feature = "socket-tcp", feature = "proto-ipv6"))] +pub fn tcp_not_accepted() { + let (mut iface, mut sockets, _) = setup(Medium::Ip); + let tcp = TcpRepr { + src_port: 4242, + dst_port: 4243, + control: TcpControl::Syn, + seq_number: TcpSeqNumber(-10001), + ack_number: None, + window_len: 256, + window_scale: None, + max_seg_size: None, + sack_permitted: false, + sack_ranges: [None, None, None], + payload: &[], + }; + + let mut tcp_bytes = vec![0u8; tcp.buffer_len()]; + + tcp.emit( + &mut TcpPacket::new_unchecked(&mut tcp_bytes), + &Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 2).into(), + &Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1).into(), + &ChecksumCapabilities::default(), + ); + + assert_eq!( + iface.inner.process_tcp( + &mut sockets, + IpRepr::Ipv6(Ipv6Repr { + src_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 2), + dst_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), + next_header: IpProtocol::Tcp, + payload_len: tcp.buffer_len(), + hop_limit: 64, + }), + &tcp_bytes, + ), + Some(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), + dst_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 2), + next_header: IpProtocol::Tcp, + payload_len: tcp.buffer_len(), + hop_limit: 64, + }, + IpPayload::Tcp(TcpRepr { + src_port: 4243, + dst_port: 4242, + control: TcpControl::Rst, + seq_number: TcpSeqNumber(0), + ack_number: Some(TcpSeqNumber(-10000)), + window_len: 0, + window_scale: None, + max_seg_size: None, + sack_permitted: false, + sack_ranges: [None, None, None], + payload: &[], + }) + )) + ); + // Unspecified destination address. + tcp.emit( + &mut TcpPacket::new_unchecked(&mut tcp_bytes), + &Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 2).into(), + &Ipv6Address::UNSPECIFIED.into(), + &ChecksumCapabilities::default(), + ); + + assert_eq!( + iface.inner.process_tcp( + &mut sockets, + IpRepr::Ipv6(Ipv6Repr { + src_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 2), + dst_addr: Ipv6Address::UNSPECIFIED, + next_header: IpProtocol::Tcp, + payload_len: tcp.buffer_len(), + hop_limit: 64, + }), + &tcp_bytes, + ), + None, + ); +} diff --git a/src/iface/interface/tests/sixlowpan.rs b/src/iface/interface/tests/sixlowpan.rs new file mode 100644 index 0000000..676835e --- /dev/null +++ b/src/iface/interface/tests/sixlowpan.rs @@ -0,0 +1,434 @@ +use super::*; + +#[rstest] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn ieee802154_wrong_pan_id(#[case] medium: Medium) { + let data = [ + 0x41, 0xcc, 0x3b, 0xff, 0xbe, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, 0x62, 0x3a, + 0xa6, 0x34, 0x57, 0x29, 0x1c, 0x26, + ]; + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ieee802154( + &mut sockets, + PacketMeta::default(), + &data[..], + &mut iface.fragments + ), + response, + ); +} + +#[rstest] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn icmp_echo_request(#[case] medium: Medium) { + let data = [ + 0x41, 0xcc, 0x3b, 0xef, 0xbe, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, 0x62, 0x3a, + 0xa6, 0x34, 0x57, 0x29, 0x1c, 0x26, 0x6a, 0x33, 0x0a, 0x62, 0x17, 0x3a, 0x80, 0x00, 0xb0, + 0xe3, 0x00, 0x04, 0x00, 0x01, 0x82, 0xf2, 0x82, 0x64, 0x00, 0x00, 0x00, 0x00, 0x66, 0x23, + 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, + 0x37, + ]; + + let response = Some(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfe80, 0, 0, 0, 0x180b, 0x4242, 0x4242, 0x4242]), + dst_addr: Ipv6Address::from_parts(&[0xfe80, 0, 0, 0, 0x241c, 0x2957, 0x34a6, 0x3a62]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 64, + }, + IpPayload::Icmpv6(Icmpv6Repr::EchoReply { + ident: 4, + seq_no: 1, + data: &[ + 0x82, 0xf2, 0x82, 0x64, 0x00, 0x00, 0x00, 0x00, 0x66, 0x23, 0x0c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, + 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + ], + }), + )); + + let (mut iface, mut sockets, _device) = setup(medium); + iface.update_ip_addrs(|ips| { + ips.push(IpCidr::Ipv6(Ipv6Cidr::new( + Ipv6Address::from_parts(&[0xfe80, 0, 0, 0, 0x180b, 0x4242, 0x4242, 0x4242]), + 10, + ))) + .unwrap(); + }); + + assert_eq!( + iface.inner.process_ieee802154( + &mut sockets, + PacketMeta::default(), + &data[..], + &mut iface.fragments + ), + response, + ); +} + +#[test] +#[cfg(feature = "proto-sixlowpan-fragmentation")] +fn test_echo_request_sixlowpan_128_bytes() { + use crate::phy::Checksum; + + let (mut iface, mut sockets, mut device) = setup(Medium::Ieee802154); + iface.update_ip_addrs(|ips| { + ips.push(IpCidr::Ipv6(Ipv6Cidr::new( + Ipv6Address::new(0xfe80, 0x0, 0x0, 0x0, 0x92fc, 0x48c2, 0xa441, 0xfc76), + 10, + ))) + .unwrap(); + }); + // TODO: modify the example, such that we can also test if the checksum is correctly + // computed. + iface.inner.caps.checksum.icmpv6 = Checksum::None; + + assert_eq!(iface.inner.caps.medium, Medium::Ieee802154); + let now = iface.inner.now(); + + iface.inner.neighbor_cache.fill( + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0x2, 0, 0, 0, 0, 0, 0, 0]).into(), + HardwareAddress::Ieee802154(Ieee802154Address::default()), + now, + ); + + let mut ieee802154_repr = Ieee802154Repr { + frame_type: Ieee802154FrameType::Data, + security_enabled: false, + frame_pending: false, + ack_request: false, + sequence_number: Some(5), + pan_id_compression: true, + frame_version: Ieee802154FrameVersion::Ieee802154_2003, + dst_pan_id: Some(Ieee802154Pan(0xbeef)), + dst_addr: Some(Ieee802154Address::Extended([ + 0x90, 0xfc, 0x48, 0xc2, 0xa4, 0x41, 0xfc, 0x76, + ])), + src_pan_id: Some(Ieee802154Pan(0xbeef)), + src_addr: Some(Ieee802154Address::Extended([ + 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, + ])), + }; + + // NOTE: this data is retrieved from tests with Contiki-NG + + let request_first_part_packet = SixlowpanFragPacket::new_checked(&[ + 0xc0, 0xb0, 0x00, 0x8e, 0x6a, 0x33, 0x05, 0x25, 0x2c, 0x3a, 0x80, 0x00, 0xe0, 0x71, 0x00, + 0x27, 0x00, 0x02, 0xa2, 0xc2, 0x2d, 0x63, 0x00, 0x00, 0x00, 0x00, 0xd9, 0x5e, 0x0c, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, + 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, + 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, + 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, + ]) + .unwrap(); + + let request_first_part_iphc_packet = + SixlowpanIphcPacket::new_checked(request_first_part_packet.payload()).unwrap(); + + let request_first_part_iphc_repr = SixlowpanIphcRepr::parse( + &request_first_part_iphc_packet, + ieee802154_repr.src_addr, + ieee802154_repr.dst_addr, + &iface.inner.sixlowpan_address_context, + ) + .unwrap(); + + assert_eq!( + request_first_part_iphc_repr.src_addr, + Ipv6Address([ + 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, 0x42, 0x42, 0x42, 0x42, 0x42, 0xb, + 0x1a, + ]), + ); + assert_eq!( + request_first_part_iphc_repr.dst_addr, + Ipv6Address([ + 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x92, 0xfc, 0x48, 0xc2, 0xa4, 0x41, 0xfc, + 0x76, + ]), + ); + + let request_second_part = [ + 0xe0, 0xb0, 0x00, 0x8e, 0x10, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, + 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, + 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, + 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + ]; + + assert_eq!( + iface.inner.process_sixlowpan( + &mut sockets, + PacketMeta::default(), + &ieee802154_repr, + &request_first_part_packet.into_inner()[..], + &mut iface.fragments + ), + None + ); + + ieee802154_repr.sequence_number = Some(6); + + // data that was generated when using `ping -s 128` + let data = &[ + 0xa2, 0xc2, 0x2d, 0x63, 0x00, 0x00, 0x00, 0x00, 0xd9, 0x5e, 0x0c, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, + 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, + 0x3c, 0x3d, 0x3e, 0x3f, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, + 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, + 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, + 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, + 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + ]; + + let result = iface.inner.process_sixlowpan( + &mut sockets, + PacketMeta::default(), + &ieee802154_repr, + &request_second_part, + &mut iface.fragments, + ); + + assert_eq!( + result, + Some(Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address([ + 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x92, 0xfc, 0x48, 0xc2, 0xa4, 0x41, + 0xfc, 0x76, + ]), + dst_addr: Ipv6Address([ + 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, 0x42, 0x42, 0x42, 0x42, 0x42, + 0xb, 0x1a, + ]), + next_header: IpProtocol::Icmpv6, + payload_len: 136, + hop_limit: 64, + }, + IpPayload::Icmpv6(Icmpv6Repr::EchoReply { + ident: 39, + seq_no: 2, + data, + }) + )) + ); + + iface.inner.neighbor_cache.fill( + IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, 0x42, 0x42, 0x42, 0x42, 0x42, 0xb, 0x1a, + ])), + HardwareAddress::Ieee802154(Ieee802154Address::default()), + Instant::now(), + ); + + let tx_token = device.transmit(Instant::now()).unwrap(); + iface.inner.dispatch_ieee802154( + Ieee802154Address::default(), + tx_token, + PacketMeta::default(), + result.unwrap(), + &mut iface.fragmenter, + ); + + assert_eq!( + device.queue.pop_front().unwrap(), + &[ + 0x41, 0xcc, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x2, 0x2, 0x2, + 0x2, 0x2, 0x2, 0x2, 0xc0, 0xb0, 0x5, 0x4e, 0x7a, 0x11, 0x3a, 0x92, 0xfc, 0x48, 0xc2, + 0xa4, 0x41, 0xfc, 0x76, 0x40, 0x42, 0x42, 0x42, 0x42, 0x42, 0xb, 0x1a, 0x81, 0x0, 0x0, + 0x0, 0x0, 0x27, 0x0, 0x2, 0xa2, 0xc2, 0x2d, 0x63, 0x0, 0x0, 0x0, 0x0, 0xd9, 0x5e, 0xc, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, + 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, + 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40, 0x41, 0x42, 0x43, + 0x44, 0x45, 0x46, 0x47, + ] + ); + + iface.poll(Instant::now(), &mut device, &mut sockets); + + assert_eq!( + device.queue.pop_front().unwrap(), + &[ + 0x41, 0xcc, 0x4, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x2, 0x2, 0x2, + 0x2, 0x2, 0x2, 0x2, 0xe0, 0xb0, 0x5, 0x4e, 0xf, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, + 0x4e, 0x4f, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, + 0x5c, 0x5d, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, + 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, + 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + ] + ); +} + +#[test] +#[cfg(feature = "proto-sixlowpan-fragmentation")] +fn test_sixlowpan_udp_with_fragmentation() { + use crate::phy::Checksum; + + let mut ieee802154_repr = Ieee802154Repr { + frame_type: Ieee802154FrameType::Data, + security_enabled: false, + frame_pending: false, + ack_request: false, + sequence_number: Some(5), + pan_id_compression: true, + frame_version: Ieee802154FrameVersion::Ieee802154_2003, + dst_pan_id: Some(Ieee802154Pan(0xbeef)), + dst_addr: Some(Ieee802154Address::Extended([ + 0x90, 0xfc, 0x48, 0xc2, 0xa4, 0x41, 0xfc, 0x76, + ])), + src_pan_id: Some(Ieee802154Pan(0xbeef)), + src_addr: Some(Ieee802154Address::Extended([ + 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, + ])), + }; + + let (mut iface, mut sockets, mut device) = setup(Medium::Ieee802154); + iface.update_ip_addrs(|ips| { + ips.push(IpCidr::Ipv6(Ipv6Cidr::new( + Ipv6Address::new(0xfe80, 0x0, 0x0, 0x0, 0x92fc, 0x48c2, 0xa441, 0xfc76), + 10, + ))) + .unwrap(); + }); + iface.inner.caps.checksum.udp = Checksum::None; + + let udp_rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 1024 * 4]); + let udp_tx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 1024 * 4]); + let udp_socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); + let udp_socket_handle = sockets.add(udp_socket); + + { + let socket = sockets.get_mut::<udp::Socket>(udp_socket_handle); + assert_eq!(socket.bind(6969), Ok(())); + assert!(!socket.can_recv()); + assert!(socket.can_send()); + } + + let udp_first_part = &[ + 0xc0, 0xbc, 0x00, 0x92, 0x6e, 0x33, 0x07, 0xe7, 0xdc, 0xf0, 0xd3, 0xc9, 0x1b, 0x39, 0xbf, + 0xa0, 0x4c, 0x6f, 0x72, 0x65, 0x6d, 0x20, 0x69, 0x70, 0x73, 0x75, 0x6d, 0x20, 0x64, 0x6f, + 0x6c, 0x6f, 0x72, 0x20, 0x73, 0x69, 0x74, 0x20, 0x61, 0x6d, 0x65, 0x74, 0x2c, 0x20, 0x63, + 0x6f, 0x6e, 0x73, 0x65, 0x63, 0x74, 0x65, 0x74, 0x75, 0x72, 0x20, 0x61, 0x64, 0x69, 0x70, + 0x69, 0x73, 0x63, 0x69, 0x6e, 0x67, 0x20, 0x65, 0x6c, 0x69, 0x74, 0x2e, 0x20, 0x49, 0x6e, + 0x20, 0x61, 0x74, 0x20, 0x72, 0x68, 0x6f, 0x6e, 0x63, 0x75, 0x73, 0x20, 0x74, 0x6f, 0x72, + 0x74, 0x6f, 0x72, 0x2e, 0x20, 0x43, 0x72, 0x61, 0x73, 0x20, 0x62, 0x6c, 0x61, 0x6e, + ]; + + assert_eq!( + iface.inner.process_sixlowpan( + &mut sockets, + PacketMeta::default(), + &ieee802154_repr, + udp_first_part, + &mut iface.fragments + ), + None + ); + + ieee802154_repr.sequence_number = Some(6); + + let udp_second_part = &[ + 0xe0, 0xbc, 0x00, 0x92, 0x11, 0x64, 0x69, 0x74, 0x20, 0x74, 0x65, 0x6c, 0x6c, 0x75, 0x73, + 0x20, 0x64, 0x69, 0x61, 0x6d, 0x2c, 0x20, 0x76, 0x61, 0x72, 0x69, 0x75, 0x73, 0x20, 0x76, + 0x65, 0x73, 0x74, 0x69, 0x62, 0x75, 0x6c, 0x75, 0x6d, 0x20, 0x6e, 0x69, 0x62, 0x68, 0x20, + 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x64, 0x6f, 0x20, 0x6e, 0x65, 0x63, 0x2e, + ]; + + assert_eq!( + iface.inner.process_sixlowpan( + &mut sockets, + PacketMeta::default(), + &ieee802154_repr, + udp_second_part, + &mut iface.fragments + ), + None + ); + + let socket = sockets.get_mut::<udp::Socket>(udp_socket_handle); + + let udp_data = b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. \ +In at rhoncus tortor. Cras blandit tellus diam, varius vestibulum nibh commodo nec."; + assert_eq!( + socket.recv(), + Ok(( + &udp_data[..], + IpEndpoint { + addr: IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, 0x42, 0x42, 0x42, 0x42, 0x42, + 0xb, 0x1a, + ])), + port: 54217, + } + .into() + )) + ); + + let tx_token = device.transmit(Instant::now()).unwrap(); + iface.inner.dispatch_ieee802154( + Ieee802154Address::default(), + tx_token, + PacketMeta::default(), + Packet::new_ipv6( + Ipv6Repr { + src_addr: Ipv6Address::default(), + dst_addr: Ipv6Address::default(), + next_header: IpProtocol::Udp, + payload_len: udp_data.len(), + hop_limit: 64, + }, + IpPayload::Udp( + UdpRepr { + src_port: 1234, + dst_port: 1234, + }, + udp_data, + ), + ), + &mut iface.fragmenter, + ); + + iface.poll(Instant::now(), &mut device, &mut sockets); + + assert_eq!( + device.queue.pop_front().unwrap(), + &[ + 0x41, 0xcc, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x2, 0x2, 0x2, + 0x2, 0x2, 0x2, 0x2, 0xc0, 0xb4, 0x5, 0x4e, 0x7e, 0x40, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf0, 0x4, 0xd2, 0x4, 0xd2, 0x0, 0x0, + 0x4c, 0x6f, 0x72, 0x65, 0x6d, 0x20, 0x69, 0x70, 0x73, 0x75, 0x6d, 0x20, 0x64, 0x6f, + 0x6c, 0x6f, 0x72, 0x20, 0x73, 0x69, 0x74, 0x20, 0x61, 0x6d, 0x65, 0x74, 0x2c, 0x20, + 0x63, 0x6f, 0x6e, 0x73, 0x65, 0x63, 0x74, 0x65, 0x74, 0x75, 0x72, 0x20, 0x61, 0x64, + 0x69, 0x70, 0x69, 0x73, 0x63, 0x69, 0x6e, 0x67, 0x20, 0x65, 0x6c, 0x69, 0x74, 0x2e, + 0x20, 0x49, 0x6e, 0x20, 0x61, 0x74, 0x20, 0x72, 0x68, 0x6f, 0x6e, 0x63, 0x75, 0x73, + 0x20, 0x74, + ], + ); + + assert_eq!( + device.queue.pop_front().unwrap(), + &[ + 0x41, 0xcc, 0x4, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x2, 0x2, 0x2, + 0x2, 0x2, 0x2, 0x2, 0xe0, 0xb4, 0x5, 0x4e, 0xf, 0x6f, 0x72, 0x74, 0x6f, 0x72, 0x2e, + 0x20, 0x43, 0x72, 0x61, 0x73, 0x20, 0x62, 0x6c, 0x61, 0x6e, 0x64, 0x69, 0x74, 0x20, + 0x74, 0x65, 0x6c, 0x6c, 0x75, 0x73, 0x20, 0x64, 0x69, 0x61, 0x6d, 0x2c, 0x20, 0x76, + 0x61, 0x72, 0x69, 0x75, 0x73, 0x20, 0x76, 0x65, 0x73, 0x74, 0x69, 0x62, 0x75, 0x6c, + 0x75, 0x6d, 0x20, 0x6e, 0x69, 0x62, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x64, + 0x6f, 0x20, 0x6e, 0x65, 0x63, 0x2e, + ] + ); +} diff --git a/src/iface/mod.rs b/src/iface/mod.rs new file mode 100644 index 0000000..3076088 --- /dev/null +++ b/src/iface/mod.rs @@ -0,0 +1,24 @@ +/*! Network interface logic. + +The `iface` module deals with the *network interfaces*. It filters incoming frames, +provides lookup and caching of hardware addresses, and handles management packets. +*/ + +mod fragmentation; +mod interface; +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +mod neighbor; +mod route; +#[cfg(feature = "proto-rpl")] +mod rpl; +mod socket_meta; +mod socket_set; + +mod packet; + +#[cfg(feature = "proto-igmp")] +pub use self::interface::MulticastError; +pub use self::interface::{Config, Interface, InterfaceInner as Context}; + +pub use self::route::{Route, RouteTableFull, Routes}; +pub use self::socket_set::{SocketHandle, SocketSet, SocketStorage}; diff --git a/src/iface/neighbor.rs b/src/iface/neighbor.rs new file mode 100644 index 0000000..0c451fa --- /dev/null +++ b/src/iface/neighbor.rs @@ -0,0 +1,306 @@ +// Heads up! Before working on this file you should read, at least, +// the parts of RFC 1122 that discuss ARP. + +use heapless::LinearMap; + +use crate::config::IFACE_NEIGHBOR_CACHE_COUNT; +use crate::time::{Duration, Instant}; +use crate::wire::{HardwareAddress, IpAddress}; + +/// A cached neighbor. +/// +/// A neighbor mapping translates from a protocol address to a hardware address, +/// and contains the timestamp past which the mapping should be discarded. +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Neighbor { + hardware_addr: HardwareAddress, + expires_at: Instant, +} + +/// An answer to a neighbor cache lookup. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) enum Answer { + /// The neighbor address is in the cache and not expired. + Found(HardwareAddress), + /// The neighbor address is not in the cache, or has expired. + NotFound, + /// The neighbor address is not in the cache, or has expired, + /// and a lookup has been made recently. + RateLimited, +} + +impl Answer { + /// Returns whether a valid address was found. + pub(crate) fn found(&self) -> bool { + match self { + Answer::Found(_) => true, + _ => false, + } + } +} + +/// A neighbor cache backed by a map. +#[derive(Debug)] +pub struct Cache { + storage: LinearMap<IpAddress, Neighbor, IFACE_NEIGHBOR_CACHE_COUNT>, + silent_until: Instant, +} + +impl Cache { + /// Minimum delay between discovery requests, in milliseconds. + pub(crate) const SILENT_TIME: Duration = Duration::from_millis(1_000); + + /// Neighbor entry lifetime, in milliseconds. + pub(crate) const ENTRY_LIFETIME: Duration = Duration::from_millis(60_000); + + /// Create a cache. + pub fn new() -> Self { + Self { + storage: LinearMap::new(), + silent_until: Instant::from_millis(0), + } + } + + pub fn fill( + &mut self, + protocol_addr: IpAddress, + hardware_addr: HardwareAddress, + timestamp: Instant, + ) { + debug_assert!(protocol_addr.is_unicast()); + debug_assert!(hardware_addr.is_unicast()); + + let expires_at = timestamp + Self::ENTRY_LIFETIME; + self.fill_with_expiration(protocol_addr, hardware_addr, expires_at); + } + + pub fn fill_with_expiration( + &mut self, + protocol_addr: IpAddress, + hardware_addr: HardwareAddress, + expires_at: Instant, + ) { + debug_assert!(protocol_addr.is_unicast()); + debug_assert!(hardware_addr.is_unicast()); + + let neighbor = Neighbor { + expires_at, + hardware_addr, + }; + match self.storage.insert(protocol_addr, neighbor) { + Ok(Some(old_neighbor)) => { + if old_neighbor.hardware_addr != hardware_addr { + net_trace!( + "replaced {} => {} (was {})", + protocol_addr, + hardware_addr, + old_neighbor.hardware_addr + ); + } + } + Ok(None) => { + net_trace!("filled {} => {} (was empty)", protocol_addr, hardware_addr); + } + Err((protocol_addr, neighbor)) => { + // If we're going down this branch, it means the cache is full, and we need to evict an entry. + let old_protocol_addr = *self + .storage + .iter() + .min_by_key(|(_, neighbor)| neighbor.expires_at) + .expect("empty neighbor cache storage") + .0; + + let _old_neighbor = self.storage.remove(&old_protocol_addr).unwrap(); + match self.storage.insert(protocol_addr, neighbor) { + Ok(None) => { + net_trace!( + "filled {} => {} (evicted {} => {})", + protocol_addr, + hardware_addr, + old_protocol_addr, + _old_neighbor.hardware_addr + ); + } + // We've covered everything else above. + _ => unreachable!(), + } + } + } + } + + pub(crate) fn lookup(&self, protocol_addr: &IpAddress, timestamp: Instant) -> Answer { + assert!(protocol_addr.is_unicast()); + + if let Some(&Neighbor { + expires_at, + hardware_addr, + }) = self.storage.get(protocol_addr) + { + if timestamp < expires_at { + return Answer::Found(hardware_addr); + } + } + + if timestamp < self.silent_until { + Answer::RateLimited + } else { + Answer::NotFound + } + } + + pub(crate) fn limit_rate(&mut self, timestamp: Instant) { + self.silent_until = timestamp + Self::SILENT_TIME; + } + + pub(crate) fn flush(&mut self) { + self.storage.clear() + } +} + +#[cfg(feature = "medium-ethernet")] +#[cfg(test)] +mod test { + use super::*; + use crate::wire::ip::test::{MOCK_IP_ADDR_1, MOCK_IP_ADDR_2, MOCK_IP_ADDR_3, MOCK_IP_ADDR_4}; + + use crate::wire::EthernetAddress; + + const HADDR_A: HardwareAddress = HardwareAddress::Ethernet(EthernetAddress([0, 0, 0, 0, 0, 1])); + const HADDR_B: HardwareAddress = HardwareAddress::Ethernet(EthernetAddress([0, 0, 0, 0, 0, 2])); + const HADDR_C: HardwareAddress = HardwareAddress::Ethernet(EthernetAddress([0, 0, 0, 0, 0, 3])); + const HADDR_D: HardwareAddress = HardwareAddress::Ethernet(EthernetAddress([0, 0, 0, 0, 0, 4])); + + #[test] + fn test_fill() { + let mut cache = Cache::new(); + + assert!(!cache + .lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)) + .found()); + assert!(!cache + .lookup(&MOCK_IP_ADDR_2, Instant::from_millis(0)) + .found()); + + cache.fill(MOCK_IP_ADDR_1, HADDR_A, Instant::from_millis(0)); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)), + Answer::Found(HADDR_A) + ); + assert!(!cache + .lookup(&MOCK_IP_ADDR_2, Instant::from_millis(0)) + .found()); + assert!(!cache + .lookup( + &MOCK_IP_ADDR_1, + Instant::from_millis(0) + Cache::ENTRY_LIFETIME * 2 + ) + .found(),); + + cache.fill(MOCK_IP_ADDR_1, HADDR_A, Instant::from_millis(0)); + assert!(!cache + .lookup(&MOCK_IP_ADDR_2, Instant::from_millis(0)) + .found()); + } + + #[test] + fn test_expire() { + let mut cache = Cache::new(); + + cache.fill(MOCK_IP_ADDR_1, HADDR_A, Instant::from_millis(0)); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)), + Answer::Found(HADDR_A) + ); + assert!(!cache + .lookup( + &MOCK_IP_ADDR_1, + Instant::from_millis(0) + Cache::ENTRY_LIFETIME * 2 + ) + .found(),); + } + + #[test] + fn test_replace() { + let mut cache = Cache::new(); + + cache.fill(MOCK_IP_ADDR_1, HADDR_A, Instant::from_millis(0)); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)), + Answer::Found(HADDR_A) + ); + cache.fill(MOCK_IP_ADDR_1, HADDR_B, Instant::from_millis(0)); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)), + Answer::Found(HADDR_B) + ); + } + + #[test] + fn test_evict() { + let mut cache = Cache::new(); + + cache.fill(MOCK_IP_ADDR_1, HADDR_A, Instant::from_millis(100)); + cache.fill(MOCK_IP_ADDR_2, HADDR_B, Instant::from_millis(50)); + cache.fill(MOCK_IP_ADDR_3, HADDR_C, Instant::from_millis(200)); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_2, Instant::from_millis(1000)), + Answer::Found(HADDR_B) + ); + assert!(!cache + .lookup(&MOCK_IP_ADDR_4, Instant::from_millis(1000)) + .found()); + + cache.fill(MOCK_IP_ADDR_4, HADDR_D, Instant::from_millis(300)); + assert!(!cache + .lookup(&MOCK_IP_ADDR_2, Instant::from_millis(1000)) + .found()); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_4, Instant::from_millis(1000)), + Answer::Found(HADDR_D) + ); + } + + #[test] + fn test_hush() { + let mut cache = Cache::new(); + + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)), + Answer::NotFound + ); + + cache.limit_rate(Instant::from_millis(0)); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(100)), + Answer::RateLimited + ); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(2000)), + Answer::NotFound + ); + } + + #[test] + fn test_flush() { + let mut cache = Cache::new(); + + cache.fill(MOCK_IP_ADDR_1, HADDR_A, Instant::from_millis(0)); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)), + Answer::Found(HADDR_A) + ); + assert!(!cache + .lookup(&MOCK_IP_ADDR_2, Instant::from_millis(0)) + .found()); + + cache.flush(); + assert!(!cache + .lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)) + .found()); + assert!(!cache + .lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)) + .found()); + } +} diff --git a/src/iface/packet.rs b/src/iface/packet.rs new file mode 100644 index 0000000..4fdb19d --- /dev/null +++ b/src/iface/packet.rs @@ -0,0 +1,234 @@ +use crate::phy::DeviceCapabilities; +use crate::wire::*; + +#[allow(clippy::large_enum_variant)] +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[cfg(feature = "medium-ethernet")] +pub(crate) enum EthernetPacket<'a> { + #[cfg(feature = "proto-ipv4")] + Arp(ArpRepr), + Ip(Packet<'a>), +} + +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) enum Packet<'p> { + #[cfg(feature = "proto-ipv4")] + Ipv4(PacketV4<'p>), + #[cfg(feature = "proto-ipv6")] + Ipv6(PacketV6<'p>), +} + +impl<'p> Packet<'p> { + pub(crate) fn new(ip_repr: IpRepr, payload: IpPayload<'p>) -> Self { + match ip_repr { + #[cfg(feature = "proto-ipv4")] + IpRepr::Ipv4(header) => Self::new_ipv4(header, payload), + #[cfg(feature = "proto-ipv6")] + IpRepr::Ipv6(header) => Self::new_ipv6(header, payload), + } + } + + #[cfg(feature = "proto-ipv4")] + pub(crate) fn new_ipv4(ip_repr: Ipv4Repr, payload: IpPayload<'p>) -> Self { + Self::Ipv4(PacketV4 { + header: ip_repr, + payload, + }) + } + + #[cfg(feature = "proto-ipv6")] + pub(crate) fn new_ipv6(ip_repr: Ipv6Repr, payload: IpPayload<'p>) -> Self { + Self::Ipv6(PacketV6 { + header: ip_repr, + #[cfg(feature = "proto-ipv6-hbh")] + hop_by_hop: None, + #[cfg(feature = "proto-ipv6-fragmentation")] + fragment: None, + #[cfg(feature = "proto-ipv6-routing")] + routing: None, + payload, + }) + } + + pub(crate) fn ip_repr(&self) -> IpRepr { + match self { + #[cfg(feature = "proto-ipv4")] + Packet::Ipv4(p) => IpRepr::Ipv4(p.header), + #[cfg(feature = "proto-ipv6")] + Packet::Ipv6(p) => IpRepr::Ipv6(p.header), + } + } + + pub(crate) fn payload(&self) -> &IpPayload<'p> { + match self { + #[cfg(feature = "proto-ipv4")] + Packet::Ipv4(p) => &p.payload, + #[cfg(feature = "proto-ipv6")] + Packet::Ipv6(p) => &p.payload, + } + } + + pub(crate) fn emit_payload( + &self, + _ip_repr: &IpRepr, + payload: &mut [u8], + caps: &DeviceCapabilities, + ) { + match self.payload() { + #[cfg(feature = "proto-ipv4")] + IpPayload::Icmpv4(icmpv4_repr) => { + icmpv4_repr.emit(&mut Icmpv4Packet::new_unchecked(payload), &caps.checksum) + } + #[cfg(feature = "proto-igmp")] + IpPayload::Igmp(igmp_repr) => igmp_repr.emit(&mut IgmpPacket::new_unchecked(payload)), + #[cfg(feature = "proto-ipv6")] + IpPayload::Icmpv6(icmpv6_repr) => icmpv6_repr.emit( + &_ip_repr.src_addr(), + &_ip_repr.dst_addr(), + &mut Icmpv6Packet::new_unchecked(payload), + &caps.checksum, + ), + #[cfg(feature = "socket-raw")] + IpPayload::Raw(raw_packet) => payload.copy_from_slice(raw_packet), + #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] + IpPayload::Udp(udp_repr, inner_payload) => udp_repr.emit( + &mut UdpPacket::new_unchecked(payload), + &_ip_repr.src_addr(), + &_ip_repr.dst_addr(), + inner_payload.len(), + |buf| buf.copy_from_slice(inner_payload), + &caps.checksum, + ), + #[cfg(feature = "socket-tcp")] + IpPayload::Tcp(mut tcp_repr) => { + // This is a terrible hack to make TCP performance more acceptable on systems + // where the TCP buffers are significantly larger than network buffers, + // e.g. a 64 kB TCP receive buffer (and so, when empty, a 64k window) + // together with four 1500 B Ethernet receive buffers. If left untreated, + // this would result in our peer pushing our window and sever packet loss. + // + // I'm really not happy about this "solution" but I don't know what else to do. + if let Some(max_burst_size) = caps.max_burst_size { + let mut max_segment_size = caps.max_transmission_unit; + max_segment_size -= _ip_repr.header_len(); + max_segment_size -= tcp_repr.header_len(); + + let max_window_size = max_burst_size * max_segment_size; + if tcp_repr.window_len as usize > max_window_size { + tcp_repr.window_len = max_window_size as u16; + } + } + + tcp_repr.emit( + &mut TcpPacket::new_unchecked(payload), + &_ip_repr.src_addr(), + &_ip_repr.dst_addr(), + &caps.checksum, + ); + } + #[cfg(feature = "socket-dhcpv4")] + IpPayload::Dhcpv4(udp_repr, dhcp_repr) => udp_repr.emit( + &mut UdpPacket::new_unchecked(payload), + &_ip_repr.src_addr(), + &_ip_repr.dst_addr(), + dhcp_repr.buffer_len(), + |buf| dhcp_repr.emit(&mut DhcpPacket::new_unchecked(buf)).unwrap(), + &caps.checksum, + ), + } + } +} + +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[cfg(feature = "proto-ipv4")] +pub(crate) struct PacketV4<'p> { + header: Ipv4Repr, + payload: IpPayload<'p>, +} + +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[cfg(feature = "proto-ipv6")] +pub(crate) struct PacketV6<'p> { + pub(crate) header: Ipv6Repr, + #[cfg(feature = "proto-ipv6-hbh")] + pub(crate) hop_by_hop: Option<Ipv6HopByHopRepr<'p>>, + #[cfg(feature = "proto-ipv6-fragmentation")] + pub(crate) fragment: Option<Ipv6FragmentRepr>, + #[cfg(feature = "proto-ipv6-routing")] + pub(crate) routing: Option<Ipv6RoutingRepr<'p>>, + pub(crate) payload: IpPayload<'p>, +} + +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) enum IpPayload<'p> { + #[cfg(feature = "proto-ipv4")] + Icmpv4(Icmpv4Repr<'p>), + #[cfg(feature = "proto-igmp")] + Igmp(IgmpRepr), + #[cfg(feature = "proto-ipv6")] + Icmpv6(Icmpv6Repr<'p>), + #[cfg(feature = "socket-raw")] + Raw(&'p [u8]), + #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] + Udp(UdpRepr, &'p [u8]), + #[cfg(feature = "socket-tcp")] + Tcp(TcpRepr<'p>), + #[cfg(feature = "socket-dhcpv4")] + Dhcpv4(UdpRepr, DhcpRepr<'p>), +} + +impl<'p> IpPayload<'p> { + #[cfg(feature = "proto-sixlowpan")] + pub(crate) fn as_sixlowpan_next_header(&self) -> SixlowpanNextHeader { + match self { + #[cfg(feature = "proto-ipv4")] + Self::Icmpv4(_) => unreachable!(), + #[cfg(feature = "socket-dhcpv4")] + Self::Dhcpv4(..) => unreachable!(), + #[cfg(feature = "proto-ipv6")] + Self::Icmpv6(_) => SixlowpanNextHeader::Uncompressed(IpProtocol::Icmpv6), + #[cfg(feature = "proto-igmp")] + Self::Igmp(_) => unreachable!(), + #[cfg(feature = "socket-tcp")] + Self::Tcp(_) => SixlowpanNextHeader::Uncompressed(IpProtocol::Tcp), + #[cfg(feature = "socket-udp")] + Self::Udp(..) => SixlowpanNextHeader::Compressed, + #[cfg(feature = "socket-raw")] + Self::Raw(_) => todo!(), + } + } +} + +#[cfg(any(feature = "proto-ipv4", feature = "proto-ipv6"))] +pub(crate) fn icmp_reply_payload_len(len: usize, mtu: usize, header_len: usize) -> usize { + // Send back as much of the original payload as will fit within + // the minimum MTU required by IPv4. See RFC 1812 § 4.3.2.3 for + // more details. + // + // Since the entire network layer packet must fit within the minimum + // MTU supported, the payload must not exceed the following: + // + // <min mtu> - IP Header Size * 2 - ICMPv4 DstUnreachable hdr size + len.min(mtu - header_len * 2 - 8) +} + +#[cfg(feature = "proto-igmp")] +pub(crate) enum IgmpReportState { + Inactive, + ToGeneralQuery { + version: IgmpVersion, + timeout: crate::time::Instant, + interval: crate::time::Duration, + next_index: usize, + }, + ToSpecificQuery { + version: IgmpVersion, + timeout: crate::time::Instant, + group: Ipv4Address, + }, +} diff --git a/src/iface/route.rs b/src/iface/route.rs new file mode 100644 index 0000000..123c695 --- /dev/null +++ b/src/iface/route.rs @@ -0,0 +1,327 @@ +use heapless::Vec; + +use crate::config::IFACE_MAX_ROUTE_COUNT; +use crate::time::Instant; +use crate::wire::{IpAddress, IpCidr}; +#[cfg(feature = "proto-ipv4")] +use crate::wire::{Ipv4Address, Ipv4Cidr}; +#[cfg(feature = "proto-ipv6")] +use crate::wire::{Ipv6Address, Ipv6Cidr}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct RouteTableFull; + +impl core::fmt::Display for RouteTableFull { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "Route table full") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for RouteTableFull {} + +/// A prefix of addresses that should be routed via a router +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Route { + pub cidr: IpCidr, + pub via_router: IpAddress, + /// `None` means "forever". + pub preferred_until: Option<Instant>, + /// `None` means "forever". + pub expires_at: Option<Instant>, +} + +#[cfg(feature = "proto-ipv4")] +const IPV4_DEFAULT: IpCidr = IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Address::new(0, 0, 0, 0), 0)); +#[cfg(feature = "proto-ipv6")] +const IPV6_DEFAULT: IpCidr = + IpCidr::Ipv6(Ipv6Cidr::new(Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 0), 0)); + +impl Route { + /// Returns a route to 0.0.0.0/0 via the `gateway`, with no expiry. + #[cfg(feature = "proto-ipv4")] + pub fn new_ipv4_gateway(gateway: Ipv4Address) -> Route { + Route { + cidr: IPV4_DEFAULT, + via_router: gateway.into(), + preferred_until: None, + expires_at: None, + } + } + + /// Returns a route to ::/0 via the `gateway`, with no expiry. + #[cfg(feature = "proto-ipv6")] + pub fn new_ipv6_gateway(gateway: Ipv6Address) -> Route { + Route { + cidr: IPV6_DEFAULT, + via_router: gateway.into(), + preferred_until: None, + expires_at: None, + } + } +} + +/// A routing table. +#[derive(Debug)] +pub struct Routes { + storage: Vec<Route, IFACE_MAX_ROUTE_COUNT>, +} + +impl Routes { + /// Creates a new empty routing table. + pub fn new() -> Self { + Self { + storage: Vec::new(), + } + } + + /// Update the routes of this node. + pub fn update<F: FnOnce(&mut Vec<Route, IFACE_MAX_ROUTE_COUNT>)>(&mut self, f: F) { + f(&mut self.storage); + } + + /// Add a default ipv4 gateway (ie. "ip route add 0.0.0.0/0 via `gateway`"). + /// + /// On success, returns the previous default route, if any. + #[cfg(feature = "proto-ipv4")] + pub fn add_default_ipv4_route( + &mut self, + gateway: Ipv4Address, + ) -> Result<Option<Route>, RouteTableFull> { + let old = self.remove_default_ipv4_route(); + self.storage + .push(Route::new_ipv4_gateway(gateway)) + .map_err(|_| RouteTableFull)?; + Ok(old) + } + + /// Add a default ipv6 gateway (ie. "ip -6 route add ::/0 via `gateway`"). + /// + /// On success, returns the previous default route, if any. + #[cfg(feature = "proto-ipv6")] + pub fn add_default_ipv6_route( + &mut self, + gateway: Ipv6Address, + ) -> Result<Option<Route>, RouteTableFull> { + let old = self.remove_default_ipv6_route(); + self.storage + .push(Route::new_ipv6_gateway(gateway)) + .map_err(|_| RouteTableFull)?; + Ok(old) + } + + /// Remove the default ipv4 gateway + /// + /// On success, returns the previous default route, if any. + #[cfg(feature = "proto-ipv4")] + pub fn remove_default_ipv4_route(&mut self) -> Option<Route> { + if let Some((i, _)) = self + .storage + .iter() + .enumerate() + .find(|(_, r)| r.cidr == IPV4_DEFAULT) + { + Some(self.storage.remove(i)) + } else { + None + } + } + + /// Remove the default ipv6 gateway + /// + /// On success, returns the previous default route, if any. + #[cfg(feature = "proto-ipv6")] + pub fn remove_default_ipv6_route(&mut self) -> Option<Route> { + if let Some((i, _)) = self + .storage + .iter() + .enumerate() + .find(|(_, r)| r.cidr == IPV6_DEFAULT) + { + Some(self.storage.remove(i)) + } else { + None + } + } + + pub(crate) fn lookup(&self, addr: &IpAddress, timestamp: Instant) -> Option<IpAddress> { + assert!(addr.is_unicast()); + + self.storage + .iter() + // Keep only matching routes + .filter(|route| { + if let Some(expires_at) = route.expires_at { + if timestamp > expires_at { + return false; + } + } + route.cidr.contains_addr(addr) + }) + // pick the most specific one (highest prefix_len) + .max_by_key(|route| route.cidr.prefix_len()) + .map(|route| route.via_router) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[cfg(feature = "proto-ipv6")] + mod mock { + use super::super::*; + pub const ADDR_1A: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1]); + pub const ADDR_1B: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 13]); + pub const ADDR_1C: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 42]); + pub fn cidr_1() -> Ipv6Cidr { + Ipv6Cidr::new( + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0]), + 64, + ) + } + + pub const ADDR_2A: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 51, 100, 0, 0, 0, 0, 0, 0, 0, 1]); + pub const ADDR_2B: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 51, 100, 0, 0, 0, 0, 0, 0, 0, 21]); + pub fn cidr_2() -> Ipv6Cidr { + Ipv6Cidr::new( + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 51, 100, 0, 0, 0, 0, 0, 0, 0, 0]), + 64, + ) + } + } + + #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] + mod mock { + use super::super::*; + pub const ADDR_1A: Ipv4Address = Ipv4Address([192, 0, 2, 1]); + pub const ADDR_1B: Ipv4Address = Ipv4Address([192, 0, 2, 13]); + pub const ADDR_1C: Ipv4Address = Ipv4Address([192, 0, 2, 42]); + pub fn cidr_1() -> Ipv4Cidr { + Ipv4Cidr::new(Ipv4Address([192, 0, 2, 0]), 24) + } + + pub const ADDR_2A: Ipv4Address = Ipv4Address([198, 51, 100, 1]); + pub const ADDR_2B: Ipv4Address = Ipv4Address([198, 51, 100, 21]); + pub fn cidr_2() -> Ipv4Cidr { + Ipv4Cidr::new(Ipv4Address([198, 51, 100, 0]), 24) + } + } + + use self::mock::*; + + #[test] + fn test_fill() { + let mut routes = Routes::new(); + + assert_eq!( + routes.lookup(&ADDR_1A.into(), Instant::from_millis(0)), + None + ); + assert_eq!( + routes.lookup(&ADDR_1B.into(), Instant::from_millis(0)), + None + ); + assert_eq!( + routes.lookup(&ADDR_1C.into(), Instant::from_millis(0)), + None + ); + assert_eq!( + routes.lookup(&ADDR_2A.into(), Instant::from_millis(0)), + None + ); + assert_eq!( + routes.lookup(&ADDR_2B.into(), Instant::from_millis(0)), + None + ); + + let route = Route { + cidr: cidr_1().into(), + via_router: ADDR_1A.into(), + preferred_until: None, + expires_at: None, + }; + routes.update(|storage| { + storage.push(route).unwrap(); + }); + + assert_eq!( + routes.lookup(&ADDR_1A.into(), Instant::from_millis(0)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_1B.into(), Instant::from_millis(0)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_1C.into(), Instant::from_millis(0)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_2A.into(), Instant::from_millis(0)), + None + ); + assert_eq!( + routes.lookup(&ADDR_2B.into(), Instant::from_millis(0)), + None + ); + + let route2 = Route { + cidr: cidr_2().into(), + via_router: ADDR_2A.into(), + preferred_until: Some(Instant::from_millis(10)), + expires_at: Some(Instant::from_millis(10)), + }; + routes.update(|storage| { + storage.push(route2).unwrap(); + }); + + assert_eq!( + routes.lookup(&ADDR_1A.into(), Instant::from_millis(0)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_1B.into(), Instant::from_millis(0)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_1C.into(), Instant::from_millis(0)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_2A.into(), Instant::from_millis(0)), + Some(ADDR_2A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_2B.into(), Instant::from_millis(0)), + Some(ADDR_2A.into()) + ); + + assert_eq!( + routes.lookup(&ADDR_1A.into(), Instant::from_millis(10)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_1B.into(), Instant::from_millis(10)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_1C.into(), Instant::from_millis(10)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_2A.into(), Instant::from_millis(10)), + Some(ADDR_2A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_2B.into(), Instant::from_millis(10)), + Some(ADDR_2A.into()) + ); + } +} diff --git a/src/iface/rpl/consts.rs b/src/iface/rpl/consts.rs new file mode 100644 index 0000000..70a6613 --- /dev/null +++ b/src/iface/rpl/consts.rs @@ -0,0 +1,8 @@ +pub const SEQUENCE_WINDOW: u8 = 16; + +pub const DEFAULT_MIN_HOP_RANK_INCREASE: u16 = 256; + +pub const DEFAULT_DIO_INTERVAL_MIN: u32 = 12; +pub const DEFAULT_DIO_REDUNDANCY_CONSTANT: usize = 10; +/// This is 20 in the standard, but in Contiki they use: +pub const DEFAULT_DIO_INTERVAL_DOUBLINGS: u32 = 8; diff --git a/src/iface/rpl/lollipop.rs b/src/iface/rpl/lollipop.rs new file mode 100644 index 0000000..4785c77 --- /dev/null +++ b/src/iface/rpl/lollipop.rs @@ -0,0 +1,189 @@ +//! Implementation of sequence counters defined in [RFC 6550 § 7.2]. Values from 128 and greater +//! are used as a linear sequence to indicate a restart and bootstrap the counter. Values less than +//! or equal to 127 are used as a circular sequence number space of size 128. When operating in the +//! circular region, if sequence numbers are detected to be too far apart, then they are not +//! comparable. +//! +//! [RFC 6550 § 7.2]: https://datatracker.ietf.org/doc/html/rfc6550#section-7.2 + +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct SequenceCounter(u8); + +impl Default for SequenceCounter { + fn default() -> Self { + // RFC6550 7.2 recommends 240 (256 - SEQUENCE_WINDOW) as the initialization value of the + // counter. + Self(240) + } +} + +impl SequenceCounter { + /// Create a new sequence counter. + /// + /// Use `Self::default()` when a new sequence counter needs to be created with a value that is + /// recommended in RFC6550 7.2, being 240. + pub fn new(value: u8) -> Self { + Self(value) + } + + /// Return the value of the sequence counter. + pub fn value(&self) -> u8 { + self.0 + } + + /// Increment the sequence counter. + /// + /// When the sequence counter is greater than or equal to 128, the maximum value is 255. + /// When the sequence counter is less than 128, the maximum value is 127. + /// + /// When an increment of the sequence counter would cause the counter to increment beyond its + /// maximum value, the counter MUST wrap back to zero. + pub fn increment(&mut self) { + let max = if self.0 >= 128 { 255 } else { 127 }; + + self.0 = match self.0.checked_add(1) { + Some(val) if val <= max => val, + _ => 0, + }; + } +} + +impl PartialEq for SequenceCounter { + fn eq(&self, other: &Self) -> bool { + let a = self.value() as usize; + let b = other.value() as usize; + + if ((128..=255).contains(&a) && (0..=127).contains(&b)) + || ((128..=255).contains(&b) && (0..=127).contains(&a)) + { + false + } else { + let result = if a > b { a - b } else { b - a }; + + if result <= super::consts::SEQUENCE_WINDOW as usize { + // RFC1982 + a == b + } else { + // This case is actually not comparable. + false + } + } + } +} + +impl PartialOrd for SequenceCounter { + fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> { + use super::consts::SEQUENCE_WINDOW; + use core::cmp::Ordering; + + let a = self.value() as usize; + let b = other.value() as usize; + + if (128..256).contains(&a) && (0..128).contains(&b) { + if 256 + b - a <= SEQUENCE_WINDOW as usize { + Some(Ordering::Less) + } else { + Some(Ordering::Greater) + } + } else if (128..256).contains(&b) && (0..128).contains(&a) { + if 256 + a - b <= SEQUENCE_WINDOW as usize { + Some(Ordering::Greater) + } else { + Some(Ordering::Less) + } + } else if ((0..128).contains(&a) && (0..128).contains(&b)) + || ((128..256).contains(&a) && (128..256).contains(&b)) + { + let result = if a > b { a - b } else { b - a }; + + if result <= SEQUENCE_WINDOW as usize { + // RFC1982 + a.partial_cmp(&b) + } else { + // This case is not comparable. + None + } + } else { + unreachable!(); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sequence_counter_increment() { + let mut seq = SequenceCounter::new(253); + seq.increment(); + assert_eq!(seq.value(), 254); + seq.increment(); + assert_eq!(seq.value(), 255); + seq.increment(); + assert_eq!(seq.value(), 0); + + let mut seq = SequenceCounter::new(126); + seq.increment(); + assert_eq!(seq.value(), 127); + seq.increment(); + assert_eq!(seq.value(), 0); + } + + #[test] + fn sequence_counter_comparison() { + use core::cmp::Ordering; + + assert!(SequenceCounter::new(240) != SequenceCounter::new(1)); + assert!(SequenceCounter::new(1) != SequenceCounter::new(240)); + assert!(SequenceCounter::new(1) != SequenceCounter::new(240)); + assert!(SequenceCounter::new(240) == SequenceCounter::new(240)); + assert!(SequenceCounter::new(240 - 17) != SequenceCounter::new(240)); + + assert_eq!( + SequenceCounter::new(240).partial_cmp(&SequenceCounter::new(5)), + Some(Ordering::Greater) + ); + assert_eq!( + SequenceCounter::new(250).partial_cmp(&SequenceCounter::new(5)), + Some(Ordering::Less) + ); + assert_eq!( + SequenceCounter::new(5).partial_cmp(&SequenceCounter::new(250)), + Some(Ordering::Greater) + ); + assert_eq!( + SequenceCounter::new(127).partial_cmp(&SequenceCounter::new(129)), + Some(Ordering::Less) + ); + assert_eq!( + SequenceCounter::new(120).partial_cmp(&SequenceCounter::new(121)), + Some(Ordering::Less) + ); + assert_eq!( + SequenceCounter::new(121).partial_cmp(&SequenceCounter::new(120)), + Some(Ordering::Greater) + ); + assert_eq!( + SequenceCounter::new(240).partial_cmp(&SequenceCounter::new(241)), + Some(Ordering::Less) + ); + assert_eq!( + SequenceCounter::new(241).partial_cmp(&SequenceCounter::new(240)), + Some(Ordering::Greater) + ); + assert_eq!( + SequenceCounter::new(120).partial_cmp(&SequenceCounter::new(120)), + Some(Ordering::Equal) + ); + assert_eq!( + SequenceCounter::new(240).partial_cmp(&SequenceCounter::new(240)), + Some(Ordering::Equal) + ); + assert_eq!( + SequenceCounter::new(130).partial_cmp(&SequenceCounter::new(241)), + None + ); + } +} diff --git a/src/iface/rpl/mod.rs b/src/iface/rpl/mod.rs new file mode 100644 index 0000000..69aa9ae --- /dev/null +++ b/src/iface/rpl/mod.rs @@ -0,0 +1,9 @@ +#![allow(unused)] + +mod consts; +mod lollipop; +mod of0; +mod parents; +mod rank; +mod relations; +mod trickle; diff --git a/src/iface/rpl/of0.rs b/src/iface/rpl/of0.rs new file mode 100644 index 0000000..99e4d1f --- /dev/null +++ b/src/iface/rpl/of0.rs @@ -0,0 +1,129 @@ +use super::parents::*; +use super::rank::Rank; + +pub struct ObjectiveFunction0; + +pub(crate) trait ObjectiveFunction { + const OCP: u16; + + /// Return the new calculated Rank, based on information from the parent. + fn rank(current_rank: Rank, parent_rank: Rank) -> Rank; + + /// Return the preferred parent from a given parent set. + fn preferred_parent(parent_set: &ParentSet) -> Option<&Parent>; +} + +impl ObjectiveFunction0 { + const OCP: u16 = 0; + + const RANK_STRETCH: u16 = 0; + const RANK_FACTOR: u16 = 1; + const RANK_STEP: u16 = 3; + + fn rank_increase(parent_rank: Rank) -> u16 { + (Self::RANK_FACTOR * Self::RANK_STEP + Self::RANK_STRETCH) + * parent_rank.min_hop_rank_increase + } +} + +impl ObjectiveFunction for ObjectiveFunction0 { + const OCP: u16 = 0; + + fn rank(_: Rank, parent_rank: Rank) -> Rank { + assert_ne!(parent_rank, Rank::INFINITE); + + Rank::new( + parent_rank.value + Self::rank_increase(parent_rank), + parent_rank.min_hop_rank_increase, + ) + } + + fn preferred_parent(parent_set: &ParentSet) -> Option<&Parent> { + let mut pref_parent: Option<&Parent> = None; + + for (_, parent) in parent_set.parents() { + if pref_parent.is_none() || parent.rank() < pref_parent.unwrap().rank() { + pref_parent = Some(parent); + } + } + + pref_parent + } +} + +#[cfg(test)] +mod tests { + use crate::iface::rpl::consts::DEFAULT_MIN_HOP_RANK_INCREASE; + + use super::*; + + #[test] + fn rank_increase() { + // 256 (root) + 3 * 256 + assert_eq!( + ObjectiveFunction0::rank(Rank::INFINITE, Rank::ROOT), + Rank::new(256 + 3 * 256, DEFAULT_MIN_HOP_RANK_INCREASE) + ); + + // 1024 + 3 * 256 + assert_eq!( + ObjectiveFunction0::rank( + Rank::INFINITE, + Rank::new(1024, DEFAULT_MIN_HOP_RANK_INCREASE) + ), + Rank::new(1024 + 3 * 256, DEFAULT_MIN_HOP_RANK_INCREASE) + ); + } + + #[test] + #[should_panic] + fn rank_increase_infinite() { + assert_eq!( + ObjectiveFunction0::rank(Rank::INFINITE, Rank::INFINITE), + Rank::INFINITE + ); + } + + #[test] + fn empty_set() { + assert_eq!( + ObjectiveFunction0::preferred_parent(&ParentSet::default()), + None + ); + } + + #[test] + fn non_empty_set() { + use crate::wire::Ipv6Address; + + let mut parents = ParentSet::default(); + + parents.add( + Ipv6Address::default(), + Parent::new(0, Rank::ROOT, Default::default(), Ipv6Address::default()), + ); + + let mut address = Ipv6Address::default(); + address.0[15] = 1; + + parents.add( + address, + Parent::new( + 0, + Rank::new(1024, DEFAULT_MIN_HOP_RANK_INCREASE), + Default::default(), + Ipv6Address::default(), + ), + ); + + assert_eq!( + ObjectiveFunction0::preferred_parent(&parents), + Some(&Parent::new( + 0, + Rank::ROOT, + Default::default(), + Ipv6Address::default(), + )) + ); + } +} diff --git a/src/iface/rpl/parents.rs b/src/iface/rpl/parents.rs new file mode 100644 index 0000000..70d5a5e --- /dev/null +++ b/src/iface/rpl/parents.rs @@ -0,0 +1,176 @@ +use crate::wire::Ipv6Address; + +use super::{lollipop::SequenceCounter, rank::Rank}; +use crate::config::RPL_PARENTS_BUFFER_COUNT; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub(crate) struct Parent { + rank: Rank, + preference: u8, + version_number: SequenceCounter, + dodag_id: Ipv6Address, +} + +impl Parent { + /// Create a new parent. + pub(crate) fn new( + preference: u8, + rank: Rank, + version_number: SequenceCounter, + dodag_id: Ipv6Address, + ) -> Self { + Self { + rank, + preference, + version_number, + dodag_id, + } + } + + /// Return the Rank of the parent. + pub(crate) fn rank(&self) -> &Rank { + &self.rank + } +} + +#[derive(Debug, Default)] +pub(crate) struct ParentSet { + parents: heapless::LinearMap<Ipv6Address, Parent, { RPL_PARENTS_BUFFER_COUNT }>, +} + +impl ParentSet { + /// Add a new parent to the parent set. The Rank of the new parent should be lower than the + /// Rank of the node that holds this parent set. + pub(crate) fn add(&mut self, address: Ipv6Address, parent: Parent) { + if let Some(p) = self.parents.get_mut(&address) { + *p = parent; + } else if let Err(p) = self.parents.insert(address, parent) { + if let Some((w_a, w_p)) = self.worst_parent() { + if w_p.rank.dag_rank() > parent.rank.dag_rank() { + self.parents.remove(&w_a.clone()).unwrap(); + self.parents.insert(address, parent).unwrap(); + } else { + net_debug!("could not add {} to parent set, buffer is full", address); + } + } else { + unreachable!() + } + } + } + + /// Find a parent based on its address. + pub(crate) fn find(&self, address: &Ipv6Address) -> Option<&Parent> { + self.parents.get(address) + } + + /// Find a mutable parent based on its address. + pub(crate) fn find_mut(&mut self, address: &Ipv6Address) -> Option<&mut Parent> { + self.parents.get_mut(address) + } + + /// Return a slice to the parent set. + pub(crate) fn parents(&self) -> impl Iterator<Item = (&Ipv6Address, &Parent)> { + self.parents.iter() + } + + /// Find the worst parent that is currently in the parent set. + fn worst_parent(&self) -> Option<(&Ipv6Address, &Parent)> { + self.parents.iter().max_by_key(|(k, v)| v.rank.dag_rank()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn add_parent() { + let mut set = ParentSet::default(); + set.add( + Default::default(), + Parent::new(0, Rank::ROOT, Default::default(), Default::default()), + ); + + assert_eq!( + set.find(&Default::default()), + Some(&Parent::new( + 0, + Rank::ROOT, + Default::default(), + Default::default() + )) + ); + } + + #[test] + fn add_more_parents() { + use super::super::consts::DEFAULT_MIN_HOP_RANK_INCREASE; + let mut set = ParentSet::default(); + + let mut last_address = Default::default(); + for i in 0..RPL_PARENTS_BUFFER_COUNT { + let i = i as u16; + let mut address = Ipv6Address::default(); + address.0[15] = i as u8; + last_address = address; + + set.add( + address, + Parent::new( + 0, + Rank::new(256 * i, DEFAULT_MIN_HOP_RANK_INCREASE), + Default::default(), + address, + ), + ); + + assert_eq!( + set.find(&address), + Some(&Parent::new( + 0, + Rank::new(256 * i, DEFAULT_MIN_HOP_RANK_INCREASE), + Default::default(), + address, + )) + ); + } + + // This one is not added to the set, because its Rank is worse than any other parent in the + // set. + let mut address = Ipv6Address::default(); + address.0[15] = 8; + set.add( + address, + Parent::new( + 0, + Rank::new(256 * 8, DEFAULT_MIN_HOP_RANK_INCREASE), + Default::default(), + address, + ), + ); + assert_eq!(set.find(&address), None); + + /// This Parent has a better rank than the last one in the set. + let mut address = Ipv6Address::default(); + address.0[15] = 9; + set.add( + address, + Parent::new( + 0, + Rank::new(0, DEFAULT_MIN_HOP_RANK_INCREASE), + Default::default(), + address, + ), + ); + assert_eq!( + set.find(&address), + Some(&Parent::new( + 0, + Rank::new(0, DEFAULT_MIN_HOP_RANK_INCREASE), + Default::default(), + address + )) + ); + assert_eq!(set.find(&last_address), None); + } +} diff --git a/src/iface/rpl/rank.rs b/src/iface/rpl/rank.rs new file mode 100644 index 0000000..02a5ecf --- /dev/null +++ b/src/iface/rpl/rank.rs @@ -0,0 +1,104 @@ +//! Implementation of the Rank comparison in RPL. +//! +//! A Rank can be thought of as a fixed-point number, where the position of the radix point between +//! the integer part and the fractional part is determined by `MinHopRankIncrease`. +//! `MinHopRankIncrease` is the minimum increase in Rank between a node and any of its DODAG +//! parents. +//! This value is provisined by the DODAG root. +//! +//! When Rank is compared, the integer portion of the Rank is to be used. +//! +//! Meaning of the comparison: +//! - **Rank M is less than Rank N**: the position of M is closer to the DODAG root than the position +//! of N. Node M may safely be a DODAG parent for node N. +//! - **Ranks are equal**: the positions of both nodes within the DODAG and with respect to the DODAG +//! are similar or identical. Routing through a node with equal Rank may cause a routing loop. +//! - **Rank M is greater than Rank N**: the position of node M is farther from the DODAG root +//! than the position of N. Node M may in fact be in the sub-DODAG of node N. If node N selects +//! node M as a DODAG parent, there is a risk of creating a loop. + +use super::consts::DEFAULT_MIN_HOP_RANK_INCREASE; + +/// The Rank is the expression of the relative position within a DODAG Version with regard to +/// neighbors, and it is not necessarily a good indication or a proper expression of a distance or +/// a path cost to the root. +#[derive(Debug, Clone, Copy, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Rank { + pub(super) value: u16, + pub(super) min_hop_rank_increase: u16, +} + +impl core::fmt::Display for Rank { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "Rank({})", self.dag_rank()) + } +} + +impl Rank { + pub const INFINITE: Self = Rank::new(0xffff, DEFAULT_MIN_HOP_RANK_INCREASE); + + /// The ROOT_RANK is the smallest rank possible. + /// DAG_RANK(ROOT_RANK) should be 1. See RFC6550 § 17. + pub const ROOT: Self = Rank::new(DEFAULT_MIN_HOP_RANK_INCREASE, DEFAULT_MIN_HOP_RANK_INCREASE); + + /// Create a new Rank from some value and a `MinHopRankIncrease`. + /// The `MinHopRankIncrease` is used for calculating the integer part for comparing to other + /// Ranks. + pub const fn new(value: u16, min_hop_rank_increase: u16) -> Self { + assert!(min_hop_rank_increase > 0); + + Self { + value, + min_hop_rank_increase, + } + } + + /// Return the integer part of the Rank. + pub fn dag_rank(&self) -> u16 { + self.value / self.min_hop_rank_increase + } + + /// Return the raw Rank value. + pub fn raw_value(&self) -> u16 { + self.value + } +} + +impl PartialEq for Rank { + fn eq(&self, other: &Self) -> bool { + self.dag_rank() == other.dag_rank() + } +} + +impl PartialOrd for Rank { + fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> { + self.dag_rank().partial_cmp(&other.dag_rank()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn calculate_rank() { + let r = Rank::new(27, 16); + assert_eq!(r.dag_rank(), 1) + } + + #[test] + fn comparison() { + let r1 = Rank::ROOT; + let r2 = Rank::new(16, 16); + assert!(r1 == r2); + + let r1 = Rank::new(16, 16); + let r2 = Rank::new(32, 16); + assert!(r1 < r2); + + let r1 = Rank::ROOT; + let r2 = Rank::INFINITE; + assert!(r1 < r2); + } +} diff --git a/src/iface/rpl/relations.rs b/src/iface/rpl/relations.rs new file mode 100644 index 0000000..da02a3c --- /dev/null +++ b/src/iface/rpl/relations.rs @@ -0,0 +1,162 @@ +use crate::time::Instant; +use crate::wire::Ipv6Address; + +use crate::config::RPL_RELATIONS_BUFFER_COUNT; + +#[derive(Debug)] +pub struct Relation { + destination: Ipv6Address, + next_hop: Ipv6Address, + expiration: Instant, +} + +#[derive(Default, Debug)] +pub struct Relations { + relations: heapless::Vec<Relation, { RPL_RELATIONS_BUFFER_COUNT }>, +} + +impl Relations { + /// Add a new relation to the buffer. If there was already a relation in the buffer, then + /// update it. + pub fn add_relation( + &mut self, + destination: Ipv6Address, + next_hop: Ipv6Address, + expiration: Instant, + ) { + if let Some(r) = self + .relations + .iter_mut() + .find(|r| r.destination == destination) + { + r.next_hop = next_hop; + r.expiration = expiration; + } else { + let relation = Relation { + destination, + next_hop, + expiration, + }; + + if let Err(e) = self.relations.push(relation) { + net_debug!("Unable to add relation, buffer is full"); + } + } + } + + /// Remove all relation entries for a specific destination. + pub fn remove_relation(&mut self, destination: Ipv6Address) { + self.relations.retain(|r| r.destination != destination) + } + + /// Return the next hop for a specific IPv6 address, if there is one. + pub fn find_next_hop(&mut self, destination: Ipv6Address) -> Option<Ipv6Address> { + self.relations.iter().find_map(|r| { + if r.destination == destination { + Some(r.next_hop) + } else { + None + } + }) + } + + /// Purge expired relations. + pub fn purge(&mut self, now: Instant) { + self.relations.retain(|r| r.expiration > now) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::time::Duration; + + fn addresses(count: usize) -> Vec<Ipv6Address> { + (0..count) + .map(|i| { + let mut ip = Ipv6Address::default(); + ip.0[0] = i as u8; + ip + }) + .collect() + } + + #[test] + fn add_relation() { + let addrs = addresses(2); + + let mut relations = Relations::default(); + relations.add_relation(addrs[0], addrs[1], Instant::now()); + assert_eq!(relations.relations.len(), 1); + } + + #[test] + fn add_relations_full_buffer() { + let addrs = addresses(crate::config::RPL_RELATIONS_BUFFER_COUNT + 1); + + // Try to add RPL_RELATIONS_BUFFER_COUNT + 1 to the buffer. + // The size of the buffer should still be RPL_RELATIONS_BUFFER_COUNT. + let mut relations = Relations::default(); + for a in addrs { + relations.add_relation(a, a, Instant::now()); + } + + assert_eq!(relations.relations.len(), RPL_RELATIONS_BUFFER_COUNT); + } + + #[test] + fn update_relation() { + let addrs = addresses(3); + + let mut relations = Relations::default(); + relations.add_relation(addrs[0], addrs[1], Instant::now()); + assert_eq!(relations.relations.len(), 1); + + relations.add_relation(addrs[0], addrs[2], Instant::now()); + assert_eq!(relations.relations.len(), 1); + + assert_eq!(relations.find_next_hop(addrs[0]), Some(addrs[2])); + } + + #[test] + fn find_next_hop() { + let addrs = addresses(3); + + let mut relations = Relations::default(); + relations.add_relation(addrs[0], addrs[1], Instant::now()); + assert_eq!(relations.relations.len(), 1); + assert_eq!(relations.find_next_hop(addrs[0]), Some(addrs[1])); + + relations.add_relation(addrs[0], addrs[2], Instant::now()); + assert_eq!(relations.relations.len(), 1); + assert_eq!(relations.find_next_hop(addrs[0]), Some(addrs[2])); + + // Find the next hop of a destination not in the buffer. + assert_eq!(relations.find_next_hop(addrs[1]), None); + } + + #[test] + fn remove_relation() { + let addrs = addresses(2); + + let mut relations = Relations::default(); + relations.add_relation(addrs[0], addrs[1], Instant::now()); + assert_eq!(relations.relations.len(), 1); + + relations.remove_relation(addrs[0]); + assert!(relations.relations.is_empty()); + } + + #[test] + fn purge_relation() { + let addrs = addresses(2); + + let mut relations = Relations::default(); + relations.add_relation(addrs[0], addrs[1], Instant::now() - Duration::from_secs(1)); + + assert_eq!(relations.relations.len(), 1); + + relations.purge(Instant::now()); + assert!(relations.relations.is_empty()); + } +} diff --git a/src/iface/rpl/trickle.rs b/src/iface/rpl/trickle.rs new file mode 100644 index 0000000..a5b3b97 --- /dev/null +++ b/src/iface/rpl/trickle.rs @@ -0,0 +1,266 @@ +//! Implementation of the Trickle timer defined in [RFC 6206]. The algorithm allows node in a lossy +//! shared medium to exchange information in a highly robust, energy efficient, simple, and +//! scalable manner. Dynamically adjusting transmission windows allows Trickle to spread new +//! information fast while sending only a few messages per hour when information does not change. +//! +//! **NOTE**: the constants used for the default Trickle timer are the ones from the [Enhanced +//! Trickle]. +//! +//! [RFC 6206]: https://datatracker.ietf.org/doc/html/rfc6206 +//! [Enhanced Trickle]: https://d1wqtxts1xzle7.cloudfront.net/71402623/E-Trickle_Enhanced_Trickle_Algorithm_for20211005-2078-1ckh34a.pdf?1633439582=&response-content-disposition=inline%3B+filename%3DE_Trickle_Enhanced_Trickle_Algorithm_for.pdf&Expires=1681472005&Signature=cC7l-Pyr5r64XBNCDeSJ2ha6oqWUtO6A-KlDOyC0UVaHxDV3h3FuVHRtcNp3O9BUfRK8jeuWCYGBkCZgQT4Zgb6XwgVB-3z4TF9o3qBRMteRyYO5vjVkpPBeN7mz4Tl746SsSCHDm2NMtr7UVtLYamriU3D0rryoqLqJXmnkNoJpn~~wJe2H5PmPgIwixTwSvDkfFLSVoESaYS9ZWHZwbW-7G7OxIw8oSYhx9xMBnzkpdmT7sJNmvDzTUhoOjYrHTRM23cLVS9~oOSpT7hKtKD4h5CSmrNK4st07KnT9~tUqEcvGO3aXdd4quRZeKUcCkCbTLvhOEYg9~QqgD8xwhA__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA + +use crate::{ + rand::Rand, + time::{Duration, Instant}, +}; + +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) struct TrickleTimer { + i_min: u32, + i_max: u32, + k: usize, + + i: Duration, + t: Duration, + t_exp: Instant, + i_exp: Instant, + counter: usize, +} + +impl TrickleTimer { + /// Creat a new Trickle timer using the default values. + /// + /// **NOTE**: the standard defines I as a random value between [Imin, Imax]. However, this + /// could result in a t value that is very close to Imax. Therefore, sending DIO messages will + /// be sporadic, which is not ideal when a network is started. It might take a long time before + /// the network is actually stable. Therefore, we don't draw a random numberm but just use Imin + /// for I. This only affects the start of the RPL tree and speeds up building it. Also, we + /// don't use the default values from the standard, but the values from the _Enhanced Trickle + /// Algorithm for Low-Power and Lossy Networks_ from Baraq Ghaleb et al. This is also what the + /// Contiki Trickle timer does. + pub(crate) fn default(now: Instant, rand: &mut Rand) -> Self { + use super::consts::{ + DEFAULT_DIO_INTERVAL_DOUBLINGS, DEFAULT_DIO_INTERVAL_MIN, + DEFAULT_DIO_REDUNDANCY_CONSTANT, + }; + + Self::new( + DEFAULT_DIO_INTERVAL_MIN, + DEFAULT_DIO_INTERVAL_MIN + DEFAULT_DIO_INTERVAL_DOUBLINGS, + DEFAULT_DIO_REDUNDANCY_CONSTANT, + now, + rand, + ) + } + + /// Create a new Trickle timer. + pub(crate) fn new(i_min: u32, i_max: u32, k: usize, now: Instant, rand: &mut Rand) -> Self { + let mut timer = Self { + i_min, + i_max, + k, + i: Duration::ZERO, + t: Duration::ZERO, + t_exp: Instant::ZERO, + i_exp: Instant::ZERO, + counter: 0, + }; + + timer.i = Duration::from_millis(2u32.pow(timer.i_min) as u64); + timer.i_exp = now + timer.i; + timer.counter = 0; + + timer.set_t(now, rand); + + timer + } + + /// Poll the Trickle timer. Returns `true` when the Trickle timer signals that a message can be + /// transmitted. This happens when the Trickle timer expires. + pub(crate) fn poll(&mut self, now: Instant, rand: &mut Rand) -> bool { + let can_transmit = self.can_transmit() && self.t_expired(now); + + if can_transmit { + self.set_t(now, rand); + } + + if self.i_expired(now) { + self.expire(now, rand); + } + + can_transmit + } + + /// Returns the Instant at which the Trickle timer should be polled again. Polling the Trickle + /// timer before this Instant is not harmfull, however, polling after it is not correct. + pub(crate) fn poll_at(&self) -> Instant { + self.t_exp.min(self.i_exp) + } + + /// Signal the Trickle timer that a consistency has been heard, and thus increasing it's + /// counter. + pub(crate) fn hear_consistent(&mut self) { + self.counter += 1; + } + + /// Signal the Trickle timer that an inconsistency has been heard. This resets the Trickle + /// timer when the current interval is not the smallest possible. + pub(crate) fn hear_inconsistency(&mut self, now: Instant, rand: &mut Rand) { + let i = Duration::from_millis(2u32.pow(self.i_min) as u64); + if self.i > i { + self.reset(i, now, rand); + } + } + + /// Check if the Trickle timer can transmit or not. Returns `false` when the consistency + /// counter is bigger or equal to the default consistency constant. + pub(crate) fn can_transmit(&self) -> bool { + self.k != 0 && self.counter < self.k + } + + /// Reset the Trickle timer when the interval has expired. + fn expire(&mut self, now: Instant, rand: &mut Rand) { + let max_interval = Duration::from_millis(2u32.pow(self.i_max) as u64); + let i = if self.i >= max_interval { + max_interval + } else { + self.i + self.i + }; + + self.reset(i, now, rand); + } + + pub(crate) fn reset(&mut self, i: Duration, now: Instant, rand: &mut Rand) { + self.i = i; + self.i_exp = now + self.i; + self.counter = 0; + self.set_t(now, rand); + } + + pub(crate) const fn max_expiration(&self) -> Duration { + Duration::from_millis(2u32.pow(self.i_max) as u64) + } + + pub(crate) const fn min_expiration(&self) -> Duration { + Duration::from_millis(2u32.pow(self.i_min) as u64) + } + + fn set_t(&mut self, now: Instant, rand: &mut Rand) { + let t = Duration::from_micros( + self.i.total_micros() / 2 + + (rand.rand_u32() as u64 + % (self.i.total_micros() - self.i.total_micros() / 2 + 1)), + ); + + self.t = t; + self.t_exp = now + t; + } + + fn t_expired(&self, now: Instant) -> bool { + now >= self.t_exp + } + + fn i_expired(&self, now: Instant) -> bool { + now >= self.i_exp + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn trickle_timer_intervals() { + let mut rand = Rand::new(1234); + let mut now = Instant::ZERO; + let mut trickle = TrickleTimer::default(now, &mut rand); + + let mut previous_i = trickle.i; + + while now <= Instant::from_secs(100_000) { + trickle.poll(now, &mut rand); + + if now < Instant::ZERO + trickle.max_expiration() { + // t should always be inbetween I/2 and I. + assert!(trickle.i / 2 < trickle.t); + assert!(trickle.i > trickle.t); + } + + if previous_i != trickle.i { + // When a new Interval is selected, this should be double the previous one. + assert_eq!(previous_i * 2, trickle.i); + assert_eq!(trickle.counter, 0); + previous_i = trickle.i; + } + + now += Duration::from_millis(100); + } + } + + #[test] + fn trickle_timer_hear_inconsistency() { + let mut rand = Rand::new(1234); + let mut now = Instant::ZERO; + let mut trickle = TrickleTimer::default(now, &mut rand); + + trickle.counter = 1; + + while now <= Instant::from_secs(10_000) { + trickle.poll(now, &mut rand); + + if now < trickle.i_exp && now < Instant::ZERO + trickle.min_expiration() { + assert_eq!(trickle.counter, 1); + } else { + // The first interval expired, so the counter is reset. + assert_eq!(trickle.counter, 0); + } + + if now == Instant::from_secs(10) { + // We set the counter to 1 such that we can test the `hear_inconsistency`. + trickle.counter = 1; + + assert_eq!(trickle.counter, 1); + + trickle.hear_inconsistency(now, &mut rand); + + assert_eq!(trickle.counter, 0); + assert_eq!(trickle.i, trickle.min_expiration()); + } + + now += Duration::from_millis(100); + } + } + + #[test] + fn trickle_timer_hear_consistency() { + let mut rand = Rand::new(1234); + let mut now = Instant::ZERO; + let mut trickle = TrickleTimer::default(now, &mut rand); + + trickle.counter = 1; + + let mut transmit_counter = 0; + + while now <= Instant::from_secs(10_000) { + trickle.hear_consistent(); + + if trickle.poll(now, &mut rand) { + transmit_counter += 1; + } + + if now == Instant::from_secs(10_000) { + use super::super::consts::{ + DEFAULT_DIO_INTERVAL_DOUBLINGS, DEFAULT_DIO_REDUNDANCY_CONSTANT, + }; + assert!(!trickle.poll(now, &mut rand)); + assert!(trickle.counter > DEFAULT_DIO_REDUNDANCY_CONSTANT); + // We should never have transmitted since the counter was higher than the default + // redundancy constant. + assert_eq!(transmit_counter, 0); + } + + now += Duration::from_millis(100); + } + } +} diff --git a/src/iface/socket_meta.rs b/src/iface/socket_meta.rs new file mode 100644 index 0000000..82c9908 --- /dev/null +++ b/src/iface/socket_meta.rs @@ -0,0 +1,103 @@ +use super::SocketHandle; +use crate::{ + socket::PollAt, + time::{Duration, Instant}, + wire::IpAddress, +}; + +/// Neighbor dependency. +/// +/// This enum tracks whether the socket should be polled based on the neighbor +/// it is going to send packets to. +#[derive(Debug, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +enum NeighborState { + /// Socket can be polled immediately. + #[default] + Active, + /// Socket should not be polled until either `silent_until` passes or + /// `neighbor` appears in the neighbor cache. + Waiting { + neighbor: IpAddress, + silent_until: Instant, + }, +} + +/// Network socket metadata. +/// +/// This includes things that only external (to the socket, that is) code +/// is interested in, but which are more conveniently stored inside the socket +/// itself. +#[derive(Debug, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) struct Meta { + /// Handle of this socket within its enclosing `SocketSet`. + /// Mainly useful for debug output. + pub(crate) handle: SocketHandle, + /// See [NeighborState](struct.NeighborState.html). + neighbor_state: NeighborState, +} + +impl Meta { + /// Minimum delay between neighbor discovery requests for this particular + /// socket, in milliseconds. + /// + /// See also `iface::NeighborCache::SILENT_TIME`. + pub(crate) const DISCOVERY_SILENT_TIME: Duration = Duration::from_millis(1_000); + + pub(crate) fn poll_at<F>(&self, socket_poll_at: PollAt, has_neighbor: F) -> PollAt + where + F: Fn(IpAddress) -> bool, + { + match self.neighbor_state { + NeighborState::Active => socket_poll_at, + NeighborState::Waiting { neighbor, .. } if has_neighbor(neighbor) => socket_poll_at, + NeighborState::Waiting { silent_until, .. } => PollAt::Time(silent_until), + } + } + + pub(crate) fn egress_permitted<F>(&mut self, timestamp: Instant, has_neighbor: F) -> bool + where + F: Fn(IpAddress) -> bool, + { + match self.neighbor_state { + NeighborState::Active => true, + NeighborState::Waiting { + neighbor, + silent_until, + } => { + if has_neighbor(neighbor) { + net_trace!( + "{}: neighbor {} discovered, unsilencing", + self.handle, + neighbor + ); + self.neighbor_state = NeighborState::Active; + true + } else if timestamp >= silent_until { + net_trace!( + "{}: neighbor {} silence timer expired, rediscovering", + self.handle, + neighbor + ); + true + } else { + false + } + } + } + } + + pub(crate) fn neighbor_missing(&mut self, timestamp: Instant, neighbor: IpAddress) { + net_trace!( + "{}: neighbor {} missing, silencing until t+{}", + self.handle, + neighbor, + Self::DISCOVERY_SILENT_TIME + ); + self.neighbor_state = NeighborState::Waiting { + neighbor, + silent_until: timestamp + Self::DISCOVERY_SILENT_TIME, + }; + } +} diff --git a/src/iface/socket_set.rs b/src/iface/socket_set.rs new file mode 100644 index 0000000..be55fef --- /dev/null +++ b/src/iface/socket_set.rs @@ -0,0 +1,151 @@ +use core::fmt; +use managed::ManagedSlice; + +use super::socket_meta::Meta; +use crate::socket::{AnySocket, Socket}; + +/// Opaque struct with space for storing one socket. +/// +/// This is public so you can use it to allocate space for storing +/// sockets when creating an Interface. +#[derive(Debug, Default)] +pub struct SocketStorage<'a> { + inner: Option<Item<'a>>, +} + +impl<'a> SocketStorage<'a> { + pub const EMPTY: Self = Self { inner: None }; +} + +/// An item of a socket set. +#[derive(Debug)] +pub(crate) struct Item<'a> { + pub(crate) meta: Meta, + pub(crate) socket: Socket<'a>, +} + +/// A handle, identifying a socket in an Interface. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Hash)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct SocketHandle(usize); + +impl fmt::Display for SocketHandle { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "#{}", self.0) + } +} + +/// An extensible set of sockets. +/// +/// The lifetime `'a` is used when storing a `Socket<'a>`. If you're using +/// owned buffers for your sockets (passed in as `Vec`s) you can use +/// `SocketSet<'static>`. +#[derive(Debug)] +pub struct SocketSet<'a> { + sockets: ManagedSlice<'a, SocketStorage<'a>>, +} + +impl<'a> SocketSet<'a> { + /// Create a socket set using the provided storage. + pub fn new<SocketsT>(sockets: SocketsT) -> SocketSet<'a> + where + SocketsT: Into<ManagedSlice<'a, SocketStorage<'a>>>, + { + let sockets = sockets.into(); + SocketSet { sockets } + } + + /// Add a socket to the set, and return its handle. + /// + /// # Panics + /// This function panics if the storage is fixed-size (not a `Vec`) and is full. + pub fn add<T: AnySocket<'a>>(&mut self, socket: T) -> SocketHandle { + fn put<'a>(index: usize, slot: &mut SocketStorage<'a>, socket: Socket<'a>) -> SocketHandle { + net_trace!("[{}]: adding", index); + let handle = SocketHandle(index); + let mut meta = Meta::default(); + meta.handle = handle; + *slot = SocketStorage { + inner: Some(Item { meta, socket }), + }; + handle + } + + let socket = socket.upcast(); + + for (index, slot) in self.sockets.iter_mut().enumerate() { + if slot.inner.is_none() { + return put(index, slot, socket); + } + } + + match &mut self.sockets { + ManagedSlice::Borrowed(_) => panic!("adding a socket to a full SocketSet"), + #[cfg(feature = "alloc")] + ManagedSlice::Owned(sockets) => { + sockets.push(SocketStorage { inner: None }); + let index = sockets.len() - 1; + put(index, &mut sockets[index], socket) + } + } + } + + /// Get a socket from the set by its handle, as mutable. + /// + /// # Panics + /// This function may panic if the handle does not belong to this socket set + /// or the socket has the wrong type. + pub fn get<T: AnySocket<'a>>(&self, handle: SocketHandle) -> &T { + match self.sockets[handle.0].inner.as_ref() { + Some(item) => { + T::downcast(&item.socket).expect("handle refers to a socket of a wrong type") + } + None => panic!("handle does not refer to a valid socket"), + } + } + + /// Get a mutable socket from the set by its handle, as mutable. + /// + /// # Panics + /// This function may panic if the handle does not belong to this socket set + /// or the socket has the wrong type. + pub fn get_mut<T: AnySocket<'a>>(&mut self, handle: SocketHandle) -> &mut T { + match self.sockets[handle.0].inner.as_mut() { + Some(item) => T::downcast_mut(&mut item.socket) + .expect("handle refers to a socket of a wrong type"), + None => panic!("handle does not refer to a valid socket"), + } + } + + /// Remove a socket from the set, without changing its state. + /// + /// # Panics + /// This function may panic if the handle does not belong to this socket set. + pub fn remove(&mut self, handle: SocketHandle) -> Socket<'a> { + net_trace!("[{}]: removing", handle.0); + match self.sockets[handle.0].inner.take() { + Some(item) => item.socket, + None => panic!("handle does not refer to a valid socket"), + } + } + + /// Get an iterator to the inner sockets. + pub fn iter(&self) -> impl Iterator<Item = (SocketHandle, &Socket<'a>)> { + self.items().map(|i| (i.meta.handle, &i.socket)) + } + + /// Get a mutable iterator to the inner sockets. + pub fn iter_mut(&mut self) -> impl Iterator<Item = (SocketHandle, &mut Socket<'a>)> { + self.items_mut().map(|i| (i.meta.handle, &mut i.socket)) + } + + /// Iterate every socket in this set. + pub(crate) fn items(&self) -> impl Iterator<Item = &Item<'a>> + '_ { + self.sockets.iter().filter_map(|x| x.inner.as_ref()) + } + + /// Iterate every socket in this set. + pub(crate) fn items_mut(&mut self) -> impl Iterator<Item = &mut Item<'a>> + '_ { + self.sockets.iter_mut().filter_map(|x| x.inner.as_mut()) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..040ff57 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,180 @@ +#![cfg_attr(not(any(test, feature = "std")), no_std)] +#![deny(unsafe_code)] + +//! The _smoltcp_ library is built in a layered structure, with the layers corresponding +//! to the levels of API abstraction. Only the highest layers would be used by a typical +//! application; however, the goal of _smoltcp_ is not just to provide a simple interface +//! for writing applications but also to be a toolbox of networking primitives, so +//! every layer is fully exposed and documented. +//! +//! When discussing networking stacks and layering, often the [OSI model][osi] is invoked. +//! _smoltcp_ makes no effort to conform to the OSI model as it is not applicable to TCP/IP. +//! +//! # The socket layer +//! The socket layer APIs are provided in the module [socket](socket/index.html); currently, +//! raw, ICMP, TCP, and UDP sockets are provided. The socket API provides the usual primitives, +//! but necessarily differs in many from the [Berkeley socket API][berk], as the latter was +//! not designed to be used without heap allocation. +//! +//! The socket layer provides the buffering, packet construction and validation, and (for +//! stateful sockets) the state machines, but it is interface-agnostic. An application must +//! use sockets together with a network interface. +//! +//! # The interface layer +//! The interface layer APIs are provided in the module [iface](iface/index.html); currently, +//! Ethernet interface is provided. +//! +//! The interface layer handles the control messages, physical addressing and neighbor discovery. +//! It routes packets to and from sockets. +//! +//! # The physical layer +//! The physical layer APIs are provided in the module [phy](phy/index.html); currently, +//! raw socket and TAP interface are provided. In addition, two _middleware_ interfaces +//! are provided: the _tracer device_, which prints a human-readable representation of packets, +//! and the _fault injector device_, which randomly introduces errors into the transmitted +//! and received packet sequences. +//! +//! The physical layer handles interaction with a platform-specific network device. +//! +//! # The wire layers +//! Unlike the higher layers, the wire layer APIs will not be used by a typical application. +//! They however are the bedrock of _smoltcp_, and everything else is built on top of them. +//! +//! The wire layer APIs are designed by the principle "make illegal states ir-representable". +//! If a wire layer object can be constructed, then it can also be parsed from or emitted to +//! a lower level. +//! +//! The wire layer APIs also provide _tcpdump_-like pretty printing. +//! +//! ## The representation layer +//! The representation layer APIs are provided in the module [wire]. +//! +//! The representation layer exists to reduce the state space of raw packets. Raw packets +//! may be nonsensical in a multitude of ways: invalid checksums, impossible combinations of flags, +//! pointers to fields out of bounds, meaningless options... Representations shed all that, +//! as well as any features not supported by _smoltcp_. +//! +//! ## The packet layer +//! The packet layer APIs are also provided in the module [wire]. +//! +//! The packet layer exists to provide a more structured way to work with packets than +//! treating them as sequences of octets. It makes no judgement as to content of the packets, +//! except where necessary to provide safe access to fields, and strives to implement every +//! feature ever defined, to ensure that, when the representation layer is unable to make sense +//! of a packet, it is still logged correctly and in full. +//! +//! # Minimum Supported Rust Version (MSRV) +//! +//! This crate is guaranteed to compile on stable Rust 1.65 and up with any valid set of features. +//! It *might* compile on older versions but that may change in any new patch release. +//! +//! The exception is when using the `defmt` feature, in which case `defmt`'s MSRV applies, which +//! is higher. +//! +//! [wire]: wire/index.html +//! [osi]: https://en.wikipedia.org/wiki/OSI_model +//! [berk]: https://en.wikipedia.org/wiki/Berkeley_sockets + +/* XXX compiler bug +#![cfg(not(any(feature = "socket-raw", + feature = "socket-udp", + feature = "socket-tcp")))] +compile_error!("at least one socket needs to be enabled"); */ + +#![allow(clippy::match_like_matches_macro)] +#![allow(clippy::redundant_field_names)] +#![allow(clippy::identity_op)] +#![allow(clippy::option_map_unit_fn)] +#![allow(clippy::unit_arg)] +#![allow(clippy::new_without_default)] + +#[cfg(feature = "alloc")] +extern crate alloc; + +#[cfg(not(any( + feature = "proto-ipv4", + feature = "proto-ipv6", + feature = "proto-sixlowpan" +)))] +compile_error!("You must enable at least one of the following features: proto-ipv4, proto-ipv6, proto-sixlowpan"); + +#[cfg(all( + feature = "socket", + not(any( + feature = "socket-raw", + feature = "socket-udp", + feature = "socket-tcp", + feature = "socket-icmp", + feature = "socket-dhcpv4", + feature = "socket-dns", + )) +))] +compile_error!("If you enable the socket feature, you must enable at least one of the following features: socket-raw, socket-udp, socket-tcp, socket-icmp, socket-dhcpv4, socket-dns"); + +#[cfg(all( + feature = "socket", + not(any( + feature = "medium-ethernet", + feature = "medium-ip", + feature = "medium-ieee802154", + )) +))] +compile_error!("If you enable the socket feature, you must enable at least one of the following features: medium-ip, medium-ethernet, medium-ieee802154"); + +#[cfg(all(feature = "defmt", feature = "log"))] +compile_error!("You must enable at most one of the following features: defmt, log"); + +#[macro_use] +mod macros; +mod parsers; +mod rand; + +#[cfg(test)] +pub mod config { + #![allow(unused)] + pub const ASSEMBLER_MAX_SEGMENT_COUNT: usize = 4; + pub const DNS_MAX_NAME_SIZE: usize = 255; + pub const DNS_MAX_RESULT_COUNT: usize = 1; + pub const DNS_MAX_SERVER_COUNT: usize = 1; + pub const FRAGMENTATION_BUFFER_SIZE: usize = 1500; + pub const IFACE_MAX_ADDR_COUNT: usize = 8; + pub const IFACE_MAX_MULTICAST_GROUP_COUNT: usize = 4; + pub const IFACE_MAX_ROUTE_COUNT: usize = 4; + pub const IFACE_MAX_SIXLOWPAN_ADDRESS_CONTEXT_COUNT: usize = 4; + pub const IFACE_NEIGHBOR_CACHE_COUNT: usize = 3; + pub const REASSEMBLY_BUFFER_COUNT: usize = 4; + pub const REASSEMBLY_BUFFER_SIZE: usize = 1500; + pub const RPL_RELATIONS_BUFFER_COUNT: usize = 16; + pub const RPL_PARENTS_BUFFER_COUNT: usize = 8; + pub const IPV6_HBH_MAX_OPTIONS: usize = 2; +} + +#[cfg(not(test))] +pub mod config { + #![allow(unused)] + include!(concat!(env!("OUT_DIR"), "/config.rs")); +} + +#[cfg(any( + feature = "medium-ethernet", + feature = "medium-ip", + feature = "medium-ieee802154" +))] +pub mod iface; + +pub mod phy; +#[cfg(feature = "socket")] +pub mod socket; +pub mod storage; +pub mod time; +pub mod wire; + +#[cfg(all( + test, + any( + feature = "medium-ethernet", + feature = "medium-ip", + feature = "medium-ieee802154" + ) +))] +mod tests; diff --git a/src/macros.rs b/src/macros.rs new file mode 100644 index 0000000..e899d24 --- /dev/null +++ b/src/macros.rs @@ -0,0 +1,169 @@ +#[cfg(not(test))] +#[cfg(feature = "log")] +macro_rules! net_log { + (trace, $($arg:expr),*) => { log::trace!($($arg),*) }; + (debug, $($arg:expr),*) => { log::debug!($($arg),*) }; +} + +#[cfg(test)] +#[cfg(feature = "log")] +macro_rules! net_log { + (trace, $($arg:expr),*) => { println!($($arg),*) }; + (debug, $($arg:expr),*) => { println!($($arg),*) }; +} + +#[cfg(feature = "defmt")] +macro_rules! net_log { + (trace, $($arg:expr),*) => { defmt::trace!($($arg),*) }; + (debug, $($arg:expr),*) => { defmt::debug!($($arg),*) }; +} + +#[cfg(not(any(feature = "log", feature = "defmt")))] +macro_rules! net_log { + ($level:ident, $($arg:expr),*) => {{ $( let _ = $arg; )* }} +} + +macro_rules! net_trace { + ($($arg:expr),*) => (net_log!(trace, $($arg),*)); +} + +macro_rules! net_debug { + ($($arg:expr),*) => (net_log!(debug, $($arg),*)); +} + +macro_rules! enum_with_unknown { + ( + $( #[$enum_attr:meta] )* + pub enum $name:ident($ty:ty) { + $( + $( #[$variant_attr:meta] )* + $variant:ident = $value:expr + ),+ $(,)? + } + ) => { + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + $( #[$enum_attr] )* + pub enum $name { + $( + $( #[$variant_attr] )* + $variant + ),*, + Unknown($ty) + } + + impl ::core::convert::From<$ty> for $name { + fn from(value: $ty) -> Self { + match value { + $( $value => $name::$variant ),*, + other => $name::Unknown(other) + } + } + } + + impl ::core::convert::From<$name> for $ty { + fn from(value: $name) -> Self { + match value { + $( $name::$variant => $value ),*, + $name::Unknown(other) => other + } + } + } + } +} + +#[cfg(feature = "proto-rpl")] +macro_rules! get { + ($buffer:expr, into: $into:ty, fun: $fun:ident, field: $field:expr $(,)?) => { + { + <$into>::$fun(&$buffer.as_ref()[$field]) + } + }; + + ($buffer:expr, into: $into:ty, field: $field:expr $(,)?) => { + get!($buffer, into: $into, field: $field, shift: 0, mask: 0b1111_1111) + }; + + ($buffer:expr, into: $into:ty, field: $field:expr, mask: $bit_mask:expr $(,)?) => { + get!($buffer, into: $into, field: $field, shift: 0, mask: $bit_mask) + }; + + ($buffer:expr, into: $into:ty, field: $field:expr, shift: $bit_shift:expr, mask: $bit_mask:expr $(,)?) => { + { + <$into>::from((&$buffer.as_ref()[$field] >> $bit_shift) & $bit_mask) + } + }; + + ($buffer:expr, field: $field:expr $(,)?) => { + get!($buffer, field: $field, shift: 0, mask: 0b1111_1111) + }; + + ($buffer:expr, field: $field:expr, mask: $bit_mask:expr $(,)?) => { + get!($buffer, field: $field, shift: 0, mask: $bit_mask) + }; + + ($buffer:expr, field: $field:expr, shift: $bit_shift:expr, mask: $bit_mask:expr $(,)?) + => + { + { + (&$buffer.as_ref()[$field] >> $bit_shift) & $bit_mask + } + }; + + ($buffer:expr, u16, field: $field:expr $(,)?) => { + { + NetworkEndian::read_u16(&$buffer.as_ref()[$field]) + } + }; + + ($buffer:expr, bool, field: $field:expr, shift: $bit_shift:expr, mask: $bit_mask:expr $(,)?) => { + { + (($buffer.as_ref()[$field] >> $bit_shift) & $bit_mask) == 0b1 + } + }; + + ($buffer:expr, u32, field: $field:expr $(,)?) => { + { + NetworkEndian::read_u32(&$buffer.as_ref()[$field]) + } + }; +} + +#[cfg(feature = "proto-rpl")] +macro_rules! set { + ($buffer:expr, address: $address:ident, field: $field:expr $(,)?) => {{ + $buffer.as_mut()[$field].copy_from_slice($address.as_bytes()); + }}; + + ($buffer:expr, $value:ident, field: $field:expr $(,)?) => { + set!($buffer, $value, field: $field, shift: 0, mask: 0b1111_1111) + }; + + ($buffer:expr, $value:ident, field: $field:expr, mask: $bit_mask:expr $(,)?) => { + set!($buffer, $value, field: $field, shift: 0, mask: $bit_mask) + }; + + ($buffer:expr, $value:ident, field: $field:expr, shift: $bit_shift:expr, mask: $bit_mask:expr $(,)?) => {{ + let raw = + ($buffer.as_ref()[$field] & !($bit_mask << $bit_shift)) | ($value << $bit_shift); + $buffer.as_mut()[$field] = raw; + }}; + + ($buffer:expr, $value:ident, bool, field: $field:expr, mask: $bit_mask:expr $(,)?) => { + set!($buffer, $value, bool, field: $field, shift: 0, mask: $bit_mask); + }; + + ($buffer:expr, $value:ident, bool, field: $field:expr, shift: $bit_shift:expr, mask: $bit_mask:expr $(,)?) => {{ + let raw = ($buffer.as_ref()[$field] & !($bit_mask << $bit_shift)) + | (if $value { 0b1 } else { 0b0 } << $bit_shift); + $buffer.as_mut()[$field] = raw; + }}; + + ($buffer:expr, $value:ident, u16, field: $field:expr $(,)?) => {{ + NetworkEndian::write_u16(&mut $buffer.as_mut()[$field], $value); + }}; + + ($buffer:expr, $value:ident, u32, field: $field:expr $(,)?) => {{ + NetworkEndian::write_u32(&mut $buffer.as_mut()[$field], $value); + }}; +} diff --git a/src/parsers.rs b/src/parsers.rs new file mode 100644 index 0000000..16419ab --- /dev/null +++ b/src/parsers.rs @@ -0,0 +1,765 @@ +#![cfg_attr( + not(all(feature = "proto-ipv6", feature = "proto-ipv4")), + allow(dead_code) +)] + +use core::result; +use core::str::FromStr; + +#[cfg(feature = "medium-ethernet")] +use crate::wire::EthernetAddress; +use crate::wire::{IpAddress, IpCidr, IpEndpoint}; +#[cfg(feature = "proto-ipv4")] +use crate::wire::{Ipv4Address, Ipv4Cidr}; +#[cfg(feature = "proto-ipv6")] +use crate::wire::{Ipv6Address, Ipv6Cidr}; + +type Result<T> = result::Result<T, ()>; + +struct Parser<'a> { + data: &'a [u8], + pos: usize, +} + +impl<'a> Parser<'a> { + fn new(data: &'a str) -> Parser<'a> { + Parser { + data: data.as_bytes(), + pos: 0, + } + } + + fn lookahead_char(&self, ch: u8) -> bool { + if self.pos < self.data.len() { + self.data[self.pos] == ch + } else { + false + } + } + + fn advance(&mut self) -> Result<u8> { + match self.data.get(self.pos) { + Some(&chr) => { + self.pos += 1; + Ok(chr) + } + None => Err(()), + } + } + + fn try_do<F, T>(&mut self, f: F) -> Option<T> + where + F: FnOnce(&mut Parser<'a>) -> Result<T>, + { + let pos = self.pos; + match f(self) { + Ok(res) => Some(res), + Err(()) => { + self.pos = pos; + None + } + } + } + + fn accept_eof(&mut self) -> Result<()> { + if self.data.len() == self.pos { + Ok(()) + } else { + Err(()) + } + } + + fn until_eof<F, T>(&mut self, f: F) -> Result<T> + where + F: FnOnce(&mut Parser<'a>) -> Result<T>, + { + let res = f(self)?; + self.accept_eof()?; + Ok(res) + } + + fn accept_char(&mut self, chr: u8) -> Result<()> { + if self.advance()? == chr { + Ok(()) + } else { + Err(()) + } + } + + fn accept_str(&mut self, string: &[u8]) -> Result<()> { + for byte in string.iter() { + self.accept_char(*byte)?; + } + Ok(()) + } + + fn accept_digit(&mut self, hex: bool) -> Result<u8> { + let digit = self.advance()?; + if digit.is_ascii_digit() { + Ok(digit - b'0') + } else if hex && (b'a'..=b'f').contains(&digit) { + Ok(digit - b'a' + 10) + } else if hex && (b'A'..=b'F').contains(&digit) { + Ok(digit - b'A' + 10) + } else { + Err(()) + } + } + + fn accept_number(&mut self, max_digits: usize, max_value: u32, hex: bool) -> Result<u32> { + let mut value = self.accept_digit(hex)? as u32; + for _ in 1..max_digits { + match self.try_do(|p| p.accept_digit(hex)) { + Some(digit) => { + value *= if hex { 16 } else { 10 }; + value += digit as u32; + } + None => break, + } + } + if value < max_value { + Ok(value) + } else { + Err(()) + } + } + + #[cfg(feature = "medium-ethernet")] + fn accept_mac_joined_with(&mut self, separator: u8) -> Result<EthernetAddress> { + let mut octets = [0u8; 6]; + for (n, octet) in octets.iter_mut().enumerate() { + *octet = self.accept_number(2, 0x100, true)? as u8; + if n != 5 { + self.accept_char(separator)?; + } + } + Ok(EthernetAddress(octets)) + } + + #[cfg(feature = "medium-ethernet")] + fn accept_mac(&mut self) -> Result<EthernetAddress> { + if let Some(mac) = self.try_do(|p| p.accept_mac_joined_with(b'-')) { + return Ok(mac); + } + if let Some(mac) = self.try_do(|p| p.accept_mac_joined_with(b':')) { + return Ok(mac); + } + Err(()) + } + + #[cfg(feature = "proto-ipv6")] + fn accept_ipv4_mapped_ipv6_part(&mut self, parts: &mut [u16], idx: &mut usize) -> Result<()> { + let octets = self.accept_ipv4_octets()?; + + parts[*idx] = ((octets[0] as u16) << 8) | (octets[1] as u16); + *idx += 1; + parts[*idx] = ((octets[2] as u16) << 8) | (octets[3] as u16); + *idx += 1; + + Ok(()) + } + + #[cfg(feature = "proto-ipv6")] + fn accept_ipv6_part( + &mut self, + (head, tail): (&mut [u16; 8], &mut [u16; 6]), + (head_idx, tail_idx): (&mut usize, &mut usize), + mut use_tail: bool, + ) -> Result<()> { + let double_colon = match self.try_do(|p| p.accept_str(b"::")) { + Some(_) if !use_tail && *head_idx < 7 => { + // Found a double colon. Start filling out the + // tail and set the double colon flag in case + // this is the last character we can parse. + use_tail = true; + true + } + Some(_) => { + // This is a bad address. Only one double colon is + // allowed and an address is only 128 bits. + return Err(()); + } + None => { + if *head_idx != 0 || use_tail && *tail_idx != 0 { + // If this is not the first number or the position following + // a double colon, we expect there to be a single colon. + self.accept_char(b':')?; + } + false + } + }; + + match self.try_do(|p| p.accept_number(4, 0x10000, true)) { + Some(part) if !use_tail && *head_idx < 8 => { + // Valid u16 to be added to the address + head[*head_idx] = part as u16; + *head_idx += 1; + + if *head_idx == 6 && head[0..*head_idx] == [0, 0, 0, 0, 0, 0xffff] { + self.try_do(|p| { + p.accept_char(b':')?; + p.accept_ipv4_mapped_ipv6_part(head, head_idx) + }); + } + Ok(()) + } + Some(part) if *tail_idx < 6 => { + // Valid u16 to be added to the address + tail[*tail_idx] = part as u16; + *tail_idx += 1; + + if *tail_idx == 1 && tail[0] == 0xffff && head[0..8] == [0, 0, 0, 0, 0, 0, 0, 0] { + self.try_do(|p| { + p.accept_char(b':')?; + p.accept_ipv4_mapped_ipv6_part(tail, tail_idx) + }); + } + Ok(()) + } + Some(_) => { + // Tail or head section is too long + Err(()) + } + None if double_colon => { + // The address ends with "::". E.g. 1234:: or :: + Ok(()) + } + None => { + // Invalid address + Err(()) + } + }?; + + if *head_idx + *tail_idx > 8 { + // The head and tail indexes add up to a bad address length. + Err(()) + } else if !self.lookahead_char(b':') { + if *head_idx < 8 && !use_tail { + // There was no double colon found, and the head is too short + return Err(()); + } + Ok(()) + } else { + // Continue recursing + self.accept_ipv6_part((head, tail), (head_idx, tail_idx), use_tail) + } + } + + #[cfg(feature = "proto-ipv6")] + fn accept_ipv6(&mut self) -> Result<Ipv6Address> { + // IPv6 addresses may contain a "::" to indicate a series of + // 16 bit sections that evaluate to 0. E.g. + // + // fe80:0000:0000:0000:0000:0000:0000:0001 + // + // May be written as + // + // fe80::1 + // + // As a result, we need to find the first section of colon + // delimited u16's before a possible "::", then the + // possible second section after the "::", and finally + // combine the second optional section to the end of the + // final address. + // + // See https://tools.ietf.org/html/rfc4291#section-2.2 + // for details. + let (mut addr, mut tail) = ([0u16; 8], [0u16; 6]); + let (mut head_idx, mut tail_idx) = (0, 0); + + self.accept_ipv6_part( + (&mut addr, &mut tail), + (&mut head_idx, &mut tail_idx), + false, + )?; + + // We need to copy the tail portion (the portion following the "::") to the + // end of the address. + addr[8 - tail_idx..].copy_from_slice(&tail[..tail_idx]); + + Ok(Ipv6Address::from_parts(&addr)) + } + + fn accept_ipv4_octets(&mut self) -> Result<[u8; 4]> { + let mut octets = [0u8; 4]; + for (n, octet) in octets.iter_mut().enumerate() { + *octet = self.accept_number(3, 0x100, false)? as u8; + if n != 3 { + self.accept_char(b'.')?; + } + } + Ok(octets) + } + + #[cfg(feature = "proto-ipv4")] + fn accept_ipv4(&mut self) -> Result<Ipv4Address> { + let octets = self.accept_ipv4_octets()?; + Ok(Ipv4Address(octets)) + } + + fn accept_ip(&mut self) -> Result<IpAddress> { + #[cfg(feature = "proto-ipv4")] + #[allow(clippy::single_match)] + match self.try_do(|p| p.accept_ipv4()) { + Some(ipv4) => return Ok(IpAddress::Ipv4(ipv4)), + None => (), + } + + #[cfg(feature = "proto-ipv6")] + #[allow(clippy::single_match)] + match self.try_do(|p| p.accept_ipv6()) { + Some(ipv6) => return Ok(IpAddress::Ipv6(ipv6)), + None => (), + } + + Err(()) + } + + #[cfg(feature = "proto-ipv4")] + fn accept_ipv4_endpoint(&mut self) -> Result<IpEndpoint> { + let ip = self.accept_ipv4()?; + + let port = if self.accept_eof().is_ok() { + 0 + } else { + self.accept_char(b':')?; + self.accept_number(5, 65535, false)? + }; + + Ok(IpEndpoint { + addr: IpAddress::Ipv4(ip), + port: port as u16, + }) + } + + #[cfg(feature = "proto-ipv6")] + fn accept_ipv6_endpoint(&mut self) -> Result<IpEndpoint> { + if self.lookahead_char(b'[') { + self.accept_char(b'[')?; + let ip = self.accept_ipv6()?; + self.accept_char(b']')?; + self.accept_char(b':')?; + let port = self.accept_number(5, 65535, false)?; + + Ok(IpEndpoint { + addr: IpAddress::Ipv6(ip), + port: port as u16, + }) + } else { + let ip = self.accept_ipv6()?; + Ok(IpEndpoint { + addr: IpAddress::Ipv6(ip), + port: 0, + }) + } + } + + fn accept_ip_endpoint(&mut self) -> Result<IpEndpoint> { + #[cfg(feature = "proto-ipv4")] + #[allow(clippy::single_match)] + match self.try_do(|p| p.accept_ipv4_endpoint()) { + Some(ipv4) => return Ok(ipv4), + None => (), + } + + #[cfg(feature = "proto-ipv6")] + #[allow(clippy::single_match)] + match self.try_do(|p| p.accept_ipv6_endpoint()) { + Some(ipv6) => return Ok(ipv6), + None => (), + } + + Err(()) + } +} + +#[cfg(feature = "medium-ethernet")] +impl FromStr for EthernetAddress { + type Err = (); + + /// Parse a string representation of an Ethernet address. + fn from_str(s: &str) -> Result<EthernetAddress> { + Parser::new(s).until_eof(|p| p.accept_mac()) + } +} + +#[cfg(feature = "proto-ipv4")] +impl FromStr for Ipv4Address { + type Err = (); + + /// Parse a string representation of an IPv4 address. + fn from_str(s: &str) -> Result<Ipv4Address> { + Parser::new(s).until_eof(|p| p.accept_ipv4()) + } +} + +#[cfg(feature = "proto-ipv6")] +impl FromStr for Ipv6Address { + type Err = (); + + /// Parse a string representation of an IPv6 address. + fn from_str(s: &str) -> Result<Ipv6Address> { + Parser::new(s).until_eof(|p| p.accept_ipv6()) + } +} + +impl FromStr for IpAddress { + type Err = (); + + /// Parse a string representation of an IP address. + fn from_str(s: &str) -> Result<IpAddress> { + Parser::new(s).until_eof(|p| p.accept_ip()) + } +} + +#[cfg(feature = "proto-ipv4")] +impl FromStr for Ipv4Cidr { + type Err = (); + + /// Parse a string representation of an IPv4 CIDR. + fn from_str(s: &str) -> Result<Ipv4Cidr> { + Parser::new(s).until_eof(|p| { + let ip = p.accept_ipv4()?; + p.accept_char(b'/')?; + let prefix_len = p.accept_number(2, 33, false)? as u8; + Ok(Ipv4Cidr::new(ip, prefix_len)) + }) + } +} + +#[cfg(feature = "proto-ipv6")] +impl FromStr for Ipv6Cidr { + type Err = (); + + /// Parse a string representation of an IPv6 CIDR. + fn from_str(s: &str) -> Result<Ipv6Cidr> { + // https://tools.ietf.org/html/rfc4291#section-2.3 + Parser::new(s).until_eof(|p| { + let ip = p.accept_ipv6()?; + p.accept_char(b'/')?; + let prefix_len = p.accept_number(3, 129, false)? as u8; + Ok(Ipv6Cidr::new(ip, prefix_len)) + }) + } +} + +impl FromStr for IpCidr { + type Err = (); + + /// Parse a string representation of an IP CIDR. + fn from_str(s: &str) -> Result<IpCidr> { + #[cfg(feature = "proto-ipv4")] + #[allow(clippy::single_match)] + match Ipv4Cidr::from_str(s) { + Ok(cidr) => return Ok(IpCidr::Ipv4(cidr)), + Err(_) => (), + } + + #[cfg(feature = "proto-ipv6")] + #[allow(clippy::single_match)] + match Ipv6Cidr::from_str(s) { + Ok(cidr) => return Ok(IpCidr::Ipv6(cidr)), + Err(_) => (), + } + + Err(()) + } +} + +impl FromStr for IpEndpoint { + type Err = (); + + fn from_str(s: &str) -> Result<IpEndpoint> { + Parser::new(s).until_eof(|p| p.accept_ip_endpoint()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + macro_rules! check_cidr_test_array { + ($tests:expr, $from_str:path, $variant:path) => { + for &(s, cidr) in &$tests { + assert_eq!($from_str(s), cidr); + assert_eq!(IpCidr::from_str(s), cidr.map($variant)); + + if let Ok(cidr) = cidr { + assert_eq!($from_str(&format!("{}", cidr)), Ok(cidr)); + assert_eq!(IpCidr::from_str(&format!("{}", cidr)), Ok($variant(cidr))); + } + } + }; + } + + #[test] + #[cfg(all(feature = "proto-ipv4", feature = "medium-ethernet"))] + fn test_mac() { + assert_eq!(EthernetAddress::from_str(""), Err(())); + assert_eq!( + EthernetAddress::from_str("02:00:00:00:00:00"), + Ok(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x00])) + ); + assert_eq!( + EthernetAddress::from_str("01:23:45:67:89:ab"), + Ok(EthernetAddress([0x01, 0x23, 0x45, 0x67, 0x89, 0xab])) + ); + assert_eq!( + EthernetAddress::from_str("cd:ef:10:00:00:00"), + Ok(EthernetAddress([0xcd, 0xef, 0x10, 0x00, 0x00, 0x00])) + ); + assert_eq!( + EthernetAddress::from_str("00:00:00:ab:cd:ef"), + Ok(EthernetAddress([0x00, 0x00, 0x00, 0xab, 0xcd, 0xef])) + ); + assert_eq!( + EthernetAddress::from_str("00-00-00-ab-cd-ef"), + Ok(EthernetAddress([0x00, 0x00, 0x00, 0xab, 0xcd, 0xef])) + ); + assert_eq!( + EthernetAddress::from_str("AB-CD-EF-00-00-00"), + Ok(EthernetAddress([0xab, 0xcd, 0xef, 0x00, 0x00, 0x00])) + ); + assert_eq!(EthernetAddress::from_str("100:00:00:00:00:00"), Err(())); + assert_eq!(EthernetAddress::from_str("002:00:00:00:00:00"), Err(())); + assert_eq!(EthernetAddress::from_str("02:00:00:00:00:000"), Err(())); + assert_eq!(EthernetAddress::from_str("02:00:00:00:00:0x"), Err(())); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_ipv4() { + assert_eq!(Ipv4Address::from_str(""), Err(())); + assert_eq!( + Ipv4Address::from_str("1.2.3.4"), + Ok(Ipv4Address([1, 2, 3, 4])) + ); + assert_eq!( + Ipv4Address::from_str("001.2.3.4"), + Ok(Ipv4Address([1, 2, 3, 4])) + ); + assert_eq!(Ipv4Address::from_str("0001.2.3.4"), Err(())); + assert_eq!(Ipv4Address::from_str("999.2.3.4"), Err(())); + assert_eq!(Ipv4Address::from_str("1.2.3.4.5"), Err(())); + assert_eq!(Ipv4Address::from_str("1.2.3"), Err(())); + assert_eq!(Ipv4Address::from_str("1.2.3."), Err(())); + assert_eq!(Ipv4Address::from_str("1.2.3.4."), Err(())); + } + + #[test] + #[cfg(feature = "proto-ipv6")] + fn test_ipv6() { + // Obviously not valid + assert_eq!(Ipv6Address::from_str(""), Err(())); + assert_eq!( + Ipv6Address::from_str("fe80:0:0:0:0:0:0:1"), + Ok(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)) + ); + assert_eq!(Ipv6Address::from_str("::1"), Ok(Ipv6Address::LOOPBACK)); + assert_eq!(Ipv6Address::from_str("::"), Ok(Ipv6Address::UNSPECIFIED)); + assert_eq!( + Ipv6Address::from_str("fe80::1"), + Ok(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)) + ); + assert_eq!( + Ipv6Address::from_str("1234:5678::"), + Ok(Ipv6Address::new(0x1234, 0x5678, 0, 0, 0, 0, 0, 0)) + ); + assert_eq!( + Ipv6Address::from_str("1234:5678::8765:4321"), + Ok(Ipv6Address::new(0x1234, 0x5678, 0, 0, 0, 0, 0x8765, 0x4321)) + ); + // Two double colons in address + assert_eq!(Ipv6Address::from_str("1234:5678::1::1"), Err(())); + assert_eq!( + Ipv6Address::from_str("4444:333:22:1::4"), + Ok(Ipv6Address::new(0x4444, 0x0333, 0x0022, 0x0001, 0, 0, 0, 4)) + ); + assert_eq!( + Ipv6Address::from_str("1:1:1:1:1:1::"), + Ok(Ipv6Address::new(1, 1, 1, 1, 1, 1, 0, 0)) + ); + assert_eq!( + Ipv6Address::from_str("::1:1:1:1:1:1"), + Ok(Ipv6Address::new(0, 0, 1, 1, 1, 1, 1, 1)) + ); + assert_eq!(Ipv6Address::from_str("::1:1:1:1:1:1:1"), Err(())); + // Double colon appears too late indicating an address that is too long + assert_eq!(Ipv6Address::from_str("1:1:1:1:1:1:1::"), Err(())); + // Section after double colon is too long for a valid address + assert_eq!(Ipv6Address::from_str("::1:1:1:1:1:1:1"), Err(())); + // Obviously too long + assert_eq!(Ipv6Address::from_str("1:1:1:1:1:1:1:1:1"), Err(())); + // Address is too short + assert_eq!(Ipv6Address::from_str("1:1:1:1:1:1:1"), Err(())); + // Long number + assert_eq!(Ipv6Address::from_str("::000001"), Err(())); + // IPv4-Mapped address + assert_eq!( + Ipv6Address::from_str("::ffff:192.168.1.1"), + Ok(Ipv6Address([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1 + ])) + ); + assert_eq!( + Ipv6Address::from_str("0:0:0:0:0:ffff:192.168.1.1"), + Ok(Ipv6Address([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1 + ])) + ); + assert_eq!( + Ipv6Address::from_str("0::ffff:192.168.1.1"), + Ok(Ipv6Address([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1 + ])) + ); + // Only ffff is allowed in position 6 when IPv4 mapped + assert_eq!(Ipv6Address::from_str("0:0:0:0:0:eeee:192.168.1.1"), Err(())); + // Positions 1-5 must be 0 when IPv4 mapped + assert_eq!(Ipv6Address::from_str("0:0:0:0:1:ffff:192.168.1.1"), Err(())); + assert_eq!(Ipv6Address::from_str("1::ffff:192.168.1.1"), Err(())); + // Out of range ipv4 octet + assert_eq!(Ipv6Address::from_str("0:0:0:0:0:ffff:256.168.1.1"), Err(())); + // Invalid hex in ipv4 octet + assert_eq!(Ipv6Address::from_str("0:0:0:0:0:ffff:c0.168.1.1"), Err(())); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_ip_ipv4() { + assert_eq!(IpAddress::from_str(""), Err(())); + assert_eq!( + IpAddress::from_str("1.2.3.4"), + Ok(IpAddress::Ipv4(Ipv4Address([1, 2, 3, 4]))) + ); + assert_eq!(IpAddress::from_str("x"), Err(())); + } + + #[test] + #[cfg(feature = "proto-ipv6")] + fn test_ip_ipv6() { + assert_eq!(IpAddress::from_str(""), Err(())); + assert_eq!( + IpAddress::from_str("fe80::1"), + Ok(IpAddress::Ipv6(Ipv6Address::new( + 0xfe80, 0, 0, 0, 0, 0, 0, 1 + ))) + ); + assert_eq!(IpAddress::from_str("x"), Err(())); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_cidr_ipv4() { + let tests = [ + ( + "127.0.0.1/8", + Ok(Ipv4Cidr::new(Ipv4Address([127, 0, 0, 1]), 8u8)), + ), + ( + "192.168.1.1/24", + Ok(Ipv4Cidr::new(Ipv4Address([192, 168, 1, 1]), 24u8)), + ), + ( + "8.8.8.8/32", + Ok(Ipv4Cidr::new(Ipv4Address([8, 8, 8, 8]), 32u8)), + ), + ( + "8.8.8.8/0", + Ok(Ipv4Cidr::new(Ipv4Address([8, 8, 8, 8]), 0u8)), + ), + ("", Err(())), + ("1", Err(())), + ("127.0.0.1", Err(())), + ("127.0.0.1/", Err(())), + ("127.0.0.1/33", Err(())), + ("127.0.0.1/111", Err(())), + ("/32", Err(())), + ]; + + check_cidr_test_array!(tests, Ipv4Cidr::from_str, IpCidr::Ipv4); + } + + #[test] + #[cfg(feature = "proto-ipv6")] + fn test_cidr_ipv6() { + let tests = [ + ( + "fe80::1/64", + Ok(Ipv6Cidr::new( + Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), + 64u8, + )), + ), + ( + "fe80::/64", + Ok(Ipv6Cidr::new( + Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0), + 64u8, + )), + ), + ("::1/128", Ok(Ipv6Cidr::new(Ipv6Address::LOOPBACK, 128u8))), + ("::/128", Ok(Ipv6Cidr::new(Ipv6Address::UNSPECIFIED, 128u8))), + ( + "fe80:0:0:0:0:0:0:1/64", + Ok(Ipv6Cidr::new( + Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), + 64u8, + )), + ), + ("fe80:0:0:0:0:0:0:1|64", Err(())), + ("fe80::|64", Err(())), + ("fe80::1::/64", Err(())), + ]; + check_cidr_test_array!(tests, Ipv6Cidr::from_str, IpCidr::Ipv6); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_endpoint_ipv4() { + assert_eq!(IpEndpoint::from_str(""), Err(())); + assert_eq!(IpEndpoint::from_str("x"), Err(())); + assert_eq!( + IpEndpoint::from_str("127.0.0.1"), + Ok(IpEndpoint { + addr: IpAddress::v4(127, 0, 0, 1), + port: 0 + }) + ); + assert_eq!( + IpEndpoint::from_str("127.0.0.1:12345"), + Ok(IpEndpoint { + addr: IpAddress::v4(127, 0, 0, 1), + port: 12345 + }) + ); + } + + #[test] + #[cfg(feature = "proto-ipv6")] + fn test_endpoint_ipv6() { + assert_eq!(IpEndpoint::from_str(""), Err(())); + assert_eq!(IpEndpoint::from_str("x"), Err(())); + assert_eq!( + IpEndpoint::from_str("fe80::1"), + Ok(IpEndpoint { + addr: IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), + port: 0 + }) + ); + assert_eq!( + IpEndpoint::from_str("[fe80::1]:12345"), + Ok(IpEndpoint { + addr: IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), + port: 12345 + }) + ); + assert_eq!( + IpEndpoint::from_str("[::]:12345"), + Ok(IpEndpoint { + addr: IpAddress::v6(0, 0, 0, 0, 0, 0, 0, 0), + port: 12345 + }) + ); + } +} diff --git a/src/phy/fault_injector.rs b/src/phy/fault_injector.rs new file mode 100644 index 0000000..fffe11a --- /dev/null +++ b/src/phy/fault_injector.rs @@ -0,0 +1,330 @@ +use crate::phy::{self, Device, DeviceCapabilities}; +use crate::time::{Duration, Instant}; + +use super::PacketMeta; + +// We use our own RNG to stay compatible with #![no_std]. +// The use of the RNG below has a slight bias, but it doesn't matter. +fn xorshift32(state: &mut u32) -> u32 { + let mut x = *state; + x ^= x << 13; + x ^= x >> 17; + x ^= x << 5; + *state = x; + x +} + +// This could be fixed once associated consts are stable. +const MTU: usize = 1536; + +#[derive(Debug, Default, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct Config { + corrupt_pct: u8, + drop_pct: u8, + max_size: usize, + max_tx_rate: u64, + max_rx_rate: u64, + interval: Duration, +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct State { + rng_seed: u32, + refilled_at: Instant, + tx_bucket: u64, + rx_bucket: u64, +} + +impl State { + fn maybe(&mut self, pct: u8) -> bool { + xorshift32(&mut self.rng_seed) % 100 < pct as u32 + } + + fn corrupt<T: AsMut<[u8]>>(&mut self, mut buffer: T) { + let buffer = buffer.as_mut(); + // We introduce a single bitflip, as the most likely, and the hardest to detect, error. + let index = (xorshift32(&mut self.rng_seed) as usize) % buffer.len(); + let bit = 1 << (xorshift32(&mut self.rng_seed) % 8) as u8; + buffer[index] ^= bit; + } + + fn refill(&mut self, config: &Config, timestamp: Instant) { + if timestamp - self.refilled_at > config.interval { + self.tx_bucket = config.max_tx_rate; + self.rx_bucket = config.max_rx_rate; + self.refilled_at = timestamp; + } + } + + fn maybe_transmit(&mut self, config: &Config, timestamp: Instant) -> bool { + if config.max_tx_rate == 0 { + return true; + } + + self.refill(config, timestamp); + if self.tx_bucket > 0 { + self.tx_bucket -= 1; + true + } else { + false + } + } + + fn maybe_receive(&mut self, config: &Config, timestamp: Instant) -> bool { + if config.max_rx_rate == 0 { + return true; + } + + self.refill(config, timestamp); + if self.rx_bucket > 0 { + self.rx_bucket -= 1; + true + } else { + false + } + } +} + +/// A fault injector device. +/// +/// A fault injector is a device that alters packets traversing through it to simulate +/// adverse network conditions (such as random packet loss or corruption), or software +/// or hardware limitations (such as a limited number or size of usable network buffers). +#[derive(Debug)] +pub struct FaultInjector<D: Device> { + inner: D, + state: State, + config: Config, + rx_buf: [u8; MTU], +} + +impl<D: Device> FaultInjector<D> { + /// Create a fault injector device, using the given random number generator seed. + pub fn new(inner: D, seed: u32) -> FaultInjector<D> { + FaultInjector { + inner, + state: State { + rng_seed: seed, + refilled_at: Instant::from_millis(0), + tx_bucket: 0, + rx_bucket: 0, + }, + config: Config::default(), + rx_buf: [0u8; MTU], + } + } + + /// Return the underlying device, consuming the fault injector. + pub fn into_inner(self) -> D { + self.inner + } + + /// Return the probability of corrupting a packet, in percents. + pub fn corrupt_chance(&self) -> u8 { + self.config.corrupt_pct + } + + /// Return the probability of dropping a packet, in percents. + pub fn drop_chance(&self) -> u8 { + self.config.drop_pct + } + + /// Return the maximum packet size, in octets. + pub fn max_packet_size(&self) -> usize { + self.config.max_size + } + + /// Return the maximum packet transmission rate, in packets per second. + pub fn max_tx_rate(&self) -> u64 { + self.config.max_tx_rate + } + + /// Return the maximum packet reception rate, in packets per second. + pub fn max_rx_rate(&self) -> u64 { + self.config.max_rx_rate + } + + /// Return the interval for packet rate limiting, in milliseconds. + pub fn bucket_interval(&self) -> Duration { + self.config.interval + } + + /// Set the probability of corrupting a packet, in percents. + /// + /// # Panics + /// This function panics if the probability is not between 0% and 100%. + pub fn set_corrupt_chance(&mut self, pct: u8) { + if pct > 100 { + panic!("percentage out of range") + } + self.config.corrupt_pct = pct + } + + /// Set the probability of dropping a packet, in percents. + /// + /// # Panics + /// This function panics if the probability is not between 0% and 100%. + pub fn set_drop_chance(&mut self, pct: u8) { + if pct > 100 { + panic!("percentage out of range") + } + self.config.drop_pct = pct + } + + /// Set the maximum packet size, in octets. + pub fn set_max_packet_size(&mut self, size: usize) { + self.config.max_size = size + } + + /// Set the maximum packet transmission rate, in packets per interval. + pub fn set_max_tx_rate(&mut self, rate: u64) { + self.config.max_tx_rate = rate + } + + /// Set the maximum packet reception rate, in packets per interval. + pub fn set_max_rx_rate(&mut self, rate: u64) { + self.config.max_rx_rate = rate + } + + /// Set the interval for packet rate limiting, in milliseconds. + pub fn set_bucket_interval(&mut self, interval: Duration) { + self.state.refilled_at = Instant::from_millis(0); + self.config.interval = interval + } +} + +impl<D: Device> Device for FaultInjector<D> { + type RxToken<'a> = RxToken<'a> + where + Self: 'a; + type TxToken<'a> = TxToken<'a, D::TxToken<'a>> + where + Self: 'a; + + fn capabilities(&self) -> DeviceCapabilities { + let mut caps = self.inner.capabilities(); + if caps.max_transmission_unit > MTU { + caps.max_transmission_unit = MTU; + } + caps + } + + fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let (rx_token, tx_token) = self.inner.receive(timestamp)?; + let rx_meta = <D::RxToken<'_> as phy::RxToken>::meta(&rx_token); + + let len = super::RxToken::consume(rx_token, |buffer| { + if (self.config.max_size > 0 && buffer.len() > self.config.max_size) + || buffer.len() > self.rx_buf.len() + { + net_trace!("rx: dropping a packet that is too large"); + return None; + } + self.rx_buf[..buffer.len()].copy_from_slice(buffer); + Some(buffer.len()) + })?; + + let buf = &mut self.rx_buf[..len]; + + if self.state.maybe(self.config.drop_pct) { + net_trace!("rx: randomly dropping a packet"); + return None; + } + + if !self.state.maybe_receive(&self.config, timestamp) { + net_trace!("rx: dropping a packet because of rate limiting"); + return None; + } + + if self.state.maybe(self.config.corrupt_pct) { + net_trace!("rx: randomly corrupting a packet"); + self.state.corrupt(&mut buf[..]); + } + + let rx = RxToken { buf, meta: rx_meta }; + let tx = TxToken { + state: &mut self.state, + config: self.config, + token: tx_token, + junk: [0; MTU], + timestamp, + }; + Some((rx, tx)) + } + + fn transmit(&mut self, timestamp: Instant) -> Option<Self::TxToken<'_>> { + self.inner.transmit(timestamp).map(|token| TxToken { + state: &mut self.state, + config: self.config, + token, + junk: [0; MTU], + timestamp, + }) + } +} + +#[doc(hidden)] +pub struct RxToken<'a> { + buf: &'a mut [u8], + meta: PacketMeta, +} + +impl<'a> phy::RxToken for RxToken<'a> { + fn consume<R, F>(self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + f(self.buf) + } + + fn meta(&self) -> phy::PacketMeta { + self.meta + } +} + +#[doc(hidden)] +pub struct TxToken<'a, Tx: phy::TxToken> { + state: &'a mut State, + config: Config, + token: Tx, + junk: [u8; MTU], + timestamp: Instant, +} + +impl<'a, Tx: phy::TxToken> phy::TxToken for TxToken<'a, Tx> { + fn consume<R, F>(mut self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let drop = if self.state.maybe(self.config.drop_pct) { + net_trace!("tx: randomly dropping a packet"); + true + } else if self.config.max_size > 0 && len > self.config.max_size { + net_trace!("tx: dropping a packet that is too large"); + true + } else if !self.state.maybe_transmit(&self.config, self.timestamp) { + net_trace!("tx: dropping a packet because of rate limiting"); + true + } else { + false + }; + + if drop { + return f(&mut self.junk[..len]); + } + + self.token.consume(len, |mut buf| { + if self.state.maybe(self.config.corrupt_pct) { + net_trace!("tx: corrupting a packet"); + self.state.corrupt(&mut buf) + } + f(buf) + }) + } + + fn set_meta(&mut self, meta: PacketMeta) { + self.token.set_meta(meta); + } +} diff --git a/src/phy/fuzz_injector.rs b/src/phy/fuzz_injector.rs new file mode 100644 index 0000000..6769d8e --- /dev/null +++ b/src/phy/fuzz_injector.rs @@ -0,0 +1,129 @@ +use crate::phy::{self, Device, DeviceCapabilities}; +use crate::time::Instant; + +// This could be fixed once associated consts are stable. +const MTU: usize = 1536; + +/// Represents a fuzzer. It is expected to replace bytes in the packet with fuzzed data. +pub trait Fuzzer { + /// Modify a single packet with fuzzed data. + fn fuzz_packet(&self, packet_data: &mut [u8]); +} + +/// A fuzz injector device. +/// +/// A fuzz injector is a device that alters packets traversing through it according to the +/// directions of a guided fuzzer. It is designed to support fuzzing internal state machines inside +/// smoltcp, and is not for production use. +#[allow(unused)] +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct FuzzInjector<D: Device, FTx: Fuzzer, FRx: Fuzzer> { + inner: D, + fuzz_tx: FTx, + fuzz_rx: FRx, +} + +#[allow(unused)] +impl<D: Device, FTx: Fuzzer, FRx: Fuzzer> FuzzInjector<D, FTx, FRx> { + /// Create a fuzz injector device. + pub fn new(inner: D, fuzz_tx: FTx, fuzz_rx: FRx) -> FuzzInjector<D, FTx, FRx> { + FuzzInjector { + inner, + fuzz_tx, + fuzz_rx, + } + } + + /// Return the underlying device, consuming the fuzz injector. + pub fn into_inner(self) -> D { + self.inner + } +} + +impl<D: Device, FTx, FRx> Device for FuzzInjector<D, FTx, FRx> +where + FTx: Fuzzer, + FRx: Fuzzer, +{ + type RxToken<'a> = RxToken<'a, D::RxToken<'a>, FRx> + where + Self: 'a; + type TxToken<'a> = TxToken<'a, D::TxToken<'a>, FTx> + where + Self: 'a; + + fn capabilities(&self) -> DeviceCapabilities { + let mut caps = self.inner.capabilities(); + if caps.max_transmission_unit > MTU { + caps.max_transmission_unit = MTU; + } + caps + } + + fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + self.inner.receive(timestamp).map(|(rx_token, tx_token)| { + let rx = RxToken { + fuzzer: &mut self.fuzz_rx, + token: rx_token, + }; + let tx = TxToken { + fuzzer: &mut self.fuzz_tx, + token: tx_token, + }; + (rx, tx) + }) + } + + fn transmit(&mut self, timestamp: Instant) -> Option<Self::TxToken<'_>> { + self.inner.transmit(timestamp).map(|token| TxToken { + fuzzer: &mut self.fuzz_tx, + token: token, + }) + } +} + +#[doc(hidden)] +pub struct RxToken<'a, Rx: phy::RxToken, F: Fuzzer + 'a> { + fuzzer: &'a F, + token: Rx, +} + +impl<'a, Rx: phy::RxToken, FRx: Fuzzer> phy::RxToken for RxToken<'a, Rx, FRx> { + fn consume<R, F>(self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + self.token.consume(|buffer| { + self.fuzzer.fuzz_packet(buffer); + f(buffer) + }) + } + + fn meta(&self) -> phy::PacketMeta { + self.token.meta() + } +} + +#[doc(hidden)] +pub struct TxToken<'a, Tx: phy::TxToken, F: Fuzzer + 'a> { + fuzzer: &'a F, + token: Tx, +} + +impl<'a, Tx: phy::TxToken, FTx: Fuzzer> phy::TxToken for TxToken<'a, Tx, FTx> { + fn consume<R, F>(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + self.token.consume(len, |buf| { + let result = f(buf); + self.fuzzer.fuzz_packet(buf); + result + }) + } + + fn set_meta(&mut self, meta: phy::PacketMeta) { + self.token.set_meta(meta) + } +} diff --git a/src/phy/loopback.rs b/src/phy/loopback.rs new file mode 100644 index 0000000..1f57c0c --- /dev/null +++ b/src/phy/loopback.rs @@ -0,0 +1,88 @@ +use alloc::collections::VecDeque; +use alloc::vec::Vec; + +use crate::phy::{self, Device, DeviceCapabilities, Medium}; +use crate::time::Instant; + +/// A loopback device. +#[derive(Debug)] +pub struct Loopback { + pub(crate) queue: VecDeque<Vec<u8>>, + medium: Medium, +} + +#[allow(clippy::new_without_default)] +impl Loopback { + /// Creates a loopback device. + /// + /// Every packet transmitted through this device will be received through it + /// in FIFO order. + pub fn new(medium: Medium) -> Loopback { + Loopback { + queue: VecDeque::new(), + medium, + } + } +} + +impl Device for Loopback { + type RxToken<'a> = RxToken; + type TxToken<'a> = TxToken<'a>; + + fn capabilities(&self) -> DeviceCapabilities { + DeviceCapabilities { + max_transmission_unit: 65535, + medium: self.medium, + ..DeviceCapabilities::default() + } + } + + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + self.queue.pop_front().map(move |buffer| { + let rx = RxToken { buffer }; + let tx = TxToken { + queue: &mut self.queue, + }; + (rx, tx) + }) + } + + fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> { + Some(TxToken { + queue: &mut self.queue, + }) + } +} + +#[doc(hidden)] +pub struct RxToken { + buffer: Vec<u8>, +} + +impl phy::RxToken for RxToken { + fn consume<R, F>(mut self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + f(&mut self.buffer) + } +} + +#[doc(hidden)] +#[derive(Debug)] +pub struct TxToken<'a> { + queue: &'a mut VecDeque<Vec<u8>>, +} + +impl<'a> phy::TxToken for TxToken<'a> { + fn consume<R, F>(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut buffer = Vec::new(); + buffer.resize(len, 0); + let result = f(&mut buffer); + self.queue.push_back(buffer); + result + } +} diff --git a/src/phy/mod.rs b/src/phy/mod.rs new file mode 100644 index 0000000..c3845d9 --- /dev/null +++ b/src/phy/mod.rs @@ -0,0 +1,398 @@ +/*! Access to networking hardware. + +The `phy` module deals with the *network devices*. It provides a trait +for transmitting and receiving frames, [Device](trait.Device.html) +and implementations of it: + + * the [_loopback_](struct.Loopback.html), for zero dependency testing; + * _middleware_ [Tracer](struct.Tracer.html) and + [FaultInjector](struct.FaultInjector.html), to facilitate debugging; + * _adapters_ [RawSocket](struct.RawSocket.html) and + [TunTapInterface](struct.TunTapInterface.html), to transmit and receive frames + on the host OS. +*/ +#![cfg_attr( + feature = "medium-ethernet", + doc = r##" +# Examples + +An implementation of the [Device](trait.Device.html) trait for a simple hardware +Ethernet controller could look as follows: + +```rust +use smoltcp::phy::{self, DeviceCapabilities, Device, Medium}; +use smoltcp::time::Instant; + +struct StmPhy { + rx_buffer: [u8; 1536], + tx_buffer: [u8; 1536], +} + +impl<'a> StmPhy { + fn new() -> StmPhy { + StmPhy { + rx_buffer: [0; 1536], + tx_buffer: [0; 1536], + } + } +} + +impl phy::Device for StmPhy { + type RxToken<'a> = StmPhyRxToken<'a> where Self: 'a; + type TxToken<'a> = StmPhyTxToken<'a> where Self: 'a; + + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + Some((StmPhyRxToken(&mut self.rx_buffer[..]), + StmPhyTxToken(&mut self.tx_buffer[..]))) + } + + fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> { + Some(StmPhyTxToken(&mut self.tx_buffer[..])) + } + + fn capabilities(&self) -> DeviceCapabilities { + let mut caps = DeviceCapabilities::default(); + caps.max_transmission_unit = 1536; + caps.max_burst_size = Some(1); + caps.medium = Medium::Ethernet; + caps + } +} + +struct StmPhyRxToken<'a>(&'a mut [u8]); + +impl<'a> phy::RxToken for StmPhyRxToken<'a> { + fn consume<R, F>(mut self, f: F) -> R + where F: FnOnce(&mut [u8]) -> R + { + // TODO: receive packet into buffer + let result = f(&mut self.0); + println!("rx called"); + result + } +} + +struct StmPhyTxToken<'a>(&'a mut [u8]); + +impl<'a> phy::TxToken for StmPhyTxToken<'a> { + fn consume<R, F>(self, len: usize, f: F) -> R + where F: FnOnce(&mut [u8]) -> R + { + let result = f(&mut self.0[..len]); + println!("tx called {}", len); + // TODO: send packet out + result + } +} +``` +"## +)] + +use crate::time::Instant; + +#[cfg(all( + any(feature = "phy-raw_socket", feature = "phy-tuntap_interface"), + unix +))] +mod sys; + +mod fault_injector; +mod fuzz_injector; +#[cfg(feature = "alloc")] +mod loopback; +mod pcap_writer; +#[cfg(all(feature = "phy-raw_socket", unix))] +mod raw_socket; +mod tracer; +#[cfg(all( + feature = "phy-tuntap_interface", + any(target_os = "linux", target_os = "android") +))] +mod tuntap_interface; + +#[cfg(all( + any(feature = "phy-raw_socket", feature = "phy-tuntap_interface"), + unix +))] +pub use self::sys::wait; + +pub use self::fault_injector::FaultInjector; +pub use self::fuzz_injector::{FuzzInjector, Fuzzer}; +#[cfg(feature = "alloc")] +pub use self::loopback::Loopback; +pub use self::pcap_writer::{PcapLinkType, PcapMode, PcapSink, PcapWriter}; +#[cfg(all(feature = "phy-raw_socket", unix))] +pub use self::raw_socket::RawSocket; +pub use self::tracer::Tracer; +#[cfg(all( + feature = "phy-tuntap_interface", + any(target_os = "linux", target_os = "android") +))] +pub use self::tuntap_interface::TunTapInterface; + +/// Metadata associated to a packet. +/// +/// The packet metadata is a set of attributes associated to network packets +/// as they travel up or down the stack. The metadata is get/set by the +/// [`Device`] implementations or by the user when sending/receiving packets from a +/// socket. +/// +/// Metadata fields are enabled via Cargo features. If no field is enabled, this +/// struct becomes zero-sized, which allows the compiler to optimize it out as if +/// the packet metadata mechanism didn't exist at all. +/// +/// Currently only UDP sockets allow setting/retrieving packet metadata. The metadata +/// for packets emitted with other sockets will be all default values. +/// +/// This struct is marked as `#[non_exhaustive]`. This means it is not possible to +/// create it directly by specifying all fields. You have to instead create it with +/// default values and then set the fields you want. This makes adding metadata +/// fields a non-breaking change. +/// +/// ```rust +/// let mut meta = smoltcp::phy::PacketMeta::default(); +/// #[cfg(feature = "packetmeta-id")] +/// { +/// meta.id = 15; +/// } +/// ``` +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Default)] +#[non_exhaustive] +pub struct PacketMeta { + #[cfg(feature = "packetmeta-id")] + pub id: u32, +} + +/// A description of checksum behavior for a particular protocol. +#[derive(Debug, Clone, Copy, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Checksum { + /// Verify checksum when receiving and compute checksum when sending. + #[default] + Both, + /// Verify checksum when receiving. + Rx, + /// Compute checksum before sending. + Tx, + /// Ignore checksum completely. + None, +} + +impl Checksum { + /// Returns whether checksum should be verified when receiving. + pub fn rx(&self) -> bool { + match *self { + Checksum::Both | Checksum::Rx => true, + _ => false, + } + } + + /// Returns whether checksum should be verified when sending. + pub fn tx(&self) -> bool { + match *self { + Checksum::Both | Checksum::Tx => true, + _ => false, + } + } +} + +/// A description of checksum behavior for every supported protocol. +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] +pub struct ChecksumCapabilities { + pub ipv4: Checksum, + pub udp: Checksum, + pub tcp: Checksum, + #[cfg(feature = "proto-ipv4")] + pub icmpv4: Checksum, + #[cfg(feature = "proto-ipv6")] + pub icmpv6: Checksum, +} + +impl ChecksumCapabilities { + /// Checksum behavior that results in not computing or verifying checksums + /// for any of the supported protocols. + pub fn ignored() -> Self { + ChecksumCapabilities { + ipv4: Checksum::None, + udp: Checksum::None, + tcp: Checksum::None, + #[cfg(feature = "proto-ipv4")] + icmpv4: Checksum::None, + #[cfg(feature = "proto-ipv6")] + icmpv6: Checksum::None, + } + } +} + +/// A description of device capabilities. +/// +/// Higher-level protocols may achieve higher throughput or lower latency if they consider +/// the bandwidth or packet size limitations. +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] +pub struct DeviceCapabilities { + /// Medium of the device. + /// + /// This indicates what kind of packet the sent/received bytes are, and determines + /// some behaviors of Interface. For example, ARP/NDISC address resolution is only done + /// for Ethernet mediums. + pub medium: Medium, + + /// Maximum transmission unit. + /// + /// The network device is unable to send or receive frames larger than the value returned + /// by this function. + /// + /// For Ethernet devices, this is the maximum Ethernet frame size, including the Ethernet header (14 octets), but + /// *not* including the Ethernet FCS (4 octets). Therefore, Ethernet MTU = IP MTU + 14. + /// + /// Note that in Linux and other OSes, "MTU" is the IP MTU, not the Ethernet MTU, even for Ethernet + /// devices. This is a common source of confusion. + /// + /// Most common IP MTU is 1500. Minimum is 576 (for IPv4) or 1280 (for IPv6). Maximum is 9216 octets. + pub max_transmission_unit: usize, + + /// Maximum burst size, in terms of MTU. + /// + /// The network device is unable to send or receive bursts large than the value returned + /// by this function. + /// + /// If `None`, there is no fixed limit on burst size, e.g. if network buffers are + /// dynamically allocated. + pub max_burst_size: Option<usize>, + + /// Checksum behavior. + /// + /// If the network device is capable of verifying or computing checksums for some protocols, + /// it can request that the stack not do so in software to improve performance. + pub checksum: ChecksumCapabilities, +} + +impl DeviceCapabilities { + pub fn ip_mtu(&self) -> usize { + match self.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => { + self.max_transmission_unit - crate::wire::EthernetFrame::<&[u8]>::header_len() + } + #[cfg(feature = "medium-ip")] + Medium::Ip => self.max_transmission_unit, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => self.max_transmission_unit, // TODO(thvdveld): what is the MTU for Medium::IEEE802 + } + } +} + +/// Type of medium of a device. +#[derive(Debug, Eq, PartialEq, Copy, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Medium { + /// Ethernet medium. Devices of this type send and receive Ethernet frames, + /// and interfaces using it must do neighbor discovery via ARP or NDISC. + /// + /// Examples of devices of this type are Ethernet, WiFi (802.11), Linux `tap`, and VPNs in tap (layer 2) mode. + #[cfg(feature = "medium-ethernet")] + Ethernet, + + /// IP medium. Devices of this type send and receive IP frames, without an + /// Ethernet header. MAC addresses are not used, and no neighbor discovery (ARP, NDISC) is done. + /// + /// Examples of devices of this type are the Linux `tun`, PPP interfaces, VPNs in tun (layer 3) mode. + #[cfg(feature = "medium-ip")] + Ip, + + #[cfg(feature = "medium-ieee802154")] + Ieee802154, +} + +impl Default for Medium { + fn default() -> Medium { + #[cfg(feature = "medium-ethernet")] + return Medium::Ethernet; + #[cfg(all(feature = "medium-ip", not(feature = "medium-ethernet")))] + return Medium::Ip; + #[cfg(all( + feature = "medium-ieee802154", + not(feature = "medium-ip"), + not(feature = "medium-ethernet") + ))] + return Medium::Ieee802154; + #[cfg(all( + not(feature = "medium-ip"), + not(feature = "medium-ethernet"), + not(feature = "medium-ieee802154") + ))] + return panic!("No medium enabled"); + } +} + +/// An interface for sending and receiving raw network frames. +/// +/// The interface is based on _tokens_, which are types that allow to receive/transmit a +/// single packet. The `receive` and `transmit` functions only construct such tokens, the +/// real sending/receiving operation are performed when the tokens are consumed. +pub trait Device { + type RxToken<'a>: RxToken + where + Self: 'a; + type TxToken<'a>: TxToken + where + Self: 'a; + + /// Construct a token pair consisting of one receive token and one transmit token. + /// + /// The additional transmit token makes it possible to generate a reply packet based + /// on the contents of the received packet. For example, this makes it possible to + /// handle arbitrarily large ICMP echo ("ping") requests, where the all received bytes + /// need to be sent back, without heap allocation. + /// + /// The timestamp must be a number of milliseconds, monotonically increasing since an + /// arbitrary moment in time, such as system startup. + fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)>; + + /// Construct a transmit token. + /// + /// The timestamp must be a number of milliseconds, monotonically increasing since an + /// arbitrary moment in time, such as system startup. + fn transmit(&mut self, timestamp: Instant) -> Option<Self::TxToken<'_>>; + + /// Get a description of device capabilities. + fn capabilities(&self) -> DeviceCapabilities; +} + +/// A token to receive a single network packet. +pub trait RxToken { + /// Consumes the token to receive a single network packet. + /// + /// This method receives a packet and then calls the given closure `f` with the raw + /// packet bytes as argument. + fn consume<R, F>(self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R; + + /// The Packet ID associated with the frame received by this [`RxToken`] + fn meta(&self) -> PacketMeta { + PacketMeta::default() + } +} + +/// A token to transmit a single network packet. +pub trait TxToken { + /// Consumes the token to send a single network packet. + /// + /// This method constructs a transmit buffer of size `len` and calls the passed + /// closure `f` with a mutable reference to that buffer. The closure should construct + /// a valid network packet (e.g. an ethernet packet) in the buffer. When the closure + /// returns, the transmit buffer is sent out. + fn consume<R, F>(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R; + + /// The Packet ID to be associated with the frame to be transmitted by this [`TxToken`]. + #[allow(unused_variables)] + fn set_meta(&mut self, meta: PacketMeta) {} +} diff --git a/src/phy/pcap_writer.rs b/src/phy/pcap_writer.rs new file mode 100644 index 0000000..aadf2a2 --- /dev/null +++ b/src/phy/pcap_writer.rs @@ -0,0 +1,268 @@ +use byteorder::{ByteOrder, NativeEndian}; +use core::cell::RefCell; +use phy::Medium; +#[cfg(feature = "std")] +use std::io::Write; + +use crate::phy::{self, Device, DeviceCapabilities}; +use crate::time::Instant; + +enum_with_unknown! { + /// Captured packet header type. + pub enum PcapLinkType(u32) { + /// Ethernet frames + Ethernet = 1, + /// IPv4 or IPv6 packets (depending on the version field) + Ip = 101, + /// IEEE 802.15.4 packets without FCS. + Ieee802154WithoutFcs = 230, + } +} + +/// Packet capture mode. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum PcapMode { + /// Capture both received and transmitted packets. + Both, + /// Capture only received packets. + RxOnly, + /// Capture only transmitted packets. + TxOnly, +} + +/// A packet capture sink. +pub trait PcapSink { + /// Write data into the sink. + fn write(&mut self, data: &[u8]); + + /// Flush data written into the sync. + fn flush(&mut self) {} + + /// Write an `u16` into the sink, in native byte order. + fn write_u16(&mut self, value: u16) { + let mut bytes = [0u8; 2]; + NativeEndian::write_u16(&mut bytes, value); + self.write(&bytes[..]) + } + + /// Write an `u32` into the sink, in native byte order. + fn write_u32(&mut self, value: u32) { + let mut bytes = [0u8; 4]; + NativeEndian::write_u32(&mut bytes, value); + self.write(&bytes[..]) + } + + /// Write the libpcap global header into the sink. + /// + /// This method may be overridden e.g. if special synchronization is necessary. + fn global_header(&mut self, link_type: PcapLinkType) { + self.write_u32(0xa1b2c3d4); // magic number + self.write_u16(2); // major version + self.write_u16(4); // minor version + self.write_u32(0); // timezone (= UTC) + self.write_u32(0); // accuracy (not used) + self.write_u32(65535); // maximum packet length + self.write_u32(link_type.into()); // link-layer header type + } + + /// Write the libpcap packet header into the sink. + /// + /// See also the note for [global_header](#method.global_header). + /// + /// # Panics + /// This function panics if `length` is greater than 65535. + fn packet_header(&mut self, timestamp: Instant, length: usize) { + assert!(length <= 65535); + + self.write_u32(timestamp.secs() as u32); // timestamp seconds + self.write_u32(timestamp.micros() as u32); // timestamp microseconds + self.write_u32(length as u32); // captured length + self.write_u32(length as u32); // original length + } + + /// Write the libpcap packet header followed by packet data into the sink. + /// + /// See also the note for [global_header](#method.global_header). + fn packet(&mut self, timestamp: Instant, packet: &[u8]) { + self.packet_header(timestamp, packet.len()); + self.write(packet); + self.flush(); + } +} + +#[cfg(feature = "std")] +impl<T: Write> PcapSink for T { + fn write(&mut self, data: &[u8]) { + T::write_all(self, data).expect("cannot write") + } + + fn flush(&mut self) { + T::flush(self).expect("cannot flush") + } +} + +/// A packet capture writer device. +/// +/// Every packet transmitted or received through this device is timestamped +/// and written (in the [libpcap] format) using the provided [sink]. +/// Note that writes are fine-grained, and buffering is recommended. +/// +/// The packet sink should be cheaply cloneable, as it is cloned on every +/// transmitted packet. For example, `&'a mut Vec<u8>` is cheaply cloneable +/// but `&std::io::File` +/// +/// [libpcap]: https://wiki.wireshark.org/Development/LibpcapFileFormat +/// [sink]: trait.PcapSink.html +#[derive(Debug)] +pub struct PcapWriter<D, S> +where + D: Device, + S: PcapSink, +{ + lower: D, + sink: RefCell<S>, + mode: PcapMode, +} + +impl<D: Device, S: PcapSink> PcapWriter<D, S> { + /// Creates a packet capture writer. + pub fn new(lower: D, mut sink: S, mode: PcapMode) -> PcapWriter<D, S> { + let medium = lower.capabilities().medium; + let link_type = match medium { + #[cfg(feature = "medium-ip")] + Medium::Ip => PcapLinkType::Ip, + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => PcapLinkType::Ethernet, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => PcapLinkType::Ieee802154WithoutFcs, + }; + sink.global_header(link_type); + PcapWriter { + lower, + sink: RefCell::new(sink), + mode, + } + } + + /// Get a reference to the underlying device. + /// + /// Even if the device offers reading through a standard reference, it is inadvisable to + /// directly read from the device as doing so will circumvent the packet capture. + pub fn get_ref(&self) -> &D { + &self.lower + } + + /// Get a mutable reference to the underlying device. + /// + /// It is inadvisable to directly read from the device as doing so will circumvent the packet capture. + pub fn get_mut(&mut self) -> &mut D { + &mut self.lower + } +} + +impl<D: Device, S> Device for PcapWriter<D, S> +where + S: PcapSink, +{ + type RxToken<'a> = RxToken<'a, D::RxToken<'a>, S> + where + Self: 'a; + type TxToken<'a> = TxToken<'a, D::TxToken<'a>, S> + where + Self: 'a; + + fn capabilities(&self) -> DeviceCapabilities { + self.lower.capabilities() + } + + fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let sink = &self.sink; + let mode = self.mode; + self.lower + .receive(timestamp) + .map(move |(rx_token, tx_token)| { + let rx = RxToken { + token: rx_token, + sink, + mode, + timestamp, + }; + let tx = TxToken { + token: tx_token, + sink, + mode, + timestamp, + }; + (rx, tx) + }) + } + + fn transmit(&mut self, timestamp: Instant) -> Option<Self::TxToken<'_>> { + let sink = &self.sink; + let mode = self.mode; + self.lower.transmit(timestamp).map(move |token| TxToken { + token, + sink, + mode, + timestamp, + }) + } +} + +#[doc(hidden)] +pub struct RxToken<'a, Rx: phy::RxToken, S: PcapSink> { + token: Rx, + sink: &'a RefCell<S>, + mode: PcapMode, + timestamp: Instant, +} + +impl<'a, Rx: phy::RxToken, S: PcapSink> phy::RxToken for RxToken<'a, Rx, S> { + fn consume<R, F: FnOnce(&mut [u8]) -> R>(self, f: F) -> R { + self.token.consume(|buffer| { + match self.mode { + PcapMode::Both | PcapMode::RxOnly => self + .sink + .borrow_mut() + .packet(self.timestamp, buffer.as_ref()), + PcapMode::TxOnly => (), + } + f(buffer) + }) + } + + fn meta(&self) -> phy::PacketMeta { + self.token.meta() + } +} + +#[doc(hidden)] +pub struct TxToken<'a, Tx: phy::TxToken, S: PcapSink> { + token: Tx, + sink: &'a RefCell<S>, + mode: PcapMode, + timestamp: Instant, +} + +impl<'a, Tx: phy::TxToken, S: PcapSink> phy::TxToken for TxToken<'a, Tx, S> { + fn consume<R, F>(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + self.token.consume(len, |buffer| { + let result = f(buffer); + match self.mode { + PcapMode::Both | PcapMode::TxOnly => { + self.sink.borrow_mut().packet(self.timestamp, buffer) + } + PcapMode::RxOnly => (), + }; + result + }) + } + + fn set_meta(&mut self, meta: phy::PacketMeta) { + self.token.set_meta(meta) + } +} diff --git a/src/phy/raw_socket.rs b/src/phy/raw_socket.rs new file mode 100644 index 0000000..19c5b98 --- /dev/null +++ b/src/phy/raw_socket.rs @@ -0,0 +1,137 @@ +use std::cell::RefCell; +use std::io; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::rc::Rc; +use std::vec::Vec; + +use crate::phy::{self, sys, Device, DeviceCapabilities, Medium}; +use crate::time::Instant; + +/// A socket that captures or transmits the complete frame. +#[derive(Debug)] +pub struct RawSocket { + medium: Medium, + lower: Rc<RefCell<sys::RawSocketDesc>>, + mtu: usize, +} + +impl AsRawFd for RawSocket { + fn as_raw_fd(&self) -> RawFd { + self.lower.borrow().as_raw_fd() + } +} + +impl RawSocket { + /// Creates a raw socket, bound to the interface called `name`. + /// + /// This requires superuser privileges or a corresponding capability bit + /// set on the executable. + pub fn new(name: &str, medium: Medium) -> io::Result<RawSocket> { + let mut lower = sys::RawSocketDesc::new(name, medium)?; + lower.bind_interface()?; + + let mut mtu = lower.interface_mtu()?; + + #[cfg(feature = "medium-ieee802154")] + if medium == Medium::Ieee802154 { + // SIOCGIFMTU returns 127 - (ACK_PSDU - FCS - 1) - FCS. + // 127 - (5 - 2 - 1) - 2 = 123 + // For IEEE802154, we want to add (ACK_PSDU - FCS - 1), since that is what SIOCGIFMTU + // uses as the size of the link layer header. + // + // https://github.com/torvalds/linux/blob/7475e51b87969e01a6812eac713a1c8310372e8a/net/mac802154/iface.c#L541 + mtu += 2; + } + + #[cfg(feature = "medium-ethernet")] + if medium == Medium::Ethernet { + // SIOCGIFMTU returns the IP MTU (typically 1500 bytes.) + // smoltcp counts the entire Ethernet packet in the MTU, so add the Ethernet header size to it. + mtu += crate::wire::EthernetFrame::<&[u8]>::header_len() + } + + Ok(RawSocket { + medium, + lower: Rc::new(RefCell::new(lower)), + mtu, + }) + } +} + +impl Device for RawSocket { + type RxToken<'a> = RxToken + where + Self: 'a; + type TxToken<'a> = TxToken + where + Self: 'a; + + fn capabilities(&self) -> DeviceCapabilities { + DeviceCapabilities { + max_transmission_unit: self.mtu, + medium: self.medium, + ..DeviceCapabilities::default() + } + } + + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let mut lower = self.lower.borrow_mut(); + let mut buffer = vec![0; self.mtu]; + match lower.recv(&mut buffer[..]) { + Ok(size) => { + buffer.resize(size, 0); + let rx = RxToken { buffer }; + let tx = TxToken { + lower: self.lower.clone(), + }; + Some((rx, tx)) + } + Err(err) if err.kind() == io::ErrorKind::WouldBlock => None, + Err(err) => panic!("{}", err), + } + } + + fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> { + Some(TxToken { + lower: self.lower.clone(), + }) + } +} + +#[doc(hidden)] +pub struct RxToken { + buffer: Vec<u8>, +} + +impl phy::RxToken for RxToken { + fn consume<R, F>(mut self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + f(&mut self.buffer[..]) + } +} + +#[doc(hidden)] +pub struct TxToken { + lower: Rc<RefCell<sys::RawSocketDesc>>, +} + +impl phy::TxToken for TxToken { + fn consume<R, F>(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut lower = self.lower.borrow_mut(); + let mut buffer = vec![0; len]; + let result = f(&mut buffer); + match lower.send(&buffer[..]) { + Ok(_) => {} + Err(err) if err.kind() == io::ErrorKind::WouldBlock => { + net_debug!("phy: tx failed due to WouldBlock") + } + Err(err) => panic!("{}", err), + } + result + } +} diff --git a/src/phy/sys/bpf.rs b/src/phy/sys/bpf.rs new file mode 100644 index 0000000..7e65b98 --- /dev/null +++ b/src/phy/sys/bpf.rs @@ -0,0 +1,180 @@ +use std::io; +use std::mem; +use std::os::unix::io::{AsRawFd, RawFd}; + +use libc; + +use super::{ifreq, ifreq_for}; +use crate::phy::Medium; +use crate::wire::ETHERNET_HEADER_LEN; + +/// set interface +#[cfg(any(target_os = "macos", target_os = "openbsd"))] +const BIOCSETIF: libc::c_ulong = 0x8020426c; +/// get buffer length +#[cfg(any(target_os = "macos", target_os = "openbsd"))] +const BIOCGBLEN: libc::c_ulong = 0x40044266; +/// set immediate/nonblocking read +#[cfg(any(target_os = "macos", target_os = "openbsd"))] +const BIOCIMMEDIATE: libc::c_ulong = 0x80044270; +/// set bpf_hdr struct size +#[cfg(target_os = "macos")] +const SIZEOF_BPF_HDR: usize = 18; +/// set bpf_hdr struct size +#[cfg(target_os = "openbsd")] +const SIZEOF_BPF_HDR: usize = 24; +/// The actual header length may be larger than the bpf_hdr struct due to aligning +/// see https://github.com/openbsd/src/blob/37ecb4d066e5566411cc16b362d3960c93b1d0be/sys/net/bpf.c#L1649 +/// and https://github.com/apple/darwin-xnu/blob/8f02f2a044b9bb1ad951987ef5bab20ec9486310/bsd/net/bpf.c#L3580 +#[cfg(any(target_os = "macos", target_os = "openbsd"))] +const BPF_HDRLEN: usize = (((SIZEOF_BPF_HDR + ETHERNET_HEADER_LEN) + mem::align_of::<u32>() - 1) + & !(mem::align_of::<u32>() - 1)) + - ETHERNET_HEADER_LEN; + +macro_rules! try_ioctl { + ($fd:expr,$cmd:expr,$req:expr) => { + unsafe { + if libc::ioctl($fd, $cmd, $req) == -1 { + return Err(io::Error::last_os_error()); + } + } + }; +} + +#[derive(Debug)] +pub struct BpfDevice { + fd: libc::c_int, + ifreq: ifreq, +} + +impl AsRawFd for BpfDevice { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +fn open_device() -> io::Result<libc::c_int> { + unsafe { + for i in 0..256 { + let dev = format!("/dev/bpf{}\0", i); + match libc::open( + dev.as_ptr() as *const libc::c_char, + libc::O_RDWR | libc::O_NONBLOCK, + ) { + -1 => continue, + fd => return Ok(fd), + }; + } + } + // at this point, all 256 BPF devices were busy and we weren't able to open any + Err(io::Error::last_os_error()) +} + +impl BpfDevice { + pub fn new(name: &str, _medium: Medium) -> io::Result<BpfDevice> { + Ok(BpfDevice { + fd: open_device()?, + ifreq: ifreq_for(name), + }) + } + + pub fn bind_interface(&mut self) -> io::Result<()> { + try_ioctl!(self.fd, BIOCSETIF, &mut self.ifreq); + + Ok(()) + } + + /// This in fact does not return the interface's mtu, + /// but it returns the size of the buffer that the app needs to allocate + /// for the BPF device + /// + /// The `SIOGIFMTU` cant be called on a BPF descriptor. There is a workaround + /// to get the actual interface mtu, but this should work better + /// + /// To get the interface MTU, you would need to create a raw socket first, + /// and then call `SIOGIFMTU` for the same interface your BPF device is "bound" to. + /// This MTU that you would get would not include the length of `struct bpf_hdr` + /// which gets prepended to every packet by BPF, + /// and your packet will be truncated if it has the length of the MTU. + /// + /// The buffer size for BPF is usually 4096 bytes, MTU is typically 1500 bytes. + /// You could do something like `mtu += BPF_HDRLEN`, + /// but you must change the buffer size the BPF device expects using `BIOCSBLEN` accordingly, + /// and you must set it before setting the interface with the `BIOCSETIF` ioctl. + /// + /// The reason I said this should work better is because you might see some unexpected behavior, + /// truncated/unaligned packets, I/O errors on read() + /// if you change the buffer size to the actual MTU of the interface. + pub fn interface_mtu(&mut self) -> io::Result<usize> { + let mut bufsize: libc::c_int = 1; + try_ioctl!(self.fd, BIOCIMMEDIATE, &mut bufsize as *mut libc::c_int); + try_ioctl!(self.fd, BIOCGBLEN, &mut bufsize as *mut libc::c_int); + + Ok(bufsize as usize) + } + + pub fn recv(&mut self, buffer: &mut [u8]) -> io::Result<usize> { + unsafe { + let len = libc::read( + self.fd, + buffer.as_mut_ptr() as *mut libc::c_void, + buffer.len(), + ); + + if len == -1 || len < BPF_HDRLEN as isize { + return Err(io::Error::last_os_error()); + } + + let len = len as usize; + + libc::memmove( + buffer.as_mut_ptr() as *mut libc::c_void, + &buffer[BPF_HDRLEN] as *const u8 as *const libc::c_void, + len - BPF_HDRLEN, + ); + + Ok(len) + } + } + + pub fn send(&mut self, buffer: &[u8]) -> io::Result<usize> { + unsafe { + let len = libc::write( + self.fd, + buffer.as_ptr() as *const libc::c_void, + buffer.len(), + ); + + if len == -1 { + Err(io::Error::last_os_error()).unwrap() + } + + Ok(len as usize) + } + } +} + +impl Drop for BpfDevice { + fn drop(&mut self) { + unsafe { + libc::close(self.fd); + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[cfg(target_os = "macos")] + fn test_aligned_bpf_hdr_len() { + assert_eq!(18, BPF_HDRLEN); + } + + #[test] + #[cfg(target_os = "openbsd")] + fn test_aligned_bpf_hdr_len() { + assert_eq!(26, BPF_HDRLEN); + } +} diff --git a/src/phy/sys/linux.rs b/src/phy/sys/linux.rs new file mode 100644 index 0000000..c73eb4f --- /dev/null +++ b/src/phy/sys/linux.rs @@ -0,0 +1,26 @@ +#![allow(unused)] + +pub const SIOCGIFMTU: libc::c_ulong = 0x8921; +pub const SIOCGIFINDEX: libc::c_ulong = 0x8933; +pub const ETH_P_ALL: libc::c_short = 0x0003; +pub const ETH_P_IEEE802154: libc::c_short = 0x00F6; + +// Constant definition as per +// https://github.com/golang/sys/blob/master/unix/zerrors_linux_<arch>.go +pub const TUNSETIFF: libc::c_ulong = if cfg!(any( + target_arch = "mips", + target_arch = "mips64", + target_arch = "mips64el", + target_arch = "mipsel", + target_arch = "powerpc", + target_arch = "powerpc64", + target_arch = "powerpc64le", + target_arch = "sparc64" +)) { + 0x800454CA +} else { + 0x400454CA +}; +pub const IFF_TUN: libc::c_int = 0x0001; +pub const IFF_TAP: libc::c_int = 0x0002; +pub const IFF_NO_PI: libc::c_int = 0x1000; diff --git a/src/phy/sys/mod.rs b/src/phy/sys/mod.rs new file mode 100644 index 0000000..3f42301 --- /dev/null +++ b/src/phy/sys/mod.rs @@ -0,0 +1,136 @@ +#![allow(unsafe_code)] + +use crate::time::Duration; +use std::os::unix::io::RawFd; +use std::{io, mem, ptr}; + +#[cfg(any(target_os = "linux", target_os = "android"))] +#[path = "linux.rs"] +mod imp; + +#[cfg(all( + feature = "phy-raw_socket", + not(any(target_os = "linux", target_os = "android")), + unix +))] +pub mod bpf; +#[cfg(all( + feature = "phy-raw_socket", + any(target_os = "linux", target_os = "android") +))] +pub mod raw_socket; +#[cfg(all( + feature = "phy-tuntap_interface", + any(target_os = "linux", target_os = "android") +))] +pub mod tuntap_interface; + +#[cfg(all( + feature = "phy-raw_socket", + not(any(target_os = "linux", target_os = "android")), + unix +))] +pub use self::bpf::BpfDevice as RawSocketDesc; +#[cfg(all( + feature = "phy-raw_socket", + any(target_os = "linux", target_os = "android") +))] +pub use self::raw_socket::RawSocketDesc; +#[cfg(all( + feature = "phy-tuntap_interface", + any(target_os = "linux", target_os = "android") +))] +pub use self::tuntap_interface::TunTapInterfaceDesc; + +/// Wait until given file descriptor becomes readable, but no longer than given timeout. +pub fn wait(fd: RawFd, duration: Option<Duration>) -> io::Result<()> { + unsafe { + let mut readfds = { + let mut readfds = mem::MaybeUninit::<libc::fd_set>::uninit(); + libc::FD_ZERO(readfds.as_mut_ptr()); + libc::FD_SET(fd, readfds.as_mut_ptr()); + readfds.assume_init() + }; + + let mut writefds = { + let mut writefds = mem::MaybeUninit::<libc::fd_set>::uninit(); + libc::FD_ZERO(writefds.as_mut_ptr()); + writefds.assume_init() + }; + + let mut exceptfds = { + let mut exceptfds = mem::MaybeUninit::<libc::fd_set>::uninit(); + libc::FD_ZERO(exceptfds.as_mut_ptr()); + exceptfds.assume_init() + }; + + let mut timeout = libc::timeval { + tv_sec: 0, + tv_usec: 0, + }; + let timeout_ptr = if let Some(duration) = duration { + timeout.tv_sec = duration.secs() as libc::time_t; + timeout.tv_usec = (duration.millis() * 1_000) as libc::suseconds_t; + &mut timeout as *mut _ + } else { + ptr::null_mut() + }; + + let res = libc::select( + fd + 1, + &mut readfds, + &mut writefds, + &mut exceptfds, + timeout_ptr, + ); + if res == -1 { + return Err(io::Error::last_os_error()); + } + Ok(()) + } +} + +#[cfg(all( + any(feature = "phy-tuntap_interface", feature = "phy-raw_socket"), + unix +))] +#[repr(C)] +#[derive(Debug)] +struct ifreq { + ifr_name: [libc::c_char; libc::IF_NAMESIZE], + ifr_data: libc::c_int, /* ifr_ifindex or ifr_mtu */ +} + +#[cfg(all( + any(feature = "phy-tuntap_interface", feature = "phy-raw_socket"), + unix +))] +fn ifreq_for(name: &str) -> ifreq { + let mut ifreq = ifreq { + ifr_name: [0; libc::IF_NAMESIZE], + ifr_data: 0, + }; + for (i, byte) in name.as_bytes().iter().enumerate() { + ifreq.ifr_name[i] = *byte as libc::c_char + } + ifreq +} + +#[cfg(all( + any(target_os = "linux", target_os = "android"), + any(feature = "phy-tuntap_interface", feature = "phy-raw_socket") +))] +fn ifreq_ioctl( + lower: libc::c_int, + ifreq: &mut ifreq, + cmd: libc::c_ulong, +) -> io::Result<libc::c_int> { + unsafe { + let res = libc::ioctl(lower, cmd as _, ifreq as *mut ifreq); + if res == -1 { + return Err(io::Error::last_os_error()); + } + } + + Ok(ifreq.ifr_data) +} diff --git a/src/phy/sys/raw_socket.rs b/src/phy/sys/raw_socket.rs new file mode 100644 index 0000000..f37fe96 --- /dev/null +++ b/src/phy/sys/raw_socket.rs @@ -0,0 +1,115 @@ +use super::*; +use crate::phy::Medium; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::{io, mem}; + +#[derive(Debug)] +pub struct RawSocketDesc { + protocol: libc::c_short, + lower: libc::c_int, + ifreq: ifreq, +} + +impl AsRawFd for RawSocketDesc { + fn as_raw_fd(&self) -> RawFd { + self.lower + } +} + +impl RawSocketDesc { + pub fn new(name: &str, medium: Medium) -> io::Result<RawSocketDesc> { + let protocol = match medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => imp::ETH_P_ALL, + #[cfg(feature = "medium-ip")] + Medium::Ip => imp::ETH_P_ALL, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => imp::ETH_P_IEEE802154, + }; + + let lower = unsafe { + let lower = libc::socket( + libc::AF_PACKET, + libc::SOCK_RAW | libc::SOCK_NONBLOCK, + protocol.to_be() as i32, + ); + if lower == -1 { + return Err(io::Error::last_os_error()); + } + lower + }; + + Ok(RawSocketDesc { + protocol, + lower, + ifreq: ifreq_for(name), + }) + } + + pub fn interface_mtu(&mut self) -> io::Result<usize> { + ifreq_ioctl(self.lower, &mut self.ifreq, imp::SIOCGIFMTU).map(|mtu| mtu as usize) + } + + pub fn bind_interface(&mut self) -> io::Result<()> { + let sockaddr = libc::sockaddr_ll { + sll_family: libc::AF_PACKET as u16, + sll_protocol: self.protocol.to_be() as u16, + sll_ifindex: ifreq_ioctl(self.lower, &mut self.ifreq, imp::SIOCGIFINDEX)?, + sll_hatype: 1, + sll_pkttype: 0, + sll_halen: 6, + sll_addr: [0; 8], + }; + + unsafe { + let res = libc::bind( + self.lower, + &sockaddr as *const libc::sockaddr_ll as *const libc::sockaddr, + mem::size_of::<libc::sockaddr_ll>() as libc::socklen_t, + ); + if res == -1 { + return Err(io::Error::last_os_error()); + } + } + + Ok(()) + } + + pub fn recv(&mut self, buffer: &mut [u8]) -> io::Result<usize> { + unsafe { + let len = libc::recv( + self.lower, + buffer.as_mut_ptr() as *mut libc::c_void, + buffer.len(), + 0, + ); + if len == -1 { + return Err(io::Error::last_os_error()); + } + Ok(len as usize) + } + } + + pub fn send(&mut self, buffer: &[u8]) -> io::Result<usize> { + unsafe { + let len = libc::send( + self.lower, + buffer.as_ptr() as *const libc::c_void, + buffer.len(), + 0, + ); + if len == -1 { + return Err(io::Error::last_os_error()); + } + Ok(len as usize) + } + } +} + +impl Drop for RawSocketDesc { + fn drop(&mut self) { + unsafe { + libc::close(self.lower); + } + } +} diff --git a/src/phy/sys/tuntap_interface.rs b/src/phy/sys/tuntap_interface.rs new file mode 100644 index 0000000..3019cad --- /dev/null +++ b/src/phy/sys/tuntap_interface.rs @@ -0,0 +1,130 @@ +use super::*; +use crate::{phy::Medium, wire::EthernetFrame}; +use std::io; +use std::os::unix::io::{AsRawFd, RawFd}; + +#[derive(Debug)] +pub struct TunTapInterfaceDesc { + lower: libc::c_int, + mtu: usize, +} + +impl AsRawFd for TunTapInterfaceDesc { + fn as_raw_fd(&self) -> RawFd { + self.lower + } +} + +impl TunTapInterfaceDesc { + pub fn new(name: &str, medium: Medium) -> io::Result<TunTapInterfaceDesc> { + let lower = unsafe { + let lower = libc::open( + "/dev/net/tun\0".as_ptr() as *const libc::c_char, + libc::O_RDWR | libc::O_NONBLOCK, + ); + if lower == -1 { + return Err(io::Error::last_os_error()); + } + lower + }; + + let mut ifreq = ifreq_for(name); + Self::attach_interface_ifreq(lower, medium, &mut ifreq)?; + let mtu = Self::mtu_ifreq(medium, &mut ifreq)?; + + Ok(TunTapInterfaceDesc { lower, mtu }) + } + + pub fn from_fd(fd: RawFd, mtu: usize) -> io::Result<TunTapInterfaceDesc> { + Ok(TunTapInterfaceDesc { lower: fd, mtu }) + } + + fn attach_interface_ifreq( + lower: libc::c_int, + medium: Medium, + ifr: &mut ifreq, + ) -> io::Result<()> { + let mode = match medium { + #[cfg(feature = "medium-ip")] + Medium::Ip => imp::IFF_TUN, + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => imp::IFF_TAP, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => todo!(), + }; + ifr.ifr_data = mode | imp::IFF_NO_PI; + ifreq_ioctl(lower, ifr, imp::TUNSETIFF).map(|_| ()) + } + + fn mtu_ifreq(medium: Medium, ifr: &mut ifreq) -> io::Result<usize> { + let lower = unsafe { + let lower = libc::socket(libc::AF_INET, libc::SOCK_DGRAM, libc::IPPROTO_IP); + if lower == -1 { + return Err(io::Error::last_os_error()); + } + lower + }; + + let ip_mtu = ifreq_ioctl(lower, ifr, imp::SIOCGIFMTU).map(|mtu| mtu as usize); + + unsafe { + libc::close(lower); + } + + // Propagate error after close, to ensure we always close. + let ip_mtu = ip_mtu?; + + // SIOCGIFMTU returns the IP MTU (typically 1500 bytes.) + // smoltcp counts the entire Ethernet packet in the MTU, so add the Ethernet header size to it. + let mtu = match medium { + #[cfg(feature = "medium-ip")] + Medium::Ip => ip_mtu, + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => ip_mtu + EthernetFrame::<&[u8]>::header_len(), + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => todo!(), + }; + + Ok(mtu) + } + + pub fn interface_mtu(&self) -> io::Result<usize> { + Ok(self.mtu) + } + + pub fn recv(&mut self, buffer: &mut [u8]) -> io::Result<usize> { + unsafe { + let len = libc::read( + self.lower, + buffer.as_mut_ptr() as *mut libc::c_void, + buffer.len(), + ); + if len == -1 { + return Err(io::Error::last_os_error()); + } + Ok(len as usize) + } + } + + pub fn send(&mut self, buffer: &[u8]) -> io::Result<usize> { + unsafe { + let len = libc::write( + self.lower, + buffer.as_ptr() as *const libc::c_void, + buffer.len(), + ); + if len == -1 { + return Err(io::Error::last_os_error()); + } + Ok(len as usize) + } + } +} + +impl Drop for TunTapInterfaceDesc { + fn drop(&mut self) { + unsafe { + libc::close(self.lower); + } + } +} diff --git a/src/phy/tracer.rs b/src/phy/tracer.rs new file mode 100644 index 0000000..48e60ec --- /dev/null +++ b/src/phy/tracer.rs @@ -0,0 +1,189 @@ +use core::fmt; + +use crate::phy::{self, Device, DeviceCapabilities, Medium}; +use crate::time::Instant; +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; + +/// A tracer device. +/// +/// A tracer is a device that pretty prints all packets traversing it +/// using the provided writer function, and then passes them to another +/// device. +pub struct Tracer<D: Device> { + inner: D, + writer: fn(Instant, Packet), +} + +impl<D: Device> Tracer<D> { + /// Create a tracer device. + pub fn new(inner: D, writer: fn(timestamp: Instant, packet: Packet)) -> Tracer<D> { + Tracer { inner, writer } + } + + /// Get a reference to the underlying device. + /// + /// Even if the device offers reading through a standard reference, it is inadvisable to + /// directly read from the device as doing so will circumvent the tracing. + pub fn get_ref(&self) -> &D { + &self.inner + } + + /// Get a mutable reference to the underlying device. + /// + /// It is inadvisable to directly read from the device as doing so will circumvent the tracing. + pub fn get_mut(&mut self) -> &mut D { + &mut self.inner + } + + /// Return the underlying device, consuming the tracer. + pub fn into_inner(self) -> D { + self.inner + } +} + +impl<D: Device> Device for Tracer<D> { + type RxToken<'a> = RxToken<D::RxToken<'a>> + where + Self: 'a; + type TxToken<'a> = TxToken<D::TxToken<'a>> + where + Self: 'a; + + fn capabilities(&self) -> DeviceCapabilities { + self.inner.capabilities() + } + + fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let medium = self.inner.capabilities().medium; + self.inner.receive(timestamp).map(|(rx_token, tx_token)| { + let rx = RxToken { + token: rx_token, + writer: self.writer, + medium, + timestamp, + }; + let tx = TxToken { + token: tx_token, + writer: self.writer, + medium, + timestamp, + }; + (rx, tx) + }) + } + + fn transmit(&mut self, timestamp: Instant) -> Option<Self::TxToken<'_>> { + let medium = self.inner.capabilities().medium; + self.inner.transmit(timestamp).map(|tx_token| TxToken { + token: tx_token, + medium, + writer: self.writer, + timestamp, + }) + } +} + +#[doc(hidden)] +pub struct RxToken<Rx: phy::RxToken> { + token: Rx, + writer: fn(Instant, Packet), + medium: Medium, + timestamp: Instant, +} + +impl<Rx: phy::RxToken> phy::RxToken for RxToken<Rx> { + fn consume<R, F>(self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + self.token.consume(|buffer| { + (self.writer)( + self.timestamp, + Packet { + buffer, + medium: self.medium, + prefix: "<- ", + }, + ); + f(buffer) + }) + } + + fn meta(&self) -> phy::PacketMeta { + self.token.meta() + } +} + +#[doc(hidden)] +pub struct TxToken<Tx: phy::TxToken> { + token: Tx, + writer: fn(Instant, Packet), + medium: Medium, + timestamp: Instant, +} + +impl<Tx: phy::TxToken> phy::TxToken for TxToken<Tx> { + fn consume<R, F>(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + self.token.consume(len, |buffer| { + let result = f(buffer); + (self.writer)( + self.timestamp, + Packet { + buffer, + medium: self.medium, + prefix: "-> ", + }, + ); + result + }) + } + + fn set_meta(&mut self, meta: phy::PacketMeta) { + self.token.set_meta(meta) + } +} + +pub struct Packet<'a> { + buffer: &'a [u8], + medium: Medium, + prefix: &'static str, +} + +impl<'a> fmt::Display for Packet<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut indent = PrettyIndent::new(self.prefix); + match self.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => crate::wire::EthernetFrame::<&'static [u8]>::pretty_print( + &self.buffer, + f, + &mut indent, + ), + #[cfg(feature = "medium-ip")] + Medium::Ip => match crate::wire::IpVersion::of_packet(self.buffer) { + #[cfg(feature = "proto-ipv4")] + Ok(crate::wire::IpVersion::Ipv4) => { + crate::wire::Ipv4Packet::<&'static [u8]>::pretty_print( + &self.buffer, + f, + &mut indent, + ) + } + #[cfg(feature = "proto-ipv6")] + Ok(crate::wire::IpVersion::Ipv6) => { + crate::wire::Ipv6Packet::<&'static [u8]>::pretty_print( + &self.buffer, + f, + &mut indent, + ) + } + _ => f.write_str("unrecognized IP version"), + }, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => Ok(()), // XXX + } + } +} diff --git a/src/phy/tuntap_interface.rs b/src/phy/tuntap_interface.rs new file mode 100644 index 0000000..32a28db --- /dev/null +++ b/src/phy/tuntap_interface.rs @@ -0,0 +1,126 @@ +use std::cell::RefCell; +use std::io; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::rc::Rc; +use std::vec::Vec; + +use crate::phy::{self, sys, Device, DeviceCapabilities, Medium}; +use crate::time::Instant; + +/// A virtual TUN (IP) or TAP (Ethernet) interface. +#[derive(Debug)] +pub struct TunTapInterface { + lower: Rc<RefCell<sys::TunTapInterfaceDesc>>, + mtu: usize, + medium: Medium, +} + +impl AsRawFd for TunTapInterface { + fn as_raw_fd(&self) -> RawFd { + self.lower.borrow().as_raw_fd() + } +} + +impl TunTapInterface { + /// Attaches to a TUN/TAP interface called `name`, or creates it if it does not exist. + /// + /// If `name` is a persistent interface configured with UID of the current user, + /// no special privileges are needed. Otherwise, this requires superuser privileges + /// or a corresponding capability set on the executable. + pub fn new(name: &str, medium: Medium) -> io::Result<TunTapInterface> { + let lower = sys::TunTapInterfaceDesc::new(name, medium)?; + let mtu = lower.interface_mtu()?; + Ok(TunTapInterface { + lower: Rc::new(RefCell::new(lower)), + mtu, + medium, + }) + } + + /// Attaches to a TUN/TAP interface specified by file descriptor `fd`. + /// + /// On platforms like Android, a file descriptor to a tun interface is exposed. + /// On these platforms, a TunTapInterface cannot be instantiated with a name. + pub fn from_fd(fd: RawFd, medium: Medium, mtu: usize) -> io::Result<TunTapInterface> { + let lower = sys::TunTapInterfaceDesc::from_fd(fd, mtu)?; + Ok(TunTapInterface { + lower: Rc::new(RefCell::new(lower)), + mtu, + medium, + }) + } +} + +impl Device for TunTapInterface { + type RxToken<'a> = RxToken; + type TxToken<'a> = TxToken; + + fn capabilities(&self) -> DeviceCapabilities { + DeviceCapabilities { + max_transmission_unit: self.mtu, + medium: self.medium, + ..DeviceCapabilities::default() + } + } + + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let mut lower = self.lower.borrow_mut(); + let mut buffer = vec![0; self.mtu]; + match lower.recv(&mut buffer[..]) { + Ok(size) => { + buffer.resize(size, 0); + let rx = RxToken { buffer }; + let tx = TxToken { + lower: self.lower.clone(), + }; + Some((rx, tx)) + } + Err(err) if err.kind() == io::ErrorKind::WouldBlock => None, + Err(err) => panic!("{}", err), + } + } + + fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> { + Some(TxToken { + lower: self.lower.clone(), + }) + } +} + +#[doc(hidden)] +pub struct RxToken { + buffer: Vec<u8>, +} + +impl phy::RxToken for RxToken { + fn consume<R, F>(mut self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + f(&mut self.buffer[..]) + } +} + +#[doc(hidden)] +pub struct TxToken { + lower: Rc<RefCell<sys::TunTapInterfaceDesc>>, +} + +impl phy::TxToken for TxToken { + fn consume<R, F>(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut lower = self.lower.borrow_mut(); + let mut buffer = vec![0; len]; + let result = f(&mut buffer); + match lower.send(&buffer[..]) { + Ok(_) => {} + Err(err) if err.kind() == io::ErrorKind::WouldBlock => { + net_debug!("phy: tx failed due to WouldBlock") + } + Err(err) => panic!("{}", err), + } + result + } +} diff --git a/src/rand.rs b/src/rand.rs new file mode 100644 index 0000000..15d88f7 --- /dev/null +++ b/src/rand.rs @@ -0,0 +1,40 @@ +#![allow(unsafe_code)] +#![allow(unused)] + +#[derive(Debug)] +pub(crate) struct Rand { + state: u64, +} + +impl Rand { + pub(crate) const fn new(seed: u64) -> Self { + Self { state: seed } + } + + pub(crate) fn rand_u32(&mut self) -> u32 { + // sPCG32 from https://www.pcg-random.org/paper.html + // see also https://nullprogram.com/blog/2017/09/21/ + const M: u64 = 0xbb2efcec3c39611d; + const A: u64 = 0x7590ef39; + + let s = self.state.wrapping_mul(M).wrapping_add(A); + self.state = s; + + let shift = 29 - (s >> 61); + (s >> shift) as u32 + } + + pub(crate) fn rand_u16(&mut self) -> u16 { + let n = self.rand_u32(); + (n ^ (n >> 16)) as u16 + } + + pub(crate) fn rand_source_port(&mut self) -> u16 { + loop { + let res = self.rand_u16(); + if res > 1024 { + return res; + } + } + } +} diff --git a/src/socket/dhcpv4.rs b/src/socket/dhcpv4.rs new file mode 100644 index 0000000..13ecbd3 --- /dev/null +++ b/src/socket/dhcpv4.rs @@ -0,0 +1,1417 @@ +#[cfg(feature = "async")] +use core::task::Waker; + +use crate::iface::Context; +use crate::time::{Duration, Instant}; +use crate::wire::dhcpv4::field as dhcpv4_field; +use crate::wire::{ + DhcpMessageType, DhcpPacket, DhcpRepr, IpAddress, IpProtocol, Ipv4Address, Ipv4Cidr, Ipv4Repr, + UdpRepr, DHCP_CLIENT_PORT, DHCP_MAX_DNS_SERVER_COUNT, DHCP_SERVER_PORT, UDP_HEADER_LEN, +}; +use crate::wire::{DhcpOption, HardwareAddress}; +use heapless::Vec; + +#[cfg(feature = "async")] +use super::WakerRegistration; + +use super::PollAt; + +const DEFAULT_LEASE_DURATION: Duration = Duration::from_secs(120); + +const DEFAULT_PARAMETER_REQUEST_LIST: &[u8] = &[ + dhcpv4_field::OPT_SUBNET_MASK, + dhcpv4_field::OPT_ROUTER, + dhcpv4_field::OPT_DOMAIN_NAME_SERVER, +]; + +/// IPv4 configuration data provided by the DHCP server. +#[derive(Debug, Eq, PartialEq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Config<'a> { + /// Information on how to reach the DHCP server that responded with DHCP + /// configuration. + pub server: ServerInfo, + /// IP address + pub address: Ipv4Cidr, + /// Router address, also known as default gateway. Does not necessarily + /// match the DHCP server's address. + pub router: Option<Ipv4Address>, + /// DNS servers + pub dns_servers: Vec<Ipv4Address, DHCP_MAX_DNS_SERVER_COUNT>, + /// Received DHCP packet + pub packet: Option<DhcpPacket<&'a [u8]>>, +} + +/// Information on how to reach a DHCP server. +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct ServerInfo { + /// IP address to use as destination in outgoing packets + pub address: Ipv4Address, + /// Server identifier to use in outgoing packets. Usually equal to server_address, + /// but may differ in some situations (eg DHCP relays) + pub identifier: Ipv4Address, +} + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct DiscoverState { + /// When to send next request + retry_at: Instant, +} + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct RequestState { + /// When to send next request + retry_at: Instant, + /// How many retries have been done + retry: u16, + /// Server we're trying to request from + server: ServerInfo, + /// IP address that we're trying to request. + requested_ip: Ipv4Address, +} + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct RenewState { + /// Active network config + config: Config<'static>, + + /// Renew timer. When reached, we will start attempting + /// to renew this lease with the DHCP server. + /// + /// Must be less or equal than `rebind_at`. + renew_at: Instant, + + /// Rebind timer. When reached, we will start broadcasting to renew + /// this lease with any DHCP server. + /// + /// Must be greater than or equal to `renew_at`, and less than or + /// equal to `expires_at`. + rebind_at: Instant, + + /// Whether the T2 time has elapsed + rebinding: bool, + + /// Expiration timer. When reached, this lease is no longer valid, so it must be + /// thrown away and the ethernet interface deconfigured. + expires_at: Instant, +} + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +enum ClientState { + /// Discovering the DHCP server + Discovering(DiscoverState), + /// Requesting an address + Requesting(RequestState), + /// Having an address, refresh it periodically. + Renewing(RenewState), +} + +/// Timeout and retry configuration. +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] +pub struct RetryConfig { + pub discover_timeout: Duration, + /// The REQUEST timeout doubles every 2 tries. + pub initial_request_timeout: Duration, + pub request_retries: u16, + pub min_renew_timeout: Duration, + /// An upper bound on how long to wait between retrying a renew or rebind. + /// + /// Set this to [`Duration::MAX`] if you don't want to impose an upper bound. + pub max_renew_timeout: Duration, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + discover_timeout: Duration::from_secs(10), + initial_request_timeout: Duration::from_secs(5), + request_retries: 5, + min_renew_timeout: Duration::from_secs(60), + max_renew_timeout: Duration::MAX, + } + } +} + +/// Return value for the `Dhcpv4Socket::poll` function +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Event<'a> { + /// Configuration has been lost (for example, the lease has expired) + Deconfigured, + /// Configuration has been newly acquired, or modified. + Configured(Config<'a>), +} + +#[derive(Debug)] +pub struct Socket<'a> { + /// State of the DHCP client. + state: ClientState, + /// Set to true on config/state change, cleared back to false by the `config` function. + config_changed: bool, + /// xid of the last sent message. + transaction_id: u32, + + /// Max lease duration. If set, it sets a maximum cap to the server-provided lease duration. + /// Useful to react faster to IP configuration changes and to test whether renews work correctly. + max_lease_duration: Option<Duration>, + + retry_config: RetryConfig, + + /// Ignore NAKs. + ignore_naks: bool, + + /// Server port config + pub(crate) server_port: u16, + + /// Client port config + pub(crate) client_port: u16, + + /// A buffer contains options additional to be added to outgoing DHCP + /// packets. + outgoing_options: &'a [DhcpOption<'a>], + /// A buffer containing all requested parameters. + parameter_request_list: Option<&'a [u8]>, + + /// Incoming DHCP packets are copied into this buffer, overwriting the previous. + receive_packet_buffer: Option<&'a mut [u8]>, + + /// Waker registration + #[cfg(feature = "async")] + waker: WakerRegistration, +} + +/// DHCP client socket. +/// +/// The socket acquires an IP address configuration through DHCP autonomously. +/// You must query the configuration with `.poll()` after every call to `Interface::poll()`, +/// and apply the configuration to the `Interface`. +impl<'a> Socket<'a> { + /// Create a DHCPv4 socket + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + Socket { + state: ClientState::Discovering(DiscoverState { + retry_at: Instant::from_millis(0), + }), + config_changed: true, + transaction_id: 1, + max_lease_duration: None, + retry_config: RetryConfig::default(), + ignore_naks: false, + outgoing_options: &[], + parameter_request_list: None, + receive_packet_buffer: None, + #[cfg(feature = "async")] + waker: WakerRegistration::new(), + server_port: DHCP_SERVER_PORT, + client_port: DHCP_CLIENT_PORT, + } + } + + /// Set the retry/timeouts configuration. + pub fn set_retry_config(&mut self, config: RetryConfig) { + self.retry_config = config; + } + + /// Gets the current retry/timeouts configuration + pub fn get_retry_config(&self) -> RetryConfig { + self.retry_config + } + + /// Set the outgoing options. + pub fn set_outgoing_options(&mut self, options: &'a [DhcpOption<'a>]) { + self.outgoing_options = options; + } + + /// Set the buffer into which incoming DHCP packets are copied into. + pub fn set_receive_packet_buffer(&mut self, buffer: &'a mut [u8]) { + self.receive_packet_buffer = Some(buffer); + } + + /// Set the parameter request list. + /// + /// This should contain at least `OPT_SUBNET_MASK` (`1`), `OPT_ROUTER` + /// (`3`), and `OPT_DOMAIN_NAME_SERVER` (`6`). + pub fn set_parameter_request_list(&mut self, parameter_request_list: &'a [u8]) { + self.parameter_request_list = Some(parameter_request_list); + } + + /// Get the configured max lease duration. + /// + /// See also [`Self::set_max_lease_duration()`] + pub fn max_lease_duration(&self) -> Option<Duration> { + self.max_lease_duration + } + + /// Set the max lease duration. + /// + /// When set, the lease duration will be capped at the configured duration if the + /// DHCP server gives us a longer lease. This is generally not recommended, but + /// can be useful for debugging or reacting faster to network configuration changes. + /// + /// If None, no max is applied (the lease duration from the DHCP server is used.) + pub fn set_max_lease_duration(&mut self, max_lease_duration: Option<Duration>) { + self.max_lease_duration = max_lease_duration; + } + + /// Get whether to ignore NAKs. + /// + /// See also [`Self::set_ignore_naks()`] + pub fn ignore_naks(&self) -> bool { + self.ignore_naks + } + + /// Set whether to ignore NAKs. + /// + /// This is not compliant with the DHCP RFCs, since theoretically + /// we must stop using the assigned IP when receiving a NAK. This + /// can increase reliability on broken networks with buggy routers + /// or rogue DHCP servers, however. + pub fn set_ignore_naks(&mut self, ignore_naks: bool) { + self.ignore_naks = ignore_naks; + } + + /// Set the server/client port + /// + /// Allows you to specify the ports used by DHCP. + /// This is meant to support esoteric usecases allowed by the dhclient program. + pub fn set_ports(&mut self, server_port: u16, client_port: u16) { + self.server_port = server_port; + self.client_port = client_port; + } + + pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt { + let t = match &self.state { + ClientState::Discovering(state) => state.retry_at, + ClientState::Requesting(state) => state.retry_at, + ClientState::Renewing(state) => if state.rebinding { + state.rebind_at + } else { + state.renew_at.min(state.rebind_at) + } + .min(state.expires_at), + }; + PollAt::Time(t) + } + + pub(crate) fn process( + &mut self, + cx: &mut Context, + ip_repr: &Ipv4Repr, + repr: &UdpRepr, + payload: &[u8], + ) { + let src_ip = ip_repr.src_addr; + + // This is enforced in interface.rs. + assert!(repr.src_port == self.server_port && repr.dst_port == self.client_port); + + let dhcp_packet = match DhcpPacket::new_checked(payload) { + Ok(dhcp_packet) => dhcp_packet, + Err(e) => { + net_debug!("DHCP invalid pkt from {}: {:?}", src_ip, e); + return; + } + }; + let dhcp_repr = match DhcpRepr::parse(&dhcp_packet) { + Ok(dhcp_repr) => dhcp_repr, + Err(e) => { + net_debug!("DHCP error parsing pkt from {}: {:?}", src_ip, e); + return; + } + }; + + let HardwareAddress::Ethernet(ethernet_addr) = cx.hardware_addr() else { + panic!("using DHCPv4 socket with a non-ethernet hardware address."); + }; + + if dhcp_repr.client_hardware_address != ethernet_addr { + return; + } + if dhcp_repr.transaction_id != self.transaction_id { + return; + } + let server_identifier = match dhcp_repr.server_identifier { + Some(server_identifier) => server_identifier, + None => { + net_debug!( + "DHCP ignoring {:?} because missing server_identifier", + dhcp_repr.message_type + ); + return; + } + }; + + net_debug!( + "DHCP recv {:?} from {}: {:?}", + dhcp_repr.message_type, + src_ip, + dhcp_repr + ); + + // Copy over the payload into the receive packet buffer. + if let Some(buffer) = self.receive_packet_buffer.as_mut() { + if let Some(buffer) = buffer.get_mut(..payload.len()) { + buffer.copy_from_slice(payload); + } + } + + match (&mut self.state, dhcp_repr.message_type) { + (ClientState::Discovering(_state), DhcpMessageType::Offer) => { + if !dhcp_repr.your_ip.is_unicast() { + net_debug!("DHCP ignoring OFFER because your_ip is not unicast"); + return; + } + + self.state = ClientState::Requesting(RequestState { + retry_at: cx.now(), + retry: 0, + server: ServerInfo { + address: src_ip, + identifier: server_identifier, + }, + requested_ip: dhcp_repr.your_ip, // use the offered ip + }); + } + (ClientState::Requesting(state), DhcpMessageType::Ack) => { + if let Some((config, renew_at, rebind_at, expires_at)) = + Self::parse_ack(cx.now(), &dhcp_repr, self.max_lease_duration, state.server) + { + self.state = ClientState::Renewing(RenewState { + config, + renew_at, + rebind_at, + expires_at, + rebinding: false, + }); + self.config_changed(); + } + } + (ClientState::Requesting(_), DhcpMessageType::Nak) => { + if !self.ignore_naks { + self.reset(); + } + } + (ClientState::Renewing(state), DhcpMessageType::Ack) => { + if let Some((config, renew_at, rebind_at, expires_at)) = Self::parse_ack( + cx.now(), + &dhcp_repr, + self.max_lease_duration, + state.config.server, + ) { + state.renew_at = renew_at; + state.rebind_at = rebind_at; + state.rebinding = false; + state.expires_at = expires_at; + // The `receive_packet_buffer` field isn't populated until + // the client asks for the state, but receiving any packet + // will change it, so we indicate that the config has + // changed every time if the receive packet buffer is set, + // but we only write changes to the rest of the config now. + let config_changed = + state.config != config || self.receive_packet_buffer.is_some(); + if state.config != config { + state.config = config; + } + if config_changed { + self.config_changed(); + } + } + } + (ClientState::Renewing(_), DhcpMessageType::Nak) => { + if !self.ignore_naks { + self.reset(); + } + } + _ => { + net_debug!( + "DHCP ignoring {:?}: unexpected in current state", + dhcp_repr.message_type + ); + } + } + } + + fn parse_ack( + now: Instant, + dhcp_repr: &DhcpRepr, + max_lease_duration: Option<Duration>, + server: ServerInfo, + ) -> Option<(Config<'static>, Instant, Instant, Instant)> { + let subnet_mask = match dhcp_repr.subnet_mask { + Some(subnet_mask) => subnet_mask, + None => { + net_debug!("DHCP ignoring ACK because missing subnet_mask"); + return None; + } + }; + + let prefix_len = match IpAddress::Ipv4(subnet_mask).prefix_len() { + Some(prefix_len) => prefix_len, + None => { + net_debug!("DHCP ignoring ACK because subnet_mask is not a valid mask"); + return None; + } + }; + + if !dhcp_repr.your_ip.is_unicast() { + net_debug!("DHCP ignoring ACK because your_ip is not unicast"); + return None; + } + + let mut lease_duration = dhcp_repr + .lease_duration + .map(|d| Duration::from_secs(d as _)) + .unwrap_or(DEFAULT_LEASE_DURATION); + if let Some(max_lease_duration) = max_lease_duration { + lease_duration = lease_duration.min(max_lease_duration); + } + + // Cleanup the DNS servers list, keeping only unicasts/ + // TP-Link TD-W8970 sends 0.0.0.0 as second DNS server if there's only one configured :( + let mut dns_servers = Vec::new(); + + dhcp_repr + .dns_servers + .iter() + .flatten() + .filter(|s| s.is_unicast()) + .for_each(|a| { + // This will never produce an error, as both the arrays and `dns_servers` + // have length DHCP_MAX_DNS_SERVER_COUNT + dns_servers.push(*a).ok(); + }); + + let config = Config { + server, + address: Ipv4Cidr::new(dhcp_repr.your_ip, prefix_len), + router: dhcp_repr.router, + dns_servers, + packet: None, + }; + + // Set renew and rebind times as per RFC 2131: + // Times T1 and T2 are configurable by the server through + // options. T1 defaults to (0.5 * duration_of_lease). T2 + // defaults to (0.875 * duration_of_lease). + let (renew_duration, rebind_duration) = match ( + dhcp_repr + .renew_duration + .map(|d| Duration::from_secs(d as u64)), + dhcp_repr + .rebind_duration + .map(|d| Duration::from_secs(d as u64)), + ) { + (Some(renew_duration), Some(rebind_duration)) => (renew_duration, rebind_duration), + (None, None) => (lease_duration / 2, lease_duration * 7 / 8), + // RFC 2131 does not say what to do if only one value is + // provided, so: + + // If only T1 is provided, set T2 to be 0.75 through the gap + // between T1 and the duration of the lease. If T1 is set to + // the default (0.5 * duration_of_lease), then T2 will also + // be set to the default (0.875 * duration_of_lease). + (Some(renew_duration), None) => ( + renew_duration, + renew_duration + (lease_duration - renew_duration) * 3 / 4, + ), + + // If only T2 is provided, then T1 will be set to be + // whichever is smaller of the default (0.5 * + // duration_of_lease) or T2. + (None, Some(rebind_duration)) => { + ((lease_duration / 2).min(rebind_duration), rebind_duration) + } + }; + let renew_at = now + renew_duration; + let rebind_at = now + rebind_duration; + let expires_at = now + lease_duration; + + Some((config, renew_at, rebind_at, expires_at)) + } + + #[cfg(not(test))] + fn random_transaction_id(cx: &mut Context) -> u32 { + cx.rand().rand_u32() + } + + #[cfg(test)] + fn random_transaction_id(_cx: &mut Context) -> u32 { + 0x12345678 + } + + pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E> + where + F: FnOnce(&mut Context, (Ipv4Repr, UdpRepr, DhcpRepr)) -> Result<(), E>, + { + // note: Dhcpv4Socket is only usable in ethernet mediums, so the + // unwrap can never fail. + let HardwareAddress::Ethernet(ethernet_addr) = cx.hardware_addr() else { + panic!("using DHCPv4 socket with a non-ethernet hardware address."); + }; + + // Worst case biggest IPv4 header length. + // 0x0f * 4 = 60 bytes. + const MAX_IPV4_HEADER_LEN: usize = 60; + + // We don't directly modify self.transaction_id because sending the packet + // may fail. We only want to update state after successfully sending. + let next_transaction_id = Self::random_transaction_id(cx); + + let mut dhcp_repr = DhcpRepr { + message_type: DhcpMessageType::Discover, + transaction_id: next_transaction_id, + secs: 0, + client_hardware_address: ethernet_addr, + client_ip: Ipv4Address::UNSPECIFIED, + your_ip: Ipv4Address::UNSPECIFIED, + server_ip: Ipv4Address::UNSPECIFIED, + router: None, + subnet_mask: None, + relay_agent_ip: Ipv4Address::UNSPECIFIED, + broadcast: false, + requested_ip: None, + client_identifier: Some(ethernet_addr), + server_identifier: None, + parameter_request_list: Some( + self.parameter_request_list + .unwrap_or(DEFAULT_PARAMETER_REQUEST_LIST), + ), + max_size: Some((cx.ip_mtu() - MAX_IPV4_HEADER_LEN - UDP_HEADER_LEN) as u16), + lease_duration: None, + renew_duration: None, + rebind_duration: None, + dns_servers: None, + additional_options: self.outgoing_options, + }; + + let udp_repr = UdpRepr { + src_port: self.client_port, + dst_port: self.server_port, + }; + + let mut ipv4_repr = Ipv4Repr { + src_addr: Ipv4Address::UNSPECIFIED, + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Udp, + payload_len: 0, // filled right before emit + hop_limit: 64, + }; + + match &mut self.state { + ClientState::Discovering(state) => { + if cx.now() < state.retry_at { + return Ok(()); + } + + // send packet + net_debug!( + "DHCP send DISCOVER to {}: {:?}", + ipv4_repr.dst_addr, + dhcp_repr + ); + ipv4_repr.payload_len = udp_repr.header_len() + dhcp_repr.buffer_len(); + emit(cx, (ipv4_repr, udp_repr, dhcp_repr))?; + + // Update state AFTER the packet has been successfully sent. + state.retry_at = cx.now() + self.retry_config.discover_timeout; + self.transaction_id = next_transaction_id; + Ok(()) + } + ClientState::Requesting(state) => { + if cx.now() < state.retry_at { + return Ok(()); + } + + if state.retry >= self.retry_config.request_retries { + net_debug!("DHCP request retries exceeded, restarting discovery"); + self.reset(); + return Ok(()); + } + + dhcp_repr.message_type = DhcpMessageType::Request; + dhcp_repr.requested_ip = Some(state.requested_ip); + dhcp_repr.server_identifier = Some(state.server.identifier); + + net_debug!( + "DHCP send request to {}: {:?}", + ipv4_repr.dst_addr, + dhcp_repr + ); + ipv4_repr.payload_len = udp_repr.header_len() + dhcp_repr.buffer_len(); + emit(cx, (ipv4_repr, udp_repr, dhcp_repr))?; + + // Exponential backoff: Double every 2 retries. + state.retry_at = cx.now() + + (self.retry_config.initial_request_timeout << (state.retry as u32 / 2)); + state.retry += 1; + + self.transaction_id = next_transaction_id; + Ok(()) + } + ClientState::Renewing(state) => { + let now = cx.now(); + if state.expires_at <= now { + net_debug!("DHCP lease expired"); + self.reset(); + // return Ok so we get polled again + return Ok(()); + } + + if now < state.renew_at || state.rebinding && now < state.rebind_at { + return Ok(()); + } + + state.rebinding |= now >= state.rebind_at; + + ipv4_repr.src_addr = state.config.address.address(); + // Renewing is unicast to the original server, rebinding is broadcast + if !state.rebinding { + ipv4_repr.dst_addr = state.config.server.address; + } + dhcp_repr.message_type = DhcpMessageType::Request; + dhcp_repr.client_ip = state.config.address.address(); + + net_debug!("DHCP send renew to {}: {:?}", ipv4_repr.dst_addr, dhcp_repr); + ipv4_repr.payload_len = udp_repr.header_len() + dhcp_repr.buffer_len(); + emit(cx, (ipv4_repr, udp_repr, dhcp_repr))?; + + // In both RENEWING and REBINDING states, if the client receives no + // response to its DHCPREQUEST message, the client SHOULD wait one-half + // of the remaining time until T2 (in RENEWING state) and one-half of + // the remaining lease time (in REBINDING state), down to a minimum of + // 60 seconds, before retransmitting the DHCPREQUEST message. + if state.rebinding { + state.rebind_at = now + + self + .retry_config + .min_renew_timeout + .max((state.expires_at - now) / 2) + .min(self.retry_config.max_renew_timeout); + } else { + state.renew_at = now + + self + .retry_config + .min_renew_timeout + .max((state.rebind_at - now) / 2) + .min(state.rebind_at - now) + .min(self.retry_config.max_renew_timeout); + } + + self.transaction_id = next_transaction_id; + Ok(()) + } + } + } + + /// Reset state and restart discovery phase. + /// + /// Use this to speed up acquisition of an address in a new + /// network if a link was down and it is now back up. + pub fn reset(&mut self) { + net_trace!("DHCP reset"); + if let ClientState::Renewing(_) = &self.state { + self.config_changed(); + } + self.state = ClientState::Discovering(DiscoverState { + retry_at: Instant::from_millis(0), + }); + } + + /// Query the socket for configuration changes. + /// + /// The socket has an internal "configuration changed" flag. If + /// set, this function returns the configuration and resets the flag. + pub fn poll(&mut self) -> Option<Event> { + if !self.config_changed { + None + } else if let ClientState::Renewing(state) = &self.state { + self.config_changed = false; + Some(Event::Configured(Config { + server: state.config.server, + address: state.config.address, + router: state.config.router, + dns_servers: state.config.dns_servers.clone(), + packet: self + .receive_packet_buffer + .as_deref() + .map(DhcpPacket::new_unchecked), + })) + } else { + self.config_changed = false; + Some(Event::Deconfigured) + } + } + + /// This function _must_ be called when the configuration provided to the + /// interface, by this DHCP socket, changes. It will update the `config_changed` field + /// so that a subsequent call to `poll` will yield an event, and wake a possible waker. + pub(crate) fn config_changed(&mut self) { + self.config_changed = true; + #[cfg(feature = "async")] + self.waker.wake(); + } + + /// Register a waker. + /// + /// The waker is woken on state changes that might affect the return value + /// of `poll` method calls, which indicates a new state in the DHCP configuration + /// provided by this DHCP socket. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + #[cfg(feature = "async")] + pub fn register_waker(&mut self, waker: &Waker) { + self.waker.register(waker) + } +} + +#[cfg(test)] +mod test { + + use std::ops::{Deref, DerefMut}; + + use super::*; + use crate::wire::EthernetAddress; + + // =========================================================================================// + // Helper functions + + struct TestSocket { + socket: Socket<'static>, + cx: Context, + } + + impl Deref for TestSocket { + type Target = Socket<'static>; + fn deref(&self) -> &Self::Target { + &self.socket + } + } + + impl DerefMut for TestSocket { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.socket + } + } + + fn send( + s: &mut TestSocket, + timestamp: Instant, + (ip_repr, udp_repr, dhcp_repr): (Ipv4Repr, UdpRepr, DhcpRepr), + ) { + s.cx.set_now(timestamp); + + net_trace!("send: {:?}", ip_repr); + net_trace!(" {:?}", udp_repr); + net_trace!(" {:?}", dhcp_repr); + + let mut payload = vec![0; dhcp_repr.buffer_len()]; + dhcp_repr + .emit(&mut DhcpPacket::new_unchecked(&mut payload)) + .unwrap(); + + s.socket.process(&mut s.cx, &ip_repr, &udp_repr, &payload) + } + + fn recv(s: &mut TestSocket, timestamp: Instant, reprs: &[(Ipv4Repr, UdpRepr, DhcpRepr)]) { + s.cx.set_now(timestamp); + + let mut i = 0; + + while s.socket.poll_at(&mut s.cx) <= PollAt::Time(timestamp) { + let _ = s + .socket + .dispatch(&mut s.cx, |_, (mut ip_repr, udp_repr, dhcp_repr)| { + assert_eq!(ip_repr.next_header, IpProtocol::Udp); + assert_eq!( + ip_repr.payload_len, + udp_repr.header_len() + dhcp_repr.buffer_len() + ); + + // We validated the payload len, change it to 0 to make equality testing easier + ip_repr.payload_len = 0; + + net_trace!("recv: {:?}", ip_repr); + net_trace!(" {:?}", udp_repr); + net_trace!(" {:?}", dhcp_repr); + + let got_repr = (ip_repr, udp_repr, dhcp_repr); + match reprs.get(i) { + Some(want_repr) => assert_eq!(want_repr, &got_repr), + None => panic!("Too many reprs emitted"), + } + i += 1; + Ok::<_, ()>(()) + }); + } + + assert_eq!(i, reprs.len()); + } + + macro_rules! send { + ($socket:ident, $repr:expr) => + (send!($socket, time 0, $repr)); + ($socket:ident, time $time:expr, $repr:expr) => + (send(&mut $socket, Instant::from_millis($time), $repr)); + } + + macro_rules! recv { + ($socket:ident, $reprs:expr) => ({ + recv!($socket, time 0, $reprs); + }); + ($socket:ident, time $time:expr, $reprs:expr) => ({ + recv(&mut $socket, Instant::from_millis($time), &$reprs); + }); + } + + // =========================================================================================// + // Constants + + const TXID: u32 = 0x12345678; + + const MY_IP: Ipv4Address = Ipv4Address([192, 168, 1, 42]); + const SERVER_IP: Ipv4Address = Ipv4Address([192, 168, 1, 1]); + const DNS_IP_1: Ipv4Address = Ipv4Address([1, 1, 1, 1]); + const DNS_IP_2: Ipv4Address = Ipv4Address([1, 1, 1, 2]); + const DNS_IP_3: Ipv4Address = Ipv4Address([1, 1, 1, 3]); + const DNS_IPS: &[Ipv4Address] = &[DNS_IP_1, DNS_IP_2, DNS_IP_3]; + + const MASK_24: Ipv4Address = Ipv4Address([255, 255, 255, 0]); + + const MY_MAC: EthernetAddress = EthernetAddress([0x02, 0x02, 0x02, 0x02, 0x02, 0x02]); + + const IP_BROADCAST: Ipv4Repr = Ipv4Repr { + src_addr: Ipv4Address::UNSPECIFIED, + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Udp, + payload_len: 0, + hop_limit: 64, + }; + + const IP_BROADCAST_ADDRESSED: Ipv4Repr = Ipv4Repr { + src_addr: MY_IP, + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Udp, + payload_len: 0, + hop_limit: 64, + }; + + const IP_SERVER_BROADCAST: Ipv4Repr = Ipv4Repr { + src_addr: SERVER_IP, + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Udp, + payload_len: 0, + hop_limit: 64, + }; + + const IP_RECV: Ipv4Repr = Ipv4Repr { + src_addr: SERVER_IP, + dst_addr: MY_IP, + next_header: IpProtocol::Udp, + payload_len: 0, + hop_limit: 64, + }; + + const IP_SEND: Ipv4Repr = Ipv4Repr { + src_addr: MY_IP, + dst_addr: SERVER_IP, + next_header: IpProtocol::Udp, + payload_len: 0, + hop_limit: 64, + }; + + const UDP_SEND: UdpRepr = UdpRepr { + src_port: DHCP_CLIENT_PORT, + dst_port: DHCP_SERVER_PORT, + }; + const UDP_RECV: UdpRepr = UdpRepr { + src_port: DHCP_SERVER_PORT, + dst_port: DHCP_CLIENT_PORT, + }; + + const DIFFERENT_CLIENT_PORT: u16 = 6800; + const DIFFERENT_SERVER_PORT: u16 = 6700; + + const UDP_SEND_DIFFERENT_PORT: UdpRepr = UdpRepr { + src_port: DIFFERENT_CLIENT_PORT, + dst_port: DIFFERENT_SERVER_PORT, + }; + const UDP_RECV_DIFFERENT_PORT: UdpRepr = UdpRepr { + src_port: DIFFERENT_SERVER_PORT, + dst_port: DIFFERENT_CLIENT_PORT, + }; + + const DHCP_DEFAULT: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Unknown(99), + transaction_id: TXID, + secs: 0, + client_hardware_address: MY_MAC, + client_ip: Ipv4Address::UNSPECIFIED, + your_ip: Ipv4Address::UNSPECIFIED, + server_ip: Ipv4Address::UNSPECIFIED, + router: None, + subnet_mask: None, + relay_agent_ip: Ipv4Address::UNSPECIFIED, + broadcast: false, + requested_ip: None, + client_identifier: None, + server_identifier: None, + parameter_request_list: None, + dns_servers: None, + max_size: None, + renew_duration: None, + rebind_duration: None, + lease_duration: None, + additional_options: &[], + }; + + const DHCP_DISCOVER: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Discover, + client_identifier: Some(MY_MAC), + parameter_request_list: Some(&[1, 3, 6]), + max_size: Some(1432), + ..DHCP_DEFAULT + }; + + fn dhcp_offer() -> DhcpRepr<'static> { + DhcpRepr { + message_type: DhcpMessageType::Offer, + server_ip: SERVER_IP, + server_identifier: Some(SERVER_IP), + + your_ip: MY_IP, + router: Some(SERVER_IP), + subnet_mask: Some(MASK_24), + dns_servers: Some(Vec::from_slice(DNS_IPS).unwrap()), + lease_duration: Some(1000), + + ..DHCP_DEFAULT + } + } + + const DHCP_REQUEST: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Request, + client_identifier: Some(MY_MAC), + server_identifier: Some(SERVER_IP), + max_size: Some(1432), + + requested_ip: Some(MY_IP), + parameter_request_list: Some(&[1, 3, 6]), + ..DHCP_DEFAULT + }; + + fn dhcp_ack() -> DhcpRepr<'static> { + DhcpRepr { + message_type: DhcpMessageType::Ack, + server_ip: SERVER_IP, + server_identifier: Some(SERVER_IP), + + your_ip: MY_IP, + router: Some(SERVER_IP), + subnet_mask: Some(MASK_24), + dns_servers: Some(Vec::from_slice(DNS_IPS).unwrap()), + lease_duration: Some(1000), + + ..DHCP_DEFAULT + } + } + + const DHCP_NAK: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Nak, + server_ip: SERVER_IP, + server_identifier: Some(SERVER_IP), + ..DHCP_DEFAULT + }; + + const DHCP_RENEW: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Request, + client_identifier: Some(MY_MAC), + // NO server_identifier in renew requests, only in first one! + client_ip: MY_IP, + max_size: Some(1432), + + requested_ip: None, + parameter_request_list: Some(&[1, 3, 6]), + ..DHCP_DEFAULT + }; + + const DHCP_REBIND: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Request, + client_identifier: Some(MY_MAC), + // NO server_identifier in renew requests, only in first one! + client_ip: MY_IP, + max_size: Some(1432), + + requested_ip: None, + parameter_request_list: Some(&[1, 3, 6]), + ..DHCP_DEFAULT + }; + + // =========================================================================================// + // Tests + + use crate::phy::Medium; + use crate::tests::setup; + use rstest::*; + + fn socket(medium: Medium) -> TestSocket { + let (iface, _, _) = setup(medium); + let mut s = Socket::new(); + assert_eq!(s.poll(), Some(Event::Deconfigured)); + TestSocket { + socket: s, + cx: iface.inner, + } + } + + fn socket_different_port(medium: Medium) -> TestSocket { + let (iface, _, _) = setup(medium); + let mut s = Socket::new(); + s.set_ports(DIFFERENT_SERVER_PORT, DIFFERENT_CLIENT_PORT); + + assert_eq!(s.poll(), Some(Event::Deconfigured)); + TestSocket { + socket: s, + cx: iface.inner, + } + } + + fn socket_bound(medium: Medium) -> TestSocket { + let mut s = socket(medium); + s.state = ClientState::Renewing(RenewState { + config: Config { + server: ServerInfo { + address: SERVER_IP, + identifier: SERVER_IP, + }, + address: Ipv4Cidr::new(MY_IP, 24), + dns_servers: Vec::from_slice(DNS_IPS).unwrap(), + router: Some(SERVER_IP), + packet: None, + }, + renew_at: Instant::from_secs(500), + rebind_at: Instant::from_secs(875), + rebinding: false, + expires_at: Instant::from_secs(1000), + }); + + s + } + + #[rstest] + #[case::ip(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_bind(#[case] medium: Medium) { + let mut s = socket(medium); + + recv!(s, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + assert_eq!(s.poll(), None); + send!(s, (IP_RECV, UDP_RECV, dhcp_offer())); + assert_eq!(s.poll(), None); + recv!(s, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + assert_eq!(s.poll(), None); + send!(s, (IP_RECV, UDP_RECV, dhcp_ack())); + + assert_eq!( + s.poll(), + Some(Event::Configured(Config { + server: ServerInfo { + address: SERVER_IP, + identifier: SERVER_IP, + }, + address: Ipv4Cidr::new(MY_IP, 24), + dns_servers: Vec::from_slice(DNS_IPS).unwrap(), + router: Some(SERVER_IP), + packet: None, + })) + ); + + match &s.state { + ClientState::Renewing(r) => { + assert_eq!(r.renew_at, Instant::from_secs(500)); + assert_eq!(r.rebind_at, Instant::from_secs(875)); + assert_eq!(r.expires_at, Instant::from_secs(1000)); + } + _ => panic!("Invalid state"), + } + } + + #[rstest] + #[case::ip(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_bind_different_ports(#[case] medium: Medium) { + let mut s = socket_different_port(medium); + + recv!(s, [(IP_BROADCAST, UDP_SEND_DIFFERENT_PORT, DHCP_DISCOVER)]); + assert_eq!(s.poll(), None); + send!(s, (IP_RECV, UDP_RECV_DIFFERENT_PORT, dhcp_offer())); + assert_eq!(s.poll(), None); + recv!(s, [(IP_BROADCAST, UDP_SEND_DIFFERENT_PORT, DHCP_REQUEST)]); + assert_eq!(s.poll(), None); + send!(s, (IP_RECV, UDP_RECV_DIFFERENT_PORT, dhcp_ack())); + + assert_eq!( + s.poll(), + Some(Event::Configured(Config { + server: ServerInfo { + address: SERVER_IP, + identifier: SERVER_IP, + }, + address: Ipv4Cidr::new(MY_IP, 24), + dns_servers: Vec::from_slice(DNS_IPS).unwrap(), + router: Some(SERVER_IP), + packet: None, + })) + ); + + match &s.state { + ClientState::Renewing(r) => { + assert_eq!(r.renew_at, Instant::from_secs(500)); + assert_eq!(r.rebind_at, Instant::from_secs(875)); + assert_eq!(r.expires_at, Instant::from_secs(1000)); + } + _ => panic!("Invalid state"), + } + } + + #[rstest] + #[case::ip(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_discover_retransmit(#[case] medium: Medium) { + let mut s = socket(medium); + + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + recv!(s, time 1_000, []); + recv!(s, time 10_000, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + recv!(s, time 11_000, []); + recv!(s, time 20_000, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + + // check after retransmits it still works + send!(s, time 20_000, (IP_RECV, UDP_RECV, dhcp_offer())); + recv!(s, time 20_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + } + + #[rstest] + #[case::ip(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_request_retransmit(#[case] medium: Medium) { + let mut s = socket(medium); + + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + send!(s, time 0, (IP_RECV, UDP_RECV, dhcp_offer())); + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 1_000, []); + recv!(s, time 5_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 6_000, []); + recv!(s, time 10_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 15_000, []); + recv!(s, time 20_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + + // check after retransmits it still works + send!(s, time 20_000, (IP_RECV, UDP_RECV, dhcp_ack())); + + match &s.state { + ClientState::Renewing(r) => { + assert_eq!(r.renew_at, Instant::from_secs(20 + 500)); + assert_eq!(r.expires_at, Instant::from_secs(20 + 1000)); + } + _ => panic!("Invalid state"), + } + } + + #[rstest] + #[case::ip(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_request_timeout(#[case] medium: Medium) { + let mut s = socket(medium); + + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + send!(s, time 0, (IP_RECV, UDP_RECV, dhcp_offer())); + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 5_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 10_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 20_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 30_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + + // After 5 tries and 70 seconds, it gives up. + // 5 + 5 + 10 + 10 + 20 = 70 + recv!(s, time 70_000, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + + // check it still works + send!(s, time 60_000, (IP_RECV, UDP_RECV, dhcp_offer())); + recv!(s, time 60_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + } + + #[rstest] + #[case::ip(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_request_nak(#[case] medium: Medium) { + let mut s = socket(medium); + + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + send!(s, time 0, (IP_RECV, UDP_RECV, dhcp_offer())); + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + send!(s, time 0, (IP_SERVER_BROADCAST, UDP_RECV, DHCP_NAK)); + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + } + + #[rstest] + #[case::ip(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_renew(#[case] medium: Medium) { + let mut s = socket_bound(medium); + + recv!(s, []); + assert_eq!(s.poll(), None); + recv!(s, time 500_000, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + assert_eq!(s.poll(), None); + + match &s.state { + ClientState::Renewing(r) => { + // the expiration still hasn't been bumped, because + // we haven't received the ACK yet + assert_eq!(r.expires_at, Instant::from_secs(1000)); + } + _ => panic!("Invalid state"), + } + + send!(s, time 500_000, (IP_RECV, UDP_RECV, dhcp_ack())); + assert_eq!(s.poll(), None); + + match &s.state { + ClientState::Renewing(r) => { + // NOW the expiration gets bumped + assert_eq!(r.renew_at, Instant::from_secs(500 + 500)); + assert_eq!(r.expires_at, Instant::from_secs(500 + 1000)); + } + _ => panic!("Invalid state"), + } + } + + #[rstest] + #[case::ip(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_renew_rebind_retransmit(#[case] medium: Medium) { + let mut s = socket_bound(medium); + + recv!(s, []); + // First renew attempt at T1 + recv!(s, time 499_000, []); + recv!(s, time 500_000, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt at half way to T2 + recv!(s, time 687_000, []); + recv!(s, time 687_500, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt at half way again to T2 + recv!(s, time 781_000, []); + recv!(s, time 781_250, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt 60s later (minimum interval) + recv!(s, time 841_000, []); + recv!(s, time 841_250, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // No more renews due to minimum interval + recv!(s, time 874_000, []); + // First rebind attempt + recv!(s, time 875_000, [(IP_BROADCAST_ADDRESSED, UDP_SEND, DHCP_REBIND)]); + // Next rebind attempt half way to expiry + recv!(s, time 937_000, []); + recv!(s, time 937_500, [(IP_BROADCAST_ADDRESSED, UDP_SEND, DHCP_REBIND)]); + // Next rebind attempt 60s later (minimum interval) + recv!(s, time 997_000, []); + recv!(s, time 997_500, [(IP_BROADCAST_ADDRESSED, UDP_SEND, DHCP_REBIND)]); + + // check it still works + send!(s, time 999_000, (IP_RECV, UDP_RECV, dhcp_ack())); + match &s.state { + ClientState::Renewing(r) => { + // NOW the expiration gets bumped + assert_eq!(r.renew_at, Instant::from_secs(999 + 500)); + assert_eq!(r.expires_at, Instant::from_secs(999 + 1000)); + } + _ => panic!("Invalid state"), + } + } + + #[rstest] + #[case::ip(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_renew_rebind_timeout(#[case] medium: Medium) { + let mut s = socket_bound(medium); + + recv!(s, []); + // First renew attempt at T1 + recv!(s, time 500_000, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt at half way to T2 + recv!(s, time 687_500, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt at half way again to T2 + recv!(s, time 781_250, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt 60s later (minimum interval) + recv!(s, time 841_250, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // TODO uncomment below part of test + // // First rebind attempt + // recv!(s, time 875_000, [(IP_BROADCAST_ADDRESSED, UDP_SEND, DHCP_REBIND)]); + // // Next rebind attempt half way to expiry + // recv!(s, time 937_500, [(IP_BROADCAST_ADDRESSED, UDP_SEND, DHCP_REBIND)]); + // // Next rebind attempt 60s later (minimum interval) + // recv!(s, time 997_500, [(IP_BROADCAST_ADDRESSED, UDP_SEND, DHCP_REBIND)]); + // No more rebinds due to minimum interval + recv!(s, time 1_000_000, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + match &s.state { + ClientState::Discovering(_) => {} + _ => panic!("Invalid state"), + } + } + + #[rstest] + #[case::ip(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_min_max_renew_timeout(#[case] medium: Medium) { + let mut s = socket_bound(medium); + // Set a minimum of 45s and a maximum of 120s + let config = RetryConfig { + max_renew_timeout: Duration::from_secs(120), + min_renew_timeout: Duration::from_secs(45), + ..s.get_retry_config() + }; + s.set_retry_config(config); + recv!(s, []); + // First renew attempt at T1 + recv!(s, time 499_999, []); + recv!(s, time 500_000, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt 120s after T1 because we hit the max + recv!(s, time 619_999, []); + recv!(s, time 620_000, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt 120s after previous because we hit the max again + recv!(s, time 739_999, []); + recv!(s, time 740_000, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt half way to T2 + recv!(s, time 807_499, []); + recv!(s, time 807_500, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt 45s after previous because we hit the min + recv!(s, time 852_499, []); + recv!(s, time 852_500, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next is a rebind, because the min puts us after T2 + recv!(s, time 874_999, []); + recv!(s, time 875_000, [(IP_BROADCAST_ADDRESSED, UDP_SEND, DHCP_REBIND)]); + } + + #[rstest] + #[case::ip(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_renew_nak(#[case] medium: Medium) { + let mut s = socket_bound(medium); + + recv!(s, time 500_000, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + send!(s, time 500_000, (IP_SERVER_BROADCAST, UDP_RECV, DHCP_NAK)); + recv!(s, time 500_000, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + } +} diff --git a/src/socket/dns.rs b/src/socket/dns.rs new file mode 100644 index 0000000..610d5c6 --- /dev/null +++ b/src/socket/dns.rs @@ -0,0 +1,699 @@ +#[cfg(feature = "async")] +use core::task::Waker; + +use heapless::Vec; +use managed::ManagedSlice; + +use crate::config::{DNS_MAX_NAME_SIZE, DNS_MAX_RESULT_COUNT, DNS_MAX_SERVER_COUNT}; +use crate::socket::{Context, PollAt}; +use crate::time::{Duration, Instant}; +use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type}; +use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr}; + +#[cfg(feature = "async")] +use super::WakerRegistration; + +const DNS_PORT: u16 = 53; +const MDNS_DNS_PORT: u16 = 5353; +const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000); +const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000); +const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs + +#[cfg(feature = "proto-ipv6")] +const MDNS_IPV6_ADDR: IpAddress = IpAddress::Ipv6(crate::wire::Ipv6Address([ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb, +])); + +#[cfg(feature = "proto-ipv4")] +const MDNS_IPV4_ADDR: IpAddress = IpAddress::Ipv4(crate::wire::Ipv4Address([224, 0, 0, 251])); + +/// Error returned by [`Socket::start_query`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum StartQueryError { + NoFreeSlot, + InvalidName, + NameTooLong, +} + +impl core::fmt::Display for StartQueryError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + StartQueryError::NoFreeSlot => write!(f, "No free slot"), + StartQueryError::InvalidName => write!(f, "Invalid name"), + StartQueryError::NameTooLong => write!(f, "Name too long"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for StartQueryError {} + +/// Error returned by [`Socket::get_query_result`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum GetQueryResultError { + /// Query is not done yet. + Pending, + /// Query failed. + Failed, +} + +impl core::fmt::Display for GetQueryResultError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + GetQueryResultError::Pending => write!(f, "Query is not done yet"), + GetQueryResultError::Failed => write!(f, "Query failed"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for GetQueryResultError {} + +/// State for an in-progress DNS query. +/// +/// The only reason this struct is public is to allow the socket state +/// to be allocated externally. +#[derive(Debug)] +pub struct DnsQuery { + state: State, + + #[cfg(feature = "async")] + waker: WakerRegistration, +} + +impl DnsQuery { + fn set_state(&mut self, state: State) { + self.state = state; + #[cfg(feature = "async")] + self.waker.wake(); + } +} + +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +enum State { + Pending(PendingQuery), + Completed(CompletedQuery), + Failure, +} + +#[derive(Debug)] +struct PendingQuery { + name: Vec<u8, DNS_MAX_NAME_SIZE>, + type_: Type, + + port: u16, // UDP port (src for request, dst for response) + txid: u16, // transaction ID + + timeout_at: Option<Instant>, + retransmit_at: Instant, + delay: Duration, + + server_idx: usize, + mdns: MulticastDns, +} + +#[derive(Debug)] +pub enum MulticastDns { + Disabled, + #[cfg(feature = "socket-mdns")] + Enabled, +} + +#[derive(Debug)] +struct CompletedQuery { + addresses: Vec<IpAddress, DNS_MAX_RESULT_COUNT>, +} + +/// A handle to an in-progress DNS query. +#[derive(Clone, Copy)] +pub struct QueryHandle(usize); + +/// A Domain Name System socket. +/// +/// A UDP socket is bound to a specific endpoint, and owns transmit and receive +/// packet buffers. +#[derive(Debug)] +pub struct Socket<'a> { + servers: Vec<IpAddress, DNS_MAX_SERVER_COUNT>, + queries: ManagedSlice<'a, Option<DnsQuery>>, + + /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + hop_limit: Option<u8>, +} + +impl<'a> Socket<'a> { + /// Create a DNS socket. + /// + /// # Panics + /// + /// Panics if `servers.len() > MAX_SERVER_COUNT` + pub fn new<Q>(servers: &[IpAddress], queries: Q) -> Socket<'a> + where + Q: Into<ManagedSlice<'a, Option<DnsQuery>>>, + { + Socket { + servers: Vec::from_slice(servers).unwrap(), + queries: queries.into(), + hop_limit: None, + } + } + + /// Update the list of DNS servers, will replace all existing servers + /// + /// # Panics + /// + /// Panics if `servers.len() > MAX_SERVER_COUNT` + pub fn update_servers(&mut self, servers: &[IpAddress]) { + self.servers = Vec::from_slice(servers).unwrap(); + } + + /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + /// + /// See also the [set_hop_limit](#method.set_hop_limit) method + pub fn hop_limit(&self) -> Option<u8> { + self.hop_limit + } + + /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + /// + /// A socket without an explicitly set hop limit value uses the default [IANA recommended] + /// value (64). + /// + /// # Panics + /// + /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7]. + /// + /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml + /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7 + pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { + // A host MUST NOT send a datagram with a hop limit value of 0 + if let Some(0) = hop_limit { + panic!("the time-to-live value of a packet must not be zero") + } + + self.hop_limit = hop_limit + } + + fn find_free_query(&mut self) -> Option<QueryHandle> { + for (i, q) in self.queries.iter().enumerate() { + if q.is_none() { + return Some(QueryHandle(i)); + } + } + + match &mut self.queries { + ManagedSlice::Borrowed(_) => None, + #[cfg(feature = "alloc")] + ManagedSlice::Owned(queries) => { + queries.push(None); + let index = queries.len() - 1; + Some(QueryHandle(index)) + } + } + } + + /// Start a query. + /// + /// `name` is specified in human-friendly format, such as `"rust-lang.org"`. + /// It accepts names both with and without trailing dot, and they're treated + /// the same (there's no support for DNS search path). + pub fn start_query( + &mut self, + cx: &mut Context, + name: &str, + query_type: Type, + ) -> Result<QueryHandle, StartQueryError> { + let mut name = name.as_bytes(); + + if name.is_empty() { + net_trace!("invalid name: zero length"); + return Err(StartQueryError::InvalidName); + } + + // Remove trailing dot, if any + if name[name.len() - 1] == b'.' { + name = &name[..name.len() - 1]; + } + + let mut raw_name: Vec<u8, DNS_MAX_NAME_SIZE> = Vec::new(); + + let mut mdns = MulticastDns::Disabled; + #[cfg(feature = "socket-mdns")] + if name.split(|&c| c == b'.').last().unwrap() == b"local" { + net_trace!("Starting a mDNS query"); + mdns = MulticastDns::Enabled; + } + + for s in name.split(|&c| c == b'.') { + if s.len() > 63 { + net_trace!("invalid name: too long label"); + return Err(StartQueryError::InvalidName); + } + if s.is_empty() { + net_trace!("invalid name: zero length label"); + return Err(StartQueryError::InvalidName); + } + + // Push label + raw_name + .push(s.len() as u8) + .map_err(|_| StartQueryError::NameTooLong)?; + raw_name + .extend_from_slice(s) + .map_err(|_| StartQueryError::NameTooLong)?; + } + + // Push terminator. + raw_name + .push(0x00) + .map_err(|_| StartQueryError::NameTooLong)?; + + self.start_query_raw(cx, &raw_name, query_type, mdns) + } + + /// Start a query with a raw (wire-format) DNS name. + /// `b"\x09rust-lang\x03org\x00"` + /// + /// You probably want to use [`start_query`] instead. + pub fn start_query_raw( + &mut self, + cx: &mut Context, + raw_name: &[u8], + query_type: Type, + mdns: MulticastDns, + ) -> Result<QueryHandle, StartQueryError> { + let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?; + + self.queries[handle.0] = Some(DnsQuery { + state: State::Pending(PendingQuery { + name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?, + type_: query_type, + txid: cx.rand().rand_u16(), + port: cx.rand().rand_source_port(), + delay: RETRANSMIT_DELAY, + timeout_at: None, + retransmit_at: Instant::ZERO, + server_idx: 0, + mdns, + }), + #[cfg(feature = "async")] + waker: WakerRegistration::new(), + }); + Ok(handle) + } + + /// Get the result of a query. + /// + /// If the query is completed, the query slot is automatically freed. + /// + /// # Panics + /// Panics if the QueryHandle corresponds to a free slot. + pub fn get_query_result( + &mut self, + handle: QueryHandle, + ) -> Result<Vec<IpAddress, DNS_MAX_RESULT_COUNT>, GetQueryResultError> { + let slot = &mut self.queries[handle.0]; + let q = slot.as_mut().unwrap(); + match &mut q.state { + // Query is not done yet. + State::Pending(_) => Err(GetQueryResultError::Pending), + // Query is done + State::Completed(q) => { + let res = q.addresses.clone(); + *slot = None; // Free up the slot for recycling. + Ok(res) + } + State::Failure => { + *slot = None; // Free up the slot for recycling. + Err(GetQueryResultError::Failed) + } + } + } + + /// Cancels a query, freeing the slot. + /// + /// # Panics + /// + /// Panics if the QueryHandle corresponds to an already free slot. + pub fn cancel_query(&mut self, handle: QueryHandle) { + let slot = &mut self.queries[handle.0]; + if slot.is_none() { + panic!("Canceling query in a free slot.") + } + *slot = None; // Free up the slot for recycling. + } + + /// Assign a waker to a query slot + /// + /// The waker will be woken when the query completes, either successfully or failed. + /// + /// # Panics + /// + /// Panics if the QueryHandle corresponds to an already free slot. + #[cfg(feature = "async")] + pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) { + self.queries[handle.0] + .as_mut() + .unwrap() + .waker + .register(waker); + } + + pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool { + (udp_repr.src_port == DNS_PORT + && self + .servers + .iter() + .any(|server| *server == ip_repr.src_addr())) + || (udp_repr.src_port == MDNS_DNS_PORT) + } + + pub(crate) fn process( + &mut self, + _cx: &mut Context, + ip_repr: &IpRepr, + udp_repr: &UdpRepr, + payload: &[u8], + ) { + debug_assert!(self.accepts(ip_repr, udp_repr)); + + let size = payload.len(); + + net_trace!( + "receiving {} octets from {:?}:{}", + size, + ip_repr.src_addr(), + udp_repr.dst_port + ); + + let p = match Packet::new_checked(payload) { + Ok(x) => x, + Err(_) => { + net_trace!("dns packet malformed"); + return; + } + }; + if p.opcode() != Opcode::Query { + net_trace!("unwanted opcode {:?}", p.opcode()); + return; + } + + if !p.flags().contains(Flags::RESPONSE) { + net_trace!("packet doesn't have response bit set"); + return; + } + + if p.question_count() != 1 { + net_trace!("bad question count {:?}", p.question_count()); + return; + } + + // Find pending query + for q in self.queries.iter_mut().flatten() { + if let State::Pending(pq) = &mut q.state { + if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid { + continue; + } + + if p.rcode() == Rcode::NXDomain { + net_trace!("rcode NXDomain"); + q.set_state(State::Failure); + continue; + } + + let payload = p.payload(); + let (mut payload, question) = match Question::parse(payload) { + Ok(x) => x, + Err(_) => { + net_trace!("question malformed"); + return; + } + }; + + if question.type_ != pq.type_ { + net_trace!("question type mismatch"); + return; + } + + match eq_names(p.parse_name(question.name), p.parse_name(&pq.name)) { + Ok(true) => {} + Ok(false) => { + net_trace!("question name mismatch"); + return; + } + Err(_) => { + net_trace!("dns question name malformed"); + return; + } + } + + let mut addresses = Vec::new(); + + for _ in 0..p.answer_record_count() { + let (payload2, r) = match Record::parse(payload) { + Ok(x) => x, + Err(_) => { + net_trace!("dns answer record malformed"); + return; + } + }; + payload = payload2; + + match eq_names(p.parse_name(r.name), p.parse_name(&pq.name)) { + Ok(true) => {} + Ok(false) => { + net_trace!("answer name mismatch: {:?}", r); + continue; + } + Err(_) => { + net_trace!("dns answer record name malformed"); + return; + } + } + + match r.data { + #[cfg(feature = "proto-ipv4")] + RecordData::A(addr) => { + net_trace!("A: {:?}", addr); + if addresses.push(addr.into()).is_err() { + net_trace!("too many addresses in response, ignoring {:?}", addr); + } + } + #[cfg(feature = "proto-ipv6")] + RecordData::Aaaa(addr) => { + net_trace!("AAAA: {:?}", addr); + if addresses.push(addr.into()).is_err() { + net_trace!("too many addresses in response, ignoring {:?}", addr); + } + } + RecordData::Cname(name) => { + net_trace!("CNAME: {:?}", name); + + // When faced with a CNAME, recursive resolvers are supposed to + // resolve the CNAME and append the results for it. + // + // We update the query with the new name, so that we pick up the A/AAAA + // records for the CNAME when we parse them later. + // I believe it's mandatory the CNAME results MUST come *after* in the + // packet, so it's enough to do one linear pass over it. + if copy_name(&mut pq.name, p.parse_name(name)).is_err() { + net_trace!("dns answer cname malformed"); + return; + } + } + RecordData::Other(type_, data) => { + net_trace!("unknown: {:?} {:?}", type_, data) + } + } + } + + q.set_state(if addresses.is_empty() { + State::Failure + } else { + State::Completed(CompletedQuery { addresses }) + }); + + // If we get here, packet matched the current query, stop processing. + return; + } + } + + // If we get here, packet matched with no query. + net_trace!("no query matched"); + } + + pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E> + where + F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>, + { + let hop_limit = self.hop_limit.unwrap_or(64); + + for q in self.queries.iter_mut().flatten() { + if let State::Pending(pq) = &mut q.state { + // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns + // so we internally overwrite the servers for any of those queries + // in this function. + let servers = match pq.mdns { + #[cfg(feature = "socket-mdns")] + MulticastDns::Enabled => &[ + #[cfg(feature = "proto-ipv6")] + MDNS_IPV6_ADDR, + #[cfg(feature = "proto-ipv4")] + MDNS_IPV4_ADDR, + ], + MulticastDns::Disabled => self.servers.as_slice(), + }; + + let timeout = if let Some(timeout) = pq.timeout_at { + timeout + } else { + let v = cx.now() + RETRANSMIT_TIMEOUT; + pq.timeout_at = Some(v); + v + }; + + // Check timeout + if timeout < cx.now() { + // DNS timeout + pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT); + pq.retransmit_at = Instant::ZERO; + pq.delay = RETRANSMIT_DELAY; + + // Try next server. We check below whether we've tried all servers. + pq.server_idx += 1; + } + // Check if we've run out of servers to try. + if pq.server_idx >= servers.len() { + net_trace!("already tried all servers."); + q.set_state(State::Failure); + continue; + } + + // Check so the IP address is valid + if servers[pq.server_idx].is_unspecified() { + net_trace!("invalid unspecified DNS server addr."); + q.set_state(State::Failure); + continue; + } + + if pq.retransmit_at > cx.now() { + // query is waiting for retransmit + continue; + } + + let repr = Repr { + transaction_id: pq.txid, + flags: Flags::RECURSION_DESIRED, + opcode: Opcode::Query, + question: Question { + name: &pq.name, + type_: pq.type_, + }, + }; + + let mut payload = [0u8; 512]; + let payload = &mut payload[..repr.buffer_len()]; + repr.emit(&mut Packet::new_unchecked(payload)); + + let dst_port = match pq.mdns { + #[cfg(feature = "socket-mdns")] + MulticastDns::Enabled => MDNS_DNS_PORT, + MulticastDns::Disabled => DNS_PORT, + }; + + let udp_repr = UdpRepr { + src_port: pq.port, + dst_port, + }; + + let dst_addr = servers[pq.server_idx]; + let src_addr = cx.get_source_address(&dst_addr).unwrap(); // TODO remove unwrap + let ip_repr = IpRepr::new( + src_addr, + dst_addr, + IpProtocol::Udp, + udp_repr.header_len() + payload.len(), + hop_limit, + ); + + net_trace!( + "sending {} octets to {} from port {}", + payload.len(), + ip_repr.dst_addr(), + udp_repr.src_port + ); + + emit(cx, (ip_repr, udp_repr, payload))?; + + pq.retransmit_at = cx.now() + pq.delay; + pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2); + + return Ok(()); + } + } + + // Nothing to dispatch + Ok(()) + } + + pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt { + self.queries + .iter() + .flatten() + .filter_map(|q| match &q.state { + State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)), + State::Completed(_) => None, + State::Failure => None, + }) + .min() + .unwrap_or(PollAt::Ingress) + } +} + +fn eq_names<'a>( + mut a: impl Iterator<Item = wire::Result<&'a [u8]>>, + mut b: impl Iterator<Item = wire::Result<&'a [u8]>>, +) -> wire::Result<bool> { + loop { + match (a.next(), b.next()) { + // Handle errors + (Some(Err(e)), _) => return Err(e), + (_, Some(Err(e))) => return Err(e), + + // Both finished -> equal + (None, None) => return Ok(true), + + // One finished before the other -> not equal + (None, _) => return Ok(false), + (_, None) => return Ok(false), + + // Got two labels, check if they're equal + (Some(Ok(la)), Some(Ok(lb))) => { + if la != lb { + return Ok(false); + } + } + } + } +} + +fn copy_name<'a, const N: usize>( + dest: &mut Vec<u8, N>, + name: impl Iterator<Item = wire::Result<&'a [u8]>>, +) -> Result<(), wire::Error> { + dest.truncate(0); + + for label in name { + let label = label?; + dest.push(label.len() as u8).map_err(|_| wire::Error)?; + dest.extend_from_slice(label).map_err(|_| wire::Error)?; + } + + // Write terminator 0x00 + dest.push(0).map_err(|_| wire::Error)?; + + Ok(()) +} diff --git a/src/socket/icmp.rs b/src/socket/icmp.rs new file mode 100644 index 0000000..c18b754 --- /dev/null +++ b/src/socket/icmp.rs @@ -0,0 +1,1240 @@ +use core::cmp; +#[cfg(feature = "async")] +use core::task::Waker; + +use crate::phy::ChecksumCapabilities; +#[cfg(feature = "async")] +use crate::socket::WakerRegistration; +use crate::socket::{Context, PollAt}; + +use crate::storage::Empty; +use crate::wire::IcmpRepr; +#[cfg(feature = "proto-ipv4")] +use crate::wire::{Icmpv4Packet, Icmpv4Repr, Ipv4Repr}; +#[cfg(feature = "proto-ipv6")] +use crate::wire::{Icmpv6Packet, Icmpv6Repr, Ipv6Repr}; +use crate::wire::{IpAddress, IpListenEndpoint, IpProtocol, IpRepr}; +use crate::wire::{UdpPacket, UdpRepr}; + +/// Error returned by [`Socket::bind`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum BindError { + InvalidState, + Unaddressable, +} + +impl core::fmt::Display for BindError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + BindError::InvalidState => write!(f, "invalid state"), + BindError::Unaddressable => write!(f, "unaddressable"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for BindError {} + +/// Error returned by [`Socket::send`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum SendError { + Unaddressable, + BufferFull, +} + +impl core::fmt::Display for SendError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + SendError::Unaddressable => write!(f, "unaddressable"), + SendError::BufferFull => write!(f, "buffer full"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for SendError {} + +/// Error returned by [`Socket::recv`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum RecvError { + Exhausted, + Truncated, +} + +impl core::fmt::Display for RecvError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + RecvError::Exhausted => write!(f, "exhausted"), + RecvError::Truncated => write!(f, "truncated"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for RecvError {} + +/// Type of endpoint to bind the ICMP socket to. See [IcmpSocket::bind] for +/// more details. +/// +/// [IcmpSocket::bind]: struct.IcmpSocket.html#method.bind +#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Endpoint { + #[default] + Unspecified, + Ident(u16), + Udp(IpListenEndpoint), +} + +impl Endpoint { + pub fn is_specified(&self) -> bool { + match *self { + Endpoint::Ident(_) => true, + Endpoint::Udp(endpoint) => endpoint.port != 0, + Endpoint::Unspecified => false, + } + } +} + +/// An ICMP packet metadata. +pub type PacketMetadata = crate::storage::PacketMetadata<IpAddress>; + +/// An ICMP packet ring buffer. +pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, IpAddress>; + +/// A ICMP socket +/// +/// An ICMP socket is bound to a specific [IcmpEndpoint] which may +/// be a specific UDP port to listen for ICMP error messages related +/// to the port or a specific ICMP identifier value. See [bind] for +/// more details. +/// +/// [IcmpEndpoint]: enum.IcmpEndpoint.html +/// [bind]: #method.bind +#[derive(Debug)] +pub struct Socket<'a> { + rx_buffer: PacketBuffer<'a>, + tx_buffer: PacketBuffer<'a>, + /// The endpoint this socket is communicating with + endpoint: Endpoint, + /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + hop_limit: Option<u8>, + #[cfg(feature = "async")] + rx_waker: WakerRegistration, + #[cfg(feature = "async")] + tx_waker: WakerRegistration, +} + +impl<'a> Socket<'a> { + /// Create an ICMP socket with the given buffers. + pub fn new(rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>) -> Socket<'a> { + Socket { + rx_buffer, + tx_buffer, + endpoint: Default::default(), + hop_limit: None, + #[cfg(feature = "async")] + rx_waker: WakerRegistration::new(), + #[cfg(feature = "async")] + tx_waker: WakerRegistration::new(), + } + } + + /// Register a waker for receive operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `recv` method calls, such as receiving data, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `recv` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_recv_waker(&mut self, waker: &Waker) { + self.rx_waker.register(waker) + } + + /// Register a waker for send operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `send` method calls, such as space becoming available in the transmit + /// buffer, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `send` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_send_waker(&mut self, waker: &Waker) { + self.tx_waker.register(waker) + } + + /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + /// + /// See also the [set_hop_limit](#method.set_hop_limit) method + pub fn hop_limit(&self) -> Option<u8> { + self.hop_limit + } + + /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + /// + /// A socket without an explicitly set hop limit value uses the default [IANA recommended] + /// value (64). + /// + /// # Panics + /// + /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7]. + /// + /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml + /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7 + pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { + // A host MUST NOT send a datagram with a hop limit value of 0 + if let Some(0) = hop_limit { + panic!("the time-to-live value of a packet must not be zero") + } + + self.hop_limit = hop_limit + } + + /// Bind the socket to the given endpoint. + /// + /// This function returns `Err(Error::Illegal)` if the socket was open + /// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)` + /// if `endpoint` is unspecified (see [is_specified]). + /// + /// # Examples + /// + /// ## Bind to ICMP Error messages associated with a specific UDP port: + /// + /// To [recv] ICMP error messages that are associated with a specific local + /// UDP port, the socket may be bound to a given port using [IcmpEndpoint::Udp]. + /// This may be useful for applications using UDP attempting to detect and/or + /// diagnose connection problems. + /// + /// ``` + /// use smoltcp::wire::IpListenEndpoint; + /// use smoltcp::socket::icmp; + /// # let rx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 20]); + /// # let tx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 20]); + /// + /// let mut icmp_socket = // ... + /// # icmp::Socket::new(rx_buffer, tx_buffer); + /// + /// // Bind to ICMP error responses for UDP packets sent from port 53. + /// let endpoint = IpListenEndpoint::from(53); + /// icmp_socket.bind(icmp::Endpoint::Udp(endpoint)).unwrap(); + /// ``` + /// + /// ## Bind to a specific ICMP identifier: + /// + /// To [send] and [recv] ICMP packets that are not associated with a specific UDP + /// port, the socket may be bound to a specific ICMP identifier using + /// [IcmpEndpoint::Ident]. This is useful for sending and receiving Echo Request/Reply + /// messages. + /// + /// ``` + /// use smoltcp::wire::IpListenEndpoint; + /// use smoltcp::socket::icmp; + /// # let rx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 20]); + /// # let tx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 20]); + /// + /// let mut icmp_socket = // ... + /// # icmp::Socket::new(rx_buffer, tx_buffer); + /// + /// // Bind to ICMP messages with the ICMP identifier 0x1234 + /// icmp_socket.bind(icmp::Endpoint::Ident(0x1234)).unwrap(); + /// ``` + /// + /// [is_specified]: enum.IcmpEndpoint.html#method.is_specified + /// [IcmpEndpoint::Ident]: enum.IcmpEndpoint.html#variant.Ident + /// [IcmpEndpoint::Udp]: enum.IcmpEndpoint.html#variant.Udp + /// [send]: #method.send + /// [recv]: #method.recv + pub fn bind<T: Into<Endpoint>>(&mut self, endpoint: T) -> Result<(), BindError> { + let endpoint = endpoint.into(); + if !endpoint.is_specified() { + return Err(BindError::Unaddressable); + } + + if self.is_open() { + return Err(BindError::InvalidState); + } + + self.endpoint = endpoint; + + #[cfg(feature = "async")] + { + self.rx_waker.wake(); + self.tx_waker.wake(); + } + + Ok(()) + } + + /// Check whether the transmit buffer is full. + #[inline] + pub fn can_send(&self) -> bool { + !self.tx_buffer.is_full() + } + + /// Check whether the receive buffer is not empty. + #[inline] + pub fn can_recv(&self) -> bool { + !self.rx_buffer.is_empty() + } + + /// Return the maximum number packets the socket can receive. + #[inline] + pub fn packet_recv_capacity(&self) -> usize { + self.rx_buffer.packet_capacity() + } + + /// Return the maximum number packets the socket can transmit. + #[inline] + pub fn packet_send_capacity(&self) -> usize { + self.tx_buffer.packet_capacity() + } + + /// Return the maximum number of bytes inside the recv buffer. + #[inline] + pub fn payload_recv_capacity(&self) -> usize { + self.rx_buffer.payload_capacity() + } + + /// Return the maximum number of bytes inside the transmit buffer. + #[inline] + pub fn payload_send_capacity(&self) -> usize { + self.tx_buffer.payload_capacity() + } + + /// Check whether the socket is open. + #[inline] + pub fn is_open(&self) -> bool { + self.endpoint != Endpoint::Unspecified + } + + /// Enqueue a packet to be sent to a given remote address, and return a pointer + /// to its payload. + /// + /// This function returns `Err(Error::Exhausted)` if the transmit buffer is full, + /// `Err(Error::Truncated)` if the requested size is larger than the packet buffer + /// size, and `Err(Error::Unaddressable)` if the remote address is unspecified. + pub fn send(&mut self, size: usize, endpoint: IpAddress) -> Result<&mut [u8], SendError> { + if endpoint.is_unspecified() { + return Err(SendError::Unaddressable); + } + + let packet_buf = self + .tx_buffer + .enqueue(size, endpoint) + .map_err(|_| SendError::BufferFull)?; + + net_trace!("icmp:{}: buffer to send {} octets", endpoint, size); + Ok(packet_buf) + } + + /// Enqueue a packet to be send to a given remote address and pass the buffer + /// to the provided closure. The closure then returns the size of the data written + /// into the buffer. + /// + /// Also see [send](#method.send). + pub fn send_with<F>( + &mut self, + max_size: usize, + endpoint: IpAddress, + f: F, + ) -> Result<usize, SendError> + where + F: FnOnce(&mut [u8]) -> usize, + { + if endpoint.is_unspecified() { + return Err(SendError::Unaddressable); + } + + let size = self + .tx_buffer + .enqueue_with_infallible(max_size, endpoint, f) + .map_err(|_| SendError::BufferFull)?; + + net_trace!("icmp:{}: buffer to send {} octets", endpoint, size); + Ok(size) + } + + /// Enqueue a packet to be sent to a given remote address, and fill it from a slice. + /// + /// See also [send](#method.send). + pub fn send_slice(&mut self, data: &[u8], endpoint: IpAddress) -> Result<(), SendError> { + let packet_buf = self.send(data.len(), endpoint)?; + packet_buf.copy_from_slice(data); + Ok(()) + } + + /// Dequeue a packet received from a remote endpoint, and return the `IpAddress` as well + /// as a pointer to the payload. + /// + /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty. + pub fn recv(&mut self) -> Result<(&[u8], IpAddress), RecvError> { + let (endpoint, packet_buf) = self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?; + + net_trace!( + "icmp:{}: receive {} buffered octets", + endpoint, + packet_buf.len() + ); + Ok((packet_buf, endpoint)) + } + + /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice, + /// and return the amount of octets copied as well as the `IpAddress` + /// + /// **Note**: when the size of the provided buffer is smaller than the size of the payload, + /// the packet is dropped and a `RecvError::Truncated` error is returned. + /// + /// See also [recv](#method.recv). + pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpAddress), RecvError> { + let (buffer, endpoint) = self.recv()?; + + if data.len() < buffer.len() { + return Err(RecvError::Truncated); + } + + let length = cmp::min(data.len(), buffer.len()); + data[..length].copy_from_slice(&buffer[..length]); + Ok((length, endpoint)) + } + + /// Filter determining which packets received by the interface are appended to + /// the given sockets received buffer. + pub(crate) fn accepts(&self, cx: &mut Context, ip_repr: &IpRepr, icmp_repr: &IcmpRepr) -> bool { + match (&self.endpoint, icmp_repr) { + // If we are bound to ICMP errors associated to a UDP port, only + // accept Destination Unreachable or Time Exceeded messages with + // the data containing a UDP packet send from the local port we + // are bound to. + #[cfg(feature = "proto-ipv4")] + ( + &Endpoint::Udp(endpoint), + &IcmpRepr::Ipv4( + Icmpv4Repr::DstUnreachable { data, header, .. } + | Icmpv4Repr::TimeExceeded { data, header, .. }, + ), + ) if endpoint.addr.is_none() || endpoint.addr == Some(ip_repr.dst_addr()) => { + let packet = UdpPacket::new_unchecked(data); + match UdpRepr::parse( + &packet, + &header.src_addr.into(), + &header.dst_addr.into(), + &cx.checksum_caps(), + ) { + Ok(repr) => endpoint.port == repr.src_port, + Err(_) => false, + } + } + #[cfg(feature = "proto-ipv6")] + ( + &Endpoint::Udp(endpoint), + &IcmpRepr::Ipv6( + Icmpv6Repr::DstUnreachable { data, header, .. } + | Icmpv6Repr::TimeExceeded { data, header, .. }, + ), + ) if endpoint.addr.is_none() || endpoint.addr == Some(ip_repr.dst_addr()) => { + let packet = UdpPacket::new_unchecked(data); + match UdpRepr::parse( + &packet, + &header.src_addr.into(), + &header.dst_addr.into(), + &cx.checksum_caps(), + ) { + Ok(repr) => endpoint.port == repr.src_port, + Err(_) => false, + } + } + // If we are bound to a specific ICMP identifier value, only accept an + // Echo Request/Reply with the identifier field matching the endpoint + // port. + #[cfg(feature = "proto-ipv4")] + ( + &Endpoint::Ident(bound_ident), + &IcmpRepr::Ipv4(Icmpv4Repr::EchoRequest { ident, .. }), + ) + | ( + &Endpoint::Ident(bound_ident), + &IcmpRepr::Ipv4(Icmpv4Repr::EchoReply { ident, .. }), + ) => ident == bound_ident, + #[cfg(feature = "proto-ipv6")] + ( + &Endpoint::Ident(bound_ident), + &IcmpRepr::Ipv6(Icmpv6Repr::EchoRequest { ident, .. }), + ) + | ( + &Endpoint::Ident(bound_ident), + &IcmpRepr::Ipv6(Icmpv6Repr::EchoReply { ident, .. }), + ) => ident == bound_ident, + _ => false, + } + } + + pub(crate) fn process(&mut self, _cx: &mut Context, ip_repr: &IpRepr, icmp_repr: &IcmpRepr) { + match icmp_repr { + #[cfg(feature = "proto-ipv4")] + IcmpRepr::Ipv4(icmp_repr) => { + net_trace!("icmp: receiving {} octets", icmp_repr.buffer_len()); + + match self + .rx_buffer + .enqueue(icmp_repr.buffer_len(), ip_repr.src_addr()) + { + Ok(packet_buf) => { + icmp_repr.emit( + &mut Icmpv4Packet::new_unchecked(packet_buf), + &ChecksumCapabilities::default(), + ); + } + Err(_) => net_trace!("icmp: buffer full, dropped incoming packet"), + } + } + #[cfg(feature = "proto-ipv6")] + IcmpRepr::Ipv6(icmp_repr) => { + net_trace!("icmp: receiving {} octets", icmp_repr.buffer_len()); + + match self + .rx_buffer + .enqueue(icmp_repr.buffer_len(), ip_repr.src_addr()) + { + Ok(packet_buf) => icmp_repr.emit( + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + &mut Icmpv6Packet::new_unchecked(packet_buf), + &ChecksumCapabilities::default(), + ), + Err(_) => net_trace!("icmp: buffer full, dropped incoming packet"), + } + } + } + + #[cfg(feature = "async")] + self.rx_waker.wake(); + } + + pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E> + where + F: FnOnce(&mut Context, (IpRepr, IcmpRepr)) -> Result<(), E>, + { + let hop_limit = self.hop_limit.unwrap_or(64); + let res = self.tx_buffer.dequeue_with(|remote_endpoint, packet_buf| { + net_trace!( + "icmp:{}: sending {} octets", + remote_endpoint, + packet_buf.len() + ); + match *remote_endpoint { + #[cfg(feature = "proto-ipv4")] + IpAddress::Ipv4(dst_addr) => { + let src_addr = match cx.get_source_address_ipv4(&dst_addr) { + Some(addr) => addr, + None => { + net_trace!( + "icmp:{}: not find suitable source address, dropping", + remote_endpoint + ); + return Ok(()); + } + }; + let packet = Icmpv4Packet::new_unchecked(&*packet_buf); + let repr = match Icmpv4Repr::parse(&packet, &ChecksumCapabilities::ignored()) { + Ok(x) => x, + Err(_) => { + net_trace!( + "icmp:{}: malformed packet in queue, dropping", + remote_endpoint + ); + return Ok(()); + } + }; + let ip_repr = IpRepr::Ipv4(Ipv4Repr { + src_addr, + dst_addr, + next_header: IpProtocol::Icmp, + payload_len: repr.buffer_len(), + hop_limit, + }); + emit(cx, (ip_repr, IcmpRepr::Ipv4(repr))) + } + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(dst_addr) => { + let src_addr = match cx.get_source_address_ipv6(&dst_addr) { + Some(addr) => addr, + None => { + net_trace!( + "icmp:{}: not find suitable source address, dropping", + remote_endpoint + ); + return Ok(()); + } + }; + let packet = Icmpv6Packet::new_unchecked(&*packet_buf); + let repr = match Icmpv6Repr::parse( + &src_addr.into(), + &dst_addr.into(), + &packet, + &ChecksumCapabilities::ignored(), + ) { + Ok(x) => x, + Err(_) => { + net_trace!( + "icmp:{}: malformed packet in queue, dropping", + remote_endpoint + ); + return Ok(()); + } + }; + let ip_repr = IpRepr::Ipv6(Ipv6Repr { + src_addr, + dst_addr, + next_header: IpProtocol::Icmpv6, + payload_len: repr.buffer_len(), + hop_limit, + }); + emit(cx, (ip_repr, IcmpRepr::Ipv6(repr))) + } + } + }); + match res { + Err(Empty) => Ok(()), + Ok(Err(e)) => Err(e), + Ok(Ok(())) => { + #[cfg(feature = "async")] + self.tx_waker.wake(); + Ok(()) + } + } + } + + pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt { + if self.tx_buffer.is_empty() { + PollAt::Ingress + } else { + PollAt::Now + } + } +} + +#[cfg(test)] +mod tests_common { + pub use super::*; + pub use crate::wire::IpAddress; + + pub fn buffer(packets: usize) -> PacketBuffer<'static> { + PacketBuffer::new(vec![PacketMetadata::EMPTY; packets], vec![0; 66 * packets]) + } + + pub fn socket( + rx_buffer: PacketBuffer<'static>, + tx_buffer: PacketBuffer<'static>, + ) -> Socket<'static> { + Socket::new(rx_buffer, tx_buffer) + } + + pub const LOCAL_PORT: u16 = 53; + + pub static UDP_REPR: UdpRepr = UdpRepr { + src_port: 53, + dst_port: 9090, + }; + + pub static UDP_PAYLOAD: &[u8] = &[0xff; 10]; +} + +#[cfg(all(test, feature = "proto-ipv4"))] +mod test_ipv4 { + use crate::phy::Medium; + use crate::tests::setup; + use rstest::*; + + use super::tests_common::*; + use crate::wire::{Icmpv4DstUnreachable, IpEndpoint, Ipv4Address}; + + const REMOTE_IPV4: Ipv4Address = Ipv4Address([192, 168, 1, 2]); + const LOCAL_IPV4: Ipv4Address = Ipv4Address([192, 168, 1, 1]); + const LOCAL_END_V4: IpEndpoint = IpEndpoint { + addr: IpAddress::Ipv4(LOCAL_IPV4), + port: LOCAL_PORT, + }; + + static ECHOV4_REPR: Icmpv4Repr = Icmpv4Repr::EchoRequest { + ident: 0x1234, + seq_no: 0x5678, + data: &[0xff; 16], + }; + + static LOCAL_IPV4_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr { + src_addr: LOCAL_IPV4, + dst_addr: REMOTE_IPV4, + next_header: IpProtocol::Icmp, + payload_len: 24, + hop_limit: 0x40, + }); + + static REMOTE_IPV4_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr { + src_addr: REMOTE_IPV4, + dst_addr: LOCAL_IPV4, + next_header: IpProtocol::Icmp, + payload_len: 24, + hop_limit: 0x40, + }); + + #[test] + fn test_send_unaddressable() { + let mut socket = socket(buffer(0), buffer(1)); + assert_eq!( + socket.send_slice(b"abcdef", IpAddress::Ipv4(Ipv4Address::default())), + Err(SendError::Unaddressable) + ); + assert_eq!(socket.send_slice(b"abcdef", REMOTE_IPV4.into()), Ok(())); + } + + #[rstest] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_send_dispatch(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(0), buffer(1)); + let checksum = ChecksumCapabilities::default(); + + assert_eq!(socket.dispatch(cx, |_, _| unreachable!()), Ok::<_, ()>(())); + + // This buffer is too long + assert_eq!( + socket.send_slice(&[0xff; 67], REMOTE_IPV4.into()), + Err(SendError::BufferFull) + ); + assert!(socket.can_send()); + + let mut bytes = [0xff; 24]; + let mut packet = Icmpv4Packet::new_unchecked(&mut bytes); + ECHOV4_REPR.emit(&mut packet, &checksum); + + assert_eq!( + socket.send_slice(&*packet.into_inner(), REMOTE_IPV4.into()), + Ok(()) + ); + assert_eq!( + socket.send_slice(b"123456", REMOTE_IPV4.into()), + Err(SendError::BufferFull) + ); + assert!(!socket.can_send()); + + assert_eq!( + socket.dispatch(cx, |_, (ip_repr, icmp_repr)| { + assert_eq!(ip_repr, LOCAL_IPV4_REPR); + assert_eq!(icmp_repr, ECHOV4_REPR.into()); + Err(()) + }), + Err(()) + ); + // buffer is not taken off of the tx queue due to the error + assert!(!socket.can_send()); + + assert_eq!( + socket.dispatch(cx, |_, (ip_repr, icmp_repr)| { + assert_eq!(ip_repr, LOCAL_IPV4_REPR); + assert_eq!(icmp_repr, ECHOV4_REPR.into()); + Ok::<_, ()>(()) + }), + Ok(()) + ); + // buffer is taken off of the queue this time + assert!(socket.can_send()); + } + + #[rstest] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_set_hop_limit_v4(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut s = socket(buffer(0), buffer(1)); + let checksum = ChecksumCapabilities::default(); + + let mut bytes = [0xff; 24]; + let mut packet = Icmpv4Packet::new_unchecked(&mut bytes); + ECHOV4_REPR.emit(&mut packet, &checksum); + + s.set_hop_limit(Some(0x2a)); + + assert_eq!( + s.send_slice(&*packet.into_inner(), REMOTE_IPV4.into()), + Ok(()) + ); + assert_eq!( + s.dispatch(cx, |_, (ip_repr, _)| { + assert_eq!( + ip_repr, + IpRepr::Ipv4(Ipv4Repr { + src_addr: LOCAL_IPV4, + dst_addr: REMOTE_IPV4, + next_header: IpProtocol::Icmp, + payload_len: ECHOV4_REPR.buffer_len(), + hop_limit: 0x2a, + }) + ); + Ok::<_, ()>(()) + }), + Ok(()) + ); + } + + #[rstest] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_recv_process(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(1)); + assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(())); + + assert!(!socket.can_recv()); + assert_eq!(socket.recv(), Err(RecvError::Exhausted)); + + let checksum = ChecksumCapabilities::default(); + + let mut bytes = [0xff; 24]; + let mut packet = Icmpv4Packet::new_unchecked(&mut bytes[..]); + ECHOV4_REPR.emit(&mut packet, &checksum); + let data = &*packet.into_inner(); + + assert!(socket.accepts(cx, &REMOTE_IPV4_REPR, &ECHOV4_REPR.into())); + socket.process(cx, &REMOTE_IPV4_REPR, &ECHOV4_REPR.into()); + assert!(socket.can_recv()); + + assert!(socket.accepts(cx, &REMOTE_IPV4_REPR, &ECHOV4_REPR.into())); + socket.process(cx, &REMOTE_IPV4_REPR, &ECHOV4_REPR.into()); + + assert_eq!(socket.recv(), Ok((data, REMOTE_IPV4.into()))); + assert!(!socket.can_recv()); + } + + #[rstest] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_accept_bad_id(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(1)); + assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(())); + + let checksum = ChecksumCapabilities::default(); + let mut bytes = [0xff; 20]; + let mut packet = Icmpv4Packet::new_unchecked(&mut bytes); + let icmp_repr = Icmpv4Repr::EchoRequest { + ident: 0x4321, + seq_no: 0x5678, + data: &[0xff; 16], + }; + icmp_repr.emit(&mut packet, &checksum); + + // Ensure that a packet with an identifier that isn't the bound + // ID is not accepted + assert!(!socket.accepts(cx, &REMOTE_IPV4_REPR, &icmp_repr.into())); + } + + #[rstest] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_accepts_udp(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(1)); + assert_eq!(socket.bind(Endpoint::Udp(LOCAL_END_V4.into())), Ok(())); + + let checksum = ChecksumCapabilities::default(); + + let mut bytes = [0xff; 18]; + let mut packet = UdpPacket::new_unchecked(&mut bytes); + UDP_REPR.emit( + &mut packet, + &REMOTE_IPV4.into(), + &LOCAL_IPV4.into(), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(UDP_PAYLOAD), + &checksum, + ); + + let data = &*packet.into_inner(); + + let icmp_repr = Icmpv4Repr::DstUnreachable { + reason: Icmpv4DstUnreachable::PortUnreachable, + header: Ipv4Repr { + src_addr: LOCAL_IPV4, + dst_addr: REMOTE_IPV4, + next_header: IpProtocol::Icmp, + payload_len: 12, + hop_limit: 0x40, + }, + data, + }; + let ip_repr = IpRepr::Ipv4(Ipv4Repr { + src_addr: REMOTE_IPV4, + dst_addr: LOCAL_IPV4, + next_header: IpProtocol::Icmp, + payload_len: icmp_repr.buffer_len(), + hop_limit: 0x40, + }); + + assert!(!socket.can_recv()); + + // Ensure we can accept ICMP error response to the bound + // UDP port + assert!(socket.accepts(cx, &ip_repr, &icmp_repr.into())); + socket.process(cx, &ip_repr, &icmp_repr.into()); + assert!(socket.can_recv()); + + let mut bytes = [0x00; 46]; + let mut packet = Icmpv4Packet::new_unchecked(&mut bytes[..]); + icmp_repr.emit(&mut packet, &checksum); + assert_eq!( + socket.recv(), + Ok((&*packet.into_inner(), REMOTE_IPV4.into())) + ); + assert!(!socket.can_recv()); + } +} + +#[cfg(all(test, feature = "proto-ipv6"))] +mod test_ipv6 { + use crate::phy::Medium; + use crate::tests::setup; + use rstest::*; + + use super::tests_common::*; + + use crate::wire::{Icmpv6DstUnreachable, IpEndpoint, Ipv6Address}; + + const REMOTE_IPV6: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]); + const LOCAL_IPV6: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); + const LOCAL_END_V6: IpEndpoint = IpEndpoint { + addr: IpAddress::Ipv6(LOCAL_IPV6), + port: LOCAL_PORT, + }; + static ECHOV6_REPR: Icmpv6Repr = Icmpv6Repr::EchoRequest { + ident: 0x1234, + seq_no: 0x5678, + data: &[0xff; 16], + }; + + static LOCAL_IPV6_REPR: IpRepr = IpRepr::Ipv6(Ipv6Repr { + src_addr: LOCAL_IPV6, + dst_addr: REMOTE_IPV6, + next_header: IpProtocol::Icmpv6, + payload_len: 24, + hop_limit: 0x40, + }); + + static REMOTE_IPV6_REPR: IpRepr = IpRepr::Ipv6(Ipv6Repr { + src_addr: REMOTE_IPV6, + dst_addr: LOCAL_IPV6, + next_header: IpProtocol::Icmpv6, + payload_len: 24, + hop_limit: 0x40, + }); + + #[test] + fn test_send_unaddressable() { + let mut socket = socket(buffer(0), buffer(1)); + assert_eq!( + socket.send_slice(b"abcdef", IpAddress::Ipv6(Ipv6Address::default())), + Err(SendError::Unaddressable) + ); + assert_eq!(socket.send_slice(b"abcdef", REMOTE_IPV6.into()), Ok(())); + } + + #[rstest] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_send_dispatch(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(0), buffer(1)); + let checksum = ChecksumCapabilities::default(); + + assert_eq!(socket.dispatch(cx, |_, _| unreachable!()), Ok::<_, ()>(())); + + // This buffer is too long + assert_eq!( + socket.send_slice(&[0xff; 67], REMOTE_IPV6.into()), + Err(SendError::BufferFull) + ); + assert!(socket.can_send()); + + let mut bytes = vec![0xff; 24]; + let mut packet = Icmpv6Packet::new_unchecked(&mut bytes); + ECHOV6_REPR.emit( + &LOCAL_IPV6.into(), + &REMOTE_IPV6.into(), + &mut packet, + &checksum, + ); + + assert_eq!( + socket.send_slice(&*packet.into_inner(), REMOTE_IPV6.into()), + Ok(()) + ); + assert_eq!( + socket.send_slice(b"123456", REMOTE_IPV6.into()), + Err(SendError::BufferFull) + ); + assert!(!socket.can_send()); + + assert_eq!( + socket.dispatch(cx, |_, (ip_repr, icmp_repr)| { + assert_eq!(ip_repr, LOCAL_IPV6_REPR); + assert_eq!(icmp_repr, ECHOV6_REPR.into()); + Err(()) + }), + Err(()) + ); + // buffer is not taken off of the tx queue due to the error + assert!(!socket.can_send()); + + assert_eq!( + socket.dispatch(cx, |_, (ip_repr, icmp_repr)| { + assert_eq!(ip_repr, LOCAL_IPV6_REPR); + assert_eq!(icmp_repr, ECHOV6_REPR.into()); + Ok::<_, ()>(()) + }), + Ok(()) + ); + // buffer is taken off of the queue this time + assert!(socket.can_send()); + } + + #[rstest] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_set_hop_limit(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut s = socket(buffer(0), buffer(1)); + let checksum = ChecksumCapabilities::default(); + + let mut bytes = vec![0xff; 24]; + let mut packet = Icmpv6Packet::new_unchecked(&mut bytes); + ECHOV6_REPR.emit( + &LOCAL_IPV6.into(), + &REMOTE_IPV6.into(), + &mut packet, + &checksum, + ); + + s.set_hop_limit(Some(0x2a)); + + assert_eq!( + s.send_slice(&*packet.into_inner(), REMOTE_IPV6.into()), + Ok(()) + ); + assert_eq!( + s.dispatch(cx, |_, (ip_repr, _)| { + assert_eq!( + ip_repr, + IpRepr::Ipv6(Ipv6Repr { + src_addr: LOCAL_IPV6, + dst_addr: REMOTE_IPV6, + next_header: IpProtocol::Icmpv6, + payload_len: ECHOV6_REPR.buffer_len(), + hop_limit: 0x2a, + }) + ); + Ok::<_, ()>(()) + }), + Ok(()) + ); + } + + #[rstest] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_recv_process(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(1)); + assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(())); + + assert!(!socket.can_recv()); + assert_eq!(socket.recv(), Err(RecvError::Exhausted)); + + let checksum = ChecksumCapabilities::default(); + + let mut bytes = [0xff; 24]; + let mut packet = Icmpv6Packet::new_unchecked(&mut bytes[..]); + ECHOV6_REPR.emit( + &LOCAL_IPV6.into(), + &REMOTE_IPV6.into(), + &mut packet, + &checksum, + ); + let data = &*packet.into_inner(); + + assert!(socket.accepts(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into())); + socket.process(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()); + assert!(socket.can_recv()); + + assert!(socket.accepts(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into())); + socket.process(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()); + + assert_eq!(socket.recv(), Ok((data, REMOTE_IPV6.into()))); + assert!(!socket.can_recv()); + } + + #[rstest] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_truncated_recv_slice(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(1)); + assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(())); + + let checksum = ChecksumCapabilities::default(); + + let mut bytes = [0xff; 24]; + let mut packet = Icmpv6Packet::new_unchecked(&mut bytes[..]); + ECHOV6_REPR.emit( + &LOCAL_IPV6.into(), + &REMOTE_IPV6.into(), + &mut packet, + &checksum, + ); + + assert!(socket.accepts(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into())); + socket.process(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()); + assert!(socket.can_recv()); + + assert!(socket.accepts(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into())); + socket.process(cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()); + + let mut buffer = [0u8; 1]; + assert_eq!( + socket.recv_slice(&mut buffer[..]), + Err(RecvError::Truncated) + ); + assert!(!socket.can_recv()); + } + + #[rstest] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_accept_bad_id(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(1)); + assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(())); + + let checksum = ChecksumCapabilities::default(); + let mut bytes = [0xff; 20]; + let mut packet = Icmpv6Packet::new_unchecked(&mut bytes); + let icmp_repr = Icmpv6Repr::EchoRequest { + ident: 0x4321, + seq_no: 0x5678, + data: &[0xff; 16], + }; + icmp_repr.emit( + &LOCAL_IPV6.into(), + &REMOTE_IPV6.into(), + &mut packet, + &checksum, + ); + + // Ensure that a packet with an identifier that isn't the bound + // ID is not accepted + assert!(!socket.accepts(cx, &REMOTE_IPV6_REPR, &icmp_repr.into())); + } + + #[rstest] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + fn test_accepts_udp(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(1)); + assert_eq!(socket.bind(Endpoint::Udp(LOCAL_END_V6.into())), Ok(())); + + let checksum = ChecksumCapabilities::default(); + + let mut bytes = [0xff; 18]; + let mut packet = UdpPacket::new_unchecked(&mut bytes); + UDP_REPR.emit( + &mut packet, + &REMOTE_IPV6.into(), + &LOCAL_IPV6.into(), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(UDP_PAYLOAD), + &checksum, + ); + + let data = &*packet.into_inner(); + + let icmp_repr = Icmpv6Repr::DstUnreachable { + reason: Icmpv6DstUnreachable::PortUnreachable, + header: Ipv6Repr { + src_addr: LOCAL_IPV6, + dst_addr: REMOTE_IPV6, + next_header: IpProtocol::Icmpv6, + payload_len: 12, + hop_limit: 0x40, + }, + data, + }; + let ip_repr = IpRepr::Ipv6(Ipv6Repr { + src_addr: REMOTE_IPV6, + dst_addr: LOCAL_IPV6, + next_header: IpProtocol::Icmpv6, + payload_len: icmp_repr.buffer_len(), + hop_limit: 0x40, + }); + + assert!(!socket.can_recv()); + + // Ensure we can accept ICMP error response to the bound + // UDP port + assert!(socket.accepts(cx, &ip_repr, &icmp_repr.into())); + socket.process(cx, &ip_repr, &icmp_repr.into()); + assert!(socket.can_recv()); + + let mut bytes = [0x00; 66]; + let mut packet = Icmpv6Packet::new_unchecked(&mut bytes[..]); + icmp_repr.emit( + &LOCAL_IPV6.into(), + &REMOTE_IPV6.into(), + &mut packet, + &checksum, + ); + assert_eq!( + socket.recv(), + Ok((&*packet.into_inner(), REMOTE_IPV6.into())) + ); + assert!(!socket.can_recv()); + } +} diff --git a/src/socket/mod.rs b/src/socket/mod.rs new file mode 100644 index 0000000..7d48b42 --- /dev/null +++ b/src/socket/mod.rs @@ -0,0 +1,141 @@ +/*! Communication between endpoints. + +The `socket` module deals with *network endpoints* and *buffering*. +It provides interfaces for accessing buffers of data, and protocol state machines +for filling and emptying these buffers. + +The programming interface implemented here differs greatly from the common Berkeley socket +interface. Specifically, in the Berkeley interface the buffering is implicit: +the operating system decides on the good size for a buffer and manages it. +The interface implemented by this module uses explicit buffering: you decide on the good +size for a buffer, allocate it, and let the networking stack use it. +*/ + +use crate::iface::Context; +use crate::time::Instant; + +#[cfg(feature = "socket-dhcpv4")] +pub mod dhcpv4; +#[cfg(feature = "socket-dns")] +pub mod dns; +#[cfg(feature = "socket-icmp")] +pub mod icmp; +#[cfg(feature = "socket-raw")] +pub mod raw; +#[cfg(feature = "socket-tcp")] +pub mod tcp; +#[cfg(feature = "socket-udp")] +pub mod udp; + +#[cfg(feature = "async")] +mod waker; + +#[cfg(feature = "async")] +pub(crate) use self::waker::WakerRegistration; + +/// Gives an indication on the next time the socket should be polled. +#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) enum PollAt { + /// The socket needs to be polled immediately. + Now, + /// The socket needs to be polled at given [Instant][struct.Instant]. + Time(Instant), + /// The socket does not need to be polled unless there are external changes. + Ingress, +} + +/// A network socket. +/// +/// This enumeration abstracts the various types of sockets based on the IP protocol. +/// To downcast a `Socket` value to a concrete socket, use the [AnySocket] trait, +/// e.g. to get `udp::Socket`, call `udp::Socket::downcast(socket)`. +/// +/// It is usually more convenient to use [SocketSet::get] instead. +/// +/// [AnySocket]: trait.AnySocket.html +/// [SocketSet::get]: struct.SocketSet.html#method.get +#[derive(Debug)] +pub enum Socket<'a> { + #[cfg(feature = "socket-raw")] + Raw(raw::Socket<'a>), + #[cfg(feature = "socket-icmp")] + Icmp(icmp::Socket<'a>), + #[cfg(feature = "socket-udp")] + Udp(udp::Socket<'a>), + #[cfg(feature = "socket-tcp")] + Tcp(tcp::Socket<'a>), + #[cfg(feature = "socket-dhcpv4")] + Dhcpv4(dhcpv4::Socket<'a>), + #[cfg(feature = "socket-dns")] + Dns(dns::Socket<'a>), +} + +impl<'a> Socket<'a> { + pub(crate) fn poll_at(&self, cx: &mut Context) -> PollAt { + match self { + #[cfg(feature = "socket-raw")] + Socket::Raw(s) => s.poll_at(cx), + #[cfg(feature = "socket-icmp")] + Socket::Icmp(s) => s.poll_at(cx), + #[cfg(feature = "socket-udp")] + Socket::Udp(s) => s.poll_at(cx), + #[cfg(feature = "socket-tcp")] + Socket::Tcp(s) => s.poll_at(cx), + #[cfg(feature = "socket-dhcpv4")] + Socket::Dhcpv4(s) => s.poll_at(cx), + #[cfg(feature = "socket-dns")] + Socket::Dns(s) => s.poll_at(cx), + } + } +} + +/// A conversion trait for network sockets. +pub trait AnySocket<'a> { + fn upcast(self) -> Socket<'a>; + fn downcast<'c>(socket: &'c Socket<'a>) -> Option<&'c Self> + where + Self: Sized; + fn downcast_mut<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self> + where + Self: Sized; +} + +macro_rules! from_socket { + ($socket:ty, $variant:ident) => { + impl<'a> AnySocket<'a> for $socket { + fn upcast(self) -> Socket<'a> { + Socket::$variant(self) + } + + fn downcast<'c>(socket: &'c Socket<'a>) -> Option<&'c Self> { + #[allow(unreachable_patterns)] + match socket { + Socket::$variant(socket) => Some(socket), + _ => None, + } + } + + fn downcast_mut<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self> { + #[allow(unreachable_patterns)] + match socket { + Socket::$variant(socket) => Some(socket), + _ => None, + } + } + } + }; +} + +#[cfg(feature = "socket-raw")] +from_socket!(raw::Socket<'a>, Raw); +#[cfg(feature = "socket-icmp")] +from_socket!(icmp::Socket<'a>, Icmp); +#[cfg(feature = "socket-udp")] +from_socket!(udp::Socket<'a>, Udp); +#[cfg(feature = "socket-tcp")] +from_socket!(tcp::Socket<'a>, Tcp); +#[cfg(feature = "socket-dhcpv4")] +from_socket!(dhcpv4::Socket<'a>, Dhcpv4); +#[cfg(feature = "socket-dns")] +from_socket!(dns::Socket<'a>, Dns); diff --git a/src/socket/raw.rs b/src/socket/raw.rs new file mode 100644 index 0000000..bb3a204 --- /dev/null +++ b/src/socket/raw.rs @@ -0,0 +1,848 @@ +use core::cmp::min; +#[cfg(feature = "async")] +use core::task::Waker; + +use crate::iface::Context; +use crate::socket::PollAt; +#[cfg(feature = "async")] +use crate::socket::WakerRegistration; + +use crate::storage::Empty; +use crate::wire::{IpProtocol, IpRepr, IpVersion}; +#[cfg(feature = "proto-ipv4")] +use crate::wire::{Ipv4Packet, Ipv4Repr}; +#[cfg(feature = "proto-ipv6")] +use crate::wire::{Ipv6Packet, Ipv6Repr}; + +/// Error returned by [`Socket::bind`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum BindError { + InvalidState, + Unaddressable, +} + +impl core::fmt::Display for BindError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + BindError::InvalidState => write!(f, "invalid state"), + BindError::Unaddressable => write!(f, "unaddressable"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for BindError {} + +/// Error returned by [`Socket::send`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum SendError { + BufferFull, +} + +impl core::fmt::Display for SendError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + SendError::BufferFull => write!(f, "buffer full"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for SendError {} + +/// Error returned by [`Socket::recv`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum RecvError { + Exhausted, + Truncated, +} + +impl core::fmt::Display for RecvError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + RecvError::Exhausted => write!(f, "exhausted"), + RecvError::Truncated => write!(f, "truncated"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for RecvError {} + +/// A UDP packet metadata. +pub type PacketMetadata = crate::storage::PacketMetadata<()>; + +/// A UDP packet ring buffer. +pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, ()>; + +/// A raw IP socket. +/// +/// A raw socket is bound to a specific IP protocol, and owns +/// transmit and receive packet buffers. +#[derive(Debug)] +pub struct Socket<'a> { + ip_version: IpVersion, + ip_protocol: IpProtocol, + rx_buffer: PacketBuffer<'a>, + tx_buffer: PacketBuffer<'a>, + #[cfg(feature = "async")] + rx_waker: WakerRegistration, + #[cfg(feature = "async")] + tx_waker: WakerRegistration, +} + +impl<'a> Socket<'a> { + /// Create a raw IP socket bound to the given IP version and datagram protocol, + /// with the given buffers. + pub fn new( + ip_version: IpVersion, + ip_protocol: IpProtocol, + rx_buffer: PacketBuffer<'a>, + tx_buffer: PacketBuffer<'a>, + ) -> Socket<'a> { + Socket { + ip_version, + ip_protocol, + rx_buffer, + tx_buffer, + #[cfg(feature = "async")] + rx_waker: WakerRegistration::new(), + #[cfg(feature = "async")] + tx_waker: WakerRegistration::new(), + } + } + + /// Register a waker for receive operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `recv` method calls, such as receiving data, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `recv` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_recv_waker(&mut self, waker: &Waker) { + self.rx_waker.register(waker) + } + + /// Register a waker for send operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `send` method calls, such as space becoming available in the transmit + /// buffer, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `send` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_send_waker(&mut self, waker: &Waker) { + self.tx_waker.register(waker) + } + + /// Return the IP version the socket is bound to. + #[inline] + pub fn ip_version(&self) -> IpVersion { + self.ip_version + } + + /// Return the IP protocol the socket is bound to. + #[inline] + pub fn ip_protocol(&self) -> IpProtocol { + self.ip_protocol + } + + /// Check whether the transmit buffer is full. + #[inline] + pub fn can_send(&self) -> bool { + !self.tx_buffer.is_full() + } + + /// Check whether the receive buffer is not empty. + #[inline] + pub fn can_recv(&self) -> bool { + !self.rx_buffer.is_empty() + } + + /// Return the maximum number packets the socket can receive. + #[inline] + pub fn packet_recv_capacity(&self) -> usize { + self.rx_buffer.packet_capacity() + } + + /// Return the maximum number packets the socket can transmit. + #[inline] + pub fn packet_send_capacity(&self) -> usize { + self.tx_buffer.packet_capacity() + } + + /// Return the maximum number of bytes inside the recv buffer. + #[inline] + pub fn payload_recv_capacity(&self) -> usize { + self.rx_buffer.payload_capacity() + } + + /// Return the maximum number of bytes inside the transmit buffer. + #[inline] + pub fn payload_send_capacity(&self) -> usize { + self.tx_buffer.payload_capacity() + } + + /// Enqueue a packet to send, and return a pointer to its payload. + /// + /// This function returns `Err(Error::Exhausted)` if the transmit buffer is full, + /// and `Err(Error::Truncated)` if there is not enough transmit buffer capacity + /// to ever send this packet. + /// + /// If the buffer is filled in a way that does not match the socket's + /// IP version or protocol, the packet will be silently dropped. + /// + /// **Note:** The IP header is parsed and re-serialized, and may not match + /// the header actually transmitted bit for bit. + pub fn send(&mut self, size: usize) -> Result<&mut [u8], SendError> { + let packet_buf = self + .tx_buffer + .enqueue(size, ()) + .map_err(|_| SendError::BufferFull)?; + + net_trace!( + "raw:{}:{}: buffer to send {} octets", + self.ip_version, + self.ip_protocol, + packet_buf.len() + ); + Ok(packet_buf) + } + + /// Enqueue a packet to be send and pass the buffer to the provided closure. + /// The closure then returns the size of the data written into the buffer. + /// + /// Also see [send](#method.send). + pub fn send_with<F>(&mut self, max_size: usize, f: F) -> Result<usize, SendError> + where + F: FnOnce(&mut [u8]) -> usize, + { + let size = self + .tx_buffer + .enqueue_with_infallible(max_size, (), f) + .map_err(|_| SendError::BufferFull)?; + + net_trace!( + "raw:{}:{}: buffer to send {} octets", + self.ip_version, + self.ip_protocol, + size + ); + + Ok(size) + } + + /// Enqueue a packet to send, and fill it from a slice. + /// + /// See also [send](#method.send). + pub fn send_slice(&mut self, data: &[u8]) -> Result<(), SendError> { + self.send(data.len())?.copy_from_slice(data); + Ok(()) + } + + /// Dequeue a packet, and return a pointer to the payload. + /// + /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty. + /// + /// **Note:** The IP header is parsed and re-serialized, and may not match + /// the header actually received bit for bit. + pub fn recv(&mut self) -> Result<&[u8], RecvError> { + let ((), packet_buf) = self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?; + + net_trace!( + "raw:{}:{}: receive {} buffered octets", + self.ip_version, + self.ip_protocol, + packet_buf.len() + ); + Ok(packet_buf) + } + + /// Dequeue a packet, and copy the payload into the given slice. + /// + /// **Note**: when the size of the provided buffer is smaller than the size of the payload, + /// the packet is dropped and a `RecvError::Truncated` error is returned. + /// + /// See also [recv](#method.recv). + pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize, RecvError> { + let buffer = self.recv()?; + if data.len() < buffer.len() { + return Err(RecvError::Truncated); + } + + let length = min(data.len(), buffer.len()); + data[..length].copy_from_slice(&buffer[..length]); + Ok(length) + } + + /// Peek at a packet in the receive buffer and return a pointer to the + /// payload without removing the packet from the receive buffer. + /// This function otherwise behaves identically to [recv](#method.recv). + /// + /// It returns `Err(Error::Exhausted)` if the receive buffer is empty. + pub fn peek(&mut self) -> Result<&[u8], RecvError> { + let ((), packet_buf) = self.rx_buffer.peek().map_err(|_| RecvError::Exhausted)?; + + net_trace!( + "raw:{}:{}: receive {} buffered octets", + self.ip_version, + self.ip_protocol, + packet_buf.len() + ); + + Ok(packet_buf) + } + + /// Peek at a packet in the receive buffer, copy the payload into the given slice, + /// and return the amount of octets copied without removing the packet from the receive buffer. + /// This function otherwise behaves identically to [recv_slice](#method.recv_slice). + /// + /// **Note**: when the size of the provided buffer is smaller than the size of the payload, + /// no data is copied into the provided buffer and a `RecvError::Truncated` error is returned. + /// + /// See also [peek](#method.peek). + pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<usize, RecvError> { + let buffer = self.peek()?; + if data.len() < buffer.len() { + return Err(RecvError::Truncated); + } + + let length = min(data.len(), buffer.len()); + data[..length].copy_from_slice(&buffer[..length]); + Ok(length) + } + + pub(crate) fn accepts(&self, ip_repr: &IpRepr) -> bool { + if ip_repr.version() != self.ip_version { + return false; + } + if ip_repr.next_header() != self.ip_protocol { + return false; + } + + true + } + + pub(crate) fn process(&mut self, cx: &mut Context, ip_repr: &IpRepr, payload: &[u8]) { + debug_assert!(self.accepts(ip_repr)); + + let header_len = ip_repr.header_len(); + let total_len = header_len + payload.len(); + + net_trace!( + "raw:{}:{}: receiving {} octets", + self.ip_version, + self.ip_protocol, + total_len + ); + + match self.rx_buffer.enqueue(total_len, ()) { + Ok(buf) => { + ip_repr.emit(&mut buf[..header_len], &cx.checksum_caps()); + buf[header_len..].copy_from_slice(payload); + } + Err(_) => net_trace!( + "raw:{}:{}: buffer full, dropped incoming packet", + self.ip_version, + self.ip_protocol + ), + } + + #[cfg(feature = "async")] + self.rx_waker.wake(); + } + + pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E> + where + F: FnOnce(&mut Context, (IpRepr, &[u8])) -> Result<(), E>, + { + let ip_protocol = self.ip_protocol; + let ip_version = self.ip_version; + let _checksum_caps = &cx.checksum_caps(); + let res = self.tx_buffer.dequeue_with(|&mut (), buffer| { + match IpVersion::of_packet(buffer) { + #[cfg(feature = "proto-ipv4")] + Ok(IpVersion::Ipv4) => { + let mut packet = match Ipv4Packet::new_checked(buffer) { + Ok(x) => x, + Err(_) => { + net_trace!("raw: malformed ipv6 packet in queue, dropping."); + return Ok(()); + } + }; + if packet.next_header() != ip_protocol { + net_trace!("raw: sent packet with wrong ip protocol, dropping."); + return Ok(()); + } + if _checksum_caps.ipv4.tx() { + packet.fill_checksum(); + } else { + // make sure we get a consistently zeroed checksum, + // since implementations might rely on it + packet.set_checksum(0); + } + + let packet = Ipv4Packet::new_unchecked(&*packet.into_inner()); + let ipv4_repr = match Ipv4Repr::parse(&packet, _checksum_caps) { + Ok(x) => x, + Err(_) => { + net_trace!("raw: malformed ipv4 packet in queue, dropping."); + return Ok(()); + } + }; + net_trace!("raw:{}:{}: sending", ip_version, ip_protocol); + emit(cx, (IpRepr::Ipv4(ipv4_repr), packet.payload())) + } + #[cfg(feature = "proto-ipv6")] + Ok(IpVersion::Ipv6) => { + let packet = match Ipv6Packet::new_checked(buffer) { + Ok(x) => x, + Err(_) => { + net_trace!("raw: malformed ipv6 packet in queue, dropping."); + return Ok(()); + } + }; + if packet.next_header() != ip_protocol { + net_trace!("raw: sent ipv6 packet with wrong ip protocol, dropping."); + return Ok(()); + } + let packet = Ipv6Packet::new_unchecked(&*packet.into_inner()); + let ipv6_repr = match Ipv6Repr::parse(&packet) { + Ok(x) => x, + Err(_) => { + net_trace!("raw: malformed ipv6 packet in queue, dropping."); + return Ok(()); + } + }; + + net_trace!("raw:{}:{}: sending", ip_version, ip_protocol); + emit(cx, (IpRepr::Ipv6(ipv6_repr), packet.payload())) + } + Err(_) => { + net_trace!("raw: sent packet with invalid IP version, dropping."); + Ok(()) + } + } + }); + match res { + Err(Empty) => Ok(()), + Ok(Err(e)) => Err(e), + Ok(Ok(())) => { + #[cfg(feature = "async")] + self.tx_waker.wake(); + Ok(()) + } + } + } + + pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt { + if self.tx_buffer.is_empty() { + PollAt::Ingress + } else { + PollAt::Now + } + } +} + +#[cfg(test)] +mod test { + use crate::phy::Medium; + use crate::tests::setup; + use rstest::*; + + use super::*; + use crate::wire::IpRepr; + #[cfg(feature = "proto-ipv4")] + use crate::wire::{Ipv4Address, Ipv4Repr}; + #[cfg(feature = "proto-ipv6")] + use crate::wire::{Ipv6Address, Ipv6Repr}; + + fn buffer(packets: usize) -> PacketBuffer<'static> { + PacketBuffer::new(vec![PacketMetadata::EMPTY; packets], vec![0; 48 * packets]) + } + + #[cfg(feature = "proto-ipv4")] + mod ipv4_locals { + use super::*; + + pub fn socket( + rx_buffer: PacketBuffer<'static>, + tx_buffer: PacketBuffer<'static>, + ) -> Socket<'static> { + Socket::new( + IpVersion::Ipv4, + IpProtocol::Unknown(IP_PROTO), + rx_buffer, + tx_buffer, + ) + } + + pub const IP_PROTO: u8 = 63; + pub const HEADER_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr { + src_addr: Ipv4Address([10, 0, 0, 1]), + dst_addr: Ipv4Address([10, 0, 0, 2]), + next_header: IpProtocol::Unknown(IP_PROTO), + payload_len: 4, + hop_limit: 64, + }); + pub const PACKET_BYTES: [u8; 24] = [ + 0x45, 0x00, 0x00, 0x18, 0x00, 0x00, 0x40, 0x00, 0x40, 0x3f, 0x00, 0x00, 0x0a, 0x00, + 0x00, 0x01, 0x0a, 0x00, 0x00, 0x02, 0xaa, 0x00, 0x00, 0xff, + ]; + pub const PACKET_PAYLOAD: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; + } + + #[cfg(feature = "proto-ipv6")] + mod ipv6_locals { + use super::*; + + pub fn socket( + rx_buffer: PacketBuffer<'static>, + tx_buffer: PacketBuffer<'static>, + ) -> Socket<'static> { + Socket::new( + IpVersion::Ipv6, + IpProtocol::Unknown(IP_PROTO), + rx_buffer, + tx_buffer, + ) + } + + pub const IP_PROTO: u8 = 63; + pub const HEADER_REPR: IpRepr = IpRepr::Ipv6(Ipv6Repr { + src_addr: Ipv6Address([ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ]), + dst_addr: Ipv6Address([ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x02, + ]), + next_header: IpProtocol::Unknown(IP_PROTO), + payload_len: 4, + hop_limit: 64, + }); + + pub const PACKET_BYTES: [u8; 44] = [ + 0x60, 0x00, 0x00, 0x00, 0x00, 0x04, 0x3f, 0x40, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xaa, 0x00, + 0x00, 0xff, + ]; + + pub const PACKET_PAYLOAD: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; + } + + macro_rules! reusable_ip_specific_tests { + ($module:ident, $socket:path, $hdr:path, $packet:path, $payload:path) => { + mod $module { + use super::*; + + #[test] + fn test_send_truncated() { + let mut socket = $socket(buffer(0), buffer(1)); + assert_eq!(socket.send_slice(&[0; 56][..]), Err(SendError::BufferFull)); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_send_dispatch(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let mut cx = iface.context(); + let mut socket = $socket(buffer(0), buffer(1)); + + assert!(socket.can_send()); + assert_eq!( + socket.dispatch(&mut cx, |_, _| unreachable!()), + Ok::<_, ()>(()) + ); + + assert_eq!(socket.send_slice(&$packet[..]), Ok(())); + assert_eq!(socket.send_slice(b""), Err(SendError::BufferFull)); + assert!(!socket.can_send()); + + assert_eq!( + socket.dispatch(&mut cx, |_, (ip_repr, ip_payload)| { + assert_eq!(ip_repr, $hdr); + assert_eq!(ip_payload, &$payload); + Err(()) + }), + Err(()) + ); + assert!(!socket.can_send()); + + assert_eq!( + socket.dispatch(&mut cx, |_, (ip_repr, ip_payload)| { + assert_eq!(ip_repr, $hdr); + assert_eq!(ip_payload, &$payload); + Ok::<_, ()>(()) + }), + Ok(()) + ); + assert!(socket.can_send()); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_recv_truncated_slice(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let mut cx = iface.context(); + let mut socket = $socket(buffer(1), buffer(0)); + + assert!(socket.accepts(&$hdr)); + socket.process(&mut cx, &$hdr, &$payload); + + let mut slice = [0; 4]; + assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated)); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_recv_truncated_packet(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let mut cx = iface.context(); + let mut socket = $socket(buffer(1), buffer(0)); + + let mut buffer = vec![0; 128]; + buffer[..$packet.len()].copy_from_slice(&$packet[..]); + + assert!(socket.accepts(&$hdr)); + socket.process(&mut cx, &$hdr, &buffer); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_peek_truncated_slice(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let mut cx = iface.context(); + let mut socket = $socket(buffer(1), buffer(0)); + + assert!(socket.accepts(&$hdr)); + socket.process(&mut cx, &$hdr, &$payload); + + let mut slice = [0; 4]; + assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Truncated)); + assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated)); + assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted)); + } + } + }; + } + + #[cfg(feature = "proto-ipv4")] + reusable_ip_specific_tests!( + ipv4, + ipv4_locals::socket, + ipv4_locals::HEADER_REPR, + ipv4_locals::PACKET_BYTES, + ipv4_locals::PACKET_PAYLOAD + ); + + #[cfg(feature = "proto-ipv6")] + reusable_ip_specific_tests!( + ipv6, + ipv6_locals::socket, + ipv6_locals::HEADER_REPR, + ipv6_locals::PACKET_BYTES, + ipv6_locals::PACKET_PAYLOAD + ); + + #[rstest] + #[case::ip(Medium::Ip)] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_send_illegal(#[case] medium: Medium) { + #[cfg(feature = "proto-ipv4")] + { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + let mut socket = ipv4_locals::socket(buffer(0), buffer(2)); + + let mut wrong_version = ipv4_locals::PACKET_BYTES; + Ipv4Packet::new_unchecked(&mut wrong_version).set_version(6); + + assert_eq!(socket.send_slice(&wrong_version[..]), Ok(())); + assert_eq!(socket.dispatch(cx, |_, _| unreachable!()), Ok::<_, ()>(())); + + let mut wrong_protocol = ipv4_locals::PACKET_BYTES; + Ipv4Packet::new_unchecked(&mut wrong_protocol).set_next_header(IpProtocol::Tcp); + + assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(())); + assert_eq!(socket.dispatch(cx, |_, _| unreachable!()), Ok::<_, ()>(())); + } + #[cfg(feature = "proto-ipv6")] + { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + let mut socket = ipv6_locals::socket(buffer(0), buffer(2)); + + let mut wrong_version = ipv6_locals::PACKET_BYTES; + Ipv6Packet::new_unchecked(&mut wrong_version[..]).set_version(4); + + assert_eq!(socket.send_slice(&wrong_version[..]), Ok(())); + assert_eq!(socket.dispatch(cx, |_, _| unreachable!()), Ok::<_, ()>(())); + + let mut wrong_protocol = ipv6_locals::PACKET_BYTES; + Ipv6Packet::new_unchecked(&mut wrong_protocol[..]).set_next_header(IpProtocol::Tcp); + + assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(())); + assert_eq!(socket.dispatch(cx, |_, _| unreachable!()), Ok::<_, ()>(())); + } + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_recv_process(#[case] medium: Medium) { + #[cfg(feature = "proto-ipv4")] + { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + let mut socket = ipv4_locals::socket(buffer(1), buffer(0)); + assert!(!socket.can_recv()); + + let mut cksumd_packet = ipv4_locals::PACKET_BYTES; + Ipv4Packet::new_unchecked(&mut cksumd_packet).fill_checksum(); + + assert_eq!(socket.recv(), Err(RecvError::Exhausted)); + assert!(socket.accepts(&ipv4_locals::HEADER_REPR)); + socket.process(cx, &ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD); + assert!(socket.can_recv()); + + assert!(socket.accepts(&ipv4_locals::HEADER_REPR)); + socket.process(cx, &ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD); + assert_eq!(socket.recv(), Ok(&cksumd_packet[..])); + assert!(!socket.can_recv()); + } + #[cfg(feature = "proto-ipv6")] + { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + let mut socket = ipv6_locals::socket(buffer(1), buffer(0)); + assert!(!socket.can_recv()); + + assert_eq!(socket.recv(), Err(RecvError::Exhausted)); + assert!(socket.accepts(&ipv6_locals::HEADER_REPR)); + socket.process(cx, &ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD); + assert!(socket.can_recv()); + + assert!(socket.accepts(&ipv6_locals::HEADER_REPR)); + socket.process(cx, &ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD); + assert_eq!(socket.recv(), Ok(&ipv6_locals::PACKET_BYTES[..])); + assert!(!socket.can_recv()); + } + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_peek_process(#[case] medium: Medium) { + #[cfg(feature = "proto-ipv4")] + { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + let mut socket = ipv4_locals::socket(buffer(1), buffer(0)); + + let mut cksumd_packet = ipv4_locals::PACKET_BYTES; + Ipv4Packet::new_unchecked(&mut cksumd_packet).fill_checksum(); + + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); + assert!(socket.accepts(&ipv4_locals::HEADER_REPR)); + socket.process(cx, &ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD); + + assert!(socket.accepts(&ipv4_locals::HEADER_REPR)); + socket.process(cx, &ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD); + assert_eq!(socket.peek(), Ok(&cksumd_packet[..])); + assert_eq!(socket.recv(), Ok(&cksumd_packet[..])); + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); + } + #[cfg(feature = "proto-ipv6")] + { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + let mut socket = ipv6_locals::socket(buffer(1), buffer(0)); + + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); + assert!(socket.accepts(&ipv6_locals::HEADER_REPR)); + socket.process(cx, &ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD); + + assert!(socket.accepts(&ipv6_locals::HEADER_REPR)); + socket.process(cx, &ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD); + assert_eq!(socket.peek(), Ok(&ipv6_locals::PACKET_BYTES[..])); + assert_eq!(socket.recv(), Ok(&ipv6_locals::PACKET_BYTES[..])); + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); + } + } + + #[test] + fn test_doesnt_accept_wrong_proto() { + #[cfg(feature = "proto-ipv4")] + { + let socket = Socket::new( + IpVersion::Ipv4, + IpProtocol::Unknown(ipv4_locals::IP_PROTO + 1), + buffer(1), + buffer(1), + ); + assert!(!socket.accepts(&ipv4_locals::HEADER_REPR)); + #[cfg(feature = "proto-ipv6")] + assert!(!socket.accepts(&ipv6_locals::HEADER_REPR)); + } + #[cfg(feature = "proto-ipv6")] + { + let socket = Socket::new( + IpVersion::Ipv6, + IpProtocol::Unknown(ipv6_locals::IP_PROTO + 1), + buffer(1), + buffer(1), + ); + assert!(!socket.accepts(&ipv6_locals::HEADER_REPR)); + #[cfg(feature = "proto-ipv4")] + assert!(!socket.accepts(&ipv4_locals::HEADER_REPR)); + } + } +} diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs new file mode 100644 index 0000000..d7b85ab --- /dev/null +++ b/src/socket/tcp.rs @@ -0,0 +1,7312 @@ +// Heads up! Before working on this file you should read, at least, RFC 793 and +// the parts of RFC 1122 that discuss TCP. Consult RFC 7414 when implementing +// a new feature. + +use core::fmt::Display; +#[cfg(feature = "async")] +use core::task::Waker; +use core::{cmp, fmt, mem}; + +#[cfg(feature = "async")] +use crate::socket::WakerRegistration; +use crate::socket::{Context, PollAt}; +use crate::storage::{Assembler, RingBuffer}; +use crate::time::{Duration, Instant}; +use crate::wire::{ + IpAddress, IpEndpoint, IpListenEndpoint, IpProtocol, IpRepr, TcpControl, TcpRepr, TcpSeqNumber, + TCP_HEADER_LEN, +}; + +macro_rules! tcp_trace { + ($($arg:expr),*) => (net_log!(trace, $($arg),*)); +} + +/// Error returned by [`Socket::listen`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum ListenError { + InvalidState, + Unaddressable, +} + +impl Display for ListenError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + ListenError::InvalidState => write!(f, "invalid state"), + ListenError::Unaddressable => write!(f, "unaddressable destination"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for ListenError {} + +/// Error returned by [`Socket::connect`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum ConnectError { + InvalidState, + Unaddressable, +} + +impl Display for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + ConnectError::InvalidState => write!(f, "invalid state"), + ConnectError::Unaddressable => write!(f, "unaddressable destination"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for ConnectError {} + +/// Error returned by [`Socket::send`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum SendError { + InvalidState, +} + +impl Display for SendError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + SendError::InvalidState => write!(f, "invalid state"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for SendError {} + +/// Error returned by [`Socket::recv`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum RecvError { + InvalidState, + Finished, +} + +impl Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + RecvError::InvalidState => write!(f, "invalid state"), + RecvError::Finished => write!(f, "operation finished"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for RecvError {} + +/// A TCP socket ring buffer. +pub type SocketBuffer<'a> = RingBuffer<'a, u8>; + +/// The state of a TCP socket, according to [RFC 793]. +/// +/// [RFC 793]: https://tools.ietf.org/html/rfc793 +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum State { + Closed, + Listen, + SynSent, + SynReceived, + Established, + FinWait1, + FinWait2, + CloseWait, + Closing, + LastAck, + TimeWait, +} + +impl fmt::Display for State { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + State::Closed => write!(f, "CLOSED"), + State::Listen => write!(f, "LISTEN"), + State::SynSent => write!(f, "SYN-SENT"), + State::SynReceived => write!(f, "SYN-RECEIVED"), + State::Established => write!(f, "ESTABLISHED"), + State::FinWait1 => write!(f, "FIN-WAIT-1"), + State::FinWait2 => write!(f, "FIN-WAIT-2"), + State::CloseWait => write!(f, "CLOSE-WAIT"), + State::Closing => write!(f, "CLOSING"), + State::LastAck => write!(f, "LAST-ACK"), + State::TimeWait => write!(f, "TIME-WAIT"), + } + } +} + +// Conservative initial RTT estimate. +const RTTE_INITIAL_RTT: u32 = 300; +const RTTE_INITIAL_DEV: u32 = 100; + +// Minimum "safety margin" for the RTO that kicks in when the +// variance gets very low. +const RTTE_MIN_MARGIN: u32 = 5; + +const RTTE_MIN_RTO: u32 = 10; +const RTTE_MAX_RTO: u32 = 10000; + +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct RttEstimator { + // Using u32 instead of Duration to save space (Duration is i64) + rtt: u32, + deviation: u32, + timestamp: Option<(Instant, TcpSeqNumber)>, + max_seq_sent: Option<TcpSeqNumber>, + rto_count: u8, +} + +impl Default for RttEstimator { + fn default() -> Self { + Self { + rtt: RTTE_INITIAL_RTT, + deviation: RTTE_INITIAL_DEV, + timestamp: None, + max_seq_sent: None, + rto_count: 0, + } + } +} + +impl RttEstimator { + fn retransmission_timeout(&self) -> Duration { + let margin = RTTE_MIN_MARGIN.max(self.deviation * 4); + let ms = (self.rtt + margin).clamp(RTTE_MIN_RTO, RTTE_MAX_RTO); + Duration::from_millis(ms as u64) + } + + fn sample(&mut self, new_rtt: u32) { + // "Congestion Avoidance and Control", Van Jacobson, Michael J. Karels, 1988 + self.rtt = (self.rtt * 7 + new_rtt + 7) / 8; + let diff = (self.rtt as i32 - new_rtt as i32).unsigned_abs(); + self.deviation = (self.deviation * 3 + diff + 3) / 4; + + self.rto_count = 0; + + let rto = self.retransmission_timeout().total_millis(); + tcp_trace!( + "rtte: sample={:?} rtt={:?} dev={:?} rto={:?}", + new_rtt, + self.rtt, + self.deviation, + rto + ); + } + + fn on_send(&mut self, timestamp: Instant, seq: TcpSeqNumber) { + if self + .max_seq_sent + .map(|max_seq_sent| seq > max_seq_sent) + .unwrap_or(true) + { + self.max_seq_sent = Some(seq); + if self.timestamp.is_none() { + self.timestamp = Some((timestamp, seq)); + tcp_trace!("rtte: sampling at seq={:?}", seq); + } + } + } + + fn on_ack(&mut self, timestamp: Instant, seq: TcpSeqNumber) { + if let Some((sent_timestamp, sent_seq)) = self.timestamp { + if seq >= sent_seq { + self.sample((timestamp - sent_timestamp).total_millis() as u32); + self.timestamp = None; + } + } + } + + fn on_retransmit(&mut self) { + if self.timestamp.is_some() { + tcp_trace!("rtte: abort sampling due to retransmit"); + } + self.timestamp = None; + self.rto_count = self.rto_count.saturating_add(1); + if self.rto_count >= 3 { + // This happens in 2 scenarios: + // - The RTT is higher than the initial estimate + // - The network conditions change, suddenly making the RTT much higher + // In these cases, the estimator can get stuck, because it can't sample because + // all packets sent would incur a retransmit. To avoid this, force an estimate + // increase if we see 3 consecutive retransmissions without any successful sample. + self.rto_count = 0; + self.rtt = RTTE_MAX_RTO.min(self.rtt * 2); + let rto = self.retransmission_timeout().total_millis(); + tcp_trace!( + "rtte: too many retransmissions, increasing: rtt={:?} dev={:?} rto={:?}", + self.rtt, + self.deviation, + rto + ); + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +enum Timer { + Idle { + keep_alive_at: Option<Instant>, + }, + Retransmit { + expires_at: Instant, + delay: Duration, + }, + FastRetransmit, + Close { + expires_at: Instant, + }, +} + +const ACK_DELAY_DEFAULT: Duration = Duration::from_millis(10); +const CLOSE_DELAY: Duration = Duration::from_millis(10_000); + +impl Timer { + fn new() -> Timer { + Timer::Idle { + keep_alive_at: None, + } + } + + fn should_keep_alive(&self, timestamp: Instant) -> bool { + match *self { + Timer::Idle { + keep_alive_at: Some(keep_alive_at), + } if timestamp >= keep_alive_at => true, + _ => false, + } + } + + fn should_retransmit(&self, timestamp: Instant) -> Option<Duration> { + match *self { + Timer::Retransmit { expires_at, delay } if timestamp >= expires_at => { + Some(timestamp - expires_at + delay) + } + Timer::FastRetransmit => Some(Duration::from_millis(0)), + _ => None, + } + } + + fn should_close(&self, timestamp: Instant) -> bool { + match *self { + Timer::Close { expires_at } if timestamp >= expires_at => true, + _ => false, + } + } + + fn poll_at(&self) -> PollAt { + match *self { + Timer::Idle { + keep_alive_at: Some(keep_alive_at), + } => PollAt::Time(keep_alive_at), + Timer::Idle { + keep_alive_at: None, + } => PollAt::Ingress, + Timer::Retransmit { expires_at, .. } => PollAt::Time(expires_at), + Timer::FastRetransmit => PollAt::Now, + Timer::Close { expires_at } => PollAt::Time(expires_at), + } + } + + fn set_for_idle(&mut self, timestamp: Instant, interval: Option<Duration>) { + *self = Timer::Idle { + keep_alive_at: interval.map(|interval| timestamp + interval), + } + } + + fn set_keep_alive(&mut self) { + if let Timer::Idle { keep_alive_at } = self { + if keep_alive_at.is_none() { + *keep_alive_at = Some(Instant::from_millis(0)) + } + } + } + + fn rewind_keep_alive(&mut self, timestamp: Instant, interval: Option<Duration>) { + if let Timer::Idle { keep_alive_at } = self { + *keep_alive_at = interval.map(|interval| timestamp + interval) + } + } + + fn set_for_retransmit(&mut self, timestamp: Instant, delay: Duration) { + match *self { + Timer::Idle { .. } | Timer::FastRetransmit { .. } => { + *self = Timer::Retransmit { + expires_at: timestamp + delay, + delay, + } + } + Timer::Retransmit { expires_at, delay } if timestamp >= expires_at => { + *self = Timer::Retransmit { + expires_at: timestamp + delay, + delay: delay * 2, + } + } + Timer::Retransmit { .. } => (), + Timer::Close { .. } => (), + } + } + + fn set_for_fast_retransmit(&mut self) { + *self = Timer::FastRetransmit + } + + fn set_for_close(&mut self, timestamp: Instant) { + *self = Timer::Close { + expires_at: timestamp + CLOSE_DELAY, + } + } + + fn is_retransmit(&self) -> bool { + match *self { + Timer::Retransmit { .. } | Timer::FastRetransmit => true, + _ => false, + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +enum AckDelayTimer { + Idle, + Waiting(Instant), + Immediate, +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct Tuple { + local: IpEndpoint, + remote: IpEndpoint, +} + +impl Display for Tuple { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}:{}", self.local, self.remote) + } +} + +/// A Transmission Control Protocol socket. +/// +/// A TCP socket may passively listen for connections or actively connect to another endpoint. +/// Note that, for listening sockets, there is no "backlog"; to be able to simultaneously +/// accept several connections, as many sockets must be allocated, or any new connection +/// attempts will be reset. +#[derive(Debug)] +pub struct Socket<'a> { + state: State, + timer: Timer, + rtte: RttEstimator, + assembler: Assembler, + rx_buffer: SocketBuffer<'a>, + rx_fin_received: bool, + tx_buffer: SocketBuffer<'a>, + /// Interval after which, if no inbound packets are received, the connection is aborted. + timeout: Option<Duration>, + /// Interval at which keep-alive packets will be sent. + keep_alive: Option<Duration>, + /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + hop_limit: Option<u8>, + /// Address passed to listen(). Listen address is set when listen() is called and + /// used every time the socket is reset back to the LISTEN state. + listen_endpoint: IpListenEndpoint, + /// Current 4-tuple (local and remote endpoints). + tuple: Option<Tuple>, + /// The sequence number corresponding to the beginning of the transmit buffer. + /// I.e. an ACK(local_seq_no+n) packet removes n bytes from the transmit buffer. + local_seq_no: TcpSeqNumber, + /// The sequence number corresponding to the beginning of the receive buffer. + /// I.e. userspace reading n bytes adds n to remote_seq_no. + remote_seq_no: TcpSeqNumber, + /// The last sequence number sent. + /// I.e. in an idle socket, local_seq_no+tx_buffer.len(). + remote_last_seq: TcpSeqNumber, + /// The last acknowledgement number sent. + /// I.e. in an idle socket, remote_seq_no+rx_buffer.len(). + remote_last_ack: Option<TcpSeqNumber>, + /// The last window length sent. + remote_last_win: u16, + /// The sending window scaling factor advertised to remotes which support RFC 1323. + /// It is zero if the window <= 64KiB and/or the remote does not support it. + remote_win_shift: u8, + /// The remote window size, relative to local_seq_no + /// I.e. we're allowed to send octets until local_seq_no+remote_win_len + remote_win_len: usize, + /// The receive window scaling factor for remotes which support RFC 1323, None if unsupported. + remote_win_scale: Option<u8>, + /// Whether or not the remote supports selective ACK as described in RFC 2018. + remote_has_sack: bool, + /// The maximum number of data octets that the remote side may receive. + remote_mss: usize, + /// The timestamp of the last packet received. + remote_last_ts: Option<Instant>, + /// The sequence number of the last packet received, used for sACK + local_rx_last_seq: Option<TcpSeqNumber>, + /// The ACK number of the last packet received. + local_rx_last_ack: Option<TcpSeqNumber>, + /// The number of packets received directly after + /// each other which have the same ACK number. + local_rx_dup_acks: u8, + + /// Duration for Delayed ACK. If None no ACKs will be delayed. + ack_delay: Option<Duration>, + /// Delayed ack timer. If set, packets containing exclusively + /// ACK or window updates (ie, no data) won't be sent until expiry. + ack_delay_timer: AckDelayTimer, + + /// Used for rate-limiting: No more challenge ACKs will be sent until this instant. + challenge_ack_timer: Instant, + + /// Nagle's Algorithm enabled. + nagle: bool, + + #[cfg(feature = "async")] + rx_waker: WakerRegistration, + #[cfg(feature = "async")] + tx_waker: WakerRegistration, +} + +const DEFAULT_MSS: usize = 536; + +impl<'a> Socket<'a> { + #[allow(unused_comparisons)] // small usize platforms always pass rx_capacity check + /// Create a socket using the given buffers. + pub fn new<T>(rx_buffer: T, tx_buffer: T) -> Socket<'a> + where + T: Into<SocketBuffer<'a>>, + { + let (rx_buffer, tx_buffer) = (rx_buffer.into(), tx_buffer.into()); + let rx_capacity = rx_buffer.capacity(); + + // From RFC 1323: + // [...] the above constraints imply that 2 * the max window size must be less + // than 2**31 [...] Thus, the shift count must be limited to 14 (which allows + // windows of 2**30 = 1 Gbyte). + if rx_capacity > (1 << 30) { + panic!("receiving buffer too large, cannot exceed 1 GiB") + } + let rx_cap_log2 = mem::size_of::<usize>() * 8 - rx_capacity.leading_zeros() as usize; + + Socket { + state: State::Closed, + timer: Timer::new(), + rtte: RttEstimator::default(), + assembler: Assembler::new(), + tx_buffer, + rx_buffer, + rx_fin_received: false, + timeout: None, + keep_alive: None, + hop_limit: None, + listen_endpoint: IpListenEndpoint::default(), + tuple: None, + local_seq_no: TcpSeqNumber::default(), + remote_seq_no: TcpSeqNumber::default(), + remote_last_seq: TcpSeqNumber::default(), + remote_last_ack: None, + remote_last_win: 0, + remote_win_len: 0, + remote_win_shift: rx_cap_log2.saturating_sub(16) as u8, + remote_win_scale: None, + remote_has_sack: false, + remote_mss: DEFAULT_MSS, + remote_last_ts: None, + local_rx_last_ack: None, + local_rx_last_seq: None, + local_rx_dup_acks: 0, + ack_delay: Some(ACK_DELAY_DEFAULT), + ack_delay_timer: AckDelayTimer::Idle, + challenge_ack_timer: Instant::from_secs(0), + nagle: true, + + #[cfg(feature = "async")] + rx_waker: WakerRegistration::new(), + #[cfg(feature = "async")] + tx_waker: WakerRegistration::new(), + } + } + + /// Register a waker for receive operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `recv` method calls, such as receiving data, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `recv` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_recv_waker(&mut self, waker: &Waker) { + self.rx_waker.register(waker) + } + + /// Register a waker for send operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `send` method calls, such as space becoming available in the transmit + /// buffer, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `send` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_send_waker(&mut self, waker: &Waker) { + self.tx_waker.register(waker) + } + + /// Return the timeout duration. + /// + /// See also the [set_timeout](#method.set_timeout) method. + pub fn timeout(&self) -> Option<Duration> { + self.timeout + } + + /// Return the ACK delay duration. + /// + /// See also the [set_ack_delay](#method.set_ack_delay) method. + pub fn ack_delay(&self) -> Option<Duration> { + self.ack_delay + } + + /// Return whether Nagle's Algorithm is enabled. + /// + /// See also the [set_nagle_enabled](#method.set_nagle_enabled) method. + pub fn nagle_enabled(&self) -> bool { + self.nagle + } + + /// Return the current window field value, including scaling according to RFC 1323. + /// + /// Used in internal calculations as well as packet generation. + /// + #[inline] + fn scaled_window(&self) -> u16 { + cmp::min( + self.rx_buffer.window() >> self.remote_win_shift as usize, + (1 << 16) - 1, + ) as u16 + } + + /// Set the timeout duration. + /// + /// A socket with a timeout duration set will abort the connection if either of the following + /// occurs: + /// + /// * After a [connect](#method.connect) call, the remote endpoint does not respond within + /// the specified duration; + /// * After establishing a connection, there is data in the transmit buffer and the remote + /// endpoint exceeds the specified duration between any two packets it sends; + /// * After enabling [keep-alive](#method.set_keep_alive), the remote endpoint exceeds + /// the specified duration between any two packets it sends. + pub fn set_timeout(&mut self, duration: Option<Duration>) { + self.timeout = duration + } + + /// Set the ACK delay duration. + /// + /// By default, the ACK delay is set to 10ms. + pub fn set_ack_delay(&mut self, duration: Option<Duration>) { + self.ack_delay = duration + } + + /// Enable or disable Nagle's Algorithm. + /// + /// Also known as "tinygram prevention". By default, it is enabled. + /// Disabling it is equivalent to Linux's TCP_NODELAY flag. + /// + /// When enabled, Nagle's Algorithm prevents sending segments smaller than MSS if + /// there is data in flight (sent but not acknowledged). In other words, it ensures + /// at most only one segment smaller than MSS is in flight at a time. + /// + /// It ensures better network utilization by preventing sending many very small packets, + /// at the cost of increased latency in some situations, particularly when the remote peer + /// has ACK delay enabled. + pub fn set_nagle_enabled(&mut self, enabled: bool) { + self.nagle = enabled + } + + /// Return the keep-alive interval. + /// + /// See also the [set_keep_alive](#method.set_keep_alive) method. + pub fn keep_alive(&self) -> Option<Duration> { + self.keep_alive + } + + /// Set the keep-alive interval. + /// + /// An idle socket with a keep-alive interval set will transmit a "keep-alive ACK" packet + /// every time it receives no communication during that interval. As a result, three things + /// may happen: + /// + /// * The remote endpoint is fine and answers with an ACK packet. + /// * The remote endpoint has rebooted and answers with an RST packet. + /// * The remote endpoint has crashed and does not answer. + /// + /// The keep-alive functionality together with the timeout functionality allows to react + /// to these error conditions. + pub fn set_keep_alive(&mut self, interval: Option<Duration>) { + self.keep_alive = interval; + if self.keep_alive.is_some() { + // If the connection is idle and we've just set the option, it would not take effect + // until the next packet, unless we wind up the timer explicitly. + self.timer.set_keep_alive(); + } + } + + /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + /// + /// See also the [set_hop_limit](#method.set_hop_limit) method + pub fn hop_limit(&self) -> Option<u8> { + self.hop_limit + } + + /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + /// + /// A socket without an explicitly set hop limit value uses the default [IANA recommended] + /// value (64). + /// + /// # Panics + /// + /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7]. + /// + /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml + /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7 + pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { + // A host MUST NOT send a datagram with a hop limit value of 0 + if let Some(0) = hop_limit { + panic!("the time-to-live value of a packet must not be zero") + } + + self.hop_limit = hop_limit + } + + /// Return the local endpoint, or None if not connected. + #[inline] + pub fn local_endpoint(&self) -> Option<IpEndpoint> { + Some(self.tuple?.local) + } + + /// Return the remote endpoint, or None if not connected. + #[inline] + pub fn remote_endpoint(&self) -> Option<IpEndpoint> { + Some(self.tuple?.remote) + } + + /// Return the connection state, in terms of the TCP state machine. + #[inline] + pub fn state(&self) -> State { + self.state + } + + fn reset(&mut self) { + let rx_cap_log2 = + mem::size_of::<usize>() * 8 - self.rx_buffer.capacity().leading_zeros() as usize; + + self.state = State::Closed; + self.timer = Timer::new(); + self.rtte = RttEstimator::default(); + self.assembler = Assembler::new(); + self.tx_buffer.clear(); + self.rx_buffer.clear(); + self.rx_fin_received = false; + self.listen_endpoint = IpListenEndpoint::default(); + self.tuple = None; + self.local_seq_no = TcpSeqNumber::default(); + self.remote_seq_no = TcpSeqNumber::default(); + self.remote_last_seq = TcpSeqNumber::default(); + self.remote_last_ack = None; + self.remote_last_win = 0; + self.remote_win_len = 0; + self.remote_win_scale = None; + self.remote_win_shift = rx_cap_log2.saturating_sub(16) as u8; + self.remote_mss = DEFAULT_MSS; + self.remote_last_ts = None; + self.ack_delay_timer = AckDelayTimer::Idle; + self.challenge_ack_timer = Instant::from_secs(0); + + #[cfg(feature = "async")] + { + self.rx_waker.wake(); + self.tx_waker.wake(); + } + } + + /// Start listening on the given endpoint. + /// + /// This function returns `Err(Error::InvalidState)` if the socket was already open + /// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)` + /// if the port in the given endpoint is zero. + pub fn listen<T>(&mut self, local_endpoint: T) -> Result<(), ListenError> + where + T: Into<IpListenEndpoint>, + { + let local_endpoint = local_endpoint.into(); + if local_endpoint.port == 0 { + return Err(ListenError::Unaddressable); + } + + if self.is_open() { + // If we were already listening to same endpoint there is nothing to do; exit early. + // + // In the past listening on an socket that was already listening was an error, + // however this makes writing an acceptor loop with multiple sockets impossible. + // Without this early exit, if you tried to listen on a socket that's already listening you'll + // immediately get an error. The only way around this is to abort the socket first + // before listening again, but this means that incoming connections can actually + // get aborted between the abort() and the next listen(). + if matches!(self.state, State::Listen) && self.listen_endpoint == local_endpoint { + return Ok(()); + } else { + return Err(ListenError::InvalidState); + } + } + + self.reset(); + self.listen_endpoint = local_endpoint; + self.tuple = None; + self.set_state(State::Listen); + Ok(()) + } + + /// Connect to a given endpoint. + /// + /// The local port must be provided explicitly. Assuming `fn get_ephemeral_port() -> u16` + /// allocates a port between 49152 and 65535, a connection may be established as follows: + /// + /// ```no_run + /// # #[cfg(all( + /// # feature = "medium-ethernet", + /// # feature = "proto-ipv4", + /// # ))] + /// # { + /// # use smoltcp::socket::tcp::{Socket, SocketBuffer}; + /// # use smoltcp::iface::Interface; + /// # use smoltcp::wire::IpAddress; + /// # + /// # fn get_ephemeral_port() -> u16 { + /// # 49152 + /// # } + /// # + /// # let mut socket = Socket::new( + /// # SocketBuffer::new(vec![0; 1200]), + /// # SocketBuffer::new(vec![0; 1200]) + /// # ); + /// # + /// # let mut iface: Interface = todo!(); + /// # + /// socket.connect( + /// iface.context(), + /// (IpAddress::v4(10, 0, 0, 1), 80), + /// get_ephemeral_port() + /// ).unwrap(); + /// # } + /// ``` + /// + /// The local address may optionally be provided. + /// + /// This function returns an error if the socket was open; see [is_open](#method.is_open). + /// It also returns an error if the local or remote port is zero, or if the remote address + /// is unspecified. + pub fn connect<T, U>( + &mut self, + cx: &mut Context, + remote_endpoint: T, + local_endpoint: U, + ) -> Result<(), ConnectError> + where + T: Into<IpEndpoint>, + U: Into<IpListenEndpoint>, + { + let remote_endpoint: IpEndpoint = remote_endpoint.into(); + let local_endpoint: IpListenEndpoint = local_endpoint.into(); + + if self.is_open() { + return Err(ConnectError::InvalidState); + } + if remote_endpoint.port == 0 || remote_endpoint.addr.is_unspecified() { + return Err(ConnectError::Unaddressable); + } + if local_endpoint.port == 0 { + return Err(ConnectError::Unaddressable); + } + + // If local address is not provided, choose it automatically. + let local_endpoint = IpEndpoint { + addr: match local_endpoint.addr { + Some(addr) => { + if addr.is_unspecified() { + return Err(ConnectError::Unaddressable); + } + addr + } + None => cx + .get_source_address(&remote_endpoint.addr) + .ok_or(ConnectError::Unaddressable)?, + }, + port: local_endpoint.port, + }; + + if local_endpoint.addr.version() != remote_endpoint.addr.version() { + return Err(ConnectError::Unaddressable); + } + + self.reset(); + self.tuple = Some(Tuple { + local: local_endpoint, + remote: remote_endpoint, + }); + self.set_state(State::SynSent); + + let seq = Self::random_seq_no(cx); + self.local_seq_no = seq; + self.remote_last_seq = seq; + Ok(()) + } + + #[cfg(test)] + fn random_seq_no(_cx: &mut Context) -> TcpSeqNumber { + TcpSeqNumber(10000) + } + + #[cfg(not(test))] + fn random_seq_no(cx: &mut Context) -> TcpSeqNumber { + TcpSeqNumber(cx.rand().rand_u32() as i32) + } + + /// Close the transmit half of the full-duplex connection. + /// + /// Note that there is no corresponding function for the receive half of the full-duplex + /// connection; only the remote end can close it. If you no longer wish to receive any + /// data and would like to reuse the socket right away, use [abort](#method.abort). + pub fn close(&mut self) { + match self.state { + // In the LISTEN state there is no established connection. + State::Listen => self.set_state(State::Closed), + // In the SYN-SENT state the remote endpoint is not yet synchronized and, upon + // receiving an RST, will abort the connection. + State::SynSent => self.set_state(State::Closed), + // In the SYN-RECEIVED, ESTABLISHED and CLOSE-WAIT states the transmit half + // of the connection is open, and needs to be explicitly closed with a FIN. + State::SynReceived | State::Established => self.set_state(State::FinWait1), + State::CloseWait => self.set_state(State::LastAck), + // In the FIN-WAIT-1, FIN-WAIT-2, CLOSING, LAST-ACK, TIME-WAIT and CLOSED states, + // the transmit half of the connection is already closed, and no further + // action is needed. + State::FinWait1 + | State::FinWait2 + | State::Closing + | State::TimeWait + | State::LastAck + | State::Closed => (), + } + } + + /// Aborts the connection, if any. + /// + /// This function instantly closes the socket. One reset packet will be sent to the remote + /// endpoint. + /// + /// In terms of the TCP state machine, the socket may be in any state and is moved to + /// the `CLOSED` state. + pub fn abort(&mut self) { + self.set_state(State::Closed); + } + + /// Return whether the socket is passively listening for incoming connections. + /// + /// In terms of the TCP state machine, the socket must be in the `LISTEN` state. + #[inline] + pub fn is_listening(&self) -> bool { + match self.state { + State::Listen => true, + _ => false, + } + } + + /// Return whether the socket is open. + /// + /// This function returns true if the socket will process incoming or dispatch outgoing + /// packets. Note that this does not mean that it is possible to send or receive data through + /// the socket; for that, use [can_send](#method.can_send) or [can_recv](#method.can_recv). + /// + /// In terms of the TCP state machine, the socket must not be in the `CLOSED` + /// or `TIME-WAIT` states. + #[inline] + pub fn is_open(&self) -> bool { + match self.state { + State::Closed => false, + State::TimeWait => false, + _ => true, + } + } + + /// Return whether a connection is active. + /// + /// This function returns true if the socket is actively exchanging packets with + /// a remote endpoint. Note that this does not mean that it is possible to send or receive + /// data through the socket; for that, use [can_send](#method.can_send) or + /// [can_recv](#method.can_recv). + /// + /// If a connection is established, [abort](#method.close) will send a reset to + /// the remote endpoint. + /// + /// In terms of the TCP state machine, the socket must not be in the `CLOSED`, `TIME-WAIT`, + /// or `LISTEN` state. + #[inline] + pub fn is_active(&self) -> bool { + match self.state { + State::Closed => false, + State::TimeWait => false, + State::Listen => false, + _ => true, + } + } + + /// Return whether the transmit half of the full-duplex connection is open. + /// + /// This function returns true if it's possible to send data and have it arrive + /// to the remote endpoint. However, it does not make any guarantees about the state + /// of the transmit buffer, and even if it returns true, [send](#method.send) may + /// not be able to enqueue any octets. + /// + /// In terms of the TCP state machine, the socket must be in the `ESTABLISHED` or + /// `CLOSE-WAIT` state. + #[inline] + pub fn may_send(&self) -> bool { + match self.state { + State::Established => true, + // In CLOSE-WAIT, the remote endpoint has closed our receive half of the connection + // but we still can transmit indefinitely. + State::CloseWait => true, + _ => false, + } + } + + /// Return whether the receive half of the full-duplex connection is open. + /// + /// This function returns true if it's possible to receive data from the remote endpoint. + /// It will return true while there is data in the receive buffer, and if there isn't, + /// as long as the remote endpoint has not closed the connection. + /// + /// In terms of the TCP state machine, the socket must be in the `ESTABLISHED`, + /// `FIN-WAIT-1`, or `FIN-WAIT-2` state, or have data in the receive buffer instead. + #[inline] + pub fn may_recv(&self) -> bool { + match self.state { + State::Established => true, + // In FIN-WAIT-1/2, we have closed our transmit half of the connection but + // we still can receive indefinitely. + State::FinWait1 | State::FinWait2 => true, + // If we have something in the receive buffer, we can receive that. + _ if !self.rx_buffer.is_empty() => true, + _ => false, + } + } + + /// Check whether the transmit half of the full-duplex connection is open + /// (see [may_send](#method.may_send)), and the transmit buffer is not full. + #[inline] + pub fn can_send(&self) -> bool { + if !self.may_send() { + return false; + } + + !self.tx_buffer.is_full() + } + + /// Return the maximum number of bytes inside the recv buffer. + #[inline] + pub fn recv_capacity(&self) -> usize { + self.rx_buffer.capacity() + } + + /// Return the maximum number of bytes inside the transmit buffer. + #[inline] + pub fn send_capacity(&self) -> usize { + self.tx_buffer.capacity() + } + + /// Check whether the receive half of the full-duplex connection buffer is open + /// (see [may_recv](#method.may_recv)), and the receive buffer is not empty. + #[inline] + pub fn can_recv(&self) -> bool { + if !self.may_recv() { + return false; + } + + !self.rx_buffer.is_empty() + } + + fn send_impl<'b, F, R>(&'b mut self, f: F) -> Result<R, SendError> + where + F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R), + { + if !self.may_send() { + return Err(SendError::InvalidState); + } + + // The connection might have been idle for a long time, and so remote_last_ts + // would be far in the past. Unless we clear it here, we'll abort the connection + // down over in dispatch() by erroneously detecting it as timed out. + if self.tx_buffer.is_empty() { + self.remote_last_ts = None + } + + let _old_length = self.tx_buffer.len(); + let (size, result) = f(&mut self.tx_buffer); + if size > 0 { + #[cfg(any(test, feature = "verbose"))] + tcp_trace!( + "tx buffer: enqueueing {} octets (now {})", + size, + _old_length + size + ); + } + Ok(result) + } + + /// Call `f` with the largest contiguous slice of octets in the transmit buffer, + /// and enqueue the amount of elements returned by `f`. + /// + /// This function returns `Err(Error::Illegal)` if the transmit half of + /// the connection is not open; see [may_send](#method.may_send). + pub fn send<'b, F, R>(&'b mut self, f: F) -> Result<R, SendError> + where + F: FnOnce(&'b mut [u8]) -> (usize, R), + { + self.send_impl(|tx_buffer| tx_buffer.enqueue_many_with(f)) + } + + /// Enqueue a sequence of octets to be sent, and fill it from a slice. + /// + /// This function returns the amount of octets actually enqueued, which is limited + /// by the amount of free space in the transmit buffer; down to zero. + /// + /// See also [send](#method.send). + pub fn send_slice(&mut self, data: &[u8]) -> Result<usize, SendError> { + self.send_impl(|tx_buffer| { + let size = tx_buffer.enqueue_slice(data); + (size, size) + }) + } + + fn recv_error_check(&mut self) -> Result<(), RecvError> { + // We may have received some data inside the initial SYN, but until the connection + // is fully open we must not dequeue any data, as it may be overwritten by e.g. + // another (stale) SYN. (We do not support TCP Fast Open.) + if !self.may_recv() { + if self.rx_fin_received { + return Err(RecvError::Finished); + } + return Err(RecvError::InvalidState); + } + + Ok(()) + } + + fn recv_impl<'b, F, R>(&'b mut self, f: F) -> Result<R, RecvError> + where + F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R), + { + self.recv_error_check()?; + + let _old_length = self.rx_buffer.len(); + let (size, result) = f(&mut self.rx_buffer); + self.remote_seq_no += size; + if size > 0 { + #[cfg(any(test, feature = "verbose"))] + tcp_trace!( + "rx buffer: dequeueing {} octets (now {})", + size, + _old_length - size + ); + } + Ok(result) + } + + /// Call `f` with the largest contiguous slice of octets in the receive buffer, + /// and dequeue the amount of elements returned by `f`. + /// + /// This function errors if the receive half of the connection is not open. + /// + /// If the receive half has been gracefully closed (with a FIN packet), `Err(Error::Finished)` + /// is returned. In this case, the previously received data is guaranteed to be complete. + /// + /// In all other cases, `Err(Error::Illegal)` is returned and previously received data (if any) + /// may be incomplete (truncated). + pub fn recv<'b, F, R>(&'b mut self, f: F) -> Result<R, RecvError> + where + F: FnOnce(&'b mut [u8]) -> (usize, R), + { + self.recv_impl(|rx_buffer| rx_buffer.dequeue_many_with(f)) + } + + /// Dequeue a sequence of received octets, and fill a slice from it. + /// + /// This function returns the amount of octets actually dequeued, which is limited + /// by the amount of occupied space in the receive buffer; down to zero. + /// + /// See also [recv](#method.recv). + pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<usize, RecvError> { + self.recv_impl(|rx_buffer| { + let size = rx_buffer.dequeue_slice(data); + (size, size) + }) + } + + /// Peek at a sequence of received octets without removing them from + /// the receive buffer, and return a pointer to it. + /// + /// This function otherwise behaves identically to [recv](#method.recv). + pub fn peek(&mut self, size: usize) -> Result<&[u8], RecvError> { + self.recv_error_check()?; + + let buffer = self.rx_buffer.get_allocated(0, size); + if !buffer.is_empty() { + #[cfg(any(test, feature = "verbose"))] + tcp_trace!("rx buffer: peeking at {} octets", buffer.len()); + } + Ok(buffer) + } + + /// Peek at a sequence of received octets without removing them from + /// the receive buffer, and fill a slice from it. + /// + /// This function otherwise behaves identically to [recv_slice](#method.recv_slice). + pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<usize, RecvError> { + Ok(self.rx_buffer.read_allocated(0, data)) + } + + /// Return the amount of octets queued in the transmit buffer. + /// + /// Note that the Berkeley sockets interface does not have an equivalent of this API. + pub fn send_queue(&self) -> usize { + self.tx_buffer.len() + } + + /// Return the amount of octets queued in the receive buffer. This value can be larger than + /// the slice read by the next `recv` or `peek` call because it includes all queued octets, + /// and not only the octets that may be returned as a contiguous slice. + /// + /// Note that the Berkeley sockets interface does not have an equivalent of this API. + pub fn recv_queue(&self) -> usize { + self.rx_buffer.len() + } + + fn set_state(&mut self, state: State) { + if self.state != state { + tcp_trace!("state={}=>{}", self.state, state); + } + + self.state = state; + + #[cfg(feature = "async")] + { + // Wake all tasks waiting. Even if we haven't received/sent data, this + // is needed because return values of functions may change depending on the state. + // For example, a pending read has to fail with an error if the socket is closed. + self.rx_waker.wake(); + self.tx_waker.wake(); + } + } + + pub(crate) fn reply(ip_repr: &IpRepr, repr: &TcpRepr) -> (IpRepr, TcpRepr<'static>) { + let reply_repr = TcpRepr { + src_port: repr.dst_port, + dst_port: repr.src_port, + control: TcpControl::None, + seq_number: TcpSeqNumber(0), + ack_number: None, + window_len: 0, + window_scale: None, + max_seg_size: None, + sack_permitted: false, + sack_ranges: [None, None, None], + payload: &[], + }; + let ip_reply_repr = IpRepr::new( + ip_repr.dst_addr(), + ip_repr.src_addr(), + IpProtocol::Tcp, + reply_repr.buffer_len(), + 64, + ); + (ip_reply_repr, reply_repr) + } + + pub(crate) fn rst_reply(ip_repr: &IpRepr, repr: &TcpRepr) -> (IpRepr, TcpRepr<'static>) { + debug_assert!(repr.control != TcpControl::Rst); + + let (ip_reply_repr, mut reply_repr) = Self::reply(ip_repr, repr); + + // See https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for explanation + // of why we sometimes send an RST and sometimes an RST|ACK + reply_repr.control = TcpControl::Rst; + reply_repr.seq_number = repr.ack_number.unwrap_or_default(); + if repr.control == TcpControl::Syn && repr.ack_number.is_none() { + reply_repr.ack_number = Some(repr.seq_number + repr.segment_len()); + } + + (ip_reply_repr, reply_repr) + } + + fn ack_reply(&mut self, ip_repr: &IpRepr, repr: &TcpRepr) -> (IpRepr, TcpRepr<'static>) { + let (mut ip_reply_repr, mut reply_repr) = Self::reply(ip_repr, repr); + + // From RFC 793: + // [...] an empty acknowledgment segment containing the current send-sequence number + // and an acknowledgment indicating the next sequence number expected + // to be received. + reply_repr.seq_number = self.remote_last_seq; + reply_repr.ack_number = Some(self.remote_seq_no + self.rx_buffer.len()); + self.remote_last_ack = reply_repr.ack_number; + + // From RFC 1323: + // The window field [...] of every outgoing segment, with the exception of SYN + // segments, is right-shifted by [advertised scale value] bits[...] + reply_repr.window_len = self.scaled_window(); + self.remote_last_win = reply_repr.window_len; + + // If the remote supports selective acknowledgement, add the option to the outgoing + // segment. + if self.remote_has_sack { + net_debug!("sending sACK option with current assembler ranges"); + + // RFC 2018: The first SACK block (i.e., the one immediately following the kind and + // length fields in the option) MUST specify the contiguous block of data containing + // the segment which triggered this ACK, unless that segment advanced the + // Acknowledgment Number field in the header. + reply_repr.sack_ranges[0] = None; + + if let Some(last_seg_seq) = self.local_rx_last_seq.map(|s| s.0 as u32) { + reply_repr.sack_ranges[0] = self + .assembler + .iter_data(reply_repr.ack_number.map(|s| s.0 as usize).unwrap_or(0)) + .map(|(left, right)| (left as u32, right as u32)) + .find(|(left, right)| *left <= last_seg_seq && *right >= last_seg_seq); + } + + if reply_repr.sack_ranges[0].is_none() { + // The matching segment was removed from the assembler, meaning the acknowledgement + // number has advanced, or there was no previous sACK. + // + // While the RFC says we SHOULD keep a list of reported sACK ranges, and iterate + // through those, that is currently infeasible. Instead, we offer the range with + // the lowest sequence number (if one exists) to hint at what segments would + // most quickly advance the acknowledgement number. + reply_repr.sack_ranges[0] = self + .assembler + .iter_data(reply_repr.ack_number.map(|s| s.0 as usize).unwrap_or(0)) + .map(|(left, right)| (left as u32, right as u32)) + .next(); + } + } + + // Since the sACK option may have changed the length of the payload, update that. + ip_reply_repr.set_payload_len(reply_repr.buffer_len()); + (ip_reply_repr, reply_repr) + } + + fn challenge_ack_reply( + &mut self, + cx: &mut Context, + ip_repr: &IpRepr, + repr: &TcpRepr, + ) -> Option<(IpRepr, TcpRepr<'static>)> { + if cx.now() < self.challenge_ack_timer { + return None; + } + + // Rate-limit to 1 per second max. + self.challenge_ack_timer = cx.now() + Duration::from_secs(1); + + Some(self.ack_reply(ip_repr, repr)) + } + + pub(crate) fn accepts(&self, _cx: &mut Context, ip_repr: &IpRepr, repr: &TcpRepr) -> bool { + if self.state == State::Closed { + return false; + } + + // If we're still listening for SYNs and the packet has an ACK, it cannot + // be destined to this socket, but another one may well listen on the same + // local endpoint. + if self.state == State::Listen && repr.ack_number.is_some() { + return false; + } + + if let Some(tuple) = &self.tuple { + // Reject packets not matching the 4-tuple + ip_repr.dst_addr() == tuple.local.addr + && repr.dst_port == tuple.local.port + && ip_repr.src_addr() == tuple.remote.addr + && repr.src_port == tuple.remote.port + } else { + // We're listening, reject packets not matching the listen endpoint. + let addr_ok = match self.listen_endpoint.addr { + Some(addr) => ip_repr.dst_addr() == addr, + None => true, + }; + addr_ok && repr.dst_port != 0 && repr.dst_port == self.listen_endpoint.port + } + } + + pub(crate) fn process( + &mut self, + cx: &mut Context, + ip_repr: &IpRepr, + repr: &TcpRepr, + ) -> Option<(IpRepr, TcpRepr<'static>)> { + debug_assert!(self.accepts(cx, ip_repr, repr)); + + // Consider how much the sequence number space differs from the transmit buffer space. + let (sent_syn, sent_fin) = match self.state { + // In SYN-SENT or SYN-RECEIVED, we've just sent a SYN. + State::SynSent | State::SynReceived => (true, false), + // In FIN-WAIT-1, LAST-ACK, or CLOSING, we've just sent a FIN. + State::FinWait1 | State::LastAck | State::Closing => (false, true), + // In all other states we've already got acknowledgements for + // all of the control flags we sent. + _ => (false, false), + }; + let control_len = (sent_syn as usize) + (sent_fin as usize); + + // Reject unacceptable acknowledgements. + match (self.state, repr.control, repr.ack_number) { + // An RST received in response to initial SYN is acceptable if it acknowledges + // the initial SYN. + (State::SynSent, TcpControl::Rst, None) => { + net_debug!("unacceptable RST (expecting RST|ACK) in response to initial SYN"); + return None; + } + (State::SynSent, TcpControl::Rst, Some(ack_number)) => { + if ack_number != self.local_seq_no + 1 { + net_debug!("unacceptable RST|ACK in response to initial SYN"); + return None; + } + } + // Any other RST need only have a valid sequence number. + (_, TcpControl::Rst, _) => (), + // The initial SYN cannot contain an acknowledgement. + (State::Listen, _, None) => (), + // This case is handled in `accepts()`. + (State::Listen, _, Some(_)) => unreachable!(), + // Every packet after the initial SYN must be an acknowledgement. + (_, _, None) => { + net_debug!("expecting an ACK"); + return None; + } + // SYN|ACK in the SYN-SENT state must have the exact ACK number. + (State::SynSent, TcpControl::Syn, Some(ack_number)) => { + if ack_number != self.local_seq_no + 1 { + net_debug!("unacceptable SYN|ACK in response to initial SYN"); + return Some(Self::rst_reply(ip_repr, repr)); + } + } + // ACKs in the SYN-SENT state are invalid. + (State::SynSent, TcpControl::None, Some(ack_number)) => { + // If the sequence number matches, ignore it instead of RSTing. + // I'm not sure why, I think it may be a workaround for broken TCP + // servers, or a defense against reordering. Either way, if Linux + // does it, we do too. + if ack_number == self.local_seq_no + 1 { + net_debug!( + "expecting a SYN|ACK, received an ACK with the right ack_number, ignoring." + ); + return None; + } + + net_debug!( + "expecting a SYN|ACK, received an ACK with the wrong ack_number, sending RST." + ); + return Some(Self::rst_reply(ip_repr, repr)); + } + // Anything else in the SYN-SENT state is invalid. + (State::SynSent, _, _) => { + net_debug!("expecting a SYN|ACK"); + return None; + } + // ACK in the SYN-RECEIVED state must have the exact ACK number, or we RST it. + (State::SynReceived, _, Some(ack_number)) => { + if ack_number != self.local_seq_no + 1 { + net_debug!("unacceptable ACK in response to SYN|ACK"); + return Some(Self::rst_reply(ip_repr, repr)); + } + } + // Every acknowledgement must be for transmitted but unacknowledged data. + (_, _, Some(ack_number)) => { + let unacknowledged = self.tx_buffer.len() + control_len; + + // Acceptable ACK range (both inclusive) + let mut ack_min = self.local_seq_no; + let ack_max = self.local_seq_no + unacknowledged; + + // If we have sent a SYN, it MUST be acknowledged. + if sent_syn { + ack_min += 1; + } + + if ack_number < ack_min { + net_debug!( + "duplicate ACK ({} not in {}...{})", + ack_number, + ack_min, + ack_max + ); + return None; + } + + if ack_number > ack_max { + net_debug!( + "unacceptable ACK ({} not in {}...{})", + ack_number, + ack_min, + ack_max + ); + return self.challenge_ack_reply(cx, ip_repr, repr); + } + } + } + + let window_start = self.remote_seq_no + self.rx_buffer.len(); + let window_end = self.remote_seq_no + self.rx_buffer.capacity(); + let segment_start = repr.seq_number; + let segment_end = repr.seq_number + repr.payload.len(); + + let (payload, payload_offset) = match self.state { + // In LISTEN and SYN-SENT states, we have not yet synchronized with the remote end. + State::Listen | State::SynSent => (&[][..], 0), + _ => { + // https://www.rfc-editor.org/rfc/rfc9293.html#name-segment-acceptability-tests + let segment_in_window = match ( + segment_start == segment_end, + window_start == window_end, + ) { + (true, _) if segment_end == window_start - 1 => { + net_debug!( + "received a keep-alive or window probe packet, will send an ACK" + ); + false + } + (true, true) => { + if window_start == segment_start { + true + } else { + net_debug!( + "zero-length segment not inside zero-length window, will send an ACK." + ); + false + } + } + (true, false) => { + if window_start <= segment_start && segment_start < window_end { + true + } else { + net_debug!("zero-length segment not inside window, will send an ACK."); + false + } + } + (false, true) => { + net_debug!( + "non-zero-length segment with zero receive window, will only send an ACK" + ); + false + } + (false, false) => { + if (window_start <= segment_start && segment_start < window_end) + || (window_start < segment_end && segment_end <= window_end) + { + true + } else { + net_debug!( + "segment not in receive window ({}..{} not intersecting {}..{}), will send challenge ACK", + segment_start, + segment_end, + window_start, + window_end + ); + false + } + } + }; + + if segment_in_window { + let overlap_start = window_start.max(segment_start); + let overlap_end = window_end.min(segment_end); + + // the checks done above imply this. + debug_assert!(overlap_start <= overlap_end); + + self.local_rx_last_seq = Some(repr.seq_number); + + ( + &repr.payload[overlap_start - segment_start..overlap_end - segment_start], + overlap_start - window_start, + ) + } else { + // If we're in the TIME-WAIT state, restart the TIME-WAIT timeout, since + // the remote end may not have realized we've closed the connection. + if self.state == State::TimeWait { + self.timer.set_for_close(cx.now()); + } + + return self.challenge_ack_reply(cx, ip_repr, repr); + } + } + }; + + // Compute the amount of acknowledged octets, removing the SYN and FIN bits + // from the sequence space. + let mut ack_len = 0; + let mut ack_of_fin = false; + let mut ack_all = false; + if repr.control != TcpControl::Rst { + if let Some(ack_number) = repr.ack_number { + // Sequence number corresponding to the first byte in `tx_buffer`. + // This normally equals `local_seq_no`, but is 1 higher if we have sent a SYN, + // as the SYN occupies 1 sequence number "before" the data. + let tx_buffer_start_seq = self.local_seq_no + (sent_syn as usize); + + if ack_number >= tx_buffer_start_seq { + ack_len = ack_number - tx_buffer_start_seq; + + // We could've sent data before the FIN, so only remove FIN from the sequence + // space if all of that data is acknowledged. + if sent_fin && self.tx_buffer.len() + 1 == ack_len { + ack_len -= 1; + tcp_trace!("received ACK of FIN"); + ack_of_fin = true; + } + + ack_all = self.remote_last_seq == ack_number + } + + self.rtte.on_ack(cx.now(), ack_number); + } + } + + // Disregard control flags we don't care about or shouldn't act on yet. + let mut control = repr.control; + control = control.quash_psh(); + + // If a FIN is received at the end of the current segment but the start of the segment + // is not at the start of the receive window, disregard this FIN. + if control == TcpControl::Fin && window_start != segment_start { + tcp_trace!("ignoring FIN because we don't have full data yet. window_start={} segment_start={}", window_start, segment_start); + control = TcpControl::None; + } + + // Validate and update the state. + match (self.state, control) { + // RSTs are not accepted in the LISTEN state. + (State::Listen, TcpControl::Rst) => return None, + + // RSTs in SYN-RECEIVED flip the socket back to the LISTEN state. + (State::SynReceived, TcpControl::Rst) => { + tcp_trace!("received RST"); + self.tuple = None; + self.set_state(State::Listen); + return None; + } + + // RSTs in any other state close the socket. + (_, TcpControl::Rst) => { + tcp_trace!("received RST"); + self.set_state(State::Closed); + self.tuple = None; + return None; + } + + // SYN packets in the LISTEN state change it to SYN-RECEIVED. + (State::Listen, TcpControl::Syn) => { + tcp_trace!("received SYN"); + if let Some(max_seg_size) = repr.max_seg_size { + if max_seg_size == 0 { + tcp_trace!("received SYNACK with zero MSS, ignoring"); + return None; + } + self.remote_mss = max_seg_size as usize + } + + self.tuple = Some(Tuple { + local: IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port), + remote: IpEndpoint::new(ip_repr.src_addr(), repr.src_port), + }); + self.local_seq_no = Self::random_seq_no(cx); + self.remote_seq_no = repr.seq_number + 1; + self.remote_last_seq = self.local_seq_no; + self.remote_has_sack = repr.sack_permitted; + self.remote_win_scale = repr.window_scale; + // Remote doesn't support window scaling, don't do it. + if self.remote_win_scale.is_none() { + self.remote_win_shift = 0; + } + self.set_state(State::SynReceived); + self.timer.set_for_idle(cx.now(), self.keep_alive); + } + + // ACK packets in the SYN-RECEIVED state change it to ESTABLISHED. + (State::SynReceived, TcpControl::None) => { + self.set_state(State::Established); + self.timer.set_for_idle(cx.now(), self.keep_alive); + } + + // FIN packets in the SYN-RECEIVED state change it to CLOSE-WAIT. + // It's not obvious from RFC 793 that this is permitted, but + // 7th and 8th steps in the "SEGMENT ARRIVES" event describe this behavior. + (State::SynReceived, TcpControl::Fin) => { + self.remote_seq_no += 1; + self.rx_fin_received = true; + self.set_state(State::CloseWait); + self.timer.set_for_idle(cx.now(), self.keep_alive); + } + + // SYN|ACK packets in the SYN-SENT state change it to ESTABLISHED. + (State::SynSent, TcpControl::Syn) => { + tcp_trace!("received SYN|ACK"); + if let Some(max_seg_size) = repr.max_seg_size { + if max_seg_size == 0 { + tcp_trace!("received SYNACK with zero MSS, ignoring"); + return None; + } + self.remote_mss = max_seg_size as usize; + } + + self.remote_seq_no = repr.seq_number + 1; + self.remote_last_seq = self.local_seq_no + 1; + self.remote_last_ack = Some(repr.seq_number); + self.remote_win_scale = repr.window_scale; + // Remote doesn't support window scaling, don't do it. + if self.remote_win_scale.is_none() { + self.remote_win_shift = 0; + } + + self.set_state(State::Established); + self.timer.set_for_idle(cx.now(), self.keep_alive); + } + + // ACK packets in ESTABLISHED state reset the retransmit timer, + // except for duplicate ACK packets which preserve it. + (State::Established, TcpControl::None) => { + if !self.timer.is_retransmit() || ack_all { + self.timer.set_for_idle(cx.now(), self.keep_alive); + } + } + + // FIN packets in ESTABLISHED state indicate the remote side has closed. + (State::Established, TcpControl::Fin) => { + self.remote_seq_no += 1; + self.rx_fin_received = true; + self.set_state(State::CloseWait); + self.timer.set_for_idle(cx.now(), self.keep_alive); + } + + // ACK packets in FIN-WAIT-1 state change it to FIN-WAIT-2, if we've already + // sent everything in the transmit buffer. If not, they reset the retransmit timer. + (State::FinWait1, TcpControl::None) => { + if ack_of_fin { + self.set_state(State::FinWait2); + } + if ack_all { + self.timer.set_for_idle(cx.now(), self.keep_alive); + } + } + + // FIN packets in FIN-WAIT-1 state change it to CLOSING, or to TIME-WAIT + // if they also acknowledge our FIN. + (State::FinWait1, TcpControl::Fin) => { + self.remote_seq_no += 1; + self.rx_fin_received = true; + if ack_of_fin { + self.set_state(State::TimeWait); + self.timer.set_for_close(cx.now()); + } else { + self.set_state(State::Closing); + self.timer.set_for_idle(cx.now(), self.keep_alive); + } + } + + // Data packets in FIN-WAIT-2 reset the idle timer. + (State::FinWait2, TcpControl::None) => { + self.timer.set_for_idle(cx.now(), self.keep_alive); + } + + // FIN packets in FIN-WAIT-2 state change it to TIME-WAIT. + (State::FinWait2, TcpControl::Fin) => { + self.remote_seq_no += 1; + self.rx_fin_received = true; + self.set_state(State::TimeWait); + self.timer.set_for_close(cx.now()); + } + + // ACK packets in CLOSING state change it to TIME-WAIT. + (State::Closing, TcpControl::None) => { + if ack_of_fin { + self.set_state(State::TimeWait); + self.timer.set_for_close(cx.now()); + } else { + self.timer.set_for_idle(cx.now(), self.keep_alive); + } + } + + // ACK packets in CLOSE-WAIT state reset the retransmit timer. + (State::CloseWait, TcpControl::None) => { + self.timer.set_for_idle(cx.now(), self.keep_alive); + } + + // ACK packets in LAST-ACK state change it to CLOSED. + (State::LastAck, TcpControl::None) => { + if ack_of_fin { + // Clear the remote endpoint, or we'll send an RST there. + self.set_state(State::Closed); + self.tuple = None; + } else { + self.timer.set_for_idle(cx.now(), self.keep_alive); + } + } + + _ => { + net_debug!("unexpected packet {}", repr); + return None; + } + } + + // Update remote state. + self.remote_last_ts = Some(cx.now()); + + // RFC 1323: The window field (SEG.WND) in the header of every incoming segment, with the + // exception of SYN segments, is left-shifted by Snd.Wind.Scale bits before updating SND.WND. + let scale = match repr.control { + TcpControl::Syn => 0, + _ => self.remote_win_scale.unwrap_or(0), + }; + let new_remote_win_len = (repr.window_len as usize) << (scale as usize); + let is_window_update = new_remote_win_len != self.remote_win_len; + self.remote_win_len = new_remote_win_len; + + if ack_len > 0 { + // Dequeue acknowledged octets. + debug_assert!(self.tx_buffer.len() >= ack_len); + tcp_trace!( + "tx buffer: dequeueing {} octets (now {})", + ack_len, + self.tx_buffer.len() - ack_len + ); + self.tx_buffer.dequeue_allocated(ack_len); + + // There's new room available in tx_buffer, wake the waiting task if any. + #[cfg(feature = "async")] + self.tx_waker.wake(); + } + + if let Some(ack_number) = repr.ack_number { + // TODO: When flow control is implemented, + // refractor the following block within that implementation + + // Detect and react to duplicate ACKs by: + // 1. Check if duplicate ACK and change self.local_rx_dup_acks accordingly + // 2. If exactly 3 duplicate ACKs received, set for fast retransmit + // 3. Update the last received ACK (self.local_rx_last_ack) + match self.local_rx_last_ack { + // Duplicate ACK if payload empty and ACK doesn't move send window -> + // Increment duplicate ACK count and set for retransmit if we just received + // the third duplicate ACK + Some(last_rx_ack) + if repr.payload.is_empty() + && last_rx_ack == ack_number + && ack_number < self.remote_last_seq + && !is_window_update => + { + // Increment duplicate ACK count + self.local_rx_dup_acks = self.local_rx_dup_acks.saturating_add(1); + + net_debug!( + "received duplicate ACK for seq {} (duplicate nr {}{})", + ack_number, + self.local_rx_dup_acks, + if self.local_rx_dup_acks == u8::max_value() { + "+" + } else { + "" + } + ); + + if self.local_rx_dup_acks == 3 { + self.timer.set_for_fast_retransmit(); + net_debug!("started fast retransmit"); + } + } + // No duplicate ACK -> Reset state and update last received ACK + _ => { + if self.local_rx_dup_acks > 0 { + self.local_rx_dup_acks = 0; + net_debug!("reset duplicate ACK count"); + } + self.local_rx_last_ack = Some(ack_number); + } + }; + // We've processed everything in the incoming segment, so advance the local + // sequence number past it. + self.local_seq_no = ack_number; + // During retransmission, if an earlier segment got lost but later was + // successfully received, self.local_seq_no can move past self.remote_last_seq. + // Do not attempt to retransmit the latter segments; not only this is pointless + // in theory but also impossible in practice, since they have been already + // deallocated from the buffer. + if self.remote_last_seq < self.local_seq_no { + self.remote_last_seq = self.local_seq_no + } + } + + let payload_len = payload.len(); + if payload_len == 0 { + return None; + } + + let assembler_was_empty = self.assembler.is_empty(); + + // Try adding payload octets to the assembler. + let Ok(contig_len) = self + .assembler + .add_then_remove_front(payload_offset, payload_len) + else { + net_debug!( + "assembler: too many holes to add {} octets at offset {}", + payload_len, + payload_offset + ); + return None; + }; + + // Place payload octets into the buffer. + tcp_trace!( + "rx buffer: receiving {} octets at offset {}", + payload_len, + payload_offset + ); + let len_written = self.rx_buffer.write_unallocated(payload_offset, payload); + debug_assert!(len_written == payload_len); + + if contig_len != 0 { + // Enqueue the contiguous data octets in front of the buffer. + tcp_trace!( + "rx buffer: enqueueing {} octets (now {})", + contig_len, + self.rx_buffer.len() + contig_len + ); + self.rx_buffer.enqueue_unallocated(contig_len); + + // There's new data in rx_buffer, notify waiting task if any. + #[cfg(feature = "async")] + self.rx_waker.wake(); + } + + if !self.assembler.is_empty() { + // Print the ranges recorded in the assembler. + tcp_trace!("assembler: {}", self.assembler); + } + + // Handle delayed acks + if let Some(ack_delay) = self.ack_delay { + if self.ack_to_transmit() || self.window_to_update() { + self.ack_delay_timer = match self.ack_delay_timer { + AckDelayTimer::Idle => { + tcp_trace!("starting delayed ack timer"); + + AckDelayTimer::Waiting(cx.now() + ack_delay) + } + // RFC1122 says "in a stream of full-sized segments there SHOULD be an ACK + // for at least every second segment". + // For now, we send an ACK every second received packet, full-sized or not. + AckDelayTimer::Waiting(_) => { + tcp_trace!("delayed ack timer already started, forcing expiry"); + AckDelayTimer::Immediate + } + AckDelayTimer::Immediate => { + tcp_trace!("delayed ack timer already force-expired"); + AckDelayTimer::Immediate + } + }; + } + } + + // Per RFC 5681, we should send an immediate ACK when either: + // 1) an out-of-order segment is received, or + // 2) a segment arrives that fills in all or part of a gap in sequence space. + if !self.assembler.is_empty() || !assembler_was_empty { + // Note that we change the transmitter state here. + // This is fine because smoltcp assumes that it can always transmit zero or one + // packets for every packet it receives. + tcp_trace!("ACKing incoming segment"); + Some(self.ack_reply(ip_repr, repr)) + } else { + None + } + } + + fn timed_out(&self, timestamp: Instant) -> bool { + match (self.remote_last_ts, self.timeout) { + (Some(remote_last_ts), Some(timeout)) => timestamp >= remote_last_ts + timeout, + (_, _) => false, + } + } + + fn seq_to_transmit(&self, cx: &mut Context) -> bool { + let ip_header_len = match self.tuple.unwrap().local.addr { + #[cfg(feature = "proto-ipv4")] + IpAddress::Ipv4(_) => crate::wire::IPV4_HEADER_LEN, + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(_) => crate::wire::IPV6_HEADER_LEN, + }; + + // Max segment size we're able to send due to MTU limitations. + let local_mss = cx.ip_mtu() - ip_header_len - TCP_HEADER_LEN; + + // The effective max segment size, taking into account our and remote's limits. + let effective_mss = local_mss.min(self.remote_mss); + + // Have we sent data that hasn't been ACKed yet? + let data_in_flight = self.remote_last_seq != self.local_seq_no; + + // If we want to send a SYN and we haven't done so, do it! + if matches!(self.state, State::SynSent | State::SynReceived) && !data_in_flight { + return true; + } + + // max sequence number we can send. + let max_send_seq = + self.local_seq_no + core::cmp::min(self.remote_win_len, self.tx_buffer.len()); + + // Max amount of octets we can send. + let max_send = if max_send_seq >= self.remote_last_seq { + max_send_seq - self.remote_last_seq + } else { + 0 + }; + + // Can we send at least 1 octet? + let mut can_send = max_send != 0; + // Can we send at least 1 full segment? + let can_send_full = max_send >= effective_mss; + + // Do we have to send a FIN? + let want_fin = match self.state { + State::FinWait1 => true, + State::Closing => true, + State::LastAck => true, + _ => false, + }; + + // If we're applying the Nagle algorithm we don't want to send more + // until one of: + // * There's no data in flight + // * We can send a full packet + // * We have all the data we'll ever send (we're closing send) + if self.nagle && data_in_flight && !can_send_full && !want_fin { + can_send = false; + } + + // Can we actually send the FIN? We can send it if: + // 1. We have unsent data that fits in the remote window. + // 2. We have no unsent data. + // This condition matches only if #2, because #1 is already covered by can_data and we're ORing them. + let can_fin = want_fin && self.remote_last_seq == self.local_seq_no + self.tx_buffer.len(); + + can_send || can_fin + } + + fn delayed_ack_expired(&self, timestamp: Instant) -> bool { + match self.ack_delay_timer { + AckDelayTimer::Idle => true, + AckDelayTimer::Waiting(t) => t <= timestamp, + AckDelayTimer::Immediate => true, + } + } + + fn ack_to_transmit(&self) -> bool { + if let Some(remote_last_ack) = self.remote_last_ack { + remote_last_ack < self.remote_seq_no + self.rx_buffer.len() + } else { + false + } + } + + fn window_to_update(&self) -> bool { + match self.state { + State::SynSent + | State::SynReceived + | State::Established + | State::FinWait1 + | State::FinWait2 => self.scaled_window() > self.remote_last_win, + _ => false, + } + } + + pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E> + where + F: FnOnce(&mut Context, (IpRepr, TcpRepr)) -> Result<(), E>, + { + if self.tuple.is_none() { + return Ok(()); + } + + if self.remote_last_ts.is_none() { + // We get here in exactly two cases: + // 1) This socket just transitioned into SYN-SENT. + // 2) This socket had an empty transmit buffer and some data was added there. + // Both are similar in that the socket has been quiet for an indefinite + // period of time, it isn't anymore, and the local endpoint is talking. + // So, we start counting the timeout not from the last received packet + // but from the first transmitted one. + self.remote_last_ts = Some(cx.now()); + } + + // Check if any state needs to be changed because of a timer. + if self.timed_out(cx.now()) { + // If a timeout expires, we should abort the connection. + net_debug!("timeout exceeded"); + self.set_state(State::Closed); + } else if !self.seq_to_transmit(cx) { + if let Some(retransmit_delta) = self.timer.should_retransmit(cx.now()) { + // If a retransmit timer expired, we should resend data starting at the last ACK. + net_debug!("retransmitting at t+{}", retransmit_delta); + + // Rewind "last sequence number sent", as if we never + // had sent them. This will cause all data in the queue + // to be sent again. + self.remote_last_seq = self.local_seq_no; + + // Clear the `should_retransmit` state. If we can't retransmit right + // now for whatever reason (like zero window), this avoids an + // infinite polling loop where `poll_at` returns `Now` but `dispatch` + // can't actually do anything. + self.timer.set_for_idle(cx.now(), self.keep_alive); + + // Inform RTTE, so that it can avoid bogus measurements. + self.rtte.on_retransmit(); + } + } + + // Decide whether we're sending a packet. + if self.seq_to_transmit(cx) { + // If we have data to transmit and it fits into partner's window, do it. + tcp_trace!("outgoing segment will send data or flags"); + } else if self.ack_to_transmit() && self.delayed_ack_expired(cx.now()) { + // If we have data to acknowledge, do it. + tcp_trace!("outgoing segment will acknowledge"); + } else if self.window_to_update() && self.delayed_ack_expired(cx.now()) { + // If we have window length increase to advertise, do it. + tcp_trace!("outgoing segment will update window"); + } else if self.state == State::Closed { + // If we need to abort the connection, do it. + tcp_trace!("outgoing segment will abort connection"); + } else if self.timer.should_keep_alive(cx.now()) { + // If we need to transmit a keep-alive packet, do it. + tcp_trace!("keep-alive timer expired"); + } else if self.timer.should_close(cx.now()) { + // If we have spent enough time in the TIME-WAIT state, close the socket. + tcp_trace!("TIME-WAIT timer expired"); + self.reset(); + return Ok(()); + } else { + return Ok(()); + } + + // NOTE(unwrap): we check tuple is not None the first thing in this function. + let tuple = self.tuple.unwrap(); + + // Construct the lowered IP representation. + // We might need this to calculate the MSS, so do it early. + let mut ip_repr = IpRepr::new( + tuple.local.addr, + tuple.remote.addr, + IpProtocol::Tcp, + 0, + self.hop_limit.unwrap_or(64), + ); + + // Construct the basic TCP representation, an empty ACK packet. + // We'll adjust this to be more specific as needed. + let mut repr = TcpRepr { + src_port: tuple.local.port, + dst_port: tuple.remote.port, + control: TcpControl::None, + seq_number: self.remote_last_seq, + ack_number: Some(self.remote_seq_no + self.rx_buffer.len()), + window_len: self.scaled_window(), + window_scale: None, + max_seg_size: None, + sack_permitted: false, + sack_ranges: [None, None, None], + payload: &[], + }; + + match self.state { + // We transmit an RST in the CLOSED state. If we ended up in the CLOSED state + // with a specified endpoint, it means that the socket was aborted. + State::Closed => { + repr.control = TcpControl::Rst; + } + + // We never transmit anything in the LISTEN state. + State::Listen => return Ok(()), + + // We transmit a SYN in the SYN-SENT state. + // We transmit a SYN|ACK in the SYN-RECEIVED state. + State::SynSent | State::SynReceived => { + repr.control = TcpControl::Syn; + // window len must NOT be scaled in SYNs. + repr.window_len = self.rx_buffer.window().min((1 << 16) - 1) as u16; + if self.state == State::SynSent { + repr.ack_number = None; + repr.window_scale = Some(self.remote_win_shift); + repr.sack_permitted = true; + } else { + repr.sack_permitted = self.remote_has_sack; + repr.window_scale = self.remote_win_scale.map(|_| self.remote_win_shift); + } + } + + // We transmit data in all states where we may have data in the buffer, + // or the transmit half of the connection is still open. + State::Established + | State::FinWait1 + | State::Closing + | State::CloseWait + | State::LastAck => { + // Extract as much data as the remote side can receive in this packet + // from the transmit buffer. + + // Right edge of window, ie the max sequence number we're allowed to send. + let win_right_edge = self.local_seq_no + self.remote_win_len; + + // Max amount of octets we're allowed to send according to the remote window. + let win_limit = if win_right_edge >= self.remote_last_seq { + win_right_edge - self.remote_last_seq + } else { + // This can happen if we've sent some data and later the remote side + // has shrunk its window so that data is no longer inside the window. + // This should be very rare and is strongly discouraged by the RFCs, + // but it does happen in practice. + // http://www.tcpipguide.com/free/t_TCPWindowManagementIssues.htm + 0 + }; + + // Maximum size we're allowed to send. This can be limited by 3 factors: + // 1. remote window + // 2. MSS the remote is willing to accept, probably determined by their MTU + // 3. MSS we can send, determined by our MTU. + let size = win_limit + .min(self.remote_mss) + .min(cx.ip_mtu() - ip_repr.header_len() - TCP_HEADER_LEN); + + let offset = self.remote_last_seq - self.local_seq_no; + repr.payload = self.tx_buffer.get_allocated(offset, size); + + // If we've sent everything we had in the buffer, follow it with the PSH or FIN + // flags, depending on whether the transmit half of the connection is open. + if offset + repr.payload.len() == self.tx_buffer.len() { + match self.state { + State::FinWait1 | State::LastAck | State::Closing => { + repr.control = TcpControl::Fin + } + State::Established | State::CloseWait if !repr.payload.is_empty() => { + repr.control = TcpControl::Psh + } + _ => (), + } + } + } + + // In FIN-WAIT-2 and TIME-WAIT states we may only transmit ACKs for incoming data or FIN + State::FinWait2 | State::TimeWait => {} + } + + // There might be more than one reason to send a packet. E.g. the keep-alive timer + // has expired, and we also have data in transmit buffer. Since any packet that occupies + // sequence space will elicit an ACK, we only need to send an explicit packet if we + // couldn't fill the sequence space with anything. + let is_keep_alive; + if self.timer.should_keep_alive(cx.now()) && repr.is_empty() { + repr.seq_number = repr.seq_number - 1; + repr.payload = b"\x00"; // RFC 1122 says we should do this + is_keep_alive = true; + } else { + is_keep_alive = false; + } + + // Trace a summary of what will be sent. + if is_keep_alive { + tcp_trace!("sending a keep-alive"); + } else if !repr.payload.is_empty() { + tcp_trace!( + "tx buffer: sending {} octets at offset {}", + repr.payload.len(), + self.remote_last_seq - self.local_seq_no + ); + } + if repr.control != TcpControl::None || repr.payload.is_empty() { + let flags = match (repr.control, repr.ack_number) { + (TcpControl::Syn, None) => "SYN", + (TcpControl::Syn, Some(_)) => "SYN|ACK", + (TcpControl::Fin, Some(_)) => "FIN|ACK", + (TcpControl::Rst, Some(_)) => "RST|ACK", + (TcpControl::Psh, Some(_)) => "PSH|ACK", + (TcpControl::None, Some(_)) => "ACK", + _ => "<unreachable>", + }; + tcp_trace!("sending {}", flags); + } + + if repr.control == TcpControl::Syn { + // Fill the MSS option. See RFC 6691 for an explanation of this calculation. + let max_segment_size = cx.ip_mtu() - ip_repr.header_len() - TCP_HEADER_LEN; + repr.max_seg_size = Some(max_segment_size as u16); + } + + // Actually send the packet. If this succeeds, it means the packet is in + // the device buffer, and its transmission is imminent. If not, we might have + // a number of problems, e.g. we need neighbor discovery. + // + // Bailing out if the packet isn't placed in the device buffer allows us + // to not waste time waiting for the retransmit timer on packets that we know + // for sure will not be successfully transmitted. + ip_repr.set_payload_len(repr.buffer_len()); + emit(cx, (ip_repr, repr))?; + + // We've sent something, whether useful data or a keep-alive packet, so rewind + // the keep-alive timer. + self.timer.rewind_keep_alive(cx.now(), self.keep_alive); + + // Reset delayed-ack timer + match self.ack_delay_timer { + AckDelayTimer::Idle => {} + AckDelayTimer::Waiting(_) => { + tcp_trace!("stop delayed ack timer") + } + AckDelayTimer::Immediate => { + tcp_trace!("stop delayed ack timer (was force-expired)") + } + } + self.ack_delay_timer = AckDelayTimer::Idle; + + // Leave the rest of the state intact if sending a keep-alive packet, since those + // carry a fake segment. + if is_keep_alive { + return Ok(()); + } + + // We've sent a packet successfully, so we can update the internal state now. + self.remote_last_seq = repr.seq_number + repr.segment_len(); + self.remote_last_ack = repr.ack_number; + self.remote_last_win = repr.window_len; + + if repr.segment_len() > 0 { + self.rtte + .on_send(cx.now(), repr.seq_number + repr.segment_len()); + } + + if !self.seq_to_transmit(cx) && repr.segment_len() > 0 { + // If we've transmitted all data we could (and there was something at all, + // data or flag, to transmit, not just an ACK), wind up the retransmit timer. + self.timer + .set_for_retransmit(cx.now(), self.rtte.retransmission_timeout()); + } + + if self.state == State::Closed { + // When aborting a connection, forget about it after sending a single RST packet. + self.tuple = None; + #[cfg(feature = "async")] + { + // Wake tx now so that async users can wait for the RST to be sent + self.tx_waker.wake(); + } + } + + Ok(()) + } + + #[allow(clippy::if_same_then_else)] + pub(crate) fn poll_at(&self, cx: &mut Context) -> PollAt { + // The logic here mirrors the beginning of dispatch() closely. + if self.tuple.is_none() { + // No one to talk to, nothing to transmit. + PollAt::Ingress + } else if self.remote_last_ts.is_none() { + // Socket stopped being quiet recently, we need to acquire a timestamp. + PollAt::Now + } else if self.state == State::Closed { + // Socket was aborted, we have an RST packet to transmit. + PollAt::Now + } else if self.seq_to_transmit(cx) { + // We have a data or flag packet to transmit. + PollAt::Now + } else { + let want_ack = self.ack_to_transmit() || self.window_to_update(); + + let delayed_ack_poll_at = match (want_ack, self.ack_delay_timer) { + (false, _) => PollAt::Ingress, + (true, AckDelayTimer::Idle) => PollAt::Now, + (true, AckDelayTimer::Waiting(t)) => PollAt::Time(t), + (true, AckDelayTimer::Immediate) => PollAt::Now, + }; + + let timeout_poll_at = match (self.remote_last_ts, self.timeout) { + // If we're transmitting or retransmitting data, we need to poll at the moment + // when the timeout would expire. + (Some(remote_last_ts), Some(timeout)) => PollAt::Time(remote_last_ts + timeout), + // Otherwise we have no timeout. + (_, _) => PollAt::Ingress, + }; + + // We wait for the earliest of our timers to fire. + *[self.timer.poll_at(), timeout_poll_at, delayed_ack_poll_at] + .iter() + .min() + .unwrap_or(&PollAt::Ingress) + } + } +} + +impl<'a> fmt::Write for Socket<'a> { + fn write_str(&mut self, slice: &str) -> fmt::Result { + let slice = slice.as_bytes(); + if self.send_slice(slice) == Ok(slice.len()) { + Ok(()) + } else { + Err(fmt::Error) + } + } +} + +// TODO: TCP should work for all features. For now, we only test with the IP feature. We could do +// it for other features as well with rstest, however, this means we have to modify a lot of the +// tests in here, which I didn't had the time for at the moment. +#[cfg(all(test, feature = "medium-ip"))] +mod test { + use super::*; + use crate::wire::IpRepr; + use core::i32; + use std::ops::{Deref, DerefMut}; + use std::vec::Vec; + + // =========================================================================================// + // Constants + // =========================================================================================// + + const LOCAL_PORT: u16 = 80; + const REMOTE_PORT: u16 = 49500; + const LISTEN_END: IpListenEndpoint = IpListenEndpoint { + addr: None, + port: LOCAL_PORT, + }; + const LOCAL_END: IpEndpoint = IpEndpoint { + addr: LOCAL_ADDR.into_address(), + port: LOCAL_PORT, + }; + const REMOTE_END: IpEndpoint = IpEndpoint { + addr: REMOTE_ADDR.into_address(), + port: REMOTE_PORT, + }; + const TUPLE: Tuple = Tuple { + local: LOCAL_END, + remote: REMOTE_END, + }; + const LOCAL_SEQ: TcpSeqNumber = TcpSeqNumber(10000); + const REMOTE_SEQ: TcpSeqNumber = TcpSeqNumber(-10001); + + cfg_if::cfg_if! { + if #[cfg(feature = "proto-ipv4")] { + use crate::wire::Ipv4Address as IpvXAddress; + use crate::wire::Ipv4Repr as IpvXRepr; + use IpRepr::Ipv4 as IpReprIpvX; + + const LOCAL_ADDR: IpvXAddress = IpvXAddress([192, 168, 1, 1]); + const REMOTE_ADDR: IpvXAddress = IpvXAddress([192, 168, 1, 2]); + const OTHER_ADDR: IpvXAddress = IpvXAddress([192, 168, 1, 3]); + + const BASE_MSS: u16 = 1460; + } else { + use crate::wire::Ipv6Address as IpvXAddress; + use crate::wire::Ipv6Repr as IpvXRepr; + use IpRepr::Ipv6 as IpReprIpvX; + + const LOCAL_ADDR: IpvXAddress = IpvXAddress([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + ]); + const REMOTE_ADDR: IpvXAddress = IpvXAddress([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, + ]); + const OTHER_ADDR: IpvXAddress = IpvXAddress([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, + ]); + + const BASE_MSS: u16 = 1440; + } + } + + const SEND_IP_TEMPL: IpRepr = IpReprIpvX(IpvXRepr { + src_addr: LOCAL_ADDR, + dst_addr: REMOTE_ADDR, + next_header: IpProtocol::Tcp, + payload_len: 20, + hop_limit: 64, + }); + const SEND_TEMPL: TcpRepr<'static> = TcpRepr { + src_port: REMOTE_PORT, + dst_port: LOCAL_PORT, + control: TcpControl::None, + seq_number: TcpSeqNumber(0), + ack_number: Some(TcpSeqNumber(0)), + window_len: 256, + window_scale: None, + max_seg_size: None, + sack_permitted: false, + sack_ranges: [None, None, None], + payload: &[], + }; + const _RECV_IP_TEMPL: IpRepr = IpReprIpvX(IpvXRepr { + src_addr: LOCAL_ADDR, + dst_addr: REMOTE_ADDR, + next_header: IpProtocol::Tcp, + payload_len: 20, + hop_limit: 64, + }); + const RECV_TEMPL: TcpRepr<'static> = TcpRepr { + src_port: LOCAL_PORT, + dst_port: REMOTE_PORT, + control: TcpControl::None, + seq_number: TcpSeqNumber(0), + ack_number: Some(TcpSeqNumber(0)), + window_len: 64, + window_scale: None, + max_seg_size: None, + sack_permitted: false, + sack_ranges: [None, None, None], + payload: &[], + }; + + // =========================================================================================// + // Helper functions + // =========================================================================================// + + struct TestSocket { + socket: Socket<'static>, + cx: Context, + } + + impl Deref for TestSocket { + type Target = Socket<'static>; + fn deref(&self) -> &Self::Target { + &self.socket + } + } + + impl DerefMut for TestSocket { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.socket + } + } + + fn send( + socket: &mut TestSocket, + timestamp: Instant, + repr: &TcpRepr, + ) -> Option<TcpRepr<'static>> { + socket.cx.set_now(timestamp); + + let ip_repr = IpReprIpvX(IpvXRepr { + src_addr: REMOTE_ADDR, + dst_addr: LOCAL_ADDR, + next_header: IpProtocol::Tcp, + payload_len: repr.buffer_len(), + hop_limit: 64, + }); + net_trace!("send: {}", repr); + + assert!(socket.socket.accepts(&mut socket.cx, &ip_repr, repr)); + + match socket.socket.process(&mut socket.cx, &ip_repr, repr) { + Some((_ip_repr, repr)) => { + net_trace!("recv: {}", repr); + Some(repr) + } + None => None, + } + } + + fn recv<F>(socket: &mut TestSocket, timestamp: Instant, mut f: F) + where + F: FnMut(Result<TcpRepr, ()>), + { + socket.cx.set_now(timestamp); + + let mut sent = 0; + let result = socket + .socket + .dispatch(&mut socket.cx, |_, (ip_repr, tcp_repr)| { + assert_eq!(ip_repr.next_header(), IpProtocol::Tcp); + assert_eq!(ip_repr.src_addr(), LOCAL_ADDR.into()); + assert_eq!(ip_repr.dst_addr(), REMOTE_ADDR.into()); + assert_eq!(ip_repr.payload_len(), tcp_repr.buffer_len()); + + net_trace!("recv: {}", tcp_repr); + sent += 1; + Ok(f(Ok(tcp_repr))) + }); + match result { + Ok(()) => assert_eq!(sent, 1, "Exactly one packet should be sent"), + Err(e) => f(Err(e)), + } + } + + fn recv_nothing(socket: &mut TestSocket, timestamp: Instant) { + socket.cx.set_now(timestamp); + + let result: Result<(), ()> = socket + .socket + .dispatch(&mut socket.cx, |_, (_ip_repr, _tcp_repr)| { + panic!("Should not send a packet") + }); + + assert_eq!(result, Ok(())) + } + + macro_rules! send { + ($socket:ident, $repr:expr) => + (send!($socket, time 0, $repr)); + ($socket:ident, $repr:expr, $result:expr) => + (send!($socket, time 0, $repr, $result)); + ($socket:ident, time $time:expr, $repr:expr) => + (send!($socket, time $time, $repr, None)); + ($socket:ident, time $time:expr, $repr:expr, $result:expr) => + (assert_eq!(send(&mut $socket, Instant::from_millis($time), &$repr), $result)); + } + + macro_rules! recv { + ($socket:ident, [$( $repr:expr ),*]) => ({ + $( recv!($socket, Ok($repr)); )* + recv_nothing!($socket) + }); + ($socket:ident, $result:expr) => + (recv!($socket, time 0, $result)); + ($socket:ident, time $time:expr, $result:expr) => + (recv(&mut $socket, Instant::from_millis($time), |result| { + // Most of the time we don't care about the PSH flag. + let result = result.map(|mut repr| { + repr.control = repr.control.quash_psh(); + repr + }); + assert_eq!(result, $result) + })); + ($socket:ident, time $time:expr, $result:expr, exact) => + (recv(&mut $socket, Instant::from_millis($time), |repr| assert_eq!(repr, $result))); + } + + macro_rules! recv_nothing { + ($socket:ident) => (recv_nothing!($socket, time 0)); + ($socket:ident, time $time:expr) => (recv_nothing(&mut $socket, Instant::from_millis($time))); + } + + macro_rules! sanity { + ($socket1:expr, $socket2:expr) => {{ + let (s1, s2) = ($socket1, $socket2); + assert_eq!(s1.state, s2.state, "state"); + assert_eq!(s1.tuple, s2.tuple, "tuple"); + assert_eq!(s1.local_seq_no, s2.local_seq_no, "local_seq_no"); + assert_eq!(s1.remote_seq_no, s2.remote_seq_no, "remote_seq_no"); + assert_eq!(s1.remote_last_seq, s2.remote_last_seq, "remote_last_seq"); + assert_eq!(s1.remote_last_ack, s2.remote_last_ack, "remote_last_ack"); + assert_eq!(s1.remote_last_win, s2.remote_last_win, "remote_last_win"); + assert_eq!(s1.remote_win_len, s2.remote_win_len, "remote_win_len"); + assert_eq!(s1.timer, s2.timer, "timer"); + }}; + } + + fn socket() -> TestSocket { + socket_with_buffer_sizes(64, 64) + } + + fn socket_with_buffer_sizes(tx_len: usize, rx_len: usize) -> TestSocket { + let (iface, _, _) = crate::tests::setup(crate::phy::Medium::Ip); + + let rx_buffer = SocketBuffer::new(vec![0; rx_len]); + let tx_buffer = SocketBuffer::new(vec![0; tx_len]); + let mut socket = Socket::new(rx_buffer, tx_buffer); + socket.set_ack_delay(None); + TestSocket { + socket, + cx: iface.inner, + } + } + + fn socket_syn_received_with_buffer_sizes(tx_len: usize, rx_len: usize) -> TestSocket { + let mut s = socket_with_buffer_sizes(tx_len, rx_len); + s.state = State::SynReceived; + s.tuple = Some(TUPLE); + s.local_seq_no = LOCAL_SEQ; + s.remote_seq_no = REMOTE_SEQ + 1; + s.remote_last_seq = LOCAL_SEQ; + s.remote_win_len = 256; + s + } + + fn socket_syn_received() -> TestSocket { + socket_syn_received_with_buffer_sizes(64, 64) + } + + fn socket_syn_sent_with_buffer_sizes(tx_len: usize, rx_len: usize) -> TestSocket { + let mut s = socket_with_buffer_sizes(tx_len, rx_len); + s.state = State::SynSent; + s.tuple = Some(TUPLE); + s.local_seq_no = LOCAL_SEQ; + s.remote_last_seq = LOCAL_SEQ; + s + } + + fn socket_syn_sent() -> TestSocket { + socket_syn_sent_with_buffer_sizes(64, 64) + } + + fn socket_established_with_buffer_sizes(tx_len: usize, rx_len: usize) -> TestSocket { + let mut s = socket_syn_received_with_buffer_sizes(tx_len, rx_len); + s.state = State::Established; + s.local_seq_no = LOCAL_SEQ + 1; + s.remote_last_seq = LOCAL_SEQ + 1; + s.remote_last_ack = Some(REMOTE_SEQ + 1); + s.remote_last_win = 64; + s + } + + fn socket_established() -> TestSocket { + socket_established_with_buffer_sizes(64, 64) + } + + fn socket_fin_wait_1() -> TestSocket { + let mut s = socket_established(); + s.state = State::FinWait1; + s + } + + fn socket_fin_wait_2() -> TestSocket { + let mut s = socket_fin_wait_1(); + s.state = State::FinWait2; + s.local_seq_no = LOCAL_SEQ + 1 + 1; + s.remote_last_seq = LOCAL_SEQ + 1 + 1; + s + } + + fn socket_closing() -> TestSocket { + let mut s = socket_fin_wait_1(); + s.state = State::Closing; + s.remote_last_seq = LOCAL_SEQ + 1 + 1; + s.remote_seq_no = REMOTE_SEQ + 1 + 1; + s + } + + fn socket_time_wait(from_closing: bool) -> TestSocket { + let mut s = socket_fin_wait_2(); + s.state = State::TimeWait; + s.remote_seq_no = REMOTE_SEQ + 1 + 1; + if from_closing { + s.remote_last_ack = Some(REMOTE_SEQ + 1 + 1); + } + s.timer = Timer::Close { + expires_at: Instant::from_secs(1) + CLOSE_DELAY, + }; + s + } + + fn socket_close_wait() -> TestSocket { + let mut s = socket_established(); + s.state = State::CloseWait; + s.remote_seq_no = REMOTE_SEQ + 1 + 1; + s.remote_last_ack = Some(REMOTE_SEQ + 1 + 1); + s + } + + fn socket_last_ack() -> TestSocket { + let mut s = socket_close_wait(); + s.state = State::LastAck; + s + } + + fn socket_recved() -> TestSocket { + let mut s = socket_established(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }] + ); + s + } + + // =========================================================================================// + // Tests for the CLOSED state. + // =========================================================================================// + #[test] + fn test_closed_reject() { + let mut s = socket(); + assert_eq!(s.state, State::Closed); + + let tcp_repr = TcpRepr { + control: TcpControl::Syn, + ..SEND_TEMPL + }; + assert!(!s.socket.accepts(&mut s.cx, &SEND_IP_TEMPL, &tcp_repr)); + } + + #[test] + fn test_closed_reject_after_listen() { + let mut s = socket(); + s.listen(LOCAL_END).unwrap(); + s.close(); + + let tcp_repr = TcpRepr { + control: TcpControl::Syn, + ..SEND_TEMPL + }; + assert!(!s.socket.accepts(&mut s.cx, &SEND_IP_TEMPL, &tcp_repr)); + } + + #[test] + fn test_closed_close() { + let mut s = socket(); + s.close(); + assert_eq!(s.state, State::Closed); + } + + // =========================================================================================// + // Tests for the LISTEN state. + // =========================================================================================// + fn socket_listen() -> TestSocket { + let mut s = socket(); + s.state = State::Listen; + s.listen_endpoint = LISTEN_END; + s + } + + #[test] + fn test_listen_sack_option() { + let mut s = socket_listen(); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + sack_permitted: false, + ..SEND_TEMPL + } + ); + assert!(!s.remote_has_sack); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + + let mut s = socket_listen(); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + sack_permitted: true, + ..SEND_TEMPL + } + ); + assert!(s.remote_has_sack); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_listen_syn_win_scale_buffers() { + for (buffer_size, shift_amt) in &[ + (64, 0), + (128, 0), + (1024, 0), + (65535, 0), + (65536, 1), + (65537, 1), + (131071, 1), + (131072, 2), + (524287, 3), + (524288, 4), + (655350, 4), + (1048576, 5), + ] { + let mut s = socket_with_buffer_sizes(64, *buffer_size); + s.state = State::Listen; + s.listen_endpoint = LISTEN_END; + assert_eq!(s.remote_win_shift, *shift_amt); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + window_scale: Some(0), + ..SEND_TEMPL + } + ); + assert_eq!(s.remote_win_shift, *shift_amt); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + window_scale: Some(*shift_amt), + window_len: cmp::min(*buffer_size, 65535) as u16, + ..RECV_TEMPL + }] + ); + } + } + + #[test] + fn test_listen_sanity() { + let mut s = socket(); + s.listen(LOCAL_PORT).unwrap(); + sanity!(s, socket_listen()); + } + + #[test] + fn test_listen_validation() { + let mut s = socket(); + assert_eq!(s.listen(0), Err(ListenError::Unaddressable)); + } + + #[test] + fn test_listen_twice() { + let mut s = socket(); + assert_eq!(s.listen(80), Ok(())); + // multiple calls to listen are okay if its the same local endpoint and the state is still in listening + assert_eq!(s.listen(80), Ok(())); + s.set_state(State::SynReceived); // state change, simulate incoming connection + assert_eq!(s.listen(80), Err(ListenError::InvalidState)); + } + + #[test] + fn test_listen_syn() { + let mut s = socket_listen(); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + ..SEND_TEMPL + } + ); + sanity!(s, socket_syn_received()); + } + + #[test] + fn test_listen_syn_reject_ack() { + let mut s = socket_listen(); + + let tcp_repr = TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ), + ..SEND_TEMPL + }; + assert!(!s.socket.accepts(&mut s.cx, &SEND_IP_TEMPL, &tcp_repr)); + + assert_eq!(s.state, State::Listen); + } + + #[test] + fn test_listen_rst() { + let mut s = socket_listen(); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, + ack_number: None, + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Listen); + } + + #[test] + fn test_listen_close() { + let mut s = socket_listen(); + s.close(); + assert_eq!(s.state, State::Closed); + } + + // =========================================================================================// + // Tests for the SYN-RECEIVED state. + // =========================================================================================// + + #[test] + fn test_syn_received_ack() { + let mut s = socket_syn_received(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Established); + sanity!(s, socket_established()); + } + + #[test] + fn test_syn_received_ack_too_low() { + let mut s = socket_syn_received(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ), // wrong + ..SEND_TEMPL + }, + Some(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ, + ack_number: None, + window_len: 0, + ..RECV_TEMPL + }) + ); + assert_eq!(s.state, State::SynReceived); + } + + #[test] + fn test_syn_received_ack_too_high() { + let mut s = socket_syn_received(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 2), // wrong + ..SEND_TEMPL + }, + Some(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 2, + ack_number: None, + window_len: 0, + ..RECV_TEMPL + }) + ); + assert_eq!(s.state, State::SynReceived); + } + + #[test] + fn test_syn_received_fin() { + let mut s = socket_syn_received(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6 + 1), + window_len: 58, + ..RECV_TEMPL + }] + ); + assert_eq!(s.state, State::CloseWait); + + let mut s2 = socket_close_wait(); + s2.remote_last_ack = Some(REMOTE_SEQ + 1 + 6 + 1); + s2.remote_last_win = 58; + sanity!(s, s2); + } + + #[test] + fn test_syn_received_rst() { + let mut s = socket_syn_received(); + s.listen_endpoint = LISTEN_END; + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Listen); + assert_eq!(s.listen_endpoint, LISTEN_END); + assert_eq!(s.tuple, None); + } + + #[test] + fn test_syn_received_no_window_scaling() { + let mut s = socket_listen(); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + ..SEND_TEMPL + } + ); + assert_eq!(s.state(), State::SynReceived); + assert_eq!(s.tuple, Some(TUPLE)); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + window_scale: None, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + window_scale: None, + ..SEND_TEMPL + } + ); + assert_eq!(s.remote_win_shift, 0); + assert_eq!(s.remote_win_scale, None); + } + + #[test] + fn test_syn_received_window_scaling() { + for scale in 0..14 { + let mut s = socket_listen(); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + window_scale: Some(scale), + ..SEND_TEMPL + } + ); + assert_eq!(s.state(), State::SynReceived); + assert_eq!(s.tuple, Some(TUPLE)); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + window_scale: None, + ..SEND_TEMPL + } + ); + assert_eq!(s.remote_win_scale, Some(scale)); + } + } + + #[test] + fn test_syn_received_close() { + let mut s = socket_syn_received(); + s.close(); + assert_eq!(s.state, State::FinWait1); + } + + // =========================================================================================// + // Tests for the SYN-SENT state. + // =========================================================================================// + + #[test] + fn test_connect_validation() { + let mut s = socket(); + assert_eq!( + s.socket + .connect(&mut s.cx, REMOTE_END, (IpvXAddress::UNSPECIFIED, 0)), + Err(ConnectError::Unaddressable) + ); + assert_eq!( + s.socket + .connect(&mut s.cx, REMOTE_END, (IpvXAddress::UNSPECIFIED, 1024)), + Err(ConnectError::Unaddressable) + ); + assert_eq!( + s.socket + .connect(&mut s.cx, (IpvXAddress::UNSPECIFIED, 0), LOCAL_END), + Err(ConnectError::Unaddressable) + ); + s.socket + .connect(&mut s.cx, REMOTE_END, LOCAL_END) + .expect("Connect failed with valid parameters"); + assert_eq!(s.tuple, Some(TUPLE)); + } + + #[test] + fn test_connect() { + let mut s = socket(); + s.local_seq_no = LOCAL_SEQ; + s.socket + .connect(&mut s.cx, REMOTE_END, LOCAL_END.port) + .unwrap(); + assert_eq!(s.tuple, Some(TUPLE)); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + max_seg_size: Some(BASE_MSS - 80), + window_scale: Some(0), + ..SEND_TEMPL + } + ); + assert_eq!(s.tuple, Some(TUPLE)); + } + + #[test] + fn test_connect_unspecified_local() { + let mut s = socket(); + assert_eq!(s.socket.connect(&mut s.cx, REMOTE_END, 80), Ok(())); + } + + #[test] + fn test_connect_specified_local() { + let mut s = socket(); + assert_eq!( + s.socket.connect(&mut s.cx, REMOTE_END, (REMOTE_ADDR, 80)), + Ok(()) + ); + } + + #[test] + fn test_connect_twice() { + let mut s = socket(); + assert_eq!(s.socket.connect(&mut s.cx, REMOTE_END, 80), Ok(())); + assert_eq!( + s.socket.connect(&mut s.cx, REMOTE_END, 80), + Err(ConnectError::InvalidState) + ); + } + + #[test] + fn test_syn_sent_sanity() { + let mut s = socket(); + s.local_seq_no = LOCAL_SEQ; + s.socket.connect(&mut s.cx, REMOTE_END, LOCAL_END).unwrap(); + sanity!(s, socket_syn_sent()); + } + + #[test] + fn test_syn_sent_syn_ack() { + let mut s = socket_syn_sent(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + max_seg_size: Some(BASE_MSS - 80), + window_scale: Some(0), + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); + recv_nothing!(s, time 1000); + assert_eq!(s.state, State::Established); + sanity!(s, socket_established()); + } + + #[test] + fn test_syn_sent_syn_ack_not_incremented() { + let mut s = socket_syn_sent(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ), // WRONG + max_seg_size: Some(BASE_MSS - 80), + window_scale: Some(0), + ..SEND_TEMPL + }, + Some(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ, + ack_number: None, + window_len: 0, + ..RECV_TEMPL + }) + ); + assert_eq!(s.state, State::SynSent); + } + + #[test] + fn test_syn_sent_rst() { + let mut s = socket_syn_sent(); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_syn_sent_rst_no_ack() { + let mut s = socket_syn_sent(); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, + ack_number: None, + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::SynSent); + } + + #[test] + fn test_syn_sent_rst_bad_ack() { + let mut s = socket_syn_sent(); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, + ack_number: Some(TcpSeqNumber(1234)), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::SynSent); + } + + #[test] + fn test_syn_sent_bad_ack() { + let mut s = socket_syn_sent(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::None, // Unexpected + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), // Correct + ..SEND_TEMPL + } + ); + + // It should trigger no response and change no state + recv!(s, []); + assert_eq!(s.state, State::SynSent); + } + + #[test] + fn test_syn_sent_bad_ack_seq_1() { + let mut s = socket_syn_sent(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::None, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ), // WRONG + ..SEND_TEMPL + }, + Some(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ, // matching the ack_number of the unexpected ack + ack_number: None, + window_len: 0, + ..RECV_TEMPL + }) + ); + + // It should trigger a RST, and change no state + assert_eq!(s.state, State::SynSent); + } + + #[test] + fn test_syn_sent_bad_ack_seq_2() { + let mut s = socket_syn_sent(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::None, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 123456), // WRONG + ..SEND_TEMPL + }, + Some(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 123456, // matching the ack_number of the unexpected ack + ack_number: None, + window_len: 0, + ..RECV_TEMPL + }) + ); + + // It should trigger a RST, and change no state + assert_eq!(s.state, State::SynSent); + } + + #[test] + fn test_syn_sent_close() { + let mut s = socket(); + s.close(); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_syn_sent_win_scale_buffers() { + for (buffer_size, shift_amt) in &[ + (64, 0), + (128, 0), + (1024, 0), + (65535, 0), + (65536, 1), + (65537, 1), + (131071, 1), + (131072, 2), + (524287, 3), + (524288, 4), + (655350, 4), + (1048576, 5), + ] { + let mut s = socket_with_buffer_sizes(64, *buffer_size); + s.local_seq_no = LOCAL_SEQ; + assert_eq!(s.remote_win_shift, *shift_amt); + s.socket.connect(&mut s.cx, REMOTE_END, LOCAL_END).unwrap(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(*shift_amt), + window_len: cmp::min(*buffer_size, 65535) as u16, + sack_permitted: true, + ..RECV_TEMPL + }] + ); + } + } + + #[test] + fn test_syn_sent_syn_ack_no_window_scaling() { + let mut s = socket_syn_sent_with_buffer_sizes(1048576, 1048576); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + // scaling does NOT apply to the window value in SYN packets + window_len: 65535, + window_scale: Some(5), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + assert_eq!(s.remote_win_shift, 5); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + max_seg_size: Some(BASE_MSS - 80), + window_scale: None, + window_len: 42, + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Established); + assert_eq!(s.remote_win_shift, 0); + assert_eq!(s.remote_win_scale, None); + assert_eq!(s.remote_win_len, 42); + } + + #[test] + fn test_syn_sent_syn_ack_window_scaling() { + let mut s = socket_syn_sent(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + max_seg_size: Some(BASE_MSS - 80), + window_scale: Some(7), + window_len: 42, + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Established); + assert_eq!(s.remote_win_scale, Some(7)); + // scaling does NOT apply to the window value in SYN packets + assert_eq!(s.remote_win_len, 42); + } + + // =========================================================================================// + // Tests for the ESTABLISHED state. + // =========================================================================================// + + #[test] + fn test_established_recv() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }] + ); + assert_eq!(s.rx_buffer.dequeue_many(6), &b"abcdef"[..]); + } + + #[test] + fn test_peek_slice() { + const BUF_SIZE: usize = 10; + + let send_buf = b"0123456"; + + let mut s = socket_established_with_buffer_sizes(BUF_SIZE, BUF_SIZE); + + // Populate the recv buffer + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &send_buf[..], + ..SEND_TEMPL + } + ); + + // Peek into the recv buffer + let mut peeked_buf = [0u8; BUF_SIZE]; + let actually_peeked = s.peek_slice(&mut peeked_buf[..]).unwrap(); + let mut recv_buf = [0u8; BUF_SIZE]; + let actually_recvd = s.recv_slice(&mut recv_buf[..]).unwrap(); + assert_eq!( + &mut peeked_buf[..actually_peeked], + &mut recv_buf[..actually_recvd] + ); + } + + #[test] + fn test_peek_slice_buffer_wrap() { + const BUF_SIZE: usize = 10; + + let send_buf = b"0123456789"; + + let mut s = socket_established_with_buffer_sizes(BUF_SIZE, BUF_SIZE); + + let _ = s.rx_buffer.enqueue_slice(&send_buf[..8]); + let _ = s.rx_buffer.dequeue_many(6); + let _ = s.rx_buffer.enqueue_slice(&send_buf[..5]); + + let mut peeked_buf = [0u8; BUF_SIZE]; + let actually_peeked = s.peek_slice(&mut peeked_buf[..]).unwrap(); + let mut recv_buf = [0u8; BUF_SIZE]; + let actually_recvd = s.recv_slice(&mut recv_buf[..]).unwrap(); + assert_eq!( + &mut peeked_buf[..actually_peeked], + &mut recv_buf[..actually_recvd] + ); + } + + fn setup_rfc2018_cases() -> (TestSocket, Vec<u8>) { + // This is a utility function used by the tests for RFC 2018 cases. It configures a socket + // in a particular way suitable for those cases. + // + // RFC 2018: Assume the left window edge is 5000 and that the data transmitter sends [...] + // segments, each containing 500 data bytes. + let mut s = socket_established_with_buffer_sizes(4000, 4000); + s.remote_has_sack = true; + + // create a segment that is 500 bytes long + let mut segment: Vec<u8> = Vec::with_capacity(500); + + // move the last ack to 5000 by sending ten of them + for _ in 0..50 { + segment.extend_from_slice(b"abcdefghij") + } + for offset in (0..5000).step_by(500) { + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + offset, + ack_number: Some(LOCAL_SEQ + 1), + payload: &segment, + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + offset + 500), + window_len: 3500, + ..RECV_TEMPL + }] + ); + s.recv(|data| { + assert_eq!(data.len(), 500); + assert_eq!(data, segment.as_slice()); + (500, ()) + }) + .unwrap(); + } + assert_eq!(s.remote_last_win, 3500); + (s, segment) + } + + #[test] + fn test_established_rfc2018_cases() { + // This test case verifies the exact scenarios described on pages 8-9 of RFC 2018. Please + // ensure its behavior does not deviate from those scenarios. + + let (mut s, segment) = setup_rfc2018_cases(); + // RFC 2018: + // + // Case 2: The first segment is dropped but the remaining 7 are received. + // + // Upon receiving each of the last seven packets, the data receiver will return a TCP ACK + // segment that acknowledges sequence number 5000 and contains a SACK option specifying one + // block of queued data: + // + // Triggering ACK Left Edge Right Edge + // Segment + // + // 5000 (lost) + // 5500 5000 5500 6000 + // 6000 5000 5500 6500 + // 6500 5000 5500 7000 + // 7000 5000 5500 7500 + // 7500 5000 5500 8000 + // 8000 5000 5500 8500 + // 8500 5000 5500 9000 + // + for offset in (500..3500).step_by(500) { + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + offset + 5000, + ack_number: Some(LOCAL_SEQ + 1), + payload: &segment, + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 5000), + window_len: 4000, + sack_ranges: [ + Some(( + REMOTE_SEQ.0 as u32 + 1 + 5500, + REMOTE_SEQ.0 as u32 + 1 + 5500 + offset as u32 + )), + None, + None + ], + ..RECV_TEMPL + }) + ); + } + } + + #[test] + fn test_established_sliding_window_recv() { + let mut s = socket_established(); + // Update our scaling parameters for a TCP with a scaled buffer. + assert_eq!(s.rx_buffer.len(), 0); + s.rx_buffer = SocketBuffer::new(vec![0; 262143]); + s.assembler = Assembler::new(); + s.remote_win_scale = Some(0); + s.remote_last_win = 65535; + s.remote_win_shift = 2; + + // Create a TCP segment that will mostly fill an IP frame. + let mut segment: Vec<u8> = Vec::with_capacity(1400); + for _ in 0..100 { + segment.extend_from_slice(b"abcdefghijklmn") + } + assert_eq!(segment.len(), 1400); + + // Send the frame + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &segment, + ..SEND_TEMPL + } + ); + + // Ensure that the received window size is shifted right by 2. + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1400), + window_len: 65185, + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_established_send() { + let mut s = socket_established(); + // First roundtrip after establishing. + s.send_slice(b"abcdef").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + assert_eq!(s.tx_buffer.len(), 6); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + ..SEND_TEMPL + } + ); + assert_eq!(s.tx_buffer.len(), 0); + // Second roundtrip. + s.send_slice(b"foobar").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"foobar"[..], + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6), + ..SEND_TEMPL + } + ); + assert_eq!(s.tx_buffer.len(), 0); + } + + #[test] + fn test_established_send_no_ack_send() { + let mut s = socket_established(); + s.set_nagle_enabled(false); + s.send_slice(b"abcdef").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + s.send_slice(b"foobar").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"foobar"[..], + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_established_send_buf_gt_win() { + let mut data = [0; 32]; + for (i, elem) in data.iter_mut().enumerate() { + *elem = i as u8 + } + + let mut s = socket_established(); + s.remote_win_len = 16; + s.send_slice(&data[..]).unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &data[0..16], + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_established_send_window_shrink() { + let mut s = socket_established(); + + // 6 octets fit on the remote side's window, so we send them. + s.send_slice(b"abcdef").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + assert_eq!(s.tx_buffer.len(), 6); + + println!( + "local_seq_no={} remote_win_len={} remote_last_seq={}", + s.local_seq_no, s.remote_win_len, s.remote_last_seq + ); + + // - Peer doesn't ack them yet + // - Sends data so we need to reply with an ACK + // - ...AND and sends a window announcement that SHRINKS the window, so data we've + // previously sent is now outside the window. Yes, this is allowed by TCP. + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + window_len: 3, + payload: &b"xyzxyz"[..], + ..SEND_TEMPL + } + ); + assert_eq!(s.tx_buffer.len(), 6); + + println!( + "local_seq_no={} remote_win_len={} remote_last_seq={}", + s.local_seq_no, s.remote_win_len, s.remote_last_seq + ); + + // More data should not get sent since it doesn't fit in the window + s.send_slice(b"foobar").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 64 - 6, + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_established_receive_partially_outside_window() { + let mut s = socket_established(); + + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + + s.recv(|data| { + assert_eq!(data, b"abc"); + (3, ()) + }) + .unwrap(); + + // Peer decides to retransmit (perhaps because the ACK was lost) + // and also pushed data. + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + + s.recv(|data| { + assert_eq!(data, b"def"); + (3, ()) + }) + .unwrap(); + } + + #[test] + fn test_established_send_wrap() { + let mut s = socket_established(); + let local_seq_start = TcpSeqNumber(i32::MAX - 1); + s.local_seq_no = local_seq_start + 1; + s.remote_last_seq = local_seq_start + 1; + s.send_slice(b"abc").unwrap(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: local_seq_start + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abc"[..], + ..RECV_TEMPL + })); + } + + #[test] + fn test_established_no_ack() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: None, + ..SEND_TEMPL + } + ); + } + + #[test] + fn test_established_bad_ack() { + let mut s = socket_established(); + // Already acknowledged data. + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(TcpSeqNumber(LOCAL_SEQ.0 - 1)), + ..SEND_TEMPL + } + ); + assert_eq!(s.local_seq_no, LOCAL_SEQ + 1); + // Data not yet transmitted. + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 10), + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); + assert_eq!(s.local_seq_no, LOCAL_SEQ + 1); + } + + #[test] + fn test_established_bad_seq() { + let mut s = socket_established(); + // Data outside of receive window. + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 256, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); + assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1); + + // Challenge ACKs are rate-limited, we don't get a second one immediately. + send!( + s, + time 100, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 256, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + + // If we wait a bit, we do get a new one. + send!( + s, + time 2000, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 256, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); + assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1); + } + + #[test] + fn test_established_fin() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + assert_eq!(s.state, State::CloseWait); + sanity!(s, socket_close_wait()); + } + + #[test] + fn test_established_fin_after_missing() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"123456"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); + assert_eq!(s.state, State::Established); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6 + 6), + window_len: 52, + ..RECV_TEMPL + }) + ); + assert_eq!(s.state, State::Established); + } + + #[test] + fn test_established_send_fin() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::CloseWait); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_established_rst() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_established_rst_no_ack() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1, + ack_number: None, + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_established_close() { + let mut s = socket_established(); + s.close(); + assert_eq!(s.state, State::FinWait1); + sanity!(s, socket_fin_wait_1()); + } + + #[test] + fn test_established_abort() { + let mut s = socket_established(); + s.abort(); + assert_eq!(s.state, State::Closed); + recv!( + s, + [TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_established_rst_bad_seq() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, // Wrong seq + ack_number: None, + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); + + assert_eq!(s.state, State::Established); + + // Send something to advance seq by 1 + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, // correct seq + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"a"[..], + ..SEND_TEMPL + } + ); + + // Send wrong rst again, check that the challenge ack is correctly updated + // The ack number must be updated even if we don't call dispatch on the socket + // See https://github.com/smoltcp-rs/smoltcp/issues/338 + send!( + s, + time 2000, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, // Wrong seq + ack_number: None, + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 2), // this has changed + window_len: 63, + ..RECV_TEMPL + }) + ); + } + + // =========================================================================================// + // Tests for the FIN-WAIT-1 state. + // =========================================================================================// + + #[test] + fn test_fin_wait_1_fin_ack() { + let mut s = socket_fin_wait_1(); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::FinWait2); + sanity!(s, socket_fin_wait_2()); + } + + #[test] + fn test_fin_wait_1_fin_fin() { + let mut s = socket_fin_wait_1(); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closing); + sanity!(s, socket_closing()); + } + + #[test] + fn test_fin_wait_1_fin_with_data_queued() { + let mut s = socket_established(); + s.remote_win_len = 6; + s.send_slice(b"abcdef123456").unwrap(); + s.close(); + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }) + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::FinWait1); + } + + #[test] + fn test_fin_wait_1_recv() { + let mut s = socket_fin_wait_1(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::FinWait1); + s.recv(|data| { + assert_eq!(data, b"abc"); + (3, ()) + }) + .unwrap(); + } + + #[test] + fn test_fin_wait_1_close() { + let mut s = socket_fin_wait_1(); + s.close(); + assert_eq!(s.state, State::FinWait1); + } + + // =========================================================================================// + // Tests for the FIN-WAIT-2 state. + // =========================================================================================// + + #[test] + fn test_fin_wait_2_fin() { + let mut s = socket_fin_wait_2(); + send!(s, time 1_000, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::TimeWait); + sanity!(s, socket_time_wait(false)); + } + + #[test] + fn test_fin_wait_2_recv() { + let mut s = socket_fin_wait_2(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::FinWait2); + s.recv(|data| { + assert_eq!(data, b"abc"); + (3, ()) + }) + .unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_fin_wait_2_close() { + let mut s = socket_fin_wait_2(); + s.close(); + assert_eq!(s.state, State::FinWait2); + } + + // =========================================================================================// + // Tests for the CLOSING state. + // =========================================================================================// + + #[test] + fn test_closing_ack_fin() { + let mut s = socket_closing(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + send!(s, time 1_000, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::TimeWait); + sanity!(s, socket_time_wait(true)); + } + + #[test] + fn test_closing_close() { + let mut s = socket_closing(); + s.close(); + assert_eq!(s.state, State::Closing); + } + + // =========================================================================================// + // Tests for the TIME-WAIT state. + // =========================================================================================// + + #[test] + fn test_time_wait_from_fin_wait_2_ack() { + let mut s = socket_time_wait(false); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_time_wait_from_closing_no_ack() { + let mut s = socket_time_wait(true); + recv!(s, []); + } + + #[test] + fn test_time_wait_close() { + let mut s = socket_time_wait(false); + s.close(); + assert_eq!(s.state, State::TimeWait); + } + + #[test] + fn test_time_wait_retransmit() { + let mut s = socket_time_wait(false); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + send!(s, time 5_000, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + }, Some(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + })); + assert_eq!( + s.timer, + Timer::Close { + expires_at: Instant::from_secs(5) + CLOSE_DELAY + } + ); + } + + #[test] + fn test_time_wait_timeout() { + let mut s = socket_time_wait(false); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + assert_eq!(s.state, State::TimeWait); + recv_nothing!(s, time 60_000); + assert_eq!(s.state, State::Closed); + } + + // =========================================================================================// + // Tests for the CLOSE-WAIT state. + // =========================================================================================// + + #[test] + fn test_close_wait_ack() { + let mut s = socket_close_wait(); + s.send_slice(b"abcdef").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + ..SEND_TEMPL + } + ); + } + + #[test] + fn test_close_wait_close() { + let mut s = socket_close_wait(); + s.close(); + assert_eq!(s.state, State::LastAck); + sanity!(s, socket_last_ack()); + } + + // =========================================================================================// + // Tests for the LAST-ACK state. + // =========================================================================================// + #[test] + fn test_last_ack_fin_ack() { + let mut s = socket_last_ack(); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + assert_eq!(s.state, State::LastAck); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_last_ack_ack_not_of_fin() { + let mut s = socket_last_ack(); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + assert_eq!(s.state, State::LastAck); + + // ACK received that doesn't ack the FIN: socket should stay in LastAck. + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::LastAck); + + // ACK received of fin: socket should change to Closed. + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_last_ack_close() { + let mut s = socket_last_ack(); + s.close(); + assert_eq!(s.state, State::LastAck); + } + + // =========================================================================================// + // Tests for transitioning through multiple states. + // =========================================================================================// + + #[test] + fn test_listen() { + let mut s = socket(); + s.listen(LISTEN_END).unwrap(); + assert_eq!(s.state, State::Listen); + } + + #[test] + fn test_three_way_handshake() { + let mut s = socket_listen(); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + ..SEND_TEMPL + } + ); + assert_eq!(s.state(), State::SynReceived); + assert_eq!(s.tuple, Some(TUPLE)); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state(), State::Established); + assert_eq!(s.local_seq_no, LOCAL_SEQ + 1); + assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1); + } + + #[test] + fn test_remote_close() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::CloseWait); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + s.close(); + assert_eq!(s.state, State::LastAck); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_local_close() { + let mut s = socket_established(); + s.close(); + assert_eq!(s.state, State::FinWait1); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::FinWait2); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::TimeWait); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_simultaneous_close() { + let mut s = socket_established(); + s.close(); + assert_eq!(s.state, State::FinWait1); + recv!( + s, + [TcpRepr { + // due to reordering, this is logically located... + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closing); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + // ... at this point + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::TimeWait); + recv!(s, []); + } + + #[test] + fn test_simultaneous_close_combined_fin_ack() { + let mut s = socket_established(); + s.close(); + assert_eq!(s.state, State::FinWait1); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::TimeWait); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_simultaneous_close_raced() { + let mut s = socket_established(); + s.close(); + assert_eq!(s.state, State::FinWait1); + + // Socket receives FIN before it has a chance to send its own FIN + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closing); + + // FIN + ack-of-FIN + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + assert_eq!(s.state, State::Closing); + + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::TimeWait); + recv!(s, []); + } + + #[test] + fn test_simultaneous_close_raced_with_data() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + s.close(); + assert_eq!(s.state, State::FinWait1); + + // Socket receives FIN before it has a chance to send its own data+FIN + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closing); + + // data + FIN + ack-of-FIN + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + assert_eq!(s.state, State::Closing); + + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::TimeWait); + recv!(s, []); + } + + #[test] + fn test_fin_with_data() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + s.close(); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ) + } + + #[test] + fn test_mutual_close_with_data_1() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + s.close(); + assert_eq!(s.state, State::FinWait1); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + } + ); + } + + #[test] + fn test_mutual_close_with_data_2() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + s.close(); + assert_eq!(s.state, State::FinWait1); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::FinWait2); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + assert_eq!(s.state, State::TimeWait); + } + + // =========================================================================================// + // Tests for retransmission on packet loss. + // =========================================================================================// + + #[test] + fn test_duplicate_seq_ack() { + let mut s = socket_recved(); + // remote retransmission + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }) + ); + } + + #[test] + fn test_data_retransmit() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + recv_nothing!(s, time 1050); + recv!(s, time 2000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + } + + #[test] + fn test_data_retransmit_bursts() { + let mut s = socket_established(); + s.remote_mss = 6; + s.send_slice(b"abcdef012345").unwrap(); + + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::None, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }), exact); + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::Psh, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"012345"[..], + ..RECV_TEMPL + }), exact); + recv_nothing!(s, time 0); + + recv_nothing!(s, time 50); + + recv!(s, time 1000, Ok(TcpRepr { + control: TcpControl::None, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }), exact); + recv!(s, time 1500, Ok(TcpRepr { + control: TcpControl::Psh, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"012345"[..], + ..RECV_TEMPL + }), exact); + recv_nothing!(s, time 1550); + } + + #[test] + fn test_data_retransmit_bursts_half_ack() { + let mut s = socket_established(); + s.remote_mss = 6; + s.send_slice(b"abcdef012345").unwrap(); + + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::None, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }), exact); + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::Psh, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"012345"[..], + ..RECV_TEMPL + }), exact); + // Acknowledge the first packet + send!(s, time 5, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + window_len: 6, + ..SEND_TEMPL + }); + // The second packet should be re-sent. + recv!(s, time 1500, Ok(TcpRepr { + control: TcpControl::Psh, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"012345"[..], + ..RECV_TEMPL + }), exact); + + recv_nothing!(s, time 1550); + } + + #[test] + fn test_data_retransmit_bursts_half_ack_close() { + let mut s = socket_established(); + s.remote_mss = 6; + s.send_slice(b"abcdef012345").unwrap(); + s.close(); + + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::None, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }), exact); + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"012345"[..], + ..RECV_TEMPL + }), exact); + // Acknowledge the first packet + send!(s, time 5, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + window_len: 6, + ..SEND_TEMPL + }); + // The second packet should be re-sent. + recv!(s, time 1500, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"012345"[..], + ..RECV_TEMPL + }), exact); + + recv_nothing!(s, time 1550); + } + + #[test] + fn test_send_data_after_syn_ack_retransmit() { + let mut s = socket_syn_received(); + recv!(s, time 50, Ok(TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + })); + recv!(s, time 750, Ok(TcpRepr { // retransmit + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + })); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state(), State::Established); + s.send_slice(b"abcdef").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ) + } + + #[test] + fn test_established_retransmit_for_dup_ack() { + let mut s = socket_established(); + // Duplicate ACKs do not replace the retransmission timer + s.send_slice(b"abc").unwrap(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abc"[..], + ..RECV_TEMPL + })); + // Retransmit timer is on because all data was sent + assert_eq!(s.tx_buffer.len(), 3); + // ACK nothing new + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + // Retransmit + recv!(s, time 4000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abc"[..], + ..RECV_TEMPL + })); + } + + #[test] + fn test_established_retransmit_reset_after_ack() { + let mut s = socket_established(); + s.remote_win_len = 6; + s.send_slice(b"abcdef").unwrap(); + s.send_slice(b"123456").unwrap(); + s.send_slice(b"ABCDEF").unwrap(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + send!(s, time 1005, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + window_len: 6, + ..SEND_TEMPL + }); + recv!(s, time 1010, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"123456"[..], + ..RECV_TEMPL + })); + send!(s, time 1015, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6), + window_len: 6, + ..SEND_TEMPL + }); + recv!(s, time 1020, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"ABCDEF"[..], + ..RECV_TEMPL + })); + } + + #[test] + fn test_established_queue_during_retransmission() { + let mut s = socket_established(); + s.remote_mss = 6; + s.send_slice(b"abcdef123456ABCDEF").unwrap(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); // this one is dropped + recv!(s, time 1005, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"123456"[..], + ..RECV_TEMPL + })); // this one is received + recv!(s, time 1010, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"ABCDEF"[..], + ..RECV_TEMPL + })); // also dropped + recv!(s, time 2000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); // retransmission + send!(s, time 2005, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6), + ..SEND_TEMPL + }); // acknowledgement of both segments + recv!(s, time 2010, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"ABCDEF"[..], + ..RECV_TEMPL + })); // retransmission of only unacknowledged data + } + + #[test] + fn test_close_wait_retransmit_reset_after_ack() { + let mut s = socket_close_wait(); + s.remote_win_len = 6; + s.send_slice(b"abcdef").unwrap(); + s.send_slice(b"123456").unwrap(); + s.send_slice(b"ABCDEF").unwrap(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + send!(s, time 1005, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + window_len: 6, + ..SEND_TEMPL + }); + recv!(s, time 1010, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"123456"[..], + ..RECV_TEMPL + })); + send!(s, time 1015, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6), + window_len: 6, + ..SEND_TEMPL + }); + recv!(s, time 1020, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"ABCDEF"[..], + ..RECV_TEMPL + })); + } + + #[test] + fn test_fin_wait_1_retransmit_reset_after_ack() { + let mut s = socket_established(); + s.remote_win_len = 6; + s.send_slice(b"abcdef").unwrap(); + s.send_slice(b"123456").unwrap(); + s.send_slice(b"ABCDEF").unwrap(); + s.close(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + send!(s, time 1005, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + window_len: 6, + ..SEND_TEMPL + }); + recv!(s, time 1010, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"123456"[..], + ..RECV_TEMPL + })); + send!(s, time 1015, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6), + window_len: 6, + ..SEND_TEMPL + }); + recv!(s, time 1020, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"ABCDEF"[..], + ..RECV_TEMPL + })); + } + + #[test] + fn test_fast_retransmit_after_triple_duplicate_ack() { + let mut s = socket_established(); + s.remote_mss = 6; + + // Normal ACK of previously received segment + send!(s, time 0, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + + // Send a long string of text divided into several packets + // because of previously received "window_len" + s.send_slice(b"xxxxxxyyyyyywwwwwwzzzzzz").unwrap(); + // This packet is lost + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"xxxxxx"[..], + ..RECV_TEMPL + })); + recv!(s, time 1005, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"yyyyyy"[..], + ..RECV_TEMPL + })); + recv!(s, time 1010, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + (6 * 2), + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"wwwwww"[..], + ..RECV_TEMPL + })); + recv!(s, time 1015, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + (6 * 3), + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"zzzzzz"[..], + ..RECV_TEMPL + })); + + // First duplicate ACK + send!(s, time 1050, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + // Second duplicate ACK + send!(s, time 1055, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + // Third duplicate ACK + // Should trigger a fast retransmit of dropped packet + send!(s, time 1060, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + + // Fast retransmit packet + recv!(s, time 1100, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"xxxxxx"[..], + ..RECV_TEMPL + })); + + recv!(s, time 1105, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"yyyyyy"[..], + ..RECV_TEMPL + })); + recv!(s, time 1110, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + (6 * 2), + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"wwwwww"[..], + ..RECV_TEMPL + })); + recv!(s, time 1115, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + (6 * 3), + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"zzzzzz"[..], + ..RECV_TEMPL + })); + + // After all was send out, enter *normal* retransmission, + // don't stay in fast retransmission. + assert!(match s.timer { + Timer::Retransmit { expires_at, .. } => expires_at > Instant::from_millis(1115), + _ => false, + }); + + // ACK all received segments + send!(s, time 1120, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + (6 * 4)), + ..SEND_TEMPL + }); + } + + #[test] + fn test_fast_retransmit_duplicate_detection_with_data() { + let mut s = socket_established(); + + s.send_slice(b"abc").unwrap(); // This is lost + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abc"[..], + ..RECV_TEMPL + })); + + // Normal ACK of previously received segment + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + // First duplicate + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + // Second duplicate + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + + assert_eq!(s.local_rx_dup_acks, 2, "duplicate ACK counter is not set"); + + // This packet has content, hence should not be detected + // as a duplicate ACK and should reset the duplicate ACK count + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"xxxxxx"[..], + ..SEND_TEMPL + } + ); + + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 3, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }] + ); + + assert_eq!( + s.local_rx_dup_acks, 0, + "duplicate ACK counter is not reset when receiving data" + ); + } + + #[test] + fn test_fast_retransmit_duplicate_detection_with_window_update() { + let mut s = socket_established(); + + s.send_slice(b"abc").unwrap(); // This is lost + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abc"[..], + ..RECV_TEMPL + })); + + // Normal ACK of previously received segment + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + // First duplicate + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + // Second duplicate + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + + assert_eq!(s.local_rx_dup_acks, 2, "duplicate ACK counter is not set"); + + // This packet has a window update, hence should not be detected + // as a duplicate ACK and should reset the duplicate ACK count + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + window_len: 400, + ..SEND_TEMPL + } + ); + + assert_eq!( + s.local_rx_dup_acks, 0, + "duplicate ACK counter is not reset when receiving a window update" + ); + } + + #[test] + fn test_fast_retransmit_duplicate_detection() { + let mut s = socket_established(); + s.remote_mss = 6; + + // Normal ACK of previously received segment + send!(s, time 0, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + + // First duplicate, should not be counted as there is nothing to resend + send!(s, time 0, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + + assert_eq!( + s.local_rx_dup_acks, 0, + "duplicate ACK counter is set but wound not transmit data" + ); + + // Send a long string of text divided into several packets + // because of small remote_mss + s.send_slice(b"xxxxxxyyyyyywwwwwwzzzzzz").unwrap(); + + // This packet is reordered in network + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"xxxxxx"[..], + ..RECV_TEMPL + })); + recv!(s, time 1005, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"yyyyyy"[..], + ..RECV_TEMPL + })); + recv!(s, time 1010, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + (6 * 2), + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"wwwwww"[..], + ..RECV_TEMPL + })); + recv!(s, time 1015, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + (6 * 3), + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"zzzzzz"[..], + ..RECV_TEMPL + })); + + // First duplicate ACK + send!(s, time 1050, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + // Second duplicate ACK + send!(s, time 1055, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + // Reordered packet arrives which should reset duplicate ACK count + send!(s, time 1060, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + (6 * 3)), + ..SEND_TEMPL + }); + + assert_eq!( + s.local_rx_dup_acks, 0, + "duplicate ACK counter is not reset when receiving ACK which updates send window" + ); + + // ACK all received segments + send!(s, time 1120, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + (6 * 4)), + ..SEND_TEMPL + }); + } + + #[test] + fn test_fast_retransmit_dup_acks_counter() { + let mut s = socket_established(); + + s.send_slice(b"abc").unwrap(); // This is lost + recv!(s, time 0, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abc"[..], + ..RECV_TEMPL + })); + + send!(s, time 0, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + + // A lot of retransmits happen here + s.local_rx_dup_acks = u8::max_value() - 1; + + // Send 3 more ACKs, which could overflow local_rx_dup_acks, + // but intended behaviour is that we saturate the bounds + // of local_rx_dup_acks + send!(s, time 0, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + send!(s, time 0, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + send!(s, time 0, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!( + s.local_rx_dup_acks, + u8::max_value(), + "duplicate ACK count should not overflow but saturate" + ); + } + + #[test] + fn test_fast_retransmit_zero_window() { + let mut s = socket_established(); + + send!(s, time 1000, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + + s.send_slice(b"abc").unwrap(); + + recv!(s, time 0, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abc"[..], + ..RECV_TEMPL + })); + + // 3 dup acks + send!(s, time 1050, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + send!(s, time 1050, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + send!(s, time 1050, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + window_len: 0, // boom + ..SEND_TEMPL + }); + + // even though we're in "fast retransmit", we shouldn't + // force-send anything because the remote's window is full. + recv_nothing!(s); + } + + // =========================================================================================// + // Tests for window management. + // =========================================================================================// + + #[test] + fn test_maximum_segment_size() { + let mut s = socket_listen(); + s.tx_buffer = SocketBuffer::new(vec![0; 32767]); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + max_seg_size: Some(1000), + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + window_len: 32767, + ..SEND_TEMPL + } + ); + s.send_slice(&[0; 1200][..]).unwrap(); + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &[0; 1000][..], + ..RECV_TEMPL + }) + ); + } + + #[test] + fn test_close_wait_no_window_update() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &[1, 2, 3, 4], + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::CloseWait); + + // we ack the FIN, with the reduced window size. + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 6), + window_len: 60, + ..RECV_TEMPL + }) + ); + + let rx_buf = &mut [0; 32]; + assert_eq!(s.recv_slice(rx_buf), Ok(4)); + + // check that we do NOT send a window update even if it has changed. + recv_nothing!(s); + } + + #[test] + fn test_time_wait_no_window_update() { + let mut s = socket_fin_wait_2(); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 2), + payload: &[1, 2, 3, 4], + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::TimeWait); + + // we ack the FIN, with the reduced window size. + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 2, + ack_number: Some(REMOTE_SEQ + 6), + window_len: 60, + ..RECV_TEMPL + }) + ); + + let rx_buf = &mut [0; 32]; + assert_eq!(s.recv_slice(rx_buf), Ok(4)); + + // check that we do NOT send a window update even if it has changed. + recv_nothing!(s); + } + + // =========================================================================================// + // Tests for flow control. + // =========================================================================================// + + #[test] + fn test_psh_transmit() { + let mut s = socket_established(); + s.remote_mss = 6; + s.send_slice(b"abcdef").unwrap(); + s.send_slice(b"123456").unwrap(); + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::None, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }), exact); + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::Psh, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"123456"[..], + ..RECV_TEMPL + }), exact); + } + + #[test] + fn test_psh_receive() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + control: TcpControl::Psh, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_zero_window_ack() { + let mut s = socket_established(); + s.rx_buffer = SocketBuffer::new(vec![0; 6]); + s.assembler = Assembler::new(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 0, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"123456"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 0, + ..RECV_TEMPL + }) + ); + } + + #[test] + fn test_zero_window_fin() { + let mut s = socket_established(); + s.rx_buffer = SocketBuffer::new(vec![0; 6]); + s.assembler = Assembler::new(); + s.ack_delay = None; + + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 0, + ..RECV_TEMPL + }] + ); + + // Even though the sequence space for the FIN itself is outside the window, + // it is not data, so FIN must be accepted when window full. + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &[], + control: TcpControl::Fin, + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::CloseWait); + + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 7), + window_len: 0, + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_zero_window_ack_on_window_growth() { + let mut s = socket_established(); + s.rx_buffer = SocketBuffer::new(vec![0; 6]); + s.assembler = Assembler::new(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 0, + ..RECV_TEMPL + }] + ); + recv_nothing!(s, time 0); + s.recv(|buffer| { + assert_eq!(&buffer[..3], b"abc"); + (3, ()) + }) + .unwrap(); + recv!(s, time 0, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 3, + ..RECV_TEMPL + })); + recv_nothing!(s, time 0); + s.recv(|buffer| { + assert_eq!(buffer, b"def"); + (buffer.len(), ()) + }) + .unwrap(); + recv!(s, time 0, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 6, + ..RECV_TEMPL + })); + } + + #[test] + fn test_fill_peer_window() { + let mut s = socket_established(); + s.remote_mss = 6; + s.send_slice(b"abcdef123456!@#$%^").unwrap(); + recv!( + s, + [ + TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }, + TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"123456"[..], + ..RECV_TEMPL + }, + TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"!@#$%^"[..], + ..RECV_TEMPL + } + ] + ); + } + + #[test] + fn test_announce_window_after_read() { + let mut s = socket_established(); + s.rx_buffer = SocketBuffer::new(vec![0; 6]); + s.assembler = Assembler::new(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + window_len: 3, + ..RECV_TEMPL + }] + ); + // Test that `dispatch` updates `remote_last_win` + assert_eq!(s.remote_last_win, s.rx_buffer.window() as u16); + s.recv(|buffer| (buffer.len(), ())).unwrap(); + assert!(s.window_to_update()); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + window_len: 6, + ..RECV_TEMPL + }] + ); + assert_eq!(s.remote_last_win, s.rx_buffer.window() as u16); + // Provoke immediate ACK to test that `process` updates `remote_last_win` + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"def"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + window_len: 6, + ..RECV_TEMPL + }) + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 9), + window_len: 0, + ..RECV_TEMPL + }) + ); + assert_eq!(s.remote_last_win, s.rx_buffer.window() as u16); + s.recv(|buffer| (buffer.len(), ())).unwrap(); + assert!(s.window_to_update()); + } + + // =========================================================================================// + // Tests for timeouts. + // =========================================================================================// + + #[test] + fn test_listen_timeout() { + let mut s = socket_listen(); + s.set_timeout(Some(Duration::from_millis(100))); + assert_eq!(s.socket.poll_at(&mut s.cx), PollAt::Ingress); + } + + #[test] + fn test_connect_timeout() { + let mut s = socket(); + s.local_seq_no = LOCAL_SEQ; + s.socket + .connect(&mut s.cx, REMOTE_END, LOCAL_END.port) + .unwrap(); + s.set_timeout(Some(Duration::from_millis(100))); + recv!(s, time 150, Ok(TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + })); + assert_eq!(s.state, State::SynSent); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(250)) + ); + recv!(s, time 250, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(TcpSeqNumber(0)), + window_scale: None, + ..RECV_TEMPL + })); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_established_timeout() { + let mut s = socket_established(); + s.set_timeout(Some(Duration::from_millis(1000))); + recv_nothing!(s, time 250); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(1250)) + ); + s.send_slice(b"abcdef").unwrap(); + assert_eq!(s.socket.poll_at(&mut s.cx), PollAt::Now); + recv!(s, time 255, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(955)) + ); + recv!(s, time 955, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(1255)) + ); + recv!(s, time 1255, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + })); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_established_keep_alive_timeout() { + let mut s = socket_established(); + s.set_keep_alive(Some(Duration::from_millis(50))); + s.set_timeout(Some(Duration::from_millis(100))); + recv!(s, time 100, Ok(TcpRepr { + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + payload: &[0], + ..RECV_TEMPL + })); + recv_nothing!(s, time 100); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(150)) + ); + send!(s, time 105, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(155)) + ); + recv!(s, time 155, Ok(TcpRepr { + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + payload: &[0], + ..RECV_TEMPL + })); + recv_nothing!(s, time 155); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(205)) + ); + recv_nothing!(s, time 200); + recv!(s, time 205, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + })); + recv_nothing!(s, time 205); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_fin_wait_1_timeout() { + let mut s = socket_fin_wait_1(); + s.set_timeout(Some(Duration::from_millis(1000))); + recv!(s, time 100, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + })); + recv!(s, time 1100, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + })); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_last_ack_timeout() { + let mut s = socket_last_ack(); + s.set_timeout(Some(Duration::from_millis(1000))); + recv!(s, time 100, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + })); + recv!(s, time 1100, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + })); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_closed_timeout() { + let mut s = socket_established(); + s.set_timeout(Some(Duration::from_millis(200))); + s.remote_last_ts = Some(Instant::from_millis(100)); + s.abort(); + assert_eq!(s.socket.poll_at(&mut s.cx), PollAt::Now); + recv!(s, time 100, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + })); + assert_eq!(s.socket.poll_at(&mut s.cx), PollAt::Ingress); + } + + // =========================================================================================// + // Tests for keep-alive. + // =========================================================================================// + + #[test] + fn test_responds_to_keep_alive() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); + } + + #[test] + fn test_sends_keep_alive() { + let mut s = socket_established(); + s.set_keep_alive(Some(Duration::from_millis(100))); + + // drain the forced keep-alive packet + assert_eq!(s.socket.poll_at(&mut s.cx), PollAt::Now); + recv!(s, time 0, Ok(TcpRepr { + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + payload: &[0], + ..RECV_TEMPL + })); + + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(100)) + ); + recv_nothing!(s, time 95); + recv!(s, time 100, Ok(TcpRepr { + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + payload: &[0], + ..RECV_TEMPL + })); + + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(200)) + ); + recv_nothing!(s, time 195); + recv!(s, time 200, Ok(TcpRepr { + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + payload: &[0], + ..RECV_TEMPL + })); + + send!(s, time 250, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(350)) + ); + recv_nothing!(s, time 345); + recv!(s, time 350, Ok(TcpRepr { + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"\x00"[..], + ..RECV_TEMPL + })); + } + + // =========================================================================================// + // Tests for time-to-live configuration. + // =========================================================================================// + + #[test] + fn test_set_hop_limit() { + let mut s = socket_syn_received(); + + s.set_hop_limit(Some(0x2a)); + assert_eq!( + s.socket.dispatch(&mut s.cx, |_, (ip_repr, _)| { + assert_eq!(ip_repr.hop_limit(), 0x2a); + Ok::<_, ()>(()) + }), + Ok(()) + ); + + // assert that user-configurable settings are kept, + // see https://github.com/smoltcp-rs/smoltcp/issues/601. + s.reset(); + assert_eq!(s.hop_limit(), Some(0x2a)); + } + + #[test] + #[should_panic(expected = "the time-to-live value of a packet must not be zero")] + fn test_set_hop_limit_zero() { + let mut s = socket_syn_received(); + s.set_hop_limit(Some(0)); + } + + // =========================================================================================// + // Tests for reassembly. + // =========================================================================================// + + #[test] + fn test_out_of_order() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"def"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); + s.recv(|buffer| { + assert_eq!(buffer, b""); + (buffer.len(), ()) + }) + .unwrap(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }) + ); + s.recv(|buffer| { + assert_eq!(buffer, b"abcdef"); + (buffer.len(), ()) + }) + .unwrap(); + } + + #[test] + fn test_buffer_wraparound_rx() { + let mut s = socket_established(); + s.rx_buffer = SocketBuffer::new(vec![0; 6]); + s.assembler = Assembler::new(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + s.recv(|buffer| { + assert_eq!(buffer, b"abc"); + (buffer.len(), ()) + }) + .unwrap(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"defghi"[..], + ..SEND_TEMPL + } + ); + let mut data = [0; 6]; + assert_eq!(s.recv_slice(&mut data[..]), Ok(6)); + assert_eq!(data, &b"defghi"[..]); + } + + #[test] + fn test_buffer_wraparound_tx() { + let mut s = socket_established(); + s.set_nagle_enabled(false); + + s.tx_buffer = SocketBuffer::new(vec![b'.'; 9]); + assert_eq!(s.send_slice(b"xxxyyy"), Ok(6)); + assert_eq!(s.tx_buffer.dequeue_many(3), &b"xxx"[..]); + assert_eq!(s.tx_buffer.len(), 3); + + // "abcdef" not contiguous in tx buffer + assert_eq!(s.send_slice(b"abcdef"), Ok(6)); + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"yyyabc"[..], + ..RECV_TEMPL + }) + ); + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"def"[..], + ..RECV_TEMPL + }) + ); + } + + // =========================================================================================// + // Tests for graceful vs ungraceful rx close + // =========================================================================================// + + #[test] + fn test_rx_close_fin() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + s.recv(|data| { + assert_eq!(data, b"abc"); + (3, ()) + }) + .unwrap(); + assert_eq!(s.recv(|_| (0, ())), Err(RecvError::Finished)); + } + + #[test] + fn test_rx_close_fin_in_fin_wait_1() { + let mut s = socket_fin_wait_1(); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closing); + s.recv(|data| { + assert_eq!(data, b"abc"); + (3, ()) + }) + .unwrap(); + assert_eq!(s.recv(|_| (0, ())), Err(RecvError::Finished)); + } + + #[test] + fn test_rx_close_fin_in_fin_wait_2() { + let mut s = socket_fin_wait_2(); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::TimeWait); + s.recv(|data| { + assert_eq!(data, b"abc"); + (3, ()) + }) + .unwrap(); + assert_eq!(s.recv(|_| (0, ())), Err(RecvError::Finished)); + } + + #[test] + fn test_rx_close_fin_with_hole() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"ghi"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + window_len: 61, + ..RECV_TEMPL + }) + ); + s.recv(|data| { + assert_eq!(data, b"abc"); + (3, ()) + }) + .unwrap(); + s.recv(|data| { + assert_eq!(data, b""); + (0, ()) + }) + .unwrap(); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1 + 9, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + // Error must be `Illegal` even if we've received a FIN, + // because we are missing data. + assert_eq!(s.recv(|_| (0, ())), Err(RecvError::InvalidState)); + } + + #[test] + fn test_rx_close_rst() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + s.recv(|data| { + assert_eq!(data, b"abc"); + (3, ()) + }) + .unwrap(); + assert_eq!(s.recv(|_| (0, ())), Err(RecvError::InvalidState)); + } + + #[test] + fn test_rx_close_rst_with_hole() { + let mut s = socket_established(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"ghi"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + window_len: 61, + ..RECV_TEMPL + }) + ); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1 + 9, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + s.recv(|data| { + assert_eq!(data, b"abc"); + (3, ()) + }) + .unwrap(); + assert_eq!(s.recv(|_| (0, ())), Err(RecvError::InvalidState)); + } + + // =========================================================================================// + // Tests for delayed ACK + // =========================================================================================// + + #[test] + fn test_delayed_ack() { + let mut s = socket_established(); + s.set_ack_delay(Some(ACK_DELAY_DEFAULT)); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + + // No ACK is immediately sent. + recv_nothing!(s); + + // After 10ms, it is sent. + recv!(s, time 11, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + window_len: 61, + ..RECV_TEMPL + })); + } + + #[test] + fn test_delayed_ack_win() { + let mut s = socket_established(); + s.set_ack_delay(Some(ACK_DELAY_DEFAULT)); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + + // Reading the data off the buffer should cause a window update. + s.recv(|data| { + assert_eq!(data, b"abc"); + (3, ()) + }) + .unwrap(); + + // However, no ACK or window update is immediately sent. + recv_nothing!(s); + + // After 10ms, it is sent. + recv!(s, time 11, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + ..RECV_TEMPL + })); + } + + #[test] + fn test_delayed_ack_reply() { + let mut s = socket_established(); + s.set_ack_delay(Some(ACK_DELAY_DEFAULT)); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + + s.recv(|data| { + assert_eq!(data, b"abc"); + (3, ()) + }) + .unwrap(); + + s.send_slice(&b"xyz"[..]).unwrap(); + + // Writing data to the socket causes ACK to not be delayed, + // because it is immediately sent with the data. + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + payload: &b"xyz"[..], + ..RECV_TEMPL + }) + ); + } + + #[test] + fn test_delayed_ack_every_second_packet() { + let mut s = socket_established(); + s.set_ack_delay(Some(ACK_DELAY_DEFAULT)); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + + // No ACK is immediately sent. + recv_nothing!(s); + + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"def"[..], + ..SEND_TEMPL + } + ); + + // Every 2nd packet, ACK is sent without delay. + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }) + ); + } + + #[test] + fn test_delayed_ack_three_packets() { + let mut s = socket_established(); + s.set_ack_delay(Some(ACK_DELAY_DEFAULT)); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + + // No ACK is immediately sent. + recv_nothing!(s); + + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"def"[..], + ..SEND_TEMPL + } + ); + + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"ghi"[..], + ..SEND_TEMPL + } + ); + + // Every 2nd (or more) packet, ACK is sent without delay. + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 9), + window_len: 55, + ..RECV_TEMPL + }) + ); + } + + // =========================================================================================// + // Tests for Nagle's Algorithm + // =========================================================================================// + + #[test] + fn test_nagle() { + let mut s = socket_established(); + s.remote_mss = 6; + + s.send_slice(b"abcdef").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + + // If there's data in flight, full segments get sent. + s.send_slice(b"foobar").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"foobar"[..], + ..RECV_TEMPL + }] + ); + + s.send_slice(b"aaabbbccc").unwrap(); + // If there's data in flight, not-full segments don't get sent. + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"aaabbb"[..], + ..RECV_TEMPL + }] + ); + + // Data gets ACKd, so there's no longer data in flight + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6 + 6), + ..SEND_TEMPL + } + ); + + // Now non-full segment gets sent. + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"ccc"[..], + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_final_packet_in_stream_doesnt_wait_for_nagle() { + let mut s = socket_established(); + s.remote_mss = 6; + s.send_slice(b"abcdef0").unwrap(); + s.socket.close(); + + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::None, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }), exact); + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"0"[..], + ..RECV_TEMPL + }), exact); + } + + // =========================================================================================// + // Tests for packet filtering. + // =========================================================================================// + + #[test] + fn test_doesnt_accept_wrong_port() { + let mut s = socket_established(); + s.rx_buffer = SocketBuffer::new(vec![0; 6]); + s.assembler = Assembler::new(); + + let tcp_repr = TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + dst_port: LOCAL_PORT + 1, + ..SEND_TEMPL + }; + assert!(!s.socket.accepts(&mut s.cx, &SEND_IP_TEMPL, &tcp_repr)); + + let tcp_repr = TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + src_port: REMOTE_PORT + 1, + ..SEND_TEMPL + }; + assert!(!s.socket.accepts(&mut s.cx, &SEND_IP_TEMPL, &tcp_repr)); + } + + #[test] + fn test_doesnt_accept_wrong_ip() { + let mut s = socket_established(); + + let tcp_repr = TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }; + + let ip_repr = IpReprIpvX(IpvXRepr { + src_addr: REMOTE_ADDR, + dst_addr: LOCAL_ADDR, + next_header: IpProtocol::Tcp, + payload_len: tcp_repr.buffer_len(), + hop_limit: 64, + }); + assert!(s.socket.accepts(&mut s.cx, &ip_repr, &tcp_repr)); + + let ip_repr_wrong_src = IpReprIpvX(IpvXRepr { + src_addr: OTHER_ADDR, + dst_addr: LOCAL_ADDR, + next_header: IpProtocol::Tcp, + payload_len: tcp_repr.buffer_len(), + hop_limit: 64, + }); + assert!(!s.socket.accepts(&mut s.cx, &ip_repr_wrong_src, &tcp_repr)); + + let ip_repr_wrong_dst = IpReprIpvX(IpvXRepr { + src_addr: REMOTE_ADDR, + dst_addr: OTHER_ADDR, + next_header: IpProtocol::Tcp, + payload_len: tcp_repr.buffer_len(), + hop_limit: 64, + }); + assert!(!s.socket.accepts(&mut s.cx, &ip_repr_wrong_dst, &tcp_repr)); + } + + // =========================================================================================// + // Timer tests + // =========================================================================================// + + #[test] + fn test_timer_retransmit() { + const RTO: Duration = Duration::from_millis(100); + let mut r = Timer::new(); + assert_eq!(r.should_retransmit(Instant::from_secs(1)), None); + r.set_for_retransmit(Instant::from_millis(1000), RTO); + assert_eq!(r.should_retransmit(Instant::from_millis(1000)), None); + assert_eq!(r.should_retransmit(Instant::from_millis(1050)), None); + assert_eq!( + r.should_retransmit(Instant::from_millis(1101)), + Some(Duration::from_millis(101)) + ); + r.set_for_retransmit(Instant::from_millis(1101), RTO); + assert_eq!(r.should_retransmit(Instant::from_millis(1101)), None); + assert_eq!(r.should_retransmit(Instant::from_millis(1150)), None); + assert_eq!(r.should_retransmit(Instant::from_millis(1200)), None); + assert_eq!( + r.should_retransmit(Instant::from_millis(1301)), + Some(Duration::from_millis(300)) + ); + r.set_for_idle(Instant::from_millis(1301), None); + assert_eq!(r.should_retransmit(Instant::from_millis(1350)), None); + } + + #[test] + fn test_rtt_estimator() { + let mut r = RttEstimator::default(); + + let rtos = &[ + 751, 766, 755, 731, 697, 656, 613, 567, 523, 484, 445, 411, 378, 350, 322, 299, 280, + 261, 243, 229, 215, 206, 197, 188, + ]; + + for &rto in rtos { + r.sample(100); + assert_eq!(r.retransmission_timeout(), Duration::from_millis(rto)); + } + } +} diff --git a/src/socket/udp.rs b/src/socket/udp.rs new file mode 100644 index 0000000..82eebf2 --- /dev/null +++ b/src/socket/udp.rs @@ -0,0 +1,1030 @@ +use core::cmp::min; +#[cfg(feature = "async")] +use core::task::Waker; + +use crate::iface::Context; +use crate::phy::PacketMeta; +use crate::socket::PollAt; +#[cfg(feature = "async")] +use crate::socket::WakerRegistration; +use crate::storage::Empty; +use crate::wire::{IpEndpoint, IpListenEndpoint, IpProtocol, IpRepr, UdpRepr}; + +/// Metadata for a sent or received UDP packet. +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct UdpMetadata { + pub endpoint: IpEndpoint, + pub meta: PacketMeta, +} + +impl<T: Into<IpEndpoint>> From<T> for UdpMetadata { + fn from(value: T) -> Self { + Self { + endpoint: value.into(), + meta: PacketMeta::default(), + } + } +} + +impl core::fmt::Display for UdpMetadata { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + #[cfg(feature = "packetmeta-id")] + return write!(f, "{}, PacketID: {:?}", self.endpoint, self.meta); + + #[cfg(not(feature = "packetmeta-id"))] + write!(f, "{}", self.endpoint) + } +} + +/// A UDP packet metadata. +pub type PacketMetadata = crate::storage::PacketMetadata<UdpMetadata>; + +/// A UDP packet ring buffer. +pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, UdpMetadata>; + +/// Error returned by [`Socket::bind`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum BindError { + InvalidState, + Unaddressable, +} + +impl core::fmt::Display for BindError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + BindError::InvalidState => write!(f, "invalid state"), + BindError::Unaddressable => write!(f, "unaddressable"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for BindError {} + +/// Error returned by [`Socket::send`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum SendError { + Unaddressable, + BufferFull, +} + +impl core::fmt::Display for SendError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + SendError::Unaddressable => write!(f, "unaddressable"), + SendError::BufferFull => write!(f, "buffer full"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for SendError {} + +/// Error returned by [`Socket::recv`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum RecvError { + Exhausted, + Truncated, +} + +impl core::fmt::Display for RecvError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + RecvError::Exhausted => write!(f, "exhausted"), + RecvError::Truncated => write!(f, "truncated"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for RecvError {} + +/// A User Datagram Protocol socket. +/// +/// A UDP socket is bound to a specific endpoint, and owns transmit and receive +/// packet buffers. +#[derive(Debug)] +pub struct Socket<'a> { + endpoint: IpListenEndpoint, + rx_buffer: PacketBuffer<'a>, + tx_buffer: PacketBuffer<'a>, + /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + hop_limit: Option<u8>, + #[cfg(feature = "async")] + rx_waker: WakerRegistration, + #[cfg(feature = "async")] + tx_waker: WakerRegistration, +} + +impl<'a> Socket<'a> { + /// Create an UDP socket with the given buffers. + pub fn new(rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>) -> Socket<'a> { + Socket { + endpoint: IpListenEndpoint::default(), + rx_buffer, + tx_buffer, + hop_limit: None, + #[cfg(feature = "async")] + rx_waker: WakerRegistration::new(), + #[cfg(feature = "async")] + tx_waker: WakerRegistration::new(), + } + } + + /// Register a waker for receive operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `recv` method calls, such as receiving data, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `recv` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_recv_waker(&mut self, waker: &Waker) { + self.rx_waker.register(waker) + } + + /// Register a waker for send operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `send` method calls, such as space becoming available in the transmit + /// buffer, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `send` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_send_waker(&mut self, waker: &Waker) { + self.tx_waker.register(waker) + } + + /// Return the bound endpoint. + #[inline] + pub fn endpoint(&self) -> IpListenEndpoint { + self.endpoint + } + + /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + /// + /// See also the [set_hop_limit](#method.set_hop_limit) method + pub fn hop_limit(&self) -> Option<u8> { + self.hop_limit + } + + /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + /// + /// A socket without an explicitly set hop limit value uses the default [IANA recommended] + /// value (64). + /// + /// # Panics + /// + /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7]. + /// + /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml + /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7 + pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) { + // A host MUST NOT send a datagram with a hop limit value of 0 + if let Some(0) = hop_limit { + panic!("the time-to-live value of a packet must not be zero") + } + + self.hop_limit = hop_limit + } + + /// Bind the socket to the given endpoint. + /// + /// This function returns `Err(Error::Illegal)` if the socket was open + /// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)` + /// if the port in the given endpoint is zero. + pub fn bind<T: Into<IpListenEndpoint>>(&mut self, endpoint: T) -> Result<(), BindError> { + let endpoint = endpoint.into(); + if endpoint.port == 0 { + return Err(BindError::Unaddressable); + } + + if self.is_open() { + return Err(BindError::InvalidState); + } + + self.endpoint = endpoint; + + #[cfg(feature = "async")] + { + self.rx_waker.wake(); + self.tx_waker.wake(); + } + + Ok(()) + } + + /// Close the socket. + pub fn close(&mut self) { + // Clear the bound endpoint of the socket. + self.endpoint = IpListenEndpoint::default(); + + // Reset the RX and TX buffers of the socket. + self.tx_buffer.reset(); + self.rx_buffer.reset(); + + #[cfg(feature = "async")] + { + self.rx_waker.wake(); + self.tx_waker.wake(); + } + } + + /// Check whether the socket is open. + #[inline] + pub fn is_open(&self) -> bool { + self.endpoint.port != 0 + } + + /// Check whether the transmit buffer is full. + #[inline] + pub fn can_send(&self) -> bool { + !self.tx_buffer.is_full() + } + + /// Check whether the receive buffer is not empty. + #[inline] + pub fn can_recv(&self) -> bool { + !self.rx_buffer.is_empty() + } + + /// Return the maximum number packets the socket can receive. + #[inline] + pub fn packet_recv_capacity(&self) -> usize { + self.rx_buffer.packet_capacity() + } + + /// Return the maximum number packets the socket can transmit. + #[inline] + pub fn packet_send_capacity(&self) -> usize { + self.tx_buffer.packet_capacity() + } + + /// Return the maximum number of bytes inside the recv buffer. + #[inline] + pub fn payload_recv_capacity(&self) -> usize { + self.rx_buffer.payload_capacity() + } + + /// Return the maximum number of bytes inside the transmit buffer. + #[inline] + pub fn payload_send_capacity(&self) -> usize { + self.tx_buffer.payload_capacity() + } + + /// Enqueue a packet to be sent to a given remote endpoint, and return a pointer + /// to its payload. + /// + /// This function returns `Err(Error::Exhausted)` if the transmit buffer is full, + /// `Err(Error::Unaddressable)` if local or remote port, or remote address are unspecified, + /// and `Err(Error::Truncated)` if there is not enough transmit buffer capacity + /// to ever send this packet. + pub fn send( + &mut self, + size: usize, + meta: impl Into<UdpMetadata>, + ) -> Result<&mut [u8], SendError> { + let meta = meta.into(); + if self.endpoint.port == 0 { + return Err(SendError::Unaddressable); + } + if meta.endpoint.addr.is_unspecified() { + return Err(SendError::Unaddressable); + } + if meta.endpoint.port == 0 { + return Err(SendError::Unaddressable); + } + + let payload_buf = self + .tx_buffer + .enqueue(size, meta) + .map_err(|_| SendError::BufferFull)?; + + net_trace!( + "udp:{}:{}: buffer to send {} octets", + self.endpoint, + meta.endpoint, + size + ); + Ok(payload_buf) + } + + /// Enqueue a packet to be send to a given remote endpoint and pass the buffer + /// to the provided closure. The closure then returns the size of the data written + /// into the buffer. + /// + /// Also see [send](#method.send). + pub fn send_with<F>( + &mut self, + max_size: usize, + meta: impl Into<UdpMetadata>, + f: F, + ) -> Result<usize, SendError> + where + F: FnOnce(&mut [u8]) -> usize, + { + let meta = meta.into(); + if self.endpoint.port == 0 { + return Err(SendError::Unaddressable); + } + if meta.endpoint.addr.is_unspecified() { + return Err(SendError::Unaddressable); + } + if meta.endpoint.port == 0 { + return Err(SendError::Unaddressable); + } + + let size = self + .tx_buffer + .enqueue_with_infallible(max_size, meta, f) + .map_err(|_| SendError::BufferFull)?; + + net_trace!( + "udp:{}:{}: buffer to send {} octets", + self.endpoint, + meta.endpoint, + size + ); + Ok(size) + } + + /// Enqueue a packet to be sent to a given remote endpoint, and fill it from a slice. + /// + /// See also [send](#method.send). + pub fn send_slice( + &mut self, + data: &[u8], + meta: impl Into<UdpMetadata>, + ) -> Result<(), SendError> { + self.send(data.len(), meta)?.copy_from_slice(data); + Ok(()) + } + + /// Dequeue a packet received from a remote endpoint, and return the endpoint as well + /// as a pointer to the payload. + /// + /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty. + pub fn recv(&mut self) -> Result<(&[u8], UdpMetadata), RecvError> { + let (remote_endpoint, payload_buf) = + self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?; + + net_trace!( + "udp:{}:{}: receive {} buffered octets", + self.endpoint, + remote_endpoint.endpoint, + payload_buf.len() + ); + Ok((payload_buf, remote_endpoint)) + } + + /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice, + /// and return the amount of octets copied as well as the endpoint. + /// + /// **Note**: when the size of the provided buffer is smaller than the size of the payload, + /// the packet is dropped and a `RecvError::Truncated` error is returned. + /// + /// See also [recv](#method.recv). + pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, UdpMetadata), RecvError> { + let (buffer, endpoint) = self.recv().map_err(|_| RecvError::Exhausted)?; + + if data.len() < buffer.len() { + return Err(RecvError::Truncated); + } + + let length = min(data.len(), buffer.len()); + data[..length].copy_from_slice(&buffer[..length]); + Ok((length, endpoint)) + } + + /// Peek at a packet received from a remote endpoint, and return the endpoint as well + /// as a pointer to the payload without removing the packet from the receive buffer. + /// This function otherwise behaves identically to [recv](#method.recv). + /// + /// It returns `Err(Error::Exhausted)` if the receive buffer is empty. + pub fn peek(&mut self) -> Result<(&[u8], &UdpMetadata), RecvError> { + let endpoint = self.endpoint; + self.rx_buffer.peek().map_err(|_| RecvError::Exhausted).map( + |(remote_endpoint, payload_buf)| { + net_trace!( + "udp:{}:{}: peek {} buffered octets", + endpoint, + remote_endpoint.endpoint, + payload_buf.len() + ); + (payload_buf, remote_endpoint) + }, + ) + } + + /// Peek at a packet received from a remote endpoint, copy the payload into the given slice, + /// and return the amount of octets copied as well as the endpoint without removing the + /// packet from the receive buffer. + /// This function otherwise behaves identically to [recv_slice](#method.recv_slice). + /// + /// **Note**: when the size of the provided buffer is smaller than the size of the payload, + /// no data is copied into the provided buffer and a `RecvError::Truncated` error is returned. + /// + /// See also [peek](#method.peek). + pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<(usize, &UdpMetadata), RecvError> { + let (buffer, endpoint) = self.peek()?; + + if data.len() < buffer.len() { + return Err(RecvError::Truncated); + } + + let length = min(data.len(), buffer.len()); + data[..length].copy_from_slice(&buffer[..length]); + Ok((length, endpoint)) + } + + pub(crate) fn accepts(&self, cx: &mut Context, ip_repr: &IpRepr, repr: &UdpRepr) -> bool { + if self.endpoint.port != repr.dst_port { + return false; + } + if self.endpoint.addr.is_some() + && self.endpoint.addr != Some(ip_repr.dst_addr()) + && !cx.is_broadcast(&ip_repr.dst_addr()) + && !ip_repr.dst_addr().is_multicast() + { + return false; + } + + true + } + + pub(crate) fn process( + &mut self, + cx: &mut Context, + meta: PacketMeta, + ip_repr: &IpRepr, + repr: &UdpRepr, + payload: &[u8], + ) { + debug_assert!(self.accepts(cx, ip_repr, repr)); + + let size = payload.len(); + + let remote_endpoint = IpEndpoint { + addr: ip_repr.src_addr(), + port: repr.src_port, + }; + + net_trace!( + "udp:{}:{}: receiving {} octets", + self.endpoint, + remote_endpoint, + size + ); + + let metadata = UdpMetadata { + endpoint: remote_endpoint, + meta, + }; + + match self.rx_buffer.enqueue(size, metadata) { + Ok(buf) => buf.copy_from_slice(payload), + Err(_) => net_trace!( + "udp:{}:{}: buffer full, dropped incoming packet", + self.endpoint, + remote_endpoint + ), + } + + #[cfg(feature = "async")] + self.rx_waker.wake(); + } + + pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E> + where + F: FnOnce(&mut Context, PacketMeta, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>, + { + let endpoint = self.endpoint; + let hop_limit = self.hop_limit.unwrap_or(64); + + let res = self.tx_buffer.dequeue_with(|packet_meta, payload_buf| { + let src_addr = match endpoint.addr { + Some(addr) => addr, + None => match cx.get_source_address(&packet_meta.endpoint.addr) { + Some(addr) => addr, + None => { + net_trace!( + "udp:{}:{}: cannot find suitable source address, dropping.", + endpoint, + packet_meta.endpoint + ); + return Ok(()); + } + }, + }; + + net_trace!( + "udp:{}:{}: sending {} octets", + endpoint, + packet_meta.endpoint, + payload_buf.len() + ); + + let repr = UdpRepr { + src_port: endpoint.port, + dst_port: packet_meta.endpoint.port, + }; + let ip_repr = IpRepr::new( + src_addr, + packet_meta.endpoint.addr, + IpProtocol::Udp, + repr.header_len() + payload_buf.len(), + hop_limit, + ); + + emit(cx, packet_meta.meta, (ip_repr, repr, payload_buf)) + }); + match res { + Err(Empty) => Ok(()), + Ok(Err(e)) => Err(e), + Ok(Ok(())) => { + #[cfg(feature = "async")] + self.tx_waker.wake(); + Ok(()) + } + } + } + + pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt { + if self.tx_buffer.is_empty() { + PollAt::Ingress + } else { + PollAt::Now + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::wire::{IpRepr, UdpRepr}; + + use crate::phy::Medium; + use crate::tests::setup; + use rstest::*; + + fn buffer(packets: usize) -> PacketBuffer<'static> { + PacketBuffer::new( + (0..packets) + .map(|_| PacketMetadata::EMPTY) + .collect::<Vec<_>>(), + vec![0; 16 * packets], + ) + } + + fn socket( + rx_buffer: PacketBuffer<'static>, + tx_buffer: PacketBuffer<'static>, + ) -> Socket<'static> { + Socket::new(rx_buffer, tx_buffer) + } + + const LOCAL_PORT: u16 = 53; + const REMOTE_PORT: u16 = 49500; + + cfg_if::cfg_if! { + if #[cfg(feature = "proto-ipv4")] { + use crate::wire::Ipv4Address as IpvXAddress; + use crate::wire::Ipv4Repr as IpvXRepr; + use IpRepr::Ipv4 as IpReprIpvX; + + const LOCAL_ADDR: IpvXAddress = IpvXAddress([192, 168, 1, 1]); + const REMOTE_ADDR: IpvXAddress = IpvXAddress([192, 168, 1, 2]); + const OTHER_ADDR: IpvXAddress = IpvXAddress([192, 168, 1, 3]); + } else { + use crate::wire::Ipv6Address as IpvXAddress; + use crate::wire::Ipv6Repr as IpvXRepr; + use IpRepr::Ipv6 as IpReprIpvX; + + const LOCAL_ADDR: IpvXAddress = IpvXAddress([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + ]); + const REMOTE_ADDR: IpvXAddress = IpvXAddress([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, + ]); + const OTHER_ADDR: IpvXAddress = IpvXAddress([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, + ]); + } + } + + pub const LOCAL_END: IpEndpoint = IpEndpoint { + addr: LOCAL_ADDR.into_address(), + port: LOCAL_PORT, + }; + pub const REMOTE_END: IpEndpoint = IpEndpoint { + addr: REMOTE_ADDR.into_address(), + port: REMOTE_PORT, + }; + + pub const LOCAL_IP_REPR: IpRepr = IpReprIpvX(IpvXRepr { + src_addr: LOCAL_ADDR, + dst_addr: REMOTE_ADDR, + next_header: IpProtocol::Udp, + payload_len: 8 + 6, + hop_limit: 64, + }); + + pub const REMOTE_IP_REPR: IpRepr = IpReprIpvX(IpvXRepr { + src_addr: REMOTE_ADDR, + dst_addr: LOCAL_ADDR, + next_header: IpProtocol::Udp, + payload_len: 8 + 6, + hop_limit: 64, + }); + + pub const BAD_IP_REPR: IpRepr = IpReprIpvX(IpvXRepr { + src_addr: REMOTE_ADDR, + dst_addr: OTHER_ADDR, + next_header: IpProtocol::Udp, + payload_len: 8 + 6, + hop_limit: 64, + }); + + const LOCAL_UDP_REPR: UdpRepr = UdpRepr { + src_port: LOCAL_PORT, + dst_port: REMOTE_PORT, + }; + + const REMOTE_UDP_REPR: UdpRepr = UdpRepr { + src_port: REMOTE_PORT, + dst_port: LOCAL_PORT, + }; + + const PAYLOAD: &[u8] = b"abcdef"; + + #[test] + fn test_bind_unaddressable() { + let mut socket = socket(buffer(0), buffer(0)); + assert_eq!(socket.bind(0), Err(BindError::Unaddressable)); + } + + #[test] + fn test_bind_twice() { + let mut socket = socket(buffer(0), buffer(0)); + assert_eq!(socket.bind(1), Ok(())); + assert_eq!(socket.bind(2), Err(BindError::InvalidState)); + } + + #[test] + #[should_panic(expected = "the time-to-live value of a packet must not be zero")] + fn test_set_hop_limit_zero() { + let mut s = socket(buffer(0), buffer(1)); + s.set_hop_limit(Some(0)); + } + + #[test] + fn test_send_unaddressable() { + let mut socket = socket(buffer(0), buffer(1)); + + assert_eq!( + socket.send_slice(b"abcdef", REMOTE_END), + Err(SendError::Unaddressable) + ); + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); + assert_eq!( + socket.send_slice( + b"abcdef", + IpEndpoint { + addr: IpvXAddress::UNSPECIFIED.into(), + ..REMOTE_END + } + ), + Err(SendError::Unaddressable) + ); + assert_eq!( + socket.send_slice( + b"abcdef", + IpEndpoint { + port: 0, + ..REMOTE_END + } + ), + Err(SendError::Unaddressable) + ); + assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(())); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_send_dispatch(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + let mut socket = socket(buffer(0), buffer(1)); + + assert_eq!(socket.bind(LOCAL_END), Ok(())); + + assert!(socket.can_send()); + assert_eq!( + socket.dispatch(cx, |_, _, _| unreachable!()), + Ok::<_, ()>(()) + ); + + assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(())); + assert_eq!( + socket.send_slice(b"123456", REMOTE_END), + Err(SendError::BufferFull) + ); + assert!(!socket.can_send()); + + assert_eq!( + socket.dispatch(cx, |_, _, (ip_repr, udp_repr, payload)| { + assert_eq!(ip_repr, LOCAL_IP_REPR); + assert_eq!(udp_repr, LOCAL_UDP_REPR); + assert_eq!(payload, PAYLOAD); + Err(()) + }), + Err(()) + ); + assert!(!socket.can_send()); + + assert_eq!( + socket.dispatch(cx, |_, _, (ip_repr, udp_repr, payload)| { + assert_eq!(ip_repr, LOCAL_IP_REPR); + assert_eq!(udp_repr, LOCAL_UDP_REPR); + assert_eq!(payload, PAYLOAD); + Ok::<_, ()>(()) + }), + Ok(()) + ); + assert!(socket.can_send()); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_recv_process(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(0)); + + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); + + assert!(!socket.can_recv()); + assert_eq!(socket.recv(), Err(RecvError::Exhausted)); + + assert!(socket.accepts(cx, &REMOTE_IP_REPR, &REMOTE_UDP_REPR)); + socket.process( + cx, + PacketMeta::default(), + &REMOTE_IP_REPR, + &REMOTE_UDP_REPR, + PAYLOAD, + ); + assert!(socket.can_recv()); + + assert!(socket.accepts(cx, &REMOTE_IP_REPR, &REMOTE_UDP_REPR)); + socket.process( + cx, + PacketMeta::default(), + &REMOTE_IP_REPR, + &REMOTE_UDP_REPR, + PAYLOAD, + ); + + assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END.into()))); + assert!(!socket.can_recv()); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_peek_process(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(0)); + + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); + + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); + + socket.process( + cx, + PacketMeta::default(), + &REMOTE_IP_REPR, + &REMOTE_UDP_REPR, + PAYLOAD, + ); + assert_eq!(socket.peek(), Ok((&b"abcdef"[..], &REMOTE_END.into(),))); + assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END.into(),))); + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_recv_truncated_slice(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(0)); + + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); + + assert!(socket.accepts(cx, &REMOTE_IP_REPR, &REMOTE_UDP_REPR)); + socket.process( + cx, + PacketMeta::default(), + &REMOTE_IP_REPR, + &REMOTE_UDP_REPR, + PAYLOAD, + ); + + let mut slice = [0; 4]; + assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated)); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_peek_truncated_slice(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(0)); + + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); + + socket.process( + cx, + PacketMeta::default(), + &REMOTE_IP_REPR, + &REMOTE_UDP_REPR, + PAYLOAD, + ); + + let mut slice = [0; 4]; + assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Truncated)); + assert_eq!(socket.recv_slice(&mut slice[..]), Err(RecvError::Truncated)); + assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted)); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_set_hop_limit(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut s = socket(buffer(0), buffer(1)); + + assert_eq!(s.bind(LOCAL_END), Ok(())); + + s.set_hop_limit(Some(0x2a)); + assert_eq!(s.send_slice(b"abcdef", REMOTE_END), Ok(())); + assert_eq!( + s.dispatch(cx, |_, _, (ip_repr, _, _)| { + assert_eq!( + ip_repr, + IpReprIpvX(IpvXRepr { + src_addr: LOCAL_ADDR, + dst_addr: REMOTE_ADDR, + next_header: IpProtocol::Udp, + payload_len: 8 + 6, + hop_limit: 0x2a, + }) + ); + Ok::<_, ()>(()) + }), + Ok(()) + ); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_doesnt_accept_wrong_port(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut socket = socket(buffer(1), buffer(0)); + + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); + + let mut udp_repr = REMOTE_UDP_REPR; + assert!(socket.accepts(cx, &REMOTE_IP_REPR, &udp_repr)); + udp_repr.dst_port += 1; + assert!(!socket.accepts(cx, &REMOTE_IP_REPR, &udp_repr)); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_doesnt_accept_wrong_ip(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + let mut port_bound_socket = socket(buffer(1), buffer(0)); + assert_eq!(port_bound_socket.bind(LOCAL_PORT), Ok(())); + assert!(port_bound_socket.accepts(cx, &BAD_IP_REPR, &REMOTE_UDP_REPR)); + + let mut ip_bound_socket = socket(buffer(1), buffer(0)); + assert_eq!(ip_bound_socket.bind(LOCAL_END), Ok(())); + assert!(!ip_bound_socket.accepts(cx, &BAD_IP_REPR, &REMOTE_UDP_REPR)); + } + + #[test] + fn test_send_large_packet() { + // buffer(4) creates a payload buffer of size 16*4 + let mut socket = socket(buffer(0), buffer(4)); + assert_eq!(socket.bind(LOCAL_END), Ok(())); + + let too_large = b"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdefx"; + assert_eq!( + socket.send_slice(too_large, REMOTE_END), + Err(SendError::BufferFull) + ); + assert_eq!(socket.send_slice(&too_large[..16 * 4], REMOTE_END), Ok(())); + } + + #[rstest] + #[case::ip(Medium::Ip)] + #[cfg(feature = "medium-ip")] + #[case::ethernet(Medium::Ethernet)] + #[cfg(feature = "medium-ethernet")] + #[case::ieee802154(Medium::Ieee802154)] + #[cfg(feature = "medium-ieee802154")] + fn test_process_empty_payload(#[case] medium: Medium) { + let meta = Box::leak(Box::new([PacketMetadata::EMPTY])); + let recv_buffer = PacketBuffer::new(&mut meta[..], vec![]); + let mut socket = socket(recv_buffer, buffer(0)); + + let (mut iface, _, _) = setup(medium); + let cx = iface.context(); + + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); + + let repr = UdpRepr { + src_port: REMOTE_PORT, + dst_port: LOCAL_PORT, + }; + socket.process(cx, PacketMeta::default(), &REMOTE_IP_REPR, &repr, &[]); + assert_eq!(socket.recv(), Ok((&[][..], REMOTE_END.into()))); + } + + #[test] + fn test_closing() { + let meta = Box::leak(Box::new([PacketMetadata::EMPTY])); + let recv_buffer = PacketBuffer::new(&mut meta[..], vec![]); + let mut socket = socket(recv_buffer, buffer(0)); + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); + + assert!(socket.is_open()); + socket.close(); + assert!(!socket.is_open()); + } +} diff --git a/src/socket/waker.rs b/src/socket/waker.rs new file mode 100644 index 0000000..4f42197 --- /dev/null +++ b/src/socket/waker.rs @@ -0,0 +1,33 @@ +use core::task::Waker; + +/// Utility struct to register and wake a waker. +#[derive(Debug)] +pub struct WakerRegistration { + waker: Option<Waker>, +} + +impl WakerRegistration { + pub const fn new() -> Self { + Self { waker: None } + } + + /// Register a waker. Overwrites the previous waker, if any. + pub fn register(&mut self, w: &Waker) { + match self.waker { + // Optimization: If both the old and new Wakers wake the same task, we can simply + // keep the old waker, skipping the clone. (In most executor implementations, + // cloning a waker is somewhat expensive, comparable to cloning an Arc). + Some(ref w2) if (w2.will_wake(w)) => {} + // In all other cases + // - we have no waker registered + // - we have a waker registered but it's for a different task. + // then clone the new waker and store it + _ => self.waker = Some(w.clone()), + } + } + + /// Wake the registered waker, if any. + pub fn wake(&mut self) { + self.waker.take().map(|w| w.wake()); + } +} diff --git a/src/storage/assembler.rs b/src/storage/assembler.rs new file mode 100644 index 0000000..365a1e0 --- /dev/null +++ b/src/storage/assembler.rs @@ -0,0 +1,750 @@ +use core::fmt; + +use crate::config::ASSEMBLER_MAX_SEGMENT_COUNT; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TooManyHolesError; + +impl fmt::Display for TooManyHolesError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "too many holes") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for TooManyHolesError {} + +/// A contiguous chunk of absent data, followed by a contiguous chunk of present data. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct Contig { + hole_size: usize, + data_size: usize, +} + +impl fmt::Display for Contig { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.has_hole() { + write!(f, "({})", self.hole_size)?; + } + if self.has_hole() && self.has_data() { + write!(f, " ")?; + } + if self.has_data() { + write!(f, "{}", self.data_size)?; + } + Ok(()) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Contig { + fn format(&self, fmt: defmt::Formatter) { + if self.has_hole() { + defmt::write!(fmt, "({})", self.hole_size); + } + if self.has_hole() && self.has_data() { + defmt::write!(fmt, " "); + } + if self.has_data() { + defmt::write!(fmt, "{}", self.data_size); + } + } +} + +impl Contig { + const fn empty() -> Contig { + Contig { + hole_size: 0, + data_size: 0, + } + } + + fn hole_and_data(hole_size: usize, data_size: usize) -> Contig { + Contig { + hole_size, + data_size, + } + } + + fn has_hole(&self) -> bool { + self.hole_size != 0 + } + + fn has_data(&self) -> bool { + self.data_size != 0 + } + + fn total_size(&self) -> usize { + self.hole_size + self.data_size + } + + fn shrink_hole_by(&mut self, size: usize) { + self.hole_size -= size; + } + + fn shrink_hole_to(&mut self, size: usize) { + debug_assert!(self.hole_size >= size); + + let total_size = self.total_size(); + self.hole_size = size; + self.data_size = total_size - size; + } +} + +/// A buffer (re)assembler. +/// +/// Currently, up to a hardcoded limit of 4 or 32 holes can be tracked in the buffer. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Assembler { + contigs: [Contig; ASSEMBLER_MAX_SEGMENT_COUNT], +} + +impl fmt::Display for Assembler { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "[ ")?; + for contig in self.contigs.iter() { + if !contig.has_data() { + break; + } + write!(f, "{contig} ")?; + } + write!(f, "]")?; + Ok(()) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Assembler { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "[ "); + for contig in self.contigs.iter() { + if !contig.has_data() { + break; + } + defmt::write!(fmt, "{} ", contig); + } + defmt::write!(fmt, "]"); + } +} + +// Invariant on Assembler::contigs: +// - There's an index `i` where all contigs before have data, and all contigs after don't (are unused). +// - All contigs with data must have hole_size != 0, except the first. + +impl Assembler { + /// Create a new buffer assembler. + pub const fn new() -> Assembler { + const EMPTY: Contig = Contig::empty(); + Assembler { + contigs: [EMPTY; ASSEMBLER_MAX_SEGMENT_COUNT], + } + } + + pub fn clear(&mut self) { + self.contigs.fill(Contig::empty()); + } + + fn front(&self) -> Contig { + self.contigs[0] + } + + /// Return length of the front contiguous range without removing it from the assembler + pub fn peek_front(&self) -> usize { + let front = self.front(); + if front.has_hole() { + 0 + } else { + front.data_size + } + } + + fn back(&self) -> Contig { + self.contigs[self.contigs.len() - 1] + } + + /// Return whether the assembler contains no data. + pub fn is_empty(&self) -> bool { + !self.front().has_data() + } + + /// Remove a contig at the given index. + fn remove_contig_at(&mut self, at: usize) { + debug_assert!(self.contigs[at].has_data()); + + for i in at..self.contigs.len() - 1 { + if !self.contigs[i].has_data() { + return; + } + self.contigs[i] = self.contigs[i + 1]; + } + + // Removing the last one. + self.contigs[self.contigs.len() - 1] = Contig::empty(); + } + + /// Add a contig at the given index, and return a pointer to it. + fn add_contig_at(&mut self, at: usize) -> Result<&mut Contig, TooManyHolesError> { + if self.back().has_data() { + return Err(TooManyHolesError); + } + + for i in (at + 1..self.contigs.len()).rev() { + self.contigs[i] = self.contigs[i - 1]; + } + + self.contigs[at] = Contig::empty(); + Ok(&mut self.contigs[at]) + } + + /// Add a new contiguous range to the assembler, + /// or return `Err(TooManyHolesError)` if too many discontinuities are already recorded. + pub fn add(&mut self, mut offset: usize, size: usize) -> Result<(), TooManyHolesError> { + if size == 0 { + return Ok(()); + } + + let mut i = 0; + + // Find index of the contig containing the start of the range. + loop { + if i == self.contigs.len() { + // The new range is after all the previous ranges, but there/s no space to add it. + return Err(TooManyHolesError); + } + let contig = &mut self.contigs[i]; + if !contig.has_data() { + // The new range is after all the previous ranges. Add it. + *contig = Contig::hole_and_data(offset, size); + return Ok(()); + } + if offset <= contig.total_size() { + break; + } + offset -= contig.total_size(); + i += 1; + } + + let contig = &mut self.contigs[i]; + if offset < contig.hole_size { + // Range starts within the hole. + + if offset + size < contig.hole_size { + // Range also ends within the hole. + let new_contig = self.add_contig_at(i)?; + new_contig.hole_size = offset; + new_contig.data_size = size; + + // Previous contigs[index] got moved to contigs[index+1] + self.contigs[i + 1].shrink_hole_by(offset + size); + return Ok(()); + } + + // The range being added covers both a part of the hole and a part of the data + // in this contig, shrink the hole in this contig. + contig.shrink_hole_to(offset); + } + + // coalesce contigs to the right. + let mut j = i + 1; + while j < self.contigs.len() + && self.contigs[j].has_data() + && offset + size >= self.contigs[i].total_size() + self.contigs[j].hole_size + { + self.contigs[i].data_size += self.contigs[j].total_size(); + j += 1; + } + let shift = j - i - 1; + if shift != 0 { + for x in i + 1..self.contigs.len() { + if !self.contigs[x].has_data() { + break; + } + + self.contigs[x] = self + .contigs + .get(x + shift) + .copied() + .unwrap_or_else(Contig::empty); + } + } + + if offset + size > self.contigs[i].total_size() { + // The added range still extends beyond the current contig. Increase data size. + let left = offset + size - self.contigs[i].total_size(); + self.contigs[i].data_size += left; + + // Decrease hole size of the next, if any. + if i + 1 < self.contigs.len() && self.contigs[i + 1].has_data() { + self.contigs[i + 1].hole_size -= left; + } + } + + Ok(()) + } + + /// Remove a contiguous range from the front of the assembler. + /// If no such range, return 0. + pub fn remove_front(&mut self) -> usize { + let front = self.front(); + if front.has_hole() || !front.has_data() { + 0 + } else { + self.remove_contig_at(0); + debug_assert!(front.data_size > 0); + front.data_size + } + } + + /// Add a segment, then remove_front. + /// + /// This is equivalent to calling `add` then `remove_front` individually, + /// except it's guaranteed to not fail when offset = 0. + /// This is required for TCP: we must never drop the next expected segment, or + /// the protocol might get stuck. + pub fn add_then_remove_front( + &mut self, + offset: usize, + size: usize, + ) -> Result<usize, TooManyHolesError> { + // This is the only case where a segment at offset=0 would cause the + // total amount of contigs to rise (and therefore can potentially cause + // a TooManyHolesError). Handle it in a way that is guaranteed to succeed. + if offset == 0 && size < self.contigs[0].hole_size { + self.contigs[0].hole_size -= size; + return Ok(size); + } + + self.add(offset, size)?; + Ok(self.remove_front()) + } + + /// Iterate over all of the contiguous data ranges. + /// + /// This is used in calculating what data ranges have been received. The offset indicates the + /// number of bytes of contiguous data received before the beginnings of this Assembler. + /// + /// Data Hole Data + /// |--- 100 ---|--- 200 ---|--- 100 ---| + /// + /// An offset of 1500 would return the ranges: ``(1500, 1600), (1800, 1900)`` + pub fn iter_data(&self, first_offset: usize) -> AssemblerIter { + AssemblerIter::new(self, first_offset) + } +} + +pub struct AssemblerIter<'a> { + assembler: &'a Assembler, + offset: usize, + index: usize, + left: usize, + right: usize, +} + +impl<'a> AssemblerIter<'a> { + fn new(assembler: &'a Assembler, offset: usize) -> AssemblerIter<'a> { + AssemblerIter { + assembler, + offset, + index: 0, + left: 0, + right: 0, + } + } +} + +impl<'a> Iterator for AssemblerIter<'a> { + type Item = (usize, usize); + + fn next(&mut self) -> Option<(usize, usize)> { + let mut data_range = None; + while data_range.is_none() && self.index < self.assembler.contigs.len() { + let contig = self.assembler.contigs[self.index]; + self.left += contig.hole_size; + self.right = self.left + contig.data_size; + data_range = if self.left < self.right { + let data_range = (self.left + self.offset, self.right + self.offset); + self.left = self.right; + Some(data_range) + } else { + None + }; + self.index += 1; + } + data_range + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::vec::Vec; + + impl From<Vec<(usize, usize)>> for Assembler { + fn from(vec: Vec<(usize, usize)>) -> Assembler { + const EMPTY: Contig = Contig::empty(); + + let mut contigs = [EMPTY; ASSEMBLER_MAX_SEGMENT_COUNT]; + for (i, &(hole_size, data_size)) in vec.iter().enumerate() { + contigs[i] = Contig { + hole_size, + data_size, + }; + } + Assembler { contigs } + } + } + + macro_rules! contigs { + [$( $x:expr ),*] => ({ + Assembler::from(vec![$( $x ),*]) + }) + } + + #[test] + fn test_new() { + let assr = Assembler::new(); + assert_eq!(assr, contigs![]); + } + + #[test] + fn test_empty_add_full() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(0, 16), Ok(())); + assert_eq!(assr, contigs![(0, 16)]); + } + + #[test] + fn test_empty_add_front() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(0, 4), Ok(())); + assert_eq!(assr, contigs![(0, 4)]); + } + + #[test] + fn test_empty_add_back() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(12, 4), Ok(())); + assert_eq!(assr, contigs![(12, 4)]); + } + + #[test] + fn test_empty_add_mid() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(4, 8), Ok(())); + assert_eq!(assr, contigs![(4, 8)]); + } + + #[test] + fn test_partial_add_front() { + let mut assr = contigs![(4, 8)]; + assert_eq!(assr.add(0, 4), Ok(())); + assert_eq!(assr, contigs![(0, 12)]); + } + + #[test] + fn test_partial_add_back() { + let mut assr = contigs![(4, 8)]; + assert_eq!(assr.add(12, 4), Ok(())); + assert_eq!(assr, contigs![(4, 12)]); + } + + #[test] + fn test_partial_add_front_overlap() { + let mut assr = contigs![(4, 8)]; + assert_eq!(assr.add(0, 8), Ok(())); + assert_eq!(assr, contigs![(0, 12)]); + } + + #[test] + fn test_partial_add_front_overlap_split() { + let mut assr = contigs![(4, 8)]; + assert_eq!(assr.add(2, 6), Ok(())); + assert_eq!(assr, contigs![(2, 10)]); + } + + #[test] + fn test_partial_add_back_overlap() { + let mut assr = contigs![(4, 8)]; + assert_eq!(assr.add(8, 8), Ok(())); + assert_eq!(assr, contigs![(4, 12)]); + } + + #[test] + fn test_partial_add_back_overlap_split() { + let mut assr = contigs![(4, 8)]; + assert_eq!(assr.add(10, 4), Ok(())); + assert_eq!(assr, contigs![(4, 10)]); + } + + #[test] + fn test_partial_add_both_overlap() { + let mut assr = contigs![(4, 8)]; + assert_eq!(assr.add(0, 16), Ok(())); + assert_eq!(assr, contigs![(0, 16)]); + } + + #[test] + fn test_partial_add_both_overlap_split() { + let mut assr = contigs![(4, 8)]; + assert_eq!(assr.add(2, 12), Ok(())); + assert_eq!(assr, contigs![(2, 12)]); + } + + #[test] + fn test_rejected_add_keeps_state() { + let mut assr = Assembler::new(); + for c in 1..=ASSEMBLER_MAX_SEGMENT_COUNT { + assert_eq!(assr.add(c * 10, 3), Ok(())); + } + // Maximum of allowed holes is reached + let assr_before = assr.clone(); + assert_eq!(assr.add(1, 3), Err(TooManyHolesError)); + assert_eq!(assr_before, assr); + } + + #[test] + fn test_empty_remove_front() { + let mut assr = contigs![]; + assert_eq!(assr.remove_front(), 0); + } + + #[test] + fn test_trailing_hole_remove_front() { + let mut assr = contigs![(0, 4)]; + assert_eq!(assr.remove_front(), 4); + assert_eq!(assr, contigs![]); + } + + #[test] + fn test_trailing_data_remove_front() { + let mut assr = contigs![(0, 4), (4, 4)]; + assert_eq!(assr.remove_front(), 4); + assert_eq!(assr, contigs![(4, 4)]); + } + + #[test] + fn test_boundary_case_remove_front() { + let mut vec = vec![(1, 1); ASSEMBLER_MAX_SEGMENT_COUNT]; + vec[0] = (0, 2); + let mut assr = Assembler::from(vec); + assert_eq!(assr.remove_front(), 2); + let mut vec = vec![(1, 1); ASSEMBLER_MAX_SEGMENT_COUNT]; + vec[ASSEMBLER_MAX_SEGMENT_COUNT - 1] = (0, 0); + let exp_assr = Assembler::from(vec); + assert_eq!(assr, exp_assr); + } + + #[test] + fn test_shrink_next_hole() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(100, 10), Ok(())); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add(40, 30), Ok(())); + assert_eq!(assr, contigs![(40, 30), (30, 10)]); + } + + #[test] + fn test_join_two() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(10, 10), Ok(())); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add(15, 40), Ok(())); + assert_eq!(assr, contigs![(10, 50)]); + } + + #[test] + fn test_join_two_reversed() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add(10, 10), Ok(())); + assert_eq!(assr.add(15, 40), Ok(())); + assert_eq!(assr, contigs![(10, 50)]); + } + + #[test] + fn test_join_two_overlong() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add(10, 10), Ok(())); + assert_eq!(assr.add(15, 60), Ok(())); + assert_eq!(assr, contigs![(10, 65)]); + } + + #[test] + fn test_iter_empty() { + let assr = Assembler::new(); + let segments: Vec<_> = assr.iter_data(10).collect(); + assert_eq!(segments, vec![]); + } + + #[test] + fn test_iter_full() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(0, 16), Ok(())); + let segments: Vec<_> = assr.iter_data(10).collect(); + assert_eq!(segments, vec![(10, 26)]); + } + + #[test] + fn test_iter_offset() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(0, 16), Ok(())); + let segments: Vec<_> = assr.iter_data(100).collect(); + assert_eq!(segments, vec![(100, 116)]); + } + + #[test] + fn test_iter_one_front() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(0, 4), Ok(())); + let segments: Vec<_> = assr.iter_data(10).collect(); + assert_eq!(segments, vec![(10, 14)]); + } + + #[test] + fn test_iter_one_back() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(12, 4), Ok(())); + let segments: Vec<_> = assr.iter_data(10).collect(); + assert_eq!(segments, vec![(22, 26)]); + } + + #[test] + fn test_iter_one_mid() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(4, 8), Ok(())); + let segments: Vec<_> = assr.iter_data(10).collect(); + assert_eq!(segments, vec![(14, 22)]); + } + + #[test] + fn test_iter_one_trailing_gap() { + let assr = contigs![(4, 8)]; + let segments: Vec<_> = assr.iter_data(100).collect(); + assert_eq!(segments, vec![(104, 112)]); + } + + #[test] + fn test_iter_two_split() { + let assr = contigs![(2, 6), (4, 1)]; + let segments: Vec<_> = assr.iter_data(100).collect(); + assert_eq!(segments, vec![(102, 108), (112, 113)]); + } + + #[test] + fn test_iter_three_split() { + let assr = contigs![(2, 6), (2, 1), (2, 2)]; + let segments: Vec<_> = assr.iter_data(100).collect(); + assert_eq!(segments, vec![(102, 108), (110, 111), (113, 115)]); + } + + #[test] + fn test_issue_694() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(0, 1), Ok(())); + assert_eq!(assr.add(2, 1), Ok(())); + assert_eq!(assr.add(1, 1), Ok(())); + } + + #[test] + fn test_add_then_remove_front() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add_then_remove_front(10, 10), Ok(0)); + assert_eq!(assr, contigs![(10, 10), (30, 10)]); + } + + #[test] + fn test_add_then_remove_front_at_front() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add_then_remove_front(0, 10), Ok(10)); + assert_eq!(assr, contigs![(40, 10)]); + } + + #[test] + fn test_add_then_remove_front_at_front_touch() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add_then_remove_front(0, 50), Ok(60)); + assert_eq!(assr, contigs![]); + } + + #[test] + fn test_add_then_remove_front_at_front_full() { + let mut assr = Assembler::new(); + for c in 1..=ASSEMBLER_MAX_SEGMENT_COUNT { + assert_eq!(assr.add(c * 10, 3), Ok(())); + } + // Maximum of allowed holes is reached + let assr_before = assr.clone(); + assert_eq!(assr.add_then_remove_front(1, 3), Err(TooManyHolesError)); + assert_eq!(assr_before, assr); + } + + #[test] + fn test_add_then_remove_front_at_front_full_offset_0() { + let mut assr = Assembler::new(); + for c in 1..=ASSEMBLER_MAX_SEGMENT_COUNT { + assert_eq!(assr.add(c * 10, 3), Ok(())); + } + assert_eq!(assr.add_then_remove_front(0, 3), Ok(3)); + } + + // Test against an obviously-correct but inefficient bitmap impl. + #[test] + fn test_random() { + use rand::Rng; + + const MAX_INDEX: usize = 256; + + for max_size in [2, 5, 10, 100] { + for _ in 0..300 { + //println!("==="); + let mut assr = Assembler::new(); + let mut map = [false; MAX_INDEX]; + + for _ in 0..60 { + let offset = rand::thread_rng().gen_range(0..MAX_INDEX - max_size - 1); + let size = rand::thread_rng().gen_range(1..=max_size); + + //println!("add {}..{} {}", offset, offset + size, size); + // Real impl + let res = assr.add(offset, size); + + // Bitmap impl + let mut map2 = map; + map2[offset..][..size].fill(true); + + let mut contigs = vec![]; + let mut hole: usize = 0; + let mut data: usize = 0; + for b in map2 { + if b { + data += 1; + } else { + if data != 0 { + contigs.push((hole, data)); + hole = 0; + data = 0; + } + hole += 1; + } + } + + // Compare. + let wanted_res = if contigs.len() > ASSEMBLER_MAX_SEGMENT_COUNT { + Err(TooManyHolesError) + } else { + Ok(()) + }; + assert_eq!(res, wanted_res); + if res.is_ok() { + map = map2; + assert_eq!(assr, Assembler::from(contigs)); + } + } + } + } + } +} diff --git a/src/storage/mod.rs b/src/storage/mod.rs new file mode 100644 index 0000000..b03de71 --- /dev/null +++ b/src/storage/mod.rs @@ -0,0 +1,31 @@ +/*! Specialized containers. + +The `storage` module provides containers for use in other modules. +The containers support both pre-allocated memory, without the `std` +or `alloc` crates being available, and heap-allocated memory. +*/ + +mod assembler; +mod packet_buffer; +mod ring_buffer; + +pub use self::assembler::Assembler; +pub use self::packet_buffer::{PacketBuffer, PacketMetadata}; +pub use self::ring_buffer::RingBuffer; + +/// A trait for setting a value to a known state. +/// +/// In-place analog of Default. +pub trait Resettable { + fn reset(&mut self); +} + +/// Error returned when enqueuing into a full buffer. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Full; + +/// Error returned when dequeuing from an empty buffer. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Empty; diff --git a/src/storage/packet_buffer.rs b/src/storage/packet_buffer.rs new file mode 100644 index 0000000..28119fa --- /dev/null +++ b/src/storage/packet_buffer.rs @@ -0,0 +1,402 @@ +use managed::ManagedSlice; + +use crate::storage::{Full, RingBuffer}; + +use super::Empty; + +/// Size and header of a packet. +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct PacketMetadata<H> { + size: usize, + header: Option<H>, +} + +impl<H> PacketMetadata<H> { + /// Empty packet description. + pub const EMPTY: PacketMetadata<H> = PacketMetadata { + size: 0, + header: None, + }; + + fn padding(size: usize) -> PacketMetadata<H> { + PacketMetadata { + size: size, + header: None, + } + } + + fn packet(size: usize, header: H) -> PacketMetadata<H> { + PacketMetadata { + size: size, + header: Some(header), + } + } + + fn is_padding(&self) -> bool { + self.header.is_none() + } +} + +/// An UDP packet ring buffer. +#[derive(Debug)] +pub struct PacketBuffer<'a, H: 'a> { + metadata_ring: RingBuffer<'a, PacketMetadata<H>>, + payload_ring: RingBuffer<'a, u8>, +} + +impl<'a, H> PacketBuffer<'a, H> { + /// Create a new packet buffer with the provided metadata and payload storage. + /// + /// Metadata storage limits the maximum _number_ of packets in the buffer and payload + /// storage limits the maximum _total size_ of packets. + pub fn new<MS, PS>(metadata_storage: MS, payload_storage: PS) -> PacketBuffer<'a, H> + where + MS: Into<ManagedSlice<'a, PacketMetadata<H>>>, + PS: Into<ManagedSlice<'a, u8>>, + { + PacketBuffer { + metadata_ring: RingBuffer::new(metadata_storage), + payload_ring: RingBuffer::new(payload_storage), + } + } + + /// Query whether the buffer is empty. + pub fn is_empty(&self) -> bool { + self.metadata_ring.is_empty() + } + + /// Query whether the buffer is full. + pub fn is_full(&self) -> bool { + self.metadata_ring.is_full() + } + + // There is currently no enqueue_with() because of the complexity of managing padding + // in case of failure. + + /// Enqueue a single packet with the given header into the buffer, and + /// return a reference to its payload, or return `Err(Full)` + /// if the buffer is full. + pub fn enqueue(&mut self, size: usize, header: H) -> Result<&mut [u8], Full> { + if self.payload_ring.capacity() < size || self.metadata_ring.is_full() { + return Err(Full); + } + + // Ring is currently empty. Clear it (resetting `read_at`) to maximize + // for contiguous space. + if self.payload_ring.is_empty() { + self.payload_ring.clear(); + } + + let window = self.payload_ring.window(); + let contig_window = self.payload_ring.contiguous_window(); + + if window < size { + return Err(Full); + } else if contig_window < size { + if window - contig_window < size { + // The buffer length is larger than the current contiguous window + // and is larger than the contiguous window will be after adding + // the padding necessary to circle around to the beginning of the + // ring buffer. + return Err(Full); + } else { + // Add padding to the end of the ring buffer so that the + // contiguous window is at the beginning of the ring buffer. + *self.metadata_ring.enqueue_one()? = PacketMetadata::padding(contig_window); + // note(discard): function does not write to the result + // enqueued padding buffer location + let _buf_enqueued = self.payload_ring.enqueue_many(contig_window); + } + } + + *self.metadata_ring.enqueue_one()? = PacketMetadata::packet(size, header); + + let payload_buf = self.payload_ring.enqueue_many(size); + debug_assert!(payload_buf.len() == size); + Ok(payload_buf) + } + + /// Call `f` with a packet from the buffer large enough to fit `max_size` bytes. The packet + /// is shrunk to the size returned from `f` and enqueued into the buffer. + pub fn enqueue_with_infallible<'b, F>( + &'b mut self, + max_size: usize, + header: H, + f: F, + ) -> Result<usize, Full> + where + F: FnOnce(&'b mut [u8]) -> usize, + { + if self.payload_ring.capacity() < max_size || self.metadata_ring.is_full() { + return Err(Full); + } + + let window = self.payload_ring.window(); + let contig_window = self.payload_ring.contiguous_window(); + + if window < max_size { + return Err(Full); + } else if contig_window < max_size { + if window - contig_window < max_size { + // The buffer length is larger than the current contiguous window + // and is larger than the contiguous window will be after adding + // the padding necessary to circle around to the beginning of the + // ring buffer. + return Err(Full); + } else { + // Add padding to the end of the ring buffer so that the + // contiguous window is at the beginning of the ring buffer. + *self.metadata_ring.enqueue_one()? = PacketMetadata::padding(contig_window); + // note(discard): function does not write to the result + // enqueued padding buffer location + let _buf_enqueued = self.payload_ring.enqueue_many(contig_window); + } + } + + let (size, _) = self + .payload_ring + .enqueue_many_with(|data| (f(&mut data[..max_size]), ())); + + *self.metadata_ring.enqueue_one()? = PacketMetadata::packet(size, header); + + Ok(size) + } + + fn dequeue_padding(&mut self) { + let _ = self.metadata_ring.dequeue_one_with(|metadata| { + if metadata.is_padding() { + // note(discard): function does not use value of dequeued padding bytes + let _buf_dequeued = self.payload_ring.dequeue_many(metadata.size); + Ok(()) // dequeue metadata + } else { + Err(()) // don't dequeue metadata + } + }); + } + + /// Call `f` with a single packet from the buffer, and dequeue the packet if `f` + /// returns successfully, or return `Err(EmptyError)` if the buffer is empty. + pub fn dequeue_with<'c, R, E, F>(&'c mut self, f: F) -> Result<Result<R, E>, Empty> + where + F: FnOnce(&mut H, &'c mut [u8]) -> Result<R, E>, + { + self.dequeue_padding(); + + self.metadata_ring.dequeue_one_with(|metadata| { + self.payload_ring + .dequeue_many_with(|payload_buf| { + debug_assert!(payload_buf.len() >= metadata.size); + + match f( + metadata.header.as_mut().unwrap(), + &mut payload_buf[..metadata.size], + ) { + Ok(val) => (metadata.size, Ok(val)), + Err(err) => (0, Err(err)), + } + }) + .1 + }) + } + + /// Dequeue a single packet from the buffer, and return a reference to its payload + /// as well as its header, or return `Err(Error::Exhausted)` if the buffer is empty. + pub fn dequeue(&mut self) -> Result<(H, &mut [u8]), Empty> { + self.dequeue_padding(); + + let meta = self.metadata_ring.dequeue_one()?; + + let payload_buf = self.payload_ring.dequeue_many(meta.size); + debug_assert!(payload_buf.len() == meta.size); + Ok((meta.header.take().unwrap(), payload_buf)) + } + + /// Peek at a single packet from the buffer without removing it, and return a reference to + /// its payload as well as its header, or return `Err(Error:Exhausted)` if the buffer is empty. + /// + /// This function otherwise behaves identically to [dequeue](#method.dequeue). + pub fn peek(&mut self) -> Result<(&H, &[u8]), Empty> { + self.dequeue_padding(); + + if let Some(metadata) = self.metadata_ring.get_allocated(0, 1).first() { + Ok(( + metadata.header.as_ref().unwrap(), + self.payload_ring.get_allocated(0, metadata.size), + )) + } else { + Err(Empty) + } + } + + /// Return the maximum number packets that can be stored. + pub fn packet_capacity(&self) -> usize { + self.metadata_ring.capacity() + } + + /// Return the maximum number of bytes in the payload ring buffer. + pub fn payload_capacity(&self) -> usize { + self.payload_ring.capacity() + } + + /// Reset the packet buffer and clear any staged. + #[allow(unused)] + pub(crate) fn reset(&mut self) { + self.payload_ring.clear(); + self.metadata_ring.clear(); + } +} + +#[cfg(test)] +mod test { + use super::*; + + fn buffer() -> PacketBuffer<'static, ()> { + PacketBuffer::new(vec![PacketMetadata::EMPTY; 4], vec![0u8; 16]) + } + + #[test] + fn test_simple() { + let mut buffer = buffer(); + buffer.enqueue(6, ()).unwrap().copy_from_slice(b"abcdef"); + assert_eq!(buffer.enqueue(16, ()), Err(Full)); + assert_eq!(buffer.metadata_ring.len(), 1); + assert_eq!(buffer.dequeue().unwrap().1, &b"abcdef"[..]); + assert_eq!(buffer.dequeue(), Err(Empty)); + } + + #[test] + fn test_peek() { + let mut buffer = buffer(); + assert_eq!(buffer.peek(), Err(Empty)); + buffer.enqueue(6, ()).unwrap().copy_from_slice(b"abcdef"); + assert_eq!(buffer.metadata_ring.len(), 1); + assert_eq!(buffer.peek().unwrap().1, &b"abcdef"[..]); + assert_eq!(buffer.dequeue().unwrap().1, &b"abcdef"[..]); + assert_eq!(buffer.peek(), Err(Empty)); + } + + #[test] + fn test_padding() { + let mut buffer = buffer(); + assert!(buffer.enqueue(6, ()).is_ok()); + assert!(buffer.enqueue(8, ()).is_ok()); + assert!(buffer.dequeue().is_ok()); + buffer.enqueue(4, ()).unwrap().copy_from_slice(b"abcd"); + assert_eq!(buffer.metadata_ring.len(), 3); + assert!(buffer.dequeue().is_ok()); + + assert_eq!(buffer.dequeue().unwrap().1, &b"abcd"[..]); + assert_eq!(buffer.metadata_ring.len(), 0); + } + + #[test] + fn test_padding_with_large_payload() { + let mut buffer = buffer(); + assert!(buffer.enqueue(12, ()).is_ok()); + assert!(buffer.dequeue().is_ok()); + buffer + .enqueue(12, ()) + .unwrap() + .copy_from_slice(b"abcdefghijkl"); + } + + #[test] + fn test_dequeue_with() { + let mut buffer = buffer(); + assert!(buffer.enqueue(6, ()).is_ok()); + assert!(buffer.enqueue(8, ()).is_ok()); + assert!(buffer.dequeue().is_ok()); + buffer.enqueue(4, ()).unwrap().copy_from_slice(b"abcd"); + assert_eq!(buffer.metadata_ring.len(), 3); + assert!(buffer.dequeue().is_ok()); + + assert!(matches!( + buffer.dequeue_with(|_, _| Result::<(), u32>::Err(123)), + Ok(Err(_)) + )); + assert_eq!(buffer.metadata_ring.len(), 1); + + assert!(buffer + .dequeue_with(|&mut (), payload| { + assert_eq!(payload, &b"abcd"[..]); + Result::<(), ()>::Ok(()) + }) + .is_ok()); + assert_eq!(buffer.metadata_ring.len(), 0); + } + + #[test] + fn test_metadata_full_empty() { + let mut buffer = buffer(); + assert!(buffer.is_empty()); + assert!(!buffer.is_full()); + assert!(buffer.enqueue(1, ()).is_ok()); + assert!(!buffer.is_empty()); + assert!(buffer.enqueue(1, ()).is_ok()); + assert!(buffer.enqueue(1, ()).is_ok()); + assert!(!buffer.is_full()); + assert!(!buffer.is_empty()); + assert!(buffer.enqueue(1, ()).is_ok()); + assert!(buffer.is_full()); + assert!(!buffer.is_empty()); + assert_eq!(buffer.metadata_ring.len(), 4); + assert_eq!(buffer.enqueue(1, ()), Err(Full)); + } + + #[test] + fn test_window_too_small() { + let mut buffer = buffer(); + assert!(buffer.enqueue(4, ()).is_ok()); + assert!(buffer.enqueue(8, ()).is_ok()); + assert!(buffer.dequeue().is_ok()); + assert_eq!(buffer.enqueue(16, ()), Err(Full)); + assert_eq!(buffer.metadata_ring.len(), 1); + } + + #[test] + fn test_contiguous_window_too_small() { + let mut buffer = buffer(); + assert!(buffer.enqueue(4, ()).is_ok()); + assert!(buffer.enqueue(8, ()).is_ok()); + assert!(buffer.dequeue().is_ok()); + assert_eq!(buffer.enqueue(8, ()), Err(Full)); + assert_eq!(buffer.metadata_ring.len(), 1); + } + + #[test] + fn test_contiguous_window_wrap() { + let mut buffer = buffer(); + assert!(buffer.enqueue(15, ()).is_ok()); + assert!(buffer.dequeue().is_ok()); + assert!(buffer.enqueue(16, ()).is_ok()); + } + + #[test] + fn test_capacity_too_small() { + let mut buffer = buffer(); + assert_eq!(buffer.enqueue(32, ()), Err(Full)); + } + + #[test] + fn test_contig_window_prioritized() { + let mut buffer = buffer(); + assert!(buffer.enqueue(4, ()).is_ok()); + assert!(buffer.dequeue().is_ok()); + assert!(buffer.enqueue(5, ()).is_ok()); + } + + #[test] + fn clear() { + let mut buffer = buffer(); + + // Ensure enqueuing data in the buffer fills it somewhat. + assert!(buffer.is_empty()); + assert!(buffer.enqueue(6, ()).is_ok()); + + // Ensure that resetting the buffer causes it to be empty. + assert!(!buffer.is_empty()); + buffer.reset(); + assert!(buffer.is_empty()); + } +} diff --git a/src/storage/ring_buffer.rs b/src/storage/ring_buffer.rs new file mode 100644 index 0000000..7d461b6 --- /dev/null +++ b/src/storage/ring_buffer.rs @@ -0,0 +1,803 @@ +// Some of the functions in ring buffer is marked as #[must_use]. It notes that +// these functions may have side effects, and it's implemented by [RFC 1940]. +// [RFC 1940]: https://github.com/rust-lang/rust/issues/43302 + +use core::cmp; +use managed::ManagedSlice; + +use crate::storage::Resettable; + +use super::{Empty, Full}; + +/// A ring buffer. +/// +/// This ring buffer implementation provides many ways to interact with it: +/// +/// * Enqueueing or dequeueing one element from corresponding side of the buffer; +/// * Enqueueing or dequeueing a slice of elements from corresponding side of the buffer; +/// * Accessing allocated and unallocated areas directly. +/// +/// It is also zero-copy; all methods provide references into the buffer's storage. +/// Note that all references are mutable; it is considered more important to allow +/// in-place processing than to protect from accidental mutation. +/// +/// This implementation is suitable for both simple uses such as a FIFO queue +/// of UDP packets, and advanced ones such as a TCP reassembly buffer. +#[derive(Debug)] +pub struct RingBuffer<'a, T: 'a> { + storage: ManagedSlice<'a, T>, + read_at: usize, + length: usize, +} + +impl<'a, T: 'a> RingBuffer<'a, T> { + /// Create a ring buffer with the given storage. + /// + /// During creation, every element in `storage` is reset. + pub fn new<S>(storage: S) -> RingBuffer<'a, T> + where + S: Into<ManagedSlice<'a, T>>, + { + RingBuffer { + storage: storage.into(), + read_at: 0, + length: 0, + } + } + + /// Clear the ring buffer. + pub fn clear(&mut self) { + self.read_at = 0; + self.length = 0; + } + + /// Return the maximum number of elements in the ring buffer. + pub fn capacity(&self) -> usize { + self.storage.len() + } + + /// Clear the ring buffer, and reset every element. + pub fn reset(&mut self) + where + T: Resettable, + { + self.clear(); + for elem in self.storage.iter_mut() { + elem.reset(); + } + } + + /// Return the current number of elements in the ring buffer. + pub fn len(&self) -> usize { + self.length + } + + /// Return the number of elements that can be added to the ring buffer. + pub fn window(&self) -> usize { + self.capacity() - self.len() + } + + /// Return the largest number of elements that can be added to the buffer + /// without wrapping around (i.e. in a single `enqueue_many` call). + pub fn contiguous_window(&self) -> usize { + cmp::min(self.window(), self.capacity() - self.get_idx(self.length)) + } + + /// Query whether the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Query whether the buffer is full. + pub fn is_full(&self) -> bool { + self.window() == 0 + } + + /// Shorthand for `(self.read + idx) % self.capacity()` with an + /// additional check to ensure that the capacity is not zero. + fn get_idx(&self, idx: usize) -> usize { + let len = self.capacity(); + if len > 0 { + (self.read_at + idx) % len + } else { + 0 + } + } + + /// Shorthand for `(self.read + idx) % self.capacity()` with no + /// additional checks to ensure the capacity is not zero. + fn get_idx_unchecked(&self, idx: usize) -> usize { + (self.read_at + idx) % self.capacity() + } +} + +/// This is the "discrete" ring buffer interface: it operates with single elements, +/// and boundary conditions (empty/full) are errors. +impl<'a, T: 'a> RingBuffer<'a, T> { + /// Call `f` with a single buffer element, and enqueue the element if `f` + /// returns successfully, or return `Err(Full)` if the buffer is full. + pub fn enqueue_one_with<'b, R, E, F>(&'b mut self, f: F) -> Result<Result<R, E>, Full> + where + F: FnOnce(&'b mut T) -> Result<R, E>, + { + if self.is_full() { + return Err(Full); + } + + let index = self.get_idx_unchecked(self.length); + let res = f(&mut self.storage[index]); + if res.is_ok() { + self.length += 1; + } + Ok(res) + } + + /// Enqueue a single element into the buffer, and return a reference to it, + /// or return `Err(Full)` if the buffer is full. + /// + /// This function is a shortcut for `ring_buf.enqueue_one_with(Ok)`. + pub fn enqueue_one(&mut self) -> Result<&mut T, Full> { + self.enqueue_one_with(Ok)? + } + + /// Call `f` with a single buffer element, and dequeue the element if `f` + /// returns successfully, or return `Err(Empty)` if the buffer is empty. + pub fn dequeue_one_with<'b, R, E, F>(&'b mut self, f: F) -> Result<Result<R, E>, Empty> + where + F: FnOnce(&'b mut T) -> Result<R, E>, + { + if self.is_empty() { + return Err(Empty); + } + + let next_at = self.get_idx_unchecked(1); + let res = f(&mut self.storage[self.read_at]); + + if res.is_ok() { + self.length -= 1; + self.read_at = next_at; + } + Ok(res) + } + + /// Dequeue an element from the buffer, and return a reference to it, + /// or return `Err(Empty)` if the buffer is empty. + /// + /// This function is a shortcut for `ring_buf.dequeue_one_with(Ok)`. + pub fn dequeue_one(&mut self) -> Result<&mut T, Empty> { + self.dequeue_one_with(Ok)? + } +} + +/// This is the "continuous" ring buffer interface: it operates with element slices, +/// and boundary conditions (empty/full) simply result in empty slices. +impl<'a, T: 'a> RingBuffer<'a, T> { + /// Call `f` with the largest contiguous slice of unallocated buffer elements, + /// and enqueue the amount of elements returned by `f`. + /// + /// # Panics + /// This function panics if the amount of elements returned by `f` is larger + /// than the size of the slice passed into it. + pub fn enqueue_many_with<'b, R, F>(&'b mut self, f: F) -> (usize, R) + where + F: FnOnce(&'b mut [T]) -> (usize, R), + { + if self.length == 0 { + // Ring is currently empty. Reset `read_at` to optimize + // for contiguous space. + self.read_at = 0; + } + + let write_at = self.get_idx(self.length); + let max_size = self.contiguous_window(); + let (size, result) = f(&mut self.storage[write_at..write_at + max_size]); + assert!(size <= max_size); + self.length += size; + (size, result) + } + + /// Enqueue a slice of elements up to the given size into the buffer, + /// and return a reference to them. + /// + /// This function may return a slice smaller than the given size + /// if the free space in the buffer is not contiguous. + #[must_use] + pub fn enqueue_many(&mut self, size: usize) -> &mut [T] { + self.enqueue_many_with(|buf| { + let size = cmp::min(size, buf.len()); + (size, &mut buf[..size]) + }) + .1 + } + + /// Enqueue as many elements from the given slice into the buffer as possible, + /// and return the amount of elements that could fit. + #[must_use] + pub fn enqueue_slice(&mut self, data: &[T]) -> usize + where + T: Copy, + { + let (size_1, data) = self.enqueue_many_with(|buf| { + let size = cmp::min(buf.len(), data.len()); + buf[..size].copy_from_slice(&data[..size]); + (size, &data[size..]) + }); + let (size_2, ()) = self.enqueue_many_with(|buf| { + let size = cmp::min(buf.len(), data.len()); + buf[..size].copy_from_slice(&data[..size]); + (size, ()) + }); + size_1 + size_2 + } + + /// Call `f` with the largest contiguous slice of allocated buffer elements, + /// and dequeue the amount of elements returned by `f`. + /// + /// # Panics + /// This function panics if the amount of elements returned by `f` is larger + /// than the size of the slice passed into it. + pub fn dequeue_many_with<'b, R, F>(&'b mut self, f: F) -> (usize, R) + where + F: FnOnce(&'b mut [T]) -> (usize, R), + { + let capacity = self.capacity(); + let max_size = cmp::min(self.len(), capacity - self.read_at); + let (size, result) = f(&mut self.storage[self.read_at..self.read_at + max_size]); + assert!(size <= max_size); + self.read_at = if capacity > 0 { + (self.read_at + size) % capacity + } else { + 0 + }; + self.length -= size; + (size, result) + } + + /// Dequeue a slice of elements up to the given size from the buffer, + /// and return a reference to them. + /// + /// This function may return a slice smaller than the given size + /// if the allocated space in the buffer is not contiguous. + #[must_use] + pub fn dequeue_many(&mut self, size: usize) -> &mut [T] { + self.dequeue_many_with(|buf| { + let size = cmp::min(size, buf.len()); + (size, &mut buf[..size]) + }) + .1 + } + + /// Dequeue as many elements from the buffer into the given slice as possible, + /// and return the amount of elements that could fit. + #[must_use] + pub fn dequeue_slice(&mut self, data: &mut [T]) -> usize + where + T: Copy, + { + let (size_1, data) = self.dequeue_many_with(|buf| { + let size = cmp::min(buf.len(), data.len()); + data[..size].copy_from_slice(&buf[..size]); + (size, &mut data[size..]) + }); + let (size_2, ()) = self.dequeue_many_with(|buf| { + let size = cmp::min(buf.len(), data.len()); + data[..size].copy_from_slice(&buf[..size]); + (size, ()) + }); + size_1 + size_2 + } +} + +/// This is the "random access" ring buffer interface: it operates with element slices, +/// and allows to access elements of the buffer that are not adjacent to its head or tail. +impl<'a, T: 'a> RingBuffer<'a, T> { + /// Return the largest contiguous slice of unallocated buffer elements starting + /// at the given offset past the last allocated element, and up to the given size. + #[must_use] + pub fn get_unallocated(&mut self, offset: usize, mut size: usize) -> &mut [T] { + let start_at = self.get_idx(self.length + offset); + // We can't access past the end of unallocated data. + if offset > self.window() { + return &mut []; + } + // We can't enqueue more than there is free space. + let clamped_window = self.window() - offset; + if size > clamped_window { + size = clamped_window + } + // We can't contiguously enqueue past the end of the storage. + let until_end = self.capacity() - start_at; + if size > until_end { + size = until_end + } + + &mut self.storage[start_at..start_at + size] + } + + /// Write as many elements from the given slice into unallocated buffer elements + /// starting at the given offset past the last allocated element, and return + /// the amount written. + #[must_use] + pub fn write_unallocated(&mut self, offset: usize, data: &[T]) -> usize + where + T: Copy, + { + let (size_1, offset, data) = { + let slice = self.get_unallocated(offset, data.len()); + let slice_len = slice.len(); + slice.copy_from_slice(&data[..slice_len]); + (slice_len, offset + slice_len, &data[slice_len..]) + }; + let size_2 = { + let slice = self.get_unallocated(offset, data.len()); + let slice_len = slice.len(); + slice.copy_from_slice(&data[..slice_len]); + slice_len + }; + size_1 + size_2 + } + + /// Enqueue the given number of unallocated buffer elements. + /// + /// # Panics + /// Panics if the number of elements given exceeds the number of unallocated elements. + pub fn enqueue_unallocated(&mut self, count: usize) { + assert!(count <= self.window()); + self.length += count; + } + + /// Return the largest contiguous slice of allocated buffer elements starting + /// at the given offset past the first allocated element, and up to the given size. + #[must_use] + pub fn get_allocated(&self, offset: usize, mut size: usize) -> &[T] { + let start_at = self.get_idx(offset); + // We can't read past the end of the allocated data. + if offset > self.length { + return &mut []; + } + // We can't read more than we have allocated. + let clamped_length = self.length - offset; + if size > clamped_length { + size = clamped_length + } + // We can't contiguously dequeue past the end of the storage. + let until_end = self.capacity() - start_at; + if size > until_end { + size = until_end + } + + &self.storage[start_at..start_at + size] + } + + /// Read as many elements from allocated buffer elements into the given slice + /// starting at the given offset past the first allocated element, and return + /// the amount read. + #[must_use] + pub fn read_allocated(&mut self, offset: usize, data: &mut [T]) -> usize + where + T: Copy, + { + let (size_1, offset, data) = { + let slice = self.get_allocated(offset, data.len()); + data[..slice.len()].copy_from_slice(slice); + (slice.len(), offset + slice.len(), &mut data[slice.len()..]) + }; + let size_2 = { + let slice = self.get_allocated(offset, data.len()); + data[..slice.len()].copy_from_slice(slice); + slice.len() + }; + size_1 + size_2 + } + + /// Dequeue the given number of allocated buffer elements. + /// + /// # Panics + /// Panics if the number of elements given exceeds the number of allocated elements. + pub fn dequeue_allocated(&mut self, count: usize) { + assert!(count <= self.len()); + self.length -= count; + self.read_at = self.get_idx(count); + } +} + +impl<'a, T: 'a> From<ManagedSlice<'a, T>> for RingBuffer<'a, T> { + fn from(slice: ManagedSlice<'a, T>) -> RingBuffer<'a, T> { + RingBuffer::new(slice) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_buffer_length_changes() { + let mut ring = RingBuffer::new(vec![0; 2]); + assert!(ring.is_empty()); + assert!(!ring.is_full()); + assert_eq!(ring.len(), 0); + assert_eq!(ring.capacity(), 2); + assert_eq!(ring.window(), 2); + + ring.length = 1; + assert!(!ring.is_empty()); + assert!(!ring.is_full()); + assert_eq!(ring.len(), 1); + assert_eq!(ring.capacity(), 2); + assert_eq!(ring.window(), 1); + + ring.length = 2; + assert!(!ring.is_empty()); + assert!(ring.is_full()); + assert_eq!(ring.len(), 2); + assert_eq!(ring.capacity(), 2); + assert_eq!(ring.window(), 0); + } + + #[test] + fn test_buffer_enqueue_dequeue_one_with() { + let mut ring = RingBuffer::new(vec![0; 5]); + assert_eq!( + ring.dequeue_one_with(|_| -> Result::<(), ()> { unreachable!() }), + Err(Empty) + ); + + ring.enqueue_one_with(Ok::<_, ()>).unwrap().unwrap(); + assert!(!ring.is_empty()); + assert!(!ring.is_full()); + + for i in 1..5 { + ring.enqueue_one_with(|e| Ok::<_, ()>(*e = i)) + .unwrap() + .unwrap(); + assert!(!ring.is_empty()); + } + assert!(ring.is_full()); + assert_eq!( + ring.enqueue_one_with(|_| -> Result::<(), ()> { unreachable!() }), + Err(Full) + ); + + for i in 0..5 { + assert_eq!( + ring.dequeue_one_with(|e| Ok::<_, ()>(*e)).unwrap().unwrap(), + i + ); + assert!(!ring.is_full()); + } + assert_eq!( + ring.dequeue_one_with(|_| -> Result::<(), ()> { unreachable!() }), + Err(Empty) + ); + assert!(ring.is_empty()); + } + + #[test] + fn test_buffer_enqueue_dequeue_one() { + let mut ring = RingBuffer::new(vec![0; 5]); + assert_eq!(ring.dequeue_one(), Err(Empty)); + + ring.enqueue_one().unwrap(); + assert!(!ring.is_empty()); + assert!(!ring.is_full()); + + for i in 1..5 { + *ring.enqueue_one().unwrap() = i; + assert!(!ring.is_empty()); + } + assert!(ring.is_full()); + assert_eq!(ring.enqueue_one(), Err(Full)); + + for i in 0..5 { + assert_eq!(*ring.dequeue_one().unwrap(), i); + assert!(!ring.is_full()); + } + assert_eq!(ring.dequeue_one(), Err(Empty)); + assert!(ring.is_empty()); + } + + #[test] + fn test_buffer_enqueue_many_with() { + let mut ring = RingBuffer::new(vec![b'.'; 12]); + + assert_eq!( + ring.enqueue_many_with(|buf| { + assert_eq!(buf.len(), 12); + buf[0..2].copy_from_slice(b"ab"); + (2, true) + }), + (2, true) + ); + assert_eq!(ring.len(), 2); + assert_eq!(&ring.storage[..], b"ab.........."); + + ring.enqueue_many_with(|buf| { + assert_eq!(buf.len(), 12 - 2); + buf[0..4].copy_from_slice(b"cdXX"); + (2, ()) + }); + assert_eq!(ring.len(), 4); + assert_eq!(&ring.storage[..], b"abcdXX......"); + + ring.enqueue_many_with(|buf| { + assert_eq!(buf.len(), 12 - 4); + buf[0..4].copy_from_slice(b"efgh"); + (4, ()) + }); + assert_eq!(ring.len(), 8); + assert_eq!(&ring.storage[..], b"abcdefgh...."); + + for _ in 0..4 { + *ring.dequeue_one().unwrap() = b'.'; + } + assert_eq!(ring.len(), 4); + assert_eq!(&ring.storage[..], b"....efgh...."); + + ring.enqueue_many_with(|buf| { + assert_eq!(buf.len(), 12 - 8); + buf[0..4].copy_from_slice(b"ijkl"); + (4, ()) + }); + assert_eq!(ring.len(), 8); + assert_eq!(&ring.storage[..], b"....efghijkl"); + + ring.enqueue_many_with(|buf| { + assert_eq!(buf.len(), 4); + buf[0..4].copy_from_slice(b"abcd"); + (4, ()) + }); + assert_eq!(ring.len(), 12); + assert_eq!(&ring.storage[..], b"abcdefghijkl"); + + for _ in 0..4 { + *ring.dequeue_one().unwrap() = b'.'; + } + assert_eq!(ring.len(), 8); + assert_eq!(&ring.storage[..], b"abcd....ijkl"); + } + + #[test] + fn test_buffer_enqueue_many() { + let mut ring = RingBuffer::new(vec![b'.'; 12]); + + ring.enqueue_many(8).copy_from_slice(b"abcdefgh"); + assert_eq!(ring.len(), 8); + assert_eq!(&ring.storage[..], b"abcdefgh...."); + + ring.enqueue_many(8).copy_from_slice(b"ijkl"); + assert_eq!(ring.len(), 12); + assert_eq!(&ring.storage[..], b"abcdefghijkl"); + } + + #[test] + fn test_buffer_enqueue_slice() { + let mut ring = RingBuffer::new(vec![b'.'; 12]); + + assert_eq!(ring.enqueue_slice(b"abcdefgh"), 8); + assert_eq!(ring.len(), 8); + assert_eq!(&ring.storage[..], b"abcdefgh...."); + + for _ in 0..4 { + *ring.dequeue_one().unwrap() = b'.'; + } + assert_eq!(ring.len(), 4); + assert_eq!(&ring.storage[..], b"....efgh...."); + + assert_eq!(ring.enqueue_slice(b"ijklabcd"), 8); + assert_eq!(ring.len(), 12); + assert_eq!(&ring.storage[..], b"abcdefghijkl"); + } + + #[test] + fn test_buffer_dequeue_many_with() { + let mut ring = RingBuffer::new(vec![b'.'; 12]); + + assert_eq!(ring.enqueue_slice(b"abcdefghijkl"), 12); + + assert_eq!( + ring.dequeue_many_with(|buf| { + assert_eq!(buf.len(), 12); + assert_eq!(buf, b"abcdefghijkl"); + buf[..4].copy_from_slice(b"...."); + (4, true) + }), + (4, true) + ); + assert_eq!(ring.len(), 8); + assert_eq!(&ring.storage[..], b"....efghijkl"); + + ring.dequeue_many_with(|buf| { + assert_eq!(buf, b"efghijkl"); + buf[..4].copy_from_slice(b"...."); + (4, ()) + }); + assert_eq!(ring.len(), 4); + assert_eq!(&ring.storage[..], b"........ijkl"); + + assert_eq!(ring.enqueue_slice(b"abcd"), 4); + assert_eq!(ring.len(), 8); + + ring.dequeue_many_with(|buf| { + assert_eq!(buf, b"ijkl"); + buf[..4].copy_from_slice(b"...."); + (4, ()) + }); + ring.dequeue_many_with(|buf| { + assert_eq!(buf, b"abcd"); + buf[..4].copy_from_slice(b"...."); + (4, ()) + }); + assert_eq!(ring.len(), 0); + assert_eq!(&ring.storage[..], b"............"); + } + + #[test] + fn test_buffer_dequeue_many() { + let mut ring = RingBuffer::new(vec![b'.'; 12]); + + assert_eq!(ring.enqueue_slice(b"abcdefghijkl"), 12); + + { + let buf = ring.dequeue_many(8); + assert_eq!(buf, b"abcdefgh"); + buf.copy_from_slice(b"........"); + } + assert_eq!(ring.len(), 4); + assert_eq!(&ring.storage[..], b"........ijkl"); + + { + let buf = ring.dequeue_many(8); + assert_eq!(buf, b"ijkl"); + buf.copy_from_slice(b"...."); + } + assert_eq!(ring.len(), 0); + assert_eq!(&ring.storage[..], b"............"); + } + + #[test] + fn test_buffer_dequeue_slice() { + let mut ring = RingBuffer::new(vec![b'.'; 12]); + + assert_eq!(ring.enqueue_slice(b"abcdefghijkl"), 12); + + { + let mut buf = [0; 8]; + assert_eq!(ring.dequeue_slice(&mut buf[..]), 8); + assert_eq!(&buf[..], b"abcdefgh"); + assert_eq!(ring.len(), 4); + } + + assert_eq!(ring.enqueue_slice(b"abcd"), 4); + + { + let mut buf = [0; 8]; + assert_eq!(ring.dequeue_slice(&mut buf[..]), 8); + assert_eq!(&buf[..], b"ijklabcd"); + assert_eq!(ring.len(), 0); + } + } + + #[test] + fn test_buffer_get_unallocated() { + let mut ring = RingBuffer::new(vec![b'.'; 12]); + + assert_eq!(ring.get_unallocated(16, 4), b""); + + { + let buf = ring.get_unallocated(0, 4); + buf.copy_from_slice(b"abcd"); + } + assert_eq!(&ring.storage[..], b"abcd........"); + + let buf_enqueued = ring.enqueue_many(4); + assert_eq!(buf_enqueued.len(), 4); + assert_eq!(ring.len(), 4); + + { + let buf = ring.get_unallocated(4, 8); + buf.copy_from_slice(b"ijkl"); + } + assert_eq!(&ring.storage[..], b"abcd....ijkl"); + + ring.enqueue_many(8).copy_from_slice(b"EFGHIJKL"); + ring.dequeue_many(4).copy_from_slice(b"abcd"); + assert_eq!(ring.len(), 8); + assert_eq!(&ring.storage[..], b"abcdEFGHIJKL"); + + { + let buf = ring.get_unallocated(0, 8); + buf.copy_from_slice(b"ABCD"); + } + assert_eq!(&ring.storage[..], b"ABCDEFGHIJKL"); + } + + #[test] + fn test_buffer_write_unallocated() { + let mut ring = RingBuffer::new(vec![b'.'; 12]); + ring.enqueue_many(6).copy_from_slice(b"abcdef"); + ring.dequeue_many(6).copy_from_slice(b"ABCDEF"); + + assert_eq!(ring.write_unallocated(0, b"ghi"), 3); + assert_eq!(ring.get_unallocated(0, 3), b"ghi"); + + assert_eq!(ring.write_unallocated(3, b"jklmno"), 6); + assert_eq!(ring.get_unallocated(3, 3), b"jkl"); + + assert_eq!(ring.write_unallocated(9, b"pqrstu"), 3); + assert_eq!(ring.get_unallocated(9, 3), b"pqr"); + } + + #[test] + fn test_buffer_get_allocated() { + let mut ring = RingBuffer::new(vec![b'.'; 12]); + + assert_eq!(ring.get_allocated(16, 4), b""); + assert_eq!(ring.get_allocated(0, 4), b""); + + let len_enqueued = ring.enqueue_slice(b"abcd"); + assert_eq!(ring.get_allocated(0, 8), b"abcd"); + assert_eq!(len_enqueued, 4); + + let len_enqueued = ring.enqueue_slice(b"efghijkl"); + ring.dequeue_many(4).copy_from_slice(b"...."); + assert_eq!(ring.get_allocated(4, 8), b"ijkl"); + assert_eq!(len_enqueued, 8); + + let len_enqueued = ring.enqueue_slice(b"abcd"); + assert_eq!(ring.get_allocated(4, 8), b"ijkl"); + assert_eq!(len_enqueued, 4); + } + + #[test] + fn test_buffer_read_allocated() { + let mut ring = RingBuffer::new(vec![b'.'; 12]); + ring.enqueue_many(12).copy_from_slice(b"abcdefghijkl"); + + let mut data = [0; 6]; + assert_eq!(ring.read_allocated(0, &mut data[..]), 6); + assert_eq!(&data[..], b"abcdef"); + + ring.dequeue_many(6).copy_from_slice(b"ABCDEF"); + ring.enqueue_many(3).copy_from_slice(b"mno"); + + let mut data = [0; 6]; + assert_eq!(ring.read_allocated(3, &mut data[..]), 6); + assert_eq!(&data[..], b"jklmno"); + + let mut data = [0; 6]; + assert_eq!(ring.read_allocated(6, &mut data[..]), 3); + assert_eq!(&data[..], b"mno\x00\x00\x00"); + } + + #[test] + fn test_buffer_with_no_capacity() { + let mut no_capacity: RingBuffer<u8> = RingBuffer::new(vec![]); + + // Call all functions that calculate the remainder against rx_buffer.capacity() + // with a backing storage with a length of 0. + assert_eq!(no_capacity.get_unallocated(0, 0), &[]); + assert_eq!(no_capacity.get_allocated(0, 0), &[]); + no_capacity.dequeue_allocated(0); + assert_eq!(no_capacity.enqueue_many(0), &[]); + assert_eq!(no_capacity.enqueue_one(), Err(Full)); + assert_eq!(no_capacity.contiguous_window(), 0); + } + + /// Use the buffer a bit. Then empty it and put in an item of + /// maximum size. By detecting a length of 0, the implementation + /// can reset the current buffer position. + #[test] + fn test_buffer_write_wholly() { + let mut ring = RingBuffer::new(vec![b'.'; 8]); + ring.enqueue_many(2).copy_from_slice(b"ab"); + ring.enqueue_many(2).copy_from_slice(b"cd"); + assert_eq!(ring.len(), 4); + let buf_dequeued = ring.dequeue_many(4); + assert_eq!(buf_dequeued, b"abcd"); + assert_eq!(ring.len(), 0); + + let large = ring.enqueue_many(8); + assert_eq!(large.len(), 8); + } +} diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..ec026ab --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,148 @@ +use crate::iface::*; +use crate::wire::*; + +pub(crate) fn setup<'a>(medium: Medium) -> (Interface, SocketSet<'a>, TestingDevice) { + let mut device = TestingDevice::new(medium); + + let config = Config::new(match medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => { + HardwareAddress::Ethernet(EthernetAddress([0x02, 0x02, 0x02, 0x02, 0x02, 0x02])) + } + #[cfg(feature = "medium-ip")] + Medium::Ip => HardwareAddress::Ip, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => HardwareAddress::Ieee802154(Ieee802154Address::Extended([ + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + ])), + }); + + let mut iface = Interface::new(config, &mut device, Instant::ZERO); + + #[cfg(feature = "proto-ipv4")] + { + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 1, 1), 24)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v4(127, 0, 0, 1), 8)) + .unwrap(); + }); + } + + #[cfg(feature = "proto-ipv6")] + { + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0, 0, 0, 0, 0, 0, 0, 1), 128)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdbe, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + } + + (iface, SocketSet::new(vec![]), device) +} + +use heapless::Deque; +use heapless::Vec; + +use crate::phy::{self, Device, DeviceCapabilities, Medium}; +use crate::time::Instant; + +/// A testing device. +#[derive(Debug)] +pub struct TestingDevice { + pub(crate) queue: Deque<Vec<u8, 1514>, 4>, + max_transmission_unit: usize, + medium: Medium, +} + +#[allow(clippy::new_without_default)] +impl TestingDevice { + /// Creates a testing device. + /// + /// Every packet transmitted through this device will be received through it + /// in FIFO order. + pub fn new(medium: Medium) -> Self { + TestingDevice { + queue: Deque::new(), + max_transmission_unit: match medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => 1514, + #[cfg(feature = "medium-ip")] + Medium::Ip => 1500, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => 1500, + }, + medium, + } + } +} + +impl Device for TestingDevice { + type RxToken<'a> = RxToken; + type TxToken<'a> = TxToken<'a>; + + fn capabilities(&self) -> DeviceCapabilities { + DeviceCapabilities { + medium: self.medium, + max_transmission_unit: self.max_transmission_unit, + ..DeviceCapabilities::default() + } + } + + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + self.queue.pop_front().map(move |buffer| { + let rx = RxToken { buffer }; + let tx = TxToken { + queue: &mut self.queue, + }; + (rx, tx) + }) + } + + fn transmit(&mut self, _timestamp: Instant) -> Option<Self::TxToken<'_>> { + Some(TxToken { + queue: &mut self.queue, + }) + } +} + +#[doc(hidden)] +pub struct RxToken { + buffer: Vec<u8, 1514>, +} + +impl phy::RxToken for RxToken { + fn consume<R, F>(mut self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + f(&mut self.buffer) + } +} + +#[doc(hidden)] +#[derive(Debug)] +pub struct TxToken<'a> { + queue: &'a mut Deque<Vec<u8, 1514>, 4>, +} + +impl<'a> phy::TxToken for TxToken<'a> { + fn consume<R, F>(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut buffer = Vec::new(); + buffer.resize(len, 0).unwrap(); + let result = f(&mut buffer); + self.queue.push_back(buffer).unwrap(); + result + } +} diff --git a/src/time.rs b/src/time.rs new file mode 100644 index 0000000..e6904af --- /dev/null +++ b/src/time.rs @@ -0,0 +1,460 @@ +/*! Time structures. + +The `time` module contains structures used to represent both +absolute and relative time. + + - [Instant] is used to represent absolute time. + - [Duration] is used to represent relative time. + +[Instant]: struct.Instant.html +[Duration]: struct.Duration.html +*/ + +use core::{fmt, ops}; + +/// A representation of an absolute time value. +/// +/// The `Instant` type is a wrapper around a `i64` value that +/// represents a number of microseconds, monotonically increasing +/// since an arbitrary moment in time, such as system startup. +/// +/// * A value of `0` is inherently arbitrary. +/// * A value less than `0` indicates a time before the starting +/// point. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct Instant { + micros: i64, +} + +impl Instant { + pub const ZERO: Instant = Instant::from_micros_const(0); + + /// Create a new `Instant` from a number of microseconds. + pub fn from_micros<T: Into<i64>>(micros: T) -> Instant { + Instant { + micros: micros.into(), + } + } + + pub const fn from_micros_const(micros: i64) -> Instant { + Instant { micros } + } + + /// Create a new `Instant` from a number of milliseconds. + pub fn from_millis<T: Into<i64>>(millis: T) -> Instant { + Instant { + micros: millis.into() * 1000, + } + } + + /// Create a new `Instant` from a number of milliseconds. + pub const fn from_millis_const(millis: i64) -> Instant { + Instant { + micros: millis * 1000, + } + } + + /// Create a new `Instant` from a number of seconds. + pub fn from_secs<T: Into<i64>>(secs: T) -> Instant { + Instant { + micros: secs.into() * 1000000, + } + } + + /// Create a new `Instant` from the current [std::time::SystemTime]. + /// + /// See [std::time::SystemTime::now] + /// + /// [std::time::SystemTime]: https://doc.rust-lang.org/std/time/struct.SystemTime.html + /// [std::time::SystemTime::now]: https://doc.rust-lang.org/std/time/struct.SystemTime.html#method.now + #[cfg(feature = "std")] + pub fn now() -> Instant { + Self::from(::std::time::SystemTime::now()) + } + + /// The fractional number of milliseconds that have passed + /// since the beginning of time. + pub const fn millis(&self) -> i64 { + self.micros % 1000000 / 1000 + } + + /// The fractional number of microseconds that have passed + /// since the beginning of time. + pub const fn micros(&self) -> i64 { + self.micros % 1000000 + } + + /// The number of whole seconds that have passed since the + /// beginning of time. + pub const fn secs(&self) -> i64 { + self.micros / 1000000 + } + + /// The total number of milliseconds that have passed since + /// the beginning of time. + pub const fn total_millis(&self) -> i64 { + self.micros / 1000 + } + /// The total number of milliseconds that have passed since + /// the beginning of time. + pub const fn total_micros(&self) -> i64 { + self.micros + } +} + +#[cfg(feature = "std")] +impl From<::std::time::Instant> for Instant { + fn from(other: ::std::time::Instant) -> Instant { + let elapsed = other.elapsed(); + Instant::from_micros((elapsed.as_secs() * 1_000000) as i64 + elapsed.subsec_micros() as i64) + } +} + +#[cfg(feature = "std")] +impl From<::std::time::SystemTime> for Instant { + fn from(other: ::std::time::SystemTime) -> Instant { + let n = other + .duration_since(::std::time::UNIX_EPOCH) + .expect("start time must not be before the unix epoch"); + Self::from_micros(n.as_secs() as i64 * 1000000 + n.subsec_micros() as i64) + } +} + +#[cfg(feature = "std")] +impl From<Instant> for ::std::time::SystemTime { + fn from(val: Instant) -> Self { + ::std::time::UNIX_EPOCH + ::std::time::Duration::from_micros(val.micros as u64) + } +} + +impl fmt::Display for Instant { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}.{:0>3}s", self.secs(), self.millis()) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Instant { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "{}.{:03}s", self.secs(), self.millis()); + } +} + +impl ops::Add<Duration> for Instant { + type Output = Instant; + + fn add(self, rhs: Duration) -> Instant { + Instant::from_micros(self.micros + rhs.total_micros() as i64) + } +} + +impl ops::AddAssign<Duration> for Instant { + fn add_assign(&mut self, rhs: Duration) { + self.micros += rhs.total_micros() as i64; + } +} + +impl ops::Sub<Duration> for Instant { + type Output = Instant; + + fn sub(self, rhs: Duration) -> Instant { + Instant::from_micros(self.micros - rhs.total_micros() as i64) + } +} + +impl ops::SubAssign<Duration> for Instant { + fn sub_assign(&mut self, rhs: Duration) { + self.micros -= rhs.total_micros() as i64; + } +} + +impl ops::Sub<Instant> for Instant { + type Output = Duration; + + fn sub(self, rhs: Instant) -> Duration { + Duration::from_micros((self.micros - rhs.micros).unsigned_abs()) + } +} + +/// A relative amount of time. +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct Duration { + micros: u64, +} + +impl Duration { + pub const ZERO: Duration = Duration::from_micros(0); + /// The longest possible duration we can encode. + pub const MAX: Duration = Duration::from_micros(u64::MAX); + /// Create a new `Duration` from a number of microseconds. + pub const fn from_micros(micros: u64) -> Duration { + Duration { micros } + } + + /// Create a new `Duration` from a number of milliseconds. + pub const fn from_millis(millis: u64) -> Duration { + Duration { + micros: millis * 1000, + } + } + + /// Create a new `Instant` from a number of seconds. + pub const fn from_secs(secs: u64) -> Duration { + Duration { + micros: secs * 1000000, + } + } + + /// The fractional number of milliseconds in this `Duration`. + pub const fn millis(&self) -> u64 { + self.micros / 1000 % 1000 + } + + /// The fractional number of milliseconds in this `Duration`. + pub const fn micros(&self) -> u64 { + self.micros % 1000000 + } + + /// The number of whole seconds in this `Duration`. + pub const fn secs(&self) -> u64 { + self.micros / 1000000 + } + + /// The total number of milliseconds in this `Duration`. + pub const fn total_millis(&self) -> u64 { + self.micros / 1000 + } + + /// The total number of microseconds in this `Duration`. + pub const fn total_micros(&self) -> u64 { + self.micros + } +} + +impl fmt::Display for Duration { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}.{:03}s", self.secs(), self.millis()) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Duration { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "{}.{:03}s", self.secs(), self.millis()); + } +} + +impl ops::Add<Duration> for Duration { + type Output = Duration; + + fn add(self, rhs: Duration) -> Duration { + Duration::from_micros(self.micros + rhs.total_micros()) + } +} + +impl ops::AddAssign<Duration> for Duration { + fn add_assign(&mut self, rhs: Duration) { + self.micros += rhs.total_micros(); + } +} + +impl ops::Sub<Duration> for Duration { + type Output = Duration; + + fn sub(self, rhs: Duration) -> Duration { + Duration::from_micros( + self.micros + .checked_sub(rhs.total_micros()) + .expect("overflow when subtracting durations"), + ) + } +} + +impl ops::SubAssign<Duration> for Duration { + fn sub_assign(&mut self, rhs: Duration) { + self.micros = self + .micros + .checked_sub(rhs.total_micros()) + .expect("overflow when subtracting durations"); + } +} + +impl ops::Mul<u32> for Duration { + type Output = Duration; + + fn mul(self, rhs: u32) -> Duration { + Duration::from_micros(self.micros * rhs as u64) + } +} + +impl ops::MulAssign<u32> for Duration { + fn mul_assign(&mut self, rhs: u32) { + self.micros *= rhs as u64; + } +} + +impl ops::Div<u32> for Duration { + type Output = Duration; + + fn div(self, rhs: u32) -> Duration { + Duration::from_micros(self.micros / rhs as u64) + } +} + +impl ops::DivAssign<u32> for Duration { + fn div_assign(&mut self, rhs: u32) { + self.micros /= rhs as u64; + } +} + +impl ops::Shl<u32> for Duration { + type Output = Duration; + + fn shl(self, rhs: u32) -> Duration { + Duration::from_micros(self.micros << rhs) + } +} + +impl ops::ShlAssign<u32> for Duration { + fn shl_assign(&mut self, rhs: u32) { + self.micros <<= rhs; + } +} + +impl ops::Shr<u32> for Duration { + type Output = Duration; + + fn shr(self, rhs: u32) -> Duration { + Duration::from_micros(self.micros >> rhs) + } +} + +impl ops::ShrAssign<u32> for Duration { + fn shr_assign(&mut self, rhs: u32) { + self.micros >>= rhs; + } +} + +impl From<::core::time::Duration> for Duration { + fn from(other: ::core::time::Duration) -> Duration { + Duration::from_micros(other.as_secs() * 1000000 + other.subsec_micros() as u64) + } +} + +impl From<Duration> for ::core::time::Duration { + fn from(val: Duration) -> Self { + ::core::time::Duration::from_micros(val.total_micros()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_instant_ops() { + // std::ops::Add + assert_eq!( + Instant::from_millis(4) + Duration::from_millis(6), + Instant::from_millis(10) + ); + // std::ops::Sub + assert_eq!( + Instant::from_millis(7) - Duration::from_millis(5), + Instant::from_millis(2) + ); + } + + #[test] + fn test_instant_getters() { + let instant = Instant::from_millis(5674); + assert_eq!(instant.secs(), 5); + assert_eq!(instant.millis(), 674); + assert_eq!(instant.total_millis(), 5674); + } + + #[test] + fn test_instant_display() { + assert_eq!(format!("{}", Instant::from_millis(74)), "0.074s"); + assert_eq!(format!("{}", Instant::from_millis(5674)), "5.674s"); + assert_eq!(format!("{}", Instant::from_millis(5000)), "5.000s"); + } + + #[test] + #[cfg(feature = "std")] + fn test_instant_conversions() { + let mut epoc: ::std::time::SystemTime = Instant::from_millis(0).into(); + assert_eq!( + Instant::from(::std::time::UNIX_EPOCH), + Instant::from_millis(0) + ); + assert_eq!(epoc, ::std::time::UNIX_EPOCH); + epoc = Instant::from_millis(2085955200i64 * 1000).into(); + assert_eq!( + epoc, + ::std::time::UNIX_EPOCH + ::std::time::Duration::from_secs(2085955200) + ); + } + + #[test] + fn test_duration_ops() { + // std::ops::Add + assert_eq!( + Duration::from_millis(40) + Duration::from_millis(2), + Duration::from_millis(42) + ); + // std::ops::Sub + assert_eq!( + Duration::from_millis(555) - Duration::from_millis(42), + Duration::from_millis(513) + ); + // std::ops::Mul + assert_eq!(Duration::from_millis(13) * 22, Duration::from_millis(286)); + // std::ops::Div + assert_eq!(Duration::from_millis(53) / 4, Duration::from_micros(13250)); + } + + #[test] + fn test_duration_assign_ops() { + let mut duration = Duration::from_millis(4735); + duration += Duration::from_millis(1733); + assert_eq!(duration, Duration::from_millis(6468)); + duration -= Duration::from_millis(1234); + assert_eq!(duration, Duration::from_millis(5234)); + duration *= 4; + assert_eq!(duration, Duration::from_millis(20936)); + duration /= 5; + assert_eq!(duration, Duration::from_micros(4187200)); + } + + #[test] + #[should_panic(expected = "overflow when subtracting durations")] + fn test_sub_from_zero_overflow() { + let _ = Duration::from_millis(0) - Duration::from_millis(1); + } + + #[test] + #[should_panic(expected = "attempt to divide by zero")] + fn test_div_by_zero() { + let _ = Duration::from_millis(4) / 0; + } + + #[test] + fn test_duration_getters() { + let instant = Duration::from_millis(4934); + assert_eq!(instant.secs(), 4); + assert_eq!(instant.millis(), 934); + assert_eq!(instant.total_millis(), 4934); + } + + #[test] + fn test_duration_conversions() { + let mut std_duration = ::core::time::Duration::from_millis(4934); + let duration: Duration = std_duration.into(); + assert_eq!(duration, Duration::from_millis(4934)); + assert_eq!(Duration::from(std_duration), Duration::from_millis(4934)); + + std_duration = duration.into(); + assert_eq!(std_duration, ::core::time::Duration::from_millis(4934)); + } +} diff --git a/src/wire/arp.rs b/src/wire/arp.rs new file mode 100644 index 0000000..bb0df3a --- /dev/null +++ b/src/wire/arp.rs @@ -0,0 +1,458 @@ +use byteorder::{ByteOrder, NetworkEndian}; +use core::fmt; + +use super::{Error, Result}; + +pub use super::EthernetProtocol as Protocol; + +enum_with_unknown! { + /// ARP hardware type. + pub enum Hardware(u16) { + Ethernet = 1 + } +} + +enum_with_unknown! { + /// ARP operation type. + pub enum Operation(u16) { + Request = 1, + Reply = 2 + } +} + +/// A read/write wrapper around an Address Resolution Protocol packet buffer. +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Packet<T: AsRef<[u8]>> { + buffer: T, +} + +mod field { + #![allow(non_snake_case)] + + use crate::wire::field::*; + + pub const HTYPE: Field = 0..2; + pub const PTYPE: Field = 2..4; + pub const HLEN: usize = 4; + pub const PLEN: usize = 5; + pub const OPER: Field = 6..8; + + #[inline] + pub const fn SHA(hardware_len: u8, _protocol_len: u8) -> Field { + let start = OPER.end; + start..(start + hardware_len as usize) + } + + #[inline] + pub const fn SPA(hardware_len: u8, protocol_len: u8) -> Field { + let start = SHA(hardware_len, protocol_len).end; + start..(start + protocol_len as usize) + } + + #[inline] + pub const fn THA(hardware_len: u8, protocol_len: u8) -> Field { + let start = SPA(hardware_len, protocol_len).end; + start..(start + hardware_len as usize) + } + + #[inline] + pub const fn TPA(hardware_len: u8, protocol_len: u8) -> Field { + let start = THA(hardware_len, protocol_len).end; + start..(start + protocol_len as usize) + } +} + +impl<T: AsRef<[u8]>> Packet<T> { + /// Imbue a raw octet buffer with ARP packet structure. + pub const fn new_unchecked(buffer: T) -> Packet<T> { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Packet<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + /// + /// The result of this check is invalidated by calling [set_hardware_len] or + /// [set_protocol_len]. + /// + /// [set_hardware_len]: #method.set_hardware_len + /// [set_protocol_len]: #method.set_protocol_len + #[allow(clippy::if_same_then_else)] + pub fn check_len(&self) -> Result<()> { + let len = self.buffer.as_ref().len(); + if len < field::OPER.end { + Err(Error) + } else if len < field::TPA(self.hardware_len(), self.protocol_len()).end { + Err(Error) + } else { + Ok(()) + } + } + + /// Consume the packet, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the hardware type field. + #[inline] + pub fn hardware_type(&self) -> Hardware { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::HTYPE]); + Hardware::from(raw) + } + + /// Return the protocol type field. + #[inline] + pub fn protocol_type(&self) -> Protocol { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::PTYPE]); + Protocol::from(raw) + } + + /// Return the hardware length field. + #[inline] + pub fn hardware_len(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::HLEN] + } + + /// Return the protocol length field. + #[inline] + pub fn protocol_len(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::PLEN] + } + + /// Return the operation field. + #[inline] + pub fn operation(&self) -> Operation { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::OPER]); + Operation::from(raw) + } + + /// Return the source hardware address field. + pub fn source_hardware_addr(&self) -> &[u8] { + let data = self.buffer.as_ref(); + &data[field::SHA(self.hardware_len(), self.protocol_len())] + } + + /// Return the source protocol address field. + pub fn source_protocol_addr(&self) -> &[u8] { + let data = self.buffer.as_ref(); + &data[field::SPA(self.hardware_len(), self.protocol_len())] + } + + /// Return the target hardware address field. + pub fn target_hardware_addr(&self) -> &[u8] { + let data = self.buffer.as_ref(); + &data[field::THA(self.hardware_len(), self.protocol_len())] + } + + /// Return the target protocol address field. + pub fn target_protocol_addr(&self) -> &[u8] { + let data = self.buffer.as_ref(); + &data[field::TPA(self.hardware_len(), self.protocol_len())] + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the hardware type field. + #[inline] + pub fn set_hardware_type(&mut self, value: Hardware) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::HTYPE], value.into()) + } + + /// Set the protocol type field. + #[inline] + pub fn set_protocol_type(&mut self, value: Protocol) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::PTYPE], value.into()) + } + + /// Set the hardware length field. + #[inline] + pub fn set_hardware_len(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::HLEN] = value + } + + /// Set the protocol length field. + #[inline] + pub fn set_protocol_len(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::PLEN] = value + } + + /// Set the operation field. + #[inline] + pub fn set_operation(&mut self, value: Operation) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::OPER], value.into()) + } + + /// Set the source hardware address field. + /// + /// # Panics + /// The function panics if `value` is not `self.hardware_len()` long. + pub fn set_source_hardware_addr(&mut self, value: &[u8]) { + let (hardware_len, protocol_len) = (self.hardware_len(), self.protocol_len()); + let data = self.buffer.as_mut(); + data[field::SHA(hardware_len, protocol_len)].copy_from_slice(value) + } + + /// Set the source protocol address field. + /// + /// # Panics + /// The function panics if `value` is not `self.protocol_len()` long. + pub fn set_source_protocol_addr(&mut self, value: &[u8]) { + let (hardware_len, protocol_len) = (self.hardware_len(), self.protocol_len()); + let data = self.buffer.as_mut(); + data[field::SPA(hardware_len, protocol_len)].copy_from_slice(value) + } + + /// Set the target hardware address field. + /// + /// # Panics + /// The function panics if `value` is not `self.hardware_len()` long. + pub fn set_target_hardware_addr(&mut self, value: &[u8]) { + let (hardware_len, protocol_len) = (self.hardware_len(), self.protocol_len()); + let data = self.buffer.as_mut(); + data[field::THA(hardware_len, protocol_len)].copy_from_slice(value) + } + + /// Set the target protocol address field. + /// + /// # Panics + /// The function panics if `value` is not `self.protocol_len()` long. + pub fn set_target_protocol_addr(&mut self, value: &[u8]) { + let (hardware_len, protocol_len) = (self.hardware_len(), self.protocol_len()); + let data = self.buffer.as_mut(); + data[field::TPA(hardware_len, protocol_len)].copy_from_slice(value) + } +} + +impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> { + fn as_ref(&self) -> &[u8] { + self.buffer.as_ref() + } +} + +use crate::wire::{EthernetAddress, Ipv4Address}; + +/// A high-level representation of an Address Resolution Protocol packet. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] +pub enum Repr { + /// An Ethernet and IPv4 Address Resolution Protocol packet. + EthernetIpv4 { + operation: Operation, + source_hardware_addr: EthernetAddress, + source_protocol_addr: Ipv4Address, + target_hardware_addr: EthernetAddress, + target_protocol_addr: Ipv4Address, + }, +} + +impl Repr { + /// Parse an Address Resolution Protocol packet and return a high-level representation, + /// or return `Err(Error)` if the packet is not recognized. + pub fn parse<T: AsRef<[u8]>>(packet: &Packet<T>) -> Result<Repr> { + match ( + packet.hardware_type(), + packet.protocol_type(), + packet.hardware_len(), + packet.protocol_len(), + ) { + (Hardware::Ethernet, Protocol::Ipv4, 6, 4) => Ok(Repr::EthernetIpv4 { + operation: packet.operation(), + source_hardware_addr: EthernetAddress::from_bytes(packet.source_hardware_addr()), + source_protocol_addr: Ipv4Address::from_bytes(packet.source_protocol_addr()), + target_hardware_addr: EthernetAddress::from_bytes(packet.target_hardware_addr()), + target_protocol_addr: Ipv4Address::from_bytes(packet.target_protocol_addr()), + }), + _ => Err(Error), + } + } + + /// Return the length of a packet that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + match *self { + Repr::EthernetIpv4 { .. } => field::TPA(6, 4).end, + } + } + + /// Emit a high-level representation into an Address Resolution Protocol packet. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, packet: &mut Packet<T>) { + match *self { + Repr::EthernetIpv4 { + operation, + source_hardware_addr, + source_protocol_addr, + target_hardware_addr, + target_protocol_addr, + } => { + packet.set_hardware_type(Hardware::Ethernet); + packet.set_protocol_type(Protocol::Ipv4); + packet.set_hardware_len(6); + packet.set_protocol_len(4); + packet.set_operation(operation); + packet.set_source_hardware_addr(source_hardware_addr.as_bytes()); + packet.set_source_protocol_addr(source_protocol_addr.as_bytes()); + packet.set_target_hardware_addr(target_hardware_addr.as_bytes()); + packet.set_target_protocol_addr(target_protocol_addr.as_bytes()); + } + } + } +} + +impl<T: AsRef<[u8]>> fmt::Display for Packet<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match Repr::parse(self) { + Ok(repr) => write!(f, "{repr}"), + _ => { + write!(f, "ARP (unrecognized)")?; + write!( + f, + " htype={:?} ptype={:?} hlen={:?} plen={:?} op={:?}", + self.hardware_type(), + self.protocol_type(), + self.hardware_len(), + self.protocol_len(), + self.operation() + )?; + write!( + f, + " sha={:?} spa={:?} tha={:?} tpa={:?}", + self.source_hardware_addr(), + self.source_protocol_addr(), + self.target_hardware_addr(), + self.target_protocol_addr() + )?; + Ok(()) + } + } + } +} + +impl fmt::Display for Repr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Repr::EthernetIpv4 { + operation, + source_hardware_addr, + source_protocol_addr, + target_hardware_addr, + target_protocol_addr, + } => { + write!( + f, + "ARP type=Ethernet+IPv4 src={source_hardware_addr}/{source_protocol_addr} tgt={target_hardware_addr}/{target_protocol_addr} op={operation:?}" + ) + } + } + } +} + +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; + +impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { + match Packet::new_checked(buffer) { + Err(err) => write!(f, "{indent}({err})"), + Ok(packet) => write!(f, "{indent}{packet}"), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + static PACKET_BYTES: [u8; 28] = [ + 0x00, 0x01, 0x08, 0x00, 0x06, 0x04, 0x00, 0x01, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x21, + 0x22, 0x23, 0x24, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x41, 0x42, 0x43, 0x44, + ]; + + #[test] + fn test_deconstruct() { + let packet = Packet::new_unchecked(&PACKET_BYTES[..]); + assert_eq!(packet.hardware_type(), Hardware::Ethernet); + assert_eq!(packet.protocol_type(), Protocol::Ipv4); + assert_eq!(packet.hardware_len(), 6); + assert_eq!(packet.protocol_len(), 4); + assert_eq!(packet.operation(), Operation::Request); + assert_eq!( + packet.source_hardware_addr(), + &[0x11, 0x12, 0x13, 0x14, 0x15, 0x16] + ); + assert_eq!(packet.source_protocol_addr(), &[0x21, 0x22, 0x23, 0x24]); + assert_eq!( + packet.target_hardware_addr(), + &[0x31, 0x32, 0x33, 0x34, 0x35, 0x36] + ); + assert_eq!(packet.target_protocol_addr(), &[0x41, 0x42, 0x43, 0x44]); + } + + #[test] + fn test_construct() { + let mut bytes = vec![0xa5; 28]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_hardware_type(Hardware::Ethernet); + packet.set_protocol_type(Protocol::Ipv4); + packet.set_hardware_len(6); + packet.set_protocol_len(4); + packet.set_operation(Operation::Request); + packet.set_source_hardware_addr(&[0x11, 0x12, 0x13, 0x14, 0x15, 0x16]); + packet.set_source_protocol_addr(&[0x21, 0x22, 0x23, 0x24]); + packet.set_target_hardware_addr(&[0x31, 0x32, 0x33, 0x34, 0x35, 0x36]); + packet.set_target_protocol_addr(&[0x41, 0x42, 0x43, 0x44]); + assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]); + } + + fn packet_repr() -> Repr { + Repr::EthernetIpv4 { + operation: Operation::Request, + source_hardware_addr: EthernetAddress::from_bytes(&[ + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, + ]), + source_protocol_addr: Ipv4Address::from_bytes(&[0x21, 0x22, 0x23, 0x24]), + target_hardware_addr: EthernetAddress::from_bytes(&[ + 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, + ]), + target_protocol_addr: Ipv4Address::from_bytes(&[0x41, 0x42, 0x43, 0x44]), + } + } + + #[test] + fn test_parse() { + let packet = Packet::new_unchecked(&PACKET_BYTES[..]); + let repr = Repr::parse(&packet).unwrap(); + assert_eq!(repr, packet_repr()); + } + + #[test] + fn test_emit() { + let mut bytes = vec![0xa5; 28]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet_repr().emit(&mut packet); + assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]); + } +} diff --git a/src/wire/dhcpv4.rs b/src/wire/dhcpv4.rs new file mode 100644 index 0000000..cae9129 --- /dev/null +++ b/src/wire/dhcpv4.rs @@ -0,0 +1,1315 @@ +// See https://tools.ietf.org/html/rfc2131 for the DHCP specification. + +use bitflags::bitflags; +use byteorder::{ByteOrder, NetworkEndian}; +use core::iter; +use heapless::Vec; + +use super::{Error, Result}; +use crate::wire::arp::Hardware; +use crate::wire::{EthernetAddress, Ipv4Address}; + +pub const SERVER_PORT: u16 = 67; +pub const CLIENT_PORT: u16 = 68; +pub const MAX_DNS_SERVER_COUNT: usize = 3; + +const DHCP_MAGIC_NUMBER: u32 = 0x63825363; + +enum_with_unknown! { + /// The possible opcodes of a DHCP packet. + pub enum OpCode(u8) { + Request = 1, + Reply = 2, + } +} + +enum_with_unknown! { + /// The possible message types of a DHCP packet. + pub enum MessageType(u8) { + Discover = 1, + Offer = 2, + Request = 3, + Decline = 4, + Ack = 5, + Nak = 6, + Release = 7, + Inform = 8, + } +} + +bitflags! { + pub struct Flags: u16 { + const BROADCAST = 0b1000_0000_0000_0000; + } +} + +impl MessageType { + const fn opcode(&self) -> OpCode { + match *self { + MessageType::Discover + | MessageType::Inform + | MessageType::Request + | MessageType::Decline + | MessageType::Release => OpCode::Request, + MessageType::Offer | MessageType::Ack | MessageType::Nak => OpCode::Reply, + MessageType::Unknown(_) => OpCode::Unknown(0), + } + } +} + +/// A buffer for DHCP options. +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct DhcpOptionWriter<'a> { + /// The underlying buffer, directly from the DHCP packet representation. + buffer: &'a mut [u8], +} + +impl<'a> DhcpOptionWriter<'a> { + pub fn new(buffer: &'a mut [u8]) -> Self { + Self { buffer } + } + + /// Emit a [`DhcpOption`] into a [`DhcpOptionWriter`]. + pub fn emit(&mut self, option: DhcpOption<'_>) -> Result<()> { + if option.data.len() > u8::MAX as _ { + return Err(Error); + } + + let total_len = 2 + option.data.len(); + if self.buffer.len() < total_len { + return Err(Error); + } + + let (buf, rest) = core::mem::take(&mut self.buffer).split_at_mut(total_len); + self.buffer = rest; + + buf[0] = option.kind; + buf[1] = option.data.len() as _; + buf[2..].copy_from_slice(option.data); + + Ok(()) + } + + pub fn end(&mut self) -> Result<()> { + if self.buffer.is_empty() { + return Err(Error); + } + + self.buffer[0] = field::OPT_END; + self.buffer = &mut []; + Ok(()) + } +} + +/// A representation of a single DHCP option. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct DhcpOption<'a> { + pub kind: u8, + pub data: &'a [u8], +} + +/// A read/write wrapper around a Dynamic Host Configuration Protocol packet buffer. +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Packet<T: AsRef<[u8]>> { + buffer: T, +} + +pub(crate) mod field { + #![allow(non_snake_case)] + #![allow(unused)] + + use crate::wire::field::*; + + pub const OP: usize = 0; + pub const HTYPE: usize = 1; + pub const HLEN: usize = 2; + pub const HOPS: usize = 3; + pub const XID: Field = 4..8; + pub const SECS: Field = 8..10; + pub const FLAGS: Field = 10..12; + pub const CIADDR: Field = 12..16; + pub const YIADDR: Field = 16..20; + pub const SIADDR: Field = 20..24; + pub const GIADDR: Field = 24..28; + pub const CHADDR: Field = 28..34; + pub const SNAME: Field = 34..108; + pub const FILE: Field = 108..236; + pub const MAGIC_NUMBER: Field = 236..240; + pub const OPTIONS: Rest = 240..; + + // Vendor Extensions + pub const OPT_END: u8 = 255; + pub const OPT_PAD: u8 = 0; + pub const OPT_SUBNET_MASK: u8 = 1; + pub const OPT_TIME_OFFSET: u8 = 2; + pub const OPT_ROUTER: u8 = 3; + pub const OPT_TIME_SERVER: u8 = 4; + pub const OPT_NAME_SERVER: u8 = 5; + pub const OPT_DOMAIN_NAME_SERVER: u8 = 6; + pub const OPT_LOG_SERVER: u8 = 7; + pub const OPT_COOKIE_SERVER: u8 = 8; + pub const OPT_LPR_SERVER: u8 = 9; + pub const OPT_IMPRESS_SERVER: u8 = 10; + pub const OPT_RESOURCE_LOCATION_SERVER: u8 = 11; + pub const OPT_HOST_NAME: u8 = 12; + pub const OPT_BOOT_FILE_SIZE: u8 = 13; + pub const OPT_MERIT_DUMP: u8 = 14; + pub const OPT_DOMAIN_NAME: u8 = 15; + pub const OPT_SWAP_SERVER: u8 = 16; + pub const OPT_ROOT_PATH: u8 = 17; + pub const OPT_EXTENSIONS_PATH: u8 = 18; + + // IP Layer Parameters per Host + pub const OPT_IP_FORWARDING: u8 = 19; + pub const OPT_NON_LOCAL_SOURCE_ROUTING: u8 = 20; + pub const OPT_POLICY_FILTER: u8 = 21; + pub const OPT_MAX_DATAGRAM_REASSEMBLY_SIZE: u8 = 22; + pub const OPT_DEFAULT_TTL: u8 = 23; + pub const OPT_PATH_MTU_AGING_TIMEOUT: u8 = 24; + pub const OPT_PATH_MTU_PLATEAU_TABLE: u8 = 25; + + // IP Layer Parameters per Interface + pub const OPT_INTERFACE_MTU: u8 = 26; + pub const OPT_ALL_SUBNETS_ARE_LOCAL: u8 = 27; + pub const OPT_BROADCAST_ADDRESS: u8 = 28; + pub const OPT_PERFORM_MASK_DISCOVERY: u8 = 29; + pub const OPT_MASK_SUPPLIER: u8 = 30; + pub const OPT_PERFORM_ROUTER_DISCOVERY: u8 = 31; + pub const OPT_ROUTER_SOLICITATION_ADDRESS: u8 = 32; + pub const OPT_STATIC_ROUTE: u8 = 33; + + // Link Layer Parameters per Interface + pub const OPT_TRAILER_ENCAPSULATION: u8 = 34; + pub const OPT_ARP_CACHE_TIMEOUT: u8 = 35; + pub const OPT_ETHERNET_ENCAPSULATION: u8 = 36; + + // TCP Parameters + pub const OPT_TCP_DEFAULT_TTL: u8 = 37; + pub const OPT_TCP_KEEPALIVE_INTERVAL: u8 = 38; + pub const OPT_TCP_KEEPALIVE_GARBAGE: u8 = 39; + + // Application and Service Parameters + pub const OPT_NIS_DOMAIN: u8 = 40; + pub const OPT_NIS_SERVERS: u8 = 41; + pub const OPT_NTP_SERVERS: u8 = 42; + pub const OPT_VENDOR_SPECIFIC_INFO: u8 = 43; + pub const OPT_NETBIOS_NAME_SERVER: u8 = 44; + pub const OPT_NETBIOS_DISTRIBUTION_SERVER: u8 = 45; + pub const OPT_NETBIOS_NODE_TYPE: u8 = 46; + pub const OPT_NETBIOS_SCOPE: u8 = 47; + pub const OPT_X_WINDOW_FONT_SERVER: u8 = 48; + pub const OPT_X_WINDOW_DISPLAY_MANAGER: u8 = 49; + pub const OPT_NIS_PLUS_DOMAIN: u8 = 64; + pub const OPT_NIS_PLUS_SERVERS: u8 = 65; + pub const OPT_MOBILE_IP_HOME_AGENT: u8 = 68; + pub const OPT_SMTP_SERVER: u8 = 69; + pub const OPT_POP3_SERVER: u8 = 70; + pub const OPT_NNTP_SERVER: u8 = 71; + pub const OPT_WWW_SERVER: u8 = 72; + pub const OPT_FINGER_SERVER: u8 = 73; + pub const OPT_IRC_SERVER: u8 = 74; + pub const OPT_STREETTALK_SERVER: u8 = 75; + pub const OPT_STDA_SERVER: u8 = 76; + + // DHCP Extensions + pub const OPT_REQUESTED_IP: u8 = 50; + pub const OPT_IP_LEASE_TIME: u8 = 51; + pub const OPT_OPTION_OVERLOAD: u8 = 52; + pub const OPT_TFTP_SERVER_NAME: u8 = 66; + pub const OPT_BOOTFILE_NAME: u8 = 67; + pub const OPT_DHCP_MESSAGE_TYPE: u8 = 53; + pub const OPT_SERVER_IDENTIFIER: u8 = 54; + pub const OPT_PARAMETER_REQUEST_LIST: u8 = 55; + pub const OPT_MESSAGE: u8 = 56; + pub const OPT_MAX_DHCP_MESSAGE_SIZE: u8 = 57; + pub const OPT_RENEWAL_TIME_VALUE: u8 = 58; + pub const OPT_REBINDING_TIME_VALUE: u8 = 59; + pub const OPT_VENDOR_CLASS_ID: u8 = 60; + pub const OPT_CLIENT_ID: u8 = 61; +} + +impl<T: AsRef<[u8]>> Packet<T> { + /// Imbue a raw octet buffer with DHCP packet structure. + pub const fn new_unchecked(buffer: T) -> Packet<T> { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Packet<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + /// + /// [set_header_len]: #method.set_header_len + pub fn check_len(&self) -> Result<()> { + let len = self.buffer.as_ref().len(); + if len < field::MAGIC_NUMBER.end { + Err(Error) + } else { + Ok(()) + } + } + + /// Consume the packet, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Returns the operation code of this packet. + pub fn opcode(&self) -> OpCode { + let data = self.buffer.as_ref(); + OpCode::from(data[field::OP]) + } + + /// Returns the hardware protocol type (e.g. ethernet). + pub fn hardware_type(&self) -> Hardware { + let data = self.buffer.as_ref(); + Hardware::from(u16::from(data[field::HTYPE])) + } + + /// Returns the length of a hardware address in bytes (e.g. 6 for ethernet). + pub fn hardware_len(&self) -> u8 { + self.buffer.as_ref()[field::HLEN] + } + + /// Returns the transaction ID. + /// + /// The transaction ID (called `xid` in the specification) is a random number used to + /// associate messages and responses between client and server. The number is chosen by + /// the client. + pub fn transaction_id(&self) -> u32 { + let field = &self.buffer.as_ref()[field::XID]; + NetworkEndian::read_u32(field) + } + + /// Returns the hardware address of the client (called `chaddr` in the specification). + /// + /// Only ethernet is supported by `smoltcp`, so this functions returns + /// an `EthernetAddress`. + pub fn client_hardware_address(&self) -> EthernetAddress { + let field = &self.buffer.as_ref()[field::CHADDR]; + EthernetAddress::from_bytes(field) + } + + /// Returns the value of the `hops` field. + /// + /// The `hops` field is set to zero by clients and optionally used by relay agents. + pub fn hops(&self) -> u8 { + self.buffer.as_ref()[field::HOPS] + } + + /// Returns the value of the `secs` field. + /// + /// The secs field is filled by clients and describes the number of seconds elapsed + /// since client began process. + pub fn secs(&self) -> u16 { + let field = &self.buffer.as_ref()[field::SECS]; + NetworkEndian::read_u16(field) + } + + /// Returns the value of the `magic cookie` field in the DHCP options. + /// + /// This field should be always be `0x63825363`. + pub fn magic_number(&self) -> u32 { + let field = &self.buffer.as_ref()[field::MAGIC_NUMBER]; + NetworkEndian::read_u32(field) + } + + /// Returns the Ipv4 address of the client, zero if not set. + /// + /// This corresponds to the `ciaddr` field in the DHCP specification. According to it, + /// this field is “only filled in if client is in `BOUND`, `RENEW` or `REBINDING` state + /// and can respond to ARP requests”. + pub fn client_ip(&self) -> Ipv4Address { + let field = &self.buffer.as_ref()[field::CIADDR]; + Ipv4Address::from_bytes(field) + } + + /// Returns the value of the `yiaddr` field, zero if not set. + pub fn your_ip(&self) -> Ipv4Address { + let field = &self.buffer.as_ref()[field::YIADDR]; + Ipv4Address::from_bytes(field) + } + + /// Returns the value of the `siaddr` field, zero if not set. + pub fn server_ip(&self) -> Ipv4Address { + let field = &self.buffer.as_ref()[field::SIADDR]; + Ipv4Address::from_bytes(field) + } + + /// Returns the value of the `giaddr` field, zero if not set. + pub fn relay_agent_ip(&self) -> Ipv4Address { + let field = &self.buffer.as_ref()[field::GIADDR]; + Ipv4Address::from_bytes(field) + } + + pub fn flags(&self) -> Flags { + let field = &self.buffer.as_ref()[field::FLAGS]; + Flags::from_bits_truncate(NetworkEndian::read_u16(field)) + } + + /// Return an iterator over the options. + #[inline] + pub fn options(&self) -> impl Iterator<Item = DhcpOption<'_>> + '_ { + let mut buf = &self.buffer.as_ref()[field::OPTIONS]; + iter::from_fn(move || { + loop { + match buf.first().copied() { + // No more options, return. + None => return None, + Some(field::OPT_END) => return None, + + // Skip padding. + Some(field::OPT_PAD) => buf = &buf[1..], + Some(kind) => { + if buf.len() < 2 { + return None; + } + + let len = buf[1] as usize; + + if buf.len() < 2 + len { + return None; + } + + let opt = DhcpOption { + kind, + data: &buf[2..2 + len], + }; + + buf = &buf[2 + len..]; + return Some(opt); + } + } + } + }) + } + + pub fn get_sname(&self) -> Result<&str> { + let data = &self.buffer.as_ref()[field::SNAME]; + let len = data.iter().position(|&x| x == 0).ok_or(Error)?; + if len == 0 { + return Err(Error); + } + + let data = core::str::from_utf8(&data[..len]).map_err(|_| Error)?; + Ok(data) + } + + pub fn get_boot_file(&self) -> Result<&str> { + let data = &self.buffer.as_ref()[field::FILE]; + let len = data.iter().position(|&x| x == 0).ok_or(Error)?; + if len == 0 { + return Err(Error); + } + let data = core::str::from_utf8(&data[..len]).map_err(|_| Error)?; + Ok(data) + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Sets the optional `sname` (“server name”) and `file` (“boot file name”) fields to zero. + /// + /// The fields are not commonly used, so we set their value always to zero. **This method + /// must be called when creating a packet, otherwise the emitted values for these fields + /// are undefined!** + pub fn set_sname_and_boot_file_to_zero(&mut self) { + let data = self.buffer.as_mut(); + for byte in &mut data[field::SNAME] { + *byte = 0; + } + for byte in &mut data[field::FILE] { + *byte = 0; + } + } + + /// Sets the `OpCode` for the packet. + pub fn set_opcode(&mut self, value: OpCode) { + let data = self.buffer.as_mut(); + data[field::OP] = value.into(); + } + + /// Sets the hardware address type (only ethernet is supported). + pub fn set_hardware_type(&mut self, value: Hardware) { + let data = self.buffer.as_mut(); + let number: u16 = value.into(); + assert!(number <= u16::from(u8::max_value())); // TODO: Replace with TryFrom when it's stable + data[field::HTYPE] = number as u8; + } + + /// Sets the hardware address length. + /// + /// Only ethernet is supported, so this field should be set to the value `6`. + pub fn set_hardware_len(&mut self, value: u8) { + self.buffer.as_mut()[field::HLEN] = value; + } + + /// Sets the transaction ID. + /// + /// The transaction ID (called `xid` in the specification) is a random number used to + /// associate messages and responses between client and server. The number is chosen by + /// the client. + pub fn set_transaction_id(&mut self, value: u32) { + let field = &mut self.buffer.as_mut()[field::XID]; + NetworkEndian::write_u32(field, value) + } + + /// Sets the ethernet address of the client. + /// + /// Sets the `chaddr` field. + pub fn set_client_hardware_address(&mut self, value: EthernetAddress) { + let field = &mut self.buffer.as_mut()[field::CHADDR]; + field.copy_from_slice(value.as_bytes()); + } + + /// Sets the hops field. + /// + /// The `hops` field is set to zero by clients and optionally used by relay agents. + pub fn set_hops(&mut self, value: u8) { + self.buffer.as_mut()[field::HOPS] = value; + } + + /// Sets the `secs` field. + /// + /// The secs field is filled by clients and describes the number of seconds elapsed + /// since client began process. + pub fn set_secs(&mut self, value: u16) { + let field = &mut self.buffer.as_mut()[field::SECS]; + NetworkEndian::write_u16(field, value); + } + + /// Sets the value of the `magic cookie` field in the DHCP options. + /// + /// This field should be always be `0x63825363`. + pub fn set_magic_number(&mut self, value: u32) { + let field = &mut self.buffer.as_mut()[field::MAGIC_NUMBER]; + NetworkEndian::write_u32(field, value); + } + + /// Sets the Ipv4 address of the client. + /// + /// This corresponds to the `ciaddr` field in the DHCP specification. According to it, + /// this field is “only filled in if client is in `BOUND`, `RENEW` or `REBINDING` state + /// and can respond to ARP requests”. + pub fn set_client_ip(&mut self, value: Ipv4Address) { + let field = &mut self.buffer.as_mut()[field::CIADDR]; + field.copy_from_slice(value.as_bytes()); + } + + /// Sets the value of the `yiaddr` field. + pub fn set_your_ip(&mut self, value: Ipv4Address) { + let field = &mut self.buffer.as_mut()[field::YIADDR]; + field.copy_from_slice(value.as_bytes()); + } + + /// Sets the value of the `siaddr` field. + pub fn set_server_ip(&mut self, value: Ipv4Address) { + let field = &mut self.buffer.as_mut()[field::SIADDR]; + field.copy_from_slice(value.as_bytes()); + } + + /// Sets the value of the `giaddr` field. + pub fn set_relay_agent_ip(&mut self, value: Ipv4Address) { + let field = &mut self.buffer.as_mut()[field::GIADDR]; + field.copy_from_slice(value.as_bytes()); + } + + /// Sets the flags to the specified value. + pub fn set_flags(&mut self, val: Flags) { + let field = &mut self.buffer.as_mut()[field::FLAGS]; + NetworkEndian::write_u16(field, val.bits()); + } +} + +impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Packet<&'a mut T> { + /// Return a pointer to the options. + #[inline] + pub fn options_mut(&mut self) -> DhcpOptionWriter<'_> { + DhcpOptionWriter::new(&mut self.buffer.as_mut()[field::OPTIONS]) + } +} + +/// A high-level representation of a Dynamic Host Configuration Protocol packet. +/// +/// DHCP messages have the following layout (see [RFC 2131](https://tools.ietf.org/html/rfc2131) +/// for details): +/// +/// ```no_rust +/// 0 1 2 3 +/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | message_type | htype (N/A) | hlen (N/A) | hops | +/// +---------------+---------------+---------------+---------------+ +/// | transaction_id | +/// +-------------------------------+-------------------------------+ +/// | secs | flags | +/// +-------------------------------+-------------------------------+ +/// | client_ip | +/// +---------------------------------------------------------------+ +/// | your_ip | +/// +---------------------------------------------------------------+ +/// | server_ip | +/// +---------------------------------------------------------------+ +/// | relay_agent_ip | +/// +---------------------------------------------------------------+ +/// | | +/// | client_hardware_address | +/// | | +/// | | +/// +---------------------------------------------------------------+ +/// | | +/// | sname (N/A) | +/// +---------------------------------------------------------------+ +/// | | +/// | file (N/A) | +/// +---------------------------------------------------------------+ +/// | | +/// | options | +/// +---------------------------------------------------------------+ +/// ``` +/// +/// It is assumed that the access layer is Ethernet, so `htype` (the field representing the +/// hardware address type) is always set to `1`, and `hlen` (which represents the hardware address +/// length) is set to `6`. +/// +/// The `options` field has a variable length. +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Repr<'a> { + /// This field is also known as `op` in the RFC. It indicates the type of DHCP message this + /// packet represents. + pub message_type: MessageType, + /// This field is also known as `xid` in the RFC. It is a random number chosen by the client, + /// used by the client and server to associate messages and responses between a client and a + /// server. + pub transaction_id: u32, + /// seconds elapsed since client began address acquisition or renewal + /// process the DHCPREQUEST message MUST use the same value in the DHCP + /// message header's 'secs' field and be sent to the same IP broadcast + /// address as the original DHCPDISCOVER message. + pub secs: u16, + /// This field is also known as `chaddr` in the RFC and for networks where the access layer is + /// ethernet, it is the client MAC address. + pub client_hardware_address: EthernetAddress, + /// This field is also known as `ciaddr` in the RFC. It is only filled in if client is in + /// BOUND, RENEW or REBINDING state and can respond to ARP requests. + pub client_ip: Ipv4Address, + /// This field is also known as `yiaddr` in the RFC. + pub your_ip: Ipv4Address, + /// This field is also known as `siaddr` in the RFC. It may be set by the server in DHCPOFFER + /// and DHCPACK messages, and represent the address of the next server to use in bootstrap. + pub server_ip: Ipv4Address, + /// Default gateway + pub router: Option<Ipv4Address>, + /// This field comes from a corresponding DhcpOption. + pub subnet_mask: Option<Ipv4Address>, + /// This field is also known as `giaddr` in the RFC. In order to allow DHCP clients on subnets + /// not directly served by DHCP servers to communicate with DHCP servers, DHCP relay agents can + /// be installed on these subnets. The DHCP client broadcasts on the local link; the relay + /// agent receives the broadcast and transmits it to one or more DHCP servers using unicast. + /// The relay agent stores its own IP address in the `relay_agent_ip` field of the DHCP packet. + /// The DHCP server uses the `relay_agent_ip` to determine the subnet on which the relay agent + /// received the broadcast, and allocates an IP address on that subnet. When the DHCP server + /// replies to the client, it sends the reply to the `relay_agent_ip` address, again using + /// unicast. The relay agent then retransmits the response on the local network + pub relay_agent_ip: Ipv4Address, + /// Broadcast flags. It can be set in DHCPDISCOVER, DHCPINFORM and DHCPREQUEST message if the + /// client requires the response to be broadcasted. + pub broadcast: bool, + /// The "requested IP address" option. It can be used by clients in DHCPREQUEST or DHCPDISCOVER + /// messages, or by servers in DHCPDECLINE messages. + pub requested_ip: Option<Ipv4Address>, + /// The "client identifier" option. + /// + /// The 'client identifier' is an opaque key, not to be interpreted by the server; for example, + /// the 'client identifier' may contain a hardware address, identical to the contents of the + /// 'chaddr' field, or it may contain another type of identifier, such as a DNS name. The + /// 'client identifier' chosen by a DHCP client MUST be unique to that client within the subnet + /// to which the client is attached. If the client uses a 'client identifier' in one message, + /// it MUST use that same identifier in all subsequent messages, to ensure that all servers + /// correctly identify the client. + pub client_identifier: Option<EthernetAddress>, + /// The "server identifier" option. It is used both to identify a DHCP server + /// in a DHCP message and as a destination address from clients to servers. + pub server_identifier: Option<Ipv4Address>, + /// The parameter request list informs the server about which configuration parameters + /// the client is interested in. + pub parameter_request_list: Option<&'a [u8]>, + /// DNS servers + pub dns_servers: Option<Vec<Ipv4Address, MAX_DNS_SERVER_COUNT>>, + /// The maximum size dhcp packet the interface can receive + pub max_size: Option<u16>, + /// The DHCP IP lease duration, specified in seconds. + pub lease_duration: Option<u32>, + /// The DHCP IP renew duration (T1 interval), in seconds, if specified in the packet. + pub renew_duration: Option<u32>, + /// The DHCP IP rebind duration (T2 interval), in seconds, if specified in the packet. + pub rebind_duration: Option<u32>, + /// When returned from [`Repr::parse`], this field will be `None`. + /// However, when calling [`Repr::emit`], this field should contain only + /// additional DHCP options not known to smoltcp. + pub additional_options: &'a [DhcpOption<'a>], +} + +impl<'a> Repr<'a> { + /// Return the length of a packet that will be emitted from this high-level representation. + pub fn buffer_len(&self) -> usize { + let mut len = field::OPTIONS.start; + // message type and end-of-options options + len += 3 + 1; + if self.requested_ip.is_some() { + len += 6; + } + if self.client_identifier.is_some() { + len += 9; + } + if self.server_identifier.is_some() { + len += 6; + } + if self.max_size.is_some() { + len += 4; + } + if self.router.is_some() { + len += 6; + } + if self.subnet_mask.is_some() { + len += 6; + } + if self.lease_duration.is_some() { + len += 6; + } + if let Some(dns_servers) = &self.dns_servers { + len += 2; + len += dns_servers.iter().count() * core::mem::size_of::<u32>(); + } + if let Some(list) = self.parameter_request_list { + len += list.len() + 2; + } + for opt in self.additional_options { + len += 2 + opt.data.len() + } + + len + } + + /// Parse a DHCP packet and return a high-level representation. + pub fn parse<T>(packet: &'a Packet<&'a T>) -> Result<Self> + where + T: AsRef<[u8]> + ?Sized, + { + let transaction_id = packet.transaction_id(); + let client_hardware_address = packet.client_hardware_address(); + let client_ip = packet.client_ip(); + let your_ip = packet.your_ip(); + let server_ip = packet.server_ip(); + let relay_agent_ip = packet.relay_agent_ip(); + let secs = packet.secs(); + + // only ethernet is supported right now + match packet.hardware_type() { + Hardware::Ethernet => { + if packet.hardware_len() != 6 { + return Err(Error); + } + } + Hardware::Unknown(_) => return Err(Error), // unimplemented + } + + if packet.magic_number() != DHCP_MAGIC_NUMBER { + return Err(Error); + } + + let mut message_type = Err(Error); + let mut requested_ip = None; + let mut client_identifier = None; + let mut server_identifier = None; + let mut router = None; + let mut subnet_mask = None; + let mut parameter_request_list = None; + let mut dns_servers = None; + let mut max_size = None; + let mut lease_duration = None; + let mut renew_duration = None; + let mut rebind_duration = None; + + for option in packet.options() { + let data = option.data; + match (option.kind, data.len()) { + (field::OPT_DHCP_MESSAGE_TYPE, 1) => { + let value = MessageType::from(data[0]); + if value.opcode() == packet.opcode() { + message_type = Ok(value); + } + } + (field::OPT_REQUESTED_IP, 4) => { + requested_ip = Some(Ipv4Address::from_bytes(data)); + } + (field::OPT_CLIENT_ID, 7) => { + let hardware_type = Hardware::from(u16::from(data[0])); + if hardware_type != Hardware::Ethernet { + return Err(Error); + } + client_identifier = Some(EthernetAddress::from_bytes(&data[1..])); + } + (field::OPT_SERVER_IDENTIFIER, 4) => { + server_identifier = Some(Ipv4Address::from_bytes(data)); + } + (field::OPT_ROUTER, 4) => { + router = Some(Ipv4Address::from_bytes(data)); + } + (field::OPT_SUBNET_MASK, 4) => { + subnet_mask = Some(Ipv4Address::from_bytes(data)); + } + (field::OPT_MAX_DHCP_MESSAGE_SIZE, 2) => { + max_size = Some(u16::from_be_bytes([data[0], data[1]])); + } + (field::OPT_RENEWAL_TIME_VALUE, 4) => { + renew_duration = Some(u32::from_be_bytes([data[0], data[1], data[2], data[3]])) + } + (field::OPT_REBINDING_TIME_VALUE, 4) => { + rebind_duration = Some(u32::from_be_bytes([data[0], data[1], data[2], data[3]])) + } + (field::OPT_IP_LEASE_TIME, 4) => { + lease_duration = Some(u32::from_be_bytes([data[0], data[1], data[2], data[3]])) + } + (field::OPT_PARAMETER_REQUEST_LIST, _) => { + parameter_request_list = Some(data); + } + (field::OPT_DOMAIN_NAME_SERVER, _) => { + let mut servers = Vec::new(); + const IP_ADDR_BYTE_LEN: usize = 4; + let mut addrs = data.chunks_exact(IP_ADDR_BYTE_LEN); + for chunk in &mut addrs { + // We ignore push failures because that will only happen + // if we attempt to push more than 4 addresses, and the only + // solution to that is to support more addresses. + servers.push(Ipv4Address::from_bytes(chunk)).ok(); + } + dns_servers = Some(servers); + + if !addrs.remainder().is_empty() { + net_trace!("DHCP domain name servers contained invalid address"); + } + } + _ => {} + } + } + + let broadcast = packet.flags().contains(Flags::BROADCAST); + + Ok(Repr { + secs, + transaction_id, + client_hardware_address, + client_ip, + your_ip, + server_ip, + relay_agent_ip, + broadcast, + requested_ip, + server_identifier, + router, + subnet_mask, + client_identifier, + parameter_request_list, + dns_servers, + max_size, + lease_duration, + renew_duration, + rebind_duration, + message_type: message_type?, + additional_options: &[], + }) + } + + /// Emit a high-level representation into a Dynamic Host + /// Configuration Protocol packet. + pub fn emit<T>(&self, packet: &mut Packet<&mut T>) -> Result<()> + where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { + packet.set_sname_and_boot_file_to_zero(); + packet.set_opcode(self.message_type.opcode()); + packet.set_hardware_type(Hardware::Ethernet); + packet.set_hardware_len(6); + packet.set_transaction_id(self.transaction_id); + packet.set_client_hardware_address(self.client_hardware_address); + packet.set_hops(0); + packet.set_secs(self.secs); + packet.set_magic_number(0x63825363); + packet.set_client_ip(self.client_ip); + packet.set_your_ip(self.your_ip); + packet.set_server_ip(self.server_ip); + packet.set_relay_agent_ip(self.relay_agent_ip); + + let mut flags = Flags::empty(); + if self.broadcast { + flags |= Flags::BROADCAST; + } + packet.set_flags(flags); + + { + let mut options = packet.options_mut(); + + options.emit(DhcpOption { + kind: field::OPT_DHCP_MESSAGE_TYPE, + data: &[self.message_type.into()], + })?; + + if let Some(val) = &self.client_identifier { + let mut data = [0; 7]; + data[0] = u16::from(Hardware::Ethernet) as u8; + data[1..].copy_from_slice(val.as_bytes()); + + options.emit(DhcpOption { + kind: field::OPT_CLIENT_ID, + data: &data, + })?; + } + + if let Some(val) = &self.server_identifier { + options.emit(DhcpOption { + kind: field::OPT_SERVER_IDENTIFIER, + data: val.as_bytes(), + })?; + } + + if let Some(val) = &self.router { + options.emit(DhcpOption { + kind: field::OPT_ROUTER, + data: val.as_bytes(), + })?; + } + if let Some(val) = &self.subnet_mask { + options.emit(DhcpOption { + kind: field::OPT_SUBNET_MASK, + data: val.as_bytes(), + })?; + } + if let Some(val) = &self.requested_ip { + options.emit(DhcpOption { + kind: field::OPT_REQUESTED_IP, + data: val.as_bytes(), + })?; + } + if let Some(val) = &self.max_size { + options.emit(DhcpOption { + kind: field::OPT_MAX_DHCP_MESSAGE_SIZE, + data: &val.to_be_bytes(), + })?; + } + if let Some(val) = &self.lease_duration { + options.emit(DhcpOption { + kind: field::OPT_IP_LEASE_TIME, + data: &val.to_be_bytes(), + })?; + } + if let Some(val) = &self.parameter_request_list { + options.emit(DhcpOption { + kind: field::OPT_PARAMETER_REQUEST_LIST, + data: val, + })?; + } + + if let Some(dns_servers) = &self.dns_servers { + const IP_SIZE: usize = core::mem::size_of::<u32>(); + let mut servers = [0; MAX_DNS_SERVER_COUNT * IP_SIZE]; + + let data_len = dns_servers + .iter() + .enumerate() + .inspect(|(i, ip)| { + servers[(i * IP_SIZE)..((i + 1) * IP_SIZE)].copy_from_slice(ip.as_bytes()); + }) + .count() + * IP_SIZE; + options.emit(DhcpOption { + kind: field::OPT_DOMAIN_NAME_SERVER, + data: &servers[..data_len], + })?; + } + + for option in self.additional_options { + options.emit(*option)?; + } + + options.end()?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::wire::Ipv4Address; + + const MAGIC_COOKIE: u32 = 0x63825363; + + static DISCOVER_BYTES: &[u8] = &[ + 0x01, 0x01, 0x06, 0x00, 0x00, 0x00, 0x3d, 0x1d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0b, + 0x82, 0x01, 0xfc, 0x42, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x82, 0x53, 0x63, + 0x35, 0x01, 0x01, 0x3d, 0x07, 0x01, 0x00, 0x0b, 0x82, 0x01, 0xfc, 0x42, 0x32, 0x04, 0x00, + 0x00, 0x00, 0x00, 0x39, 0x2, 0x5, 0xdc, 0x37, 0x04, 0x01, 0x03, 0x06, 0x2a, 0xff, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + + static ACK_DNS_SERVER_BYTES: &[u8] = &[ + 0x02, 0x01, 0x06, 0x00, 0xcc, 0x34, 0x75, 0xab, 0x00, 0x00, 0x80, 0x00, 0x0a, 0xff, 0x06, + 0x91, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0xff, 0x06, 0xfe, 0x34, 0x17, + 0xeb, 0xc9, 0xaa, 0x2f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x82, 0x53, 0x63, + 0x35, 0x01, 0x05, 0x36, 0x04, 0xa3, 0x01, 0x4a, 0x16, 0x01, 0x04, 0xff, 0xff, 0xff, 0x00, + 0x2b, 0x05, 0xdc, 0x03, 0x4e, 0x41, 0x50, 0x0f, 0x15, 0x6e, 0x61, 0x74, 0x2e, 0x70, 0x68, + 0x79, 0x73, 0x69, 0x63, 0x73, 0x2e, 0x6f, 0x78, 0x2e, 0x61, 0x63, 0x2e, 0x75, 0x6b, 0x00, + 0x03, 0x04, 0x0a, 0xff, 0x06, 0xfe, 0x06, 0x10, 0xa3, 0x01, 0x4a, 0x06, 0xa3, 0x01, 0x4a, + 0x07, 0xa3, 0x01, 0x4a, 0x03, 0xa3, 0x01, 0x4a, 0x04, 0x2c, 0x10, 0xa3, 0x01, 0x4a, 0x03, + 0xa3, 0x01, 0x4a, 0x04, 0xa3, 0x01, 0x4a, 0x06, 0xa3, 0x01, 0x4a, 0x07, 0x2e, 0x01, 0x08, + 0xff, + ]; + + static ACK_LEASE_TIME_BYTES: &[u8] = &[ + 0x02, 0x01, 0x06, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x0a, 0x22, 0x10, 0x0b, 0x0a, 0x22, 0x10, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x04, 0x91, + 0x62, 0xd2, 0xa8, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x82, 0x53, 0x63, + 0x35, 0x01, 0x05, 0x36, 0x04, 0x0a, 0x22, 0x10, 0x0a, 0x33, 0x04, 0x00, 0x00, 0x02, 0x56, + 0x01, 0x04, 0xff, 0xff, 0xff, 0x00, 0x03, 0x04, 0x0a, 0x22, 0x10, 0x0a, 0xff, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + + const IP_NULL: Ipv4Address = Ipv4Address([0, 0, 0, 0]); + const CLIENT_MAC: EthernetAddress = EthernetAddress([0x0, 0x0b, 0x82, 0x01, 0xfc, 0x42]); + const DHCP_SIZE: u16 = 1500; + + #[test] + fn test_deconstruct_discover() { + let packet = Packet::new_unchecked(DISCOVER_BYTES); + assert_eq!(packet.magic_number(), MAGIC_COOKIE); + assert_eq!(packet.opcode(), OpCode::Request); + assert_eq!(packet.hardware_type(), Hardware::Ethernet); + assert_eq!(packet.hardware_len(), 6); + assert_eq!(packet.hops(), 0); + assert_eq!(packet.transaction_id(), 0x3d1d); + assert_eq!(packet.secs(), 0); + assert_eq!(packet.client_ip(), IP_NULL); + assert_eq!(packet.your_ip(), IP_NULL); + assert_eq!(packet.server_ip(), IP_NULL); + assert_eq!(packet.relay_agent_ip(), IP_NULL); + assert_eq!(packet.client_hardware_address(), CLIENT_MAC); + + let mut options = packet.options(); + assert_eq!( + options.next(), + Some(DhcpOption { + kind: field::OPT_DHCP_MESSAGE_TYPE, + data: &[0x01] + }) + ); + assert_eq!( + options.next(), + Some(DhcpOption { + kind: field::OPT_CLIENT_ID, + data: &[0x01, 0x00, 0x0b, 0x82, 0x01, 0xfc, 0x42], + }) + ); + assert_eq!( + options.next(), + Some(DhcpOption { + kind: field::OPT_REQUESTED_IP, + data: &[0x00, 0x00, 0x00, 0x00], + }) + ); + assert_eq!( + options.next(), + Some(DhcpOption { + kind: field::OPT_MAX_DHCP_MESSAGE_SIZE, + data: &DHCP_SIZE.to_be_bytes(), + }) + ); + assert_eq!( + options.next(), + Some(DhcpOption { + kind: field::OPT_PARAMETER_REQUEST_LIST, + data: &[1, 3, 6, 42] + }) + ); + assert_eq!(options.next(), None); + } + + #[test] + fn test_construct_discover() { + let mut bytes = vec![0xa5; 276]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_magic_number(MAGIC_COOKIE); + packet.set_sname_and_boot_file_to_zero(); + packet.set_opcode(OpCode::Request); + packet.set_hardware_type(Hardware::Ethernet); + packet.set_hardware_len(6); + packet.set_hops(0); + packet.set_transaction_id(0x3d1d); + packet.set_secs(0); + packet.set_flags(Flags::empty()); + packet.set_client_ip(IP_NULL); + packet.set_your_ip(IP_NULL); + packet.set_server_ip(IP_NULL); + packet.set_relay_agent_ip(IP_NULL); + packet.set_client_hardware_address(CLIENT_MAC); + + let mut options = packet.options_mut(); + + options + .emit(DhcpOption { + kind: field::OPT_DHCP_MESSAGE_TYPE, + data: &[0x01], + }) + .unwrap(); + options + .emit(DhcpOption { + kind: field::OPT_CLIENT_ID, + data: &[0x01, 0x00, 0x0b, 0x82, 0x01, 0xfc, 0x42], + }) + .unwrap(); + options + .emit(DhcpOption { + kind: field::OPT_REQUESTED_IP, + data: &[0x00, 0x00, 0x00, 0x00], + }) + .unwrap(); + options + .emit(DhcpOption { + kind: field::OPT_MAX_DHCP_MESSAGE_SIZE, + data: &DHCP_SIZE.to_be_bytes(), + }) + .unwrap(); + options + .emit(DhcpOption { + kind: field::OPT_PARAMETER_REQUEST_LIST, + data: &[1, 3, 6, 42], + }) + .unwrap(); + options.end().unwrap(); + + let packet = &mut packet.into_inner()[..]; + for byte in &mut packet[269..276] { + *byte = 0; // padding bytes + } + + assert_eq!(packet, DISCOVER_BYTES); + } + + const fn offer_repr() -> Repr<'static> { + Repr { + message_type: MessageType::Offer, + transaction_id: 0x3d1d, + client_hardware_address: CLIENT_MAC, + client_ip: IP_NULL, + your_ip: IP_NULL, + server_ip: IP_NULL, + router: Some(IP_NULL), + subnet_mask: Some(IP_NULL), + relay_agent_ip: IP_NULL, + secs: 0, + broadcast: false, + requested_ip: None, + client_identifier: Some(CLIENT_MAC), + server_identifier: None, + parameter_request_list: None, + dns_servers: None, + max_size: None, + renew_duration: None, + rebind_duration: None, + lease_duration: Some(0xffff_ffff), // Infinite lease + additional_options: &[], + } + } + + const fn discover_repr() -> Repr<'static> { + Repr { + message_type: MessageType::Discover, + transaction_id: 0x3d1d, + client_hardware_address: CLIENT_MAC, + client_ip: IP_NULL, + your_ip: IP_NULL, + server_ip: IP_NULL, + router: None, + subnet_mask: None, + relay_agent_ip: IP_NULL, + broadcast: false, + secs: 0, + max_size: Some(DHCP_SIZE), + renew_duration: None, + rebind_duration: None, + lease_duration: None, + requested_ip: Some(IP_NULL), + client_identifier: Some(CLIENT_MAC), + server_identifier: None, + parameter_request_list: Some(&[1, 3, 6, 42]), + dns_servers: None, + additional_options: &[], + } + } + + #[test] + fn test_parse_discover() { + let packet = Packet::new_unchecked(DISCOVER_BYTES); + let repr = Repr::parse(&packet).unwrap(); + assert_eq!(repr, discover_repr()); + } + + #[test] + fn test_emit_discover() { + let repr = discover_repr(); + let mut bytes = vec![0xa5; repr.buffer_len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + repr.emit(&mut packet).unwrap(); + let packet = &*packet.into_inner(); + let packet_len = packet.len(); + assert_eq!(packet, &DISCOVER_BYTES[..packet_len]); + for byte in &DISCOVER_BYTES[packet_len..] { + assert_eq!(*byte, 0); // padding bytes + } + } + + #[test] + fn test_emit_offer() { + let repr = offer_repr(); + let mut bytes = vec![0xa5; repr.buffer_len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + repr.emit(&mut packet).unwrap(); + } + + #[test] + fn test_emit_offer_dns() { + let repr = { + let mut repr = offer_repr(); + repr.dns_servers = Some( + Vec::from_slice(&[ + Ipv4Address([163, 1, 74, 6]), + Ipv4Address([163, 1, 74, 7]), + Ipv4Address([163, 1, 74, 3]), + ]) + .unwrap(), + ); + repr + }; + let mut bytes = vec![0xa5; repr.buffer_len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + repr.emit(&mut packet).unwrap(); + + let packet = Packet::new_unchecked(&bytes); + let repr_parsed = Repr::parse(&packet).unwrap(); + + assert_eq!( + repr_parsed.dns_servers, + Some( + Vec::from_slice(&[ + Ipv4Address([163, 1, 74, 6]), + Ipv4Address([163, 1, 74, 7]), + Ipv4Address([163, 1, 74, 3]), + ]) + .unwrap() + ) + ); + } + + #[test] + fn test_emit_dhcp_option() { + static DATA: &[u8] = &[1, 3, 6]; + let dhcp_option = DhcpOption { + kind: field::OPT_PARAMETER_REQUEST_LIST, + data: DATA, + }; + + let mut bytes = vec![0xa5; 5]; + let mut writer = DhcpOptionWriter::new(&mut bytes); + writer.emit(dhcp_option).unwrap(); + + assert_eq!( + &bytes[0..2], + &[field::OPT_PARAMETER_REQUEST_LIST, DATA.len() as u8] + ); + assert_eq!(&bytes[2..], DATA); + } + + #[test] + fn test_parse_ack_dns_servers() { + let packet = Packet::new_unchecked(ACK_DNS_SERVER_BYTES); + let repr = Repr::parse(&packet).unwrap(); + + // The packet described by ACK_BYTES advertises 4 DNS servers + // Here we ensure that we correctly parse the first 3 into our fixed + // length-3 array (see issue #305) + assert_eq!( + repr.dns_servers, + Some( + Vec::from_slice(&[ + Ipv4Address([163, 1, 74, 6]), + Ipv4Address([163, 1, 74, 7]), + Ipv4Address([163, 1, 74, 3]) + ]) + .unwrap() + ) + ); + } + + #[test] + fn test_parse_ack_lease_duration() { + let packet = Packet::new_unchecked(ACK_LEASE_TIME_BYTES); + let repr = Repr::parse(&packet).unwrap(); + + // Verify that the lease time in the ACK is properly parsed. The packet contains a lease + // duration of 598s. + assert_eq!(repr.lease_duration, Some(598)); + } +} diff --git a/src/wire/dns.rs b/src/wire/dns.rs new file mode 100644 index 0000000..2c9c10d --- /dev/null +++ b/src/wire/dns.rs @@ -0,0 +1,793 @@ +#![allow(dead_code)] + +use bitflags::bitflags; +use byteorder::{ByteOrder, NetworkEndian}; +use core::iter; +use core::iter::Iterator; + +use super::{Error, Result}; +#[cfg(feature = "proto-ipv4")] +use crate::wire::Ipv4Address; +#[cfg(feature = "proto-ipv6")] +use crate::wire::Ipv6Address; + +enum_with_unknown! { + /// DNS OpCodes + pub enum Opcode(u8) { + Query = 0x00, + Status = 0x01, + } +} +enum_with_unknown! { + /// DNS OpCodes + pub enum Rcode(u8) { + NoError = 0x00, + FormErr = 0x01, + ServFail = 0x02, + NXDomain = 0x03, + NotImp = 0x04, + Refused = 0x05, + YXDomain = 0x06, + YXRRSet = 0x07, + NXRRSet = 0x08, + NotAuth = 0x09, + NotZone = 0x0a, + } +} + +enum_with_unknown! { + /// DNS record types + pub enum Type(u16) { + A = 0x0001, + Ns = 0x0002, + Cname = 0x0005, + Soa = 0x0006, + Aaaa = 0x001c, + } +} + +bitflags! { + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct Flags: u16 { + const RESPONSE = 0b1000_0000_0000_0000; + const AUTHORITATIVE = 0b0000_0100_0000_0000; + const TRUNCATED = 0b0000_0010_0000_0000; + const RECURSION_DESIRED = 0b0000_0001_0000_0000; + const RECURSION_AVAILABLE = 0b0000_0000_1000_0000; + const AUTHENTIC_DATA = 0b0000_0000_0010_0000; + const CHECK_DISABLED = 0b0000_0000_0001_0000; + } +} + +mod field { + use crate::wire::field::*; + + pub const ID: Field = 0..2; + pub const FLAGS: Field = 2..4; + pub const QDCOUNT: Field = 4..6; + pub const ANCOUNT: Field = 6..8; + pub const NSCOUNT: Field = 8..10; + pub const ARCOUNT: Field = 10..12; + + pub const HEADER_END: usize = 12; +} + +// DNS class IN (Internet) +const CLASS_IN: u16 = 1; + +/// A read/write wrapper around a DNS packet buffer. +#[derive(Debug, PartialEq, Eq)] +pub struct Packet<T: AsRef<[u8]>> { + buffer: T, +} + +impl<T: AsRef<[u8]>> Packet<T> { + /// Imbue a raw octet buffer with DNS packet structure. + pub const fn new_unchecked(buffer: T) -> Packet<T> { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Packet<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is smaller than + /// the header length. + pub fn check_len(&self) -> Result<()> { + let len = self.buffer.as_ref().len(); + if len < field::HEADER_END { + Err(Error) + } else { + Ok(()) + } + } + + /// Consume the packet, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + pub fn payload(&self) -> &[u8] { + &self.buffer.as_ref()[field::HEADER_END..] + } + + pub fn transaction_id(&self) -> u16 { + let field = &self.buffer.as_ref()[field::ID]; + NetworkEndian::read_u16(field) + } + + pub fn flags(&self) -> Flags { + let field = &self.buffer.as_ref()[field::FLAGS]; + Flags::from_bits_truncate(NetworkEndian::read_u16(field)) + } + + pub fn opcode(&self) -> Opcode { + let field = &self.buffer.as_ref()[field::FLAGS]; + let flags = NetworkEndian::read_u16(field); + Opcode::from((flags >> 11 & 0xF) as u8) + } + + pub fn rcode(&self) -> Rcode { + let field = &self.buffer.as_ref()[field::FLAGS]; + let flags = NetworkEndian::read_u16(field); + Rcode::from((flags & 0xF) as u8) + } + + pub fn question_count(&self) -> u16 { + let field = &self.buffer.as_ref()[field::QDCOUNT]; + NetworkEndian::read_u16(field) + } + + pub fn answer_record_count(&self) -> u16 { + let field = &self.buffer.as_ref()[field::ANCOUNT]; + NetworkEndian::read_u16(field) + } + + pub fn authority_record_count(&self) -> u16 { + let field = &self.buffer.as_ref()[field::NSCOUNT]; + NetworkEndian::read_u16(field) + } + + pub fn additional_record_count(&self) -> u16 { + let field = &self.buffer.as_ref()[field::ARCOUNT]; + NetworkEndian::read_u16(field) + } + + /// Parse part of a name from `bytes`, following pointers if any. + pub fn parse_name<'a>(&'a self, mut bytes: &'a [u8]) -> impl Iterator<Item = Result<&'a [u8]>> { + let mut packet = self.buffer.as_ref(); + + iter::from_fn(move || loop { + if bytes.is_empty() { + return Some(Err(Error)); + } + match bytes[0] { + 0x00 => return None, + x if x & 0xC0 == 0x00 => { + let len = (x & 0x3F) as usize; + if bytes.len() < 1 + len { + return Some(Err(Error)); + } + let label = &bytes[1..1 + len]; + bytes = &bytes[1 + len..]; + return Some(Ok(label)); + } + x if x & 0xC0 == 0xC0 => { + if bytes.len() < 2 { + return Some(Err(Error)); + } + let y = bytes[1]; + let ptr = ((x & 0x3F) as usize) << 8 | (y as usize); + if packet.len() <= ptr { + return Some(Err(Error)); + } + + // RFC1035 says: "In this scheme, an entire domain name or a list of labels at + // the end of a domain name is replaced with a pointer to a ***prior*** occurrence + // of the same name. + // + // Is it unclear if this means the pointer MUST point backwards in the packet or not. Either way, + // pointers that don't point backwards are never seen in the fields, so use this to check that + // there are no pointer loops. + + // Split packet into parts before and after `ptr`. + // parse the part after, keep only the part before in `packet`. This ensure we never + // parse the same byte twice, therefore eliminating pointer loops. + + bytes = &packet[ptr..]; + packet = &packet[..ptr]; + } + _ => return Some(Err(Error)), + } + }) + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + pub fn payload_mut(&mut self) -> &mut [u8] { + let data = self.buffer.as_mut(); + &mut data[field::HEADER_END..] + } + + pub fn set_transaction_id(&mut self, val: u16) { + let field = &mut self.buffer.as_mut()[field::ID]; + NetworkEndian::write_u16(field, val) + } + + pub fn set_flags(&mut self, val: Flags) { + let field = &mut self.buffer.as_mut()[field::FLAGS]; + let mask = Flags::all().bits; + let old = NetworkEndian::read_u16(field); + NetworkEndian::write_u16(field, (old & !mask) | val.bits()); + } + + pub fn set_opcode(&mut self, val: Opcode) { + let field = &mut self.buffer.as_mut()[field::FLAGS]; + let mask = 0x3800; + let val: u8 = val.into(); + let val = (val as u16) << 11; + let old = NetworkEndian::read_u16(field); + NetworkEndian::write_u16(field, (old & !mask) | val); + } + + pub fn set_question_count(&mut self, val: u16) { + let field = &mut self.buffer.as_mut()[field::QDCOUNT]; + NetworkEndian::write_u16(field, val) + } + pub fn set_answer_record_count(&mut self, val: u16) { + let field = &mut self.buffer.as_mut()[field::ANCOUNT]; + NetworkEndian::write_u16(field, val) + } + pub fn set_authority_record_count(&mut self, val: u16) { + let field = &mut self.buffer.as_mut()[field::NSCOUNT]; + NetworkEndian::write_u16(field, val) + } + pub fn set_additional_record_count(&mut self, val: u16) { + let field = &mut self.buffer.as_mut()[field::ARCOUNT]; + NetworkEndian::write_u16(field, val) + } +} + +/// Parse part of a name from `bytes`, not following pointers. +/// Returns the unused part of `bytes`, and the pointer offset if the sequence ends with a pointer. +fn parse_name_part<'a>( + mut bytes: &'a [u8], + mut f: impl FnMut(&'a [u8]), +) -> Result<(&'a [u8], Option<usize>)> { + loop { + let x = *bytes.first().ok_or(Error)?; + bytes = &bytes[1..]; + match x { + 0x00 => return Ok((bytes, None)), + x if x & 0xC0 == 0x00 => { + let len = (x & 0x3F) as usize; + let label = bytes.get(..len).ok_or(Error)?; + bytes = &bytes[len..]; + f(label); + } + x if x & 0xC0 == 0xC0 => { + let y = *bytes.first().ok_or(Error)?; + bytes = &bytes[1..]; + + let ptr = ((x & 0x3F) as usize) << 8 | (y as usize); + return Ok((bytes, Some(ptr))); + } + _ => return Err(Error), + } + } +} + +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Question<'a> { + pub name: &'a [u8], + pub type_: Type, +} + +impl<'a> Question<'a> { + pub fn parse(buffer: &'a [u8]) -> Result<(&'a [u8], Question<'a>)> { + let (rest, _) = parse_name_part(buffer, |_| ())?; + let name = &buffer[..buffer.len() - rest.len()]; + + if rest.len() < 4 { + return Err(Error); + } + let type_ = NetworkEndian::read_u16(&rest[0..2]).into(); + let class = NetworkEndian::read_u16(&rest[2..4]); + let rest = &rest[4..]; + + if class != CLASS_IN { + return Err(Error); + } + + Ok((rest, Question { name, type_ })) + } + + /// Return the length of a packet that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + self.name.len() + 4 + } + + /// Emit a high-level representation into a DNS packet. + pub fn emit(&self, packet: &mut [u8]) { + packet[..self.name.len()].copy_from_slice(self.name); + let rest = &mut packet[self.name.len()..]; + NetworkEndian::write_u16(&mut rest[0..2], self.type_.into()); + NetworkEndian::write_u16(&mut rest[2..4], CLASS_IN); + } +} + +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Record<'a> { + pub name: &'a [u8], + pub ttl: u32, + pub data: RecordData<'a>, +} + +impl<'a> RecordData<'a> { + pub fn parse(type_: Type, data: &'a [u8]) -> Result<RecordData<'a>> { + match type_ { + #[cfg(feature = "proto-ipv4")] + Type::A => { + if data.len() != 4 { + return Err(Error); + } + Ok(RecordData::A(Ipv4Address::from_bytes(data))) + } + #[cfg(feature = "proto-ipv6")] + Type::Aaaa => { + if data.len() != 16 { + return Err(Error); + } + Ok(RecordData::Aaaa(Ipv6Address::from_bytes(data))) + } + Type::Cname => Ok(RecordData::Cname(data)), + x => Ok(RecordData::Other(x, data)), + } + } +} + +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum RecordData<'a> { + #[cfg(feature = "proto-ipv4")] + A(Ipv4Address), + #[cfg(feature = "proto-ipv6")] + Aaaa(Ipv6Address), + Cname(&'a [u8]), + Other(Type, &'a [u8]), +} + +impl<'a> Record<'a> { + pub fn parse(buffer: &'a [u8]) -> Result<(&'a [u8], Record<'a>)> { + let (rest, _) = parse_name_part(buffer, |_| ())?; + let name = &buffer[..buffer.len() - rest.len()]; + + if rest.len() < 10 { + return Err(Error); + } + let type_ = NetworkEndian::read_u16(&rest[0..2]).into(); + let class = NetworkEndian::read_u16(&rest[2..4]); + let ttl = NetworkEndian::read_u32(&rest[4..8]); + let len = NetworkEndian::read_u16(&rest[8..10]) as usize; + let rest = &rest[10..]; + + if class != CLASS_IN { + return Err(Error); + } + + let data = rest.get(..len).ok_or(Error)?; + let rest = &rest[len..]; + + Ok(( + rest, + Record { + name, + ttl, + data: RecordData::parse(type_, data)?, + }, + )) + } +} + +/// High-level DNS packet representation. +/// +/// Currently only supports query packets. +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Repr<'a> { + pub transaction_id: u16, + pub opcode: Opcode, + pub flags: Flags, + pub question: Question<'a>, +} + +impl<'a> Repr<'a> { + /// Return the length of a packet that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + field::HEADER_END + self.question.buffer_len() + } + + /// Emit a high-level representation into a DNS packet. + pub fn emit<T: ?Sized>(&self, packet: &mut Packet<&mut T>) + where + T: AsRef<[u8]> + AsMut<[u8]>, + { + packet.set_transaction_id(self.transaction_id); + packet.set_flags(self.flags); + packet.set_opcode(self.opcode); + packet.set_question_count(1); + packet.set_answer_record_count(0); + packet.set_authority_record_count(0); + packet.set_additional_record_count(0); + self.question.emit(packet.payload_mut()) + } +} + +#[cfg(feature = "proto-ipv4")] // tests assume ipv4 +#[cfg(test)] +mod test { + use super::*; + use std::vec::Vec; + + #[test] + fn test_parse_name() { + let bytes = &[ + 0x78, 0x6c, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x03, 0x77, + 0x77, 0x77, 0x08, 0x66, 0x61, 0x63, 0x65, 0x62, 0x6f, 0x6f, 0x6b, 0x03, 0x63, 0x6f, + 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, + 0x05, 0xf3, 0x00, 0x11, 0x09, 0x73, 0x74, 0x61, 0x72, 0x2d, 0x6d, 0x69, 0x6e, 0x69, + 0x04, 0x63, 0x31, 0x30, 0x72, 0xc0, 0x10, 0xc0, 0x2e, 0x00, 0x01, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x05, 0x00, 0x04, 0x1f, 0x0d, 0x53, 0x24, + ]; + let packet = Packet::new_unchecked(bytes); + + let name_vec = |bytes| { + let mut v = Vec::new(); + packet + .parse_name(bytes) + .try_for_each(|label| label.map(|label| v.push(label))) + .map(|_| v) + }; + + //assert_eq!(parse_name_len(bytes, 0x0c), Ok(18)); + assert_eq!( + name_vec(&bytes[0x0c..]), + Ok(vec![&b"www"[..], &b"facebook"[..], &b"com"[..]]) + ); + //assert_eq!(parse_name_len(bytes, 0x22), Ok(2)); + assert_eq!( + name_vec(&bytes[0x22..]), + Ok(vec![&b"www"[..], &b"facebook"[..], &b"com"[..]]) + ); + //assert_eq!(parse_name_len(bytes, 0x2e), Ok(17)); + assert_eq!( + name_vec(&bytes[0x2e..]), + Ok(vec![ + &b"star-mini"[..], + &b"c10r"[..], + &b"facebook"[..], + &b"com"[..] + ]) + ); + //assert_eq!(parse_name_len(bytes, 0x3f), Ok(2)); + assert_eq!( + name_vec(&bytes[0x3f..]), + Ok(vec![ + &b"star-mini"[..], + &b"c10r"[..], + &b"facebook"[..], + &b"com"[..] + ]) + ); + } + + struct Parsed<'a> { + packet: Packet<&'a [u8]>, + questions: Vec<Question<'a>>, + answers: Vec<Record<'a>>, + authorities: Vec<Record<'a>>, + additionals: Vec<Record<'a>>, + } + + impl<'a> Parsed<'a> { + fn parse(bytes: &'a [u8]) -> Result<Self> { + let packet = Packet::new_unchecked(bytes); + let mut questions = Vec::new(); + let mut answers = Vec::new(); + let mut authorities = Vec::new(); + let mut additionals = Vec::new(); + + let mut payload = &bytes[12..]; + + for _ in 0..packet.question_count() { + let (p, r) = Question::parse(payload)?; + questions.push(r); + payload = p; + } + for _ in 0..packet.answer_record_count() { + let (p, r) = Record::parse(payload)?; + answers.push(r); + payload = p; + } + for _ in 0..packet.authority_record_count() { + let (p, r) = Record::parse(payload)?; + authorities.push(r); + payload = p; + } + for _ in 0..packet.additional_record_count() { + let (p, r) = Record::parse(payload)?; + additionals.push(r); + payload = p; + } + + // Check that there are no bytes left + assert_eq!(payload.len(), 0); + + Ok(Parsed { + packet, + questions, + answers, + authorities, + additionals, + }) + } + } + + #[test] + fn test_parse_request() { + let p = Parsed::parse(&[ + 0x51, 0x84, 0x01, 0x20, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, + ]) + .unwrap(); + + assert_eq!(p.packet.transaction_id(), 0x5184); + assert_eq!( + p.packet.flags(), + Flags::RECURSION_DESIRED | Flags::AUTHENTIC_DATA + ); + assert_eq!(p.packet.opcode(), Opcode::Query); + assert_eq!(p.packet.question_count(), 1); + assert_eq!(p.packet.answer_record_count(), 0); + assert_eq!(p.packet.authority_record_count(), 0); + assert_eq!(p.packet.additional_record_count(), 0); + + assert_eq!(p.questions.len(), 1); + assert_eq!( + p.questions[0].name, + &[0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00] + ); + assert_eq!(p.questions[0].type_, Type::A); + + assert_eq!(p.answers.len(), 0); + assert_eq!(p.authorities.len(), 0); + assert_eq!(p.additionals.len(), 0); + } + + #[test] + fn test_parse_response() { + let p = Parsed::parse(&[ + 0x51, 0x84, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x06, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, + 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xca, 0x00, 0x04, 0xac, 0xd9, + 0xa8, 0xae, + ]) + .unwrap(); + + assert_eq!(p.packet.transaction_id(), 0x5184); + assert_eq!( + p.packet.flags(), + Flags::RESPONSE | Flags::RECURSION_DESIRED | Flags::RECURSION_AVAILABLE + ); + assert_eq!(p.packet.opcode(), Opcode::Query); + assert_eq!(p.packet.rcode(), Rcode::NoError); + assert_eq!(p.packet.question_count(), 1); + assert_eq!(p.packet.answer_record_count(), 1); + assert_eq!(p.packet.authority_record_count(), 0); + assert_eq!(p.packet.additional_record_count(), 0); + + assert_eq!( + p.questions[0].name, + &[0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00] + ); + assert_eq!(p.questions[0].type_, Type::A); + + assert_eq!(p.answers[0].name, &[0xc0, 0x0c]); + assert_eq!(p.answers[0].ttl, 202); + assert_eq!( + p.answers[0].data, + RecordData::A(Ipv4Address::new(0xac, 0xd9, 0xa8, 0xae)) + ); + } + + #[test] + fn test_parse_response_multiple_a() { + let p = Parsed::parse(&[ + 0x4b, 0x9e, 0x81, 0x80, 0x00, 0x01, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x09, 0x72, + 0x75, 0x73, 0x74, 0x2d, 0x6c, 0x61, 0x6e, 0x67, 0x03, 0x6f, 0x72, 0x67, 0x00, 0x00, + 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x09, 0x00, + 0x04, 0x0d, 0xe0, 0x77, 0x35, 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x04, 0x0d, 0xe0, 0x77, 0x28, 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x09, 0x00, 0x04, 0x0d, 0xe0, 0x77, 0x43, 0xc0, 0x0c, 0x00, 0x01, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x09, 0x00, 0x04, 0x0d, 0xe0, 0x77, 0x62, + ]) + .unwrap(); + + assert_eq!(p.packet.transaction_id(), 0x4b9e); + assert_eq!( + p.packet.flags(), + Flags::RESPONSE | Flags::RECURSION_DESIRED | Flags::RECURSION_AVAILABLE + ); + assert_eq!(p.packet.opcode(), Opcode::Query); + assert_eq!(p.packet.rcode(), Rcode::NoError); + assert_eq!(p.packet.question_count(), 1); + assert_eq!(p.packet.answer_record_count(), 4); + assert_eq!(p.packet.authority_record_count(), 0); + assert_eq!(p.packet.additional_record_count(), 0); + + assert_eq!( + p.questions[0].name, + &[ + 0x09, 0x72, 0x75, 0x73, 0x74, 0x2d, 0x6c, 0x61, 0x6e, 0x67, 0x03, 0x6f, 0x72, 0x67, + 0x00 + ] + ); + assert_eq!(p.questions[0].type_, Type::A); + + assert_eq!(p.answers[0].name, &[0xc0, 0x0c]); + assert_eq!(p.answers[0].ttl, 9); + assert_eq!( + p.answers[0].data, + RecordData::A(Ipv4Address::new(0x0d, 0xe0, 0x77, 0x35)) + ); + + assert_eq!(p.answers[1].name, &[0xc0, 0x0c]); + assert_eq!(p.answers[1].ttl, 9); + assert_eq!( + p.answers[1].data, + RecordData::A(Ipv4Address::new(0x0d, 0xe0, 0x77, 0x28)) + ); + + assert_eq!(p.answers[2].name, &[0xc0, 0x0c]); + assert_eq!(p.answers[2].ttl, 9); + assert_eq!( + p.answers[2].data, + RecordData::A(Ipv4Address::new(0x0d, 0xe0, 0x77, 0x43)) + ); + + assert_eq!(p.answers[3].name, &[0xc0, 0x0c]); + assert_eq!(p.answers[3].ttl, 9); + assert_eq!( + p.answers[3].data, + RecordData::A(Ipv4Address::new(0x0d, 0xe0, 0x77, 0x62)) + ); + } + + #[test] + fn test_parse_response_cname() { + let p = Parsed::parse(&[ + 0x78, 0x6c, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x03, 0x77, + 0x77, 0x77, 0x08, 0x66, 0x61, 0x63, 0x65, 0x62, 0x6f, 0x6f, 0x6b, 0x03, 0x63, 0x6f, + 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, + 0x05, 0xf3, 0x00, 0x11, 0x09, 0x73, 0x74, 0x61, 0x72, 0x2d, 0x6d, 0x69, 0x6e, 0x69, + 0x04, 0x63, 0x31, 0x30, 0x72, 0xc0, 0x10, 0xc0, 0x2e, 0x00, 0x01, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x05, 0x00, 0x04, 0x1f, 0x0d, 0x53, 0x24, + ]) + .unwrap(); + + assert_eq!(p.packet.transaction_id(), 0x786c); + assert_eq!( + p.packet.flags(), + Flags::RESPONSE | Flags::RECURSION_DESIRED | Flags::RECURSION_AVAILABLE + ); + assert_eq!(p.packet.opcode(), Opcode::Query); + assert_eq!(p.packet.rcode(), Rcode::NoError); + assert_eq!(p.packet.question_count(), 1); + assert_eq!(p.packet.answer_record_count(), 2); + assert_eq!(p.packet.authority_record_count(), 0); + assert_eq!(p.packet.additional_record_count(), 0); + + assert_eq!( + p.questions[0].name, + &[ + 0x03, 0x77, 0x77, 0x77, 0x08, 0x66, 0x61, 0x63, 0x65, 0x62, 0x6f, 0x6f, 0x6b, 0x03, + 0x63, 0x6f, 0x6d, 0x00 + ] + ); + assert_eq!(p.questions[0].type_, Type::A); + + // cname + assert_eq!(p.answers[0].name, &[0xc0, 0x0c]); + assert_eq!(p.answers[0].ttl, 1523); + assert_eq!( + p.answers[0].data, + RecordData::Cname(&[ + 0x09, 0x73, 0x74, 0x61, 0x72, 0x2d, 0x6d, 0x69, 0x6e, 0x69, 0x04, 0x63, 0x31, 0x30, + 0x72, 0xc0, 0x10 + ]) + ); + // a + assert_eq!(p.answers[1].name, &[0xc0, 0x2e]); + assert_eq!(p.answers[1].ttl, 5); + assert_eq!( + p.answers[1].data, + RecordData::A(Ipv4Address::new(0x1f, 0x0d, 0x53, 0x24)) + ); + } + + #[test] + fn test_parse_response_nxdomain() { + let p = Parsed::parse(&[ + 0x63, 0xc4, 0x81, 0x83, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x13, 0x61, + 0x68, 0x61, 0x73, 0x64, 0x67, 0x68, 0x6c, 0x61, 0x6b, 0x73, 0x6a, 0x68, 0x62, 0x61, + 0x61, 0x73, 0x6c, 0x64, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, + 0x20, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x03, 0x83, 0x00, 0x3d, 0x01, 0x61, 0x0c, + 0x67, 0x74, 0x6c, 0x64, 0x2d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x03, 0x6e, + 0x65, 0x74, 0x00, 0x05, 0x6e, 0x73, 0x74, 0x6c, 0x64, 0x0c, 0x76, 0x65, 0x72, 0x69, + 0x73, 0x69, 0x67, 0x6e, 0x2d, 0x67, 0x72, 0x73, 0xc0, 0x20, 0x5f, 0xce, 0x8b, 0x85, + 0x00, 0x00, 0x07, 0x08, 0x00, 0x00, 0x03, 0x84, 0x00, 0x09, 0x3a, 0x80, 0x00, 0x01, + 0x51, 0x80, + ]) + .unwrap(); + + assert_eq!(p.packet.transaction_id(), 0x63c4); + assert_eq!( + p.packet.flags(), + Flags::RESPONSE | Flags::RECURSION_DESIRED | Flags::RECURSION_AVAILABLE + ); + assert_eq!(p.packet.opcode(), Opcode::Query); + assert_eq!(p.packet.rcode(), Rcode::NXDomain); + assert_eq!(p.packet.question_count(), 1); + assert_eq!(p.packet.answer_record_count(), 0); + assert_eq!(p.packet.authority_record_count(), 1); + assert_eq!(p.packet.additional_record_count(), 0); + + assert_eq!(p.questions[0].type_, Type::A); + + // SOA authority + assert_eq!(p.authorities[0].name, &[0xc0, 0x20]); // com. + assert_eq!(p.authorities[0].ttl, 899); + assert!(matches!( + p.authorities[0].data, + RecordData::Other(Type::Soa, _) + )); + } + + #[test] + fn test_emit() { + let name = &[ + 0x09, 0x72, 0x75, 0x73, 0x74, 0x2d, 0x6c, 0x61, 0x6e, 0x67, 0x03, 0x6f, 0x72, 0x67, + 0x00, + ]; + + let repr = Repr { + transaction_id: 0x1234, + flags: Flags::RECURSION_DESIRED, + opcode: Opcode::Query, + question: Question { + name, + type_: Type::A, + }, + }; + + let mut buf = Vec::new(); + buf.resize(repr.buffer_len(), 0); + repr.emit(&mut Packet::new_unchecked(&mut buf)); + + let want = &[ + 0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x72, + 0x75, 0x73, 0x74, 0x2d, 0x6c, 0x61, 0x6e, 0x67, 0x03, 0x6f, 0x72, 0x67, 0x00, 0x00, + 0x01, 0x00, 0x01, + ]; + assert_eq!(&buf, want); + } +} diff --git a/src/wire/ethernet.rs b/src/wire/ethernet.rs new file mode 100644 index 0000000..53dc1ea --- /dev/null +++ b/src/wire/ethernet.rs @@ -0,0 +1,400 @@ +use byteorder::{ByteOrder, NetworkEndian}; +use core::fmt; + +use super::{Error, Result}; + +enum_with_unknown! { + /// Ethernet protocol type. + pub enum EtherType(u16) { + Ipv4 = 0x0800, + Arp = 0x0806, + Ipv6 = 0x86DD + } +} + +impl fmt::Display for EtherType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + EtherType::Ipv4 => write!(f, "IPv4"), + EtherType::Ipv6 => write!(f, "IPv6"), + EtherType::Arp => write!(f, "ARP"), + EtherType::Unknown(id) => write!(f, "0x{id:04x}"), + } + } +} + +/// A six-octet Ethernet II address. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Address(pub [u8; 6]); + +impl Address { + /// The broadcast address. + pub const BROADCAST: Address = Address([0xff; 6]); + + /// Construct an Ethernet address from a sequence of octets, in big-endian. + /// + /// # Panics + /// The function panics if `data` is not six octets long. + pub fn from_bytes(data: &[u8]) -> Address { + let mut bytes = [0; 6]; + bytes.copy_from_slice(data); + Address(bytes) + } + + /// Return an Ethernet address as a sequence of octets, in big-endian. + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } + + /// Query whether the address is an unicast address. + pub fn is_unicast(&self) -> bool { + !(self.is_broadcast() || self.is_multicast()) + } + + /// Query whether this address is the broadcast address. + pub fn is_broadcast(&self) -> bool { + *self == Self::BROADCAST + } + + /// Query whether the "multicast" bit in the OUI is set. + pub const fn is_multicast(&self) -> bool { + self.0[0] & 0x01 != 0 + } + + /// Query whether the "locally administered" bit in the OUI is set. + pub const fn is_local(&self) -> bool { + self.0[0] & 0x02 != 0 + } +} + +impl fmt::Display for Address { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let bytes = self.0; + write!( + f, + "{:02x}-{:02x}-{:02x}-{:02x}-{:02x}-{:02x}", + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5] + ) + } +} + +/// A read/write wrapper around an Ethernet II frame buffer. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Frame<T: AsRef<[u8]>> { + buffer: T, +} + +mod field { + use crate::wire::field::*; + + pub const DESTINATION: Field = 0..6; + pub const SOURCE: Field = 6..12; + pub const ETHERTYPE: Field = 12..14; + pub const PAYLOAD: Rest = 14..; +} + +/// The Ethernet header length +pub const HEADER_LEN: usize = field::PAYLOAD.start; + +impl<T: AsRef<[u8]>> Frame<T> { + /// Imbue a raw octet buffer with Ethernet frame structure. + pub const fn new_unchecked(buffer: T) -> Frame<T> { + Frame { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Frame<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let len = self.buffer.as_ref().len(); + if len < HEADER_LEN { + Err(Error) + } else { + Ok(()) + } + } + + /// Consumes the frame, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the length of a frame header. + pub const fn header_len() -> usize { + HEADER_LEN + } + + /// Return the length of a buffer required to hold a packet with the payload + /// of a given length. + pub const fn buffer_len(payload_len: usize) -> usize { + HEADER_LEN + payload_len + } + + /// Return the destination address field. + #[inline] + pub fn dst_addr(&self) -> Address { + let data = self.buffer.as_ref(); + Address::from_bytes(&data[field::DESTINATION]) + } + + /// Return the source address field. + #[inline] + pub fn src_addr(&self) -> Address { + let data = self.buffer.as_ref(); + Address::from_bytes(&data[field::SOURCE]) + } + + /// Return the EtherType field, without checking for 802.1Q. + #[inline] + pub fn ethertype(&self) -> EtherType { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::ETHERTYPE]); + EtherType::from(raw) + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Frame<&'a T> { + /// Return a pointer to the payload, without checking for 802.1Q. + #[inline] + pub fn payload(&self) -> &'a [u8] { + let data = self.buffer.as_ref(); + &data[field::PAYLOAD] + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Frame<T> { + /// Set the destination address field. + #[inline] + pub fn set_dst_addr(&mut self, value: Address) { + let data = self.buffer.as_mut(); + data[field::DESTINATION].copy_from_slice(value.as_bytes()) + } + + /// Set the source address field. + #[inline] + pub fn set_src_addr(&mut self, value: Address) { + let data = self.buffer.as_mut(); + data[field::SOURCE].copy_from_slice(value.as_bytes()) + } + + /// Set the EtherType field. + #[inline] + pub fn set_ethertype(&mut self, value: EtherType) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::ETHERTYPE], value.into()) + } + + /// Return a mutable pointer to the payload. + #[inline] + pub fn payload_mut(&mut self) -> &mut [u8] { + let data = self.buffer.as_mut(); + &mut data[field::PAYLOAD] + } +} + +impl<T: AsRef<[u8]>> AsRef<[u8]> for Frame<T> { + fn as_ref(&self) -> &[u8] { + self.buffer.as_ref() + } +} + +impl<T: AsRef<[u8]>> fmt::Display for Frame<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "EthernetII src={} dst={} type={}", + self.src_addr(), + self.dst_addr(), + self.ethertype() + ) + } +} + +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; + +impl<T: AsRef<[u8]>> PrettyPrint for Frame<T> { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { + let frame = match Frame::new_checked(buffer) { + Err(err) => return write!(f, "{indent}({err})"), + Ok(frame) => frame, + }; + write!(f, "{indent}{frame}")?; + + match frame.ethertype() { + #[cfg(feature = "proto-ipv4")] + EtherType::Arp => { + indent.increase(f)?; + super::ArpPacket::<&[u8]>::pretty_print(&frame.payload(), f, indent) + } + #[cfg(feature = "proto-ipv4")] + EtherType::Ipv4 => { + indent.increase(f)?; + super::Ipv4Packet::<&[u8]>::pretty_print(&frame.payload(), f, indent) + } + #[cfg(feature = "proto-ipv6")] + EtherType::Ipv6 => { + indent.increase(f)?; + super::Ipv6Packet::<&[u8]>::pretty_print(&frame.payload(), f, indent) + } + _ => Ok(()), + } + } +} + +/// A high-level representation of an Internet Protocol version 4 packet header. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Repr { + pub src_addr: Address, + pub dst_addr: Address, + pub ethertype: EtherType, +} + +impl Repr { + /// Parse an Ethernet II frame and return a high-level representation. + pub fn parse<T: AsRef<[u8]> + ?Sized>(frame: &Frame<&T>) -> Result<Repr> { + frame.check_len()?; + Ok(Repr { + src_addr: frame.src_addr(), + dst_addr: frame.dst_addr(), + ethertype: frame.ethertype(), + }) + } + + /// Return the length of a header that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + HEADER_LEN + } + + /// Emit a high-level representation into an Ethernet II frame. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, frame: &mut Frame<T>) { + frame.set_src_addr(self.src_addr); + frame.set_dst_addr(self.dst_addr); + frame.set_ethertype(self.ethertype); + } +} + +#[cfg(test)] +mod test { + // Tests that are valid with any combination of + // "proto-*" features. + use super::*; + + #[test] + fn test_broadcast() { + assert!(Address::BROADCAST.is_broadcast()); + assert!(!Address::BROADCAST.is_unicast()); + assert!(Address::BROADCAST.is_multicast()); + assert!(Address::BROADCAST.is_local()); + } +} + +#[cfg(test)] +#[cfg(feature = "proto-ipv4")] +mod test_ipv4 { + // Tests that are valid only with "proto-ipv4" + use super::*; + + static FRAME_BYTES: [u8; 64] = [ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x08, 0x00, 0xaa, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0xff, + ]; + + static PAYLOAD_BYTES: [u8; 50] = [ + 0xaa, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xff, + ]; + + #[test] + fn test_deconstruct() { + let frame = Frame::new_unchecked(&FRAME_BYTES[..]); + assert_eq!( + frame.dst_addr(), + Address([0x01, 0x02, 0x03, 0x04, 0x05, 0x06]) + ); + assert_eq!( + frame.src_addr(), + Address([0x11, 0x12, 0x13, 0x14, 0x15, 0x16]) + ); + assert_eq!(frame.ethertype(), EtherType::Ipv4); + assert_eq!(frame.payload(), &PAYLOAD_BYTES[..]); + } + + #[test] + fn test_construct() { + let mut bytes = vec![0xa5; 64]; + let mut frame = Frame::new_unchecked(&mut bytes); + frame.set_dst_addr(Address([0x01, 0x02, 0x03, 0x04, 0x05, 0x06])); + frame.set_src_addr(Address([0x11, 0x12, 0x13, 0x14, 0x15, 0x16])); + frame.set_ethertype(EtherType::Ipv4); + frame.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]); + assert_eq!(&frame.into_inner()[..], &FRAME_BYTES[..]); + } +} + +#[cfg(test)] +#[cfg(feature = "proto-ipv6")] +mod test_ipv6 { + // Tests that are valid only with "proto-ipv6" + use super::*; + + static FRAME_BYTES: [u8; 54] = [ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x86, 0xdd, 0x60, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + ]; + + static PAYLOAD_BYTES: [u8; 40] = [ + 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + ]; + + #[test] + fn test_deconstruct() { + let frame = Frame::new_unchecked(&FRAME_BYTES[..]); + assert_eq!( + frame.dst_addr(), + Address([0x01, 0x02, 0x03, 0x04, 0x05, 0x06]) + ); + assert_eq!( + frame.src_addr(), + Address([0x11, 0x12, 0x13, 0x14, 0x15, 0x16]) + ); + assert_eq!(frame.ethertype(), EtherType::Ipv6); + assert_eq!(frame.payload(), &PAYLOAD_BYTES[..]); + } + + #[test] + fn test_construct() { + let mut bytes = vec![0xa5; 54]; + let mut frame = Frame::new_unchecked(&mut bytes); + frame.set_dst_addr(Address([0x01, 0x02, 0x03, 0x04, 0x05, 0x06])); + frame.set_src_addr(Address([0x11, 0x12, 0x13, 0x14, 0x15, 0x16])); + frame.set_ethertype(EtherType::Ipv6); + assert_eq!(PAYLOAD_BYTES.len(), frame.payload_mut().len()); + frame.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]); + assert_eq!(&frame.into_inner()[..], &FRAME_BYTES[..]); + } +} diff --git a/src/wire/icmp.rs b/src/wire/icmp.rs new file mode 100644 index 0000000..6bbc574 --- /dev/null +++ b/src/wire/icmp.rs @@ -0,0 +1,25 @@ +#[cfg(feature = "proto-ipv4")] +use crate::wire::icmpv4; +#[cfg(feature = "proto-ipv6")] +use crate::wire::icmpv6; + +#[derive(Clone, PartialEq, Eq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Repr<'a> { + #[cfg(feature = "proto-ipv4")] + Ipv4(icmpv4::Repr<'a>), + #[cfg(feature = "proto-ipv6")] + Ipv6(icmpv6::Repr<'a>), +} +#[cfg(feature = "proto-ipv4")] +impl<'a> From<icmpv4::Repr<'a>> for Repr<'a> { + fn from(s: icmpv4::Repr<'a>) -> Self { + Repr::Ipv4(s) + } +} +#[cfg(feature = "proto-ipv6")] +impl<'a> From<icmpv6::Repr<'a>> for Repr<'a> { + fn from(s: icmpv6::Repr<'a>) -> Self { + Repr::Ipv6(s) + } +} diff --git a/src/wire/icmpv4.rs b/src/wire/icmpv4.rs new file mode 100644 index 0000000..60e1215 --- /dev/null +++ b/src/wire/icmpv4.rs @@ -0,0 +1,702 @@ +use byteorder::{ByteOrder, NetworkEndian}; +use core::{cmp, fmt}; + +use super::{Error, Result}; +use crate::phy::ChecksumCapabilities; +use crate::wire::ip::checksum; +use crate::wire::{Ipv4Packet, Ipv4Repr}; + +enum_with_unknown! { + /// Internet protocol control message type. + pub enum Message(u8) { + /// Echo reply + EchoReply = 0, + /// Destination unreachable + DstUnreachable = 3, + /// Message redirect + Redirect = 5, + /// Echo request + EchoRequest = 8, + /// Router advertisement + RouterAdvert = 9, + /// Router solicitation + RouterSolicit = 10, + /// Time exceeded + TimeExceeded = 11, + /// Parameter problem + ParamProblem = 12, + /// Timestamp + Timestamp = 13, + /// Timestamp reply + TimestampReply = 14 + } +} + +impl fmt::Display for Message { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Message::EchoReply => write!(f, "echo reply"), + Message::DstUnreachable => write!(f, "destination unreachable"), + Message::Redirect => write!(f, "message redirect"), + Message::EchoRequest => write!(f, "echo request"), + Message::RouterAdvert => write!(f, "router advertisement"), + Message::RouterSolicit => write!(f, "router solicitation"), + Message::TimeExceeded => write!(f, "time exceeded"), + Message::ParamProblem => write!(f, "parameter problem"), + Message::Timestamp => write!(f, "timestamp"), + Message::TimestampReply => write!(f, "timestamp reply"), + Message::Unknown(id) => write!(f, "{id}"), + } + } +} + +enum_with_unknown! { + /// Internet protocol control message subtype for type "Destination Unreachable". + pub enum DstUnreachable(u8) { + /// Destination network unreachable + NetUnreachable = 0, + /// Destination host unreachable + HostUnreachable = 1, + /// Destination protocol unreachable + ProtoUnreachable = 2, + /// Destination port unreachable + PortUnreachable = 3, + /// Fragmentation required, and DF flag set + FragRequired = 4, + /// Source route failed + SrcRouteFailed = 5, + /// Destination network unknown + DstNetUnknown = 6, + /// Destination host unknown + DstHostUnknown = 7, + /// Source host isolated + SrcHostIsolated = 8, + /// Network administratively prohibited + NetProhibited = 9, + /// Host administratively prohibited + HostProhibited = 10, + /// Network unreachable for ToS + NetUnreachToS = 11, + /// Host unreachable for ToS + HostUnreachToS = 12, + /// Communication administratively prohibited + CommProhibited = 13, + /// Host precedence violation + HostPrecedViol = 14, + /// Precedence cutoff in effect + PrecedCutoff = 15 + } +} + +impl fmt::Display for DstUnreachable { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + DstUnreachable::NetUnreachable => write!(f, "destination network unreachable"), + DstUnreachable::HostUnreachable => write!(f, "destination host unreachable"), + DstUnreachable::ProtoUnreachable => write!(f, "destination protocol unreachable"), + DstUnreachable::PortUnreachable => write!(f, "destination port unreachable"), + DstUnreachable::FragRequired => write!(f, "fragmentation required, and DF flag set"), + DstUnreachable::SrcRouteFailed => write!(f, "source route failed"), + DstUnreachable::DstNetUnknown => write!(f, "destination network unknown"), + DstUnreachable::DstHostUnknown => write!(f, "destination host unknown"), + DstUnreachable::SrcHostIsolated => write!(f, "source host isolated"), + DstUnreachable::NetProhibited => write!(f, "network administratively prohibited"), + DstUnreachable::HostProhibited => write!(f, "host administratively prohibited"), + DstUnreachable::NetUnreachToS => write!(f, "network unreachable for ToS"), + DstUnreachable::HostUnreachToS => write!(f, "host unreachable for ToS"), + DstUnreachable::CommProhibited => { + write!(f, "communication administratively prohibited") + } + DstUnreachable::HostPrecedViol => write!(f, "host precedence violation"), + DstUnreachable::PrecedCutoff => write!(f, "precedence cutoff in effect"), + DstUnreachable::Unknown(id) => write!(f, "{id}"), + } + } +} + +enum_with_unknown! { + /// Internet protocol control message subtype for type "Redirect Message". + pub enum Redirect(u8) { + /// Redirect Datagram for the Network + Net = 0, + /// Redirect Datagram for the Host + Host = 1, + /// Redirect Datagram for the ToS & network + NetToS = 2, + /// Redirect Datagram for the ToS & host + HostToS = 3 + } +} + +enum_with_unknown! { + /// Internet protocol control message subtype for type "Time Exceeded". + pub enum TimeExceeded(u8) { + /// TTL expired in transit + TtlExpired = 0, + /// Fragment reassembly time exceeded + FragExpired = 1 + } +} + +impl fmt::Display for TimeExceeded { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + TimeExceeded::TtlExpired => write!(f, "time-to-live exceeded in transit"), + TimeExceeded::FragExpired => write!(f, "fragment reassembly time exceeded"), + TimeExceeded::Unknown(id) => write!(f, "{id}"), + } + } +} + +enum_with_unknown! { + /// Internet protocol control message subtype for type "Parameter Problem". + pub enum ParamProblem(u8) { + /// Pointer indicates the error + AtPointer = 0, + /// Missing a required option + MissingOption = 1, + /// Bad length + BadLength = 2 + } +} + +/// A read/write wrapper around an Internet Control Message Protocol version 4 packet buffer. +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Packet<T: AsRef<[u8]>> { + buffer: T, +} + +mod field { + use crate::wire::field::*; + + pub const TYPE: usize = 0; + pub const CODE: usize = 1; + pub const CHECKSUM: Field = 2..4; + + pub const UNUSED: Field = 4..8; + + pub const ECHO_IDENT: Field = 4..6; + pub const ECHO_SEQNO: Field = 6..8; + + pub const HEADER_END: usize = 8; +} + +impl<T: AsRef<[u8]>> Packet<T> { + /// Imbue a raw octet buffer with ICMPv4 packet structure. + pub const fn new_unchecked(buffer: T) -> Packet<T> { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Packet<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + /// + /// The result of this check is invalidated by calling [set_header_len]. + /// + /// [set_header_len]: #method.set_header_len + pub fn check_len(&self) -> Result<()> { + let len = self.buffer.as_ref().len(); + if len < field::HEADER_END { + Err(Error) + } else { + Ok(()) + } + } + + /// Consume the packet, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the message type field. + #[inline] + pub fn msg_type(&self) -> Message { + let data = self.buffer.as_ref(); + Message::from(data[field::TYPE]) + } + + /// Return the message code field. + #[inline] + pub fn msg_code(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::CODE] + } + + /// Return the checksum field. + #[inline] + pub fn checksum(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::CHECKSUM]) + } + + /// Return the identifier field (for echo request and reply packets). + /// + /// # Panics + /// This function may panic if this packet is not an echo request or reply packet. + #[inline] + pub fn echo_ident(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::ECHO_IDENT]) + } + + /// Return the sequence number field (for echo request and reply packets). + /// + /// # Panics + /// This function may panic if this packet is not an echo request or reply packet. + #[inline] + pub fn echo_seq_no(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::ECHO_SEQNO]) + } + + /// Return the header length. + /// The result depends on the value of the message type field. + pub fn header_len(&self) -> usize { + match self.msg_type() { + Message::EchoRequest => field::ECHO_SEQNO.end, + Message::EchoReply => field::ECHO_SEQNO.end, + Message::DstUnreachable => field::UNUSED.end, + _ => field::UNUSED.end, // make a conservative assumption + } + } + + /// Validate the header checksum. + /// + /// # Fuzzing + /// This function always returns `true` when fuzzing. + pub fn verify_checksum(&self) -> bool { + if cfg!(fuzzing) { + return true; + } + + let data = self.buffer.as_ref(); + checksum::data(data) == !0 + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> { + /// Return a pointer to the type-specific data. + #[inline] + pub fn data(&self) -> &'a [u8] { + let data = self.buffer.as_ref(); + &data[self.header_len()..] + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the message type field. + #[inline] + pub fn set_msg_type(&mut self, value: Message) { + let data = self.buffer.as_mut(); + data[field::TYPE] = value.into() + } + + /// Set the message code field. + #[inline] + pub fn set_msg_code(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::CODE] = value + } + + /// Set the checksum field. + #[inline] + pub fn set_checksum(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::CHECKSUM], value) + } + + /// Set the identifier field (for echo request and reply packets). + /// + /// # Panics + /// This function may panic if this packet is not an echo request or reply packet. + #[inline] + pub fn set_echo_ident(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::ECHO_IDENT], value) + } + + /// Set the sequence number field (for echo request and reply packets). + /// + /// # Panics + /// This function may panic if this packet is not an echo request or reply packet. + #[inline] + pub fn set_echo_seq_no(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::ECHO_SEQNO], value) + } + + /// Compute and fill in the header checksum. + pub fn fill_checksum(&mut self) { + self.set_checksum(0); + let checksum = { + let data = self.buffer.as_ref(); + !checksum::data(data) + }; + self.set_checksum(checksum) + } +} + +impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Packet<&'a mut T> { + /// Return a mutable pointer to the type-specific data. + #[inline] + pub fn data_mut(&mut self) -> &mut [u8] { + let range = self.header_len()..; + let data = self.buffer.as_mut(); + &mut data[range] + } +} + +impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> { + fn as_ref(&self) -> &[u8] { + self.buffer.as_ref() + } +} + +/// A high-level representation of an Internet Control Message Protocol version 4 packet header. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] +pub enum Repr<'a> { + EchoRequest { + ident: u16, + seq_no: u16, + data: &'a [u8], + }, + EchoReply { + ident: u16, + seq_no: u16, + data: &'a [u8], + }, + DstUnreachable { + reason: DstUnreachable, + header: Ipv4Repr, + data: &'a [u8], + }, + TimeExceeded { + reason: TimeExceeded, + header: Ipv4Repr, + data: &'a [u8], + }, +} + +impl<'a> Repr<'a> { + /// Parse an Internet Control Message Protocol version 4 packet and return + /// a high-level representation. + pub fn parse<T>( + packet: &Packet<&'a T>, + checksum_caps: &ChecksumCapabilities, + ) -> Result<Repr<'a>> + where + T: AsRef<[u8]> + ?Sized, + { + // Valid checksum is expected. + if checksum_caps.icmpv4.rx() && !packet.verify_checksum() { + return Err(Error); + } + + match (packet.msg_type(), packet.msg_code()) { + (Message::EchoRequest, 0) => Ok(Repr::EchoRequest { + ident: packet.echo_ident(), + seq_no: packet.echo_seq_no(), + data: packet.data(), + }), + + (Message::EchoReply, 0) => Ok(Repr::EchoReply { + ident: packet.echo_ident(), + seq_no: packet.echo_seq_no(), + data: packet.data(), + }), + + (Message::DstUnreachable, code) => { + let ip_packet = Ipv4Packet::new_checked(packet.data())?; + + let payload = &packet.data()[ip_packet.header_len() as usize..]; + // RFC 792 requires exactly eight bytes to be returned. + // We allow more, since there isn't a reason not to, but require at least eight. + if payload.len() < 8 { + return Err(Error); + } + + Ok(Repr::DstUnreachable { + reason: DstUnreachable::from(code), + header: Ipv4Repr { + src_addr: ip_packet.src_addr(), + dst_addr: ip_packet.dst_addr(), + next_header: ip_packet.next_header(), + payload_len: payload.len(), + hop_limit: ip_packet.hop_limit(), + }, + data: payload, + }) + } + + (Message::TimeExceeded, code) => { + let ip_packet = Ipv4Packet::new_checked(packet.data())?; + + let payload = &packet.data()[ip_packet.header_len() as usize..]; + // RFC 792 requires exactly eight bytes to be returned. + // We allow more, since there isn't a reason not to, but require at least eight. + if payload.len() < 8 { + return Err(Error); + } + + Ok(Repr::TimeExceeded { + reason: TimeExceeded::from(code), + header: Ipv4Repr { + src_addr: ip_packet.src_addr(), + dst_addr: ip_packet.dst_addr(), + next_header: ip_packet.next_header(), + payload_len: payload.len(), + hop_limit: ip_packet.hop_limit(), + }, + data: payload, + }) + } + + _ => Err(Error), + } + } + + /// Return the length of a packet that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + match self { + &Repr::EchoRequest { data, .. } | &Repr::EchoReply { data, .. } => { + field::ECHO_SEQNO.end + data.len() + } + &Repr::DstUnreachable { header, data, .. } + | &Repr::TimeExceeded { header, data, .. } => { + field::UNUSED.end + header.buffer_len() + data.len() + } + } + } + + /// Emit a high-level representation into an Internet Control Message Protocol version 4 + /// packet. + pub fn emit<T>(&self, packet: &mut Packet<&mut T>, checksum_caps: &ChecksumCapabilities) + where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { + packet.set_msg_code(0); + match *self { + Repr::EchoRequest { + ident, + seq_no, + data, + } => { + packet.set_msg_type(Message::EchoRequest); + packet.set_msg_code(0); + packet.set_echo_ident(ident); + packet.set_echo_seq_no(seq_no); + let data_len = cmp::min(packet.data_mut().len(), data.len()); + packet.data_mut()[..data_len].copy_from_slice(&data[..data_len]) + } + + Repr::EchoReply { + ident, + seq_no, + data, + } => { + packet.set_msg_type(Message::EchoReply); + packet.set_msg_code(0); + packet.set_echo_ident(ident); + packet.set_echo_seq_no(seq_no); + let data_len = cmp::min(packet.data_mut().len(), data.len()); + packet.data_mut()[..data_len].copy_from_slice(&data[..data_len]) + } + + Repr::DstUnreachable { + reason, + header, + data, + } => { + packet.set_msg_type(Message::DstUnreachable); + packet.set_msg_code(reason.into()); + + let mut ip_packet = Ipv4Packet::new_unchecked(packet.data_mut()); + header.emit(&mut ip_packet, checksum_caps); + let payload = &mut ip_packet.into_inner()[header.buffer_len()..]; + payload.copy_from_slice(data) + } + + Repr::TimeExceeded { + reason, + header, + data, + } => { + packet.set_msg_type(Message::TimeExceeded); + packet.set_msg_code(reason.into()); + + let mut ip_packet = Ipv4Packet::new_unchecked(packet.data_mut()); + header.emit(&mut ip_packet, checksum_caps); + let payload = &mut ip_packet.into_inner()[header.buffer_len()..]; + payload.copy_from_slice(data) + } + } + + if checksum_caps.icmpv4.tx() { + packet.fill_checksum() + } else { + // make sure we get a consistently zeroed checksum, + // since implementations might rely on it + packet.set_checksum(0); + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match Repr::parse(self, &ChecksumCapabilities::default()) { + Ok(repr) => write!(f, "{repr}"), + Err(err) => { + write!(f, "ICMPv4 ({err})")?; + write!(f, " type={:?}", self.msg_type())?; + match self.msg_type() { + Message::DstUnreachable => { + write!(f, " code={:?}", DstUnreachable::from(self.msg_code())) + } + Message::TimeExceeded => { + write!(f, " code={:?}", TimeExceeded::from(self.msg_code())) + } + _ => write!(f, " code={}", self.msg_code()), + } + } + } + } +} + +impl<'a> fmt::Display for Repr<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Repr::EchoRequest { + ident, + seq_no, + data, + } => write!( + f, + "ICMPv4 echo request id={} seq={} len={}", + ident, + seq_no, + data.len() + ), + Repr::EchoReply { + ident, + seq_no, + data, + } => write!( + f, + "ICMPv4 echo reply id={} seq={} len={}", + ident, + seq_no, + data.len() + ), + Repr::DstUnreachable { reason, .. } => { + write!(f, "ICMPv4 destination unreachable ({reason})") + } + Repr::TimeExceeded { reason, .. } => { + write!(f, "ICMPv4 time exceeded ({reason})") + } + } + } +} + +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; + +impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { + let packet = match Packet::new_checked(buffer) { + Err(err) => return write!(f, "{indent}({err})"), + Ok(packet) => packet, + }; + write!(f, "{indent}{packet}")?; + + match packet.msg_type() { + Message::DstUnreachable | Message::TimeExceeded => { + indent.increase(f)?; + super::Ipv4Packet::<&[u8]>::pretty_print(&packet.data(), f, indent) + } + _ => Ok(()), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + static ECHO_PACKET_BYTES: [u8; 12] = [ + 0x08, 0x00, 0x8e, 0xfe, 0x12, 0x34, 0xab, 0xcd, 0xaa, 0x00, 0x00, 0xff, + ]; + + static ECHO_DATA_BYTES: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; + + #[test] + fn test_echo_deconstruct() { + let packet = Packet::new_unchecked(&ECHO_PACKET_BYTES[..]); + assert_eq!(packet.msg_type(), Message::EchoRequest); + assert_eq!(packet.msg_code(), 0); + assert_eq!(packet.checksum(), 0x8efe); + assert_eq!(packet.echo_ident(), 0x1234); + assert_eq!(packet.echo_seq_no(), 0xabcd); + assert_eq!(packet.data(), &ECHO_DATA_BYTES[..]); + assert!(packet.verify_checksum()); + } + + #[test] + fn test_echo_construct() { + let mut bytes = vec![0xa5; 12]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_msg_type(Message::EchoRequest); + packet.set_msg_code(0); + packet.set_echo_ident(0x1234); + packet.set_echo_seq_no(0xabcd); + packet.data_mut().copy_from_slice(&ECHO_DATA_BYTES[..]); + packet.fill_checksum(); + assert_eq!(&packet.into_inner()[..], &ECHO_PACKET_BYTES[..]); + } + + fn echo_packet_repr() -> Repr<'static> { + Repr::EchoRequest { + ident: 0x1234, + seq_no: 0xabcd, + data: &ECHO_DATA_BYTES, + } + } + + #[test] + fn test_echo_parse() { + let packet = Packet::new_unchecked(&ECHO_PACKET_BYTES[..]); + let repr = Repr::parse(&packet, &ChecksumCapabilities::default()).unwrap(); + assert_eq!(repr, echo_packet_repr()); + } + + #[test] + fn test_echo_emit() { + let repr = echo_packet_repr(); + let mut bytes = vec![0xa5; repr.buffer_len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + repr.emit(&mut packet, &ChecksumCapabilities::default()); + assert_eq!(&packet.into_inner()[..], &ECHO_PACKET_BYTES[..]); + } + + #[test] + fn test_check_len() { + let bytes = [0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + assert_eq!(Packet::new_checked(&[]), Err(Error)); + assert_eq!(Packet::new_checked(&bytes[..4]), Err(Error)); + assert!(Packet::new_checked(&bytes[..]).is_ok()); + } +} diff --git a/src/wire/icmpv6.rs b/src/wire/icmpv6.rs new file mode 100644 index 0000000..72d2451 --- /dev/null +++ b/src/wire/icmpv6.rs @@ -0,0 +1,1085 @@ +use byteorder::{ByteOrder, NetworkEndian}; +use core::{cmp, fmt}; + +use super::{Error, Result}; +use crate::phy::ChecksumCapabilities; +use crate::wire::ip::checksum; +use crate::wire::MldRepr; +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +use crate::wire::NdiscRepr; +#[cfg(feature = "proto-rpl")] +use crate::wire::RplRepr; +use crate::wire::{IpAddress, IpProtocol, Ipv6Packet, Ipv6Repr}; +use crate::wire::{IPV6_HEADER_LEN, IPV6_MIN_MTU}; + +/// Error packets must not exceed min MTU +const MAX_ERROR_PACKET_LEN: usize = IPV6_MIN_MTU - IPV6_HEADER_LEN; + +enum_with_unknown! { + /// Internet protocol control message type. + pub enum Message(u8) { + /// Destination Unreachable. + DstUnreachable = 0x01, + /// Packet Too Big. + PktTooBig = 0x02, + /// Time Exceeded. + TimeExceeded = 0x03, + /// Parameter Problem. + ParamProblem = 0x04, + /// Echo Request + EchoRequest = 0x80, + /// Echo Reply + EchoReply = 0x81, + /// Multicast Listener Query + MldQuery = 0x82, + /// Router Solicitation + RouterSolicit = 0x85, + /// Router Advertisement + RouterAdvert = 0x86, + /// Neighbor Solicitation + NeighborSolicit = 0x87, + /// Neighbor Advertisement + NeighborAdvert = 0x88, + /// Redirect + Redirect = 0x89, + /// Multicast Listener Report + MldReport = 0x8f, + /// RPL Control Message + RplControl = 0x9b, + } +} + +impl Message { + /// Per [RFC 4443 § 2.1] ICMPv6 message types with the highest order + /// bit set are informational messages while message types without + /// the highest order bit set are error messages. + /// + /// [RFC 4443 § 2.1]: https://tools.ietf.org/html/rfc4443#section-2.1 + pub fn is_error(&self) -> bool { + (u8::from(*self) & 0x80) != 0x80 + } + + /// Return a boolean value indicating if the given message type + /// is an [NDISC] message type. + /// + /// [NDISC]: https://tools.ietf.org/html/rfc4861 + pub const fn is_ndisc(&self) -> bool { + match *self { + Message::RouterSolicit + | Message::RouterAdvert + | Message::NeighborSolicit + | Message::NeighborAdvert + | Message::Redirect => true, + _ => false, + } + } + + /// Return a boolean value indicating if the given message type + /// is an [MLD] message type. + /// + /// [MLD]: https://tools.ietf.org/html/rfc3810 + pub const fn is_mld(&self) -> bool { + match *self { + Message::MldQuery | Message::MldReport => true, + _ => false, + } + } +} + +impl fmt::Display for Message { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Message::DstUnreachable => write!(f, "destination unreachable"), + Message::PktTooBig => write!(f, "packet too big"), + Message::TimeExceeded => write!(f, "time exceeded"), + Message::ParamProblem => write!(f, "parameter problem"), + Message::EchoReply => write!(f, "echo reply"), + Message::EchoRequest => write!(f, "echo request"), + Message::RouterSolicit => write!(f, "router solicitation"), + Message::RouterAdvert => write!(f, "router advertisement"), + Message::NeighborSolicit => write!(f, "neighbor solicitation"), + Message::NeighborAdvert => write!(f, "neighbor advert"), + Message::Redirect => write!(f, "redirect"), + Message::MldQuery => write!(f, "multicast listener query"), + Message::MldReport => write!(f, "multicast listener report"), + Message::RplControl => write!(f, "RPL control message"), + Message::Unknown(id) => write!(f, "{id}"), + } + } +} + +enum_with_unknown! { + /// Internet protocol control message subtype for type "Destination Unreachable". + pub enum DstUnreachable(u8) { + /// No Route to destination. + NoRoute = 0, + /// Communication with destination administratively prohibited. + AdminProhibit = 1, + /// Beyond scope of source address. + BeyondScope = 2, + /// Address unreachable. + AddrUnreachable = 3, + /// Port unreachable. + PortUnreachable = 4, + /// Source address failed ingress/egress policy. + FailedPolicy = 5, + /// Reject route to destination. + RejectRoute = 6 + } +} + +impl fmt::Display for DstUnreachable { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + DstUnreachable::NoRoute => write!(f, "no route to destination"), + DstUnreachable::AdminProhibit => write!( + f, + "communication with destination administratively prohibited" + ), + DstUnreachable::BeyondScope => write!(f, "beyond scope of source address"), + DstUnreachable::AddrUnreachable => write!(f, "address unreachable"), + DstUnreachable::PortUnreachable => write!(f, "port unreachable"), + DstUnreachable::FailedPolicy => { + write!(f, "source address failed ingress/egress policy") + } + DstUnreachable::RejectRoute => write!(f, "reject route to destination"), + DstUnreachable::Unknown(id) => write!(f, "{id}"), + } + } +} + +enum_with_unknown! { + /// Internet protocol control message subtype for the type "Parameter Problem". + pub enum ParamProblem(u8) { + /// Erroneous header field encountered. + ErroneousHdrField = 0, + /// Unrecognized Next Header type encountered. + UnrecognizedNxtHdr = 1, + /// Unrecognized IPv6 option encountered. + UnrecognizedOption = 2 + } +} + +impl fmt::Display for ParamProblem { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + ParamProblem::ErroneousHdrField => write!(f, "erroneous header field."), + ParamProblem::UnrecognizedNxtHdr => write!(f, "unrecognized next header type."), + ParamProblem::UnrecognizedOption => write!(f, "unrecognized IPv6 option."), + ParamProblem::Unknown(id) => write!(f, "{id}"), + } + } +} + +enum_with_unknown! { + /// Internet protocol control message subtype for the type "Time Exceeded". + pub enum TimeExceeded(u8) { + /// Hop limit exceeded in transit. + HopLimitExceeded = 0, + /// Fragment reassembly time exceeded. + FragReassemExceeded = 1 + } +} + +impl fmt::Display for TimeExceeded { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + TimeExceeded::HopLimitExceeded => write!(f, "hop limit exceeded in transit"), + TimeExceeded::FragReassemExceeded => write!(f, "fragment reassembly time exceeded"), + TimeExceeded::Unknown(id) => write!(f, "{id}"), + } + } +} + +/// A read/write wrapper around an Internet Control Message Protocol version 6 packet buffer. +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Packet<T: AsRef<[u8]>> { + pub(super) buffer: T, +} + +// Ranges and constants describing key boundaries in the ICMPv6 header. +pub(super) mod field { + use crate::wire::field::*; + + // ICMPv6: See https://tools.ietf.org/html/rfc4443 + pub const TYPE: usize = 0; + pub const CODE: usize = 1; + pub const CHECKSUM: Field = 2..4; + + pub const UNUSED: Field = 4..8; + pub const MTU: Field = 4..8; + pub const POINTER: Field = 4..8; + pub const ECHO_IDENT: Field = 4..6; + pub const ECHO_SEQNO: Field = 6..8; + + pub const HEADER_END: usize = 8; + + // NDISC: See https://tools.ietf.org/html/rfc4861 + // Router Advertisement message offsets + pub const CUR_HOP_LIMIT: usize = 4; + pub const ROUTER_FLAGS: usize = 5; + pub const ROUTER_LT: Field = 6..8; + pub const REACHABLE_TM: Field = 8..12; + pub const RETRANS_TM: Field = 12..16; + + // Neighbor Solicitation message offsets + pub const TARGET_ADDR: Field = 8..24; + + // Neighbor Advertisement message offsets + pub const NEIGH_FLAGS: usize = 4; + + // Redirected Header message offsets + pub const DEST_ADDR: Field = 24..40; + + // MLD: + // - https://tools.ietf.org/html/rfc3810 + // - https://tools.ietf.org/html/rfc3810 + // Multicast Listener Query message + pub const MAX_RESP_CODE: Field = 4..6; + pub const QUERY_RESV: Field = 6..8; + pub const QUERY_MCAST_ADDR: Field = 8..24; + pub const SQRV: usize = 24; + pub const QQIC: usize = 25; + pub const QUERY_NUM_SRCS: Field = 26..28; + + // Multicast Listener Report Message + pub const RECORD_RESV: Field = 4..6; + pub const NR_MCAST_RCRDS: Field = 6..8; + + // Multicast Address Record Offsets + pub const RECORD_TYPE: usize = 0; + pub const AUX_DATA_LEN: usize = 1; + pub const RECORD_NUM_SRCS: Field = 2..4; + pub const RECORD_MCAST_ADDR: Field = 4..20; +} + +impl<T: AsRef<[u8]>> Packet<T> { + /// Imbue a raw octet buffer with ICMPv6 packet structure. + pub const fn new_unchecked(buffer: T) -> Packet<T> { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Packet<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let len = self.buffer.as_ref().len(); + + if len < 4 { + return Err(Error); + } + + match self.msg_type() { + Message::DstUnreachable + | Message::PktTooBig + | Message::TimeExceeded + | Message::ParamProblem + | Message::EchoRequest + | Message::EchoReply + | Message::MldQuery + | Message::RouterSolicit + | Message::RouterAdvert + | Message::NeighborSolicit + | Message::NeighborAdvert + | Message::Redirect + | Message::MldReport => { + if len < field::HEADER_END || len < self.header_len() { + return Err(Error); + } + } + #[cfg(feature = "proto-rpl")] + Message::RplControl => match super::rpl::RplControlMessage::from(self.msg_code()) { + super::rpl::RplControlMessage::DodagInformationSolicitation => { + // TODO(thvdveld): replace magic number + if len < 6 { + return Err(Error); + } + } + super::rpl::RplControlMessage::DodagInformationObject => { + // TODO(thvdveld): replace magic number + if len < 28 { + return Err(Error); + } + } + super::rpl::RplControlMessage::DestinationAdvertisementObject => { + // TODO(thvdveld): replace magic number + if len < 8 || (self.dao_dodag_id_present() && len < 24) { + return Err(Error); + } + } + super::rpl::RplControlMessage::DestinationAdvertisementObjectAck => { + // TODO(thvdveld): replace magic number + if len < 8 || (self.dao_dodag_id_present() && len < 24) { + return Err(Error); + } + } + super::rpl::RplControlMessage::SecureDodagInformationSolicitation + | super::rpl::RplControlMessage::SecureDodagInformationObject + | super::rpl::RplControlMessage::SecureDestinationAdvertisementObject + | super::rpl::RplControlMessage::SecureDestinationAdvertisementObjectAck + | super::rpl::RplControlMessage::ConsistencyCheck => return Err(Error), + super::rpl::RplControlMessage::Unknown(_) => return Err(Error), + }, + #[cfg(not(feature = "proto-rpl"))] + Message::RplControl => return Err(Error), + Message::Unknown(_) => return Err(Error), + } + + Ok(()) + } + + /// Consume the packet, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the message type field. + #[inline] + pub fn msg_type(&self) -> Message { + let data = self.buffer.as_ref(); + Message::from(data[field::TYPE]) + } + + /// Return the message code field. + #[inline] + pub fn msg_code(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::CODE] + } + + /// Return the checksum field. + #[inline] + pub fn checksum(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::CHECKSUM]) + } + + /// Return the identifier field (for echo request and reply packets). + #[inline] + pub fn echo_ident(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::ECHO_IDENT]) + } + + /// Return the sequence number field (for echo request and reply packets). + #[inline] + pub fn echo_seq_no(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::ECHO_SEQNO]) + } + + /// Return the MTU field (for packet too big messages). + #[inline] + pub fn pkt_too_big_mtu(&self) -> u32 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u32(&data[field::MTU]) + } + + /// Return the pointer field (for parameter problem messages). + #[inline] + pub fn param_problem_ptr(&self) -> u32 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u32(&data[field::POINTER]) + } + + /// Return the header length. The result depends on the value of + /// the message type field. + pub fn header_len(&self) -> usize { + match self.msg_type() { + Message::DstUnreachable => field::UNUSED.end, + Message::PktTooBig => field::MTU.end, + Message::TimeExceeded => field::UNUSED.end, + Message::ParamProblem => field::POINTER.end, + Message::EchoRequest => field::ECHO_SEQNO.end, + Message::EchoReply => field::ECHO_SEQNO.end, + Message::RouterSolicit => field::UNUSED.end, + Message::RouterAdvert => field::RETRANS_TM.end, + Message::NeighborSolicit => field::TARGET_ADDR.end, + Message::NeighborAdvert => field::TARGET_ADDR.end, + Message::Redirect => field::DEST_ADDR.end, + Message::MldQuery => field::QUERY_NUM_SRCS.end, + Message::MldReport => field::NR_MCAST_RCRDS.end, + // For packets that are not included in RFC 4443, do not + // include the last 32 bits of the ICMPv6 header in + // `header_bytes`. This must be done so that these bytes + // can be accessed in the `payload`. + _ => field::CHECKSUM.end, + } + } + + /// Validate the header checksum. + /// + /// # Fuzzing + /// This function always returns `true` when fuzzing. + pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool { + if cfg!(fuzzing) { + return true; + } + + let data = self.buffer.as_ref(); + checksum::combine(&[ + checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Icmpv6, data.len() as u32), + checksum::data(data), + ]) == !0 + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> { + /// Return a pointer to the type-specific data. + #[inline] + pub fn payload(&self) -> &'a [u8] { + let data = self.buffer.as_ref(); + &data[self.header_len()..] + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the message type field. + #[inline] + pub fn set_msg_type(&mut self, value: Message) { + let data = self.buffer.as_mut(); + data[field::TYPE] = value.into() + } + + /// Set the message code field. + #[inline] + pub fn set_msg_code(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::CODE] = value + } + + /// Clear any reserved fields in the message header. + /// + /// # Panics + /// This function panics if the message type has not been set. + /// See [set_msg_type]. + /// + /// [set_msg_type]: #method.set_msg_type + #[inline] + pub fn clear_reserved(&mut self) { + match self.msg_type() { + Message::RouterSolicit + | Message::NeighborSolicit + | Message::NeighborAdvert + | Message::Redirect => { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::UNUSED], 0); + } + Message::MldQuery => { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::QUERY_RESV], 0); + data[field::SQRV] &= 0xf; + } + Message::MldReport => { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::RECORD_RESV], 0); + } + ty => panic!("Message type `{ty}` does not have any reserved fields."), + } + } + + #[inline] + pub fn set_checksum(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::CHECKSUM], value) + } + + /// Set the identifier field (for echo request and reply packets). + /// + /// # Panics + /// This function may panic if this packet is not an echo request or reply packet. + #[inline] + pub fn set_echo_ident(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::ECHO_IDENT], value) + } + + /// Set the sequence number field (for echo request and reply packets). + /// + /// # Panics + /// This function may panic if this packet is not an echo request or reply packet. + #[inline] + pub fn set_echo_seq_no(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::ECHO_SEQNO], value) + } + + /// Set the MTU field (for packet too big messages). + /// + /// # Panics + /// This function may panic if this packet is not an packet too big packet. + #[inline] + pub fn set_pkt_too_big_mtu(&mut self, value: u32) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::MTU], value) + } + + /// Set the pointer field (for parameter problem messages). + /// + /// # Panics + /// This function may panic if this packet is not a parameter problem message. + #[inline] + pub fn set_param_problem_ptr(&mut self, value: u32) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::POINTER], value) + } + + /// Compute and fill in the header checksum. + pub fn fill_checksum(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress) { + self.set_checksum(0); + let checksum = { + let data = self.buffer.as_ref(); + !checksum::combine(&[ + checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Icmpv6, data.len() as u32), + checksum::data(data), + ]) + }; + self.set_checksum(checksum) + } + + /// Return a mutable pointer to the type-specific data. + #[inline] + pub fn payload_mut(&mut self) -> &mut [u8] { + let range = self.header_len()..; + let data = self.buffer.as_mut(); + &mut data[range] + } +} + +impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> { + fn as_ref(&self) -> &[u8] { + self.buffer.as_ref() + } +} + +/// A high-level representation of an Internet Control Message Protocol version 6 packet header. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] +pub enum Repr<'a> { + DstUnreachable { + reason: DstUnreachable, + header: Ipv6Repr, + data: &'a [u8], + }, + PktTooBig { + mtu: u32, + header: Ipv6Repr, + data: &'a [u8], + }, + TimeExceeded { + reason: TimeExceeded, + header: Ipv6Repr, + data: &'a [u8], + }, + ParamProblem { + reason: ParamProblem, + pointer: u32, + header: Ipv6Repr, + data: &'a [u8], + }, + EchoRequest { + ident: u16, + seq_no: u16, + data: &'a [u8], + }, + EchoReply { + ident: u16, + seq_no: u16, + data: &'a [u8], + }, + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + Ndisc(NdiscRepr<'a>), + Mld(MldRepr<'a>), + #[cfg(feature = "proto-rpl")] + Rpl(RplRepr<'a>), +} + +impl<'a> Repr<'a> { + /// Parse an Internet Control Message Protocol version 6 packet and return + /// a high-level representation. + pub fn parse<T>( + src_addr: &IpAddress, + dst_addr: &IpAddress, + packet: &Packet<&'a T>, + checksum_caps: &ChecksumCapabilities, + ) -> Result<Repr<'a>> + where + T: AsRef<[u8]> + ?Sized, + { + fn create_packet_from_payload<'a, T>(packet: &Packet<&'a T>) -> Result<(&'a [u8], Ipv6Repr)> + where + T: AsRef<[u8]> + ?Sized, + { + // The packet must be truncated to fit the min MTU. Since we don't know the offset of + // the ICMPv6 header in the L2 frame, we should only check whether the payload's IPv6 + // header is present, the rest is allowed to be truncated. + let ip_packet = if packet.payload().len() >= IPV6_HEADER_LEN { + Ipv6Packet::new_unchecked(packet.payload()) + } else { + return Err(Error); + }; + + let payload = &packet.payload()[ip_packet.header_len()..]; + let repr = Ipv6Repr { + src_addr: ip_packet.src_addr(), + dst_addr: ip_packet.dst_addr(), + next_header: ip_packet.next_header(), + payload_len: ip_packet.payload_len().into(), + hop_limit: ip_packet.hop_limit(), + }; + Ok((payload, repr)) + } + // Valid checksum is expected. + if checksum_caps.icmpv6.rx() && !packet.verify_checksum(src_addr, dst_addr) { + return Err(Error); + } + + match (packet.msg_type(), packet.msg_code()) { + (Message::DstUnreachable, code) => { + let (payload, repr) = create_packet_from_payload(packet)?; + Ok(Repr::DstUnreachable { + reason: DstUnreachable::from(code), + header: repr, + data: payload, + }) + } + (Message::PktTooBig, 0) => { + let (payload, repr) = create_packet_from_payload(packet)?; + Ok(Repr::PktTooBig { + mtu: packet.pkt_too_big_mtu(), + header: repr, + data: payload, + }) + } + (Message::TimeExceeded, code) => { + let (payload, repr) = create_packet_from_payload(packet)?; + Ok(Repr::TimeExceeded { + reason: TimeExceeded::from(code), + header: repr, + data: payload, + }) + } + (Message::ParamProblem, code) => { + let (payload, repr) = create_packet_from_payload(packet)?; + Ok(Repr::ParamProblem { + reason: ParamProblem::from(code), + pointer: packet.param_problem_ptr(), + header: repr, + data: payload, + }) + } + (Message::EchoRequest, 0) => Ok(Repr::EchoRequest { + ident: packet.echo_ident(), + seq_no: packet.echo_seq_no(), + data: packet.payload(), + }), + (Message::EchoReply, 0) => Ok(Repr::EchoReply { + ident: packet.echo_ident(), + seq_no: packet.echo_seq_no(), + data: packet.payload(), + }), + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + (msg_type, 0) if msg_type.is_ndisc() => NdiscRepr::parse(packet).map(Repr::Ndisc), + (msg_type, 0) if msg_type.is_mld() => MldRepr::parse(packet).map(Repr::Mld), + #[cfg(feature = "proto-rpl")] + (Message::RplControl, _) => RplRepr::parse(packet).map(Repr::Rpl), + _ => Err(Error), + } + } + + /// Return the length of a packet that will be emitted from this high-level representation. + pub fn buffer_len(&self) -> usize { + match self { + &Repr::DstUnreachable { header, data, .. } + | &Repr::PktTooBig { header, data, .. } + | &Repr::TimeExceeded { header, data, .. } + | &Repr::ParamProblem { header, data, .. } => cmp::min( + field::UNUSED.end + header.buffer_len() + data.len(), + MAX_ERROR_PACKET_LEN, + ), + &Repr::EchoRequest { data, .. } | &Repr::EchoReply { data, .. } => { + field::ECHO_SEQNO.end + data.len() + } + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + &Repr::Ndisc(ndisc) => ndisc.buffer_len(), + &Repr::Mld(mld) => mld.buffer_len(), + #[cfg(feature = "proto-rpl")] + Repr::Rpl(rpl) => rpl.buffer_len(), + } + } + + /// Emit a high-level representation into an Internet Control Message Protocol version 6 + /// packet. + pub fn emit<T>( + &self, + src_addr: &IpAddress, + dst_addr: &IpAddress, + packet: &mut Packet<&mut T>, + checksum_caps: &ChecksumCapabilities, + ) where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { + fn emit_contained_packet<T>(packet: &mut Packet<&mut T>, header: Ipv6Repr, data: &[u8]) + where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { + let icmp_header_len = packet.header_len(); + let mut ip_packet = Ipv6Packet::new_unchecked(packet.payload_mut()); + header.emit(&mut ip_packet); + let payload = &mut ip_packet.into_inner()[header.buffer_len()..]; + // FIXME: this should rather be checked at link level, as we can't know in advance how + // much space we have for the packet due to IPv6 options and etc + let payload_len = cmp::min( + data.len(), + MAX_ERROR_PACKET_LEN - icmp_header_len - IPV6_HEADER_LEN, + ); + payload[..payload_len].copy_from_slice(&data[..payload_len]); + } + + match *self { + Repr::DstUnreachable { + reason, + header, + data, + } => { + packet.set_msg_type(Message::DstUnreachable); + packet.set_msg_code(reason.into()); + + emit_contained_packet(packet, header, data); + } + + Repr::PktTooBig { mtu, header, data } => { + packet.set_msg_type(Message::PktTooBig); + packet.set_msg_code(0); + packet.set_pkt_too_big_mtu(mtu); + + emit_contained_packet(packet, header, data); + } + + Repr::TimeExceeded { + reason, + header, + data, + } => { + packet.set_msg_type(Message::TimeExceeded); + packet.set_msg_code(reason.into()); + + emit_contained_packet(packet, header, data); + } + + Repr::ParamProblem { + reason, + pointer, + header, + data, + } => { + packet.set_msg_type(Message::ParamProblem); + packet.set_msg_code(reason.into()); + packet.set_param_problem_ptr(pointer); + + emit_contained_packet(packet, header, data); + } + + Repr::EchoRequest { + ident, + seq_no, + data, + } => { + packet.set_msg_type(Message::EchoRequest); + packet.set_msg_code(0); + packet.set_echo_ident(ident); + packet.set_echo_seq_no(seq_no); + let data_len = cmp::min(packet.payload_mut().len(), data.len()); + packet.payload_mut()[..data_len].copy_from_slice(&data[..data_len]) + } + + Repr::EchoReply { + ident, + seq_no, + data, + } => { + packet.set_msg_type(Message::EchoReply); + packet.set_msg_code(0); + packet.set_echo_ident(ident); + packet.set_echo_seq_no(seq_no); + let data_len = cmp::min(packet.payload_mut().len(), data.len()); + packet.payload_mut()[..data_len].copy_from_slice(&data[..data_len]) + } + + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + Repr::Ndisc(ndisc) => ndisc.emit(packet), + + Repr::Mld(mld) => mld.emit(packet), + + #[cfg(feature = "proto-rpl")] + Repr::Rpl(ref rpl) => rpl.emit(packet), + } + + if checksum_caps.icmpv6.tx() { + packet.fill_checksum(src_addr, dst_addr); + } else { + // make sure we get a consistently zeroed checksum, since implementations might rely on it + packet.set_checksum(0); + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::wire::ip::test::{MOCK_IP_ADDR_1, MOCK_IP_ADDR_2}; + use crate::wire::{IpProtocol, Ipv6Address, Ipv6Repr}; + + static ECHO_PACKET_BYTES: [u8; 12] = [ + 0x80, 0x00, 0x19, 0xb3, 0x12, 0x34, 0xab, 0xcd, 0xaa, 0x00, 0x00, 0xff, + ]; + + static ECHO_PACKET_PAYLOAD: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; + + static PKT_TOO_BIG_BYTES: [u8; 60] = [ + 0x02, 0x00, 0x0f, 0xc9, 0x00, 0x00, 0x05, 0xdc, 0x60, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x11, + 0x40, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x02, 0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff, + ]; + + static PKT_TOO_BIG_IP_PAYLOAD: [u8; 52] = [ + 0x60, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x11, 0x40, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xbf, 0x00, 0x00, 0x35, 0x00, + 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff, + ]; + + static PKT_TOO_BIG_UDP_PAYLOAD: [u8; 12] = [ + 0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff, + ]; + + fn echo_packet_repr() -> Repr<'static> { + Repr::EchoRequest { + ident: 0x1234, + seq_no: 0xabcd, + data: &ECHO_PACKET_PAYLOAD, + } + } + + fn too_big_packet_repr() -> Repr<'static> { + Repr::PktTooBig { + mtu: 1500, + header: Ipv6Repr { + src_addr: Ipv6Address([ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x01, + ]), + dst_addr: Ipv6Address([ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x02, + ]), + next_header: IpProtocol::Udp, + payload_len: 12, + hop_limit: 0x40, + }, + data: &PKT_TOO_BIG_UDP_PAYLOAD, + } + } + + #[test] + fn test_echo_deconstruct() { + let packet = Packet::new_unchecked(&ECHO_PACKET_BYTES[..]); + assert_eq!(packet.msg_type(), Message::EchoRequest); + assert_eq!(packet.msg_code(), 0); + assert_eq!(packet.checksum(), 0x19b3); + assert_eq!(packet.echo_ident(), 0x1234); + assert_eq!(packet.echo_seq_no(), 0xabcd); + assert_eq!(packet.payload(), &ECHO_PACKET_PAYLOAD[..]); + assert!(packet.verify_checksum(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2)); + assert!(!packet.msg_type().is_error()); + } + + #[test] + fn test_echo_construct() { + let mut bytes = vec![0xa5; 12]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_msg_type(Message::EchoRequest); + packet.set_msg_code(0); + packet.set_echo_ident(0x1234); + packet.set_echo_seq_no(0xabcd); + packet + .payload_mut() + .copy_from_slice(&ECHO_PACKET_PAYLOAD[..]); + packet.fill_checksum(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2); + assert_eq!(&*packet.into_inner(), &ECHO_PACKET_BYTES[..]); + } + + #[test] + fn test_echo_repr_parse() { + let packet = Packet::new_unchecked(&ECHO_PACKET_BYTES[..]); + let repr = Repr::parse( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &packet, + &ChecksumCapabilities::default(), + ) + .unwrap(); + assert_eq!(repr, echo_packet_repr()); + } + + #[test] + fn test_echo_emit() { + let repr = echo_packet_repr(); + let mut bytes = vec![0xa5; repr.buffer_len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + repr.emit( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &mut packet, + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &ECHO_PACKET_BYTES[..]); + } + + #[test] + fn test_too_big_deconstruct() { + let packet = Packet::new_unchecked(&PKT_TOO_BIG_BYTES[..]); + assert_eq!(packet.msg_type(), Message::PktTooBig); + assert_eq!(packet.msg_code(), 0); + assert_eq!(packet.checksum(), 0x0fc9); + assert_eq!(packet.pkt_too_big_mtu(), 1500); + assert_eq!(packet.payload(), &PKT_TOO_BIG_IP_PAYLOAD[..]); + assert!(packet.verify_checksum(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2)); + assert!(packet.msg_type().is_error()); + } + + #[test] + fn test_too_big_construct() { + let mut bytes = vec![0xa5; 60]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_msg_type(Message::PktTooBig); + packet.set_msg_code(0); + packet.set_pkt_too_big_mtu(1500); + packet + .payload_mut() + .copy_from_slice(&PKT_TOO_BIG_IP_PAYLOAD[..]); + packet.fill_checksum(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2); + assert_eq!(&*packet.into_inner(), &PKT_TOO_BIG_BYTES[..]); + } + + #[test] + fn test_too_big_repr_parse() { + let packet = Packet::new_unchecked(&PKT_TOO_BIG_BYTES[..]); + let repr = Repr::parse( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &packet, + &ChecksumCapabilities::default(), + ) + .unwrap(); + assert_eq!(repr, too_big_packet_repr()); + } + + #[test] + fn test_too_big_emit() { + let repr = too_big_packet_repr(); + let mut bytes = vec![0xa5; repr.buffer_len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + repr.emit( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &mut packet, + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &PKT_TOO_BIG_BYTES[..]); + } + + #[test] + fn test_buffer_length_is_truncated_to_mtu() { + let repr = Repr::PktTooBig { + mtu: 1280, + header: Ipv6Repr { + src_addr: Default::default(), + dst_addr: Default::default(), + next_header: IpProtocol::Tcp, + hop_limit: 64, + payload_len: 1280, + }, + data: &vec![0; 9999], + }; + assert_eq!(repr.buffer_len(), 1280 - IPV6_HEADER_LEN); + } + + #[test] + fn test_mtu_truncated_payload_roundtrip() { + let ip_packet_repr = Ipv6Repr { + src_addr: Default::default(), + dst_addr: Default::default(), + next_header: IpProtocol::Tcp, + hop_limit: 64, + payload_len: IPV6_MIN_MTU - IPV6_HEADER_LEN, + }; + let mut ip_packet = Ipv6Packet::new_unchecked(vec![0; IPV6_MIN_MTU]); + ip_packet_repr.emit(&mut ip_packet); + + let repr1 = Repr::PktTooBig { + mtu: IPV6_MIN_MTU as u32, + header: ip_packet_repr, + data: &ip_packet.as_ref()[IPV6_HEADER_LEN..], + }; + // this is needed to make sure roundtrip gives the same value + // it is not needed for ensuring the correct bytes get emitted + let repr1 = Repr::PktTooBig { + mtu: IPV6_MIN_MTU as u32, + header: ip_packet_repr, + data: &ip_packet.as_ref()[IPV6_HEADER_LEN..repr1.buffer_len() - field::UNUSED.end], + }; + let mut data = vec![0; MAX_ERROR_PACKET_LEN]; + let mut packet = Packet::new_unchecked(&mut data); + repr1.emit( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &mut packet, + &ChecksumCapabilities::default(), + ); + + let packet = Packet::new_unchecked(&data); + let repr2 = Repr::parse( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &packet, + &ChecksumCapabilities::default(), + ) + .unwrap(); + + assert_eq!(repr1, repr2); + } + + #[test] + fn test_truncated_payload_ipv6_header_parse_fails() { + let repr = too_big_packet_repr(); + let mut bytes = vec![0xa5; repr.buffer_len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + repr.emit( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &mut packet, + &ChecksumCapabilities::default(), + ); + let packet = Packet::new_unchecked(&bytes[..field::HEADER_END + IPV6_HEADER_LEN - 1]); + assert!(Repr::parse( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &packet, + &ChecksumCapabilities::ignored(), + ) + .is_err()); + } +} diff --git a/src/wire/ieee802154.rs b/src/wire/ieee802154.rs new file mode 100644 index 0000000..33fafb6 --- /dev/null +++ b/src/wire/ieee802154.rs @@ -0,0 +1,1182 @@ +use core::fmt; + +use byteorder::{ByteOrder, LittleEndian}; + +use super::{Error, Result}; +use crate::wire::ipv6::Address as Ipv6Address; + +enum_with_unknown! { + /// IEEE 802.15.4 frame type. + pub enum FrameType(u8) { + Beacon = 0b000, + Data = 0b001, + Acknowledgement = 0b010, + MacCommand = 0b011, + Multipurpose = 0b101, + FragmentOrFrak = 0b110, + Extended = 0b111, + } +} + +impl fmt::Display for FrameType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + FrameType::Beacon => write!(f, "Beacon"), + FrameType::Data => write!(f, "Data"), + FrameType::Acknowledgement => write!(f, "Ack"), + FrameType::MacCommand => write!(f, "MAC command"), + FrameType::Multipurpose => write!(f, "Multipurpose"), + FrameType::FragmentOrFrak => write!(f, "FragmentOrFrak"), + FrameType::Extended => write!(f, "Extended"), + FrameType::Unknown(id) => write!(f, "0b{id:04b}"), + } + } +} +enum_with_unknown! { + /// IEEE 802.15.4 addressing mode for destination and source addresses. + pub enum AddressingMode(u8) { + Absent = 0b00, + Short = 0b10, + Extended = 0b11, + } +} + +impl AddressingMode { + /// Return the size in octets of the address. + const fn size(&self) -> usize { + match self { + AddressingMode::Absent => 0, + AddressingMode::Short => 2, + AddressingMode::Extended => 8, + AddressingMode::Unknown(_) => 0, // TODO(thvdveld): what do we need to here? + } + } +} + +impl fmt::Display for AddressingMode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + AddressingMode::Absent => write!(f, "Absent"), + AddressingMode::Short => write!(f, "Short"), + AddressingMode::Extended => write!(f, "Extended"), + AddressingMode::Unknown(id) => write!(f, "0b{id:04b}"), + } + } +} + +/// A IEEE 802.15.4 PAN. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub struct Pan(pub u16); + +impl Pan { + pub const BROADCAST: Self = Self(0xffff); + + /// Return the PAN ID as bytes. + pub fn as_bytes(&self) -> [u8; 2] { + let mut pan = [0u8; 2]; + LittleEndian::write_u16(&mut pan, self.0); + pan + } +} + +impl fmt::Display for Pan { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:0x}", self.0) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Pan { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "{:02x}", self.0) + } +} + +/// A IEEE 802.15.4 address. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub enum Address { + Absent, + Short([u8; 2]), + Extended([u8; 8]), +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Address { + fn format(&self, f: defmt::Formatter) { + match self { + Self::Absent => defmt::write!(f, "not-present"), + Self::Short(bytes) => defmt::write!(f, "{:02x}:{:02x}", bytes[0], bytes[1]), + Self::Extended(bytes) => defmt::write!( + f, + "{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}", + bytes[0], + bytes[1], + bytes[2], + bytes[3], + bytes[4], + bytes[5], + bytes[6], + bytes[7] + ), + } + } +} + +#[cfg(test)] +impl Default for Address { + fn default() -> Self { + Address::Extended([0u8; 8]) + } +} + +impl Address { + /// The broadcast address. + pub const BROADCAST: Address = Address::Short([0xff; 2]); + + /// Query whether the address is an unicast address. + pub fn is_unicast(&self) -> bool { + !self.is_broadcast() + } + + /// Query whether this address is the broadcast address. + pub fn is_broadcast(&self) -> bool { + *self == Self::BROADCAST + } + + const fn short_from_bytes(a: [u8; 2]) -> Self { + Self::Short(a) + } + + const fn extended_from_bytes(a: [u8; 8]) -> Self { + Self::Extended(a) + } + + pub fn from_bytes(a: &[u8]) -> Self { + if a.len() == 2 { + let mut b = [0u8; 2]; + b.copy_from_slice(a); + Address::Short(b) + } else if a.len() == 8 { + let mut b = [0u8; 8]; + b.copy_from_slice(a); + Address::Extended(b) + } else { + panic!("Not an IEEE802.15.4 address"); + } + } + + pub const fn as_bytes(&self) -> &[u8] { + match self { + Address::Absent => &[], + Address::Short(value) => value, + Address::Extended(value) => value, + } + } + + /// Convert the extended address to an Extended Unique Identifier (EUI-64) + pub fn as_eui_64(&self) -> Option<[u8; 8]> { + match self { + Address::Absent | Address::Short(_) => None, + Address::Extended(value) => { + let mut bytes = [0; 8]; + bytes.copy_from_slice(&value[..]); + + bytes[0] ^= 1 << 1; + + Some(bytes) + } + } + } + + /// Convert an extended address to a link-local IPv6 address using the EUI-64 format from + /// RFC2464. + pub fn as_link_local_address(&self) -> Option<Ipv6Address> { + let mut bytes = [0; 16]; + bytes[0] = 0xfe; + bytes[1] = 0x80; + bytes[8..].copy_from_slice(&self.as_eui_64()?); + + Some(Ipv6Address::from_bytes(&bytes)) + } +} + +impl fmt::Display for Address { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Absent => write!(f, "not-present"), + Self::Short(bytes) => write!(f, "{:02x}:{:02x}", bytes[0], bytes[1]), + Self::Extended(bytes) => write!( + f, + "{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}", + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7] + ), + } + } +} + +enum_with_unknown! { + /// IEEE 802.15.4 addressing mode for destination and source addresses. + pub enum FrameVersion(u8) { + Ieee802154_2003 = 0b00, + Ieee802154_2006 = 0b01, + Ieee802154 = 0b10, + } +} + +/// A read/write wrapper around an IEEE 802.15.4 frame buffer. +#[derive(Debug, Clone)] +pub struct Frame<T: AsRef<[u8]>> { + buffer: T, +} + +mod field { + use crate::wire::field::*; + + pub const FRAMECONTROL: Field = 0..2; + pub const SEQUENCE_NUMBER: usize = 2; + pub const ADDRESSING: Rest = 3..; +} + +macro_rules! fc_bit_field { + ($field:ident, $bit:literal) => { + #[inline] + pub fn $field(&self) -> bool { + let data = self.buffer.as_ref(); + let raw = LittleEndian::read_u16(&data[field::FRAMECONTROL]); + + ((raw >> $bit) & 0b1) == 0b1 + } + }; +} + +macro_rules! set_fc_bit_field { + ($field:ident, $bit:literal) => { + #[inline] + pub fn $field(&mut self, val: bool) { + let data = &mut self.buffer.as_mut()[field::FRAMECONTROL]; + let mut raw = LittleEndian::read_u16(data); + raw |= ((val as u16) << $bit); + + data.copy_from_slice(&raw.to_le_bytes()); + } + }; +} + +impl<T: AsRef<[u8]>> Frame<T> { + /// Input a raw octet buffer with Ethernet frame structure. + pub const fn new_unchecked(buffer: T) -> Frame<T> { + Frame { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Frame<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + + // We don't handle unknown frame versions. + if matches!(packet.frame_version(), FrameVersion::Unknown(_)) { + return Err(Error); + } + + // We don't handle unknown addressing modes. + if matches!(packet.dst_addressing_mode(), AddressingMode::Unknown(_)) + || matches!(packet.src_addressing_mode(), AddressingMode::Unknown(_)) + { + return Err(Error); + } + + // We don't handle absent addressing mode with PAN ID compression for older frame versions. + if matches!( + packet.frame_version(), + FrameVersion::Ieee802154_2003 | FrameVersion::Ieee802154_2006 + ) && packet.pan_id_compression() + && matches!(packet.dst_addressing_mode(), AddressingMode::Absent) + && matches!(packet.src_addressing_mode(), AddressingMode::Absent) + { + return Err(Error); + } + + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + // We need at least 3 bytes + if self.buffer.as_ref().len() < 3 { + return Err(Error); + } + + // We don't handle frames with a payload larger than 127 bytes. + if self.buffer.as_ref().len() > 127 { + return Err(Error); + } + + let mut offset = field::ADDRESSING.start + + if let Some((dst_pan_id, dst_addr, src_pan_id, src_addr)) = self.addr_present_flags() + { + let mut offset = if dst_pan_id { 2 } else { 0 }; + offset += dst_addr.size(); + offset += if src_pan_id { 2 } else { 0 }; + offset += src_addr.size(); + + if offset > self.buffer.as_ref().len() { + return Err(Error); + } + offset + } else { + 0 + }; + + if self.security_enabled() { + // First check that we can access the security header control bits. + if offset + 1 > self.buffer.as_ref().len() { + return Err(Error); + } + + offset += self.security_header_len(); + } + + if offset > self.buffer.as_ref().len() { + return Err(Error); + } + + Ok(()) + } + + /// Consumes the frame, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the FrameType field. + #[inline] + pub fn frame_type(&self) -> FrameType { + let data = self.buffer.as_ref(); + let raw = LittleEndian::read_u16(&data[field::FRAMECONTROL]); + let ft = (raw & 0b111) as u8; + FrameType::from(ft) + } + + fc_bit_field!(security_enabled, 3); + fc_bit_field!(frame_pending, 4); + fc_bit_field!(ack_request, 5); + fc_bit_field!(pan_id_compression, 6); + + fc_bit_field!(sequence_number_suppression, 8); + fc_bit_field!(ie_present, 9); + + /// Return the destination addressing mode. + #[inline] + pub fn dst_addressing_mode(&self) -> AddressingMode { + let data = self.buffer.as_ref(); + let raw = LittleEndian::read_u16(&data[field::FRAMECONTROL]); + let am = ((raw >> 10) & 0b11) as u8; + AddressingMode::from(am) + } + + /// Return the frame version. + #[inline] + pub fn frame_version(&self) -> FrameVersion { + let data = self.buffer.as_ref(); + let raw = LittleEndian::read_u16(&data[field::FRAMECONTROL]); + let fv = ((raw >> 12) & 0b11) as u8; + FrameVersion::from(fv) + } + + /// Return the source addressing mode. + #[inline] + pub fn src_addressing_mode(&self) -> AddressingMode { + let data = self.buffer.as_ref(); + let raw = LittleEndian::read_u16(&data[field::FRAMECONTROL]); + let am = ((raw >> 14) & 0b11) as u8; + AddressingMode::from(am) + } + + /// Return the sequence number of the frame. + #[inline] + pub fn sequence_number(&self) -> Option<u8> { + match self.frame_type() { + FrameType::Beacon + | FrameType::Data + | FrameType::Acknowledgement + | FrameType::MacCommand + | FrameType::Multipurpose => { + let data = self.buffer.as_ref(); + let raw = data[field::SEQUENCE_NUMBER]; + Some(raw) + } + FrameType::Extended | FrameType::FragmentOrFrak | FrameType::Unknown(_) => None, + } + } + + /// Return the addressing fields. + #[inline] + fn addressing_fields(&self) -> Option<&[u8]> { + match self.frame_type() { + FrameType::Beacon + | FrameType::Data + | FrameType::MacCommand + | FrameType::Multipurpose => (), + FrameType::Acknowledgement if self.frame_version() == FrameVersion::Ieee802154 => (), + FrameType::Acknowledgement + | FrameType::Extended + | FrameType::FragmentOrFrak + | FrameType::Unknown(_) => return None, + } + + if let Some((dst_pan_id, dst_addr, src_pan_id, src_addr)) = self.addr_present_flags() { + let mut offset = if dst_pan_id { 2 } else { 0 }; + offset += dst_addr.size(); + offset += if src_pan_id { 2 } else { 0 }; + offset += src_addr.size(); + + let data = self.buffer.as_ref(); + Some(&data[field::ADDRESSING][..offset]) + } else { + None + } + } + + fn addr_present_flags(&self) -> Option<(bool, AddressingMode, bool, AddressingMode)> { + let dst_addr_mode = self.dst_addressing_mode(); + let src_addr_mode = self.src_addressing_mode(); + let pan_id_compression = self.pan_id_compression(); + + use AddressingMode::*; + match self.frame_version() { + FrameVersion::Ieee802154_2003 | FrameVersion::Ieee802154_2006 => { + match (dst_addr_mode, src_addr_mode) { + (Absent, src) => Some((false, Absent, true, src)), + (dst, Absent) => Some((true, dst, false, Absent)), + + (dst, src) if pan_id_compression => Some((true, dst, false, src)), + (dst, src) if !pan_id_compression => Some((true, dst, true, src)), + _ => None, + } + } + FrameVersion::Ieee802154 => { + Some(match (dst_addr_mode, src_addr_mode, pan_id_compression) { + (Absent, Absent, false) => (false, Absent, false, Absent), + (Absent, Absent, true) => (true, Absent, false, Absent), + (dst, Absent, false) if !matches!(dst, Absent) => (true, dst, false, Absent), + (dst, Absent, true) if !matches!(dst, Absent) => (false, dst, false, Absent), + (Absent, src, false) if !matches!(src, Absent) => (false, Absent, true, src), + (Absent, src, true) if !matches!(src, Absent) => (false, Absent, true, src), + (Extended, Extended, false) => (true, Extended, false, Extended), + (Extended, Extended, true) => (false, Extended, false, Extended), + (Short, Short, false) => (true, Short, true, Short), + (Short, Extended, false) => (true, Short, true, Extended), + (Extended, Short, false) => (true, Extended, true, Short), + (Short, Extended, true) => (true, Short, false, Extended), + (Extended, Short, true) => (true, Extended, false, Short), + (Short, Short, true) => (true, Short, false, Short), + _ => return None, + }) + } + _ => None, + } + } + + /// Return the destination PAN field. + #[inline] + pub fn dst_pan_id(&self) -> Option<Pan> { + if let Some((true, _, _, _)) = self.addr_present_flags() { + let addressing_fields = self.addressing_fields()?; + Some(Pan(LittleEndian::read_u16(&addressing_fields[..2]))) + } else { + None + } + } + + /// Return the destination address field. + #[inline] + pub fn dst_addr(&self) -> Option<Address> { + if let Some((dst_pan_id, dst_addr, _, _)) = self.addr_present_flags() { + let addressing_fields = self.addressing_fields()?; + let offset = if dst_pan_id { 2 } else { 0 }; + + match dst_addr { + AddressingMode::Absent => Some(Address::Absent), + AddressingMode::Short => { + let mut raw = [0u8; 2]; + raw.clone_from_slice(&addressing_fields[offset..offset + 2]); + raw.reverse(); + Some(Address::short_from_bytes(raw)) + } + AddressingMode::Extended => { + let mut raw = [0u8; 8]; + raw.clone_from_slice(&addressing_fields[offset..offset + 8]); + raw.reverse(); + Some(Address::extended_from_bytes(raw)) + } + AddressingMode::Unknown(_) => None, + } + } else { + None + } + } + + /// Return the destination PAN field. + #[inline] + pub fn src_pan_id(&self) -> Option<Pan> { + if let Some((dst_pan_id, dst_addr, true, _)) = self.addr_present_flags() { + let mut offset = if dst_pan_id { 2 } else { 0 }; + offset += dst_addr.size(); + let addressing_fields = self.addressing_fields()?; + Some(Pan(LittleEndian::read_u16( + &addressing_fields[offset..][..2], + ))) + } else { + None + } + } + + /// Return the source address field. + #[inline] + pub fn src_addr(&self) -> Option<Address> { + if let Some((dst_pan_id, dst_addr, src_pan_id, src_addr)) = self.addr_present_flags() { + let addressing_fields = self.addressing_fields()?; + let mut offset = if dst_pan_id { 2 } else { 0 }; + offset += dst_addr.size(); + offset += if src_pan_id { 2 } else { 0 }; + + match src_addr { + AddressingMode::Absent => Some(Address::Absent), + AddressingMode::Short => { + let mut raw = [0u8; 2]; + raw.clone_from_slice(&addressing_fields[offset..offset + 2]); + raw.reverse(); + Some(Address::short_from_bytes(raw)) + } + AddressingMode::Extended => { + let mut raw = [0u8; 8]; + raw.clone_from_slice(&addressing_fields[offset..offset + 8]); + raw.reverse(); + Some(Address::extended_from_bytes(raw)) + } + AddressingMode::Unknown(_) => None, + } + } else { + None + } + } + + /// Return the index where the auxiliary security header starts. + fn aux_security_header_start(&self) -> usize { + // We start with 3, because 2 bytes for frame control and the sequence number. + let mut index = 3; + index += if let Some(addrs) = self.addressing_fields() { + addrs.len() + } else { + 0 + }; + index + } + + /// Return the size of the security header. + fn security_header_len(&self) -> usize { + let mut size = 1; + size += if self.frame_counter_suppressed() { + 0 + } else { + 4 + }; + size += if let Some(len) = self.key_identifier_length() { + len as usize + } else { + 0 + }; + size + } + + /// Return the index where the payload starts. + fn payload_start(&self) -> usize { + let mut index = self.aux_security_header_start(); + + if self.security_enabled() { + index += self.security_header_len(); + } + + index + } + + /// Return the length of the key identifier field. + fn key_identifier_length(&self) -> Option<u8> { + Some(match self.key_identifier_mode() { + 0 => 0, + 1 => 1, + 2 => 5, + 3 => 9, + _ => return None, + }) + } + + /// Return the security level of the auxiliary security header. + pub fn security_level(&self) -> u8 { + let index = self.aux_security_header_start(); + let b = self.buffer.as_ref()[index..][0]; + b & 0b111 + } + + /// Return the key identifier mode used by the auxiliary security header. + pub fn key_identifier_mode(&self) -> u8 { + let index = self.aux_security_header_start(); + let b = self.buffer.as_ref()[index..][0]; + (b >> 3) & 0b11 + } + + /// Return `true` when the frame counter in the security header is suppressed. + pub fn frame_counter_suppressed(&self) -> bool { + let index = self.aux_security_header_start(); + let b = self.buffer.as_ref()[index..][0]; + ((b >> 5) & 0b1) == 0b1 + } + + /// Return the frame counter field. + pub fn frame_counter(&self) -> Option<u32> { + if self.frame_counter_suppressed() { + None + } else { + let index = self.aux_security_header_start(); + let b = &self.buffer.as_ref()[index..]; + Some(LittleEndian::read_u32(&b[1..1 + 4])) + } + } + + /// Return the Key Identifier field. + fn key_identifier(&self) -> &[u8] { + let index = self.aux_security_header_start(); + let b = &self.buffer.as_ref()[index..]; + let length = if let Some(len) = self.key_identifier_length() { + len as usize + } else { + 0 + }; + &b[5..][..length] + } + + /// Return the Key Source field. + pub fn key_source(&self) -> Option<&[u8]> { + let ki = self.key_identifier(); + let len = ki.len(); + if len > 1 { + Some(&ki[..len - 1]) + } else { + None + } + } + + /// Return the Key Index field. + pub fn key_index(&self) -> Option<u8> { + let ki = self.key_identifier(); + let len = ki.len(); + + if len > 0 { + Some(ki[len - 1]) + } else { + None + } + } + + /// Return the Message Integrity Code (MIC). + pub fn message_integrity_code(&self) -> Option<&[u8]> { + let mic_len = match self.security_level() { + 0 | 4 => return None, + 1 | 5 => 4, + 2 | 6 => 8, + 3 | 7 => 16, + _ => panic!(), + }; + + let data = &self.buffer.as_ref(); + let len = data.len(); + + Some(&data[len - mic_len..]) + } + + /// Return the MAC header. + pub fn mac_header(&self) -> &[u8] { + let data = &self.buffer.as_ref(); + &data[..self.payload_start()] + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Frame<&'a T> { + /// Return a pointer to the payload. + #[inline] + pub fn payload(&self) -> Option<&'a [u8]> { + match self.frame_type() { + FrameType::Data => { + let index = self.payload_start(); + let data = &self.buffer.as_ref(); + + Some(&data[index..]) + } + _ => None, + } + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Frame<T> { + /// Set the frame type. + #[inline] + pub fn set_frame_type(&mut self, frame_type: FrameType) { + let data = &mut self.buffer.as_mut()[field::FRAMECONTROL]; + let mut raw = LittleEndian::read_u16(data); + + raw = (raw & !(0b111)) | (u8::from(frame_type) as u16 & 0b111); + data.copy_from_slice(&raw.to_le_bytes()); + } + + set_fc_bit_field!(set_security_enabled, 3); + set_fc_bit_field!(set_frame_pending, 4); + set_fc_bit_field!(set_ack_request, 5); + set_fc_bit_field!(set_pan_id_compression, 6); + + /// Set the frame version. + #[inline] + pub fn set_frame_version(&mut self, version: FrameVersion) { + let data = &mut self.buffer.as_mut()[field::FRAMECONTROL]; + let mut raw = LittleEndian::read_u16(data); + + raw = (raw & !(0b11 << 12)) | ((u8::from(version) as u16 & 0b11) << 12); + data.copy_from_slice(&raw.to_le_bytes()); + } + + /// Set the frame sequence number. + #[inline] + pub fn set_sequence_number(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::SEQUENCE_NUMBER] = value; + } + + /// Set the destination PAN ID. + #[inline] + pub fn set_dst_pan_id(&mut self, value: Pan) { + // NOTE the destination addressing mode must be different than Absent. + // This is the reason why we set it to Extended. + self.set_dst_addressing_mode(AddressingMode::Extended); + + let data = self.buffer.as_mut(); + data[field::ADDRESSING][..2].copy_from_slice(&value.as_bytes()); + } + + /// Set the destination address. + #[inline] + pub fn set_dst_addr(&mut self, value: Address) { + match value { + Address::Absent => self.set_dst_addressing_mode(AddressingMode::Absent), + Address::Short(mut value) => { + value.reverse(); + self.set_dst_addressing_mode(AddressingMode::Short); + let data = self.buffer.as_mut(); + data[field::ADDRESSING][2..2 + 2].copy_from_slice(&value); + value.reverse(); + } + Address::Extended(mut value) => { + value.reverse(); + self.set_dst_addressing_mode(AddressingMode::Extended); + let data = &mut self.buffer.as_mut()[field::ADDRESSING]; + data[2..2 + 8].copy_from_slice(&value); + value.reverse(); + } + } + } + + /// Set the destination addressing mode. + #[inline] + fn set_dst_addressing_mode(&mut self, value: AddressingMode) { + let data = &mut self.buffer.as_mut()[field::FRAMECONTROL]; + let mut raw = LittleEndian::read_u16(data); + + raw = (raw & !(0b11 << 10)) | ((u8::from(value) as u16 & 0b11) << 10); + data.copy_from_slice(&raw.to_le_bytes()); + } + + /// Set the source PAN ID. + #[inline] + pub fn set_src_pan_id(&mut self, value: Pan) { + let offset = match self.dst_addressing_mode() { + AddressingMode::Absent => 0, + AddressingMode::Short => 2, + AddressingMode::Extended => 8, + _ => unreachable!(), + } + 2; + + let data = &mut self.buffer.as_mut()[field::ADDRESSING]; + data[offset..offset + 2].copy_from_slice(&value.as_bytes()); + } + + /// Set the source address. + #[inline] + pub fn set_src_addr(&mut self, value: Address) { + let offset = match self.dst_addressing_mode() { + AddressingMode::Absent => 0, + AddressingMode::Short => 2, + AddressingMode::Extended => 8, + _ => unreachable!(), + } + 2; + + let offset = offset + if self.pan_id_compression() { 0 } else { 2 }; + + match value { + Address::Absent => self.set_src_addressing_mode(AddressingMode::Absent), + Address::Short(mut value) => { + value.reverse(); + self.set_src_addressing_mode(AddressingMode::Short); + let data = &mut self.buffer.as_mut()[field::ADDRESSING]; + data[offset..offset + 2].copy_from_slice(&value); + value.reverse(); + } + Address::Extended(mut value) => { + value.reverse(); + self.set_src_addressing_mode(AddressingMode::Extended); + let data = &mut self.buffer.as_mut()[field::ADDRESSING]; + data[offset..offset + 8].copy_from_slice(&value); + value.reverse(); + } + } + } + + /// Set the source addressing mode. + #[inline] + fn set_src_addressing_mode(&mut self, value: AddressingMode) { + let data = &mut self.buffer.as_mut()[field::FRAMECONTROL]; + let mut raw = LittleEndian::read_u16(data); + + raw = (raw & !(0b11 << 14)) | ((u8::from(value) as u16 & 0b11) << 14); + data.copy_from_slice(&raw.to_le_bytes()); + } + + /// Return a mutable pointer to the payload. + #[inline] + pub fn payload_mut(&mut self) -> Option<&mut [u8]> { + match self.frame_type() { + FrameType::Data => { + let index = self.payload_start(); + let data = self.buffer.as_mut(); + Some(&mut data[index..]) + } + _ => None, + } + } +} + +impl<T: AsRef<[u8]>> fmt::Display for Frame<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "IEEE802.15.4 frame type={}", self.frame_type())?; + + if let Some(seq) = self.sequence_number() { + write!(f, " seq={:02x}", seq)?; + } + + if let Some(pan) = self.dst_pan_id() { + write!(f, " dst-pan={}", pan)?; + } + + if let Some(pan) = self.src_pan_id() { + write!(f, " src-pan={}", pan)?; + } + + if let Some(addr) = self.dst_addr() { + write!(f, " dst={}", addr)?; + } + + if let Some(addr) = self.src_addr() { + write!(f, " src={}", addr)?; + } + + Ok(()) + } +} + +#[cfg(feature = "defmt")] +impl<T: AsRef<[u8]>> defmt::Format for Frame<T> { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "IEEE802.15.4 frame type={}", self.frame_type()); + + if let Some(seq) = self.sequence_number() { + defmt::write!(f, " seq={:02x}", seq); + } + + if let Some(pan) = self.dst_pan_id() { + defmt::write!(f, " dst-pan={}", pan); + } + + if let Some(pan) = self.src_pan_id() { + defmt::write!(f, " src-pan={}", pan); + } + + if let Some(addr) = self.dst_addr() { + defmt::write!(f, " dst={}", addr); + } + + if let Some(addr) = self.src_addr() { + defmt::write!(f, " src={}", addr); + } + } +} + +/// A high-level representation of an IEEE802.15.4 frame. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Repr { + pub frame_type: FrameType, + pub security_enabled: bool, + pub frame_pending: bool, + pub ack_request: bool, + pub sequence_number: Option<u8>, + pub pan_id_compression: bool, + pub frame_version: FrameVersion, + pub dst_pan_id: Option<Pan>, + pub dst_addr: Option<Address>, + pub src_pan_id: Option<Pan>, + pub src_addr: Option<Address>, +} + +impl Repr { + /// Parse an IEEE 802.15.4 frame and return a high-level representation. + pub fn parse<T: AsRef<[u8]> + ?Sized>(packet: &Frame<&T>) -> Result<Repr> { + // Ensure the basic accessors will work. + packet.check_len()?; + + Ok(Repr { + frame_type: packet.frame_type(), + security_enabled: packet.security_enabled(), + frame_pending: packet.frame_pending(), + ack_request: packet.ack_request(), + sequence_number: packet.sequence_number(), + pan_id_compression: packet.pan_id_compression(), + frame_version: packet.frame_version(), + dst_pan_id: packet.dst_pan_id(), + dst_addr: packet.dst_addr(), + src_pan_id: packet.src_pan_id(), + src_addr: packet.src_addr(), + }) + } + + /// Return the length of a buffer required to hold a packet with the payload of a given length. + #[inline] + pub const fn buffer_len(&self) -> usize { + 3 + 2 + + match self.dst_addr { + Some(Address::Absent) | None => 0, + Some(Address::Short(_)) => 2, + Some(Address::Extended(_)) => 8, + } + + if !self.pan_id_compression { 2 } else { 0 } + + match self.src_addr { + Some(Address::Absent) | None => 0, + Some(Address::Short(_)) => 2, + Some(Address::Extended(_)) => 8, + } + } + + /// Emit a high-level representation into an IEEE802.15.4 frame. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, frame: &mut Frame<T>) { + frame.set_frame_type(self.frame_type); + frame.set_security_enabled(self.security_enabled); + frame.set_frame_pending(self.frame_pending); + frame.set_ack_request(self.ack_request); + frame.set_pan_id_compression(self.pan_id_compression); + frame.set_frame_version(self.frame_version); + + if let Some(sequence_number) = self.sequence_number { + frame.set_sequence_number(sequence_number); + } + + if let Some(dst_pan_id) = self.dst_pan_id { + frame.set_dst_pan_id(dst_pan_id); + } + if let Some(dst_addr) = self.dst_addr { + frame.set_dst_addr(dst_addr); + } + + if !self.pan_id_compression && self.src_pan_id.is_some() { + frame.set_src_pan_id(self.src_pan_id.unwrap()); + } + + if let Some(src_addr) = self.src_addr { + frame.set_src_addr(src_addr); + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_broadcast() { + assert!(Address::BROADCAST.is_broadcast()); + assert!(!Address::BROADCAST.is_unicast()); + } + + #[test] + fn prepare_frame() { + let mut buffer = [0u8; 128]; + + let repr = Repr { + frame_type: FrameType::Data, + security_enabled: false, + frame_pending: false, + ack_request: true, + pan_id_compression: true, + frame_version: FrameVersion::Ieee802154, + sequence_number: Some(1), + dst_pan_id: Some(Pan(0xabcd)), + dst_addr: Some(Address::BROADCAST), + src_pan_id: None, + src_addr: Some(Address::Extended([ + 0xc7, 0xd9, 0xb5, 0x14, 0x00, 0x4b, 0x12, 0x00, + ])), + }; + + let buffer_len = repr.buffer_len(); + + let mut frame = Frame::new_unchecked(&mut buffer[..buffer_len]); + repr.emit(&mut frame); + + println!("{frame:2x?}"); + + assert_eq!(frame.frame_type(), FrameType::Data); + assert!(!frame.security_enabled()); + assert!(!frame.frame_pending()); + assert!(frame.ack_request()); + assert!(frame.pan_id_compression()); + assert_eq!(frame.frame_version(), FrameVersion::Ieee802154); + assert_eq!(frame.sequence_number(), Some(1)); + assert_eq!(frame.dst_pan_id(), Some(Pan(0xabcd))); + assert_eq!(frame.dst_addr(), Some(Address::BROADCAST)); + assert_eq!(frame.src_pan_id(), None); + assert_eq!( + frame.src_addr(), + Some(Address::Extended([ + 0xc7, 0xd9, 0xb5, 0x14, 0x00, 0x4b, 0x12, 0x00 + ])) + ); + } + + macro_rules! vector_test { + ($name:ident $bytes:expr ; $($test_method:ident -> $expected:expr,)*) => { + #[test] + #[allow(clippy::bool_assert_comparison)] + fn $name() -> Result<()> { + let frame = &$bytes; + let frame = Frame::new_checked(frame)?; + + $( + assert_eq!(frame.$test_method(), $expected, stringify!($test_method)); + )* + + Ok(()) + } + } + } + + vector_test! { + extended_addr + [ + 0b0000_0001, 0b1100_1100, // frame control + 0b0, // seq + 0xcd, 0xab, // pan id + 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, // dst addr + 0x03, 0x04, // pan id + 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x02, // src addr + ]; + frame_type -> FrameType::Data, + dst_addr -> Some(Address::Extended([0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00])), + src_addr -> Some(Address::Extended([0x02, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00])), + dst_pan_id -> Some(Pan(0xabcd)), + } + + vector_test! { + short_addr + [ + 0x01, 0x98, // frame control + 0x00, // sequence number + 0x34, 0x12, 0x78, 0x56, // PAN identifier and address of destination + 0x34, 0x12, 0xbc, 0x9a, // PAN identifier and address of source + ]; + frame_type -> FrameType::Data, + security_enabled -> false, + frame_pending -> false, + ack_request -> false, + pan_id_compression -> false, + dst_addressing_mode -> AddressingMode::Short, + frame_version -> FrameVersion::Ieee802154_2006, + src_addressing_mode -> AddressingMode::Short, + dst_pan_id -> Some(Pan(0x1234)), + dst_addr -> Some(Address::Short([0x56, 0x78])), + src_pan_id -> Some(Pan(0x1234)), + src_addr -> Some(Address::Short([0x9a, 0xbc])), + } + + vector_test! { + zolertia_remote + [ + 0x41, 0xd8, // frame control + 0x01, // sequence number + 0xcd, 0xab, // Destination PAN id + 0xff, 0xff, // Short destination address + 0xc7, 0xd9, 0xb5, 0x14, 0x00, 0x4b, 0x12, 0x00, // Extended source address + 0x2b, 0x00, 0x00, 0x00, // payload + ]; + frame_type -> FrameType::Data, + security_enabled -> false, + frame_pending -> false, + ack_request -> false, + pan_id_compression -> true, + dst_addressing_mode -> AddressingMode::Short, + frame_version -> FrameVersion::Ieee802154_2006, + src_addressing_mode -> AddressingMode::Extended, + payload -> Some(&[0x2b, 0x00, 0x00, 0x00][..]), + } + + vector_test! { + security + [ + 0x69,0xdc, // frame control + 0x32, // sequence number + 0xcd,0xab, // destination PAN id + 0xbf,0x9b,0x15,0x06,0x00,0x4b,0x12,0x00, // extended destination address + 0xc7,0xd9,0xb5,0x14,0x00,0x4b,0x12,0x00, // extended source address + 0x05, // security control field + 0x31,0x01,0x00,0x00, // frame counter + 0x3e,0xe8,0xfb,0x85,0xe4,0xcc,0xf4,0x48,0x90,0xfe,0x56,0x66,0xf7,0x1c,0x65,0x9e,0xf9, // data + 0x93,0xc8,0x34,0x2e,// MIC + ]; + frame_type -> FrameType::Data, + security_enabled -> true, + frame_pending -> false, + ack_request -> true, + pan_id_compression -> true, + dst_addressing_mode -> AddressingMode::Extended, + frame_version -> FrameVersion::Ieee802154_2006, + src_addressing_mode -> AddressingMode::Extended, + dst_pan_id -> Some(Pan(0xabcd)), + dst_addr -> Some(Address::Extended([0x00,0x12,0x4b,0x00,0x06,0x15,0x9b,0xbf])), + src_pan_id -> None, + src_addr -> Some(Address::Extended([0x00,0x12,0x4b,0x00,0x14,0xb5,0xd9,0xc7])), + security_level -> 5, + key_identifier_mode -> 0, + frame_counter -> Some(305), + key_source -> None, + key_index -> None, + payload -> Some(&[0x3e,0xe8,0xfb,0x85,0xe4,0xcc,0xf4,0x48,0x90,0xfe,0x56,0x66,0xf7,0x1c,0x65,0x9e,0xf9,0x93,0xc8,0x34,0x2e][..]), + message_integrity_code -> Some(&[0x93, 0xC8, 0x34, 0x2E][..]), + mac_header -> &[ + 0x69,0xdc, // frame control + 0x32, // sequence number + 0xcd,0xab, // destination PAN id + 0xbf,0x9b,0x15,0x06,0x00,0x4b,0x12,0x00, // extended destination address + 0xc7,0xd9,0xb5,0x14,0x00,0x4b,0x12,0x00, // extended source address + 0x05, // security control field + 0x31,0x01,0x00,0x00, // frame counter + ][..], + } +} diff --git a/src/wire/igmp.rs b/src/wire/igmp.rs new file mode 100644 index 0000000..ac13ece --- /dev/null +++ b/src/wire/igmp.rs @@ -0,0 +1,445 @@ +use byteorder::{ByteOrder, NetworkEndian}; +use core::fmt; + +use super::{Error, Result}; +use crate::time::Duration; +use crate::wire::ip::checksum; + +use crate::wire::Ipv4Address; + +enum_with_unknown! { + /// Internet Group Management Protocol v1/v2 message version/type. + pub enum Message(u8) { + /// Membership Query + MembershipQuery = 0x11, + /// Version 2 Membership Report + MembershipReportV2 = 0x16, + /// Leave Group + LeaveGroup = 0x17, + /// Version 1 Membership Report + MembershipReportV1 = 0x12 + } +} + +/// A read/write wrapper around an Internet Group Management Protocol v1/v2 packet buffer. +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Packet<T: AsRef<[u8]>> { + buffer: T, +} + +mod field { + use crate::wire::field::*; + + pub const TYPE: usize = 0; + pub const MAX_RESP_CODE: usize = 1; + pub const CHECKSUM: Field = 2..4; + pub const GROUP_ADDRESS: Field = 4..8; +} + +impl fmt::Display for Message { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Message::MembershipQuery => write!(f, "membership query"), + Message::MembershipReportV2 => write!(f, "version 2 membership report"), + Message::LeaveGroup => write!(f, "leave group"), + Message::MembershipReportV1 => write!(f, "version 1 membership report"), + Message::Unknown(id) => write!(f, "{id}"), + } + } +} + +/// Internet Group Management Protocol v1/v2 defined in [RFC 2236]. +/// +/// [RFC 2236]: https://tools.ietf.org/html/rfc2236 +impl<T: AsRef<[u8]>> Packet<T> { + /// Imbue a raw octet buffer with IGMPv2 packet structure. + pub const fn new_unchecked(buffer: T) -> Packet<T> { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Packet<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let len = self.buffer.as_ref().len(); + if len < field::GROUP_ADDRESS.end { + Err(Error) + } else { + Ok(()) + } + } + + /// Consume the packet, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the message type field. + #[inline] + pub fn msg_type(&self) -> Message { + let data = self.buffer.as_ref(); + Message::from(data[field::TYPE]) + } + + /// Return the maximum response time, using the encoding specified in + /// [RFC 3376]: 4.1.1. Max Resp Code. + /// + /// [RFC 3376]: https://tools.ietf.org/html/rfc3376 + #[inline] + pub fn max_resp_code(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::MAX_RESP_CODE] + } + + /// Return the checksum field. + #[inline] + pub fn checksum(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::CHECKSUM]) + } + + /// Return the source address field. + #[inline] + pub fn group_addr(&self) -> Ipv4Address { + let data = self.buffer.as_ref(); + Ipv4Address::from_bytes(&data[field::GROUP_ADDRESS]) + } + + /// Validate the header checksum. + /// + /// # Fuzzing + /// This function always returns `true` when fuzzing. + pub fn verify_checksum(&self) -> bool { + if cfg!(fuzzing) { + return true; + } + + let data = self.buffer.as_ref(); + checksum::data(data) == !0 + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the message type field. + #[inline] + pub fn set_msg_type(&mut self, value: Message) { + let data = self.buffer.as_mut(); + data[field::TYPE] = value.into() + } + + /// Set the maximum response time, using the encoding specified in + /// [RFC 3376]: 4.1.1. Max Resp Code. + #[inline] + pub fn set_max_resp_code(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::MAX_RESP_CODE] = value; + } + + /// Set the checksum field. + #[inline] + pub fn set_checksum(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::CHECKSUM], value) + } + + /// Set the group address field + #[inline] + pub fn set_group_address(&mut self, addr: Ipv4Address) { + let data = self.buffer.as_mut(); + data[field::GROUP_ADDRESS].copy_from_slice(addr.as_bytes()); + } + + /// Compute and fill in the header checksum. + pub fn fill_checksum(&mut self) { + self.set_checksum(0); + let checksum = { + let data = self.buffer.as_ref(); + !checksum::data(data) + }; + self.set_checksum(checksum) + } +} + +/// A high-level representation of an Internet Group Management Protocol v1/v2 header. +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Repr { + MembershipQuery { + max_resp_time: Duration, + group_addr: Ipv4Address, + version: IgmpVersion, + }, + MembershipReport { + group_addr: Ipv4Address, + version: IgmpVersion, + }, + LeaveGroup { + group_addr: Ipv4Address, + }, +} + +/// Type of IGMP membership report version +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum IgmpVersion { + /// IGMPv1 + Version1, + /// IGMPv2 + Version2, +} + +impl Repr { + /// Parse an Internet Group Management Protocol v1/v2 packet and return + /// a high-level representation. + pub fn parse<T>(packet: &Packet<&T>) -> Result<Repr> + where + T: AsRef<[u8]> + ?Sized, + { + // Check if the address is 0.0.0.0 or multicast + let addr = packet.group_addr(); + if !addr.is_unspecified() && !addr.is_multicast() { + return Err(Error); + } + + // construct a packet based on the Type field + match packet.msg_type() { + Message::MembershipQuery => { + let max_resp_time = max_resp_code_to_duration(packet.max_resp_code()); + // See RFC 3376: 7.1. Query Version Distinctions + let version = if packet.max_resp_code() == 0 { + IgmpVersion::Version1 + } else { + IgmpVersion::Version2 + }; + Ok(Repr::MembershipQuery { + max_resp_time, + group_addr: addr, + version, + }) + } + Message::MembershipReportV2 => Ok(Repr::MembershipReport { + group_addr: packet.group_addr(), + version: IgmpVersion::Version2, + }), + Message::LeaveGroup => Ok(Repr::LeaveGroup { + group_addr: packet.group_addr(), + }), + Message::MembershipReportV1 => { + // for backwards compatibility with IGMPv1 + Ok(Repr::MembershipReport { + group_addr: packet.group_addr(), + version: IgmpVersion::Version1, + }) + } + _ => Err(Error), + } + } + + /// Return the length of a packet that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + // always 8 bytes + field::GROUP_ADDRESS.end + } + + /// Emit a high-level representation into an Internet Group Management Protocol v2 packet. + pub fn emit<T>(&self, packet: &mut Packet<&mut T>) + where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { + match *self { + Repr::MembershipQuery { + max_resp_time, + group_addr, + version, + } => { + packet.set_msg_type(Message::MembershipQuery); + match version { + IgmpVersion::Version1 => packet.set_max_resp_code(0), + IgmpVersion::Version2 => { + packet.set_max_resp_code(duration_to_max_resp_code(max_resp_time)) + } + } + packet.set_group_address(group_addr); + } + Repr::MembershipReport { + group_addr, + version, + } => { + match version { + IgmpVersion::Version1 => packet.set_msg_type(Message::MembershipReportV1), + IgmpVersion::Version2 => packet.set_msg_type(Message::MembershipReportV2), + }; + packet.set_max_resp_code(0); + packet.set_group_address(group_addr); + } + Repr::LeaveGroup { group_addr } => { + packet.set_msg_type(Message::LeaveGroup); + packet.set_group_address(group_addr); + } + } + + packet.fill_checksum() + } +} + +fn max_resp_code_to_duration(value: u8) -> Duration { + let value: u64 = value.into(); + let decisecs = if value < 128 { + value + } else { + let mant = value & 0xF; + let exp = (value >> 4) & 0x7; + (mant | 0x10) << (exp + 3) + }; + Duration::from_millis(decisecs * 100) +} + +const fn duration_to_max_resp_code(duration: Duration) -> u8 { + let decisecs = duration.total_millis() / 100; + if decisecs < 128 { + decisecs as u8 + } else if decisecs < 31744 { + let mut mant = decisecs >> 3; + let mut exp = 0u8; + while mant > 0x1F && exp < 0x8 { + mant >>= 1; + exp += 1; + } + 0x80 | (exp << 4) | (mant as u8 & 0xF) + } else { + 0xFF + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match Repr::parse(self) { + Ok(repr) => write!(f, "{repr}"), + Err(err) => write!(f, "IGMP ({err})"), + } + } +} + +impl fmt::Display for Repr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Repr::MembershipQuery { + max_resp_time, + group_addr, + version, + } => write!( + f, + "IGMP membership query max_resp_time={max_resp_time} group_addr={group_addr} version={version:?}" + ), + Repr::MembershipReport { + group_addr, + version, + } => write!( + f, + "IGMP membership report group_addr={group_addr} version={version:?}" + ), + Repr::LeaveGroup { group_addr } => { + write!(f, "IGMP leave group group_addr={group_addr})") + } + } + } +} + +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; + +impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { + match Packet::new_checked(buffer) { + Err(err) => writeln!(f, "{indent}({err})"), + Ok(packet) => writeln!(f, "{indent}{packet}"), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + static LEAVE_PACKET_BYTES: [u8; 8] = [0x17, 0x00, 0x02, 0x69, 0xe0, 0x00, 0x06, 0x96]; + static REPORT_PACKET_BYTES: [u8; 8] = [0x16, 0x00, 0x08, 0xda, 0xe1, 0x00, 0x00, 0x25]; + + #[test] + fn test_leave_group_deconstruct() { + let packet = Packet::new_unchecked(&LEAVE_PACKET_BYTES[..]); + assert_eq!(packet.msg_type(), Message::LeaveGroup); + assert_eq!(packet.max_resp_code(), 0); + assert_eq!(packet.checksum(), 0x269); + assert_eq!( + packet.group_addr(), + Ipv4Address::from_bytes(&[224, 0, 6, 150]) + ); + assert!(packet.verify_checksum()); + } + + #[test] + fn test_report_deconstruct() { + let packet = Packet::new_unchecked(&REPORT_PACKET_BYTES[..]); + assert_eq!(packet.msg_type(), Message::MembershipReportV2); + assert_eq!(packet.max_resp_code(), 0); + assert_eq!(packet.checksum(), 0x08da); + assert_eq!( + packet.group_addr(), + Ipv4Address::from_bytes(&[225, 0, 0, 37]) + ); + assert!(packet.verify_checksum()); + } + + #[test] + fn test_leave_construct() { + let mut bytes = vec![0xa5; 8]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_msg_type(Message::LeaveGroup); + packet.set_max_resp_code(0); + packet.set_group_address(Ipv4Address::from_bytes(&[224, 0, 6, 150])); + packet.fill_checksum(); + assert_eq!(&*packet.into_inner(), &LEAVE_PACKET_BYTES[..]); + } + + #[test] + fn test_report_construct() { + let mut bytes = vec![0xa5; 8]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_msg_type(Message::MembershipReportV2); + packet.set_max_resp_code(0); + packet.set_group_address(Ipv4Address::from_bytes(&[225, 0, 0, 37])); + packet.fill_checksum(); + assert_eq!(&*packet.into_inner(), &REPORT_PACKET_BYTES[..]); + } + + #[test] + fn max_resp_time_to_duration_and_back() { + for i in 0..256usize { + let time1 = i as u8; + let duration = max_resp_code_to_duration(time1); + let time2 = duration_to_max_resp_code(duration); + assert!(time1 == time2); + } + } + + #[test] + fn duration_to_max_resp_time_max() { + for duration in 31744..65536 { + let time = duration_to_max_resp_code(Duration::from_millis(duration * 100)); + assert_eq!(time, 0xFF); + } + } +} diff --git a/src/wire/ip.rs b/src/wire/ip.rs new file mode 100644 index 0000000..da80aba --- /dev/null +++ b/src/wire/ip.rs @@ -0,0 +1,998 @@ +use core::convert::From; +use core::fmt; + +use super::{Error, Result}; +use crate::phy::ChecksumCapabilities; +#[cfg(feature = "proto-ipv4")] +use crate::wire::{Ipv4Address, Ipv4Cidr, Ipv4Packet, Ipv4Repr}; +#[cfg(feature = "proto-ipv6")] +use crate::wire::{Ipv6Address, Ipv6Cidr, Ipv6Packet, Ipv6Repr}; + +/// Internet protocol version. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Version { + #[cfg(feature = "proto-ipv4")] + Ipv4, + #[cfg(feature = "proto-ipv6")] + Ipv6, +} + +impl Version { + /// Return the version of an IP packet stored in the provided buffer. + /// + /// This function never returns `Ok(IpVersion::Unspecified)`; instead, + /// unknown versions result in `Err(Error)`. + pub const fn of_packet(data: &[u8]) -> Result<Version> { + match data[0] >> 4 { + #[cfg(feature = "proto-ipv4")] + 4 => Ok(Version::Ipv4), + #[cfg(feature = "proto-ipv6")] + 6 => Ok(Version::Ipv6), + _ => Err(Error), + } + } +} + +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + #[cfg(feature = "proto-ipv4")] + Version::Ipv4 => write!(f, "IPv4"), + #[cfg(feature = "proto-ipv6")] + Version::Ipv6 => write!(f, "IPv6"), + } + } +} + +enum_with_unknown! { + /// IP datagram encapsulated protocol. + pub enum Protocol(u8) { + HopByHop = 0x00, + Icmp = 0x01, + Igmp = 0x02, + Tcp = 0x06, + Udp = 0x11, + Ipv6Route = 0x2b, + Ipv6Frag = 0x2c, + IpSecEsp = 0x32, + IpSecAh = 0x33, + Icmpv6 = 0x3a, + Ipv6NoNxt = 0x3b, + Ipv6Opts = 0x3c + } +} + +impl fmt::Display for Protocol { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Protocol::HopByHop => write!(f, "Hop-by-Hop"), + Protocol::Icmp => write!(f, "ICMP"), + Protocol::Igmp => write!(f, "IGMP"), + Protocol::Tcp => write!(f, "TCP"), + Protocol::Udp => write!(f, "UDP"), + Protocol::Ipv6Route => write!(f, "IPv6-Route"), + Protocol::Ipv6Frag => write!(f, "IPv6-Frag"), + Protocol::IpSecEsp => write!(f, "IPsec-ESP"), + Protocol::IpSecAh => write!(f, "IPsec-AH"), + Protocol::Icmpv6 => write!(f, "ICMPv6"), + Protocol::Ipv6NoNxt => write!(f, "IPv6-NoNxt"), + Protocol::Ipv6Opts => write!(f, "IPv6-Opts"), + Protocol::Unknown(id) => write!(f, "0x{id:02x}"), + } + } +} + +/// An internetworking address. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub enum Address { + /// An IPv4 address. + #[cfg(feature = "proto-ipv4")] + Ipv4(Ipv4Address), + /// An IPv6 address. + #[cfg(feature = "proto-ipv6")] + Ipv6(Ipv6Address), +} + +impl Address { + /// Create an address wrapping an IPv4 address with the given octets. + #[cfg(feature = "proto-ipv4")] + pub const fn v4(a0: u8, a1: u8, a2: u8, a3: u8) -> Address { + Address::Ipv4(Ipv4Address::new(a0, a1, a2, a3)) + } + + /// Create an address wrapping an IPv6 address with the given octets. + #[cfg(feature = "proto-ipv6")] + #[allow(clippy::too_many_arguments)] + pub fn v6(a0: u16, a1: u16, a2: u16, a3: u16, a4: u16, a5: u16, a6: u16, a7: u16) -> Address { + Address::Ipv6(Ipv6Address::new(a0, a1, a2, a3, a4, a5, a6, a7)) + } + + /// Return the protocol version. + pub const fn version(&self) -> Version { + match self { + #[cfg(feature = "proto-ipv4")] + Address::Ipv4(_) => Version::Ipv4, + #[cfg(feature = "proto-ipv6")] + Address::Ipv6(_) => Version::Ipv6, + } + } + + /// Return an address as a sequence of octets, in big-endian. + pub const fn as_bytes(&self) -> &[u8] { + match self { + #[cfg(feature = "proto-ipv4")] + Address::Ipv4(addr) => addr.as_bytes(), + #[cfg(feature = "proto-ipv6")] + Address::Ipv6(addr) => addr.as_bytes(), + } + } + + /// Query whether the address is a valid unicast address. + pub fn is_unicast(&self) -> bool { + match self { + #[cfg(feature = "proto-ipv4")] + Address::Ipv4(addr) => addr.is_unicast(), + #[cfg(feature = "proto-ipv6")] + Address::Ipv6(addr) => addr.is_unicast(), + } + } + + /// Query whether the address is a valid multicast address. + pub const fn is_multicast(&self) -> bool { + match self { + #[cfg(feature = "proto-ipv4")] + Address::Ipv4(addr) => addr.is_multicast(), + #[cfg(feature = "proto-ipv6")] + Address::Ipv6(addr) => addr.is_multicast(), + } + } + + /// Query whether the address is the broadcast address. + pub fn is_broadcast(&self) -> bool { + match self { + #[cfg(feature = "proto-ipv4")] + Address::Ipv4(addr) => addr.is_broadcast(), + #[cfg(feature = "proto-ipv6")] + Address::Ipv6(_) => false, + } + } + + /// Query whether the address falls into the "unspecified" range. + pub fn is_unspecified(&self) -> bool { + match self { + #[cfg(feature = "proto-ipv4")] + Address::Ipv4(addr) => addr.is_unspecified(), + #[cfg(feature = "proto-ipv6")] + Address::Ipv6(addr) => addr.is_unspecified(), + } + } + + /// If `self` is a CIDR-compatible subnet mask, return `Some(prefix_len)`, + /// where `prefix_len` is the number of leading zeroes. Return `None` otherwise. + pub fn prefix_len(&self) -> Option<u8> { + let mut ones = true; + let mut prefix_len = 0; + for byte in self.as_bytes() { + let mut mask = 0x80; + for _ in 0..8 { + let one = *byte & mask != 0; + if ones { + // Expect 1s until first 0 + if one { + prefix_len += 1; + } else { + ones = false; + } + } else if one { + // 1 where 0 was expected + return None; + } + mask >>= 1; + } + } + Some(prefix_len) + } +} + +#[cfg(all(feature = "std", feature = "proto-ipv4", feature = "proto-ipv6"))] +impl From<::std::net::IpAddr> for Address { + fn from(x: ::std::net::IpAddr) -> Address { + match x { + ::std::net::IpAddr::V4(ipv4) => Address::Ipv4(ipv4.into()), + ::std::net::IpAddr::V6(ipv6) => Address::Ipv6(ipv6.into()), + } + } +} + +#[cfg(feature = "std")] +impl From<Address> for ::std::net::IpAddr { + fn from(x: Address) -> ::std::net::IpAddr { + match x { + #[cfg(feature = "proto-ipv4")] + Address::Ipv4(ipv4) => ::std::net::IpAddr::V4(ipv4.into()), + #[cfg(feature = "proto-ipv6")] + Address::Ipv6(ipv6) => ::std::net::IpAddr::V6(ipv6.into()), + } + } +} + +#[cfg(all(feature = "std", feature = "proto-ipv4"))] +impl From<::std::net::Ipv4Addr> for Address { + fn from(ipv4: ::std::net::Ipv4Addr) -> Address { + Address::Ipv4(ipv4.into()) + } +} + +#[cfg(all(feature = "std", feature = "proto-ipv6"))] +impl From<::std::net::Ipv6Addr> for Address { + fn from(ipv6: ::std::net::Ipv6Addr) -> Address { + Address::Ipv6(ipv6.into()) + } +} + +#[cfg(feature = "proto-ipv4")] +impl From<Ipv4Address> for Address { + fn from(addr: Ipv4Address) -> Self { + Address::Ipv4(addr) + } +} + +#[cfg(feature = "proto-ipv6")] +impl From<Ipv6Address> for Address { + fn from(addr: Ipv6Address) -> Self { + Address::Ipv6(addr) + } +} + +impl fmt::Display for Address { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + #[cfg(feature = "proto-ipv4")] + Address::Ipv4(addr) => write!(f, "{addr}"), + #[cfg(feature = "proto-ipv6")] + Address::Ipv6(addr) => write!(f, "{addr}"), + } + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Address { + fn format(&self, f: defmt::Formatter) { + match self { + #[cfg(feature = "proto-ipv4")] + &Address::Ipv4(addr) => defmt::write!(f, "{:?}", addr), + #[cfg(feature = "proto-ipv6")] + &Address::Ipv6(addr) => defmt::write!(f, "{:?}", addr), + } + } +} + +/// A specification of a CIDR block, containing an address and a variable-length +/// subnet masking prefix length. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub enum Cidr { + #[cfg(feature = "proto-ipv4")] + Ipv4(Ipv4Cidr), + #[cfg(feature = "proto-ipv6")] + Ipv6(Ipv6Cidr), +} + +impl Cidr { + /// Create a CIDR block from the given address and prefix length. + /// + /// # Panics + /// This function panics if the given prefix length is invalid for the given address. + pub fn new(addr: Address, prefix_len: u8) -> Cidr { + match addr { + #[cfg(feature = "proto-ipv4")] + Address::Ipv4(addr) => Cidr::Ipv4(Ipv4Cidr::new(addr, prefix_len)), + #[cfg(feature = "proto-ipv6")] + Address::Ipv6(addr) => Cidr::Ipv6(Ipv6Cidr::new(addr, prefix_len)), + } + } + + /// Return the IP address of this CIDR block. + pub const fn address(&self) -> Address { + match *self { + #[cfg(feature = "proto-ipv4")] + Cidr::Ipv4(cidr) => Address::Ipv4(cidr.address()), + #[cfg(feature = "proto-ipv6")] + Cidr::Ipv6(cidr) => Address::Ipv6(cidr.address()), + } + } + + /// Return the prefix length of this CIDR block. + pub const fn prefix_len(&self) -> u8 { + match *self { + #[cfg(feature = "proto-ipv4")] + Cidr::Ipv4(cidr) => cidr.prefix_len(), + #[cfg(feature = "proto-ipv6")] + Cidr::Ipv6(cidr) => cidr.prefix_len(), + } + } + + /// Query whether the subnetwork described by this CIDR block contains + /// the given address. + pub fn contains_addr(&self, addr: &Address) -> bool { + match (self, addr) { + #[cfg(feature = "proto-ipv4")] + (Cidr::Ipv4(cidr), Address::Ipv4(addr)) => cidr.contains_addr(addr), + #[cfg(feature = "proto-ipv6")] + (Cidr::Ipv6(cidr), Address::Ipv6(addr)) => cidr.contains_addr(addr), + #[allow(unreachable_patterns)] + _ => false, + } + } + + /// Query whether the subnetwork described by this CIDR block contains + /// the subnetwork described by the given CIDR block. + pub fn contains_subnet(&self, subnet: &Cidr) -> bool { + match (self, subnet) { + #[cfg(feature = "proto-ipv4")] + (Cidr::Ipv4(cidr), Cidr::Ipv4(other)) => cidr.contains_subnet(other), + #[cfg(feature = "proto-ipv6")] + (Cidr::Ipv6(cidr), Cidr::Ipv6(other)) => cidr.contains_subnet(other), + #[allow(unreachable_patterns)] + _ => false, + } + } +} + +#[cfg(feature = "proto-ipv4")] +impl From<Ipv4Cidr> for Cidr { + fn from(addr: Ipv4Cidr) -> Self { + Cidr::Ipv4(addr) + } +} + +#[cfg(feature = "proto-ipv6")] +impl From<Ipv6Cidr> for Cidr { + fn from(addr: Ipv6Cidr) -> Self { + Cidr::Ipv6(addr) + } +} + +impl fmt::Display for Cidr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + #[cfg(feature = "proto-ipv4")] + Cidr::Ipv4(cidr) => write!(f, "{cidr}"), + #[cfg(feature = "proto-ipv6")] + Cidr::Ipv6(cidr) => write!(f, "{cidr}"), + } + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Cidr { + fn format(&self, f: defmt::Formatter) { + match self { + #[cfg(feature = "proto-ipv4")] + &Cidr::Ipv4(cidr) => defmt::write!(f, "{:?}", cidr), + #[cfg(feature = "proto-ipv6")] + &Cidr::Ipv6(cidr) => defmt::write!(f, "{:?}", cidr), + } + } +} + +/// An internet endpoint address. +/// +/// `Endpoint` always fully specifies both the address and the port. +/// +/// See also ['ListenEndpoint'], which allows not specifying the address +/// in order to listen on a given port on any address. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub struct Endpoint { + pub addr: Address, + pub port: u16, +} + +impl Endpoint { + /// Create an endpoint address from given address and port. + pub const fn new(addr: Address, port: u16) -> Endpoint { + Endpoint { addr: addr, port } + } +} + +#[cfg(all(feature = "std", feature = "proto-ipv4", feature = "proto-ipv6"))] +impl From<::std::net::SocketAddr> for Endpoint { + fn from(x: ::std::net::SocketAddr) -> Endpoint { + Endpoint { + addr: x.ip().into(), + port: x.port(), + } + } +} + +#[cfg(all(feature = "std", feature = "proto-ipv4"))] +impl From<::std::net::SocketAddrV4> for Endpoint { + fn from(x: ::std::net::SocketAddrV4) -> Endpoint { + Endpoint { + addr: (*x.ip()).into(), + port: x.port(), + } + } +} + +#[cfg(all(feature = "std", feature = "proto-ipv6"))] +impl From<::std::net::SocketAddrV6> for Endpoint { + fn from(x: ::std::net::SocketAddrV6) -> Endpoint { + Endpoint { + addr: (*x.ip()).into(), + port: x.port(), + } + } +} + +impl fmt::Display for Endpoint { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}:{}", self.addr, self.port) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Endpoint { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "{:?}:{=u16}", self.addr, self.port); + } +} + +impl<T: Into<Address>> From<(T, u16)> for Endpoint { + fn from((addr, port): (T, u16)) -> Endpoint { + Endpoint { + addr: addr.into(), + port, + } + } +} + +/// An internet endpoint address for listening. +/// +/// In contrast with [`Endpoint`], `ListenEndpoint` allows not specifying the address, +/// in order to listen on a given port at all our addresses. +/// +/// An endpoint can be constructed from a port, in which case the address is unspecified. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] +pub struct ListenEndpoint { + pub addr: Option<Address>, + pub port: u16, +} + +impl ListenEndpoint { + /// Query whether the endpoint has a specified address and port. + pub const fn is_specified(&self) -> bool { + self.addr.is_some() && self.port != 0 + } +} + +#[cfg(all(feature = "std", feature = "proto-ipv4", feature = "proto-ipv6"))] +impl From<::std::net::SocketAddr> for ListenEndpoint { + fn from(x: ::std::net::SocketAddr) -> ListenEndpoint { + ListenEndpoint { + addr: Some(x.ip().into()), + port: x.port(), + } + } +} + +#[cfg(all(feature = "std", feature = "proto-ipv4"))] +impl From<::std::net::SocketAddrV4> for ListenEndpoint { + fn from(x: ::std::net::SocketAddrV4) -> ListenEndpoint { + ListenEndpoint { + addr: Some((*x.ip()).into()), + port: x.port(), + } + } +} + +#[cfg(all(feature = "std", feature = "proto-ipv6"))] +impl From<::std::net::SocketAddrV6> for ListenEndpoint { + fn from(x: ::std::net::SocketAddrV6) -> ListenEndpoint { + ListenEndpoint { + addr: Some((*x.ip()).into()), + port: x.port(), + } + } +} + +impl fmt::Display for ListenEndpoint { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if let Some(addr) = self.addr { + write!(f, "{}:{}", addr, self.port) + } else { + write!(f, "*:{}", self.port) + } + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for ListenEndpoint { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "{:?}:{=u16}", self.addr, self.port); + } +} + +impl From<u16> for ListenEndpoint { + fn from(port: u16) -> ListenEndpoint { + ListenEndpoint { addr: None, port } + } +} + +impl From<Endpoint> for ListenEndpoint { + fn from(endpoint: Endpoint) -> ListenEndpoint { + ListenEndpoint { + addr: Some(endpoint.addr), + port: endpoint.port, + } + } +} + +impl<T: Into<Address>> From<(T, u16)> for ListenEndpoint { + fn from((addr, port): (T, u16)) -> ListenEndpoint { + ListenEndpoint { + addr: Some(addr.into()), + port, + } + } +} + +/// An IP packet representation. +/// +/// This enum abstracts the various versions of IP packets. It either contains an IPv4 +/// or IPv6 concrete high-level representation. +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Repr { + #[cfg(feature = "proto-ipv4")] + Ipv4(Ipv4Repr), + #[cfg(feature = "proto-ipv6")] + Ipv6(Ipv6Repr), +} + +#[cfg(feature = "proto-ipv4")] +impl From<Ipv4Repr> for Repr { + fn from(repr: Ipv4Repr) -> Repr { + Repr::Ipv4(repr) + } +} + +#[cfg(feature = "proto-ipv6")] +impl From<Ipv6Repr> for Repr { + fn from(repr: Ipv6Repr) -> Repr { + Repr::Ipv6(repr) + } +} + +impl Repr { + /// Create a new IpRepr, choosing the right IP version for the src/dst addrs. + /// + /// # Panics + /// + /// Panics if `src_addr` and `dst_addr` are different IP version. + pub fn new( + src_addr: Address, + dst_addr: Address, + next_header: Protocol, + payload_len: usize, + hop_limit: u8, + ) -> Self { + match (src_addr, dst_addr) { + #[cfg(feature = "proto-ipv4")] + (Address::Ipv4(src_addr), Address::Ipv4(dst_addr)) => Self::Ipv4(Ipv4Repr { + src_addr, + dst_addr, + next_header, + payload_len, + hop_limit, + }), + #[cfg(feature = "proto-ipv6")] + (Address::Ipv6(src_addr), Address::Ipv6(dst_addr)) => Self::Ipv6(Ipv6Repr { + src_addr, + dst_addr, + next_header, + payload_len, + hop_limit, + }), + #[allow(unreachable_patterns)] + _ => panic!("IP version mismatch: src={src_addr:?} dst={dst_addr:?}"), + } + } + + /// Return the protocol version. + pub const fn version(&self) -> Version { + match *self { + #[cfg(feature = "proto-ipv4")] + Repr::Ipv4(_) => Version::Ipv4, + #[cfg(feature = "proto-ipv6")] + Repr::Ipv6(_) => Version::Ipv6, + } + } + + /// Return the source address. + pub const fn src_addr(&self) -> Address { + match *self { + #[cfg(feature = "proto-ipv4")] + Repr::Ipv4(repr) => Address::Ipv4(repr.src_addr), + #[cfg(feature = "proto-ipv6")] + Repr::Ipv6(repr) => Address::Ipv6(repr.src_addr), + } + } + + /// Return the destination address. + pub const fn dst_addr(&self) -> Address { + match *self { + #[cfg(feature = "proto-ipv4")] + Repr::Ipv4(repr) => Address::Ipv4(repr.dst_addr), + #[cfg(feature = "proto-ipv6")] + Repr::Ipv6(repr) => Address::Ipv6(repr.dst_addr), + } + } + + /// Return the next header (protocol). + pub const fn next_header(&self) -> Protocol { + match *self { + #[cfg(feature = "proto-ipv4")] + Repr::Ipv4(repr) => repr.next_header, + #[cfg(feature = "proto-ipv6")] + Repr::Ipv6(repr) => repr.next_header, + } + } + + /// Return the payload length. + pub const fn payload_len(&self) -> usize { + match *self { + #[cfg(feature = "proto-ipv4")] + Repr::Ipv4(repr) => repr.payload_len, + #[cfg(feature = "proto-ipv6")] + Repr::Ipv6(repr) => repr.payload_len, + } + } + + /// Set the payload length. + pub fn set_payload_len(&mut self, length: usize) { + match self { + #[cfg(feature = "proto-ipv4")] + Repr::Ipv4(Ipv4Repr { payload_len, .. }) => *payload_len = length, + #[cfg(feature = "proto-ipv6")] + Repr::Ipv6(Ipv6Repr { payload_len, .. }) => *payload_len = length, + } + } + + /// Return the TTL value. + pub const fn hop_limit(&self) -> u8 { + match *self { + #[cfg(feature = "proto-ipv4")] + Repr::Ipv4(Ipv4Repr { hop_limit, .. }) => hop_limit, + #[cfg(feature = "proto-ipv6")] + Repr::Ipv6(Ipv6Repr { hop_limit, .. }) => hop_limit, + } + } + + /// Return the length of a header that will be emitted from this high-level representation. + pub const fn header_len(&self) -> usize { + match *self { + #[cfg(feature = "proto-ipv4")] + Repr::Ipv4(repr) => repr.buffer_len(), + #[cfg(feature = "proto-ipv6")] + Repr::Ipv6(repr) => repr.buffer_len(), + } + } + + /// Emit this high-level representation into a buffer. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>( + &self, + buffer: T, + _checksum_caps: &ChecksumCapabilities, + ) { + match *self { + #[cfg(feature = "proto-ipv4")] + Repr::Ipv4(repr) => repr.emit(&mut Ipv4Packet::new_unchecked(buffer), _checksum_caps), + #[cfg(feature = "proto-ipv6")] + Repr::Ipv6(repr) => repr.emit(&mut Ipv6Packet::new_unchecked(buffer)), + } + } + + /// Return the total length of a packet that will be emitted from this + /// high-level representation. + /// + /// This is the same as `repr.buffer_len() + repr.payload_len()`. + pub const fn buffer_len(&self) -> usize { + self.header_len() + self.payload_len() + } +} + +pub mod checksum { + use byteorder::{ByteOrder, NetworkEndian}; + + use super::*; + + const fn propagate_carries(word: u32) -> u16 { + let sum = (word >> 16) + (word & 0xffff); + ((sum >> 16) as u16) + (sum as u16) + } + + /// Compute an RFC 1071 compliant checksum (without the final complement). + pub fn data(mut data: &[u8]) -> u16 { + let mut accum = 0; + + // For each 32-byte chunk... + const CHUNK_SIZE: usize = 32; + while data.len() >= CHUNK_SIZE { + let mut d = &data[..CHUNK_SIZE]; + // ... take by 2 bytes and sum them. + while d.len() >= 2 { + accum += NetworkEndian::read_u16(d) as u32; + d = &d[2..]; + } + + data = &data[CHUNK_SIZE..]; + } + + // Sum the rest that does not fit the last 32-byte chunk, + // taking by 2 bytes. + while data.len() >= 2 { + accum += NetworkEndian::read_u16(data) as u32; + data = &data[2..]; + } + + // Add the last remaining odd byte, if any. + if let Some(&value) = data.first() { + accum += (value as u32) << 8; + } + + propagate_carries(accum) + } + + /// Combine several RFC 1071 compliant checksums. + pub fn combine(checksums: &[u16]) -> u16 { + let mut accum: u32 = 0; + for &word in checksums { + accum += word as u32; + } + propagate_carries(accum) + } + + /// Compute an IP pseudo header checksum. + pub fn pseudo_header( + src_addr: &Address, + dst_addr: &Address, + next_header: Protocol, + length: u32, + ) -> u16 { + match (src_addr, dst_addr) { + #[cfg(feature = "proto-ipv4")] + (&Address::Ipv4(src_addr), &Address::Ipv4(dst_addr)) => { + let mut proto_len = [0u8; 4]; + proto_len[1] = next_header.into(); + NetworkEndian::write_u16(&mut proto_len[2..4], length as u16); + + combine(&[ + data(src_addr.as_bytes()), + data(dst_addr.as_bytes()), + data(&proto_len[..]), + ]) + } + + #[cfg(feature = "proto-ipv6")] + (&Address::Ipv6(src_addr), &Address::Ipv6(dst_addr)) => { + let mut proto_len = [0u8; 8]; + proto_len[7] = next_header.into(); + NetworkEndian::write_u32(&mut proto_len[0..4], length); + combine(&[ + data(src_addr.as_bytes()), + data(dst_addr.as_bytes()), + data(&proto_len[..]), + ]) + } + + #[allow(unreachable_patterns)] + _ => panic!("Unexpected pseudo header addresses: {src_addr}, {dst_addr}"), + } + } + + // We use this in pretty printer implementations. + pub(crate) fn format_checksum(f: &mut fmt::Formatter, correct: bool) -> fmt::Result { + if !correct { + write!(f, " (checksum incorrect)") + } else { + Ok(()) + } + } +} + +use crate::wire::pretty_print::PrettyIndent; + +pub fn pretty_print_ip_payload<T: Into<Repr>>( + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ip_repr: T, + payload: &[u8], +) -> fmt::Result { + #[cfg(feature = "proto-ipv4")] + use super::pretty_print::PrettyPrint; + use crate::wire::ip::checksum::format_checksum; + #[cfg(feature = "proto-ipv4")] + use crate::wire::Icmpv4Packet; + use crate::wire::{TcpPacket, TcpRepr, UdpPacket, UdpRepr}; + + let checksum_caps = ChecksumCapabilities::ignored(); + let repr = ip_repr.into(); + match repr.next_header() { + #[cfg(feature = "proto-ipv4")] + Protocol::Icmp => { + indent.increase(f)?; + Icmpv4Packet::<&[u8]>::pretty_print(&payload, f, indent) + } + Protocol::Udp => { + indent.increase(f)?; + match UdpPacket::<&[u8]>::new_checked(payload) { + Err(err) => write!(f, "{indent}({err})"), + Ok(udp_packet) => { + match UdpRepr::parse( + &udp_packet, + &repr.src_addr(), + &repr.dst_addr(), + &checksum_caps, + ) { + Err(err) => write!(f, "{indent}{udp_packet} ({err})"), + Ok(udp_repr) => { + write!( + f, + "{}{} len={}", + indent, + udp_repr, + udp_packet.payload().len() + )?; + let valid = + udp_packet.verify_checksum(&repr.src_addr(), &repr.dst_addr()); + format_checksum(f, valid) + } + } + } + } + } + Protocol::Tcp => { + indent.increase(f)?; + match TcpPacket::<&[u8]>::new_checked(payload) { + Err(err) => write!(f, "{indent}({err})"), + Ok(tcp_packet) => { + match TcpRepr::parse( + &tcp_packet, + &repr.src_addr(), + &repr.dst_addr(), + &checksum_caps, + ) { + Err(err) => write!(f, "{indent}{tcp_packet} ({err})"), + Ok(tcp_repr) => { + write!(f, "{indent}{tcp_repr}")?; + let valid = + tcp_packet.verify_checksum(&repr.src_addr(), &repr.dst_addr()); + format_checksum(f, valid) + } + } + } + } + } + _ => Ok(()), + } +} + +#[cfg(test)] +pub(crate) mod test { + #![allow(unused)] + + #[cfg(feature = "proto-ipv6")] + pub(crate) const MOCK_IP_ADDR_1: IpAddress = IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + ])); + #[cfg(feature = "proto-ipv6")] + pub(crate) const MOCK_IP_ADDR_2: IpAddress = IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, + ])); + #[cfg(feature = "proto-ipv6")] + pub(crate) const MOCK_IP_ADDR_3: IpAddress = IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, + ])); + #[cfg(feature = "proto-ipv6")] + pub(crate) const MOCK_IP_ADDR_4: IpAddress = IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, + ])); + #[cfg(feature = "proto-ipv6")] + pub(crate) const MOCK_UNSPECIFIED: IpAddress = IpAddress::Ipv6(Ipv6Address::UNSPECIFIED); + + #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] + pub(crate) const MOCK_IP_ADDR_1: IpAddress = IpAddress::Ipv4(Ipv4Address([192, 168, 1, 1])); + #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] + pub(crate) const MOCK_IP_ADDR_2: IpAddress = IpAddress::Ipv4(Ipv4Address([192, 168, 1, 2])); + #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] + pub(crate) const MOCK_IP_ADDR_3: IpAddress = IpAddress::Ipv4(Ipv4Address([192, 168, 1, 3])); + #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] + pub(crate) const MOCK_IP_ADDR_4: IpAddress = IpAddress::Ipv4(Ipv4Address([192, 168, 1, 4])); + #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] + pub(crate) const MOCK_UNSPECIFIED: IpAddress = IpAddress::Ipv4(Ipv4Address::UNSPECIFIED); + + use super::*; + use crate::wire::{IpAddress, IpCidr, IpProtocol}; + #[cfg(feature = "proto-ipv4")] + use crate::wire::{Ipv4Address, Ipv4Repr}; + + #[test] + #[cfg(feature = "proto-ipv4")] + fn to_prefix_len_ipv4() { + fn test_eq<A: Into<Address>>(prefix_len: u8, mask: A) { + assert_eq!(Some(prefix_len), mask.into().prefix_len()); + } + + test_eq(0, Ipv4Address::new(0, 0, 0, 0)); + test_eq(1, Ipv4Address::new(128, 0, 0, 0)); + test_eq(2, Ipv4Address::new(192, 0, 0, 0)); + test_eq(3, Ipv4Address::new(224, 0, 0, 0)); + test_eq(4, Ipv4Address::new(240, 0, 0, 0)); + test_eq(5, Ipv4Address::new(248, 0, 0, 0)); + test_eq(6, Ipv4Address::new(252, 0, 0, 0)); + test_eq(7, Ipv4Address::new(254, 0, 0, 0)); + test_eq(8, Ipv4Address::new(255, 0, 0, 0)); + test_eq(9, Ipv4Address::new(255, 128, 0, 0)); + test_eq(10, Ipv4Address::new(255, 192, 0, 0)); + test_eq(11, Ipv4Address::new(255, 224, 0, 0)); + test_eq(12, Ipv4Address::new(255, 240, 0, 0)); + test_eq(13, Ipv4Address::new(255, 248, 0, 0)); + test_eq(14, Ipv4Address::new(255, 252, 0, 0)); + test_eq(15, Ipv4Address::new(255, 254, 0, 0)); + test_eq(16, Ipv4Address::new(255, 255, 0, 0)); + test_eq(17, Ipv4Address::new(255, 255, 128, 0)); + test_eq(18, Ipv4Address::new(255, 255, 192, 0)); + test_eq(19, Ipv4Address::new(255, 255, 224, 0)); + test_eq(20, Ipv4Address::new(255, 255, 240, 0)); + test_eq(21, Ipv4Address::new(255, 255, 248, 0)); + test_eq(22, Ipv4Address::new(255, 255, 252, 0)); + test_eq(23, Ipv4Address::new(255, 255, 254, 0)); + test_eq(24, Ipv4Address::new(255, 255, 255, 0)); + test_eq(25, Ipv4Address::new(255, 255, 255, 128)); + test_eq(26, Ipv4Address::new(255, 255, 255, 192)); + test_eq(27, Ipv4Address::new(255, 255, 255, 224)); + test_eq(28, Ipv4Address::new(255, 255, 255, 240)); + test_eq(29, Ipv4Address::new(255, 255, 255, 248)); + test_eq(30, Ipv4Address::new(255, 255, 255, 252)); + test_eq(31, Ipv4Address::new(255, 255, 255, 254)); + test_eq(32, Ipv4Address::new(255, 255, 255, 255)); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn to_prefix_len_ipv4_error() { + assert_eq!( + None, + IpAddress::from(Ipv4Address::new(255, 255, 255, 1)).prefix_len() + ); + } + + #[test] + #[cfg(feature = "proto-ipv6")] + fn to_prefix_len_ipv6() { + fn test_eq<A: Into<Address>>(prefix_len: u8, mask: A) { + assert_eq!(Some(prefix_len), mask.into().prefix_len()); + } + + test_eq(0, Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 0)); + test_eq( + 128, + Ipv6Address::new( + 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, + ), + ); + } + + #[test] + #[cfg(feature = "proto-ipv6")] + fn to_prefix_len_ipv6_error() { + assert_eq!( + None, + IpAddress::from(Ipv6Address::new( + 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 1 + )) + .prefix_len() + ); + } +} diff --git a/src/wire/ipsec_ah.rs b/src/wire/ipsec_ah.rs new file mode 100644 index 0000000..1c3f00b --- /dev/null +++ b/src/wire/ipsec_ah.rs @@ -0,0 +1,286 @@ +use byteorder::{ByteOrder, NetworkEndian}; + +use super::{Error, IpProtocol, Result}; + +/// A read/write wrapper around an IPSec Authentication Header packet buffer. +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Packet<T: AsRef<[u8]>> { + buffer: T, +} + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Next Header | Payload Len | RESERVED | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Security Parameters Index (SPI) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Sequence Number Field | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// + Integrity Check Value-ICV (variable) | +// | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +mod field { + #![allow(non_snake_case)] + + use crate::wire::field::Field; + + pub const NEXT_HEADER: usize = 0; + pub const PAYLOAD_LEN: usize = 1; + pub const RESERVED: Field = 2..4; + pub const SPI: Field = 4..8; + pub const SEQUENCE_NUMBER: Field = 8..12; + + pub const fn ICV(payload_len: u8) -> Field { + // The `payload_len` is the length of this Authentication Header in 4-octet units, minus 2. + let header_len = (payload_len as usize + 2) * 4; + + SEQUENCE_NUMBER.end..header_len + } +} + +impl<T: AsRef<[u8]>> Packet<T> { + /// Imbue a raw octet buffer with IPsec Authentication Header packet structure. + pub const fn new_unchecked(buffer: T) -> Packet<T> { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Packet<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short or shorter than payload length. + /// + /// The result of this check is invalidated by calling [set_payload_len]. + /// + /// [set_payload_len]: #method.set_payload_len + #[allow(clippy::if_same_then_else)] + pub fn check_len(&self) -> Result<()> { + let data = self.buffer.as_ref(); + let len = data.len(); + if len < field::SEQUENCE_NUMBER.end { + Err(Error) + } else if len < field::ICV(data[field::PAYLOAD_LEN]).end { + Err(Error) + } else { + Ok(()) + } + } + + /// Consume the packet, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return next header protocol type + /// The value is taken from the list of IP protocol numbers. + pub fn next_header(&self) -> IpProtocol { + let data = self.buffer.as_ref(); + IpProtocol::from(data[field::NEXT_HEADER]) + } + + /// Return the length of this Authentication Header in 4-octet units, minus 2 + pub fn payload_len(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::PAYLOAD_LEN] + } + + /// Return the security parameters index + pub fn security_parameters_index(&self) -> u32 { + let field = &self.buffer.as_ref()[field::SPI]; + NetworkEndian::read_u32(field) + } + + /// Return sequence number + pub fn sequence_number(&self) -> u32 { + let field = &self.buffer.as_ref()[field::SEQUENCE_NUMBER]; + NetworkEndian::read_u32(field) + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> { + /// Return a pointer to the integrity check value + #[inline] + pub fn integrity_check_value(&self) -> &'a [u8] { + let data = self.buffer.as_ref(); + &data[field::ICV(data[field::PAYLOAD_LEN])] + } +} + +impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> { + fn as_ref(&self) -> &[u8] { + self.buffer.as_ref() + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set next header protocol field + fn set_next_header(&mut self, value: IpProtocol) { + let data = self.buffer.as_mut(); + data[field::NEXT_HEADER] = value.into() + } + + /// Set payload length field + fn set_payload_len(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::PAYLOAD_LEN] = value + } + + /// Clear reserved field + fn clear_reserved(&mut self) { + let data = self.buffer.as_mut(); + data[field::RESERVED].fill(0) + } + + /// Set security parameters index field + fn set_security_parameters_index(&mut self, value: u32) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::SPI], value) + } + + /// Set sequence number + fn set_sequence_number(&mut self, value: u32) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::SEQUENCE_NUMBER], value) + } + + /// Return a mutable pointer to the integrity check value. + #[inline] + pub fn integrity_check_value_mut(&mut self) -> &mut [u8] { + let data = self.buffer.as_mut(); + let range = field::ICV(data[field::PAYLOAD_LEN]); + &mut data[range] + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Repr<'a> { + next_header: IpProtocol, + security_parameters_index: u32, + sequence_number: u32, + integrity_check_value: &'a [u8], +} + +impl<'a> Repr<'a> { + /// Parse an IPSec Authentication Header packet and return a high-level representation. + pub fn parse<T: AsRef<[u8]> + ?Sized>(packet: &Packet<&'a T>) -> Result<Repr<'a>> { + Ok(Repr { + next_header: packet.next_header(), + security_parameters_index: packet.security_parameters_index(), + sequence_number: packet.sequence_number(), + integrity_check_value: packet.integrity_check_value(), + }) + } + + /// Return the length of a packet that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + self.integrity_check_value.len() + field::SEQUENCE_NUMBER.end + } + + /// Emit a high-level representation into an IPSec Authentication Header. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(&self, packet: &mut Packet<&'a mut T>) { + packet.set_next_header(self.next_header); + + let payload_len = ((field::SEQUENCE_NUMBER.end + self.integrity_check_value.len()) / 4) - 2; + packet.set_payload_len(payload_len as u8); + + packet.clear_reserved(); + packet.set_security_parameters_index(self.security_parameters_index); + packet.set_sequence_number(self.sequence_number); + packet + .integrity_check_value_mut() + .copy_from_slice(self.integrity_check_value); + } +} + +#[cfg(test)] +mod test { + use super::*; + + static PACKET_BYTES1: [u8; 24] = [ + 0x32, 0x04, 0x00, 0x00, 0x81, 0x79, 0xb7, 0x05, 0x00, 0x00, 0x00, 0x01, 0x27, 0xcf, 0xc0, + 0xa5, 0xe4, 0x3d, 0x69, 0xb3, 0x72, 0x8e, 0xc5, 0xb0, + ]; + + static PACKET_BYTES2: [u8; 24] = [ + 0x32, 0x04, 0x00, 0x00, 0xba, 0x8b, 0xd0, 0x60, 0x00, 0x00, 0x00, 0x01, 0xaf, 0xd2, 0xe7, + 0xa1, 0x73, 0xd3, 0x29, 0x0b, 0xfe, 0x6b, 0x63, 0x73, + ]; + + #[test] + fn test_deconstruct() { + let packet = Packet::new_unchecked(&PACKET_BYTES1[..]); + assert_eq!(packet.next_header(), IpProtocol::IpSecEsp); + assert_eq!(packet.payload_len(), 4); + assert_eq!(packet.security_parameters_index(), 0x8179b705); + assert_eq!(packet.sequence_number(), 1); + assert_eq!( + packet.integrity_check_value(), + &[0x27, 0xcf, 0xc0, 0xa5, 0xe4, 0x3d, 0x69, 0xb3, 0x72, 0x8e, 0xc5, 0xb0] + ); + } + + #[test] + fn test_construct() { + let mut bytes = vec![0xa5; 24]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_next_header(IpProtocol::IpSecEsp); + packet.set_payload_len(4); + packet.clear_reserved(); + packet.set_security_parameters_index(0xba8bd060); + packet.set_sequence_number(1); + const ICV: [u8; 12] = [ + 0xaf, 0xd2, 0xe7, 0xa1, 0x73, 0xd3, 0x29, 0x0b, 0xfe, 0x6b, 0x63, 0x73, + ]; + packet.integrity_check_value_mut().copy_from_slice(&ICV); + assert_eq!(bytes, PACKET_BYTES2); + } + #[test] + fn test_check_len() { + assert!(matches!(Packet::new_checked(&PACKET_BYTES1[..10]), Err(_))); + assert!(matches!(Packet::new_checked(&PACKET_BYTES1[..22]), Err(_))); + assert!(matches!(Packet::new_checked(&PACKET_BYTES1[..]), Ok(_))); + } + + fn packet_repr<'a>() -> Repr<'a> { + Repr { + next_header: IpProtocol::IpSecEsp, + security_parameters_index: 0xba8bd060, + sequence_number: 1, + integrity_check_value: &[ + 0xaf, 0xd2, 0xe7, 0xa1, 0x73, 0xd3, 0x29, 0x0b, 0xfe, 0x6b, 0x63, 0x73, + ], + } + } + + #[test] + fn test_parse() { + let packet = Packet::new_unchecked(&PACKET_BYTES2[..]); + assert_eq!(Repr::parse(&packet).unwrap(), packet_repr()); + } + + #[test] + fn test_emit() { + let mut bytes = vec![0x17; 24]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet_repr().emit(&mut packet); + assert_eq!(bytes, PACKET_BYTES2); + } + + #[test] + fn test_buffer_len() { + let header = Packet::new_unchecked(&PACKET_BYTES1[..]); + let repr = Repr::parse(&header).unwrap(); + assert_eq!(repr.buffer_len(), PACKET_BYTES1.len()); + } +} diff --git a/src/wire/ipsec_esp.rs b/src/wire/ipsec_esp.rs new file mode 100644 index 0000000..d0cd572 --- /dev/null +++ b/src/wire/ipsec_esp.rs @@ -0,0 +1,177 @@ +use super::{Error, Result}; +use byteorder::{ByteOrder, NetworkEndian}; + +/// A read/write wrapper around an IPSec Encapsulating Security Payload (ESP) packet buffer. +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Packet<T: AsRef<[u8]>> { + buffer: T, +} + +mod field { + use crate::wire::field::Field; + + pub const SPI: Field = 0..4; + pub const SEQUENCE_NUMBER: Field = 4..8; +} + +impl<T: AsRef<[u8]>> Packet<T> { + /// Imbue a raw octet buffer with IPsec Encapsulating Security Payload packet structure. + pub const fn new_unchecked(buffer: T) -> Packet<T> { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Packet<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let data = self.buffer.as_ref(); + let len = data.len(); + if len < field::SEQUENCE_NUMBER.end { + Err(Error) + } else { + Ok(()) + } + } + + /// Consume the packet, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the security parameters index + pub fn security_parameters_index(&self) -> u32 { + let field = &self.buffer.as_ref()[field::SPI]; + NetworkEndian::read_u32(field) + } + + /// Return sequence number + pub fn sequence_number(&self) -> u32 { + let field = &self.buffer.as_ref()[field::SEQUENCE_NUMBER]; + NetworkEndian::read_u32(field) + } +} + +impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> { + fn as_ref(&self) -> &[u8] { + self.buffer.as_ref() + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set security parameters index field + fn set_security_parameters_index(&mut self, value: u32) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::SPI], value) + } + + /// Set sequence number + fn set_sequence_number(&mut self, value: u32) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::SEQUENCE_NUMBER], value) + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Repr { + security_parameters_index: u32, + sequence_number: u32, +} + +impl Repr { + /// Parse an IPSec Encapsulating Security Payload packet and return a high-level representation. + pub fn parse<T: AsRef<[u8]>>(packet: &Packet<T>) -> Result<Repr> { + Ok(Repr { + security_parameters_index: packet.security_parameters_index(), + sequence_number: packet.sequence_number(), + }) + } + + /// Return the length of a packet that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + field::SEQUENCE_NUMBER.end + } + + /// Emit a high-level representation into an IPSec Encapsulating Security Payload. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, packet: &mut Packet<T>) { + packet.set_security_parameters_index(self.security_parameters_index); + packet.set_sequence_number(self.sequence_number); + } +} + +#[cfg(test)] +mod test { + use super::*; + + static PACKET_BYTES: [u8; 136] = [ + 0xfb, 0x51, 0x28, 0xa6, 0x00, 0x00, 0x00, 0x02, 0x5d, 0xbe, 0x2d, 0x56, 0xd4, 0x6a, 0x57, + 0xf5, 0xfc, 0x69, 0x8b, 0x3c, 0xa6, 0xb6, 0x88, 0x3a, 0x6c, 0xc1, 0x33, 0x92, 0xdb, 0x40, + 0xab, 0x11, 0x54, 0xb4, 0x0f, 0x22, 0x4d, 0x37, 0x3a, 0x06, 0x94, 0x1e, 0xd4, 0x25, 0xaf, + 0xf0, 0xb0, 0x11, 0x1f, 0x07, 0x96, 0x2a, 0xa7, 0x20, 0xb1, 0xf5, 0x52, 0xb2, 0x12, 0x46, + 0xd6, 0xa5, 0x13, 0x4e, 0x97, 0x75, 0x44, 0x19, 0xc7, 0x29, 0x35, 0xc5, 0xed, 0xa4, 0x0c, + 0xe7, 0x87, 0xec, 0x9c, 0xb1, 0x12, 0x42, 0x74, 0x7c, 0x12, 0x3c, 0x7f, 0x44, 0x9c, 0x6b, + 0x46, 0x27, 0x28, 0xd2, 0x0e, 0xb1, 0x28, 0xd3, 0xd8, 0xc2, 0xd1, 0xac, 0x25, 0xfe, 0xef, + 0xed, 0x13, 0xfd, 0x8f, 0x18, 0x9c, 0x2d, 0xb1, 0x0e, 0x50, 0xe9, 0xaa, 0x65, 0x93, 0x56, + 0x40, 0x43, 0xa3, 0x72, 0x54, 0xba, 0x1b, 0xb1, 0xaf, 0xca, 0x04, 0x15, 0xf9, 0xef, 0xb7, + 0x1d, + ]; + + #[test] + fn test_deconstruct() { + let packet = Packet::new_unchecked(&PACKET_BYTES[..]); + assert_eq!(packet.security_parameters_index(), 0xfb5128a6); + assert_eq!(packet.sequence_number(), 2); + } + + #[test] + fn test_construct() { + let mut bytes = vec![0xa5; 8]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_security_parameters_index(0xfb5128a6); + packet.set_sequence_number(2); + assert_eq!(&bytes, &PACKET_BYTES[..8]); + } + #[test] + fn test_check_len() { + assert!(matches!(Packet::new_checked(&PACKET_BYTES[..7]), Err(_))); + assert!(matches!(Packet::new_checked(&PACKET_BYTES[..]), Ok(_))); + } + + fn packet_repr() -> Repr { + Repr { + security_parameters_index: 0xfb5128a6, + sequence_number: 2, + } + } + + #[test] + fn test_parse() { + let packet = Packet::new_unchecked(&PACKET_BYTES[..]); + assert_eq!(Repr::parse(&packet).unwrap(), packet_repr()); + } + + #[test] + fn test_emit() { + let mut bytes = vec![0x17; 8]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet_repr().emit(&mut packet); + assert_eq!(&bytes, &PACKET_BYTES[..8]); + } + + #[test] + fn test_buffer_len() { + let header = Packet::new_unchecked(&PACKET_BYTES[..]); + let repr = Repr::parse(&header).unwrap(); + assert_eq!(repr.buffer_len(), 8); + } +} diff --git a/src/wire/ipv4.rs b/src/wire/ipv4.rs new file mode 100644 index 0000000..1027fc2 --- /dev/null +++ b/src/wire/ipv4.rs @@ -0,0 +1,1178 @@ +use byteorder::{ByteOrder, NetworkEndian}; +use core::fmt; + +use super::{Error, Result}; +use crate::phy::ChecksumCapabilities; +use crate::wire::ip::{checksum, pretty_print_ip_payload}; + +pub use super::IpProtocol as Protocol; + +/// Minimum MTU required of all links supporting IPv4. See [RFC 791 § 3.1]. +/// +/// [RFC 791 § 3.1]: https://tools.ietf.org/html/rfc791#section-3.1 +// RFC 791 states the following: +// +// > Every internet module must be able to forward a datagram of 68 +// > octets without further fragmentation... Every internet destination +// > must be able to receive a datagram of 576 octets either in one piece +// > or in fragments to be reassembled. +// +// As a result, we can assume that every host we send packets to can +// accept a packet of the following size. +pub const MIN_MTU: usize = 576; + +/// Size of IPv4 adderess in octets. +/// +/// [RFC 8200 § 2]: https://www.rfc-editor.org/rfc/rfc791#section-3.2 +pub const ADDR_SIZE: usize = 4; + +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Key { + id: u16, + src_addr: Address, + dst_addr: Address, + protocol: Protocol, +} + +/// A four-octet IPv4 address. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] +pub struct Address(pub [u8; ADDR_SIZE]); + +impl Address { + /// An unspecified address. + pub const UNSPECIFIED: Address = Address([0x00; ADDR_SIZE]); + + /// The broadcast address. + pub const BROADCAST: Address = Address([0xff; ADDR_SIZE]); + + /// All multicast-capable nodes + pub const MULTICAST_ALL_SYSTEMS: Address = Address([224, 0, 0, 1]); + + /// All multicast-capable routers + pub const MULTICAST_ALL_ROUTERS: Address = Address([224, 0, 0, 2]); + + /// Construct an IPv4 address from parts. + pub const fn new(a0: u8, a1: u8, a2: u8, a3: u8) -> Address { + Address([a0, a1, a2, a3]) + } + + /// Construct an IPv4 address from a sequence of octets, in big-endian. + /// + /// # Panics + /// The function panics if `data` is not four octets long. + pub fn from_bytes(data: &[u8]) -> Address { + let mut bytes = [0; ADDR_SIZE]; + bytes.copy_from_slice(data); + Address(bytes) + } + + /// Return an IPv4 address as a sequence of octets, in big-endian. + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } + + /// Query whether the address is an unicast address. + pub fn is_unicast(&self) -> bool { + !(self.is_broadcast() || self.is_multicast() || self.is_unspecified()) + } + + /// Query whether the address is the broadcast address. + pub fn is_broadcast(&self) -> bool { + self.0[0..4] == [255; ADDR_SIZE] + } + + /// Query whether the address is a multicast address. + pub const fn is_multicast(&self) -> bool { + self.0[0] & 0xf0 == 224 + } + + /// Query whether the address falls into the "unspecified" range. + pub const fn is_unspecified(&self) -> bool { + self.0[0] == 0 + } + + /// Query whether the address falls into the "link-local" range. + pub fn is_link_local(&self) -> bool { + self.0[0..2] == [169, 254] + } + + /// Query whether the address falls into the "loopback" range. + pub const fn is_loopback(&self) -> bool { + self.0[0] == 127 + } + + /// Convert to an `IpAddress`. + /// + /// Same as `.into()`, but works in `const`. + pub const fn into_address(self) -> super::IpAddress { + super::IpAddress::Ipv4(self) + } +} + +#[cfg(feature = "std")] +impl From<::std::net::Ipv4Addr> for Address { + fn from(x: ::std::net::Ipv4Addr) -> Address { + Address(x.octets()) + } +} + +#[cfg(feature = "std")] +impl From<Address> for ::std::net::Ipv4Addr { + fn from(Address(x): Address) -> ::std::net::Ipv4Addr { + x.into() + } +} + +impl fmt::Display for Address { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let bytes = self.0; + write!(f, "{}.{}.{}.{}", bytes[0], bytes[1], bytes[2], bytes[3]) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Address { + fn format(&self, f: defmt::Formatter) { + defmt::write!( + f, + "{=u8}.{=u8}.{=u8}.{=u8}", + self.0[0], + self.0[1], + self.0[2], + self.0[3] + ) + } +} + +/// A specification of an IPv4 CIDR block, containing an address and a variable-length +/// subnet masking prefix length. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] +pub struct Cidr { + address: Address, + prefix_len: u8, +} + +impl Cidr { + /// Create an IPv4 CIDR block from the given address and prefix length. + /// + /// # Panics + /// This function panics if the prefix length is larger than 32. + #[allow(clippy::no_effect)] + pub const fn new(address: Address, prefix_len: u8) -> Cidr { + // Replace with const panic (or assert) when stabilized + // see: https://github.com/rust-lang/rust/issues/51999 + ["Prefix length should be <= 32"][(prefix_len > 32) as usize]; + Cidr { + address, + prefix_len, + } + } + + /// Create an IPv4 CIDR block from the given address and network mask. + pub fn from_netmask(addr: Address, netmask: Address) -> Result<Cidr> { + let netmask = NetworkEndian::read_u32(&netmask.0[..]); + if netmask.leading_zeros() == 0 && netmask.trailing_zeros() == netmask.count_zeros() { + Ok(Cidr { + address: addr, + prefix_len: netmask.count_ones() as u8, + }) + } else { + Err(Error) + } + } + + /// Return the address of this IPv4 CIDR block. + pub const fn address(&self) -> Address { + self.address + } + + /// Return the prefix length of this IPv4 CIDR block. + pub const fn prefix_len(&self) -> u8 { + self.prefix_len + } + + /// Return the network mask of this IPv4 CIDR. + pub const fn netmask(&self) -> Address { + if self.prefix_len == 0 { + return Address([0, 0, 0, 0]); + } + + let number = 0xffffffffu32 << (32 - self.prefix_len); + let data = [ + ((number >> 24) & 0xff) as u8, + ((number >> 16) & 0xff) as u8, + ((number >> 8) & 0xff) as u8, + ((number >> 0) & 0xff) as u8, + ]; + + Address(data) + } + + /// Return the broadcast address of this IPv4 CIDR. + pub fn broadcast(&self) -> Option<Address> { + let network = self.network(); + + if network.prefix_len == 31 || network.prefix_len == 32 { + return None; + } + + let network_number = NetworkEndian::read_u32(&network.address.0[..]); + let number = network_number | 0xffffffffu32 >> network.prefix_len; + let data = [ + ((number >> 24) & 0xff) as u8, + ((number >> 16) & 0xff) as u8, + ((number >> 8) & 0xff) as u8, + ((number >> 0) & 0xff) as u8, + ]; + + Some(Address(data)) + } + + /// Return the network block of this IPv4 CIDR. + pub const fn network(&self) -> Cidr { + let mask = self.netmask().0; + let network = [ + self.address.0[0] & mask[0], + self.address.0[1] & mask[1], + self.address.0[2] & mask[2], + self.address.0[3] & mask[3], + ]; + Cidr { + address: Address(network), + prefix_len: self.prefix_len, + } + } + + /// Query whether the subnetwork described by this IPv4 CIDR block contains + /// the given address. + pub fn contains_addr(&self, addr: &Address) -> bool { + // right shift by 32 is not legal + if self.prefix_len == 0 { + return true; + } + + let shift = 32 - self.prefix_len; + let self_prefix = NetworkEndian::read_u32(self.address.as_bytes()) >> shift; + let addr_prefix = NetworkEndian::read_u32(addr.as_bytes()) >> shift; + self_prefix == addr_prefix + } + + /// Query whether the subnetwork described by this IPv4 CIDR block contains + /// the subnetwork described by the given IPv4 CIDR block. + pub fn contains_subnet(&self, subnet: &Cidr) -> bool { + self.prefix_len <= subnet.prefix_len && self.contains_addr(&subnet.address) + } +} + +impl fmt::Display for Cidr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}/{}", self.address, self.prefix_len) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Cidr { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "{}/{=u8}", self.address, self.prefix_len); + } +} + +/// A read/write wrapper around an Internet Protocol version 4 packet buffer. +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Packet<T: AsRef<[u8]>> { + buffer: T, +} + +mod field { + use crate::wire::field::*; + + pub const VER_IHL: usize = 0; + pub const DSCP_ECN: usize = 1; + pub const LENGTH: Field = 2..4; + pub const IDENT: Field = 4..6; + pub const FLG_OFF: Field = 6..8; + pub const TTL: usize = 8; + pub const PROTOCOL: usize = 9; + pub const CHECKSUM: Field = 10..12; + pub const SRC_ADDR: Field = 12..16; + pub const DST_ADDR: Field = 16..20; +} + +pub const HEADER_LEN: usize = field::DST_ADDR.end; + +impl<T: AsRef<[u8]>> Packet<T> { + /// Imbue a raw octet buffer with IPv4 packet structure. + pub const fn new_unchecked(buffer: T) -> Packet<T> { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Packet<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + /// Returns `Err(Error)` if the header length is greater + /// than total length. + /// + /// The result of this check is invalidated by calling [set_header_len] + /// and [set_total_len]. + /// + /// [set_header_len]: #method.set_header_len + /// [set_total_len]: #method.set_total_len + #[allow(clippy::if_same_then_else)] + pub fn check_len(&self) -> Result<()> { + let len = self.buffer.as_ref().len(); + if len < field::DST_ADDR.end { + Err(Error) + } else if len < self.header_len() as usize { + Err(Error) + } else if self.header_len() as u16 > self.total_len() { + Err(Error) + } else if len < self.total_len() as usize { + Err(Error) + } else { + Ok(()) + } + } + + /// Consume the packet, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the version field. + #[inline] + pub fn version(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::VER_IHL] >> 4 + } + + /// Return the header length, in octets. + #[inline] + pub fn header_len(&self) -> u8 { + let data = self.buffer.as_ref(); + (data[field::VER_IHL] & 0x0f) * 4 + } + + /// Return the Differential Services Code Point field. + pub fn dscp(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::DSCP_ECN] >> 2 + } + + /// Return the Explicit Congestion Notification field. + pub fn ecn(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::DSCP_ECN] & 0x03 + } + + /// Return the total length field. + #[inline] + pub fn total_len(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::LENGTH]) + } + + /// Return the fragment identification field. + #[inline] + pub fn ident(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::IDENT]) + } + + /// Return the "don't fragment" flag. + #[inline] + pub fn dont_frag(&self) -> bool { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::FLG_OFF]) & 0x4000 != 0 + } + + /// Return the "more fragments" flag. + #[inline] + pub fn more_frags(&self) -> bool { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::FLG_OFF]) & 0x2000 != 0 + } + + /// Return the fragment offset, in octets. + #[inline] + pub fn frag_offset(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::FLG_OFF]) << 3 + } + + /// Return the time to live field. + #[inline] + pub fn hop_limit(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::TTL] + } + + /// Return the next_header (protocol) field. + #[inline] + pub fn next_header(&self) -> Protocol { + let data = self.buffer.as_ref(); + Protocol::from(data[field::PROTOCOL]) + } + + /// Return the header checksum field. + #[inline] + pub fn checksum(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::CHECKSUM]) + } + + /// Return the source address field. + #[inline] + pub fn src_addr(&self) -> Address { + let data = self.buffer.as_ref(); + Address::from_bytes(&data[field::SRC_ADDR]) + } + + /// Return the destination address field. + #[inline] + pub fn dst_addr(&self) -> Address { + let data = self.buffer.as_ref(); + Address::from_bytes(&data[field::DST_ADDR]) + } + + /// Validate the header checksum. + /// + /// # Fuzzing + /// This function always returns `true` when fuzzing. + pub fn verify_checksum(&self) -> bool { + if cfg!(fuzzing) { + return true; + } + + let data = self.buffer.as_ref(); + checksum::data(&data[..self.header_len() as usize]) == !0 + } + + /// Returns the key for identifying the packet. + pub fn get_key(&self) -> Key { + Key { + id: self.ident(), + src_addr: self.src_addr(), + dst_addr: self.dst_addr(), + protocol: self.next_header(), + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> { + /// Return a pointer to the payload. + #[inline] + pub fn payload(&self) -> &'a [u8] { + let range = self.header_len() as usize..self.total_len() as usize; + let data = self.buffer.as_ref(); + &data[range] + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the version field. + #[inline] + pub fn set_version(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::VER_IHL] = (data[field::VER_IHL] & !0xf0) | (value << 4); + } + + /// Set the header length, in octets. + #[inline] + pub fn set_header_len(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::VER_IHL] = (data[field::VER_IHL] & !0x0f) | ((value / 4) & 0x0f); + } + + /// Set the Differential Services Code Point field. + pub fn set_dscp(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::DSCP_ECN] = (data[field::DSCP_ECN] & !0xfc) | (value << 2) + } + + /// Set the Explicit Congestion Notification field. + pub fn set_ecn(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::DSCP_ECN] = (data[field::DSCP_ECN] & !0x03) | (value & 0x03) + } + + /// Set the total length field. + #[inline] + pub fn set_total_len(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::LENGTH], value) + } + + /// Set the fragment identification field. + #[inline] + pub fn set_ident(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::IDENT], value) + } + + /// Clear the entire flags field. + #[inline] + pub fn clear_flags(&mut self) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLG_OFF]); + let raw = raw & !0xe000; + NetworkEndian::write_u16(&mut data[field::FLG_OFF], raw); + } + + /// Set the "don't fragment" flag. + #[inline] + pub fn set_dont_frag(&mut self, value: bool) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLG_OFF]); + let raw = if value { raw | 0x4000 } else { raw & !0x4000 }; + NetworkEndian::write_u16(&mut data[field::FLG_OFF], raw); + } + + /// Set the "more fragments" flag. + #[inline] + pub fn set_more_frags(&mut self, value: bool) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLG_OFF]); + let raw = if value { raw | 0x2000 } else { raw & !0x2000 }; + NetworkEndian::write_u16(&mut data[field::FLG_OFF], raw); + } + + /// Set the fragment offset, in octets. + #[inline] + pub fn set_frag_offset(&mut self, value: u16) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLG_OFF]); + let raw = (raw & 0xe000) | (value >> 3); + NetworkEndian::write_u16(&mut data[field::FLG_OFF], raw); + } + + /// Set the time to live field. + #[inline] + pub fn set_hop_limit(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::TTL] = value + } + + /// Set the next header (protocol) field. + #[inline] + pub fn set_next_header(&mut self, value: Protocol) { + let data = self.buffer.as_mut(); + data[field::PROTOCOL] = value.into() + } + + /// Set the header checksum field. + #[inline] + pub fn set_checksum(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::CHECKSUM], value) + } + + /// Set the source address field. + #[inline] + pub fn set_src_addr(&mut self, value: Address) { + let data = self.buffer.as_mut(); + data[field::SRC_ADDR].copy_from_slice(value.as_bytes()) + } + + /// Set the destination address field. + #[inline] + pub fn set_dst_addr(&mut self, value: Address) { + let data = self.buffer.as_mut(); + data[field::DST_ADDR].copy_from_slice(value.as_bytes()) + } + + /// Compute and fill in the header checksum. + pub fn fill_checksum(&mut self) { + self.set_checksum(0); + let checksum = { + let data = self.buffer.as_ref(); + !checksum::data(&data[..self.header_len() as usize]) + }; + self.set_checksum(checksum) + } + + /// Return a mutable pointer to the payload. + #[inline] + pub fn payload_mut(&mut self) -> &mut [u8] { + let range = self.header_len() as usize..self.total_len() as usize; + let data = self.buffer.as_mut(); + &mut data[range] + } +} + +impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> { + fn as_ref(&self) -> &[u8] { + self.buffer.as_ref() + } +} + +/// A high-level representation of an Internet Protocol version 4 packet header. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Repr { + pub src_addr: Address, + pub dst_addr: Address, + pub next_header: Protocol, + pub payload_len: usize, + pub hop_limit: u8, +} + +impl Repr { + /// Parse an Internet Protocol version 4 packet and return a high-level representation. + pub fn parse<T: AsRef<[u8]> + ?Sized>( + packet: &Packet<&T>, + checksum_caps: &ChecksumCapabilities, + ) -> Result<Repr> { + // Version 4 is expected. + if packet.version() != 4 { + return Err(Error); + } + // Valid checksum is expected. + if checksum_caps.ipv4.rx() && !packet.verify_checksum() { + return Err(Error); + } + + #[cfg(not(feature = "proto-ipv4-fragmentation"))] + // We do not support fragmentation. + if packet.more_frags() || packet.frag_offset() != 0 { + return Err(Error); + } + + let payload_len = packet.total_len() as usize - packet.header_len() as usize; + + // All DSCP values are acceptable, since they are of no concern to receiving endpoint. + // All ECN values are acceptable, since ECN requires opt-in from both endpoints. + // All TTL values are acceptable, since we do not perform routing. + Ok(Repr { + src_addr: packet.src_addr(), + dst_addr: packet.dst_addr(), + next_header: packet.next_header(), + payload_len, + hop_limit: packet.hop_limit(), + }) + } + + /// Return the length of a header that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + // We never emit any options. + field::DST_ADDR.end + } + + /// Emit a high-level representation into an Internet Protocol version 4 packet. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>( + &self, + packet: &mut Packet<T>, + checksum_caps: &ChecksumCapabilities, + ) { + packet.set_version(4); + packet.set_header_len(field::DST_ADDR.end as u8); + packet.set_dscp(0); + packet.set_ecn(0); + let total_len = packet.header_len() as u16 + self.payload_len as u16; + packet.set_total_len(total_len); + packet.set_ident(0); + packet.clear_flags(); + packet.set_more_frags(false); + packet.set_dont_frag(true); + packet.set_frag_offset(0); + packet.set_hop_limit(self.hop_limit); + packet.set_next_header(self.next_header); + packet.set_src_addr(self.src_addr); + packet.set_dst_addr(self.dst_addr); + + if checksum_caps.ipv4.tx() { + packet.fill_checksum(); + } else { + // make sure we get a consistently zeroed checksum, + // since implementations might rely on it + packet.set_checksum(0); + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match Repr::parse(self, &ChecksumCapabilities::ignored()) { + Ok(repr) => write!(f, "{repr}"), + Err(err) => { + write!(f, "IPv4 ({err})")?; + write!( + f, + " src={} dst={} proto={} hop_limit={}", + self.src_addr(), + self.dst_addr(), + self.next_header(), + self.hop_limit() + )?; + if self.version() != 4 { + write!(f, " ver={}", self.version())?; + } + if self.header_len() != 20 { + write!(f, " hlen={}", self.header_len())?; + } + if self.dscp() != 0 { + write!(f, " dscp={}", self.dscp())?; + } + if self.ecn() != 0 { + write!(f, " ecn={}", self.ecn())?; + } + write!(f, " tlen={}", self.total_len())?; + if self.dont_frag() { + write!(f, " df")?; + } + if self.more_frags() { + write!(f, " mf")?; + } + if self.frag_offset() != 0 { + write!(f, " off={}", self.frag_offset())?; + } + if self.more_frags() || self.frag_offset() != 0 { + write!(f, " id={}", self.ident())?; + } + Ok(()) + } + } + } +} + +impl fmt::Display for Repr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "IPv4 src={} dst={} proto={}", + self.src_addr, self.dst_addr, self.next_header + ) + } +} + +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; + +impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { + use crate::wire::ip::checksum::format_checksum; + + let checksum_caps = ChecksumCapabilities::ignored(); + + let (ip_repr, payload) = match Packet::new_checked(buffer) { + Err(err) => return write!(f, "{indent}({err})"), + Ok(ip_packet) => match Repr::parse(&ip_packet, &checksum_caps) { + Err(_) => return Ok(()), + Ok(ip_repr) => { + if ip_packet.more_frags() || ip_packet.frag_offset() != 0 { + write!( + f, + "{}IPv4 Fragment more_frags={} offset={}", + indent, + ip_packet.more_frags(), + ip_packet.frag_offset() + )?; + return Ok(()); + } else { + write!(f, "{indent}{ip_repr}")?; + format_checksum(f, ip_packet.verify_checksum())?; + (ip_repr, ip_packet.payload()) + } + } + }, + }; + + pretty_print_ip_payload(f, indent, ip_repr, payload) + } +} + +#[cfg(test)] +mod test { + use super::*; + + static PACKET_BYTES: [u8; 30] = [ + 0x45, 0x00, 0x00, 0x1e, 0x01, 0x02, 0x62, 0x03, 0x1a, 0x01, 0xd5, 0x6e, 0x11, 0x12, 0x13, + 0x14, 0x21, 0x22, 0x23, 0x24, 0xaa, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, + ]; + + static PAYLOAD_BYTES: [u8; 10] = [0xaa, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff]; + + #[test] + fn test_deconstruct() { + let packet = Packet::new_unchecked(&PACKET_BYTES[..]); + assert_eq!(packet.version(), 4); + assert_eq!(packet.header_len(), 20); + assert_eq!(packet.dscp(), 0); + assert_eq!(packet.ecn(), 0); + assert_eq!(packet.total_len(), 30); + assert_eq!(packet.ident(), 0x102); + assert!(packet.more_frags()); + assert!(packet.dont_frag()); + assert_eq!(packet.frag_offset(), 0x203 * 8); + assert_eq!(packet.hop_limit(), 0x1a); + assert_eq!(packet.next_header(), Protocol::Icmp); + assert_eq!(packet.checksum(), 0xd56e); + assert_eq!(packet.src_addr(), Address([0x11, 0x12, 0x13, 0x14])); + assert_eq!(packet.dst_addr(), Address([0x21, 0x22, 0x23, 0x24])); + assert!(packet.verify_checksum()); + assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]); + } + + #[test] + fn test_construct() { + let mut bytes = vec![0xa5; 30]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_version(4); + packet.set_header_len(20); + packet.clear_flags(); + packet.set_dscp(0); + packet.set_ecn(0); + packet.set_total_len(30); + packet.set_ident(0x102); + packet.set_more_frags(true); + packet.set_dont_frag(true); + packet.set_frag_offset(0x203 * 8); + packet.set_hop_limit(0x1a); + packet.set_next_header(Protocol::Icmp); + packet.set_src_addr(Address([0x11, 0x12, 0x13, 0x14])); + packet.set_dst_addr(Address([0x21, 0x22, 0x23, 0x24])); + packet.fill_checksum(); + packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]); + assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]); + } + + #[test] + fn test_overlong() { + let mut bytes = vec![]; + bytes.extend(&PACKET_BYTES[..]); + bytes.push(0); + + assert_eq!( + Packet::new_unchecked(&bytes).payload().len(), + PAYLOAD_BYTES.len() + ); + assert_eq!( + Packet::new_unchecked(&mut bytes).payload_mut().len(), + PAYLOAD_BYTES.len() + ); + } + + #[test] + fn test_total_len_overflow() { + let mut bytes = vec![]; + bytes.extend(&PACKET_BYTES[..]); + Packet::new_unchecked(&mut bytes).set_total_len(128); + + assert_eq!(Packet::new_checked(&bytes).unwrap_err(), Error); + } + + static REPR_PACKET_BYTES: [u8; 24] = [ + 0x45, 0x00, 0x00, 0x18, 0x00, 0x00, 0x40, 0x00, 0x40, 0x01, 0xd2, 0x79, 0x11, 0x12, 0x13, + 0x14, 0x21, 0x22, 0x23, 0x24, 0xaa, 0x00, 0x00, 0xff, + ]; + + static REPR_PAYLOAD_BYTES: [u8; ADDR_SIZE] = [0xaa, 0x00, 0x00, 0xff]; + + const fn packet_repr() -> Repr { + Repr { + src_addr: Address([0x11, 0x12, 0x13, 0x14]), + dst_addr: Address([0x21, 0x22, 0x23, 0x24]), + next_header: Protocol::Icmp, + payload_len: 4, + hop_limit: 64, + } + } + + #[test] + fn test_parse() { + let packet = Packet::new_unchecked(&REPR_PACKET_BYTES[..]); + let repr = Repr::parse(&packet, &ChecksumCapabilities::default()).unwrap(); + assert_eq!(repr, packet_repr()); + } + + #[test] + fn test_parse_bad_version() { + let mut bytes = vec![0; 24]; + bytes.copy_from_slice(&REPR_PACKET_BYTES[..]); + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_version(6); + packet.fill_checksum(); + let packet = Packet::new_unchecked(&*packet.into_inner()); + assert_eq!( + Repr::parse(&packet, &ChecksumCapabilities::default()), + Err(Error) + ); + } + + #[test] + fn test_parse_total_len_less_than_header_len() { + let mut bytes = vec![0; 40]; + bytes[0] = 0x09; + assert_eq!(Packet::new_checked(&mut bytes), Err(Error)); + } + + #[test] + fn test_emit() { + let repr = packet_repr(); + let mut bytes = vec![0xa5; repr.buffer_len() + REPR_PAYLOAD_BYTES.len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + repr.emit(&mut packet, &ChecksumCapabilities::default()); + packet.payload_mut().copy_from_slice(&REPR_PAYLOAD_BYTES); + assert_eq!(&*packet.into_inner(), &REPR_PACKET_BYTES[..]); + } + + #[test] + fn test_unspecified() { + assert!(Address::UNSPECIFIED.is_unspecified()); + assert!(!Address::UNSPECIFIED.is_broadcast()); + assert!(!Address::UNSPECIFIED.is_multicast()); + assert!(!Address::UNSPECIFIED.is_link_local()); + assert!(!Address::UNSPECIFIED.is_loopback()); + } + + #[test] + fn test_broadcast() { + assert!(!Address::BROADCAST.is_unspecified()); + assert!(Address::BROADCAST.is_broadcast()); + assert!(!Address::BROADCAST.is_multicast()); + assert!(!Address::BROADCAST.is_link_local()); + assert!(!Address::BROADCAST.is_loopback()); + } + + #[test] + fn test_cidr() { + let cidr = Cidr::new(Address::new(192, 168, 1, 10), 24); + + let inside_subnet = [ + [192, 168, 1, 0], + [192, 168, 1, 1], + [192, 168, 1, 2], + [192, 168, 1, 10], + [192, 168, 1, 127], + [192, 168, 1, 255], + ]; + + let outside_subnet = [ + [192, 168, 0, 0], + [127, 0, 0, 1], + [192, 168, 2, 0], + [192, 168, 0, 255], + [0, 0, 0, 0], + [255, 255, 255, 255], + ]; + + let subnets = [ + ([192, 168, 1, 0], 32), + ([192, 168, 1, 255], 24), + ([192, 168, 1, 10], 30), + ]; + + let not_subnets = [ + ([192, 168, 1, 10], 23), + ([127, 0, 0, 1], 8), + ([192, 168, 1, 0], 0), + ([192, 168, 0, 255], 32), + ]; + + for addr in inside_subnet.iter().map(|a| Address::from_bytes(a)) { + assert!(cidr.contains_addr(&addr)); + } + + for addr in outside_subnet.iter().map(|a| Address::from_bytes(a)) { + assert!(!cidr.contains_addr(&addr)); + } + + for subnet in subnets + .iter() + .map(|&(a, p)| Cidr::new(Address::new(a[0], a[1], a[2], a[3]), p)) + { + assert!(cidr.contains_subnet(&subnet)); + } + + for subnet in not_subnets + .iter() + .map(|&(a, p)| Cidr::new(Address::new(a[0], a[1], a[2], a[3]), p)) + { + assert!(!cidr.contains_subnet(&subnet)); + } + + let cidr_without_prefix = Cidr::new(cidr.address(), 0); + assert!(cidr_without_prefix.contains_addr(&Address::new(127, 0, 0, 1))); + } + + #[test] + fn test_cidr_from_netmask() { + assert!(Cidr::from_netmask(Address([0, 0, 0, 0]), Address([1, 0, 2, 0])).is_err()); + assert!(Cidr::from_netmask(Address([0, 0, 0, 0]), Address([0, 0, 0, 0])).is_err()); + assert_eq!( + Cidr::from_netmask(Address([0, 0, 0, 1]), Address([255, 255, 255, 0])).unwrap(), + Cidr::new(Address([0, 0, 0, 1]), 24) + ); + assert_eq!( + Cidr::from_netmask(Address([192, 168, 0, 1]), Address([255, 255, 0, 0])).unwrap(), + Cidr::new(Address([192, 168, 0, 1]), 16) + ); + assert_eq!( + Cidr::from_netmask(Address([172, 16, 0, 1]), Address([255, 240, 0, 0])).unwrap(), + Cidr::new(Address([172, 16, 0, 1]), 12) + ); + assert_eq!( + Cidr::from_netmask(Address([255, 255, 255, 1]), Address([255, 255, 255, 0])).unwrap(), + Cidr::new(Address([255, 255, 255, 1]), 24) + ); + assert_eq!( + Cidr::from_netmask(Address([255, 255, 255, 255]), Address([255, 255, 255, 255])) + .unwrap(), + Cidr::new(Address([255, 255, 255, 255]), 32) + ); + } + + #[test] + fn test_cidr_netmask() { + assert_eq!( + Cidr::new(Address([0, 0, 0, 0]), 0).netmask(), + Address([0, 0, 0, 0]) + ); + assert_eq!( + Cidr::new(Address([0, 0, 0, 1]), 24).netmask(), + Address([255, 255, 255, 0]) + ); + assert_eq!( + Cidr::new(Address([0, 0, 0, 0]), 32).netmask(), + Address([255, 255, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([127, 0, 0, 0]), 8).netmask(), + Address([255, 0, 0, 0]) + ); + assert_eq!( + Cidr::new(Address([192, 168, 0, 0]), 16).netmask(), + Address([255, 255, 0, 0]) + ); + assert_eq!( + Cidr::new(Address([192, 168, 1, 1]), 16).netmask(), + Address([255, 255, 0, 0]) + ); + assert_eq!( + Cidr::new(Address([192, 168, 1, 1]), 17).netmask(), + Address([255, 255, 128, 0]) + ); + assert_eq!( + Cidr::new(Address([172, 16, 0, 0]), 12).netmask(), + Address([255, 240, 0, 0]) + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 1]), 24).netmask(), + Address([255, 255, 255, 0]) + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 255]), 32).netmask(), + Address([255, 255, 255, 255]) + ); + } + + #[test] + fn test_cidr_broadcast() { + assert_eq!( + Cidr::new(Address([0, 0, 0, 0]), 0).broadcast().unwrap(), + Address([255, 255, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([0, 0, 0, 1]), 24).broadcast().unwrap(), + Address([0, 0, 0, 255]) + ); + assert_eq!(Cidr::new(Address([0, 0, 0, 0]), 32).broadcast(), None); + assert_eq!( + Cidr::new(Address([127, 0, 0, 0]), 8).broadcast().unwrap(), + Address([127, 255, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([192, 168, 0, 0]), 16) + .broadcast() + .unwrap(), + Address([192, 168, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([192, 168, 1, 1]), 16) + .broadcast() + .unwrap(), + Address([192, 168, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([192, 168, 1, 1]), 17) + .broadcast() + .unwrap(), + Address([192, 168, 127, 255]) + ); + assert_eq!( + Cidr::new(Address([172, 16, 0, 1]), 12).broadcast().unwrap(), + Address([172, 31, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 1]), 24) + .broadcast() + .unwrap(), + Address([255, 255, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 254]), 31).broadcast(), + None + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 255]), 32).broadcast(), + None + ); + } + + #[test] + fn test_cidr_network() { + assert_eq!( + Cidr::new(Address([0, 0, 0, 0]), 0).network(), + Cidr::new(Address([0, 0, 0, 0]), 0) + ); + assert_eq!( + Cidr::new(Address([0, 0, 0, 1]), 24).network(), + Cidr::new(Address([0, 0, 0, 0]), 24) + ); + assert_eq!( + Cidr::new(Address([0, 0, 0, 0]), 32).network(), + Cidr::new(Address([0, 0, 0, 0]), 32) + ); + assert_eq!( + Cidr::new(Address([127, 0, 0, 0]), 8).network(), + Cidr::new(Address([127, 0, 0, 0]), 8) + ); + assert_eq!( + Cidr::new(Address([192, 168, 0, 0]), 16).network(), + Cidr::new(Address([192, 168, 0, 0]), 16) + ); + assert_eq!( + Cidr::new(Address([192, 168, 1, 1]), 16).network(), + Cidr::new(Address([192, 168, 0, 0]), 16) + ); + assert_eq!( + Cidr::new(Address([192, 168, 1, 1]), 17).network(), + Cidr::new(Address([192, 168, 0, 0]), 17) + ); + assert_eq!( + Cidr::new(Address([172, 16, 0, 1]), 12).network(), + Cidr::new(Address([172, 16, 0, 0]), 12) + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 1]), 24).network(), + Cidr::new(Address([255, 255, 255, 0]), 24) + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 255]), 32).network(), + Cidr::new(Address([255, 255, 255, 255]), 32) + ); + } +} diff --git a/src/wire/ipv6.rs b/src/wire/ipv6.rs new file mode 100644 index 0000000..236600d --- /dev/null +++ b/src/wire/ipv6.rs @@ -0,0 +1,1449 @@ +#![deny(missing_docs)] + +use byteorder::{ByteOrder, NetworkEndian}; +use core::fmt; + +use super::{Error, Result}; +use crate::wire::ip::pretty_print_ip_payload; +#[cfg(feature = "proto-ipv4")] +use crate::wire::ipv4; + +pub use super::IpProtocol as Protocol; + +/// Minimum MTU required of all links supporting IPv6. See [RFC 8200 § 5]. +/// +/// [RFC 8200 § 5]: https://tools.ietf.org/html/rfc8200#section-5 +pub const MIN_MTU: usize = 1280; + +/// Size of IPv6 adderess in octets. +/// +/// [RFC 8200 § 2]: https://www.rfc-editor.org/rfc/rfc4291#section-2 +pub const ADDR_SIZE: usize = 16; + +/// Size of IPv4-mapping prefix in octets. +/// +/// [RFC 8200 § 2]: https://www.rfc-editor.org/rfc/rfc4291#section-2 +pub const IPV4_MAPPED_PREFIX_SIZE: usize = ADDR_SIZE - 4; // 4 == ipv4::ADDR_SIZE , cannot DRY here because of dependency on a IPv4 module which is behind the feature + +/// The [scope] of an address. +/// +/// [scope]: https://www.rfc-editor.org/rfc/rfc4291#section-2.7 +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum Scope { + /// Interface Local scope + InterfaceLocal = 0x1, + /// Link local scope + LinkLocal = 0x2, + /// Administratively configured + AdminLocal = 0x4, + /// Single site scope + SiteLocal = 0x5, + /// Organization scope + OrganizationLocal = 0x8, + /// Global scope + Global = 0xE, + /// Unknown scope + Unknown = 0xFF, +} + +impl From<u8> for Scope { + fn from(value: u8) -> Self { + match value { + 0x1 => Self::InterfaceLocal, + 0x2 => Self::LinkLocal, + 0x4 => Self::AdminLocal, + 0x5 => Self::SiteLocal, + 0x8 => Self::OrganizationLocal, + 0xE => Self::Global, + _ => Self::Unknown, + } + } +} + +/// A sixteen-octet IPv6 address. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] +pub struct Address(pub [u8; ADDR_SIZE]); + +impl Address { + /// The [unspecified address]. + /// + /// [unspecified address]: https://tools.ietf.org/html/rfc4291#section-2.5.2 + pub const UNSPECIFIED: Address = Address([0x00; ADDR_SIZE]); + + /// The link-local [all nodes multicast address]. + /// + /// [all nodes multicast address]: https://tools.ietf.org/html/rfc4291#section-2.7.1 + pub const LINK_LOCAL_ALL_NODES: Address = Address([ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, + ]); + + /// The link-local [all routers multicast address]. + /// + /// [all routers multicast address]: https://tools.ietf.org/html/rfc4291#section-2.7.1 + pub const LINK_LOCAL_ALL_ROUTERS: Address = Address([ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, + ]); + + /// The link-local [all RPL nodes multicast address]. + /// + /// [all RPL nodes multicast address]: https://www.rfc-editor.org/rfc/rfc6550.html#section-20.19 + pub const LINK_LOCAL_ALL_RPL_NODES: Address = Address([ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x1a, + ]); + + /// The [loopback address]. + /// + /// [loopback address]: https://tools.ietf.org/html/rfc4291#section-2.5.3 + pub const LOOPBACK: Address = Address([ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, + ]); + + /// The prefix used in [IPv4-mapped addresses]. + /// + /// [IPv4-mapped addresses]: https://www.rfc-editor.org/rfc/rfc4291#section-2.5.5.2 + pub const IPV4_MAPPED_PREFIX: [u8; IPV4_MAPPED_PREFIX_SIZE] = + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff]; + + /// Construct an IPv6 address from parts. + #[allow(clippy::too_many_arguments)] + pub const fn new( + a0: u16, + a1: u16, + a2: u16, + a3: u16, + a4: u16, + a5: u16, + a6: u16, + a7: u16, + ) -> Address { + Address([ + (a0 >> 8) as u8, + a0 as u8, + (a1 >> 8) as u8, + a1 as u8, + (a2 >> 8) as u8, + a2 as u8, + (a3 >> 8) as u8, + a3 as u8, + (a4 >> 8) as u8, + a4 as u8, + (a5 >> 8) as u8, + a5 as u8, + (a6 >> 8) as u8, + a6 as u8, + (a7 >> 8) as u8, + a7 as u8, + ]) + } + + /// Construct an IPv6 address from a sequence of octets, in big-endian. + /// + /// # Panics + /// The function panics if `data` is not sixteen octets long. + pub fn from_bytes(data: &[u8]) -> Address { + let mut bytes = [0; ADDR_SIZE]; + bytes.copy_from_slice(data); + Address(bytes) + } + + /// Construct an IPv6 address from a sequence of words, in big-endian. + /// + /// # Panics + /// The function panics if `data` is not 8 words long. + pub fn from_parts(data: &[u16]) -> Address { + assert!(data.len() >= 8); + let mut bytes = [0; ADDR_SIZE]; + for (word_idx, chunk) in bytes.chunks_mut(2).enumerate() { + NetworkEndian::write_u16(chunk, data[word_idx]); + } + Address(bytes) + } + + /// Write a IPv6 address to the given slice. + /// + /// # Panics + /// The function panics if `data` is not 8 words long. + pub fn write_parts(&self, data: &mut [u16]) { + assert!(data.len() >= 8); + for (i, chunk) in self.0.chunks(2).enumerate() { + data[i] = NetworkEndian::read_u16(chunk); + } + } + + /// Return an IPv6 address as a sequence of octets, in big-endian. + pub const fn as_bytes(&self) -> &[u8] { + &self.0 + } + + /// Query whether the IPv6 address is an [unicast address]. + /// + /// [unicast address]: https://tools.ietf.org/html/rfc4291#section-2.5 + pub fn is_unicast(&self) -> bool { + !(self.is_multicast() || self.is_unspecified()) + } + + /// Query whether the IPv6 address is a [global unicast address]. + /// + /// [global unicast address]: https://datatracker.ietf.org/doc/html/rfc3587 + pub const fn is_global_unicast(&self) -> bool { + (self.0[0] >> 5) == 0b001 + } + + /// Query whether the IPv6 address is a [multicast address]. + /// + /// [multicast address]: https://tools.ietf.org/html/rfc4291#section-2.7 + pub const fn is_multicast(&self) -> bool { + self.0[0] == 0xff + } + + /// Query whether the IPv6 address is the [unspecified address]. + /// + /// [unspecified address]: https://tools.ietf.org/html/rfc4291#section-2.5.2 + pub fn is_unspecified(&self) -> bool { + self.0 == [0x00; ADDR_SIZE] + } + + /// Query whether the IPv6 address is in the [link-local] scope. + /// + /// [link-local]: https://tools.ietf.org/html/rfc4291#section-2.5.6 + pub fn is_link_local(&self) -> bool { + self.0[0..8] == [0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00] + } + + /// Query whether the IPv6 address is a [Unique Local Address] (ULA). + /// + /// [Unique Local Address]: https://tools.ietf.org/html/rfc4193 + pub fn is_unique_local(&self) -> bool { + (self.0[0] & 0b1111_1110) == 0xfc + } + + /// Query whether the IPv6 address is the [loopback address]. + /// + /// [loopback address]: https://tools.ietf.org/html/rfc4291#section-2.5.3 + pub fn is_loopback(&self) -> bool { + *self == Self::LOOPBACK + } + + /// Query whether the IPv6 address is an [IPv4 mapped IPv6 address]. + /// + /// [IPv4 mapped IPv6 address]: https://tools.ietf.org/html/rfc4291#section-2.5.5.2 + pub fn is_ipv4_mapped(&self) -> bool { + self.0[..IPV4_MAPPED_PREFIX_SIZE] == Self::IPV4_MAPPED_PREFIX + } + + #[cfg(feature = "proto-ipv4")] + /// Convert an IPv4 mapped IPv6 address to an IPv4 address. + pub fn as_ipv4(&self) -> Option<ipv4::Address> { + if self.is_ipv4_mapped() { + Some(ipv4::Address::from_bytes( + &self.0[IPV4_MAPPED_PREFIX_SIZE..], + )) + } else { + None + } + } + + /// Helper function used to mask an address given a prefix. + /// + /// # Panics + /// This function panics if `mask` is greater than 128. + pub(super) fn mask(&self, mask: u8) -> [u8; ADDR_SIZE] { + assert!(mask <= 128); + let mut bytes = [0u8; ADDR_SIZE]; + let idx = (mask as usize) / 8; + let modulus = (mask as usize) % 8; + let (first, second) = self.0.split_at(idx); + bytes[0..idx].copy_from_slice(first); + if idx < ADDR_SIZE { + let part = second[0]; + bytes[idx] = part & (!(0xff >> modulus) as u8); + } + bytes + } + + /// The solicited node for the given unicast address. + /// + /// # Panics + /// This function panics if the given address is not + /// unicast. + pub fn solicited_node(&self) -> Address { + assert!(self.is_unicast()); + Address([ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xFF, + self.0[13], self.0[14], self.0[15], + ]) + } + + /// Return the scope of the address. + pub(crate) fn scope(&self) -> Scope { + if self.is_multicast() { + return Scope::from(self.as_bytes()[1] & 0b1111); + } + + if self.is_link_local() { + Scope::LinkLocal + } else if self.is_unique_local() || self.is_global_unicast() { + // ULA are considered global scope + // https://www.rfc-editor.org/rfc/rfc6724#section-3.1 + Scope::Global + } else { + Scope::Unknown + } + } + + /// Convert to an `IpAddress`. + /// + /// Same as `.into()`, but works in `const`. + pub const fn into_address(self) -> super::IpAddress { + super::IpAddress::Ipv6(self) + } +} + +#[cfg(feature = "std")] +impl From<::std::net::Ipv6Addr> for Address { + fn from(x: ::std::net::Ipv6Addr) -> Address { + Address(x.octets()) + } +} + +#[cfg(feature = "std")] +impl From<Address> for ::std::net::Ipv6Addr { + fn from(Address(x): Address) -> ::std::net::Ipv6Addr { + x.into() + } +} + +impl fmt::Display for Address { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.is_ipv4_mapped() { + return write!( + f, + "::ffff:{}.{}.{}.{}", + self.0[IPV4_MAPPED_PREFIX_SIZE + 0], + self.0[IPV4_MAPPED_PREFIX_SIZE + 1], + self.0[IPV4_MAPPED_PREFIX_SIZE + 2], + self.0[IPV4_MAPPED_PREFIX_SIZE + 3] + ); + } + + // The string representation of an IPv6 address should + // collapse a series of 16 bit sections that evaluate + // to 0 to "::" + // + // See https://tools.ietf.org/html/rfc4291#section-2.2 + // for details. + enum State { + Head, + HeadBody, + Tail, + TailBody, + } + let mut words = [0u16; 8]; + self.write_parts(&mut words); + let mut state = State::Head; + for word in words.iter() { + state = match (*word, &state) { + // Once a u16 equal to zero write a double colon and + // skip to the next non-zero u16. + (0, &State::Head) | (0, &State::HeadBody) => { + write!(f, "::")?; + State::Tail + } + // Continue iterating without writing any characters until + // we hit a non-zero value. + (0, &State::Tail) => State::Tail, + // When the state is Head or Tail write a u16 in hexadecimal + // without the leading colon if the value is not 0. + (_, &State::Head) => { + write!(f, "{word:x}")?; + State::HeadBody + } + (_, &State::Tail) => { + write!(f, "{word:x}")?; + State::TailBody + } + // Write the u16 with a leading colon when parsing a value + // that isn't the first in a section + (_, &State::HeadBody) | (_, &State::TailBody) => { + write!(f, ":{word:x}")?; + state + } + } + } + Ok(()) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Address { + fn format(&self, f: defmt::Formatter) { + if self.is_ipv4_mapped() { + return defmt::write!( + f, + "::ffff:{}.{}.{}.{}", + self.0[IPV4_MAPPED_PREFIX_SIZE + 0], + self.0[IPV4_MAPPED_PREFIX_SIZE + 1], + self.0[IPV4_MAPPED_PREFIX_SIZE + 2], + self.0[IPV4_MAPPED_PREFIX_SIZE + 3] + ); + } + + // The string representation of an IPv6 address should + // collapse a series of 16 bit sections that evaluate + // to 0 to "::" + // + // See https://tools.ietf.org/html/rfc4291#section-2.2 + // for details. + enum State { + Head, + HeadBody, + Tail, + TailBody, + } + let mut words = [0u16; 8]; + self.write_parts(&mut words); + let mut state = State::Head; + for word in words.iter() { + state = match (*word, &state) { + // Once a u16 equal to zero write a double colon and + // skip to the next non-zero u16. + (0, &State::Head) | (0, &State::HeadBody) => { + defmt::write!(f, "::"); + State::Tail + } + // Continue iterating without writing any characters until + // we hit a non-zero value. + (0, &State::Tail) => State::Tail, + // When the state is Head or Tail write a u16 in hexadecimal + // without the leading colon if the value is not 0. + (_, &State::Head) => { + defmt::write!(f, "{:x}", word); + State::HeadBody + } + (_, &State::Tail) => { + defmt::write!(f, "{:x}", word); + State::TailBody + } + // Write the u16 with a leading colon when parsing a value + // that isn't the first in a section + (_, &State::HeadBody) | (_, &State::TailBody) => { + defmt::write!(f, ":{:x}", word); + state + } + } + } + } +} + +#[cfg(feature = "proto-ipv4")] +/// Convert the given IPv4 address into a IPv4-mapped IPv6 address +impl From<ipv4::Address> for Address { + fn from(address: ipv4::Address) -> Self { + let mut b = [0_u8; ADDR_SIZE]; + b[..Self::IPV4_MAPPED_PREFIX.len()].copy_from_slice(&Self::IPV4_MAPPED_PREFIX); + b[Self::IPV4_MAPPED_PREFIX.len()..].copy_from_slice(&address.0); + Self(b) + } +} + +/// A specification of an IPv6 CIDR block, containing an address and a variable-length +/// subnet masking prefix length. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] +pub struct Cidr { + address: Address, + prefix_len: u8, +} + +impl Cidr { + /// The [solicited node prefix]. + /// + /// [solicited node prefix]: https://tools.ietf.org/html/rfc4291#section-2.7.1 + pub const SOLICITED_NODE_PREFIX: Cidr = Cidr { + address: Address([ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff, 0x00, + 0x00, 0x00, + ]), + prefix_len: 104, + }; + + /// Create an IPv6 CIDR block from the given address and prefix length. + /// + /// # Panics + /// This function panics if the prefix length is larger than 128. + pub const fn new(address: Address, prefix_len: u8) -> Cidr { + assert!(prefix_len <= 128); + Cidr { + address, + prefix_len, + } + } + + /// Return the address of this IPv6 CIDR block. + pub const fn address(&self) -> Address { + self.address + } + + /// Return the prefix length of this IPv6 CIDR block. + pub const fn prefix_len(&self) -> u8 { + self.prefix_len + } + + /// Query whether the subnetwork described by this IPv6 CIDR block contains + /// the given address. + pub fn contains_addr(&self, addr: &Address) -> bool { + // right shift by 128 is not legal + if self.prefix_len == 0 { + return true; + } + + self.address.mask(self.prefix_len) == addr.mask(self.prefix_len) + } + + /// Query whether the subnetwork described by this IPV6 CIDR block contains + /// the subnetwork described by the given IPv6 CIDR block. + pub fn contains_subnet(&self, subnet: &Cidr) -> bool { + self.prefix_len <= subnet.prefix_len && self.contains_addr(&subnet.address) + } +} + +impl fmt::Display for Cidr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // https://tools.ietf.org/html/rfc4291#section-2.3 + write!(f, "{}/{}", self.address, self.prefix_len) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Cidr { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "{}/{=u8}", self.address, self.prefix_len); + } +} + +/// A read/write wrapper around an Internet Protocol version 6 packet buffer. +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Packet<T: AsRef<[u8]>> { + buffer: T, +} + +// Ranges and constants describing the IPv6 header +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Version| Traffic Class | Flow Label | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Payload Length | Next Header | Hop Limit | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// + + +// | | +// + Source Address + +// | | +// + + +// | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// + + +// | | +// + Destination Address + +// | | +// + + +// | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// See https://tools.ietf.org/html/rfc2460#section-3 for details. +mod field { + use crate::wire::field::*; + // 4-bit version number, 8-bit traffic class, and the + // 20-bit flow label. + pub const VER_TC_FLOW: Field = 0..4; + // 16-bit value representing the length of the payload. + // Note: Options are included in this length. + pub const LENGTH: Field = 4..6; + // 8-bit value identifying the type of header following this + // one. Note: The same numbers are used in IPv4. + pub const NXT_HDR: usize = 6; + // 8-bit value decremented by each node that forwards this + // packet. The packet is discarded when the value is 0. + pub const HOP_LIMIT: usize = 7; + // IPv6 address of the source node. + pub const SRC_ADDR: Field = 8..24; + // IPv6 address of the destination node. + pub const DST_ADDR: Field = 24..40; +} + +/// Length of an IPv6 header. +pub const HEADER_LEN: usize = field::DST_ADDR.end; + +impl<T: AsRef<[u8]>> Packet<T> { + /// Create a raw octet buffer with an IPv6 packet structure. + #[inline] + pub const fn new_unchecked(buffer: T) -> Packet<T> { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + #[inline] + pub fn new_checked(buffer: T) -> Result<Packet<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + /// + /// The result of this check is invalidated by calling [set_payload_len]. + /// + /// [set_payload_len]: #method.set_payload_len + #[inline] + pub fn check_len(&self) -> Result<()> { + let len = self.buffer.as_ref().len(); + if len < field::DST_ADDR.end || len < self.total_len() { + Err(Error) + } else { + Ok(()) + } + } + + /// Consume the packet, returning the underlying buffer. + #[inline] + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the header length. + #[inline] + pub const fn header_len(&self) -> usize { + // This is not a strictly necessary function, but it makes + // code more readable. + field::DST_ADDR.end + } + + /// Return the version field. + #[inline] + pub fn version(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::VER_TC_FLOW.start] >> 4 + } + + /// Return the traffic class. + #[inline] + pub fn traffic_class(&self) -> u8 { + let data = self.buffer.as_ref(); + ((NetworkEndian::read_u16(&data[0..2]) & 0x0ff0) >> 4) as u8 + } + + /// Return the flow label field. + #[inline] + pub fn flow_label(&self) -> u32 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u24(&data[1..4]) & 0x000fffff + } + + /// Return the payload length field. + #[inline] + pub fn payload_len(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::LENGTH]) + } + + /// Return the payload length added to the known header length. + #[inline] + pub fn total_len(&self) -> usize { + self.header_len() + self.payload_len() as usize + } + + /// Return the next header field. + #[inline] + pub fn next_header(&self) -> Protocol { + let data = self.buffer.as_ref(); + Protocol::from(data[field::NXT_HDR]) + } + + /// Return the hop limit field. + #[inline] + pub fn hop_limit(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::HOP_LIMIT] + } + + /// Return the source address field. + #[inline] + pub fn src_addr(&self) -> Address { + let data = self.buffer.as_ref(); + Address::from_bytes(&data[field::SRC_ADDR]) + } + + /// Return the destination address field. + #[inline] + pub fn dst_addr(&self) -> Address { + let data = self.buffer.as_ref(); + Address::from_bytes(&data[field::DST_ADDR]) + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> { + /// Return a pointer to the payload. + #[inline] + pub fn payload(&self) -> &'a [u8] { + let data = self.buffer.as_ref(); + let range = self.header_len()..self.total_len(); + &data[range] + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the version field. + #[inline] + pub fn set_version(&mut self, value: u8) { + let data = self.buffer.as_mut(); + // Make sure to retain the lower order bits which contain + // the higher order bits of the traffic class + data[0] = (data[0] & 0x0f) | ((value & 0x0f) << 4); + } + + /// Set the traffic class field. + #[inline] + pub fn set_traffic_class(&mut self, value: u8) { + let data = self.buffer.as_mut(); + // Put the higher order 4-bits of value in the lower order + // 4-bits of the first byte + data[0] = (data[0] & 0xf0) | ((value & 0xf0) >> 4); + // Put the lower order 4-bits of value in the higher order + // 4-bits of the second byte + data[1] = (data[1] & 0x0f) | ((value & 0x0f) << 4); + } + + /// Set the flow label field. + #[inline] + pub fn set_flow_label(&mut self, value: u32) { + let data = self.buffer.as_mut(); + // Retain the lower order 4-bits of the traffic class + let raw = (((data[1] & 0xf0) as u32) << 16) | (value & 0x0fffff); + NetworkEndian::write_u24(&mut data[1..4], raw); + } + + /// Set the payload length field. + #[inline] + pub fn set_payload_len(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::LENGTH], value); + } + + /// Set the next header field. + #[inline] + pub fn set_next_header(&mut self, value: Protocol) { + let data = self.buffer.as_mut(); + data[field::NXT_HDR] = value.into(); + } + + /// Set the hop limit field. + #[inline] + pub fn set_hop_limit(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::HOP_LIMIT] = value; + } + + /// Set the source address field. + #[inline] + pub fn set_src_addr(&mut self, value: Address) { + let data = self.buffer.as_mut(); + data[field::SRC_ADDR].copy_from_slice(value.as_bytes()); + } + + /// Set the destination address field. + #[inline] + pub fn set_dst_addr(&mut self, value: Address) { + let data = self.buffer.as_mut(); + data[field::DST_ADDR].copy_from_slice(value.as_bytes()); + } + + /// Return a mutable pointer to the payload. + #[inline] + pub fn payload_mut(&mut self) -> &mut [u8] { + let range = self.header_len()..self.total_len(); + let data = self.buffer.as_mut(); + &mut data[range] + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match Repr::parse(self) { + Ok(repr) => write!(f, "{repr}"), + Err(err) => { + write!(f, "IPv6 ({err})")?; + Ok(()) + } + } + } +} + +impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> { + fn as_ref(&self) -> &[u8] { + self.buffer.as_ref() + } +} + +/// A high-level representation of an Internet Protocol version 6 packet header. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct Repr { + /// IPv6 address of the source node. + pub src_addr: Address, + /// IPv6 address of the destination node. + pub dst_addr: Address, + /// Protocol contained in the next header. + pub next_header: Protocol, + /// Length of the payload including the extension headers. + pub payload_len: usize, + /// The 8-bit hop limit field. + pub hop_limit: u8, +} + +impl Repr { + /// Parse an Internet Protocol version 6 packet and return a high-level representation. + pub fn parse<T: AsRef<[u8]> + ?Sized>(packet: &Packet<&T>) -> Result<Repr> { + // Ensure basic accessors will work + packet.check_len()?; + if packet.version() != 6 { + return Err(Error); + } + Ok(Repr { + src_addr: packet.src_addr(), + dst_addr: packet.dst_addr(), + next_header: packet.next_header(), + payload_len: packet.payload_len() as usize, + hop_limit: packet.hop_limit(), + }) + } + + /// Return the length of a header that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + // This function is not strictly necessary, but it can make client code more readable. + field::DST_ADDR.end + } + + /// Emit a high-level representation into an Internet Protocol version 6 packet. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, packet: &mut Packet<T>) { + // Make no assumptions about the original state of the packet buffer. + // Make sure to set every byte. + packet.set_version(6); + packet.set_traffic_class(0); + packet.set_flow_label(0); + packet.set_payload_len(self.payload_len as u16); + packet.set_hop_limit(self.hop_limit); + packet.set_next_header(self.next_header); + packet.set_src_addr(self.src_addr); + packet.set_dst_addr(self.dst_addr); + } +} + +impl fmt::Display for Repr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "IPv6 src={} dst={} nxt_hdr={} hop_limit={}", + self.src_addr, self.dst_addr, self.next_header, self.hop_limit + ) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Repr { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!( + fmt, + "IPv6 src={} dst={} nxt_hdr={} hop_limit={}", + self.src_addr, + self.dst_addr, + self.next_header, + self.hop_limit + ) + } +} + +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; + +// TODO: This is very similar to the implementation for IPv4. Make +// a way to have less copy and pasted code here. +impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { + let (ip_repr, payload) = match Packet::new_checked(buffer) { + Err(err) => return write!(f, "{indent}({err})"), + Ok(ip_packet) => match Repr::parse(&ip_packet) { + Err(_) => return Ok(()), + Ok(ip_repr) => { + write!(f, "{indent}{ip_repr}")?; + (ip_repr, ip_packet.payload()) + } + }, + }; + + pretty_print_ip_payload(f, indent, ip_repr, payload) + } +} + +#[cfg(test)] +mod test { + use super::Error; + use super::{Address, Cidr}; + use super::{Packet, Protocol, Repr}; + use crate::wire::pretty_print::PrettyPrinter; + + #[cfg(feature = "proto-ipv4")] + use crate::wire::ipv4::Address as Ipv4Address; + + const LINK_LOCAL_ADDR: Address = Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1); + const UNIQUE_LOCAL_ADDR: Address = Address::new(0xfd00, 0, 0, 201, 1, 1, 1, 1); + const GLOBAL_UNICAST_ADDR: Address = Address::new(0x2001, 0xdb8, 0x3, 0, 0, 0, 0, 1); + + #[test] + fn test_basic_multicast() { + assert!(!Address::LINK_LOCAL_ALL_ROUTERS.is_unspecified()); + assert!(Address::LINK_LOCAL_ALL_ROUTERS.is_multicast()); + assert!(!Address::LINK_LOCAL_ALL_ROUTERS.is_link_local()); + assert!(!Address::LINK_LOCAL_ALL_ROUTERS.is_loopback()); + assert!(!Address::LINK_LOCAL_ALL_ROUTERS.is_unique_local()); + assert!(!Address::LINK_LOCAL_ALL_ROUTERS.is_global_unicast()); + assert!(!Address::LINK_LOCAL_ALL_NODES.is_unspecified()); + assert!(Address::LINK_LOCAL_ALL_NODES.is_multicast()); + assert!(!Address::LINK_LOCAL_ALL_NODES.is_link_local()); + assert!(!Address::LINK_LOCAL_ALL_NODES.is_loopback()); + assert!(!Address::LINK_LOCAL_ALL_NODES.is_unique_local()); + assert!(!Address::LINK_LOCAL_ALL_NODES.is_global_unicast()); + } + + #[test] + fn test_basic_link_local() { + assert!(!LINK_LOCAL_ADDR.is_unspecified()); + assert!(!LINK_LOCAL_ADDR.is_multicast()); + assert!(LINK_LOCAL_ADDR.is_link_local()); + assert!(!LINK_LOCAL_ADDR.is_loopback()); + assert!(!LINK_LOCAL_ADDR.is_unique_local()); + assert!(!LINK_LOCAL_ADDR.is_global_unicast()); + } + + #[test] + fn test_basic_loopback() { + assert!(!Address::LOOPBACK.is_unspecified()); + assert!(!Address::LOOPBACK.is_multicast()); + assert!(!Address::LOOPBACK.is_link_local()); + assert!(Address::LOOPBACK.is_loopback()); + assert!(!Address::LOOPBACK.is_unique_local()); + assert!(!Address::LOOPBACK.is_global_unicast()); + } + + #[test] + fn test_unique_local() { + assert!(!UNIQUE_LOCAL_ADDR.is_unspecified()); + assert!(!UNIQUE_LOCAL_ADDR.is_multicast()); + assert!(!UNIQUE_LOCAL_ADDR.is_link_local()); + assert!(!UNIQUE_LOCAL_ADDR.is_loopback()); + assert!(UNIQUE_LOCAL_ADDR.is_unique_local()); + assert!(!UNIQUE_LOCAL_ADDR.is_global_unicast()); + } + + #[test] + fn test_global_unicast() { + assert!(!GLOBAL_UNICAST_ADDR.is_unspecified()); + assert!(!GLOBAL_UNICAST_ADDR.is_multicast()); + assert!(!GLOBAL_UNICAST_ADDR.is_link_local()); + assert!(!GLOBAL_UNICAST_ADDR.is_loopback()); + assert!(!GLOBAL_UNICAST_ADDR.is_unique_local()); + assert!(GLOBAL_UNICAST_ADDR.is_global_unicast()); + } + + #[test] + fn test_address_format() { + assert_eq!("ff02::1", format!("{}", Address::LINK_LOCAL_ALL_NODES)); + assert_eq!("fe80::1", format!("{LINK_LOCAL_ADDR}")); + assert_eq!( + "fe80::7f00:0:1", + format!( + "{}", + Address::new(0xfe80, 0, 0, 0, 0, 0x7f00, 0x0000, 0x0001) + ) + ); + assert_eq!("::", format!("{}", Address::UNSPECIFIED)); + assert_eq!("::1", format!("{}", Address::LOOPBACK)); + + #[cfg(feature = "proto-ipv4")] + assert_eq!( + "::ffff:192.168.1.1", + format!("{}", Address::from(Ipv4Address::new(192, 168, 1, 1))) + ); + } + + #[test] + fn test_new() { + assert_eq!( + Address::new(0xff02, 0, 0, 0, 0, 0, 0, 1), + Address::LINK_LOCAL_ALL_NODES + ); + assert_eq!( + Address::new(0xff02, 0, 0, 0, 0, 0, 0, 2), + Address::LINK_LOCAL_ALL_ROUTERS + ); + assert_eq!(Address::new(0, 0, 0, 0, 0, 0, 0, 1), Address::LOOPBACK); + assert_eq!(Address::new(0, 0, 0, 0, 0, 0, 0, 0), Address::UNSPECIFIED); + assert_eq!(Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), LINK_LOCAL_ADDR); + } + + #[test] + fn test_from_parts() { + assert_eq!( + Address::from_parts(&[0xff02, 0, 0, 0, 0, 0, 0, 1]), + Address::LINK_LOCAL_ALL_NODES + ); + assert_eq!( + Address::from_parts(&[0xff02, 0, 0, 0, 0, 0, 0, 2]), + Address::LINK_LOCAL_ALL_ROUTERS + ); + assert_eq!( + Address::from_parts(&[0, 0, 0, 0, 0, 0, 0, 1]), + Address::LOOPBACK + ); + assert_eq!( + Address::from_parts(&[0, 0, 0, 0, 0, 0, 0, 0]), + Address::UNSPECIFIED + ); + assert_eq!( + Address::from_parts(&[0xfe80, 0, 0, 0, 0, 0, 0, 1]), + LINK_LOCAL_ADDR + ); + } + + #[test] + fn test_write_parts() { + let mut bytes = [0u16; 8]; + { + Address::LOOPBACK.write_parts(&mut bytes); + assert_eq!(Address::LOOPBACK, Address::from_parts(&bytes)); + } + { + Address::LINK_LOCAL_ALL_ROUTERS.write_parts(&mut bytes); + assert_eq!(Address::LINK_LOCAL_ALL_ROUTERS, Address::from_parts(&bytes)); + } + { + LINK_LOCAL_ADDR.write_parts(&mut bytes); + assert_eq!(LINK_LOCAL_ADDR, Address::from_parts(&bytes)); + } + } + + #[test] + fn test_mask() { + let addr = Address::new(0x0123, 0x4567, 0x89ab, 0, 0, 0, 0, 1); + assert_eq!( + addr.mask(11), + [0x01, 0x20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!( + addr.mask(15), + [0x01, 0x22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!( + addr.mask(26), + [0x01, 0x23, 0x45, 0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!( + addr.mask(128), + [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] + ); + assert_eq!( + addr.mask(127), + [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + } + + #[cfg(feature = "proto-ipv4")] + #[test] + fn test_is_ipv4_mapped() { + assert!(!Address::UNSPECIFIED.is_ipv4_mapped()); + assert!(Address::from(Ipv4Address::new(192, 168, 1, 1)).is_ipv4_mapped()); + } + + #[cfg(feature = "proto-ipv4")] + #[test] + fn test_as_ipv4() { + assert_eq!(None, Address::UNSPECIFIED.as_ipv4()); + + let ipv4 = Ipv4Address::new(192, 168, 1, 1); + assert_eq!(Some(ipv4), Address::from(ipv4).as_ipv4()); + } + + #[cfg(feature = "proto-ipv4")] + #[test] + fn test_from_ipv4_address() { + assert_eq!( + Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1]), + Address::from(Ipv4Address::new(192, 168, 1, 1)) + ); + assert_eq!( + Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 222, 1, 41, 90]), + Address::from(Ipv4Address::new(222, 1, 41, 90)) + ); + } + + #[test] + fn test_cidr() { + // fe80::1/56 + // 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, + // 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + let cidr = Cidr::new(LINK_LOCAL_ADDR, 56); + + let inside_subnet = [ + // fe80::2 + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x02, + ], + // fe80::1122:3344:5566:7788 + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, + 0x77, 0x88, + ], + // fe80::ff00:0:0:0 + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, + ], + // fe80::ff + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0xff, + ], + ]; + + let outside_subnet = [ + // fe80:0:0:101::1 + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ], + // ::1 + [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ], + // ff02::1 + [ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ], + // ff02::2 + [ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x02, + ], + ]; + + let subnets = [ + // fe80::ffff:ffff:ffff:ffff/65 + ( + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, + ], + 65, + ), + // fe80::1/128 + ( + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x01, + ], + 128, + ), + // fe80::1234:5678/96 + ( + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, + 0x34, 0x56, 0x78, + ], + 96, + ), + ]; + + let not_subnets = [ + // fe80::101:ffff:ffff:ffff:ffff/55 + ( + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, + ], + 55, + ), + // fe80::101:ffff:ffff:ffff:ffff/56 + ( + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, + ], + 56, + ), + // fe80::101:ffff:ffff:ffff:ffff/57 + ( + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, + ], + 57, + ), + // ::1/128 + ( + [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x01, + ], + 128, + ), + ]; + + for addr in inside_subnet.iter().map(|a| Address::from_bytes(a)) { + assert!(cidr.contains_addr(&addr)); + } + + for addr in outside_subnet.iter().map(|a| Address::from_bytes(a)) { + assert!(!cidr.contains_addr(&addr)); + } + + for subnet in subnets.iter().map(|&(a, p)| Cidr::new(Address(a), p)) { + assert!(cidr.contains_subnet(&subnet)); + } + + for subnet in not_subnets.iter().map(|&(a, p)| Cidr::new(Address(a), p)) { + assert!(!cidr.contains_subnet(&subnet)); + } + + let cidr_without_prefix = Cidr::new(LINK_LOCAL_ADDR, 0); + assert!(cidr_without_prefix.contains_addr(&Address::LOOPBACK)); + } + + #[test] + #[should_panic(expected = "length")] + fn test_from_bytes_too_long() { + let _ = Address::from_bytes(&[0u8; 15]); + } + + #[test] + #[should_panic(expected = "data.len() >= 8")] + fn test_from_parts_too_long() { + let _ = Address::from_parts(&[0u16; 7]); + } + + #[test] + fn test_scope() { + use super::*; + assert_eq!( + Address::new(0xff01, 0, 0, 0, 0, 0, 0, 1).scope(), + Scope::InterfaceLocal + ); + assert_eq!( + Address::new(0xff02, 0, 0, 0, 0, 0, 0, 1).scope(), + Scope::LinkLocal + ); + assert_eq!( + Address::new(0xff03, 0, 0, 0, 0, 0, 0, 1).scope(), + Scope::Unknown + ); + assert_eq!( + Address::new(0xff04, 0, 0, 0, 0, 0, 0, 1).scope(), + Scope::AdminLocal + ); + assert_eq!( + Address::new(0xff05, 0, 0, 0, 0, 0, 0, 1).scope(), + Scope::SiteLocal + ); + assert_eq!( + Address::new(0xff08, 0, 0, 0, 0, 0, 0, 1).scope(), + Scope::OrganizationLocal + ); + assert_eq!( + Address::new(0xff0e, 0, 0, 0, 0, 0, 0, 1).scope(), + Scope::Global + ); + + assert_eq!(Address::LINK_LOCAL_ALL_NODES.scope(), Scope::LinkLocal); + + // For source address selection, unicast addresses also have a scope: + assert_eq!(LINK_LOCAL_ADDR.scope(), Scope::LinkLocal); + assert_eq!(GLOBAL_UNICAST_ADDR.scope(), Scope::Global); + assert_eq!(UNIQUE_LOCAL_ADDR.scope(), Scope::Global); + } + + static REPR_PACKET_BYTES: [u8; 52] = [ + 0x60, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x11, 0x40, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0x00, + 0x0c, 0x02, 0x4e, 0xff, 0xff, 0xff, 0xff, + ]; + static REPR_PAYLOAD_BYTES: [u8; 12] = [ + 0x00, 0x01, 0x00, 0x02, 0x00, 0x0c, 0x02, 0x4e, 0xff, 0xff, 0xff, 0xff, + ]; + + const fn packet_repr() -> Repr { + Repr { + src_addr: Address([ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ]), + dst_addr: Address::LINK_LOCAL_ALL_NODES, + next_header: Protocol::Udp, + payload_len: 12, + hop_limit: 64, + } + } + + #[test] + fn test_packet_deconstruction() { + let packet = Packet::new_unchecked(&REPR_PACKET_BYTES[..]); + assert_eq!(packet.check_len(), Ok(())); + assert_eq!(packet.version(), 6); + assert_eq!(packet.traffic_class(), 0); + assert_eq!(packet.flow_label(), 0); + assert_eq!(packet.total_len(), 0x34); + assert_eq!(packet.payload_len() as usize, REPR_PAYLOAD_BYTES.len()); + assert_eq!(packet.next_header(), Protocol::Udp); + assert_eq!(packet.hop_limit(), 0x40); + assert_eq!( + packet.src_addr(), + Address([ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01 + ]) + ); + assert_eq!(packet.dst_addr(), Address::LINK_LOCAL_ALL_NODES); + assert_eq!(packet.payload(), &REPR_PAYLOAD_BYTES[..]); + } + + #[test] + fn test_packet_construction() { + let mut bytes = [0xff; 52]; + let mut packet = Packet::new_unchecked(&mut bytes[..]); + // Version, Traffic Class, and Flow Label are not + // byte aligned. make sure the setters and getters + // do not interfere with each other. + packet.set_version(6); + assert_eq!(packet.version(), 6); + packet.set_traffic_class(0x99); + assert_eq!(packet.version(), 6); + assert_eq!(packet.traffic_class(), 0x99); + packet.set_flow_label(0x54321); + assert_eq!(packet.traffic_class(), 0x99); + assert_eq!(packet.flow_label(), 0x54321); + packet.set_payload_len(0xc); + packet.set_next_header(Protocol::Udp); + packet.set_hop_limit(0xfe); + packet.set_src_addr(Address::LINK_LOCAL_ALL_ROUTERS); + packet.set_dst_addr(Address::LINK_LOCAL_ALL_NODES); + packet + .payload_mut() + .copy_from_slice(&REPR_PAYLOAD_BYTES[..]); + let mut expected_bytes = [ + 0x69, 0x95, 0x43, 0x21, 0x00, 0x0c, 0x11, 0xfe, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x02, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + let start = expected_bytes.len() - REPR_PAYLOAD_BYTES.len(); + expected_bytes[start..].copy_from_slice(&REPR_PAYLOAD_BYTES[..]); + assert_eq!(packet.check_len(), Ok(())); + assert_eq!(&*packet.into_inner(), &expected_bytes[..]); + } + + #[test] + fn test_overlong() { + let mut bytes = vec![]; + bytes.extend(&REPR_PACKET_BYTES[..]); + bytes.push(0); + + assert_eq!( + Packet::new_unchecked(&bytes).payload().len(), + REPR_PAYLOAD_BYTES.len() + ); + assert_eq!( + Packet::new_unchecked(&mut bytes).payload_mut().len(), + REPR_PAYLOAD_BYTES.len() + ); + } + + #[test] + fn test_total_len_overflow() { + let mut bytes = vec![]; + bytes.extend(&REPR_PACKET_BYTES[..]); + Packet::new_unchecked(&mut bytes).set_payload_len(0x80); + + assert_eq!(Packet::new_checked(&bytes).unwrap_err(), Error); + } + + #[test] + fn test_repr_parse_valid() { + let packet = Packet::new_unchecked(&REPR_PACKET_BYTES[..]); + let repr = Repr::parse(&packet).unwrap(); + assert_eq!(repr, packet_repr()); + } + + #[test] + fn test_repr_parse_bad_version() { + let mut bytes = vec![0; 40]; + let mut packet = Packet::new_unchecked(&mut bytes[..]); + packet.set_version(4); + packet.set_payload_len(0); + let packet = Packet::new_unchecked(&*packet.into_inner()); + assert_eq!(Repr::parse(&packet), Err(Error)); + } + + #[test] + fn test_repr_parse_smaller_than_header() { + let mut bytes = vec![0; 40]; + let mut packet = Packet::new_unchecked(&mut bytes[..]); + packet.set_version(6); + packet.set_payload_len(39); + let packet = Packet::new_unchecked(&*packet.into_inner()); + assert_eq!(Repr::parse(&packet), Err(Error)); + } + + #[test] + fn test_repr_parse_smaller_than_payload() { + let mut bytes = vec![0; 40]; + let mut packet = Packet::new_unchecked(&mut bytes[..]); + packet.set_version(6); + packet.set_payload_len(1); + let packet = Packet::new_unchecked(&*packet.into_inner()); + assert_eq!(Repr::parse(&packet), Err(Error)); + } + + #[test] + fn test_basic_repr_emit() { + let repr = packet_repr(); + let mut bytes = vec![0xff; repr.buffer_len() + REPR_PAYLOAD_BYTES.len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + repr.emit(&mut packet); + packet.payload_mut().copy_from_slice(&REPR_PAYLOAD_BYTES); + assert_eq!(&*packet.into_inner(), &REPR_PACKET_BYTES[..]); + } + + #[test] + fn test_pretty_print() { + assert_eq!( + format!( + "{}", + PrettyPrinter::<Packet<&'static [u8]>>::new("\n", &&REPR_PACKET_BYTES[..]) + ), + "\nIPv6 src=fe80::1 dst=ff02::1 nxt_hdr=UDP hop_limit=64\n \\ UDP src=1 dst=2 len=4" + ); + } +} diff --git a/src/wire/ipv6ext_header.rs b/src/wire/ipv6ext_header.rs new file mode 100644 index 0000000..bc8ef87 --- /dev/null +++ b/src/wire/ipv6ext_header.rs @@ -0,0 +1,305 @@ +#![allow(unused)] + +use super::IpProtocol; +use super::{Error, Result}; + +mod field { + #![allow(non_snake_case)] + + use crate::wire::field::*; + + pub const MIN_HEADER_SIZE: usize = 8; + + pub const NXT_HDR: usize = 0; + pub const LENGTH: usize = 1; + // Variable-length field. + // + // Length of the header is in 8-octet units, not including the first 8 octets. + // The first two octets are the next header type and the header length. + pub const fn PAYLOAD(length_field: u8) -> Field { + let bytes = length_field as usize * 8 + 8; + 2..bytes + } +} + +/// A read/write wrapper around an IPv6 Extension Header buffer. +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Header<T: AsRef<[u8]>> { + buffer: T, +} + +/// Core getter methods relevant to any IPv6 extension header. +impl<T: AsRef<[u8]>> Header<T> { + /// Create a raw octet buffer with an IPv6 Extension Header structure. + pub const fn new_unchecked(buffer: T) -> Self { + Header { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Self> { + let header = Self::new_unchecked(buffer); + header.check_len()?; + Ok(header) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + /// + /// The result of this check is invalidated by calling [set_header_len]. + /// + /// [set_header_len]: #method.set_header_len + pub fn check_len(&self) -> Result<()> { + let data = self.buffer.as_ref(); + + let len = data.len(); + if len < field::MIN_HEADER_SIZE { + return Err(Error); + } + + let of = field::PAYLOAD(data[field::LENGTH]); + if len < of.end { + return Err(Error); + } + + Ok(()) + } + + /// Consume the header, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the next header field. + pub fn next_header(&self) -> IpProtocol { + let data = self.buffer.as_ref(); + IpProtocol::from(data[field::NXT_HDR]) + } + + /// Return the header length field. + pub fn header_len(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::LENGTH] + } +} + +impl<'h, T: AsRef<[u8]> + ?Sized> Header<&'h T> { + /// Return the payload of the IPv6 extension header. + pub fn payload(&self) -> &'h [u8] { + let data = self.buffer.as_ref(); + &data[field::PAYLOAD(data[field::LENGTH])] + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Header<T> { + /// Set the next header field. + #[inline] + pub fn set_next_header(&mut self, value: IpProtocol) { + let data = self.buffer.as_mut(); + data[field::NXT_HDR] = value.into(); + } + + /// Set the extension header data length. The length of the header is + /// in 8-octet units, not including the first 8 octets. + #[inline] + pub fn set_header_len(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::LENGTH] = value; + } +} + +impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Header<&'a mut T> { + /// Return a mutable pointer to the payload data. + #[inline] + pub fn payload_mut(&mut self) -> &mut [u8] { + let data = self.buffer.as_mut(); + let len = data[field::LENGTH]; + &mut data[field::PAYLOAD(len)] + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Repr<'a> { + pub next_header: IpProtocol, + pub length: u8, + pub data: &'a [u8], +} + +impl<'a> Repr<'a> { + /// Parse an IPv6 Extension Header Header and return a high-level representation. + pub fn parse<T>(header: &Header<&'a T>) -> Result<Self> + where + T: AsRef<[u8]> + ?Sized, + { + Ok(Self { + next_header: header.next_header(), + length: header.header_len(), + data: header.payload(), + }) + } + + /// Return the length, in bytes, of a header that will be emitted from this high-level + /// representation. + pub const fn header_len(&self) -> usize { + 2 + } + + /// Emit a high-level representation into an IPv6 Extension Header. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(&self, header: &mut Header<&mut T>) { + header.set_next_header(self.next_header); + header.set_header_len(self.length); + } +} + +#[cfg(test)] +mod test { + use super::*; + + // A Hop-by-Hop Option header with a PadN option of option data length 4. + static REPR_PACKET_PAD4: [u8; 8] = [0x6, 0x0, 0x1, 0x4, 0x0, 0x0, 0x0, 0x0]; + + // A Hop-by-Hop Option header with a PadN option of option data length 12. + static REPR_PACKET_PAD12: [u8; 16] = [ + 0x06, 0x1, 0x1, 0x0C, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + ]; + + #[test] + fn test_check_len() { + // zero byte buffer + assert_eq!( + Err(Error), + Header::new_unchecked(&REPR_PACKET_PAD4[..0]).check_len() + ); + // no length field + assert_eq!( + Err(Error), + Header::new_unchecked(&REPR_PACKET_PAD4[..1]).check_len() + ); + // less than 8 bytes + assert_eq!( + Err(Error), + Header::new_unchecked(&REPR_PACKET_PAD4[..7]).check_len() + ); + // valid + assert_eq!(Ok(()), Header::new_unchecked(&REPR_PACKET_PAD4).check_len()); + // valid + assert_eq!( + Ok(()), + Header::new_unchecked(&REPR_PACKET_PAD12).check_len() + ); + // length field value greater than number of bytes + let header: [u8; 8] = [0x06, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0]; + assert_eq!(Err(Error), Header::new_unchecked(&header).check_len()); + } + + #[test] + fn test_header_deconstruct() { + let header = Header::new_unchecked(&REPR_PACKET_PAD4); + assert_eq!(header.next_header(), IpProtocol::Tcp); + assert_eq!(header.header_len(), 0); + assert_eq!(header.payload(), &REPR_PACKET_PAD4[2..]); + + let header = Header::new_unchecked(&REPR_PACKET_PAD12); + assert_eq!(header.next_header(), IpProtocol::Tcp); + assert_eq!(header.header_len(), 1); + assert_eq!(header.payload(), &REPR_PACKET_PAD12[2..]); + } + + #[test] + fn test_overlong() { + let mut bytes = vec![]; + bytes.extend(&REPR_PACKET_PAD4[..]); + bytes.push(0); + + assert_eq!( + Header::new_unchecked(&bytes).payload().len(), + REPR_PACKET_PAD4[2..].len() + ); + assert_eq!( + Header::new_unchecked(&mut bytes).payload_mut().len(), + REPR_PACKET_PAD4[2..].len() + ); + + let mut bytes = vec![]; + bytes.extend(&REPR_PACKET_PAD12[..]); + bytes.push(0); + + assert_eq!( + Header::new_unchecked(&bytes).payload().len(), + REPR_PACKET_PAD12[2..].len() + ); + assert_eq!( + Header::new_unchecked(&mut bytes).payload_mut().len(), + REPR_PACKET_PAD12[2..].len() + ); + } + + #[test] + fn test_header_len_overflow() { + let mut bytes = vec![]; + bytes.extend(REPR_PACKET_PAD4); + let len = bytes.len() as u8; + Header::new_unchecked(&mut bytes).set_header_len(len + 1); + + assert_eq!(Header::new_checked(&bytes).unwrap_err(), Error); + + let mut bytes = vec![]; + bytes.extend(REPR_PACKET_PAD12); + let len = bytes.len() as u8; + Header::new_unchecked(&mut bytes).set_header_len(len + 1); + + assert_eq!(Header::new_checked(&bytes).unwrap_err(), Error); + } + + #[test] + fn test_repr_parse_valid() { + let header = Header::new_unchecked(&REPR_PACKET_PAD4); + let repr = Repr::parse(&header).unwrap(); + assert_eq!( + repr, + Repr { + next_header: IpProtocol::Tcp, + length: 0, + data: &REPR_PACKET_PAD4[2..] + } + ); + + let header = Header::new_unchecked(&REPR_PACKET_PAD12); + let repr = Repr::parse(&header).unwrap(); + assert_eq!( + repr, + Repr { + next_header: IpProtocol::Tcp, + length: 1, + data: &REPR_PACKET_PAD12[2..] + } + ); + } + + #[test] + fn test_repr_emit() { + let repr = Repr { + next_header: IpProtocol::Tcp, + length: 0, + data: &REPR_PACKET_PAD4[2..], + }; + let mut bytes = [0u8; 2]; + let mut header = Header::new_unchecked(&mut bytes); + repr.emit(&mut header); + assert_eq!(header.into_inner(), &REPR_PACKET_PAD4[..2]); + + let repr = Repr { + next_header: IpProtocol::Tcp, + length: 1, + data: &REPR_PACKET_PAD12[2..], + }; + let mut bytes = [0u8; 2]; + let mut header = Header::new_unchecked(&mut bytes); + repr.emit(&mut header); + assert_eq!(header.into_inner(), &REPR_PACKET_PAD12[..2]); + } +} diff --git a/src/wire/ipv6fragment.rs b/src/wire/ipv6fragment.rs new file mode 100644 index 0000000..cf6b6d0 --- /dev/null +++ b/src/wire/ipv6fragment.rs @@ -0,0 +1,283 @@ +use super::{Error, Result}; +use core::fmt; + +use byteorder::{ByteOrder, NetworkEndian}; + +/// A read/write wrapper around an IPv6 Fragment Header. +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Header<T: AsRef<[u8]>> { + buffer: T, +} + +// Format of the Fragment Header +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Next Header | Reserved | Fragment Offset |Res|M| +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Identification | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// See https://tools.ietf.org/html/rfc8200#section-4.5 for details. +// +// **NOTE**: The fields start counting after the header length field. +mod field { + use crate::wire::field::*; + + // 16-bit field containing the fragment offset, reserved and more fragments values. + pub const FR_OF_M: Field = 0..2; + // 32-bit field identifying the fragmented packet + pub const IDENT: Field = 2..6; + /// 1 bit flag indicating if there are more fragments coming. + pub const M: usize = 1; +} + +impl<T: AsRef<[u8]>> Header<T> { + /// Create a raw octet buffer with an IPv6 Fragment Header structure. + pub const fn new_unchecked(buffer: T) -> Header<T> { + Header { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Header<T>> { + let header = Self::new_unchecked(buffer); + header.check_len()?; + Ok(header) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let data = self.buffer.as_ref(); + let len = data.len(); + + if len < field::IDENT.end { + Err(Error) + } else { + Ok(()) + } + } + + /// Consume the header, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the fragment offset field. + #[inline] + pub fn frag_offset(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::FR_OF_M]) >> 3 + } + + /// Return more fragment flag field. + #[inline] + pub fn more_frags(&self) -> bool { + let data = self.buffer.as_ref(); + (data[field::M] & 0x1) == 1 + } + + /// Return the fragment identification value field. + #[inline] + pub fn ident(&self) -> u32 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u32(&data[field::IDENT]) + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Header<T> { + /// Set reserved fields. + /// + /// Set 8-bit reserved field after the next header field. + /// Set 2-bit reserved field between fragment offset and more fragments. + #[inline] + pub fn clear_reserved(&mut self) { + let data = self.buffer.as_mut(); + // Retain the higher order 5 bits and lower order 1 bit + data[field::M] &= 0xf9; + } + + /// Set the fragment offset field. + #[inline] + pub fn set_frag_offset(&mut self, value: u16) { + let data = self.buffer.as_mut(); + // Retain the lower order 3 bits + let raw = ((value & 0x1fff) << 3) | ((data[field::M] & 0x7) as u16); + NetworkEndian::write_u16(&mut data[field::FR_OF_M], raw); + } + + /// Set the more fragments flag field. + #[inline] + pub fn set_more_frags(&mut self, value: bool) { + let data = self.buffer.as_mut(); + // Retain the high order 7 bits + let raw = (data[field::M] & 0xfe) | (value as u8 & 0x1); + data[field::M] = raw; + } + + /// Set the fragmentation identification field. + #[inline] + pub fn set_ident(&mut self, value: u32) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::IDENT], value); + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Header<&'a T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match Repr::parse(self) { + Ok(repr) => write!(f, "{repr}"), + Err(err) => { + write!(f, "IPv6 Fragment ({err})")?; + Ok(()) + } + } + } +} + +/// A high-level representation of an IPv6 Fragment header. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Repr { + /// The offset of the data following this header, relative to the start of the Fragmentable + /// Part of the original packet. + pub frag_offset: u16, + /// When there are more fragments following this header + pub more_frags: bool, + /// The identification for every packet that is fragmented. + pub ident: u32, +} + +impl Repr { + /// Parse an IPv6 Fragment Header and return a high-level representation. + pub fn parse<T>(header: &Header<&T>) -> Result<Repr> + where + T: AsRef<[u8]> + ?Sized, + { + Ok(Repr { + frag_offset: header.frag_offset(), + more_frags: header.more_frags(), + ident: header.ident(), + }) + } + + /// Return the length, in bytes, of a header that will be emitted from this high-level + /// representation. + pub const fn buffer_len(&self) -> usize { + field::IDENT.end + } + + /// Emit a high-level representation into an IPv6 Fragment Header. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(&self, header: &mut Header<&mut T>) { + header.clear_reserved(); + header.set_frag_offset(self.frag_offset); + header.set_more_frags(self.more_frags); + header.set_ident(self.ident); + } +} + +impl fmt::Display for Repr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "IPv6 Fragment offset={} more={} ident={}", + self.frag_offset, self.more_frags, self.ident + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + + // A Fragment Header with more fragments remaining + static BYTES_HEADER_MORE_FRAG: [u8; 6] = [0x0, 0x1, 0x0, 0x0, 0x30, 0x39]; + + // A Fragment Header with no more fragments remaining + static BYTES_HEADER_LAST_FRAG: [u8; 6] = [0xa, 0x0, 0x0, 0x1, 0x9, 0x32]; + + #[test] + fn test_check_len() { + // less than 6 bytes + assert_eq!( + Err(Error), + Header::new_unchecked(&BYTES_HEADER_MORE_FRAG[..5]).check_len() + ); + // valid + assert_eq!( + Ok(()), + Header::new_unchecked(&BYTES_HEADER_MORE_FRAG).check_len() + ); + } + + #[test] + fn test_header_deconstruct() { + let header = Header::new_unchecked(&BYTES_HEADER_MORE_FRAG); + assert_eq!(header.frag_offset(), 0); + assert!(header.more_frags()); + assert_eq!(header.ident(), 12345); + + let header = Header::new_unchecked(&BYTES_HEADER_LAST_FRAG); + assert_eq!(header.frag_offset(), 320); + assert!(!header.more_frags()); + assert_eq!(header.ident(), 67890); + } + + #[test] + fn test_repr_parse_valid() { + let header = Header::new_unchecked(&BYTES_HEADER_MORE_FRAG); + let repr = Repr::parse(&header).unwrap(); + assert_eq!( + repr, + Repr { + frag_offset: 0, + more_frags: true, + ident: 12345 + } + ); + + let header = Header::new_unchecked(&BYTES_HEADER_LAST_FRAG); + let repr = Repr::parse(&header).unwrap(); + assert_eq!( + repr, + Repr { + frag_offset: 320, + more_frags: false, + ident: 67890 + } + ); + } + + #[test] + fn test_repr_emit() { + let repr = Repr { + frag_offset: 0, + more_frags: true, + ident: 12345, + }; + let mut bytes = [0u8; 6]; + let mut header = Header::new_unchecked(&mut bytes); + repr.emit(&mut header); + assert_eq!(header.into_inner(), &BYTES_HEADER_MORE_FRAG[0..6]); + + let repr = Repr { + frag_offset: 320, + more_frags: false, + ident: 67890, + }; + let mut bytes = [0u8; 6]; + let mut header = Header::new_unchecked(&mut bytes); + repr.emit(&mut header); + assert_eq!(header.into_inner(), &BYTES_HEADER_LAST_FRAG[0..6]); + } + + #[test] + fn test_buffer_len() { + let header = Header::new_unchecked(&BYTES_HEADER_MORE_FRAG); + let repr = Repr::parse(&header).unwrap(); + assert_eq!(repr.buffer_len(), BYTES_HEADER_MORE_FRAG.len()); + } +} diff --git a/src/wire/ipv6hbh.rs b/src/wire/ipv6hbh.rs new file mode 100644 index 0000000..bc68300 --- /dev/null +++ b/src/wire/ipv6hbh.rs @@ -0,0 +1,176 @@ +use super::{Error, Ipv6Option, Ipv6OptionRepr, Ipv6OptionsIterator, Result}; + +use heapless::Vec; + +/// A read/write wrapper around an IPv6 Hop-by-Hop Header buffer. +pub struct Header<T: AsRef<[u8]>> { + buffer: T, +} + +impl<T: AsRef<[u8]>> Header<T> { + /// Create a raw octet buffer with an IPv6 Hop-by-Hop Header structure. + pub const fn new_unchecked(buffer: T) -> Self { + Header { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Self> { + let header = Self::new_unchecked(buffer); + header.check_len()?; + Ok(header) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + /// + /// The result of this check is invalidated by calling [set_header_len]. + /// + /// [set_header_len]: #method.set_header_len + pub fn check_len(&self) -> Result<()> { + if self.buffer.as_ref().is_empty() { + return Err(Error); + } + + Ok(()) + } + + /// Consume the header, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Header<&'a T> { + /// Return the options of the IPv6 Hop-by-Hop header. + pub fn options(&self) -> &'a [u8] { + self.buffer.as_ref() + } +} + +impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Header<&'a mut T> { + /// Return a mutable pointer to the options of the IPv6 Hop-by-Hop header. + pub fn options_mut(&mut self) -> &mut [u8] { + self.buffer.as_mut() + } +} + +/// A high-level representation of an IPv6 Hop-by-Hop Header. +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Repr<'a> { + pub options: heapless::Vec<Ipv6OptionRepr<'a>, { crate::config::IPV6_HBH_MAX_OPTIONS }>, +} + +impl<'a> Repr<'a> { + /// Parse an IPv6 Hop-by-Hop Header and return a high-level representation. + pub fn parse<T>(header: &'a Header<&'a T>) -> Result<Repr<'a>> + where + T: AsRef<[u8]> + ?Sized, + { + let mut options = Vec::new(); + + let iter = Ipv6OptionsIterator::new(header.options()); + + for option in iter { + let option = option?; + + if let Err(e) = options.push(option) { + net_trace!("error when parsing hop-by-hop options: {}", e); + break; + } + } + + Ok(Self { options }) + } + + /// Return the length, in bytes, of a header that will be emitted from this high-level + /// representation. + pub fn buffer_len(&self) -> usize { + self.options.iter().map(|o| o.buffer_len()).sum() + } + + /// Emit a high-level representation into an IPv6 Hop-by-Hop Header. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(&self, header: &mut Header<&mut T>) { + let mut buffer = header.options_mut(); + + for opt in &self.options { + opt.emit(&mut Ipv6Option::new_unchecked( + &mut buffer[..opt.buffer_len()], + )); + buffer = &mut buffer[opt.buffer_len()..]; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::wire::Error; + + // A Hop-by-Hop Option header with a PadN option of option data length 4. + static REPR_PACKET_PAD4: [u8; 6] = [0x1, 0x4, 0x0, 0x0, 0x0, 0x0]; + + // A Hop-by-Hop Option header with a PadN option of option data length 12. + static REPR_PACKET_PAD12: [u8; 14] = [ + 0x1, 0x0C, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + ]; + + #[test] + fn test_check_len() { + // zero byte buffer + assert_eq!( + Err(Error), + Header::new_unchecked(&REPR_PACKET_PAD4[..0]).check_len() + ); + // valid + assert_eq!(Ok(()), Header::new_unchecked(&REPR_PACKET_PAD4).check_len()); + // valid + assert_eq!( + Ok(()), + Header::new_unchecked(&REPR_PACKET_PAD12).check_len() + ); + } + + #[test] + fn test_repr_parse_valid() { + let header = Header::new_unchecked(&REPR_PACKET_PAD4); + let repr = Repr::parse(&header).unwrap(); + + let mut options = Vec::new(); + options.push(Ipv6OptionRepr::PadN(4)).unwrap(); + assert_eq!(repr, Repr { options }); + + let header = Header::new_unchecked(&REPR_PACKET_PAD12); + let repr = Repr::parse(&header).unwrap(); + + let mut options = Vec::new(); + options.push(Ipv6OptionRepr::PadN(12)).unwrap(); + assert_eq!(repr, Repr { options }); + } + + #[test] + fn test_repr_emit() { + let mut options = Vec::new(); + options.push(Ipv6OptionRepr::PadN(4)).unwrap(); + let repr = Repr { options }; + + let mut bytes = [0u8; 6]; + let mut header = Header::new_unchecked(&mut bytes); + repr.emit(&mut header); + + assert_eq!(header.into_inner(), &REPR_PACKET_PAD4[..]); + + let mut options = Vec::new(); + options.push(Ipv6OptionRepr::PadN(12)).unwrap(); + let repr = Repr { options }; + + let mut bytes = [0u8; 14]; + let mut header = Header::new_unchecked(&mut bytes); + repr.emit(&mut header); + + assert_eq!(header.into_inner(), &REPR_PACKET_PAD12[..]); + } +} diff --git a/src/wire/ipv6option.rs b/src/wire/ipv6option.rs new file mode 100644 index 0000000..dfbd6ac --- /dev/null +++ b/src/wire/ipv6option.rs @@ -0,0 +1,611 @@ +use super::{Error, Result}; +#[cfg(feature = "proto-rpl")] +use super::{RplHopByHopPacket, RplHopByHopRepr}; + +use core::fmt; + +enum_with_unknown! { + /// IPv6 Extension Header Option Type + pub enum Type(u8) { + /// 1 byte of padding + Pad1 = 0, + /// Multiple bytes of padding + PadN = 1, + /// RPL Option + Rpl = 0x63, + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Type::Pad1 => write!(f, "Pad1"), + Type::PadN => write!(f, "PadN"), + Type::Rpl => write!(f, "RPL"), + Type::Unknown(id) => write!(f, "{id}"), + } + } +} + +enum_with_unknown! { + /// Action required when parsing the given IPv6 Extension + /// Header Option Type fails + pub enum FailureType(u8) { + /// Skip this option and continue processing the packet + Skip = 0b00000000, + /// Discard the containing packet + Discard = 0b01000000, + /// Discard the containing packet and notify the sender + DiscardSendAll = 0b10000000, + /// Discard the containing packet and only notify the sender + /// if the sender is a unicast address + DiscardSendUnicast = 0b11000000, + } +} + +impl fmt::Display for FailureType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + FailureType::Skip => write!(f, "skip"), + FailureType::Discard => write!(f, "discard"), + FailureType::DiscardSendAll => write!(f, "discard and send error"), + FailureType::DiscardSendUnicast => write!(f, "discard and send error if unicast"), + FailureType::Unknown(id) => write!(f, "Unknown({id})"), + } + } +} + +impl From<Type> for FailureType { + fn from(other: Type) -> FailureType { + let raw: u8 = other.into(); + Self::from(raw & 0b11000000u8) + } +} + +/// A read/write wrapper around an IPv6 Extension Header Option. +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Ipv6Option<T: AsRef<[u8]>> { + buffer: T, +} + +// Format of Option +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+- - - - - - - - - +// | Option Type | Opt Data Len | Option Data +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+- - - - - - - - - +// +// +// See https://tools.ietf.org/html/rfc8200#section-4.2 for details. +mod field { + #![allow(non_snake_case)] + + use crate::wire::field::*; + + // 8-bit identifier of the type of option. + pub const TYPE: usize = 0; + // 8-bit unsigned integer. Length of the DATA field of this option, in octets. + pub const LENGTH: usize = 1; + // Variable-length field. Option-Type-specific data. + pub const fn DATA(length: u8) -> Field { + 2..length as usize + 2 + } +} + +impl<T: AsRef<[u8]>> Ipv6Option<T> { + /// Create a raw octet buffer with an IPv6 Extension Header Option structure. + pub const fn new_unchecked(buffer: T) -> Ipv6Option<T> { + Ipv6Option { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Ipv6Option<T>> { + let opt = Self::new_unchecked(buffer); + opt.check_len()?; + Ok(opt) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + /// + /// The result of this check is invalidated by calling [set_data_len]. + /// + /// [set_data_len]: #method.set_data_len + pub fn check_len(&self) -> Result<()> { + let data = self.buffer.as_ref(); + let len = data.len(); + + if len < field::LENGTH { + return Err(Error); + } + + if self.option_type() == Type::Pad1 { + return Ok(()); + } + + if len == field::LENGTH { + return Err(Error); + } + + let df = field::DATA(data[field::LENGTH]); + + if len < df.end { + return Err(Error); + } + + Ok(()) + } + + /// Consume the ipv6 option, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the option type. + #[inline] + pub fn option_type(&self) -> Type { + let data = self.buffer.as_ref(); + Type::from(data[field::TYPE]) + } + + /// Return the length of the data. + /// + /// # Panics + /// This function panics if this is an 1-byte padding option. + #[inline] + pub fn data_len(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::LENGTH] + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Ipv6Option<&'a T> { + /// Return the option data. + /// + /// # Panics + /// This function panics if this is an 1-byte padding option. + #[inline] + pub fn data(&self) -> &'a [u8] { + let len = self.data_len(); + let data = self.buffer.as_ref(); + &data[field::DATA(len)] + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Ipv6Option<T> { + /// Set the option type. + #[inline] + pub fn set_option_type(&mut self, value: Type) { + let data = self.buffer.as_mut(); + data[field::TYPE] = value.into(); + } + + /// Set the option data length. + /// + /// # Panics + /// This function panics if this is an 1-byte padding option. + #[inline] + pub fn set_data_len(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::LENGTH] = value; + } +} + +impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Ipv6Option<&'a mut T> { + /// Return a mutable pointer to the option data. + /// + /// # Panics + /// This function panics if this is an 1-byte padding option. + #[inline] + pub fn data_mut(&mut self) -> &mut [u8] { + let len = self.data_len(); + let data = self.buffer.as_mut(); + &mut data[field::DATA(len)] + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Ipv6Option<&'a T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match Repr::parse(self) { + Ok(repr) => write!(f, "{repr}"), + Err(err) => { + write!(f, "IPv6 Extension Option ({err})")?; + Ok(()) + } + } + } +} + +/// A high-level representation of an IPv6 Extension Header Option. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] +pub enum Repr<'a> { + Pad1, + PadN(u8), + #[cfg(feature = "proto-rpl")] + Rpl(RplHopByHopRepr), + Unknown { + type_: Type, + length: u8, + data: &'a [u8], + }, +} + +impl<'a> Repr<'a> { + /// Parse an IPv6 Extension Header Option and return a high-level representation. + pub fn parse<T>(opt: &Ipv6Option<&'a T>) -> Result<Repr<'a>> + where + T: AsRef<[u8]> + ?Sized, + { + match opt.option_type() { + Type::Pad1 => Ok(Repr::Pad1), + Type::PadN => Ok(Repr::PadN(opt.data_len())), + + #[cfg(feature = "proto-rpl")] + Type::Rpl => Ok(Repr::Rpl(RplHopByHopRepr::parse( + &RplHopByHopPacket::new_checked(opt.data())?, + ))), + #[cfg(not(feature = "proto-rpl"))] + Type::Rpl => Ok(Repr::Unknown { + type_: Type::Rpl, + length: opt.data_len(), + data: opt.data(), + }), + + unknown_type @ Type::Unknown(_) => Ok(Repr::Unknown { + type_: unknown_type, + length: opt.data_len(), + data: opt.data(), + }), + } + } + + /// Return the length of a header that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + match *self { + Repr::Pad1 => 1, + Repr::PadN(length) => field::DATA(length).end, + #[cfg(feature = "proto-rpl")] + Repr::Rpl(opt) => field::DATA(opt.buffer_len() as u8).end, + Repr::Unknown { length, .. } => field::DATA(length).end, + } + } + + /// Emit a high-level representation into an IPv6 Extension Header Option. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(&self, opt: &mut Ipv6Option<&'a mut T>) { + match *self { + Repr::Pad1 => opt.set_option_type(Type::Pad1), + Repr::PadN(len) => { + opt.set_option_type(Type::PadN); + opt.set_data_len(len); + // Ensure all padding bytes are set to zero. + for x in opt.data_mut().iter_mut() { + *x = 0 + } + } + #[cfg(feature = "proto-rpl")] + Repr::Rpl(rpl) => { + opt.set_option_type(Type::Rpl); + opt.set_data_len(4); + rpl.emit(&mut crate::wire::RplHopByHopPacket::new_unchecked( + opt.data_mut(), + )); + } + Repr::Unknown { + type_, + length, + data, + } => { + opt.set_option_type(type_); + opt.set_data_len(length); + opt.data_mut().copy_from_slice(&data[..length as usize]); + } + } + } +} + +/// A iterator for IPv6 options. +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Ipv6OptionsIterator<'a> { + pos: usize, + length: usize, + data: &'a [u8], + hit_error: bool, +} + +impl<'a> Ipv6OptionsIterator<'a> { + /// Create a new `Ipv6OptionsIterator`, used to iterate over the + /// options contained in a IPv6 Extension Header (e.g. the Hop-by-Hop + /// header). + pub fn new(data: &'a [u8]) -> Ipv6OptionsIterator<'a> { + let length = data.len(); + Ipv6OptionsIterator { + pos: 0, + hit_error: false, + length, + data, + } + } +} + +impl<'a> Iterator for Ipv6OptionsIterator<'a> { + type Item = Result<Repr<'a>>; + + fn next(&mut self) -> Option<Self::Item> { + if self.pos < self.length && !self.hit_error { + // If we still have data to parse and we have not previously + // hit an error, attempt to parse the next option. + match Ipv6Option::new_checked(&self.data[self.pos..]) { + Ok(hdr) => match Repr::parse(&hdr) { + Ok(repr) => { + self.pos += repr.buffer_len(); + Some(Ok(repr)) + } + Err(e) => { + self.hit_error = true; + Some(Err(e)) + } + }, + Err(e) => { + self.hit_error = true; + Some(Err(e)) + } + } + } else { + // If we failed to parse a previous option or hit the end of the + // buffer, we do not continue to iterate. + None + } + } +} + +impl<'a> fmt::Display for Repr<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "IPv6 Option ")?; + match *self { + Repr::Pad1 => write!(f, "{} ", Type::Pad1), + Repr::PadN(len) => write!(f, "{} length={} ", Type::PadN, len), + #[cfg(feature = "proto-rpl")] + Repr::Rpl(rpl) => write!(f, "{} {rpl}", Type::Rpl), + Repr::Unknown { type_, length, .. } => write!(f, "{type_} length={length} "), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + static IPV6OPTION_BYTES_PAD1: [u8; 1] = [0x0]; + static IPV6OPTION_BYTES_PADN: [u8; 3] = [0x1, 0x1, 0x0]; + static IPV6OPTION_BYTES_UNKNOWN: [u8; 5] = [0xff, 0x3, 0x0, 0x0, 0x0]; + #[cfg(feature = "proto-rpl")] + static IPV6OPTION_BYTES_RPL: [u8; 6] = [0x63, 0x04, 0x00, 0x1e, 0x08, 0x00]; + + #[test] + fn test_check_len() { + let bytes = [0u8]; + // zero byte buffer + assert_eq!( + Err(Error), + Ipv6Option::new_unchecked(&bytes[..0]).check_len() + ); + // pad1 + assert_eq!( + Ok(()), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PAD1).check_len() + ); + + // padn with truncated data + assert_eq!( + Err(Error), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PADN[..2]).check_len() + ); + // padn + assert_eq!( + Ok(()), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PADN).check_len() + ); + + // unknown option type with truncated data + assert_eq!( + Err(Error), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_UNKNOWN[..4]).check_len() + ); + assert_eq!( + Err(Error), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_UNKNOWN[..1]).check_len() + ); + // unknown type + assert_eq!( + Ok(()), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_UNKNOWN).check_len() + ); + + #[cfg(feature = "proto-rpl")] + { + assert_eq!( + Ok(()), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_RPL).check_len() + ); + } + } + + #[test] + #[should_panic(expected = "index out of bounds")] + fn test_data_len() { + let opt = Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PAD1); + opt.data_len(); + } + + #[test] + fn test_option_deconstruct() { + // one octet of padding + let opt = Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PAD1); + assert_eq!(opt.option_type(), Type::Pad1); + + // two octets of padding + let bytes: [u8; 2] = [0x1, 0x0]; + let opt = Ipv6Option::new_unchecked(&bytes); + assert_eq!(opt.option_type(), Type::PadN); + assert_eq!(opt.data_len(), 0); + + // three octets of padding + let opt = Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PADN); + assert_eq!(opt.option_type(), Type::PadN); + assert_eq!(opt.data_len(), 1); + assert_eq!(opt.data(), &[0]); + + // extra bytes in buffer + let bytes: [u8; 10] = [0x1, 0x7, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff]; + let opt = Ipv6Option::new_unchecked(&bytes); + assert_eq!(opt.option_type(), Type::PadN); + assert_eq!(opt.data_len(), 7); + assert_eq!(opt.data(), &[0, 0, 0, 0, 0, 0, 0]); + + // unrecognized option + let bytes: [u8; 1] = [0xff]; + let opt = Ipv6Option::new_unchecked(&bytes); + assert_eq!(opt.option_type(), Type::Unknown(255)); + + // unrecognized option without length and data + assert_eq!(Ipv6Option::new_checked(&bytes), Err(Error)); + + #[cfg(feature = "proto-rpl")] + { + let opt = Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_RPL); + assert_eq!(opt.option_type(), Type::Rpl); + assert_eq!(opt.data_len(), 4); + assert_eq!(opt.data(), &[0x00, 0x1e, 0x08, 0x00]); + } + } + + #[test] + fn test_option_parse() { + // one octet of padding + let opt = Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PAD1); + let pad1 = Repr::parse(&opt).unwrap(); + assert_eq!(pad1, Repr::Pad1); + assert_eq!(pad1.buffer_len(), 1); + + // two or more octets of padding + let opt = Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PADN); + let padn = Repr::parse(&opt).unwrap(); + assert_eq!(padn, Repr::PadN(1)); + assert_eq!(padn.buffer_len(), 3); + + // unrecognized option type + let data = [0u8; 3]; + let opt = Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_UNKNOWN); + let unknown = Repr::parse(&opt).unwrap(); + assert_eq!( + unknown, + Repr::Unknown { + type_: Type::Unknown(255), + length: 3, + data: &data + } + ); + + #[cfg(feature = "proto-rpl")] + { + let opt = Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_RPL); + let rpl = Repr::parse(&opt).unwrap(); + + assert_eq!( + rpl, + Repr::Rpl(crate::wire::RplHopByHopRepr { + down: false, + rank_error: false, + forwarding_error: false, + instance_id: crate::wire::RplInstanceId::from(0x1e), + sender_rank: 0x0800, + }) + ); + } + } + + #[test] + fn test_option_emit() { + let repr = Repr::Pad1; + let mut bytes = [255u8; 1]; // don't assume bytes are initialized to zero + let mut opt = Ipv6Option::new_unchecked(&mut bytes); + repr.emit(&mut opt); + assert_eq!(opt.into_inner(), &IPV6OPTION_BYTES_PAD1); + + let repr = Repr::PadN(1); + let mut bytes = [255u8; 3]; // don't assume bytes are initialized to zero + let mut opt = Ipv6Option::new_unchecked(&mut bytes); + repr.emit(&mut opt); + assert_eq!(opt.into_inner(), &IPV6OPTION_BYTES_PADN); + + let data = [0u8; 3]; + let repr = Repr::Unknown { + type_: Type::Unknown(255), + length: 3, + data: &data, + }; + let mut bytes = [254u8; 5]; // don't assume bytes are initialized to zero + let mut opt = Ipv6Option::new_unchecked(&mut bytes); + repr.emit(&mut opt); + assert_eq!(opt.into_inner(), &IPV6OPTION_BYTES_UNKNOWN); + + #[cfg(feature = "proto-rpl")] + { + let opt = Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_RPL); + let rpl = Repr::parse(&opt).unwrap(); + let mut bytes = [0u8; 6]; + rpl.emit(&mut Ipv6Option::new_unchecked(&mut bytes)); + + assert_eq!(&bytes, &IPV6OPTION_BYTES_RPL); + } + } + + #[test] + fn test_failure_type() { + let mut failure_type: FailureType = Type::Pad1.into(); + assert_eq!(failure_type, FailureType::Skip); + failure_type = Type::PadN.into(); + assert_eq!(failure_type, FailureType::Skip); + failure_type = Type::Unknown(0b01000001).into(); + assert_eq!(failure_type, FailureType::Discard); + failure_type = Type::Unknown(0b10100000).into(); + assert_eq!(failure_type, FailureType::DiscardSendAll); + failure_type = Type::Unknown(0b11000100).into(); + assert_eq!(failure_type, FailureType::DiscardSendUnicast); + } + + #[test] + fn test_options_iter() { + let options = [ + 0x00, 0x01, 0x01, 0x00, 0x01, 0x02, 0x00, 0x00, 0x01, 0x00, 0x00, 0x11, 0x00, 0x01, + 0x08, 0x00, + ]; + + let iterator = Ipv6OptionsIterator::new(&options); + for (i, opt) in iterator.enumerate() { + match (i, opt) { + (0, Ok(Repr::Pad1)) => continue, + (1, Ok(Repr::PadN(1))) => continue, + (2, Ok(Repr::PadN(2))) => continue, + (3, Ok(Repr::PadN(0))) => continue, + (4, Ok(Repr::Pad1)) => continue, + ( + 5, + Ok(Repr::Unknown { + type_: Type::Unknown(0x11), + length: 0, + .. + }), + ) => continue, + (6, Err(Error)) => continue, + (i, res) => panic!("Unexpected option `{res:?}` at index {i}"), + } + } + } +} diff --git a/src/wire/ipv6routing.rs b/src/wire/ipv6routing.rs new file mode 100644 index 0000000..a4943b7 --- /dev/null +++ b/src/wire/ipv6routing.rs @@ -0,0 +1,606 @@ +use super::{Error, Result}; +use core::fmt; + +use crate::wire::Ipv6Address as Address; + +enum_with_unknown! { + /// IPv6 Extension Routing Header Routing Type + pub enum Type(u8) { + /// Source Route (DEPRECATED) + /// + /// See https://tools.ietf.org/html/rfc5095 for details. + Type0 = 0, + /// Nimrod (DEPRECATED 2009-05-06) + Nimrod = 1, + /// Type 2 Routing Header for Mobile IPv6 + /// + /// See https://tools.ietf.org/html/rfc6275#section-6.4 for details. + Type2 = 2, + /// RPL Source Routing Header + /// + /// See https://tools.ietf.org/html/rfc6554 for details. + Rpl = 3, + /// RFC3692-style Experiment 1 + /// + /// See https://tools.ietf.org/html/rfc4727 for details. + Experiment1 = 253, + /// RFC3692-style Experiment 2 + /// + /// See https://tools.ietf.org/html/rfc4727 for details. + Experiment2 = 254, + /// Reserved for future use + Reserved = 252 + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Type::Type0 => write!(f, "Type0"), + Type::Nimrod => write!(f, "Nimrod"), + Type::Type2 => write!(f, "Type2"), + Type::Rpl => write!(f, "Rpl"), + Type::Experiment1 => write!(f, "Experiment1"), + Type::Experiment2 => write!(f, "Experiment2"), + Type::Reserved => write!(f, "Reserved"), + Type::Unknown(id) => write!(f, "{id}"), + } + } +} + +/// A read/write wrapper around an IPv6 Routing Header buffer. +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Header<T: AsRef<[u8]>> { + buffer: T, +} + +// Format of the Routing Header +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Next Header | Hdr Ext Len | Routing Type | Segments Left | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// . . +// . type-specific data . +// . . +// | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// +// See https://tools.ietf.org/html/rfc8200#section-4.4 for details. +// +// **NOTE**: The fields start counting after the header length field. +mod field { + #![allow(non_snake_case)] + + use crate::wire::field::*; + + // Minimum size of the header. + pub const MIN_HEADER_SIZE: usize = 2; + + // 8-bit identifier of a particular Routing header variant. + pub const TYPE: usize = 0; + // 8-bit unsigned integer. The number of route segments remaining. + pub const SEG_LEFT: usize = 1; + + // The Type 2 Routing Header has the following format: + // + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Next Header | Hdr Ext Len=2 | Routing Type=2|Segments Left=1| + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Reserved | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | | + // + + + // | | + // + Home Address + + // | | + // + + + // | | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + // 16-byte field containing the home address of the destination mobile node. + pub const HOME_ADDRESS: Field = 6..22; + + // The RPL Source Routing Header has the following format: + // + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Next Header | Hdr Ext Len | Routing Type | Segments Left | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | CmprI | CmprE | Pad | Reserved | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | | + // . . + // . Addresses[1..n] . + // . . + // | | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + // 8-bit field containing the CmprI and CmprE values. + pub const CMPR: usize = 2; + // 8-bit field containing the Pad value. + pub const PAD: usize = 3; + // Variable length field containing addresses + pub const ADDRESSES: usize = 6; +} + +/// Core getter methods relevant to any routing type. +impl<T: AsRef<[u8]>> Header<T> { + /// Create a raw octet buffer with an IPv6 Routing Header structure. + pub const fn new_unchecked(buffer: T) -> Header<T> { + Header { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Header<T>> { + let header = Self::new_unchecked(buffer); + header.check_len()?; + Ok(header) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + /// + /// The result of this check is invalidated by calling [set_header_len]. + /// + /// [set_header_len]: #method.set_header_len + pub fn check_len(&self) -> Result<()> { + let len = self.buffer.as_ref().len(); + if len < field::MIN_HEADER_SIZE { + return Err(Error); + } + + match self.routing_type() { + Type::Type2 if len < field::HOME_ADDRESS.end => return Err(Error), + Type::Rpl if len < field::ADDRESSES => return Err(Error), + _ => (), + } + + Ok(()) + } + + /// Consume the header, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the routing type field. + #[inline] + pub fn routing_type(&self) -> Type { + let data = self.buffer.as_ref(); + Type::from(data[field::TYPE]) + } + + /// Return the segments left field. + #[inline] + pub fn segments_left(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::SEG_LEFT] + } +} + +/// Getter methods for the Type 2 Routing Header routing type. +impl<T: AsRef<[u8]>> Header<T> { + /// Return the IPv6 Home Address + /// + /// # Panics + /// This function may panic if this header is not the Type2 Routing Header routing type. + pub fn home_address(&self) -> Address { + let data = self.buffer.as_ref(); + Address::from_bytes(&data[field::HOME_ADDRESS]) + } +} + +/// Getter methods for the RPL Source Routing Header routing type. +impl<T: AsRef<[u8]>> Header<T> { + /// Return the number of prefix octets elided from addresses[1..n-1]. + /// + /// # Panics + /// This function may panic if this header is not the RPL Source Routing Header routing type. + pub fn cmpr_i(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::CMPR] >> 4 + } + + /// Return the number of prefix octets elided from the last address (`addresses[n]`). + /// + /// # Panics + /// This function may panic if this header is not the RPL Source Routing Header routing type. + pub fn cmpr_e(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::CMPR] & 0xf + } + + /// Return the number of octets used for padding after `addresses[n]`. + /// + /// # Panics + /// This function may panic if this header is not the RPL Source Routing Header routing type. + pub fn pad(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::PAD] >> 4 + } + + /// Return the address vector in bytes + /// + /// # Panics + /// This function may panic if this header is not the RPL Source Routing Header routing type. + pub fn addresses(&self) -> &[u8] { + let data = self.buffer.as_ref(); + &data[field::ADDRESSES..] + } +} + +/// Core setter methods relevant to any routing type. +impl<T: AsRef<[u8]> + AsMut<[u8]>> Header<T> { + /// Set the routing type. + #[inline] + pub fn set_routing_type(&mut self, value: Type) { + let data = self.buffer.as_mut(); + data[field::TYPE] = value.into(); + } + + /// Set the segments left field. + #[inline] + pub fn set_segments_left(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::SEG_LEFT] = value; + } + + /// Initialize reserved fields to 0. + /// + /// # Panics + /// This function may panic if the routing type is not set. + #[inline] + pub fn clear_reserved(&mut self) { + let routing_type = self.routing_type(); + let data = self.buffer.as_mut(); + + match routing_type { + Type::Type2 => { + data[4] = 0; + data[5] = 0; + data[6] = 0; + data[7] = 0; + } + Type::Rpl => { + // Retain the higher order 4 bits of the padding field + data[field::PAD] &= 0xF0; + data[4] = 0; + data[5] = 0; + } + + _ => panic!("Unrecognized routing type when clearing reserved fields."), + } + } +} + +/// Setter methods for the RPL Source Routing Header routing type. +impl<T: AsRef<[u8]> + AsMut<[u8]>> Header<T> { + /// Set the Ipv6 Home Address + /// + /// # Panics + /// This function may panic if this header is not the Type 2 Routing Header routing type. + pub fn set_home_address(&mut self, value: Address) { + let data = self.buffer.as_mut(); + data[field::HOME_ADDRESS].copy_from_slice(value.as_bytes()); + } +} + +/// Setter methods for the RPL Source Routing Header routing type. +impl<T: AsRef<[u8]> + AsMut<[u8]>> Header<T> { + /// Set the number of prefix octets elided from addresses[1..n-1]. + /// + /// # Panics + /// This function may panic if this header is not the RPL Source Routing Header routing type. + pub fn set_cmpr_i(&mut self, value: u8) { + let data = self.buffer.as_mut(); + let raw = (value << 4) | (data[field::CMPR] & 0xF); + data[field::CMPR] = raw; + } + + /// Set the number of prefix octets elided from the last address (`addresses[n]`). + /// + /// # Panics + /// This function may panic if this header is not the RPL Source Routing Header routing type. + pub fn set_cmpr_e(&mut self, value: u8) { + let data = self.buffer.as_mut(); + let raw = (value & 0xF) | (data[field::CMPR] & 0xF0); + data[field::CMPR] = raw; + } + + /// Set the number of octets used for padding after `addresses[n]`. + /// + /// # Panics + /// This function may panic if this header is not the RPL Source Routing Header routing type. + pub fn set_pad(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::PAD] = value << 4; + } + + /// Set address data + /// + /// # Panics + /// This function may panic if this header is not the RPL Source Routing Header routing type. + pub fn set_addresses(&mut self, value: &[u8]) { + let data = self.buffer.as_mut(); + let addresses = &mut data[field::ADDRESSES..]; + addresses.copy_from_slice(value); + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Header<&'a T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match Repr::parse(self) { + Ok(repr) => write!(f, "{repr}"), + Err(err) => { + write!(f, "IPv6 Routing ({err})")?; + Ok(()) + } + } + } +} + +/// A high-level representation of an IPv6 Routing Header. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] +pub enum Repr<'a> { + Type2 { + /// Number of route segments remaining. + segments_left: u8, + /// The home address of the destination mobile node. + home_address: Address, + }, + Rpl { + /// Number of route segments remaining. + segments_left: u8, + /// Number of prefix octets from each segment, except the last segment, that are elided. + cmpr_i: u8, + /// Number of prefix octets from the last segment that are elided. + cmpr_e: u8, + /// Number of octets that are used for padding after `address[n]` at the end of the + /// RPL Source Route Header. + pad: u8, + /// Vector of addresses, numbered 1 to `n`. + addresses: &'a [u8], + }, +} + +impl<'a> Repr<'a> { + /// Parse an IPv6 Routing Header and return a high-level representation. + pub fn parse<T>(header: &'a Header<&'a T>) -> Result<Repr<'a>> + where + T: AsRef<[u8]> + ?Sized, + { + match header.routing_type() { + Type::Type2 => Ok(Repr::Type2 { + segments_left: header.segments_left(), + home_address: header.home_address(), + }), + Type::Rpl => Ok(Repr::Rpl { + segments_left: header.segments_left(), + cmpr_i: header.cmpr_i(), + cmpr_e: header.cmpr_e(), + pad: header.pad(), + addresses: header.addresses(), + }), + + _ => Err(Error), + } + } + + /// Return the length, in bytes, of a header that will be emitted from this high-level + /// representation. + pub const fn buffer_len(&self) -> usize { + match self { + // Routing Type + Segments Left + Reserved + Home Address + Repr::Type2 { home_address, .. } => 2 + 4 + home_address.as_bytes().len(), + Repr::Rpl { addresses, .. } => 2 + 4 + addresses.len(), + } + } + + /// Emit a high-level representation into an IPv6 Routing Header. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(&self, header: &mut Header<&mut T>) { + match *self { + Repr::Type2 { + segments_left, + home_address, + } => { + header.set_routing_type(Type::Type2); + header.set_segments_left(segments_left); + header.clear_reserved(); + header.set_home_address(home_address); + } + Repr::Rpl { + segments_left, + cmpr_i, + cmpr_e, + pad, + addresses, + } => { + header.set_routing_type(Type::Rpl); + header.set_segments_left(segments_left); + header.set_cmpr_i(cmpr_i); + header.set_cmpr_e(cmpr_e); + header.set_pad(pad); + header.clear_reserved(); + header.set_addresses(addresses); + } + } + } +} + +impl<'a> fmt::Display for Repr<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Repr::Type2 { + segments_left, + home_address, + } => { + write!( + f, + "IPv6 Routing type={} seg_left={} home_address={}", + Type::Type2, + segments_left, + home_address + ) + } + Repr::Rpl { + segments_left, + cmpr_i, + cmpr_e, + pad, + .. + } => { + write!( + f, + "IPv6 Routing type={} seg_left={} cmpr_i={} cmpr_e={} pad={}", + Type::Rpl, + segments_left, + cmpr_i, + cmpr_e, + pad + ) + } + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + // A Type 2 Routing Header + static BYTES_TYPE2: [u8; 22] = [ + 0x2, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x1, + ]; + + // A representation of a Type 2 Routing header + static REPR_TYPE2: Repr = Repr::Type2 { + segments_left: 1, + home_address: Address::LOOPBACK, + }; + + // A Source Routing Header with full IPv6 addresses in bytes + static BYTES_SRH_FULL: [u8; 38] = [ + 0x3, 0x2, 0x0, 0x0, 0x0, 0x0, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x2, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x3, 0x1, + ]; + + // A representation of a Source Routing Header with full IPv6 addresses + static REPR_SRH_FULL: Repr = Repr::Rpl { + segments_left: 2, + cmpr_i: 0, + cmpr_e: 0, + pad: 0, + addresses: &[ + 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x1, + ], + }; + + // A Source Routing Header with elided IPv6 addresses in bytes + static BYTES_SRH_ELIDED: [u8; 14] = [ + 0x3, 0x2, 0xfe, 0x50, 0x0, 0x0, 0x2, 0x3, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, + ]; + + // A representation of a Source Routing Header with elided IPv6 addresses + static REPR_SRH_ELIDED: Repr = Repr::Rpl { + segments_left: 2, + cmpr_i: 15, + cmpr_e: 14, + pad: 5, + addresses: &[0x2, 0x3, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0], + }; + + #[test] + fn test_check_len() { + // less than min header size + assert_eq!( + Err(Error), + Header::new_unchecked(&BYTES_TYPE2[..3]).check_len() + ); + assert_eq!( + Err(Error), + Header::new_unchecked(&BYTES_SRH_FULL[..3]).check_len() + ); + assert_eq!( + Err(Error), + Header::new_unchecked(&BYTES_SRH_ELIDED[..3]).check_len() + ); + // valid + assert_eq!(Ok(()), Header::new_unchecked(&BYTES_TYPE2[..]).check_len()); + assert_eq!( + Ok(()), + Header::new_unchecked(&BYTES_SRH_FULL[..]).check_len() + ); + assert_eq!( + Ok(()), + Header::new_unchecked(&BYTES_SRH_ELIDED[..]).check_len() + ); + } + + #[test] + fn test_header_deconstruct() { + let header = Header::new_unchecked(&BYTES_TYPE2[..]); + assert_eq!(header.routing_type(), Type::Type2); + assert_eq!(header.segments_left(), 1); + assert_eq!(header.home_address(), Address::LOOPBACK); + + let header = Header::new_unchecked(&BYTES_SRH_FULL[..]); + assert_eq!(header.routing_type(), Type::Rpl); + assert_eq!(header.segments_left(), 2); + assert_eq!(header.addresses(), &BYTES_SRH_FULL[6..]); + + let header = Header::new_unchecked(&BYTES_SRH_ELIDED[..]); + assert_eq!(header.routing_type(), Type::Rpl); + assert_eq!(header.segments_left(), 2); + assert_eq!(header.addresses(), &BYTES_SRH_ELIDED[6..]); + } + + #[test] + fn test_repr_parse_valid() { + let header = Header::new_checked(&BYTES_TYPE2[..]).unwrap(); + let repr = Repr::parse(&header).unwrap(); + assert_eq!(repr, REPR_TYPE2); + + let header = Header::new_checked(&BYTES_SRH_FULL[..]).unwrap(); + let repr = Repr::parse(&header).unwrap(); + assert_eq!(repr, REPR_SRH_FULL); + + let header = Header::new_checked(&BYTES_SRH_ELIDED[..]).unwrap(); + let repr = Repr::parse(&header).unwrap(); + assert_eq!(repr, REPR_SRH_ELIDED); + } + + #[test] + fn test_repr_emit() { + let mut bytes = [0u8; 22]; + let mut header = Header::new_unchecked(&mut bytes[..]); + REPR_TYPE2.emit(&mut header); + assert_eq!(header.into_inner(), &BYTES_TYPE2[..]); + + let mut bytes = [0u8; 38]; + let mut header = Header::new_unchecked(&mut bytes[..]); + REPR_SRH_FULL.emit(&mut header); + assert_eq!(header.into_inner(), &BYTES_SRH_FULL[..]); + + let mut bytes = [0u8; 14]; + let mut header = Header::new_unchecked(&mut bytes[..]); + REPR_SRH_ELIDED.emit(&mut header); + assert_eq!(header.into_inner(), &BYTES_SRH_ELIDED[..]); + } + + #[test] + fn test_buffer_len() { + assert_eq!(REPR_TYPE2.buffer_len(), 22); + assert_eq!(REPR_SRH_FULL.buffer_len(), 38); + assert_eq!(REPR_SRH_ELIDED.buffer_len(), 14); + } +} diff --git a/src/wire/mld.rs b/src/wire/mld.rs new file mode 100644 index 0000000..18872b5 --- /dev/null +++ b/src/wire/mld.rs @@ -0,0 +1,578 @@ +// Packet implementation for the Multicast Listener Discovery +// protocol. See [RFC 3810] and [RFC 2710]. +// +// [RFC 3810]: https://tools.ietf.org/html/rfc3810 +// [RFC 2710]: https://tools.ietf.org/html/rfc2710 + +use byteorder::{ByteOrder, NetworkEndian}; + +use super::{Error, Result}; +use crate::wire::icmpv6::{field, Message, Packet}; +use crate::wire::Ipv6Address; + +enum_with_unknown! { + /// MLDv2 Multicast Listener Report Record Type. See [RFC 3810 § 5.2.12] for + /// more details. + /// + /// [RFC 3810 § 5.2.12]: https://tools.ietf.org/html/rfc3010#section-5.2.12 + pub enum RecordType(u8) { + /// Interface has a filter mode of INCLUDE for the specified multicast address. + ModeIsInclude = 0x01, + /// Interface has a filter mode of EXCLUDE for the specified multicast address. + ModeIsExclude = 0x02, + /// Interface has changed to a filter mode of INCLUDE for the specified + /// multicast address. + ChangeToInclude = 0x03, + /// Interface has changed to a filter mode of EXCLUDE for the specified + /// multicast address. + ChangeToExclude = 0x04, + /// Interface wishes to listen to the sources in the specified list. + AllowNewSources = 0x05, + /// Interface no longer wishes to listen to the sources in the specified list. + BlockOldSources = 0x06 + } +} + +/// Getters for the Multicast Listener Query message header. +/// See [RFC 3810 § 5.1]. +/// +/// [RFC 3810 § 5.1]: https://tools.ietf.org/html/rfc3010#section-5.1 +impl<T: AsRef<[u8]>> Packet<T> { + /// Return the maximum response code field. + #[inline] + pub fn max_resp_code(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::MAX_RESP_CODE]) + } + + /// Return the address being queried. + #[inline] + pub fn mcast_addr(&self) -> Ipv6Address { + let data = self.buffer.as_ref(); + Ipv6Address::from_bytes(&data[field::QUERY_MCAST_ADDR]) + } + + /// Return the Suppress Router-Side Processing flag. + #[inline] + pub fn s_flag(&self) -> bool { + let data = self.buffer.as_ref(); + (data[field::SQRV] & 0x08) != 0 + } + + /// Return the Querier's Robustness Variable. + #[inline] + pub fn qrv(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::SQRV] & 0x7 + } + + /// Return the Querier's Query Interval Code. + #[inline] + pub fn qqic(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::QQIC] + } + + /// Return number of sources. + #[inline] + pub fn num_srcs(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::QUERY_NUM_SRCS]) + } +} + +/// Getters for the Multicast Listener Report message header. +/// See [RFC 3810 § 5.2]. +/// +/// [RFC 3810 § 5.2]: https://tools.ietf.org/html/rfc3010#section-5.2 +impl<T: AsRef<[u8]>> Packet<T> { + /// Return the number of Multicast Address Records. + #[inline] + pub fn nr_mcast_addr_rcrds(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::NR_MCAST_RCRDS]) + } +} + +/// Setters for the Multicast Listener Query message header. +/// See [RFC 3810 § 5.1]. +/// +/// [RFC 3810 § 5.1]: https://tools.ietf.org/html/rfc3010#section-5.1 +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the maximum response code field. + #[inline] + pub fn set_max_resp_code(&mut self, code: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::MAX_RESP_CODE], code); + } + + /// Set the address being queried. + #[inline] + pub fn set_mcast_addr(&mut self, addr: Ipv6Address) { + let data = self.buffer.as_mut(); + data[field::QUERY_MCAST_ADDR].copy_from_slice(addr.as_bytes()); + } + + /// Set the Suppress Router-Side Processing flag. + #[inline] + pub fn set_s_flag(&mut self) { + let data = self.buffer.as_mut(); + let current = data[field::SQRV]; + data[field::SQRV] = 0x8 | (current & 0x7); + } + + /// Clear the Suppress Router-Side Processing flag. + #[inline] + pub fn clear_s_flag(&mut self) { + let data = self.buffer.as_mut(); + data[field::SQRV] &= 0x7; + } + + /// Set the Querier's Robustness Variable. + #[inline] + pub fn set_qrv(&mut self, value: u8) { + assert!(value < 8); + let data = self.buffer.as_mut(); + data[field::SQRV] = (data[field::SQRV] & 0x8) | value & 0x7; + } + + /// Set the Querier's Query Interval Code. + #[inline] + pub fn set_qqic(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::QQIC] = value; + } + + /// Set number of sources. + #[inline] + pub fn set_num_srcs(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::QUERY_NUM_SRCS], value); + } +} + +/// Setters for the Multicast Listener Report message header. +/// See [RFC 3810 § 5.2]. +/// +/// [RFC 3810 § 5.2]: https://tools.ietf.org/html/rfc3010#section-5.2 +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the number of Multicast Address Records. + #[inline] + pub fn set_nr_mcast_addr_rcrds(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::NR_MCAST_RCRDS], value) + } +} + +/// A read/write wrapper around an MLDv2 Listener Report Message Address Record. +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct AddressRecord<T: AsRef<[u8]>> { + buffer: T, +} + +impl<T: AsRef<[u8]>> AddressRecord<T> { + /// Imbue a raw octet buffer with a Address Record structure. + pub const fn new_unchecked(buffer: T) -> Self { + Self { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Self> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error::Truncated)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let len = self.buffer.as_ref().len(); + if len < field::RECORD_MCAST_ADDR.end { + Err(Error) + } else { + Ok(()) + } + } + + /// Consume the packet, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } +} + +/// Getters for a MLDv2 Listener Report Message Address Record. +/// See [RFC 3810 § 5.2]. +/// +/// [RFC 3810 § 5.2]: https://tools.ietf.org/html/rfc3010#section-5.2 +impl<T: AsRef<[u8]>> AddressRecord<T> { + /// Return the record type for the given sources. + #[inline] + pub fn record_type(&self) -> RecordType { + let data = self.buffer.as_ref(); + RecordType::from(data[field::RECORD_TYPE]) + } + + /// Return the length of the auxiliary data. + #[inline] + pub fn aux_data_len(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::AUX_DATA_LEN] + } + + /// Return the number of sources field. + #[inline] + pub fn num_srcs(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::RECORD_NUM_SRCS]) + } + + /// Return the multicast address field. + #[inline] + pub fn mcast_addr(&self) -> Ipv6Address { + let data = self.buffer.as_ref(); + Ipv6Address::from_bytes(&data[field::RECORD_MCAST_ADDR]) + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> AddressRecord<&'a T> { + /// Return a pointer to the address records. + #[inline] + pub fn payload(&self) -> &'a [u8] { + let data = self.buffer.as_ref(); + &data[field::RECORD_MCAST_ADDR.end..] + } +} + +/// Setters for a MLDv2 Listener Report Message Address Record. +/// See [RFC 3810 § 5.2]. +/// +/// [RFC 3810 § 5.2]: https://tools.ietf.org/html/rfc3010#section-5.2 +impl<T: AsMut<[u8]> + AsRef<[u8]>> AddressRecord<T> { + /// Return the record type for the given sources. + #[inline] + pub fn set_record_type(&mut self, rty: RecordType) { + let data = self.buffer.as_mut(); + data[field::RECORD_TYPE] = rty.into(); + } + + /// Return the length of the auxiliary data. + #[inline] + pub fn set_aux_data_len(&mut self, len: u8) { + let data = self.buffer.as_mut(); + data[field::AUX_DATA_LEN] = len; + } + + /// Return the number of sources field. + #[inline] + pub fn set_num_srcs(&mut self, num_srcs: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::RECORD_NUM_SRCS], num_srcs); + } + + /// Return the multicast address field. + /// + /// # Panics + /// This function panics if the given address is not a multicast address. + #[inline] + pub fn set_mcast_addr(&mut self, addr: Ipv6Address) { + assert!(addr.is_multicast()); + let data = self.buffer.as_mut(); + data[field::RECORD_MCAST_ADDR].copy_from_slice(addr.as_bytes()); + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> AddressRecord<T> { + /// Return a pointer to the address records. + #[inline] + pub fn payload_mut(&mut self) -> &mut [u8] { + let data = self.buffer.as_mut(); + &mut data[field::RECORD_MCAST_ADDR.end..] + } +} + +/// A high-level representation of an MLDv2 packet header. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Repr<'a> { + Query { + max_resp_code: u16, + mcast_addr: Ipv6Address, + s_flag: bool, + qrv: u8, + qqic: u8, + num_srcs: u16, + data: &'a [u8], + }, + Report { + nr_mcast_addr_rcrds: u16, + data: &'a [u8], + }, +} + +impl<'a> Repr<'a> { + /// Parse an MLDv2 packet and return a high-level representation. + pub fn parse<T>(packet: &Packet<&'a T>) -> Result<Repr<'a>> + where + T: AsRef<[u8]> + ?Sized, + { + match packet.msg_type() { + Message::MldQuery => Ok(Repr::Query { + max_resp_code: packet.max_resp_code(), + mcast_addr: packet.mcast_addr(), + s_flag: packet.s_flag(), + qrv: packet.qrv(), + qqic: packet.qqic(), + num_srcs: packet.num_srcs(), + data: packet.payload(), + }), + Message::MldReport => Ok(Repr::Report { + nr_mcast_addr_rcrds: packet.nr_mcast_addr_rcrds(), + data: packet.payload(), + }), + _ => Err(Error), + } + } + + /// Return the length of a packet that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + match self { + Repr::Query { data, .. } => field::QUERY_NUM_SRCS.end + data.len(), + Repr::Report { data, .. } => field::NR_MCAST_RCRDS.end + data.len(), + } + } + + /// Emit a high-level representation into an MLDv2 packet. + pub fn emit<T>(&self, packet: &mut Packet<&mut T>) + where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { + match self { + Repr::Query { + max_resp_code, + mcast_addr, + s_flag, + qrv, + qqic, + num_srcs, + data, + } => { + packet.set_msg_type(Message::MldQuery); + packet.set_msg_code(0); + packet.clear_reserved(); + packet.set_max_resp_code(*max_resp_code); + packet.set_mcast_addr(*mcast_addr); + if *s_flag { + packet.set_s_flag(); + } else { + packet.clear_s_flag(); + } + packet.set_qrv(*qrv); + packet.set_qqic(*qqic); + packet.set_num_srcs(*num_srcs); + packet.payload_mut().copy_from_slice(&data[..]); + } + Repr::Report { + nr_mcast_addr_rcrds, + data, + } => { + packet.set_msg_type(Message::MldReport); + packet.set_msg_code(0); + packet.clear_reserved(); + packet.set_nr_mcast_addr_rcrds(*nr_mcast_addr_rcrds); + packet.payload_mut().copy_from_slice(&data[..]); + } + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::phy::ChecksumCapabilities; + use crate::wire::icmpv6::Message; + use crate::wire::Icmpv6Repr; + + static QUERY_PACKET_BYTES: [u8; 44] = [ + 0x82, 0x00, 0x73, 0x74, 0x04, 0x00, 0x00, 0x00, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x0a, 0x12, 0x00, 0x01, 0xff, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + ]; + + static QUERY_PACKET_PAYLOAD: [u8; 16] = [ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, + ]; + + static REPORT_PACKET_BYTES: [u8; 44] = [ + 0x8f, 0x00, 0x73, 0x85, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0xff, 0x02, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + ]; + + static REPORT_PACKET_PAYLOAD: [u8; 36] = [ + 0x01, 0x00, 0x00, 0x01, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + ]; + + fn create_repr<'a>(ty: Message) -> Icmpv6Repr<'a> { + match ty { + Message::MldQuery => Icmpv6Repr::Mld(Repr::Query { + max_resp_code: 0x400, + mcast_addr: Ipv6Address::LINK_LOCAL_ALL_NODES, + s_flag: true, + qrv: 0x02, + qqic: 0x12, + num_srcs: 0x01, + data: &QUERY_PACKET_PAYLOAD, + }), + Message::MldReport => Icmpv6Repr::Mld(Repr::Report { + nr_mcast_addr_rcrds: 1, + data: &REPORT_PACKET_PAYLOAD, + }), + _ => { + panic!("Message type must be a MLDv2 message type"); + } + } + } + + #[test] + fn test_query_deconstruct() { + let packet = Packet::new_unchecked(&QUERY_PACKET_BYTES[..]); + assert_eq!(packet.msg_type(), Message::MldQuery); + assert_eq!(packet.msg_code(), 0); + assert_eq!(packet.checksum(), 0x7374); + assert_eq!(packet.max_resp_code(), 0x0400); + assert_eq!(packet.mcast_addr(), Ipv6Address::LINK_LOCAL_ALL_NODES); + assert!(packet.s_flag()); + assert_eq!(packet.qrv(), 0x02); + assert_eq!(packet.qqic(), 0x12); + assert_eq!(packet.num_srcs(), 0x01); + assert_eq!( + Ipv6Address::from_bytes(packet.payload()), + Ipv6Address::LINK_LOCAL_ALL_ROUTERS + ); + } + + #[test] + fn test_query_construct() { + let mut bytes = vec![0xff; 44]; + let mut packet = Packet::new_unchecked(&mut bytes[..]); + packet.set_msg_type(Message::MldQuery); + packet.set_msg_code(0); + packet.set_max_resp_code(0x0400); + packet.set_mcast_addr(Ipv6Address::LINK_LOCAL_ALL_NODES); + packet.set_s_flag(); + packet.set_qrv(0x02); + packet.set_qqic(0x12); + packet.set_num_srcs(0x01); + packet + .payload_mut() + .copy_from_slice(Ipv6Address::LINK_LOCAL_ALL_ROUTERS.as_bytes()); + packet.clear_reserved(); + packet.fill_checksum( + &Ipv6Address::LINK_LOCAL_ALL_NODES.into(), + &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), + ); + assert_eq!(&*packet.into_inner(), &QUERY_PACKET_BYTES[..]); + } + + #[test] + fn test_record_deconstruct() { + let packet = Packet::new_unchecked(&REPORT_PACKET_BYTES[..]); + assert_eq!(packet.msg_type(), Message::MldReport); + assert_eq!(packet.msg_code(), 0); + assert_eq!(packet.checksum(), 0x7385); + assert_eq!(packet.nr_mcast_addr_rcrds(), 0x01); + let addr_rcrd = AddressRecord::new_unchecked(packet.payload()); + assert_eq!(addr_rcrd.record_type(), RecordType::ModeIsInclude); + assert_eq!(addr_rcrd.aux_data_len(), 0x00); + assert_eq!(addr_rcrd.num_srcs(), 0x01); + assert_eq!(addr_rcrd.mcast_addr(), Ipv6Address::LINK_LOCAL_ALL_NODES); + assert_eq!( + Ipv6Address::from_bytes(addr_rcrd.payload()), + Ipv6Address::LINK_LOCAL_ALL_ROUTERS + ); + } + + #[test] + fn test_record_construct() { + let mut bytes = vec![0xff; 44]; + let mut packet = Packet::new_unchecked(&mut bytes[..]); + packet.set_msg_type(Message::MldReport); + packet.set_msg_code(0); + packet.clear_reserved(); + packet.set_nr_mcast_addr_rcrds(1); + { + let mut addr_rcrd = AddressRecord::new_unchecked(packet.payload_mut()); + addr_rcrd.set_record_type(RecordType::ModeIsInclude); + addr_rcrd.set_aux_data_len(0); + addr_rcrd.set_num_srcs(1); + addr_rcrd.set_mcast_addr(Ipv6Address::LINK_LOCAL_ALL_NODES); + addr_rcrd + .payload_mut() + .copy_from_slice(Ipv6Address::LINK_LOCAL_ALL_ROUTERS.as_bytes()); + } + packet.fill_checksum( + &Ipv6Address::LINK_LOCAL_ALL_NODES.into(), + &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), + ); + assert_eq!(&*packet.into_inner(), &REPORT_PACKET_BYTES[..]); + } + + #[test] + fn test_query_repr_parse() { + let packet = Packet::new_unchecked(&QUERY_PACKET_BYTES[..]); + let repr = Icmpv6Repr::parse( + &Ipv6Address::LINK_LOCAL_ALL_NODES.into(), + &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), + &packet, + &ChecksumCapabilities::default(), + ); + assert_eq!(repr, Ok(create_repr(Message::MldQuery))); + } + + #[test] + fn test_report_repr_parse() { + let packet = Packet::new_unchecked(&REPORT_PACKET_BYTES[..]); + let repr = Icmpv6Repr::parse( + &Ipv6Address::LINK_LOCAL_ALL_NODES.into(), + &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), + &packet, + &ChecksumCapabilities::default(), + ); + assert_eq!(repr, Ok(create_repr(Message::MldReport))); + } + + #[test] + fn test_query_repr_emit() { + let mut bytes = [0x2a; 44]; + let mut packet = Packet::new_unchecked(&mut bytes[..]); + let repr = create_repr(Message::MldQuery); + repr.emit( + &Ipv6Address::LINK_LOCAL_ALL_NODES.into(), + &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), + &mut packet, + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &QUERY_PACKET_BYTES[..]); + } + + #[test] + fn test_report_repr_emit() { + let mut bytes = [0x2a; 44]; + let mut packet = Packet::new_unchecked(&mut bytes[..]); + let repr = create_repr(Message::MldReport); + repr.emit( + &Ipv6Address::LINK_LOCAL_ALL_NODES.into(), + &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), + &mut packet, + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &REPORT_PACKET_BYTES[..]); + } +} diff --git a/src/wire/mod.rs b/src/wire/mod.rs new file mode 100644 index 0000000..5aed2d4 --- /dev/null +++ b/src/wire/mod.rs @@ -0,0 +1,524 @@ +/*! Low-level packet access and construction. + +The `wire` module deals with the packet *representation*. It provides two levels +of functionality. + + * First, it provides functions to extract fields from sequences of octets, + and to insert fields into sequences of octets. This happens `Packet` family of + structures, e.g. [EthernetFrame] or [Ipv4Packet]. + * Second, in cases where the space of valid field values is much smaller than the space + of possible field values, it provides a compact, high-level representation + of packet data that can be parsed from and emitted into a sequence of octets. + This happens through the `Repr` family of structs and enums, e.g. [ArpRepr] or [Ipv4Repr]. + +[EthernetFrame]: struct.EthernetFrame.html +[Ipv4Packet]: struct.Ipv4Packet.html +[ArpRepr]: enum.ArpRepr.html +[Ipv4Repr]: struct.Ipv4Repr.html + +The functions in the `wire` module are designed for use together with `-Cpanic=abort`. + +The `Packet` family of data structures guarantees that, if the `Packet::check_len()` method +returned `Ok(())`, then no accessor or setter method will panic; however, the guarantee +provided by `Packet::check_len()` may no longer hold after changing certain fields, +which are listed in the documentation for the specific packet. + +The `Packet::new_checked` method is a shorthand for a combination of `Packet::new_unchecked` +and `Packet::check_len`. +When parsing untrusted input, it is *necessary* to use `Packet::new_checked()`; +so long as the buffer is not modified, no accessor will fail. +When emitting output, though, it is *incorrect* to use `Packet::new_checked()`; +the length check is likely to succeed on a zeroed buffer, but fail on a buffer +filled with data from a previous packet, such as when reusing buffers, resulting +in nondeterministic panics with some network devices but not others. +The buffer length for emission is not calculated by the `Packet` layer. + +In the `Repr` family of data structures, the `Repr::parse()` method never panics +as long as `Packet::new_checked()` (or `Packet::check_len()`) has succeeded, and +the `Repr::emit()` method never panics as long as the underlying buffer is exactly +`Repr::buffer_len()` octets long. + +# Examples + +To emit an IP packet header into an octet buffer, and then parse it back: + +```rust +# #[cfg(feature = "proto-ipv4")] +# { +use smoltcp::phy::ChecksumCapabilities; +use smoltcp::wire::*; +let repr = Ipv4Repr { + src_addr: Ipv4Address::new(10, 0, 0, 1), + dst_addr: Ipv4Address::new(10, 0, 0, 2), + next_header: IpProtocol::Tcp, + payload_len: 10, + hop_limit: 64, +}; +let mut buffer = vec![0; repr.buffer_len() + repr.payload_len]; +{ // emission + let mut packet = Ipv4Packet::new_unchecked(&mut buffer); + repr.emit(&mut packet, &ChecksumCapabilities::default()); +} +{ // parsing + let packet = Ipv4Packet::new_checked(&buffer) + .expect("truncated packet"); + let parsed = Ipv4Repr::parse(&packet, &ChecksumCapabilities::default()) + .expect("malformed packet"); + assert_eq!(repr, parsed); +} +# } +``` +*/ + +mod field { + pub type Field = ::core::ops::Range<usize>; + pub type Rest = ::core::ops::RangeFrom<usize>; +} + +pub mod pretty_print; + +#[cfg(all(feature = "proto-ipv4", feature = "medium-ethernet"))] +mod arp; +#[cfg(feature = "proto-dhcpv4")] +pub(crate) mod dhcpv4; +#[cfg(feature = "proto-dns")] +pub(crate) mod dns; +#[cfg(feature = "medium-ethernet")] +mod ethernet; +#[cfg(any(feature = "proto-ipv4", feature = "proto-ipv6"))] +mod icmp; +#[cfg(feature = "proto-ipv4")] +mod icmpv4; +#[cfg(feature = "proto-ipv6")] +mod icmpv6; +#[cfg(feature = "medium-ieee802154")] +pub mod ieee802154; +#[cfg(feature = "proto-igmp")] +mod igmp; +pub(crate) mod ip; +#[cfg(feature = "proto-ipv4")] +mod ipv4; +#[cfg(feature = "proto-ipv6")] +mod ipv6; +#[cfg(feature = "proto-ipv6")] +mod ipv6ext_header; +#[cfg(feature = "proto-ipv6")] +mod ipv6fragment; +#[cfg(feature = "proto-ipv6")] +mod ipv6hbh; +#[cfg(feature = "proto-ipv6")] +mod ipv6option; +#[cfg(feature = "proto-ipv6")] +mod ipv6routing; +#[cfg(feature = "proto-ipv6")] +mod mld; +#[cfg(all( + feature = "proto-ipv6", + any(feature = "medium-ethernet", feature = "medium-ieee802154") +))] +mod ndisc; +#[cfg(all( + feature = "proto-ipv6", + any(feature = "medium-ethernet", feature = "medium-ieee802154") +))] +mod ndiscoption; +#[cfg(feature = "proto-rpl")] +mod rpl; +#[cfg(all(feature = "proto-sixlowpan", feature = "medium-ieee802154"))] +mod sixlowpan; +mod tcp; +mod udp; + +#[cfg(feature = "proto-ipsec-ah")] +mod ipsec_ah; + +#[cfg(feature = "proto-ipsec-esp")] +mod ipsec_esp; + +use core::fmt; + +use crate::phy::Medium; + +pub use self::pretty_print::PrettyPrinter; + +#[cfg(feature = "medium-ethernet")] +pub use self::ethernet::{ + Address as EthernetAddress, EtherType as EthernetProtocol, Frame as EthernetFrame, + Repr as EthernetRepr, HEADER_LEN as ETHERNET_HEADER_LEN, +}; + +#[cfg(all(feature = "proto-ipv4", feature = "medium-ethernet"))] +pub use self::arp::{ + Hardware as ArpHardware, Operation as ArpOperation, Packet as ArpPacket, Repr as ArpRepr, +}; + +#[cfg(feature = "proto-rpl")] +pub use self::rpl::{ + data::HopByHopOption as RplHopByHopRepr, data::Packet as RplHopByHopPacket, + options::Packet as RplOptionPacket, options::Repr as RplOptionRepr, + InstanceId as RplInstanceId, Repr as RplRepr, +}; + +#[cfg(all(feature = "proto-sixlowpan", feature = "medium-ieee802154"))] +pub use self::sixlowpan::{ + frag::{Key as SixlowpanFragKey, Packet as SixlowpanFragPacket, Repr as SixlowpanFragRepr}, + iphc::{Packet as SixlowpanIphcPacket, Repr as SixlowpanIphcRepr}, + nhc::{ + ExtHeaderId as SixlowpanExtHeaderId, ExtHeaderPacket as SixlowpanExtHeaderPacket, + ExtHeaderRepr as SixlowpanExtHeaderRepr, NhcPacket as SixlowpanNhcPacket, + UdpNhcPacket as SixlowpanUdpNhcPacket, UdpNhcRepr as SixlowpanUdpNhcRepr, + }, + AddressContext as SixlowpanAddressContext, NextHeader as SixlowpanNextHeader, SixlowpanPacket, +}; + +#[cfg(feature = "medium-ieee802154")] +pub use self::ieee802154::{ + Address as Ieee802154Address, AddressingMode as Ieee802154AddressingMode, + Frame as Ieee802154Frame, FrameType as Ieee802154FrameType, + FrameVersion as Ieee802154FrameVersion, Pan as Ieee802154Pan, Repr as Ieee802154Repr, +}; + +pub use self::ip::{ + Address as IpAddress, Cidr as IpCidr, Endpoint as IpEndpoint, + ListenEndpoint as IpListenEndpoint, Protocol as IpProtocol, Repr as IpRepr, + Version as IpVersion, +}; + +#[cfg(feature = "proto-ipv4")] +pub use self::ipv4::{ + Address as Ipv4Address, Cidr as Ipv4Cidr, Key as Ipv4FragKey, Packet as Ipv4Packet, + Repr as Ipv4Repr, HEADER_LEN as IPV4_HEADER_LEN, MIN_MTU as IPV4_MIN_MTU, +}; + +#[cfg(feature = "proto-ipv6")] +pub(crate) use self::ipv6::Scope as Ipv6AddressScope; +#[cfg(feature = "proto-ipv6")] +pub use self::ipv6::{ + Address as Ipv6Address, Cidr as Ipv6Cidr, Packet as Ipv6Packet, Repr as Ipv6Repr, + HEADER_LEN as IPV6_HEADER_LEN, MIN_MTU as IPV6_MIN_MTU, +}; + +#[cfg(feature = "proto-ipv6")] +pub use self::ipv6option::{ + FailureType as Ipv6OptionFailureType, Ipv6Option, Ipv6OptionsIterator, Repr as Ipv6OptionRepr, + Type as Ipv6OptionType, +}; + +#[cfg(feature = "proto-ipv6")] +pub use self::ipv6ext_header::{Header as Ipv6ExtHeader, Repr as Ipv6ExtHeaderRepr}; + +#[cfg(feature = "proto-ipv6")] +pub use self::ipv6fragment::{Header as Ipv6FragmentHeader, Repr as Ipv6FragmentRepr}; + +#[cfg(feature = "proto-ipv6")] +pub use self::ipv6hbh::{Header as Ipv6HopByHopHeader, Repr as Ipv6HopByHopRepr}; + +#[cfg(feature = "proto-ipv6")] +pub use self::ipv6routing::{ + Header as Ipv6RoutingHeader, Repr as Ipv6RoutingRepr, Type as Ipv6RoutingType, +}; + +#[cfg(feature = "proto-ipv4")] +pub use self::icmpv4::{ + DstUnreachable as Icmpv4DstUnreachable, Message as Icmpv4Message, Packet as Icmpv4Packet, + ParamProblem as Icmpv4ParamProblem, Redirect as Icmpv4Redirect, Repr as Icmpv4Repr, + TimeExceeded as Icmpv4TimeExceeded, +}; + +#[cfg(feature = "proto-igmp")] +pub use self::igmp::{IgmpVersion, Packet as IgmpPacket, Repr as IgmpRepr}; + +#[cfg(feature = "proto-ipv6")] +pub use self::icmpv6::{ + DstUnreachable as Icmpv6DstUnreachable, Message as Icmpv6Message, Packet as Icmpv6Packet, + ParamProblem as Icmpv6ParamProblem, Repr as Icmpv6Repr, TimeExceeded as Icmpv6TimeExceeded, +}; + +#[cfg(any(feature = "proto-ipv4", feature = "proto-ipv6"))] +pub use self::icmp::Repr as IcmpRepr; + +#[cfg(all( + feature = "proto-ipv6", + any(feature = "medium-ethernet", feature = "medium-ieee802154") +))] +pub use self::ndisc::{ + NeighborFlags as NdiscNeighborFlags, Repr as NdiscRepr, RouterFlags as NdiscRouterFlags, +}; + +#[cfg(all( + feature = "proto-ipv6", + any(feature = "medium-ethernet", feature = "medium-ieee802154") +))] +pub use self::ndiscoption::{ + NdiscOption, PrefixInfoFlags as NdiscPrefixInfoFlags, + PrefixInformation as NdiscPrefixInformation, RedirectedHeader as NdiscRedirectedHeader, + Repr as NdiscOptionRepr, Type as NdiscOptionType, +}; + +#[cfg(feature = "proto-ipv6")] +pub use self::mld::{AddressRecord as MldAddressRecord, Repr as MldRepr}; + +pub use self::udp::{Packet as UdpPacket, Repr as UdpRepr, HEADER_LEN as UDP_HEADER_LEN}; + +pub use self::tcp::{ + Control as TcpControl, Packet as TcpPacket, Repr as TcpRepr, SeqNumber as TcpSeqNumber, + TcpOption, HEADER_LEN as TCP_HEADER_LEN, +}; + +#[cfg(feature = "proto-dhcpv4")] +pub use self::dhcpv4::{ + DhcpOption, DhcpOptionWriter, MessageType as DhcpMessageType, Packet as DhcpPacket, + Repr as DhcpRepr, CLIENT_PORT as DHCP_CLIENT_PORT, + MAX_DNS_SERVER_COUNT as DHCP_MAX_DNS_SERVER_COUNT, SERVER_PORT as DHCP_SERVER_PORT, +}; + +#[cfg(feature = "proto-dns")] +pub use self::dns::{ + Flags as DnsFlags, Opcode as DnsOpcode, Packet as DnsPacket, Rcode as DnsRcode, + Repr as DnsRepr, Type as DnsQueryType, +}; + +#[cfg(feature = "proto-ipsec-ah")] +pub use self::ipsec_ah::{Packet as IpSecAuthHeaderPacket, Repr as IpSecAuthHeaderRepr}; + +#[cfg(feature = "proto-ipsec-esp")] +pub use self::ipsec_esp::{Packet as IpSecEspPacket, Repr as IpSecEspRepr}; + +/// Parsing a packet failed. +/// +/// Either it is malformed, or it is not supported by smoltcp. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Error; + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "wire::Error") + } +} + +pub type Result<T> = core::result::Result<T, Error>; + +/// Representation of an hardware address, such as an Ethernet address or an IEEE802.15.4 address. +#[cfg(any( + feature = "medium-ip", + feature = "medium-ethernet", + feature = "medium-ieee802154" +))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum HardwareAddress { + #[cfg(feature = "medium-ip")] + Ip, + #[cfg(feature = "medium-ethernet")] + Ethernet(EthernetAddress), + #[cfg(feature = "medium-ieee802154")] + Ieee802154(Ieee802154Address), +} + +#[cfg(any( + feature = "medium-ip", + feature = "medium-ethernet", + feature = "medium-ieee802154" +))] +impl HardwareAddress { + pub const fn as_bytes(&self) -> &[u8] { + match self { + #[cfg(feature = "medium-ip")] + HardwareAddress::Ip => unreachable!(), + #[cfg(feature = "medium-ethernet")] + HardwareAddress::Ethernet(addr) => addr.as_bytes(), + #[cfg(feature = "medium-ieee802154")] + HardwareAddress::Ieee802154(addr) => addr.as_bytes(), + } + } + + /// Query whether the address is an unicast address. + pub fn is_unicast(&self) -> bool { + match self { + #[cfg(feature = "medium-ip")] + HardwareAddress::Ip => unreachable!(), + #[cfg(feature = "medium-ethernet")] + HardwareAddress::Ethernet(addr) => addr.is_unicast(), + #[cfg(feature = "medium-ieee802154")] + HardwareAddress::Ieee802154(addr) => addr.is_unicast(), + } + } + + /// Query whether the address is a broadcast address. + pub fn is_broadcast(&self) -> bool { + match self { + #[cfg(feature = "medium-ip")] + HardwareAddress::Ip => unreachable!(), + #[cfg(feature = "medium-ethernet")] + HardwareAddress::Ethernet(addr) => addr.is_broadcast(), + #[cfg(feature = "medium-ieee802154")] + HardwareAddress::Ieee802154(addr) => addr.is_broadcast(), + } + } + + #[cfg(feature = "medium-ethernet")] + pub(crate) fn ethernet_or_panic(&self) -> EthernetAddress { + match self { + HardwareAddress::Ethernet(addr) => *addr, + #[allow(unreachable_patterns)] + _ => panic!("HardwareAddress is not Ethernet."), + } + } + + #[cfg(feature = "medium-ieee802154")] + pub(crate) fn ieee802154_or_panic(&self) -> Ieee802154Address { + match self { + HardwareAddress::Ieee802154(addr) => *addr, + #[allow(unreachable_patterns)] + _ => panic!("HardwareAddress is not Ethernet."), + } + } + + #[inline] + pub(crate) fn medium(&self) -> Medium { + match self { + #[cfg(feature = "medium-ip")] + HardwareAddress::Ip => Medium::Ip, + #[cfg(feature = "medium-ethernet")] + HardwareAddress::Ethernet(_) => Medium::Ethernet, + #[cfg(feature = "medium-ieee802154")] + HardwareAddress::Ieee802154(_) => Medium::Ieee802154, + } + } +} + +#[cfg(any( + feature = "medium-ip", + feature = "medium-ethernet", + feature = "medium-ieee802154" +))] +impl core::fmt::Display for HardwareAddress { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + #[cfg(feature = "medium-ip")] + HardwareAddress::Ip => write!(f, "no hardware addr"), + #[cfg(feature = "medium-ethernet")] + HardwareAddress::Ethernet(addr) => write!(f, "{addr}"), + #[cfg(feature = "medium-ieee802154")] + HardwareAddress::Ieee802154(addr) => write!(f, "{addr}"), + } + } +} + +#[cfg(feature = "medium-ethernet")] +impl From<EthernetAddress> for HardwareAddress { + fn from(addr: EthernetAddress) -> Self { + HardwareAddress::Ethernet(addr) + } +} + +#[cfg(feature = "medium-ieee802154")] +impl From<Ieee802154Address> for HardwareAddress { + fn from(addr: Ieee802154Address) -> Self { + HardwareAddress::Ieee802154(addr) + } +} + +#[cfg(not(feature = "medium-ieee802154"))] +pub const MAX_HARDWARE_ADDRESS_LEN: usize = 6; +#[cfg(feature = "medium-ieee802154")] +pub const MAX_HARDWARE_ADDRESS_LEN: usize = 8; + +/// Unparsed hardware address. +/// +/// Used to make NDISC parsing agnostic of the hardware medium in use. +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct RawHardwareAddress { + len: u8, + data: [u8; MAX_HARDWARE_ADDRESS_LEN], +} + +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +impl RawHardwareAddress { + pub fn from_bytes(addr: &[u8]) -> Self { + let mut data = [0u8; MAX_HARDWARE_ADDRESS_LEN]; + data[..addr.len()].copy_from_slice(addr); + + Self { + len: addr.len() as u8, + data, + } + } + + pub fn as_bytes(&self) -> &[u8] { + &self.data[..self.len as usize] + } + + pub const fn len(&self) -> usize { + self.len as usize + } + + pub const fn is_empty(&self) -> bool { + self.len == 0 + } + + pub fn parse(&self, medium: Medium) -> Result<HardwareAddress> { + match medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => { + if self.len() < 6 { + return Err(Error); + } + Ok(HardwareAddress::Ethernet(EthernetAddress::from_bytes( + self.as_bytes(), + ))) + } + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => { + if self.len() < 8 { + return Err(Error); + } + Ok(HardwareAddress::Ieee802154(Ieee802154Address::from_bytes( + self.as_bytes(), + ))) + } + #[cfg(feature = "medium-ip")] + Medium::Ip => unreachable!(), + } + } +} + +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +impl core::fmt::Display for RawHardwareAddress { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + for (i, &b) in self.as_bytes().iter().enumerate() { + if i != 0 { + write!(f, ":")?; + } + write!(f, "{b:02x}")?; + } + Ok(()) + } +} + +#[cfg(feature = "medium-ethernet")] +impl From<EthernetAddress> for RawHardwareAddress { + fn from(addr: EthernetAddress) -> Self { + Self::from_bytes(addr.as_bytes()) + } +} + +#[cfg(feature = "medium-ieee802154")] +impl From<Ieee802154Address> for RawHardwareAddress { + fn from(addr: Ieee802154Address) -> Self { + Self::from_bytes(addr.as_bytes()) + } +} + +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +impl From<HardwareAddress> for RawHardwareAddress { + fn from(addr: HardwareAddress) -> Self { + Self::from_bytes(addr.as_bytes()) + } +} diff --git a/src/wire/ndisc.rs b/src/wire/ndisc.rs new file mode 100644 index 0000000..691b69b --- /dev/null +++ b/src/wire/ndisc.rs @@ -0,0 +1,541 @@ +use bitflags::bitflags; +use byteorder::{ByteOrder, NetworkEndian}; + +use super::{Error, Result}; +use crate::time::Duration; +use crate::wire::icmpv6::{field, Message, Packet}; +use crate::wire::Ipv6Address; +use crate::wire::RawHardwareAddress; +use crate::wire::{NdiscOption, NdiscOptionRepr}; +use crate::wire::{NdiscPrefixInformation, NdiscRedirectedHeader}; + +bitflags! { + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct RouterFlags: u8 { + const MANAGED = 0b10000000; + const OTHER = 0b01000000; + } +} + +bitflags! { + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct NeighborFlags: u8 { + const ROUTER = 0b10000000; + const SOLICITED = 0b01000000; + const OVERRIDE = 0b00100000; + } +} + +/// Getters for the Router Advertisement message header. +/// See [RFC 4861 § 4.2]. +/// +/// [RFC 4861 § 4.2]: https://tools.ietf.org/html/rfc4861#section-4.2 +impl<T: AsRef<[u8]>> Packet<T> { + /// Return the current hop limit field. + #[inline] + pub fn current_hop_limit(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::CUR_HOP_LIMIT] + } + + /// Return the Router Advertisement flags. + #[inline] + pub fn router_flags(&self) -> RouterFlags { + let data = self.buffer.as_ref(); + RouterFlags::from_bits_truncate(data[field::ROUTER_FLAGS]) + } + + /// Return the router lifetime field. + #[inline] + pub fn router_lifetime(&self) -> Duration { + let data = self.buffer.as_ref(); + Duration::from_secs(NetworkEndian::read_u16(&data[field::ROUTER_LT]) as u64) + } + + /// Return the reachable time field. + #[inline] + pub fn reachable_time(&self) -> Duration { + let data = self.buffer.as_ref(); + Duration::from_millis(NetworkEndian::read_u32(&data[field::REACHABLE_TM]) as u64) + } + + /// Return the retransmit time field. + #[inline] + pub fn retrans_time(&self) -> Duration { + let data = self.buffer.as_ref(); + Duration::from_millis(NetworkEndian::read_u32(&data[field::RETRANS_TM]) as u64) + } +} + +/// Common getters for the [Neighbor Solicitation], [Neighbor Advertisement], and +/// [Redirect] message types. +/// +/// [Neighbor Solicitation]: https://tools.ietf.org/html/rfc4861#section-4.3 +/// [Neighbor Advertisement]: https://tools.ietf.org/html/rfc4861#section-4.4 +/// [Redirect]: https://tools.ietf.org/html/rfc4861#section-4.5 +impl<T: AsRef<[u8]>> Packet<T> { + /// Return the target address field. + #[inline] + pub fn target_addr(&self) -> Ipv6Address { + let data = self.buffer.as_ref(); + Ipv6Address::from_bytes(&data[field::TARGET_ADDR]) + } +} + +/// Getters for the Neighbor Solicitation message header. +/// See [RFC 4861 § 4.3]. +/// +/// [RFC 4861 § 4.3]: https://tools.ietf.org/html/rfc4861#section-4.3 +impl<T: AsRef<[u8]>> Packet<T> { + /// Return the Neighbor Solicitation flags. + #[inline] + pub fn neighbor_flags(&self) -> NeighborFlags { + let data = self.buffer.as_ref(); + NeighborFlags::from_bits_truncate(data[field::NEIGH_FLAGS]) + } +} + +/// Getters for the Redirect message header. +/// See [RFC 4861 § 4.5]. +/// +/// [RFC 4861 § 4.5]: https://tools.ietf.org/html/rfc4861#section-4.5 +impl<T: AsRef<[u8]>> Packet<T> { + /// Return the destination address field. + #[inline] + pub fn dest_addr(&self) -> Ipv6Address { + let data = self.buffer.as_ref(); + Ipv6Address::from_bytes(&data[field::DEST_ADDR]) + } +} + +/// Setters for the Router Advertisement message header. +/// See [RFC 4861 § 4.2]. +/// +/// [RFC 4861 § 4.2]: https://tools.ietf.org/html/rfc4861#section-4.2 +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the current hop limit field. + #[inline] + pub fn set_current_hop_limit(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::CUR_HOP_LIMIT] = value; + } + + /// Set the Router Advertisement flags. + #[inline] + pub fn set_router_flags(&mut self, flags: RouterFlags) { + self.buffer.as_mut()[field::ROUTER_FLAGS] = flags.bits(); + } + + /// Set the router lifetime field. + #[inline] + pub fn set_router_lifetime(&mut self, value: Duration) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::ROUTER_LT], value.secs() as u16); + } + + /// Set the reachable time field. + #[inline] + pub fn set_reachable_time(&mut self, value: Duration) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::REACHABLE_TM], value.total_millis() as u32); + } + + /// Set the retransmit time field. + #[inline] + pub fn set_retrans_time(&mut self, value: Duration) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::RETRANS_TM], value.total_millis() as u32); + } +} + +/// Common setters for the [Neighbor Solicitation], [Neighbor Advertisement], and +/// [Redirect] message types. +/// +/// [Neighbor Solicitation]: https://tools.ietf.org/html/rfc4861#section-4.3 +/// [Neighbor Advertisement]: https://tools.ietf.org/html/rfc4861#section-4.4 +/// [Redirect]: https://tools.ietf.org/html/rfc4861#section-4.5 +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the target address field. + #[inline] + pub fn set_target_addr(&mut self, value: Ipv6Address) { + let data = self.buffer.as_mut(); + data[field::TARGET_ADDR].copy_from_slice(value.as_bytes()); + } +} + +/// Setters for the Neighbor Solicitation message header. +/// See [RFC 4861 § 4.3]. +/// +/// [RFC 4861 § 4.3]: https://tools.ietf.org/html/rfc4861#section-4.3 +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the Neighbor Solicitation flags. + #[inline] + pub fn set_neighbor_flags(&mut self, flags: NeighborFlags) { + self.buffer.as_mut()[field::NEIGH_FLAGS] = flags.bits(); + } +} + +/// Setters for the Redirect message header. +/// See [RFC 4861 § 4.5]. +/// +/// [RFC 4861 § 4.5]: https://tools.ietf.org/html/rfc4861#section-4.5 +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the destination address field. + #[inline] + pub fn set_dest_addr(&mut self, value: Ipv6Address) { + let data = self.buffer.as_mut(); + data[field::DEST_ADDR].copy_from_slice(value.as_bytes()); + } +} + +/// A high-level representation of an Neighbor Discovery packet header. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Repr<'a> { + RouterSolicit { + lladdr: Option<RawHardwareAddress>, + }, + RouterAdvert { + hop_limit: u8, + flags: RouterFlags, + router_lifetime: Duration, + reachable_time: Duration, + retrans_time: Duration, + lladdr: Option<RawHardwareAddress>, + mtu: Option<u32>, + prefix_info: Option<NdiscPrefixInformation>, + }, + NeighborSolicit { + target_addr: Ipv6Address, + lladdr: Option<RawHardwareAddress>, + }, + NeighborAdvert { + flags: NeighborFlags, + target_addr: Ipv6Address, + lladdr: Option<RawHardwareAddress>, + }, + Redirect { + target_addr: Ipv6Address, + dest_addr: Ipv6Address, + lladdr: Option<RawHardwareAddress>, + redirected_hdr: Option<NdiscRedirectedHeader<'a>>, + }, +} + +impl<'a> Repr<'a> { + /// Parse an NDISC packet and return a high-level representation of the + /// packet. + #[allow(clippy::single_match)] + pub fn parse<T>(packet: &Packet<&'a T>) -> Result<Repr<'a>> + where + T: AsRef<[u8]> + ?Sized, + { + let (mut src_ll_addr, mut mtu, mut prefix_info, mut target_ll_addr, mut redirected_hdr) = + (None, None, None, None, None); + + let mut offset = 0; + while packet.payload().len() > offset { + let pkt = NdiscOption::new_checked(&packet.payload()[offset..])?; + + // If an option doesn't parse, ignore it and still parse the others. + if let Ok(opt) = NdiscOptionRepr::parse(&pkt) { + match opt { + NdiscOptionRepr::SourceLinkLayerAddr(addr) => src_ll_addr = Some(addr), + NdiscOptionRepr::TargetLinkLayerAddr(addr) => target_ll_addr = Some(addr), + NdiscOptionRepr::PrefixInformation(prefix) => prefix_info = Some(prefix), + NdiscOptionRepr::RedirectedHeader(redirect) => redirected_hdr = Some(redirect), + NdiscOptionRepr::Mtu(m) => mtu = Some(m), + _ => {} + } + } + + let len = pkt.data_len() as usize * 8; + if len == 0 { + return Err(Error); + } + offset += len; + } + + match packet.msg_type() { + Message::RouterSolicit => Ok(Repr::RouterSolicit { + lladdr: src_ll_addr, + }), + Message::RouterAdvert => Ok(Repr::RouterAdvert { + hop_limit: packet.current_hop_limit(), + flags: packet.router_flags(), + router_lifetime: packet.router_lifetime(), + reachable_time: packet.reachable_time(), + retrans_time: packet.retrans_time(), + lladdr: src_ll_addr, + mtu, + prefix_info, + }), + Message::NeighborSolicit => Ok(Repr::NeighborSolicit { + target_addr: packet.target_addr(), + lladdr: src_ll_addr, + }), + Message::NeighborAdvert => Ok(Repr::NeighborAdvert { + flags: packet.neighbor_flags(), + target_addr: packet.target_addr(), + lladdr: target_ll_addr, + }), + Message::Redirect => Ok(Repr::Redirect { + target_addr: packet.target_addr(), + dest_addr: packet.dest_addr(), + lladdr: src_ll_addr, + redirected_hdr, + }), + _ => Err(Error), + } + } + + pub const fn buffer_len(&self) -> usize { + match self { + &Repr::RouterSolicit { lladdr } => match lladdr { + Some(addr) => { + field::UNUSED.end + { NdiscOptionRepr::SourceLinkLayerAddr(addr).buffer_len() } + } + None => field::UNUSED.end, + }, + &Repr::RouterAdvert { + lladdr, + mtu, + prefix_info, + .. + } => { + let mut offset = 0; + if let Some(lladdr) = lladdr { + offset += NdiscOptionRepr::TargetLinkLayerAddr(lladdr).buffer_len(); + } + if let Some(mtu) = mtu { + offset += NdiscOptionRepr::Mtu(mtu).buffer_len(); + } + if let Some(prefix_info) = prefix_info { + offset += NdiscOptionRepr::PrefixInformation(prefix_info).buffer_len(); + } + field::RETRANS_TM.end + offset + } + &Repr::NeighborSolicit { lladdr, .. } | &Repr::NeighborAdvert { lladdr, .. } => { + let mut offset = field::TARGET_ADDR.end; + if let Some(lladdr) = lladdr { + offset += NdiscOptionRepr::SourceLinkLayerAddr(lladdr).buffer_len(); + } + offset + } + &Repr::Redirect { + lladdr, + redirected_hdr, + .. + } => { + let mut offset = field::DEST_ADDR.end; + if let Some(lladdr) = lladdr { + offset += NdiscOptionRepr::TargetLinkLayerAddr(lladdr).buffer_len(); + } + if let Some(NdiscRedirectedHeader { header, data }) = redirected_hdr { + offset += + NdiscOptionRepr::RedirectedHeader(NdiscRedirectedHeader { header, data }) + .buffer_len(); + } + offset + } + } + } + + pub fn emit<T>(&self, packet: &mut Packet<&mut T>) + where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { + match *self { + Repr::RouterSolicit { lladdr } => { + packet.set_msg_type(Message::RouterSolicit); + packet.set_msg_code(0); + packet.clear_reserved(); + if let Some(lladdr) = lladdr { + let mut opt_pkt = NdiscOption::new_unchecked(packet.payload_mut()); + NdiscOptionRepr::SourceLinkLayerAddr(lladdr).emit(&mut opt_pkt); + } + } + + Repr::RouterAdvert { + hop_limit, + flags, + router_lifetime, + reachable_time, + retrans_time, + lladdr, + mtu, + prefix_info, + } => { + packet.set_msg_type(Message::RouterAdvert); + packet.set_msg_code(0); + packet.set_current_hop_limit(hop_limit); + packet.set_router_flags(flags); + packet.set_router_lifetime(router_lifetime); + packet.set_reachable_time(reachable_time); + packet.set_retrans_time(retrans_time); + let mut offset = 0; + if let Some(lladdr) = lladdr { + let mut opt_pkt = NdiscOption::new_unchecked(packet.payload_mut()); + let opt = NdiscOptionRepr::SourceLinkLayerAddr(lladdr); + opt.emit(&mut opt_pkt); + offset += opt.buffer_len(); + } + if let Some(mtu) = mtu { + let mut opt_pkt = + NdiscOption::new_unchecked(&mut packet.payload_mut()[offset..]); + NdiscOptionRepr::Mtu(mtu).emit(&mut opt_pkt); + offset += NdiscOptionRepr::Mtu(mtu).buffer_len(); + } + if let Some(prefix_info) = prefix_info { + let mut opt_pkt = + NdiscOption::new_unchecked(&mut packet.payload_mut()[offset..]); + NdiscOptionRepr::PrefixInformation(prefix_info).emit(&mut opt_pkt) + } + } + + Repr::NeighborSolicit { + target_addr, + lladdr, + } => { + packet.set_msg_type(Message::NeighborSolicit); + packet.set_msg_code(0); + packet.clear_reserved(); + packet.set_target_addr(target_addr); + if let Some(lladdr) = lladdr { + let mut opt_pkt = NdiscOption::new_unchecked(packet.payload_mut()); + NdiscOptionRepr::SourceLinkLayerAddr(lladdr).emit(&mut opt_pkt); + } + } + + Repr::NeighborAdvert { + flags, + target_addr, + lladdr, + } => { + packet.set_msg_type(Message::NeighborAdvert); + packet.set_msg_code(0); + packet.clear_reserved(); + packet.set_neighbor_flags(flags); + packet.set_target_addr(target_addr); + if let Some(lladdr) = lladdr { + let mut opt_pkt = NdiscOption::new_unchecked(packet.payload_mut()); + NdiscOptionRepr::TargetLinkLayerAddr(lladdr).emit(&mut opt_pkt); + } + } + + Repr::Redirect { + target_addr, + dest_addr, + lladdr, + redirected_hdr, + } => { + packet.set_msg_type(Message::Redirect); + packet.set_msg_code(0); + packet.clear_reserved(); + packet.set_target_addr(target_addr); + packet.set_dest_addr(dest_addr); + let offset = match lladdr { + Some(lladdr) => { + let mut opt_pkt = NdiscOption::new_unchecked(packet.payload_mut()); + NdiscOptionRepr::TargetLinkLayerAddr(lladdr).emit(&mut opt_pkt); + NdiscOptionRepr::TargetLinkLayerAddr(lladdr).buffer_len() + } + None => 0, + }; + if let Some(redirected_hdr) = redirected_hdr { + let mut opt_pkt = + NdiscOption::new_unchecked(&mut packet.payload_mut()[offset..]); + NdiscOptionRepr::RedirectedHeader(redirected_hdr).emit(&mut opt_pkt); + } + } + } + } +} + +#[cfg(feature = "medium-ethernet")] +#[cfg(test)] +mod test { + use super::*; + use crate::phy::ChecksumCapabilities; + use crate::wire::ip::test::{MOCK_IP_ADDR_1, MOCK_IP_ADDR_2}; + use crate::wire::EthernetAddress; + use crate::wire::Icmpv6Repr; + + static ROUTER_ADVERT_BYTES: [u8; 24] = [ + 0x86, 0x00, 0xa9, 0xde, 0x40, 0x80, 0x03, 0x84, 0x00, 0x00, 0x03, 0x84, 0x00, 0x00, 0x03, + 0x84, 0x01, 0x01, 0x52, 0x54, 0x00, 0x12, 0x34, 0x56, + ]; + static SOURCE_LINK_LAYER_OPT: [u8; 8] = [0x01, 0x01, 0x52, 0x54, 0x00, 0x12, 0x34, 0x56]; + + fn create_repr<'a>() -> Icmpv6Repr<'a> { + Icmpv6Repr::Ndisc(Repr::RouterAdvert { + hop_limit: 64, + flags: RouterFlags::MANAGED, + router_lifetime: Duration::from_secs(900), + reachable_time: Duration::from_millis(900), + retrans_time: Duration::from_millis(900), + lladdr: Some(EthernetAddress([0x52, 0x54, 0x00, 0x12, 0x34, 0x56]).into()), + mtu: None, + prefix_info: None, + }) + } + + #[test] + fn test_router_advert_deconstruct() { + let packet = Packet::new_unchecked(&ROUTER_ADVERT_BYTES[..]); + assert_eq!(packet.msg_type(), Message::RouterAdvert); + assert_eq!(packet.msg_code(), 0); + assert_eq!(packet.current_hop_limit(), 64); + assert_eq!(packet.router_flags(), RouterFlags::MANAGED); + assert_eq!(packet.router_lifetime(), Duration::from_secs(900)); + assert_eq!(packet.reachable_time(), Duration::from_millis(900)); + assert_eq!(packet.retrans_time(), Duration::from_millis(900)); + assert_eq!(packet.payload(), &SOURCE_LINK_LAYER_OPT[..]); + } + + #[test] + fn test_router_advert_construct() { + let mut bytes = vec![0x0; 24]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_msg_type(Message::RouterAdvert); + packet.set_msg_code(0); + packet.set_current_hop_limit(64); + packet.set_router_flags(RouterFlags::MANAGED); + packet.set_router_lifetime(Duration::from_secs(900)); + packet.set_reachable_time(Duration::from_millis(900)); + packet.set_retrans_time(Duration::from_millis(900)); + packet + .payload_mut() + .copy_from_slice(&SOURCE_LINK_LAYER_OPT[..]); + packet.fill_checksum(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2); + assert_eq!(&*packet.into_inner(), &ROUTER_ADVERT_BYTES[..]); + } + + #[test] + fn test_router_advert_repr_parse() { + let packet = Packet::new_unchecked(&ROUTER_ADVERT_BYTES[..]); + assert_eq!( + Icmpv6Repr::parse( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &packet, + &ChecksumCapabilities::default() + ) + .unwrap(), + create_repr() + ); + } + + #[test] + fn test_router_advert_repr_emit() { + let mut bytes = vec![0x2a; 24]; + let mut packet = Packet::new_unchecked(&mut bytes[..]); + create_repr().emit( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &mut packet, + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &ROUTER_ADVERT_BYTES[..]); + } +} diff --git a/src/wire/ndiscoption.rs b/src/wire/ndiscoption.rs new file mode 100644 index 0000000..eff7a93 --- /dev/null +++ b/src/wire/ndiscoption.rs @@ -0,0 +1,768 @@ +use bitflags::bitflags; +use byteorder::{ByteOrder, NetworkEndian}; +use core::fmt; + +use super::{Error, Result}; +use crate::time::Duration; +use crate::wire::{Ipv6Address, Ipv6Packet, Ipv6Repr, MAX_HARDWARE_ADDRESS_LEN}; + +use crate::wire::RawHardwareAddress; + +enum_with_unknown! { + /// NDISC Option Type + pub enum Type(u8) { + /// Source Link-layer Address + SourceLinkLayerAddr = 0x1, + /// Target Link-layer Address + TargetLinkLayerAddr = 0x2, + /// Prefix Information + PrefixInformation = 0x3, + /// Redirected Header + RedirectedHeader = 0x4, + /// MTU + Mtu = 0x5 + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Type::SourceLinkLayerAddr => write!(f, "source link-layer address"), + Type::TargetLinkLayerAddr => write!(f, "target link-layer address"), + Type::PrefixInformation => write!(f, "prefix information"), + Type::RedirectedHeader => write!(f, "redirected header"), + Type::Mtu => write!(f, "mtu"), + Type::Unknown(id) => write!(f, "{id}"), + } + } +} + +bitflags! { + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct PrefixInfoFlags: u8 { + const ON_LINK = 0b10000000; + const ADDRCONF = 0b01000000; + } +} + +/// A read/write wrapper around an [NDISC Option]. +/// +/// [NDISC Option]: https://tools.ietf.org/html/rfc4861#section-4.6 +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct NdiscOption<T: AsRef<[u8]>> { + buffer: T, +} + +// Format of an NDISC Option +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type | Length | ... | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// ~ ... ~ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// See https://tools.ietf.org/html/rfc4861#section-4.6 for details. +mod field { + #![allow(non_snake_case)] + + use crate::wire::field::*; + + // 8-bit identifier of the type of option. + pub const TYPE: usize = 0; + // 8-bit unsigned integer. Length of the option, in units of 8 octets. + pub const LENGTH: usize = 1; + // Minimum length of an option. + pub const MIN_OPT_LEN: usize = 8; + // Variable-length field. Option-Type-specific data. + pub const fn DATA(length: u8) -> Field { + 2..length as usize * 8 + } + + // Source/Target Link-layer Option fields. + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Type | Length | Link-Layer Address ... + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + // Prefix Information Option fields. + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Type | Length | Prefix Length |L|A| Reserved1 | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Valid Lifetime | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Preferred Lifetime | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Reserved2 | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | | + // + + + // | | + // + Prefix + + // | | + // + + + // | | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + // Prefix length. + pub const PREFIX_LEN: usize = 2; + // Flags field of prefix header. + pub const FLAGS: usize = 3; + // Valid lifetime. + pub const VALID_LT: Field = 4..8; + // Preferred lifetime. + pub const PREF_LT: Field = 8..12; + // Reserved bits + pub const PREF_RESERVED: Field = 12..16; + // Prefix + pub const PREFIX: Field = 16..32; + + // Redirected Header Option fields. + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Type | Length | Reserved | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Reserved | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | | + // ~ IP header + data ~ + // | | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + // Reserved bits. + pub const REDIRECTED_RESERVED: Field = 2..8; + pub const REDIR_MIN_SZ: usize = 48; + + // MTU Option fields + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Type | Length | Reserved | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | MTU | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + // MTU + pub const MTU: Field = 4..8; +} + +/// Core getter methods relevant to any type of NDISC option. +impl<T: AsRef<[u8]>> NdiscOption<T> { + /// Create a raw octet buffer with an NDISC Option structure. + pub const fn new_unchecked(buffer: T) -> NdiscOption<T> { + NdiscOption { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<NdiscOption<T>> { + let opt = Self::new_unchecked(buffer); + opt.check_len()?; + + // A data length field of 0 is invalid. + if opt.data_len() == 0 { + return Err(Error); + } + + Ok(opt) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + /// + /// The result of this check is invalidated by calling [set_data_len]. + /// + /// [set_data_len]: #method.set_data_len + pub fn check_len(&self) -> Result<()> { + let data = self.buffer.as_ref(); + let len = data.len(); + + if len < field::MIN_OPT_LEN { + Err(Error) + } else { + let data_range = field::DATA(data[field::LENGTH]); + if len < data_range.end { + Err(Error) + } else { + match self.option_type() { + Type::SourceLinkLayerAddr | Type::TargetLinkLayerAddr | Type::Mtu => Ok(()), + Type::PrefixInformation if data_range.end >= field::PREFIX.end => Ok(()), + Type::RedirectedHeader if data_range.end >= field::REDIR_MIN_SZ => Ok(()), + Type::Unknown(_) => Ok(()), + _ => Err(Error), + } + } + } + } + + /// Consume the NDISC option, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the option type. + #[inline] + pub fn option_type(&self) -> Type { + let data = self.buffer.as_ref(); + Type::from(data[field::TYPE]) + } + + /// Return the length of the data. + #[inline] + pub fn data_len(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::LENGTH] + } +} + +/// Getter methods only relevant for Source/Target Link-layer Address options. +impl<T: AsRef<[u8]>> NdiscOption<T> { + /// Return the Source/Target Link-layer Address. + #[inline] + pub fn link_layer_addr(&self) -> RawHardwareAddress { + let len = MAX_HARDWARE_ADDRESS_LEN.min(self.data_len() as usize * 8 - 2); + let data = self.buffer.as_ref(); + RawHardwareAddress::from_bytes(&data[2..len + 2]) + } +} + +/// Getter methods only relevant for the MTU option. +impl<T: AsRef<[u8]>> NdiscOption<T> { + /// Return the MTU value. + #[inline] + pub fn mtu(&self) -> u32 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u32(&data[field::MTU]) + } +} + +/// Getter methods only relevant for the Prefix Information option. +impl<T: AsRef<[u8]>> NdiscOption<T> { + /// Return the prefix length. + #[inline] + pub fn prefix_len(&self) -> u8 { + self.buffer.as_ref()[field::PREFIX_LEN] + } + + /// Return the prefix information flags. + #[inline] + pub fn prefix_flags(&self) -> PrefixInfoFlags { + PrefixInfoFlags::from_bits_truncate(self.buffer.as_ref()[field::FLAGS]) + } + + /// Return the valid lifetime of the prefix. + #[inline] + pub fn valid_lifetime(&self) -> Duration { + let data = self.buffer.as_ref(); + Duration::from_secs(NetworkEndian::read_u32(&data[field::VALID_LT]) as u64) + } + + /// Return the preferred lifetime of the prefix. + #[inline] + pub fn preferred_lifetime(&self) -> Duration { + let data = self.buffer.as_ref(); + Duration::from_secs(NetworkEndian::read_u32(&data[field::PREF_LT]) as u64) + } + + /// Return the prefix. + #[inline] + pub fn prefix(&self) -> Ipv6Address { + let data = self.buffer.as_ref(); + Ipv6Address::from_bytes(&data[field::PREFIX]) + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> NdiscOption<&'a T> { + /// Return the option data. + #[inline] + pub fn data(&self) -> &'a [u8] { + let len = self.data_len(); + let data = self.buffer.as_ref(); + &data[field::DATA(len)] + } +} + +/// Core setter methods relevant to any type of NDISC option. +impl<T: AsRef<[u8]> + AsMut<[u8]>> NdiscOption<T> { + /// Set the option type. + #[inline] + pub fn set_option_type(&mut self, value: Type) { + let data = self.buffer.as_mut(); + data[field::TYPE] = value.into(); + } + + /// Set the option data length. + #[inline] + pub fn set_data_len(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::LENGTH] = value; + } +} + +/// Setter methods only relevant for Source/Target Link-layer Address options. +impl<T: AsRef<[u8]> + AsMut<[u8]>> NdiscOption<T> { + /// Set the Source/Target Link-layer Address. + #[inline] + pub fn set_link_layer_addr(&mut self, addr: RawHardwareAddress) { + let data = self.buffer.as_mut(); + data[2..2 + addr.len()].copy_from_slice(addr.as_bytes()) + } +} + +/// Setter methods only relevant for the MTU option. +impl<T: AsRef<[u8]> + AsMut<[u8]>> NdiscOption<T> { + /// Set the MTU value. + #[inline] + pub fn set_mtu(&mut self, value: u32) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::MTU], value); + } +} + +/// Setter methods only relevant for the Prefix Information option. +impl<T: AsRef<[u8]> + AsMut<[u8]>> NdiscOption<T> { + /// Set the prefix length. + #[inline] + pub fn set_prefix_len(&mut self, value: u8) { + self.buffer.as_mut()[field::PREFIX_LEN] = value; + } + + /// Set the prefix information flags. + #[inline] + pub fn set_prefix_flags(&mut self, flags: PrefixInfoFlags) { + self.buffer.as_mut()[field::FLAGS] = flags.bits(); + } + + /// Set the valid lifetime of the prefix. + #[inline] + pub fn set_valid_lifetime(&mut self, time: Duration) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::VALID_LT], time.secs() as u32); + } + + /// Set the preferred lifetime of the prefix. + #[inline] + pub fn set_preferred_lifetime(&mut self, time: Duration) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::PREF_LT], time.secs() as u32); + } + + /// Clear the reserved bits. + #[inline] + pub fn clear_prefix_reserved(&mut self) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u32(&mut data[field::PREF_RESERVED], 0); + } + + /// Set the prefix. + #[inline] + pub fn set_prefix(&mut self, addr: Ipv6Address) { + let data = self.buffer.as_mut(); + data[field::PREFIX].copy_from_slice(addr.as_bytes()); + } +} + +/// Setter methods only relevant for the Redirected Header option. +impl<T: AsRef<[u8]> + AsMut<[u8]>> NdiscOption<T> { + /// Clear the reserved bits. + #[inline] + pub fn clear_redirected_reserved(&mut self) { + let data = self.buffer.as_mut(); + data[field::REDIRECTED_RESERVED].fill_with(|| 0); + } +} + +impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> NdiscOption<&'a mut T> { + /// Return a mutable pointer to the option data. + #[inline] + pub fn data_mut(&mut self) -> &mut [u8] { + let len = self.data_len(); + let data = self.buffer.as_mut(); + &mut data[field::DATA(len)] + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for NdiscOption<&'a T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match Repr::parse(self) { + Ok(repr) => write!(f, "{repr}"), + Err(err) => { + write!(f, "NDISC Option ({err})")?; + Ok(()) + } + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct PrefixInformation { + pub prefix_len: u8, + pub flags: PrefixInfoFlags, + pub valid_lifetime: Duration, + pub preferred_lifetime: Duration, + pub prefix: Ipv6Address, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct RedirectedHeader<'a> { + pub header: Ipv6Repr, + pub data: &'a [u8], +} + +/// A high-level representation of an NDISC Option. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Repr<'a> { + SourceLinkLayerAddr(RawHardwareAddress), + TargetLinkLayerAddr(RawHardwareAddress), + PrefixInformation(PrefixInformation), + RedirectedHeader(RedirectedHeader<'a>), + Mtu(u32), + Unknown { + type_: u8, + length: u8, + data: &'a [u8], + }, +} + +impl<'a> Repr<'a> { + /// Parse an NDISC Option and return a high-level representation. + pub fn parse<T>(opt: &NdiscOption<&'a T>) -> Result<Repr<'a>> + where + T: AsRef<[u8]> + ?Sized, + { + match opt.option_type() { + Type::SourceLinkLayerAddr => { + if opt.data_len() >= 1 { + Ok(Repr::SourceLinkLayerAddr(opt.link_layer_addr())) + } else { + Err(Error) + } + } + Type::TargetLinkLayerAddr => { + if opt.data_len() >= 1 { + Ok(Repr::TargetLinkLayerAddr(opt.link_layer_addr())) + } else { + Err(Error) + } + } + Type::PrefixInformation => { + if opt.data_len() == 4 { + Ok(Repr::PrefixInformation(PrefixInformation { + prefix_len: opt.prefix_len(), + flags: opt.prefix_flags(), + valid_lifetime: opt.valid_lifetime(), + preferred_lifetime: opt.preferred_lifetime(), + prefix: opt.prefix(), + })) + } else { + Err(Error) + } + } + Type::RedirectedHeader => { + // If the options data length is less than 6, the option + // does not have enough data to fill out the IP header + // and common option fields. + if opt.data_len() < 6 { + Err(Error) + } else { + let redirected_packet = &opt.data()[field::REDIRECTED_RESERVED.len()..]; + + let ip_packet = Ipv6Packet::new_checked(redirected_packet)?; + let ip_repr = Ipv6Repr::parse(&ip_packet)?; + + Ok(Repr::RedirectedHeader(RedirectedHeader { + header: ip_repr, + data: &redirected_packet[ip_repr.buffer_len()..][..ip_repr.payload_len], + })) + } + } + Type::Mtu => { + if opt.data_len() == 1 { + Ok(Repr::Mtu(opt.mtu())) + } else { + Err(Error) + } + } + Type::Unknown(id) => { + // A length of 0 is invalid. + if opt.data_len() != 0 { + Ok(Repr::Unknown { + type_: id, + length: opt.data_len(), + data: opt.data(), + }) + } else { + Err(Error) + } + } + } + } + + /// Return the length of a header that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + match self { + &Repr::SourceLinkLayerAddr(addr) | &Repr::TargetLinkLayerAddr(addr) => { + let len = 2 + addr.len(); + // Round up to next multiple of 8 + (len + 7) / 8 * 8 + } + &Repr::PrefixInformation(_) => field::PREFIX.end, + &Repr::RedirectedHeader(RedirectedHeader { header, data }) => { + (8 + header.buffer_len() + data.len() + 7) / 8 * 8 + } + &Repr::Mtu(_) => field::MTU.end, + &Repr::Unknown { length, .. } => field::DATA(length).end, + } + } + + /// Emit a high-level representation into an NDISC Option. + pub fn emit<T>(&self, opt: &mut NdiscOption<&'a mut T>) + where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { + match *self { + Repr::SourceLinkLayerAddr(addr) => { + opt.set_option_type(Type::SourceLinkLayerAddr); + let opt_len = addr.len() + 2; + opt.set_data_len(((opt_len + 7) / 8) as u8); // round to next multiple of 8. + opt.set_link_layer_addr(addr); + } + Repr::TargetLinkLayerAddr(addr) => { + opt.set_option_type(Type::TargetLinkLayerAddr); + let opt_len = addr.len() + 2; + opt.set_data_len(((opt_len + 7) / 8) as u8); // round to next multiple of 8. + opt.set_link_layer_addr(addr); + } + Repr::PrefixInformation(PrefixInformation { + prefix_len, + flags, + valid_lifetime, + preferred_lifetime, + prefix, + }) => { + opt.clear_prefix_reserved(); + opt.set_option_type(Type::PrefixInformation); + opt.set_data_len(4); + opt.set_prefix_len(prefix_len); + opt.set_prefix_flags(flags); + opt.set_valid_lifetime(valid_lifetime); + opt.set_preferred_lifetime(preferred_lifetime); + opt.set_prefix(prefix); + } + Repr::RedirectedHeader(RedirectedHeader { header, data }) => { + // TODO(thvdveld): I think we need to check if the data we are sending is not + // exceeding the MTU. + opt.clear_redirected_reserved(); + opt.set_option_type(Type::RedirectedHeader); + opt.set_data_len((((8 + header.buffer_len() + data.len()) + 7) / 8) as u8); + let mut packet = &mut opt.data_mut()[field::REDIRECTED_RESERVED.end - 2..]; + let mut ip_packet = Ipv6Packet::new_unchecked(&mut packet); + header.emit(&mut ip_packet); + ip_packet.payload_mut().copy_from_slice(data); + } + Repr::Mtu(mtu) => { + opt.set_option_type(Type::Mtu); + opt.set_data_len(1); + opt.set_mtu(mtu); + } + Repr::Unknown { + type_: id, + length, + data, + } => { + opt.set_option_type(Type::Unknown(id)); + opt.set_data_len(length); + opt.data_mut().copy_from_slice(data); + } + } + } +} + +impl<'a> fmt::Display for Repr<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "NDISC Option: ")?; + match *self { + Repr::SourceLinkLayerAddr(addr) => { + write!(f, "SourceLinkLayer addr={addr}") + } + Repr::TargetLinkLayerAddr(addr) => { + write!(f, "TargetLinkLayer addr={addr}") + } + Repr::PrefixInformation(PrefixInformation { + prefix, prefix_len, .. + }) => { + write!(f, "PrefixInformation prefix={prefix}/{prefix_len}") + } + Repr::RedirectedHeader(RedirectedHeader { header, .. }) => { + write!(f, "RedirectedHeader header={header}") + } + Repr::Mtu(mtu) => { + write!(f, "MTU mtu={mtu}") + } + Repr::Unknown { + type_: id, length, .. + } => { + write!(f, "Unknown({id}) length={length}") + } + } + } +} + +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; + +impl<T: AsRef<[u8]>> PrettyPrint for NdiscOption<T> { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { + match NdiscOption::new_checked(buffer) { + Err(err) => write!(f, "{indent}({err})"), + Ok(ndisc) => match Repr::parse(&ndisc) { + Err(_) => Ok(()), + Ok(repr) => { + write!(f, "{indent}{repr}") + } + }, + } + } +} + +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +#[cfg(test)] +mod test { + use super::Error; + use super::{NdiscOption, PrefixInfoFlags, PrefixInformation, Repr, Type}; + use crate::time::Duration; + use crate::wire::Ipv6Address; + + #[cfg(feature = "medium-ethernet")] + use crate::wire::EthernetAddress; + #[cfg(all(not(feature = "medium-ethernet"), feature = "medium-ieee802154"))] + use crate::wire::Ieee802154Address; + + static PREFIX_OPT_BYTES: [u8; 32] = [ + 0x03, 0x04, 0x40, 0xc0, 0x00, 0x00, 0x03, 0x84, 0x00, 0x00, 0x03, 0xe8, 0x00, 0x00, 0x00, + 0x00, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ]; + + #[test] + fn test_deconstruct() { + let opt = NdiscOption::new_unchecked(&PREFIX_OPT_BYTES[..]); + assert_eq!(opt.option_type(), Type::PrefixInformation); + assert_eq!(opt.data_len(), 4); + assert_eq!(opt.prefix_len(), 64); + assert_eq!( + opt.prefix_flags(), + PrefixInfoFlags::ON_LINK | PrefixInfoFlags::ADDRCONF + ); + assert_eq!(opt.valid_lifetime(), Duration::from_secs(900)); + assert_eq!(opt.preferred_lifetime(), Duration::from_secs(1000)); + assert_eq!(opt.prefix(), Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)); + } + + #[test] + fn test_construct() { + let mut bytes = [0x00; 32]; + let mut opt = NdiscOption::new_unchecked(&mut bytes[..]); + opt.set_option_type(Type::PrefixInformation); + opt.set_data_len(4); + opt.set_prefix_len(64); + opt.set_prefix_flags(PrefixInfoFlags::ON_LINK | PrefixInfoFlags::ADDRCONF); + opt.set_valid_lifetime(Duration::from_secs(900)); + opt.set_preferred_lifetime(Duration::from_secs(1000)); + opt.set_prefix(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)); + assert_eq!(&PREFIX_OPT_BYTES[..], &*opt.into_inner()); + } + + #[test] + fn test_short_packet() { + assert_eq!(NdiscOption::new_checked(&[0x00, 0x00]), Err(Error)); + let bytes = [0x03, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + assert_eq!(NdiscOption::new_checked(&bytes), Err(Error)); + } + + #[cfg(feature = "medium-ethernet")] + #[test] + fn test_repr_parse_link_layer_opt_ethernet() { + let mut bytes = [0x01, 0x01, 0x54, 0x52, 0x00, 0x12, 0x23, 0x34]; + let addr = EthernetAddress([0x54, 0x52, 0x00, 0x12, 0x23, 0x34]); + { + assert_eq!( + Repr::parse(&NdiscOption::new_unchecked(&bytes)), + Ok(Repr::SourceLinkLayerAddr(addr.into())) + ); + } + bytes[0] = 0x02; + { + assert_eq!( + Repr::parse(&NdiscOption::new_unchecked(&bytes)), + Ok(Repr::TargetLinkLayerAddr(addr.into())) + ); + } + } + + #[cfg(all(not(feature = "medium-ethernet"), feature = "medium-ieee802154"))] + #[test] + fn test_repr_parse_link_layer_opt_ieee802154() { + let mut bytes = [ + 0x01, 0x02, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, + ]; + let addr = Ieee802154Address::Extended([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); + { + assert_eq!( + Repr::parse(&NdiscOption::new_unchecked(&bytes)), + Ok(Repr::SourceLinkLayerAddr(addr.into())) + ); + } + bytes[0] = 0x02; + { + assert_eq!( + Repr::parse(&NdiscOption::new_unchecked(&bytes)), + Ok(Repr::TargetLinkLayerAddr(addr.into())) + ); + } + } + + #[test] + fn test_repr_parse_prefix_info() { + let repr = Repr::PrefixInformation(PrefixInformation { + prefix_len: 64, + flags: PrefixInfoFlags::ON_LINK | PrefixInfoFlags::ADDRCONF, + valid_lifetime: Duration::from_secs(900), + preferred_lifetime: Duration::from_secs(1000), + prefix: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), + }); + assert_eq!( + Repr::parse(&NdiscOption::new_unchecked(&PREFIX_OPT_BYTES)), + Ok(repr) + ); + } + + #[test] + fn test_repr_emit_prefix_info() { + let mut bytes = [0x2a; 32]; + let repr = Repr::PrefixInformation(PrefixInformation { + prefix_len: 64, + flags: PrefixInfoFlags::ON_LINK | PrefixInfoFlags::ADDRCONF, + valid_lifetime: Duration::from_secs(900), + preferred_lifetime: Duration::from_secs(1000), + prefix: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), + }); + let mut opt = NdiscOption::new_unchecked(&mut bytes); + repr.emit(&mut opt); + assert_eq!(&opt.into_inner()[..], &PREFIX_OPT_BYTES[..]); + } + + #[test] + fn test_repr_parse_mtu() { + let bytes = [0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x05, 0xdc]; + assert_eq!( + Repr::parse(&NdiscOption::new_unchecked(&bytes)), + Ok(Repr::Mtu(1500)) + ); + } +} diff --git a/src/wire/pretty_print.rs b/src/wire/pretty_print.rs new file mode 100644 index 0000000..fe7d8b8 --- /dev/null +++ b/src/wire/pretty_print.rs @@ -0,0 +1,126 @@ +/*! Pretty-printing of packet representation. + +The `pretty_print` module provides bits and pieces for printing concise, +easily human readable packet listings. + +# Example + +A packet can be formatted using the `PrettyPrinter` wrapper: + +```rust +use smoltcp::wire::*; +let buffer = vec![ + // Ethernet II + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, + 0x08, 0x00, + // IPv4 + 0x45, 0x00, 0x00, 0x20, + 0x00, 0x00, 0x40, 0x00, + 0x40, 0x01, 0xd2, 0x79, + 0x11, 0x12, 0x13, 0x14, + 0x21, 0x22, 0x23, 0x24, + // ICMPv4 + 0x08, 0x00, 0x8e, 0xfe, + 0x12, 0x34, 0xab, 0xcd, + 0xaa, 0x00, 0x00, 0xff +]; + +let result = "\ +EthernetII src=11-12-13-14-15-16 dst=01-02-03-04-05-06 type=IPv4\n\ +\\ IPv4 src=17.18.19.20 dst=33.34.35.36 proto=ICMP (checksum incorrect)\n \ + \\ ICMPv4 echo request id=4660 seq=43981 len=4\ +"; + +#[cfg(all(feature = "medium-ethernet", feature = "proto-ipv4"))] +assert_eq!( + result, + &format!("{}", PrettyPrinter::<EthernetFrame<&'static [u8]>>::new("", &buffer)) +); +``` +*/ + +use core::fmt; +use core::marker::PhantomData; + +/// Indentation state. +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct PrettyIndent { + prefix: &'static str, + level: usize, +} + +impl PrettyIndent { + /// Create an indentation state. The entire listing will be indented by the width + /// of `prefix`, and `prefix` will appear at the start of the first line. + pub fn new(prefix: &'static str) -> PrettyIndent { + PrettyIndent { prefix, level: 0 } + } + + /// Increase indentation level. + pub fn increase(&mut self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f)?; + self.level += 1; + Ok(()) + } +} + +impl fmt::Display for PrettyIndent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.level == 0 { + write!(f, "{}", self.prefix) + } else { + write!(f, "{0:1$}{0:2$}\\ ", "", self.prefix.len(), self.level - 1) + } + } +} + +/// Interface for printing listings. +pub trait PrettyPrint { + /// Write a concise, formatted representation of a packet contained in the provided + /// buffer, and any nested packets it may contain. + /// + /// `pretty_print` accepts a buffer and not a packet wrapper because the packet might + /// be truncated, and so it might not be possible to create the packet wrapper. + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + fmt: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result; +} + +/// Wrapper for using a `PrettyPrint` where a `Display` is expected. +pub struct PrettyPrinter<'a, T: PrettyPrint> { + prefix: &'static str, + buffer: &'a dyn AsRef<[u8]>, + phantom: PhantomData<T>, +} + +impl<'a, T: PrettyPrint> PrettyPrinter<'a, T> { + /// Format the listing with the recorded parameters when Display::fmt is called. + pub fn new(prefix: &'static str, buffer: &'a dyn AsRef<[u8]>) -> PrettyPrinter<'a, T> { + PrettyPrinter { + prefix: prefix, + buffer: buffer, + phantom: PhantomData, + } + } +} + +impl<'a, T: PrettyPrint + AsRef<[u8]>> PrettyPrinter<'a, T> { + /// Create a `PrettyPrinter` which prints the given object. + pub fn print(printable: &'a T) -> PrettyPrinter<'a, T> { + PrettyPrinter { + prefix: "", + buffer: printable, + phantom: PhantomData, + } + } +} + +impl<'a, T: PrettyPrint> fmt::Display for PrettyPrinter<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + T::pretty_print(&self.buffer, f, &mut PrettyIndent::new(self.prefix)) + } +} diff --git a/src/wire/rpl.rs b/src/wire/rpl.rs new file mode 100644 index 0000000..0a8467c --- /dev/null +++ b/src/wire/rpl.rs @@ -0,0 +1,2721 @@ +//! Implementation of the RPL packet formats. See [RFC 6550 § 6]. +//! +//! [RFC 6550 § 6]: https://datatracker.ietf.org/doc/html/rfc6550#section-6 + +use byteorder::{ByteOrder, NetworkEndian}; + +use super::{Error, Result}; +use crate::wire::icmpv6::Packet; +use crate::wire::ipv6::Address; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[repr(u8)] +pub enum InstanceId { + Global(u8), + Local(u8), +} + +impl From<u8> for InstanceId { + fn from(val: u8) -> Self { + const MASK: u8 = 0b0111_1111; + + if ((val >> 7) & 0xb1) == 0b0 { + Self::Global(val & MASK) + } else { + Self::Local(val & MASK) + } + } +} + +impl From<InstanceId> for u8 { + fn from(val: InstanceId) -> Self { + match val { + InstanceId::Global(val) => 0b0000_0000 | val, + InstanceId::Local(val) => 0b1000_0000 | val, + } + } +} + +impl InstanceId { + /// Return the real part of the ID. + pub fn id(&self) -> u8 { + match self { + Self::Global(val) => *val, + Self::Local(val) => *val, + } + } + + /// Returns `true` when the DODAG ID is the destination address of the IPv6 packet. + #[inline] + pub fn dodag_is_destination(&self) -> bool { + match self { + Self::Global(_) => false, + Self::Local(val) => ((val >> 6) & 0b1) == 0b1, + } + } + + /// Returns `true` when the DODAG ID is the source address of the IPv6 packet. + /// + /// *NOTE*: this only makes sense when using a local RPL Instance ID and the packet is not a + /// RPL control message. + #[inline] + pub fn dodag_is_source(&self) -> bool { + !self.dodag_is_destination() + } +} + +mod field { + use crate::wire::field::*; + + pub const RPL_INSTANCE_ID: usize = 4; + + // DODAG information solicitation fields (DIS) + pub const DIS_FLAGS: usize = 4; + pub const DIS_RESERVED: usize = 5; + + // DODAG information object fields (DIO) + pub const DIO_VERSION_NUMBER: usize = 5; + pub const DIO_RANK: Field = 6..8; + pub const DIO_GROUNDED: usize = 8; + pub const DIO_MOP: usize = 8; + pub const DIO_PRF: usize = 8; + pub const DIO_DTSN: usize = 9; + //pub const DIO_FLAGS: usize = 10; + //pub const DIO_RESERVED: usize = 11; + pub const DIO_DODAG_ID: Field = 12..12 + 16; + + // Destination advertisement object (DAO) + pub const DAO_K: usize = 5; + pub const DAO_D: usize = 5; + //pub const DAO_FLAGS: usize = 5; + //pub const DAO_RESERVED: usize = 6; + pub const DAO_SEQUENCE: usize = 7; + pub const DAO_DODAG_ID: Field = 8..8 + 16; + + // Destination advertisement object ack (DAO-ACK) + pub const DAO_ACK_D: usize = 5; + //pub const DAO_ACK_RESERVED: usize = 5; + pub const DAO_ACK_SEQUENCE: usize = 6; + pub const DAO_ACK_STATUS: usize = 7; + pub const DAO_ACK_DODAG_ID: Field = 8..8 + 16; +} + +enum_with_unknown! { + /// RPL Control Message subtypes. + pub enum RplControlMessage(u8) { + DodagInformationSolicitation = 0x00, + DodagInformationObject = 0x01, + DestinationAdvertisementObject = 0x02, + DestinationAdvertisementObjectAck = 0x03, + SecureDodagInformationSolicitation = 0x80, + SecureDodagInformationObject = 0x81, + SecureDestinationAdvertisementObject = 0x82, + SecureDestinationAdvertisementObjectAck = 0x83, + ConsistencyCheck = 0x8a, + } +} + +impl core::fmt::Display for RplControlMessage { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + RplControlMessage::DodagInformationSolicitation => { + write!(f, "DODAG information solicitation (DIS)") + } + RplControlMessage::DodagInformationObject => { + write!(f, "DODAG information object (DIO)") + } + RplControlMessage::DestinationAdvertisementObject => { + write!(f, "destination advertisement object (DAO)") + } + RplControlMessage::DestinationAdvertisementObjectAck => write!( + f, + "destination advertisement object acknowledgement (DAO-ACK)" + ), + RplControlMessage::SecureDodagInformationSolicitation => { + write!(f, "secure DODAG information solicitation (DIS)") + } + RplControlMessage::SecureDodagInformationObject => { + write!(f, "secure DODAG information object (DIO)") + } + RplControlMessage::SecureDestinationAdvertisementObject => { + write!(f, "secure destination advertisement object (DAO)") + } + RplControlMessage::SecureDestinationAdvertisementObjectAck => write!( + f, + "secure destination advertisement object acknowledgement (DAO-ACK)" + ), + RplControlMessage::ConsistencyCheck => write!(f, "consistency check (CC)"), + RplControlMessage::Unknown(id) => write!(f, "{}", id), + } + } +} + +impl<T: AsRef<[u8]>> Packet<T> { + /// Return the RPL instance ID. + #[inline] + pub fn rpl_instance_id(&self) -> InstanceId { + get!(self.buffer, into: InstanceId, field: field::RPL_INSTANCE_ID) + } +} + +impl<'p, T: AsRef<[u8]> + ?Sized> Packet<&'p T> { + /// Return a pointer to the options. + pub fn options(&self) -> Result<&'p [u8]> { + let len = self.buffer.as_ref().len(); + match RplControlMessage::from(self.msg_code()) { + RplControlMessage::DodagInformationSolicitation if len < field::DIS_RESERVED + 1 => { + return Err(Error) + } + RplControlMessage::DodagInformationObject if len < field::DIO_DODAG_ID.end => { + return Err(Error) + } + RplControlMessage::DestinationAdvertisementObject + if self.dao_dodag_id_present() && len < field::DAO_DODAG_ID.end => + { + return Err(Error) + } + RplControlMessage::DestinationAdvertisementObject if len < field::DAO_SEQUENCE + 1 => { + return Err(Error) + } + RplControlMessage::DestinationAdvertisementObjectAck + if self.dao_ack_dodag_id_present() && len < field::DAO_ACK_DODAG_ID.end => + { + return Err(Error) + } + RplControlMessage::DestinationAdvertisementObjectAck + if len < field::DAO_ACK_STATUS + 1 => + { + return Err(Error) + } + RplControlMessage::SecureDodagInformationSolicitation + | RplControlMessage::SecureDodagInformationObject + | RplControlMessage::SecureDestinationAdvertisementObject + | RplControlMessage::SecureDestinationAdvertisementObjectAck + | RplControlMessage::ConsistencyCheck => return Err(Error), + RplControlMessage::Unknown(_) => return Err(Error), + _ => {} + } + + let buffer = &self.buffer.as_ref(); + Ok(match RplControlMessage::from(self.msg_code()) { + RplControlMessage::DodagInformationSolicitation => &buffer[field::DIS_RESERVED + 1..], + RplControlMessage::DodagInformationObject => &buffer[field::DIO_DODAG_ID.end..], + RplControlMessage::DestinationAdvertisementObject if self.dao_dodag_id_present() => { + &buffer[field::DAO_DODAG_ID.end..] + } + RplControlMessage::DestinationAdvertisementObject => &buffer[field::DAO_SEQUENCE + 1..], + RplControlMessage::DestinationAdvertisementObjectAck + if self.dao_ack_dodag_id_present() => + { + &buffer[field::DAO_ACK_DODAG_ID.end..] + } + RplControlMessage::DestinationAdvertisementObjectAck => { + &buffer[field::DAO_ACK_STATUS + 1..] + } + RplControlMessage::SecureDodagInformationSolicitation + | RplControlMessage::SecureDodagInformationObject + | RplControlMessage::SecureDestinationAdvertisementObject + | RplControlMessage::SecureDestinationAdvertisementObjectAck + | RplControlMessage::ConsistencyCheck => unreachable!(), + RplControlMessage::Unknown(_) => unreachable!(), + }) + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the RPL Instance ID field. + #[inline] + pub fn set_rpl_instance_id(&mut self, value: u8) { + set!(self.buffer, value, field: field::RPL_INSTANCE_ID) + } +} + +impl<'p, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Packet<&'p mut T> { + /// Return a pointer to the options. + pub fn options_mut(&mut self) -> &mut [u8] { + match RplControlMessage::from(self.msg_code()) { + RplControlMessage::DodagInformationSolicitation => { + &mut self.buffer.as_mut()[field::DIS_RESERVED + 1..] + } + RplControlMessage::DodagInformationObject => { + &mut self.buffer.as_mut()[field::DIO_DODAG_ID.end..] + } + RplControlMessage::DestinationAdvertisementObject => { + if self.dao_dodag_id_present() { + &mut self.buffer.as_mut()[field::DAO_DODAG_ID.end..] + } else { + &mut self.buffer.as_mut()[field::DAO_SEQUENCE + 1..] + } + } + RplControlMessage::DestinationAdvertisementObjectAck => { + if self.dao_ack_dodag_id_present() { + &mut self.buffer.as_mut()[field::DAO_ACK_DODAG_ID.end..] + } else { + &mut self.buffer.as_mut()[field::DAO_ACK_STATUS + 1..] + } + } + RplControlMessage::SecureDodagInformationSolicitation + | RplControlMessage::SecureDodagInformationObject + | RplControlMessage::SecureDestinationAdvertisementObject + | RplControlMessage::SecureDestinationAdvertisementObjectAck + | RplControlMessage::ConsistencyCheck => todo!("Secure messages not supported"), + RplControlMessage::Unknown(_) => todo!(), + } + } +} + +/// Getters for the DODAG information solicitation (DIS) message. +/// +/// ```txt +/// 0 1 2 +/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | Flags | Reserved | Option(s)... +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// ``` +impl<T: AsRef<[u8]>> Packet<T> { + /// Return the DIS flags field. + #[inline] + pub fn dis_flags(&self) -> u8 { + get!(self.buffer, field: field::DIS_FLAGS) + } + + /// Return the DIS reserved field. + #[inline] + pub fn dis_reserved(&self) -> u8 { + get!(self.buffer, field: field::DIS_RESERVED) + } +} + +/// Setters for the DODAG information solicitation (DIS) message. +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Clear the DIS flags field. + pub fn clear_dis_flags(&mut self) { + self.buffer.as_mut()[field::DIS_FLAGS] = 0; + } + + /// Clear the DIS rserved field. + pub fn clear_dis_reserved(&mut self) { + self.buffer.as_mut()[field::DIS_RESERVED] = 0; + } +} + +enum_with_unknown! { + pub enum ModeOfOperation(u8) { + NoDownwardRoutesMaintained = 0x00, + NonStoringMode = 0x01, + StoringModeWithoutMulticast = 0x02, + StoringModeWithMulticast = 0x03, + } +} + +impl Default for ModeOfOperation { + fn default() -> Self { + Self::StoringModeWithoutMulticast + } +} + +/// Getters for the DODAG information object (DIO) message. +/// +/// ```txt +/// 0 1 2 3 +/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | RPLInstanceID |Version Number | Rank | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// |G|0| MOP | Prf | DTSN | Flags | Reserved | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | | +/// + + +/// | | +/// + DODAGID + +/// | | +/// + + +/// | | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | Option(s)... +/// +-+-+-+-+-+-+-+-+ +/// ``` +impl<T: AsRef<[u8]>> Packet<T> { + /// Return the Version Number field. + #[inline] + pub fn dio_version_number(&self) -> u8 { + get!(self.buffer, field: field::DIO_VERSION_NUMBER) + } + + /// Return the Rank field. + #[inline] + pub fn dio_rank(&self) -> u16 { + get!(self.buffer, u16, field: field::DIO_RANK) + } + + /// Return the value of the Grounded flag. + #[inline] + pub fn dio_grounded(&self) -> bool { + get!(self.buffer, bool, field: field::DIO_GROUNDED, shift: 7, mask: 0b01) + } + + /// Return the mode of operation field. + #[inline] + pub fn dio_mode_of_operation(&self) -> ModeOfOperation { + get!(self.buffer, into: ModeOfOperation, field: field::DIO_MOP, shift: 3, mask: 0b111) + } + + /// Return the DODAG preference field. + #[inline] + pub fn dio_dodag_preference(&self) -> u8 { + get!(self.buffer, field: field::DIO_PRF, mask: 0b111) + } + + /// Return the destination advertisement trigger sequence number. + #[inline] + pub fn dio_dest_adv_trigger_seq_number(&self) -> u8 { + get!(self.buffer, field: field::DIO_DTSN) + } + + /// Return the DODAG id, which is an IPv6 address. + #[inline] + pub fn dio_dodag_id(&self) -> Address { + get!( + self.buffer, + into: Address, + fun: from_bytes, + field: field::DIO_DODAG_ID + ) + } +} + +/// Setters for the DODAG information object (DIO) message. +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the Version Number field. + #[inline] + pub fn set_dio_version_number(&mut self, value: u8) { + set!(self.buffer, value, field: field::DIO_VERSION_NUMBER) + } + + /// Set the Rank field. + #[inline] + pub fn set_dio_rank(&mut self, value: u16) { + set!(self.buffer, value, u16, field: field::DIO_RANK) + } + + /// Set the value of the Grounded flag. + #[inline] + pub fn set_dio_grounded(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::DIO_GROUNDED, shift: 7, mask: 0b01) + } + + /// Set the mode of operation field. + #[inline] + pub fn set_dio_mode_of_operation(&mut self, mode: ModeOfOperation) { + let raw = (self.buffer.as_ref()[field::DIO_MOP] & !(0b111 << 3)) | (u8::from(mode) << 3); + self.buffer.as_mut()[field::DIO_MOP] = raw; + } + + /// Set the DODAG preference field. + #[inline] + pub fn set_dio_dodag_preference(&mut self, value: u8) { + set!(self.buffer, value, field: field::DIO_PRF, mask: 0b111) + } + + /// Set the destination advertisement trigger sequence number. + #[inline] + pub fn set_dio_dest_adv_trigger_seq_number(&mut self, value: u8) { + set!(self.buffer, value, field: field::DIO_DTSN) + } + + /// Set the DODAG id, which is an IPv6 address. + #[inline] + pub fn set_dio_dodag_id(&mut self, address: Address) { + set!(self.buffer, address: address, field: field::DIO_DODAG_ID) + } +} + +/// Getters for the Destination Advertisement Object (DAO) message. +/// +/// ```txt +/// 0 1 2 3 +/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | RPLInstanceID |K|D| Flags | Reserved | DAOSequence | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | | +/// + + +/// | | +/// + DODAGID* + +/// | | +/// + + +/// | | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | Option(s)... +/// +-+-+-+-+-+-+-+-+ +/// ``` +impl<T: AsRef<[u8]>> Packet<T> { + /// Returns the Expect DAO-ACK flag. + #[inline] + pub fn dao_ack_request(&self) -> bool { + get!(self.buffer, bool, field: field::DAO_K, shift: 7, mask: 0b1) + } + + /// Returns the flag indicating that the DODAG ID is present or not. + #[inline] + pub fn dao_dodag_id_present(&self) -> bool { + get!(self.buffer, bool, field: field::DAO_D, shift: 6, mask: 0b1) + } + + /// Returns the DODAG sequence flag. + #[inline] + pub fn dao_dodag_sequence(&self) -> u8 { + get!(self.buffer, field: field::DAO_SEQUENCE) + } + + /// Returns the DODAG ID, an IPv6 address, when it is present. + #[inline] + pub fn dao_dodag_id(&self) -> Option<Address> { + if self.dao_dodag_id_present() { + Some(Address::from_bytes( + &self.buffer.as_ref()[field::DAO_DODAG_ID], + )) + } else { + None + } + } +} + +/// Setters for the Destination Advertisement Object (DAO) message. +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the Expect DAO-ACK flag. + #[inline] + pub fn set_dao_ack_request(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::DAO_K, shift: 7, mask: 0b1,) + } + + /// Set the flag indicating that the DODAG ID is present or not. + #[inline] + pub fn set_dao_dodag_id_present(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::DAO_D, shift: 6, mask: 0b1) + } + + /// Set the DODAG sequence flag. + #[inline] + pub fn set_dao_dodag_sequence(&mut self, value: u8) { + set!(self.buffer, value, field: field::DAO_SEQUENCE) + } + + /// Set the DODAG ID. + #[inline] + pub fn set_dao_dodag_id(&mut self, address: Option<Address>) { + match address { + Some(address) => { + self.buffer.as_mut()[field::DAO_DODAG_ID].copy_from_slice(address.as_bytes()); + self.set_dao_dodag_id_present(true); + } + None => { + self.set_dao_dodag_id_present(false); + } + } + } +} + +/// Getters for the Destination Advertisement Object acknowledgement (DAO-ACK) message. +/// +/// ```txt +/// 0 1 2 3 +/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | RPLInstanceID |D| Reserved | DAOSequence | Status | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | | +/// + + +/// | | +/// + DODAGID* + +/// | | +/// + + +/// | | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | Option(s)... +/// +-+-+-+-+-+-+-+-+ +/// ``` +impl<T: AsRef<[u8]>> Packet<T> { + /// Returns the flag indicating that the DODAG ID is present or not. + #[inline] + pub fn dao_ack_dodag_id_present(&self) -> bool { + get!(self.buffer, bool, field: field::DAO_ACK_D, shift: 7, mask: 0b1) + } + + /// Return the DODAG sequence number. + #[inline] + pub fn dao_ack_sequence(&self) -> u8 { + get!(self.buffer, field: field::DAO_ACK_SEQUENCE) + } + + /// Return the DOA status field. + #[inline] + pub fn dao_ack_status(&self) -> u8 { + get!(self.buffer, field: field::DAO_ACK_STATUS) + } + + /// Returns the DODAG ID, an IPv6 address, when it is present. + #[inline] + pub fn dao_ack_dodag_id(&self) -> Option<Address> { + if self.dao_ack_dodag_id_present() { + Some(Address::from_bytes( + &self.buffer.as_ref()[field::DAO_ACK_DODAG_ID], + )) + } else { + None + } + } +} + +/// Setters for the Destination Advertisement Object acknowledgement (DAO-ACK) message. +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the flag indicating that the DODAG ID is present or not. + #[inline] + pub fn set_dao_ack_dodag_id_present(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::DAO_ACK_D, shift: 7, mask: 0b1) + } + + /// Set the DODAG sequence number. + #[inline] + pub fn set_dao_ack_sequence(&mut self, value: u8) { + set!(self.buffer, value, field: field::DAO_ACK_SEQUENCE) + } + + /// Set the DOA status field. + #[inline] + pub fn set_dao_ack_status(&mut self, value: u8) { + set!(self.buffer, value, field: field::DAO_ACK_STATUS) + } + + /// Set the DODAG ID. + #[inline] + pub fn set_dao_ack_dodag_id(&mut self, address: Option<Address>) { + match address { + Some(address) => { + self.buffer.as_mut()[field::DAO_ACK_DODAG_ID].copy_from_slice(address.as_bytes()); + self.set_dao_ack_dodag_id_present(true); + } + None => { + self.set_dao_ack_dodag_id_present(false); + } + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Repr<'p> { + DodagInformationSolicitation { + options: &'p [u8], + }, + DodagInformationObject { + rpl_instance_id: InstanceId, + version_number: u8, + rank: u16, + grounded: bool, + mode_of_operation: ModeOfOperation, + dodag_preference: u8, + dtsn: u8, + dodag_id: Address, + options: &'p [u8], + }, + DestinationAdvertisementObject { + rpl_instance_id: InstanceId, + expect_ack: bool, + sequence: u8, + dodag_id: Option<Address>, + options: &'p [u8], + }, + DestinationAdvertisementObjectAck { + rpl_instance_id: InstanceId, + sequence: u8, + status: u8, + dodag_id: Option<Address>, + }, +} + +impl core::fmt::Display for Repr<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Repr::DodagInformationSolicitation { .. } => { + write!(f, "DIS")?; + } + Repr::DodagInformationObject { + rpl_instance_id, + version_number, + rank, + grounded, + mode_of_operation, + dodag_preference, + dtsn, + dodag_id, + .. + } => { + write!( + f, + "DIO \ + IID={rpl_instance_id:?} \ + V={version_number} \ + R={rank} \ + G={grounded} \ + MOP={mode_of_operation:?} \ + Pref={dodag_preference} \ + DTSN={dtsn} \ + DODAGID={dodag_id}" + )?; + } + Repr::DestinationAdvertisementObject { + rpl_instance_id, + expect_ack, + sequence, + dodag_id, + .. + } => { + write!( + f, + "DAO \ + IID={rpl_instance_id:?} \ + Ack={expect_ack} \ + Seq={sequence} \ + DODAGID={dodag_id:?}", + )?; + } + Repr::DestinationAdvertisementObjectAck { + rpl_instance_id, + sequence, + status, + dodag_id, + .. + } => { + write!( + f, + "DAO-ACK \ + IID={rpl_instance_id:?} \ + Seq={sequence} \ + Status={status} \ + DODAGID={dodag_id:?}", + )?; + } + }; + + Ok(()) + } +} + +impl<'p> Repr<'p> { + pub fn set_options(&mut self, options: &'p [u8]) { + let opts = match self { + Repr::DodagInformationSolicitation { options } => options, + Repr::DodagInformationObject { options, .. } => options, + Repr::DestinationAdvertisementObject { options, .. } => options, + Repr::DestinationAdvertisementObjectAck { .. } => unreachable!(), + }; + + *opts = options; + } + + pub fn parse<T: AsRef<[u8]> + ?Sized>(packet: &Packet<&'p T>) -> Result<Self> { + let options = packet.options()?; + match RplControlMessage::from(packet.msg_code()) { + RplControlMessage::DodagInformationSolicitation => { + Ok(Repr::DodagInformationSolicitation { options }) + } + RplControlMessage::DodagInformationObject => Ok(Repr::DodagInformationObject { + rpl_instance_id: packet.rpl_instance_id(), + version_number: packet.dio_version_number(), + rank: packet.dio_rank(), + grounded: packet.dio_grounded(), + mode_of_operation: packet.dio_mode_of_operation(), + dodag_preference: packet.dio_dodag_preference(), + dtsn: packet.dio_dest_adv_trigger_seq_number(), + dodag_id: packet.dio_dodag_id(), + options, + }), + RplControlMessage::DestinationAdvertisementObject => { + Ok(Repr::DestinationAdvertisementObject { + rpl_instance_id: packet.rpl_instance_id(), + expect_ack: packet.dao_ack_request(), + sequence: packet.dao_dodag_sequence(), + dodag_id: packet.dao_dodag_id(), + options, + }) + } + RplControlMessage::DestinationAdvertisementObjectAck => { + Ok(Repr::DestinationAdvertisementObjectAck { + rpl_instance_id: packet.rpl_instance_id(), + sequence: packet.dao_ack_sequence(), + status: packet.dao_ack_status(), + dodag_id: packet.dao_ack_dodag_id(), + }) + } + RplControlMessage::SecureDodagInformationSolicitation + | RplControlMessage::SecureDodagInformationObject + | RplControlMessage::SecureDestinationAdvertisementObject + | RplControlMessage::SecureDestinationAdvertisementObjectAck + | RplControlMessage::ConsistencyCheck => Err(Error), + RplControlMessage::Unknown(_) => Err(Error), + } + } + + pub fn buffer_len(&self) -> usize { + let mut len = 4 + match self { + Repr::DodagInformationSolicitation { .. } => 2, + Repr::DodagInformationObject { .. } => 24, + Repr::DestinationAdvertisementObject { dodag_id, .. } => { + if dodag_id.is_some() { + 20 + } else { + 4 + } + } + Repr::DestinationAdvertisementObjectAck { dodag_id, .. } => { + if dodag_id.is_some() { + 20 + } else { + 4 + } + } + }; + + let opts = match self { + Repr::DodagInformationSolicitation { options } => &options[..], + Repr::DodagInformationObject { options, .. } => &options[..], + Repr::DestinationAdvertisementObject { options, .. } => &options[..], + Repr::DestinationAdvertisementObjectAck { .. } => &[], + }; + + len += opts.len(); + + len + } + + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(&self, packet: &mut Packet<&mut T>) { + packet.set_msg_type(crate::wire::icmpv6::Message::RplControl); + + match self { + Repr::DodagInformationSolicitation { .. } => { + packet.set_msg_code(RplControlMessage::DodagInformationSolicitation.into()); + packet.clear_dis_flags(); + packet.clear_dis_reserved(); + } + Repr::DodagInformationObject { + rpl_instance_id, + version_number, + rank, + grounded, + mode_of_operation, + dodag_preference, + dtsn, + dodag_id, + .. + } => { + packet.set_msg_code(RplControlMessage::DodagInformationObject.into()); + packet.set_rpl_instance_id((*rpl_instance_id).into()); + packet.set_dio_version_number(*version_number); + packet.set_dio_rank(*rank); + packet.set_dio_grounded(*grounded); + packet.set_dio_mode_of_operation(*mode_of_operation); + packet.set_dio_dodag_preference(*dodag_preference); + packet.set_dio_dest_adv_trigger_seq_number(*dtsn); + packet.set_dio_dodag_id(*dodag_id); + } + Repr::DestinationAdvertisementObject { + rpl_instance_id, + expect_ack, + sequence, + dodag_id, + .. + } => { + packet.set_msg_code(RplControlMessage::DestinationAdvertisementObject.into()); + packet.set_rpl_instance_id((*rpl_instance_id).into()); + packet.set_dao_ack_request(*expect_ack); + packet.set_dao_dodag_sequence(*sequence); + packet.set_dao_dodag_id(*dodag_id); + } + Repr::DestinationAdvertisementObjectAck { + rpl_instance_id, + sequence, + status, + dodag_id, + .. + } => { + packet.set_msg_code(RplControlMessage::DestinationAdvertisementObjectAck.into()); + packet.set_rpl_instance_id((*rpl_instance_id).into()); + packet.set_dao_ack_sequence(*sequence); + packet.set_dao_ack_status(*status); + packet.set_dao_ack_dodag_id(*dodag_id); + } + } + + let options = match self { + Repr::DodagInformationSolicitation { options } => &options[..], + Repr::DodagInformationObject { options, .. } => &options[..], + Repr::DestinationAdvertisementObject { options, .. } => &options[..], + Repr::DestinationAdvertisementObjectAck { .. } => &[], + }; + + packet.options_mut().copy_from_slice(options); + } +} + +pub mod options { + use byteorder::{ByteOrder, NetworkEndian}; + + use super::{Error, InstanceId, Result}; + use crate::wire::ipv6::Address; + + /// A read/write wrapper around a RPL Control Message Option. + #[derive(Debug, Clone)] + pub struct Packet<T: AsRef<[u8]>> { + buffer: T, + } + + enum_with_unknown! { + pub enum OptionType(u8) { + Pad1 = 0x00, + PadN = 0x01, + DagMetricContainer = 0x02, + RouteInformation = 0x03, + DodagConfiguration = 0x04, + RplTarget = 0x05, + TransitInformation = 0x06, + SolicitedInformation = 0x07, + PrefixInformation = 0x08, + RplTargetDescriptor = 0x09, + } + } + + impl From<&Repr<'_>> for OptionType { + fn from(repr: &Repr) -> Self { + match repr { + Repr::Pad1 => Self::Pad1, + Repr::PadN(_) => Self::PadN, + Repr::DagMetricContainer => Self::DagMetricContainer, + Repr::RouteInformation { .. } => Self::RouteInformation, + Repr::DodagConfiguration { .. } => Self::DodagConfiguration, + Repr::RplTarget { .. } => Self::RplTarget, + Repr::TransitInformation { .. } => Self::TransitInformation, + Repr::SolicitedInformation { .. } => Self::SolicitedInformation, + Repr::PrefixInformation { .. } => Self::PrefixInformation, + Repr::RplTargetDescriptor { .. } => Self::RplTargetDescriptor, + } + } + } + + mod field { + use crate::wire::field::*; + + // Generic fields. + pub const TYPE: usize = 0; + pub const LENGTH: usize = 1; + + pub const PADN: Rest = 2..; + + // Route Information fields. + pub const ROUTE_INFO_PREFIX_LENGTH: usize = 2; + pub const ROUTE_INFO_RESERVED: usize = 3; + pub const ROUTE_INFO_PREFERENCE: usize = 3; + pub const ROUTE_INFO_LIFETIME: Field = 4..9; + + // DODAG Configuration fields. + pub const DODAG_CONF_FLAGS: usize = 2; + pub const DODAG_CONF_AUTHENTICATION_ENABLED: usize = 2; + pub const DODAG_CONF_PATH_CONTROL_SIZE: usize = 2; + pub const DODAG_CONF_DIO_INTERVAL_DOUBLINGS: usize = 3; + pub const DODAG_CONF_DIO_INTERVAL_MINIMUM: usize = 4; + pub const DODAG_CONF_DIO_REDUNDANCY_CONSTANT: usize = 5; + pub const DODAG_CONF_DIO_MAX_RANK_INCREASE: Field = 6..8; + pub const DODAG_CONF_MIN_HOP_RANK_INCREASE: Field = 8..10; + pub const DODAG_CONF_OBJECTIVE_CODE_POINT: Field = 10..12; + pub const DODAG_CONF_DEFAULT_LIFETIME: usize = 13; + pub const DODAG_CONF_LIFETIME_UNIT: Field = 14..16; + + // RPL Target fields. + pub const RPL_TARGET_FLAGS: usize = 2; + pub const RPL_TARGET_PREFIX_LENGTH: usize = 3; + + // Transit Information fields. + pub const TRANSIT_INFO_FLAGS: usize = 2; + pub const TRANSIT_INFO_EXTERNAL: usize = 2; + pub const TRANSIT_INFO_PATH_CONTROL: usize = 3; + pub const TRANSIT_INFO_PATH_SEQUENCE: usize = 4; + pub const TRANSIT_INFO_PATH_LIFETIME: usize = 5; + pub const TRANSIT_INFO_PARENT_ADDRESS: Field = 6..6 + 16; + + // Solicited Information fields. + pub const SOLICITED_INFO_RPL_INSTANCE_ID: usize = 2; + pub const SOLICITED_INFO_FLAGS: usize = 3; + pub const SOLICITED_INFO_VERSION_PREDICATE: usize = 3; + pub const SOLICITED_INFO_INSTANCE_ID_PREDICATE: usize = 3; + pub const SOLICITED_INFO_DODAG_ID_PREDICATE: usize = 3; + pub const SOLICITED_INFO_DODAG_ID: Field = 4..20; + pub const SOLICITED_INFO_VERSION_NUMBER: usize = 20; + + // Prefix Information fields. + pub const PREFIX_INFO_PREFIX_LENGTH: usize = 2; + pub const PREFIX_INFO_RESERVED1: usize = 3; + pub const PREFIX_INFO_ON_LINK: usize = 3; + pub const PREFIX_INFO_AUTONOMOUS_CONF: usize = 3; + pub const PREFIX_INFO_ROUTER_ADDRESS_FLAG: usize = 3; + pub const PREFIX_INFO_VALID_LIFETIME: Field = 4..8; + pub const PREFIX_INFO_PREFERRED_LIFETIME: Field = 8..12; + pub const PREFIX_INFO_RESERVED2: Field = 12..16; + pub const PREFIX_INFO_PREFIX: Field = 16..16 + 16; + + // RPL Target Descriptor fields. + pub const TARGET_DESCRIPTOR: Field = 2..6; + } + + /// Getters for the RPL Control Message Options. + impl<T: AsRef<[u8]>> Packet<T> { + /// Imbue a raw octet buffer with RPL Control Message Option structure. + #[inline] + pub fn new_unchecked(buffer: T) -> Self { + Packet { buffer } + } + + #[inline] + pub fn new_checked(buffer: T) -> Result<Self> { + if buffer.as_ref().is_empty() { + return Err(Error); + } + + Ok(Packet { buffer }) + } + + /// Return the type field. + #[inline] + pub fn option_type(&self) -> OptionType { + OptionType::from(self.buffer.as_ref()[field::TYPE]) + } + + /// Return the length field. + #[inline] + pub fn option_length(&self) -> u8 { + get!(self.buffer, field: field::LENGTH) + } + } + + impl<'p, T: AsRef<[u8]> + ?Sized> Packet<&'p T> { + /// Return a pointer to the next option. + #[inline] + pub fn next_option(&self) -> Option<&'p [u8]> { + if !self.buffer.as_ref().is_empty() { + match self.option_type() { + OptionType::Pad1 => Some(&self.buffer.as_ref()[1..]), + OptionType::Unknown(_) => unreachable!(), + _ => { + let len = self.option_length(); + Some(&self.buffer.as_ref()[2 + len as usize..]) + } + } + } else { + None + } + } + } + + impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the Option Type field. + #[inline] + pub fn set_option_type(&mut self, option_type: OptionType) { + self.buffer.as_mut()[field::TYPE] = option_type.into(); + } + + /// Set the Option Length field. + #[inline] + pub fn set_option_length(&mut self, length: u8) { + self.buffer.as_mut()[field::LENGTH] = length; + } + } + + impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + #[inline] + pub fn clear_padn(&mut self, size: u8) { + for b in &mut self.buffer.as_mut()[field::PADN][..size as usize] { + *b = 0; + } + } + } + + /// Getters for the DAG Metric Container Option Message. + + /// Getters for the Route Information Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x03 | Option Length | Prefix Length |Resvd|Prf|Resvd| + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Route Lifetime | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | | + /// . Prefix (Variable Length) . + /// . . + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + impl<T: AsRef<[u8]>> Packet<T> { + /// Return the Prefix Length field. + #[inline] + pub fn prefix_length(&self) -> u8 { + get!(self.buffer, field: field::ROUTE_INFO_PREFIX_LENGTH) + } + + /// Return the Route Preference field. + #[inline] + pub fn route_preference(&self) -> u8 { + (self.buffer.as_ref()[field::ROUTE_INFO_PREFERENCE] & 0b0001_1000) >> 3 + } + + /// Return the Route Lifetime field. + #[inline] + pub fn route_lifetime(&self) -> u32 { + get!(self.buffer, u32, field: field::ROUTE_INFO_LIFETIME) + } + } + + impl<'p, T: AsRef<[u8]> + ?Sized> Packet<&'p T> { + /// Return the Prefix field. + #[inline] + pub fn prefix(&self) -> &'p [u8] { + let option_len = self.option_length(); + &self.buffer.as_ref()[field::ROUTE_INFO_LIFETIME.end..] + [..option_len as usize - field::ROUTE_INFO_LIFETIME.end] + } + } + + /// Setters for the Route Information Option Message. + impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the Prefix Length field. + #[inline] + pub fn set_route_info_prefix_length(&mut self, value: u8) { + set!(self.buffer, value, field: field::ROUTE_INFO_PREFIX_LENGTH) + } + + /// Set the Route Preference field. + #[inline] + pub fn set_route_info_route_preference(&mut self, _value: u8) { + todo!(); + } + + /// Set the Route Lifetime field. + #[inline] + pub fn set_route_info_route_lifetime(&mut self, value: u32) { + set!(self.buffer, value, u32, field: field::ROUTE_INFO_LIFETIME) + } + + /// Set the prefix field. + #[inline] + pub fn set_route_info_prefix(&mut self, _prefix: &[u8]) { + todo!(); + } + + /// Clear the reserved field. + #[inline] + pub fn clear_route_info_reserved(&mut self) { + self.buffer.as_mut()[field::ROUTE_INFO_RESERVED] = 0; + } + } + + /// Getters for the DODAG Configuration Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x04 |Opt Length = 14| Flags |A| PCS | DIOIntDoubl. | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | DIOIntMin. | DIORedun. | MaxRankIncrease | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | MinHopRankIncrease | OCP | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Reserved | Def. Lifetime | Lifetime Unit | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + impl<T: AsRef<[u8]>> Packet<T> { + /// Return the Authentication Enabled field. + #[inline] + pub fn authentication_enabled(&self) -> bool { + get!( + self.buffer, + bool, + field: field::DODAG_CONF_AUTHENTICATION_ENABLED, + shift: 3, + mask: 0b1 + ) + } + + /// Return the Path Control Size field. + #[inline] + pub fn path_control_size(&self) -> u8 { + get!(self.buffer, field: field::DODAG_CONF_PATH_CONTROL_SIZE, mask: 0b111) + } + + /// Return the DIO Interval Doublings field. + #[inline] + pub fn dio_interval_doublings(&self) -> u8 { + get!(self.buffer, field: field::DODAG_CONF_DIO_INTERVAL_DOUBLINGS) + } + + /// Return the DIO Interval Minimum field. + #[inline] + pub fn dio_interval_minimum(&self) -> u8 { + get!(self.buffer, field: field::DODAG_CONF_DIO_INTERVAL_MINIMUM) + } + + /// Return the DIO Redundancy Constant field. + #[inline] + pub fn dio_redundancy_constant(&self) -> u8 { + get!( + self.buffer, + field: field::DODAG_CONF_DIO_REDUNDANCY_CONSTANT + ) + } + + /// Return the Max Rank Increase field. + #[inline] + pub fn max_rank_increase(&self) -> u16 { + get!( + self.buffer, + u16, + field: field::DODAG_CONF_DIO_MAX_RANK_INCREASE + ) + } + + /// Return the Minimum Hop Rank Increase field. + #[inline] + pub fn minimum_hop_rank_increase(&self) -> u16 { + get!( + self.buffer, + u16, + field: field::DODAG_CONF_MIN_HOP_RANK_INCREASE + ) + } + + /// Return the Objective Code Point field. + #[inline] + pub fn objective_code_point(&self) -> u16 { + get!( + self.buffer, + u16, + field: field::DODAG_CONF_OBJECTIVE_CODE_POINT + ) + } + + /// Return the Default Lifetime field. + #[inline] + pub fn default_lifetime(&self) -> u8 { + get!(self.buffer, field: field::DODAG_CONF_DEFAULT_LIFETIME) + } + + /// Return the Lifetime Unit field. + #[inline] + pub fn lifetime_unit(&self) -> u16 { + get!(self.buffer, u16, field: field::DODAG_CONF_LIFETIME_UNIT) + } + } + + /// Getters for the DODAG Configuration Option Message. + impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Clear the Flags field. + #[inline] + pub fn clear_dodag_conf_flags(&mut self) { + self.buffer.as_mut()[field::DODAG_CONF_FLAGS] = 0; + } + + /// Set the Authentication Enabled field. + #[inline] + pub fn set_dodag_conf_authentication_enabled(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::DODAG_CONF_AUTHENTICATION_ENABLED, + shift: 3, + mask: 0b1 + ) + } + + /// Set the Path Control Size field. + #[inline] + pub fn set_dodag_conf_path_control_size(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::DODAG_CONF_PATH_CONTROL_SIZE, + mask: 0b111 + ) + } + + /// Set the DIO Interval Doublings field. + #[inline] + pub fn set_dodag_conf_dio_interval_doublings(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::DODAG_CONF_DIO_INTERVAL_DOUBLINGS + ) + } + + /// Set the DIO Interval Minimum field. + #[inline] + pub fn set_dodag_conf_dio_interval_minimum(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::DODAG_CONF_DIO_INTERVAL_MINIMUM + ) + } + + /// Set the DIO Redundancy Constant field. + #[inline] + pub fn set_dodag_conf_dio_redundancy_constant(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::DODAG_CONF_DIO_REDUNDANCY_CONSTANT + ) + } + + /// Set the Max Rank Increase field. + #[inline] + pub fn set_dodag_conf_max_rank_increase(&mut self, value: u16) { + set!( + self.buffer, + value, + u16, + field: field::DODAG_CONF_DIO_MAX_RANK_INCREASE + ) + } + + /// Set the Minimum Hop Rank Increase field. + #[inline] + pub fn set_dodag_conf_minimum_hop_rank_increase(&mut self, value: u16) { + set!( + self.buffer, + value, + u16, + field: field::DODAG_CONF_MIN_HOP_RANK_INCREASE + ) + } + + /// Set the Objective Code Point field. + #[inline] + pub fn set_dodag_conf_objective_code_point(&mut self, value: u16) { + set!( + self.buffer, + value, + u16, + field: field::DODAG_CONF_OBJECTIVE_CODE_POINT + ) + } + + /// Set the Default Lifetime field. + #[inline] + pub fn set_dodag_conf_default_lifetime(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::DODAG_CONF_DEFAULT_LIFETIME + ) + } + + /// Set the Lifetime Unit field. + #[inline] + pub fn set_dodag_conf_lifetime_unit(&mut self, value: u16) { + set!( + self.buffer, + value, + u16, + field: field::DODAG_CONF_LIFETIME_UNIT + ) + } + } + + /// Getters for the RPL Target Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x05 | Option Length | Flags | Prefix Length | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | | + /// + + + /// | Target Prefix (Variable Length) | + /// . . + /// . . + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + impl<T: AsRef<[u8]>> Packet<T> { + /// Return the Target Prefix Length field. + pub fn target_prefix_length(&self) -> u8 { + get!(self.buffer, field: field::RPL_TARGET_PREFIX_LENGTH) + } + } + + impl<'p, T: AsRef<[u8]> + ?Sized> Packet<&'p T> { + /// Return the Target Prefix field. + #[inline] + pub fn target_prefix(&self) -> &'p [u8] { + let option_len = self.option_length(); + &self.buffer.as_ref()[field::RPL_TARGET_PREFIX_LENGTH + 1..] + [..option_len as usize - field::RPL_TARGET_PREFIX_LENGTH + 1] + } + } + + /// Setters for the RPL Target Option Message. + impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Clear the Flags field. + #[inline] + pub fn clear_rpl_target_flags(&mut self) { + self.buffer.as_mut()[field::RPL_TARGET_FLAGS] = 0; + } + + /// Set the Target Prefix Length field. + #[inline] + pub fn set_rpl_target_prefix_length(&mut self, value: u8) { + set!(self.buffer, value, field: field::RPL_TARGET_PREFIX_LENGTH) + } + + /// Set the Target Prefix field. + #[inline] + pub fn set_rpl_target_prefix(&mut self, prefix: &[u8]) { + self.buffer.as_mut()[field::RPL_TARGET_PREFIX_LENGTH + 1..][..prefix.len()] + .copy_from_slice(prefix); + } + } + + /// Getters for the Transit Information Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x06 | Option Length |E| Flags | Path Control | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Path Sequence | Path Lifetime | | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + /// | | + /// + + + /// | | + /// + Parent Address* + + /// | | + /// + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + impl<T: AsRef<[u8]>> Packet<T> { + /// Return the External flag. + #[inline] + pub fn is_external(&self) -> bool { + get!( + self.buffer, + bool, + field: field::TRANSIT_INFO_EXTERNAL, + shift: 7, + mask: 0b1, + ) + } + + /// Return the Path Control field. + #[inline] + pub fn path_control(&self) -> u8 { + get!(self.buffer, field: field::TRANSIT_INFO_PATH_CONTROL) + } + + /// Return the Path Sequence field. + #[inline] + pub fn path_sequence(&self) -> u8 { + get!(self.buffer, field: field::TRANSIT_INFO_PATH_SEQUENCE) + } + + /// Return the Path Lifetime field. + #[inline] + pub fn path_lifetime(&self) -> u8 { + get!(self.buffer, field: field::TRANSIT_INFO_PATH_LIFETIME) + } + + /// Return the Parent Address field. + #[inline] + pub fn parent_address(&self) -> Option<Address> { + if self.option_length() > 5 { + Some(Address::from_bytes( + &self.buffer.as_ref()[field::TRANSIT_INFO_PARENT_ADDRESS], + )) + } else { + None + } + } + } + + /// Setters for the Transit Information Option Message. + impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Clear the Flags field. + #[inline] + pub fn clear_transit_info_flags(&mut self) { + self.buffer.as_mut()[field::TRANSIT_INFO_FLAGS] = 0; + } + + /// Set the External flag. + #[inline] + pub fn set_transit_info_is_external(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::TRANSIT_INFO_EXTERNAL, + shift: 7, + mask: 0b1 + ) + } + + /// Set the Path Control field. + #[inline] + pub fn set_transit_info_path_control(&mut self, value: u8) { + set!(self.buffer, value, field: field::TRANSIT_INFO_PATH_CONTROL) + } + + /// Set the Path Sequence field. + #[inline] + pub fn set_transit_info_path_sequence(&mut self, value: u8) { + set!(self.buffer, value, field: field::TRANSIT_INFO_PATH_SEQUENCE) + } + + /// Set the Path Lifetime field. + #[inline] + pub fn set_transit_info_path_lifetime(&mut self, value: u8) { + set!(self.buffer, value, field: field::TRANSIT_INFO_PATH_LIFETIME) + } + + /// Set the Parent Address field. + #[inline] + pub fn set_transit_info_parent_address(&mut self, address: Address) { + self.buffer.as_mut()[field::TRANSIT_INFO_PARENT_ADDRESS] + .copy_from_slice(address.as_bytes()); + } + } + + /// Getters for the Solicited Information Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x07 |Opt Length = 19| RPLInstanceID |V|I|D| Flags | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | | + /// + + + /// | | + /// + DODAGID + + /// | | + /// + + + /// | | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// |Version Number | + /// +-+-+-+-+-+-+-+-+ + /// ``` + impl<T: AsRef<[u8]>> Packet<T> { + /// Return the RPL Instance ID field. + #[inline] + pub fn rpl_instance_id(&self) -> u8 { + get!(self.buffer, field: field::SOLICITED_INFO_RPL_INSTANCE_ID) + } + + /// Return the Version Predicate flag. + #[inline] + pub fn version_predicate(&self) -> bool { + get!( + self.buffer, + bool, + field: field::SOLICITED_INFO_VERSION_PREDICATE, + shift: 7, + mask: 0b1, + ) + } + + /// Return the Instance ID Predicate flag. + #[inline] + pub fn instance_id_predicate(&self) -> bool { + get!( + self.buffer, + bool, + field: field::SOLICITED_INFO_INSTANCE_ID_PREDICATE, + shift: 6, + mask: 0b1, + ) + } + + /// Return the DODAG Predicate ID flag. + #[inline] + pub fn dodag_id_predicate(&self) -> bool { + get!( + self.buffer, + bool, + field: field::SOLICITED_INFO_DODAG_ID_PREDICATE, + shift: 5, + mask: 0b1, + ) + } + + /// Return the DODAG ID field. + #[inline] + pub fn dodag_id(&self) -> Address { + get!( + self.buffer, + into: Address, + fun: from_bytes, + field: field::SOLICITED_INFO_DODAG_ID + ) + } + + /// Return the Version Number field. + #[inline] + pub fn version_number(&self) -> u8 { + get!(self.buffer, field: field::SOLICITED_INFO_VERSION_NUMBER) + } + } + + /// Setters for the Solicited Information Option Message. + impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Clear the Flags field. + #[inline] + pub fn clear_solicited_info_flags(&mut self) { + self.buffer.as_mut()[field::SOLICITED_INFO_FLAGS] = 0; + } + + /// Set the RPL Instance ID field. + #[inline] + pub fn set_solicited_info_rpl_instance_id(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::SOLICITED_INFO_RPL_INSTANCE_ID + ) + } + + /// Set the Version Predicate flag. + #[inline] + pub fn set_solicited_info_version_predicate(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::SOLICITED_INFO_VERSION_PREDICATE, + shift: 7, + mask: 0b1 + ) + } + + /// Set the Instance ID Predicate flag. + #[inline] + pub fn set_solicited_info_instance_id_predicate(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::SOLICITED_INFO_INSTANCE_ID_PREDICATE, + shift: 6, + mask: 0b1 + ) + } + + /// Set the DODAG Predicate ID flag. + #[inline] + pub fn set_solicited_info_dodag_id_predicate(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::SOLICITED_INFO_DODAG_ID_PREDICATE, + shift: 5, + mask: 0b1 + ) + } + + /// Set the DODAG ID field. + #[inline] + pub fn set_solicited_info_dodag_id(&mut self, address: Address) { + set!( + self.buffer, + address: address, + field: field::SOLICITED_INFO_DODAG_ID + ) + } + + /// Set the Version Number field. + #[inline] + pub fn set_solicited_info_version_number(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::SOLICITED_INFO_VERSION_NUMBER + ) + } + } + + /// Getters for the Prefix Information Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x08 |Opt Length = 30| Prefix Length |L|A|R|Reserved1| + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Valid Lifetime | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Preferred Lifetime | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Reserved2 | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | | + /// + + + /// | | + /// + Prefix + + /// | | + /// + + + /// | | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + impl<T: AsRef<[u8]>> Packet<T> { + /// Return the Prefix Length field. + #[inline] + pub fn prefix_info_prefix_length(&self) -> u8 { + get!(self.buffer, field: field::PREFIX_INFO_PREFIX_LENGTH) + } + + /// Return the On-Link flag. + #[inline] + pub fn on_link(&self) -> bool { + get!( + self.buffer, + bool, + field: field::PREFIX_INFO_ON_LINK, + shift: 7, + mask: 0b1, + ) + } + + /// Return the Autonomous Address-Configuration flag. + #[inline] + pub fn autonomous_address_configuration(&self) -> bool { + get!( + self.buffer, + bool, + field: field::PREFIX_INFO_AUTONOMOUS_CONF, + shift: 6, + mask: 0b1, + ) + } + + /// Return the Router Address flag. + #[inline] + pub fn router_address(&self) -> bool { + get!( + self.buffer, + bool, + field: field::PREFIX_INFO_ROUTER_ADDRESS_FLAG, + shift: 5, + mask: 0b1, + ) + } + + /// Return the Valid Lifetime field. + #[inline] + pub fn valid_lifetime(&self) -> u32 { + get!(self.buffer, u32, field: field::PREFIX_INFO_VALID_LIFETIME) + } + + /// Return the Preferred Lifetime field. + #[inline] + pub fn preferred_lifetime(&self) -> u32 { + get!( + self.buffer, + u32, + field: field::PREFIX_INFO_PREFERRED_LIFETIME + ) + } + } + + impl<'p, T: AsRef<[u8]> + ?Sized> Packet<&'p T> { + /// Return the Prefix field. + #[inline] + pub fn destination_prefix(&self) -> &'p [u8] { + &self.buffer.as_ref()[field::PREFIX_INFO_PREFIX] + } + } + + /// Setters for the Prefix Information Option Message. + impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Clear the reserved fields. + #[inline] + pub fn clear_prefix_info_reserved(&mut self) { + self.buffer.as_mut()[field::PREFIX_INFO_RESERVED1] = 0; + self.buffer.as_mut()[field::PREFIX_INFO_RESERVED2].copy_from_slice(&[0; 4]); + } + + /// Set the Prefix Length field. + #[inline] + pub fn set_prefix_info_prefix_length(&mut self, value: u8) { + set!(self.buffer, value, field: field::PREFIX_INFO_PREFIX_LENGTH) + } + + /// Set the On-Link flag. + #[inline] + pub fn set_prefix_info_on_link(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::PREFIX_INFO_ON_LINK, shift: 7, mask: 0b1) + } + + /// Set the Autonomous Address-Configuration flag. + #[inline] + pub fn set_prefix_info_autonomous_address_configuration(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::PREFIX_INFO_AUTONOMOUS_CONF, + shift: 6, + mask: 0b1 + ) + } + + /// Set the Router Address flag. + #[inline] + pub fn set_prefix_info_router_address(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::PREFIX_INFO_ROUTER_ADDRESS_FLAG, + shift: 5, + mask: 0b1 + ) + } + + /// Set the Valid Lifetime field. + #[inline] + pub fn set_prefix_info_valid_lifetime(&mut self, value: u32) { + set!( + self.buffer, + value, + u32, + field: field::PREFIX_INFO_VALID_LIFETIME + ) + } + + /// Set the Preferred Lifetime field. + #[inline] + pub fn set_prefix_info_preferred_lifetime(&mut self, value: u32) { + set!( + self.buffer, + value, + u32, + field: field::PREFIX_INFO_PREFERRED_LIFETIME + ) + } + + /// Set the Prefix field. + #[inline] + pub fn set_prefix_info_destination_prefix(&mut self, prefix: &[u8]) { + self.buffer.as_mut()[field::PREFIX_INFO_PREFIX].copy_from_slice(prefix); + } + } + + /// Getters for the RPL Target Descriptor Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x09 |Opt Length = 4 | Descriptor + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// Descriptor (cont.) | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + impl<T: AsRef<[u8]>> Packet<T> { + /// Return the Descriptor field. + #[inline] + pub fn descriptor(&self) -> u32 { + get!(self.buffer, u32, field: field::TARGET_DESCRIPTOR) + } + } + + /// Setters for the RPL Target Descriptor Option Message. + impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the Descriptor field. + #[inline] + pub fn set_rpl_target_descriptor_descriptor(&mut self, value: u32) { + set!(self.buffer, value, u32, field: field::TARGET_DESCRIPTOR) + } + } + + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub enum Repr<'p> { + Pad1, + PadN(u8), + DagMetricContainer, + RouteInformation { + prefix_length: u8, + preference: u8, + lifetime: u32, + prefix: &'p [u8], + }, + DodagConfiguration { + authentication_enabled: bool, + path_control_size: u8, + dio_interval_doublings: u8, + dio_interval_min: u8, + dio_redundancy_constant: u8, + max_rank_increase: u16, + minimum_hop_rank_increase: u16, + objective_code_point: u16, + default_lifetime: u8, + lifetime_unit: u16, + }, + RplTarget { + prefix_length: u8, + prefix: crate::wire::Ipv6Address, // FIXME: this is not the correct type, because the + // field can be an IPv6 address, a prefix or a + // multicast group. + }, + TransitInformation { + external: bool, + path_control: u8, + path_sequence: u8, + path_lifetime: u8, + parent_address: Option<Address>, + }, + SolicitedInformation { + rpl_instance_id: InstanceId, + version_predicate: bool, + instance_id_predicate: bool, + dodag_id_predicate: bool, + dodag_id: Address, + version_number: u8, + }, + PrefixInformation { + prefix_length: u8, + on_link: bool, + autonomous_address_configuration: bool, + router_address: bool, + valid_lifetime: u32, + preferred_lifetime: u32, + destination_prefix: &'p [u8], + }, + RplTargetDescriptor { + descriptor: u32, + }, + } + + impl core::fmt::Display for Repr<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Repr::Pad1 => write!(f, "Pad1"), + Repr::PadN(n) => write!(f, "PadN({n})"), + Repr::DagMetricContainer => todo!(), + Repr::RouteInformation { + prefix_length, + preference, + lifetime, + prefix, + } => { + write!( + f, + "ROUTE INFO \ + PrefixLength={prefix_length} \ + Preference={preference} \ + Lifetime={lifetime} \ + Prefix={prefix:0x?}" + ) + } + Repr::DodagConfiguration { + dio_interval_doublings, + dio_interval_min, + dio_redundancy_constant, + max_rank_increase, + minimum_hop_rank_increase, + objective_code_point, + default_lifetime, + lifetime_unit, + .. + } => { + write!( + f, + "DODAG CONF \ + IntD={dio_interval_doublings} \ + IntMin={dio_interval_min} \ + RedCst={dio_redundancy_constant} \ + MaxRankIncr={max_rank_increase} \ + MinHopRankIncr={minimum_hop_rank_increase} \ + OCP={objective_code_point} \ + DefaultLifetime={default_lifetime} \ + LifeUnit={lifetime_unit}" + ) + } + Repr::RplTarget { + prefix_length, + prefix, + } => { + write!( + f, + "RPL Target \ + PrefixLength={prefix_length} \ + Prefix={prefix:0x?}" + ) + } + Repr::TransitInformation { + external, + path_control, + path_sequence, + path_lifetime, + parent_address, + } => { + write!( + f, + "Transit Info \ + External={external} \ + PathCtrl={path_control} \ + PathSqnc={path_sequence} \ + PathLifetime={path_lifetime} \ + Parent={parent_address:0x?}" + ) + } + Repr::SolicitedInformation { + rpl_instance_id, + version_predicate, + instance_id_predicate, + dodag_id_predicate, + dodag_id, + version_number, + } => { + write!( + f, + "Solicited Info \ + I={instance_id_predicate} \ + IID={rpl_instance_id:0x?} \ + D={dodag_id_predicate} \ + DODAGID={dodag_id} \ + V={version_predicate} \ + Version={version_number}" + ) + } + Repr::PrefixInformation { + prefix_length, + on_link, + autonomous_address_configuration, + router_address, + valid_lifetime, + preferred_lifetime, + destination_prefix, + } => { + write!( + f, + "Prefix Info \ + PrefixLength={prefix_length} \ + L={on_link} A={autonomous_address_configuration} R={router_address} \ + Valid={valid_lifetime} \ + Preferred={preferred_lifetime} \ + Prefix={destination_prefix:0x?}" + ) + } + Repr::RplTargetDescriptor { .. } => write!(f, "Target Descriptor"), + } + } + } + + impl<'p> Repr<'p> { + pub fn parse<T: AsRef<[u8]> + ?Sized>(packet: &Packet<&'p T>) -> Result<Self> { + match packet.option_type() { + OptionType::Pad1 => Ok(Repr::Pad1), + OptionType::PadN => Ok(Repr::PadN(packet.option_length())), + OptionType::DagMetricContainer => todo!(), + OptionType::RouteInformation => Ok(Repr::RouteInformation { + prefix_length: packet.prefix_length(), + preference: packet.route_preference(), + lifetime: packet.route_lifetime(), + prefix: packet.prefix(), + }), + OptionType::DodagConfiguration => Ok(Repr::DodagConfiguration { + authentication_enabled: packet.authentication_enabled(), + path_control_size: packet.path_control_size(), + dio_interval_doublings: packet.dio_interval_doublings(), + dio_interval_min: packet.dio_interval_minimum(), + dio_redundancy_constant: packet.dio_redundancy_constant(), + max_rank_increase: packet.max_rank_increase(), + minimum_hop_rank_increase: packet.minimum_hop_rank_increase(), + objective_code_point: packet.objective_code_point(), + default_lifetime: packet.default_lifetime(), + lifetime_unit: packet.lifetime_unit(), + }), + OptionType::RplTarget => Ok(Repr::RplTarget { + prefix_length: packet.target_prefix_length(), + prefix: crate::wire::Ipv6Address::from_bytes(packet.target_prefix()), + }), + OptionType::TransitInformation => Ok(Repr::TransitInformation { + external: packet.is_external(), + path_control: packet.path_control(), + path_sequence: packet.path_sequence(), + path_lifetime: packet.path_lifetime(), + parent_address: packet.parent_address(), + }), + OptionType::SolicitedInformation => Ok(Repr::SolicitedInformation { + rpl_instance_id: InstanceId::from(packet.rpl_instance_id()), + version_predicate: packet.version_predicate(), + instance_id_predicate: packet.instance_id_predicate(), + dodag_id_predicate: packet.dodag_id_predicate(), + dodag_id: packet.dodag_id(), + version_number: packet.version_number(), + }), + OptionType::PrefixInformation => Ok(Repr::PrefixInformation { + prefix_length: packet.prefix_info_prefix_length(), + on_link: packet.on_link(), + autonomous_address_configuration: packet.autonomous_address_configuration(), + router_address: packet.router_address(), + valid_lifetime: packet.valid_lifetime(), + preferred_lifetime: packet.preferred_lifetime(), + destination_prefix: packet.destination_prefix(), + }), + OptionType::RplTargetDescriptor => Ok(Repr::RplTargetDescriptor { + descriptor: packet.descriptor(), + }), + OptionType::Unknown(_) => Err(Error), + } + } + + pub fn buffer_len(&self) -> usize { + match self { + Repr::Pad1 => 1, + Repr::PadN(size) => 2 + *size as usize, + Repr::DagMetricContainer => todo!(), + Repr::RouteInformation { prefix, .. } => 2 + 6 + prefix.len(), + Repr::DodagConfiguration { .. } => 2 + 14, + Repr::RplTarget { prefix, .. } => 2 + 2 + prefix.0.len(), + Repr::TransitInformation { parent_address, .. } => { + 2 + 4 + if parent_address.is_some() { 16 } else { 0 } + } + Repr::SolicitedInformation { .. } => 2 + 2 + 16 + 1, + Repr::PrefixInformation { .. } => 32, + Repr::RplTargetDescriptor { .. } => 2 + 4, + } + } + + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(&self, packet: &mut Packet<&'p mut T>) { + let mut option_length = self.buffer_len() as u8; + + packet.set_option_type(self.into()); + + if !matches!(self, Repr::Pad1) { + option_length -= 2; + packet.set_option_length(option_length); + } + + match self { + Repr::Pad1 => {} + Repr::PadN(size) => { + packet.clear_padn(*size); + } + Repr::DagMetricContainer => { + unimplemented!(); + } + Repr::RouteInformation { + prefix_length, + preference, + lifetime, + prefix, + } => { + packet.clear_route_info_reserved(); + packet.set_route_info_prefix_length(*prefix_length); + packet.set_route_info_route_preference(*preference); + packet.set_route_info_route_lifetime(*lifetime); + packet.set_route_info_prefix(prefix); + } + Repr::DodagConfiguration { + authentication_enabled, + path_control_size, + dio_interval_doublings, + dio_interval_min, + dio_redundancy_constant, + max_rank_increase, + minimum_hop_rank_increase, + objective_code_point, + default_lifetime, + lifetime_unit, + } => { + packet.clear_dodag_conf_flags(); + packet.set_dodag_conf_authentication_enabled(*authentication_enabled); + packet.set_dodag_conf_path_control_size(*path_control_size); + packet.set_dodag_conf_dio_interval_doublings(*dio_interval_doublings); + packet.set_dodag_conf_dio_interval_minimum(*dio_interval_min); + packet.set_dodag_conf_dio_redundancy_constant(*dio_redundancy_constant); + packet.set_dodag_conf_max_rank_increase(*max_rank_increase); + packet.set_dodag_conf_minimum_hop_rank_increase(*minimum_hop_rank_increase); + packet.set_dodag_conf_objective_code_point(*objective_code_point); + packet.set_dodag_conf_default_lifetime(*default_lifetime); + packet.set_dodag_conf_lifetime_unit(*lifetime_unit); + } + Repr::RplTarget { + prefix_length, + prefix, + } => { + packet.clear_rpl_target_flags(); + packet.set_rpl_target_prefix_length(*prefix_length); + packet.set_rpl_target_prefix(prefix.as_bytes()); + } + Repr::TransitInformation { + external, + path_control, + path_sequence, + path_lifetime, + parent_address, + } => { + packet.clear_transit_info_flags(); + packet.set_transit_info_is_external(*external); + packet.set_transit_info_path_control(*path_control); + packet.set_transit_info_path_sequence(*path_sequence); + packet.set_transit_info_path_lifetime(*path_lifetime); + + if let Some(address) = parent_address { + packet.set_transit_info_parent_address(*address); + } + } + Repr::SolicitedInformation { + rpl_instance_id, + version_predicate, + instance_id_predicate, + dodag_id_predicate, + dodag_id, + version_number, + } => { + packet.clear_solicited_info_flags(); + packet.set_solicited_info_rpl_instance_id((*rpl_instance_id).into()); + packet.set_solicited_info_version_predicate(*version_predicate); + packet.set_solicited_info_instance_id_predicate(*instance_id_predicate); + packet.set_solicited_info_dodag_id_predicate(*dodag_id_predicate); + packet.set_solicited_info_version_number(*version_number); + packet.set_solicited_info_dodag_id(*dodag_id); + } + Repr::PrefixInformation { + prefix_length, + on_link, + autonomous_address_configuration, + router_address, + valid_lifetime, + preferred_lifetime, + destination_prefix, + } => { + packet.clear_prefix_info_reserved(); + packet.set_prefix_info_prefix_length(*prefix_length); + packet.set_prefix_info_on_link(*on_link); + packet.set_prefix_info_autonomous_address_configuration( + *autonomous_address_configuration, + ); + packet.set_prefix_info_router_address(*router_address); + packet.set_prefix_info_valid_lifetime(*valid_lifetime); + packet.set_prefix_info_preferred_lifetime(*preferred_lifetime); + packet.set_prefix_info_destination_prefix(destination_prefix); + } + Repr::RplTargetDescriptor { descriptor } => { + packet.set_rpl_target_descriptor_descriptor(*descriptor); + } + } + } + } +} + +pub mod data { + use super::{InstanceId, Result}; + use byteorder::{ByteOrder, NetworkEndian}; + + mod field { + use crate::wire::field::*; + + pub const FLAGS: usize = 0; + pub const INSTANCE_ID: usize = 1; + pub const SENDER_RANK: Field = 2..4; + } + + /// A read/write wrapper around a RPL Packet Information send with + /// an IPv6 Hop-by-Hop option, defined in RFC6553. + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Option Type | Opt Data Len | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// |O|R|F|0|0|0|0|0| RPLInstanceID | SenderRank | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | (sub-TLVs) | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + pub struct Packet<T: AsRef<[u8]>> { + buffer: T, + } + + impl<T: AsRef<[u8]>> Packet<T> { + #[inline] + pub fn new_unchecked(buffer: T) -> Self { + Self { buffer } + } + + #[inline] + pub fn new_checked(buffer: T) -> Result<Self> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + #[inline] + pub fn check_len(&self) -> Result<()> { + if self.buffer.as_ref().len() == 4 { + Ok(()) + } else { + Err(crate::wire::Error) + } + } + + #[inline] + pub fn is_down(&self) -> bool { + get!(self.buffer, bool, field: field::FLAGS, shift: 7, mask: 0b1) + } + + #[inline] + pub fn has_rank_error(&self) -> bool { + get!(self.buffer, bool, field: field::FLAGS, shift: 6, mask: 0b1) + } + + #[inline] + pub fn has_forwarding_error(&self) -> bool { + get!(self.buffer, bool, field: field::FLAGS, shift: 5, mask: 0b1) + } + + #[inline] + pub fn rpl_instance_id(&self) -> InstanceId { + get!(self.buffer, into: InstanceId, field: field::INSTANCE_ID) + } + + #[inline] + pub fn sender_rank(&self) -> u16 { + get!(self.buffer, u16, field: field::SENDER_RANK) + } + } + + impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + #[inline] + pub fn set_is_down(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::FLAGS, shift: 7, mask: 0b1) + } + + #[inline] + pub fn set_has_rank_error(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::FLAGS, shift: 6, mask: 0b1) + } + + #[inline] + pub fn set_has_forwarding_error(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::FLAGS, shift: 5, mask: 0b1) + } + + #[inline] + pub fn set_rpl_instance_id(&mut self, value: u8) { + set!(self.buffer, value, field: field::INSTANCE_ID) + } + + #[inline] + pub fn set_sender_rank(&mut self, value: u16) { + set!(self.buffer, value, u16, field: field::SENDER_RANK) + } + } + + /// A high-level representation of an IPv6 Extension Header Option. + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct HopByHopOption { + pub down: bool, + pub rank_error: bool, + pub forwarding_error: bool, + pub instance_id: InstanceId, + pub sender_rank: u16, + } + + impl HopByHopOption { + /// Parse an IPv6 Extension Header Option and return a high-level representation. + pub fn parse<T>(opt: &Packet<&T>) -> Self + where + T: AsRef<[u8]> + ?Sized, + { + Self { + down: opt.is_down(), + rank_error: opt.has_rank_error(), + forwarding_error: opt.has_forwarding_error(), + instance_id: opt.rpl_instance_id(), + sender_rank: opt.sender_rank(), + } + } + + /// Return the length of a header that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + 4 + } + + /// Emit a high-level representation into an IPv6 Extension Header Option. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]> + ?Sized>(&self, opt: &mut Packet<&mut T>) { + opt.set_is_down(self.down); + opt.set_has_rank_error(self.rank_error); + opt.set_has_forwarding_error(self.forwarding_error); + opt.set_rpl_instance_id(self.instance_id.into()); + opt.set_sender_rank(self.sender_rank); + } + } + + impl core::fmt::Display for HopByHopOption { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!( + f, + "down={} rank_error={} forw_error={} IID={:?} sender_rank={}", + self.down, + self.rank_error, + self.forwarding_error, + self.instance_id, + self.sender_rank + ) + } + } +} + +#[cfg(test)] +mod tests { + use super::options::{Packet as OptionPacket, Repr as OptionRepr}; + use super::Repr as RplRepr; + use super::*; + use crate::phy::ChecksumCapabilities; + use crate::wire::{icmpv6::*, *}; + + #[test] + fn dis_packet() { + let data = [0x7a, 0x3b, 0x3a, 0x1a, 0x9b, 0x00, 0x00, 0x00, 0x00, 0x00]; + + let ll_src_address = + Ieee802154Address::Extended([0x9e, 0xd3, 0xa2, 0x9c, 0x57, 0x1a, 0x4f, 0xe4]); + let ll_dst_address = Ieee802154Address::Short([0xff, 0xff]); + + let packet = SixlowpanIphcPacket::new_checked(&data).unwrap(); + let repr = + SixlowpanIphcRepr::parse(&packet, Some(ll_src_address), Some(ll_dst_address), &[]) + .unwrap(); + + let icmp_repr = match repr.next_header { + SixlowpanNextHeader::Uncompressed(IpProtocol::Icmpv6) => { + let icmp_packet = Icmpv6Packet::new_checked(packet.payload()).unwrap(); + match Icmpv6Repr::parse( + &IpAddress::Ipv6(repr.src_addr), + &IpAddress::Ipv6(repr.dst_addr), + &icmp_packet, + &ChecksumCapabilities::ignored(), + ) { + Ok(icmp @ Icmpv6Repr::Rpl(RplRepr::DodagInformationSolicitation { .. })) => { + icmp + } + _ => unreachable!(), + } + } + _ => unreachable!(), + }; + + // We also try to emit the packet: + let mut buffer = vec![0u8; repr.buffer_len() + icmp_repr.buffer_len()]; + repr.emit(&mut SixlowpanIphcPacket::new_unchecked( + &mut buffer[..repr.buffer_len()], + )); + icmp_repr.emit( + &repr.src_addr.into(), + &repr.dst_addr.into(), + &mut Icmpv6Packet::new_unchecked( + &mut buffer[repr.buffer_len()..][..icmp_repr.buffer_len()], + ), + &ChecksumCapabilities::ignored(), + ); + + assert_eq!(&data[..], &buffer[..]); + } + + /// Parsing of DIO packets. + #[test] + fn dio_packet() { + let data = [ + 0x9b, 0x01, 0x00, 0x00, 0x00, 0xf0, 0x00, 0x80, 0x08, 0xf0, 0x00, 0x00, 0xfd, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, + 0x04, 0x0e, 0x00, 0x08, 0x0c, 0x00, 0x04, 0x00, 0x00, 0x80, 0x00, 0x01, 0x00, 0x1e, + 0x00, 0x3c, 0x08, 0x1e, 0x40, 0x40, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + + let addr = Address::from_bytes(&[ + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x01, 0x00, 0x01, 0x00, 0x01, + 0x00, 0x01, + ]); + + let dest_prefix = [ + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, + ]; + + let packet = Packet::new_checked(&data[..]).unwrap(); + assert_eq!(packet.msg_type(), Message::RplControl); + assert_eq!( + RplControlMessage::from(packet.msg_code()), + RplControlMessage::DodagInformationObject + ); + + let mut dio_repr = RplRepr::parse(&packet).unwrap(); + match dio_repr { + RplRepr::DodagInformationObject { + rpl_instance_id, + version_number, + rank, + grounded, + mode_of_operation, + dodag_preference, + dtsn, + dodag_id, + .. + } => { + assert_eq!(rpl_instance_id, InstanceId::from(0)); + assert_eq!(version_number, 240); + assert_eq!(rank, 128); + assert!(!grounded); + assert_eq!(mode_of_operation, ModeOfOperation::NonStoringMode); + assert_eq!(dodag_preference, 0); + assert_eq!(dtsn, 240); + assert_eq!(dodag_id, addr); + } + _ => unreachable!(), + } + + let option = OptionPacket::new_unchecked(packet.options().unwrap()); + let dodag_conf_option = OptionRepr::parse(&option).unwrap(); + match dodag_conf_option { + OptionRepr::DodagConfiguration { + authentication_enabled, + path_control_size, + dio_interval_doublings, + dio_interval_min, + dio_redundancy_constant, + max_rank_increase, + minimum_hop_rank_increase, + objective_code_point, + default_lifetime, + lifetime_unit, + } => { + assert!(!authentication_enabled); + assert_eq!(path_control_size, 0); + assert_eq!(dio_interval_doublings, 8); + assert_eq!(dio_interval_min, 12); + assert_eq!(dio_redundancy_constant, 0); + assert_eq!(max_rank_increase, 1024); + assert_eq!(minimum_hop_rank_increase, 128); + assert_eq!(objective_code_point, 1); + assert_eq!(default_lifetime, 30); + assert_eq!(lifetime_unit, 60); + } + _ => unreachable!(), + } + + let option = OptionPacket::new_unchecked(option.next_option().unwrap()); + let prefix_info_option = OptionRepr::parse(&option).unwrap(); + match prefix_info_option { + OptionRepr::PrefixInformation { + prefix_length, + on_link, + autonomous_address_configuration, + valid_lifetime, + preferred_lifetime, + destination_prefix, + .. + } => { + assert_eq!(prefix_length, 64); + assert!(!on_link); + assert!(autonomous_address_configuration); + assert_eq!(valid_lifetime, u32::MAX); + assert_eq!(preferred_lifetime, u32::MAX); + assert_eq!(destination_prefix, &dest_prefix[..]); + } + _ => unreachable!(), + } + + let mut options_buffer = + vec![0u8; dodag_conf_option.buffer_len() + prefix_info_option.buffer_len()]; + + dodag_conf_option.emit(&mut OptionPacket::new_unchecked( + &mut options_buffer[..dodag_conf_option.buffer_len()], + )); + prefix_info_option.emit(&mut OptionPacket::new_unchecked( + &mut options_buffer[dodag_conf_option.buffer_len()..] + [..prefix_info_option.buffer_len()], + )); + + dio_repr.set_options(&options_buffer[..]); + + let mut buffer = vec![0u8; dio_repr.buffer_len()]; + dio_repr.emit(&mut Packet::new_unchecked(&mut buffer[..])); + + assert_eq!(&data[..], &buffer[..]); + } + + /// Parsing of DAO packets. + #[test] + fn dao_packet() { + let data = [ + 0x9b, 0x02, 0x00, 0x00, 0x00, 0x80, 0x00, 0xf1, 0x05, 0x12, 0x00, 0x80, 0xfd, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x00, 0x02, 0x00, 0x02, 0x00, 0x02, + 0x06, 0x14, 0x00, 0x00, 0x00, 0x1e, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, + ]; + + let target_prefix = [ + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x00, 0x02, 0x00, 0x02, + 0x00, 0x02, + ]; + + let parent_addr = Address::from_bytes(&[ + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x01, 0x00, 0x01, 0x00, 0x01, + 0x00, 0x01, + ]); + + let packet = Packet::new_checked(&data[..]).unwrap(); + let mut dao_repr = RplRepr::parse(&packet).unwrap(); + match dao_repr { + RplRepr::DestinationAdvertisementObject { + rpl_instance_id, + expect_ack, + sequence, + dodag_id, + .. + } => { + assert_eq!(rpl_instance_id, InstanceId::from(0)); + assert!(expect_ack); + assert_eq!(sequence, 241); + assert_eq!(dodag_id, None); + } + _ => unreachable!(), + } + + let option = OptionPacket::new_unchecked(packet.options().unwrap()); + + let rpl_target_option = OptionRepr::parse(&option).unwrap(); + match rpl_target_option { + OptionRepr::RplTarget { + prefix_length, + prefix, + } => { + assert_eq!(prefix_length, 128); + assert_eq!(prefix.as_bytes(), &target_prefix[..]); + } + _ => unreachable!(), + } + + let option = OptionPacket::new_unchecked(option.next_option().unwrap()); + let transit_info_option = OptionRepr::parse(&option).unwrap(); + match transit_info_option { + OptionRepr::TransitInformation { + external, + path_control, + path_sequence, + path_lifetime, + parent_address, + } => { + assert!(!external); + assert_eq!(path_control, 0); + assert_eq!(path_sequence, 0); + assert_eq!(path_lifetime, 30); + assert_eq!(parent_address, Some(parent_addr)); + } + _ => unreachable!(), + } + + let mut options_buffer = + vec![0u8; rpl_target_option.buffer_len() + transit_info_option.buffer_len()]; + + rpl_target_option.emit(&mut OptionPacket::new_unchecked( + &mut options_buffer[..rpl_target_option.buffer_len()], + )); + transit_info_option.emit(&mut OptionPacket::new_unchecked( + &mut options_buffer[rpl_target_option.buffer_len()..] + [..transit_info_option.buffer_len()], + )); + + dao_repr.set_options(&options_buffer[..]); + + let mut buffer = vec![0u8; dao_repr.buffer_len()]; + dao_repr.emit(&mut Packet::new_unchecked(&mut buffer[..])); + + assert_eq!(&data[..], &buffer[..]); + } + + /// Parsing of DAO-ACK packets. + #[test] + fn dao_ack_packet() { + let data = [0x9b, 0x03, 0x00, 0x00, 0x00, 0x00, 0xf1, 0x00]; + + let packet = Packet::new_checked(&data[..]).unwrap(); + let dao_ack_repr = RplRepr::parse(&packet).unwrap(); + match dao_ack_repr { + RplRepr::DestinationAdvertisementObjectAck { + rpl_instance_id, + sequence, + status, + dodag_id, + .. + } => { + assert_eq!(rpl_instance_id, InstanceId::from(0)); + assert_eq!(sequence, 241); + assert_eq!(status, 0); + assert_eq!(dodag_id, None); + } + _ => unreachable!(), + } + + let mut buffer = vec![0u8; dao_ack_repr.buffer_len()]; + dao_ack_repr.emit(&mut Packet::new_unchecked(&mut buffer[..])); + + assert_eq!(&data[..], &buffer[..]); + + let data = [ + 0x9b, 0x03, 0x0, 0x0, 0x1e, 0x80, 0xf0, 0x00, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + ]; + + let packet = Packet::new_checked(&data[..]).unwrap(); + let dao_ack_repr = RplRepr::parse(&packet).unwrap(); + match dao_ack_repr { + RplRepr::DestinationAdvertisementObjectAck { + rpl_instance_id, + sequence, + status, + dodag_id, + .. + } => { + assert_eq!(rpl_instance_id, InstanceId::from(30)); + assert_eq!(sequence, 240); + assert_eq!(status, 0x0); + assert_eq!( + dodag_id, + Some(Ipv6Address([ + 254, 128, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1 + ])) + ); + } + _ => unreachable!(), + } + + let mut buffer = vec![0u8; dao_ack_repr.buffer_len()]; + dao_ack_repr.emit(&mut Packet::new_unchecked(&mut buffer[..])); + + assert_eq!(&data[..], &buffer[..]); + } +} diff --git a/src/wire/sixlowpan/frag.rs b/src/wire/sixlowpan/frag.rs new file mode 100644 index 0000000..de45702 --- /dev/null +++ b/src/wire/sixlowpan/frag.rs @@ -0,0 +1,275 @@ +//! Implementation of the fragment headers from [RFC 4944 § 5.3]. +//! +//! [RFC 4944 § 5.3]: https://datatracker.ietf.org/doc/html/rfc4944#section-5.3 + +use super::{DISPATCH_FIRST_FRAGMENT_HEADER, DISPATCH_FRAGMENT_HEADER}; +use crate::wire::{Error, Result}; +use crate::wire::{Ieee802154Address, Ieee802154Repr}; +use byteorder::{ByteOrder, NetworkEndian}; + +/// Key used for identifying all the link fragments that belong to the same packet. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Key { + pub(crate) ll_src_addr: Ieee802154Address, + pub(crate) ll_dst_addr: Ieee802154Address, + pub(crate) datagram_size: u16, + pub(crate) datagram_tag: u16, +} + +/// A read/write wrapper around a 6LoWPAN Fragment header. +/// [RFC 4944 § 5.3] specifies the format of the header. +/// +/// A First Fragment header has the following format: +/// ```txt +/// 1 2 3 +/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// |1 1 0 0 0| datagram_size | datagram_tag | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// ``` +/// +/// Subsequent fragment headers have the following format: +/// ```txt +/// 1 2 3 +/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// |1 1 1 0 0| datagram_size | datagram_tag | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// |datagram_offset| +/// +-+-+-+-+-+-+-+-+ +/// ``` +/// +/// [RFC 4944 § 5.3]: https://datatracker.ietf.org/doc/html/rfc4944#section-5.3 +#[derive(Debug, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Packet<T: AsRef<[u8]>> { + buffer: T, +} + +pub const FIRST_FRAGMENT_HEADER_SIZE: usize = 4; +pub const NEXT_FRAGMENT_HEADER_SIZE: usize = 5; + +mod field { + use crate::wire::field::*; + + pub const DISPATCH: usize = 0; + pub const DATAGRAM_SIZE: Field = 0..2; + pub const DATAGRAM_TAG: Field = 2..4; + pub const DATAGRAM_OFFSET: usize = 4; + + pub const FIRST_FRAGMENT_REST: Rest = super::FIRST_FRAGMENT_HEADER_SIZE..; + pub const NEXT_FRAGMENT_REST: Rest = super::NEXT_FRAGMENT_HEADER_SIZE..; +} + +impl<T: AsRef<[u8]>> Packet<T> { + /// Input a raw octet buffer with a 6LoWPAN Fragment header structure. + pub const fn new_unchecked(buffer: T) -> Self { + Self { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Self> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + + let dispatch = packet.dispatch(); + + if dispatch != DISPATCH_FIRST_FRAGMENT_HEADER && dispatch != DISPATCH_FRAGMENT_HEADER { + return Err(Error); + } + + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let buffer = self.buffer.as_ref(); + if buffer.is_empty() { + return Err(Error); + } + + match self.dispatch() { + DISPATCH_FIRST_FRAGMENT_HEADER if buffer.len() >= FIRST_FRAGMENT_HEADER_SIZE => Ok(()), + DISPATCH_FIRST_FRAGMENT_HEADER if buffer.len() < FIRST_FRAGMENT_HEADER_SIZE => { + Err(Error) + } + DISPATCH_FRAGMENT_HEADER if buffer.len() >= NEXT_FRAGMENT_HEADER_SIZE => Ok(()), + DISPATCH_FRAGMENT_HEADER if buffer.len() < NEXT_FRAGMENT_HEADER_SIZE => Err(Error), + _ => Err(Error), + } + } + + /// Consumes the frame, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the dispatch field. + pub fn dispatch(&self) -> u8 { + let raw = self.buffer.as_ref(); + raw[field::DISPATCH] >> 3 + } + + /// Return the total datagram size. + pub fn datagram_size(&self) -> u16 { + let raw = self.buffer.as_ref(); + NetworkEndian::read_u16(&raw[field::DATAGRAM_SIZE]) & 0b111_1111_1111 + } + + /// Return the datagram tag. + pub fn datagram_tag(&self) -> u16 { + let raw = self.buffer.as_ref(); + NetworkEndian::read_u16(&raw[field::DATAGRAM_TAG]) + } + + /// Return the datagram offset. + pub fn datagram_offset(&self) -> u8 { + match self.dispatch() { + DISPATCH_FIRST_FRAGMENT_HEADER => 0, + DISPATCH_FRAGMENT_HEADER => { + let raw = self.buffer.as_ref(); + raw[field::DATAGRAM_OFFSET] + } + _ => unreachable!(), + } + } + + /// Returns `true` when this header is from the first fragment of a link. + pub fn is_first_fragment(&self) -> bool { + self.dispatch() == DISPATCH_FIRST_FRAGMENT_HEADER + } + + /// Returns the key for identifying the packet it belongs to. + pub fn get_key(&self, ieee802154_repr: &Ieee802154Repr) -> Key { + Key { + ll_src_addr: ieee802154_repr.src_addr.unwrap(), + ll_dst_addr: ieee802154_repr.dst_addr.unwrap(), + datagram_size: self.datagram_size(), + datagram_tag: self.datagram_tag(), + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> { + /// Return the payload. + pub fn payload(&self) -> &'a [u8] { + match self.dispatch() { + DISPATCH_FIRST_FRAGMENT_HEADER => { + let raw = self.buffer.as_ref(); + &raw[field::FIRST_FRAGMENT_REST] + } + DISPATCH_FRAGMENT_HEADER => { + let raw = self.buffer.as_ref(); + &raw[field::NEXT_FRAGMENT_REST] + } + _ => unreachable!(), + } + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + fn set_dispatch_field(&mut self, value: u8) { + let raw = self.buffer.as_mut(); + raw[field::DISPATCH] = (raw[field::DISPATCH] & !(0b11111 << 3)) | (value << 3); + } + + fn set_datagram_size(&mut self, size: u16) { + let raw = self.buffer.as_mut(); + let mut v = NetworkEndian::read_u16(&raw[field::DATAGRAM_SIZE]); + v = (v & !0b111_1111_1111) | size; + + NetworkEndian::write_u16(&mut raw[field::DATAGRAM_SIZE], v); + } + + fn set_datagram_tag(&mut self, tag: u16) { + let raw = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut raw[field::DATAGRAM_TAG], tag); + } + + fn set_datagram_offset(&mut self, offset: u8) { + let raw = self.buffer.as_mut(); + raw[field::DATAGRAM_OFFSET] = offset; + } +} + +/// A high-level representation of a 6LoWPAN Fragment header. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum Repr { + FirstFragment { size: u16, tag: u16 }, + Fragment { size: u16, tag: u16, offset: u8 }, +} + +impl core::fmt::Display for Repr { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Repr::FirstFragment { size, tag } => { + write!(f, "FirstFrag size={size} tag={tag}") + } + Repr::Fragment { size, tag, offset } => { + write!(f, "NthFrag size={size} tag={tag} offset={offset}") + } + } + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Repr { + fn format(&self, fmt: defmt::Formatter) { + match self { + Repr::FirstFragment { size, tag } => { + defmt::write!(fmt, "FirstFrag size={} tag={}", size, tag); + } + Repr::Fragment { size, tag, offset } => { + defmt::write!(fmt, "NthFrag size={} tag={} offset={}", size, tag, offset); + } + } + } +} + +impl Repr { + /// Parse a 6LoWPAN Fragment header. + pub fn parse<T: AsRef<[u8]>>(packet: &Packet<T>) -> Result<Self> { + let size = packet.datagram_size(); + let tag = packet.datagram_tag(); + + match packet.dispatch() { + DISPATCH_FIRST_FRAGMENT_HEADER => Ok(Self::FirstFragment { size, tag }), + DISPATCH_FRAGMENT_HEADER => Ok(Self::Fragment { + size, + tag, + offset: packet.datagram_offset(), + }), + _ => Err(Error), + } + } + + /// Returns the length of the Fragment header. + pub const fn buffer_len(&self) -> usize { + match self { + Self::FirstFragment { .. } => field::FIRST_FRAGMENT_REST.start, + Self::Fragment { .. } => field::NEXT_FRAGMENT_REST.start, + } + } + + /// Emit a high-level representation into a 6LoWPAN Fragment header. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, packet: &mut Packet<T>) { + match self { + Self::FirstFragment { size, tag } => { + packet.set_dispatch_field(DISPATCH_FIRST_FRAGMENT_HEADER); + packet.set_datagram_size(*size); + packet.set_datagram_tag(*tag); + } + Self::Fragment { size, tag, offset } => { + packet.set_dispatch_field(DISPATCH_FRAGMENT_HEADER); + packet.set_datagram_size(*size); + packet.set_datagram_tag(*tag); + packet.set_datagram_offset(*offset); + } + } + } +} diff --git a/src/wire/sixlowpan/iphc.rs b/src/wire/sixlowpan/iphc.rs new file mode 100644 index 0000000..f9dcc2b --- /dev/null +++ b/src/wire/sixlowpan/iphc.rs @@ -0,0 +1,948 @@ +//! Implementation of IP Header Compression from [RFC 6282 § 3.1]. +//! It defines the compression of IPv6 headers. +//! +//! [RFC 6282 § 3.1]: https://datatracker.ietf.org/doc/html/rfc6282#section-3.1 + +use super::{ + AddressContext, AddressMode, Error, NextHeader, Result, UnresolvedAddress, DISPATCH_IPHC_HEADER, +}; +use crate::wire::{ieee802154::Address as LlAddress, ipv6, IpProtocol}; +use byteorder::{ByteOrder, NetworkEndian}; + +mod field { + use crate::wire::field::*; + + pub const IPHC_FIELD: Field = 0..2; +} + +macro_rules! get_field { + ($name:ident, $mask:expr, $shift:expr) => { + fn $name(&self) -> u8 { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::IPHC_FIELD]); + ((raw >> $shift) & $mask) as u8 + } + }; +} + +macro_rules! set_field { + ($name:ident, $mask:expr, $shift:expr) => { + fn $name(&mut self, val: u8) { + let data = &mut self.buffer.as_mut()[field::IPHC_FIELD]; + let mut raw = NetworkEndian::read_u16(data); + + raw = (raw & !($mask << $shift)) | ((val as u16) << $shift); + NetworkEndian::write_u16(data, raw); + } + }; +} + +/// A read/write wrapper around a 6LoWPAN IPHC header. +/// [RFC 6282 § 3.1] specifies the format of the header. +/// +/// The header always start with the following base format (from [RFC 6282 § 3.1.1]): +/// ```txt +/// 0 1 +/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +/// +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ +/// | 0 | 1 | 1 | TF |NH | HLIM |CID|SAC| SAM | M |DAC| DAM | +/// +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ +/// ``` +/// With: +/// - TF: Traffic Class and Flow Label +/// - NH: Next Header +/// - HLIM: Hop Limit +/// - CID: Context Identifier Extension +/// - SAC: Source Address Compression +/// - SAM: Source Address Mode +/// - M: Multicast Compression +/// - DAC: Destination Address Compression +/// - DAM: Destination Address Mode +/// +/// Depending on the flags in the base format, the following fields are added to the header: +/// - Traffic Class and Flow Label +/// - Next Header +/// - Hop Limit +/// - IPv6 source address +/// - IPv6 destination address +/// +/// [RFC 6282 § 3.1]: https://datatracker.ietf.org/doc/html/rfc6282#section-3.1 +/// [RFC 6282 § 3.1.1]: https://datatracker.ietf.org/doc/html/rfc6282#section-3.1.1 +#[derive(Debug, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Packet<T: AsRef<[u8]>> { + buffer: T, +} + +impl<T: AsRef<[u8]>> Packet<T> { + /// Input a raw octet buffer with a 6LoWPAN IPHC header structure. + pub const fn new_unchecked(buffer: T) -> Self { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Self> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let buffer = self.buffer.as_ref(); + if buffer.len() < 2 { + return Err(Error); + } + + let mut offset = self.ip_fields_start() + + self.traffic_class_size() + + self.next_header_size() + + self.hop_limit_size(); + offset += self.src_address_size(); + offset += self.dst_address_size(); + + if offset as usize > buffer.len() { + return Err(Error); + } + + Ok(()) + } + + /// Consumes the frame, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the Next Header field. + pub fn next_header(&self) -> NextHeader { + let nh = self.nh_field(); + + if nh == 1 { + // The next header field is compressed. + // It is also encoded using LOWPAN_NHC. + NextHeader::Compressed + } else { + // The full 8 bits for Next Header are carried in-line. + let start = (self.ip_fields_start() + self.traffic_class_size()) as usize; + + let data = self.buffer.as_ref(); + let nh = data[start..start + 1][0]; + NextHeader::Uncompressed(IpProtocol::from(nh)) + } + } + + /// Return the Hop Limit. + pub fn hop_limit(&self) -> u8 { + match self.hlim_field() { + 0b00 => { + let start = (self.ip_fields_start() + + self.traffic_class_size() + + self.next_header_size()) as usize; + + let data = self.buffer.as_ref(); + data[start..start + 1][0] + } + 0b01 => 1, + 0b10 => 64, + 0b11 => 255, + _ => unreachable!(), + } + } + + /// Return the Source Context Identifier. + pub fn src_context_id(&self) -> Option<u8> { + if self.cid_field() == 1 { + let data = self.buffer.as_ref(); + Some(data[2] >> 4) + } else { + None + } + } + + /// Return the Destination Context Identifier. + pub fn dst_context_id(&self) -> Option<u8> { + if self.cid_field() == 1 { + let data = self.buffer.as_ref(); + Some(data[2] & 0x0f) + } else { + None + } + } + + /// Return the ECN field (when it is inlined). + pub fn ecn_field(&self) -> Option<u8> { + match self.tf_field() { + 0b00 | 0b01 | 0b10 => { + let start = self.ip_fields_start() as usize; + Some(self.buffer.as_ref()[start..][0] & 0b1100_0000) + } + 0b11 => None, + _ => unreachable!(), + } + } + + /// Return the DSCP field (when it is inlined). + pub fn dscp_field(&self) -> Option<u8> { + match self.tf_field() { + 0b00 | 0b10 => { + let start = self.ip_fields_start() as usize; + Some(self.buffer.as_ref()[start..][0] & 0b111111) + } + 0b01 | 0b11 => None, + _ => unreachable!(), + } + } + + /// Return the flow label field (when it is inlined). + pub fn flow_label_field(&self) -> Option<u16> { + match self.tf_field() { + 0b00 => { + let start = self.ip_fields_start() as usize; + Some(NetworkEndian::read_u16( + &self.buffer.as_ref()[start..][2..4], + )) + } + 0b01 => { + let start = self.ip_fields_start() as usize; + Some(NetworkEndian::read_u16( + &self.buffer.as_ref()[start..][1..3], + )) + } + 0b10 | 0b11 => None, + _ => unreachable!(), + } + } + + /// Return the Source Address. + pub fn src_addr(&self) -> Result<UnresolvedAddress> { + let start = (self.ip_fields_start() + + self.traffic_class_size() + + self.next_header_size() + + self.hop_limit_size()) as usize; + + let data = self.buffer.as_ref(); + match (self.sac_field(), self.sam_field()) { + (0, 0b00) => Ok(UnresolvedAddress::WithoutContext(AddressMode::FullInline( + &data[start..][..16], + ))), + (0, 0b01) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::InLine64bits(&data[start..][..8]), + )), + (0, 0b10) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::InLine16bits(&data[start..][..2]), + )), + (0, 0b11) => Ok(UnresolvedAddress::WithoutContext(AddressMode::FullyElided)), + (1, 0b00) => Ok(UnresolvedAddress::WithContext(( + 0, + AddressMode::Unspecified, + ))), + (1, 0b01) => { + if let Some(id) = self.src_context_id() { + Ok(UnresolvedAddress::WithContext(( + id as usize, + AddressMode::InLine64bits(&data[start..][..8]), + ))) + } else { + Err(Error) + } + } + (1, 0b10) => { + if let Some(id) = self.src_context_id() { + Ok(UnresolvedAddress::WithContext(( + id as usize, + AddressMode::InLine16bits(&data[start..][..2]), + ))) + } else { + Err(Error) + } + } + (1, 0b11) => { + if let Some(id) = self.src_context_id() { + Ok(UnresolvedAddress::WithContext(( + id as usize, + AddressMode::FullyElided, + ))) + } else { + Err(Error) + } + } + _ => Err(Error), + } + } + + /// Return the Destination Address. + pub fn dst_addr(&self) -> Result<UnresolvedAddress> { + let start = (self.ip_fields_start() + + self.traffic_class_size() + + self.next_header_size() + + self.hop_limit_size() + + self.src_address_size()) as usize; + + let data = self.buffer.as_ref(); + match (self.m_field(), self.dac_field(), self.dam_field()) { + (0, 0, 0b00) => Ok(UnresolvedAddress::WithoutContext(AddressMode::FullInline( + &data[start..][..16], + ))), + (0, 0, 0b01) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::InLine64bits(&data[start..][..8]), + )), + (0, 0, 0b10) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::InLine16bits(&data[start..][..2]), + )), + (0, 0, 0b11) => Ok(UnresolvedAddress::WithoutContext(AddressMode::FullyElided)), + (0, 1, 0b00) => Ok(UnresolvedAddress::Reserved), + (0, 1, 0b01) => { + if let Some(id) = self.dst_context_id() { + Ok(UnresolvedAddress::WithContext(( + id as usize, + AddressMode::InLine64bits(&data[start..][..8]), + ))) + } else { + Err(Error) + } + } + (0, 1, 0b10) => { + if let Some(id) = self.dst_context_id() { + Ok(UnresolvedAddress::WithContext(( + id as usize, + AddressMode::InLine16bits(&data[start..][..2]), + ))) + } else { + Err(Error) + } + } + (0, 1, 0b11) => { + if let Some(id) = self.dst_context_id() { + Ok(UnresolvedAddress::WithContext(( + id as usize, + AddressMode::FullyElided, + ))) + } else { + Err(Error) + } + } + (1, 0, 0b00) => Ok(UnresolvedAddress::WithoutContext(AddressMode::FullInline( + &data[start..][..16], + ))), + (1, 0, 0b01) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::Multicast48bits(&data[start..][..6]), + )), + (1, 0, 0b10) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::Multicast32bits(&data[start..][..4]), + )), + (1, 0, 0b11) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::Multicast8bits(&data[start..][..1]), + )), + (1, 1, 0b00) => Ok(UnresolvedAddress::WithContext(( + 0, + AddressMode::NotSupported, + ))), + (1, 1, 0b01 | 0b10 | 0b11) => Ok(UnresolvedAddress::Reserved), + _ => Err(Error), + } + } + + get_field!(dispatch_field, 0b111, 13); + get_field!(tf_field, 0b11, 11); + get_field!(nh_field, 0b1, 10); + get_field!(hlim_field, 0b11, 8); + get_field!(cid_field, 0b1, 7); + get_field!(sac_field, 0b1, 6); + get_field!(sam_field, 0b11, 4); + get_field!(m_field, 0b1, 3); + get_field!(dac_field, 0b1, 2); + get_field!(dam_field, 0b11, 0); + + /// Return the start for the IP fields. + fn ip_fields_start(&self) -> u8 { + 2 + self.cid_size() + } + + /// Get the size in octets of the traffic class field. + fn traffic_class_size(&self) -> u8 { + match self.tf_field() { + 0b00 => 4, + 0b01 => 3, + 0b10 => 1, + 0b11 => 0, + _ => unreachable!(), + } + } + + /// Get the size in octets of the next header field. + fn next_header_size(&self) -> u8 { + (self.nh_field() != 1) as u8 + } + + /// Get the size in octets of the hop limit field. + fn hop_limit_size(&self) -> u8 { + (self.hlim_field() == 0b00) as u8 + } + + /// Get the size in octets of the CID field. + fn cid_size(&self) -> u8 { + (self.cid_field() == 1) as u8 + } + + /// Get the size in octets of the source address. + fn src_address_size(&self) -> u8 { + match (self.sac_field(), self.sam_field()) { + (0, 0b00) => 16, // The full address is carried in-line. + (0, 0b01) => 8, // The first 64 bits are elided. + (0, 0b10) => 2, // The first 112 bits are elided. + (0, 0b11) => 0, // The address is fully elided. + (1, 0b00) => 0, // The UNSPECIFIED address. + (1, 0b01) => 8, // Address derived using context information. + (1, 0b10) => 2, // Address derived using context information. + (1, 0b11) => 0, // Address derived using context information. + _ => unreachable!(), + } + } + + /// Get the size in octets of the address address. + fn dst_address_size(&self) -> u8 { + match (self.m_field(), self.dac_field(), self.dam_field()) { + (0, 0, 0b00) => 16, // The full address is carried in-line. + (0, 0, 0b01) => 8, // The first 64 bits are elided. + (0, 0, 0b10) => 2, // The first 112 bits are elided. + (0, 0, 0b11) => 0, // The address is fully elided. + (0, 1, 0b00) => 0, // Reserved. + (0, 1, 0b01) => 8, // Address derived using context information. + (0, 1, 0b10) => 2, // Address derived using context information. + (0, 1, 0b11) => 0, // Address derived using context information. + (1, 0, 0b00) => 16, // The full address is carried in-line. + (1, 0, 0b01) => 6, // The address takes the form ffXX::00XX:XXXX:XXXX. + (1, 0, 0b10) => 4, // The address takes the form ffXX::00XX:XXXX. + (1, 0, 0b11) => 1, // The address takes the form ff02::00XX. + (1, 1, 0b00) => 6, // Match Unicast-Prefix-based IPv6. + (1, 1, 0b01) => 0, // Reserved. + (1, 1, 0b10) => 0, // Reserved. + (1, 1, 0b11) => 0, // Reserved. + _ => unreachable!(), + } + } + + /// Return the length of the header. + pub fn header_len(&self) -> usize { + let mut len = self.ip_fields_start(); + len += self.traffic_class_size(); + len += self.next_header_size(); + len += self.hop_limit_size(); + len += self.src_address_size(); + len += self.dst_address_size(); + + len as usize + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> { + /// Return a pointer to the payload. + pub fn payload(&self) -> &'a [u8] { + let mut len = self.ip_fields_start(); + len += self.traffic_class_size(); + len += self.next_header_size(); + len += self.hop_limit_size(); + len += self.src_address_size(); + len += self.dst_address_size(); + + let len = len as usize; + + let data = self.buffer.as_ref(); + &data[len..] + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the dispatch field to `0b011`. + fn set_dispatch_field(&mut self) { + let data = &mut self.buffer.as_mut()[field::IPHC_FIELD]; + let mut raw = NetworkEndian::read_u16(data); + + raw = (raw & !(0b111 << 13)) | (0b11 << 13); + NetworkEndian::write_u16(data, raw); + } + + set_field!(set_tf_field, 0b11, 11); + set_field!(set_nh_field, 0b1, 10); + set_field!(set_hlim_field, 0b11, 8); + set_field!(set_cid_field, 0b1, 7); + set_field!(set_sac_field, 0b1, 6); + set_field!(set_sam_field, 0b11, 4); + set_field!(set_m_field, 0b1, 3); + set_field!(set_dac_field, 0b1, 2); + set_field!(set_dam_field, 0b11, 0); + + fn set_field(&mut self, idx: usize, value: &[u8]) { + let raw = self.buffer.as_mut(); + raw[idx..idx + value.len()].copy_from_slice(value); + } + + /// Set the Next Header. + /// + /// **NOTE**: `idx` is the offset at which the Next Header needs to be written to. + fn set_next_header(&mut self, nh: NextHeader, mut idx: usize) -> usize { + match nh { + NextHeader::Uncompressed(nh) => { + self.set_nh_field(0); + self.set_field(idx, &[nh.into()]); + idx += 1; + } + NextHeader::Compressed => self.set_nh_field(1), + } + + idx + } + + /// Set the Hop Limit. + /// + /// **NOTE**: `idx` is the offset at which the Next Header needs to be written to. + fn set_hop_limit(&mut self, hl: u8, mut idx: usize) -> usize { + match hl { + 255 => self.set_hlim_field(0b11), + 64 => self.set_hlim_field(0b10), + 1 => self.set_hlim_field(0b01), + _ => { + self.set_hlim_field(0b00); + self.set_field(idx, &[hl]); + idx += 1; + } + } + + idx + } + + /// Set the Source Address based on the IPv6 address and the Link-Local address. + /// + /// **NOTE**: `idx` is the offset at which the Next Header needs to be written to. + fn set_src_address( + &mut self, + src_addr: ipv6::Address, + ll_src_addr: Option<LlAddress>, + mut idx: usize, + ) -> usize { + self.set_cid_field(0); + self.set_sac_field(0); + let src = src_addr.as_bytes(); + if src_addr == ipv6::Address::UNSPECIFIED { + self.set_sac_field(1); + self.set_sam_field(0b00); + } else if src_addr.is_link_local() { + // We have a link local address. + // The remainder of the address can be elided when the context contains + // a 802.15.4 short address or a 802.15.4 extended address which can be + // converted to a eui64 address. + let is_eui_64 = ll_src_addr + .map(|addr| { + addr.as_eui_64() + .map(|addr| addr[..] == src[8..]) + .unwrap_or(false) + }) + .unwrap_or(false); + + if src[8..14] == [0, 0, 0, 0xff, 0xfe, 0] { + let ll = [src[14], src[15]]; + + if ll_src_addr == Some(LlAddress::Short(ll)) { + // We have the context from the 802.15.4 frame. + // The context contains the short address. + // We can elide the source address. + self.set_sam_field(0b11); + } else { + // We don't have the context from the 802.15.4 frame. + // We cannot elide the source address, however we can elide 112 bits. + self.set_sam_field(0b10); + + self.set_field(idx, &src[14..]); + idx += 2; + } + } else if is_eui_64 { + // We have the context from the 802.15.4 frame. + // The context contains the extended address. + // We can elide the source address. + self.set_sam_field(0b11); + } else { + // We cannot elide the source address, however we can elide 64 bits. + self.set_sam_field(0b01); + + self.set_field(idx, &src[8..]); + idx += 8; + } + } else { + // We cannot elide anything. + self.set_sam_field(0b00); + self.set_field(idx, src); + idx += 16; + } + + idx + } + + /// Set the Destination Address based on the IPv6 address and the Link-Local address. + /// + /// **NOTE**: `idx` is the offset at which the Next Header needs to be written to. + fn set_dst_address( + &mut self, + dst_addr: ipv6::Address, + ll_dst_addr: Option<LlAddress>, + mut idx: usize, + ) -> usize { + self.set_dac_field(0); + self.set_dam_field(0); + self.set_m_field(0); + let dst = dst_addr.as_bytes(); + if dst_addr.is_multicast() { + self.set_m_field(1); + + if dst[1] == 0x02 && dst[2..15] == [0; 13] { + self.set_dam_field(0b11); + + self.set_field(idx, &[dst[15]]); + idx += 1; + } else if dst[2..13] == [0; 11] { + self.set_dam_field(0b10); + + self.set_field(idx, &[dst[1]]); + idx += 1; + self.set_field(idx, &dst[13..]); + idx += 3; + } else if dst[2..11] == [0; 9] { + self.set_dam_field(0b01); + + self.set_field(idx, &[dst[1]]); + idx += 1; + self.set_field(idx, &dst[11..]); + idx += 5; + } else { + self.set_dam_field(0b11); + + self.set_field(idx, dst); + idx += 16; + } + } else if dst_addr.is_link_local() { + let is_eui_64 = ll_dst_addr + .map(|addr| { + addr.as_eui_64() + .map(|addr| addr[..] == dst[8..]) + .unwrap_or(false) + }) + .unwrap_or(false); + + if dst[8..14] == [0, 0, 0, 0xff, 0xfe, 0] { + let ll = [dst[14], dst[15]]; + + if ll_dst_addr == Some(LlAddress::Short(ll)) { + self.set_dam_field(0b11); + } else { + self.set_dam_field(0b10); + + self.set_field(idx, &dst[14..]); + idx += 2; + } + } else if is_eui_64 { + self.set_dam_field(0b11); + } else { + self.set_dam_field(0b01); + + self.set_field(idx, &dst[8..]); + idx += 8; + } + } else { + self.set_dam_field(0b00); + + self.set_field(idx, dst); + idx += 16; + } + + idx + } + + /// Return a mutable pointer to the payload. + pub fn payload_mut(&mut self) -> &mut [u8] { + let mut len = self.ip_fields_start(); + + len += self.traffic_class_size(); + len += self.next_header_size(); + len += self.hop_limit_size(); + len += self.src_address_size(); + len += self.dst_address_size(); + + let len = len as usize; + + let data = self.buffer.as_mut(); + &mut data[len..] + } +} + +/// A high-level representation of a 6LoWPAN IPHC header. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct Repr { + pub src_addr: ipv6::Address, + pub ll_src_addr: Option<LlAddress>, + pub dst_addr: ipv6::Address, + pub ll_dst_addr: Option<LlAddress>, + pub next_header: NextHeader, + pub hop_limit: u8, + // TODO(thvdveld): refactor the following fields into something else + pub ecn: Option<u8>, + pub dscp: Option<u8>, + pub flow_label: Option<u16>, +} + +impl core::fmt::Display for Repr { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!( + f, + "IPHC src={} dst={} nxt-hdr={} hop-limit={}", + self.src_addr, self.dst_addr, self.next_header, self.hop_limit + ) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Repr { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!( + fmt, + "IPHC src={} dst={} nxt-hdr={} hop-limit={}", + self.src_addr, + self.dst_addr, + self.next_header, + self.hop_limit + ); + } +} + +impl Repr { + /// Parse a 6LoWPAN IPHC header and return a high-level representation. + /// + /// The `ll_src_addr` and `ll_dst_addr` are the link-local addresses used for resolving the + /// IPv6 packets. + pub fn parse<T: AsRef<[u8]> + ?Sized>( + packet: &Packet<&T>, + ll_src_addr: Option<LlAddress>, + ll_dst_addr: Option<LlAddress>, + addr_context: &[AddressContext], + ) -> Result<Self> { + // Ensure basic accessors will work. + packet.check_len()?; + + if packet.dispatch_field() != DISPATCH_IPHC_HEADER { + // This is not an LOWPAN_IPHC packet. + return Err(Error); + } + + let src_addr = packet.src_addr()?.resolve(ll_src_addr, addr_context)?; + let dst_addr = packet.dst_addr()?.resolve(ll_dst_addr, addr_context)?; + + Ok(Self { + src_addr, + ll_src_addr, + dst_addr, + ll_dst_addr, + next_header: packet.next_header(), + hop_limit: packet.hop_limit(), + ecn: packet.ecn_field(), + dscp: packet.dscp_field(), + flow_label: packet.flow_label_field(), + }) + } + + /// Return the length of a header that will be emitted from this high-level representation. + pub fn buffer_len(&self) -> usize { + let mut len = 0; + len += 2; // The minimal header length + + len += match self.next_header { + NextHeader::Compressed => 0, // The next header is compressed (we don't need to inline what the next header is) + NextHeader::Uncompressed(_) => 1, // The next header field is inlined + }; + + // Hop Limit size + len += match self.hop_limit { + 255 | 64 | 1 => 0, // We can inline the hop limit + _ => 1, + }; + + // Add the length of the source address + len += if self.src_addr == ipv6::Address::UNSPECIFIED { + 0 + } else if self.src_addr.is_link_local() { + let src = self.src_addr.as_bytes(); + let ll = [src[14], src[15]]; + + let is_eui_64 = self + .ll_src_addr + .map(|addr| { + addr.as_eui_64() + .map(|addr| addr[..] == src[8..]) + .unwrap_or(false) + }) + .unwrap_or(false); + + if src[8..14] == [0, 0, 0, 0xff, 0xfe, 0] { + if self.ll_src_addr == Some(LlAddress::Short(ll)) { + 0 + } else { + 2 + } + } else if is_eui_64 { + 0 + } else { + 8 + } + } else { + 16 + }; + + // Add the size of the destination header + let dst = self.dst_addr.as_bytes(); + len += if self.dst_addr.is_multicast() { + if dst[1] == 0x02 && dst[2..15] == [0; 13] { + 1 + } else if dst[2..13] == [0; 11] { + 4 + } else if dst[2..11] == [0; 9] { + 6 + } else { + 16 + } + } else if self.dst_addr.is_link_local() { + let is_eui_64 = self + .ll_dst_addr + .map(|addr| { + addr.as_eui_64() + .map(|addr| addr[..] == dst[8..]) + .unwrap_or(false) + }) + .unwrap_or(false); + + if dst[8..14] == [0, 0, 0, 0xff, 0xfe, 0] { + let ll = [dst[14], dst[15]]; + + if self.ll_dst_addr == Some(LlAddress::Short(ll)) { + 0 + } else { + 2 + } + } else if is_eui_64 { + 0 + } else { + 8 + } + } else { + 16 + }; + + len += match (self.ecn, self.dscp, self.flow_label) { + (Some(_), Some(_), Some(_)) => 4, + (Some(_), None, Some(_)) => 3, + (Some(_), Some(_), None) => 1, + (None, None, None) => 0, + _ => unreachable!(), + }; + + len + } + + /// Emit a high-level representation into a 6LoWPAN IPHC header. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, packet: &mut Packet<T>) { + let idx = 2; + + packet.set_dispatch_field(); + + // FIXME(thvdveld): we don't set anything from the traffic flow. + packet.set_tf_field(0b11); + + let idx = packet.set_next_header(self.next_header, idx); + let idx = packet.set_hop_limit(self.hop_limit, idx); + let idx = packet.set_src_address(self.src_addr, self.ll_src_addr, idx); + packet.set_dst_address(self.dst_addr, self.ll_dst_addr, idx); + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn iphc_fields() { + let bytes = [ + 0x7a, 0x33, // IPHC + 0x3a, // Next header + ]; + + let packet = Packet::new_unchecked(bytes); + + assert_eq!(packet.dispatch_field(), 0b011); + assert_eq!(packet.tf_field(), 0b11); + assert_eq!(packet.nh_field(), 0b0); + assert_eq!(packet.hlim_field(), 0b10); + assert_eq!(packet.cid_field(), 0b0); + assert_eq!(packet.sac_field(), 0b0); + assert_eq!(packet.sam_field(), 0b11); + assert_eq!(packet.m_field(), 0b0); + assert_eq!(packet.dac_field(), 0b0); + assert_eq!(packet.dam_field(), 0b11); + + assert_eq!( + packet.next_header(), + NextHeader::Uncompressed(IpProtocol::Icmpv6) + ); + + assert_eq!(packet.src_address_size(), 0); + assert_eq!(packet.dst_address_size(), 0); + assert_eq!(packet.hop_limit(), 64); + + assert_eq!( + packet.src_addr(), + Ok(UnresolvedAddress::WithoutContext(AddressMode::FullyElided)) + ); + assert_eq!( + packet.dst_addr(), + Ok(UnresolvedAddress::WithoutContext(AddressMode::FullyElided)) + ); + + let bytes = [ + 0x7e, 0xf7, // IPHC, + 0x00, // CID + ]; + + let packet = Packet::new_unchecked(bytes); + + assert_eq!(packet.dispatch_field(), 0b011); + assert_eq!(packet.tf_field(), 0b11); + assert_eq!(packet.nh_field(), 0b1); + assert_eq!(packet.hlim_field(), 0b10); + assert_eq!(packet.cid_field(), 0b1); + assert_eq!(packet.sac_field(), 0b1); + assert_eq!(packet.sam_field(), 0b11); + assert_eq!(packet.m_field(), 0b0); + assert_eq!(packet.dac_field(), 0b1); + assert_eq!(packet.dam_field(), 0b11); + + assert_eq!(packet.next_header(), NextHeader::Compressed); + + assert_eq!(packet.src_address_size(), 0); + assert_eq!(packet.dst_address_size(), 0); + assert_eq!(packet.hop_limit(), 64); + + assert_eq!( + packet.src_addr(), + Ok(UnresolvedAddress::WithContext(( + 0, + AddressMode::FullyElided + ))) + ); + assert_eq!( + packet.dst_addr(), + Ok(UnresolvedAddress::WithContext(( + 0, + AddressMode::FullyElided + ))) + ); + } +} diff --git a/src/wire/sixlowpan/mod.rs b/src/wire/sixlowpan/mod.rs new file mode 100644 index 0000000..03a5218 --- /dev/null +++ b/src/wire/sixlowpan/mod.rs @@ -0,0 +1,371 @@ +//! Implementation of [RFC 6282] which specifies a compression format for IPv6 datagrams over +//! IEEE802.154-based networks. +//! +//! [RFC 6282]: https://datatracker.ietf.org/doc/html/rfc6282 + +use super::{Error, Result}; +use crate::wire::ieee802154::Address as LlAddress; +use crate::wire::ipv6; +use crate::wire::IpProtocol; + +pub mod frag; +pub mod iphc; +pub mod nhc; + +const ADDRESS_CONTEXT_LENGTH: usize = 8; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct AddressContext(pub [u8; ADDRESS_CONTEXT_LENGTH]); + +/// The representation of an unresolved address. 6LoWPAN compression of IPv6 addresses can be with +/// and without context information. The decompression with context information is not yet +/// implemented. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum UnresolvedAddress<'a> { + WithoutContext(AddressMode<'a>), + WithContext((usize, AddressMode<'a>)), + Reserved, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum AddressMode<'a> { + /// The full address is carried in-line. + FullInline(&'a [u8]), + /// The first 64-bits of the address are elided. The value of those bits + /// is the link-local prefix padded with zeros. The remaining 64 bits are + /// carried in-line. + InLine64bits(&'a [u8]), + /// The first 112 bits of the address are elided. The value of the first + /// 64 bits is the link-local prefix padded with zeros. The following 64 bits + /// are 0000:00ff:fe00:XXXX, where XXXX are the 16 bits carried in-line. + InLine16bits(&'a [u8]), + /// The address is fully elided. The first 64 bits of the address are + /// the link-local prefix padded with zeros. The remaining 64 bits are + /// computed from the encapsulating header (e.g., 802.15.4 or IPv6 source address) + /// as specified in Section 3.2.2. + FullyElided, + /// The address takes the form ffXX::00XX:XXXX:XXXX + Multicast48bits(&'a [u8]), + /// The address takes the form ffXX::00XX:XXXX. + Multicast32bits(&'a [u8]), + /// The address takes the form ff02::00XX. + Multicast8bits(&'a [u8]), + /// The unspecified address. + Unspecified, + NotSupported, +} + +const LINK_LOCAL_PREFIX: [u8; 2] = [0xfe, 0x80]; +const EUI64_MIDDLE_VALUE: [u8; 2] = [0xff, 0xfe]; + +impl<'a> UnresolvedAddress<'a> { + pub fn resolve( + self, + ll_address: Option<LlAddress>, + addr_context: &[AddressContext], + ) -> Result<ipv6::Address> { + let mut bytes = [0; 16]; + + let copy_context = |index: usize, bytes: &mut [u8]| -> Result<()> { + if index >= addr_context.len() { + return Err(Error); + } + + let context = addr_context[index]; + bytes[..ADDRESS_CONTEXT_LENGTH].copy_from_slice(&context.0); + + Ok(()) + }; + + match self { + UnresolvedAddress::WithoutContext(mode) => match mode { + AddressMode::FullInline(addr) => Ok(ipv6::Address::from_bytes(addr)), + AddressMode::InLine64bits(inline) => { + bytes[0..2].copy_from_slice(&LINK_LOCAL_PREFIX[..]); + bytes[8..].copy_from_slice(inline); + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + AddressMode::InLine16bits(inline) => { + bytes[0..2].copy_from_slice(&LINK_LOCAL_PREFIX[..]); + bytes[11..13].copy_from_slice(&EUI64_MIDDLE_VALUE[..]); + bytes[14..].copy_from_slice(inline); + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + AddressMode::FullyElided => { + bytes[0..2].copy_from_slice(&LINK_LOCAL_PREFIX[..]); + match ll_address { + Some(LlAddress::Short(ll)) => { + bytes[11..13].copy_from_slice(&EUI64_MIDDLE_VALUE[..]); + bytes[14..].copy_from_slice(&ll); + } + Some(addr @ LlAddress::Extended(_)) => match addr.as_eui_64() { + Some(addr) => bytes[8..].copy_from_slice(&addr), + None => return Err(Error), + }, + Some(LlAddress::Absent) => return Err(Error), + None => return Err(Error), + } + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + AddressMode::Multicast48bits(inline) => { + bytes[0] = 0xff; + bytes[1] = inline[0]; + bytes[11..].copy_from_slice(&inline[1..][..5]); + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + AddressMode::Multicast32bits(inline) => { + bytes[0] = 0xff; + bytes[1] = inline[0]; + bytes[13..].copy_from_slice(&inline[1..][..3]); + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + AddressMode::Multicast8bits(inline) => { + bytes[0] = 0xff; + bytes[1] = 0x02; + bytes[15] = inline[0]; + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + _ => Err(Error), + }, + UnresolvedAddress::WithContext(mode) => match mode { + (_, AddressMode::Unspecified) => Ok(ipv6::Address::UNSPECIFIED), + (index, AddressMode::InLine64bits(inline)) => { + copy_context(index, &mut bytes[..])?; + bytes[16 - inline.len()..].copy_from_slice(inline); + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + (index, AddressMode::InLine16bits(inline)) => { + copy_context(index, &mut bytes[..])?; + bytes[16 - inline.len()..].copy_from_slice(inline); + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + (index, AddressMode::FullyElided) => { + match ll_address { + Some(LlAddress::Short(ll)) => { + bytes[11..13].copy_from_slice(&EUI64_MIDDLE_VALUE[..]); + bytes[14..].copy_from_slice(&ll); + } + Some(addr @ LlAddress::Extended(_)) => match addr.as_eui_64() { + Some(addr) => bytes[8..].copy_from_slice(&addr), + None => return Err(Error), + }, + Some(LlAddress::Absent) => return Err(Error), + None => return Err(Error), + } + + copy_context(index, &mut bytes[..])?; + + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + _ => Err(Error), + }, + UnresolvedAddress::Reserved => Err(Error), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum SixlowpanPacket { + FragmentHeader, + IphcHeader, +} + +const DISPATCH_FIRST_FRAGMENT_HEADER: u8 = 0b11000; +const DISPATCH_FRAGMENT_HEADER: u8 = 0b11100; +const DISPATCH_IPHC_HEADER: u8 = 0b011; +const DISPATCH_UDP_HEADER: u8 = 0b11110; +const DISPATCH_EXT_HEADER: u8 = 0b1110; + +impl SixlowpanPacket { + /// Returns the type of the 6LoWPAN header. + /// This can either be a fragment header or an IPHC header. + /// + /// # Errors + /// Returns `[Error::Unrecognized]` when neither the Fragment Header dispatch or the IPHC + /// dispatch is recognized. + pub fn dispatch(buffer: impl AsRef<[u8]>) -> Result<Self> { + let raw = buffer.as_ref(); + + if raw.is_empty() { + return Err(Error); + } + + if raw[0] >> 3 == DISPATCH_FIRST_FRAGMENT_HEADER || raw[0] >> 3 == DISPATCH_FRAGMENT_HEADER + { + Ok(Self::FragmentHeader) + } else if raw[0] >> 5 == DISPATCH_IPHC_HEADER { + Ok(Self::IphcHeader) + } else { + Err(Error) + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum NextHeader { + Compressed, + Uncompressed(IpProtocol), +} + +impl core::fmt::Display for NextHeader { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + NextHeader::Compressed => write!(f, "compressed"), + NextHeader::Uncompressed(protocol) => write!(f, "{protocol}"), + } + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for NextHeader { + fn format(&self, fmt: defmt::Formatter) { + match self { + NextHeader::Compressed => defmt::write!(fmt, "compressed"), + NextHeader::Uncompressed(protocol) => defmt::write!(fmt, "{}", protocol), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn sixlowpan_fragment_emit() { + let repr = frag::Repr::FirstFragment { + size: 0xff, + tag: 0xabcd, + }; + let buffer = [0u8; 4]; + let mut packet = frag::Packet::new_unchecked(buffer); + + assert_eq!(repr.buffer_len(), 4); + repr.emit(&mut packet); + + assert_eq!(packet.datagram_size(), 0xff); + assert_eq!(packet.datagram_tag(), 0xabcd); + assert_eq!(packet.into_inner(), [0xc0, 0xff, 0xab, 0xcd]); + + let repr = frag::Repr::Fragment { + size: 0xff, + tag: 0xabcd, + offset: 0xcc, + }; + let buffer = [0u8; 5]; + let mut packet = frag::Packet::new_unchecked(buffer); + + assert_eq!(repr.buffer_len(), 5); + repr.emit(&mut packet); + + assert_eq!(packet.datagram_size(), 0xff); + assert_eq!(packet.datagram_tag(), 0xabcd); + assert_eq!(packet.into_inner(), [0xe0, 0xff, 0xab, 0xcd, 0xcc]); + } + + #[test] + fn sixlowpan_three_fragments() { + use crate::wire::ieee802154::Frame as Ieee802154Frame; + use crate::wire::ieee802154::Repr as Ieee802154Repr; + use crate::wire::Ieee802154Address; + + let key = frag::Key { + ll_src_addr: Ieee802154Address::Extended([50, 147, 130, 47, 40, 8, 62, 217]), + ll_dst_addr: Ieee802154Address::Extended([26, 11, 66, 66, 66, 66, 66, 66]), + datagram_size: 307, + datagram_tag: 63, + }; + + let frame1: &[u8] = &[ + 0x41, 0xcc, 0x92, 0xef, 0xbe, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, 0xd9, + 0x3e, 0x08, 0x28, 0x2f, 0x82, 0x93, 0x32, 0xc1, 0x33, 0x00, 0x3f, 0x6e, 0x33, 0x02, + 0x35, 0x3d, 0xf0, 0xd2, 0x5f, 0x1b, 0x39, 0xb4, 0x6b, 0x4c, 0x6f, 0x72, 0x65, 0x6d, + 0x20, 0x69, 0x70, 0x73, 0x75, 0x6d, 0x20, 0x64, 0x6f, 0x6c, 0x6f, 0x72, 0x20, 0x73, + 0x69, 0x74, 0x20, 0x61, 0x6d, 0x65, 0x74, 0x2c, 0x20, 0x63, 0x6f, 0x6e, 0x73, 0x65, + 0x63, 0x74, 0x65, 0x74, 0x75, 0x72, 0x20, 0x61, 0x64, 0x69, 0x70, 0x69, 0x73, 0x63, + 0x69, 0x6e, 0x67, 0x20, 0x65, 0x6c, 0x69, 0x74, 0x2e, 0x20, 0x41, 0x6c, 0x69, 0x71, + 0x75, 0x61, 0x6d, 0x20, 0x64, 0x75, 0x69, 0x20, 0x6f, 0x64, 0x69, 0x6f, 0x2c, 0x20, + 0x69, 0x61, 0x63, 0x75, 0x6c, 0x69, 0x73, 0x20, 0x76, 0x65, 0x6c, 0x20, 0x72, + ]; + + let ieee802154_frame = Ieee802154Frame::new_checked(frame1).unwrap(); + let ieee802154_repr = Ieee802154Repr::parse(&ieee802154_frame).unwrap(); + + let sixlowpan_frame = + SixlowpanPacket::dispatch(ieee802154_frame.payload().unwrap()).unwrap(); + + let frag = if let SixlowpanPacket::FragmentHeader = sixlowpan_frame { + frag::Packet::new_checked(ieee802154_frame.payload().unwrap()).unwrap() + } else { + unreachable!() + }; + + assert_eq!(frag.datagram_size(), 307); + assert_eq!(frag.datagram_tag(), 0x003f); + assert_eq!(frag.datagram_offset(), 0); + + assert_eq!(frag.get_key(&ieee802154_repr), key); + + let frame2: &[u8] = &[ + 0x41, 0xcc, 0x93, 0xef, 0xbe, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, 0xd9, + 0x3e, 0x08, 0x28, 0x2f, 0x82, 0x93, 0x32, 0xe1, 0x33, 0x00, 0x3f, 0x11, 0x75, 0x74, + 0x72, 0x75, 0x6d, 0x20, 0x61, 0x74, 0x2c, 0x20, 0x74, 0x72, 0x69, 0x73, 0x74, 0x69, + 0x71, 0x75, 0x65, 0x20, 0x6e, 0x6f, 0x6e, 0x20, 0x6e, 0x75, 0x6e, 0x63, 0x20, 0x65, + 0x72, 0x61, 0x74, 0x20, 0x63, 0x75, 0x72, 0x61, 0x65, 0x2e, 0x20, 0x4c, 0x6f, 0x72, + 0x65, 0x6d, 0x20, 0x69, 0x70, 0x73, 0x75, 0x6d, 0x20, 0x64, 0x6f, 0x6c, 0x6f, 0x72, + 0x20, 0x73, 0x69, 0x74, 0x20, 0x61, 0x6d, 0x65, 0x74, 0x2c, 0x20, 0x63, 0x6f, 0x6e, + 0x73, 0x65, 0x63, 0x74, 0x65, 0x74, 0x75, 0x72, 0x20, 0x61, 0x64, 0x69, 0x70, 0x69, + 0x73, 0x63, 0x69, 0x6e, 0x67, 0x20, 0x65, 0x6c, 0x69, 0x74, + ]; + + let ieee802154_frame = Ieee802154Frame::new_checked(frame2).unwrap(); + let ieee802154_repr = Ieee802154Repr::parse(&ieee802154_frame).unwrap(); + + let sixlowpan_frame = + SixlowpanPacket::dispatch(ieee802154_frame.payload().unwrap()).unwrap(); + + let frag = if let SixlowpanPacket::FragmentHeader = sixlowpan_frame { + frag::Packet::new_checked(ieee802154_frame.payload().unwrap()).unwrap() + } else { + unreachable!() + }; + + assert_eq!(frag.datagram_size(), 307); + assert_eq!(frag.datagram_tag(), 0x003f); + assert_eq!(frag.datagram_offset(), 136 / 8); + + assert_eq!(frag.get_key(&ieee802154_repr), key); + + let frame3: &[u8] = &[ + 0x41, 0xcc, 0x94, 0xef, 0xbe, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, 0xd9, + 0x3e, 0x08, 0x28, 0x2f, 0x82, 0x93, 0x32, 0xe1, 0x33, 0x00, 0x3f, 0x1d, 0x2e, 0x20, + 0x41, 0x6c, 0x69, 0x71, 0x75, 0x61, 0x6d, 0x20, 0x64, 0x75, 0x69, 0x20, 0x6f, 0x64, + 0x69, 0x6f, 0x2c, 0x20, 0x69, 0x61, 0x63, 0x75, 0x6c, 0x69, 0x73, 0x20, 0x76, 0x65, + 0x6c, 0x20, 0x72, 0x75, 0x74, 0x72, 0x75, 0x6d, 0x20, 0x61, 0x74, 0x2c, 0x20, 0x74, + 0x72, 0x69, 0x73, 0x74, 0x69, 0x71, 0x75, 0x65, 0x20, 0x6e, 0x6f, 0x6e, 0x20, 0x6e, + 0x75, 0x6e, 0x63, 0x20, 0x65, 0x72, 0x61, 0x74, 0x20, 0x63, 0x75, 0x72, 0x61, 0x65, + 0x2e, 0x20, 0x0a, + ]; + + let ieee802154_frame = Ieee802154Frame::new_checked(frame3).unwrap(); + let ieee802154_repr = Ieee802154Repr::parse(&ieee802154_frame).unwrap(); + + let sixlowpan_frame = + SixlowpanPacket::dispatch(ieee802154_frame.payload().unwrap()).unwrap(); + + let frag = if let SixlowpanPacket::FragmentHeader = sixlowpan_frame { + frag::Packet::new_checked(ieee802154_frame.payload().unwrap()).unwrap() + } else { + unreachable!() + }; + + assert_eq!(frag.datagram_size(), 307); + assert_eq!(frag.datagram_tag(), 0x003f); + assert_eq!(frag.datagram_offset(), 232 / 8); + + assert_eq!(frag.get_key(&ieee802154_repr), key); + } +} diff --git a/src/wire/sixlowpan/nhc.rs b/src/wire/sixlowpan/nhc.rs new file mode 100644 index 0000000..5539ce4 --- /dev/null +++ b/src/wire/sixlowpan/nhc.rs @@ -0,0 +1,890 @@ +//! Implementation of Next Header Compression from [RFC 6282 § 4]. +//! +//! [RFC 6282 § 4]: https://datatracker.ietf.org/doc/html/rfc6282#section-4 +use super::{Error, NextHeader, Result, DISPATCH_EXT_HEADER, DISPATCH_UDP_HEADER}; +use crate::{ + phy::ChecksumCapabilities, + wire::{ + ip::{checksum, Address as IpAddress}, + ipv6, + udp::Repr as UdpRepr, + IpProtocol, + }, +}; +use byteorder::{ByteOrder, NetworkEndian}; +use ipv6::Address; + +macro_rules! get_field { + ($name:ident, $mask:expr, $shift:expr) => { + fn $name(&self) -> u8 { + let data = self.buffer.as_ref(); + let raw = &data[0]; + ((raw >> $shift) & $mask) as u8 + } + }; +} + +macro_rules! set_field { + ($name:ident, $mask:expr, $shift:expr) => { + fn $name(&mut self, val: u8) { + let data = self.buffer.as_mut(); + let mut raw = data[0]; + raw = (raw & !($mask << $shift)) | (val << $shift); + data[0] = raw; + } + }; +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +/// A read/write wrapper around a 6LoWPAN_NHC Header. +/// [RFC 6282 § 4.2] specifies the format of the header. +/// +/// The header has the following format: +/// ```txt +/// 0 1 2 3 4 5 6 7 +/// +---+---+---+---+---+---+---+---+ +/// | 1 | 1 | 1 | 0 | EID |NH | +/// +---+---+---+---+---+---+---+---+ +/// ``` +/// +/// With: +/// - EID: the extension header ID +/// - NH: Next Header +/// +/// [RFC 6282 § 4.2]: https://datatracker.ietf.org/doc/html/rfc6282#section-4.2 +pub enum NhcPacket { + ExtHeader, + UdpHeader, +} + +impl NhcPacket { + /// Returns the type of the Next Header header. + /// This can either be an Extension header or an 6LoWPAN Udp header. + /// + /// # Errors + /// Returns `[Error::Unrecognized]` when neither the Extension Header dispatch or the Udp + /// dispatch is recognized. + pub fn dispatch(buffer: impl AsRef<[u8]>) -> Result<Self> { + let raw = buffer.as_ref(); + if raw.is_empty() { + return Err(Error); + } + + if raw[0] >> 4 == DISPATCH_EXT_HEADER { + // We have a compressed IPv6 Extension Header. + Ok(Self::ExtHeader) + } else if raw[0] >> 3 == DISPATCH_UDP_HEADER { + // We have a compressed UDP header. + Ok(Self::UdpHeader) + } else { + Err(Error) + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum ExtHeaderId { + HopByHopHeader, + RoutingHeader, + FragmentHeader, + DestinationOptionsHeader, + MobilityHeader, + Header, + Reserved, +} + +impl From<ExtHeaderId> for IpProtocol { + fn from(val: ExtHeaderId) -> Self { + match val { + ExtHeaderId::HopByHopHeader => Self::HopByHop, + ExtHeaderId::RoutingHeader => Self::Ipv6Route, + ExtHeaderId::FragmentHeader => Self::Ipv6Frag, + ExtHeaderId::DestinationOptionsHeader => Self::Ipv6Opts, + ExtHeaderId::MobilityHeader => Self::Unknown(0), + ExtHeaderId::Header => Self::Unknown(0), + ExtHeaderId::Reserved => Self::Unknown(0), + } + } +} + +/// A read/write wrapper around a 6LoWPAN NHC Extension header. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct ExtHeaderPacket<T: AsRef<[u8]>> { + buffer: T, +} + +impl<T: AsRef<[u8]>> ExtHeaderPacket<T> { + /// Input a raw octet buffer with a 6LoWPAN NHC Extension Header structure. + pub const fn new_unchecked(buffer: T) -> Self { + ExtHeaderPacket { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Self> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + + if packet.eid_field() > 7 { + return Err(Error); + } + + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let buffer = self.buffer.as_ref(); + + if buffer.is_empty() { + return Err(Error); + } + + let mut len = 2; + len += self.next_header_size(); + + if len <= buffer.len() { + Ok(()) + } else { + Err(Error) + } + } + + /// Consumes the frame, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + get_field!(dispatch_field, 0b1111, 4); + get_field!(eid_field, 0b111, 1); + get_field!(nh_field, 0b1, 0); + + /// Return the Extension Header ID. + pub fn extension_header_id(&self) -> ExtHeaderId { + match self.eid_field() { + 0 => ExtHeaderId::HopByHopHeader, + 1 => ExtHeaderId::RoutingHeader, + 2 => ExtHeaderId::FragmentHeader, + 3 => ExtHeaderId::DestinationOptionsHeader, + 4 => ExtHeaderId::MobilityHeader, + 5 | 6 => ExtHeaderId::Reserved, + 7 => ExtHeaderId::Header, + _ => unreachable!(), + } + } + + /// Return the length field. + pub fn length(&self) -> u8 { + self.buffer.as_ref()[1 + self.next_header_size()] + } + + /// Parse the next header field. + pub fn next_header(&self) -> NextHeader { + if self.nh_field() == 1 { + NextHeader::Compressed + } else { + // The full 8 bits for Next Header are carried in-line. + NextHeader::Uncompressed(IpProtocol::from(self.buffer.as_ref()[1])) + } + } + + /// Return the size of the Next Header field. + fn next_header_size(&self) -> usize { + // If nh is set, then the Next Header is compressed using LOWPAN_NHC + match self.nh_field() { + 0 => 1, + 1 => 0, + _ => unreachable!(), + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> ExtHeaderPacket<&'a T> { + /// Return a pointer to the payload. + pub fn payload(&self) -> &'a [u8] { + let start = 2 + self.next_header_size(); + let len = self.length() as usize; + &self.buffer.as_ref()[start..][..len] + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> ExtHeaderPacket<T> { + /// Return a mutable pointer to the payload. + pub fn payload_mut(&mut self) -> &mut [u8] { + let start = 2 + self.next_header_size(); + let len = self.length() as usize; + &mut self.buffer.as_mut()[start..][..len] + } + + /// Set the dispatch field to `0b1110`. + fn set_dispatch_field(&mut self) { + let data = self.buffer.as_mut(); + data[0] = (data[0] & !(0b1111 << 4)) | (DISPATCH_EXT_HEADER << 4); + } + + set_field!(set_eid_field, 0b111, 1); + set_field!(set_nh_field, 0b1, 0); + + /// Set the Extension Header ID field. + fn set_extension_header_id(&mut self, ext_header_id: ExtHeaderId) { + let id = match ext_header_id { + ExtHeaderId::HopByHopHeader => 0, + ExtHeaderId::RoutingHeader => 1, + ExtHeaderId::FragmentHeader => 2, + ExtHeaderId::DestinationOptionsHeader => 3, + ExtHeaderId::MobilityHeader => 4, + ExtHeaderId::Reserved => 5, + ExtHeaderId::Header => 7, + }; + + self.set_eid_field(id); + } + + /// Set the Next Header. + fn set_next_header(&mut self, next_header: NextHeader) { + match next_header { + NextHeader::Compressed => self.set_nh_field(0b1), + NextHeader::Uncompressed(nh) => { + self.set_nh_field(0b0); + + let start = 1; + let data = self.buffer.as_mut(); + data[start] = nh.into(); + } + } + } + + /// Set the length. + fn set_length(&mut self, length: u8) { + let start = 1 + self.next_header_size(); + + let data = self.buffer.as_mut(); + data[start] = length; + } +} + +/// A high-level representation of an 6LoWPAN NHC Extension header. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct ExtHeaderRepr { + pub ext_header_id: ExtHeaderId, + pub next_header: NextHeader, + pub length: u8, +} + +impl ExtHeaderRepr { + /// Parse a 6LoWPAN NHC Extension Header packet and return a high-level representation. + pub fn parse<T: AsRef<[u8]> + ?Sized>(packet: &ExtHeaderPacket<&T>) -> Result<Self> { + // Ensure basic accessors will work. + packet.check_len()?; + + if packet.dispatch_field() != DISPATCH_EXT_HEADER { + return Err(Error); + } + + Ok(Self { + ext_header_id: packet.extension_header_id(), + next_header: packet.next_header(), + length: packet.length(), + }) + } + + /// Return the length of a header that will be emitted from this high-level representation. + pub fn buffer_len(&self) -> usize { + let mut len = 1; // The minimal header size + + if self.next_header != NextHeader::Compressed { + len += 1; + } + + len += 1; // The length + + len + } + + /// Emit a high-level representation into a 6LoWPAN NHC Extension Header packet. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, packet: &mut ExtHeaderPacket<T>) { + packet.set_dispatch_field(); + packet.set_extension_header_id(self.ext_header_id); + packet.set_next_header(self.next_header); + packet.set_length(self.length); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::wire::{Ipv6RoutingHeader, Ipv6RoutingRepr}; + + #[cfg(feature = "proto-rpl")] + use crate::wire::{ + Ipv6Option, Ipv6OptionRepr, Ipv6OptionsIterator, RplHopByHopRepr, RplInstanceId, + }; + + #[cfg(feature = "proto-rpl")] + const RPL_HOP_BY_HOP_PACKET: [u8; 9] = [0xe0, 0x3a, 0x06, 0x63, 0x04, 0x00, 0x1e, 0x03, 0x00]; + + const ROUTING_SR_PACKET: [u8; 32] = [ + 0xe3, 0x1e, 0x03, 0x03, 0x99, 0x30, 0x00, 0x00, 0x05, 0x00, 0x05, 0x00, 0x05, 0x00, 0x05, + 0x06, 0x00, 0x06, 0x00, 0x06, 0x00, 0x06, 0x02, 0x00, 0x02, 0x00, 0x02, 0x00, 0x02, 0x00, + 0x00, 0x00, + ]; + + #[test] + #[cfg(feature = "proto-rpl")] + fn test_rpl_hop_by_hop_option_deconstruct() { + let header = ExtHeaderPacket::new_checked(&RPL_HOP_BY_HOP_PACKET).unwrap(); + assert_eq!( + header.next_header(), + NextHeader::Uncompressed(IpProtocol::Icmpv6) + ); + assert_eq!(header.extension_header_id(), ExtHeaderId::HopByHopHeader); + + let options = header.payload(); + let mut options = Ipv6OptionsIterator::new(options); + let rpl_repr = options.next().unwrap(); + let rpl_repr = rpl_repr.unwrap(); + + match rpl_repr { + Ipv6OptionRepr::Rpl(rpl) => { + assert_eq!( + rpl, + RplHopByHopRepr { + down: false, + rank_error: false, + forwarding_error: false, + instance_id: RplInstanceId::from(0x1e), + sender_rank: 0x0300, + } + ); + } + _ => unreachable!(), + } + } + + #[test] + #[cfg(feature = "proto-rpl")] + fn test_rpl_hop_by_hop_option_emit() { + let repr = Ipv6OptionRepr::Rpl(RplHopByHopRepr { + down: false, + rank_error: false, + forwarding_error: false, + instance_id: RplInstanceId::from(0x1e), + sender_rank: 0x0300, + }); + + let ext_hdr = ExtHeaderRepr { + ext_header_id: ExtHeaderId::HopByHopHeader, + next_header: NextHeader::Uncompressed(IpProtocol::Icmpv6), + length: repr.buffer_len() as u8, + }; + + let mut buffer = vec![0u8; ext_hdr.buffer_len() + repr.buffer_len()]; + ext_hdr.emit(&mut ExtHeaderPacket::new_unchecked( + &mut buffer[..ext_hdr.buffer_len()], + )); + repr.emit(&mut Ipv6Option::new_unchecked( + &mut buffer[ext_hdr.buffer_len()..], + )); + + assert_eq!(&buffer[..], RPL_HOP_BY_HOP_PACKET); + } + + #[test] + fn test_source_routing_deconstruct() { + let header = ExtHeaderPacket::new_checked(&ROUTING_SR_PACKET).unwrap(); + assert_eq!(header.next_header(), NextHeader::Compressed); + assert_eq!(header.extension_header_id(), ExtHeaderId::RoutingHeader); + + let routing_hdr = Ipv6RoutingHeader::new_checked(header.payload()).unwrap(); + let repr = Ipv6RoutingRepr::parse(&routing_hdr).unwrap(); + assert_eq!( + repr, + Ipv6RoutingRepr::Rpl { + segments_left: 3, + cmpr_i: 9, + cmpr_e: 9, + pad: 3, + addresses: &[ + 0x05, 0x00, 0x05, 0x00, 0x05, 0x00, 0x05, 0x06, 0x00, 0x06, 0x00, 0x06, 0x00, + 0x06, 0x02, 0x00, 0x02, 0x00, 0x02, 0x00, 0x02, 0x00, 0x00, 0x00 + ], + } + ); + } + + #[test] + fn test_source_routing_emit() { + let routing_hdr = Ipv6RoutingRepr::Rpl { + segments_left: 3, + cmpr_i: 9, + cmpr_e: 9, + pad: 3, + addresses: &[ + 0x05, 0x00, 0x05, 0x00, 0x05, 0x00, 0x05, 0x06, 0x00, 0x06, 0x00, 0x06, 0x00, 0x06, + 0x02, 0x00, 0x02, 0x00, 0x02, 0x00, 0x02, 0x00, 0x00, 0x00, + ], + }; + + let ext_hdr = ExtHeaderRepr { + ext_header_id: ExtHeaderId::RoutingHeader, + next_header: NextHeader::Compressed, + length: routing_hdr.buffer_len() as u8, + }; + + let mut buffer = vec![0u8; ext_hdr.buffer_len() + routing_hdr.buffer_len()]; + ext_hdr.emit(&mut ExtHeaderPacket::new_unchecked( + &mut buffer[..ext_hdr.buffer_len()], + )); + routing_hdr.emit(&mut Ipv6RoutingHeader::new_unchecked( + &mut buffer[ext_hdr.buffer_len()..], + )); + + assert_eq!(&buffer[..], ROUTING_SR_PACKET); + } +} + +/// A read/write wrapper around a 6LoWPAN_NHC UDP frame. +/// [RFC 6282 § 4.3] specifies the format of the header. +/// +/// The base header has the following format: +/// ```txt +/// 0 1 2 3 4 5 6 7 +/// +---+---+---+---+---+---+---+---+ +/// | 1 | 1 | 1 | 1 | 0 | C | P | +/// +---+---+---+---+---+---+---+---+ +/// With: +/// - C: checksum, specifies if the checksum is elided. +/// - P: ports, specifies if the ports are elided. +/// ``` +/// +/// [RFC 6282 § 4.3]: https://datatracker.ietf.org/doc/html/rfc6282#section-4.3 +#[derive(Debug, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct UdpNhcPacket<T: AsRef<[u8]>> { + buffer: T, +} + +impl<T: AsRef<[u8]>> UdpNhcPacket<T> { + /// Input a raw octet buffer with a LOWPAN_NHC frame structure for UDP. + pub const fn new_unchecked(buffer: T) -> Self { + Self { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Self> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error::Truncated)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let buffer = self.buffer.as_ref(); + + if buffer.is_empty() { + return Err(Error); + } + + let index = 1 + self.ports_size() + self.checksum_size(); + if index > buffer.len() { + return Err(Error); + } + + Ok(()) + } + + /// Consumes the frame, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + get_field!(dispatch_field, 0b11111, 3); + get_field!(checksum_field, 0b1, 2); + get_field!(ports_field, 0b11, 0); + + /// Returns the index of the start of the next header compressed fields. + const fn nhc_fields_start(&self) -> usize { + 1 + } + + /// Return the source port number. + pub fn src_port(&self) -> u16 { + match self.ports_field() { + 0b00 | 0b01 => { + // The full 16 bits are carried in-line. + let data = self.buffer.as_ref(); + let start = self.nhc_fields_start(); + + NetworkEndian::read_u16(&data[start..start + 2]) + } + 0b10 => { + // The first 8 bits are elided. + let data = self.buffer.as_ref(); + let start = self.nhc_fields_start(); + + 0xf000 + data[start] as u16 + } + 0b11 => { + // The first 12 bits are elided. + let data = self.buffer.as_ref(); + let start = self.nhc_fields_start(); + + 0xf0b0 + (data[start] >> 4) as u16 + } + _ => unreachable!(), + } + } + + /// Return the destination port number. + pub fn dst_port(&self) -> u16 { + match self.ports_field() { + 0b00 => { + // The full 16 bits are carried in-line. + let data = self.buffer.as_ref(); + let idx = self.nhc_fields_start(); + + NetworkEndian::read_u16(&data[idx + 2..idx + 4]) + } + 0b01 => { + // The first 8 bits are elided. + let data = self.buffer.as_ref(); + let idx = self.nhc_fields_start(); + + 0xf000 + data[idx] as u16 + } + 0b10 => { + // The full 16 bits are carried in-line. + let data = self.buffer.as_ref(); + let idx = self.nhc_fields_start(); + + NetworkEndian::read_u16(&data[idx + 1..idx + 1 + 2]) + } + 0b11 => { + // The first 12 bits are elided. + let data = self.buffer.as_ref(); + let start = self.nhc_fields_start(); + + 0xf0b0 + (data[start] & 0xff) as u16 + } + _ => unreachable!(), + } + } + + /// Return the checksum. + pub fn checksum(&self) -> Option<u16> { + if self.checksum_field() == 0b0 { + // The first 12 bits are elided. + let data = self.buffer.as_ref(); + let start = self.nhc_fields_start() + self.ports_size(); + Some(NetworkEndian::read_u16(&data[start..start + 2])) + } else { + // The checksum is elided and needs to be recomputed on the 6LoWPAN termination point. + None + } + } + + // Return the size of the checksum field. + pub(crate) fn checksum_size(&self) -> usize { + match self.checksum_field() { + 0b0 => 2, + 0b1 => 0, + _ => unreachable!(), + } + } + + /// Returns the total size of both port numbers. + pub(crate) fn ports_size(&self) -> usize { + match self.ports_field() { + 0b00 => 4, // 16 bits + 16 bits + 0b01 => 3, // 16 bits + 8 bits + 0b10 => 3, // 8 bits + 16 bits + 0b11 => 1, // 4 bits + 4 bits + _ => unreachable!(), + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> UdpNhcPacket<&'a T> { + /// Return a pointer to the payload. + pub fn payload(&self) -> &'a [u8] { + let start = 1 + self.ports_size() + self.checksum_size(); + &self.buffer.as_ref()[start..] + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> UdpNhcPacket<T> { + /// Return a mutable pointer to the payload. + pub fn payload_mut(&mut self) -> &mut [u8] { + let start = 1 + self.ports_size() + 2; // XXX(thvdveld): we assume we put the checksum inlined. + &mut self.buffer.as_mut()[start..] + } + + /// Set the dispatch field to `0b11110`. + fn set_dispatch_field(&mut self) { + let data = self.buffer.as_mut(); + data[0] = (data[0] & !(0b11111 << 3)) | (DISPATCH_UDP_HEADER << 3); + } + + set_field!(set_checksum_field, 0b1, 2); + set_field!(set_ports_field, 0b11, 0); + + fn set_ports(&mut self, src_port: u16, dst_port: u16) { + let mut idx = 1; + + match (src_port, dst_port) { + (0xf0b0..=0xf0bf, 0xf0b0..=0xf0bf) => { + // We can compress both the source and destination ports. + self.set_ports_field(0b11); + let data = self.buffer.as_mut(); + data[idx] = (((src_port - 0xf0b0) as u8) << 4) & ((dst_port - 0xf0b0) as u8); + } + (0xf000..=0xf0ff, _) => { + // We can compress the source port, but not the destination port. + self.set_ports_field(0b10); + let data = self.buffer.as_mut(); + data[idx] = (src_port - 0xf000) as u8; + idx += 1; + + NetworkEndian::write_u16(&mut data[idx..idx + 2], dst_port); + } + (_, 0xf000..=0xf0ff) => { + // We can compress the destination port, but not the source port. + self.set_ports_field(0b01); + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[idx..idx + 2], src_port); + idx += 2; + data[idx] = (dst_port - 0xf000) as u8; + } + (_, _) => { + // We cannot compress any port. + self.set_ports_field(0b00); + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[idx..idx + 2], src_port); + idx += 2; + NetworkEndian::write_u16(&mut data[idx..idx + 2], dst_port); + } + }; + } + + fn set_checksum(&mut self, checksum: u16) { + self.set_checksum_field(0b0); + let idx = 1 + self.ports_size(); + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[idx..idx + 2], checksum); + } +} + +/// A high-level representation of a 6LoWPAN NHC UDP header. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct UdpNhcRepr(pub UdpRepr); + +impl<'a> UdpNhcRepr { + /// Parse a 6LoWPAN NHC UDP packet and return a high-level representation. + pub fn parse<T: AsRef<[u8]> + ?Sized>( + packet: &UdpNhcPacket<&'a T>, + src_addr: &ipv6::Address, + dst_addr: &ipv6::Address, + checksum_caps: &ChecksumCapabilities, + ) -> Result<Self> { + packet.check_len()?; + + if packet.dispatch_field() != DISPATCH_UDP_HEADER { + return Err(Error); + } + + if checksum_caps.udp.rx() { + let payload_len = packet.payload().len(); + let chk_sum = !checksum::combine(&[ + checksum::pseudo_header( + &IpAddress::Ipv6(*src_addr), + &IpAddress::Ipv6(*dst_addr), + crate::wire::ip::Protocol::Udp, + payload_len as u32 + 8, + ), + packet.src_port(), + packet.dst_port(), + payload_len as u16 + 8, + checksum::data(packet.payload()), + ]); + + if let Some(checksum) = packet.checksum() { + if chk_sum != checksum { + return Err(Error); + } + } + } + + Ok(Self(UdpRepr { + src_port: packet.src_port(), + dst_port: packet.dst_port(), + })) + } + + /// Return the length of a packet that will be emitted from this high-level representation. + pub fn header_len(&self) -> usize { + let mut len = 1; // The minimal header size + + len += 2; // XXX We assume we will add the checksum at the end + + // Check if we can compress the source and destination ports + match (self.src_port, self.dst_port) { + (0xf0b0..=0xf0bf, 0xf0b0..=0xf0bf) => len + 1, + (0xf000..=0xf0ff, _) | (_, 0xf000..=0xf0ff) => len + 3, + (_, _) => len + 4, + } + } + + /// Emit a high-level representation into a LOWPAN_NHC UDP header. + pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>( + &self, + packet: &mut UdpNhcPacket<T>, + src_addr: &Address, + dst_addr: &Address, + payload_len: usize, + emit_payload: impl FnOnce(&mut [u8]), + checksum_caps: &ChecksumCapabilities, + ) { + packet.set_dispatch_field(); + packet.set_ports(self.src_port, self.dst_port); + emit_payload(packet.payload_mut()); + + if checksum_caps.udp.tx() { + let chk_sum = !checksum::combine(&[ + checksum::pseudo_header( + &IpAddress::Ipv6(*src_addr), + &IpAddress::Ipv6(*dst_addr), + crate::wire::ip::Protocol::Udp, + payload_len as u32 + 8, + ), + self.src_port, + self.dst_port, + payload_len as u16 + 8, + checksum::data(packet.payload_mut()), + ]); + + packet.set_checksum(chk_sum); + } + } +} + +impl core::ops::Deref for UdpNhcRepr { + type Target = UdpRepr; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl core::ops::DerefMut for UdpNhcRepr { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn ext_header_nh_inlined() { + let bytes = [0xe2, 0x3a, 0x6, 0x3, 0x0, 0xff, 0x0, 0x0, 0x0]; + + let packet = ExtHeaderPacket::new_checked(&bytes[..]).unwrap(); + assert_eq!(packet.next_header_size(), 1); + assert_eq!(packet.length(), 6); + assert_eq!(packet.dispatch_field(), DISPATCH_EXT_HEADER); + assert_eq!(packet.extension_header_id(), ExtHeaderId::RoutingHeader); + assert_eq!( + packet.next_header(), + NextHeader::Uncompressed(IpProtocol::Icmpv6) + ); + + assert_eq!(packet.payload(), [0x03, 0x00, 0xff, 0x00, 0x00, 0x00]); + } + + #[test] + fn ext_header_nh_elided() { + let bytes = [0xe3, 0x06, 0x03, 0x00, 0xff, 0x00, 0x00, 0x00]; + + let packet = ExtHeaderPacket::new_checked(&bytes[..]).unwrap(); + assert_eq!(packet.next_header_size(), 0); + assert_eq!(packet.length(), 6); + assert_eq!(packet.dispatch_field(), DISPATCH_EXT_HEADER); + assert_eq!(packet.extension_header_id(), ExtHeaderId::RoutingHeader); + assert_eq!(packet.next_header(), NextHeader::Compressed); + + assert_eq!(packet.payload(), [0x03, 0x00, 0xff, 0x00, 0x00, 0x00]); + } + + #[test] + fn ext_header_emit() { + let ext_header = ExtHeaderRepr { + ext_header_id: ExtHeaderId::RoutingHeader, + next_header: NextHeader::Compressed, + length: 6, + }; + + let len = ext_header.buffer_len(); + let mut buffer = [0u8; 127]; + let mut packet = ExtHeaderPacket::new_unchecked(&mut buffer[..len]); + ext_header.emit(&mut packet); + + assert_eq!(packet.dispatch_field(), DISPATCH_EXT_HEADER); + assert_eq!(packet.next_header(), NextHeader::Compressed); + assert_eq!(packet.extension_header_id(), ExtHeaderId::RoutingHeader); + } + + #[test] + fn udp_nhc_fields() { + let bytes = [0xf0, 0x16, 0x2e, 0x22, 0x3d, 0x28, 0xc4]; + + let packet = UdpNhcPacket::new_checked(&bytes[..]).unwrap(); + assert_eq!(packet.dispatch_field(), DISPATCH_UDP_HEADER); + assert_eq!(packet.checksum(), Some(0x28c4)); + assert_eq!(packet.src_port(), 5678); + assert_eq!(packet.dst_port(), 8765); + } + + #[test] + fn udp_emit() { + let udp = UdpNhcRepr(UdpRepr { + src_port: 0xf0b1, + dst_port: 0xf001, + }); + + let payload = b"Hello World!"; + + let src_addr = ipv6::Address::default(); + let dst_addr = ipv6::Address::default(); + + let len = udp.header_len() + payload.len(); + let mut buffer = [0u8; 127]; + let mut packet = UdpNhcPacket::new_unchecked(&mut buffer[..len]); + udp.emit( + &mut packet, + &src_addr, + &dst_addr, + payload.len(), + |buf| buf.copy_from_slice(&payload[..]), + &ChecksumCapabilities::default(), + ); + + assert_eq!(packet.dispatch_field(), DISPATCH_UDP_HEADER); + assert_eq!(packet.src_port(), 0xf0b1); + assert_eq!(packet.dst_port(), 0xf001); + assert_eq!(packet.payload_mut(), b"Hello World!"); + } +} diff --git a/src/wire/tcp.rs b/src/wire/tcp.rs new file mode 100644 index 0000000..2482143 --- /dev/null +++ b/src/wire/tcp.rs @@ -0,0 +1,1331 @@ +use byteorder::{ByteOrder, NetworkEndian}; +use core::{cmp, fmt, i32, ops}; + +use super::{Error, Result}; +use crate::phy::ChecksumCapabilities; +use crate::wire::ip::checksum; +use crate::wire::{IpAddress, IpProtocol}; + +/// A TCP sequence number. +/// +/// A sequence number is a monotonically advancing integer modulo 2<sup>32</sup>. +/// Sequence numbers do not have a discontiguity when compared pairwise across a signed overflow. +#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)] +pub struct SeqNumber(pub i32); + +impl SeqNumber { + pub fn max(self, rhs: Self) -> Self { + if self > rhs { + self + } else { + rhs + } + } + + pub fn min(self, rhs: Self) -> Self { + if self < rhs { + self + } else { + rhs + } + } +} + +impl fmt::Display for SeqNumber { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0 as u32) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for SeqNumber { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "{}", self.0 as u32); + } +} + +impl ops::Add<usize> for SeqNumber { + type Output = SeqNumber; + + fn add(self, rhs: usize) -> SeqNumber { + if rhs > i32::MAX as usize { + panic!("attempt to add to sequence number with unsigned overflow") + } + SeqNumber(self.0.wrapping_add(rhs as i32)) + } +} + +impl ops::Sub<usize> for SeqNumber { + type Output = SeqNumber; + + fn sub(self, rhs: usize) -> SeqNumber { + if rhs > i32::MAX as usize { + panic!("attempt to subtract to sequence number with unsigned overflow") + } + SeqNumber(self.0.wrapping_sub(rhs as i32)) + } +} + +impl ops::AddAssign<usize> for SeqNumber { + fn add_assign(&mut self, rhs: usize) { + *self = *self + rhs; + } +} + +impl ops::Sub for SeqNumber { + type Output = usize; + + fn sub(self, rhs: SeqNumber) -> usize { + let result = self.0.wrapping_sub(rhs.0); + if result < 0 { + panic!("attempt to subtract sequence numbers with underflow") + } + result as usize + } +} + +impl cmp::PartialOrd for SeqNumber { + fn partial_cmp(&self, other: &SeqNumber) -> Option<cmp::Ordering> { + self.0.wrapping_sub(other.0).partial_cmp(&0) + } +} + +/// A read/write wrapper around a Transmission Control Protocol packet buffer. +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Packet<T: AsRef<[u8]>> { + buffer: T, +} + +mod field { + #![allow(non_snake_case)] + + use crate::wire::field::*; + + pub const SRC_PORT: Field = 0..2; + pub const DST_PORT: Field = 2..4; + pub const SEQ_NUM: Field = 4..8; + pub const ACK_NUM: Field = 8..12; + pub const FLAGS: Field = 12..14; + pub const WIN_SIZE: Field = 14..16; + pub const CHECKSUM: Field = 16..18; + pub const URGENT: Field = 18..20; + + pub const fn OPTIONS(length: u8) -> Field { + URGENT.end..(length as usize) + } + + pub const FLG_FIN: u16 = 0x001; + pub const FLG_SYN: u16 = 0x002; + pub const FLG_RST: u16 = 0x004; + pub const FLG_PSH: u16 = 0x008; + pub const FLG_ACK: u16 = 0x010; + pub const FLG_URG: u16 = 0x020; + pub const FLG_ECE: u16 = 0x040; + pub const FLG_CWR: u16 = 0x080; + pub const FLG_NS: u16 = 0x100; + + pub const OPT_END: u8 = 0x00; + pub const OPT_NOP: u8 = 0x01; + pub const OPT_MSS: u8 = 0x02; + pub const OPT_WS: u8 = 0x03; + pub const OPT_SACKPERM: u8 = 0x04; + pub const OPT_SACKRNG: u8 = 0x05; +} + +pub const HEADER_LEN: usize = field::URGENT.end; + +impl<T: AsRef<[u8]>> Packet<T> { + /// Imbue a raw octet buffer with TCP packet structure. + pub const fn new_unchecked(buffer: T) -> Packet<T> { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Packet<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + /// Returns `Err(Error)` if the header length field has a value smaller + /// than the minimal header length. + /// + /// The result of this check is invalidated by calling [set_header_len]. + /// + /// [set_header_len]: #method.set_header_len + pub fn check_len(&self) -> Result<()> { + let len = self.buffer.as_ref().len(); + if len < field::URGENT.end { + Err(Error) + } else { + let header_len = self.header_len() as usize; + if len < header_len || header_len < field::URGENT.end { + Err(Error) + } else { + Ok(()) + } + } + } + + /// Consume the packet, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the source port field. + #[inline] + pub fn src_port(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::SRC_PORT]) + } + + /// Return the destination port field. + #[inline] + pub fn dst_port(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::DST_PORT]) + } + + /// Return the sequence number field. + #[inline] + pub fn seq_number(&self) -> SeqNumber { + let data = self.buffer.as_ref(); + SeqNumber(NetworkEndian::read_i32(&data[field::SEQ_NUM])) + } + + /// Return the acknowledgement number field. + #[inline] + pub fn ack_number(&self) -> SeqNumber { + let data = self.buffer.as_ref(); + SeqNumber(NetworkEndian::read_i32(&data[field::ACK_NUM])) + } + + /// Return the FIN flag. + #[inline] + pub fn fin(&self) -> bool { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + raw & field::FLG_FIN != 0 + } + + /// Return the SYN flag. + #[inline] + pub fn syn(&self) -> bool { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + raw & field::FLG_SYN != 0 + } + + /// Return the RST flag. + #[inline] + pub fn rst(&self) -> bool { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + raw & field::FLG_RST != 0 + } + + /// Return the PSH flag. + #[inline] + pub fn psh(&self) -> bool { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + raw & field::FLG_PSH != 0 + } + + /// Return the ACK flag. + #[inline] + pub fn ack(&self) -> bool { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + raw & field::FLG_ACK != 0 + } + + /// Return the URG flag. + #[inline] + pub fn urg(&self) -> bool { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + raw & field::FLG_URG != 0 + } + + /// Return the ECE flag. + #[inline] + pub fn ece(&self) -> bool { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + raw & field::FLG_ECE != 0 + } + + /// Return the CWR flag. + #[inline] + pub fn cwr(&self) -> bool { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + raw & field::FLG_CWR != 0 + } + + /// Return the NS flag. + #[inline] + pub fn ns(&self) -> bool { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + raw & field::FLG_NS != 0 + } + + /// Return the header length, in octets. + #[inline] + pub fn header_len(&self) -> u8 { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + ((raw >> 12) * 4) as u8 + } + + /// Return the window size field. + #[inline] + pub fn window_len(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::WIN_SIZE]) + } + + /// Return the checksum field. + #[inline] + pub fn checksum(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::CHECKSUM]) + } + + /// Return the urgent pointer field. + #[inline] + pub fn urgent_at(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::URGENT]) + } + + /// Return the length of the segment, in terms of sequence space. + pub fn segment_len(&self) -> usize { + let data = self.buffer.as_ref(); + let mut length = data.len() - self.header_len() as usize; + if self.syn() { + length += 1 + } + if self.fin() { + length += 1 + } + length + } + + /// Returns whether the selective acknowledgement SYN flag is set or not. + pub fn selective_ack_permitted(&self) -> Result<bool> { + let data = self.buffer.as_ref(); + let mut options = &data[field::OPTIONS(self.header_len())]; + while !options.is_empty() { + let (next_options, option) = TcpOption::parse(options)?; + if option == TcpOption::SackPermitted { + return Ok(true); + } + options = next_options; + } + Ok(false) + } + + /// Return the selective acknowledgement ranges, if any. If there are none in the packet, an + /// array of ``None`` values will be returned. + /// + pub fn selective_ack_ranges(&self) -> Result<[Option<(u32, u32)>; 3]> { + let data = self.buffer.as_ref(); + let mut options = &data[field::OPTIONS(self.header_len())]; + while !options.is_empty() { + let (next_options, option) = TcpOption::parse(options)?; + if let TcpOption::SackRange(slice) = option { + return Ok(slice); + } + options = next_options; + } + Ok([None, None, None]) + } + + /// Validate the packet checksum. + /// + /// # Panics + /// This function panics unless `src_addr` and `dst_addr` belong to the same family, + /// and that family is IPv4 or IPv6. + /// + /// # Fuzzing + /// This function always returns `true` when fuzzing. + pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool { + if cfg!(fuzzing) { + return true; + } + + let data = self.buffer.as_ref(); + checksum::combine(&[ + checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Tcp, data.len() as u32), + checksum::data(data), + ]) == !0 + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> { + /// Return a pointer to the options. + #[inline] + pub fn options(&self) -> &'a [u8] { + let header_len = self.header_len(); + let data = self.buffer.as_ref(); + &data[field::OPTIONS(header_len)] + } + + /// Return a pointer to the payload. + #[inline] + pub fn payload(&self) -> &'a [u8] { + let header_len = self.header_len() as usize; + let data = self.buffer.as_ref(); + &data[header_len..] + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the source port field. + #[inline] + pub fn set_src_port(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::SRC_PORT], value) + } + + /// Set the destination port field. + #[inline] + pub fn set_dst_port(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::DST_PORT], value) + } + + /// Set the sequence number field. + #[inline] + pub fn set_seq_number(&mut self, value: SeqNumber) { + let data = self.buffer.as_mut(); + NetworkEndian::write_i32(&mut data[field::SEQ_NUM], value.0) + } + + /// Set the acknowledgement number field. + #[inline] + pub fn set_ack_number(&mut self, value: SeqNumber) { + let data = self.buffer.as_mut(); + NetworkEndian::write_i32(&mut data[field::ACK_NUM], value.0) + } + + /// Clear the entire flags field. + #[inline] + pub fn clear_flags(&mut self) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + let raw = raw & !0x0fff; + NetworkEndian::write_u16(&mut data[field::FLAGS], raw) + } + + /// Set the FIN flag. + #[inline] + pub fn set_fin(&mut self, value: bool) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + let raw = if value { + raw | field::FLG_FIN + } else { + raw & !field::FLG_FIN + }; + NetworkEndian::write_u16(&mut data[field::FLAGS], raw) + } + + /// Set the SYN flag. + #[inline] + pub fn set_syn(&mut self, value: bool) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + let raw = if value { + raw | field::FLG_SYN + } else { + raw & !field::FLG_SYN + }; + NetworkEndian::write_u16(&mut data[field::FLAGS], raw) + } + + /// Set the RST flag. + #[inline] + pub fn set_rst(&mut self, value: bool) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + let raw = if value { + raw | field::FLG_RST + } else { + raw & !field::FLG_RST + }; + NetworkEndian::write_u16(&mut data[field::FLAGS], raw) + } + + /// Set the PSH flag. + #[inline] + pub fn set_psh(&mut self, value: bool) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + let raw = if value { + raw | field::FLG_PSH + } else { + raw & !field::FLG_PSH + }; + NetworkEndian::write_u16(&mut data[field::FLAGS], raw) + } + + /// Set the ACK flag. + #[inline] + pub fn set_ack(&mut self, value: bool) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + let raw = if value { + raw | field::FLG_ACK + } else { + raw & !field::FLG_ACK + }; + NetworkEndian::write_u16(&mut data[field::FLAGS], raw) + } + + /// Set the URG flag. + #[inline] + pub fn set_urg(&mut self, value: bool) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + let raw = if value { + raw | field::FLG_URG + } else { + raw & !field::FLG_URG + }; + NetworkEndian::write_u16(&mut data[field::FLAGS], raw) + } + + /// Set the ECE flag. + #[inline] + pub fn set_ece(&mut self, value: bool) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + let raw = if value { + raw | field::FLG_ECE + } else { + raw & !field::FLG_ECE + }; + NetworkEndian::write_u16(&mut data[field::FLAGS], raw) + } + + /// Set the CWR flag. + #[inline] + pub fn set_cwr(&mut self, value: bool) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + let raw = if value { + raw | field::FLG_CWR + } else { + raw & !field::FLG_CWR + }; + NetworkEndian::write_u16(&mut data[field::FLAGS], raw) + } + + /// Set the NS flag. + #[inline] + pub fn set_ns(&mut self, value: bool) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + let raw = if value { + raw | field::FLG_NS + } else { + raw & !field::FLG_NS + }; + NetworkEndian::write_u16(&mut data[field::FLAGS], raw) + } + + /// Set the header length, in octets. + #[inline] + pub fn set_header_len(&mut self, value: u8) { + let data = self.buffer.as_mut(); + let raw = NetworkEndian::read_u16(&data[field::FLAGS]); + let raw = (raw & !0xf000) | ((value as u16) / 4) << 12; + NetworkEndian::write_u16(&mut data[field::FLAGS], raw) + } + + /// Set the window size field. + #[inline] + pub fn set_window_len(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::WIN_SIZE], value) + } + + /// Set the checksum field. + #[inline] + pub fn set_checksum(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::CHECKSUM], value) + } + + /// Set the urgent pointer field. + #[inline] + pub fn set_urgent_at(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::URGENT], value) + } + + /// Compute and fill in the header checksum. + /// + /// # Panics + /// This function panics unless `src_addr` and `dst_addr` belong to the same family, + /// and that family is IPv4 or IPv6. + pub fn fill_checksum(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress) { + self.set_checksum(0); + let checksum = { + let data = self.buffer.as_ref(); + !checksum::combine(&[ + checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Tcp, data.len() as u32), + checksum::data(data), + ]) + }; + self.set_checksum(checksum) + } + + /// Return a pointer to the options. + #[inline] + pub fn options_mut(&mut self) -> &mut [u8] { + let header_len = self.header_len(); + let data = self.buffer.as_mut(); + &mut data[field::OPTIONS(header_len)] + } + + /// Return a mutable pointer to the payload data. + #[inline] + pub fn payload_mut(&mut self) -> &mut [u8] { + let header_len = self.header_len() as usize; + let data = self.buffer.as_mut(); + &mut data[header_len..] + } +} + +impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> { + fn as_ref(&self) -> &[u8] { + self.buffer.as_ref() + } +} + +/// A representation of a single TCP option. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum TcpOption<'a> { + EndOfList, + NoOperation, + MaxSegmentSize(u16), + WindowScale(u8), + SackPermitted, + SackRange([Option<(u32, u32)>; 3]), + Unknown { kind: u8, data: &'a [u8] }, +} + +impl<'a> TcpOption<'a> { + pub fn parse(buffer: &'a [u8]) -> Result<(&'a [u8], TcpOption<'a>)> { + let (length, option); + match *buffer.first().ok_or(Error)? { + field::OPT_END => { + length = 1; + option = TcpOption::EndOfList; + } + field::OPT_NOP => { + length = 1; + option = TcpOption::NoOperation; + } + kind => { + length = *buffer.get(1).ok_or(Error)? as usize; + let data = buffer.get(2..length).ok_or(Error)?; + match (kind, length) { + (field::OPT_END, _) | (field::OPT_NOP, _) => unreachable!(), + (field::OPT_MSS, 4) => { + option = TcpOption::MaxSegmentSize(NetworkEndian::read_u16(data)) + } + (field::OPT_MSS, _) => return Err(Error), + (field::OPT_WS, 3) => option = TcpOption::WindowScale(data[0]), + (field::OPT_WS, _) => return Err(Error), + (field::OPT_SACKPERM, 2) => option = TcpOption::SackPermitted, + (field::OPT_SACKPERM, _) => return Err(Error), + (field::OPT_SACKRNG, n) => { + if n < 10 || (n - 2) % 8 != 0 { + return Err(Error); + } + if n > 26 { + // It's possible for a remote to send 4 SACK blocks, but extremely rare. + // Better to "lose" that 4th block and save the extra RAM and CPU + // cycles in the vastly more common case. + // + // RFC 2018: SACK option that specifies n blocks will have a length of + // 8*n+2 bytes, so the 40 bytes available for TCP options can specify a + // maximum of 4 blocks. It is expected that SACK will often be used in + // conjunction with the Timestamp option used for RTTM [...] thus a + // maximum of 3 SACK blocks will be allowed in this case. + net_debug!("sACK with >3 blocks, truncating to 3"); + } + let mut sack_ranges: [Option<(u32, u32)>; 3] = [None; 3]; + + // RFC 2018: Each contiguous block of data queued at the data receiver is + // defined in the SACK option by two 32-bit unsigned integers in network + // byte order[...] + sack_ranges.iter_mut().enumerate().for_each(|(i, nmut)| { + let left = i * 8; + *nmut = if left < data.len() { + let mid = left + 4; + let right = mid + 4; + let range_left = NetworkEndian::read_u32(&data[left..mid]); + let range_right = NetworkEndian::read_u32(&data[mid..right]); + Some((range_left, range_right)) + } else { + None + }; + }); + option = TcpOption::SackRange(sack_ranges); + } + (_, _) => option = TcpOption::Unknown { kind, data }, + } + } + } + Ok((&buffer[length..], option)) + } + + pub fn buffer_len(&self) -> usize { + match *self { + TcpOption::EndOfList => 1, + TcpOption::NoOperation => 1, + TcpOption::MaxSegmentSize(_) => 4, + TcpOption::WindowScale(_) => 3, + TcpOption::SackPermitted => 2, + TcpOption::SackRange(s) => s.iter().filter(|s| s.is_some()).count() * 8 + 2, + TcpOption::Unknown { data, .. } => 2 + data.len(), + } + } + + pub fn emit<'b>(&self, buffer: &'b mut [u8]) -> &'b mut [u8] { + let length; + match *self { + TcpOption::EndOfList => { + length = 1; + // There may be padding space which also should be initialized. + for p in buffer.iter_mut() { + *p = field::OPT_END; + } + } + TcpOption::NoOperation => { + length = 1; + buffer[0] = field::OPT_NOP; + } + _ => { + length = self.buffer_len(); + buffer[1] = length as u8; + match self { + &TcpOption::EndOfList | &TcpOption::NoOperation => unreachable!(), + &TcpOption::MaxSegmentSize(value) => { + buffer[0] = field::OPT_MSS; + NetworkEndian::write_u16(&mut buffer[2..], value) + } + &TcpOption::WindowScale(value) => { + buffer[0] = field::OPT_WS; + buffer[2] = value; + } + &TcpOption::SackPermitted => { + buffer[0] = field::OPT_SACKPERM; + } + &TcpOption::SackRange(slice) => { + buffer[0] = field::OPT_SACKRNG; + slice + .iter() + .filter(|s| s.is_some()) + .enumerate() + .for_each(|(i, s)| { + let (first, second) = *s.as_ref().unwrap(); + let pos = i * 8 + 2; + NetworkEndian::write_u32(&mut buffer[pos..], first); + NetworkEndian::write_u32(&mut buffer[pos + 4..], second); + }); + } + &TcpOption::Unknown { + kind, + data: provided, + } => { + buffer[0] = kind; + buffer[2..].copy_from_slice(provided) + } + } + } + } + &mut buffer[length..] + } +} + +/// The possible control flags of a Transmission Control Protocol packet. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Control { + None, + Psh, + Syn, + Fin, + Rst, +} + +#[allow(clippy::len_without_is_empty)] +impl Control { + /// Return the length of a control flag, in terms of sequence space. + pub const fn len(self) -> usize { + match self { + Control::Syn | Control::Fin => 1, + _ => 0, + } + } + + /// Turn the PSH flag into no flag, and keep the rest as-is. + pub const fn quash_psh(self) -> Control { + match self { + Control::Psh => Control::None, + _ => self, + } + } +} + +/// A high-level representation of a Transmission Control Protocol packet. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct Repr<'a> { + pub src_port: u16, + pub dst_port: u16, + pub control: Control, + pub seq_number: SeqNumber, + pub ack_number: Option<SeqNumber>, + pub window_len: u16, + pub window_scale: Option<u8>, + pub max_seg_size: Option<u16>, + pub sack_permitted: bool, + pub sack_ranges: [Option<(u32, u32)>; 3], + pub payload: &'a [u8], +} + +impl<'a> Repr<'a> { + /// Parse a Transmission Control Protocol packet and return a high-level representation. + pub fn parse<T>( + packet: &Packet<&'a T>, + src_addr: &IpAddress, + dst_addr: &IpAddress, + checksum_caps: &ChecksumCapabilities, + ) -> Result<Repr<'a>> + where + T: AsRef<[u8]> + ?Sized, + { + // Source and destination ports must be present. + if packet.src_port() == 0 { + return Err(Error); + } + if packet.dst_port() == 0 { + return Err(Error); + } + // Valid checksum is expected. + if checksum_caps.tcp.rx() && !packet.verify_checksum(src_addr, dst_addr) { + return Err(Error); + } + + let control = match (packet.syn(), packet.fin(), packet.rst(), packet.psh()) { + (false, false, false, false) => Control::None, + (false, false, false, true) => Control::Psh, + (true, false, false, _) => Control::Syn, + (false, true, false, _) => Control::Fin, + (false, false, true, _) => Control::Rst, + _ => return Err(Error), + }; + let ack_number = match packet.ack() { + true => Some(packet.ack_number()), + false => None, + }; + // The PSH flag is ignored. + // The URG flag and the urgent field is ignored. This behavior is standards-compliant, + // however, most deployed systems (e.g. Linux) are *not* standards-compliant, and would + // cut the byte at the urgent pointer from the stream. + + let mut max_seg_size = None; + let mut window_scale = None; + let mut options = packet.options(); + let mut sack_permitted = false; + let mut sack_ranges = [None, None, None]; + while !options.is_empty() { + let (next_options, option) = TcpOption::parse(options)?; + match option { + TcpOption::EndOfList => break, + TcpOption::NoOperation => (), + TcpOption::MaxSegmentSize(value) => max_seg_size = Some(value), + TcpOption::WindowScale(value) => { + // RFC 1323: Thus, the shift count must be limited to 14 (which allows windows + // of 2**30 = 1 Gigabyte). If a Window Scale option is received with a shift.cnt + // value exceeding 14, the TCP should log the error but use 14 instead of the + // specified value. + window_scale = if value > 14 { + net_debug!( + "{}:{}:{}:{}: parsed window scaling factor >14, setting to 14", + src_addr, + packet.src_port(), + dst_addr, + packet.dst_port() + ); + Some(14) + } else { + Some(value) + }; + } + TcpOption::SackPermitted => sack_permitted = true, + TcpOption::SackRange(slice) => sack_ranges = slice, + _ => (), + } + options = next_options; + } + + Ok(Repr { + src_port: packet.src_port(), + dst_port: packet.dst_port(), + control: control, + seq_number: packet.seq_number(), + ack_number: ack_number, + window_len: packet.window_len(), + window_scale: window_scale, + max_seg_size: max_seg_size, + sack_permitted: sack_permitted, + sack_ranges: sack_ranges, + payload: packet.payload(), + }) + } + + /// Return the length of a header that will be emitted from this high-level representation. + /// + /// This should be used for buffer space calculations. + /// The TCP header length is a multiple of 4. + pub fn header_len(&self) -> usize { + let mut length = field::URGENT.end; + if self.max_seg_size.is_some() { + length += 4 + } + if self.window_scale.is_some() { + length += 3 + } + if self.sack_permitted { + length += 2; + } + let sack_range_len: usize = self + .sack_ranges + .iter() + .map(|o| o.map(|_| 8).unwrap_or(0)) + .sum(); + if sack_range_len > 0 { + length += sack_range_len + 2; + } + if length % 4 != 0 { + length += 4 - length % 4; + } + length + } + + /// Return the length of a packet that will be emitted from this high-level representation. + pub fn buffer_len(&self) -> usize { + self.header_len() + self.payload.len() + } + + /// Emit a high-level representation into a Transmission Control Protocol packet. + pub fn emit<T>( + &self, + packet: &mut Packet<&mut T>, + src_addr: &IpAddress, + dst_addr: &IpAddress, + checksum_caps: &ChecksumCapabilities, + ) where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { + packet.set_src_port(self.src_port); + packet.set_dst_port(self.dst_port); + packet.set_seq_number(self.seq_number); + packet.set_ack_number(self.ack_number.unwrap_or(SeqNumber(0))); + packet.set_window_len(self.window_len); + packet.set_header_len(self.header_len() as u8); + packet.clear_flags(); + match self.control { + Control::None => (), + Control::Psh => packet.set_psh(true), + Control::Syn => packet.set_syn(true), + Control::Fin => packet.set_fin(true), + Control::Rst => packet.set_rst(true), + } + packet.set_ack(self.ack_number.is_some()); + { + let mut options = packet.options_mut(); + if let Some(value) = self.max_seg_size { + let tmp = options; + options = TcpOption::MaxSegmentSize(value).emit(tmp); + } + if let Some(value) = self.window_scale { + let tmp = options; + options = TcpOption::WindowScale(value).emit(tmp); + } + if self.sack_permitted { + let tmp = options; + options = TcpOption::SackPermitted.emit(tmp); + } else if self.ack_number.is_some() && self.sack_ranges.iter().any(|s| s.is_some()) { + let tmp = options; + options = TcpOption::SackRange(self.sack_ranges).emit(tmp); + } + + if !options.is_empty() { + TcpOption::EndOfList.emit(options); + } + } + packet.set_urgent_at(0); + packet.payload_mut()[..self.payload.len()].copy_from_slice(self.payload); + + if checksum_caps.tcp.tx() { + packet.fill_checksum(src_addr, dst_addr) + } else { + // make sure we get a consistently zeroed checksum, + // since implementations might rely on it + packet.set_checksum(0); + } + } + + /// Return the length of the segment, in terms of sequence space. + pub const fn segment_len(&self) -> usize { + self.payload.len() + self.control.len() + } + + /// Return whether the segment has no flags set (except PSH) and no data. + pub const fn is_empty(&self) -> bool { + match self.control { + _ if !self.payload.is_empty() => false, + Control::Syn | Control::Fin | Control::Rst => false, + Control::None | Control::Psh => true, + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // Cannot use Repr::parse because we don't have the IP addresses. + write!(f, "TCP src={} dst={}", self.src_port(), self.dst_port())?; + if self.syn() { + write!(f, " syn")? + } + if self.fin() { + write!(f, " fin")? + } + if self.rst() { + write!(f, " rst")? + } + if self.psh() { + write!(f, " psh")? + } + if self.ece() { + write!(f, " ece")? + } + if self.cwr() { + write!(f, " cwr")? + } + if self.ns() { + write!(f, " ns")? + } + write!(f, " seq={}", self.seq_number())?; + if self.ack() { + write!(f, " ack={}", self.ack_number())?; + } + write!(f, " win={}", self.window_len())?; + if self.urg() { + write!(f, " urg={}", self.urgent_at())?; + } + write!(f, " len={}", self.payload().len())?; + + let mut options = self.options(); + while !options.is_empty() { + let (next_options, option) = match TcpOption::parse(options) { + Ok(res) => res, + Err(err) => return write!(f, " ({err})"), + }; + match option { + TcpOption::EndOfList => break, + TcpOption::NoOperation => (), + TcpOption::MaxSegmentSize(value) => write!(f, " mss={value}")?, + TcpOption::WindowScale(value) => write!(f, " ws={value}")?, + TcpOption::SackPermitted => write!(f, " sACK")?, + TcpOption::SackRange(slice) => write!(f, " sACKr{slice:?}")?, // debug print conveniently includes the []s + TcpOption::Unknown { kind, .. } => write!(f, " opt({kind})")?, + } + options = next_options; + } + Ok(()) + } +} + +impl<'a> fmt::Display for Repr<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "TCP src={} dst={}", self.src_port, self.dst_port)?; + match self.control { + Control::Syn => write!(f, " syn")?, + Control::Fin => write!(f, " fin")?, + Control::Rst => write!(f, " rst")?, + Control::Psh => write!(f, " psh")?, + Control::None => (), + } + write!(f, " seq={}", self.seq_number)?; + if let Some(ack_number) = self.ack_number { + write!(f, " ack={ack_number}")?; + } + write!(f, " win={}", self.window_len)?; + write!(f, " len={}", self.payload.len())?; + if let Some(max_seg_size) = self.max_seg_size { + write!(f, " mss={max_seg_size}")?; + } + Ok(()) + } +} + +#[cfg(feature = "defmt")] +impl<'a> defmt::Format for Repr<'a> { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "TCP src={} dst={}", self.src_port, self.dst_port); + match self.control { + Control::Syn => defmt::write!(fmt, " syn"), + Control::Fin => defmt::write!(fmt, " fin"), + Control::Rst => defmt::write!(fmt, " rst"), + Control::Psh => defmt::write!(fmt, " psh"), + Control::None => (), + } + defmt::write!(fmt, " seq={}", self.seq_number); + if let Some(ack_number) = self.ack_number { + defmt::write!(fmt, " ack={}", ack_number); + } + defmt::write!(fmt, " win={}", self.window_len); + defmt::write!(fmt, " len={}", self.payload.len()); + if let Some(max_seg_size) = self.max_seg_size { + defmt::write!(fmt, " mss={}", max_seg_size); + } + } +} + +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; + +impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { + match Packet::new_checked(buffer) { + Err(err) => write!(f, "{indent}({err})"), + Ok(packet) => write!(f, "{indent}{packet}"), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + #[cfg(feature = "proto-ipv4")] + use crate::wire::Ipv4Address; + + #[cfg(feature = "proto-ipv4")] + const SRC_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 1]); + #[cfg(feature = "proto-ipv4")] + const DST_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 2]); + + #[cfg(feature = "proto-ipv4")] + static PACKET_BYTES: [u8; 28] = [ + 0xbf, 0x00, 0x00, 0x50, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x60, 0x35, 0x01, + 0x23, 0x01, 0xb6, 0x02, 0x01, 0x03, 0x03, 0x0c, 0x01, 0xaa, 0x00, 0x00, 0xff, + ]; + + #[cfg(feature = "proto-ipv4")] + static OPTION_BYTES: [u8; 4] = [0x03, 0x03, 0x0c, 0x01]; + + #[cfg(feature = "proto-ipv4")] + static PAYLOAD_BYTES: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_deconstruct() { + let packet = Packet::new_unchecked(&PACKET_BYTES[..]); + assert_eq!(packet.src_port(), 48896); + assert_eq!(packet.dst_port(), 80); + assert_eq!(packet.seq_number(), SeqNumber(0x01234567)); + assert_eq!(packet.ack_number(), SeqNumber(0x89abcdefu32 as i32)); + assert_eq!(packet.header_len(), 24); + assert!(packet.fin()); + assert!(!packet.syn()); + assert!(packet.rst()); + assert!(!packet.psh()); + assert!(packet.ack()); + assert!(packet.urg()); + assert_eq!(packet.window_len(), 0x0123); + assert_eq!(packet.urgent_at(), 0x0201); + assert_eq!(packet.checksum(), 0x01b6); + assert_eq!(packet.options(), &OPTION_BYTES[..]); + assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]); + assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into())); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_construct() { + let mut bytes = vec![0xa5; PACKET_BYTES.len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_src_port(48896); + packet.set_dst_port(80); + packet.set_seq_number(SeqNumber(0x01234567)); + packet.set_ack_number(SeqNumber(0x89abcdefu32 as i32)); + packet.set_header_len(24); + packet.clear_flags(); + packet.set_fin(true); + packet.set_syn(false); + packet.set_rst(true); + packet.set_psh(false); + packet.set_ack(true); + packet.set_urg(true); + packet.set_window_len(0x0123); + packet.set_urgent_at(0x0201); + packet.set_checksum(0xEEEE); + packet.options_mut().copy_from_slice(&OPTION_BYTES[..]); + packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]); + packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into()); + assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_truncated() { + let packet = Packet::new_unchecked(&PACKET_BYTES[..23]); + assert_eq!(packet.check_len(), Err(Error)); + } + + #[test] + fn test_impossible_len() { + let mut bytes = vec![0; 20]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_header_len(10); + assert_eq!(packet.check_len(), Err(Error)); + } + + #[cfg(feature = "proto-ipv4")] + static SYN_PACKET_BYTES: [u8; 24] = [ + 0xbf, 0x00, 0x00, 0x50, 0x01, 0x23, 0x45, 0x67, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02, 0x01, + 0x23, 0x7a, 0x8d, 0x00, 0x00, 0xaa, 0x00, 0x00, 0xff, + ]; + + #[cfg(feature = "proto-ipv4")] + fn packet_repr() -> Repr<'static> { + Repr { + src_port: 48896, + dst_port: 80, + seq_number: SeqNumber(0x01234567), + ack_number: None, + window_len: 0x0123, + window_scale: None, + control: Control::Syn, + max_seg_size: None, + sack_permitted: false, + sack_ranges: [None, None, None], + payload: &PAYLOAD_BYTES, + } + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_parse() { + let packet = Packet::new_unchecked(&SYN_PACKET_BYTES[..]); + let repr = Repr::parse( + &packet, + &SRC_ADDR.into(), + &DST_ADDR.into(), + &ChecksumCapabilities::default(), + ) + .unwrap(); + assert_eq!(repr, packet_repr()); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_emit() { + let repr = packet_repr(); + let mut bytes = vec![0xa5; repr.buffer_len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + repr.emit( + &mut packet, + &SRC_ADDR.into(), + &DST_ADDR.into(), + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &SYN_PACKET_BYTES[..]); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_header_len_multiple_of_4() { + let mut repr = packet_repr(); + repr.window_scale = Some(0); // This TCP Option needs 3 bytes. + assert_eq!(repr.header_len() % 4, 0); // Should e.g. be 28 instead of 27. + } + + macro_rules! assert_option_parses { + ($opt:expr, $data:expr) => {{ + assert_eq!(TcpOption::parse($data), Ok((&[][..], $opt))); + let buffer = &mut [0; 40][..$opt.buffer_len()]; + assert_eq!($opt.emit(buffer), &mut []); + assert_eq!(&*buffer, $data); + }}; + } + + #[test] + fn test_tcp_options() { + assert_option_parses!(TcpOption::EndOfList, &[0x00]); + assert_option_parses!(TcpOption::NoOperation, &[0x01]); + assert_option_parses!(TcpOption::MaxSegmentSize(1500), &[0x02, 0x04, 0x05, 0xdc]); + assert_option_parses!(TcpOption::WindowScale(12), &[0x03, 0x03, 0x0c]); + assert_option_parses!(TcpOption::SackPermitted, &[0x4, 0x02]); + assert_option_parses!( + TcpOption::SackRange([Some((500, 1500)), None, None]), + &[0x05, 0x0a, 0x00, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x05, 0xdc] + ); + assert_option_parses!( + TcpOption::SackRange([Some((875, 1225)), Some((1500, 2500)), None]), + &[ + 0x05, 0x12, 0x00, 0x00, 0x03, 0x6b, 0x00, 0x00, 0x04, 0xc9, 0x00, 0x00, 0x05, 0xdc, + 0x00, 0x00, 0x09, 0xc4 + ] + ); + assert_option_parses!( + TcpOption::SackRange([ + Some((875000, 1225000)), + Some((1500000, 2500000)), + Some((876543210, 876654320)) + ]), + &[ + 0x05, 0x1a, 0x00, 0x0d, 0x59, 0xf8, 0x00, 0x12, 0xb1, 0x28, 0x00, 0x16, 0xe3, 0x60, + 0x00, 0x26, 0x25, 0xa0, 0x34, 0x3e, 0xfc, 0xea, 0x34, 0x40, 0xae, 0xf0 + ] + ); + assert_option_parses!( + TcpOption::Unknown { + kind: 12, + data: &[1, 2, 3][..] + }, + &[0x0c, 0x05, 0x01, 0x02, 0x03] + ) + } + + #[test] + fn test_malformed_tcp_options() { + assert_eq!(TcpOption::parse(&[]), Err(Error)); + assert_eq!(TcpOption::parse(&[0xc]), Err(Error)); + assert_eq!(TcpOption::parse(&[0xc, 0x05, 0x01, 0x02]), Err(Error)); + assert_eq!(TcpOption::parse(&[0xc, 0x01]), Err(Error)); + assert_eq!(TcpOption::parse(&[0x2, 0x02]), Err(Error)); + assert_eq!(TcpOption::parse(&[0x3, 0x02]), Err(Error)); + } +} diff --git a/src/wire/udp.rs b/src/wire/udp.rs new file mode 100644 index 0000000..77f9f84 --- /dev/null +++ b/src/wire/udp.rs @@ -0,0 +1,482 @@ +use byteorder::{ByteOrder, NetworkEndian}; +use core::fmt; + +use super::{Error, Result}; +use crate::phy::ChecksumCapabilities; +use crate::wire::ip::checksum; +use crate::wire::{IpAddress, IpProtocol}; + +/// A read/write wrapper around an User Datagram Protocol packet buffer. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Packet<T: AsRef<[u8]>> { + buffer: T, +} + +mod field { + #![allow(non_snake_case)] + + use crate::wire::field::*; + + pub const SRC_PORT: Field = 0..2; + pub const DST_PORT: Field = 2..4; + pub const LENGTH: Field = 4..6; + pub const CHECKSUM: Field = 6..8; + + pub const fn PAYLOAD(length: u16) -> Field { + CHECKSUM.end..(length as usize) + } +} + +pub const HEADER_LEN: usize = field::CHECKSUM.end; + +#[allow(clippy::len_without_is_empty)] +impl<T: AsRef<[u8]>> Packet<T> { + /// Imbue a raw octet buffer with UDP packet structure. + pub const fn new_unchecked(buffer: T) -> Packet<T> { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result<Packet<T>> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + /// Returns `Err(Error)` if the length field has a value smaller + /// than the header length. + /// + /// The result of this check is invalidated by calling [set_len]. + /// + /// [set_len]: #method.set_len + pub fn check_len(&self) -> Result<()> { + let buffer_len = self.buffer.as_ref().len(); + if buffer_len < HEADER_LEN { + Err(Error) + } else { + let field_len = self.len() as usize; + if buffer_len < field_len || field_len < HEADER_LEN { + Err(Error) + } else { + Ok(()) + } + } + } + + /// Consume the packet, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the source port field. + #[inline] + pub fn src_port(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::SRC_PORT]) + } + + /// Return the destination port field. + #[inline] + pub fn dst_port(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::DST_PORT]) + } + + /// Return the length field. + #[inline] + pub fn len(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::LENGTH]) + } + + /// Return the checksum field. + #[inline] + pub fn checksum(&self) -> u16 { + let data = self.buffer.as_ref(); + NetworkEndian::read_u16(&data[field::CHECKSUM]) + } + + /// Validate the packet checksum. + /// + /// # Panics + /// This function panics unless `src_addr` and `dst_addr` belong to the same family, + /// and that family is IPv4 or IPv6. + /// + /// # Fuzzing + /// This function always returns `true` when fuzzing. + pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool { + if cfg!(fuzzing) { + return true; + } + + // From the RFC: + // > An all zero transmitted checksum value means that the transmitter + // > generated no checksum (for debugging or for higher level protocols + // > that don't care). + if self.checksum() == 0 { + return true; + } + + let data = self.buffer.as_ref(); + checksum::combine(&[ + checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32), + checksum::data(&data[..self.len() as usize]), + ]) == !0 + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> { + /// Return a pointer to the payload. + #[inline] + pub fn payload(&self) -> &'a [u8] { + let length = self.len(); + let data = self.buffer.as_ref(); + &data[field::PAYLOAD(length)] + } +} + +impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> { + /// Set the source port field. + #[inline] + pub fn set_src_port(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::SRC_PORT], value) + } + + /// Set the destination port field. + #[inline] + pub fn set_dst_port(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::DST_PORT], value) + } + + /// Set the length field. + #[inline] + pub fn set_len(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::LENGTH], value) + } + + /// Set the checksum field. + #[inline] + pub fn set_checksum(&mut self, value: u16) { + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[field::CHECKSUM], value) + } + + /// Compute and fill in the header checksum. + /// + /// # Panics + /// This function panics unless `src_addr` and `dst_addr` belong to the same family, + /// and that family is IPv4 or IPv6. + pub fn fill_checksum(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress) { + self.set_checksum(0); + let checksum = { + let data = self.buffer.as_ref(); + !checksum::combine(&[ + checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32), + checksum::data(&data[..self.len() as usize]), + ]) + }; + // UDP checksum value of 0 means no checksum; if the checksum really is zero, + // use all-ones, which indicates that the remote end must verify the checksum. + // Arithmetically, RFC 1071 checksums of all-zeroes and all-ones behave identically, + // so no action is necessary on the remote end. + self.set_checksum(if checksum == 0 { 0xffff } else { checksum }) + } + + /// Return a mutable pointer to the payload. + #[inline] + pub fn payload_mut(&mut self) -> &mut [u8] { + let length = self.len(); + let data = self.buffer.as_mut(); + &mut data[field::PAYLOAD(length)] + } +} + +impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> { + fn as_ref(&self) -> &[u8] { + self.buffer.as_ref() + } +} + +/// A high-level representation of an User Datagram Protocol packet. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct Repr { + pub src_port: u16, + pub dst_port: u16, +} + +impl Repr { + /// Parse an User Datagram Protocol packet and return a high-level representation. + pub fn parse<T>( + packet: &Packet<&T>, + src_addr: &IpAddress, + dst_addr: &IpAddress, + checksum_caps: &ChecksumCapabilities, + ) -> Result<Repr> + where + T: AsRef<[u8]> + ?Sized, + { + // Destination port cannot be omitted (but source port can be). + if packet.dst_port() == 0 { + return Err(Error); + } + // Valid checksum is expected... + if checksum_caps.udp.rx() && !packet.verify_checksum(src_addr, dst_addr) { + match (src_addr, dst_addr) { + // ... except on UDP-over-IPv4, where it can be omitted. + #[cfg(feature = "proto-ipv4")] + (&IpAddress::Ipv4(_), &IpAddress::Ipv4(_)) if packet.checksum() == 0 => (), + _ => return Err(Error), + } + } + + Ok(Repr { + src_port: packet.src_port(), + dst_port: packet.dst_port(), + }) + } + + /// Return the length of the packet header that will be emitted from this high-level representation. + pub const fn header_len(&self) -> usize { + HEADER_LEN + } + + /// Emit a high-level representation into an User Datagram Protocol packet. + /// + /// This never calculates the checksum, and is intended for internal-use only, + /// not for packets that are going to be actually sent over the network. For + /// example, when decompressing 6lowpan. + pub(crate) fn emit_header<T: ?Sized>(&self, packet: &mut Packet<&mut T>, payload_len: usize) + where + T: AsRef<[u8]> + AsMut<[u8]>, + { + packet.set_src_port(self.src_port); + packet.set_dst_port(self.dst_port); + packet.set_len((HEADER_LEN + payload_len) as u16); + packet.set_checksum(0); + } + + /// Emit a high-level representation into an User Datagram Protocol packet. + pub fn emit<T: ?Sized>( + &self, + packet: &mut Packet<&mut T>, + src_addr: &IpAddress, + dst_addr: &IpAddress, + payload_len: usize, + emit_payload: impl FnOnce(&mut [u8]), + checksum_caps: &ChecksumCapabilities, + ) where + T: AsRef<[u8]> + AsMut<[u8]>, + { + packet.set_src_port(self.src_port); + packet.set_dst_port(self.dst_port); + packet.set_len((HEADER_LEN + payload_len) as u16); + emit_payload(packet.payload_mut()); + + if checksum_caps.udp.tx() { + packet.fill_checksum(src_addr, dst_addr) + } else { + // make sure we get a consistently zeroed checksum, + // since implementations might rely on it + packet.set_checksum(0); + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // Cannot use Repr::parse because we don't have the IP addresses. + write!( + f, + "UDP src={} dst={} len={}", + self.src_port(), + self.dst_port(), + self.payload().len() + ) + } +} + +#[cfg(feature = "defmt")] +impl<'a, T: AsRef<[u8]> + ?Sized> defmt::Format for Packet<&'a T> { + fn format(&self, fmt: defmt::Formatter) { + // Cannot use Repr::parse because we don't have the IP addresses. + defmt::write!( + fmt, + "UDP src={} dst={} len={}", + self.src_port(), + self.dst_port(), + self.payload().len() + ); + } +} + +impl fmt::Display for Repr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "UDP src={} dst={}", self.src_port, self.dst_port) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Repr { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "UDP src={} dst={}", self.src_port, self.dst_port); + } +} + +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; + +impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { + match Packet::new_checked(buffer) { + Err(err) => write!(f, "{indent}({err})"), + Ok(packet) => write!(f, "{indent}{packet}"), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + #[cfg(feature = "proto-ipv4")] + use crate::wire::Ipv4Address; + + #[cfg(feature = "proto-ipv4")] + const SRC_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 1]); + #[cfg(feature = "proto-ipv4")] + const DST_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 2]); + + #[cfg(feature = "proto-ipv4")] + static PACKET_BYTES: [u8; 12] = [ + 0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff, + ]; + + #[cfg(feature = "proto-ipv4")] + static NO_CHECKSUM_PACKET: [u8; 12] = [ + 0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x00, 0x00, 0xaa, 0x00, 0x00, 0xff, + ]; + + #[cfg(feature = "proto-ipv4")] + static PAYLOAD_BYTES: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_deconstruct() { + let packet = Packet::new_unchecked(&PACKET_BYTES[..]); + assert_eq!(packet.src_port(), 48896); + assert_eq!(packet.dst_port(), 53); + assert_eq!(packet.len(), 12); + assert_eq!(packet.checksum(), 0x124d); + assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]); + assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into())); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_construct() { + let mut bytes = vec![0xa5; 12]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_src_port(48896); + packet.set_dst_port(53); + packet.set_len(12); + packet.set_checksum(0xffff); + packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]); + packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into()); + assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]); + } + + #[test] + fn test_impossible_len() { + let mut bytes = vec![0; 12]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_len(4); + assert_eq!(packet.check_len(), Err(Error)); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_zero_checksum() { + let mut bytes = vec![0; 8]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_src_port(1); + packet.set_dst_port(31881); + packet.set_len(8); + packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into()); + assert_eq!(packet.checksum(), 0xffff); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_no_checksum() { + let mut bytes = vec![0; 8]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_src_port(1); + packet.set_dst_port(31881); + packet.set_len(8); + packet.set_checksum(0); + assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into())); + } + + #[cfg(feature = "proto-ipv4")] + fn packet_repr() -> Repr { + Repr { + src_port: 48896, + dst_port: 53, + } + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_parse() { + let packet = Packet::new_unchecked(&PACKET_BYTES[..]); + let repr = Repr::parse( + &packet, + &SRC_ADDR.into(), + &DST_ADDR.into(), + &ChecksumCapabilities::default(), + ) + .unwrap(); + assert_eq!(repr, packet_repr()); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_emit() { + let repr = packet_repr(); + let mut bytes = vec![0xa5; repr.header_len() + PAYLOAD_BYTES.len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + repr.emit( + &mut packet, + &SRC_ADDR.into(), + &DST_ADDR.into(), + PAYLOAD_BYTES.len(), + |payload| payload.copy_from_slice(&PAYLOAD_BYTES), + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]); + } + + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_checksum_omitted() { + let packet = Packet::new_unchecked(&NO_CHECKSUM_PACKET[..]); + let repr = Repr::parse( + &packet, + &SRC_ADDR.into(), + &DST_ADDR.into(), + &ChecksumCapabilities::default(), + ) + .unwrap(); + assert_eq!(repr, packet_repr()); + } +} diff --git a/utils/packet2pcap.rs b/utils/packet2pcap.rs new file mode 100644 index 0000000..7d06c6f --- /dev/null +++ b/utils/packet2pcap.rs @@ -0,0 +1,74 @@ +use getopts::Options; +use smoltcp::phy::{PcapLinkType, PcapSink}; +use smoltcp::time::Instant; +use std::env; +use std::fs::File; +use std::io::{self, Read}; +use std::path::Path; +use std::process::exit; + +fn convert( + packet_filename: &Path, + pcap_filename: &Path, + link_type: PcapLinkType, +) -> io::Result<()> { + let mut packet_file = File::open(packet_filename)?; + let mut packet = Vec::new(); + packet_file.read_to_end(&mut packet)?; + + let mut pcap_file = File::create(pcap_filename)?; + PcapSink::global_header(&mut pcap_file, link_type); + PcapSink::packet(&mut pcap_file, Instant::from_millis(0), &packet[..]); + + Ok(()) +} + +fn print_usage(program: &str, opts: Options) { + let brief = format!("Usage: {program} [options] INPUT OUTPUT"); + print!("{}", opts.usage(&brief)); +} + +fn main() { + let args: Vec<String> = env::args().collect(); + let program = args[0].clone(); + + let mut opts = Options::new(); + opts.optflag("h", "help", "print this help menu"); + opts.optopt( + "t", + "link-type", + "set link type (one of: ethernet ip)", + "TYPE", + ); + + let matches = match opts.parse(&args[1..]) { + Ok(m) => m, + Err(e) => { + eprintln!("{e}"); + return; + } + }; + + let link_type = match matches.opt_str("t").as_ref().map(|s| &s[..]) { + Some("ethernet") => Some(PcapLinkType::Ethernet), + Some("ip") => Some(PcapLinkType::Ip), + _ => None, + }; + + if matches.opt_present("h") || matches.free.len() != 2 || link_type.is_none() { + print_usage(&program, opts); + return; + } + + match convert( + Path::new(&matches.free[0]), + Path::new(&matches.free[1]), + link_type.unwrap(), + ) { + Ok(()) => (), + Err(e) => { + eprintln!("Cannot convert packet to pcap: {e}"); + exit(1); + } + } +} |