move codegen to xtask

This commit is contained in:
nora 2022-02-19 20:07:17 +01:00
parent c5d83fe776
commit 077b6fd633
12 changed files with 1456 additions and 1727 deletions

11
xtask/Cargo.toml Normal file
View file

@ -0,0 +1,11 @@
[package]
name = "xtask"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
heck = "0.4.0"
itertools = "0.10.3"
strong-xml = "0.6.3"

2845
xtask/src/amqp0-9-1.xml Normal file

File diff suppressed because it is too large Load diff

301
xtask/src/codegen/mod.rs Normal file
View file

@ -0,0 +1,301 @@
mod parser;
mod random;
mod write;
use heck::ToUpperCamelCase;
use parser::codegen_parser;
use random::codegen_random;
use std::fs;
use std::iter::Peekable;
use std::path::PathBuf;
use std::str::FromStr;
use strong_xml::XmlRead;
use write::codegen_write;
#[derive(Debug, XmlRead)]
#[xml(tag = "amqp")]
struct Amqp {
#[xml(child = "domain")]
domains: Vec<Domain>,
#[xml(child = "class")]
classes: Vec<Class>,
}
#[derive(Debug, XmlRead)]
#[xml(tag = "domain")]
struct Domain {
#[xml(attr = "name")]
name: String,
#[xml(attr = "type")]
kind: String,
#[xml(attr = "label")]
label: Option<String>,
#[xml(child = "assert")]
asserts: Vec<Assert>,
#[xml(child = "doc")]
doc: Vec<Doc>,
}
#[derive(Debug, XmlRead)]
#[xml(tag = "assert")]
struct Assert {
#[xml(attr = "check")]
check: String,
#[xml(attr = "method")]
method: Option<String>,
#[xml(attr = "field")]
field: Option<String>,
#[xml(attr = "value")]
value: Option<String>,
}
#[derive(Debug, XmlRead)]
#[xml(tag = "class")]
struct Class {
#[xml(attr = "name")]
name: String,
#[xml(attr = "index")]
index: u16,
#[xml(child = "method")]
methods: Vec<Method>,
#[xml(child = "doc")]
doc: Vec<Doc>,
}
#[derive(Debug, XmlRead)]
#[xml(tag = "method")]
struct Method {
#[xml(attr = "name")]
name: String,
#[xml(child = "field")]
fields: Vec<Field>,
#[xml(attr = "index")]
index: u16,
#[xml(child = "doc")]
doc: Vec<Doc>,
}
#[derive(Debug, XmlRead)]
#[xml(tag = "field")]
struct Field {
#[xml(attr = "name")]
name: String,
#[xml(attr = "domain")]
domain: Option<String>,
#[xml(attr = "type")]
kind: Option<String>,
#[xml(child = "assert")]
asserts: Vec<Assert>,
#[xml(child = "doc")]
doc: Vec<Doc>,
}
#[derive(Debug, XmlRead)]
#[xml(tag = "doc")]
struct Doc {
#[xml(text)]
text: String,
#[xml(attr = "type")]
kind: Option<String>,
}
pub fn main() {
let this_file = PathBuf::from_str(file!()).unwrap();
let expected_location = this_file
.parent()
.unwrap()
.parent()
.unwrap()
.join("amqp0-9-1.xml");
let content = fs::read_to_string(expected_location).unwrap();
let amqp = match Amqp::from_str(&content) {
Ok(amqp) => amqp,
Err(err) => {
eprintln!("{err}");
std::process::exit(1);
}
};
codegen(&amqp);
}
fn codegen(amqp: &Amqp) {
println!("#![allow(dead_code)]");
println!("// This file has been generated by `amqp_codegen`. Do not edit it manually.\n");
codegen_domain_defs(amqp);
codegen_class_defs(amqp);
codegen_parser(amqp);
codegen_write(amqp);
codegen_random(amqp);
}
fn codegen_domain_defs(amqp: &Amqp) {
for domain in &amqp.domains {
let invariants = invariants(domain.asserts.iter());
if let Some(label) = &domain.label {
println!("/// {label}");
}
if !invariants.is_empty() {
if domain.label.is_some() {
println!("///");
}
println!("/// {invariants}");
}
if !domain.doc.is_empty() {
println!("///");
doc_comment(&domain.doc, 0);
}
println!(
"pub type {} = {};\n",
domain.name.to_upper_camel_case(),
amqp_type_to_rust_type(&domain.kind),
);
}
}
fn codegen_class_defs(amqp: &Amqp) {
println!("#[derive(Debug, Clone, PartialEq)]");
println!("pub enum Class {{");
for class in &amqp.classes {
let class_name = class.name.to_upper_camel_case();
println!(" {class_name}({class_name}),");
}
println!("}}\n");
for class in &amqp.classes {
let enum_name = class.name.to_upper_camel_case();
doc_comment(&class.doc, 0);
println!("#[derive(Debug, Clone, PartialEq)]");
println!("pub enum {enum_name} {{");
for method in &class.methods {
let method_name = method.name.to_upper_camel_case();
doc_comment(&method.doc, 4);
print!(" {method_name}");
if !method.fields.is_empty() {
println!(" {{");
for field in &method.fields {
let field_name = snake_case(&field.name);
let (field_type, field_docs) =
get_invariants_with_type(field_type(field), field.asserts.as_ref());
if !field_docs.is_empty() {
println!(" /// {field_docs}");
if !field.doc.is_empty() {
println!(" ///");
doc_comment(&field.doc, 8);
}
} else {
doc_comment(&field.doc, 8);
}
println!(" {field_name}: {field_type},");
}
println!(" }},");
} else {
println!(",");
}
}
println!("}}");
}
}
fn amqp_type_to_rust_type(amqp_type: &str) -> &'static str {
match amqp_type {
"octet" => "u8",
"short" => "u16",
"long" => "u32",
"longlong" => "u64",
"bit" => "bool",
"shortstr" => "String",
"longstr" => "Vec<u8>",
"timestamp" => "u64",
"table" => "super::Table",
_ => unreachable!("invalid type {}", amqp_type),
}
}
fn field_type(field: &Field) -> &String {
field.domain.as_ref().or(field.kind.as_ref()).unwrap()
}
fn resolve_type_from_domain(amqp: &Amqp, domain: &str) -> String {
amqp.domains
.iter()
.find(|d| d.name == domain)
.map(|d| d.kind.clone())
.unwrap()
}
/// returns (type name, invariant docs)
fn get_invariants_with_type(domain: &str, asserts: &[Assert]) -> (String, String) {
let additional_docs = invariants(asserts.iter());
let type_name = domain.to_upper_camel_case();
(type_name, additional_docs)
}
fn snake_case(ident: &str) -> String {
use heck::ToSnakeCase;
if ident == "type" {
"r#type".to_string()
} else {
ident.to_snake_case()
}
}
fn subsequent_bit_fields<'a>(
amqp: &Amqp,
bit_field: &'a Field,
iter: &mut Peekable<impl Iterator<Item = &'a Field>>,
) -> Vec<&'a Field> {
let mut fields_with_bit = vec![bit_field];
loop {
if iter
.peek()
.map(|f| resolve_type_from_domain(amqp, field_type(f)) == "bit")
.unwrap_or(false)
{
fields_with_bit.push(iter.next().unwrap());
} else {
break;
}
}
fields_with_bit
}
fn invariants<'a>(asserts: impl Iterator<Item = &'a Assert>) -> String {
asserts
.map(|assert| match &*assert.check {
"notnull" => "must not be null".to_string(),
"length" => format!("must be shorter than {}", assert.value.as_ref().unwrap()),
"regexp" => format!("must match `{}`", assert.value.as_ref().unwrap()),
"le" => {
format!(
"must be less than the {} field of the method {}",
assert.method.as_ref().unwrap(),
assert.field.as_ref().unwrap()
)
}
_ => unimplemented!(),
})
.collect::<Vec<_>>()
.join(", ")
}
fn doc_comment(docs: &[Doc], indent: usize) {
for doc in docs {
if doc.kind == Some("grammar".to_string()) {
continue;
}
for line in doc.text.lines() {
let line = line.trim();
if !line.is_empty() {
let indent = " ".repeat(indent);
println!("{indent}/// {line}");
}
}
}
}

