more parser generation

This commit is contained in:
nora 2022-02-12 18:54:58 +01:00
parent 6f45a52871
commit c43126af1f
10 changed files with 904 additions and 252 deletions

View file

@ -1,4 +1,7 @@
use anyhow::{Context, Result};
mod parser;
use crate::parser::codegen_parser;
use anyhow::Result;
use heck::ToUpperCamelCase;
use std::fs;
use strong_xml::XmlRead;
@ -81,12 +84,13 @@ fn main() -> Result<()> {
}
fn codegen(amqp: &Amqp) -> Result<()> {
println!("use std::collections::HashMap;\n");
domain_defs(amqp)?;
class_defs(amqp)
println!("// This file has been generated by `amqp_codegen`. Do not edit it manually.\n");
codegen_domain_defs(amqp)?;
codegen_class_defs(amqp)?;
codegen_parser(amqp)
}
fn domain_defs(amqp: &Amqp) -> Result<()> {
fn codegen_domain_defs(amqp: &Amqp) -> Result<()> {
for domain in &amqp.domains {
let invariants = invariants(domain.asserts.iter());
@ -94,7 +98,7 @@ fn domain_defs(amqp: &Amqp) -> Result<()> {
println!("/// {invariants}");
}
println!(
"type {} = {};\n",
"pub type {} = {};\n",
domain.name.to_upper_camel_case(),
amqp_type_to_rust_type(&domain.kind),
);
@ -103,7 +107,7 @@ fn domain_defs(amqp: &Amqp) -> Result<()> {
Ok(())
}
fn class_defs(amqp: &Amqp) -> Result<()> {
fn codegen_class_defs(amqp: &Amqp) -> Result<()> {
println!("pub enum Class {{");
for class in &amqp.classes {
let class_name = class.name.to_upper_camel_case();
@ -111,12 +115,6 @@ fn class_defs(amqp: &Amqp) -> Result<()> {
}
println!("}}\n");
println!(
"pub enum TableValue {{
"
);
for class in &amqp.classes {
let enum_name = class.name.to_upper_camel_case();
println!("/// Index {}, handler = {}", class.index, class.handler);
@ -125,13 +123,12 @@ fn class_defs(amqp: &Amqp) -> Result<()> {
let method_name = method.name.to_upper_camel_case();
println!(" /// Index {}", method.index);
print!(" {method_name}");
if method.fields.len() > 0 {
if !method.fields.is_empty() {
println!(" {{");
for field in &method.fields {
let field_name = snake_case(&field.name);
let (field_type, field_docs) = resolve_type(
amqp,
&field.domain.as_ref().or(field.kind.as_ref()).unwrap(),
field.domain.as_ref().or(field.kind.as_ref()).unwrap(),
field.asserts.as_ref(),
)?;
if !field_docs.is_empty() {
@ -150,7 +147,7 @@ fn class_defs(amqp: &Amqp) -> Result<()> {
Ok(())
}
fn amqp_type_to_rust_type<'a>(amqp_type: &str) -> &'static str {
fn amqp_type_to_rust_type(amqp_type: &str) -> &'static str {
match amqp_type {
"octet" => "u8",
"short" => "u16",
@ -159,37 +156,18 @@ fn amqp_type_to_rust_type<'a>(amqp_type: &str) -> &'static str {
"bit" => "u8",
"shortstr" | "longstr" => "String",
"timestamp" => "u64",
"table" => "HashMap<Shortstr, (Octet, /* todo */ Box<dyn std::any::Any>)>",
"table" => "super::Table",
_ => unreachable!("invalid type {}", amqp_type),
}
}
/// returns (type name, invariant docs)
fn resolve_type(amqp: &Amqp, domain: &str, asserts: &[Assert]) -> Result<(String, String)> {
let kind = amqp
.domains
.iter()
.find(|d| &d.name == domain)
.context("domain not found")?;
let is_nonnull = is_nonnull(asserts.iter().chain(kind.asserts.iter()));
fn resolve_type(domain: &str, asserts: &[Assert]) -> Result<(String, String)> {
let additional_docs = invariants(asserts.iter());
let type_name = domain.to_upper_camel_case();
Ok((
if is_nonnull {
type_name
} else {
format!("Option<{type_name}>")
},
additional_docs,
))
}
fn is_nonnull<'a>(mut asserts: impl Iterator<Item = &'a Assert>) -> bool {
asserts.find(|assert| assert.check == "notnull").is_some()
Ok((type_name, additional_docs))
}
fn snake_case(ident: &str) -> String {
@ -204,18 +182,17 @@ fn snake_case(ident: &str) -> String {
fn invariants<'a>(asserts: impl Iterator<Item = &'a Assert>) -> String {
asserts
.filter_map(|assert| match &*assert.check {
"notnull" => None,
"length" => Some(format!(
"must be shorter than {}",
assert.value.as_ref().unwrap()
)),
"regexp" => Some(format!("must match `{}`", assert.value.as_ref().unwrap())),
"le" => Some(format!(
"must be less than the {} field of the method {}",
assert.method.as_ref().unwrap(),
assert.field.as_ref().unwrap()
)),
.map(|assert| match &*assert.check {
"notnull" => "must not be null".to_string(),
"length" => format!("must be shorter than {}", assert.value.as_ref().unwrap()),
"regexp" => format!("must match `{}`", assert.value.as_ref().unwrap()),
"le" => {
format!(
"must be less than the {} field of the method {}",
assert.method.as_ref().unwrap(),
assert.field.as_ref().unwrap()
)
}
_ => unimplemented!(),
})
.collect::<Vec<_>>()