mirror of
https://github.com/Noratrieb/haesli.git
synced 2026-01-16 20:55:03 +01:00
rename lol
This commit is contained in:
parent
c68cd04af7
commit
543e39f129
70 changed files with 283 additions and 266 deletions
681
haesli_transport/src/connection.rs
Normal file
681
haesli_transport/src/connection.rs
Normal file
|
|
@ -0,0 +1,681 @@
|
|||
use std::{
|
||||
cmp::Ordering, collections::HashMap, net::SocketAddr, pin::Pin, sync::Arc, time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use bytes::Bytes;
|
||||
use haesli_core::{
|
||||
connection::{
|
||||
Channel, ChannelInner, ChannelNum, ConEventReceiver, ConEventSender, Connection,
|
||||
ConnectionEvent, ConnectionId, ContentHeader,
|
||||
},
|
||||
message::{MessageId, MessageInner, RoutingInformation},
|
||||
methods::{
|
||||
BasicPublish, ChannelClose, ChannelCloseOk, ChannelOpenOk, ConnectionClose,
|
||||
ConnectionCloseOk, ConnectionOpen, ConnectionOpenOk, ConnectionStart, ConnectionStartOk,
|
||||
ConnectionTune, ConnectionTuneOk, FieldValue, Longstr, Method, ReplyCode, ReplyText, Table,
|
||||
},
|
||||
GlobalData,
|
||||
};
|
||||
use smallvec::SmallVec;
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
net::TcpStream,
|
||||
select, time,
|
||||
};
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
|
||||
use crate::{
|
||||
error::{ConException, ProtocolError, Result, TransError},
|
||||
frame::{self, parse_content_header, Frame, FrameType, MaxFrameSize},
|
||||
methods, sasl,
|
||||
};
|
||||
|
||||
fn ensure_conn(condition: bool) -> Result<()> {
|
||||
if condition {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConException::Todo.into())
|
||||
}
|
||||
}
|
||||
|
||||
const FRAME_SIZE_MIN_MAX: MaxFrameSize = MaxFrameSize::new(4096);
|
||||
const CHANNEL_MAX: u16 = 0;
|
||||
const FRAME_SIZE_MAX: u32 = 0;
|
||||
const HEARTBEAT_DELAY: u16 = 0;
|
||||
|
||||
const BASIC_CLASS_ID: u16 = 60;
|
||||
|
||||
pub struct TransportChannel {
|
||||
/// A handle to the global channel representation. Used to remove the channel when it's dropped
|
||||
global_chan: Channel,
|
||||
/// The current status of the channel, whether it has sent a method that expects a body
|
||||
status: ChannelStatus,
|
||||
}
|
||||
|
||||
pub struct TransportConnection {
|
||||
id: ConnectionId,
|
||||
stream: TcpStream,
|
||||
max_frame_size: MaxFrameSize,
|
||||
heartbeat_delay: u16,
|
||||
channel_max: u16,
|
||||
/// When the next heartbeat expires
|
||||
next_timeout: Pin<Box<time::Sleep>>,
|
||||
channels: HashMap<ChannelNum, TransportChannel>,
|
||||
global_con: Connection,
|
||||
global_data: GlobalData,
|
||||
/// Only here to forward to other futures so they can send events
|
||||
event_sender: ConEventSender,
|
||||
/// To receive events from other futures
|
||||
event_receiver: ConEventReceiver,
|
||||
}
|
||||
|
||||
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
enum ChannelStatus {
|
||||
Default,
|
||||
NeedHeader(u16, Box<Method>),
|
||||
NeedsBody(Box<Method>, ContentHeader, SmallVec<[Bytes; 1]>),
|
||||
}
|
||||
|
||||
impl ChannelStatus {
|
||||
fn take(&mut self) -> Self {
|
||||
std::mem::replace(self, Self::Default)
|
||||
}
|
||||
}
|
||||
|
||||
impl TransportConnection {
|
||||
pub fn new(
|
||||
id: ConnectionId,
|
||||
stream: TcpStream,
|
||||
global_con: Connection,
|
||||
global_data: GlobalData,
|
||||
method_queue_send: ConEventSender,
|
||||
method_queue_recv: ConEventReceiver,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
stream,
|
||||
max_frame_size: FRAME_SIZE_MIN_MAX,
|
||||
heartbeat_delay: HEARTBEAT_DELAY,
|
||||
channel_max: CHANNEL_MAX,
|
||||
next_timeout: Box::pin(time::sleep(DEFAULT_TIMEOUT)),
|
||||
global_con,
|
||||
channels: HashMap::with_capacity(4),
|
||||
global_data,
|
||||
event_sender: method_queue_send,
|
||||
event_receiver: method_queue_recv,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_connection_processing(mut self) {
|
||||
let process_result = self.process_connection().await;
|
||||
|
||||
match process_result {
|
||||
Ok(()) => {}
|
||||
Err(TransError::Protocol(ProtocolError::GracefullyClosed)) => {
|
||||
/* do nothing, remove below */
|
||||
}
|
||||
Err(TransError::Protocol(ProtocolError::ConException(ex))) => {
|
||||
warn!(%ex, "Connection exception occurred. This indicates a faulty client.");
|
||||
let close_result = self.close(ex.reply_code(), ex.reply_text()).await;
|
||||
|
||||
match close_result {
|
||||
Ok(()) => {}
|
||||
Err(err) => {
|
||||
error!(%ex, %err, "Failed to close connection after ConnectionException");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => error!(%err, "Error during processing of connection"),
|
||||
}
|
||||
|
||||
// global connection is closed on drop
|
||||
}
|
||||
|
||||
pub async fn process_connection(&mut self) -> Result<()> {
|
||||
self.negotiate_version().await?;
|
||||
self.start().await?;
|
||||
self.tune().await?;
|
||||
self.open().await?;
|
||||
|
||||
info!("Connection is ready for usage!");
|
||||
|
||||
self.main_loop().await
|
||||
}
|
||||
|
||||
async fn send_method_content(
|
||||
&mut self,
|
||||
channel: ChannelNum,
|
||||
method: &Method,
|
||||
header: ContentHeader,
|
||||
body: &SmallVec<[Bytes; 1]>,
|
||||
) -> Result<()> {
|
||||
self.send_method(channel, method).await?;
|
||||
|
||||
let mut header_buf = Vec::new();
|
||||
frame::write_content_header(&mut header_buf, &header)?;
|
||||
frame::write_frame(&mut self.stream, FrameType::Header, channel, &header_buf).await?;
|
||||
|
||||
self.send_bodies(channel, body).await
|
||||
}
|
||||
|
||||
async fn send_bodies(
|
||||
&mut self,
|
||||
channel: ChannelNum,
|
||||
body: &SmallVec<[Bytes; 1]>,
|
||||
) -> Result<()> {
|
||||
// this is inefficient if it's a huge message sent by a client with big frames to one with
|
||||
// small frames
|
||||
// we assume that this won't happen that that the first branch will be taken in most cases,
|
||||
// elimination the overhead. What we win from keeping each frame as it is that we don't have
|
||||
// to allocate again for each message
|
||||
|
||||
let max_size = self.max_frame_size.as_usize();
|
||||
|
||||
for payload in body {
|
||||
if max_size > payload.len() {
|
||||
trace!("Sending single method body frame");
|
||||
// single frame
|
||||
frame::write_frame(&mut self.stream, FrameType::Body, channel, payload).await?;
|
||||
} else {
|
||||
trace!(max = ?self.max_frame_size, "Chunking up method body frames");
|
||||
// chunk it up into multiple sub-frames
|
||||
let mut start = 0;
|
||||
let mut end = max_size;
|
||||
|
||||
while end < payload.len() {
|
||||
let sub_payload = &payload[start..end];
|
||||
|
||||
frame::write_frame(&mut self.stream, FrameType::Body, channel, sub_payload)
|
||||
.await?;
|
||||
|
||||
start = end;
|
||||
end = (end + max_size).max(payload.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), level = "trace")]
|
||||
async fn send_method(&mut self, channel: ChannelNum, method: &Method) -> Result<()> {
|
||||
let mut payload = Vec::with_capacity(64);
|
||||
methods::write::write_method(method, &mut payload)?;
|
||||
frame::write_frame(&mut self.stream, FrameType::Method, channel, &payload).await
|
||||
}
|
||||
|
||||
async fn recv_method(&mut self) -> Result<Method> {
|
||||
let start_ok_frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?;
|
||||
|
||||
ensure_conn(start_ok_frame.kind == FrameType::Method)?;
|
||||
|
||||
let method = methods::parse_method(&start_ok_frame.payload)?;
|
||||
Ok(method)
|
||||
}
|
||||
|
||||
async fn start(&mut self) -> Result<()> {
|
||||
let start_method = Method::ConnectionStart(ConnectionStart {
|
||||
version_major: 0,
|
||||
version_minor: 9,
|
||||
server_properties: server_properties(
|
||||
self.stream
|
||||
.local_addr()
|
||||
.context("failed to get local_addr")?,
|
||||
),
|
||||
mechanisms: "PLAIN".into(),
|
||||
locales: "en_US".into(),
|
||||
});
|
||||
|
||||
debug!(?start_method, "Sending Start method");
|
||||
self.send_method(ChannelNum::zero(), &start_method).await?;
|
||||
|
||||
let start_ok = self.recv_method().await?;
|
||||
debug!(?start_ok, "Received Start-Ok");
|
||||
|
||||
if let Method::ConnectionStartOk(ConnectionStartOk {
|
||||
mechanism,
|
||||
locale,
|
||||
response,
|
||||
..
|
||||
}) = start_ok
|
||||
{
|
||||
ensure_conn(mechanism == "PLAIN")?;
|
||||
ensure_conn(locale == "en_US")?;
|
||||
let plain_user = sasl::parse_sasl_plain_response(&response)?;
|
||||
info!(username = %plain_user.authentication_identity, "SASL Authentication successful");
|
||||
} else {
|
||||
return Err(ConException::Todo.into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn tune(&mut self) -> Result<()> {
|
||||
let tune_method = Method::ConnectionTune(ConnectionTune {
|
||||
channel_max: CHANNEL_MAX,
|
||||
frame_max: FRAME_SIZE_MAX,
|
||||
heartbeat: HEARTBEAT_DELAY,
|
||||
});
|
||||
|
||||
debug!("Sending Tune method");
|
||||
self.send_method(ChannelNum::zero(), &tune_method).await?;
|
||||
|
||||
let tune_ok = self.recv_method().await?;
|
||||
debug!(?tune_ok, "Received Tune-Ok method");
|
||||
|
||||
if let Method::ConnectionTuneOk(ConnectionTuneOk {
|
||||
channel_max,
|
||||
frame_max,
|
||||
heartbeat,
|
||||
}) = tune_ok
|
||||
{
|
||||
self.channel_max = channel_max;
|
||||
self.max_frame_size = MaxFrameSize::new(usize::try_from(frame_max).unwrap());
|
||||
self.heartbeat_delay = heartbeat;
|
||||
self.reset_timeout();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn open(&mut self) -> Result<()> {
|
||||
let open = self.recv_method().await?;
|
||||
debug!(?open, "Received Open method");
|
||||
|
||||
if let Method::ConnectionOpen(ConnectionOpen { virtual_host, .. }) = open {
|
||||
ensure_conn(virtual_host == "/")?;
|
||||
}
|
||||
|
||||
self.send_method(
|
||||
ChannelNum::zero(),
|
||||
&Method::ConnectionOpenOk(ConnectionOpenOk {
|
||||
reserved_1: "".to_owned(),
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn main_loop(&mut self) -> Result<()> {
|
||||
loop {
|
||||
select! {
|
||||
frame = frame::read_frame(&mut self.stream, self.max_frame_size) => {
|
||||
let frame = frame?;
|
||||
self.handle_frame(frame).await?;
|
||||
}
|
||||
queued_method = self.event_receiver.recv() => {
|
||||
match queued_method {
|
||||
Some(ConnectionEvent::Method(channel, method)) => {
|
||||
trace!(?channel, ?method, "Received method from event queue");
|
||||
self.send_method(channel, &method).await?
|
||||
}
|
||||
Some(ConnectionEvent::MethodContent(channel, method, header, body)) => {
|
||||
trace!(?channel, ?method, ?header, ?body, "Received method with body from event queue");
|
||||
self.send_method_content(channel, &method, header, &body).await?
|
||||
}
|
||||
Some(ConnectionEvent::Shutdown) => return self.close(0, "".to_owned()).await,
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), level = "debug")]
|
||||
async fn handle_frame(&mut self, frame: Frame) -> Result<()> {
|
||||
let channel = frame.channel;
|
||||
self.reset_timeout();
|
||||
|
||||
let result = match frame.kind {
|
||||
FrameType::Method => self.dispatch_method(frame).await,
|
||||
FrameType::Heartbeat => {
|
||||
Ok(()) /* Nothing here, just the `reset_timeout` above */
|
||||
}
|
||||
FrameType::Header => self.dispatch_header(frame),
|
||||
FrameType::Body => self.dispatch_body(frame),
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(()) => Ok(()),
|
||||
Err(TransError::Protocol(ProtocolError::ChannelException(ex))) => {
|
||||
warn!(%ex, "Channel exception occurred");
|
||||
self.send_method(
|
||||
channel,
|
||||
&Method::ChannelClose(ChannelClose {
|
||||
reply_code: ex.reply_code(),
|
||||
reply_text: ex.reply_text(),
|
||||
class_id: 0, // todo: do this
|
||||
method_id: 0,
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
drop(self.channels.remove(&channel));
|
||||
Ok(())
|
||||
}
|
||||
Err(other_err) => Err(other_err),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, frame), level = "trace")]
|
||||
async fn dispatch_method(&mut self, frame: Frame) -> Result<()> {
|
||||
let method = methods::parse_method(&frame.payload)?;
|
||||
|
||||
// Sending a method implicitly cancels the content frames that might be ongoing
|
||||
self.channels
|
||||
.get_mut(&frame.channel)
|
||||
.map(|channel| channel.status.take());
|
||||
|
||||
match method {
|
||||
Method::ConnectionClose(ConnectionClose {
|
||||
reply_code,
|
||||
reply_text,
|
||||
class_id,
|
||||
method_id,
|
||||
}) => {
|
||||
info!(%reply_code, %reply_text, %class_id, %method_id, "Closing connection");
|
||||
self.send_method(
|
||||
ChannelNum::zero(),
|
||||
&Method::ConnectionCloseOk(ConnectionCloseOk),
|
||||
)
|
||||
.await?;
|
||||
return Err(ProtocolError::GracefullyClosed.into());
|
||||
}
|
||||
Method::ChannelOpen { .. } => self.channel_open(frame.channel).await?,
|
||||
Method::ChannelClose { .. } => self.channel_close(frame.channel, method).await?,
|
||||
Method::BasicPublish { .. } => match self.channels.get_mut(&frame.channel) {
|
||||
Some(channel) => {
|
||||
channel.status = ChannelStatus::NeedHeader(BASIC_CLASS_ID, Box::new(method));
|
||||
}
|
||||
None => return Err(ConException::Todo.into()),
|
||||
},
|
||||
_ => {
|
||||
let channel_handle = self
|
||||
.channels
|
||||
.get(&frame.channel)
|
||||
.ok_or(ConException::Todo)?
|
||||
.global_chan
|
||||
.clone();
|
||||
|
||||
// call into haesli_messaging to handle the method
|
||||
// it returns the response method that we are supposed to send
|
||||
// maybe this might become an `Option` in the future
|
||||
let return_method =
|
||||
haesli_messaging::methods::handle_method(channel_handle, method).await?;
|
||||
self.send_method(frame.channel, &return_method).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dispatch_header(&mut self, frame: Frame) -> Result<()> {
|
||||
self.channels
|
||||
.get_mut(&frame.channel)
|
||||
.ok_or_else(|| ConException::Todo.into())
|
||||
.and_then(|channel| match channel.status.take() {
|
||||
ChannelStatus::Default => {
|
||||
warn!(channel = %frame.channel, "unexpected header");
|
||||
Err(ConException::UnexpectedFrame.into())
|
||||
}
|
||||
ChannelStatus::NeedHeader(class_id, method) => {
|
||||
let header = parse_content_header(&frame.payload)?;
|
||||
ensure_conn(header.class_id == class_id)?;
|
||||
|
||||
channel.status = ChannelStatus::NeedsBody(method, header, SmallVec::new());
|
||||
Ok(())
|
||||
}
|
||||
ChannelStatus::NeedsBody(_, _, _) => {
|
||||
warn!(channel = %frame.channel, "already got header");
|
||||
Err(ConException::UnexpectedFrame.into())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn dispatch_body(&mut self, frame: Frame) -> Result<()> {
|
||||
let channel = self
|
||||
.channels
|
||||
.get_mut(&frame.channel)
|
||||
.ok_or(ConException::Todo)?;
|
||||
|
||||
match channel.status.take() {
|
||||
ChannelStatus::Default => {
|
||||
warn!(channel = %frame.channel, "unexpected body");
|
||||
Err(ConException::UnexpectedFrame.into())
|
||||
}
|
||||
ChannelStatus::NeedHeader(_, _) => {
|
||||
warn!(channel = %frame.channel, "unexpected body");
|
||||
Err(ConException::UnexpectedFrame.into())
|
||||
}
|
||||
ChannelStatus::NeedsBody(method, header, mut vec) => {
|
||||
vec.push(frame.payload);
|
||||
match vec
|
||||
.iter()
|
||||
.map(Bytes::len)
|
||||
.sum::<usize>()
|
||||
.cmp(&usize::try_from(header.body_size).unwrap())
|
||||
{
|
||||
Ordering::Equal => {
|
||||
self.process_method_with_body(*method, header, vec, frame.channel)
|
||||
}
|
||||
Ordering::Greater => Err(ConException::Todo.into()),
|
||||
Ordering::Less => Ok(()), // wait for next body
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_method_with_body(
|
||||
&mut self,
|
||||
method: Method,
|
||||
header: ContentHeader,
|
||||
payloads: SmallVec<[Bytes; 1]>,
|
||||
channel: ChannelNum,
|
||||
) -> Result<()> {
|
||||
// The only method with content that is sent to the server is Basic.Publish.
|
||||
ensure_conn(header.class_id == BASIC_CLASS_ID)?;
|
||||
|
||||
if let Method::BasicPublish(BasicPublish {
|
||||
exchange,
|
||||
routing_key,
|
||||
mandatory,
|
||||
immediate,
|
||||
..
|
||||
}) = method
|
||||
{
|
||||
let message = MessageInner {
|
||||
id: MessageId::random(),
|
||||
header,
|
||||
routing: RoutingInformation {
|
||||
exchange,
|
||||
routing_key,
|
||||
mandatory,
|
||||
immediate,
|
||||
},
|
||||
content: payloads,
|
||||
};
|
||||
let message = Arc::new(message);
|
||||
|
||||
let channel = self.channels.get(&channel).ok_or(ConException::Todo)?;
|
||||
|
||||
haesli_messaging::methods::handle_basic_publish(channel.global_chan.clone(), message)?;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ConException::Todo.into())
|
||||
}
|
||||
}
|
||||
|
||||
async fn channel_open(&mut self, channel_num: ChannelNum) -> Result<()> {
|
||||
let id = rand::random();
|
||||
let channel_handle = ChannelInner::new(
|
||||
id,
|
||||
channel_num,
|
||||
self.global_con.clone(),
|
||||
self.global_data.clone(),
|
||||
self.event_sender.clone(),
|
||||
);
|
||||
|
||||
let channel = TransportChannel {
|
||||
global_chan: channel_handle.clone(),
|
||||
status: ChannelStatus::Default,
|
||||
};
|
||||
|
||||
let prev = self.channels.insert(channel_num, channel);
|
||||
if let Some(prev) = prev {
|
||||
self.channels.insert(channel_num, prev); // restore previous state
|
||||
return Err(ConException::ChannelError.into());
|
||||
}
|
||||
|
||||
{
|
||||
let mut global_data = self.global_data.lock();
|
||||
global_data.channels.insert(id, channel_handle.clone());
|
||||
global_data
|
||||
.connections
|
||||
.get_mut(&self.id)
|
||||
.unwrap()
|
||||
.channels
|
||||
.lock()
|
||||
.insert(channel_num, channel_handle);
|
||||
}
|
||||
|
||||
info!(%channel_num, "Opened new channel");
|
||||
|
||||
self.send_method(
|
||||
channel_num,
|
||||
&Method::ChannelOpenOk(ChannelOpenOk {
|
||||
reserved_1: Vec::new(),
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn channel_close(&mut self, channel_id: ChannelNum, method: Method) -> Result<()> {
|
||||
if let Method::ChannelClose(ChannelClose {
|
||||
reply_code: code,
|
||||
reply_text: reason,
|
||||
..
|
||||
}) = method
|
||||
{
|
||||
info!(%code, %reason, "Closing channel");
|
||||
|
||||
if let Some(channel) = self.channels.remove(&channel_id) {
|
||||
drop(channel);
|
||||
self.send_method(channel_id, &Method::ChannelCloseOk(ChannelCloseOk))
|
||||
.await?;
|
||||
} else {
|
||||
return Err(ConException::Todo.into());
|
||||
}
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn reset_timeout(&mut self) {
|
||||
if self.heartbeat_delay != 0 {
|
||||
let next = Duration::from_secs(u64::from(self.heartbeat_delay / 2));
|
||||
self.next_timeout = Box::pin(time::sleep(next));
|
||||
}
|
||||
}
|
||||
|
||||
async fn negotiate_version(&mut self) -> Result<()> {
|
||||
const HEADER_SIZE: usize = 8;
|
||||
const SUPPORTED_PROTOCOL_VERSION: &[u8] = &[0, 9, 1];
|
||||
const AMQP_PROTOCOL: &[u8] = b"AMQP";
|
||||
const OWN_PROTOCOL_HEADER: &[u8] = b"AMQP\0\0\x09\x01";
|
||||
|
||||
debug!("Negotiating version");
|
||||
|
||||
let mut read_header_buf = [0; HEADER_SIZE];
|
||||
|
||||
self.stream
|
||||
.read_exact(&mut read_header_buf)
|
||||
.await
|
||||
.context("read protocol header")?;
|
||||
|
||||
debug!(received_header = ?read_header_buf,"Received protocol header");
|
||||
|
||||
let protocol = &read_header_buf[0..4];
|
||||
let version = &read_header_buf[5..8];
|
||||
|
||||
if protocol != AMQP_PROTOCOL {
|
||||
self.stream
|
||||
.write_all(OWN_PROTOCOL_HEADER)
|
||||
.await
|
||||
.context("write protocol header")?;
|
||||
debug!(?protocol, "Version negotiation failed");
|
||||
return Err(ProtocolError::ProtocolNegotiationFailed.into());
|
||||
}
|
||||
|
||||
if &read_header_buf[0..5] == b"AMQP\0" && version == SUPPORTED_PROTOCOL_VERSION {
|
||||
debug!(?version, "Version negotiation successful");
|
||||
Ok(())
|
||||
} else {
|
||||
self.stream
|
||||
.write_all(OWN_PROTOCOL_HEADER)
|
||||
.await
|
||||
.context("write protocol header")?;
|
||||
debug!(?version, expected_version = ?SUPPORTED_PROTOCOL_VERSION, "Version negotiation failed");
|
||||
Err(ProtocolError::ProtocolNegotiationFailed.into())
|
||||
}
|
||||
}
|
||||
|
||||
async fn close(&mut self, reply_code: ReplyCode, reply_text: ReplyText) -> Result<()> {
|
||||
self.send_method(
|
||||
ChannelNum::zero(),
|
||||
&Method::ConnectionClose(ConnectionClose {
|
||||
reply_code,
|
||||
reply_text,
|
||||
class_id: 0, // todo: do this
|
||||
method_id: 0,
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
match self.recv_method().await {
|
||||
Ok(Method::ConnectionCloseOk(_)) => Ok(()),
|
||||
Ok(method) => {
|
||||
return Err(TransError::Other(anyhow!(
|
||||
"Received wrong method after closing, method: {method:?}"
|
||||
)));
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(TransError::Other(anyhow!(
|
||||
"Failed to receive Connection.CloseOk method after closing, err: {err}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TransportConnection {
|
||||
fn drop(&mut self) {
|
||||
self.global_con.close();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TransportChannel {
|
||||
fn drop(&mut self) {
|
||||
self.global_chan.close();
|
||||
}
|
||||
}
|
||||
|
||||
fn server_properties(host: SocketAddr) -> Table {
|
||||
fn ls(str: impl Into<Longstr>) -> FieldValue {
|
||||
FieldValue::LongString(str.into())
|
||||
}
|
||||
|
||||
let host_str = host.ip().to_string();
|
||||
HashMap::from([
|
||||
("host".to_owned(), ls(host_str)),
|
||||
("product".to_owned(), ls("no name yet")),
|
||||
("version".to_owned(), ls("0.1.0")),
|
||||
("platform".to_owned(), ls("microsoft linux")),
|
||||
("copyright".to_owned(), ls("MIT")),
|
||||
("information".to_owned(), ls("hello reader")),
|
||||
("uwu".to_owned(), ls("owo")),
|
||||
])
|
||||
}
|
||||
27
haesli_transport/src/error.rs
Normal file
27
haesli_transport/src/error.rs
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
use std::io::Error;
|
||||
|
||||
pub use haesli_core::error::{ConException, ProtocolError};
|
||||
|
||||
type StdResult<T, E> = std::result::Result<T, E>;
|
||||
|
||||
pub type Result<T> = StdResult<T, TransError>;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TransError {
|
||||
#[error("{0}")]
|
||||
Protocol(#[from] ProtocolError),
|
||||
#[error("connection error: `{0}`")]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for TransError {
|
||||
fn from(err: Error) -> Self {
|
||||
Self::Other(err.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<haesli_core::error::ConException> for TransError {
|
||||
fn from(err: ConException) -> Self {
|
||||
Self::Protocol(ProtocolError::ConException(err))
|
||||
}
|
||||
}
|
||||
372
haesli_transport/src/frame.rs
Normal file
372
haesli_transport/src/frame.rs
Normal file
|
|
@ -0,0 +1,372 @@
|
|||
use std::{
|
||||
fmt::{Debug, Formatter},
|
||||
num::NonZeroUsize,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use haesli_core::connection::{ChannelNum, ContentHeader};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tracing::trace;
|
||||
|
||||
use crate::error::{ConException, ProtocolError, Result};
|
||||
|
||||
const REQUIRED_FRAME_END: u8 = 0xCE;
|
||||
|
||||
mod frame_type {
|
||||
pub const METHOD: u8 = 1;
|
||||
pub const HEADER: u8 = 2;
|
||||
pub const BODY: u8 = 3;
|
||||
pub const HEARTBEAT: u8 = 8;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Frame {
|
||||
/// The type of the frame including its parsed metadata.
|
||||
pub kind: FrameType,
|
||||
pub channel: ChannelNum,
|
||||
/// Includes the whole payload, also including the metadata from each type.
|
||||
pub payload: Bytes,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum FrameType {
|
||||
Method = 1,
|
||||
Header = 2,
|
||||
Body = 3,
|
||||
Heartbeat = 8,
|
||||
}
|
||||
|
||||
mod content_header_parse {
|
||||
use haesli_core::{
|
||||
connection::ContentHeader,
|
||||
methods::{
|
||||
self,
|
||||
FieldValue::{FieldTable, ShortShortUInt, ShortString, Timestamp},
|
||||
},
|
||||
};
|
||||
use nom::number::{
|
||||
complete::{u16, u64},
|
||||
Endianness::Big,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
error::TransError,
|
||||
methods::parse_helper::{octet, shortstr, table, timestamp},
|
||||
};
|
||||
|
||||
type IResult<'a, T> = nom::IResult<&'a [u8], T, TransError>;
|
||||
|
||||
pub fn basic_properties(flags: u16, input: &[u8]) -> IResult<'_, methods::Table> {
|
||||
macro_rules! parse_property {
|
||||
(if $flags:ident >> $n:literal, $parser:ident($input:ident)?, $map:ident.insert($name:expr, $ctor:path)) => {
|
||||
if (($flags >> $n) & 1) == 1 {
|
||||
let (input, value) = $parser($input)?;
|
||||
$map.insert(String::from($name), $ctor(value));
|
||||
input
|
||||
} else {
|
||||
$input
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let mut map = methods::Table::new();
|
||||
|
||||
let input = parse_property!(if flags >> 15, shortstr(input)?, map.insert("content-type", ShortString));
|
||||
let input = parse_property!(if flags >> 14, shortstr(input)?, map.insert("content-encoding", ShortString));
|
||||
let input =
|
||||
parse_property!(if flags >> 13, table(input)?, map.insert("headers", FieldTable));
|
||||
let input = parse_property!(if flags >> 12, octet(input)?, map.insert("delivery-mode", ShortShortUInt));
|
||||
let input =
|
||||
parse_property!(if flags >> 11, octet(input)?, map.insert("priority", ShortShortUInt));
|
||||
let input = parse_property!(if flags >> 10, shortstr(input)?, map.insert("correlation-id", ShortString));
|
||||
let input =
|
||||
parse_property!(if flags >> 9, shortstr(input)?, map.insert("reply-to", ShortString));
|
||||
let input =
|
||||
parse_property!(if flags >> 8, shortstr(input)?, map.insert("expiration", ShortString));
|
||||
let input =
|
||||
parse_property!(if flags >> 7, shortstr(input)?, map.insert("message-id", ShortString));
|
||||
let input =
|
||||
parse_property!(if flags >> 6, timestamp(input)?, map.insert("timestamp", Timestamp));
|
||||
let input =
|
||||
parse_property!(if flags >> 5, shortstr(input)?, map.insert("type", ShortString));
|
||||
let input =
|
||||
parse_property!(if flags >> 4, shortstr(input)?, map.insert("user-id", ShortString));
|
||||
let input =
|
||||
parse_property!(if flags >> 3, shortstr(input)?, map.insert("app-id", ShortString));
|
||||
let input =
|
||||
parse_property!(if flags >> 2, shortstr(input)?, map.insert("reserved", ShortString));
|
||||
|
||||
Ok((input, map))
|
||||
}
|
||||
|
||||
pub fn header(input: &[u8]) -> IResult<'_, ContentHeader> {
|
||||
let (input, class_id) = u16(Big)(input)?;
|
||||
let (input, weight) = u16(Big)(input)?;
|
||||
let (input, body_size) = u64(Big)(input)?;
|
||||
|
||||
// I do not quite understand this here. Apparently, there can be more than 15 flags?
|
||||
// But the Basic class only specifies 15, so idk. Don't care about this for now
|
||||
// Todo: But probably later.
|
||||
let (input, property_flags) = u16(Big)(input)?;
|
||||
let (input, property_fields) = basic_properties(property_flags, input)?;
|
||||
|
||||
Ok((
|
||||
input,
|
||||
ContentHeader {
|
||||
class_id,
|
||||
weight,
|
||||
body_size,
|
||||
property_fields,
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_content_header(input: &[u8]) -> Result<ContentHeader> {
|
||||
match content_header_parse::header(input) {
|
||||
Ok(([], header)) => Ok(header),
|
||||
Ok((_, _)) => {
|
||||
Err(ConException::SyntaxError(vec!["could not consume all input".to_owned()]).into())
|
||||
}
|
||||
Err(nom::Err::Incomplete(_)) => {
|
||||
Err(ConException::SyntaxError(vec!["there was not enough data".to_owned()]).into())
|
||||
}
|
||||
Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
mod content_header_write {
|
||||
use std::io::Write;
|
||||
|
||||
use haesli_core::{
|
||||
connection::ContentHeader,
|
||||
methods::{
|
||||
FieldValue::{FieldTable, ShortShortUInt, ShortString, Timestamp},
|
||||
Table,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
error::Result,
|
||||
methods::write_helper::{longlong, octet, short, shortstr, table, timestamp},
|
||||
};
|
||||
|
||||
pub fn write_content_header<W: Write>(buf: &mut W, header: &ContentHeader) -> Result<()> {
|
||||
short(&header.class_id, buf)?;
|
||||
short(&header.weight, buf)?;
|
||||
longlong(&header.body_size, buf)?;
|
||||
|
||||
write_content_header_props(buf, &header.property_fields)
|
||||
}
|
||||
|
||||
pub fn write_content_header_props<W: Write>(writer: &mut W, header: &Table) -> Result<()> {
|
||||
let mut flags = 0_u16;
|
||||
// todo: don't allocate for no reason here
|
||||
let mut temp_buf = Vec::new();
|
||||
let buf = &mut temp_buf;
|
||||
|
||||
buf.extend_from_slice(&flags.to_be_bytes()); // placeholder
|
||||
|
||||
if let Some(ShortString(value)) = header.get("content-type") {
|
||||
flags |= 1 << 15;
|
||||
shortstr(value, buf)?;
|
||||
}
|
||||
if let Some(ShortString(value)) = header.get("content-encoding") {
|
||||
flags |= 1 << 14;
|
||||
shortstr(value, buf)?;
|
||||
}
|
||||
if let Some(FieldTable(value)) = header.get("headers") {
|
||||
flags |= 1 << 13;
|
||||
table(value, buf)?;
|
||||
}
|
||||
if let Some(ShortShortUInt(value)) = header.get("delivery-mode") {
|
||||
flags |= 1 << 12;
|
||||
octet(value, buf)?;
|
||||
}
|
||||
if let Some(ShortShortUInt(value)) = header.get("priority") {
|
||||
flags |= 1 << 11;
|
||||
octet(value, buf)?;
|
||||
}
|
||||
if let Some(ShortString(value)) = header.get("correlation-id") {
|
||||
flags |= 1 << 10;
|
||||
shortstr(value, buf)?;
|
||||
}
|
||||
if let Some(ShortString(value)) = header.get("reply-to") {
|
||||
flags |= 1 << 9;
|
||||
shortstr(value, buf)?;
|
||||
}
|
||||
if let Some(ShortString(value)) = header.get("expiration") {
|
||||
flags |= 1 << 8;
|
||||
shortstr(value, buf)?;
|
||||
}
|
||||
if let Some(ShortString(value)) = header.get("message-id") {
|
||||
flags |= 1 << 7;
|
||||
shortstr(value, buf)?;
|
||||
}
|
||||
if let Some(Timestamp(value)) = header.get("timestamp") {
|
||||
flags |= 1 << 6;
|
||||
timestamp(value, buf)?;
|
||||
}
|
||||
if let Some(ShortString(value)) = header.get("type") {
|
||||
flags |= 1 << 5;
|
||||
shortstr(value, buf)?;
|
||||
}
|
||||
if let Some(ShortString(value)) = header.get("user-id") {
|
||||
flags |= 1 << 4;
|
||||
shortstr(value, buf)?;
|
||||
}
|
||||
if let Some(ShortString(value)) = header.get("app-id") {
|
||||
flags |= 1 << 3;
|
||||
shortstr(value, buf)?;
|
||||
}
|
||||
if let Some(ShortString(value)) = header.get("reserved") {
|
||||
flags |= 1 << 2;
|
||||
shortstr(value, buf)?;
|
||||
}
|
||||
|
||||
let [a, b] = flags.to_be_bytes();
|
||||
buf[0] = a;
|
||||
buf[1] = b;
|
||||
|
||||
writer.write_all(&temp_buf)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_content_header(buf: &mut Vec<u8>, content_header: &ContentHeader) -> Result<()> {
|
||||
content_header_write::write_content_header(buf, content_header)
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct MaxFrameSize(Option<NonZeroUsize>);
|
||||
|
||||
impl MaxFrameSize {
|
||||
pub const fn new(size: usize) -> Self {
|
||||
Self(NonZeroUsize::new(size))
|
||||
}
|
||||
|
||||
pub fn as_usize(&self) -> usize {
|
||||
self.0.map(NonZeroUsize::get).unwrap_or(usize::MAX)
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for MaxFrameSize {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(w), level = "trace")]
|
||||
pub async fn write_frame<W>(
|
||||
mut w: W,
|
||||
kind: FrameType,
|
||||
channel: ChannelNum,
|
||||
payload: &[u8],
|
||||
) -> Result<()>
|
||||
where
|
||||
W: AsyncWriteExt + Unpin + Send,
|
||||
{
|
||||
w.write_u8(kind as u8).await?;
|
||||
w.write_u16(channel.num()).await?;
|
||||
w.write_u32(u32::try_from(payload.len()).context("frame size too big")?)
|
||||
.await?;
|
||||
w.write_all(payload).await?;
|
||||
w.write_u8(REQUIRED_FRAME_END).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn read_frame<R>(r: &mut R, max_frame_size: MaxFrameSize) -> Result<Frame>
|
||||
where
|
||||
R: AsyncReadExt + Unpin + Send,
|
||||
{
|
||||
let kind = r.read_u8().await.context("read type")?;
|
||||
let channel = r.read_u16().await.context("read channel")?;
|
||||
let channel = ChannelNum::new(channel);
|
||||
let size = r.read_u32().await.context("read size")?;
|
||||
|
||||
let mut payload = vec![0; size.try_into().unwrap()];
|
||||
r.read_exact(&mut payload).await.context("read payload")?;
|
||||
|
||||
let frame_end = r.read_u8().await.context("read frame end")?;
|
||||
|
||||
if frame_end != REQUIRED_FRAME_END {
|
||||
return Err(ProtocolError::Fatal.into());
|
||||
}
|
||||
|
||||
if payload.len() > max_frame_size.as_usize() {
|
||||
return Err(ConException::FrameError.into());
|
||||
}
|
||||
|
||||
let kind = parse_frame_type(kind, channel)?;
|
||||
|
||||
let frame = Frame {
|
||||
kind,
|
||||
channel,
|
||||
payload: payload.into(),
|
||||
};
|
||||
|
||||
trace!(?frame, "Received frame");
|
||||
|
||||
Ok(frame)
|
||||
}
|
||||
|
||||
fn parse_frame_type(kind: u8, channel: ChannelNum) -> Result<FrameType> {
|
||||
match kind {
|
||||
frame_type::METHOD => Ok(FrameType::Method),
|
||||
frame_type::HEADER => Ok(FrameType::Header),
|
||||
frame_type::BODY => Ok(FrameType::Body),
|
||||
frame_type::HEARTBEAT => {
|
||||
if channel.is_zero() {
|
||||
Ok(FrameType::Heartbeat)
|
||||
} else {
|
||||
Err(ProtocolError::ConException(ConException::FrameError).into())
|
||||
}
|
||||
}
|
||||
_ => Err(ConException::FrameError.into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytes::Bytes;
|
||||
|
||||
use crate::frame::{ChannelNum, Frame, FrameType, MaxFrameSize};
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_small_body() {
|
||||
let mut bytes: &[u8] = &[
|
||||
/*type*/
|
||||
1,
|
||||
/*channel*/
|
||||
0,
|
||||
0,
|
||||
/*size*/
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
3,
|
||||
/*payload*/
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
/*frame-end*/
|
||||
super::REQUIRED_FRAME_END,
|
||||
];
|
||||
|
||||
let frame = super::read_frame(&mut bytes, MaxFrameSize::new(10000))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
frame,
|
||||
Frame {
|
||||
kind: FrameType::Method,
|
||||
channel: ChannelNum::new(0),
|
||||
payload: Bytes::from_static(&[1, 2, 3]),
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
104
haesli_transport/src/lib.rs
Normal file
104
haesli_transport/src/lib.rs
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
#![warn(rust_2018_idioms)]
|
||||
|
||||
mod connection;
|
||||
mod error;
|
||||
mod frame;
|
||||
mod methods;
|
||||
mod sasl;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// TODO: handle big types
|
||||
|
||||
use std::{future::Future, net::SocketAddr};
|
||||
|
||||
use anyhow::Context;
|
||||
use haesli_core::{connection::ConnectionEvent, queue::QueueEvent, GlobalData};
|
||||
use tokio::{net, net::TcpStream, select};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use crate::connection::TransportConnection;
|
||||
|
||||
pub async fn do_thing_i_guess(
|
||||
global_data: GlobalData,
|
||||
terminate: impl Future + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
select! {
|
||||
res = accept_cons(global_data.clone()) => {
|
||||
res
|
||||
}
|
||||
_ = terminate => {
|
||||
handle_shutdown(global_data).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn accept_cons(global_data: GlobalData) -> anyhow::Result<()> {
|
||||
info!("Binding TCP listener...");
|
||||
let listener = net::TcpListener::bind(("127.0.0.1", 5672)).await?;
|
||||
info!(addr = ?listener.local_addr()?, "Successfully bound TCP listener");
|
||||
|
||||
loop {
|
||||
let connection = listener.accept().await?;
|
||||
handle_con(global_data.clone(), connection);
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_con(global_data: GlobalData, connection: (TcpStream, SocketAddr)) {
|
||||
let (stream, peer_addr) = connection;
|
||||
let id = rand::random();
|
||||
|
||||
info!(local_addr = ?stream.local_addr(), %id, "Accepted new connection");
|
||||
let span = info_span!("client-connection", %id);
|
||||
|
||||
let (method_send, method_recv) = tokio::sync::mpsc::channel(10);
|
||||
|
||||
let connection_handle = haesli_core::connection::ConnectionInner::new(
|
||||
id,
|
||||
peer_addr,
|
||||
global_data.clone(),
|
||||
method_send.clone(),
|
||||
);
|
||||
|
||||
let mut global_data_guard = global_data.lock();
|
||||
global_data_guard
|
||||
.connections
|
||||
.insert(id, connection_handle.clone());
|
||||
|
||||
let connection = TransportConnection::new(
|
||||
id,
|
||||
stream,
|
||||
connection_handle,
|
||||
global_data.clone(),
|
||||
method_send,
|
||||
method_recv,
|
||||
);
|
||||
|
||||
tokio::spawn(connection.start_connection_processing().instrument(span));
|
||||
}
|
||||
|
||||
async fn handle_shutdown(global_data: GlobalData) -> anyhow::Result<()> {
|
||||
info!("Shutting down...");
|
||||
|
||||
let lock = global_data.lock();
|
||||
|
||||
for con in lock.connections.values() {
|
||||
con.event_sender
|
||||
.try_send(ConnectionEvent::Shutdown)
|
||||
.context("failed to stop connection")?;
|
||||
}
|
||||
|
||||
for queue in lock.queues.values() {
|
||||
queue
|
||||
.event_send
|
||||
.try_send(QueueEvent::Shutdown)
|
||||
.context("failed to stop queue worker")?;
|
||||
}
|
||||
|
||||
// todo: here we should wait for everything to close
|
||||
// https://github.com/tokio-rs/mini-redis/blob/4b4ecf0310e6bca43d336dde90a06d9dcad00d6c/src/server.rs#L51
|
||||
|
||||
info!("Finished shutdown");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
1554
haesli_transport/src/methods/generated.rs
generated
Normal file
1554
haesli_transport/src/methods/generated.rs
generated
Normal file
File diff suppressed because it is too large
Load diff
100
haesli_transport/src/methods/mod.rs
Normal file
100
haesli_transport/src/methods/mod.rs
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
use haesli_core::{
|
||||
error::ConException,
|
||||
methods::{FieldValue, Method, Table},
|
||||
};
|
||||
use rand::Rng;
|
||||
|
||||
use crate::error::TransError;
|
||||
|
||||
mod generated;
|
||||
pub mod parse_helper;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
pub mod write_helper;
|
||||
|
||||
pub use generated::*;
|
||||
|
||||
/// Parses the payload of a method frame into the method
|
||||
pub fn parse_method(payload: &[u8]) -> Result<Method, TransError> {
|
||||
let nom_result = generated::parse::parse_method(payload);
|
||||
|
||||
match nom_result {
|
||||
Ok(([], method)) => Ok(method),
|
||||
Ok((_, _)) => {
|
||||
Err(ConException::SyntaxError(vec!["could not consume all input".to_owned()]).into())
|
||||
}
|
||||
Err(nom::Err::Incomplete(_)) => {
|
||||
Err(ConException::SyntaxError(vec!["there was not enough data".to_owned()]).into())
|
||||
}
|
||||
Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Allows the creation of a random instance of that type
|
||||
pub trait RandomMethod<R: Rng> {
|
||||
fn random(rng: &mut R) -> Self;
|
||||
}
|
||||
|
||||
impl<R: Rng> RandomMethod<R> for String {
|
||||
fn random(rng: &mut R) -> Self {
|
||||
let n = rng.gen_range(0_u16..9999);
|
||||
format!("string{n}")
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Rng, T: RandomMethod<R>> RandomMethod<R> for Vec<T> {
|
||||
fn random(rng: &mut R) -> Self {
|
||||
let len = rng.gen_range(1_usize..10);
|
||||
let mut vec = Vec::with_capacity(len);
|
||||
(0..len).for_each(|_| vec.push(RandomMethod::random(rng)));
|
||||
vec
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! rand_random_method {
|
||||
($($ty:ty),+) => {
|
||||
$(
|
||||
impl<R: Rng> RandomMethod<R> for $ty {
|
||||
fn random(rng: &mut R) -> Self {
|
||||
rng.gen()
|
||||
}
|
||||
})+
|
||||
};
|
||||
}
|
||||
|
||||
rand_random_method!(bool, u8, i8, u16, i16, u32, i32, u64, i64, f32, f64);
|
||||
|
||||
impl<R: Rng> RandomMethod<R> for Table {
|
||||
fn random(rng: &mut R) -> Self {
|
||||
let len = rng.gen_range(0..3);
|
||||
(0..len)
|
||||
.map(|_| (String::random(rng), FieldValue::random(rng)))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Rng> RandomMethod<R> for FieldValue {
|
||||
fn random(rng: &mut R) -> Self {
|
||||
let index = rng.gen_range(0_u32..17);
|
||||
match index {
|
||||
0 => Self::Boolean(RandomMethod::random(rng)),
|
||||
1 => Self::ShortShortInt(RandomMethod::random(rng)),
|
||||
2 => Self::ShortShortUInt(RandomMethod::random(rng)),
|
||||
3 => Self::ShortInt(RandomMethod::random(rng)),
|
||||
4 => Self::ShortUInt(RandomMethod::random(rng)),
|
||||
5 => Self::LongInt(RandomMethod::random(rng)),
|
||||
6 => Self::LongUInt(RandomMethod::random(rng)),
|
||||
7 => Self::LongLongInt(RandomMethod::random(rng)),
|
||||
8 => Self::LongLongUInt(RandomMethod::random(rng)),
|
||||
9 => Self::Float(RandomMethod::random(rng)),
|
||||
10 => Self::Double(RandomMethod::random(rng)),
|
||||
11 => Self::ShortString(RandomMethod::random(rng)),
|
||||
12 => Self::LongString(RandomMethod::random(rng)),
|
||||
13 => Self::FieldArray(RandomMethod::random(rng)),
|
||||
14 => Self::Timestamp(RandomMethod::random(rng)),
|
||||
15 => Self::FieldTable(RandomMethod::random(rng)),
|
||||
16 => Self::Void,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
243
haesli_transport/src/methods/parse_helper.rs
Normal file
243
haesli_transport/src/methods/parse_helper.rs
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
use haesli_core::{
|
||||
error::{ConException, ProtocolError},
|
||||
methods::{
|
||||
Bit, FieldValue, Long, Longlong, Longstr, Octet, Short, Shortstr, Table, TableFieldName,
|
||||
Timestamp,
|
||||
},
|
||||
};
|
||||
use nom::{
|
||||
branch::alt,
|
||||
bytes::complete::{tag, take},
|
||||
error::ErrorKind,
|
||||
multi::{count, many0},
|
||||
number::{
|
||||
complete::{f32, f64, i16, i32, i64, i8, u16, u32, u64, u8},
|
||||
Endianness::Big,
|
||||
},
|
||||
Err,
|
||||
};
|
||||
|
||||
use crate::{error::TransError, methods::generated::parse::IResult};
|
||||
|
||||
impl<T> nom::error::ParseError<T> for TransError {
|
||||
fn from_error_kind(_input: T, _kind: ErrorKind) -> Self {
|
||||
ConException::SyntaxError(vec![]).into()
|
||||
}
|
||||
|
||||
fn append(_input: T, _kind: ErrorKind, other: Self) -> Self {
|
||||
other
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fail_err<S: Into<String>>(msg: S) -> impl FnOnce(Err<TransError>) -> Err<TransError> {
|
||||
move |err| {
|
||||
let msg = msg.into();
|
||||
let stack = match err {
|
||||
Err::Error(e) | Err::Failure(e) => match e {
|
||||
TransError::Protocol(ProtocolError::ConException(ConException::SyntaxError(
|
||||
mut stack,
|
||||
))) => {
|
||||
stack.push(msg);
|
||||
stack
|
||||
}
|
||||
_ => vec![msg],
|
||||
},
|
||||
Err::Incomplete(_) => vec![msg],
|
||||
};
|
||||
Err::Failure(ConException::SyntaxError(stack).into())
|
||||
}
|
||||
}
|
||||
pub fn other_fail<E, S: Into<String>>(msg: S) -> impl FnOnce(E) -> Err<TransError> {
|
||||
move |_| Err::Failure(ConException::SyntaxError(vec![msg.into()]).into())
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! fail {
|
||||
($cause:expr) => {
|
||||
return Err(nom::Err::Failure(
|
||||
::haesli_core::error::ProtocolError::ConException(
|
||||
::haesli_core::error::ConException::SyntaxError(vec![String::from($cause)]),
|
||||
)
|
||||
.into(),
|
||||
))
|
||||
};
|
||||
}
|
||||
|
||||
pub use fail;
|
||||
|
||||
pub fn octet(input: &[u8]) -> IResult<'_, Octet> {
|
||||
u8(input)
|
||||
}
|
||||
|
||||
pub fn short(input: &[u8]) -> IResult<'_, Short> {
|
||||
u16(Big)(input)
|
||||
}
|
||||
|
||||
pub fn long(input: &[u8]) -> IResult<'_, Long> {
|
||||
u32(Big)(input)
|
||||
}
|
||||
|
||||
pub fn longlong(input: &[u8]) -> IResult<'_, Longlong> {
|
||||
u64(Big)(input)
|
||||
}
|
||||
|
||||
pub fn bit(input: &[u8], amount: usize) -> IResult<'_, Vec<Bit>> {
|
||||
let octets = (amount + 7) / 8;
|
||||
let (input, bytes) = take(octets)(input)?;
|
||||
|
||||
let mut vec = Vec::new();
|
||||
let mut byte_index = 0;
|
||||
let mut total_index = 0;
|
||||
|
||||
for &byte in bytes {
|
||||
while byte_index < 8 && total_index < amount {
|
||||
let next_bit = 1 & (byte >> byte_index);
|
||||
let bit_bool = match next_bit {
|
||||
0 => false,
|
||||
1 => true,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
vec.push(bit_bool);
|
||||
byte_index += 1;
|
||||
total_index += 1;
|
||||
}
|
||||
byte_index = 0;
|
||||
}
|
||||
|
||||
Ok((input, vec))
|
||||
}
|
||||
|
||||
pub fn shortstr(input: &[u8]) -> IResult<'_, Shortstr> {
|
||||
let (input, len) = u8(input)?;
|
||||
let (input, str_data) = take(usize::from(len))(input)?;
|
||||
let data = String::from_utf8(str_data.into()).map_err(other_fail("shortstr"))?;
|
||||
Ok((input, data))
|
||||
}
|
||||
|
||||
pub fn longstr(input: &[u8]) -> IResult<'_, Longstr> {
|
||||
let (input, len) = u32(Big)(input)?;
|
||||
let (input, str_data) = take(usize::try_from(len).unwrap())(input)?;
|
||||
let data = str_data.into();
|
||||
Ok((input, data))
|
||||
}
|
||||
|
||||
pub fn timestamp(input: &[u8]) -> IResult<'_, Timestamp> {
|
||||
u64(Big)(input)
|
||||
}
|
||||
|
||||
pub fn table(input: &[u8]) -> IResult<'_, Table> {
|
||||
let (input, size) = u32(Big)(input)?;
|
||||
let (table_input, rest_input) = input.split_at(size.try_into().unwrap());
|
||||
|
||||
let (input, values) = many0(table_value_pair)(table_input)?;
|
||||
|
||||
if !input.is_empty() {
|
||||
fail!(format!(
|
||||
"table longer than expected, expected = {size}, remaining = {}",
|
||||
input.len()
|
||||
));
|
||||
}
|
||||
|
||||
let table = values.into_iter().collect();
|
||||
Ok((rest_input, table))
|
||||
}
|
||||
|
||||
fn table_value_pair(input: &[u8]) -> IResult<'_, (TableFieldName, FieldValue)> {
|
||||
let (input, field_name) = shortstr(input)?;
|
||||
let (input, field_value) =
|
||||
field_value(input).map_err(fail_err(format!("field {field_name}")))?;
|
||||
Ok((input, (field_name, field_value)))
|
||||
}
|
||||
|
||||
fn field_value(input: &[u8]) -> IResult<'_, FieldValue> {
|
||||
type R<'a> = IResult<'a, FieldValue>;
|
||||
|
||||
fn boolean(input: &[u8]) -> R<'_> {
|
||||
let (input, _) = tag(b"t")(input)?;
|
||||
let (input, bool_byte) = u8(input)?;
|
||||
match bool_byte {
|
||||
0 => Ok((input, FieldValue::Boolean(false))),
|
||||
1 => Ok((input, FieldValue::Boolean(true))),
|
||||
value => fail!(format!("invalid bool value {value}")),
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! number {
|
||||
($tag:literal, $name:ident, $comb:expr, $value:ident, $r:path) => {
|
||||
fn $name(input: &[u8]) -> $r {
|
||||
let (input, _) = tag($tag)(input)?;
|
||||
$comb(input).map(|(input, int)| (input, FieldValue::$value(int)))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
number!(b"b", short_short_int, i8, ShortShortInt, R<'_>);
|
||||
number!(b"B", short_short_uint, u8, ShortShortUInt, R<'_>);
|
||||
number!(b"U", short_int, i16(Big), ShortInt, R<'_>);
|
||||
number!(b"u", short_uint, u16(Big), ShortUInt, R<'_>);
|
||||
number!(b"I", long_int, i32(Big), LongInt, R<'_>);
|
||||
number!(b"i", long_uint, u32(Big), LongUInt, R<'_>);
|
||||
number!(b"L", long_long_int, i64(Big), LongLongInt, R<'_>);
|
||||
number!(b"l", long_long_uint, u64(Big), LongLongUInt, R<'_>);
|
||||
number!(b"f", float, f32(Big), Float, R<'_>);
|
||||
number!(b"d", double, f64(Big), Double, R<'_>);
|
||||
|
||||
fn decimal(input: &[u8]) -> R<'_> {
|
||||
let (input, _) = tag("D")(input)?;
|
||||
let (input, scale) = u8(input)?;
|
||||
let (input, value) = u32(Big)(input)?;
|
||||
Ok((input, FieldValue::DecimalValue(scale, value)))
|
||||
}
|
||||
|
||||
fn short_str(input: &[u8]) -> R<'_> {
|
||||
let (input, _) = tag("s")(input)?;
|
||||
let (input, str) = shortstr(input)?;
|
||||
Ok((input, FieldValue::ShortString(str)))
|
||||
}
|
||||
|
||||
fn long_str(input: &[u8]) -> R<'_> {
|
||||
let (input, _) = tag("S")(input)?;
|
||||
let (input, str) = longstr(input)?;
|
||||
Ok((input, FieldValue::LongString(str)))
|
||||
}
|
||||
|
||||
fn field_array(input: &[u8]) -> R<'_> {
|
||||
let (input, _) = tag("A")(input)?;
|
||||
// todo is it i32?
|
||||
let (input, len) = u32(Big)(input)?;
|
||||
count(field_value, usize::try_from(len).unwrap())(input)
|
||||
.map(|(input, value)| (input, FieldValue::FieldArray(value)))
|
||||
}
|
||||
|
||||
number!(b"T", timestamp, u64(Big), Timestamp, R<'_>);
|
||||
|
||||
fn field_table(input: &[u8]) -> R<'_> {
|
||||
let (input, _) = tag("F")(input)?;
|
||||
table(input).map(|(input, value)| (input, FieldValue::FieldTable(value)))
|
||||
}
|
||||
|
||||
fn void(input: &[u8]) -> R<'_> {
|
||||
tag("V")(input).map(|(input, _)| (input, FieldValue::Void))
|
||||
}
|
||||
|
||||
alt((
|
||||
boolean,
|
||||
short_short_int,
|
||||
short_short_uint,
|
||||
short_int,
|
||||
short_uint,
|
||||
long_int,
|
||||
long_uint,
|
||||
long_long_int,
|
||||
long_long_uint,
|
||||
float,
|
||||
double,
|
||||
decimal,
|
||||
short_str,
|
||||
long_str,
|
||||
field_array,
|
||||
timestamp,
|
||||
field_table,
|
||||
void,
|
||||
))(input)
|
||||
}
|
||||
85
haesli_transport/src/methods/tests.rs
Normal file
85
haesli_transport/src/methods/tests.rs
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
// create random methods to test the ser/de code together. if they diverge, we have a bug
|
||||
// this is not perfect, if they both have the same bug it won't be found, but that's an ok tradeoff
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rand::SeedableRng;
|
||||
|
||||
use crate::methods::{FieldValue, Method, RandomMethod};
|
||||
|
||||
#[test]
|
||||
fn pack_few_bits() {
|
||||
let bits = [true, false, true];
|
||||
|
||||
let mut buffer = [0u8; 2];
|
||||
super::write_helper::bit(&bits, &mut buffer.as_mut_slice()).unwrap();
|
||||
|
||||
let (_, parsed_bits) = super::parse_helper::bit(&buffer, 3).unwrap();
|
||||
assert_eq!(bits.as_slice(), parsed_bits.as_slice());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pack_many_bits() {
|
||||
let bits = [
|
||||
/* first 8 */
|
||||
true, true, true, true, false, false, false, false, /* second 4 */
|
||||
true, false, true, true,
|
||||
];
|
||||
let mut buffer = [0u8; 2];
|
||||
super::write_helper::bit(&bits, &mut buffer.as_mut_slice()).unwrap();
|
||||
|
||||
let (_, parsed_bits) = super::parse_helper::bit(&buffer, 12).unwrap();
|
||||
assert_eq!(bits.as_slice(), parsed_bits.as_slice());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn random_ser_de() {
|
||||
const ITERATIONS: usize = 10000;
|
||||
let mut rng = rand::rngs::StdRng::from_seed([0; 32]);
|
||||
|
||||
for _ in 0..ITERATIONS {
|
||||
let method = Method::random(&mut rng);
|
||||
let mut bytes = Vec::new();
|
||||
|
||||
if let Err(err) = super::write::write_method(&method, &mut bytes) {
|
||||
eprintln!("{method:#?}");
|
||||
eprintln!("{err:?}");
|
||||
panic!("Failed to serialize");
|
||||
}
|
||||
|
||||
match super::parse_method(&bytes) {
|
||||
Ok(parsed) => {
|
||||
if method != parsed {
|
||||
eprintln!("{method:#?}");
|
||||
eprintln!("{bytes:?}");
|
||||
eprintln!("{parsed:?}");
|
||||
panic!("Not equal!");
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("{method:#?}");
|
||||
eprintln!("{bytes:?}");
|
||||
eprintln!("{err:?}");
|
||||
panic!("Failed to deserialize");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nested_table() {
|
||||
let table = HashMap::from([(
|
||||
"A".to_owned(),
|
||||
FieldValue::FieldTable(HashMap::from([("B".to_owned(), FieldValue::Boolean(true))])),
|
||||
)]);
|
||||
eprintln!("{table:?}");
|
||||
|
||||
let mut bytes = Vec::new();
|
||||
crate::methods::write_helper::table(&table, &mut bytes).unwrap();
|
||||
eprintln!("{bytes:?}");
|
||||
|
||||
let (rest, parsed_table) = crate::methods::parse_helper::table(&bytes).unwrap();
|
||||
|
||||
assert!(rest.is_empty());
|
||||
assert_eq!(table, parsed_table);
|
||||
}
|
||||
199
haesli_transport/src/methods/write_helper.rs
Normal file
199
haesli_transport/src/methods/write_helper.rs
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
use std::io::Write;
|
||||
|
||||
use anyhow::Context;
|
||||
use haesli_core::methods::{
|
||||
Bit, Long, Longlong, Longstr, Octet, Short, Shortstr, Table, Timestamp,
|
||||
};
|
||||
|
||||
use crate::{error::TransError, methods::FieldValue};
|
||||
|
||||
pub fn octet<W: Write>(value: &Octet, writer: &mut W) -> Result<(), TransError> {
|
||||
writer.write_all(&[*value])?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn short<W: Write>(value: &Short, writer: &mut W) -> Result<(), TransError> {
|
||||
writer.write_all(&value.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn long<W: Write>(value: &Long, writer: &mut W) -> Result<(), TransError> {
|
||||
writer.write_all(&value.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn longlong<W: Write>(value: &Longlong, writer: &mut W) -> Result<(), TransError> {
|
||||
writer.write_all(&value.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn bit<W: Write>(value: &[Bit], writer: &mut W) -> Result<(), TransError> {
|
||||
// accumulate bits into bytes, starting from the least significant bit in each byte
|
||||
|
||||
// how many bits have already been packed into `current_buf`
|
||||
let mut already_filled = 0;
|
||||
let mut current_buf = 0u8;
|
||||
|
||||
for &bit in value {
|
||||
if already_filled >= 8 {
|
||||
writer.write_all(&[current_buf])?;
|
||||
current_buf = 0;
|
||||
already_filled = 0;
|
||||
}
|
||||
|
||||
let new_bit = (u8::from(bit)) << already_filled;
|
||||
current_buf |= new_bit;
|
||||
already_filled += 1;
|
||||
}
|
||||
|
||||
if already_filled > 0 {
|
||||
writer.write_all(&[current_buf])?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn shortstr<W: Write>(value: &Shortstr, writer: &mut W) -> Result<(), TransError> {
|
||||
let len = u8::try_from(value.len()).context("shortstr too long")?;
|
||||
writer.write_all(&[len])?;
|
||||
writer.write_all(value.as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn longstr<W: Write>(value: &Longstr, writer: &mut W) -> Result<(), TransError> {
|
||||
let len = u32::try_from(value.len()).context("longstr too long")?;
|
||||
writer.write_all(&len.to_be_bytes())?;
|
||||
writer.write_all(value.as_slice())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// this appears to be unused right now, but it could be used in `Basic` things?
|
||||
#[allow(dead_code)]
|
||||
pub fn timestamp<W: Write>(value: &Timestamp, writer: &mut W) -> Result<(), TransError> {
|
||||
writer.write_all(&value.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn table<W: Write>(table: &Table, writer: &mut W) -> Result<(), TransError> {
|
||||
let mut table_buf = Vec::new();
|
||||
|
||||
for (field_name, value) in table {
|
||||
shortstr(field_name, &mut table_buf)?;
|
||||
field_value(value, &mut table_buf)?;
|
||||
}
|
||||
|
||||
let len = u32::try_from(table_buf.len()).context("table too big")?;
|
||||
writer.write_all(&len.to_be_bytes())?;
|
||||
writer.write_all(&table_buf)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn field_value<W: Write>(value: &FieldValue, writer: &mut W) -> Result<(), TransError> {
|
||||
match value {
|
||||
FieldValue::Boolean(bool) => {
|
||||
writer.write_all(&[b't', u8::from(*bool)])?;
|
||||
}
|
||||
FieldValue::ShortShortInt(int) => {
|
||||
writer.write_all(b"b")?;
|
||||
writer.write_all(&int.to_be_bytes())?;
|
||||
}
|
||||
FieldValue::ShortShortUInt(int) => {
|
||||
writer.write_all(&[b'B', *int])?;
|
||||
}
|
||||
FieldValue::ShortInt(int) => {
|
||||
writer.write_all(b"U")?;
|
||||
writer.write_all(&int.to_be_bytes())?;
|
||||
}
|
||||
FieldValue::ShortUInt(int) => {
|
||||
writer.write_all(b"u")?;
|
||||
writer.write_all(&int.to_be_bytes())?;
|
||||
}
|
||||
FieldValue::LongInt(int) => {
|
||||
writer.write_all(b"I")?;
|
||||
writer.write_all(&int.to_be_bytes())?;
|
||||
}
|
||||
FieldValue::LongUInt(int) => {
|
||||
writer.write_all(b"i")?;
|
||||
writer.write_all(&int.to_be_bytes())?;
|
||||
}
|
||||
FieldValue::LongLongInt(int) => {
|
||||
writer.write_all(b"L")?;
|
||||
writer.write_all(&int.to_be_bytes())?;
|
||||
}
|
||||
FieldValue::LongLongUInt(int) => {
|
||||
writer.write_all(b"l")?;
|
||||
writer.write_all(&int.to_be_bytes())?;
|
||||
}
|
||||
FieldValue::Float(float) => {
|
||||
writer.write_all(b"f")?;
|
||||
writer.write_all(&float.to_be_bytes())?;
|
||||
}
|
||||
FieldValue::Double(float) => {
|
||||
writer.write_all(b"d")?;
|
||||
writer.write_all(&float.to_be_bytes())?;
|
||||
}
|
||||
FieldValue::DecimalValue(scale, long) => {
|
||||
writer.write_all(&[b'D', *scale])?;
|
||||
writer.write_all(&long.to_be_bytes())?;
|
||||
}
|
||||
FieldValue::ShortString(str) => {
|
||||
writer.write_all(b"s")?;
|
||||
shortstr(str, writer)?;
|
||||
}
|
||||
FieldValue::LongString(str) => {
|
||||
writer.write_all(b"S")?;
|
||||
longstr(str, writer)?;
|
||||
}
|
||||
FieldValue::FieldArray(array) => {
|
||||
writer.write_all(b"A")?;
|
||||
let len = u32::try_from(array.len()).context("array too long")?;
|
||||
writer.write_all(&len.to_be_bytes())?;
|
||||
|
||||
for element in array {
|
||||
field_value(element, writer)?;
|
||||
}
|
||||
}
|
||||
FieldValue::Timestamp(time) => {
|
||||
writer.write_all(b"T")?;
|
||||
writer.write_all(&time.to_be_bytes())?;
|
||||
}
|
||||
FieldValue::FieldTable(value) => {
|
||||
writer.write_all(b"F")?;
|
||||
table(value, writer)?;
|
||||
}
|
||||
FieldValue::Void => {
|
||||
writer.write_all(b"V")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn pack_few_bits() {
|
||||
let bits = [true, false, true];
|
||||
|
||||
let mut buffer = [0u8; 1];
|
||||
super::bit(&bits, &mut buffer.as_mut_slice()).unwrap();
|
||||
|
||||
assert_eq!(buffer, [0b00000101])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pack_many_bits() {
|
||||
let bits = [
|
||||
/* first 8 */
|
||||
true, true, true, true, false, false, false, false, /* second 4 */
|
||||
true, false, true, true,
|
||||
];
|
||||
|
||||
let mut buffer = [0u8; 2];
|
||||
super::bit(&bits, &mut buffer.as_mut_slice()).unwrap();
|
||||
|
||||
assert_eq!(buffer, [0b00001111, 0b00001101]);
|
||||
}
|
||||
}
|
||||
29
haesli_transport/src/sasl.rs
Normal file
29
haesli_transport/src/sasl.rs
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
//! (Very) partial implementation of SASL Authentication (see [RFC 4422](https://datatracker.ietf.org/doc/html/rfc4422))
|
||||
//!
|
||||
//! Currently only supports PLAIN (see [RFC 4616](https://datatracker.ietf.org/doc/html/rfc4616))
|
||||
|
||||
use haesli_core::error::ConException;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
pub struct PlainUser {
|
||||
pub authorization_identity: String,
|
||||
pub authentication_identity: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
pub fn parse_sasl_plain_response(response: &[u8]) -> Result<PlainUser> {
|
||||
let mut parts = response
|
||||
.split(|&n| n == 0)
|
||||
.map(|bytes| String::from_utf8(bytes.into()).map_err(|_| ConException::Todo));
|
||||
|
||||
let authorization_identity = parts.next().ok_or(ConException::Todo)??;
|
||||
let authentication_identity = parts.next().ok_or(ConException::Todo)??;
|
||||
let password = parts.next().ok_or(ConException::Todo)??;
|
||||
|
||||
Ok(PlainUser {
|
||||
authorization_identity,
|
||||
authentication_identity,
|
||||
password,
|
||||
})
|
||||
}
|
||||
180
haesli_transport/src/tests.rs
Normal file
180
haesli_transport/src/tests.rs
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use haesli_core::{
|
||||
connection::ChannelNum,
|
||||
methods::{ConnectionStart, ConnectionStartOk, FieldValue, Method},
|
||||
};
|
||||
|
||||
use crate::{frame, frame::FrameType, methods};
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_start_ok_frame() {
|
||||
let mut payload = Vec::new();
|
||||
let method = Method::ConnectionStart(ConnectionStart {
|
||||
version_major: 0,
|
||||
version_minor: 9,
|
||||
server_properties: HashMap::from([(
|
||||
"product".to_owned(),
|
||||
FieldValue::LongString("no name yet".into()),
|
||||
)]),
|
||||
mechanisms: "PLAIN".into(),
|
||||
locales: "en_US".into(),
|
||||
});
|
||||
|
||||
methods::write::write_method(&method, &mut payload).unwrap();
|
||||
|
||||
let mut output = Vec::new();
|
||||
|
||||
frame::write_frame(&mut output, FrameType::Method, ChannelNum::zero(), &payload)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
#[rustfmt::skip]
|
||||
let expected = [
|
||||
/* type, octet, method */
|
||||
1u8,
|
||||
/* channel, short */
|
||||
0, 0,
|
||||
/* size, long */
|
||||
/* count all the bytes in the payload, 33 here */
|
||||
0, 0, 0, 52,
|
||||
/* payload */
|
||||
/* class-id, short, connection */
|
||||
0, 10,
|
||||
/* method-id, short, start */
|
||||
0, 10,
|
||||
/* version-major, octet */
|
||||
0,
|
||||
/* version-minor, octet */
|
||||
9,
|
||||
/* server-properties, table */
|
||||
/* table-size, long (actual byte size) */
|
||||
0, 0, 0, 24,
|
||||
/* table-items */
|
||||
/* name ("product"), shortstr */
|
||||
/* len (7) ; bytes */
|
||||
7, b'p', b'r', b'o', b'd', b'u', b'c', b't',
|
||||
/* value, a shortstr ("no name yet") here */
|
||||
/* tag (s) ; len (11) ; data */
|
||||
b'S', 0, 0, 0, 11, b'n', b'o', b' ', b'n', b'a', b'm', b'e', b' ', b'y', b'e', b't',
|
||||
/* mechanisms, longstr */
|
||||
/* str-len, long ; len 5 ; data ("PLAIN") */
|
||||
0, 0, 0, 5,
|
||||
b'P', b'L', b'A', b'I', b'N',
|
||||
/* locales, longstr */
|
||||
/* str-len, long ; len 5 ; data ("en_US") */
|
||||
0, 0, 0, 5,
|
||||
b'e', b'n', b'_', b'U', b'S',
|
||||
/* frame-end */
|
||||
0xCE,
|
||||
];
|
||||
|
||||
assert_eq!(expected.as_slice(), output.as_slice());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_start_ok_payload() {
|
||||
#[rustfmt::skip]
|
||||
let raw_data = [
|
||||
/* Connection.Start-Ok */
|
||||
0, 10, 0, 11,
|
||||
/* field client-properties */
|
||||
/* byte size of the table */
|
||||
0, 0, 0, 254,
|
||||
/* first key of len 7, "product"*/
|
||||
7, 112, 114, 111, 100, 117, 99, 116,
|
||||
/* value is of type 83 ("S"), long-string */
|
||||
/* has length 26 "Pika Python Client Library" */
|
||||
83, 0, 0, 0, 26,
|
||||
80, 105, 107, 97, 32, 80, 121, 116, 104, 111, 110, 32, 67, 108, 105, 101, 110, 116, 32, 76, 105, 98, 114, 97, 114, 121,
|
||||
/* second key of len 8, "platform" */
|
||||
8, 112, 108, 97, 116, 102, 111, 114, 109,
|
||||
/* value is of type 83("S"), long-string */
|
||||
/* has length 13, "Python 3.8.10" */
|
||||
83, 0, 0, 0, 13,
|
||||
80, 121, 116, 104, 111, 110, 32, 51, 46, 56, 46, 49, 48,
|
||||
/* third key has len 12 "capabilities" */
|
||||
12, 99, 97, 112, 97, 98, 105, 108, 105, 116, 105, 101, 115,
|
||||
/* type is 70 F (table), with byte-len of 111 */
|
||||
70, 0, 0, 0, 111,
|
||||
/* first key has length 28, "authentication_failure_close" */
|
||||
28, 97, 117, 116, 104, 101, 110, 116, 105, 99, 97, 116, 105, 111, 110, 95, 102, 97, 105, 108, 117, 114, 101, 95, 99, 108, 111, 115, 101,
|
||||
/* value of type 116, "t", boolean, true */
|
||||
116, 1,
|
||||
/* second key has length 10, "basic.nack" */
|
||||
10, 98, 97, 115, 105, 99, 46, 110, 97, 99, 107,
|
||||
/* value of type 116, "t", boolean, true */
|
||||
116, 1,
|
||||
/* third key has length 18 "connection.blocked" */
|
||||
18, 99, 111, 110, 110, 101, 99, 116, 105, 111, 110, 46, 98, 108, 111, 99, 107, 101, 100,
|
||||
/* value of type 116, "t", boolean, true */
|
||||
116, 1,
|
||||
/* fourth key has length 22 "consumer_cancel_notify" */
|
||||
22, 99, 111, 110, 115, 117, 109, 101, 114, 95, 99, 97, 110, 99, 101, 108, 95, 110, 111, 116, 105, 102, 121,
|
||||
/* value of type 116, "t", boolean, true */
|
||||
116, 1,
|
||||
/* fifth key has length 18 "publisher_confirms" */
|
||||
18, 112, 117, 98, 108, 105, 115, 104, 101, 114, 95, 99, 111, 110, 102, 105, 114, 109, 115,
|
||||
/* value of type 116, "t", boolean, true */
|
||||
116, 1,
|
||||
/* sixth key has length 11 "information" */
|
||||
11, 105, 110, 102, 111, 114, 109, 97, 116, 105, 111, 110,
|
||||
/* value of type 83, "S" long-str ; len 24 ; data "See http://pika.rtfd.org" */
|
||||
83, 0, 0, 0, 24,
|
||||
83, 101, 101, 32, 104, 116, 116, 112, 58, 47, 47, 112, 105, 107, 97, 46, 114, 116, 102, 100, 46, 111, 114, 103,
|
||||
/* seventh key has length 7, "version" */
|
||||
7, 118, 101, 114, 115, 105, 111, 110,
|
||||
/* value of type 83, "S" long-str ; length 5 ; "1.1.0" */
|
||||
83, 0, 0, 0, 5,
|
||||
49, 46, 49, 46, 48,
|
||||
/* client-properties table ends here */
|
||||
/* field mechanism, length 5, "PLAIN" */
|
||||
5, 80, 76, 65, 73, 78,
|
||||
/* field response, longstr, length 7, "\x00admin\x00" */
|
||||
0, 0, 0, 7, 0, 97, 100, 109, 105, 110, 0,
|
||||
/* locale, shortstr, len 5 "en_US" */
|
||||
5, 101, 110, 95, 85, 83,
|
||||
];
|
||||
|
||||
let method = methods::parse_method(&raw_data).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
method,
|
||||
Method::ConnectionStartOk(ConnectionStartOk {
|
||||
client_properties: HashMap::from([
|
||||
(
|
||||
"product".to_owned(),
|
||||
FieldValue::LongString("Pika Python Client Library".into())
|
||||
),
|
||||
(
|
||||
"platform".to_owned(),
|
||||
FieldValue::LongString("Python 3.8.10".into())
|
||||
),
|
||||
(
|
||||
"capabilities".to_owned(),
|
||||
FieldValue::FieldTable(HashMap::from([
|
||||
(
|
||||
"authentication_failure_close".to_owned(),
|
||||
FieldValue::Boolean(true)
|
||||
),
|
||||
("basic.nack".to_owned(), FieldValue::Boolean(true)),
|
||||
("connection.blocked".to_owned(), FieldValue::Boolean(true)),
|
||||
(
|
||||
"consumer_cancel_notify".to_owned(),
|
||||
FieldValue::Boolean(true)
|
||||
),
|
||||
("publisher_confirms".to_owned(), FieldValue::Boolean(true)),
|
||||
]))
|
||||
),
|
||||
(
|
||||
"information".to_owned(),
|
||||
FieldValue::LongString("See http://pika.rtfd.org".into())
|
||||
),
|
||||
("version".to_owned(), FieldValue::LongString("1.1.0".into()))
|
||||
]),
|
||||
mechanism: "PLAIN".to_owned(),
|
||||
response: "\x00admin\x00".into(),
|
||||
locale: "en_US".to_owned()
|
||||
})
|
||||
);
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue