#![allow(clippy::needless_late_init)] // because of a bad derive macro mod parser; mod random; mod write; use anyhow::{bail, Context}; use heck::ToUpperCamelCase; use std::fs; use std::fs::File; use std::io::Write; use std::iter::Peekable; use std::path::{Path, PathBuf}; use std::process::Command; use std::str::FromStr; use strong_xml::XmlRead; #[derive(Debug, XmlRead)] #[xml(tag = "amqp")] struct Amqp { #[xml(child = "domain")] domains: Vec, #[xml(child = "class")] classes: Vec, } #[derive(Debug, XmlRead)] #[xml(tag = "domain")] struct Domain { #[xml(attr = "name")] name: String, #[xml(attr = "type")] kind: String, #[xml(attr = "label")] label: Option, #[xml(child = "assert")] asserts: Vec, #[xml(child = "doc")] doc: Vec, } #[derive(Debug, XmlRead)] #[xml(tag = "assert")] struct Assert { #[xml(attr = "check")] check: String, #[xml(attr = "method")] method: Option, #[xml(attr = "field")] field: Option, #[xml(attr = "value")] value: Option, } #[derive(Debug, XmlRead)] #[xml(tag = "class")] struct Class { #[xml(attr = "name")] name: String, #[xml(attr = "index")] index: u16, #[xml(child = "method")] methods: Vec, #[xml(child = "doc")] doc: Vec, } #[derive(Debug, XmlRead)] #[xml(tag = "method")] struct Method { #[xml(attr = "name")] name: String, #[xml(child = "field")] fields: Vec, #[xml(attr = "index")] index: u16, #[xml(child = "doc")] doc: Vec, } #[derive(Debug, XmlRead)] #[xml(tag = "field")] struct Field { #[xml(attr = "name")] name: String, #[xml(attr = "domain")] domain: Option, #[xml(attr = "type")] kind: Option, #[xml(child = "assert")] asserts: Vec, #[xml(child = "doc")] doc: Vec, } #[derive(Debug, XmlRead)] #[xml(tag = "doc")] struct Doc { #[xml(text)] text: String, #[xml(attr = "type")] kind: Option, } struct Codegen { output: Box, } fn fmt(path: &Path) -> anyhow::Result<()> { println!("Formatting {path:?}..."); let status = Command::new("rustfmt").arg(path).status()?; if !status.success() { bail!("error formatting {path:?}"); } Ok(()) } pub fn main() -> anyhow::Result<()> { let this_file = PathBuf::from_str(file!()).context("own file path")?; let xtask_root = this_file .parent() .context("codegen directory path")? .parent() .context("src directory path")? .parent() .context("xtask root path")?; let amqp_spec = xtask_root.join("amqp0-9-1.xml"); let project_root = xtask_root.parent().context("get project root parent")?; let transport_generated_path = project_root.join("amqp_transport/src/methods/generated.rs"); let core_generated_path = project_root.join("amqp_core/src/methods/generated.rs"); let content = fs::read_to_string(amqp_spec).context("read amqp spec file")?; let amqp = Amqp::from_str(&content).context("parse amqp spec file")?; let transport_output = File::create(&transport_generated_path).context("transport output file create")?; let core_output = File::create(&core_generated_path).context("core output file create")?; Codegen { output: Box::new(transport_output), } .transport_codegen(&amqp); Codegen { output: Box::new(core_output), } .core_codegen(&amqp); fmt(&transport_generated_path)?; fmt(&core_generated_path)?; Ok(()) } impl Codegen { fn transport_codegen(&mut self, amqp: &Amqp) { writeln!(self.output, "#![allow(dead_code)]").ok(); writeln!( self.output, "// This file has been generated by `xtask/src/codegen`. Do not edit it manually.\n" ) .ok(); self.codegen_parser(amqp); self.codegen_write(amqp); self.codegen_random(amqp); } fn core_codegen(&mut self, amqp: &Amqp) { writeln!(self.output, "#![allow(dead_code)]").ok(); writeln!( self.output, "// This file has been generated by `xtask/src/codegen`. Do not edit it manually.\n" ) .ok(); self.codegen_domain_defs(amqp); self.codegen_class_defs(amqp); } fn codegen_domain_defs(&mut self, amqp: &Amqp) { for domain in &amqp.domains { let invariants = self.invariants(domain.asserts.iter()); if let Some(label) = &domain.label { writeln!(self.output, "/// {label}").ok(); } if !invariants.is_empty() { if domain.label.is_some() { writeln!(self.output, "///").ok(); } writeln!(self.output, "/// {invariants}").ok(); } if !domain.doc.is_empty() { writeln!(self.output, "///").ok(); self.doc_comment(&domain.doc, 0); } writeln!( self.output, "pub type {} = {};\n", domain.name.to_upper_camel_case(), self.amqp_type_to_rust_type(&domain.kind), ) .ok(); } } fn codegen_class_defs(&mut self, amqp: &Amqp) { writeln!(self.output, "#[derive(Debug, Clone, PartialEq)]").ok(); writeln!(self.output, "pub enum Method {{").ok(); for class in &amqp.classes { let enum_name = class.name.to_upper_camel_case(); for method in &class.methods { let method_name = method.name.to_upper_camel_case(); self.doc_comment(&class.doc, 4); self.doc_comment(&method.doc, 4); write!(self.output, " {enum_name}{method_name}").ok(); if !method.fields.is_empty() { writeln!(self.output, " {{").ok(); for field in &method.fields { let field_name = self.snake_case(&field.name); let (field_type, field_docs) = self.get_invariants_with_type( self.field_type(field), field.asserts.as_ref(), ); if !field_docs.is_empty() { writeln!(self.output, " /// {field_docs}").ok(); if !field.doc.is_empty() { writeln!(self.output, " ///").ok(); self.doc_comment(&field.doc, 8); } } else { self.doc_comment(&field.doc, 8); } writeln!(self.output, " {field_name}: {field_type},").ok(); } writeln!(self.output, " }},").ok(); } else { writeln!(self.output, ",").ok(); } } } writeln!(self.output, "}}\n").ok(); } fn amqp_type_to_rust_type(&self, amqp_type: &str) -> &'static str { match amqp_type { "octet" => "u8", "short" => "u16", "long" => "u32", "longlong" => "u64", "bit" => "bool", "shortstr" => "String", "longstr" => "Vec", "timestamp" => "u64", "table" => "super::Table", _ => unreachable!("invalid type {}", amqp_type), } } fn field_type<'a>(&self, field: &'a Field) -> &'a String { field.domain.as_ref().or(field.kind.as_ref()).unwrap() } fn resolve_type_from_domain(&self, amqp: &Amqp, domain: &str) -> String { amqp.domains .iter() .find(|d| d.name == domain) .map(|d| d.kind.clone()) .unwrap() } /// returns (type name, invariant docs) fn get_invariants_with_type(&self, domain: &str, asserts: &[Assert]) -> (String, String) { let additional_docs = self.invariants(asserts.iter()); let type_name = domain.to_upper_camel_case(); (type_name, additional_docs) } fn snake_case(&self, ident: &str) -> String { use heck::ToSnakeCase; if ident == "type" { "r#type".to_string() } else { ident.to_snake_case() } } fn subsequent_bit_fields<'a>( &self, bit_field: &'a Field, iter: &mut Peekable>, amqp: &Amqp, ) -> Vec<&'a Field> { let mut fields_with_bit = vec![bit_field]; loop { if iter .peek() .map(|f| self.resolve_type_from_domain(amqp, self.field_type(f)) == "bit") .unwrap_or(false) { fields_with_bit.push(iter.next().unwrap()); } else { break; } } fields_with_bit } fn invariants<'a>(&self, asserts: impl Iterator) -> String { asserts .map(|assert| match &*assert.check { "notnull" => "must not be null".to_string(), "length" => format!("must be shorter than {}", assert.value.as_ref().unwrap()), "regexp" => format!("must match `{}`", assert.value.as_ref().unwrap()), "le" => { format!( "must be less than the {} field of the method {}", assert.method.as_ref().unwrap(), assert.field.as_ref().unwrap() ) } _ => unimplemented!(), }) .collect::>() .join(", ") } fn doc_comment(&mut self, docs: &[Doc], indent: usize) { for doc in docs { if doc.kind == Some("grammar".to_string()) { continue; } for line in doc.text.lines() { let line = line.trim(); if !line.is_empty() { let indent = " ".repeat(indent); writeln!(self.output, "{indent}/// {line}").ok(); } } } } }