aboutsummaryrefslogtreecommitdiff
path: root/src/handshake/machine.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/handshake/machine.rs')
-rw-r--r--src/handshake/machine.rs85
1 files changed, 69 insertions, 16 deletions
diff --git a/src/handshake/machine.rs b/src/handshake/machine.rs
index eacb4bf..2e3f2cb 100644
--- a/src/handshake/machine.rs
+++ b/src/handshake/machine.rs
@@ -20,7 +20,7 @@ pub struct HandshakeMachine<Stream> {
impl<Stream> HandshakeMachine<Stream> {
/// Start reading data from the peer.
pub fn start_read(stream: Stream) -> Self {
- HandshakeMachine { stream, state: HandshakeState::Reading(ReadBuffer::new()) }
+ Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) }
}
/// Start writing data to the peer.
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
@@ -41,25 +41,31 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
trace!("Doing handshake round.");
match self.state {
- HandshakeState::Reading(mut buf) => {
+ HandshakeState::Reading(mut buf, mut attack_check) => {
let read = buf.read_from(&mut self.stream).no_block()?;
match read {
Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
- Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
- buf.advance(size);
- RoundResult::StageFinished(StageResult::DoneReading {
- result: obj,
- stream: self.stream,
- tail: buf.into_vec(),
+ Some(count) => {
+ attack_check.check_incoming_packet_size(count)?;
+ // TODO: this is slow for big headers with too many small packets.
+ // The parser has to be reworked in order to work on streams instead
+ // of buffers.
+ Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
+ buf.advance(size);
+ RoundResult::StageFinished(StageResult::DoneReading {
+ result: obj,
+ stream: self.stream,
+ tail: buf.into_vec(),
+ })
+ } else {
+ RoundResult::Incomplete(HandshakeMachine {
+ state: HandshakeState::Reading(buf, attack_check),
+ ..self
+ })
})
- } else {
- RoundResult::Incomplete(HandshakeMachine {
- state: HandshakeState::Reading(buf),
- ..self
- })
- }),
+ }
None => Ok(RoundResult::WouldBlock(HandshakeMachine {
- state: HandshakeState::Reading(buf),
+ state: HandshakeState::Reading(buf, attack_check),
..self
})),
}
@@ -119,7 +125,54 @@ pub trait TryParse: Sized {
#[derive(Debug)]
enum HandshakeState {
/// Reading data from the peer.
- Reading(ReadBuffer),
+ Reading(ReadBuffer, AttackCheck),
/// Sending data to the peer.
Writing(Cursor<Vec<u8>>),
}
+
+/// Attack mitigation. Contains counters needed to prevent DoS attacks
+/// and reject valid but useless headers.
+#[derive(Debug)]
+pub(crate) struct AttackCheck {
+ /// Number of HTTP header successful reads (TCP packets).
+ number_of_packets: usize,
+ /// Total number of bytes in HTTP header.
+ number_of_bytes: usize,
+}
+
+impl AttackCheck {
+ /// Initialize attack checking for incoming buffer.
+ fn new() -> Self {
+ Self { number_of_packets: 0, number_of_bytes: 0 }
+ }
+
+ /// Check the size of an incoming packet. To be called immediately after `read()`
+ /// passing its returned bytes count as `size`.
+ fn check_incoming_packet_size(&mut self, size: usize) -> Result<()> {
+ self.number_of_packets += 1;
+ self.number_of_bytes += size;
+
+ // TODO: these values are hardcoded. Instead of making them configurable,
+ // rework the way HTTP header is parsed to remove this check at all.
+ const MAX_BYTES: usize = 65536;
+ const MAX_PACKETS: usize = 512;
+ const MIN_PACKET_SIZE: usize = 128;
+ const MIN_PACKET_CHECK_THRESHOLD: usize = 64;
+
+ if self.number_of_bytes > MAX_BYTES {
+ return Err(Error::AttackAttempt);
+ }
+
+ if self.number_of_packets > MAX_PACKETS {
+ return Err(Error::AttackAttempt);
+ }
+
+ if self.number_of_packets > MIN_PACKET_CHECK_THRESHOLD
+ && self.number_of_packets * MIN_PACKET_SIZE > self.number_of_bytes
+ {
+ return Err(Error::AttackAttempt);
+ }
+
+ Ok(())
+ }
+}