working codegen

This commit is contained in:
nora 2022-02-09 21:32:32 +01:00
parent 08628022c2
commit 0d548f7798
3 changed files with 463 additions and 13 deletions

View file

@ -1,7 +1,7 @@
use anyhow::Result;
use heck::{ToSnakeCase, ToUpperCamelCase};
use anyhow::{Context, Result};
use heck::ToUpperCamelCase;
use std::fs;
use strong_xml::{XmlError, XmlRead};
use strong_xml::XmlRead;
#[derive(Debug, XmlRead)]
#[xml(tag = "amqp")]
@ -28,6 +28,12 @@ struct Domain {
struct Assert {
#[xml(attr = "check")]
check: String,
#[xml(attr = "method")]
method: Option<String>,
#[xml(attr = "field")]
field: Option<String>,
#[xml(attr = "value")]
value: Option<String>,
}
#[derive(Debug, XmlRead)]
@ -61,6 +67,8 @@ struct Field {
name: String,
#[xml(attr = "domain")]
domain: Option<String>,
#[xml(attr = "type")]
kind: Option<String>,
#[xml(child = "assert")]
asserts: Vec<Assert>,
}
@ -73,20 +81,52 @@ fn main() -> Result<()> {
}
fn codegen(amqp: &Amqp) -> Result<()> {
println!("use std::collections::HashMap;\n");
domain_defs(amqp)?;
class_defs(amqp)
}
fn domain_defs(amqp: &Amqp) -> Result<()> {
for domain in &amqp.domains {
let invariants = invariants(domain.asserts.iter());
if !invariants.is_empty() {
println!("/// {invariants}");
}
println!(
"type {} = {};\n",
domain.name.to_upper_camel_case(),
amqp_type_to_rust_type(&domain.kind),
);
}
Ok(())
}
fn class_defs(amqp: &Amqp) -> Result<()> {
for class in &amqp.classes {
let enum_name = class.name.to_upper_camel_case();
println!("///////// ---- Class {enum_name}");
println!("enum {enum_name} {{");
println!("/// Index {}, handler = {}", class.index, class.handler);
println!("pub enum {enum_name} {{");
for method in &class.methods {
let method_name = method.name.to_upper_camel_case();
println!(" /// Index {}", method.index);
print!(" {method_name}");
if method.fields.len() > 0 {
println!(" {{");
for field in &method.fields {
let field_name = field.name.to_snake_case();
println!(" {field_name}: (),");
let field_name = snake_case(&field.name);
let (field_type, field_docs) = resolve_type(
amqp,
&field.domain.as_ref().or(field.kind.as_ref()).unwrap(),
field.asserts.as_ref(),
)?;
if !field_docs.is_empty() {
println!(" /// {field_docs}");
}
println!(" {field_name}: {field_type},");
}
println!(" }}");
println!(" }},");
} else {
println!(",");
}
@ -96,3 +136,75 @@ fn codegen(amqp: &Amqp) -> Result<()> {
Ok(())
}
fn amqp_type_to_rust_type<'a>(amqp_type: &str) -> &'static str {
match amqp_type {
"octet" => "u8",
"short" => "u16",
"long" => "u32",
"longlong" => "u64",
"bit" => "u8",
"shortstr" | "longstr" => "String",
"timestamp" => "u64",
"table" => "HashMap<Shortstr, (Octet, /* todo */ Box<dyn std::any::Any>)>",
_ => unreachable!("invalid type {}", amqp_type),
}
}
/// returns (type name, invariant docs)
fn resolve_type(amqp: &Amqp, domain: &str, asserts: &[Assert]) -> Result<(String, String)> {
let kind = amqp
.domains
.iter()
.find(|d| &d.name == domain)
.context("domain not found")?;
let is_nonnull = is_nonnull(asserts.iter().chain(kind.asserts.iter()));
let additional_docs = invariants(asserts.iter());
let type_name = domain.to_upper_camel_case();
Ok((
if is_nonnull {
type_name
} else {
format!("Option<{type_name}>")
},
additional_docs,
))
}
fn is_nonnull<'a>(mut asserts: impl Iterator<Item = &'a Assert>) -> bool {
asserts.find(|assert| assert.check == "notnull").is_some()
}
fn snake_case(ident: &str) -> String {
use heck::ToSnakeCase;
if ident == "type" {
"r#type".to_string()
} else {
ident.to_snake_case()
}
}
fn invariants<'a>(asserts: impl Iterator<Item = &'a Assert>) -> String {
asserts
.filter_map(|assert| match &*assert.check {
"notnull" => None,
"length" => Some(format!(
"must be shorter than {}",
assert.value.as_ref().unwrap()
)),
"regexp" => Some(format!("must match `{}`", assert.value.as_ref().unwrap())),
"le" => Some(format!(
"must be less than the {} field of the method {}",
assert.method.as_ref().unwrap(),
assert.field.as_ref().unwrap()
)),
_ => unimplemented!(),
})
.collect::<Vec<_>>()
.join(", ")
}

View file

@ -1,7 +1,345 @@
mod connection;
use std::collections::HashMap;
use crate::classes::connection::Connection;
type ClassId = u16;
pub enum Class {
Connection(Connection),
type ConsumerTag = String;
type DeliveryTag = u64;
/// must be shorter than 127, must match `^[a-zA-Z0-9-_.:]*$`
type ExchangeName = String;
type MethodId = u16;
type NoAck = u8;
type NoLocal = u8;
type NoWait = u8;
/// must be shorter than 127
type Path = String;
type PeerProperties = HashMap<Shortstr, (Octet, /* todo */ Box<dyn std::any::Any>)>;
/// must be shorter than 127, must match `^[a-zA-Z0-9-_.:]*$`
type QueueName = String;
type Redelivered = u8;
type MessageCount = u32;
type ReplyCode = u16;
type ReplyText = String;
type Bit = u8;
type Octet = u8;
type Short = u16;
type Long = u32;
type Longlong = u64;
type Shortstr = String;
type Longstr = String;
type Timestamp = u64;
type Table = HashMap<Shortstr, (Octet, /* todo */ Box<dyn std::any::Any>)>;
/// Index 10, handler = connection
pub enum Connection {
/// Index 10
Start {
version_major: Option<Octet>,
version_minor: Option<Octet>,
server_properties: Option<PeerProperties>,
mechanisms: Longstr,
locales: Longstr,
},
/// Index 11
StartOk {
client_properties: Option<PeerProperties>,
mechanism: Shortstr,
response: Longstr,
locale: Shortstr,
},
/// Index 20
Secure {
challenge: Option<Longstr>,
},
/// Index 21
SecureOk {
response: Longstr,
},
/// Index 30
Tune {
channel_max: Option<Short>,
frame_max: Option<Long>,
heartbeat: Option<Short>,
},
/// Index 31
TuneOk {
/// must be less than the tune field of the method channel-max
channel_max: Short,
frame_max: Option<Long>,
heartbeat: Option<Short>,
},
/// Index 40
Open {
virtual_host: Path,
reserved_1: Option<Shortstr>,
reserved_2: Option<Bit>,
},
/// Index 41
OpenOk {
reserved_1: Option<Shortstr>,
},
/// Index 50
Close {
reply_code: ReplyCode,
reply_text: ReplyText,
class_id: Option<ClassId>,
method_id: Option<MethodId>,
},
/// Index 51
CloseOk,
/// Index 60
Blocked {
reason: Option<Shortstr>,
},
/// Index 61
Unblocked,
}
/// Index 20, handler = channel
pub enum Channel {
/// Index 10
Open {
reserved_1: Option<Shortstr>,
},
/// Index 11
OpenOk {
reserved_1: Option<Longstr>,
},
/// Index 20
Flow {
active: Option<Bit>,
},
/// Index 21
FlowOk {
active: Option<Bit>,
},
/// Index 40
Close {
reply_code: ReplyCode,
reply_text: ReplyText,
class_id: Option<ClassId>,
method_id: Option<MethodId>,
},
/// Index 41
CloseOk,
}
/// Index 40, handler = channel
pub enum Exchange {
/// Index 10
Declare {
reserved_1: Option<Short>,
exchange: ExchangeName,
r#type: Option<Shortstr>,
passive: Option<Bit>,
durable: Option<Bit>,
reserved_2: Option<Bit>,
reserved_3: Option<Bit>,
no_wait: Option<NoWait>,
arguments: Option<Table>,
},
/// Index 11
DeclareOk,
/// Index 20
Delete {
reserved_1: Option<Short>,
exchange: ExchangeName,
if_unused: Option<Bit>,
no_wait: Option<NoWait>,
},
/// Index 21
DeleteOk,
}
/// Index 50, handler = channel
pub enum Queue {
/// Index 10
Declare {
reserved_1: Option<Short>,
queue: Option<QueueName>,
passive: Option<Bit>,
durable: Option<Bit>,
exclusive: Option<Bit>,
auto_delete: Option<Bit>,
no_wait: Option<NoWait>,
arguments: Option<Table>,
},
/// Index 11
DeclareOk {
queue: QueueName,
message_count: Option<MessageCount>,
consumer_count: Option<Long>,
},
/// Index 20
Bind {
reserved_1: Option<Short>,
queue: Option<QueueName>,
exchange: Option<ExchangeName>,
routing_key: Option<Shortstr>,
no_wait: Option<NoWait>,
arguments: Option<Table>,
},
/// Index 21
BindOk,
/// Index 50
Unbind {
reserved_1: Option<Short>,
queue: Option<QueueName>,
exchange: Option<ExchangeName>,
routing_key: Option<Shortstr>,
arguments: Option<Table>,
},
/// Index 51
UnbindOk,
/// Index 30
Purge {
reserved_1: Option<Short>,
queue: Option<QueueName>,
no_wait: Option<NoWait>,
},
/// Index 31
PurgeOk {
message_count: Option<MessageCount>,
},
/// Index 40
Delete {
reserved_1: Option<Short>,
queue: Option<QueueName>,
if_unused: Option<Bit>,
if_empty: Option<Bit>,
no_wait: Option<NoWait>,
},
/// Index 41
DeleteOk {
message_count: Option<MessageCount>,
},
}
/// Index 60, handler = channel
pub enum Basic {
/// Index 10
Qos {
prefetch_size: Option<Long>,
prefetch_count: Option<Short>,
global: Option<Bit>,
},
/// Index 11
QosOk,
/// Index 20
Consume {
reserved_1: Option<Short>,
queue: Option<QueueName>,
consumer_tag: Option<ConsumerTag>,
no_local: Option<NoLocal>,
no_ack: Option<NoAck>,
exclusive: Option<Bit>,
no_wait: Option<NoWait>,
arguments: Option<Table>,
},
/// Index 21
ConsumeOk {
consumer_tag: Option<ConsumerTag>,
},
/// Index 30
Cancel {
consumer_tag: Option<ConsumerTag>,
no_wait: Option<NoWait>,
},
/// Index 31
CancelOk {
consumer_tag: Option<ConsumerTag>,
},
/// Index 40
Publish {
reserved_1: Option<Short>,
exchange: Option<ExchangeName>,
routing_key: Option<Shortstr>,
mandatory: Option<Bit>,
immediate: Option<Bit>,
},
/// Index 50
Return {
reply_code: ReplyCode,
reply_text: ReplyText,
exchange: Option<ExchangeName>,
routing_key: Option<Shortstr>,
},
/// Index 60
Deliver {
consumer_tag: Option<ConsumerTag>,
delivery_tag: Option<DeliveryTag>,
redelivered: Option<Redelivered>,
exchange: Option<ExchangeName>,
routing_key: Option<Shortstr>,
},
/// Index 70
Get {
reserved_1: Option<Short>,
queue: Option<QueueName>,
no_ack: Option<NoAck>,
},
/// Index 71
GetOk {
delivery_tag: Option<DeliveryTag>,
redelivered: Option<Redelivered>,
exchange: Option<ExchangeName>,
routing_key: Option<Shortstr>,
message_count: Option<MessageCount>,
},
/// Index 72
GetEmpty {
reserved_1: Option<Shortstr>,
},
/// Index 80
Ack {
delivery_tag: Option<DeliveryTag>,
multiple: Option<Bit>,
},
/// Index 90
Reject {
delivery_tag: Option<DeliveryTag>,
requeue: Option<Bit>,
},
/// Index 100
RecoverAsync {
requeue: Option<Bit>,
},
/// Index 110
Recover {
requeue: Option<Bit>,
},
/// Index 111
RecoverOk,
}
/// Index 90, handler = channel
pub enum Tx {
/// Index 10
Select,
/// Index 11
SelectOk,
/// Index 20
Commit,
/// Index 21
CommitOk,
/// Index 30
Rollback,
/// Index 31
RollbackOk,
}

View file

@ -8,7 +8,7 @@ mod frame_type {
pub const METHOD: u8 = 1;
pub const HEADER: u8 = 2;
pub const BODY: u8 = 3;
pub const HEARTBEAT: u8 = 4;
pub const HEARTBEAT: u8 = 8;
}
#[derive(Debug, Clone, PartialEq, Eq)]