diff --git a/amqp_transport/src/classes/mod.rs b/amqp_transport/src/classes/mod.rs index 437073d..72bfd3f 100644 --- a/amqp_transport/src/classes/mod.rs +++ b/amqp_transport/src/classes/mod.rs @@ -1,4 +1,4 @@ -use crate::error::{ConException, ProtocolError, TransError}; +use crate::error::{ConException, TransError}; use std::collections::HashMap; mod generated; @@ -41,15 +41,17 @@ pub fn parse_method(payload: &[u8]) -> Result { match nom_result { Ok(([], class)) => Ok(class), - Ok((_, _)) => Err(ProtocolError::ConException(ConException::SyntaxError(vec![ - "could not consume all input".to_string(), - ])) - .into()), + Ok((_, _)) => { + Err( + ConException::SyntaxError(vec!["could not consume all input".to_string()]) + .into_trans(), + ) + } Err(nom::Err::Incomplete(_)) => { - Err(ProtocolError::ConException(ConException::SyntaxError(vec![ - "there was not enough data".to_string(), - ])) - .into()) + Err( + ConException::SyntaxError(vec!["there was not enough data".to_string()]) + .into_trans(), + ) } Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err), } diff --git a/amqp_transport/src/classes/parse_helper.rs b/amqp_transport/src/classes/parse_helper.rs index bac771c..c386c26 100644 --- a/amqp_transport/src/classes/parse_helper.rs +++ b/amqp_transport/src/classes/parse_helper.rs @@ -17,7 +17,7 @@ use std::collections::HashMap; impl nom::error::ParseError for TransError { fn from_error_kind(_input: T, _kind: ErrorKind) -> Self { - ProtocolError::ConException(ConException::SyntaxError(vec![])).into() + ConException::SyntaxError(vec![]).into_trans() } fn append(_input: T, _kind: ErrorKind, other: Self) -> Self { @@ -47,13 +47,11 @@ pub fn err>(msg: S) -> impl FnOnce(Err) -> Err vec![msg], }; - error_level(ProtocolError::ConException(ConException::SyntaxError(stack)).into()) + error_level(ConException::SyntaxError(stack).into_trans()) } } pub fn err_other>(msg: S) -> impl FnOnce(E) -> Err { - move |_| { - Err::Error(ProtocolError::ConException(ConException::SyntaxError(vec![msg.into()])).into()) - } + move |_| Err::Error(ConException::SyntaxError(vec![msg.into()]).into_trans()) } pub fn failure(err: Err) -> Err { @@ -145,7 +143,7 @@ pub fn table(input: &[u8]) -> IResult { let (input, values) = many0(table_value_pair)(table_input)?; - if input != &[] { + if !input.is_empty() { fail!(format!( "table longer than expected, expected = {size}, remaining = {}", input.len() diff --git a/amqp_transport/src/connection.rs b/amqp_transport/src/connection.rs index dd0fb8d..317ec0f 100644 --- a/amqp_transport/src/connection.rs +++ b/amqp_transport/src/connection.rs @@ -1,50 +1,90 @@ -use crate::classes::FieldValue; use crate::error::{ConException, ProtocolError, Result}; use crate::frame::{Frame, FrameType}; -use crate::{classes, frame}; +use crate::{classes, frame, sasl}; use anyhow::Context; use std::collections::HashMap; use std::net::SocketAddr; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; -use tracing::{debug, error}; +use tracing::{debug, error, info}; +use uuid::Uuid; -const MIN_MAX_FRAME_SIZE: usize = 4096; +fn ensure_conn(condition: bool) -> Result<()> { + if condition { + Ok(()) + } else { + Err(ConException::Todo.into_trans()) + } +} + +const FRAME_SIZE_MIN_MAX: usize = 4096; +const CHANNEL_MAX: u16 = 0; +const FRAME_SIZE_MAX: u32 = 0; +const HEARTBEAT_DELAY: u16 = 0; pub struct Connection { stream: TcpStream, max_frame_size: usize, + heartbeat_delay: u16, + channel_max: u16, + id: Uuid, } impl Connection { - pub fn new(stream: TcpStream) -> Self { + pub fn new(stream: TcpStream, id: Uuid) -> Self { Self { stream, - max_frame_size: MIN_MAX_FRAME_SIZE, + max_frame_size: FRAME_SIZE_MIN_MAX, + heartbeat_delay: HEARTBEAT_DELAY, + channel_max: CHANNEL_MAX, + id, } } - pub async fn open_connection(mut self) { - match self.run().await { + pub async fn start_connection_processing(mut self) { + match self.process_connection().await { Ok(()) => {} Err(err) => error!(%err, "Error during processing of connection"), } } - pub async fn run(&mut self) -> Result<()> { + pub async fn process_connection(&mut self) -> Result<()> { self.negotiate_version().await?; self.start().await?; + self.tune().await?; + self.open().await?; + + info!("Connection is ready for usage!"); loop { - let frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?; - debug!(?frame, "received frame"); - if frame.kind == FrameType::Method { - let class = super::classes::parse_method(&frame.payload)?; - debug!(?class, "was method frame"); - } + let method = self.recv_method().await?; + debug!(?method, "Received method"); } } + async fn send_method(&mut self, channel: u16, method: classes::Class) -> Result<()> { + let mut payload = Vec::with_capacity(64); + classes::write::write_method(method, &mut payload)?; + frame::write_frame( + &Frame { + kind: FrameType::Method, + channel, + payload, + }, + &mut self.stream, + ) + .await + } + + async fn recv_method(&mut self) -> Result { + let start_ok_frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?; + + ensure_conn(start_ok_frame.kind == FrameType::Method)?; + + let class = classes::parse_method(&start_ok_frame.payload)?; + Ok(class) + } + async fn start(&mut self) -> Result<()> { let start_method = classes::Class::Connection(classes::Connection::Start { version_major: 0, @@ -58,30 +98,72 @@ impl Connection { locales: "en_US".into(), }); - debug!(?start_method, "Sending start method"); + debug!(?start_method, "Sending Start method"); + self.send_method(0, start_method).await?; - let mut payload = Vec::with_capacity(64); - classes::write::write_method(start_method, &mut payload)?; - frame::write_frame( - &Frame { - kind: FrameType::Method, - channel: 0, - payload, - }, - &mut self.stream, - ) - .await?; + let start_ok = self.recv_method().await?; + debug!(?start_ok, "Received Start-Ok"); - let start_ok_frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?; - debug!(?start_ok_frame, "Received Start-Ok frame"); - - if start_ok_frame.kind != FrameType::Method { - return Err(ProtocolError::ConException(ConException::Todo).into()); + if let classes::Class::Connection(classes::Connection::StartOk { + mechanism, + locale, + response, + .. + }) = start_ok + { + ensure_conn(mechanism == "PLAIN")?; + ensure_conn(locale == "en_US")?; + let plain_user = sasl::parse_sasl_plain_response(&response)?; + info!(username = %plain_user.authentication_identity, "SASL Authentication successful") + } else { + return Err(ConException::Todo.into_trans()); } - let class = classes::parse_method(&start_ok_frame.payload)?; + Ok(()) + } - debug!(?class, "extracted method"); + async fn tune(&mut self) -> Result<()> { + let tune_method = classes::Class::Connection(classes::Connection::Tune { + channel_max: CHANNEL_MAX, + frame_max: FRAME_SIZE_MAX, + heartbeat: HEARTBEAT_DELAY, + }); + + debug!("Sending Tune method"); + self.send_method(0, tune_method).await?; + + let tune_ok = self.recv_method().await?; + debug!(?tune_ok, "Received Tune-Ok method"); + + if let classes::Class::Connection(classes::Connection::TuneOk { + channel_max, + frame_max, + heartbeat, + }) = tune_ok + { + self.channel_max = channel_max; + self.max_frame_size = usize::try_from(frame_max).unwrap(); + self.heartbeat_delay = heartbeat; + } + + Ok(()) + } + + async fn open(&mut self) -> Result<()> { + let open = self.recv_method().await?; + debug!(?open, "Received Open method"); + + if let classes::Class::Connection(classes::Connection::Open { virtual_host, .. }) = open { + ensure_conn(virtual_host == "/")?; + } + + self.send_method( + 0, + classes::Class::Connection(classes::Connection::OpenOk { + reserved_1: "".to_string(), + }), + ) + .await?; Ok(()) } @@ -120,21 +202,18 @@ impl Connection { } fn server_properties(host: SocketAddr) -> classes::Table { - fn ss(str: &str) -> FieldValue { - FieldValue::LongString(str.into()) + fn ls(str: &str) -> classes::FieldValue { + classes::FieldValue::LongString(str.into()) } let host_str = host.ip().to_string(); HashMap::from([ - ("host".to_string(), ss(&host_str)), - ( - "product".to_string(), - ss("no name yet"), - ), - ("version".to_string(), ss("0.1.0")), - ("platform".to_string(), ss("microsoft linux")), - ("copyright".to_string(), ss("MIT")), - ("information".to_string(), ss("hello reader")), - ("uwu".to_string(), ss("owo")), + ("host".to_string(), ls(&host_str)), + ("product".to_string(), ls("no name yet")), + ("version".to_string(), ls("0.1.0")), + ("platform".to_string(), ls("microsoft linux")), + ("copyright".to_string(), ls("MIT")), + ("information".to_string(), ls("hello reader")), + ("uwu".to_string(), ls("owo")), ]) } diff --git a/amqp_transport/src/error.rs b/amqp_transport/src/error.rs index 7a9a13a..b5b4b6d 100644 --- a/amqp_transport/src/error.rs +++ b/amqp_transport/src/error.rs @@ -1,6 +1,8 @@ use std::io::Error; -pub type Result = std::result::Result; +pub type StdResult = std::result::Result; + +pub type Result = StdResult; #[derive(Debug, thiserror::Error)] pub enum TransError { @@ -34,7 +36,7 @@ pub enum ConException { FrameError, #[error("503 Command invalid")] CommandInvalid, - #[error("503 Syntax error")] + #[error("503 Syntax error | {0:?}")] /// A method was received but there was a syntax error. The string stores where it occured. SyntaxError(Vec), #[error("504 Channel error")] @@ -43,5 +45,11 @@ pub enum ConException { Todo, } +impl ConException { + pub fn into_trans(self) -> TransError { + TransError::Invalid(ProtocolError::ConException(self)) + } +} + #[derive(Debug, thiserror::Error)] pub enum ChannelException {} diff --git a/amqp_transport/src/frame.rs b/amqp_transport/src/frame.rs index 31b340b..689b5de 100644 --- a/amqp_transport/src/frame.rs +++ b/amqp_transport/src/frame.rs @@ -1,7 +1,7 @@ use crate::error::{ConException, ProtocolError, Result}; use anyhow::Context; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tracing::debug; +use tracing::trace; const REQUIRED_FRAME_END: u8 = 0xCE; @@ -30,11 +30,11 @@ pub enum FrameType { Heartbeat = 8, } -pub async fn write_frame(frame: &Frame, mut w: W, ) -> Result<()> +pub async fn write_frame(frame: &Frame, mut w: W) -> Result<()> where W: AsyncWriteExt + Unpin, { - debug!(?frame, "sending frame"); + trace!(?frame, "Sending frame"); w.write_u8(frame.kind as u8).await?; w.write_u16(frame.channel).await?; @@ -63,17 +63,21 @@ where return Err(ProtocolError::Fatal.into()); } - if payload.len() > max_frame_size { - return Err(ProtocolError::ConException(ConException::FrameError).into()); + if max_frame_size != 0 && payload.len() > max_frame_size { + return Err(ConException::FrameError.into_trans()); } let kind = parse_frame_type(kind, channel)?; - Ok(Frame { + let frame = Frame { kind, channel, payload, - }) + }; + + trace!(?frame, "Received frame"); + + Ok(frame) } fn parse_frame_type(kind: u8, channel: u16) -> Result { @@ -88,7 +92,7 @@ fn parse_frame_type(kind: u8, channel: u16) -> Result { Ok(FrameType::Heartbeat) } } - _ => Err(ProtocolError::ConException(ConException::FrameError).into()), + _ => Err(ConException::FrameError.into_trans()), } } diff --git a/amqp_transport/src/lib.rs b/amqp_transport/src/lib.rs index 18bc06f..ab6361e 100644 --- a/amqp_transport/src/lib.rs +++ b/amqp_transport/src/lib.rs @@ -6,13 +6,15 @@ mod classes; mod connection; mod error; mod frame; +mod sasl; #[cfg(test)] mod tests; use crate::connection::Connection; use anyhow::Result; use tokio::net; -use tracing::info; +use tracing::{info, info_span, Instrument}; +use uuid::Uuid; pub async fn do_thing_i_guess() -> Result<()> { info!("Binding TCP listener..."); @@ -22,10 +24,13 @@ pub async fn do_thing_i_guess() -> Result<()> { loop { let (stream, _) = listener.accept().await?; - info!(local_addr = ?stream.local_addr(), "Accepted new connection"); + let id = Uuid::from_bytes(rand::random()); - let connection = Connection::new(stream); + info!(local_addr = ?stream.local_addr(), %id, "Accepted new connection"); + let span = info_span!("client-connection", %id); - tokio::spawn(connection.open_connection()); + let connection = Connection::new(stream, id); + + tokio::spawn(connection.start_connection_processing().instrument(span)); } } diff --git a/amqp_transport/src/sasl.rs b/amqp_transport/src/sasl.rs new file mode 100644 index 0000000..caefd12 --- /dev/null +++ b/amqp_transport/src/sasl.rs @@ -0,0 +1,33 @@ +//! Partial implementation of the SASL Authentication (see [RFC 4422](https://datatracker.ietf.org/doc/html/rfc4422)) +//! +//! Currently only supports PLAN (see [RFC 4616](https://datatracker.ietf.org/doc/html/rfc4616)) + +use crate::error::{ConException, Result}; + +pub struct PlainUser { + pub authorization_identity: String, + pub authentication_identity: String, + pub password: String, +} + +pub fn parse_sasl_plain_response(response: &[u8]) -> Result { + let mut parts = response + .split(|&n| n == 0) + .map(|bytes| String::from_utf8(bytes.into()).map_err(|_| ConException::Todo.into_trans())); + + let authorization_identity = parts + .next() + .ok_or_else(|| ConException::Todo.into_trans())??; + let authentication_identity = parts + .next() + .ok_or_else(|| ConException::Todo.into_trans())??; + let password = parts + .next() + .ok_or_else(|| ConException::Todo.into_trans())??; + + Ok(PlainUser { + authorization_identity, + authentication_identity, + password, + }) +} diff --git a/amqp_transport/src/tests.rs b/amqp_transport/src/tests.rs index db0fab5..a0a929a 100644 --- a/amqp_transport/src/tests.rs +++ b/amqp_transport/src/tests.rs @@ -29,8 +29,6 @@ async fn write_start_ok_frame() { frame::write_frame(&frame, &mut output).await.unwrap(); - - #[rustfmt::skip] let expected = [ /* type, octet, method */ @@ -76,8 +74,6 @@ async fn write_start_ok_frame() { #[test] fn read_start_ok_payload() { - - #[rustfmt::skip] let raw_data = [ /* Connection.Start-Ok */ diff --git a/src/main.rs b/src/main.rs index f7420b9..4eb2a69 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,28 @@ use anyhow::Result; +use std::env; use tracing::Level; #[tokio::main] async fn main() -> Result<()> { - setup_tracing(); + let mut level = Level::DEBUG; + + for arg in env::args().skip(1) { + match arg.as_str() { + "--trace" => level = Level::TRACE, + _ => {} + } + } + + setup_tracing(level); amqp_transport::do_thing_i_guess().await } -fn setup_tracing() { +fn setup_tracing(level: Level) { tracing_subscriber::fmt() .with_level(true) .with_timer(tracing_subscriber::fmt::time::time()) .with_ansi(true) .with_thread_names(true) - .with_max_level(Level::DEBUG) + .with_max_level(level) .init() }