rename lol

This commit is contained in:
nora 2022-03-19 14:27:30 +01:00
parent c68cd04af7
commit 543e39f129
70 changed files with 283 additions and 266 deletions

View file

@ -0,0 +1,681 @@
use std::{
cmp::Ordering, collections::HashMap, net::SocketAddr, pin::Pin, sync::Arc, time::Duration,
};
use anyhow::{anyhow, Context};
use bytes::Bytes;
use haesli_core::{
connection::{
Channel, ChannelInner, ChannelNum, ConEventReceiver, ConEventSender, Connection,
ConnectionEvent, ConnectionId, ContentHeader,
},
message::{MessageId, MessageInner, RoutingInformation},
methods::{
BasicPublish, ChannelClose, ChannelCloseOk, ChannelOpenOk, ConnectionClose,
ConnectionCloseOk, ConnectionOpen, ConnectionOpenOk, ConnectionStart, ConnectionStartOk,
ConnectionTune, ConnectionTuneOk, FieldValue, Longstr, Method, ReplyCode, ReplyText, Table,
},
GlobalData,
};
use smallvec::SmallVec;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
select, time,
};
use tracing::{debug, error, info, trace, warn};
use crate::{
error::{ConException, ProtocolError, Result, TransError},
frame::{self, parse_content_header, Frame, FrameType, MaxFrameSize},
methods, sasl,
};
fn ensure_conn(condition: bool) -> Result<()> {
if condition {
Ok(())
} else {
Err(ConException::Todo.into())
}
}
const FRAME_SIZE_MIN_MAX: MaxFrameSize = MaxFrameSize::new(4096);
const CHANNEL_MAX: u16 = 0;
const FRAME_SIZE_MAX: u32 = 0;
const HEARTBEAT_DELAY: u16 = 0;
const BASIC_CLASS_ID: u16 = 60;
pub struct TransportChannel {
/// A handle to the global channel representation. Used to remove the channel when it's dropped
global_chan: Channel,
/// The current status of the channel, whether it has sent a method that expects a body
status: ChannelStatus,
}
pub struct TransportConnection {
id: ConnectionId,
stream: TcpStream,
max_frame_size: MaxFrameSize,
heartbeat_delay: u16,
channel_max: u16,
/// When the next heartbeat expires
next_timeout: Pin<Box<time::Sleep>>,
channels: HashMap<ChannelNum, TransportChannel>,
global_con: Connection,
global_data: GlobalData,
/// Only here to forward to other futures so they can send events
event_sender: ConEventSender,
/// To receive events from other futures
event_receiver: ConEventReceiver,
}
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
enum ChannelStatus {
Default,
NeedHeader(u16, Box<Method>),
NeedsBody(Box<Method>, ContentHeader, SmallVec<[Bytes; 1]>),
}
impl ChannelStatus {
fn take(&mut self) -> Self {
std::mem::replace(self, Self::Default)
}
}
impl TransportConnection {
pub fn new(
id: ConnectionId,
stream: TcpStream,
global_con: Connection,
global_data: GlobalData,
method_queue_send: ConEventSender,
method_queue_recv: ConEventReceiver,
) -> Self {
Self {
id,
stream,
max_frame_size: FRAME_SIZE_MIN_MAX,
heartbeat_delay: HEARTBEAT_DELAY,
channel_max: CHANNEL_MAX,
next_timeout: Box::pin(time::sleep(DEFAULT_TIMEOUT)),
global_con,
channels: HashMap::with_capacity(4),
global_data,
event_sender: method_queue_send,
event_receiver: method_queue_recv,
}
}
pub async fn start_connection_processing(mut self) {
let process_result = self.process_connection().await;
match process_result {
Ok(()) => {}
Err(TransError::Protocol(ProtocolError::GracefullyClosed)) => {
/* do nothing, remove below */
}
Err(TransError::Protocol(ProtocolError::ConException(ex))) => {
warn!(%ex, "Connection exception occurred. This indicates a faulty client.");
let close_result = self.close(ex.reply_code(), ex.reply_text()).await;
match close_result {
Ok(()) => {}
Err(err) => {
error!(%ex, %err, "Failed to close connection after ConnectionException");
}
}
}
Err(err) => error!(%err, "Error during processing of connection"),
}
// global connection is closed on drop
}
pub async fn process_connection(&mut self) -> Result<()> {
self.negotiate_version().await?;
self.start().await?;
self.tune().await?;
self.open().await?;
info!("Connection is ready for usage!");
self.main_loop().await
}
async fn send_method_content(
&mut self,
channel: ChannelNum,
method: &Method,
header: ContentHeader,
body: &SmallVec<[Bytes; 1]>,
) -> Result<()> {
self.send_method(channel, method).await?;
let mut header_buf = Vec::new();
frame::write_content_header(&mut header_buf, &header)?;
frame::write_frame(&mut self.stream, FrameType::Header, channel, &header_buf).await?;
self.send_bodies(channel, body).await
}
async fn send_bodies(
&mut self,
channel: ChannelNum,
body: &SmallVec<[Bytes; 1]>,
) -> Result<()> {
// this is inefficient if it's a huge message sent by a client with big frames to one with
// small frames
// we assume that this won't happen that that the first branch will be taken in most cases,
// elimination the overhead. What we win from keeping each frame as it is that we don't have
// to allocate again for each message
let max_size = self.max_frame_size.as_usize();
for payload in body {
if max_size > payload.len() {
trace!("Sending single method body frame");
// single frame
frame::write_frame(&mut self.stream, FrameType::Body, channel, payload).await?;
} else {
trace!(max = ?self.max_frame_size, "Chunking up method body frames");
// chunk it up into multiple sub-frames
let mut start = 0;
let mut end = max_size;
while end < payload.len() {
let sub_payload = &payload[start..end];
frame::write_frame(&mut self.stream, FrameType::Body, channel, sub_payload)
.await?;
start = end;
end = (end + max_size).max(payload.len());
}
}
}
Ok(())
}
#[tracing::instrument(skip(self), level = "trace")]
async fn send_method(&mut self, channel: ChannelNum, method: &Method) -> Result<()> {
let mut payload = Vec::with_capacity(64);
methods::write::write_method(method, &mut payload)?;
frame::write_frame(&mut self.stream, FrameType::Method, channel, &payload).await
}
async fn recv_method(&mut self) -> Result<Method> {
let start_ok_frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?;
ensure_conn(start_ok_frame.kind == FrameType::Method)?;
let method = methods::parse_method(&start_ok_frame.payload)?;
Ok(method)
}
async fn start(&mut self) -> Result<()> {
let start_method = Method::ConnectionStart(ConnectionStart {
version_major: 0,
version_minor: 9,
server_properties: server_properties(
self.stream
.local_addr()
.context("failed to get local_addr")?,
),
mechanisms: "PLAIN".into(),
locales: "en_US".into(),
});
debug!(?start_method, "Sending Start method");
self.send_method(ChannelNum::zero(), &start_method).await?;
let start_ok = self.recv_method().await?;
debug!(?start_ok, "Received Start-Ok");
if let Method::ConnectionStartOk(ConnectionStartOk {
mechanism,
locale,
response,
..
}) = start_ok
{
ensure_conn(mechanism == "PLAIN")?;
ensure_conn(locale == "en_US")?;
let plain_user = sasl::parse_sasl_plain_response(&response)?;
info!(username = %plain_user.authentication_identity, "SASL Authentication successful");
} else {
return Err(ConException::Todo.into());
}
Ok(())
}
async fn tune(&mut self) -> Result<()> {
let tune_method = Method::ConnectionTune(ConnectionTune {
channel_max: CHANNEL_MAX,
frame_max: FRAME_SIZE_MAX,
heartbeat: HEARTBEAT_DELAY,
});
debug!("Sending Tune method");
self.send_method(ChannelNum::zero(), &tune_method).await?;
let tune_ok = self.recv_method().await?;
debug!(?tune_ok, "Received Tune-Ok method");
if let Method::ConnectionTuneOk(ConnectionTuneOk {
channel_max,
frame_max,
heartbeat,
}) = tune_ok
{
self.channel_max = channel_max;
self.max_frame_size = MaxFrameSize::new(usize::try_from(frame_max).unwrap());
self.heartbeat_delay = heartbeat;
self.reset_timeout();
}
Ok(())
}
async fn open(&mut self) -> Result<()> {
let open = self.recv_method().await?;
debug!(?open, "Received Open method");
if let Method::ConnectionOpen(ConnectionOpen { virtual_host, .. }) = open {
ensure_conn(virtual_host == "/")?;
}
self.send_method(
ChannelNum::zero(),
&Method::ConnectionOpenOk(ConnectionOpenOk {
reserved_1: "".to_owned(),
}),
)
.await?;
Ok(())
}
async fn main_loop(&mut self) -> Result<()> {
loop {
select! {
frame = frame::read_frame(&mut self.stream, self.max_frame_size) => {
let frame = frame?;
self.handle_frame(frame).await?;
}
queued_method = self.event_receiver.recv() => {
match queued_method {
Some(ConnectionEvent::Method(channel, method)) => {
trace!(?channel, ?method, "Received method from event queue");
self.send_method(channel, &method).await?
}
Some(ConnectionEvent::MethodContent(channel, method, header, body)) => {
trace!(?channel, ?method, ?header, ?body, "Received method with body from event queue");
self.send_method_content(channel, &method, header, &body).await?
}
Some(ConnectionEvent::Shutdown) => return self.close(0, "".to_owned()).await,
None => {}
}
}
}
}
}
#[tracing::instrument(skip(self), level = "debug")]
async fn handle_frame(&mut self, frame: Frame) -> Result<()> {
let channel = frame.channel;
self.reset_timeout();
let result = match frame.kind {
FrameType::Method => self.dispatch_method(frame).await,
FrameType::Heartbeat => {
Ok(()) /* Nothing here, just the `reset_timeout` above */
}
FrameType::Header => self.dispatch_header(frame),
FrameType::Body => self.dispatch_body(frame),
};
match result {
Ok(()) => Ok(()),
Err(TransError::Protocol(ProtocolError::ChannelException(ex))) => {
warn!(%ex, "Channel exception occurred");
self.send_method(
channel,
&Method::ChannelClose(ChannelClose {
reply_code: ex.reply_code(),
reply_text: ex.reply_text(),
class_id: 0, // todo: do this
method_id: 0,
}),
)
.await?;
drop(self.channels.remove(&channel));
Ok(())
}
Err(other_err) => Err(other_err),
}
}
#[tracing::instrument(skip(self, frame), level = "trace")]
async fn dispatch_method(&mut self, frame: Frame) -> Result<()> {
let method = methods::parse_method(&frame.payload)?;
// Sending a method implicitly cancels the content frames that might be ongoing
self.channels
.get_mut(&frame.channel)
.map(|channel| channel.status.take());
match method {
Method::ConnectionClose(ConnectionClose {
reply_code,
reply_text,
class_id,
method_id,
}) => {
info!(%reply_code, %reply_text, %class_id, %method_id, "Closing connection");
self.send_method(
ChannelNum::zero(),
&Method::ConnectionCloseOk(ConnectionCloseOk),
)
.await?;
return Err(ProtocolError::GracefullyClosed.into());
}
Method::ChannelOpen { .. } => self.channel_open(frame.channel).await?,
Method::ChannelClose { .. } => self.channel_close(frame.channel, method).await?,
Method::BasicPublish { .. } => match self.channels.get_mut(&frame.channel) {
Some(channel) => {
channel.status = ChannelStatus::NeedHeader(BASIC_CLASS_ID, Box::new(method));
}
None => return Err(ConException::Todo.into()),
},
_ => {
let channel_handle = self
.channels
.get(&frame.channel)
.ok_or(ConException::Todo)?
.global_chan
.clone();
// call into haesli_messaging to handle the method
// it returns the response method that we are supposed to send
// maybe this might become an `Option` in the future
let return_method =
haesli_messaging::methods::handle_method(channel_handle, method).await?;
self.send_method(frame.channel, &return_method).await?;
}
}
Ok(())
}
fn dispatch_header(&mut self, frame: Frame) -> Result<()> {
self.channels
.get_mut(&frame.channel)
.ok_or_else(|| ConException::Todo.into())
.and_then(|channel| match channel.status.take() {
ChannelStatus::Default => {
warn!(channel = %frame.channel, "unexpected header");
Err(ConException::UnexpectedFrame.into())
}
ChannelStatus::NeedHeader(class_id, method) => {
let header = parse_content_header(&frame.payload)?;
ensure_conn(header.class_id == class_id)?;
channel.status = ChannelStatus::NeedsBody(method, header, SmallVec::new());
Ok(())
}
ChannelStatus::NeedsBody(_, _, _) => {
warn!(channel = %frame.channel, "already got header");
Err(ConException::UnexpectedFrame.into())
}
})
}
fn dispatch_body(&mut self, frame: Frame) -> Result<()> {
let channel = self
.channels
.get_mut(&frame.channel)
.ok_or(ConException::Todo)?;
match channel.status.take() {
ChannelStatus::Default => {
warn!(channel = %frame.channel, "unexpected body");
Err(ConException::UnexpectedFrame.into())
}
ChannelStatus::NeedHeader(_, _) => {
warn!(channel = %frame.channel, "unexpected body");
Err(ConException::UnexpectedFrame.into())
}
ChannelStatus::NeedsBody(method, header, mut vec) => {
vec.push(frame.payload);
match vec
.iter()
.map(Bytes::len)
.sum::<usize>()
.cmp(&usize::try_from(header.body_size).unwrap())
{
Ordering::Equal => {
self.process_method_with_body(*method, header, vec, frame.channel)
}
Ordering::Greater => Err(ConException::Todo.into()),
Ordering::Less => Ok(()), // wait for next body
}
}
}
}
fn process_method_with_body(
&mut self,
method: Method,
header: ContentHeader,
payloads: SmallVec<[Bytes; 1]>,
channel: ChannelNum,
) -> Result<()> {
// The only method with content that is sent to the server is Basic.Publish.
ensure_conn(header.class_id == BASIC_CLASS_ID)?;
if let Method::BasicPublish(BasicPublish {
exchange,
routing_key,
mandatory,
immediate,
..
}) = method
{
let message = MessageInner {
id: MessageId::random(),
header,
routing: RoutingInformation {
exchange,
routing_key,
mandatory,
immediate,
},
content: payloads,
};
let message = Arc::new(message);
let channel = self.channels.get(&channel).ok_or(ConException::Todo)?;
haesli_messaging::methods::handle_basic_publish(channel.global_chan.clone(), message)?;
Ok(())
} else {
Err(ConException::Todo.into())
}
}
async fn channel_open(&mut self, channel_num: ChannelNum) -> Result<()> {
let id = rand::random();
let channel_handle = ChannelInner::new(
id,
channel_num,
self.global_con.clone(),
self.global_data.clone(),
self.event_sender.clone(),
);
let channel = TransportChannel {
global_chan: channel_handle.clone(),
status: ChannelStatus::Default,
};
let prev = self.channels.insert(channel_num, channel);
if let Some(prev) = prev {
self.channels.insert(channel_num, prev); // restore previous state
return Err(ConException::ChannelError.into());
}
{
let mut global_data = self.global_data.lock();
global_data.channels.insert(id, channel_handle.clone());
global_data
.connections
.get_mut(&self.id)
.unwrap()
.channels
.lock()
.insert(channel_num, channel_handle);
}
info!(%channel_num, "Opened new channel");
self.send_method(
channel_num,
&Method::ChannelOpenOk(ChannelOpenOk {
reserved_1: Vec::new(),
}),
)
.await?;
Ok(())
}
async fn channel_close(&mut self, channel_id: ChannelNum, method: Method) -> Result<()> {
if let Method::ChannelClose(ChannelClose {
reply_code: code,
reply_text: reason,
..
}) = method
{
info!(%code, %reason, "Closing channel");
if let Some(channel) = self.channels.remove(&channel_id) {
drop(channel);
self.send_method(channel_id, &Method::ChannelCloseOk(ChannelCloseOk))
.await?;
} else {
return Err(ConException::Todo.into());
}
} else {
unreachable!()
}
Ok(())
}
fn reset_timeout(&mut self) {
if self.heartbeat_delay != 0 {
let next = Duration::from_secs(u64::from(self.heartbeat_delay / 2));
self.next_timeout = Box::pin(time::sleep(next));
}
}
async fn negotiate_version(&mut self) -> Result<()> {
const HEADER_SIZE: usize = 8;
const SUPPORTED_PROTOCOL_VERSION: &[u8] = &[0, 9, 1];
const AMQP_PROTOCOL: &[u8] = b"AMQP";
const OWN_PROTOCOL_HEADER: &[u8] = b"AMQP\0\0\x09\x01";
debug!("Negotiating version");
let mut read_header_buf = [0; HEADER_SIZE];
self.stream
.read_exact(&mut read_header_buf)
.await
.context("read protocol header")?;
debug!(received_header = ?read_header_buf,"Received protocol header");
let protocol = &read_header_buf[0..4];
let version = &read_header_buf[5..8];
if protocol != AMQP_PROTOCOL {
self.stream
.write_all(OWN_PROTOCOL_HEADER)
.await
.context("write protocol header")?;
debug!(?protocol, "Version negotiation failed");
return Err(ProtocolError::ProtocolNegotiationFailed.into());
}
if &read_header_buf[0..5] == b"AMQP\0" && version == SUPPORTED_PROTOCOL_VERSION {
debug!(?version, "Version negotiation successful");
Ok(())
} else {
self.stream
.write_all(OWN_PROTOCOL_HEADER)
.await
.context("write protocol header")?;
debug!(?version, expected_version = ?SUPPORTED_PROTOCOL_VERSION, "Version negotiation failed");
Err(ProtocolError::ProtocolNegotiationFailed.into())
}
}
async fn close(&mut self, reply_code: ReplyCode, reply_text: ReplyText) -> Result<()> {
self.send_method(
ChannelNum::zero(),
&Method::ConnectionClose(ConnectionClose {
reply_code,
reply_text,
class_id: 0, // todo: do this
method_id: 0,
}),
)
.await?;
match self.recv_method().await {
Ok(Method::ConnectionCloseOk(_)) => Ok(()),
Ok(method) => {
return Err(TransError::Other(anyhow!(
"Received wrong method after closing, method: {method:?}"
)));
}
Err(err) => {
return Err(TransError::Other(anyhow!(
"Failed to receive Connection.CloseOk method after closing, err: {err}"
)));
}
}
}
}
impl Drop for TransportConnection {
fn drop(&mut self) {
self.global_con.close();
}
}
impl Drop for TransportChannel {
fn drop(&mut self) {
self.global_chan.close();
}
}
fn server_properties(host: SocketAddr) -> Table {
fn ls(str: impl Into<Longstr>) -> FieldValue {
FieldValue::LongString(str.into())
}
let host_str = host.ip().to_string();
HashMap::from([
("host".to_owned(), ls(host_str)),
("product".to_owned(), ls("no name yet")),
("version".to_owned(), ls("0.1.0")),
("platform".to_owned(), ls("microsoft linux")),
("copyright".to_owned(), ls("MIT")),
("information".to_owned(), ls("hello reader")),
("uwu".to_owned(), ls("owo")),
])
}

