From 08ba799d237f032cc62e9319989fbe2aead8272f Mon Sep 17 00:00:00 2001 From: Nilstrieb <48135649+Nilstrieb@users.noreply.github.com> Date: Sat, 5 Mar 2022 16:26:12 +0100 Subject: [PATCH] not working :( --- amqp_messaging/src/methods/publish.rs | 14 ++-- amqp_transport/src/connection.rs | 91 ++++++++++++++-------- amqp_transport/src/frame.rs | 106 +++++++++++++++++--------- amqp_transport/src/lib.rs | 3 +- 4 files changed, 140 insertions(+), 74 deletions(-) diff --git a/amqp_messaging/src/methods/publish.rs b/amqp_messaging/src/methods/publish.rs index 32958ae..a0ca73f 100644 --- a/amqp_messaging/src/methods/publish.rs +++ b/amqp_messaging/src/methods/publish.rs @@ -4,12 +4,12 @@ use amqp_core::{ connection::{Channel, ConnectionEvent}, error::ChannelException, message::Message, - methods::{BasicPublish, Method}, + methods::{BasicDeliver, Method}, }; -use tracing::info; +use tracing::debug; pub async fn publish(channel_handle: Channel, message: Message) -> Result<()> { - info!(?message, "Publishing message"); + debug!(?message, "Publishing message"); let global_data = channel_handle.global_data.clone(); @@ -31,12 +31,12 @@ pub async fn publish(channel_handle: Channel, message: Message) -> Result<()> { // consuming is hard, but this should work *for now* let consumers = queue.consumers.lock(); if let Some(consumer) = consumers.first() { - let method = Box::new(Method::BasicPublish(BasicPublish { - reserved_1: 0, + let method = Box::new(Method::BasicDeliver(BasicDeliver { + consumer_tag: consumer.tag.clone(), + delivery_tag: 0, + redelivered: false, exchange: routing.exchange.clone(), routing_key: routing.routing_key.clone(), - mandatory: false, - immediate: false, })); consumer diff --git a/amqp_transport/src/connection.rs b/amqp_transport/src/connection.rs index ee10969..35546c4 100644 --- a/amqp_transport/src/connection.rs +++ b/amqp_transport/src/connection.rs @@ -1,11 +1,10 @@ use crate::{ error::{ConException, ProtocolError, Result, TransError}, frame, - frame::{parse_content_header, Frame, FrameType}, + frame::{parse_content_header, Frame, FrameType, MaxFrameSize}, methods, sasl, }; use amqp_core::{ - amqp_todo, connection::{ Channel, ChannelInner, ChannelNum, ConEventReceiver, ConEventSender, Connection, ConnectionEvent, ConnectionId, ContentHeader, @@ -14,7 +13,7 @@ use amqp_core::{ methods::{ BasicPublish, ChannelClose, ChannelCloseOk, ChannelOpenOk, ConnectionClose, ConnectionCloseOk, ConnectionOpen, ConnectionOpenOk, ConnectionStart, ConnectionStartOk, - ConnectionTune, ConnectionTuneOk, FieldValue, Method, ReplyCode, ReplyText, Table, + ConnectionTune, ConnectionTuneOk, FieldValue, Longstr, Method, ReplyCode, ReplyText, Table, }, GlobalData, }; @@ -39,7 +38,7 @@ fn ensure_conn(condition: bool) -> Result<()> { } } -const FRAME_SIZE_MIN_MAX: usize = 4096; +const FRAME_SIZE_MIN_MAX: MaxFrameSize = MaxFrameSize::new(4096); const CHANNEL_MAX: u16 = 0; const FRAME_SIZE_MAX: u32 = 0; const HEARTBEAT_DELAY: u16 = 0; @@ -56,7 +55,7 @@ pub struct TransportChannel { pub struct TransportConnection { id: ConnectionId, stream: TcpStream, - max_frame_size: usize, + max_frame_size: MaxFrameSize, heartbeat_delay: u16, channel_max: u16, /// When the next heartbeat expires @@ -149,23 +148,55 @@ impl TransportConnection { channel: ChannelNum, method: &Method, header: ContentHeader, - _body: SmallVec<[Bytes; 1]>, + body: &SmallVec<[Bytes; 1]>, ) -> Result<()> { self.send_method(channel, method).await?; let mut header_buf = Vec::new(); - frame::write_content_header(&mut header_buf, header)?; - frame::write_frame( - &Frame { - kind: FrameType::Method, - channel, - payload: header_buf.into(), - }, - &mut self.stream, - ) - .await?; + frame::write_content_header(&mut header_buf, &header)?; + warn!(?header, ?header_buf, "Sending content header"); + frame::write_frame(&mut self.stream, FrameType::Header, channel, &header_buf).await?; - amqp_todo!() + self.send_bodies(channel, body).await + } + + async fn send_bodies( + &mut self, + channel: ChannelNum, + body: &SmallVec<[Bytes; 1]>, + ) -> Result<()> { + // this is inefficient if it's a huge message sent by a client with big frames to one with + // small frames + // we assume that this won't happen that that the first branch will be taken in most cases, + // elimination the overhead. What we win from keeping each frame as it is that we don't have + // to allocate again for each message + + let max_size = self.max_frame_size.as_usize(); + + for payload in body { + if max_size > payload.len() { + trace!("Sending single method body frame"); + // single frame + frame::write_frame(&mut self.stream, FrameType::Body, channel, payload).await?; + } else { + trace!(max = ?self.max_frame_size, "Chunking up method body frames"); + // chunk it up into multiple sub-frames + let mut start = 0; + let mut end = max_size; + + while end < payload.len() { + let sub_payload = &payload[start..end]; + + frame::write_frame(&mut self.stream, FrameType::Body, channel, sub_payload) + .await?; + + start = end; + end = (end + max_size).max(payload.len()); + } + } + } + + Ok(()) } async fn send_method(&mut self, channel: ChannelNum, method: &Method) -> Result<()> { @@ -173,15 +204,7 @@ impl TransportConnection { let mut payload = Vec::with_capacity(64); methods::write::write_method(method, &mut payload)?; - frame::write_frame( - &Frame { - kind: FrameType::Method, - channel, - payload: payload.into(), - }, - &mut self.stream, - ) - .await + frame::write_frame(&mut self.stream, FrameType::Method, channel, &payload).await } async fn recv_method(&mut self) -> Result { @@ -250,7 +273,7 @@ impl TransportConnection { }) = tune_ok { self.channel_max = channel_max; - self.max_frame_size = usize::try_from(frame_max).unwrap(); + self.max_frame_size = MaxFrameSize::new(usize::try_from(frame_max).unwrap()); self.heartbeat_delay = heartbeat; self.reset_timeout(); } @@ -286,8 +309,14 @@ impl TransportConnection { } queued_method = self.event_receiver.recv() => { match queued_method { - Some(ConnectionEvent::Method(channel, method)) => self.send_method(channel, &method).await?, - Some(ConnectionEvent::MethodContent(channel, method, header, body)) => self.send_method_content(channel, &method, header, body).await?, + Some(ConnectionEvent::Method(channel, method)) => { + trace!(?channel, ?method, "Received method from event queue"); + self.send_method(channel, &method).await? + } + Some(ConnectionEvent::MethodContent(channel, method, header, body)) => { + trace!(?channel, ?method, ?header, ?body, "Received method with body from event queue"); + self.send_method_content(channel, &method, header, &body).await? + } Some(ConnectionEvent::Shutdown) => return self.close(0, "".to_owned()).await, None => {} } @@ -640,13 +669,13 @@ impl Drop for TransportChannel { } fn server_properties(host: SocketAddr) -> Table { - fn ls(str: &str) -> FieldValue { + fn ls(str: impl Into) -> FieldValue { FieldValue::LongString(str.into()) } let host_str = host.ip().to_string(); HashMap::from([ - ("host".to_owned(), ls(&host_str)), + ("host".to_owned(), ls(host_str)), ("product".to_owned(), ls("no name yet")), ("version".to_owned(), ls("0.1.0")), ("platform".to_owned(), ls("microsoft linux")), diff --git a/amqp_transport/src/frame.rs b/amqp_transport/src/frame.rs index d2ff318..4df5c29 100644 --- a/amqp_transport/src/frame.rs +++ b/amqp_transport/src/frame.rs @@ -1,10 +1,11 @@ use crate::error::{ConException, ProtocolError, Result}; -use amqp_core::{ - amqp_todo, - connection::{ChannelNum, ContentHeader}, -}; +use amqp_core::connection::{ChannelNum, ContentHeader}; use anyhow::Context; use bytes::Bytes; +use std::{ + fmt::{Debug, Formatter}, + num::NonZeroUsize, +}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tracing::trace; @@ -133,71 +134,82 @@ pub fn parse_content_header(input: &[u8]) -> Result { mod content_header_write { use crate::{ - methods::write_helper::{octet, shortstr, table, timestamp}, - Result, + error::Result, + methods::write_helper::{longlong, octet, short, shortstr, table, timestamp}, }; use amqp_core::{ connection::ContentHeader, - methods::FieldValue::{FieldTable, ShortShortUInt, ShortString, Timestamp}, + methods::{ + FieldValue::{FieldTable, ShortShortUInt, ShortString, Timestamp}, + Table, + }, }; - pub fn write_content_header(buf: &mut Vec, header: ContentHeader) -> Result<()> { + pub fn write_content_header(buf: &mut Vec, header: &ContentHeader) -> Result<()> { + short(&header.class_id, buf)?; + short(&header.weight, buf)?; + longlong(&header.body_size, buf)?; + + write_content_header_props(buf, &header.property_fields) + } + + pub fn write_content_header_props(buf: &mut Vec, header: &Table) -> Result<()> { let mut flags = 0_u16; buf.extend_from_slice(&flags.to_be_bytes()); // placeholder - if let Some(ShortString(value)) = header.property_fields.get("content-type") { + if let Some(ShortString(value)) = header.get("content-type") { flags |= 1 << 15; shortstr(value, buf)?; } - if let Some(ShortString(value)) = header.property_fields.get("content-encoding") { + if let Some(ShortString(value)) = header.get("content-encoding") { flags |= 1 << 14; shortstr(value, buf)?; } - if let Some(FieldTable(value)) = header.property_fields.get("headers") { + if let Some(FieldTable(value)) = header.get("headers") { flags |= 1 << 13; table(value, buf)?; } - if let Some(ShortShortUInt(value)) = header.property_fields.get("delivery-mode") { + if let Some(ShortShortUInt(value)) = header.get("delivery-mode") { flags |= 1 << 12; octet(value, buf)?; } - if let Some(ShortShortUInt(value)) = header.property_fields.get("priority") { + if let Some(ShortShortUInt(value)) = header.get("priority") { flags |= 1 << 11; octet(value, buf)?; } - if let Some(ShortString(value)) = header.property_fields.get("correlation-id") { + if let Some(ShortString(value)) = header.get("correlation-id") { flags |= 1 << 10; shortstr(value, buf)?; } - if let Some(ShortString(value)) = header.property_fields.get("reply-to") { + if let Some(ShortString(value)) = header.get("reply-to") { flags |= 1 << 9; shortstr(value, buf)?; } - if let Some(ShortString(value)) = header.property_fields.get("expiration") { + if let Some(ShortString(value)) = header.get("expiration") { flags |= 1 << 8; shortstr(value, buf)?; } - if let Some(ShortString(value)) = header.property_fields.get("message-id") { + if let Some(ShortString(value)) = header.get("message-id") { flags |= 1 << 7; shortstr(value, buf)?; } - if let Some(Timestamp(value)) = header.property_fields.get("timestamp") { + if let Some(Timestamp(value)) = header.get("timestamp") { flags |= 1 << 6; timestamp(value, buf)?; } - if let Some(ShortString(value)) = header.property_fields.get("type") { + if let Some(ShortString(value)) = header.get("type") { flags |= 1 << 5; shortstr(value, buf)?; } - if let Some(ShortString(value)) = header.property_fields.get("user-id") { + if let Some(ShortString(value)) = header.get("user-id") { flags |= 1 << 4; shortstr(value, buf)?; } - if let Some(ShortString(value)) = header.property_fields.get("app-id") { + if let Some(ShortString(value)) = header.get("app-id") { flags |= 1 << 3; shortstr(value, buf)?; } - if let Some(ShortString(value)) = header.property_fields.get("reserved") { + if let Some(ShortString(value)) = header.get("reserved") { flags |= 1 << 2; shortstr(value, buf)?; } @@ -210,27 +222,51 @@ mod content_header_write { } } -pub fn write_content_header(buf: &mut Vec, content_header: ContentHeader) -> Result<()> { - write_content_header(buf, content_header) +pub fn write_content_header(buf: &mut Vec, content_header: &ContentHeader) -> Result<()> { + content_header_write::write_content_header(buf, content_header) } -pub async fn write_frame(frame: &Frame, mut w: W) -> Result<()> +#[derive(Clone, Copy)] +pub struct MaxFrameSize(Option); + +impl MaxFrameSize { + pub const fn new(size: usize) -> Self { + Self(NonZeroUsize::new(size)) + } + + pub fn as_usize(&self) -> usize { + self.0.map(NonZeroUsize::get).unwrap_or(usize::MAX) + } +} + +impl Debug for MaxFrameSize { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +pub async fn write_frame( + mut w: W, + kind: FrameType, + channel: ChannelNum, + payload: &[u8], +) -> Result<()> where W: AsyncWriteExt + Unpin + Send, { - trace!(?frame, "Sending frame"); + trace!(?kind, ?channel, ?payload, "Sending frame"); - w.write_u8(frame.kind as u8).await?; - w.write_u16(frame.channel.num()).await?; - w.write_u32(u32::try_from(frame.payload.len()).context("frame size too big")?) + w.write_u8(kind as u8).await?; + w.write_u16(channel.num()).await?; + w.write_u32(u32::try_from(payload.len()).context("frame size too big")?) .await?; - w.write_all(&frame.payload).await?; + w.write_all(payload).await?; w.write_u8(REQUIRED_FRAME_END).await?; Ok(()) } -pub async fn read_frame(r: &mut R, max_frame_size: usize) -> Result +pub async fn read_frame(r: &mut R, max_frame_size: MaxFrameSize) -> Result where R: AsyncReadExt + Unpin + Send, { @@ -248,7 +284,7 @@ where return Err(ProtocolError::Fatal.into()); } - if max_frame_size != 0 && payload.len() > max_frame_size { + if payload.len() > max_frame_size.as_usize() { return Err(ConException::FrameError.into()); } @@ -283,7 +319,7 @@ fn parse_frame_type(kind: u8, channel: ChannelNum) -> Result { #[cfg(test)] mod tests { - use crate::frame::{ChannelNum, Frame, FrameType}; + use crate::frame::{ChannelNum, Frame, FrameType, MaxFrameSize}; use bytes::Bytes; #[tokio::test] @@ -307,7 +343,9 @@ mod tests { super::REQUIRED_FRAME_END, ]; - let frame = super::read_frame(&mut bytes, 10000).await.unwrap(); + let frame = super::read_frame(&mut bytes, MaxFrameSize::new(10000)) + .await + .unwrap(); assert_eq!( frame, Frame { diff --git a/amqp_transport/src/lib.rs b/amqp_transport/src/lib.rs index cdeefd6..16c2a15 100644 --- a/amqp_transport/src/lib.rs +++ b/amqp_transport/src/lib.rs @@ -12,11 +12,10 @@ mod tests; use crate::connection::TransportConnection; use amqp_core::GlobalData; -use anyhow::Result; use tokio::net; use tracing::{info, info_span, Instrument}; -pub async fn do_thing_i_guess(global_data: GlobalData) -> Result<()> { +pub async fn do_thing_i_guess(global_data: GlobalData) -> anyhow::Result<()> { info!("Binding TCP listener..."); let listener = net::TcpListener::bind(("127.0.0.1", 5672)).await?; info!(addr = ?listener.local_addr()?, "Successfully bound TCP listener");