better frame parsing and handling

This commit is contained in:
nora 2022-02-09 14:05:13 +01:00
parent 2b0770705a
commit 706219c046
6 changed files with 132 additions and 31 deletions

21
Cargo.lock generated
View file

@ -19,6 +19,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"rand",
"thiserror",
"tokio",
"tracing",
"tracing-subscriber",
@ -315,6 +316,26 @@ dependencies = [
"unicode-xid",
]
[[package]]
name = "thiserror"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "thread_local"
version = "1.1.4"

View file

@ -8,6 +8,7 @@ edition = "2021"
[dependencies]
anyhow = "1.0.53"
rand = "0.8.4"
thiserror = "1.0.30"
tokio = { version = "1.16.1", features = ["full"] }
tracing = "0.1.30"
tracing-subscriber = "0.3.8"

View file

@ -1,4 +1,6 @@
use anyhow::{bail, ensure, Result};
use crate::error::{ConError, ProtocolError};
use crate::frame;
use anyhow::{ensure, Context};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, error};
@ -19,12 +21,16 @@ impl Connection {
}
}
pub async fn run(&mut self) -> Result<()> {
pub async fn run(&mut self) -> Result<(), ConError> {
self.negotiate_version().await?;
Ok(())
loop {
let frame = frame::read_frame(&mut self.stream, 10000).await?;
debug!(?frame, "received frame");
}
}
async fn negotiate_version(&mut self) -> Result<()> {
async fn negotiate_version(&mut self) -> Result<(), ConError> {
const HEADER_SIZE: usize = 8;
const PROTOCOL_VERSION: &[u8] = &[0, 9, 1];
const PROTOCOL_HEADER: &[u8] = b"AMQP\0\0\x09\x01";
@ -33,26 +39,26 @@ impl Connection {
let mut read_header_buf = [0; HEADER_SIZE];
self.stream.read_exact(&mut read_header_buf).await?;
self.stream
.read_exact(&mut read_header_buf)
.await
.context("read protocol header")?;
debug!(received_header = ?read_header_buf,"Received protocol header");
ensure!(
&read_header_buf[0..5] == b"AMQP\0",
"Received wrong protocol"
);
let version = &read_header_buf[5..8];
self.stream.write_all(PROTOCOL_HEADER).await?;
self.stream
.write_all(PROTOCOL_HEADER)
.await
.context("write protocol header")?;
if version == PROTOCOL_VERSION {
if &read_header_buf[0..5] == b"AMQP\0" && version == PROTOCOL_VERSION {
debug!(?version, "Version negotiation successful");
Ok(())
} else {
debug!(?version, expected_version = ?PROTOCOL_VERSION, "Version negotiation failed, unsupported version");
self.stream.shutdown().await?;
bail!("Unsupported protocol version {:?}", version);
return Err(ProtocolError::OtherCloseConnection.into());
}
}
}

View file

@ -0,0 +1,30 @@
#[derive(Debug, thiserror::Error)]
pub enum ConError {
#[error("{0}")]
Invalid(#[from] ProtocolError),
#[error("connection error: `{0}`")]
Other(#[from] anyhow::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum ProtocolError {
#[error("fatal error")]
Fatal,
#[error("{0}")]
ConException(#[from] ConException),
#[error("{0}")]
ChannelException(#[from] ChannelException),
#[error("closing connection")]
OtherCloseConnection,
}
#[derive(Debug, thiserror::Error)]
pub enum ConException {
#[error("501 Frame error")]
FrameError,
#[error("503 Command invalid")]
CommandInvalid,
}
#[derive(Debug, thiserror::Error)]
pub enum ChannelException {}

View file

@ -1,57 +1,97 @@
use anyhow::Result;
use crate::error::{ConError, ConException, ProtocolError};
use anyhow::Context;
use tokio::io::AsyncReadExt;
const REQUIRED_FRAME_END: u8 = 0xCE;
#[derive(Debug, Clone, PartialEq, Eq)]
#[repr(u8)]
pub enum FrameType {
Method = 1,
Header = 2,
Body = 3,
Heartbeat = 4,
}
impl TryFrom<u8> for FrameType {
type Error = ConError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
Ok(match value {
1 => Self::Method,
2 => Self::Header,
3 => Self::Body,
4 => Self::Heartbeat,
_ => return Err(ProtocolError::Fatal.into()),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Frame {
r#type: u8,
r#type: FrameType,
channel: u16,
size: u32,
payload: Vec<u8>,
frame_end: u8,
}
pub async fn read_frame<R>(r: &mut R) -> Result<Frame>
pub async fn read_frame<R>(r: &mut R, max_frame_size: usize) -> Result<Frame, ConError>
where
R: AsyncReadExt + Unpin,
{
let r#type = r.read_u8().await?;
let channel = r.read_u16().await?;
let size = r.read_u32().await?;
let r#type = r.read_u8().await.context("read type")?;
let channel = r.read_u16().await.context("read channel")?;
let size = r.read_u32().await.context("read size")?;
let mut payload = vec![0; size.try_into().unwrap()];
r.read_exact(&mut payload).await?;
r.read_exact(&mut payload).await.context("read payload")?;
let frame_end = r.read_u8().await?;
let frame_end = r.read_u8().await.context("read frame end")?;
if frame_end != REQUIRED_FRAME_END {
return Err(ProtocolError::Fatal.into());
}
if payload.len() > max_frame_size {
return Err(ProtocolError::ConException(ConException::FrameError).into());
}
Ok(Frame {
r#type,
r#type: r#type.try_into()?,
channel,
size,
payload,
frame_end,
})
}
#[cfg(test)]
mod tests {
use crate::frame::Frame;
use crate::frame::{Frame, FrameType};
#[tokio::test]
async fn read_small_body() {
let mut bytes: &[u8] = &[
/*type*/ 1, /*channel*/ 0, 0, /*size*/ 0, 0, 0, 3, /*payload*/ 1,
2, 3, /*frame-end*/ 0,
/*type*/ 1,
/*channel*/ 0,
0,
/*size*/ 0,
0,
0,
3,
/*payload*/ 1,
2,
3,
/*frame-end*/ super::REQUIRED_FRAME_END,
];
let frame = super::read_frame(&mut bytes).await.unwrap();
let frame = super::read_frame(&mut bytes, 10000).await.unwrap();
assert_eq!(
frame,
Frame {
r#type: 1,
r#type: FrameType::Method,
channel: 0,
size: 3,
payload: vec![1, 2, 3],
frame_end: 0
}
);
}

View file

@ -1,4 +1,7 @@
#![allow(dead_code)]
mod connection;
mod error;
mod frame;
use crate::connection::Connection;