diff --git a/amqp_core/src/connection.rs b/amqp_core/src/connection.rs index 685542c..4055724 100644 --- a/amqp_core/src/connection.rs +++ b/amqp_core/src/connection.rs @@ -53,17 +53,18 @@ pub struct ConnectionInner { pub global_data: GlobalData, pub channels: Mutex>, pub exclusive_queues: Vec, - _events: ConEventSender, + pub event_sender: ConEventSender, } #[derive(Debug)] -pub enum QueuedMethod { - Normal(Method), - WithContent(Method, ContentHeader, SmallVec<[Bytes; 1]>), +pub enum ConnectionEvent { + Shutdown, + Method(ChannelNum, Box), + MethodContent(ChannelNum, Box, ContentHeader, SmallVec<[Bytes; 1]>), } -pub type ConEventSender = mpsc::Sender<(ChannelNum, QueuedMethod)>; -pub type ConEventReceiver = mpsc::Receiver<(ChannelNum, QueuedMethod)>; +pub type ConEventSender = mpsc::Sender; +pub type ConEventReceiver = mpsc::Receiver; impl ConnectionInner { #[must_use] @@ -71,7 +72,7 @@ impl ConnectionInner { id: ConnectionId, peer_addr: SocketAddr, global_data: GlobalData, - method_queue: ConEventSender, + event_sender: ConEventSender, ) -> Connection { Arc::new(Self { id, @@ -79,7 +80,7 @@ impl ConnectionInner { global_data, channels: Mutex::new(HashMap::new()), exclusive_queues: vec![], - _events: method_queue, + event_sender, }) } @@ -97,7 +98,7 @@ pub struct ChannelInner { pub num: ChannelNum, pub connection: Connection, pub global_data: GlobalData, - method_queue: ConEventSender, + pub event_sender: ConEventSender, } impl ChannelInner { @@ -114,7 +115,7 @@ impl ChannelInner { num, connection, global_data, - method_queue, + event_sender: method_queue, }) } @@ -122,13 +123,6 @@ impl ChannelInner { let mut global_data = self.global_data.lock(); global_data.channels.remove(&self.id); } - - pub fn queue_method(&self, method: QueuedMethod) { - // todo: this is a horrible hack around the lock chaos - self.method_queue - .try_send((self.num, method)) - .expect("could not send method to channel, RIP"); - } } /// A content frame header. diff --git a/amqp_messaging/src/methods/publish.rs b/amqp_messaging/src/methods/publish.rs index 301add2..32958ae 100644 --- a/amqp_messaging/src/methods/publish.rs +++ b/amqp_messaging/src/methods/publish.rs @@ -1,7 +1,7 @@ use crate::Result; use amqp_core::{ amqp_todo, - connection::{Channel, QueuedMethod}, + connection::{Channel, ConnectionEvent}, error::ChannelException, message::Message, methods::{BasicPublish, Method}, @@ -31,19 +31,24 @@ pub async fn publish(channel_handle: Channel, message: Message) -> Result<()> { // consuming is hard, but this should work *for now* let consumers = queue.consumers.lock(); if let Some(consumer) = consumers.first() { - let method = Method::BasicPublish(BasicPublish { + let method = Box::new(Method::BasicPublish(BasicPublish { reserved_1: 0, exchange: routing.exchange.clone(), routing_key: routing.routing_key.clone(), mandatory: false, immediate: false, - }); + })); - consumer.channel.queue_method(QueuedMethod::WithContent( - method, - message.header.clone(), - message.content.clone(), - )); + consumer + .channel + .event_sender + .try_send(ConnectionEvent::MethodContent( + consumer.channel.num, + method, + message.header.clone(), + message.content.clone(), + )) + .unwrap(); } } diff --git a/amqp_transport/src/connection.rs b/amqp_transport/src/connection.rs index c57436b..0b40d94 100644 --- a/amqp_transport/src/connection.rs +++ b/amqp_transport/src/connection.rs @@ -8,17 +8,17 @@ use amqp_core::{ amqp_todo, connection::{ Channel, ChannelInner, ChannelNum, ConEventReceiver, ConEventSender, Connection, - ConnectionId, ContentHeader, QueuedMethod, + ConnectionEvent, ConnectionId, ContentHeader, }, message::{MessageId, RawMessage, RoutingInformation}, methods::{ BasicPublish, ChannelClose, ChannelCloseOk, ChannelOpenOk, ConnectionClose, ConnectionCloseOk, ConnectionOpen, ConnectionOpenOk, ConnectionStart, ConnectionStartOk, - ConnectionTune, ConnectionTuneOk, FieldValue, Method, Table, + ConnectionTune, ConnectionTuneOk, FieldValue, Method, ReplyCode, ReplyText, Table, }, GlobalData, }; -use anyhow::Context; +use anyhow::{anyhow, Context}; use bytes::Bytes; use smallvec::SmallVec; use std::{ @@ -64,9 +64,10 @@ pub struct TransportConnection { channels: HashMap, global_con: Connection, global_data: GlobalData, - - method_queue_send: ConEventSender, - method_queue_recv: ConEventReceiver, + /// Only here to forward to other futures so they can send events + event_sender: ConEventSender, + /// To receive events from other futures + event_receiver: ConEventReceiver, } const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); @@ -102,8 +103,8 @@ impl TransportConnection { global_con, channels: HashMap::with_capacity(4), global_data, - method_queue_send, - method_queue_recv, + event_sender: method_queue_send, + event_receiver: method_queue_recv, } } @@ -117,27 +118,12 @@ impl TransportConnection { } Err(TransError::Protocol(ProtocolError::ConException(ex))) => { warn!(%ex, "Connection exception occurred. This indicates a faulty client."); - if let Err(err) = self - .send_method( - ChannelNum::zero(), - Method::ConnectionClose(ConnectionClose { - reply_code: ex.reply_code(), - reply_text: ex.reply_text(), - class_id: 0, // todo: do this - method_id: 0, - }), - ) - .await - { - error!(%ex, %err, "Failed to close connection after ConnectionException"); - } - match self.recv_method().await { - Ok(Method::ConnectionCloseOk(_)) => {} - Ok(method) => { - error!(%ex, ?method, "Received wrong method after ConnectionException") - } + let close_result = self.close(ex.reply_code(), ex.reply_text()).await; + + match close_result { + Ok(()) => {} Err(err) => { - error!(%ex, %err, "Failed to receive Connection.CloseOk method after ConnectionException") + error!(%ex, %err, "Failed to close connection after ConnectionException"); } } } @@ -161,7 +147,7 @@ impl TransportConnection { async fn send_method_content( &mut self, channel: ChannelNum, - method: Method, + method: &Method, _header: ContentHeader, _body: SmallVec<[Bytes; 1]>, ) -> Result<()> { @@ -169,7 +155,7 @@ impl TransportConnection { amqp_todo!() } - async fn send_method(&mut self, channel: ChannelNum, method: Method) -> Result<()> { + async fn send_method(&mut self, channel: ChannelNum, method: &Method) -> Result<()> { trace!(%channel, ?method, "Sending method"); let mut payload = Vec::with_capacity(64); @@ -208,7 +194,7 @@ impl TransportConnection { }); debug!(?start_method, "Sending Start method"); - self.send_method(ChannelNum::zero(), start_method).await?; + self.send_method(ChannelNum::zero(), &start_method).await?; let start_ok = self.recv_method().await?; debug!(?start_ok, "Received Start-Ok"); @@ -239,7 +225,7 @@ impl TransportConnection { }); debug!("Sending Tune method"); - self.send_method(ChannelNum::zero(), tune_method).await?; + self.send_method(ChannelNum::zero(), &tune_method).await?; let tune_ok = self.recv_method().await?; debug!(?tune_ok, "Received Tune-Ok method"); @@ -269,8 +255,8 @@ impl TransportConnection { self.send_method( ChannelNum::zero(), - Method::ConnectionOpenOk(ConnectionOpenOk { - reserved_1: "".to_string(), + &Method::ConnectionOpenOk(ConnectionOpenOk { + reserved_1: "".to_owned(), }), ) .await?; @@ -285,10 +271,11 @@ impl TransportConnection { let frame = frame?; self.handle_frame(frame).await?; } - queued_method = self.method_queue_recv.recv() => { + queued_method = self.event_receiver.recv() => { match queued_method { - Some((channel, QueuedMethod::Normal(method))) => self.send_method(channel, method).await?, - Some((channel, QueuedMethod::WithContent(method, header, body))) => self.send_method_content(channel, method, header, body).await?, + Some(ConnectionEvent::Method(channel, method)) => self.send_method(channel, &method).await?, + Some(ConnectionEvent::MethodContent(channel, method, header, body)) => self.send_method_content(channel, &method, header, body).await?, + Some(ConnectionEvent::Shutdown) => return self.close(0, "".to_owned()).await, None => {} } } @@ -315,7 +302,7 @@ impl TransportConnection { warn!(%ex, "Channel exception occurred"); self.send_method( channel, - Method::ChannelClose(ChannelClose { + &Method::ChannelClose(ChannelClose { reply_code: ex.reply_code(), reply_text: ex.reply_text(), class_id: 0, // todo: do this @@ -349,7 +336,7 @@ impl TransportConnection { info!(%reply_code, %reply_text, %class_id, %method_id, "Closing connection"); self.send_method( ChannelNum::zero(), - Method::ConnectionCloseOk(ConnectionCloseOk), + &Method::ConnectionCloseOk(ConnectionCloseOk), ) .await?; return Err(ProtocolError::GracefullyClosed.into()); @@ -375,7 +362,7 @@ impl TransportConnection { // maybe this might become an `Option` in the future let return_method = amqp_messaging::methods::handle_method(channel_handle, method).await?; - self.send_method(frame.channel, return_method).await?; + self.send_method(frame.channel, &return_method).await?; } } Ok(()) @@ -489,7 +476,7 @@ impl TransportConnection { channel_num, self.global_con.clone(), self.global_data.clone(), - self.method_queue_send.clone(), + self.event_sender.clone(), ); let channel = TransportChannel { @@ -519,7 +506,7 @@ impl TransportConnection { self.send_method( channel_num, - Method::ChannelOpenOk(ChannelOpenOk { + &Method::ChannelOpenOk(ChannelOpenOk { reserved_1: Vec::new(), }), ) @@ -539,7 +526,7 @@ impl TransportConnection { if let Some(channel) = self.channels.remove(&channel_id) { drop(channel); - self.send_method(channel_id, Method::ChannelCloseOk(ChannelCloseOk)) + self.send_method(channel_id, &Method::ChannelCloseOk(ChannelCloseOk)) .await?; } else { return Err(ConException::Todo.into()); @@ -598,6 +585,33 @@ impl TransportConnection { Err(ProtocolError::ProtocolNegotiationFailed.into()) } } + + async fn close(&mut self, reply_code: ReplyCode, reply_text: ReplyText) -> Result<()> { + self.send_method( + ChannelNum::zero(), + &Method::ConnectionClose(ConnectionClose { + reply_code, + reply_text, + class_id: 0, // todo: do this + method_id: 0, + }), + ) + .await?; + + match self.recv_method().await { + Ok(Method::ConnectionCloseOk(_)) => Ok(()), + Ok(method) => { + return Err(TransError::Other(anyhow!( + "Received wrong method after closing, method: {method:?}" + ))); + } + Err(err) => { + return Err(TransError::Other(anyhow!( + "Failed to receive Connection.CloseOk method after closing, err: {err}" + ))); + } + } + } } impl Drop for TransportConnection { @@ -619,12 +633,12 @@ fn server_properties(host: SocketAddr) -> Table { let host_str = host.ip().to_string(); HashMap::from([ - ("host".to_string(), ls(&host_str)), - ("product".to_string(), ls("no name yet")), - ("version".to_string(), ls("0.1.0")), - ("platform".to_string(), ls("microsoft linux")), - ("copyright".to_string(), ls("MIT")), - ("information".to_string(), ls("hello reader")), - ("uwu".to_string(), ls("owo")), + ("host".to_owned(), ls(&host_str)), + ("product".to_owned(), ls("no name yet")), + ("version".to_owned(), ls("0.1.0")), + ("platform".to_owned(), ls("microsoft linux")), + ("copyright".to_owned(), ls("MIT")), + ("information".to_owned(), ls("hello reader")), + ("uwu".to_owned(), ls("owo")), ]) } diff --git a/amqp_transport/src/frame.rs b/amqp_transport/src/frame.rs index 68c7d73..55812e4 100644 --- a/amqp_transport/src/frame.rs +++ b/amqp_transport/src/frame.rs @@ -119,10 +119,10 @@ pub fn parse_content_header(input: &[u8]) -> Result { match content_header_parse::header(input) { Ok(([], header)) => Ok(header), Ok((_, _)) => { - Err(ConException::SyntaxError(vec!["could not consume all input".to_string()]).into()) + Err(ConException::SyntaxError(vec!["could not consume all input".to_owned()]).into()) } Err(nom::Err::Incomplete(_)) => { - Err(ConException::SyntaxError(vec!["there was not enough data".to_string()]).into()) + Err(ConException::SyntaxError(vec!["there was not enough data".to_owned()]).into()) } Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err), } diff --git a/amqp_transport/src/methods/generated.rs b/amqp_transport/src/methods/generated.rs index b04473c..b4a8758 100644 --- a/amqp_transport/src/methods/generated.rs +++ b/amqp_transport/src/methods/generated.rs @@ -890,7 +890,7 @@ pub mod write { use amqp_core::methods::*; use std::io::Write; - pub fn write_method(method: Method, mut writer: W) -> Result<(), TransError> { + pub fn write_method(method: &Method, mut writer: W) -> Result<(), TransError> { match method { Method::ConnectionStart(ConnectionStart { version_major, diff --git a/amqp_transport/src/methods/mod.rs b/amqp_transport/src/methods/mod.rs index c921b44..a317451 100644 --- a/amqp_transport/src/methods/mod.rs +++ b/amqp_transport/src/methods/mod.rs @@ -20,10 +20,10 @@ pub fn parse_method(payload: &[u8]) -> Result { match nom_result { Ok(([], method)) => Ok(method), Ok((_, _)) => { - Err(ConException::SyntaxError(vec!["could not consume all input".to_string()]).into()) + Err(ConException::SyntaxError(vec!["could not consume all input".to_owned()]).into()) } Err(nom::Err::Incomplete(_)) => { - Err(ConException::SyntaxError(vec!["there was not enough data".to_string()]).into()) + Err(ConException::SyntaxError(vec!["there was not enough data".to_owned()]).into()) } Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err), } diff --git a/amqp_transport/src/methods/tests.rs b/amqp_transport/src/methods/tests.rs index 571ba3c..fa32001 100644 --- a/amqp_transport/src/methods/tests.rs +++ b/amqp_transport/src/methods/tests.rs @@ -67,11 +67,8 @@ fn random_ser_de() { #[test] fn nested_table() { let table = HashMap::from([( - "A".to_string(), - FieldValue::FieldTable(HashMap::from([( - "B".to_string(), - FieldValue::Boolean(true), - )])), + "A".to_owned(), + FieldValue::FieldTable(HashMap::from([("B".to_owned(), FieldValue::Boolean(true))])), )]); eprintln!("{table:?}"); diff --git a/amqp_transport/src/methods/write_helper.rs b/amqp_transport/src/methods/write_helper.rs index 89100c3..fb5e318 100644 --- a/amqp_transport/src/methods/write_helper.rs +++ b/amqp_transport/src/methods/write_helper.rs @@ -3,27 +3,27 @@ use amqp_core::methods::{Bit, Long, Longlong, Longstr, Octet, Short, Shortstr, T use anyhow::Context; use std::io::Write; -pub fn octet(value: Octet, writer: &mut W) -> Result<(), TransError> { - writer.write_all(&[value])?; +pub fn octet(value: &Octet, writer: &mut W) -> Result<(), TransError> { + writer.write_all(&[*value])?; Ok(()) } -pub fn short(value: Short, writer: &mut W) -> Result<(), TransError> { +pub fn short(value: &Short, writer: &mut W) -> Result<(), TransError> { writer.write_all(&value.to_be_bytes())?; Ok(()) } -pub fn long(value: Long, writer: &mut W) -> Result<(), TransError> { +pub fn long(value: &Long, writer: &mut W) -> Result<(), TransError> { writer.write_all(&value.to_be_bytes())?; Ok(()) } -pub fn longlong(value: Longlong, writer: &mut W) -> Result<(), TransError> { +pub fn longlong(value: &Longlong, writer: &mut W) -> Result<(), TransError> { writer.write_all(&value.to_be_bytes())?; Ok(()) } -pub fn bit(value: &[Bit], writer: &mut W) -> Result<(), TransError> { +pub fn bit(value: &[&Bit], writer: &mut W) -> Result<(), TransError> { // accumulate bits into bytes, starting from the least significant bit in each byte // how many bits have already been packed into `current_buf` @@ -37,7 +37,7 @@ pub fn bit(value: &[Bit], writer: &mut W) -> Result<(), TransError> { already_filled = 0; } - let new_bit = (u8::from(bit)) << already_filled; + let new_bit = (u8::from(*bit)) << already_filled; current_buf |= new_bit; already_filled += 1; } @@ -49,7 +49,7 @@ pub fn bit(value: &[Bit], writer: &mut W) -> Result<(), TransError> { Ok(()) } -pub fn shortstr(value: Shortstr, writer: &mut W) -> Result<(), TransError> { +pub fn shortstr(value: &Shortstr, writer: &mut W) -> Result<(), TransError> { let len = u8::try_from(value.len()).context("shortstr too long")?; writer.write_all(&[len])?; writer.write_all(value.as_bytes())?; @@ -57,7 +57,7 @@ pub fn shortstr(value: Shortstr, writer: &mut W) -> Result<(), TransEr Ok(()) } -pub fn longstr(value: Longstr, writer: &mut W) -> Result<(), TransError> { +pub fn longstr(value: &Longstr, writer: &mut W) -> Result<(), TransError> { let len = u32::try_from(value.len()).context("longstr too long")?; writer.write_all(&len.to_be_bytes())?; writer.write_all(value.as_slice())?; @@ -67,12 +67,12 @@ pub fn longstr(value: Longstr, writer: &mut W) -> Result<(), TransErro // this appears to be unused right now, but it could be used in `Basic` things? #[allow(dead_code)] -pub fn timestamp(value: Timestamp, writer: &mut W) -> Result<(), TransError> { +pub fn timestamp(value: &Timestamp, writer: &mut W) -> Result<(), TransError> { writer.write_all(&value.to_be_bytes())?; Ok(()) } -pub fn table(table: Table, writer: &mut W) -> Result<(), TransError> { +pub fn table(table: &Table, writer: &mut W) -> Result<(), TransError> { let mut table_buf = Vec::new(); for (field_name, value) in table { @@ -87,17 +87,17 @@ pub fn table(table: Table, writer: &mut W) -> Result<(), TransError> { Ok(()) } -fn field_value(value: FieldValue, writer: &mut W) -> Result<(), TransError> { +fn field_value(value: &FieldValue, writer: &mut W) -> Result<(), TransError> { match value { FieldValue::Boolean(bool) => { - writer.write_all(&[b't', u8::from(bool)])?; + writer.write_all(&[b't', u8::from(*bool)])?; } FieldValue::ShortShortInt(int) => { writer.write_all(b"b")?; writer.write_all(&int.to_be_bytes())?; } FieldValue::ShortShortUInt(int) => { - writer.write_all(&[b'B', int])?; + writer.write_all(&[b'B', *int])?; } FieldValue::ShortInt(int) => { writer.write_all(b"U")?; @@ -132,7 +132,7 @@ fn field_value(value: FieldValue, writer: &mut W) -> Result<(), TransE writer.write_all(&float.to_be_bytes())?; } FieldValue::DecimalValue(scale, long) => { - writer.write_all(&[b'D', scale])?; + writer.write_all(&[b'D', *scale])?; writer.write_all(&long.to_be_bytes())?; } FieldValue::ShortString(str) => { @@ -174,7 +174,7 @@ mod tests { let bits = [true, false, true]; let mut buffer = [0u8; 1]; - super::bit(&bits, &mut buffer.as_mut_slice()).unwrap(); + super::bit(&bits.map(|b| &b), &mut buffer.as_mut_slice()).unwrap(); assert_eq!(buffer, [0b00000101]) } @@ -188,7 +188,7 @@ mod tests { ]; let mut buffer = [0u8; 2]; - super::bit(&bits, &mut buffer.as_mut_slice()).unwrap(); + super::bit(&bits.map(|b| &b), &mut buffer.as_mut_slice()).unwrap(); assert_eq!(buffer, [0b00001111, 0b00001101]); } diff --git a/amqp_transport/src/tests.rs b/amqp_transport/src/tests.rs index 13425fe..334b147 100644 --- a/amqp_transport/src/tests.rs +++ b/amqp_transport/src/tests.rs @@ -12,7 +12,7 @@ async fn write_start_ok_frame() { version_major: 0, version_minor: 9, server_properties: HashMap::from([( - "product".to_string(), + "product".to_owned(), FieldValue::LongString("no name yet".into()), )]), mechanisms: "PLAIN".into(), @@ -145,41 +145,38 @@ fn read_start_ok_payload() { Method::ConnectionStartOk(ConnectionStartOk { client_properties: HashMap::from([ ( - "product".to_string(), + "product".to_owned(), FieldValue::LongString("Pika Python Client Library".into()) ), ( - "platform".to_string(), + "platform".to_owned(), FieldValue::LongString("Python 3.8.10".into()) ), ( - "capabilities".to_string(), + "capabilities".to_owned(), FieldValue::FieldTable(HashMap::from([ ( - "authentication_failure_close".to_string(), + "authentication_failure_close".to_owned(), FieldValue::Boolean(true) ), - ("basic.nack".to_string(), FieldValue::Boolean(true)), - ("connection.blocked".to_string(), FieldValue::Boolean(true)), + ("basic.nack".to_owned(), FieldValue::Boolean(true)), + ("connection.blocked".to_owned(), FieldValue::Boolean(true)), ( - "consumer_cancel_notify".to_string(), + "consumer_cancel_notify".to_owned(), FieldValue::Boolean(true) ), - ("publisher_confirms".to_string(), FieldValue::Boolean(true)), + ("publisher_confirms".to_owned(), FieldValue::Boolean(true)), ])) ), ( - "information".to_string(), + "information".to_owned(), FieldValue::LongString("See http://pika.rtfd.org".into()) ), - ( - "version".to_string(), - FieldValue::LongString("1.1.0".into()) - ) + ("version".to_owned(), FieldValue::LongString("1.1.0".into())) ]), - mechanism: "PLAIN".to_string(), + mechanism: "PLAIN".to_owned(), response: "\x00admin\x00".into(), - locale: "en_US".to_string() + locale: "en_US".to_owned() }) ); } diff --git a/xtask/src/codegen/mod.rs b/xtask/src/codegen/mod.rs index 3040271..5c672ea 100644 --- a/xtask/src/codegen/mod.rs +++ b/xtask/src/codegen/mod.rs @@ -305,7 +305,7 @@ pub struct {class_name}{method_name}" use heck::ToSnakeCase; if ident == "type" { - "r#type".to_string() + "r#type".to_owned() } else { ident.to_snake_case() } @@ -336,7 +336,7 @@ pub struct {class_name}{method_name}" fn invariants<'a>(&self, asserts: impl Iterator) -> String { asserts .map(|assert| match &*assert.check { - "notnull" => "must not be null".to_string(), + "notnull" => "must not be null".to_owned(), "length" => format!("must be shorter than {}", assert.value.as_ref().unwrap()), "regexp" => format!("must match `{}`", assert.value.as_ref().unwrap()), "le" => { @@ -354,7 +354,7 @@ pub struct {class_name}{method_name}" fn doc_comment(&mut self, docs: &[Doc], indent: usize) { for doc in docs { - if doc.kind == Some("grammar".to_string()) { + if doc.kind == Some("grammar".to_owned()) { continue; } for line in doc.text.lines() { diff --git a/xtask/src/codegen/write.rs b/xtask/src/codegen/write.rs index 46015f0..187e8c9 100644 --- a/xtask/src/codegen/write.rs +++ b/xtask/src/codegen/write.rs @@ -11,7 +11,7 @@ use crate::error::TransError; use crate::methods::write_helper::*; use std::io::Write; -pub fn write_method(method: Method, mut writer: W) -> Result<(), TransError> {{ +pub fn write_method(method: &Method, mut writer: W) -> Result<(), TransError> {{ match method {{" ) .ok();