182
xtask/src/codegen/parser.rs Normal file
View file

@ -0,0 +1,182 @@
use super::{
field_type, resolve_type_from_domain, snake_case, subsequent_bit_fields, Amqp, Assert, Class,
Domain, Method,
};
use heck::{ToSnakeCase, ToUpperCamelCase};
use itertools::Itertools;
fn method_function_name(class_name: &str) -> impl Fn(&Method) -> String + '_ {
move |method| {
let method_name = method.name.to_snake_case();
format!("{class_name}_{method_name}")
}
}
fn domain_function_name(domain_name: &str) -> String {
let domain_name = domain_name.to_snake_case();
format!("domain_{domain_name}")
}
pub(super) fn codegen_parser(amqp: &Amqp) {
println!(
"pub mod parse {{
use super::*;
use crate::classes::parse_helper::*;
use crate::error::TransError;
use nom::{{branch::alt, bytes::complete::tag}};
use regex::Regex;
use once_cell::sync::Lazy;
pub type IResult<'a, T> = nom::IResult<&'a [u8], T, TransError>;
"
);
println!(
"pub fn parse_method(input: &[u8]) -> Result<(&[u8], Class), nom::Err<TransError>> {{
alt(({}))(input)
}}",
amqp.classes
.iter()
.map(|class| class.name.to_snake_case())
.join(", ")
);
for domain in &amqp.domains {
domain_parser(domain);
}
for class in &amqp.classes {
let class_name = class.name.to_snake_case();
function(&class_name, "Class", || {
let class_index = class.index;
let all_methods = class
.methods
.iter()
.map(method_function_name(&class_name))
.join(", ");
let class_name_raw = &class.name;
println!(
r#" let (input, _) = tag({class_index}_u16.to_be_bytes())(input).map_err(err("invalid tag for class {class_name_raw}"))?;
alt(({all_methods}))(input).map_err(err("class {class_name_raw}")).map_err(failure)"#
);
});
for method in &class.methods {
method_parser(amqp, class, method);
}
}
println!("\n}}");
}
fn domain_parser(domain: &Domain) {
let fn_name = domain_function_name(&domain.name);
let type_name = domain.kind.to_snake_case();
// don't even bother with bit domains, do them manually at call site
if type_name != "bit" {
function(&fn_name, &domain.name.to_upper_camel_case(), || {
if domain.asserts.is_empty() {
println!(" {type_name}(input)");
} else {
println!(" let (input, result) = {type_name}(input)?;");
for assert in &domain.asserts {
assert_check(assert, &type_name, "result");
}
println!(" Ok((input, result))");
}
});
}
}
fn method_parser(amqp: &Amqp, class: &Class, method: &Method) {
let class_name = class.name.to_snake_case();
let method_name_raw = &method.name;
let function_name = method_function_name(&class_name)(method);
function(&function_name, "Class", || {
let method_index = method.index;
println!(
r#" let (input, _) = tag({method_index}_u16.to_be_bytes())(input).map_err(err("parsing method index"))?;"#
);
let mut iter = method.fields.iter().peekable();
while let Some(field) = iter.next() {
let field_name_raw = &field.name;
let type_name = resolve_type_from_domain(amqp, field_type(field));
if type_name == "bit" {
let fields_with_bit = subsequent_bit_fields(amqp, field, &mut iter);
let amount = fields_with_bit.len();
println!(
r#" let (input, bits) = bit(input, {amount}).map_err(err("field {field_name_raw} in method {method_name_raw}")).map_err(failure)?;"#
);
for (i, field) in fields_with_bit.iter().enumerate() {
let field_name = snake_case(&field.name);
println!(" let {field_name} = bits[{i}];");
}
} else {
let fn_name = domain_function_name(field_type(field));
let field_name = snake_case(&field.name);
println!(
r#" let (input, {field_name}) = {fn_name}(input).map_err(err("field {field_name_raw} in method {method_name_raw}")).map_err(failure)?;"#
);
for assert in &field.asserts {
assert_check(assert, &type_name, &field_name);
}
}
}
let class_name = class_name.to_upper_camel_case();
let method_name = method.name.to_upper_camel_case();
println!(" Ok((input, Class::{class_name}({class_name}::{method_name} {{");
for field in &method.fields {
let field_name = snake_case(&field.name);
println!(" {field_name},");
}
println!(" }})))");
});
}
fn assert_check(assert: &Assert, type_name: &str, var_name: &str) {
match &*assert.check {
"notnull" => match type_name {
"shortstr" | "longstr" => {
println!(
r#" if {var_name}.is_empty() {{ fail!("string was null for field {var_name}") }}"#
);
}
"short" => {
println!(
r#" if {var_name} == 0 {{ fail!("number was 0 for field {var_name}") }}"#
);
}
_ => unimplemented!(),
},
"regexp" => {
let value = assert.value.as_ref().unwrap();
println!(
r#" static REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"{value}").unwrap());"#
);
let cause = format!("regex `{value}` did not match value for field {var_name}");
println!(r#" if !REGEX.is_match(&{var_name}) {{ fail!(r"{cause}") }}"#);
}
"le" => {} // can't validate this here
"length" => {
let length = assert.value.as_ref().unwrap();
let cause = format!("value is shorter than {length} for field {var_name}");
println!(r#" if {var_name}.len() > {length} {{ fail!("{cause}") }}"#);
}
_ => unimplemented!(),
}
}
fn function<F>(name: &str, ret_ty: &str, body: F)
where
F: FnOnce(),
{
println!("fn {name}(input: &[u8]) -> IResult<{ret_ty}> {{");
body();
println!("}}");
}

View file

@ -0,0 +1,62 @@
use super::{snake_case, Amqp};
use heck::ToUpperCamelCase;
pub(super) fn codegen_random(amqp: &Amqp) {
println!(
"#[cfg(test)]
mod random {{
use rand::Rng;
use crate::classes::tests::RandomMethod;
use super::*;
"
);
impl_random("Class", || {
let class_lens = amqp.classes.len();
println!(" match rng.gen_range(0u32..{class_lens}) {{");
for (i, class) in amqp.classes.iter().enumerate() {
let class_name = class.name.to_upper_camel_case();
println!(" {i} => Class::{class_name}({class_name}::random(rng)),");
}
println!(
" _ => unreachable!(),
}}"
);
});
for class in &amqp.classes {
let class_name = class.name.to_upper_camel_case();
impl_random(&class_name, || {
let method_len = class.methods.len();
println!(" match rng.gen_range(0u32..{method_len}) {{");
for (i, method) in class.methods.iter().enumerate() {
let method_name = method.name.to_upper_camel_case();
println!(" {i} => {class_name}::{method_name} {{");
for field in &method.fields {
let field_name = snake_case(&field.name);
println!(" {field_name}: RandomMethod::random(rng),");
}
println!(" }},");
}
println!(
" _ => unreachable!(),
}}"
);
});
}
println!("}}");
}
fn impl_random(name: &str, body: impl FnOnce()) {
println!(
"impl<R: Rng> RandomMethod<R> for {name} {{
#[allow(unused_variables)]
fn random(rng: &mut R) -> Self {{"
);
body();
println!(" }}\n}}");
}

View file

@ -0,0 +1,58 @@
use super::{field_type, resolve_type_from_domain, snake_case, subsequent_bit_fields, Amqp};
use heck::ToUpperCamelCase;
pub(super) fn codegen_write(amqp: &Amqp) {
println!(
"pub mod write {{
use super::*;
use crate::classes::write_helper::*;
use crate::error::TransError;
use std::io::Write;
pub fn write_method<W: Write>(class: Class, mut writer: W) -> Result<(), TransError> {{
match class {{"
);
for class in &amqp.classes {
let class_name = class.name.to_upper_camel_case();
let class_index = class.index;
for method in &class.methods {
let method_name = method.name.to_upper_camel_case();
let method_index = method.index;
println!(" Class::{class_name}({class_name}::{method_name} {{");
for field in &method.fields {
let field_name = snake_case(&field.name);
println!(" {field_name},");
}
println!(" }}) => {{");
let [ci0, ci1] = class_index.to_be_bytes();
let [mi0, mi1] = method_index.to_be_bytes();
println!(" writer.write_all(&[{ci0}, {ci1}, {mi0}, {mi1}])?;");
let mut iter = method.fields.iter().peekable();
while let Some(field) = iter.next() {
let field_name = snake_case(&field.name);
let type_name = resolve_type_from_domain(amqp, field_type(field));
if type_name == "bit" {
let fields_with_bit = subsequent_bit_fields(amqp, field, &mut iter);
print!(" bit(&[");
for field in fields_with_bit {
let field_name = snake_case(&field.name);
print!("{field_name}, ");
}
println!("], &mut writer)?;");
} else {
println!(" {type_name}({field_name}, &mut writer)?;");
}
}
println!(" }}");
}
}
println!(
" }}
Ok(())
}}
}}"
);
}

22
xtask/src/main.rs Normal file
View file

@ -0,0 +1,22 @@
mod codegen;
fn main() {
let command = std::env::args().nth(1).unwrap_or_else(|| {
eprintln!("No task provided");
help();
std::process::exit(1);
});
match command.as_str() {
"generate" | "gen" => codegen::main(),
_ => eprintln!("Unknown command {command}."),
}
}
fn help() {
println!(
"Available tasks:
generate - Generate amqp method code in `amqp_transport/src/classes/generated.rs.
Dumps code to stdout and should be redirected manually."
);
}