better frame parsing and handling

This commit is contained in:
nora 2022-02-09 14:05:13 +01:00
parent 2b0770705a
commit 706219c046
6 changed files with 132 additions and 31 deletions

21
Cargo.lock generated
View file

@ -19,6 +19,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"rand", "rand",
"thiserror",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@ -315,6 +316,26 @@ dependencies = [
"unicode-xid", "unicode-xid",
] ]
[[package]]
name = "thiserror"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "thread_local" name = "thread_local"
version = "1.1.4" version = "1.1.4"

View file

@ -8,6 +8,7 @@ edition = "2021"
[dependencies] [dependencies]
anyhow = "1.0.53" anyhow = "1.0.53"
rand = "0.8.4" rand = "0.8.4"
thiserror = "1.0.30"
tokio = { version = "1.16.1", features = ["full"] } tokio = { version = "1.16.1", features = ["full"] }
tracing = "0.1.30" tracing = "0.1.30"
tracing-subscriber = "0.3.8" tracing-subscriber = "0.3.8"

View file

@ -1,4 +1,6 @@
use anyhow::{bail, ensure, Result}; use crate::error::{ConError, ProtocolError};
use crate::frame;
use anyhow::{ensure, Context};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tracing::{debug, error}; use tracing::{debug, error};
@ -19,12 +21,16 @@ impl Connection {
} }
} }
pub async fn run(&mut self) -> Result<()> { pub async fn run(&mut self) -> Result<(), ConError> {
self.negotiate_version().await?; self.negotiate_version().await?;
Ok(())
loop {
let frame = frame::read_frame(&mut self.stream, 10000).await?;
debug!(?frame, "received frame");
}
} }
async fn negotiate_version(&mut self) -> Result<()> { async fn negotiate_version(&mut self) -> Result<(), ConError> {
const HEADER_SIZE: usize = 8; const HEADER_SIZE: usize = 8;
const PROTOCOL_VERSION: &[u8] = &[0, 9, 1]; const PROTOCOL_VERSION: &[u8] = &[0, 9, 1];
const PROTOCOL_HEADER: &[u8] = b"AMQP\0\0\x09\x01"; const PROTOCOL_HEADER: &[u8] = b"AMQP\0\0\x09\x01";
@ -33,26 +39,26 @@ impl Connection {
let mut read_header_buf = [0; HEADER_SIZE]; let mut read_header_buf = [0; HEADER_SIZE];
self.stream.read_exact(&mut read_header_buf).await?; self.stream
.read_exact(&mut read_header_buf)
.await
.context("read protocol header")?;
debug!(received_header = ?read_header_buf,"Received protocol header"); debug!(received_header = ?read_header_buf,"Received protocol header");
ensure!(
&read_header_buf[0..5] == b"AMQP\0",
"Received wrong protocol"
);
let version = &read_header_buf[5..8]; let version = &read_header_buf[5..8];
self.stream.write_all(PROTOCOL_HEADER).await?; self.stream
.write_all(PROTOCOL_HEADER)
.await
.context("write protocol header")?;
if version == PROTOCOL_VERSION { if &read_header_buf[0..5] == b"AMQP\0" && version == PROTOCOL_VERSION {
debug!(?version, "Version negotiation successful"); debug!(?version, "Version negotiation successful");
Ok(()) Ok(())
} else { } else {
debug!(?version, expected_version = ?PROTOCOL_VERSION, "Version negotiation failed, unsupported version"); debug!(?version, expected_version = ?PROTOCOL_VERSION, "Version negotiation failed, unsupported version");
self.stream.shutdown().await?; return Err(ProtocolError::OtherCloseConnection.into());
bail!("Unsupported protocol version {:?}", version);
} }
} }
} }

View file

