mod parser; mod random; mod write; use crate::parser::codegen_parser; use crate::random::codegen_random; use crate::write::codegen_write; use heck::ToUpperCamelCase; use std::fs; use std::iter::Peekable; use strong_xml::XmlRead; #[derive(Debug, XmlRead)] #[xml(tag = "amqp")] struct Amqp { #[xml(child = "domain")] domains: Vec, #[xml(child = "class")] classes: Vec, } #[derive(Debug, XmlRead)] #[xml(tag = "domain")] struct Domain { #[xml(attr = "name")] name: String, #[xml(attr = "type")] kind: String, #[xml(attr = "label")] label: Option, #[xml(child = "assert")] asserts: Vec, #[xml(child = "doc")] doc: Vec, } #[derive(Debug, XmlRead)] #[xml(tag = "assert")] struct Assert { #[xml(attr = "check")] check: String, #[xml(attr = "method")] method: Option, #[xml(attr = "field")] field: Option, #[xml(attr = "value")] value: Option, } #[derive(Debug, XmlRead)] #[xml(tag = "class")] struct Class { #[xml(attr = "name")] name: String, #[xml(attr = "index")] index: u16, #[xml(child = "method")] methods: Vec, #[xml(child = "doc")] doc: Vec, } #[derive(Debug, XmlRead)] #[xml(tag = "method")] struct Method { #[xml(attr = "name")] name: String, #[xml(child = "field")] fields: Vec, #[xml(attr = "index")] index: u16, #[xml(child = "doc")] doc: Vec, } #[derive(Debug, XmlRead)] #[xml(tag = "field")] struct Field { #[xml(attr = "name")] name: String, #[xml(attr = "domain")] domain: Option, #[xml(attr = "type")] kind: Option, #[xml(child = "assert")] asserts: Vec, #[xml(child = "doc")] doc: Vec, } #[derive(Debug, XmlRead)] #[xml(tag = "doc")] struct Doc { #[xml(text)] text: String, #[xml(attr = "type")] kind: Option, } fn main() { let content = fs::read_to_string("./amqp0-9-1.xml").unwrap(); let amqp = match Amqp::from_str(&content) { Ok(amqp) => amqp, Err(err) => { eprintln!("{err}"); std::process::exit(1); } }; codegen(&amqp); } fn codegen(amqp: &Amqp) { println!("// This file has been generated by `amqp_codegen`. Do not edit it manually.\n"); codegen_domain_defs(amqp); codegen_class_defs(amqp); codegen_parser(amqp); codegen_write(amqp); codegen_random(amqp); } fn codegen_domain_defs(amqp: &Amqp) { for domain in &amqp.domains { let invariants = invariants(domain.asserts.iter()); if let Some(label) = &domain.label { println!("/// {label}"); } if !invariants.is_empty() { if domain.label.is_some() { println!("///"); } println!("/// {invariants}"); } if !domain.doc.is_empty() { println!("///"); doc_comment(&domain.doc, 0); } println!( "pub type {} = {};\n", domain.name.to_upper_camel_case(), amqp_type_to_rust_type(&domain.kind), ); } } fn codegen_class_defs(amqp: &Amqp) { println!("#[derive(Debug, Clone, PartialEq)]"); println!("pub enum Class {{"); for class in &amqp.classes { let class_name = class.name.to_upper_camel_case(); println!(" {class_name}({class_name}),"); } println!("}}\n"); for class in &amqp.classes { let enum_name = class.name.to_upper_camel_case(); doc_comment(&class.doc, 0); println!("#[derive(Debug, Clone, PartialEq)]"); println!("pub enum {enum_name} {{"); for method in &class.methods { let method_name = method.name.to_upper_camel_case(); doc_comment(&method.doc, 4); print!(" {method_name}"); if !method.fields.is_empty() { println!(" {{"); for field in &method.fields { let field_name = snake_case(&field.name); let (field_type, field_docs) = get_invariants_with_type(field_type(field), field.asserts.as_ref()); if !field_docs.is_empty() { println!(" /// {field_docs}"); if !field.doc.is_empty() { println!(" ///"); doc_comment(&field.doc, 8); } } else { doc_comment(&field.doc, 8); } println!(" {field_name}: {field_type},"); } println!(" }},"); } else { println!(","); } } println!("}}"); } } fn amqp_type_to_rust_type(amqp_type: &str) -> &'static str { match amqp_type { "octet" => "u8", "short" => "u16", "long" => "u32", "longlong" => "u64", "bit" => "bool", "shortstr" => "String", "longstr" => "Vec", "timestamp" => "u64", "table" => "super::Table", _ => unreachable!("invalid type {}", amqp_type), } } fn field_type(field: &Field) -> &String { field.domain.as_ref().or(field.kind.as_ref()).unwrap() } fn resolve_type_from_domain(amqp: &Amqp, domain: &str) -> String { amqp.domains .iter() .find(|d| d.name == domain) .map(|d| d.kind.clone()) .unwrap() } /// returns (type name, invariant docs) fn get_invariants_with_type(domain: &str, asserts: &[Assert]) -> (String, String) { let additional_docs = invariants(asserts.iter()); let type_name = domain.to_upper_camel_case(); (type_name, additional_docs) } fn snake_case(ident: &str) -> String { use heck::ToSnakeCase; if ident == "type" { "r#type".to_string() } else { ident.to_snake_case() } } fn subsequent_bit_fields<'a>( amqp: &Amqp, bit_field: &'a Field, iter: &mut Peekable>, ) -> Vec<&'a Field> { let mut fields_with_bit = vec![bit_field]; loop { if iter .peek() .map(|f| resolve_type_from_domain(amqp, field_type(f)) == "bit") .unwrap_or(false) { fields_with_bit.push(iter.next().unwrap()); } else { break; } } fields_with_bit } fn invariants<'a>(asserts: impl Iterator) -> String { asserts .map(|assert| match &*assert.check { "notnull" => "must not be null".to_string(), "length" => format!("must be shorter than {}", assert.value.as_ref().unwrap()), "regexp" => format!("must match `{}`", assert.value.as_ref().unwrap()), "le" => { format!( "must be less than the {} field of the method {}", assert.method.as_ref().unwrap(), assert.field.as_ref().unwrap() ) } _ => unimplemented!(), }) .collect::>() .join(", ") } fn doc_comment(docs: &[Doc], indent: usize) { for doc in docs { if doc.kind == Some("grammar".to_string()) { continue; } for line in doc.text.lines() { let line = line.trim(); if !line.is_empty() { let indent = " ".repeat(indent); println!("{indent}/// {line}"); } } } }