View file

@ -0,0 +1,27 @@
use std::io::Error;
pub use haesli_core::error::{ConException, ProtocolError};
type StdResult<T, E> = std::result::Result<T, E>;
pub type Result<T> = StdResult<T, TransError>;
#[derive(Debug, thiserror::Error)]
pub enum TransError {
#[error("{0}")]
Protocol(#[from] ProtocolError),
#[error("connection error: `{0}`")]
Other(#[from] anyhow::Error),
}
impl From<std::io::Error> for TransError {
fn from(err: Error) -> Self {
Self::Other(err.into())
}
}
impl From<haesli_core::error::ConException> for TransError {
fn from(err: ConException) -> Self {
Self::Protocol(ProtocolError::ConException(err))
}
}

View file

@ -0,0 +1,372 @@
use std::{
fmt::{Debug, Formatter},
num::NonZeroUsize,
};
use anyhow::Context;
use bytes::Bytes;
use haesli_core::connection::{ChannelNum, ContentHeader};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::trace;
use crate::error::{ConException, ProtocolError, Result};
const REQUIRED_FRAME_END: u8 = 0xCE;
mod frame_type {
pub const METHOD: u8 = 1;
pub const HEADER: u8 = 2;
pub const BODY: u8 = 3;
pub const HEARTBEAT: u8 = 8;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Frame {
/// The type of the frame including its parsed metadata.
pub kind: FrameType,
pub channel: ChannelNum,
/// Includes the whole payload, also including the metadata from each type.
pub payload: Bytes,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[repr(u8)]
pub enum FrameType {
Method = 1,
Header = 2,
Body = 3,
Heartbeat = 8,
}
mod content_header_parse {
use haesli_core::{
connection::ContentHeader,
methods::{
self,
FieldValue::{FieldTable, ShortShortUInt, ShortString, Timestamp},
},
};
use nom::number::{
complete::{u16, u64},
Endianness::Big,
};
use crate::{
error::TransError,
methods::parse_helper::{octet, shortstr, table, timestamp},
};
type IResult<'a, T> = nom::IResult<&'a [u8], T, TransError>;
pub fn basic_properties(flags: u16, input: &[u8]) -> IResult<'_, methods::Table> {
macro_rules! parse_property {
(if $flags:ident >> $n:literal, $parser:ident($input:ident)?, $map:ident.insert($name:expr, $ctor:path)) => {
if (($flags >> $n) & 1) == 1 {
let (input, value) = $parser($input)?;
$map.insert(String::from($name), $ctor(value));
input
} else {
$input
}
};
}
let mut map = methods::Table::new();
let input = parse_property!(if flags >> 15, shortstr(input)?, map.insert("content-type", ShortString));
let input = parse_property!(if flags >> 14, shortstr(input)?, map.insert("content-encoding", ShortString));
let input =
parse_property!(if flags >> 13, table(input)?, map.insert("headers", FieldTable));
let input = parse_property!(if flags >> 12, octet(input)?, map.insert("delivery-mode", ShortShortUInt));
let input =
parse_property!(if flags >> 11, octet(input)?, map.insert("priority", ShortShortUInt));
let input = parse_property!(if flags >> 10, shortstr(input)?, map.insert("correlation-id", ShortString));
let input =
parse_property!(if flags >> 9, shortstr(input)?, map.insert("reply-to", ShortString));
let input =
parse_property!(if flags >> 8, shortstr(input)?, map.insert("expiration", ShortString));
let input =
parse_property!(if flags >> 7, shortstr(input)?, map.insert("message-id", ShortString));
let input =
parse_property!(if flags >> 6, timestamp(input)?, map.insert("timestamp", Timestamp));
let input =
parse_property!(if flags >> 5, shortstr(input)?, map.insert("type", ShortString));
let input =
parse_property!(if flags >> 4, shortstr(input)?, map.insert("user-id", ShortString));
let input =
parse_property!(if flags >> 3, shortstr(input)?, map.insert("app-id", ShortString));
let input =
parse_property!(if flags >> 2, shortstr(input)?, map.insert("reserved", ShortString));
Ok((input, map))
}
pub fn header(input: &[u8]) -> IResult<'_, ContentHeader> {
let (input, class_id) = u16(Big)(input)?;
let (input, weight) = u16(Big)(input)?;
let (input, body_size) = u64(Big)(input)?;
// I do not quite understand this here. Apparently, there can be more than 15 flags?
// But the Basic class only specifies 15, so idk. Don't care about this for now
// Todo: But probably later.
let (input, property_flags) = u16(Big)(input)?;
let (input, property_fields) = basic_properties(property_flags, input)?;
Ok((
input,
ContentHeader {
class_id,
weight,
body_size,
property_fields,
},
))
}
}
pub fn parse_content_header(input: &[u8]) -> Result<ContentHeader> {
match content_header_parse::header(input) {
Ok(([], header)) => Ok(header),
Ok((_, _)) => {
Err(ConException::SyntaxError(vec!["could not consume all input".to_owned()]).into())
}
Err(nom::Err::Incomplete(_)) => {
Err(ConException::SyntaxError(vec!["there was not enough data".to_owned()]).into())
}
Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err),
}
}
mod content_header_write {
use std::io::Write;
use haesli_core::{
connection::ContentHeader,
methods::{
FieldValue::{FieldTable, ShortShortUInt, ShortString, Timestamp},
Table,
},
};
use crate::{
error::Result,
methods::write_helper::{longlong, octet, short, shortstr, table, timestamp},
};
pub fn write_content_header<W: Write>(buf: &mut W, header: &ContentHeader) -> Result<()> {
short(&header.class_id, buf)?;
short(&header.weight, buf)?;
longlong(&header.body_size, buf)?;
write_content_header_props(buf, &header.property_fields)
}
pub fn write_content_header_props<W: Write>(writer: &mut W, header: &Table) -> Result<()> {
let mut flags = 0_u16;
// todo: don't allocate for no reason here
let mut temp_buf = Vec::new();
let buf = &mut temp_buf;
buf.extend_from_slice(&flags.to_be_bytes()); // placeholder
if let Some(ShortString(value)) = header.get("content-type") {
flags |= 1 << 15;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.get("content-encoding") {
flags |= 1 << 14;
shortstr(value, buf)?;
}
if let Some(FieldTable(value)) = header.get("headers") {
flags |= 1 << 13;
table(value, buf)?;
}
if let Some(ShortShortUInt(value)) = header.get("delivery-mode") {
flags |= 1 << 12;
octet(value, buf)?;
}
if let Some(ShortShortUInt(value)) = header.get("priority") {
flags |= 1 << 11;
octet(value, buf)?;
}
if let Some(ShortString(value)) = header.get("correlation-id") {
flags |= 1 << 10;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.get("reply-to") {
flags |= 1 << 9;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.get("expiration") {
flags |= 1 << 8;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.get("message-id") {
flags |= 1 << 7;
shortstr(value, buf)?;
}
if let Some(Timestamp(value)) = header.get("timestamp") {
flags |= 1 << 6;
timestamp(value, buf)?;
}
if let Some(ShortString(value)) = header.get("type") {
flags |= 1 << 5;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.get("user-id") {
flags |= 1 << 4;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.get("app-id") {
flags |= 1 << 3;
shortstr(value, buf)?;
}
if let Some(ShortString(value)) = header.get("reserved") {
flags |= 1 << 2;
shortstr(value, buf)?;
}
let [a, b] = flags.to_be_bytes();
buf[0] = a;
buf[1] = b;
writer.write_all(&temp_buf)?;
Ok(())
}
}
pub fn write_content_header(buf: &mut Vec<u8>, content_header: &ContentHeader) -> Result<()> {
content_header_write::write_content_header(buf, content_header)
}
#[derive(Clone, Copy)]
pub struct MaxFrameSize(Option<NonZeroUsize>);
impl MaxFrameSize {
pub const fn new(size: usize) -> Self {
Self(NonZeroUsize::new(size))
}
pub fn as_usize(&self) -> usize {
self.0.map(NonZeroUsize::get).unwrap_or(usize::MAX)
}
}
impl Debug for MaxFrameSize {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
#[tracing::instrument(skip(w), level = "trace")]
pub async fn write_frame<W>(
mut w: W,
kind: FrameType,
channel: ChannelNum,
payload: &[u8],
) -> Result<()>
where
W: AsyncWriteExt + Unpin + Send,
{
w.write_u8(kind as u8).await?;
w.write_u16(channel.num()).await?;
w.write_u32(u32::try_from(payload.len()).context("frame size too big")?)
.await?;
w.write_all(payload).await?;
w.write_u8(REQUIRED_FRAME_END).await?;
Ok(())
}
pub async fn read_frame<R>(r: &mut R, max_frame_size: MaxFrameSize) -> Result<Frame>
where
R: AsyncReadExt + Unpin + Send,
{
let kind = r.read_u8().await.context("read type")?;
let channel = r.read_u16().await.context("read channel")?;
let channel = ChannelNum::new(channel);
let size = r.read_u32().await.context("read size")?;
let mut payload = vec![0; size.try_into().unwrap()];
r.read_exact(&mut payload).await.context("read payload")?;
let frame_end = r.read_u8().await.context("read frame end")?;
if frame_end != REQUIRED_FRAME_END {
return Err(ProtocolError::Fatal.into());
}
if payload.len() > max_frame_size.as_usize() {
return Err(ConException::FrameError.into());
}
let kind = parse_frame_type(kind, channel)?;
let frame = Frame {
kind,
channel,
payload: payload.into(),
};
trace!(?frame, "Received frame");
Ok(frame)
}
fn parse_frame_type(kind: u8, channel: ChannelNum) -> Result<FrameType> {
match kind {
frame_type::METHOD => Ok(FrameType::Method),
frame_type::HEADER => Ok(FrameType::Header),
frame_type::BODY => Ok(FrameType::Body),
frame_type::HEARTBEAT => {
if channel.is_zero() {
Ok(FrameType::Heartbeat)
} else {
Err(ProtocolError::ConException(ConException::FrameError).into())
}
}
_ => Err(ConException::FrameError.into()),
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use crate::frame::{ChannelNum, Frame, FrameType, MaxFrameSize};
#[tokio::test]
async fn read_small_body() {
let mut bytes: &[u8] = &[
/*type*/
1,
/*channel*/
0,
0,
/*size*/
0,
0,
0,
3,
/*payload*/
1,
2,
3,
/*frame-end*/
super::REQUIRED_FRAME_END,
];
let frame = super::read_frame(&mut bytes, MaxFrameSize::new(10000))
.await
.unwrap();
assert_eq!(
frame,
Frame {
kind: FrameType::Method,
channel: ChannelNum::new(0),
payload: Bytes::from_static(&[1, 2, 3]),
}
);
}
}

104
haesli_transport/src/lib.rs Normal file
View file

@ -0,0 +1,104 @@
#![warn(rust_2018_idioms)]
mod connection;
mod error;
mod frame;
mod methods;
mod sasl;
#[cfg(test)]
mod tests;
// TODO: handle big types
use std::{future::Future, net::SocketAddr};
use anyhow::Context;
use haesli_core::{connection::ConnectionEvent, queue::QueueEvent, GlobalData};
use tokio::{net, net::TcpStream, select};
use tracing::{info, info_span, Instrument};
use crate::connection::TransportConnection;
pub async fn do_thing_i_guess(
global_data: GlobalData,
terminate: impl Future + Send,
) -> anyhow::Result<()> {
select! {
res = accept_cons(global_data.clone()) => {
res
}
_ = terminate => {
handle_shutdown(global_data).await
}
}
}
async fn accept_cons(global_data: GlobalData) -> anyhow::Result<()> {
info!("Binding TCP listener...");
let listener = net::TcpListener::bind(("127.0.0.1", 5672)).await?;
info!(addr = ?listener.local_addr()?, "Successfully bound TCP listener");
loop {
let connection = listener.accept().await?;
handle_con(global_data.clone(), connection);
}
}
fn handle_con(global_data: GlobalData, connection: (TcpStream, SocketAddr)) {
let (stream, peer_addr) = connection;
let id = rand::random();
info!(local_addr = ?stream.local_addr(), %id, "Accepted new connection");
let span = info_span!("client-connection", %id);
let (method_send, method_recv) = tokio::sync::mpsc::channel(10);
let connection_handle = haesli_core::connection::ConnectionInner::new(
id,
peer_addr,
global_data.clone(),
method_send.clone(),
);
let mut global_data_guard = global_data.lock();
global_data_guard
.connections
.insert(id, connection_handle.clone());
let connection = TransportConnection::new(
id,
stream,
connection_handle,
global_data.clone(),
method_send,
method_recv,
);
tokio::spawn(connection.start_connection_processing().instrument(span));
}
async fn handle_shutdown(global_data: GlobalData) -> anyhow::Result<()> {
info!("Shutting down...");
let lock = global_data.lock();
for con in lock.connections.values() {
con.event_sender
.try_send(ConnectionEvent::Shutdown)
.context("failed to stop connection")?;
}
for queue in lock.queues.values() {
queue
.event_send
.try_send(QueueEvent::Shutdown)
.context("failed to stop queue worker")?;
}
// todo: here we should wait for everything to close
// https://github.com/tokio-rs/mini-redis/blob/4b4ecf0310e6bca43d336dde90a06d9dcad00d6c/src/server.rs#L51
info!("Finished shutdown");
Ok(())
}

1554
haesli_transport/src/methods/generated.rs generated Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,100 @@
use haesli_core::{
error::ConException,
methods::{FieldValue, Method, Table},
};
use rand::Rng;
use crate::error::TransError;
mod generated;
pub mod parse_helper;
#[cfg(test)]
mod tests;
pub mod write_helper;
pub use generated::*;
/// Parses the payload of a method frame into the method
pub fn parse_method(payload: &[u8]) -> Result<Method, TransError> {
let nom_result = generated::parse::parse_method(payload);
match nom_result {
Ok(([], method)) => Ok(method),
Ok((_, _)) => {
Err(ConException::SyntaxError(vec!["could not consume all input".to_owned()]).into())
}
Err(nom::Err::Incomplete(_)) => {
Err(ConException::SyntaxError(vec!["there was not enough data".to_owned()]).into())
}
Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err),
}
}
/// Allows the creation of a random instance of that type
pub trait RandomMethod<R: Rng> {
fn random(rng: &mut R) -> Self;
}
impl<R: Rng> RandomMethod<R> for String {
fn random(rng: &mut R) -> Self {
let n = rng.gen_range(0_u16..9999);
format!("string{n}")
}
}
impl<R: Rng, T: RandomMethod<R>> RandomMethod<R> for Vec<T> {
fn random(rng: &mut R) -> Self {
let len = rng.gen_range(1_usize..10);
let mut vec = Vec::with_capacity(len);
(0..len).for_each(|_| vec.push(RandomMethod::random(rng)));
vec
}
}
macro_rules! rand_random_method {
($($ty:ty),+) => {
$(
impl<R: Rng> RandomMethod<R> for $ty {
fn random(rng: &mut R) -> Self {
rng.gen()
}
})+
};
}
rand_random_method!(bool, u8, i8, u16, i16, u32, i32, u64, i64, f32, f64);
impl<R: Rng> RandomMethod<R> for Table {
fn random(rng: &mut R) -> Self {
let len = rng.gen_range(0..3);
(0..len)
.map(|_| (String::random(rng), FieldValue::random(rng)))
.collect()
}
}
impl<R: Rng> RandomMethod<R> for FieldValue {
fn random(rng: &mut R) -> Self {
let index = rng.gen_range(0_u32..17);
match index {
0 => Self::Boolean(RandomMethod::random(rng)),
1 => Self::ShortShortInt(RandomMethod::random(rng)),
2 => Self::ShortShortUInt(RandomMethod::random(rng)),
3 => Self::ShortInt(RandomMethod::random(rng)),
4 => Self::ShortUInt(RandomMethod::random(rng)),
5 => Self::LongInt(RandomMethod::random(rng)),
6 => Self::LongUInt(RandomMethod::random(rng)),
7 => Self::LongLongInt(RandomMethod::random(rng)),
8 => Self::LongLongUInt(RandomMethod::random(rng)),
9 => Self::Float(RandomMethod::random(rng)),
10 => Self::Double(RandomMethod::random(rng)),
11 => Self::ShortString(RandomMethod::random(rng)),
12 => Self::LongString(RandomMethod::random(rng)),
13 => Self::FieldArray(RandomMethod::random(rng)),
14 => Self::Timestamp(RandomMethod::random(rng)),
15 => Self::FieldTable(RandomMethod::random(rng)),
16 => Self::Void,
_ => unreachable!(),
}
}
}

View file

@ -0,0 +1,243 @@
use haesli_core::{
error::{ConException, ProtocolError},
methods::{
Bit, FieldValue, Long, Longlong, Longstr, Octet, Short, Shortstr, Table, TableFieldName,
Timestamp,
},
};
use nom::{
branch::alt,
bytes::complete::{tag, take},
error::ErrorKind,
multi::{count, many0},
number::{
complete::{f32, f64, i16, i32, i64, i8, u16, u32, u64, u8},
Endianness::Big,
},
Err,
};
use crate::{error::TransError, methods::generated::parse::IResult};
impl<T> nom::error::ParseError<T> for TransError {
fn from_error_kind(_input: T, _kind: ErrorKind) -> Self {
ConException::SyntaxError(vec![]).into()
}
fn append(_input: T, _kind: ErrorKind, other: Self) -> Self {
other
}
}
pub fn fail_err<S: Into<String>>(msg: S) -> impl FnOnce(Err<TransError>) -> Err<TransError> {
move |err| {
let msg = msg.into();
let stack = match err {
Err::Error(e) | Err::Failure(e) => match e {
TransError::Protocol(ProtocolError::ConException(ConException::SyntaxError(
mut stack,
))) => {
stack.push(msg);
stack
}
_ => vec![msg],
},
Err::Incomplete(_) => vec![msg],
};
Err::Failure(ConException::SyntaxError(stack).into())
}
}
pub fn other_fail<E, S: Into<String>>(msg: S) -> impl FnOnce(E) -> Err<TransError> {
move |_| Err::Failure(ConException::SyntaxError(vec![msg.into()]).into())
}
#[macro_export]
macro_rules! fail {
($cause:expr) => {
return Err(nom::Err::Failure(
::haesli_core::error::ProtocolError::ConException(
::haesli_core::error::ConException::SyntaxError(vec![String::from($cause)]),
)
.into(),
))
};
}
pub use fail;
pub fn octet(input: &[u8]) -> IResult<'_, Octet> {
u8(input)
}
pub fn short(input: &[u8]) -> IResult<'_, Short> {
u16(Big)(input)
}
pub fn long(input: &[u8]) -> IResult<'_, Long> {
u32(Big)(input)
}
pub fn longlong(input: &[u8]) -> IResult<'_, Longlong> {
u64(Big)(input)
}
pub fn bit(input: &[u8], amount: usize) -> IResult<'_, Vec<Bit>> {
let octets = (amount + 7) / 8;
let (input, bytes) = take(octets)(input)?;
let mut vec = Vec::new();
let mut byte_index = 0;
let mut total_index = 0;
for &byte in bytes {
while byte_index < 8 && total_index < amount {
let next_bit = 1 & (byte >> byte_index);
let bit_bool = match next_bit {
0 => false,
1 => true,
_ => unreachable!(),
};
vec.push(bit_bool);
byte_index += 1;
total_index += 1;
}
byte_index = 0;
}
Ok((input, vec))
}
pub fn shortstr(input: &[u8]) -> IResult<'_, Shortstr> {
let (input, len) = u8(input)?;
let (input, str_data) = take(usize::from(len))(input)?;
let data = String::from_utf8(str_data.into()).map_err(other_fail("shortstr"))?;
Ok((input, data))
}
pub fn longstr(input: &[u8]) -> IResult<'_, Longstr> {
let (input, len) = u32(Big)(input)?;
let (input, str_data) = take(usize::try_from(len).unwrap())(input)?;
let data = str_data.into();
Ok((input, data))
}
pub fn timestamp(input: &[u8]) -> IResult<'_, Timestamp> {
u64(Big)(input)
}
pub fn table(input: &[u8]) -> IResult<'_, Table> {
let (input, size) = u32(Big)(input)?;
let (table_input, rest_input) = input.split_at(size.try_into().unwrap());
let (input, values) = many0(table_value_pair)(table_input)?;
if !input.is_empty() {
fail!(format!(
"table longer than expected, expected = {size}, remaining = {}",
input.len()
));
}
let table = values.into_iter().collect();
Ok((rest_input, table))
}
fn table_value_pair(input: &[u8]) -> IResult<'_, (TableFieldName, FieldValue)> {
let (input, field_name) = shortstr(input)?;
let (input, field_value) =
field_value(input).map_err(fail_err(format!("field {field_name}")))?;
Ok((input, (field_name, field_value)))
}
fn field_value(input: &[u8]) -> IResult<'_, FieldValue> {
type R<'a> = IResult<'a, FieldValue>;
fn boolean(input: &[u8]) -> R<'_> {
let (input, _) = tag(b"t")(input)?;
let (input, bool_byte) = u8(input)?;
match bool_byte {
0 => Ok((input, FieldValue::Boolean(false))),
1 => Ok((input, FieldValue::Boolean(true))),
value => fail!(format!("invalid bool value {value}")),
}
}
macro_rules! number {
($tag:literal, $name:ident, $comb:expr, $value:ident, $r:path) => {
fn $name(input: &[u8]) -> $r {
let (input, _) = tag($tag)(input)?;
$comb(input).map(|(input, int)| (input, FieldValue::$value(int)))
}
};
}
number!(b"b", short_short_int, i8, ShortShortInt, R<'_>);
number!(b"B", short_short_uint, u8, ShortShortUInt, R<'_>);
number!(b"U", short_int, i16(Big), ShortInt, R<'_>);
number!(b"u", short_uint, u16(Big), ShortUInt, R<'_>);
number!(b"I", long_int, i32(Big), LongInt, R<'_>);
number!(b"i", long_uint, u32(Big), LongUInt, R<'_>);
number!(b"L", long_long_int, i64(Big), LongLongInt, R<'_>);
number!(b"l", long_long_uint, u64(Big), LongLongUInt, R<'_>);
number!(b"f", float, f32(Big), Float, R<'_>);
number!(b"d", double, f64(Big), Double, R<'_>);
fn decimal(input: &[u8]) -> R<'_> {
let (input, _) = tag("D")(input)?;
let (input, scale) = u8(input)?;
let (input, value) = u32(Big)(input)?;
Ok((input, FieldValue::DecimalValue(scale, value)))
}
fn short_str(input: &[u8]) -> R<'_> {
let (input, _) = tag("s")(input)?;
let (input, str) = shortstr(input)?;
Ok((input, FieldValue::ShortString(str)))
}
fn long_str(input: &[u8]) -> R<'_> {
let (input, _) = tag("S")(input)?;
let (input, str) = longstr(input)?;
Ok((input, FieldValue::LongString(str)))
}
fn field_array(input: &[u8]) -> R<'_> {
let (input, _) = tag("A")(input)?;
// todo is it i32?
let (input, len) = u32(Big)(input)?;
count(field_value, usize::try_from(len).unwrap())(input)
.map(|(input, value)| (input, FieldValue::FieldArray(value)))
}
number!(b"T", timestamp, u64(Big), Timestamp, R<'_>);
fn field_table(input: &[u8]) -> R<'_> {
let (input, _) = tag("F")(input)?;
table(input).map(|(input, value)| (input, FieldValue::FieldTable(value)))
}
fn void(input: &[u8]) -> R<'_> {
tag("V")(input).map(|(input, _)| (input, FieldValue::Void))
}
alt((
boolean,
short_short_int,
short_short_uint,
short_int,
short_uint,
long_int,
long_uint,
long_long_int,
long_long_uint,
float,
double,
decimal,
short_str,
long_str,
field_array,
timestamp,
field_table,
void,
))(input)
}

View file

@ -0,0 +1,85 @@
// create random methods to test the ser/de code together. if they diverge, we have a bug
// this is not perfect, if they both have the same bug it won't be found, but that's an ok tradeoff
use std::collections::HashMap;
use rand::SeedableRng;
use crate::methods::{FieldValue, Method, RandomMethod};
#[test]
fn pack_few_bits() {
let bits = [true, false, true];
let mut buffer = [0u8; 2];
super::write_helper::bit(&bits, &mut buffer.as_mut_slice()).unwrap();
let (_, parsed_bits) = super::parse_helper::bit(&buffer, 3).unwrap();
assert_eq!(bits.as_slice(), parsed_bits.as_slice());
}
#[test]
fn pack_many_bits() {
let bits = [
/* first 8 */
true, true, true, true, false, false, false, false, /* second 4 */
true, false, true, true,
];
let mut buffer = [0u8; 2];
super::write_helper::bit(&bits, &mut buffer.as_mut_slice()).unwrap();
let (_, parsed_bits) = super::parse_helper::bit(&buffer, 12).unwrap();
assert_eq!(bits.as_slice(), parsed_bits.as_slice());
}
#[test]
fn random_ser_de() {
const ITERATIONS: usize = 10000;
let mut rng = rand::rngs::StdRng::from_seed([0; 32]);
for _ in 0..ITERATIONS {
let method = Method::random(&mut rng);
let mut bytes = Vec::new();
if let Err(err) = super::write::write_method(&method, &mut bytes) {
eprintln!("{method:#?}");
eprintln!("{err:?}");
panic!("Failed to serialize");
}
match super::parse_method(&bytes) {
Ok(parsed) => {
if method != parsed {
eprintln!("{method:#?}");
eprintln!("{bytes:?}");
eprintln!("{parsed:?}");
panic!("Not equal!");
}
}
Err(err) => {
eprintln!("{method:#?}");
eprintln!("{bytes:?}");
eprintln!("{err:?}");
panic!("Failed to deserialize");
}
}
}
}
#[test]
fn nested_table() {
let table = HashMap::from([(
"A".to_owned(),
FieldValue::FieldTable(HashMap::from([("B".to_owned(), FieldValue::Boolean(true))])),
)]);
eprintln!("{table:?}");
let mut bytes = Vec::new();
crate::methods::write_helper::table(&table, &mut bytes).unwrap();
eprintln!("{bytes:?}");
let (rest, parsed_table) = crate::methods::parse_helper::table(&bytes).unwrap();
assert!(rest.is_empty());
assert_eq!(table, parsed_table);
}

View file

@ -0,0 +1,199 @@
use std::io::Write;
use anyhow::Context;
use haesli_core::methods::{
Bit, Long, Longlong, Longstr, Octet, Short, Shortstr, Table, Timestamp,
};
use crate::{error::TransError, methods::FieldValue};
pub fn octet<W: Write>(value: &Octet, writer: &mut W) -> Result<(), TransError> {
writer.write_all(&[*value])?;
Ok(())
}
pub fn short<W: Write>(value: &Short, writer: &mut W) -> Result<(), TransError> {
writer.write_all(&value.to_be_bytes())?;
Ok(())
}
pub fn long<W: Write>(value: &Long, writer: &mut W) -> Result<(), TransError> {
writer.write_all(&value.to_be_bytes())?;
Ok(())
}
pub fn longlong<W: Write>(value: &Longlong, writer: &mut W) -> Result<(), TransError> {
writer.write_all(&value.to_be_bytes())?;
Ok(())
}
pub fn bit<W: Write>(value: &[Bit], writer: &mut W) -> Result<(), TransError> {
// accumulate bits into bytes, starting from the least significant bit in each byte
// how many bits have already been packed into `current_buf`
let mut already_filled = 0;
let mut current_buf = 0u8;
for &bit in value {
if already_filled >= 8 {
writer.write_all(&[current_buf])?;
current_buf = 0;
already_filled = 0;
}
let new_bit = (u8::from(bit)) << already_filled;
current_buf |= new_bit;
already_filled += 1;
}
if already_filled > 0 {
writer.write_all(&[current_buf])?;
}
Ok(())
}
pub fn shortstr<W: Write>(value: &Shortstr, writer: &mut W) -> Result<(), TransError> {
let len = u8::try_from(value.len()).context("shortstr too long")?;
writer.write_all(&[len])?;
writer.write_all(value.as_bytes())?;
Ok(())
}
pub fn longstr<W: Write>(value: &Longstr, writer: &mut W) -> Result<(), TransError> {
let len = u32::try_from(value.len()).context("longstr too long")?;
writer.write_all(&len.to_be_bytes())?;
writer.write_all(value.as_slice())?;
Ok(())
}
// this appears to be unused right now, but it could be used in `Basic` things?
#[allow(dead_code)]
pub fn timestamp<W: Write>(value: &Timestamp, writer: &mut W) -> Result<(), TransError> {
writer.write_all(&value.to_be_bytes())?;
Ok(())
}
pub fn table<W: Write>(table: &Table, writer: &mut W) -> Result<(), TransError> {
let mut table_buf = Vec::new();
for (field_name, value) in table {
shortstr(field_name, &mut table_buf)?;
field_value(value, &mut table_buf)?;
}
let len = u32::try_from(table_buf.len()).context("table too big")?;
writer.write_all(&len.to_be_bytes())?;
writer.write_all(&table_buf)?;
Ok(())
}
fn field_value<W: Write>(value: &FieldValue, writer: &mut W) -> Result<(), TransError> {
match value {
FieldValue::Boolean(bool) => {
writer.write_all(&[b't', u8::from(*bool)])?;
}
FieldValue::ShortShortInt(int) => {
writer.write_all(b"b")?;
writer.write_all(&int.to_be_bytes())?;
}
FieldValue::ShortShortUInt(int) => {
writer.write_all(&[b'B', *int])?;
}
FieldValue::ShortInt(int) => {
writer.write_all(b"U")?;
writer.write_all(&int.to_be_bytes())?;
}
FieldValue::ShortUInt(int) => {
writer.write_all(b"u")?;
writer.write_all(&int.to_be_bytes())?;
}
FieldValue::LongInt(int) => {
writer.write_all(b"I")?;
writer.write_all(&int.to_be_bytes())?;
}
FieldValue::LongUInt(int) => {
writer.write_all(b"i")?;
writer.write_all(&int.to_be_bytes())?;
}
FieldValue::LongLongInt(int) => {
writer.write_all(b"L")?;
writer.write_all(&int.to_be_bytes())?;
}
FieldValue::LongLongUInt(int) => {
writer.write_all(b"l")?;
writer.write_all(&int.to_be_bytes())?;
}
FieldValue::Float(float) => {
writer.write_all(b"f")?;
writer.write_all(&float.to_be_bytes())?;
}
FieldValue::Double(float) => {
writer.write_all(b"d")?;
writer.write_all(&float.to_be_bytes())?;
}
FieldValue::DecimalValue(scale, long) => {
writer.write_all(&[b'D', *scale])?;
writer.write_all(&long.to_be_bytes())?;
}
FieldValue::ShortString(str) => {
writer.write_all(b"s")?;
shortstr(str, writer)?;
}
FieldValue::LongString(str) => {
writer.write_all(b"S")?;
longstr(str, writer)?;
}
FieldValue::FieldArray(array) => {
writer.write_all(b"A")?;
let len = u32::try_from(array.len()).context("array too long")?;
writer.write_all(&len.to_be_bytes())?;
for element in array {
field_value(element, writer)?;
}
}
FieldValue::Timestamp(time) => {
writer.write_all(b"T")?;
writer.write_all(&time.to_be_bytes())?;
}
FieldValue::FieldTable(value) => {
writer.write_all(b"F")?;
table(value, writer)?;
}
FieldValue::Void => {
writer.write_all(b"V")?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
#[test]
fn pack_few_bits() {
let bits = [true, false, true];
let mut buffer = [0u8; 1];
super::bit(&bits, &mut buffer.as_mut_slice()).unwrap();
assert_eq!(buffer, [0b00000101])
}
#[test]
fn pack_many_bits() {
let bits = [
/* first 8 */
true, true, true, true, false, false, false, false, /* second 4 */
true, false, true, true,
];
let mut buffer = [0u8; 2];
super::bit(&bits, &mut buffer.as_mut_slice()).unwrap();
assert_eq!(buffer, [0b00001111, 0b00001101]);
}
}

View file

@ -0,0 +1,29 @@
//! (Very) partial implementation of SASL Authentication (see [RFC 4422](https://datatracker.ietf.org/doc/html/rfc4422))
//!
//! Currently only supports PLAIN (see [RFC 4616](https://datatracker.ietf.org/doc/html/rfc4616))
use haesli_core::error::ConException;
use crate::error::Result;
pub struct PlainUser {
pub authorization_identity: String,
pub authentication_identity: String,
pub password: String,
}
pub fn parse_sasl_plain_response(response: &[u8]) -> Result<PlainUser> {
let mut parts = response
.split(|&n| n == 0)
.map(|bytes| String::from_utf8(bytes.into()).map_err(|_| ConException::Todo));
let authorization_identity = parts.next().ok_or(ConException::Todo)??;
let authentication_identity = parts.next().ok_or(ConException::Todo)??;
let password = parts.next().ok_or(ConException::Todo)??;
Ok(PlainUser {
authorization_identity,
authentication_identity,
password,
})
}

View file

@ -0,0 +1,180 @@
use std::collections::HashMap;
use haesli_core::{
connection::ChannelNum,
methods::{ConnectionStart, ConnectionStartOk, FieldValue, Method},
};
use crate::{frame, frame::FrameType, methods};
#[tokio::test]
async fn write_start_ok_frame() {
let mut payload = Vec::new();
let method = Method::ConnectionStart(ConnectionStart {
version_major: 0,
version_minor: 9,
server_properties: HashMap::from([(
"product".to_owned(),
FieldValue::LongString("no name yet".into()),
)]),
mechanisms: "PLAIN".into(),
locales: "en_US".into(),
});
methods::write::write_method(&method, &mut payload).unwrap();
let mut output = Vec::new();
frame::write_frame(&mut output, FrameType::Method, ChannelNum::zero(), &payload)
.await
.unwrap();
#[rustfmt::skip]
let expected = [
/* type, octet, method */
1u8,
/* channel, short */
0, 0,
/* size, long */
/* count all the bytes in the payload, 33 here */
0, 0, 0, 52,
/* payload */
/* class-id, short, connection */
0, 10,
/* method-id, short, start */
0, 10,
/* version-major, octet */
0,
/* version-minor, octet */
9,
/* server-properties, table */
/* table-size, long (actual byte size) */
0, 0, 0, 24,
/* table-items */
/* name ("product"), shortstr */
/* len (7) ; bytes */
7, b'p', b'r', b'o', b'd', b'u', b'c', b't',
/* value, a shortstr ("no name yet") here */
/* tag (s) ; len (11) ; data */
b'S', 0, 0, 0, 11, b'n', b'o', b' ', b'n', b'a', b'm', b'e', b' ', b'y', b'e', b't',
/* mechanisms, longstr */
/* str-len, long ; len 5 ; data ("PLAIN") */
0, 0, 0, 5,
b'P', b'L', b'A', b'I', b'N',
/* locales, longstr */
/* str-len, long ; len 5 ; data ("en_US") */
0, 0, 0, 5,
b'e', b'n', b'_', b'U', b'S',
/* frame-end */
0xCE,
];
assert_eq!(expected.as_slice(), output.as_slice());
}
#[test]
fn read_start_ok_payload() {
#[rustfmt::skip]
let raw_data = [
/* Connection.Start-Ok */
0, 10, 0, 11,
/* field client-properties */
/* byte size of the table */
0, 0, 0, 254,
/* first key of len 7, "product"*/
7, 112, 114, 111, 100, 117, 99, 116,
/* value is of type 83 ("S"), long-string */
/* has length 26 "Pika Python Client Library" */
83, 0, 0, 0, 26,
80, 105, 107, 97, 32, 80, 121, 116, 104, 111, 110, 32, 67, 108, 105, 101, 110, 116, 32, 76, 105, 98, 114, 97, 114, 121,
/* second key of len 8, "platform" */
8, 112, 108, 97, 116, 102, 111, 114, 109,
/* value is of type 83("S"), long-string */
/* has length 13, "Python 3.8.10" */
83, 0, 0, 0, 13,
80, 121, 116, 104, 111, 110, 32, 51, 46, 56, 46, 49, 48,
/* third key has len 12 "capabilities" */
12, 99, 97, 112, 97, 98, 105, 108, 105, 116, 105, 101, 115,
/* type is 70 F (table), with byte-len of 111 */
70, 0, 0, 0, 111,
/* first key has length 28, "authentication_failure_close" */
28, 97, 117, 116, 104, 101, 110, 116, 105, 99, 97, 116, 105, 111, 110, 95, 102, 97, 105, 108, 117, 114, 101, 95, 99, 108, 111, 115, 101,
/* value of type 116, "t", boolean, true */
116, 1,
/* second key has length 10, "basic.nack" */
10, 98, 97, 115, 105, 99, 46, 110, 97, 99, 107,
/* value of type 116, "t", boolean, true */
116, 1,
/* third key has length 18 "connection.blocked" */
18, 99, 111, 110, 110, 101, 99, 116, 105, 111, 110, 46, 98, 108, 111, 99, 107, 101, 100,
/* value of type 116, "t", boolean, true */
116, 1,
/* fourth key has length 22 "consumer_cancel_notify" */
22, 99, 111, 110, 115, 117, 109, 101, 114, 95, 99, 97, 110, 99, 101, 108, 95, 110, 111, 116, 105, 102, 121,
/* value of type 116, "t", boolean, true */
116, 1,
/* fifth key has length 18 "publisher_confirms" */
18, 112, 117, 98, 108, 105, 115, 104, 101, 114, 95, 99, 111, 110, 102, 105, 114, 109, 115,
/* value of type 116, "t", boolean, true */
116, 1,
/* sixth key has length 11 "information" */
11, 105, 110, 102, 111, 114, 109, 97, 116, 105, 111, 110,
/* value of type 83, "S" long-str ; len 24 ; data "See http://pika.rtfd.org" */
83, 0, 0, 0, 24,
83, 101, 101, 32, 104, 116, 116, 112, 58, 47, 47, 112, 105, 107, 97, 46, 114, 116, 102, 100, 46, 111, 114, 103,
/* seventh key has length 7, "version" */
7, 118, 101, 114, 115, 105, 111, 110,
/* value of type 83, "S" long-str ; length 5 ; "1.1.0" */
83, 0, 0, 0, 5,
49, 46, 49, 46, 48,
/* client-properties table ends here */
/* field mechanism, length 5, "PLAIN" */
5, 80, 76, 65, 73, 78,
/* field response, longstr, length 7, "\x00admin\x00" */
0, 0, 0, 7, 0, 97, 100, 109, 105, 110, 0,
/* locale, shortstr, len 5 "en_US" */
5, 101, 110, 95, 85, 83,
];
let method = methods::parse_method(&raw_data).unwrap();
assert_eq!(
method,
Method::ConnectionStartOk(ConnectionStartOk {
client_properties: HashMap::from([
(
"product".to_owned(),
FieldValue::LongString("Pika Python Client Library".into())
),
(
"platform".to_owned(),
FieldValue::LongString("Python 3.8.10".into())
),
(
"capabilities".to_owned(),
FieldValue::FieldTable(HashMap::from([
(
"authentication_failure_close".to_owned(),
FieldValue::Boolean(true)
),
("basic.nack".to_owned(), FieldValue::Boolean(true)),
("connection.blocked".to_owned(), FieldValue::Boolean(true)),
(
"consumer_cancel_notify".to_owned(),
FieldValue::Boolean(true)
),
("publisher_confirms".to_owned(), FieldValue::Boolean(true)),
]))
),
(
"information".to_owned(),
FieldValue::LongString("See http://pika.rtfd.org".into())
),
("version".to_owned(), FieldValue::LongString("1.1.0".into()))
]),
mechanism: "PLAIN".to_owned(),
response: "\x00admin\x00".into(),
locale: "en_US".to_owned()
})
);
}