more parser generation

This commit is contained in:
nora 2022-02-12 18:54:58 +01:00
parent 6f45a52871
commit c43126af1f
10 changed files with 904 additions and 252 deletions

View file

@ -1,4 +1,7 @@
use anyhow::{Context, Result};
mod parser;
use crate::parser::codegen_parser;
use anyhow::Result;
use heck::ToUpperCamelCase;
use std::fs;
use strong_xml::XmlRead;
@ -81,12 +84,13 @@ fn main() -> Result<()> {
}
fn codegen(amqp: &Amqp) -> Result<()> {
println!("use std::collections::HashMap;\n");
domain_defs(amqp)?;
class_defs(amqp)
println!("// This file has been generated by `amqp_codegen`. Do not edit it manually.\n");
codegen_domain_defs(amqp)?;
codegen_class_defs(amqp)?;
codegen_parser(amqp)
}
fn domain_defs(amqp: &Amqp) -> Result<()> {
fn codegen_domain_defs(amqp: &Amqp) -> Result<()> {
for domain in &amqp.domains {
let invariants = invariants(domain.asserts.iter());
@ -94,7 +98,7 @@ fn domain_defs(amqp: &Amqp) -> Result<()> {
println!("/// {invariants}");
}
println!(
"type {} = {};\n",
"pub type {} = {};\n",
domain.name.to_upper_camel_case(),
amqp_type_to_rust_type(&domain.kind),
);
@ -103,7 +107,7 @@ fn domain_defs(amqp: &Amqp) -> Result<()> {
Ok(())
}
fn class_defs(amqp: &Amqp) -> Result<()> {
fn codegen_class_defs(amqp: &Amqp) -> Result<()> {
println!("pub enum Class {{");
for class in &amqp.classes {
let class_name = class.name.to_upper_camel_case();
@ -111,12 +115,6 @@ fn class_defs(amqp: &Amqp) -> Result<()> {
}
println!("}}\n");
println!(
"pub enum TableValue {{
"
);
for class in &amqp.classes {
let enum_name = class.name.to_upper_camel_case();
println!("/// Index {}, handler = {}", class.index, class.handler);
@ -125,13 +123,12 @@ fn class_defs(amqp: &Amqp) -> Result<()> {
let method_name = method.name.to_upper_camel_case();
println!(" /// Index {}", method.index);
print!(" {method_name}");
if method.fields.len() > 0 {
if !method.fields.is_empty() {
println!(" {{");
for field in &method.fields {
let field_name = snake_case(&field.name);
let (field_type, field_docs) = resolve_type(
amqp,
&field.domain.as_ref().or(field.kind.as_ref()).unwrap(),
field.domain.as_ref().or(field.kind.as_ref()).unwrap(),
field.asserts.as_ref(),
)?;
if !field_docs.is_empty() {
@ -150,7 +147,7 @@ fn class_defs(amqp: &Amqp) -> Result<()> {
Ok(())
}
fn amqp_type_to_rust_type<'a>(amqp_type: &str) -> &'static str {
fn amqp_type_to_rust_type(amqp_type: &str) -> &'static str {
match amqp_type {
"octet" => "u8",
"short" => "u16",
@ -159,37 +156,18 @@ fn amqp_type_to_rust_type<'a>(amqp_type: &str) -> &'static str {
"bit" => "u8",
"shortstr" | "longstr" => "String",
"timestamp" => "u64",
"table" => "HashMap<Shortstr, (Octet, /* todo */ Box<dyn std::any::Any>)>",
"table" => "super::Table",
_ => unreachable!("invalid type {}", amqp_type),
}
}
/// returns (type name, invariant docs)
fn resolve_type(amqp: &Amqp, domain: &str, asserts: &[Assert]) -> Result<(String, String)> {
let kind = amqp
.domains
.iter()
.find(|d| &d.name == domain)
.context("domain not found")?;
let is_nonnull = is_nonnull(asserts.iter().chain(kind.asserts.iter()));
fn resolve_type(domain: &str, asserts: &[Assert]) -> Result<(String, String)> {
let additional_docs = invariants(asserts.iter());
let type_name = domain.to_upper_camel_case();
Ok((
if is_nonnull {
type_name
} else {
format!("Option<{type_name}>")
},
additional_docs,
))
}
fn is_nonnull<'a>(mut asserts: impl Iterator<Item = &'a Assert>) -> bool {
asserts.find(|assert| assert.check == "notnull").is_some()
Ok((type_name, additional_docs))
}
fn snake_case(ident: &str) -> String {
@ -204,18 +182,17 @@ fn snake_case(ident: &str) -> String {
fn invariants<'a>(asserts: impl Iterator<Item = &'a Assert>) -> String {
asserts
.filter_map(|assert| match &*assert.check {
"notnull" => None,
"length" => Some(format!(
"must be shorter than {}",
assert.value.as_ref().unwrap()
)),
"regexp" => Some(format!("must match `{}`", assert.value.as_ref().unwrap())),
"le" => Some(format!(
"must be less than the {} field of the method {}",
assert.method.as_ref().unwrap(),
assert.field.as_ref().unwrap()
)),
.map(|assert| match &*assert.check {
"notnull" => "must not be null".to_string(),
"length" => format!("must be shorter than {}", assert.value.as_ref().unwrap()),
"regexp" => format!("must match `{}`", assert.value.as_ref().unwrap()),
"le" => {
format!(
"must be less than the {} field of the method {}",
assert.method.as_ref().unwrap(),
assert.field.as_ref().unwrap()
)
}
_ => unimplemented!(),
})
.collect::<Vec<_>>()

133
amqp_codegen/src/parser.rs Normal file
View file

@ -0,0 +1,133 @@
use crate::{Amqp, Class, Domain, Method};
use anyhow::Result;
use heck::{ToSnakeCase, ToUpperCamelCase};
use itertools::Itertools;
fn method_function_name(class_name: &str) -> impl Fn(&Method) -> String + '_ {
move |method| {
let method_name = method.name.to_snake_case();
format!("{class_name}_{method_name}")
}
}
fn domain_function_name(domain_name: &str) -> String {
let domain_name = domain_name.to_snake_case();
format!("domain_{domain_name}")
}
pub(crate) fn codegen_parser(amqp: &Amqp) -> Result<()> {
println!(
"pub mod parse {{
use super::*;
use crate::classes::parse_helper::*;
use crate::error::TransError;
use nom::{{branch::alt, bytes::complete::tag}};
use regex::Regex;
use once_cell::sync::Lazy;
pub type IResult<'a, T> = nom::IResult<&'a [u8], T, TransError>;
"
);
println!(
"pub fn parse_method(input: &[u8]) -> Result<(&[u8], Class), nom::Err<TransError>> {{
alt(({}))(input)
}}",
amqp.classes
.iter()
.map(|class| class.name.to_snake_case())
.join(", ")
);
for domain in &amqp.domains {
domain_parser(domain)?;
}
for class in &amqp.classes {
let class_name = class.name.to_snake_case();
function(&class_name, "Class", || {
let class_index = class.index;
let all_methods = class
.methods
.iter()
.map(method_function_name(&class_name))
.join(", ");
println!(
" let (input, _) = tag([{class_index}])(input)?;
alt(({all_methods}))(input)"
);
Ok(())
})?;
for method in &class.methods {
method_parser(class, method)?;
}
}
println!("\n}}");
Ok(())
}
fn domain_parser(domain: &Domain) -> Result<()> {
let fn_name = domain_function_name(&domain.name);
let type_name = domain.kind.to_snake_case();
function(&fn_name, &domain.name.to_upper_camel_case(), || {
if domain.asserts.is_empty() {
if type_name == "bit" {
println!(" todo!() // bit")
} else {
println!(" {type_name}(input)");
}
} else {
println!(" let (input, result) = {type_name}(input)?;");
for assert in &domain.asserts {
match &*assert.check {
"notnull" => { /* todo */ }
"regexp" => {
let value = assert.value.as_ref().unwrap();
println!(
r#" static REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"{value}").unwrap());"#
);
println!(" if !REGEX.is_match(&result) {{ fail!() }}");
}
"le" => {} // can't validate this here
"length" => {
let length = assert.value.as_ref().unwrap();
println!(" if result.len() > {length} {{ fail!() }}");
}
_ => unimplemented!(),
}
}
println!(" Ok((input, result))");
}
Ok(())
})
}
fn method_parser(class: &Class, method: &Method) -> Result<()> {
let class_name = class.name.to_snake_case();
let function_name = method_function_name(&class_name)(method);
function(&function_name, "Class", || {
let method_index = method.index;
println!(" let (input, _) = tag([{method_index}])(input)?;");
println!(" todo!()");
for _field in &method.fields {}
Ok(())
})?;
Ok(())
}
fn function<F>(name: &str, ret_ty: &str, body: F) -> Result<()>
where
F: FnOnce() -> Result<()>,
{
println!("fn {name}(input: &[u8]) -> IResult<{ret_ty}> {{");
body()?;
println!("}}");
Ok(())
}