safe attr

This commit is contained in:
nora 2023-02-16 19:05:53 +01:00
parent 5f98cb10bc
commit 0cf263257c
5 changed files with 209 additions and 44 deletions

View file

@ -1,50 +1,46 @@
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::quote;
use syn::{fold::Fold, parse_macro_input, parse_quote, ItemFn, Stmt};
mod safe_extern;
mod scratch;
#[proc_macro_attribute]
pub fn scratch_space(_: TokenStream, input: TokenStream) -> TokenStream {
let fn_def = parse_macro_input!(input as ItemFn);
let track_ident = Ident::new("scratch_local", Span::mixed_site());
let mut fn_def = LocalInitFolder {
track_ident: track_ident.clone(),
}
.fold_item_fn(fn_def);
let init: Stmt = parse_quote! { let #track_ident: (); };
fn_def.block.stmts.insert(0, init);
quote! { #fn_def }.into()
pub fn scratch_space(attr: TokenStream, input: TokenStream) -> TokenStream {
scratch::scratch_space(attr, input)
}
struct LocalInitFolder {
track_ident: Ident,
}
impl syn::fold::Fold for LocalInitFolder {
fn fold_macro(&mut self, mut mac: syn::Macro) -> syn::Macro {
if let Some(last_path) = mac.path.segments.iter().next_back() {
match last_path.ident.to_string().as_str() {
"scratch_write" => {
let track_ident = &self.track_ident.clone();
mac.path = parse_quote! { actual_scratch_write };
mac.tokens.extend(quote! { ; #track_ident });
}
"scratch_read" => {
let mut track_ident = self.track_ident.clone();
track_ident.set_span(track_ident.span().located_at(last_path.ident.span()));
mac.path = parse_quote! { actual_scratch_read };
mac.tokens.extend(quote! { ; #track_ident });
}
_ => {}
}
mac
} else {
mac
}
}
/// # safe-extern
///
/// Mark foreign functions as to be safe to call.
///
/// ```ignore
/// #[safe_extern]
/// extern "Rust" {
/// fn add(a: u8, b: u8) -> u8;
/// }
///
/// fn main() {
/// assert_eq!(add(1, 2), 3);
/// }
/// ```
///
/// It works by expanding the above to this
///
/// ```ignore
/// extern "Rust" {
/// #[link_name = "add"]
/// fn _safe_extern_inner_add(a: u8, b: u8) -> u8;
/// }
/// fn add(a: u8, b: u8) -> u8 {
/// unsafe { _safe_extern_inner_add(a, b) }
/// }
///
/// fn main() {
/// assert_eq!(add(1, 2), 3);
/// }
/// ```
///
/// This is of course unsound and the macro needs to be `unsafe` somehow but I can't be bothered with that right now lol.
#[proc_macro_attribute]
pub fn safe_extern(attr: TokenStream, input: TokenStream) -> TokenStream {
safe_extern::safe_extern(attr, input)
}

98
pm/src/safe_extern.rs Normal file
View file

@ -0,0 +1,98 @@
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::{quote, quote_spanned};
use syn::{
parse_macro_input, ForeignItem, ForeignItemFn, ItemFn, Pat, PatIdent, PatType, Visibility,
};
pub fn safe_extern(_: TokenStream, input: TokenStream) -> TokenStream {
let mut foreign = parse_macro_input!(input as syn::ItemForeignMod);
let mut safe_wrappers = Vec::new();
let src_items = std::mem::take(&mut foreign.items);
for item in src_items {
match item {
ForeignItem::Fn(item_fn) => {
let (replacement, safe_wrapper) = mangle_ident_and_add_link_name(item_fn);
foreign.items.push(ForeignItem::Fn(replacement));
safe_wrappers.push(safe_wrapper);
}
item => match head_span_foreign_item(&item) {
Some(span) => {
return quote_spanned! {
span => compile_error! { "only foreign functions are allowed" }
}
.into();
}
None => {
return quote! {
compile_error! { "only foreign functions are allowed" }
}
.into();
}
},
}
}
quote! { #foreign #(#safe_wrappers)* }.into()
}
fn mangle_ident_and_add_link_name(mut item: ForeignItemFn) -> (ForeignItemFn, ItemFn) {
if item.attrs.iter().any(|attr| {
attr.path
.get_ident()
.map_or(false, |ident| ident.to_string() == "link_name")
}) {
panic!("oh no you have alink name already")
}
let vis = std::mem::replace(&mut item.vis, Visibility::Inherited);
let name = item.sig.ident;
let name_str = name.to_string();
if name_str.starts_with("r#") {
panic!("rawr :>(");
}
let mangled = format!("_safe_extern_inner_{name_str}");
let new_name = Ident::new(&mangled, name.span());
item.sig.ident = new_name.clone();
item.attrs
.push(syn::parse_quote! { #[link_name = #name_str] });
let args = item.sig.inputs.iter().map(|param| match param {
syn::FnArg::Receiver(_) => panic!("cannot have reciver in foreign function"),
syn::FnArg::Typed(PatType { pat, .. }) => match &**pat {
Pat::Ident(PatIdent { ident, .. }) => quote! { #ident },
_ => panic!("invalid argument in foreign function"),
},
});
let mut safe_sig = item.sig.clone();
safe_sig.ident = name;
let safe_wrapper = ItemFn {
attrs: Vec::new(),
vis,
sig: safe_sig,
block: syn::parse_quote! {
{
unsafe { #new_name(#(#args),*) }
}
},
};
(item, safe_wrapper)
}
fn head_span_foreign_item(item: &ForeignItem) -> Option<proc_macro2::Span> {
Some(match item {
ForeignItem::Fn(_) => unreachable!(),
ForeignItem::Static(s) => s.static_token.span,
ForeignItem::Type(ty) => ty.type_token.span,
ForeignItem::Macro(m) => m.mac.path.segments[0].ident.span(),
_ => return None,
})
}

49
pm/src/scratch.rs Normal file
View file

@ -0,0 +1,49 @@
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::quote;
use syn::{fold::Fold, parse_macro_input, parse_quote, ItemFn, Stmt};
pub fn scratch_space(_: TokenStream, input: TokenStream) -> TokenStream {
let fn_def = parse_macro_input!(input as ItemFn);
let track_ident = Ident::new("scratch_local", Span::mixed_site());
let mut fn_def = LocalInitFolder {
track_ident: track_ident.clone(),
}
.fold_item_fn(fn_def);
let init: Stmt = parse_quote! { let #track_ident: (); };
fn_def.block.stmts.insert(0, init);
quote! { #fn_def }.into()
}
struct LocalInitFolder {
track_ident: Ident,
}
impl syn::fold::Fold for LocalInitFolder {
fn fold_macro(&mut self, mut mac: syn::Macro) -> syn::Macro {
if let Some(last_path) = mac.path.segments.iter().next_back() {
match last_path.ident.to_string().as_str() {
"scratch_write" => {
let track_ident = &self.track_ident.clone();
mac.path = parse_quote! { actual_scratch_write };
mac.tokens.extend(quote! { ; #track_ident });
}
"scratch_read" => {
let mut track_ident = self.track_ident.clone();
track_ident.set_span(track_ident.span().located_at(last_path.ident.span()));
mac.path = parse_quote! { actual_scratch_read };
mac.tokens.extend(quote! { ; #track_ident });
}
_ => {}
}
mac
} else {
mac
}
}
}

View file

@ -9,3 +9,7 @@ pub mod scratch;
pub mod sendsync;
pub mod unroll_int;
pub mod unsized_clone;
pub mod safe_extern {
pub use pm::safe_extern;
}

18
tests/safe_extern.rs Normal file
View file

@ -0,0 +1,18 @@
use uwu::safe_extern::safe_extern;
#[safe_extern]
extern "Rust" {
fn add(a: u8, b: u8) -> u8;
}
mod _impl {
#[no_mangle]
pub(super) fn add(a: u8, b: u8) -> u8 {
a + b
}
}
#[test]
fn adding() {
assert_eq!(add(1, 2), 3);
}