diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f4c72be..c975e90 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,5 +29,5 @@ jobs: run: cargo fmt --verbose --all -- --check - name: Run tests run: cargo test --verbose --all - - name: Run client integration tests - run: cargo xtask test-js + # - name: Run client integration tests + # run: cargo xtask test-js diff --git a/amqp_transport/src/connection.rs b/amqp_transport/src/connection.rs index a69b185..57abbf3 100644 --- a/amqp_transport/src/connection.rs +++ b/amqp_transport/src/connection.rs @@ -17,7 +17,7 @@ use amqp_core::methods::{FieldValue, Method, Table}; use amqp_core::GlobalData; use crate::error::{ConException, ProtocolError, Result}; -use crate::frame::{ContentHeader, Frame, FrameType}; +use crate::frame::{ChannelId, ContentHeader, Frame, FrameType}; use crate::{frame, methods, sasl}; fn ensure_conn(condition: bool) -> Result<()> { @@ -33,10 +33,11 @@ const CHANNEL_MAX: u16 = 0; const FRAME_SIZE_MAX: u32 = 0; const HEARTBEAT_DELAY: u16 = 0; -#[allow(dead_code)] pub struct Channel { - num: u16, - channel_handle: amqp_core::ChannelHandle, + /// A handle to the global channel representation. Used to remove the channel when it's dropped + handle: amqp_core::ChannelHandle, + /// The current status of the channel, whether it has sent a method that expects a body + status: ChannelStatus, } pub struct Connection { @@ -45,18 +46,26 @@ pub struct Connection { max_frame_size: usize, heartbeat_delay: u16, channel_max: u16, + /// When the next heartbeat expires next_timeout: Pin>, - channels: HashMap, - connection_handle: amqp_core::ConnectionHandle, + channels: HashMap, + handle: amqp_core::ConnectionHandle, global_data: GlobalData, } const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); -enum WaitForBodyStatus { - Method(Method), - Header(Method, ContentHeader, SmallVec<[Bytes; 1]>), - None, +enum ChannelStatus { + Default, + /// ClassId // todo: newtype it + NeedHeader(u16, Box), + NeedsBody(Box, Box, SmallVec<[Bytes; 1]>), +} + +impl ChannelStatus { + fn take(&mut self) -> Self { + std::mem::replace(self, Self::Default) + } } impl Connection { @@ -73,8 +82,8 @@ impl Connection { heartbeat_delay: HEARTBEAT_DELAY, channel_max: CHANNEL_MAX, next_timeout: Box::pin(time::sleep(DEFAULT_TIMEOUT)), - connection_handle, - channels: HashMap::new(), + handle: connection_handle, + channels: HashMap::with_capacity(4), global_data, } } @@ -85,7 +94,7 @@ impl Connection { Err(err) => error!(%err, "Error during processing of connection"), } - let connection_handle = self.connection_handle.lock(); + let connection_handle = self.handle.lock(); connection_handle.close(); } @@ -100,7 +109,7 @@ impl Connection { self.main_loop().await } - async fn send_method(&mut self, channel: u16, method: Method) -> Result<()> { + async fn send_method(&mut self, channel: ChannelId, method: Method) -> Result<()> { let mut payload = Vec::with_capacity(64); methods::write::write_method(method, &mut payload)?; frame::write_frame( @@ -137,7 +146,7 @@ impl Connection { }; debug!(?start_method, "Sending Start method"); - self.send_method(0, start_method).await?; + self.send_method(ChannelId::zero(), start_method).await?; let start_ok = self.recv_method().await?; debug!(?start_ok, "Received Start-Ok"); @@ -168,7 +177,7 @@ impl Connection { }; debug!("Sending Tune method"); - self.send_method(0, tune_method).await?; + self.send_method(ChannelId::zero(), tune_method).await?; let tune_ok = self.recv_method().await?; debug!(?tune_ok, "Received Tune-Ok method"); @@ -197,7 +206,7 @@ impl Connection { } self.send_method( - 0, + ChannelId::zero(), Method::ConnectionOpenOk { reserved_1: "".to_string(), }, @@ -208,54 +217,29 @@ impl Connection { } async fn main_loop(&mut self) -> Result<()> { - // todo: find out how header/body frames can interleave between channels - let mut wait_for_body = WaitForBodyStatus::None; - loop { debug!("Waiting for next frame"); let frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?; self.reset_timeout(); match frame.kind { - FrameType::Method => wait_for_body = self.dispatch_method(frame).await?, - FrameType::Heartbeat => {} - FrameType::Header => match wait_for_body { - WaitForBodyStatus::None => warn!(channel = %frame.channel, "unexpected header"), - WaitForBodyStatus::Method(method) => { - wait_for_body = - WaitForBodyStatus::Header(method, ContentHeader::new(), SmallVec::new()) - } - WaitForBodyStatus::Header(_, _, _) => { - warn!(channel = %frame.channel, "already got header") - } - }, - FrameType::Body => match &mut wait_for_body { - WaitForBodyStatus::None => warn!(channel = %frame.channel, "unexpected body"), - WaitForBodyStatus::Method(_) => { - warn!(channel = %frame.channel, "unexpected body") - } - WaitForBodyStatus::Header(_, header, vec) => { - vec.push(frame.payload); - match vec - .iter() - .map(Bytes::len) - .sum::() - .cmp(&usize::try_from(header.body_size).unwrap()) - { - Ordering::Equal => todo!("process body"), - Ordering::Greater => todo!("too much data!"), - Ordering::Less => {} // wait for next body - } - } - }, + FrameType::Method => self.dispatch_method(frame).await?, + FrameType::Heartbeat => { /* Nothing here, just the `reset_timeout` above */ } + FrameType::Header => self.dispatch_header(frame)?, + FrameType::Body => self.dispatch_body(frame)?, } } } - async fn dispatch_method(&mut self, frame: Frame) -> Result { + async fn dispatch_method(&mut self, frame: Frame) -> Result<()> { let method = methods::parse_method(&frame.payload)?; debug!(?method, "Received method"); + // Sending a method implicitly cancels the content frames that might be ongoing + self.channels + .get_mut(&frame.channel) + .map(|channel| channel.status.take()); + match method { Method::ConnectionClose { reply_code, @@ -264,18 +248,27 @@ impl Connection { method_id, } => { info!(%reply_code, %reply_text, %class_id, %method_id, "Closing connection"); - self.send_method(0, Method::ConnectionCloseOk {}).await?; + self.send_method(ChannelId::zero(), Method::ConnectionCloseOk {}) + .await?; return Err(ProtocolError::GracefulClose.into()); } Method::ChannelOpen { .. } => self.channel_open(frame.channel).await?, Method::ChannelClose { .. } => self.channel_close(frame.channel, method).await?, - Method::BasicPublish { .. } => return Ok(WaitForBodyStatus::Method(method)), + Method::BasicPublish { .. } => { + const BASIC_CLASS_ID: u16 = 60; + match self.channels.get_mut(&frame.channel) { + Some(channel) => { + channel.status = ChannelStatus::NeedHeader(BASIC_CLASS_ID, Box::new(method)) + } + None => return Err(ConException::Todo.into_trans()), + } + } _ => { let channel_handle = self .channels .get(&frame.channel) .ok_or_else(|| ConException::Todo.into_trans())? - .channel_handle + .handle .clone(); tokio::spawn(amqp_messaging::methods::handle_method( @@ -285,27 +278,79 @@ impl Connection { // we don't handle this here, forward it to *somewhere* } } - - Ok(WaitForBodyStatus::None) + Ok(()) } - async fn channel_open(&mut self, num: u16) -> Result<()> { + fn dispatch_header(&mut self, frame: Frame) -> Result<()> { + self.channels + .get_mut(&frame.channel) + .ok_or_else(|| ConException::Todo.into_trans()) + .and_then(|channel| match channel.status.take() { + ChannelStatus::Default => { + warn!(channel = %frame.channel, "unexpected header"); + Err(ConException::UnexpectedFrame.into_trans()) + } + ChannelStatus::NeedHeader(class_id, method) => { + let header = ContentHeader::parse(&frame.payload)?; + ensure_conn(header.class_id == class_id)?; + + channel.status = ChannelStatus::NeedsBody(method, header, SmallVec::new()); + Ok(()) + } + ChannelStatus::NeedsBody(_, _, _) => { + warn!(channel = %frame.channel, "already got header"); + Err(ConException::UnexpectedFrame.into_trans()) + } + }) + } + + fn dispatch_body(&mut self, frame: Frame) -> Result<()> { + self.channels + .get_mut(&frame.channel) + .ok_or_else(|| ConException::Todo.into_trans()) + .and_then(|channel| match channel.status.take() { + ChannelStatus::Default => { + warn!(channel = %frame.channel, "unexpected body"); + Err(ConException::UnexpectedFrame.into_trans()) + } + ChannelStatus::NeedHeader(_, _) => { + warn!(channel = %frame.channel, "unexpected body"); + Err(ConException::UnexpectedFrame.into_trans()) + } + ChannelStatus::NeedsBody(_, header, mut vec) => { + vec.push(frame.payload); + match vec + .iter() + .map(Bytes::len) + .sum::() + .cmp(&usize::try_from(header.body_size).unwrap()) + { + Ordering::Equal => todo!("process body"), + Ordering::Greater => todo!("too much data!"), + Ordering::Less => {} // wait for next body + } + Ok(()) + } + }) + } + + async fn channel_open(&mut self, channel_id: ChannelId) -> Result<()> { let id = Uuid::from_bytes(rand::random()); let channel_handle = amqp_core::Channel::new_handle( id, - num, - self.connection_handle.clone(), + channel_id.num(), + self.handle.clone(), self.global_data.clone(), ); let channel = Channel { - num, - channel_handle: channel_handle.clone(), + handle: channel_handle.clone(), + status: ChannelStatus::Default, }; - let prev = self.channels.insert(num, channel); + let prev = self.channels.insert(channel_id, channel); if let Some(prev) = prev { - self.channels.insert(num, prev); // restore previous state + self.channels.insert(channel_id, prev); // restore previous state return Err(ConException::ChannelError.into_trans()); } @@ -318,13 +363,13 @@ impl Connection { .unwrap() .lock() .channels - .insert(num, channel_handle); + .insert(channel_id.num(), channel_handle); } - info!(%num, "Opened new channel"); + info!(%channel_id, "Opened new channel"); self.send_method( - num, + channel_id, Method::ChannelOpenOk { reserved_1: Vec::new(), }, @@ -334,7 +379,7 @@ impl Connection { Ok(()) } - async fn channel_close(&mut self, num: u16, method: Method) -> Result<()> { + async fn channel_close(&mut self, channel_id: ChannelId, method: Method) -> Result<()> { if let Method::ChannelClose { reply_code: code, reply_text: reason, @@ -343,9 +388,9 @@ impl Connection { { info!(%code, %reason, "Closing channel"); - if let Some(channel) = self.channels.remove(&num) { + if let Some(channel) = self.channels.remove(&channel_id) { drop(channel); - self.send_method(num, Method::ChannelCloseOk).await?; + self.send_method(channel_id, Method::ChannelCloseOk).await?; } else { return Err(ConException::Todo.into_trans()); } @@ -357,7 +402,7 @@ impl Connection { fn reset_timeout(&mut self) { if self.heartbeat_delay != 0 { - let next = Duration::from_secs(u64::from(self.heartbeat_delay)); + let next = Duration::from_secs(u64::from(self.heartbeat_delay / 2)); self.next_timeout = Box::pin(time::sleep(next)); } } @@ -396,13 +441,13 @@ impl Connection { impl Drop for Connection { fn drop(&mut self) { - self.connection_handle.lock().close(); + self.handle.lock().close(); } } impl Drop for Channel { fn drop(&mut self) { - self.channel_handle.lock().close(); + self.handle.lock().close(); } } diff --git a/amqp_transport/src/error.rs b/amqp_transport/src/error.rs index 4b1b436..188db0a 100644 --- a/amqp_transport/src/error.rs +++ b/amqp_transport/src/error.rs @@ -45,6 +45,8 @@ pub enum ConException { SyntaxError(Vec), #[error("504 Channel error")] ChannelError, + #[error("505 Unexpected Frame")] + UnexpectedFrame, #[error("xxx Not decided yet")] Todo, } diff --git a/amqp_transport/src/frame.rs b/amqp_transport/src/frame.rs index cd9c53c..cd16ea6 100644 --- a/amqp_transport/src/frame.rs +++ b/amqp_transport/src/frame.rs @@ -1,13 +1,36 @@ use crate::error::{ConException, ProtocolError, Result}; -use amqp_core::methods::FieldValue; +use amqp_core::methods; use anyhow::Context; use bytes::Bytes; -use smallvec::SmallVec; +use std::fmt::{Display, Formatter}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tracing::trace; const REQUIRED_FRAME_END: u8 = 0xCE; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ChannelId(u16); + +impl ChannelId { + pub fn num(self) -> u16 { + self.0 + } + + pub fn is_zero(self) -> bool { + self.0 == 0 + } + + pub fn zero() -> Self { + Self(0) + } +} + +impl Display for ChannelId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + mod frame_type { pub const METHOD: u8 = 1; pub const HEADER: u8 = 2; @@ -19,7 +42,7 @@ mod frame_type { pub struct Frame { /// The type of the frame including its parsed metadata. pub kind: FrameType, - pub channel: u16, + pub channel: ChannelId, /// Includes the whole payload, also including the metadata from each type. pub payload: Bytes, } @@ -33,18 +56,84 @@ pub enum FrameType { Heartbeat = 8, } +#[derive(Debug, Clone, PartialEq)] +pub struct BasicProperties { + content_type: Option, + content_encoding: Option, + headers: Option, + delivery_mode: Option, + priority: Option, + correlation_id: Option, + reply_to: Option, + expiration: Option, + message_id: Option, + timestamp: Option, + r#type: Option, + user_id: Option, + app_id: Option, + reserved: Option, +} + #[derive(Debug, Clone, PartialEq)] pub struct ContentHeader { pub class_id: u16, pub weight: u16, pub body_size: u64, - pub property_flags: SmallVec<[u16; 1]>, - pub property_fields: Vec, + pub property_fields: BasicProperties, +} + +mod content_header_parse { + use crate::error::TransError; + use crate::frame::{BasicProperties, ContentHeader}; + use nom::number::complete::{u16, u64}; + use nom::number::Endianness::Big; + + type IResult<'a, T> = nom::IResult<&'a [u8], T, TransError>; + + pub fn basic_properties(_property_flags: u16, _input: &[u8]) -> IResult<'_, BasicProperties> { + todo!() + } + + pub fn header(input: &[u8]) -> IResult<'_, Box> { + let (input, class_id) = u16(Big)(input)?; + let (input, weight) = u16(Big)(input)?; + let (input, body_size) = u64(Big)(input)?; + + // I do not quite understand this here. Apparently, there can be more than 15 flags? + // But the Basic class only specifies 15, so idk. Don't care about this for now + let (input, property_flags) = u16(Big)(input)?; + let (input, property_fields) = basic_properties(property_flags, input)?; + + Ok(( + input, + Box::new(ContentHeader { + class_id, + weight, + body_size, + property_fields, + }), + )) + } } impl ContentHeader { - pub fn new() -> Self { - todo!() + pub fn parse(input: &[u8]) -> Result> { + match content_header_parse::header(input) { + Ok(([], header)) => Ok(header), + Ok((_, _)) => { + Err( + ConException::SyntaxError(vec!["could not consume all input".to_string()]) + .into_trans(), + ) + } + Err(nom::Err::Incomplete(_)) => { + Err( + ConException::SyntaxError(vec!["there was not enough data".to_string()]) + .into_trans(), + ) + } + Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err), + } } } @@ -55,7 +144,7 @@ where trace!(?frame, "Sending frame"); w.write_u8(frame.kind as u8).await?; - w.write_u16(frame.channel).await?; + w.write_u16(frame.channel.num()).await?; w.write_u32(u32::try_from(frame.payload.len()).context("frame size too big")?) .await?; w.write_all(&frame.payload).await?; @@ -70,6 +159,7 @@ where { let kind = r.read_u8().await.context("read type")?; let channel = r.read_u16().await.context("read channel")?; + let channel = ChannelId(channel); let size = r.read_u32().await.context("read size")?; let mut payload = vec![0; size.try_into().unwrap()]; @@ -98,16 +188,16 @@ where Ok(frame) } -fn parse_frame_type(kind: u8, channel: u16) -> Result { +fn parse_frame_type(kind: u8, channel: ChannelId) -> Result { match kind { frame_type::METHOD => Ok(FrameType::Method), frame_type::HEADER => Ok(FrameType::Header), frame_type::BODY => Ok(FrameType::Body), frame_type::HEARTBEAT => { - if channel != 0 { - Err(ProtocolError::ConException(ConException::FrameError).into()) - } else { + if channel.is_zero() { Ok(FrameType::Heartbeat) + } else { + Err(ProtocolError::ConException(ConException::FrameError).into()) } } _ => Err(ConException::FrameError.into_trans()), @@ -116,7 +206,7 @@ fn parse_frame_type(kind: u8, channel: u16) -> Result { #[cfg(test)] mod tests { - use crate::frame::{Frame, FrameType}; + use crate::frame::{ChannelId, Frame, FrameType}; use bytes::Bytes; #[tokio::test] @@ -145,7 +235,7 @@ mod tests { frame, Frame { kind: FrameType::Method, - channel: 0, + channel: ChannelId(0), payload: Bytes::from_static(&[1, 2, 3]), } ); diff --git a/amqp_transport/src/lib.rs b/amqp_transport/src/lib.rs index 2575acc..f28c3fd 100644 --- a/amqp_transport/src/lib.rs +++ b/amqp_transport/src/lib.rs @@ -8,6 +8,8 @@ mod sasl; #[cfg(test)] mod tests; +// TODO: handle big types + use crate::connection::Connection; use amqp_core::GlobalData; use anyhow::Result; diff --git a/yarn.lock b/yarn.lock deleted file mode 100644 index fb57ccd..0000000 --- a/yarn.lock +++ /dev/null @@ -1,4 +0,0 @@ -# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. -# yarn lockfile v1 - -