diff --git a/pm/src/lib.rs b/pm/src/lib.rs index c0ac2ba..b5a6949 100644 --- a/pm/src/lib.rs +++ b/pm/src/lib.rs @@ -1,50 +1,46 @@ use proc_macro::TokenStream; -use proc_macro2::{Ident, Span}; -use quote::quote; -use syn::{fold::Fold, parse_macro_input, parse_quote, ItemFn, Stmt}; + +mod safe_extern; +mod scratch; #[proc_macro_attribute] -pub fn scratch_space(_: TokenStream, input: TokenStream) -> TokenStream { - let fn_def = parse_macro_input!(input as ItemFn); - let track_ident = Ident::new("scratch_local", Span::mixed_site()); - - let mut fn_def = LocalInitFolder { - track_ident: track_ident.clone(), - } - .fold_item_fn(fn_def); - - let init: Stmt = parse_quote! { let #track_ident: (); }; - - fn_def.block.stmts.insert(0, init); - - quote! { #fn_def }.into() +pub fn scratch_space(attr: TokenStream, input: TokenStream) -> TokenStream { + scratch::scratch_space(attr, input) } -struct LocalInitFolder { - track_ident: Ident, -} - -impl syn::fold::Fold for LocalInitFolder { - fn fold_macro(&mut self, mut mac: syn::Macro) -> syn::Macro { - if let Some(last_path) = mac.path.segments.iter().next_back() { - match last_path.ident.to_string().as_str() { - "scratch_write" => { - let track_ident = &self.track_ident.clone(); - mac.path = parse_quote! { actual_scratch_write }; - mac.tokens.extend(quote! { ; #track_ident }); - } - "scratch_read" => { - let mut track_ident = self.track_ident.clone(); - track_ident.set_span(track_ident.span().located_at(last_path.ident.span())); - mac.path = parse_quote! { actual_scratch_read }; - mac.tokens.extend(quote! { ; #track_ident }); - } - _ => {} - } - - mac - } else { - mac - } - } +/// # safe-extern +/// +/// Mark foreign functions as to be safe to call. +/// +/// ```ignore +/// #[safe_extern] +/// extern "Rust" { +/// fn add(a: u8, b: u8) -> u8; +/// } +/// +/// fn main() { +/// assert_eq!(add(1, 2), 3); +/// } +/// ``` +/// +/// It works by expanding the above to this +/// +/// ```ignore +/// extern "Rust" { +/// #[link_name = "add"] +/// fn _safe_extern_inner_add(a: u8, b: u8) -> u8; +/// } +/// fn add(a: u8, b: u8) -> u8 { +/// unsafe { _safe_extern_inner_add(a, b) } +/// } +/// +/// fn main() { +/// assert_eq!(add(1, 2), 3); +/// } +/// ``` +/// +/// This is of course unsound and the macro needs to be `unsafe` somehow but I can't be bothered with that right now lol. +#[proc_macro_attribute] +pub fn safe_extern(attr: TokenStream, input: TokenStream) -> TokenStream { + safe_extern::safe_extern(attr, input) } diff --git a/pm/src/safe_extern.rs b/pm/src/safe_extern.rs new file mode 100644 index 0000000..2d50727 --- /dev/null +++ b/pm/src/safe_extern.rs @@ -0,0 +1,98 @@ +use proc_macro::TokenStream; +use proc_macro2::Ident; +use quote::{quote, quote_spanned}; +use syn::{ + parse_macro_input, ForeignItem, ForeignItemFn, ItemFn, Pat, PatIdent, PatType, Visibility, +}; + +pub fn safe_extern(_: TokenStream, input: TokenStream) -> TokenStream { + let mut foreign = parse_macro_input!(input as syn::ItemForeignMod); + + let mut safe_wrappers = Vec::new(); + let src_items = std::mem::take(&mut foreign.items); + + for item in src_items { + match item { + ForeignItem::Fn(item_fn) => { + let (replacement, safe_wrapper) = mangle_ident_and_add_link_name(item_fn); + foreign.items.push(ForeignItem::Fn(replacement)); + + safe_wrappers.push(safe_wrapper); + } + item => match head_span_foreign_item(&item) { + Some(span) => { + return quote_spanned! { + span => compile_error! { "only foreign functions are allowed" } + } + .into(); + } + None => { + return quote! { + compile_error! { "only foreign functions are allowed" } + } + .into(); + } + }, + } + } + + quote! { #foreign #(#safe_wrappers)* }.into() +} + +fn mangle_ident_and_add_link_name(mut item: ForeignItemFn) -> (ForeignItemFn, ItemFn) { + if item.attrs.iter().any(|attr| { + attr.path + .get_ident() + .map_or(false, |ident| ident.to_string() == "link_name") + }) { + panic!("oh no you have alink name already") + } + + let vis = std::mem::replace(&mut item.vis, Visibility::Inherited); + + let name = item.sig.ident; + let name_str = name.to_string(); + if name_str.starts_with("r#") { + panic!("rawr :>("); + } + + let mangled = format!("_safe_extern_inner_{name_str}"); + let new_name = Ident::new(&mangled, name.span()); + item.sig.ident = new_name.clone(); + + item.attrs + .push(syn::parse_quote! { #[link_name = #name_str] }); + + let args = item.sig.inputs.iter().map(|param| match param { + syn::FnArg::Receiver(_) => panic!("cannot have reciver in foreign function"), + syn::FnArg::Typed(PatType { pat, .. }) => match &**pat { + Pat::Ident(PatIdent { ident, .. }) => quote! { #ident }, + _ => panic!("invalid argument in foreign function"), + }, + }); + + let mut safe_sig = item.sig.clone(); + safe_sig.ident = name; + let safe_wrapper = ItemFn { + attrs: Vec::new(), + vis, + sig: safe_sig, + block: syn::parse_quote! { + { + unsafe { #new_name(#(#args),*) } + } + }, + }; + + (item, safe_wrapper) +} + +fn head_span_foreign_item(item: &ForeignItem) -> Option { + Some(match item { + ForeignItem::Fn(_) => unreachable!(), + ForeignItem::Static(s) => s.static_token.span, + ForeignItem::Type(ty) => ty.type_token.span, + ForeignItem::Macro(m) => m.mac.path.segments[0].ident.span(), + _ => return None, + }) +} diff --git a/pm/src/scratch.rs b/pm/src/scratch.rs new file mode 100644 index 0000000..f93a81b --- /dev/null +++ b/pm/src/scratch.rs @@ -0,0 +1,49 @@ +use proc_macro::TokenStream; +use proc_macro2::{Ident, Span}; +use quote::quote; +use syn::{fold::Fold, parse_macro_input, parse_quote, ItemFn, Stmt}; + +pub fn scratch_space(_: TokenStream, input: TokenStream) -> TokenStream { + let fn_def = parse_macro_input!(input as ItemFn); + let track_ident = Ident::new("scratch_local", Span::mixed_site()); + + let mut fn_def = LocalInitFolder { + track_ident: track_ident.clone(), + } + .fold_item_fn(fn_def); + + let init: Stmt = parse_quote! { let #track_ident: (); }; + + fn_def.block.stmts.insert(0, init); + + quote! { #fn_def }.into() +} + +struct LocalInitFolder { + track_ident: Ident, +} + +impl syn::fold::Fold for LocalInitFolder { + fn fold_macro(&mut self, mut mac: syn::Macro) -> syn::Macro { + if let Some(last_path) = mac.path.segments.iter().next_back() { + match last_path.ident.to_string().as_str() { + "scratch_write" => { + let track_ident = &self.track_ident.clone(); + mac.path = parse_quote! { actual_scratch_write }; + mac.tokens.extend(quote! { ; #track_ident }); + } + "scratch_read" => { + let mut track_ident = self.track_ident.clone(); + track_ident.set_span(track_ident.span().located_at(last_path.ident.span())); + mac.path = parse_quote! { actual_scratch_read }; + mac.tokens.extend(quote! { ; #track_ident }); + } + _ => {} + } + + mac + } else { + mac + } + } +} diff --git a/src/lib.rs b/src/lib.rs index b65bc0e..64e2b17 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,3 +9,7 @@ pub mod scratch; pub mod sendsync; pub mod unroll_int; pub mod unsized_clone; + +pub mod safe_extern { + pub use pm::safe_extern; +} \ No newline at end of file diff --git a/tests/safe_extern.rs b/tests/safe_extern.rs new file mode 100644 index 0000000..2de4fd7 --- /dev/null +++ b/tests/safe_extern.rs @@ -0,0 +1,18 @@ +use uwu::safe_extern::safe_extern; + +#[safe_extern] +extern "Rust" { + fn add(a: u8, b: u8) -> u8; +} + +mod _impl { + #[no_mangle] + pub(super) fn add(a: u8, b: u8) -> u8 { + a + b + } +} + +#[test] +fn adding() { + assert_eq!(add(1, 2), 3); +}