some cleanup

This commit is contained in:
nora 2022-02-26 22:26:35 +01:00
parent 8532d454c3
commit 6d944e1265
12 changed files with 190 additions and 151 deletions

View file

@ -0,0 +1,99 @@
use crate::{newtype_id, GlobalData, Handle, Queue};
use parking_lot::Mutex;
use std::collections::HashMap;
use std::fmt::{Display, Formatter};
use std::net::SocketAddr;
use std::sync::Arc;
newtype_id!(pub ConnectionId);
newtype_id!(pub ChannelId);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ChannelNum(u16);
impl ChannelNum {
pub fn new(num: u16) -> Self {
Self(num)
}
pub fn num(self) -> u16 {
self.0
}
pub fn is_zero(self) -> bool {
self.0 == 0
}
pub fn zero() -> Self {
Self(0)
}
}
impl Display for ChannelNum {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
pub type ConnectionHandle = Handle<Connection>;
#[derive(Debug)]
pub struct Connection {
pub id: ConnectionId,
pub peer_addr: SocketAddr,
pub global_data: GlobalData,
pub channels: HashMap<u16, ChannelHandle>,
pub exclusive_queues: Vec<Queue>,
}
impl Connection {
pub fn new_handle(
id: ConnectionId,
peer_addr: SocketAddr,
global_data: GlobalData,
) -> ConnectionHandle {
Arc::new(Mutex::new(Self {
id,
peer_addr,
global_data,
channels: HashMap::new(),
exclusive_queues: vec![],
}))
}
pub fn close(&self) {
let mut global_data = self.global_data.lock();
global_data.connections.remove(&self.id);
}
}
pub type ChannelHandle = Handle<Channel>;
#[derive(Debug)]
pub struct Channel {
pub id: ChannelId,
pub num: u16,
pub connection: ConnectionHandle,
pub global_data: GlobalData,
}
impl Channel {
pub fn new_handle(
id: ChannelId,
num: u16,
connection: ConnectionHandle,
global_data: GlobalData,
) -> ChannelHandle {
Arc::new(Mutex::new(Self {
id,
num,
connection,
global_data,
}))
}
pub fn close(&self) {
let mut global_data = self.global_data.lock();
global_data.channels.remove(&self.id);
}
}

View file

@ -1,16 +1,18 @@
#![warn(rust_2018_idioms)] #![warn(rust_2018_idioms)]
pub mod connection;
pub mod error; pub mod error;
mod macros;
pub mod message; pub mod message;
pub mod methods; pub mod methods;
pub mod queue; pub mod queue;
use crate::queue::Queue; use crate::connection::{ChannelHandle, ConnectionHandle};
use crate::queue::{Queue, QueueId};
use connection::{ChannelId, ConnectionId};
use parking_lot::Mutex; use parking_lot::Mutex;
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid;
type Handle<T> = Arc<Mutex<T>>; type Handle<T> = Arc<Mutex<T>>;
@ -40,83 +42,9 @@ impl GlobalData {
#[derive(Debug)] #[derive(Debug)]
pub struct GlobalDataInner { pub struct GlobalDataInner {
pub connections: HashMap<Uuid, ConnectionHandle>, pub connections: HashMap<ConnectionId, ConnectionHandle>,
pub channels: HashMap<Uuid, ChannelHandle>, pub channels: HashMap<ChannelId, ChannelHandle>,
pub queues: HashMap<Uuid, Queue>, pub queues: HashMap<QueueId, Queue>,
/// Todo: This is just for testing and will be removed later! /// Todo: This is just for testing and will be removed later!
pub default_exchange: HashMap<String, Queue>, pub default_exchange: HashMap<String, Queue>,
} }
pub type ConnectionHandle = Handle<Connection>;
#[derive(Debug)]
pub struct Connection {
pub id: Uuid,
pub peer_addr: SocketAddr,
pub global_data: GlobalData,
pub channels: HashMap<u16, ChannelHandle>,
pub exclusive_queues: Vec<Queue>,
}
impl Connection {
pub fn new_handle(
id: Uuid,
peer_addr: SocketAddr,
global_data: GlobalData,
) -> ConnectionHandle {
Arc::new(Mutex::new(Self {
id,
peer_addr,
global_data,
channels: HashMap::new(),
exclusive_queues: vec![],
}))
}
pub fn close(&self) {
let mut global_data = self.global_data.lock();
global_data.connections.remove(&self.id);
}
}
pub type ChannelHandle = Handle<Channel>;
#[derive(Debug)]
pub struct Channel {
pub id: Uuid,
pub num: u16,
pub connection: ConnectionHandle,
pub global_data: GlobalData,
}
impl Channel {
pub fn new_handle(
id: Uuid,
num: u16,
connection: ConnectionHandle,
global_data: GlobalData,
) -> ChannelHandle {
Arc::new(Mutex::new(Self {
id,
num,
connection,
global_data,
}))
}
pub fn close(&self) {
let mut global_data = self.global_data.lock();
global_data.channels.remove(&self.id);
}
}
pub fn gen_uuid() -> Uuid {
Uuid::from_bytes(rand::random())
}
#[macro_export]
macro_rules! amqp_todo {
() => {
return Err(::amqp_core::error::ConException::NotImplemented.into())
};
}

32
amqp_core/src/macros.rs Normal file
View file

@ -0,0 +1,32 @@
#[macro_export]
macro_rules! newtype_id {
($vis:vis $name:ident) => {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
$vis struct $name(::uuid::Uuid);
impl $name {
pub fn random() -> Self {
::rand::random()
}
}
impl ::std::fmt::Display for $name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
self.0.fmt(f)
}
}
impl ::rand::prelude::Distribution<$name> for ::rand::distributions::Standard {
fn sample<R: ::rand::Rng + ?Sized>(&self, rng: &mut R) -> $name {
$name(::uuid::Uuid::from_bytes(rng.gen()))
}
}
};
}
#[macro_export]
macro_rules! amqp_todo {
() => {
return Err(::amqp_core::error::ConException::NotImplemented.into())
};
}

View file

@ -1,16 +1,18 @@
#![allow(dead_code)] #![allow(dead_code)]
use crate::methods; use crate::methods;
use crate::newtype_id;
use bytes::Bytes; use bytes::Bytes;
use smallvec::SmallVec; use smallvec::SmallVec;
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid;
pub type Message = Arc<RawMessage>; pub type Message = Arc<RawMessage>;
newtype_id!(pub MessageId);
#[derive(Debug)] #[derive(Debug)]
pub struct RawMessage { pub struct RawMessage {
pub id: Uuid, pub id: MessageId,
pub properties: methods::Table, pub properties: methods::Table,
pub routing: RoutingInformation, pub routing: RoutingInformation,
pub content: SmallVec<[Bytes; 1]>, pub content: SmallVec<[Bytes; 1]>,

View file

@ -1,18 +1,20 @@
use crate::message::Message; use crate::message::Message;
use crate::{newtype_id, ChannelId};
use parking_lot::Mutex; use parking_lot::Mutex;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid;
pub type Queue = Arc<RawQueue>; pub type Queue = Arc<RawQueue>;
newtype_id!(pub QueueId);
#[derive(Debug)] #[derive(Debug)]
pub struct RawQueue { pub struct RawQueue {
pub id: Uuid, pub id: QueueId,
pub name: String, pub name: String,
pub messages: Mutex<Vec<Message>>, // use a concurrent linked list??? pub messages: Mutex<Vec<Message>>, // use a concurrent linked list???
pub durable: bool, pub durable: bool,
pub exclusive: Option<Uuid>, pub exclusive: Option<ChannelId>,
/// Whether the queue will automatically be deleted when no consumers uses it anymore. /// Whether the queue will automatically be deleted when no consumers uses it anymore.
/// The queue can always be manually deleted. /// The queue can always be manually deleted.
/// If auto-delete is enabled, it keeps track of the consumer count. /// If auto-delete is enabled, it keeps track of the consumer count.

View file

@ -1,6 +1,6 @@
use amqp_core::connection::ChannelHandle;
use amqp_core::error::ProtocolError; use amqp_core::error::ProtocolError;
use amqp_core::methods::{Bit, ConsumerTag, NoAck, NoLocal, NoWait, QueueName, Table}; use amqp_core::methods::{Bit, ConsumerTag, NoAck, NoLocal, NoWait, QueueName, Table};
use amqp_core::ChannelHandle;
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn consume( pub async fn consume(

View file

@ -2,10 +2,10 @@ mod consume;
mod queue; mod queue;
use amqp_core::amqp_todo; use amqp_core::amqp_todo;
use amqp_core::connection::ChannelHandle;
use amqp_core::error::ProtocolError; use amqp_core::error::ProtocolError;
use amqp_core::message::Message; use amqp_core::message::Message;
use amqp_core::methods::Method; use amqp_core::methods::Method;
use amqp_core::ChannelHandle;
use tracing::info; use tracing::info;
pub async fn handle_basic_publish(_channel_handle: ChannelHandle, message: Message) { pub async fn handle_basic_publish(_channel_handle: ChannelHandle, message: Message) {

View file

@ -1,14 +1,13 @@
#![deny(clippy::future_not_send)] #![deny(clippy::future_not_send)]
use amqp_core::connection::ChannelHandle;
use amqp_core::error::{ConException, ProtocolError}; use amqp_core::error::{ConException, ProtocolError};
use amqp_core::methods::{Bit, ExchangeName, NoWait, QueueName, Shortstr, Table}; use amqp_core::methods::{Bit, ExchangeName, NoWait, QueueName, Shortstr, Table};
use amqp_core::queue::{QueueDeletion, RawQueue}; use amqp_core::queue::{QueueDeletion, QueueId, RawQueue};
use amqp_core::ChannelHandle;
use amqp_core::{amqp_todo, GlobalData}; use amqp_core::{amqp_todo, GlobalData};
use parking_lot::Mutex; use parking_lot::Mutex;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid;
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn declare( pub async fn declare(
@ -32,7 +31,7 @@ pub async fn declare(
amqp_todo!(); amqp_todo!();
} }
let id = amqp_core::gen_uuid(); let id = QueueId::random();
let queue = Arc::new(RawQueue { let queue = Arc::new(RawQueue {
id, id,
name: queue_name.clone(), name: queue_name.clone(),
@ -72,7 +71,7 @@ pub async fn bind(
async fn bind_queue( async fn bind_queue(
_global_data: GlobalData, _global_data: GlobalData,
_queue: Uuid, _queue: QueueId,
_exchange: (), _exchange: (),
_routing_key: String, _routing_key: String,
) -> Result<(), ProtocolError> { ) -> Result<(), ProtocolError> {

View file

@ -1,7 +1,8 @@
use crate::error::{ConException, ProtocolError, Result}; use crate::error::{ConException, ProtocolError, Result};
use crate::frame::{ChannelId, ContentHeader, Frame, FrameType}; use crate::frame::{ContentHeader, Frame, FrameType};
use crate::{frame, methods, sasl}; use crate::{frame, methods, sasl};
use amqp_core::message::{RawMessage, RoutingInformation}; use amqp_core::connection::{ChannelHandle, ChannelNum, ConnectionHandle, ConnectionId};
use amqp_core::message::{MessageId, RawMessage, RoutingInformation};
use amqp_core::methods::{FieldValue, Method, Table}; use amqp_core::methods::{FieldValue, Method, Table};
use amqp_core::GlobalData; use amqp_core::GlobalData;
use anyhow::Context; use anyhow::Context;
@ -17,7 +18,6 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time; use tokio::time;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use uuid::Uuid;
fn ensure_conn(condition: bool) -> Result<()> { fn ensure_conn(condition: bool) -> Result<()> {
if condition { if condition {
@ -36,21 +36,21 @@ const BASIC_CLASS_ID: u16 = 60;
pub struct Channel { pub struct Channel {
/// A handle to the global channel representation. Used to remove the channel when it's dropped /// A handle to the global channel representation. Used to remove the channel when it's dropped
handle: amqp_core::ChannelHandle, handle: ChannelHandle,
/// The current status of the channel, whether it has sent a method that expects a body /// The current status of the channel, whether it has sent a method that expects a body
status: ChannelStatus, status: ChannelStatus,
} }
pub struct Connection { pub struct Connection {
id: Uuid, id: ConnectionId,
stream: TcpStream, stream: TcpStream,
max_frame_size: usize, max_frame_size: usize,
heartbeat_delay: u16, heartbeat_delay: u16,
channel_max: u16, channel_max: u16,
/// When the next heartbeat expires /// When the next heartbeat expires
next_timeout: Pin<Box<time::Sleep>>, next_timeout: Pin<Box<time::Sleep>>,
channels: HashMap<ChannelId, Channel>, channels: HashMap<ChannelNum, Channel>,
handle: amqp_core::ConnectionHandle, handle: ConnectionHandle,
global_data: GlobalData, global_data: GlobalData,
} }
@ -71,9 +71,9 @@ impl ChannelStatus {
impl Connection { impl Connection {
pub fn new( pub fn new(
id: Uuid, id: ConnectionId,
stream: TcpStream, stream: TcpStream,
connection_handle: amqp_core::ConnectionHandle, connection_handle: ConnectionHandle,
global_data: GlobalData, global_data: GlobalData,
) -> Self { ) -> Self {
Self { Self {
@ -110,7 +110,7 @@ impl Connection {
self.main_loop().await self.main_loop().await
} }
async fn send_method(&mut self, channel: ChannelId, method: Method) -> Result<()> { async fn send_method(&mut self, channel: ChannelNum, method: Method) -> Result<()> {
let mut payload = Vec::with_capacity(64); let mut payload = Vec::with_capacity(64);
methods::write::write_method(method, &mut payload)?; methods::write::write_method(method, &mut payload)?;
frame::write_frame( frame::write_frame(
@ -147,7 +147,7 @@ impl Connection {
}; };
debug!(?start_method, "Sending Start method"); debug!(?start_method, "Sending Start method");
self.send_method(ChannelId::zero(), start_method).await?; self.send_method(ChannelNum::zero(), start_method).await?;
let start_ok = self.recv_method().await?; let start_ok = self.recv_method().await?;
debug!(?start_ok, "Received Start-Ok"); debug!(?start_ok, "Received Start-Ok");
@ -178,7 +178,7 @@ impl Connection {
}; };
debug!("Sending Tune method"); debug!("Sending Tune method");
self.send_method(ChannelId::zero(), tune_method).await?; self.send_method(ChannelNum::zero(), tune_method).await?;
let tune_ok = self.recv_method().await?; let tune_ok = self.recv_method().await?;
debug!(?tune_ok, "Received Tune-Ok method"); debug!(?tune_ok, "Received Tune-Ok method");
@ -207,7 +207,7 @@ impl Connection {
} }
self.send_method( self.send_method(
ChannelId::zero(), ChannelNum::zero(),
Method::ConnectionOpenOk { Method::ConnectionOpenOk {
reserved_1: "".to_string(), reserved_1: "".to_string(),
}, },
@ -249,7 +249,7 @@ impl Connection {
method_id, method_id,
} => { } => {
info!(%reply_code, %reply_text, %class_id, %method_id, "Closing connection"); info!(%reply_code, %reply_text, %class_id, %method_id, "Closing connection");
self.send_method(ChannelId::zero(), Method::ConnectionCloseOk {}) self.send_method(ChannelNum::zero(), Method::ConnectionCloseOk {})
.await?; .await?;
return Err(ProtocolError::GracefulClose.into()); return Err(ProtocolError::GracefulClose.into());
} }
@ -339,7 +339,7 @@ impl Connection {
method: Method, method: Method,
header: ContentHeader, header: ContentHeader,
payloads: SmallVec<[Bytes; 1]>, payloads: SmallVec<[Bytes; 1]>,
channel: ChannelId, channel: ChannelNum,
) -> Result<()> { ) -> Result<()> {
// The only method with content that is sent to the server is Basic.Publish. // The only method with content that is sent to the server is Basic.Publish.
ensure_conn(header.class_id == BASIC_CLASS_ID)?; ensure_conn(header.class_id == BASIC_CLASS_ID)?;
@ -353,7 +353,7 @@ impl Connection {
} = method } = method
{ {
let message = RawMessage { let message = RawMessage {
id: amqp_core::gen_uuid(), id: MessageId::random(),
properties: header.property_fields, properties: header.property_fields,
routing: RoutingInformation { routing: RoutingInformation {
exchange, exchange,
@ -379,11 +379,11 @@ impl Connection {
} }
} }
async fn channel_open(&mut self, channel_id: ChannelId) -> Result<()> { async fn channel_open(&mut self, channel_num: ChannelNum) -> Result<()> {
let id = amqp_core::gen_uuid(); let id = rand::random();
let channel_handle = amqp_core::Channel::new_handle( let channel_handle = amqp_core::connection::Channel::new_handle(
id, id,
channel_id.num(), channel_num.num(),
self.handle.clone(), self.handle.clone(),
self.global_data.clone(), self.global_data.clone(),
); );
@ -393,9 +393,9 @@ impl Connection {
status: ChannelStatus::Default, status: ChannelStatus::Default,
}; };
let prev = self.channels.insert(channel_id, channel); let prev = self.channels.insert(channel_num, channel);
if let Some(prev) = prev { if let Some(prev) = prev {
self.channels.insert(channel_id, prev); // restore previous state self.channels.insert(channel_num, prev); // restore previous state
return Err(ConException::ChannelError.into()); return Err(ConException::ChannelError.into());
} }
@ -408,13 +408,13 @@ impl Connection {
.unwrap() .unwrap()
.lock() .lock()
.channels .channels
.insert(channel_id.num(), channel_handle); .insert(channel_num.num(), channel_handle);
} }
info!(%channel_id, "Opened new channel"); info!(%channel_num, "Opened new channel");
self.send_method( self.send_method(
channel_id, channel_num,
Method::ChannelOpenOk { Method::ChannelOpenOk {
reserved_1: Vec::new(), reserved_1: Vec::new(),
}, },
@ -424,7 +424,7 @@ impl Connection {
Ok(()) Ok(())
} }
async fn channel_close(&mut self, channel_id: ChannelId, method: Method) -> Result<()> { async fn channel_close(&mut self, channel_id: ChannelNum, method: Method) -> Result<()> {
if let Method::ChannelClose { if let Method::ChannelClose {
reply_code: code, reply_code: code,
reply_text: reason, reply_text: reason,

View file

@ -1,36 +1,13 @@
use crate::error::{ConException, ProtocolError, Result}; use crate::error::{ConException, ProtocolError, Result};
use amqp_core::connection::ChannelNum;
use amqp_core::methods; use amqp_core::methods;
use anyhow::Context; use anyhow::Context;
use bytes::Bytes; use bytes::Bytes;
use std::fmt::{Display, Formatter};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::trace; use tracing::trace;
const REQUIRED_FRAME_END: u8 = 0xCE; const REQUIRED_FRAME_END: u8 = 0xCE;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ChannelId(u16);
impl ChannelId {
pub fn num(self) -> u16 {
self.0
}
pub fn is_zero(self) -> bool {
self.0 == 0
}
pub fn zero() -> Self {
Self(0)
}
}
impl Display for ChannelId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
mod frame_type { mod frame_type {
pub const METHOD: u8 = 1; pub const METHOD: u8 = 1;
pub const HEADER: u8 = 2; pub const HEADER: u8 = 2;
@ -42,7 +19,7 @@ mod frame_type {
pub struct Frame { pub struct Frame {
/// The type of the frame including its parsed metadata. /// The type of the frame including its parsed metadata.
pub kind: FrameType, pub kind: FrameType,
pub channel: ChannelId, pub channel: ChannelNum,
/// Includes the whole payload, also including the metadata from each type. /// Includes the whole payload, also including the metadata from each type.
pub payload: Bytes, pub payload: Bytes,
} }
@ -181,7 +158,7 @@ where
{ {
let kind = r.read_u8().await.context("read type")?; let kind = r.read_u8().await.context("read type")?;
let channel = r.read_u16().await.context("read channel")?; let channel = r.read_u16().await.context("read channel")?;
let channel = ChannelId(channel); let channel = ChannelNum::new(channel);
let size = r.read_u32().await.context("read size")?; let size = r.read_u32().await.context("read size")?;
let mut payload = vec![0; size.try_into().unwrap()]; let mut payload = vec![0; size.try_into().unwrap()];
@ -210,7 +187,7 @@ where
Ok(frame) Ok(frame)
} }
fn parse_frame_type(kind: u8, channel: ChannelId) -> Result<FrameType> { fn parse_frame_type(kind: u8, channel: ChannelNum) -> Result<FrameType> {
match kind { match kind {
frame_type::METHOD => Ok(FrameType::Method), frame_type::METHOD => Ok(FrameType::Method),
frame_type::HEADER => Ok(FrameType::Header), frame_type::HEADER => Ok(FrameType::Header),
@ -228,7 +205,7 @@ fn parse_frame_type(kind: u8, channel: ChannelId) -> Result<FrameType> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::frame::{ChannelId, Frame, FrameType}; use crate::frame::{ChannelNum, Frame, FrameType};
use bytes::Bytes; use bytes::Bytes;
#[tokio::test] #[tokio::test]
@ -257,7 +234,7 @@ mod tests {
frame, frame,
Frame { Frame {
kind: FrameType::Method, kind: FrameType::Method,
channel: ChannelId(0), channel: ChannelNum::new(0),
payload: Bytes::from_static(&[1, 2, 3]), payload: Bytes::from_static(&[1, 2, 3]),
} }
); );

View file

@ -24,13 +24,13 @@ pub async fn do_thing_i_guess(global_data: GlobalData) -> Result<()> {
loop { loop {
let (stream, peer_addr) = listener.accept().await?; let (stream, peer_addr) = listener.accept().await?;
let id = amqp_core::gen_uuid(); let id = rand::random();
info!(local_addr = ?stream.local_addr(), %id, "Accepted new connection"); info!(local_addr = ?stream.local_addr(), %id, "Accepted new connection");
let span = info_span!("client-connection", %id); let span = info_span!("client-connection", %id);
let connection_handle = let connection_handle =
amqp_core::Connection::new_handle(id, peer_addr, global_data.clone()); amqp_core::connection::Connection::new_handle(id, peer_addr, global_data.clone());
let mut global_data_guard = global_data.lock(); let mut global_data_guard = global_data.lock();
global_data_guard global_data_guard

View file

@ -1,4 +1,4 @@
use crate::frame::{ChannelId, FrameType}; use crate::frame::{ChannelNum, FrameType};
use crate::{frame, methods}; use crate::{frame, methods};
use amqp_core::methods::{FieldValue, Method}; use amqp_core::methods::{FieldValue, Method};
use std::collections::HashMap; use std::collections::HashMap;
@ -21,7 +21,7 @@ async fn write_start_ok_frame() {
let frame = frame::Frame { let frame = frame::Frame {
kind: FrameType::Method, kind: FrameType::Method,
channel: ChannelId::zero(), channel: ChannelNum::zero(),
payload: payload.into(), payload: payload.into(),
}; };