This commit is contained in:
nora 2022-03-04 22:35:55 +01:00
parent 4346db648f
commit 5d127eceee
8 changed files with 53 additions and 62 deletions

View file

@ -1,5 +1,6 @@
use crate::{methods, methods::Method, newtype_id, GlobalData, Queue}; use crate::{methods, methods::Method, newtype_id, GlobalData, Queue};
use bytes::Bytes; use bytes::Bytes;
use parking_lot::Mutex;
use smallvec::SmallVec; use smallvec::SmallVec;
use std::{ use std::{
collections::HashMap, collections::HashMap,
@ -43,12 +44,14 @@ impl Display for ChannelNum {
} }
} }
pub type Connection = Arc<ConnectionInner>;
#[derive(Debug)] #[derive(Debug)]
pub struct Connection { pub struct ConnectionInner {
pub id: ConnectionId, pub id: ConnectionId,
pub peer_addr: SocketAddr, pub peer_addr: SocketAddr,
pub global_data: GlobalData, pub global_data: GlobalData,
pub channels: HashMap<ChannelNum, Channel>, pub channels: Mutex<HashMap<ChannelNum, Channel>>,
pub exclusive_queues: Vec<Queue>, pub exclusive_queues: Vec<Queue>,
_events: ConEventSender, _events: ConEventSender,
} }
@ -62,19 +65,19 @@ pub enum QueuedMethod {
pub type ConEventSender = mpsc::Sender<(ChannelNum, QueuedMethod)>; pub type ConEventSender = mpsc::Sender<(ChannelNum, QueuedMethod)>;
pub type ConEventReceiver = mpsc::Receiver<(ChannelNum, QueuedMethod)>; pub type ConEventReceiver = mpsc::Receiver<(ChannelNum, QueuedMethod)>;
impl Connection { impl ConnectionInner {
#[must_use] #[must_use]
pub fn new( pub fn new(
id: ConnectionId, id: ConnectionId,
peer_addr: SocketAddr, peer_addr: SocketAddr,
global_data: GlobalData, global_data: GlobalData,
method_queue: ConEventSender, method_queue: ConEventSender,
) -> Arc<Connection> { ) -> Connection {
Arc::new(Self { Arc::new(Self {
id, id,
peer_addr, peer_addr,
global_data, global_data,
channels: HashMap::new(), channels: Mutex::new(HashMap::new()),
exclusive_queues: vec![], exclusive_queues: vec![],
_events: method_queue, _events: method_queue,
}) })
@ -86,8 +89,10 @@ impl Connection {
} }
} }
pub type Channel = Arc<ChannelInner>;
#[derive(Debug)] #[derive(Debug)]
pub struct Channel { pub struct ChannelInner {
pub id: ChannelId, pub id: ChannelId,
pub num: ChannelNum, pub num: ChannelNum,
pub connection: Connection, pub connection: Connection,
@ -95,7 +100,7 @@ pub struct Channel {
method_queue: ConEventSender, method_queue: ConEventSender,
} }
impl Channel { impl ChannelInner {
#[must_use] #[must_use]
pub fn new( pub fn new(
id: ChannelId, id: ChannelId,
@ -103,7 +108,7 @@ impl Channel {
connection: Connection, connection: Connection,
global_data: GlobalData, global_data: GlobalData,
method_queue: ConEventSender, method_queue: ConEventSender,
) -> Arc<Channel> { ) -> Channel {
Arc::new(Self { Arc::new(Self {
id, id,
num, num,

View file

@ -82,23 +82,18 @@ async fn get_data(global_data: GlobalData) -> impl IntoResponse {
let connections = global_data let connections = global_data
.connections .connections
.values() .values()
.map(|conn| { .map(|conn| Connection {
let conn = conn.lock(); id: conn.id.to_string(),
Connection { peer_addr: conn.peer_addr.to_string(),
id: conn.id.to_string(), channels: conn
peer_addr: conn.peer_addr.to_string(), .channels
channels: conn .lock()
.channels .values()
.values() .map(|chan| Channel {
.map(|chan| { id: chan.id.to_string(),
let chan = chan.lock(); number: chan.num.num(),
Channel { })
id: chan.id.to_string(), .collect(),
number: chan.num.num(),
}
})
.collect(),
}
}) })
.collect(); .collect();

View file

@ -9,7 +9,7 @@ use amqp_core::{
use std::sync::Arc; use std::sync::Arc;
use tracing::info; use tracing::info;
pub fn consume(channel_handle: Channel, basic_consume: BasicConsume) -> Result<Method> { pub fn consume(channel: Channel, basic_consume: BasicConsume) -> Result<Method> {
let BasicConsume { let BasicConsume {
queue: queue_name, queue: queue_name,
consumer_tag, consumer_tag,
@ -24,10 +24,7 @@ pub fn consume(channel_handle: Channel, basic_consume: BasicConsume) -> Result<M
amqp_todo!(); amqp_todo!();
} }
let global_data = { let global_data = channel.global_data.clone();
let channel = channel_handle.lock();
channel.global_data.clone()
};
let consumer_tag = if consumer_tag.is_empty() { let consumer_tag = if consumer_tag.is_empty() {
amqp_core::random_uuid().to_string() amqp_core::random_uuid().to_string()
@ -40,7 +37,7 @@ pub fn consume(channel_handle: Channel, basic_consume: BasicConsume) -> Result<M
let consumer = Consumer { let consumer = Consumer {
id: ConsumerId::random(), id: ConsumerId::random(),
tag: consumer_tag.clone(), tag: consumer_tag.clone(),
channel: Arc::clone(&channel_handle), channel: Arc::clone(&channel),
}; };
let queue = global_data let queue = global_data

View file

@ -4,17 +4,16 @@ mod queue;
use crate::Result; use crate::Result;
use amqp_core::{amqp_todo, connection::Channel, message::Message, methods::Method}; use amqp_core::{amqp_todo, connection::Channel, message::Message, methods::Method};
use std::sync::Arc;
use tracing::{error, info}; use tracing::{error, info};
pub async fn handle_basic_publish(channel_handle: Arc<Channel>, message: Message) { pub async fn handle_basic_publish(channel_handle: Channel, message: Message) {
match publish::publish(channel_handle, message).await { match publish::publish(channel_handle, message).await {
Ok(()) => {} Ok(()) => {}
Err(err) => error!(%err, "publish error occurred"), Err(err) => error!(%err, "publish error occurred"),
} }
} }
pub async fn handle_method(channel_handle: Arc<Channel>, method: Method) -> Result<Method> { pub async fn handle_method(channel_handle: Channel, method: Method) -> Result<Method> {
info!(?method, "Handling method"); info!(?method, "Handling method");
let response = match method { let response = match method {

View file

@ -8,10 +8,10 @@ use amqp_core::{
}; };
use tracing::info; use tracing::info;
pub async fn publish(channel_handle: Arc<Channel>, message: Message) -> Result<()> { pub async fn publish(channel_handle: Channel, message: Message) -> Result<()> {
info!(?message, "Publishing message"); info!(?message, "Publishing message");
let global_data = channel_handle.lock().global_data.clone(); let global_data = channel_handle.global_data.clone();
let routing = &message.routing; let routing = &message.routing;
@ -39,14 +39,11 @@ pub async fn publish(channel_handle: Arc<Channel>, message: Message) -> Result<(
immediate: false, immediate: false,
}); });
consumer consumer.channel.queue_method(QueuedMethod::WithContent(
.channel method,
.lock() message.header.clone(),
.queue_method(QueuedMethod::WithContent( message.content.clone(),
method, ));
message.header.clone(),
message.content.clone(),
));
} }
} }

View file

@ -9,7 +9,7 @@ use amqp_core::{
use parking_lot::Mutex; use parking_lot::Mutex;
use std::sync::{atomic::AtomicUsize, Arc}; use std::sync::{atomic::AtomicUsize, Arc};
pub fn declare(channel_handle: Channel, queue_declare: QueueDeclare) -> Result<Method> { pub fn declare(channel: Channel, queue_declare: QueueDeclare) -> Result<Method> {
let QueueDeclare { let QueueDeclare {
queue: queue_name, queue: queue_name,
passive, passive,
@ -34,7 +34,6 @@ pub fn declare(channel_handle: Channel, queue_declare: QueueDeclare) -> Result<M
} }
let global_data = { let global_data = {
let channel = channel_handle.lock();
let global_data = channel.global_data.clone(); let global_data = channel.global_data.clone();
let id = QueueId::random(); let id = QueueId::random();

View file

@ -7,8 +7,8 @@ use crate::{
use amqp_core::{ use amqp_core::{
amqp_todo, amqp_todo,
connection::{ connection::{
Channel, ChannelNum, ConEventReceiver, ConEventSender, Connection, ConnectionId, Channel, ChannelInner, ChannelNum, ConEventReceiver, ConEventSender, Connection,
ContentHeader, QueuedMethod, ConnectionId, ContentHeader, QueuedMethod,
}, },
message::{MessageId, RawMessage, RoutingInformation}, message::{MessageId, RawMessage, RoutingInformation},
methods::{ methods::{
@ -44,11 +44,11 @@ const CHANNEL_MAX: u16 = 0;
const FRAME_SIZE_MAX: u32 = 0; const FRAME_SIZE_MAX: u32 = 0;
const HEARTBEAT_DELAY: u16 = 0; const HEARTBEAT_DELAY: u16 = 0;
const BASIC_CLASS_ID: ChannelNum = ChannelNum::new(60); const BASIC_CLASS_ID: u16 = 60;
pub struct TransportChannel { pub struct TransportChannel {
/// A handle to the global channel representation. Used to remove the channel when it's dropped /// A handle to the global channel representation. Used to remove the channel when it's dropped
global_chan: Arc<Channel>, global_chan: Channel,
/// The current status of the channel, whether it has sent a method that expects a body /// The current status of the channel, whether it has sent a method that expects a body
status: ChannelStatus, status: ChannelStatus,
} }
@ -62,7 +62,7 @@ pub struct TransportConnection {
/// When the next heartbeat expires /// When the next heartbeat expires
next_timeout: Pin<Box<time::Sleep>>, next_timeout: Pin<Box<time::Sleep>>,
channels: HashMap<ChannelNum, TransportChannel>, channels: HashMap<ChannelNum, TransportChannel>,
global_con: Arc<Connection>, global_con: Connection,
global_data: GlobalData, global_data: GlobalData,
method_queue_send: ConEventSender, method_queue_send: ConEventSender,
@ -73,7 +73,7 @@ const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
enum ChannelStatus { enum ChannelStatus {
Default, Default,
NeedHeader(ChannelNum, Box<Method>), NeedHeader(u16, Box<Method>),
NeedsBody(Box<Method>, ContentHeader, SmallVec<[Bytes; 1]>), NeedsBody(Box<Method>, ContentHeader, SmallVec<[Bytes; 1]>),
} }
@ -87,7 +87,7 @@ impl TransportConnection {
pub fn new( pub fn new(
id: ConnectionId, id: ConnectionId,
stream: TcpStream, stream: TcpStream,
connection_handle: Arc<GConnection>, global_con: Connection,
global_data: GlobalData, global_data: GlobalData,
method_queue_send: ConEventSender, method_queue_send: ConEventSender,
method_queue_recv: ConEventReceiver, method_queue_recv: ConEventReceiver,
@ -99,7 +99,7 @@ impl TransportConnection {
heartbeat_delay: HEARTBEAT_DELAY, heartbeat_delay: HEARTBEAT_DELAY,
channel_max: CHANNEL_MAX, channel_max: CHANNEL_MAX,
next_timeout: Box::pin(time::sleep(DEFAULT_TIMEOUT)), next_timeout: Box::pin(time::sleep(DEFAULT_TIMEOUT)),
global_con: connection_handle, global_con,
channels: HashMap::with_capacity(4), channels: HashMap::with_capacity(4),
global_data, global_data,
method_queue_send, method_queue_send,
@ -144,8 +144,7 @@ impl TransportConnection {
Err(err) => error!(%err, "Error during processing of connection"), Err(err) => error!(%err, "Error during processing of connection"),
} }
let connection_handle = self.global_con.lock(); // global connection is closed on drop
connection_handle.close();
} }
pub async fn process_connection(&mut self) -> Result<()> { pub async fn process_connection(&mut self) -> Result<()> {
@ -485,7 +484,7 @@ impl TransportConnection {
async fn channel_open(&mut self, channel_num: ChannelNum) -> Result<()> { async fn channel_open(&mut self, channel_num: ChannelNum) -> Result<()> {
let id = rand::random(); let id = rand::random();
let channel_handle = amqp_core::connection::c::new_handle( let channel_handle = ChannelInner::new(
id, id,
channel_num, channel_num,
self.global_con.clone(), self.global_con.clone(),
@ -511,8 +510,8 @@ impl TransportConnection {
.connections .connections
.get_mut(&self.id) .get_mut(&self.id)
.unwrap() .unwrap()
.lock()
.channels .channels
.lock()
.insert(channel_num, channel_handle); .insert(channel_num, channel_handle);
} }
@ -603,13 +602,13 @@ impl TransportConnection {
impl Drop for TransportConnection { impl Drop for TransportConnection {
fn drop(&mut self) { fn drop(&mut self) {
self.global_con.lock().close(); self.global_con.close();
} }
} }
impl Drop for TransportChannel { impl Drop for TransportChannel {
fn drop(&mut self) { fn drop(&mut self) {
self.global_chan.lock().close(); self.global_chan.close();
} }
} }

View file

@ -3,7 +3,7 @@
mod connection; mod connection;
mod error; mod error;
mod frame; mod frame;
pub mod methods; mod methods;
mod sasl; mod sasl;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
@ -31,7 +31,7 @@ pub async fn do_thing_i_guess(global_data: GlobalData) -> Result<()> {
let (method_send, method_recv) = tokio::sync::mpsc::channel(10); let (method_send, method_recv) = tokio::sync::mpsc::channel(10);
let connection_handle = amqp_core::connection::ConnectionInner::new_handle( let connection_handle = amqp_core::connection::ConnectionInner::new(
id, id,
peer_addr, peer_addr,
global_data.clone(), global_data.clone(),