diff --git a/amqp_core/src/connection.rs b/amqp_core/src/connection.rs new file mode 100644 index 0000000..615f670 --- /dev/null +++ b/amqp_core/src/connection.rs @@ -0,0 +1,99 @@ +use crate::{newtype_id, GlobalData, Handle, Queue}; +use parking_lot::Mutex; +use std::collections::HashMap; +use std::fmt::{Display, Formatter}; +use std::net::SocketAddr; +use std::sync::Arc; + +newtype_id!(pub ConnectionId); +newtype_id!(pub ChannelId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ChannelNum(u16); + +impl ChannelNum { + pub fn new(num: u16) -> Self { + Self(num) + } + + pub fn num(self) -> u16 { + self.0 + } + + pub fn is_zero(self) -> bool { + self.0 == 0 + } + + pub fn zero() -> Self { + Self(0) + } +} + +impl Display for ChannelNum { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +pub type ConnectionHandle = Handle; + +#[derive(Debug)] +pub struct Connection { + pub id: ConnectionId, + pub peer_addr: SocketAddr, + pub global_data: GlobalData, + pub channels: HashMap, + pub exclusive_queues: Vec, +} + +impl Connection { + pub fn new_handle( + id: ConnectionId, + peer_addr: SocketAddr, + global_data: GlobalData, + ) -> ConnectionHandle { + Arc::new(Mutex::new(Self { + id, + peer_addr, + global_data, + channels: HashMap::new(), + exclusive_queues: vec![], + })) + } + + pub fn close(&self) { + let mut global_data = self.global_data.lock(); + global_data.connections.remove(&self.id); + } +} + +pub type ChannelHandle = Handle; + +#[derive(Debug)] +pub struct Channel { + pub id: ChannelId, + pub num: u16, + pub connection: ConnectionHandle, + pub global_data: GlobalData, +} + +impl Channel { + pub fn new_handle( + id: ChannelId, + num: u16, + connection: ConnectionHandle, + global_data: GlobalData, + ) -> ChannelHandle { + Arc::new(Mutex::new(Self { + id, + num, + connection, + global_data, + })) + } + + pub fn close(&self) { + let mut global_data = self.global_data.lock(); + global_data.channels.remove(&self.id); + } +} diff --git a/amqp_core/src/lib.rs b/amqp_core/src/lib.rs index 1901e88..90d1480 100644 --- a/amqp_core/src/lib.rs +++ b/amqp_core/src/lib.rs @@ -1,16 +1,18 @@ #![warn(rust_2018_idioms)] +pub mod connection; pub mod error; +mod macros; pub mod message; pub mod methods; pub mod queue; -use crate::queue::Queue; +use crate::connection::{ChannelHandle, ConnectionHandle}; +use crate::queue::{Queue, QueueId}; +use connection::{ChannelId, ConnectionId}; use parking_lot::Mutex; use std::collections::HashMap; -use std::net::SocketAddr; use std::sync::Arc; -use uuid::Uuid; type Handle = Arc>; @@ -40,83 +42,9 @@ impl GlobalData { #[derive(Debug)] pub struct GlobalDataInner { - pub connections: HashMap, - pub channels: HashMap, - pub queues: HashMap, + pub connections: HashMap, + pub channels: HashMap, + pub queues: HashMap, /// Todo: This is just for testing and will be removed later! pub default_exchange: HashMap, } - -pub type ConnectionHandle = Handle; - -#[derive(Debug)] -pub struct Connection { - pub id: Uuid, - pub peer_addr: SocketAddr, - pub global_data: GlobalData, - pub channels: HashMap, - pub exclusive_queues: Vec, -} - -impl Connection { - pub fn new_handle( - id: Uuid, - peer_addr: SocketAddr, - global_data: GlobalData, - ) -> ConnectionHandle { - Arc::new(Mutex::new(Self { - id, - peer_addr, - global_data, - channels: HashMap::new(), - exclusive_queues: vec![], - })) - } - - pub fn close(&self) { - let mut global_data = self.global_data.lock(); - global_data.connections.remove(&self.id); - } -} - -pub type ChannelHandle = Handle; - -#[derive(Debug)] -pub struct Channel { - pub id: Uuid, - pub num: u16, - pub connection: ConnectionHandle, - pub global_data: GlobalData, -} - -impl Channel { - pub fn new_handle( - id: Uuid, - num: u16, - connection: ConnectionHandle, - global_data: GlobalData, - ) -> ChannelHandle { - Arc::new(Mutex::new(Self { - id, - num, - connection, - global_data, - })) - } - - pub fn close(&self) { - let mut global_data = self.global_data.lock(); - global_data.channels.remove(&self.id); - } -} - -pub fn gen_uuid() -> Uuid { - Uuid::from_bytes(rand::random()) -} - -#[macro_export] -macro_rules! amqp_todo { - () => { - return Err(::amqp_core::error::ConException::NotImplemented.into()) - }; -} diff --git a/amqp_core/src/macros.rs b/amqp_core/src/macros.rs new file mode 100644 index 0000000..2a4c2bd --- /dev/null +++ b/amqp_core/src/macros.rs @@ -0,0 +1,32 @@ +#[macro_export] +macro_rules! newtype_id { + ($vis:vis $name:ident) => { + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + $vis struct $name(::uuid::Uuid); + + impl $name { + pub fn random() -> Self { + ::rand::random() + } + } + + impl ::std::fmt::Display for $name { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + self.0.fmt(f) + } + } + + impl ::rand::prelude::Distribution<$name> for ::rand::distributions::Standard { + fn sample(&self, rng: &mut R) -> $name { + $name(::uuid::Uuid::from_bytes(rng.gen())) + } + } + }; +} + +#[macro_export] +macro_rules! amqp_todo { + () => { + return Err(::amqp_core::error::ConException::NotImplemented.into()) + }; +} diff --git a/amqp_core/src/message.rs b/amqp_core/src/message.rs index 2b190f2..190328e 100644 --- a/amqp_core/src/message.rs +++ b/amqp_core/src/message.rs @@ -1,16 +1,18 @@ #![allow(dead_code)] use crate::methods; +use crate::newtype_id; use bytes::Bytes; use smallvec::SmallVec; use std::sync::Arc; -use uuid::Uuid; pub type Message = Arc; +newtype_id!(pub MessageId); + #[derive(Debug)] pub struct RawMessage { - pub id: Uuid, + pub id: MessageId, pub properties: methods::Table, pub routing: RoutingInformation, pub content: SmallVec<[Bytes; 1]>, diff --git a/amqp_core/src/queue.rs b/amqp_core/src/queue.rs index ecfa9c9..9b75c13 100644 --- a/amqp_core/src/queue.rs +++ b/amqp_core/src/queue.rs @@ -1,18 +1,20 @@ use crate::message::Message; +use crate::{newtype_id, ChannelId}; use parking_lot::Mutex; use std::sync::atomic::AtomicUsize; use std::sync::Arc; -use uuid::Uuid; pub type Queue = Arc; +newtype_id!(pub QueueId); + #[derive(Debug)] pub struct RawQueue { - pub id: Uuid, + pub id: QueueId, pub name: String, pub messages: Mutex>, // use a concurrent linked list??? pub durable: bool, - pub exclusive: Option, + pub exclusive: Option, /// Whether the queue will automatically be deleted when no consumers uses it anymore. /// The queue can always be manually deleted. /// If auto-delete is enabled, it keeps track of the consumer count. diff --git a/amqp_messaging/src/methods/consume.rs b/amqp_messaging/src/methods/consume.rs index 7148705..c5cd341 100644 --- a/amqp_messaging/src/methods/consume.rs +++ b/amqp_messaging/src/methods/consume.rs @@ -1,6 +1,6 @@ +use amqp_core::connection::ChannelHandle; use amqp_core::error::ProtocolError; use amqp_core::methods::{Bit, ConsumerTag, NoAck, NoLocal, NoWait, QueueName, Table}; -use amqp_core::ChannelHandle; #[allow(clippy::too_many_arguments)] pub async fn consume( diff --git a/amqp_messaging/src/methods/mod.rs b/amqp_messaging/src/methods/mod.rs index 9a4c57d..478b23a 100644 --- a/amqp_messaging/src/methods/mod.rs +++ b/amqp_messaging/src/methods/mod.rs @@ -2,10 +2,10 @@ mod consume; mod queue; use amqp_core::amqp_todo; +use amqp_core::connection::ChannelHandle; use amqp_core::error::ProtocolError; use amqp_core::message::Message; use amqp_core::methods::Method; -use amqp_core::ChannelHandle; use tracing::info; pub async fn handle_basic_publish(_channel_handle: ChannelHandle, message: Message) { diff --git a/amqp_messaging/src/methods/queue.rs b/amqp_messaging/src/methods/queue.rs index e3ff92a..08a5fc4 100644 --- a/amqp_messaging/src/methods/queue.rs +++ b/amqp_messaging/src/methods/queue.rs @@ -1,14 +1,13 @@ #![deny(clippy::future_not_send)] +use amqp_core::connection::ChannelHandle; use amqp_core::error::{ConException, ProtocolError}; use amqp_core::methods::{Bit, ExchangeName, NoWait, QueueName, Shortstr, Table}; -use amqp_core::queue::{QueueDeletion, RawQueue}; -use amqp_core::ChannelHandle; +use amqp_core::queue::{QueueDeletion, QueueId, RawQueue}; use amqp_core::{amqp_todo, GlobalData}; use parking_lot::Mutex; use std::sync::atomic::AtomicUsize; use std::sync::Arc; -use uuid::Uuid; #[allow(clippy::too_many_arguments)] pub async fn declare( @@ -32,7 +31,7 @@ pub async fn declare( amqp_todo!(); } - let id = amqp_core::gen_uuid(); + let id = QueueId::random(); let queue = Arc::new(RawQueue { id, name: queue_name.clone(), @@ -72,7 +71,7 @@ pub async fn bind( async fn bind_queue( _global_data: GlobalData, - _queue: Uuid, + _queue: QueueId, _exchange: (), _routing_key: String, ) -> Result<(), ProtocolError> { diff --git a/amqp_transport/src/connection.rs b/amqp_transport/src/connection.rs index fd20366..4268a25 100644 --- a/amqp_transport/src/connection.rs +++ b/amqp_transport/src/connection.rs @@ -1,7 +1,8 @@ use crate::error::{ConException, ProtocolError, Result}; -use crate::frame::{ChannelId, ContentHeader, Frame, FrameType}; +use crate::frame::{ContentHeader, Frame, FrameType}; use crate::{frame, methods, sasl}; -use amqp_core::message::{RawMessage, RoutingInformation}; +use amqp_core::connection::{ChannelHandle, ChannelNum, ConnectionHandle, ConnectionId}; +use amqp_core::message::{MessageId, RawMessage, RoutingInformation}; use amqp_core::methods::{FieldValue, Method, Table}; use amqp_core::GlobalData; use anyhow::Context; @@ -17,7 +18,6 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::time; use tracing::{debug, error, info, warn}; -use uuid::Uuid; fn ensure_conn(condition: bool) -> Result<()> { if condition { @@ -36,21 +36,21 @@ const BASIC_CLASS_ID: u16 = 60; pub struct Channel { /// A handle to the global channel representation. Used to remove the channel when it's dropped - handle: amqp_core::ChannelHandle, + handle: ChannelHandle, /// The current status of the channel, whether it has sent a method that expects a body status: ChannelStatus, } pub struct Connection { - id: Uuid, + id: ConnectionId, stream: TcpStream, max_frame_size: usize, heartbeat_delay: u16, channel_max: u16, /// When the next heartbeat expires next_timeout: Pin>, - channels: HashMap, - handle: amqp_core::ConnectionHandle, + channels: HashMap, + handle: ConnectionHandle, global_data: GlobalData, } @@ -71,9 +71,9 @@ impl ChannelStatus { impl Connection { pub fn new( - id: Uuid, + id: ConnectionId, stream: TcpStream, - connection_handle: amqp_core::ConnectionHandle, + connection_handle: ConnectionHandle, global_data: GlobalData, ) -> Self { Self { @@ -110,7 +110,7 @@ impl Connection { self.main_loop().await } - async fn send_method(&mut self, channel: ChannelId, method: Method) -> Result<()> { + async fn send_method(&mut self, channel: ChannelNum, method: Method) -> Result<()> { let mut payload = Vec::with_capacity(64); methods::write::write_method(method, &mut payload)?; frame::write_frame( @@ -147,7 +147,7 @@ impl Connection { }; debug!(?start_method, "Sending Start method"); - self.send_method(ChannelId::zero(), start_method).await?; + self.send_method(ChannelNum::zero(), start_method).await?; let start_ok = self.recv_method().await?; debug!(?start_ok, "Received Start-Ok"); @@ -178,7 +178,7 @@ impl Connection { }; debug!("Sending Tune method"); - self.send_method(ChannelId::zero(), tune_method).await?; + self.send_method(ChannelNum::zero(), tune_method).await?; let tune_ok = self.recv_method().await?; debug!(?tune_ok, "Received Tune-Ok method"); @@ -207,7 +207,7 @@ impl Connection { } self.send_method( - ChannelId::zero(), + ChannelNum::zero(), Method::ConnectionOpenOk { reserved_1: "".to_string(), }, @@ -249,7 +249,7 @@ impl Connection { method_id, } => { info!(%reply_code, %reply_text, %class_id, %method_id, "Closing connection"); - self.send_method(ChannelId::zero(), Method::ConnectionCloseOk {}) + self.send_method(ChannelNum::zero(), Method::ConnectionCloseOk {}) .await?; return Err(ProtocolError::GracefulClose.into()); } @@ -339,7 +339,7 @@ impl Connection { method: Method, header: ContentHeader, payloads: SmallVec<[Bytes; 1]>, - channel: ChannelId, + channel: ChannelNum, ) -> Result<()> { // The only method with content that is sent to the server is Basic.Publish. ensure_conn(header.class_id == BASIC_CLASS_ID)?; @@ -353,7 +353,7 @@ impl Connection { } = method { let message = RawMessage { - id: amqp_core::gen_uuid(), + id: MessageId::random(), properties: header.property_fields, routing: RoutingInformation { exchange, @@ -379,11 +379,11 @@ impl Connection { } } - async fn channel_open(&mut self, channel_id: ChannelId) -> Result<()> { - let id = amqp_core::gen_uuid(); - let channel_handle = amqp_core::Channel::new_handle( + async fn channel_open(&mut self, channel_num: ChannelNum) -> Result<()> { + let id = rand::random(); + let channel_handle = amqp_core::connection::Channel::new_handle( id, - channel_id.num(), + channel_num.num(), self.handle.clone(), self.global_data.clone(), ); @@ -393,9 +393,9 @@ impl Connection { status: ChannelStatus::Default, }; - let prev = self.channels.insert(channel_id, channel); + let prev = self.channels.insert(channel_num, channel); if let Some(prev) = prev { - self.channels.insert(channel_id, prev); // restore previous state + self.channels.insert(channel_num, prev); // restore previous state return Err(ConException::ChannelError.into()); } @@ -408,13 +408,13 @@ impl Connection { .unwrap() .lock() .channels - .insert(channel_id.num(), channel_handle); + .insert(channel_num.num(), channel_handle); } - info!(%channel_id, "Opened new channel"); + info!(%channel_num, "Opened new channel"); self.send_method( - channel_id, + channel_num, Method::ChannelOpenOk { reserved_1: Vec::new(), }, @@ -424,7 +424,7 @@ impl Connection { Ok(()) } - async fn channel_close(&mut self, channel_id: ChannelId, method: Method) -> Result<()> { + async fn channel_close(&mut self, channel_id: ChannelNum, method: Method) -> Result<()> { if let Method::ChannelClose { reply_code: code, reply_text: reason, diff --git a/amqp_transport/src/frame.rs b/amqp_transport/src/frame.rs index b4f71ff..e957afb 100644 --- a/amqp_transport/src/frame.rs +++ b/amqp_transport/src/frame.rs @@ -1,36 +1,13 @@ use crate::error::{ConException, ProtocolError, Result}; +use amqp_core::connection::ChannelNum; use amqp_core::methods; use anyhow::Context; use bytes::Bytes; -use std::fmt::{Display, Formatter}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tracing::trace; const REQUIRED_FRAME_END: u8 = 0xCE; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct ChannelId(u16); - -impl ChannelId { - pub fn num(self) -> u16 { - self.0 - } - - pub fn is_zero(self) -> bool { - self.0 == 0 - } - - pub fn zero() -> Self { - Self(0) - } -} - -impl Display for ChannelId { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) - } -} - mod frame_type { pub const METHOD: u8 = 1; pub const HEADER: u8 = 2; @@ -42,7 +19,7 @@ mod frame_type { pub struct Frame { /// The type of the frame including its parsed metadata. pub kind: FrameType, - pub channel: ChannelId, + pub channel: ChannelNum, /// Includes the whole payload, also including the metadata from each type. pub payload: Bytes, } @@ -181,7 +158,7 @@ where { let kind = r.read_u8().await.context("read type")?; let channel = r.read_u16().await.context("read channel")?; - let channel = ChannelId(channel); + let channel = ChannelNum::new(channel); let size = r.read_u32().await.context("read size")?; let mut payload = vec![0; size.try_into().unwrap()]; @@ -210,7 +187,7 @@ where Ok(frame) } -fn parse_frame_type(kind: u8, channel: ChannelId) -> Result { +fn parse_frame_type(kind: u8, channel: ChannelNum) -> Result { match kind { frame_type::METHOD => Ok(FrameType::Method), frame_type::HEADER => Ok(FrameType::Header), @@ -228,7 +205,7 @@ fn parse_frame_type(kind: u8, channel: ChannelId) -> Result { #[cfg(test)] mod tests { - use crate::frame::{ChannelId, Frame, FrameType}; + use crate::frame::{ChannelNum, Frame, FrameType}; use bytes::Bytes; #[tokio::test] @@ -257,7 +234,7 @@ mod tests { frame, Frame { kind: FrameType::Method, - channel: ChannelId(0), + channel: ChannelNum::new(0), payload: Bytes::from_static(&[1, 2, 3]), } ); diff --git a/amqp_transport/src/lib.rs b/amqp_transport/src/lib.rs index 9705bc6..ad498fa 100644 --- a/amqp_transport/src/lib.rs +++ b/amqp_transport/src/lib.rs @@ -24,13 +24,13 @@ pub async fn do_thing_i_guess(global_data: GlobalData) -> Result<()> { loop { let (stream, peer_addr) = listener.accept().await?; - let id = amqp_core::gen_uuid(); + let id = rand::random(); info!(local_addr = ?stream.local_addr(), %id, "Accepted new connection"); let span = info_span!("client-connection", %id); let connection_handle = - amqp_core::Connection::new_handle(id, peer_addr, global_data.clone()); + amqp_core::connection::Connection::new_handle(id, peer_addr, global_data.clone()); let mut global_data_guard = global_data.lock(); global_data_guard diff --git a/amqp_transport/src/tests.rs b/amqp_transport/src/tests.rs index 1920fe1..fd89146 100644 --- a/amqp_transport/src/tests.rs +++ b/amqp_transport/src/tests.rs @@ -1,4 +1,4 @@ -use crate::frame::{ChannelId, FrameType}; +use crate::frame::{ChannelNum, FrameType}; use crate::{frame, methods}; use amqp_core::methods::{FieldValue, Method}; use std::collections::HashMap; @@ -21,7 +21,7 @@ async fn write_start_ok_frame() { let frame = frame::Frame { kind: FrameType::Method, - channel: ChannelId::zero(), + channel: ChannelNum::zero(), payload: payload.into(), };