restructuring

This commit is contained in:
nora 2022-02-20 14:59:54 +01:00
parent ed4a107c44
commit 9b48dec533
12 changed files with 1988 additions and 1586 deletions

View file

@ -40,7 +40,7 @@ const renderConnections = (connections) => {
}; };
const refresh = async () => { const refresh = async () => {
const fetched = await fetch('http://localhost:3000/api/data'); const fetched = await fetch('api/data');
const data = await fetched.json(); const data = await fetched.json();
renderConnections(data.connections); renderConnections(data.connections);
}; };

File diff suppressed because it is too large Load diff

View file

@ -36,7 +36,7 @@ pub enum FieldValue {
pub use generated::*; pub use generated::*;
/// Parses the payload of a method frame into the class/method /// Parses the payload of a method frame into the class/method
pub fn parse_method(payload: &[u8]) -> Result<generated::Class, TransError> { pub fn parse_method(payload: &[u8]) -> Result<generated::Method, TransError> {
let nom_result = generated::parse::parse_method(payload); let nom_result = generated::parse::parse_method(payload);
match nom_result { match nom_result {

View file

@ -25,8 +25,7 @@ impl<T> nom::error::ParseError<T> for TransError {
} }
} }
// todo: make this into fail_err to avoid useless allocations pub fn fail_err<S: Into<String>>(msg: S) -> impl FnOnce(Err<TransError>) -> Err<TransError> {
pub fn err<S: Into<String>>(msg: S) -> impl FnOnce(Err<TransError>) -> Err<TransError> {
move |err| { move |err| {
let error_level = if matches!(err, nom::Err::Failure(_)) { let error_level = if matches!(err, nom::Err::Failure(_)) {
Err::Failure Err::Failure
@ -156,7 +155,8 @@ pub fn table(input: &[u8]) -> IResult<'_, Table> {
fn table_value_pair(input: &[u8]) -> IResult<'_, (TableFieldName, FieldValue)> { fn table_value_pair(input: &[u8]) -> IResult<'_, (TableFieldName, FieldValue)> {
let (input, field_name) = shortstr(input)?; let (input, field_name) = shortstr(input)?;
let (input, field_value) = field_value(input).map_err(err(format!("field {field_name}")))?; let (input, field_value) =
field_value(input).map_err(fail_err(format!("field {field_name}")))?;
Ok((input, (field_name, field_value))) Ok((input, (field_name, field_value)))
} }

View file

@ -1,7 +1,7 @@
// create random methods to test the ser/de code together. if they diverge, we have a bug // create random methods to test the ser/de code together. if they diverge, we have a bug
// this is not perfect, if they both have the same bug it won't be found, but tha's an ok tradeoff // this is not perfect, if they both have the same bug it won't be found, but tha's an ok tradeoff
use crate::classes::{Class, FieldValue}; use crate::classes::{FieldValue, Method};
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
use std::collections::HashMap; use std::collections::HashMap;
@ -103,7 +103,7 @@ fn random_ser_de() {
let mut rng = rand::rngs::StdRng::from_seed([0; 32]); let mut rng = rand::rngs::StdRng::from_seed([0; 32]);
for _ in 0..ITERATIONS { for _ in 0..ITERATIONS {
let class = Class::random(&mut rng); let class = Method::random(&mut rng);
let mut bytes = Vec::new(); let mut bytes = Vec::new();
if let Err(err) = super::write::write_method(class.clone(), &mut bytes) { if let Err(err) = super::write::write_method(class.clone(), &mut bytes) {

View file

@ -1,4 +1,4 @@
use crate::classes::Class; use crate::classes::Method;
use crate::error::{ConException, ProtocolError, Result}; use crate::error::{ConException, ProtocolError, Result};
use crate::frame::{Frame, FrameType}; use crate::frame::{Frame, FrameType};
use crate::{classes, frame, sasl}; use crate::{classes, frame, sasl};
@ -88,7 +88,7 @@ impl Connection {
self.main_loop().await self.main_loop().await
} }
async fn send_method(&mut self, channel: u16, method: classes::Class) -> Result<()> { async fn send_method(&mut self, channel: u16, method: Method) -> Result<()> {
let mut payload = Vec::with_capacity(64); let mut payload = Vec::with_capacity(64);
classes::write::write_method(method, &mut payload)?; classes::write::write_method(method, &mut payload)?;
frame::write_frame( frame::write_frame(
@ -102,7 +102,7 @@ impl Connection {
.await .await
} }
async fn recv_method(&mut self) -> Result<classes::Class> { async fn recv_method(&mut self) -> Result<Method> {
let start_ok_frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?; let start_ok_frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?;
ensure_conn(start_ok_frame.kind == FrameType::Method)?; ensure_conn(start_ok_frame.kind == FrameType::Method)?;
@ -112,7 +112,7 @@ impl Connection {
} }
async fn start(&mut self) -> Result<()> { async fn start(&mut self) -> Result<()> {
let start_method = classes::Class::Connection(classes::Connection::Start { let start_method = Method::ConnectionStart {
version_major: 0, version_major: 0,
version_minor: 9, version_minor: 9,
server_properties: server_properties( server_properties: server_properties(
@ -122,7 +122,7 @@ impl Connection {
), ),
mechanisms: "PLAIN".into(), mechanisms: "PLAIN".into(),
locales: "en_US".into(), locales: "en_US".into(),
}); };
debug!(?start_method, "Sending Start method"); debug!(?start_method, "Sending Start method");
self.send_method(0, start_method).await?; self.send_method(0, start_method).await?;
@ -130,12 +130,12 @@ impl Connection {
let start_ok = self.recv_method().await?; let start_ok = self.recv_method().await?;
debug!(?start_ok, "Received Start-Ok"); debug!(?start_ok, "Received Start-Ok");
if let classes::Class::Connection(classes::Connection::StartOk { if let Method::ConnectionStartOk {
mechanism, mechanism,
locale, locale,
response, response,
.. ..
}) = start_ok } = start_ok
{ {
ensure_conn(mechanism == "PLAIN")?; ensure_conn(mechanism == "PLAIN")?;
ensure_conn(locale == "en_US")?; ensure_conn(locale == "en_US")?;
@ -149,11 +149,11 @@ impl Connection {
} }
async fn tune(&mut self) -> Result<()> { async fn tune(&mut self) -> Result<()> {
let tune_method = classes::Class::Connection(classes::Connection::Tune { let tune_method = Method::ConnectionTune {
channel_max: CHANNEL_MAX, channel_max: CHANNEL_MAX,
frame_max: FRAME_SIZE_MAX, frame_max: FRAME_SIZE_MAX,
heartbeat: HEARTBEAT_DELAY, heartbeat: HEARTBEAT_DELAY,
}); };
debug!("Sending Tune method"); debug!("Sending Tune method");
self.send_method(0, tune_method).await?; self.send_method(0, tune_method).await?;
@ -161,11 +161,11 @@ impl Connection {
let tune_ok = self.recv_method().await?; let tune_ok = self.recv_method().await?;
debug!(?tune_ok, "Received Tune-Ok method"); debug!(?tune_ok, "Received Tune-Ok method");
if let classes::Class::Connection(classes::Connection::TuneOk { if let Method::ConnectionTuneOk {
channel_max, channel_max,
frame_max, frame_max,
heartbeat, heartbeat,
}) = tune_ok } = tune_ok
{ {
self.channel_max = channel_max; self.channel_max = channel_max;
self.max_frame_size = usize::try_from(frame_max).unwrap(); self.max_frame_size = usize::try_from(frame_max).unwrap();
@ -180,15 +180,15 @@ impl Connection {
let open = self.recv_method().await?; let open = self.recv_method().await?;
debug!(?open, "Received Open method"); debug!(?open, "Received Open method");
if let classes::Class::Connection(classes::Connection::Open { virtual_host, .. }) = open { if let Method::ConnectionOpen { virtual_host, .. } = open {
ensure_conn(virtual_host == "/")?; ensure_conn(virtual_host == "/")?;
} }
self.send_method( self.send_method(
0, 0,
classes::Class::Connection(classes::Connection::OpenOk { Method::ConnectionOpenOk {
reserved_1: "".to_string(), reserved_1: "".to_string(),
}), },
) )
.await?; .await?;
@ -197,10 +197,8 @@ impl Connection {
async fn main_loop(&mut self) -> Result<()> { async fn main_loop(&mut self) -> Result<()> {
loop { loop {
tokio::select! { let frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?;
frame = frame::read_frame(&mut self.stream, self.max_frame_size) => {
debug!(?frame); debug!(?frame);
let frame = frame?;
self.reset_timeout(); self.reset_timeout();
match frame.kind { match frame.kind {
@ -209,13 +207,6 @@ impl Connection {
_ => warn!(frame_type = ?frame.kind, "TODO"), _ => warn!(frame_type = ?frame.kind, "TODO"),
} }
} }
_ = &mut self.next_timeout => {
if self.heartbeat_delay != 0 {
return Err(ProtocolError::CloseNow.into());
}
}
}
}
} }
async fn dispatch_method(&mut self, frame: Frame) -> Result<()> { async fn dispatch_method(&mut self, frame: Frame) -> Result<()> {
@ -223,12 +214,10 @@ impl Connection {
debug!(?method, "Received method"); debug!(?method, "Received method");
match method { match method {
classes::Class::Connection(classes::Connection::Close { .. }) => { Method::ConnectionClose { .. } => {
// todo: handle closing // todo: handle closing
} }
classes::Class::Channel(classes::Channel::Open { .. }) => { Method::ChannelOpen { .. } => self.channel_open(frame.channel).await?,
self.channel_open(frame.channel).await?
}
_ => { _ => {
// we don't handle this here, forward it to *somewhere* // we don't handle this here, forward it to *somewhere*
@ -274,9 +263,9 @@ impl Connection {
self.send_method( self.send_method(
num, num,
Class::Channel(classes::Channel::OpenOk { Method::ChannelOpenOk {
reserved_1: Vec::new(), reserved_1: Vec::new(),
}), },
) )
.await?; .await?;
@ -325,6 +314,18 @@ impl Connection {
} }
} }
impl Drop for Connection {
fn drop(&mut self) {
self.connection_handle.lock().close();
}
}
impl Drop for Channel {
fn drop(&mut self) {
self.channel_handle.lock().close();
}
}
fn server_properties(host: SocketAddr) -> classes::Table { fn server_properties(host: SocketAddr) -> classes::Table {
fn ls(str: &str) -> classes::FieldValue { fn ls(str: &str) -> classes::FieldValue {
classes::FieldValue::LongString(str.into()) classes::FieldValue::LongString(str.into())

View file

@ -1,4 +1,4 @@
use crate::classes::{Class, Connection, FieldValue}; use crate::classes::{FieldValue, Method};
use crate::frame::FrameType; use crate::frame::FrameType;
use crate::{classes, frame}; use crate::{classes, frame};
use std::collections::HashMap; use std::collections::HashMap;
@ -6,7 +6,7 @@ use std::collections::HashMap;
#[tokio::test] #[tokio::test]
async fn write_start_ok_frame() { async fn write_start_ok_frame() {
let mut payload = Vec::new(); let mut payload = Vec::new();
let method = classes::Class::Connection(classes::Connection::Start { let method = Method::ConnectionStart {
version_major: 0, version_major: 0,
version_minor: 9, version_minor: 9,
server_properties: HashMap::from([( server_properties: HashMap::from([(
@ -15,7 +15,7 @@ async fn write_start_ok_frame() {
)]), )]),
mechanisms: "PLAIN".into(), mechanisms: "PLAIN".into(),
locales: "en_US".into(), locales: "en_US".into(),
}); };
classes::write::write_method(method, &mut payload).unwrap(); classes::write::write_method(method, &mut payload).unwrap();
@ -140,7 +140,7 @@ fn read_start_ok_payload() {
assert_eq!( assert_eq!(
method, method,
Class::Connection(Connection::StartOk { Method::ConnectionStartOk {
client_properties: HashMap::from([ client_properties: HashMap::from([
( (
"product".to_string(), "product".to_string(),
@ -178,6 +178,6 @@ fn read_start_ok_payload() {
mechanism: "PLAIN".to_string(), mechanism: "PLAIN".to_string(),
response: "\x00admin\x00".into(), response: "\x00admin\x00".into(),
locale: "en_US".to_string() locale: "en_US".to_string()
}) }
); );
} }

View file

@ -12,6 +12,7 @@ async fn main() -> Result<()> {
for arg in env::args().skip(1) { for arg in env::args().skip(1) {
match arg.as_str() { match arg.as_str() {
"--debug" => level = Level::DEBUG,
"--trace" => level = Level::TRACE, "--trace" => level = Level::TRACE,
"--dashboard" => dashboard = true, "--dashboard" => dashboard = true,
"ignore-this-clippy" => eprintln!("yes please"), "ignore-this-clippy" => eprintln!("yes please"),

View file

@ -123,7 +123,7 @@ pub fn main() {
fn codegen(amqp: &Amqp) { fn codegen(amqp: &Amqp) {
println!("#![allow(dead_code)]"); println!("#![allow(dead_code)]");
println!("// This file has been generated by `amqp_codegen`. Do not edit it manually.\n"); println!("// This file has been generated by `xtask/src/codegen`. Do not edit it manually.\n");
codegen_domain_defs(amqp); codegen_domain_defs(amqp);
codegen_class_defs(amqp); codegen_class_defs(amqp);
codegen_parser(amqp); codegen_parser(amqp);
@ -159,22 +159,15 @@ fn codegen_domain_defs(amqp: &Amqp) {
fn codegen_class_defs(amqp: &Amqp) { fn codegen_class_defs(amqp: &Amqp) {
println!("#[derive(Debug, Clone, PartialEq)]"); println!("#[derive(Debug, Clone, PartialEq)]");
println!("pub enum Class {{"); println!("pub enum Method {{");
for class in &amqp.classes {
let class_name = class.name.to_upper_camel_case();
println!(" {class_name}({class_name}),");
}
println!("}}\n");
for class in &amqp.classes { for class in &amqp.classes {
let enum_name = class.name.to_upper_camel_case(); let enum_name = class.name.to_upper_camel_case();
doc_comment(&class.doc, 0);
println!("#[derive(Debug, Clone, PartialEq)]");
println!("pub enum {enum_name} {{");
for method in &class.methods { for method in &class.methods {
let method_name = method.name.to_upper_camel_case(); let method_name = method.name.to_upper_camel_case();
doc_comment(&class.doc, 4);
doc_comment(&method.doc, 4); doc_comment(&method.doc, 4);
print!(" {method_name}"); print!(" {enum_name}{method_name}");
if !method.fields.is_empty() { if !method.fields.is_empty() {
println!(" {{"); println!(" {{");
for field in &method.fields { for field in &method.fields {
@ -197,8 +190,9 @@ fn codegen_class_defs(amqp: &Amqp) {
println!(","); println!(",");
} }
} }
println!("}}");
} }
println!("}}\n");
} }
fn amqp_type_to_rust_type(amqp_type: &str) -> &'static str { fn amqp_type_to_rust_type(amqp_type: &str) -> &'static str {

View file

@ -31,7 +31,7 @@ pub type IResult<'a, T> = nom::IResult<&'a [u8], T, TransError>;
" "
); );
println!( println!(
"pub fn parse_method(input: &[u8]) -> Result<(&[u8], Class), nom::Err<TransError>> {{ "pub fn parse_method(input: &[u8]) -> Result<(&[u8], Method), nom::Err<TransError>> {{
alt(({}))(input) alt(({}))(input)
}}", }}",
amqp.classes amqp.classes
@ -47,7 +47,7 @@ pub type IResult<'a, T> = nom::IResult<&'a [u8], T, TransError>;
for class in &amqp.classes { for class in &amqp.classes {
let class_name = class.name.to_snake_case(); let class_name = class.name.to_snake_case();
function(&class_name, "Class", || { function(&class_name, "Method", || {
let class_index = class.index; let class_index = class.index;
let all_methods = class let all_methods = class
.methods .methods
@ -56,8 +56,8 @@ pub type IResult<'a, T> = nom::IResult<&'a [u8], T, TransError>;
.join(", "); .join(", ");
let class_name_raw = &class.name; let class_name_raw = &class.name;
println!( println!(
r#" let (input, _) = tag({class_index}_u16.to_be_bytes())(input).map_err(err("invalid tag for class {class_name_raw}"))?; r#" let (input, _) = tag({class_index}_u16.to_be_bytes())(input).map_err(fail_err("invalid tag for class {class_name_raw}"))?;
alt(({all_methods}))(input).map_err(err("class {class_name_raw}")).map_err(failure)"# alt(({all_methods}))(input).map_err(fail_err("class {class_name_raw}"))"#
); );
}); });
@ -94,10 +94,10 @@ fn method_parser(amqp: &Amqp, class: &Class, method: &Method) {
let method_name_raw = &method.name; let method_name_raw = &method.name;
let function_name = method_function_name(&class_name)(method); let function_name = method_function_name(&class_name)(method);
function(&function_name, "Class", || { function(&function_name, "Method", || {
let method_index = method.index; let method_index = method.index;
println!( println!(
r#" let (input, _) = tag({method_index}_u16.to_be_bytes())(input).map_err(err("parsing method index"))?;"# r#" let (input, _) = tag({method_index}_u16.to_be_bytes())(input).map_err(fail_err("parsing method index"))?;"#
); );
let mut iter = method.fields.iter().peekable(); let mut iter = method.fields.iter().peekable();
while let Some(field) = iter.next() { while let Some(field) = iter.next() {
@ -108,8 +108,9 @@ fn method_parser(amqp: &Amqp, class: &Class, method: &Method) {
let fields_with_bit = subsequent_bit_fields(amqp, field, &mut iter); let fields_with_bit = subsequent_bit_fields(amqp, field, &mut iter);
let amount = fields_with_bit.len(); let amount = fields_with_bit.len();
// todo: remove those map_err(failure)
println!( println!(
r#" let (input, bits) = bit(input, {amount}).map_err(err("field {field_name_raw} in method {method_name_raw}")).map_err(failure)?;"# r#" let (input, bits) = bit(input, {amount}).map_err(fail_err("field {field_name_raw} in method {method_name_raw}")).map_err(failure)?;"#
); );
for (i, field) in fields_with_bit.iter().enumerate() { for (i, field) in fields_with_bit.iter().enumerate() {
@ -120,7 +121,7 @@ fn method_parser(amqp: &Amqp, class: &Class, method: &Method) {
let fn_name = domain_function_name(field_type(field)); let fn_name = domain_function_name(field_type(field));
let field_name = snake_case(&field.name); let field_name = snake_case(&field.name);
println!( println!(
r#" let (input, {field_name}) = {fn_name}(input).map_err(err("field {field_name_raw} in method {method_name_raw}")).map_err(failure)?;"# r#" let (input, {field_name}) = {fn_name}(input).map_err(fail_err("field {field_name_raw} in method {method_name_raw}")).map_err(failure)?;"#
); );
for assert in &field.asserts { for assert in &field.asserts {
@ -130,12 +131,12 @@ fn method_parser(amqp: &Amqp, class: &Class, method: &Method) {
} }
let class_name = class_name.to_upper_camel_case(); let class_name = class_name.to_upper_camel_case();
let method_name = method.name.to_upper_camel_case(); let method_name = method.name.to_upper_camel_case();
println!(" Ok((input, Class::{class_name}({class_name}::{method_name} {{"); println!(" Ok((input, Method::{class_name}{method_name} {{");
for field in &method.fields { for field in &method.fields {
let field_name = snake_case(&field.name); let field_name = snake_case(&field.name);
println!(" {field_name},"); println!(" {field_name},");
} }
println!(" }})))"); println!(" }}))");
}); });
} }

View file

@ -11,28 +11,19 @@ use super::*;
" "
); );
impl_random("Class", || { impl_random("Method", || {
let class_lens = amqp.classes.len(); let class_lens = amqp.classes.len();
println!(" match rng.gen_range(0u32..{class_lens}) {{"); println!(" match rng.gen_range(0u32..{class_lens}) {{");
for (i, class) in amqp.classes.iter().enumerate() { for (i, class) in amqp.classes.iter().enumerate() {
let class_name = class.name.to_upper_camel_case(); let class_name = class.name.to_upper_camel_case();
println!(" {i} => Class::{class_name}({class_name}::random(rng)),"); println!(" {i} => {{");
}
println!(
" _ => unreachable!(),
}}"
);
});
for class in &amqp.classes {
let class_name = class.name.to_upper_camel_case();
impl_random(&class_name, || {
let method_len = class.methods.len(); let method_len = class.methods.len();
println!(" match rng.gen_range(0u32..{method_len}) {{"); println!(" match rng.gen_range(0u32..{method_len}) {{");
for (i, method) in class.methods.iter().enumerate() { for (i, method) in class.methods.iter().enumerate() {
let method_name = method.name.to_upper_camel_case(); let method_name = method.name.to_upper_camel_case();
println!(" {i} => {class_name}::{method_name} {{"); println!(" {i} => Method::{class_name}{method_name} {{");
for field in &method.fields { for field in &method.fields {
let field_name = snake_case(&field.name); let field_name = snake_case(&field.name);
println!(" {field_name}: RandomMethod::random(rng),"); println!(" {field_name}: RandomMethod::random(rng),");
@ -43,8 +34,14 @@ use super::*;
" _ => unreachable!(), " _ => unreachable!(),
}}" }}"
); );
});
println!(" }}");
} }
println!(
" _ => unreachable!(),
}}"
);
});
println!("}}"); println!("}}");
} }

View file

@ -9,7 +9,7 @@ use crate::classes::write_helper::*;
use crate::error::TransError; use crate::error::TransError;
use std::io::Write; use std::io::Write;
pub fn write_method<W: Write>(class: Class, mut writer: W) -> Result<(), TransError> {{ pub fn write_method<W: Write>(class: Method, mut writer: W) -> Result<(), TransError> {{
match class {{" match class {{"
); );
@ -19,12 +19,12 @@ pub fn write_method<W: Write>(class: Class, mut writer: W) -> Result<(), TransEr
for method in &class.methods { for method in &class.methods {
let method_name = method.name.to_upper_camel_case(); let method_name = method.name.to_upper_camel_case();
let method_index = method.index; let method_index = method.index;
println!(" Class::{class_name}({class_name}::{method_name} {{"); println!(" Method::{class_name}{method_name} {{");
for field in &method.fields { for field in &method.fields {
let field_name = snake_case(&field.name); let field_name = snake_case(&field.name);
println!(" {field_name},"); println!(" {field_name},");
} }
println!(" }}) => {{"); println!(" }} => {{");
let [ci0, ci1] = class_index.to_be_bytes(); let [ci0, ci1] = class_index.to_be_bytes();
let [mi0, mi1] = method_index.to_be_bytes(); let [mi0, mi1] = method_index.to_be_bytes();
println!(" writer.write_all(&[{ci0}, {ci1}, {mi0}, {mi1}])?;"); println!(" writer.write_all(&[{ci0}, {ci1}, {mi0}, {mi1}])?;");