@ -0,0 +1,30 @@
#[derive(Debug, thiserror::Error)]
pub enum ConError {
#[error("{0}")]
Invalid(#[from] ProtocolError),
#[error("connection error: `{0}`")]
Other(#[from] anyhow::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum ProtocolError {
#[error("fatal error")]
Fatal,
#[error("{0}")]
ConException(#[from] ConException),
#[error("{0}")]
ChannelException(#[from] ChannelException),
#[error("closing connection")]
OtherCloseConnection,
}
#[derive(Debug, thiserror::Error)]
pub enum ConException {
#[error("501 Frame error")]
FrameError,
#[error("503 Command invalid")]
CommandInvalid,
}
#[derive(Debug, thiserror::Error)]
pub enum ChannelException {}

View file

@ -1,57 +1,97 @@
use anyhow::Result; use crate::error::{ConError, ConException, ProtocolError};
use anyhow::Context;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
const REQUIRED_FRAME_END: u8 = 0xCE;
#[derive(Debug, Clone, PartialEq, Eq)]
#[repr(u8)]
pub enum FrameType {
Method = 1,
Header = 2,
Body = 3,
Heartbeat = 4,
}
impl TryFrom<u8> for FrameType {
type Error = ConError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
Ok(match value {
1 => Self::Method,
2 => Self::Header,
3 => Self::Body,
4 => Self::Heartbeat,
_ => return Err(ProtocolError::Fatal.into()),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct Frame { pub struct Frame {
r#type: u8, r#type: FrameType,
channel: u16, channel: u16,
size: u32, size: u32,
payload: Vec<u8>, payload: Vec<u8>,
frame_end: u8,
} }
pub async fn read_frame<R>(r: &mut R) -> Result<Frame> pub async fn read_frame<R>(r: &mut R, max_frame_size: usize) -> Result<Frame, ConError>
where where
R: AsyncReadExt + Unpin, R: AsyncReadExt + Unpin,
{ {
let r#type = r.read_u8().await?; let r#type = r.read_u8().await.context("read type")?;
let channel = r.read_u16().await?; let channel = r.read_u16().await.context("read channel")?;
let size = r.read_u32().await?; let size = r.read_u32().await.context("read size")?;
let mut payload = vec![0; size.try_into().unwrap()]; let mut payload = vec![0; size.try_into().unwrap()];
r.read_exact(&mut payload).await?; r.read_exact(&mut payload).await.context("read payload")?;
let frame_end = r.read_u8().await?; let frame_end = r.read_u8().await.context("read frame end")?;
if frame_end != REQUIRED_FRAME_END {
return Err(ProtocolError::Fatal.into());
}
if payload.len() > max_frame_size {
return Err(ProtocolError::ConException(ConException::FrameError).into());
}
Ok(Frame { Ok(Frame {
r#type, r#type: r#type.try_into()?,
channel, channel,
size, size,
payload, payload,
frame_end,
}) })
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::frame::Frame; use crate::frame::{Frame, FrameType};
#[tokio::test] #[tokio::test]
async fn read_small_body() { async fn read_small_body() {
let mut bytes: &[u8] = &[ let mut bytes: &[u8] = &[
/*type*/ 1, /*channel*/ 0, 0, /*size*/ 0, 0, 0, 3, /*payload*/ 1, /*type*/ 1,
2, 3, /*frame-end*/ 0, /*channel*/ 0,
0,
/*size*/ 0,
0,
0,
3,
/*payload*/ 1,
2,
3,
/*frame-end*/ super::REQUIRED_FRAME_END,
]; ];
let frame = super::read_frame(&mut bytes).await.unwrap(); let frame = super::read_frame(&mut bytes, 10000).await.unwrap();
assert_eq!( assert_eq!(
frame, frame,
Frame { Frame {
r#type: 1, r#type: FrameType::Method,
channel: 0, channel: 0,
size: 3, size: 3,
payload: vec![1, 2, 3], payload: vec![1, 2, 3],
frame_end: 0
} }
); );
} }

View file

@ -1,4 +1,7 @@
#![allow(dead_code)]
mod connection; mod connection;
mod error;
mod frame; mod frame;
use crate::connection::Connection; use crate::connection::Connection;