From 706219c04625807b4dc979b04352a425f60c2426 Mon Sep 17 00:00:00 2001 From: Nilstrieb <48135649+Nilstrieb@users.noreply.github.com> Date: Wed, 9 Feb 2022 14:05:13 +0100 Subject: [PATCH] better frame parsing and handling --- Cargo.lock | 21 +++++++++ amqp_transport/Cargo.toml | 1 + amqp_transport/src/connection.rs | 34 +++++++++------ amqp_transport/src/error.rs | 30 +++++++++++++ amqp_transport/src/frame.rs | 74 ++++++++++++++++++++++++-------- amqp_transport/src/lib.rs | 3 ++ 6 files changed, 132 insertions(+), 31 deletions(-) create mode 100644 amqp_transport/src/error.rs diff --git a/Cargo.lock b/Cargo.lock index c8271da..e1c577c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,6 +19,7 @@ version = "0.1.0" dependencies = [ "anyhow", "rand", + "thiserror", "tokio", "tracing", "tracing-subscriber", @@ -315,6 +316,26 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "thiserror" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.4" diff --git a/amqp_transport/Cargo.toml b/amqp_transport/Cargo.toml index 5843ed1..c596e00 100644 --- a/amqp_transport/Cargo.toml +++ b/amqp_transport/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" [dependencies] anyhow = "1.0.53" rand = "0.8.4" +thiserror = "1.0.30" tokio = { version = "1.16.1", features = ["full"] } tracing = "0.1.30" tracing-subscriber = "0.3.8" diff --git a/amqp_transport/src/connection.rs b/amqp_transport/src/connection.rs index a910702..4021a38 100644 --- a/amqp_transport/src/connection.rs +++ b/amqp_transport/src/connection.rs @@ -1,4 +1,6 @@ -use anyhow::{bail, ensure, Result}; +use crate::error::{ConError, ProtocolError}; +use crate::frame; +use anyhow::{ensure, Context}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tracing::{debug, error}; @@ -19,12 +21,16 @@ impl Connection { } } - pub async fn run(&mut self) -> Result<()> { + pub async fn run(&mut self) -> Result<(), ConError> { self.negotiate_version().await?; - Ok(()) + + loop { + let frame = frame::read_frame(&mut self.stream, 10000).await?; + debug!(?frame, "received frame"); + } } - async fn negotiate_version(&mut self) -> Result<()> { + async fn negotiate_version(&mut self) -> Result<(), ConError> { const HEADER_SIZE: usize = 8; const PROTOCOL_VERSION: &[u8] = &[0, 9, 1]; const PROTOCOL_HEADER: &[u8] = b"AMQP\0\0\x09\x01"; @@ -33,26 +39,26 @@ impl Connection { let mut read_header_buf = [0; HEADER_SIZE]; - self.stream.read_exact(&mut read_header_buf).await?; + self.stream + .read_exact(&mut read_header_buf) + .await + .context("read protocol header")?; debug!(received_header = ?read_header_buf,"Received protocol header"); - ensure!( - &read_header_buf[0..5] == b"AMQP\0", - "Received wrong protocol" - ); - let version = &read_header_buf[5..8]; - self.stream.write_all(PROTOCOL_HEADER).await?; + self.stream + .write_all(PROTOCOL_HEADER) + .await + .context("write protocol header")?; - if version == PROTOCOL_VERSION { + if &read_header_buf[0..5] == b"AMQP\0" && version == PROTOCOL_VERSION { debug!(?version, "Version negotiation successful"); Ok(()) } else { debug!(?version, expected_version = ?PROTOCOL_VERSION, "Version negotiation failed, unsupported version"); - self.stream.shutdown().await?; - bail!("Unsupported protocol version {:?}", version); + return Err(ProtocolError::OtherCloseConnection.into()); } } } diff --git a/amqp_transport/src/error.rs b/amqp_transport/src/error.rs new file mode 100644 index 0000000..db36cd2 --- /dev/null +++ b/amqp_transport/src/error.rs @@ -0,0 +1,30 @@ +#[derive(Debug, thiserror::Error)] +pub enum ConError { + #[error("{0}")] + Invalid(#[from] ProtocolError), + #[error("connection error: `{0}`")] + Other(#[from] anyhow::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum ProtocolError { + #[error("fatal error")] + Fatal, + #[error("{0}")] + ConException(#[from] ConException), + #[error("{0}")] + ChannelException(#[from] ChannelException), + #[error("closing connection")] + OtherCloseConnection, +} + +#[derive(Debug, thiserror::Error)] +pub enum ConException { + #[error("501 Frame error")] + FrameError, + #[error("503 Command invalid")] + CommandInvalid, +} + +#[derive(Debug, thiserror::Error)] +pub enum ChannelException {} diff --git a/amqp_transport/src/frame.rs b/amqp_transport/src/frame.rs index 7691b58..16dae6d 100644 --- a/amqp_transport/src/frame.rs +++ b/amqp_transport/src/frame.rs @@ -1,57 +1,97 @@ -use anyhow::Result; +use crate::error::{ConError, ConException, ProtocolError}; +use anyhow::Context; use tokio::io::AsyncReadExt; +const REQUIRED_FRAME_END: u8 = 0xCE; + +#[derive(Debug, Clone, PartialEq, Eq)] +#[repr(u8)] +pub enum FrameType { + Method = 1, + Header = 2, + Body = 3, + Heartbeat = 4, +} + +impl TryFrom for FrameType { + type Error = ConError; + + fn try_from(value: u8) -> Result { + Ok(match value { + 1 => Self::Method, + 2 => Self::Header, + 3 => Self::Body, + 4 => Self::Heartbeat, + _ => return Err(ProtocolError::Fatal.into()), + }) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct Frame { - r#type: u8, + r#type: FrameType, channel: u16, size: u32, payload: Vec, - frame_end: u8, } -pub async fn read_frame(r: &mut R) -> Result +pub async fn read_frame(r: &mut R, max_frame_size: usize) -> Result where R: AsyncReadExt + Unpin, { - let r#type = r.read_u8().await?; - let channel = r.read_u16().await?; - let size = r.read_u32().await?; + let r#type = r.read_u8().await.context("read type")?; + let channel = r.read_u16().await.context("read channel")?; + let size = r.read_u32().await.context("read size")?; let mut payload = vec![0; size.try_into().unwrap()]; - r.read_exact(&mut payload).await?; + r.read_exact(&mut payload).await.context("read payload")?; - let frame_end = r.read_u8().await?; + let frame_end = r.read_u8().await.context("read frame end")?; + + if frame_end != REQUIRED_FRAME_END { + return Err(ProtocolError::Fatal.into()); + } + + if payload.len() > max_frame_size { + return Err(ProtocolError::ConException(ConException::FrameError).into()); + } Ok(Frame { - r#type, + r#type: r#type.try_into()?, channel, size, payload, - frame_end, }) } #[cfg(test)] mod tests { - use crate::frame::Frame; + use crate::frame::{Frame, FrameType}; #[tokio::test] async fn read_small_body() { let mut bytes: &[u8] = &[ - /*type*/ 1, /*channel*/ 0, 0, /*size*/ 0, 0, 0, 3, /*payload*/ 1, - 2, 3, /*frame-end*/ 0, + /*type*/ 1, + /*channel*/ 0, + 0, + /*size*/ 0, + 0, + 0, + 3, + /*payload*/ 1, + 2, + 3, + /*frame-end*/ super::REQUIRED_FRAME_END, ]; - let frame = super::read_frame(&mut bytes).await.unwrap(); + let frame = super::read_frame(&mut bytes, 10000).await.unwrap(); assert_eq!( frame, Frame { - r#type: 1, + r#type: FrameType::Method, channel: 0, size: 3, payload: vec![1, 2, 3], - frame_end: 0 } ); } diff --git a/amqp_transport/src/lib.rs b/amqp_transport/src/lib.rs index 2ca54cb..c4fa279 100644 --- a/amqp_transport/src/lib.rs +++ b/amqp_transport/src/lib.rs @@ -1,4 +1,7 @@ +#![allow(dead_code)] + mod connection; +mod error; mod frame; use crate::connection::Connection;