From f7395cd8175f257eb8d9387c3a19d11f9673901e Mon Sep 17 00:00:00 2001 From: nils <48135649+Nilstrieb@users.noreply.github.com> Date: Mon, 19 Dec 2022 16:12:55 +0100 Subject: [PATCH] some dylib cleanup --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- src/dylib_flag.rs | 44 +++++++++++++++++++++++++++++++++----------- 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2030711..48ad3fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -301,9 +301,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.21" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" +checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" dependencies = [ "proc-macro2", ] diff --git a/Cargo.toml b/Cargo.toml index 167f723..d132d1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ anyhow = "1.0.65" clap = { version = "4.0.29", features = ["derive"] } prettyplease = "0.1.19" proc-macro2 = { version = "1.0.48", features = ["span-locations"] } -quote = "1.0.21" +quote = "1.0.23" rustfix = "0.6.1" serde = { version = "1.0.151", features = ["derive"] } serde_json = "1.0.90" diff --git a/src/dylib_flag.rs b/src/dylib_flag.rs index 25c1786..e3a38a2 100644 --- a/src/dylib_flag.rs +++ b/src/dylib_flag.rs @@ -4,13 +4,12 @@ use std::{fmt::Debug, str::FromStr}; use anyhow::{Context, Result}; -use quote::quote; -type Entrypoint = unsafe extern "C" fn(*const u8, usize) -> bool; +type CheckerCFn = unsafe extern "C" fn(*const u8, usize) -> bool; #[derive(Clone, Copy)] pub struct RustFunction { - func: Entrypoint, + func: CheckerCFn, } impl FromStr for RustFunction { @@ -24,24 +23,28 @@ impl FromStr for RustFunction { fn wrap_func_body(func: &str) -> Result { let closure = syn::parse_str::(func).context("invalid rust syntax")?; - let tokenstream = quote! { + let syn_file = syn::parse_quote! { #[no_mangle] pub extern "C" fn cargo_minimize_ffi_function(ptr: *const u8, len: usize) -> bool { - match ::std::panic::catch_unwind(|| __cargo_minimize_inner(ptr, len)) { + match std::panic::catch_unwind(|| __cargo_minimize_inner(ptr, len)) { Ok(bool) => bool, - Err(_) => ::std::process::abort(), + Err(_) => std::process::abort(), } } fn __cargo_minimize_inner(__ptr: *const u8, __len: usize) -> bool { - let __slice = unsafe { ::std::slice::from_raw_parts(__ptr, __len) }; - let __str = ::std::str::from_utf8(__slice).unwrap(); + let __slice = unsafe { std::slice::from_raw_parts(__ptr, __len) }; + let __str = std::str::from_utf8(__slice).unwrap(); - (#closure)(__str) + fn ascribe_type bool>(f: F, output: &str) -> bool { + f(output) + } + + ascribe_type((#closure), __str) } }; - Ok(tokenstream.to_string()) + Ok(prettyplease::unparse(&syn_file)) } impl RustFunction { @@ -96,7 +99,7 @@ impl RustFunction { bail!("didn't find entrypoint symbol"); } - let func = unsafe { std::mem::transmute::<*mut _, Entrypoint>(func) }; + let func = unsafe { std::mem::transmute::<*mut _, CheckerCFn>(func) }; Ok(Self { func }) } @@ -114,3 +117,22 @@ impl Debug for RustFunction { f.debug_struct("RustFunction").finish_non_exhaustive() } } + +#[cfg(test)] +mod tests { + use super::RustFunction; + + #[test] + #[cfg_attr(not(unix), ignore)] + fn basic_contains_work() { + let code = r#"|output| output.contains("test")"#; + + let function = RustFunction::compile(code).unwrap(); + + let input = "this is a test"; + let not_input = "this is not a tst"; + + assert!(function.call(input)); + assert!(!function.call(not_input)); + } +}