not working :(

This commit is contained in:
nora 2022-03-05 16:26:12 +01:00
parent 3bcce76885
commit 08ba799d23
4 changed files with 140 additions and 74 deletions

View file

@ -1,10 +1,11 @@
use crate::error::{ConException, ProtocolError, Result};
use amqp_core::{
amqp_todo,
connection::{ChannelNum, ContentHeader},
};
use amqp_core::connection::{ChannelNum, ContentHeader};
use anyhow::Context;
use bytes::Bytes;
use std::{
fmt::{Debug, Formatter},
num::NonZeroUsize,
};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::trace;
@ -133,71 +134,82 @@ pub fn parse_content_header(input: &[u8]) -> Result<ContentHeader> {
mod content_header_write {
use crate::{
methods::write_helper::{octet, shortstr, table, timestamp},
Result,
error::Result,
methods::write_helper::{longlong, octet, short, shortstr, table, timestamp},
};
use amqp_core::{
connection::ContentHeader,
methods::FieldValue::{FieldTable, ShortShortUInt, ShortString, Timestamp},
methods::{
FieldValue::{FieldTable, ShortShortUInt, ShortString, Timestamp},
Table,
},
};
pub fn write_content_header(buf: &mut Vec<u8>, header: ContentHeader) -> Result<()> {
pub fn write_content_header(buf: &mut Vec<u8>, header: &ContentHeader) -> Result<()> {
short(&header.class_id, buf)?;
short(&header.weight, buf)?;
longlong(&header.body_size, buf)?;
write_content_header_props(buf, &header.property_fields)
}
pub fn write_content_header_props(buf: &mut Vec<u8>, header: &Table) -> Result<()> {
let mut flags = 0_u16;
buf.extend_from_slice(&flags.to_be_bytes()); // placeholder
if let Some(ShortString(value)) = header.property_fields.get("content-type") {
if let Some(ShortString(value)) = header.get("content-type") {
flags |= 1 << 15;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.property_fields.get("content-encoding") {
if let Some(ShortString(value)) = header.get("content-encoding") {
flags |= 1 << 14;
shortstr(value, buf)?;
}
if let Some(FieldTable(value)) = header.property_fields.get("headers") {
if let Some(FieldTable(value)) = header.get("headers") {
flags |= 1 << 13;
table(value, buf)?;
}
if let Some(ShortShortUInt(value)) = header.property_fields.get("delivery-mode") {
if let Some(ShortShortUInt(value)) = header.get("delivery-mode") {
flags |= 1 << 12;
octet(value, buf)?;
}
if let Some(ShortShortUInt(value)) = header.property_fields.get("priority") {
if let Some(ShortShortUInt(value)) = header.get("priority") {
flags |= 1 << 11;
octet(value, buf)?;
}
if let Some(ShortString(value)) = header.property_fields.get("correlation-id") {
if let Some(ShortString(value)) = header.get("correlation-id") {
flags |= 1 << 10;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.property_fields.get("reply-to") {
if let Some(ShortString(value)) = header.get("reply-to") {
flags |= 1 << 9;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.property_fields.get("expiration") {
if let Some(ShortString(value)) = header.get("expiration") {
flags |= 1 << 8;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.property_fields.get("message-id") {
if let Some(ShortString(value)) = header.get("message-id") {
flags |= 1 << 7;
shortstr(value, buf)?;
}
if let Some(Timestamp(value)) = header.property_fields.get("timestamp") {
if let Some(Timestamp(value)) = header.get("timestamp") {
flags |= 1 << 6;
timestamp(value, buf)?;
}
if let Some(ShortString(value)) = header.property_fields.get("type") {
if let Some(ShortString(value)) = header.get("type") {
flags |= 1 << 5;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.property_fields.get("user-id") {
if let Some(ShortString(value)) = header.get("user-id") {
flags |= 1 << 4;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.property_fields.get("app-id") {
if let Some(ShortString(value)) = header.get("app-id") {
flags |= 1 << 3;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.property_fields.get("reserved") {
if let Some(ShortString(value)) = header.get("reserved") {
flags |= 1 << 2;
shortstr(value, buf)?;
}
@ -210,27 +222,51 @@ mod content_header_write {
}
}
pub fn write_content_header(buf: &mut Vec<u8>, content_header: ContentHeader) -> Result<()> {
write_content_header(buf, content_header)
pub fn write_content_header(buf: &mut Vec<u8>, content_header: &ContentHeader) -> Result<()> {
content_header_write::write_content_header(buf, content_header)
}
pub async fn write_frame<W>(frame: &Frame, mut w: W) -> Result<()>
#[derive(Clone, Copy)]
pub struct MaxFrameSize(Option<NonZeroUsize>);
impl MaxFrameSize {
pub const fn new(size: usize) -> Self {
Self(NonZeroUsize::new(size))
}
pub fn as_usize(&self) -> usize {
self.0.map(NonZeroUsize::get).unwrap_or(usize::MAX)
}
}
impl Debug for MaxFrameSize {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
pub async fn write_frame<W>(
mut w: W,
kind: FrameType,
channel: ChannelNum,
payload: &[u8],
) -> Result<()>
where
W: AsyncWriteExt + Unpin + Send,
{
trace!(?frame, "Sending frame");
trace!(?kind, ?channel, ?payload, "Sending frame");
w.write_u8(frame.kind as u8).await?;
w.write_u16(frame.channel.num()).await?;
w.write_u32(u32::try_from(frame.payload.len()).context("frame size too big")?)
w.write_u8(kind as u8).await?;
w.write_u16(channel.num()).await?;
w.write_u32(u32::try_from(payload.len()).context("frame size too big")?)
.await?;
w.write_all(&frame.payload).await?;
w.write_all(payload).await?;
w.write_u8(REQUIRED_FRAME_END).await?;
Ok(())
}
pub async fn read_frame<R>(r: &mut R, max_frame_size: usize) -> Result<Frame>
pub async fn read_frame<R>(r: &mut R, max_frame_size: MaxFrameSize) -> Result<Frame>
where
R: AsyncReadExt + Unpin + Send,
{
@ -248,7 +284,7 @@ where
return Err(ProtocolError::Fatal.into());
}
if max_frame_size != 0 && payload.len() > max_frame_size {
if payload.len() > max_frame_size.as_usize() {
return Err(ConException::FrameError.into());
}
@ -283,7 +319,7 @@ fn parse_frame_type(kind: u8, channel: ChannelNum) -> Result<FrameType> {
#[cfg(test)]
mod tests {
use crate::frame::{ChannelNum, Frame, FrameType};
use crate::frame::{ChannelNum, Frame, FrameType, MaxFrameSize};
use bytes::Bytes;
#[tokio::test]
@ -307,7 +343,9 @@ mod tests {
super::REQUIRED_FRAME_END,
];
let frame = super::read_frame(&mut bytes, 10000).await.unwrap();
let frame = super::read_frame(&mut bytes, MaxFrameSize::new(10000))
.await
.unwrap();
assert_eq!(
frame,
Frame {