more things

This commit is contained in:
nora 2022-02-22 23:00:43 +01:00
parent 9a819bc3f4
commit b50634841d
6 changed files with 228 additions and 93 deletions

View file

@ -29,5 +29,5 @@ jobs:
run: cargo fmt --verbose --all -- --check run: cargo fmt --verbose --all -- --check
- name: Run tests - name: Run tests
run: cargo test --verbose --all run: cargo test --verbose --all
- name: Run client integration tests # - name: Run client integration tests
run: cargo xtask test-js # run: cargo xtask test-js

View file

@ -17,7 +17,7 @@ use amqp_core::methods::{FieldValue, Method, Table};
use amqp_core::GlobalData; use amqp_core::GlobalData;
use crate::error::{ConException, ProtocolError, Result}; use crate::error::{ConException, ProtocolError, Result};
use crate::frame::{ContentHeader, Frame, FrameType}; use crate::frame::{ChannelId, ContentHeader, Frame, FrameType};
use crate::{frame, methods, sasl}; use crate::{frame, methods, sasl};
fn ensure_conn(condition: bool) -> Result<()> { fn ensure_conn(condition: bool) -> Result<()> {
@ -33,10 +33,11 @@ const CHANNEL_MAX: u16 = 0;
const FRAME_SIZE_MAX: u32 = 0; const FRAME_SIZE_MAX: u32 = 0;
const HEARTBEAT_DELAY: u16 = 0; const HEARTBEAT_DELAY: u16 = 0;
#[allow(dead_code)]
pub struct Channel { pub struct Channel {
num: u16, /// A handle to the global channel representation. Used to remove the channel when it's dropped
channel_handle: amqp_core::ChannelHandle, handle: amqp_core::ChannelHandle,
/// The current status of the channel, whether it has sent a method that expects a body
status: ChannelStatus,
} }
pub struct Connection { pub struct Connection {
@ -45,18 +46,26 @@ pub struct Connection {
max_frame_size: usize, max_frame_size: usize,
heartbeat_delay: u16, heartbeat_delay: u16,
channel_max: u16, channel_max: u16,
/// When the next heartbeat expires
next_timeout: Pin<Box<time::Sleep>>, next_timeout: Pin<Box<time::Sleep>>,
channels: HashMap<u16, Channel>, channels: HashMap<ChannelId, Channel>,
connection_handle: amqp_core::ConnectionHandle, handle: amqp_core::ConnectionHandle,
global_data: GlobalData, global_data: GlobalData,
} }
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
enum WaitForBodyStatus { enum ChannelStatus {
Method(Method), Default,
Header(Method, ContentHeader, SmallVec<[Bytes; 1]>), /// ClassId // todo: newtype it
None, NeedHeader(u16, Box<Method>),
NeedsBody(Box<Method>, Box<ContentHeader>, SmallVec<[Bytes; 1]>),
}
impl ChannelStatus {
fn take(&mut self) -> Self {
std::mem::replace(self, Self::Default)
}
} }
impl Connection { impl Connection {
@ -73,8 +82,8 @@ impl Connection {
heartbeat_delay: HEARTBEAT_DELAY, heartbeat_delay: HEARTBEAT_DELAY,
channel_max: CHANNEL_MAX, channel_max: CHANNEL_MAX,
next_timeout: Box::pin(time::sleep(DEFAULT_TIMEOUT)), next_timeout: Box::pin(time::sleep(DEFAULT_TIMEOUT)),
connection_handle, handle: connection_handle,
channels: HashMap::new(), channels: HashMap::with_capacity(4),
global_data, global_data,
} }
} }
@ -85,7 +94,7 @@ impl Connection {
Err(err) => error!(%err, "Error during processing of connection"), Err(err) => error!(%err, "Error during processing of connection"),
} }
let connection_handle = self.connection_handle.lock(); let connection_handle = self.handle.lock();
connection_handle.close(); connection_handle.close();
} }
@ -100,7 +109,7 @@ impl Connection {
self.main_loop().await self.main_loop().await
} }
async fn send_method(&mut self, channel: u16, method: Method) -> Result<()> { async fn send_method(&mut self, channel: ChannelId, method: Method) -> Result<()> {
let mut payload = Vec::with_capacity(64); let mut payload = Vec::with_capacity(64);
methods::write::write_method(method, &mut payload)?; methods::write::write_method(method, &mut payload)?;
frame::write_frame( frame::write_frame(
@ -137,7 +146,7 @@ impl Connection {
}; };
debug!(?start_method, "Sending Start method"); debug!(?start_method, "Sending Start method");
self.send_method(0, start_method).await?; self.send_method(ChannelId::zero(), start_method).await?;
let start_ok = self.recv_method().await?; let start_ok = self.recv_method().await?;
debug!(?start_ok, "Received Start-Ok"); debug!(?start_ok, "Received Start-Ok");
@ -168,7 +177,7 @@ impl Connection {
}; };
debug!("Sending Tune method"); debug!("Sending Tune method");
self.send_method(0, tune_method).await?; self.send_method(ChannelId::zero(), tune_method).await?;
let tune_ok = self.recv_method().await?; let tune_ok = self.recv_method().await?;
debug!(?tune_ok, "Received Tune-Ok method"); debug!(?tune_ok, "Received Tune-Ok method");
@ -197,7 +206,7 @@ impl Connection {
} }
self.send_method( self.send_method(
0, ChannelId::zero(),
Method::ConnectionOpenOk { Method::ConnectionOpenOk {
reserved_1: "".to_string(), reserved_1: "".to_string(),
}, },
@ -208,54 +217,29 @@ impl Connection {
} }
async fn main_loop(&mut self) -> Result<()> { async fn main_loop(&mut self) -> Result<()> {
// todo: find out how header/body frames can interleave between channels
let mut wait_for_body = WaitForBodyStatus::None;
loop { loop {
debug!("Waiting for next frame"); debug!("Waiting for next frame");
let frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?; let frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?;
self.reset_timeout(); self.reset_timeout();
match frame.kind { match frame.kind {
FrameType::Method => wait_for_body = self.dispatch_method(frame).await?, FrameType::Method => self.dispatch_method(frame).await?,
FrameType::Heartbeat => {} FrameType::Heartbeat => { /* Nothing here, just the `reset_timeout` above */ }
FrameType::Header => match wait_for_body { FrameType::Header => self.dispatch_header(frame)?,
WaitForBodyStatus::None => warn!(channel = %frame.channel, "unexpected header"), FrameType::Body => self.dispatch_body(frame)?,
WaitForBodyStatus::Method(method) => {
wait_for_body =
WaitForBodyStatus::Header(method, ContentHeader::new(), SmallVec::new())
}
WaitForBodyStatus::Header(_, _, _) => {
warn!(channel = %frame.channel, "already got header")
}
},
FrameType::Body => match &mut wait_for_body {
WaitForBodyStatus::None => warn!(channel = %frame.channel, "unexpected body"),
WaitForBodyStatus::Method(_) => {
warn!(channel = %frame.channel, "unexpected body")
}
WaitForBodyStatus::Header(_, header, vec) => {
vec.push(frame.payload);
match vec
.iter()
.map(Bytes::len)
.sum::<usize>()
.cmp(&usize::try_from(header.body_size).unwrap())
{
Ordering::Equal => todo!("process body"),
Ordering::Greater => todo!("too much data!"),
Ordering::Less => {} // wait for next body
}
}
},
} }
} }
} }
async fn dispatch_method(&mut self, frame: Frame) -> Result<WaitForBodyStatus> { async fn dispatch_method(&mut self, frame: Frame) -> Result<()> {
let method = methods::parse_method(&frame.payload)?; let method = methods::parse_method(&frame.payload)?;
debug!(?method, "Received method"); debug!(?method, "Received method");
// Sending a method implicitly cancels the content frames that might be ongoing
self.channels
.get_mut(&frame.channel)
.map(|channel| channel.status.take());
match method { match method {
Method::ConnectionClose { Method::ConnectionClose {
reply_code, reply_code,
@ -264,18 +248,27 @@ impl Connection {
method_id, method_id,
} => { } => {
info!(%reply_code, %reply_text, %class_id, %method_id, "Closing connection"); info!(%reply_code, %reply_text, %class_id, %method_id, "Closing connection");
self.send_method(0, Method::ConnectionCloseOk {}).await?; self.send_method(ChannelId::zero(), Method::ConnectionCloseOk {})
.await?;
return Err(ProtocolError::GracefulClose.into()); return Err(ProtocolError::GracefulClose.into());
} }
Method::ChannelOpen { .. } => self.channel_open(frame.channel).await?, Method::ChannelOpen { .. } => self.channel_open(frame.channel).await?,
Method::ChannelClose { .. } => self.channel_close(frame.channel, method).await?, Method::ChannelClose { .. } => self.channel_close(frame.channel, method).await?,
Method::BasicPublish { .. } => return Ok(WaitForBodyStatus::Method(method)), Method::BasicPublish { .. } => {
const BASIC_CLASS_ID: u16 = 60;
match self.channels.get_mut(&frame.channel) {
Some(channel) => {
channel.status = ChannelStatus::NeedHeader(BASIC_CLASS_ID, Box::new(method))
}
None => return Err(ConException::Todo.into_trans()),
}
}
_ => { _ => {
let channel_handle = self let channel_handle = self
.channels .channels
.get(&frame.channel) .get(&frame.channel)
.ok_or_else(|| ConException::Todo.into_trans())? .ok_or_else(|| ConException::Todo.into_trans())?
.channel_handle .handle
.clone(); .clone();
tokio::spawn(amqp_messaging::methods::handle_method( tokio::spawn(amqp_messaging::methods::handle_method(
@ -285,27 +278,79 @@ impl Connection {
// we don't handle this here, forward it to *somewhere* // we don't handle this here, forward it to *somewhere*
} }
} }
Ok(())
Ok(WaitForBodyStatus::None)
} }
async fn channel_open(&mut self, num: u16) -> Result<()> { fn dispatch_header(&mut self, frame: Frame) -> Result<()> {
self.channels
.get_mut(&frame.channel)
.ok_or_else(|| ConException::Todo.into_trans())
.and_then(|channel| match channel.status.take() {
ChannelStatus::Default => {
warn!(channel = %frame.channel, "unexpected header");
Err(ConException::UnexpectedFrame.into_trans())
}
ChannelStatus::NeedHeader(class_id, method) => {
let header = ContentHeader::parse(&frame.payload)?;
ensure_conn(header.class_id == class_id)?;
channel.status = ChannelStatus::NeedsBody(method, header, SmallVec::new());
Ok(())
}
ChannelStatus::NeedsBody(_, _, _) => {
warn!(channel = %frame.channel, "already got header");
Err(ConException::UnexpectedFrame.into_trans())
}
})
}
fn dispatch_body(&mut self, frame: Frame) -> Result<()> {
self.channels
.get_mut(&frame.channel)
.ok_or_else(|| ConException::Todo.into_trans())
.and_then(|channel| match channel.status.take() {
ChannelStatus::Default => {
warn!(channel = %frame.channel, "unexpected body");
Err(ConException::UnexpectedFrame.into_trans())
}
ChannelStatus::NeedHeader(_, _) => {
warn!(channel = %frame.channel, "unexpected body");
Err(ConException::UnexpectedFrame.into_trans())
}
ChannelStatus::NeedsBody(_, header, mut vec) => {
vec.push(frame.payload);
match vec
.iter()
.map(Bytes::len)
.sum::<usize>()
.cmp(&usize::try_from(header.body_size).unwrap())
{
Ordering::Equal => todo!("process body"),
Ordering::Greater => todo!("too much data!"),
Ordering::Less => {} // wait for next body
}
Ok(())
}
})
}
async fn channel_open(&mut self, channel_id: ChannelId) -> Result<()> {
let id = Uuid::from_bytes(rand::random()); let id = Uuid::from_bytes(rand::random());
let channel_handle = amqp_core::Channel::new_handle( let channel_handle = amqp_core::Channel::new_handle(
id, id,
num, channel_id.num(),
self.connection_handle.clone(), self.handle.clone(),
self.global_data.clone(), self.global_data.clone(),
); );
let channel = Channel { let channel = Channel {
num, handle: channel_handle.clone(),
channel_handle: channel_handle.clone(), status: ChannelStatus::Default,
}; };
let prev = self.channels.insert(num, channel); let prev = self.channels.insert(channel_id, channel);
if let Some(prev) = prev { if let Some(prev) = prev {
self.channels.insert(num, prev); // restore previous state self.channels.insert(channel_id, prev); // restore previous state
return Err(ConException::ChannelError.into_trans()); return Err(ConException::ChannelError.into_trans());
} }
@ -318,13 +363,13 @@ impl Connection {
.unwrap() .unwrap()
.lock() .lock()
.channels .channels
.insert(num, channel_handle); .insert(channel_id.num(), channel_handle);
} }
info!(%num, "Opened new channel"); info!(%channel_id, "Opened new channel");
self.send_method( self.send_method(
num, channel_id,
Method::ChannelOpenOk { Method::ChannelOpenOk {
reserved_1: Vec::new(), reserved_1: Vec::new(),
}, },
@ -334,7 +379,7 @@ impl Connection {
Ok(()) Ok(())
} }
async fn channel_close(&mut self, num: u16, method: Method) -> Result<()> { async fn channel_close(&mut self, channel_id: ChannelId, method: Method) -> Result<()> {
if let Method::ChannelClose { if let Method::ChannelClose {
reply_code: code, reply_code: code,
reply_text: reason, reply_text: reason,
@ -343,9 +388,9 @@ impl Connection {
{ {
info!(%code, %reason, "Closing channel"); info!(%code, %reason, "Closing channel");
if let Some(channel) = self.channels.remove(&num) { if let Some(channel) = self.channels.remove(&channel_id) {
drop(channel); drop(channel);
self.send_method(num, Method::ChannelCloseOk).await?; self.send_method(channel_id, Method::ChannelCloseOk).await?;
} else { } else {
return Err(ConException::Todo.into_trans()); return Err(ConException::Todo.into_trans());
} }
@ -357,7 +402,7 @@ impl Connection {
fn reset_timeout(&mut self) { fn reset_timeout(&mut self) {
if self.heartbeat_delay != 0 { if self.heartbeat_delay != 0 {
let next = Duration::from_secs(u64::from(self.heartbeat_delay)); let next = Duration::from_secs(u64::from(self.heartbeat_delay / 2));
self.next_timeout = Box::pin(time::sleep(next)); self.next_timeout = Box::pin(time::sleep(next));
} }
} }
@ -396,13 +441,13 @@ impl Connection {
impl Drop for Connection { impl Drop for Connection {
fn drop(&mut self) { fn drop(&mut self) {
self.connection_handle.lock().close(); self.handle.lock().close();
} }
} }
impl Drop for Channel { impl Drop for Channel {
fn drop(&mut self) { fn drop(&mut self) {
self.channel_handle.lock().close(); self.handle.lock().close();
} }
} }

View file

@ -45,6 +45,8 @@ pub enum ConException {
SyntaxError(Vec<String>), SyntaxError(Vec<String>),
#[error("504 Channel error")] #[error("504 Channel error")]
ChannelError, ChannelError,
#[error("505 Unexpected Frame")]
UnexpectedFrame,
#[error("xxx Not decided yet")] #[error("xxx Not decided yet")]
Todo, Todo,
} }

View file

@ -1,13 +1,36 @@
use crate::error::{ConException, ProtocolError, Result}; use crate::error::{ConException, ProtocolError, Result};
use amqp_core::methods::FieldValue; use amqp_core::methods;
use anyhow::Context; use anyhow::Context;
use bytes::Bytes; use bytes::Bytes;
use smallvec::SmallVec; use std::fmt::{Display, Formatter};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::trace; use tracing::trace;
const REQUIRED_FRAME_END: u8 = 0xCE; const REQUIRED_FRAME_END: u8 = 0xCE;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ChannelId(u16);
impl ChannelId {
pub fn num(self) -> u16 {
self.0
}
pub fn is_zero(self) -> bool {
self.0 == 0
}
pub fn zero() -> Self {
Self(0)
}
}
impl Display for ChannelId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
mod frame_type { mod frame_type {
pub const METHOD: u8 = 1; pub const METHOD: u8 = 1;
pub const HEADER: u8 = 2; pub const HEADER: u8 = 2;
@ -19,7 +42,7 @@ mod frame_type {
pub struct Frame { pub struct Frame {
/// The type of the frame including its parsed metadata. /// The type of the frame including its parsed metadata.
pub kind: FrameType, pub kind: FrameType,
pub channel: u16, pub channel: ChannelId,
/// Includes the whole payload, also including the metadata from each type. /// Includes the whole payload, also including the metadata from each type.
pub payload: Bytes, pub payload: Bytes,
} }
@ -33,18 +56,84 @@ pub enum FrameType {
Heartbeat = 8, Heartbeat = 8,
} }
#[derive(Debug, Clone, PartialEq)]
pub struct BasicProperties {
content_type: Option<methods::Shortstr>,
content_encoding: Option<methods::Shortstr>,
headers: Option<methods::Table>,
delivery_mode: Option<methods::Octet>,
priority: Option<methods::Octet>,
correlation_id: Option<methods::Shortstr>,
reply_to: Option<methods::Shortstr>,
expiration: Option<methods::Shortstr>,
message_id: Option<methods::Shortstr>,
timestamp: Option<methods::Timestamp>,
r#type: Option<methods::Shortstr>,
user_id: Option<methods::Shortstr>,
app_id: Option<methods::Shortstr>,
reserved: Option<methods::Shortstr>,
}
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct ContentHeader { pub struct ContentHeader {
pub class_id: u16, pub class_id: u16,
pub weight: u16, pub weight: u16,
pub body_size: u64, pub body_size: u64,
pub property_flags: SmallVec<[u16; 1]>, pub property_fields: BasicProperties,
pub property_fields: Vec<FieldValue>, }
mod content_header_parse {
use crate::error::TransError;
use crate::frame::{BasicProperties, ContentHeader};
use nom::number::complete::{u16, u64};
use nom::number::Endianness::Big;
type IResult<'a, T> = nom::IResult<&'a [u8], T, TransError>;
pub fn basic_properties(_property_flags: u16, _input: &[u8]) -> IResult<'_, BasicProperties> {
todo!()
}
pub fn header(input: &[u8]) -> IResult<'_, Box<ContentHeader>> {
let (input, class_id) = u16(Big)(input)?;
let (input, weight) = u16(Big)(input)?;
let (input, body_size) = u64(Big)(input)?;
// I do not quite understand this here. Apparently, there can be more than 15 flags?
// But the Basic class only specifies 15, so idk. Don't care about this for now
let (input, property_flags) = u16(Big)(input)?;
let (input, property_fields) = basic_properties(property_flags, input)?;
Ok((
input,
Box::new(ContentHeader {
class_id,
weight,
body_size,
property_fields,
}),
))
}
} }
impl ContentHeader { impl ContentHeader {
pub fn new() -> Self { pub fn parse(input: &[u8]) -> Result<Box<Self>> {
todo!() match content_header_parse::header(input) {
Ok(([], header)) => Ok(header),
Ok((_, _)) => {
Err(
ConException::SyntaxError(vec!["could not consume all input".to_string()])
.into_trans(),
)
}
Err(nom::Err::Incomplete(_)) => {
Err(
ConException::SyntaxError(vec!["there was not enough data".to_string()])
.into_trans(),
)
}
Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err),
}
} }
} }
@ -55,7 +144,7 @@ where
trace!(?frame, "Sending frame"); trace!(?frame, "Sending frame");
w.write_u8(frame.kind as u8).await?; w.write_u8(frame.kind as u8).await?;
w.write_u16(frame.channel).await?; w.write_u16(frame.channel.num()).await?;
w.write_u32(u32::try_from(frame.payload.len()).context("frame size too big")?) w.write_u32(u32::try_from(frame.payload.len()).context("frame size too big")?)
.await?; .await?;
w.write_all(&frame.payload).await?; w.write_all(&frame.payload).await?;
@ -70,6 +159,7 @@ where
{ {
let kind = r.read_u8().await.context("read type")?; let kind = r.read_u8().await.context("read type")?;
let channel = r.read_u16().await.context("read channel")?; let channel = r.read_u16().await.context("read channel")?;
let channel = ChannelId(channel);
let size = r.read_u32().await.context("read size")?; let size = r.read_u32().await.context("read size")?;
let mut payload = vec![0; size.try_into().unwrap()]; let mut payload = vec![0; size.try_into().unwrap()];
@ -98,16 +188,16 @@ where
Ok(frame) Ok(frame)
} }
fn parse_frame_type(kind: u8, channel: u16) -> Result<FrameType> { fn parse_frame_type(kind: u8, channel: ChannelId) -> Result<FrameType> {
match kind { match kind {
frame_type::METHOD => Ok(FrameType::Method), frame_type::METHOD => Ok(FrameType::Method),
frame_type::HEADER => Ok(FrameType::Header), frame_type::HEADER => Ok(FrameType::Header),
frame_type::BODY => Ok(FrameType::Body), frame_type::BODY => Ok(FrameType::Body),
frame_type::HEARTBEAT => { frame_type::HEARTBEAT => {
if channel != 0 { if channel.is_zero() {
Err(ProtocolError::ConException(ConException::FrameError).into())
} else {
Ok(FrameType::Heartbeat) Ok(FrameType::Heartbeat)
} else {
Err(ProtocolError::ConException(ConException::FrameError).into())
} }
} }
_ => Err(ConException::FrameError.into_trans()), _ => Err(ConException::FrameError.into_trans()),
@ -116,7 +206,7 @@ fn parse_frame_type(kind: u8, channel: u16) -> Result<FrameType> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::frame::{Frame, FrameType}; use crate::frame::{ChannelId, Frame, FrameType};
use bytes::Bytes; use bytes::Bytes;
#[tokio::test] #[tokio::test]
@ -145,7 +235,7 @@ mod tests {
frame, frame,
Frame { Frame {
kind: FrameType::Method, kind: FrameType::Method,
channel: 0, channel: ChannelId(0),
payload: Bytes::from_static(&[1, 2, 3]), payload: Bytes::from_static(&[1, 2, 3]),
} }
); );

View file

@ -8,6 +8,8 @@ mod sasl;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
// TODO: handle big types
use crate::connection::Connection; use crate::connection::Connection;
use amqp_core::GlobalData; use amqp_core::GlobalData;
use anyhow::Result; use anyhow::Result;

View file

@ -1,4 +0,0 @@
# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY.
# yarn lockfile v1