consume prototype

This commit is contained in:
nora 2022-03-01 21:52:00 +01:00
parent beb2187cd6
commit 93ce632b5d
21 changed files with 328 additions and 108 deletions

1
Cargo.lock generated
View file

@ -33,6 +33,7 @@ dependencies = [
"rand", "rand",
"smallvec", "smallvec",
"thiserror", "thiserror",
"tokio",
"uuid", "uuid",
] ]

View file

@ -11,4 +11,7 @@ parking_lot = "0.12.0"
rand = "0.8.5" rand = "0.8.5"
smallvec = { version = "1.8.0", features = ["union"] } smallvec = { version = "1.8.0", features = ["union"] }
thiserror = "1.0.30" thiserror = "1.0.30"
tokio = { version = "1.17.0", features = ["sync"] }
uuid = "0.8.2" uuid = "0.8.2"
[features]

View file

@ -1,9 +1,13 @@
use crate::{newtype_id, GlobalData, Handle, Queue}; use crate::methods::Method;
use crate::{methods, newtype_id, GlobalData, Handle, Queue};
use bytes::Bytes;
use parking_lot::Mutex; use parking_lot::Mutex;
use smallvec::SmallVec;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::mpsc;
newtype_id!(pub ConnectionId); newtype_id!(pub ConnectionId);
newtype_id!(pub ChannelId); newtype_id!(pub ChannelId);
@ -48,14 +52,25 @@ pub struct Connection {
pub global_data: GlobalData, pub global_data: GlobalData,
pub channels: HashMap<ChannelNum, ChannelHandle>, pub channels: HashMap<ChannelNum, ChannelHandle>,
pub exclusive_queues: Vec<Queue>, pub exclusive_queues: Vec<Queue>,
_method_queue: MethodSender,
} }
#[derive(Debug)]
pub enum QueuedMethod {
Normal(Method),
WithContent(Method, ContentHeader, SmallVec<[Bytes; 1]>),
}
pub type MethodSender = mpsc::Sender<(ChannelNum, QueuedMethod)>;
pub type MethodReceiver = mpsc::Receiver<(ChannelNum, QueuedMethod)>;
impl Connection { impl Connection {
#[must_use] #[must_use]
pub fn new_handle( pub fn new_handle(
id: ConnectionId, id: ConnectionId,
peer_addr: SocketAddr, peer_addr: SocketAddr,
global_data: GlobalData, global_data: GlobalData,
method_queue: MethodSender,
) -> ConnectionHandle { ) -> ConnectionHandle {
Arc::new(Mutex::new(Self { Arc::new(Mutex::new(Self {
id, id,
@ -63,6 +78,7 @@ impl Connection {
global_data, global_data,
channels: HashMap::new(), channels: HashMap::new(),
exclusive_queues: vec![], exclusive_queues: vec![],
_method_queue: method_queue,
})) }))
} }
@ -77,24 +93,27 @@ pub type ChannelHandle = Handle<Channel>;
#[derive(Debug)] #[derive(Debug)]
pub struct Channel { pub struct Channel {
pub id: ChannelId, pub id: ChannelId,
pub num: u16, pub num: ChannelNum,
pub connection: ConnectionHandle, pub connection: ConnectionHandle,
pub global_data: GlobalData, pub global_data: GlobalData,
method_queue: MethodSender,
} }
impl Channel { impl Channel {
#[must_use] #[must_use]
pub fn new_handle( pub fn new_handle(
id: ChannelId, id: ChannelId,
num: u16, num: ChannelNum,
connection: ConnectionHandle, connection: ConnectionHandle,
global_data: GlobalData, global_data: GlobalData,
method_queue: MethodSender,
) -> ChannelHandle { ) -> ChannelHandle {
Arc::new(Mutex::new(Self { Arc::new(Mutex::new(Self {
id, id,
num, num,
connection, connection,
global_data, global_data,
method_queue,
})) }))
} }
@ -102,4 +121,19 @@ impl Channel {
let mut global_data = self.global_data.lock(); let mut global_data = self.global_data.lock();
global_data.channels.remove(&self.id); global_data.channels.remove(&self.id);
} }
pub fn queue_method(&self, method: QueuedMethod) {
// todo: this is a horrible hack around the lock chaos
self.method_queue
.try_send((self.num, method))
.expect("could not send method to channel, RIP");
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ContentHeader {
pub class_id: u16,
pub weight: u16,
pub body_size: u64,
pub property_fields: methods::Table,
} }

12
amqp_core/src/consumer.rs Normal file
View file

@ -0,0 +1,12 @@
use crate::{newtype_id, ChannelHandle};
newtype_id!(
pub ConsumerId
);
#[derive(Debug)]
pub struct Consumer {
pub id: ConsumerId,
pub tag: String,
pub channel: ChannelHandle,
}

View file

@ -1,6 +1,7 @@
#![warn(rust_2018_idioms)] #![warn(rust_2018_idioms)]
pub mod connection; pub mod connection;
pub mod consumer;
pub mod error; pub mod error;
mod macros; mod macros;
pub mod message; pub mod message;
@ -13,6 +14,7 @@ use connection::{ChannelId, ConnectionId};
use parking_lot::Mutex; use parking_lot::Mutex;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid;
type Handle<T> = Arc<Mutex<T>>; type Handle<T> = Arc<Mutex<T>>;
@ -48,3 +50,7 @@ pub struct GlobalDataInner {
/// Todo: This is just for testing and will be removed later! /// Todo: This is just for testing and will be removed later!
pub default_exchange: HashMap<String, Queue>, pub default_exchange: HashMap<String, Queue>,
} }
pub fn random_uuid() -> Uuid {
Uuid::from_bytes(rand::random())
}

View file

@ -1,6 +1,7 @@
#[macro_export] #[macro_export]
macro_rules! newtype_id { macro_rules! newtype_id {
($vis:vis $name:ident) => { ($(#[$meta:meta])* $vis:vis $name:ident) => {
$(#[$meta])*
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
$vis struct $name(::uuid::Uuid); $vis struct $name(::uuid::Uuid);

View file

@ -1,6 +1,4 @@
#![allow(dead_code)] use crate::connection::ContentHeader;
use crate::methods;
use crate::newtype_id; use crate::newtype_id;
use bytes::Bytes; use bytes::Bytes;
use smallvec::SmallVec; use smallvec::SmallVec;
@ -13,7 +11,7 @@ newtype_id!(pub MessageId);
#[derive(Debug)] #[derive(Debug)]
pub struct RawMessage { pub struct RawMessage {
pub id: MessageId, pub id: MessageId,
pub properties: methods::Table, pub header: ContentHeader,
pub routing: RoutingInformation, pub routing: RoutingInformation,
pub content: SmallVec<[Bytes; 1]>, pub content: SmallVec<[Bytes; 1]>,
} }

View file

@ -1,6 +1,8 @@
use crate::consumer::Consumer;
use crate::message::Message; use crate::message::Message;
use crate::{newtype, newtype_id, ChannelId}; use crate::{newtype, newtype_id, ChannelId};
use parking_lot::Mutex; use parking_lot::Mutex;
use std::borrow::Borrow;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
use std::sync::Arc; use std::sync::Arc;
@ -14,6 +16,12 @@ newtype!(
pub QueueName: Arc<str> pub QueueName: Arc<str>
); );
impl Borrow<str> for QueueName {
fn borrow(&self) -> &str {
std::borrow::Borrow::borrow(&self.0)
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct RawQueue { pub struct RawQueue {
pub id: QueueId, pub id: QueueId,
@ -25,6 +33,7 @@ pub struct RawQueue {
/// The queue can always be manually deleted. /// The queue can always be manually deleted.
/// If auto-delete is enabled, it keeps track of the consumer count. /// If auto-delete is enabled, it keeps track of the consumer count.
pub deletion: QueueDeletion, pub deletion: QueueDeletion,
pub consumers: Mutex<Vec<Consumer>>,
} }
#[derive(Debug)] #[derive(Debug)]

View file

@ -92,7 +92,7 @@ async fn get_data(global_data: GlobalData) -> impl IntoResponse {
let chan = chan.lock(); let chan = chan.lock();
Channel { Channel {
id: chan.id.to_string(), id: chan.id.to_string(),
number: chan.num, number: chan.num.num(),
} }
}) })
.collect(), .collect(),

View file

@ -10,3 +10,5 @@ amqp_core = { path = "../amqp_core" }
parking_lot = "0.12.0" parking_lot = "0.12.0"
tracing = "0.1.31" tracing = "0.1.31"
tokio = { version = "1.17.0", features = ["full"] } tokio = { version = "1.17.0", features = ["full"] }
[features]

View file

@ -1,3 +1,7 @@
#![warn(rust_2018_idioms)] #![warn(rust_2018_idioms)]
use amqp_core::error::ProtocolError;
pub mod methods; pub mod methods;
type Result<T> = std::result::Result<T, ProtocolError>;

View file

@ -1,13 +1,56 @@
use crate::Result;
use amqp_core::amqp_todo; use amqp_core::amqp_todo;
use amqp_core::connection::ChannelHandle; use amqp_core::connection::ChannelHandle;
use amqp_core::error::ProtocolError; use amqp_core::consumer::{Consumer, ConsumerId};
use amqp_core::methods::{BasicConsume, Method}; use amqp_core::error::{ChannelException};
use amqp_core::methods::{BasicConsume, BasicConsumeOk, Method};
use std::sync::Arc;
use tracing::info;
pub async fn consume( pub fn consume(channel_handle: ChannelHandle, basic_consume: BasicConsume) -> Result<Method> {
channel_handle: ChannelHandle, let BasicConsume {
_basic_consume: BasicConsume, queue: queue_name,
) -> Result<Method, ProtocolError> { consumer_tag,
let _channel = channel_handle.lock(); no_local,
no_ack,
exclusive,
no_wait,
..
} = basic_consume;
amqp_todo!() if no_wait || no_local || exclusive || no_ack {
amqp_todo!();
}
let global_data = {
let channel = channel_handle.lock();
channel.global_data.clone()
};
let consumer_tag = if consumer_tag.is_empty() {
amqp_core::random_uuid().to_string()
} else {
consumer_tag
};
let mut global_data = global_data.lock();
let consumer = Consumer {
id: ConsumerId::random(),
tag: consumer_tag.clone(),
channel: Arc::clone(&channel_handle),
};
let queue = global_data
.queues
.get_mut(queue_name.as_str())
.ok_or(ChannelException::NotFound)?;
queue.consumers.lock().push(consumer);
info!(%queue_name, %consumer_tag, "Consumer started consuming");
let method = Method::BasicConsumeOk(BasicConsumeOk { consumer_tag });
Ok(method)
} }

View file

@ -1,24 +1,22 @@
mod consume; mod consume;
mod publish;
mod queue; mod queue;
use crate::Result;
use amqp_core::amqp_todo; use amqp_core::amqp_todo;
use amqp_core::connection::ChannelHandle; use amqp_core::connection::ChannelHandle;
use amqp_core::error::ProtocolError;
use amqp_core::message::Message; use amqp_core::message::Message;
use amqp_core::methods::Method; use amqp_core::methods::Method;
use tracing::info; use tracing::{error, info};
pub async fn handle_basic_publish(_channel_handle: ChannelHandle, message: Message) { pub async fn handle_basic_publish(channel_handle: ChannelHandle, message: Message) {
info!( match publish::publish(channel_handle, message).await {
?message, Ok(()) => {}
"Someone has summoned the almighty Basic.Publish handler" Err(err) => error!(%err, "publish error occurred"),
); }
} }
pub async fn handle_method( pub async fn handle_method(channel_handle: ChannelHandle, method: Method) -> Result<Method> {
channel_handle: ChannelHandle,
method: Method,
) -> Result<Method, ProtocolError> {
info!(?method, "Handling method"); info!(?method, "Handling method");
let response = match method { let response = match method {
@ -26,9 +24,7 @@ pub async fn handle_method(
Method::ExchangeDeclareOk(_) => amqp_todo!(), Method::ExchangeDeclareOk(_) => amqp_todo!(),
Method::ExchangeDelete(_) => amqp_todo!(), Method::ExchangeDelete(_) => amqp_todo!(),
Method::ExchangeDeleteOk(_) => amqp_todo!(), Method::ExchangeDeleteOk(_) => amqp_todo!(),
Method::QueueDeclare(queue_declare) => { Method::QueueDeclare(queue_declare) => queue::declare(channel_handle, queue_declare)?,
queue::declare(channel_handle, queue_declare).await?
}
Method::QueueDeclareOk { .. } => amqp_todo!(), Method::QueueDeclareOk { .. } => amqp_todo!(),
Method::QueueBind(queue_bind) => queue::bind(channel_handle, queue_bind).await?, Method::QueueBind(queue_bind) => queue::bind(channel_handle, queue_bind).await?,
Method::QueueBindOk(_) => amqp_todo!(), Method::QueueBindOk(_) => amqp_todo!(),
@ -40,7 +36,7 @@ pub async fn handle_method(
Method::QueueDeleteOk { .. } => amqp_todo!(), Method::QueueDeleteOk { .. } => amqp_todo!(),
Method::BasicQos { .. } => amqp_todo!(), Method::BasicQos { .. } => amqp_todo!(),
Method::BasicQosOk(_) => amqp_todo!(), Method::BasicQosOk(_) => amqp_todo!(),
Method::BasicConsume(consume) => consume::consume(channel_handle, consume).await?, Method::BasicConsume(consume) => consume::consume(channel_handle, consume)?,
Method::BasicConsumeOk { .. } => amqp_todo!(), Method::BasicConsumeOk { .. } => amqp_todo!(),
Method::BasicCancel { .. } => amqp_todo!(), Method::BasicCancel { .. } => amqp_todo!(),
Method::BasicCancelOk { .. } => amqp_todo!(), Method::BasicCancelOk { .. } => amqp_todo!(),

View file

@ -0,0 +1,52 @@
use crate::Result;
use amqp_core::amqp_todo;
use amqp_core::connection::{ChannelHandle, QueuedMethod};
use amqp_core::error::ChannelException;
use amqp_core::message::Message;
use amqp_core::methods::{BasicPublish, Method};
use tracing::info;
pub async fn publish(channel_handle: ChannelHandle, message: Message) -> Result<()> {
info!(?message, "Publishing message");
let global_data = channel_handle.lock().global_data.clone();
let routing = &message.routing;
if !routing.exchange.is_empty() {
amqp_todo!();
}
let mut global_data = global_data.lock();
let queue = global_data
.queues
.get_mut(routing.routing_key.as_str())
.ok_or(ChannelException::NotFound)?;
{
// todo: we just send it to the consumer directly and ignore it if the consumer doesn't exist
// consuming is hard, but this should work *for now*
let consumers = queue.consumers.lock();
if let Some(consumer) = consumers.first() {
let method = Method::BasicPublish(BasicPublish {
reserved_1: 0,
exchange: routing.exchange.clone(),
routing_key: routing.routing_key.clone(),
mandatory: false,
immediate: false,
});
consumer
.channel
.lock()
.queue_method(QueuedMethod::WithContent(
method,
message.header.clone(),
message.content.clone(),
));
}
}
Ok(())
}

View file

@ -1,16 +1,16 @@
use amqp_core::connection::ChannelHandle; use amqp_core::connection::ChannelHandle;
use amqp_core::error::ProtocolError;
use amqp_core::methods::{Method, QueueBind, QueueDeclare, QueueDeclareOk}; use amqp_core::methods::{Method, QueueBind, QueueDeclare, QueueDeclareOk};
use amqp_core::queue::{QueueDeletion, QueueId, QueueName, RawQueue}; use amqp_core::queue::{QueueDeletion, QueueId, QueueName, RawQueue};
use amqp_core::{amqp_todo, GlobalData}; use amqp_core::{amqp_todo, GlobalData};
use parking_lot::Mutex; use parking_lot::Mutex;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
use std::sync::Arc; use std::sync::Arc;
use crate::Result;
pub async fn declare( pub fn declare(
channel_handle: ChannelHandle, channel_handle: ChannelHandle,
queue_declare: QueueDeclare, queue_declare: QueueDeclare,
) -> Result<Method, ProtocolError> { ) -> Result<Method> {
let QueueDeclare { let QueueDeclare {
queue: queue_name, queue: queue_name,
passive, passive,
@ -28,7 +28,9 @@ pub async fn declare(
amqp_todo!(); amqp_todo!();
} }
if passive || no_wait || durable { // todo: durable is technically spec-compliant, the spec doesn't really require it, but it's a todo
// not checked here because it's the default for amqplib which is annoying
if passive || no_wait {
amqp_todo!(); amqp_todo!();
} }
@ -48,6 +50,7 @@ pub async fn declare(
} else { } else {
QueueDeletion::Manual QueueDeletion::Manual
}, },
consumers: Mutex::default(),
}); });
{ {
@ -58,7 +61,7 @@ pub async fn declare(
global_data global_data
}; };
bind_queue(global_data, (), queue_name.clone().into_inner()).await?; bind_queue(global_data, (), queue_name.clone().into_inner())?;
Ok(Method::QueueDeclareOk(QueueDeclareOk { Ok(Method::QueueDeclareOk(QueueDeclareOk {
queue: queue_name.to_string(), queue: queue_name.to_string(),
@ -70,15 +73,15 @@ pub async fn declare(
pub async fn bind( pub async fn bind(
_channel_handle: ChannelHandle, _channel_handle: ChannelHandle,
_queue_bind: QueueBind, _queue_bind: QueueBind,
) -> Result<Method, ProtocolError> { ) -> Result<Method> {
amqp_todo!(); amqp_todo!();
} }
async fn bind_queue( fn bind_queue(
global_data: GlobalData, global_data: GlobalData,
_exchange: (), _exchange: (),
routing_key: Arc<str>, routing_key: Arc<str>,
) -> Result<(), ProtocolError> { ) -> Result<()> {
let mut global_data = global_data.lock(); let mut global_data = global_data.lock();
// todo: don't // todo: don't

View file

@ -1,14 +1,17 @@
use crate::error::{ConException, ProtocolError, Result, TransError}; use crate::error::{ConException, ProtocolError, Result, TransError};
use crate::frame::{ContentHeader, Frame, FrameType}; use crate::frame::{parse_content_header, Frame, FrameType};
use crate::{frame, methods, sasl}; use crate::{frame, methods, sasl};
use amqp_core::connection::{ChannelHandle, ChannelNum, ConnectionHandle, ConnectionId}; use amqp_core::connection::{
ChannelHandle, ChannelNum, ConnectionHandle, ConnectionId, ContentHeader, MethodReceiver,
MethodSender, QueuedMethod,
};
use amqp_core::message::{MessageId, RawMessage, RoutingInformation}; use amqp_core::message::{MessageId, RawMessage, RoutingInformation};
use amqp_core::methods::{ use amqp_core::methods::{
BasicPublish, ChannelClose, ChannelCloseOk, ChannelOpenOk, ConnectionClose, ConnectionCloseOk, BasicPublish, ChannelClose, ChannelCloseOk, ChannelOpenOk, ConnectionClose, ConnectionCloseOk,
ConnectionOpen, ConnectionOpenOk, ConnectionStart, ConnectionStartOk, ConnectionTune, ConnectionOpen, ConnectionOpenOk, ConnectionStart, ConnectionStartOk, ConnectionTune,
ConnectionTuneOk, FieldValue, Method, Table, ConnectionTuneOk, FieldValue, Method, Table,
}; };
use amqp_core::GlobalData; use amqp_core::{amqp_todo, GlobalData};
use anyhow::Context; use anyhow::Context;
use bytes::Bytes; use bytes::Bytes;
use smallvec::SmallVec; use smallvec::SmallVec;
@ -20,7 +23,7 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time; use tokio::{select, time};
use tracing::{debug, error, info, trace, warn}; use tracing::{debug, error, info, trace, warn};
fn ensure_conn(condition: bool) -> Result<()> { fn ensure_conn(condition: bool) -> Result<()> {
@ -56,6 +59,9 @@ pub struct Connection {
channels: HashMap<ChannelNum, Channel>, channels: HashMap<ChannelNum, Channel>,
handle: ConnectionHandle, handle: ConnectionHandle,
global_data: GlobalData, global_data: GlobalData,
method_queue_send: MethodSender,
method_queue_recv: MethodReceiver,
} }
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
@ -64,7 +70,7 @@ enum ChannelStatus {
Default, Default,
/// ClassId // todo: newtype it /// ClassId // todo: newtype it
NeedHeader(u16, Box<Method>), NeedHeader(u16, Box<Method>),
NeedsBody(Box<Method>, Box<ContentHeader>, SmallVec<[Bytes; 1]>), NeedsBody(Box<Method>, ContentHeader, SmallVec<[Bytes; 1]>),
} }
impl ChannelStatus { impl ChannelStatus {
@ -79,6 +85,8 @@ impl Connection {
stream: TcpStream, stream: TcpStream,
connection_handle: ConnectionHandle, connection_handle: ConnectionHandle,
global_data: GlobalData, global_data: GlobalData,
method_queue_send: MethodSender,
method_queue_recv: MethodReceiver,
) -> Self { ) -> Self {
Self { Self {
id, id,
@ -90,6 +98,8 @@ impl Connection {
handle: connection_handle, handle: connection_handle,
channels: HashMap::with_capacity(4), channels: HashMap::with_capacity(4),
global_data, global_data,
method_queue_send,
method_queue_recv: method_queue_recv,
} }
} }
@ -145,6 +155,17 @@ impl Connection {
self.main_loop().await self.main_loop().await
} }
async fn send_method_content(
&mut self,
channel: ChannelNum,
method: Method,
_header: ContentHeader,
_body: SmallVec<[Bytes; 1]>,
) -> Result<()> {
self.send_method(channel, method).await?;
amqp_todo!()
}
async fn send_method(&mut self, channel: ChannelNum, method: Method) -> Result<()> { async fn send_method(&mut self, channel: ChannelNum, method: Method) -> Result<()> {
trace!(%channel, ?method, "Sending method"); trace!(%channel, ?method, "Sending method");
@ -256,41 +277,54 @@ impl Connection {
async fn main_loop(&mut self) -> Result<()> { async fn main_loop(&mut self) -> Result<()> {
loop { loop {
let frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?; select! {
let channel = frame.channel; frame = frame::read_frame(&mut self.stream, self.max_frame_size) => {
let result = self.handle_frame(frame).await; let frame = frame?;
match result { self.handle_frame(frame).await?;
Ok(()) => {} }
Err(TransError::Protocol(ProtocolError::ChannelException(ex))) => { queued_method = self.method_queue_recv.recv() => {
warn!(%ex, "Channel exception occurred"); match queued_method {
self.send_method( Some((channel, QueuedMethod::Normal(method))) => self.send_method(channel, method).await?,
channel, Some((channel, QueuedMethod::WithContent(method, header, body))) => self.send_method_content(channel, method, header, body).await?,
Method::ChannelClose(ChannelClose { None => {}
reply_code: ex.reply_code(), }
reply_text: ex.reply_text(),
class_id: 0, // todo: do this
method_id: 0,
}),
)
.await?;
drop(self.channels.remove(&channel));
} }
Err(other_err) => return Err(other_err),
} }
} }
} }
async fn handle_frame(&mut self, frame: Frame) -> Result<()> { async fn handle_frame(&mut self, frame: Frame) -> Result<()> {
let channel = frame.channel;
self.reset_timeout(); self.reset_timeout();
match frame.kind { let result = match frame.kind {
FrameType::Method => self.dispatch_method(frame).await?, FrameType::Method => self.dispatch_method(frame).await,
FrameType::Heartbeat => { /* Nothing here, just the `reset_timeout` above */ } FrameType::Heartbeat => {
FrameType::Header => self.dispatch_header(frame)?, Ok(()) /* Nothing here, just the `reset_timeout` above */
FrameType::Body => self.dispatch_body(frame)?, }
} FrameType::Header => self.dispatch_header(frame),
FrameType::Body => self.dispatch_body(frame),
};
Ok(()) match result {
Ok(()) => Ok(()),
Err(TransError::Protocol(ProtocolError::ChannelException(ex))) => {
warn!(%ex, "Channel exception occurred");
self.send_method(
channel,
Method::ChannelClose(ChannelClose {
reply_code: ex.reply_code(),
reply_text: ex.reply_text(),
class_id: 0, // todo: do this
method_id: 0,
}),
)
.await?;
drop(self.channels.remove(&channel));
Ok(())
}
Err(other_err) => Err(other_err),
}
} }
async fn dispatch_method(&mut self, frame: Frame) -> Result<()> { async fn dispatch_method(&mut self, frame: Frame) -> Result<()> {
@ -354,7 +388,7 @@ impl Connection {
Err(ConException::UnexpectedFrame.into()) Err(ConException::UnexpectedFrame.into())
} }
ChannelStatus::NeedHeader(class_id, method) => { ChannelStatus::NeedHeader(class_id, method) => {
let header = ContentHeader::parse(&frame.payload)?; let header = parse_content_header(&frame.payload)?;
ensure_conn(header.class_id == class_id)?; ensure_conn(header.class_id == class_id)?;
channel.status = ChannelStatus::NeedsBody(method, header, SmallVec::new()); channel.status = ChannelStatus::NeedsBody(method, header, SmallVec::new());
@ -391,7 +425,7 @@ impl Connection {
.cmp(&usize::try_from(header.body_size).unwrap()) .cmp(&usize::try_from(header.body_size).unwrap())
{ {
Ordering::Equal => { Ordering::Equal => {
self.process_method_with_body(*method, *header, vec, frame.channel) self.process_method_with_body(*method, header, vec, frame.channel)
} }
Ordering::Greater => Err(ConException::Todo.into()), Ordering::Greater => Err(ConException::Todo.into()),
Ordering::Less => Ok(()), // wait for next body Ordering::Less => Ok(()), // wait for next body
@ -420,7 +454,7 @@ impl Connection {
{ {
let message = RawMessage { let message = RawMessage {
id: MessageId::random(), id: MessageId::random(),
properties: header.property_fields, header,
routing: RoutingInformation { routing: RoutingInformation {
exchange, exchange,
routing_key, routing_key,
@ -449,9 +483,10 @@ impl Connection {
let id = rand::random(); let id = rand::random();
let channel_handle = amqp_core::connection::Channel::new_handle( let channel_handle = amqp_core::connection::Channel::new_handle(
id, id,
channel_num.num(), channel_num,
self.handle.clone(), self.handle.clone(),
self.global_data.clone(), self.global_data.clone(),
self.method_queue_send.clone(),
); );
let channel = Channel { let channel = Channel {

View file

@ -1,6 +1,5 @@
use crate::error::{ConException, ProtocolError, Result}; use crate::error::{ConException, ProtocolError, Result};
use amqp_core::connection::ChannelNum; use amqp_core::connection::{ChannelNum, ContentHeader};
use amqp_core::methods;
use anyhow::Context; use anyhow::Context;
use bytes::Bytes; use bytes::Bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
@ -33,18 +32,10 @@ pub enum FrameType {
Heartbeat = 8, Heartbeat = 8,
} }
#[derive(Debug, Clone, PartialEq)]
pub struct ContentHeader {
pub class_id: u16,
pub weight: u16,
pub body_size: u64,
pub property_fields: methods::Table,
}
mod content_header_parse { mod content_header_parse {
use crate::error::TransError; use crate::error::TransError;
use crate::frame::ContentHeader;
use crate::methods::parse_helper::{octet, shortstr, table, timestamp}; use crate::methods::parse_helper::{octet, shortstr, table, timestamp};
use amqp_core::connection::ContentHeader;
use amqp_core::methods; use amqp_core::methods;
use amqp_core::methods::FieldValue::{FieldTable, ShortShortUInt, ShortString, Timestamp}; use amqp_core::methods::FieldValue::{FieldTable, ShortShortUInt, ShortString, Timestamp};
use nom::number::complete::{u16, u64}; use nom::number::complete::{u16, u64};
@ -95,7 +86,7 @@ mod content_header_parse {
Ok((input, map)) Ok((input, map))
} }
pub fn header(input: &[u8]) -> IResult<'_, Box<ContentHeader>> { pub fn header(input: &[u8]) -> IResult<'_, ContentHeader> {
let (input, class_id) = u16(Big)(input)?; let (input, class_id) = u16(Big)(input)?;
let (input, weight) = u16(Big)(input)?; let (input, weight) = u16(Big)(input)?;
let (input, body_size) = u64(Big)(input)?; let (input, body_size) = u64(Big)(input)?;
@ -108,31 +99,26 @@ mod content_header_parse {
Ok(( Ok((
input, input,
Box::new(ContentHeader { ContentHeader {
class_id, class_id,
weight, weight,
body_size, body_size,
property_fields, property_fields,
}), },
)) ))
} }
} }
impl ContentHeader { pub fn parse_content_header(input: &[u8]) -> Result<ContentHeader> {
pub fn parse(input: &[u8]) -> Result<Box<Self>> { match content_header_parse::header(input) {
match content_header_parse::header(input) { Ok(([], header)) => Ok(header),
Ok(([], header)) => Ok(header), Ok((_, _)) => {
Ok((_, _)) => { Err(ConException::SyntaxError(vec!["could not consume all input".to_string()]).into())
Err(
ConException::SyntaxError(vec!["could not consume all input".to_string()])
.into(),
)
}
Err(nom::Err::Incomplete(_)) => {
Err(ConException::SyntaxError(vec!["there was not enough data".to_string()]).into())
}
Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err),
} }
Err(nom::Err::Incomplete(_)) => {
Err(ConException::SyntaxError(vec!["there was not enough data".to_string()]).into())
}
Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err),
} }
} }

View file

@ -29,15 +29,28 @@ pub async fn do_thing_i_guess(global_data: GlobalData) -> Result<()> {
info!(local_addr = ?stream.local_addr(), %id, "Accepted new connection"); info!(local_addr = ?stream.local_addr(), %id, "Accepted new connection");
let span = info_span!("client-connection", %id); let span = info_span!("client-connection", %id);
let connection_handle = let (method_send, method_recv) = tokio::sync::mpsc::channel(10);
amqp_core::connection::Connection::new_handle(id, peer_addr, global_data.clone());
let connection_handle = amqp_core::connection::Connection::new_handle(
id,
peer_addr,
global_data.clone(),
method_send.clone(),
);
let mut global_data_guard = global_data.lock(); let mut global_data_guard = global_data.lock();
global_data_guard global_data_guard
.connections .connections
.insert(id, connection_handle.clone()); .insert(id, connection_handle.clone());
let connection = Connection::new(id, stream, connection_handle, global_data.clone()); let connection = Connection::new(
id,
stream,
connection_handle,
global_data.clone(),
method_send,
method_recv,
);
tokio::spawn(connection.start_connection_processing().instrument(span)); tokio::spawn(connection.start_connection_processing().instrument(span));
} }

View file

@ -0,0 +1,23 @@
import { connectAmqp } from './utils/utils.js';
const connection = await connectAmqp();
const channel = await connection.createChannel();
await channel.assertQueue('consume-queue-1415');
const consumePromise = new Promise((resolve) => {
channel
.consume('consume-queue-1415', (msg) => {
if (msg.content.toString() === 'STOP') {
resolve();
}
})
.then((response) =>
console.log(`Registered consumer, consumerTag: "${response.consumerTag}"`)
);
});
await channel.sendToQueue('consume-queue-1415', Buffer.from('STOP'));
console.log('Sent STOP message to queue');
await consumePromise;

View file

@ -6,7 +6,7 @@ const connection = await connectAmqp();
const channel = await connection.createChannel(); const channel = await connection.createChannel();
const reply = await channel.assertQueue(queueName, { durable: false }); const reply = await channel.assertQueue(queueName);
assert(reply.messageCount === 0, 'Message found in queue'); assert(reply.messageCount === 0, 'Message found in queue');
assert(reply.consumerCount === 0, 'Consumer listening on queue'); assert(reply.consumerCount === 0, 'Consumer listening on queue');

View file

@ -1,7 +1,6 @@
import { connectAmqp } from './utils/utils.js'; import { connectAmqp } from './utils/utils.js';
const connection = await connectAmqp(); const connection = await connectAmqp();
const channel = await connection.createChannel(); const channel = await connection.createChannel();
channel.publish('exchange-1', 'queue-1', Buffer.from('hello')); channel.publish('exchange-1', 'queue-1', Buffer.from('hello'));