consume prototype

This commit is contained in:
nora 2022-03-01 21:52:00 +01:00
parent beb2187cd6
commit 93ce632b5d
21 changed files with 328 additions and 108 deletions

View file

@ -1,14 +1,17 @@
use crate::error::{ConException, ProtocolError, Result, TransError};
use crate::frame::{ContentHeader, Frame, FrameType};
use crate::frame::{parse_content_header, Frame, FrameType};
use crate::{frame, methods, sasl};
use amqp_core::connection::{ChannelHandle, ChannelNum, ConnectionHandle, ConnectionId};
use amqp_core::connection::{
ChannelHandle, ChannelNum, ConnectionHandle, ConnectionId, ContentHeader, MethodReceiver,
MethodSender, QueuedMethod,
};
use amqp_core::message::{MessageId, RawMessage, RoutingInformation};
use amqp_core::methods::{
BasicPublish, ChannelClose, ChannelCloseOk, ChannelOpenOk, ConnectionClose, ConnectionCloseOk,
ConnectionOpen, ConnectionOpenOk, ConnectionStart, ConnectionStartOk, ConnectionTune,
ConnectionTuneOk, FieldValue, Method, Table,
};
use amqp_core::GlobalData;
use amqp_core::{amqp_todo, GlobalData};
use anyhow::Context;
use bytes::Bytes;
use smallvec::SmallVec;
@ -20,7 +23,7 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time;
use tokio::{select, time};
use tracing::{debug, error, info, trace, warn};
fn ensure_conn(condition: bool) -> Result<()> {
@ -56,6 +59,9 @@ pub struct Connection {
channels: HashMap<ChannelNum, Channel>,
handle: ConnectionHandle,
global_data: GlobalData,
method_queue_send: MethodSender,
method_queue_recv: MethodReceiver,
}
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
@ -64,7 +70,7 @@ enum ChannelStatus {
Default,
/// ClassId // todo: newtype it
NeedHeader(u16, Box<Method>),
NeedsBody(Box<Method>, Box<ContentHeader>, SmallVec<[Bytes; 1]>),
NeedsBody(Box<Method>, ContentHeader, SmallVec<[Bytes; 1]>),
}
impl ChannelStatus {
@ -79,6 +85,8 @@ impl Connection {
stream: TcpStream,
connection_handle: ConnectionHandle,
global_data: GlobalData,
method_queue_send: MethodSender,
method_queue_recv: MethodReceiver,
) -> Self {
Self {
id,
@ -90,6 +98,8 @@ impl Connection {
handle: connection_handle,
channels: HashMap::with_capacity(4),
global_data,
method_queue_send,
method_queue_recv: method_queue_recv,
}
}
@ -145,6 +155,17 @@ impl Connection {
self.main_loop().await
}
async fn send_method_content(
&mut self,
channel: ChannelNum,
method: Method,
_header: ContentHeader,
_body: SmallVec<[Bytes; 1]>,
) -> Result<()> {
self.send_method(channel, method).await?;
amqp_todo!()
}
async fn send_method(&mut self, channel: ChannelNum, method: Method) -> Result<()> {
trace!(%channel, ?method, "Sending method");
@ -256,41 +277,54 @@ impl Connection {
async fn main_loop(&mut self) -> Result<()> {
loop {
let frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?;
let channel = frame.channel;
let result = self.handle_frame(frame).await;
match result {
Ok(()) => {}
Err(TransError::Protocol(ProtocolError::ChannelException(ex))) => {
warn!(%ex, "Channel exception occurred");
self.send_method(
channel,
Method::ChannelClose(ChannelClose {
reply_code: ex.reply_code(),
reply_text: ex.reply_text(),
class_id: 0, // todo: do this
method_id: 0,
}),
)
.await?;
drop(self.channels.remove(&channel));
select! {
frame = frame::read_frame(&mut self.stream, self.max_frame_size) => {
let frame = frame?;
self.handle_frame(frame).await?;
}
queued_method = self.method_queue_recv.recv() => {
match queued_method {
Some((channel, QueuedMethod::Normal(method))) => self.send_method(channel, method).await?,
Some((channel, QueuedMethod::WithContent(method, header, body))) => self.send_method_content(channel, method, header, body).await?,
None => {}
}
}
Err(other_err) => return Err(other_err),
}
}
}
async fn handle_frame(&mut self, frame: Frame) -> Result<()> {
let channel = frame.channel;
self.reset_timeout();
match frame.kind {
FrameType::Method => self.dispatch_method(frame).await?,
FrameType::Heartbeat => { /* Nothing here, just the `reset_timeout` above */ }
FrameType::Header => self.dispatch_header(frame)?,
FrameType::Body => self.dispatch_body(frame)?,
}
let result = match frame.kind {
FrameType::Method => self.dispatch_method(frame).await,
FrameType::Heartbeat => {
Ok(()) /* Nothing here, just the `reset_timeout` above */
}
FrameType::Header => self.dispatch_header(frame),
FrameType::Body => self.dispatch_body(frame),
};
Ok(())
match result {
Ok(()) => Ok(()),
Err(TransError::Protocol(ProtocolError::ChannelException(ex))) => {
warn!(%ex, "Channel exception occurred");
self.send_method(
channel,
Method::ChannelClose(ChannelClose {
reply_code: ex.reply_code(),
reply_text: ex.reply_text(),
class_id: 0, // todo: do this
method_id: 0,
}),
)
.await?;
drop(self.channels.remove(&channel));
Ok(())
}
Err(other_err) => Err(other_err),
}
}
async fn dispatch_method(&mut self, frame: Frame) -> Result<()> {
@ -354,7 +388,7 @@ impl Connection {
Err(ConException::UnexpectedFrame.into())
}
ChannelStatus::NeedHeader(class_id, method) => {
let header = ContentHeader::parse(&frame.payload)?;
let header = parse_content_header(&frame.payload)?;
ensure_conn(header.class_id == class_id)?;
channel.status = ChannelStatus::NeedsBody(method, header, SmallVec::new());
@ -391,7 +425,7 @@ impl Connection {
.cmp(&usize::try_from(header.body_size).unwrap())
{
Ordering::Equal => {
self.process_method_with_body(*method, *header, vec, frame.channel)
self.process_method_with_body(*method, header, vec, frame.channel)
}
Ordering::Greater => Err(ConException::Todo.into()),
Ordering::Less => Ok(()), // wait for next body
@ -420,7 +454,7 @@ impl Connection {
{
let message = RawMessage {
id: MessageId::random(),
properties: header.property_fields,
header,
routing: RoutingInformation {
exchange,
routing_key,
@ -449,9 +483,10 @@ impl Connection {
let id = rand::random();
let channel_handle = amqp_core::connection::Channel::new_handle(
id,
channel_num.num(),
channel_num,
self.handle.clone(),
self.global_data.clone(),
self.method_queue_send.clone(),
);
let channel = Channel {

View file

@ -1,6 +1,5 @@
use crate::error::{ConException, ProtocolError, Result};
use amqp_core::connection::ChannelNum;
use amqp_core::methods;
use amqp_core::connection::{ChannelNum, ContentHeader};
use anyhow::Context;
use bytes::Bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
@ -33,18 +32,10 @@ pub enum FrameType {
Heartbeat = 8,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ContentHeader {
pub class_id: u16,
pub weight: u16,
pub body_size: u64,
pub property_fields: methods::Table,
}
mod content_header_parse {
use crate::error::TransError;
use crate::frame::ContentHeader;
use crate::methods::parse_helper::{octet, shortstr, table, timestamp};
use amqp_core::connection::ContentHeader;
use amqp_core::methods;
use amqp_core::methods::FieldValue::{FieldTable, ShortShortUInt, ShortString, Timestamp};
use nom::number::complete::{u16, u64};
@ -95,7 +86,7 @@ mod content_header_parse {
Ok((input, map))
}
pub fn header(input: &[u8]) -> IResult<'_, Box<ContentHeader>> {
pub fn header(input: &[u8]) -> IResult<'_, ContentHeader> {
let (input, class_id) = u16(Big)(input)?;
let (input, weight) = u16(Big)(input)?;
let (input, body_size) = u64(Big)(input)?;
@ -108,31 +99,26 @@ mod content_header_parse {
Ok((
input,
Box::new(ContentHeader {
ContentHeader {
class_id,
weight,
body_size,
property_fields,
}),
},
))
}
}
impl ContentHeader {
pub fn parse(input: &[u8]) -> Result<Box<Self>> {
match content_header_parse::header(input) {
Ok(([], header)) => Ok(header),
Ok((_, _)) => {
Err(
ConException::SyntaxError(vec!["could not consume all input".to_string()])
.into(),
)
}
Err(nom::Err::Incomplete(_)) => {
Err(ConException::SyntaxError(vec!["there was not enough data".to_string()]).into())
}
Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err),
pub fn parse_content_header(input: &[u8]) -> Result<ContentHeader> {
match content_header_parse::header(input) {
Ok(([], header)) => Ok(header),
Ok((_, _)) => {
Err(ConException::SyntaxError(vec!["could not consume all input".to_string()]).into())
}
Err(nom::Err::Incomplete(_)) => {
Err(ConException::SyntaxError(vec!["there was not enough data".to_string()]).into())
}
Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err),
}
}

View file

@ -29,15 +29,28 @@ pub async fn do_thing_i_guess(global_data: GlobalData) -> Result<()> {
info!(local_addr = ?stream.local_addr(), %id, "Accepted new connection");
let span = info_span!("client-connection", %id);
let connection_handle =
amqp_core::connection::Connection::new_handle(id, peer_addr, global_data.clone());
let (method_send, method_recv) = tokio::sync::mpsc::channel(10);
let connection_handle = amqp_core::connection::Connection::new_handle(
id,
peer_addr,
global_data.clone(),
method_send.clone(),
);
let mut global_data_guard = global_data.lock();
global_data_guard
.connections
.insert(id, connection_handle.clone());
let connection = Connection::new(id, stream, connection_handle, global_data.clone());
let connection = Connection::new(
id,
stream,
connection_handle,
global_data.clone(),
method_send,
method_recv,
);
tokio::spawn(connection.start_connection_processing().instrument(span));
}