mirror of
https://github.com/Noratrieb/haesli.git
synced 2026-01-16 12:45:04 +01:00
restructuring
This commit is contained in:
parent
ed4a107c44
commit
9b48dec533
12 changed files with 1988 additions and 1586 deletions
File diff suppressed because it is too large
Load diff
|
|
@ -36,7 +36,7 @@ pub enum FieldValue {
|
|||
pub use generated::*;
|
||||
|
||||
/// Parses the payload of a method frame into the class/method
|
||||
pub fn parse_method(payload: &[u8]) -> Result<generated::Class, TransError> {
|
||||
pub fn parse_method(payload: &[u8]) -> Result<generated::Method, TransError> {
|
||||
let nom_result = generated::parse::parse_method(payload);
|
||||
|
||||
match nom_result {
|
||||
|
|
|
|||
|
|
@ -25,8 +25,7 @@ impl<T> nom::error::ParseError<T> for TransError {
|
|||
}
|
||||
}
|
||||
|
||||
// todo: make this into fail_err to avoid useless allocations
|
||||
pub fn err<S: Into<String>>(msg: S) -> impl FnOnce(Err<TransError>) -> Err<TransError> {
|
||||
pub fn fail_err<S: Into<String>>(msg: S) -> impl FnOnce(Err<TransError>) -> Err<TransError> {
|
||||
move |err| {
|
||||
let error_level = if matches!(err, nom::Err::Failure(_)) {
|
||||
Err::Failure
|
||||
|
|
@ -156,7 +155,8 @@ pub fn table(input: &[u8]) -> IResult<'_, Table> {
|
|||
|
||||
fn table_value_pair(input: &[u8]) -> IResult<'_, (TableFieldName, FieldValue)> {
|
||||
let (input, field_name) = shortstr(input)?;
|
||||
let (input, field_value) = field_value(input).map_err(err(format!("field {field_name}")))?;
|
||||
let (input, field_value) =
|
||||
field_value(input).map_err(fail_err(format!("field {field_name}")))?;
|
||||
Ok((input, (field_name, field_value)))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
// create random methods to test the ser/de code together. if they diverge, we have a bug
|
||||
// this is not perfect, if they both have the same bug it won't be found, but tha's an ok tradeoff
|
||||
|
||||
use crate::classes::{Class, FieldValue};
|
||||
use crate::classes::{FieldValue, Method};
|
||||
use rand::{Rng, SeedableRng};
|
||||
use std::collections::HashMap;
|
||||
|
||||
|
|
@ -103,7 +103,7 @@ fn random_ser_de() {
|
|||
let mut rng = rand::rngs::StdRng::from_seed([0; 32]);
|
||||
|
||||
for _ in 0..ITERATIONS {
|
||||
let class = Class::random(&mut rng);
|
||||
let class = Method::random(&mut rng);
|
||||
let mut bytes = Vec::new();
|
||||
|
||||
if let Err(err) = super::write::write_method(class.clone(), &mut bytes) {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::classes::Class;
|
||||
use crate::classes::Method;
|
||||
use crate::error::{ConException, ProtocolError, Result};
|
||||
use crate::frame::{Frame, FrameType};
|
||||
use crate::{classes, frame, sasl};
|
||||
|
|
@ -88,7 +88,7 @@ impl Connection {
|
|||
self.main_loop().await
|
||||
}
|
||||
|
||||
async fn send_method(&mut self, channel: u16, method: classes::Class) -> Result<()> {
|
||||
async fn send_method(&mut self, channel: u16, method: Method) -> Result<()> {
|
||||
let mut payload = Vec::with_capacity(64);
|
||||
classes::write::write_method(method, &mut payload)?;
|
||||
frame::write_frame(
|
||||
|
|
@ -102,7 +102,7 @@ impl Connection {
|
|||
.await
|
||||
}
|
||||
|
||||
async fn recv_method(&mut self) -> Result<classes::Class> {
|
||||
async fn recv_method(&mut self) -> Result<Method> {
|
||||
let start_ok_frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?;
|
||||
|
||||
ensure_conn(start_ok_frame.kind == FrameType::Method)?;
|
||||
|
|
@ -112,7 +112,7 @@ impl Connection {
|
|||
}
|
||||
|
||||
async fn start(&mut self) -> Result<()> {
|
||||
let start_method = classes::Class::Connection(classes::Connection::Start {
|
||||
let start_method = Method::ConnectionStart {
|
||||
version_major: 0,
|
||||
version_minor: 9,
|
||||
server_properties: server_properties(
|
||||
|
|
@ -122,7 +122,7 @@ impl Connection {
|
|||
),
|
||||
mechanisms: "PLAIN".into(),
|
||||
locales: "en_US".into(),
|
||||
});
|
||||
};
|
||||
|
||||
debug!(?start_method, "Sending Start method");
|
||||
self.send_method(0, start_method).await?;
|
||||
|
|
@ -130,12 +130,12 @@ impl Connection {
|
|||
let start_ok = self.recv_method().await?;
|
||||
debug!(?start_ok, "Received Start-Ok");
|
||||
|
||||
if let classes::Class::Connection(classes::Connection::StartOk {
|
||||
if let Method::ConnectionStartOk {
|
||||
mechanism,
|
||||
locale,
|
||||
response,
|
||||
..
|
||||
}) = start_ok
|
||||
} = start_ok
|
||||
{
|
||||
ensure_conn(mechanism == "PLAIN")?;
|
||||
ensure_conn(locale == "en_US")?;
|
||||
|
|
@ -149,11 +149,11 @@ impl Connection {
|
|||
}
|
||||
|
||||
async fn tune(&mut self) -> Result<()> {
|
||||
let tune_method = classes::Class::Connection(classes::Connection::Tune {
|
||||
let tune_method = Method::ConnectionTune {
|
||||
channel_max: CHANNEL_MAX,
|
||||
frame_max: FRAME_SIZE_MAX,
|
||||
heartbeat: HEARTBEAT_DELAY,
|
||||
});
|
||||
};
|
||||
|
||||
debug!("Sending Tune method");
|
||||
self.send_method(0, tune_method).await?;
|
||||
|
|
@ -161,11 +161,11 @@ impl Connection {
|
|||
let tune_ok = self.recv_method().await?;
|
||||
debug!(?tune_ok, "Received Tune-Ok method");
|
||||
|
||||
if let classes::Class::Connection(classes::Connection::TuneOk {
|
||||
if let Method::ConnectionTuneOk {
|
||||
channel_max,
|
||||
frame_max,
|
||||
heartbeat,
|
||||
}) = tune_ok
|
||||
} = tune_ok
|
||||
{
|
||||
self.channel_max = channel_max;
|
||||
self.max_frame_size = usize::try_from(frame_max).unwrap();
|
||||
|
|
@ -180,15 +180,15 @@ impl Connection {
|
|||
let open = self.recv_method().await?;
|
||||
debug!(?open, "Received Open method");
|
||||
|
||||
if let classes::Class::Connection(classes::Connection::Open { virtual_host, .. }) = open {
|
||||
if let Method::ConnectionOpen { virtual_host, .. } = open {
|
||||
ensure_conn(virtual_host == "/")?;
|
||||
}
|
||||
|
||||
self.send_method(
|
||||
0,
|
||||
classes::Class::Connection(classes::Connection::OpenOk {
|
||||
Method::ConnectionOpenOk {
|
||||
reserved_1: "".to_string(),
|
||||
}),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
|
@ -197,23 +197,14 @@ impl Connection {
|
|||
|
||||
async fn main_loop(&mut self) -> Result<()> {
|
||||
loop {
|
||||
tokio::select! {
|
||||
frame = frame::read_frame(&mut self.stream, self.max_frame_size) => {
|
||||
debug!(?frame);
|
||||
let frame = frame?;
|
||||
self.reset_timeout();
|
||||
let frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?;
|
||||
debug!(?frame);
|
||||
self.reset_timeout();
|
||||
|
||||
match frame.kind {
|
||||
FrameType::Method => self.dispatch_method(frame).await?,
|
||||
FrameType::Heartbeat => {}
|
||||
_ => warn!(frame_type = ?frame.kind, "TODO"),
|
||||
}
|
||||
}
|
||||
_ = &mut self.next_timeout => {
|
||||
if self.heartbeat_delay != 0 {
|
||||
return Err(ProtocolError::CloseNow.into());
|
||||
}
|
||||
}
|
||||
match frame.kind {
|
||||
FrameType::Method => self.dispatch_method(frame).await?,
|
||||
FrameType::Heartbeat => {}
|
||||
_ => warn!(frame_type = ?frame.kind, "TODO"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -223,12 +214,10 @@ impl Connection {
|
|||
debug!(?method, "Received method");
|
||||
|
||||
match method {
|
||||
classes::Class::Connection(classes::Connection::Close { .. }) => {
|
||||
Method::ConnectionClose { .. } => {
|
||||
// todo: handle closing
|
||||
}
|
||||
classes::Class::Channel(classes::Channel::Open { .. }) => {
|
||||
self.channel_open(frame.channel).await?
|
||||
}
|
||||
Method::ChannelOpen { .. } => self.channel_open(frame.channel).await?,
|
||||
|
||||
_ => {
|
||||
// we don't handle this here, forward it to *somewhere*
|
||||
|
|
@ -274,9 +263,9 @@ impl Connection {
|
|||
|
||||
self.send_method(
|
||||
num,
|
||||
Class::Channel(classes::Channel::OpenOk {
|
||||
Method::ChannelOpenOk {
|
||||
reserved_1: Vec::new(),
|
||||
}),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
|
@ -325,6 +314,18 @@ impl Connection {
|
|||
}
|
||||
}
|
||||
|
||||
impl Drop for Connection {
|
||||
fn drop(&mut self) {
|
||||
self.connection_handle.lock().close();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Channel {
|
||||
fn drop(&mut self) {
|
||||
self.channel_handle.lock().close();
|
||||
}
|
||||
}
|
||||
|
||||
fn server_properties(host: SocketAddr) -> classes::Table {
|
||||
fn ls(str: &str) -> classes::FieldValue {
|
||||
classes::FieldValue::LongString(str.into())
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::classes::{Class, Connection, FieldValue};
|
||||
use crate::classes::{FieldValue, Method};
|
||||
use crate::frame::FrameType;
|
||||
use crate::{classes, frame};
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -6,7 +6,7 @@ use std::collections::HashMap;
|
|||
#[tokio::test]
|
||||
async fn write_start_ok_frame() {
|
||||
let mut payload = Vec::new();
|
||||
let method = classes::Class::Connection(classes::Connection::Start {
|
||||
let method = Method::ConnectionStart {
|
||||
version_major: 0,
|
||||
version_minor: 9,
|
||||
server_properties: HashMap::from([(
|
||||
|
|
@ -15,7 +15,7 @@ async fn write_start_ok_frame() {
|
|||
)]),
|
||||
mechanisms: "PLAIN".into(),
|
||||
locales: "en_US".into(),
|
||||
});
|
||||
};
|
||||
|
||||
classes::write::write_method(method, &mut payload).unwrap();
|
||||
|
||||
|
|
@ -140,7 +140,7 @@ fn read_start_ok_payload() {
|
|||
|
||||
assert_eq!(
|
||||
method,
|
||||
Class::Connection(Connection::StartOk {
|
||||
Method::ConnectionStartOk {
|
||||
client_properties: HashMap::from([
|
||||
(
|
||||
"product".to_string(),
|
||||
|
|
@ -178,6 +178,6 @@ fn read_start_ok_payload() {
|
|||
mechanism: "PLAIN".to_string(),
|
||||
response: "\x00admin\x00".into(),
|
||||
locale: "en_US".to_string()
|
||||
})
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue