not working :(

This commit is contained in:
nora 2022-03-05 16:26:12 +01:00
parent 3bcce76885
commit 08ba799d23
4 changed files with 140 additions and 74 deletions

View file

@ -4,12 +4,12 @@ use amqp_core::{
connection::{Channel, ConnectionEvent}, connection::{Channel, ConnectionEvent},
error::ChannelException, error::ChannelException,
message::Message, message::Message,
methods::{BasicPublish, Method}, methods::{BasicDeliver, Method},
}; };
use tracing::info; use tracing::debug;
pub async fn publish(channel_handle: Channel, message: Message) -> Result<()> { pub async fn publish(channel_handle: Channel, message: Message) -> Result<()> {
info!(?message, "Publishing message"); debug!(?message, "Publishing message");
let global_data = channel_handle.global_data.clone(); let global_data = channel_handle.global_data.clone();
@ -31,12 +31,12 @@ pub async fn publish(channel_handle: Channel, message: Message) -> Result<()> {
// consuming is hard, but this should work *for now* // consuming is hard, but this should work *for now*
let consumers = queue.consumers.lock(); let consumers = queue.consumers.lock();
if let Some(consumer) = consumers.first() { if let Some(consumer) = consumers.first() {
let method = Box::new(Method::BasicPublish(BasicPublish { let method = Box::new(Method::BasicDeliver(BasicDeliver {
reserved_1: 0, consumer_tag: consumer.tag.clone(),
delivery_tag: 0,
redelivered: false,
exchange: routing.exchange.clone(), exchange: routing.exchange.clone(),
routing_key: routing.routing_key.clone(), routing_key: routing.routing_key.clone(),
mandatory: false,
immediate: false,
})); }));
consumer consumer

View file

@ -1,11 +1,10 @@
use crate::{ use crate::{
error::{ConException, ProtocolError, Result, TransError}, error::{ConException, ProtocolError, Result, TransError},
frame, frame,
frame::{parse_content_header, Frame, FrameType}, frame::{parse_content_header, Frame, FrameType, MaxFrameSize},
methods, sasl, methods, sasl,
}; };
use amqp_core::{ use amqp_core::{
amqp_todo,
connection::{ connection::{
Channel, ChannelInner, ChannelNum, ConEventReceiver, ConEventSender, Connection, Channel, ChannelInner, ChannelNum, ConEventReceiver, ConEventSender, Connection,
ConnectionEvent, ConnectionId, ContentHeader, ConnectionEvent, ConnectionId, ContentHeader,
@ -14,7 +13,7 @@ use amqp_core::{
methods::{ methods::{
BasicPublish, ChannelClose, ChannelCloseOk, ChannelOpenOk, ConnectionClose, BasicPublish, ChannelClose, ChannelCloseOk, ChannelOpenOk, ConnectionClose,
ConnectionCloseOk, ConnectionOpen, ConnectionOpenOk, ConnectionStart, ConnectionStartOk, ConnectionCloseOk, ConnectionOpen, ConnectionOpenOk, ConnectionStart, ConnectionStartOk,
ConnectionTune, ConnectionTuneOk, FieldValue, Method, ReplyCode, ReplyText, Table, ConnectionTune, ConnectionTuneOk, FieldValue, Longstr, Method, ReplyCode, ReplyText, Table,
}, },
GlobalData, GlobalData,
}; };
@ -39,7 +38,7 @@ fn ensure_conn(condition: bool) -> Result<()> {
} }
} }
const FRAME_SIZE_MIN_MAX: usize = 4096; const FRAME_SIZE_MIN_MAX: MaxFrameSize = MaxFrameSize::new(4096);
const CHANNEL_MAX: u16 = 0; const CHANNEL_MAX: u16 = 0;
const FRAME_SIZE_MAX: u32 = 0; const FRAME_SIZE_MAX: u32 = 0;
const HEARTBEAT_DELAY: u16 = 0; const HEARTBEAT_DELAY: u16 = 0;
@ -56,7 +55,7 @@ pub struct TransportChannel {
pub struct TransportConnection { pub struct TransportConnection {
id: ConnectionId, id: ConnectionId,
stream: TcpStream, stream: TcpStream,
max_frame_size: usize, max_frame_size: MaxFrameSize,
heartbeat_delay: u16, heartbeat_delay: u16,
channel_max: u16, channel_max: u16,
/// When the next heartbeat expires /// When the next heartbeat expires
@ -149,23 +148,55 @@ impl TransportConnection {
channel: ChannelNum, channel: ChannelNum,
method: &Method, method: &Method,
header: ContentHeader, header: ContentHeader,
_body: SmallVec<[Bytes; 1]>, body: &SmallVec<[Bytes; 1]>,
) -> Result<()> { ) -> Result<()> {
self.send_method(channel, method).await?; self.send_method(channel, method).await?;
let mut header_buf = Vec::new(); let mut header_buf = Vec::new();
frame::write_content_header(&mut header_buf, header)?; frame::write_content_header(&mut header_buf, &header)?;
frame::write_frame( warn!(?header, ?header_buf, "Sending content header");
&Frame { frame::write_frame(&mut self.stream, FrameType::Header, channel, &header_buf).await?;
kind: FrameType::Method,
channel, self.send_bodies(channel, body).await
payload: header_buf.into(), }
},
&mut self.stream, async fn send_bodies(
) &mut self,
channel: ChannelNum,
body: &SmallVec<[Bytes; 1]>,
) -> Result<()> {
// this is inefficient if it's a huge message sent by a client with big frames to one with
// small frames
// we assume that this won't happen that that the first branch will be taken in most cases,
// elimination the overhead. What we win from keeping each frame as it is that we don't have
// to allocate again for each message
let max_size = self.max_frame_size.as_usize();
for payload in body {
if max_size > payload.len() {
trace!("Sending single method body frame");
// single frame
frame::write_frame(&mut self.stream, FrameType::Body, channel, payload).await?;
} else {
trace!(max = ?self.max_frame_size, "Chunking up method body frames");
// chunk it up into multiple sub-frames
let mut start = 0;
let mut end = max_size;
while end < payload.len() {
let sub_payload = &payload[start..end];
frame::write_frame(&mut self.stream, FrameType::Body, channel, sub_payload)
.await?; .await?;
amqp_todo!() start = end;
end = (end + max_size).max(payload.len());
}
}
}
Ok(())
} }
async fn send_method(&mut self, channel: ChannelNum, method: &Method) -> Result<()> { async fn send_method(&mut self, channel: ChannelNum, method: &Method) -> Result<()> {
@ -173,15 +204,7 @@ impl TransportConnection {
let mut payload = Vec::with_capacity(64); let mut payload = Vec::with_capacity(64);
methods::write::write_method(method, &mut payload)?; methods::write::write_method(method, &mut payload)?;
frame::write_frame( frame::write_frame(&mut self.stream, FrameType::Method, channel, &payload).await
&Frame {
kind: FrameType::Method,
channel,
payload: payload.into(),
},
&mut self.stream,
)
.await
} }
async fn recv_method(&mut self) -> Result<Method> { async fn recv_method(&mut self) -> Result<Method> {
@ -250,7 +273,7 @@ impl TransportConnection {
}) = tune_ok }) = tune_ok
{ {
self.channel_max = channel_max; self.channel_max = channel_max;
self.max_frame_size = usize::try_from(frame_max).unwrap(); self.max_frame_size = MaxFrameSize::new(usize::try_from(frame_max).unwrap());
self.heartbeat_delay = heartbeat; self.heartbeat_delay = heartbeat;
self.reset_timeout(); self.reset_timeout();
} }
@ -286,8 +309,14 @@ impl TransportConnection {
} }
queued_method = self.event_receiver.recv() => { queued_method = self.event_receiver.recv() => {
match queued_method { match queued_method {
Some(ConnectionEvent::Method(channel, method)) => self.send_method(channel, &method).await?, Some(ConnectionEvent::Method(channel, method)) => {
Some(ConnectionEvent::MethodContent(channel, method, header, body)) => self.send_method_content(channel, &method, header, body).await?, trace!(?channel, ?method, "Received method from event queue");
self.send_method(channel, &method).await?
}
Some(ConnectionEvent::MethodContent(channel, method, header, body)) => {
trace!(?channel, ?method, ?header, ?body, "Received method with body from event queue");
self.send_method_content(channel, &method, header, &body).await?
}
Some(ConnectionEvent::Shutdown) => return self.close(0, "".to_owned()).await, Some(ConnectionEvent::Shutdown) => return self.close(0, "".to_owned()).await,
None => {} None => {}
} }
@ -640,13 +669,13 @@ impl Drop for TransportChannel {
} }
fn server_properties(host: SocketAddr) -> Table { fn server_properties(host: SocketAddr) -> Table {
fn ls(str: &str) -> FieldValue { fn ls(str: impl Into<Longstr>) -> FieldValue {
FieldValue::LongString(str.into()) FieldValue::LongString(str.into())
} }
let host_str = host.ip().to_string(); let host_str = host.ip().to_string();
HashMap::from([ HashMap::from([
("host".to_owned(), ls(&host_str)), ("host".to_owned(), ls(host_str)),
("product".to_owned(), ls("no name yet")), ("product".to_owned(), ls("no name yet")),
("version".to_owned(), ls("0.1.0")), ("version".to_owned(), ls("0.1.0")),
("platform".to_owned(), ls("microsoft linux")), ("platform".to_owned(), ls("microsoft linux")),

View file

@ -1,10 +1,11 @@
use crate::error::{ConException, ProtocolError, Result}; use crate::error::{ConException, ProtocolError, Result};
use amqp_core::{ use amqp_core::connection::{ChannelNum, ContentHeader};
amqp_todo,
connection::{ChannelNum, ContentHeader},
};
use anyhow::Context; use anyhow::Context;
use bytes::Bytes; use bytes::Bytes;
use std::{
fmt::{Debug, Formatter},
num::NonZeroUsize,
};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::trace; use tracing::trace;
@ -133,71 +134,82 @@ pub fn parse_content_header(input: &[u8]) -> Result<ContentHeader> {
mod content_header_write { mod content_header_write {
use crate::{ use crate::{
methods::write_helper::{octet, shortstr, table, timestamp}, error::Result,
Result, methods::write_helper::{longlong, octet, short, shortstr, table, timestamp},
}; };
use amqp_core::{ use amqp_core::{
connection::ContentHeader, connection::ContentHeader,
methods::FieldValue::{FieldTable, ShortShortUInt, ShortString, Timestamp}, methods::{
FieldValue::{FieldTable, ShortShortUInt, ShortString, Timestamp},
Table,
},
}; };
pub fn write_content_header(buf: &mut Vec<u8>, header: ContentHeader) -> Result<()> { pub fn write_content_header(buf: &mut Vec<u8>, header: &ContentHeader) -> Result<()> {
short(&header.class_id, buf)?;
short(&header.weight, buf)?;
longlong(&header.body_size, buf)?;
write_content_header_props(buf, &header.property_fields)
}
pub fn write_content_header_props(buf: &mut Vec<u8>, header: &Table) -> Result<()> {
let mut flags = 0_u16; let mut flags = 0_u16;
buf.extend_from_slice(&flags.to_be_bytes()); // placeholder buf.extend_from_slice(&flags.to_be_bytes()); // placeholder
if let Some(ShortString(value)) = header.property_fields.get("content-type") { if let Some(ShortString(value)) = header.get("content-type") {
flags |= 1 << 15; flags |= 1 << 15;
shortstr(value, buf)?; shortstr(value, buf)?;
} }
if let Some(ShortString(value)) = header.property_fields.get("content-encoding") { if let Some(ShortString(value)) = header.get("content-encoding") {
flags |= 1 << 14; flags |= 1 << 14;
shortstr(value, buf)?; shortstr(value, buf)?;
} }
if let Some(FieldTable(value)) = header.property_fields.get("headers") { if let Some(FieldTable(value)) = header.get("headers") {
flags |= 1 << 13; flags |= 1 << 13;
table(value, buf)?; table(value, buf)?;
} }
if let Some(ShortShortUInt(value)) = header.property_fields.get("delivery-mode") { if let Some(ShortShortUInt(value)) = header.get("delivery-mode") {
flags |= 1 << 12; flags |= 1 << 12;
octet(value, buf)?; octet(value, buf)?;
} }
if let Some(ShortShortUInt(value)) = header.property_fields.get("priority") { if let Some(ShortShortUInt(value)) = header.get("priority") {
flags |= 1 << 11; flags |= 1 << 11;
octet(value, buf)?; octet(value, buf)?;
} }
if let Some(ShortString(value)) = header.property_fields.get("correlation-id") { if let Some(ShortString(value)) = header.get("correlation-id") {
flags |= 1 << 10; flags |= 1 << 10;
shortstr(value, buf)?; shortstr(value, buf)?;
} }
if let Some(ShortString(value)) = header.property_fields.get("reply-to") { if let Some(ShortString(value)) = header.get("reply-to") {
flags |= 1 << 9; flags |= 1 << 9;
shortstr(value, buf)?; shortstr(value, buf)?;
} }
if let Some(ShortString(value)) = header.property_fields.get("expiration") { if let Some(ShortString(value)) = header.get("expiration") {
flags |= 1 << 8; flags |= 1 << 8;
shortstr(value, buf)?; shortstr(value, buf)?;
} }
if let Some(ShortString(value)) = header.property_fields.get("message-id") { if let Some(ShortString(value)) = header.get("message-id") {
flags |= 1 << 7; flags |= 1 << 7;
shortstr(value, buf)?; shortstr(value, buf)?;
} }
if let Some(Timestamp(value)) = header.property_fields.get("timestamp") { if let Some(Timestamp(value)) = header.get("timestamp") {
flags |= 1 << 6; flags |= 1 << 6;
timestamp(value, buf)?; timestamp(value, buf)?;
} }
if let Some(ShortString(value)) = header.property_fields.get("type") { if let Some(ShortString(value)) = header.get("type") {
flags |= 1 << 5; flags |= 1 << 5;
shortstr(value, buf)?; shortstr(value, buf)?;
} }
if let Some(ShortString(value)) = header.property_fields.get("user-id") { if let Some(ShortString(value)) = header.get("user-id") {
flags |= 1 << 4; flags |= 1 << 4;
shortstr(value, buf)?; shortstr(value, buf)?;
} }
if let Some(ShortString(value)) = header.property_fields.get("app-id") { if let Some(ShortString(value)) = header.get("app-id") {
flags |= 1 << 3; flags |= 1 << 3;
shortstr(value, buf)?; shortstr(value, buf)?;
} }
if let Some(ShortString(value)) = header.property_fields.get("reserved") { if let Some(ShortString(value)) = header.get("reserved") {
flags |= 1 << 2; flags |= 1 << 2;
shortstr(value, buf)?; shortstr(value, buf)?;
} }
@ -210,27 +222,51 @@ mod content_header_write {
} }
} }
pub fn write_content_header(buf: &mut Vec<u8>, content_header: ContentHeader) -> Result<()> { pub fn write_content_header(buf: &mut Vec<u8>, content_header: &ContentHeader) -> Result<()> {
write_content_header(buf, content_header) content_header_write::write_content_header(buf, content_header)
} }
pub async fn write_frame<W>(frame: &Frame, mut w: W) -> Result<()> #[derive(Clone, Copy)]
pub struct MaxFrameSize(Option<NonZeroUsize>);
impl MaxFrameSize {
pub const fn new(size: usize) -> Self {
Self(NonZeroUsize::new(size))
}
pub fn as_usize(&self) -> usize {
self.0.map(NonZeroUsize::get).unwrap_or(usize::MAX)
}
}
impl Debug for MaxFrameSize {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
pub async fn write_frame<W>(
mut w: W,
kind: FrameType,
channel: ChannelNum,
payload: &[u8],
) -> Result<()>
where where
W: AsyncWriteExt + Unpin + Send, W: AsyncWriteExt + Unpin + Send,
{ {
trace!(?frame, "Sending frame"); trace!(?kind, ?channel, ?payload, "Sending frame");
w.write_u8(frame.kind as u8).await?; w.write_u8(kind as u8).await?;
w.write_u16(frame.channel.num()).await?; w.write_u16(channel.num()).await?;
w.write_u32(u32::try_from(frame.payload.len()).context("frame size too big")?) w.write_u32(u32::try_from(payload.len()).context("frame size too big")?)
.await?; .await?;
w.write_all(&frame.payload).await?; w.write_all(payload).await?;
w.write_u8(REQUIRED_FRAME_END).await?; w.write_u8(REQUIRED_FRAME_END).await?;
Ok(()) Ok(())
} }
pub async fn read_frame<R>(r: &mut R, max_frame_size: usize) -> Result<Frame> pub async fn read_frame<R>(r: &mut R, max_frame_size: MaxFrameSize) -> Result<Frame>
where where
R: AsyncReadExt + Unpin + Send, R: AsyncReadExt + Unpin + Send,
{ {
@ -248,7 +284,7 @@ where
return Err(ProtocolError::Fatal.into()); return Err(ProtocolError::Fatal.into());
} }
if max_frame_size != 0 && payload.len() > max_frame_size { if payload.len() > max_frame_size.as_usize() {
return Err(ConException::FrameError.into()); return Err(ConException::FrameError.into());
} }
@ -283,7 +319,7 @@ fn parse_frame_type(kind: u8, channel: ChannelNum) -> Result<FrameType> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::frame::{ChannelNum, Frame, FrameType}; use crate::frame::{ChannelNum, Frame, FrameType, MaxFrameSize};
use bytes::Bytes; use bytes::Bytes;
#[tokio::test] #[tokio::test]
@ -307,7 +343,9 @@ mod tests {
super::REQUIRED_FRAME_END, super::REQUIRED_FRAME_END,
]; ];
let frame = super::read_frame(&mut bytes, 10000).await.unwrap(); let frame = super::read_frame(&mut bytes, MaxFrameSize::new(10000))
.await
.unwrap();
assert_eq!( assert_eq!(
frame, frame,
Frame { Frame {

View file

@ -12,11 +12,10 @@ mod tests;
use crate::connection::TransportConnection; use crate::connection::TransportConnection;
use amqp_core::GlobalData; use amqp_core::GlobalData;
use anyhow::Result;
use tokio::net; use tokio::net;
use tracing::{info, info_span, Instrument}; use tracing::{info, info_span, Instrument};
pub async fn do_thing_i_guess(global_data: GlobalData) -> Result<()> { pub async fn do_thing_i_guess(global_data: GlobalData) -> anyhow::Result<()> {
info!("Binding TCP listener..."); info!("Binding TCP listener...");
let listener = net::TcpListener::bind(("127.0.0.1", 5672)).await?; let listener = net::TcpListener::bind(("127.0.0.1", 5672)).await?;
info!(addr = ?listener.local_addr()?, "Successfully bound TCP listener"); info!(addr = ?listener.local_addr()?, "Successfully bound TCP listener");