mirror of
https://github.com/Noratrieb/the-good-stuff.git
synced 2026-01-14 16:45:01 +01:00
safe attr
This commit is contained in:
parent
5f98cb10bc
commit
0cf263257c
5 changed files with 209 additions and 44 deletions
|
|
@ -1,50 +1,46 @@
|
|||
use proc_macro::TokenStream;
|
||||
use proc_macro2::{Ident, Span};
|
||||
use quote::quote;
|
||||
use syn::{fold::Fold, parse_macro_input, parse_quote, ItemFn, Stmt};
|
||||
|
||||
mod safe_extern;
|
||||
mod scratch;
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn scratch_space(_: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let fn_def = parse_macro_input!(input as ItemFn);
|
||||
let track_ident = Ident::new("scratch_local", Span::mixed_site());
|
||||
|
||||
let mut fn_def = LocalInitFolder {
|
||||
track_ident: track_ident.clone(),
|
||||
}
|
||||
.fold_item_fn(fn_def);
|
||||
|
||||
let init: Stmt = parse_quote! { let #track_ident: (); };
|
||||
|
||||
fn_def.block.stmts.insert(0, init);
|
||||
|
||||
quote! { #fn_def }.into()
|
||||
pub fn scratch_space(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
scratch::scratch_space(attr, input)
|
||||
}
|
||||
|
||||
struct LocalInitFolder {
|
||||
track_ident: Ident,
|
||||
}
|
||||
|
||||
impl syn::fold::Fold for LocalInitFolder {
|
||||
fn fold_macro(&mut self, mut mac: syn::Macro) -> syn::Macro {
|
||||
if let Some(last_path) = mac.path.segments.iter().next_back() {
|
||||
match last_path.ident.to_string().as_str() {
|
||||
"scratch_write" => {
|
||||
let track_ident = &self.track_ident.clone();
|
||||
mac.path = parse_quote! { actual_scratch_write };
|
||||
mac.tokens.extend(quote! { ; #track_ident });
|
||||
}
|
||||
"scratch_read" => {
|
||||
let mut track_ident = self.track_ident.clone();
|
||||
track_ident.set_span(track_ident.span().located_at(last_path.ident.span()));
|
||||
mac.path = parse_quote! { actual_scratch_read };
|
||||
mac.tokens.extend(quote! { ; #track_ident });
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
mac
|
||||
} else {
|
||||
mac
|
||||
}
|
||||
}
|
||||
/// # safe-extern
|
||||
///
|
||||
/// Mark foreign functions as to be safe to call.
|
||||
///
|
||||
/// ```ignore
|
||||
/// #[safe_extern]
|
||||
/// extern "Rust" {
|
||||
/// fn add(a: u8, b: u8) -> u8;
|
||||
/// }
|
||||
///
|
||||
/// fn main() {
|
||||
/// assert_eq!(add(1, 2), 3);
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// It works by expanding the above to this
|
||||
///
|
||||
/// ```ignore
|
||||
/// extern "Rust" {
|
||||
/// #[link_name = "add"]
|
||||
/// fn _safe_extern_inner_add(a: u8, b: u8) -> u8;
|
||||
/// }
|
||||
/// fn add(a: u8, b: u8) -> u8 {
|
||||
/// unsafe { _safe_extern_inner_add(a, b) }
|
||||
/// }
|
||||
///
|
||||
/// fn main() {
|
||||
/// assert_eq!(add(1, 2), 3);
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// This is of course unsound and the macro needs to be `unsafe` somehow but I can't be bothered with that right now lol.
|
||||
#[proc_macro_attribute]
|
||||
pub fn safe_extern(attr: TokenStream, input: TokenStream) -> TokenStream {
|
||||
safe_extern::safe_extern(attr, input)
|
||||
}
|
||||
|
|
|
|||
98
pm/src/safe_extern.rs
Normal file
98
pm/src/safe_extern.rs
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
use proc_macro::TokenStream;
|
||||
use proc_macro2::Ident;
|
||||
use quote::{quote, quote_spanned};
|
||||
use syn::{
|
||||
parse_macro_input, ForeignItem, ForeignItemFn, ItemFn, Pat, PatIdent, PatType, Visibility,
|
||||
};
|
||||
|
||||
pub fn safe_extern(_: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let mut foreign = parse_macro_input!(input as syn::ItemForeignMod);
|
||||
|
||||
let mut safe_wrappers = Vec::new();
|
||||
let src_items = std::mem::take(&mut foreign.items);
|
||||
|
||||
for item in src_items {
|
||||
match item {
|
||||
ForeignItem::Fn(item_fn) => {
|
||||
let (replacement, safe_wrapper) = mangle_ident_and_add_link_name(item_fn);
|
||||
foreign.items.push(ForeignItem::Fn(replacement));
|
||||
|
||||
safe_wrappers.push(safe_wrapper);
|
||||
}
|
||||
item => match head_span_foreign_item(&item) {
|
||||
Some(span) => {
|
||||
return quote_spanned! {
|
||||
span => compile_error! { "only foreign functions are allowed" }
|
||||
}
|
||||
.into();
|
||||
}
|
||||
None => {
|
||||
return quote! {
|
||||
compile_error! { "only foreign functions are allowed" }
|
||||
}
|
||||
.into();
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
quote! { #foreign #(#safe_wrappers)* }.into()
|
||||
}
|
||||
|
||||
fn mangle_ident_and_add_link_name(mut item: ForeignItemFn) -> (ForeignItemFn, ItemFn) {
|
||||
if item.attrs.iter().any(|attr| {
|
||||
attr.path
|
||||
.get_ident()
|
||||
.map_or(false, |ident| ident.to_string() == "link_name")
|
||||
}) {
|
||||
panic!("oh no you have alink name already")
|
||||
}
|
||||
|
||||
let vis = std::mem::replace(&mut item.vis, Visibility::Inherited);
|
||||
|
||||
let name = item.sig.ident;
|
||||
let name_str = name.to_string();
|
||||
if name_str.starts_with("r#") {
|
||||
panic!("rawr :>(");
|
||||
}
|
||||
|
||||
let mangled = format!("_safe_extern_inner_{name_str}");
|
||||
let new_name = Ident::new(&mangled, name.span());
|
||||
item.sig.ident = new_name.clone();
|
||||
|
||||
item.attrs
|
||||
.push(syn::parse_quote! { #[link_name = #name_str] });
|
||||
|
||||
let args = item.sig.inputs.iter().map(|param| match param {
|
||||
syn::FnArg::Receiver(_) => panic!("cannot have reciver in foreign function"),
|
||||
syn::FnArg::Typed(PatType { pat, .. }) => match &**pat {
|
||||
Pat::Ident(PatIdent { ident, .. }) => quote! { #ident },
|
||||
_ => panic!("invalid argument in foreign function"),
|
||||
},
|
||||
});
|
||||
|
||||
let mut safe_sig = item.sig.clone();
|
||||
safe_sig.ident = name;
|
||||
let safe_wrapper = ItemFn {
|
||||
attrs: Vec::new(),
|
||||
vis,
|
||||
sig: safe_sig,
|
||||
block: syn::parse_quote! {
|
||||
{
|
||||
unsafe { #new_name(#(#args),*) }
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
(item, safe_wrapper)
|
||||
}
|
||||
|
||||
fn head_span_foreign_item(item: &ForeignItem) -> Option<proc_macro2::Span> {
|
||||
Some(match item {
|
||||
ForeignItem::Fn(_) => unreachable!(),
|
||||
ForeignItem::Static(s) => s.static_token.span,
|
||||
ForeignItem::Type(ty) => ty.type_token.span,
|
||||
ForeignItem::Macro(m) => m.mac.path.segments[0].ident.span(),
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
49
pm/src/scratch.rs
Normal file
49
pm/src/scratch.rs
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
use proc_macro::TokenStream;
|
||||
use proc_macro2::{Ident, Span};
|
||||
use quote::quote;
|
||||
use syn::{fold::Fold, parse_macro_input, parse_quote, ItemFn, Stmt};
|
||||
|
||||
pub fn scratch_space(_: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let fn_def = parse_macro_input!(input as ItemFn);
|
||||
let track_ident = Ident::new("scratch_local", Span::mixed_site());
|
||||
|
||||
let mut fn_def = LocalInitFolder {
|
||||
track_ident: track_ident.clone(),
|
||||
}
|
||||
.fold_item_fn(fn_def);
|
||||
|
||||
let init: Stmt = parse_quote! { let #track_ident: (); };
|
||||
|
||||
fn_def.block.stmts.insert(0, init);
|
||||
|
||||
quote! { #fn_def }.into()
|
||||
}
|
||||
|
||||
struct LocalInitFolder {
|
||||
track_ident: Ident,
|
||||
}
|
||||
|
||||
impl syn::fold::Fold for LocalInitFolder {
|
||||
fn fold_macro(&mut self, mut mac: syn::Macro) -> syn::Macro {
|
||||
if let Some(last_path) = mac.path.segments.iter().next_back() {
|
||||
match last_path.ident.to_string().as_str() {
|
||||
"scratch_write" => {
|
||||
let track_ident = &self.track_ident.clone();
|
||||
mac.path = parse_quote! { actual_scratch_write };
|
||||
mac.tokens.extend(quote! { ; #track_ident });
|
||||
}
|
||||
"scratch_read" => {
|
||||
let mut track_ident = self.track_ident.clone();
|
||||
track_ident.set_span(track_ident.span().located_at(last_path.ident.span()));
|
||||
mac.path = parse_quote! { actual_scratch_read };
|
||||
mac.tokens.extend(quote! { ; #track_ident });
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
mac
|
||||
} else {
|
||||
mac
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -9,3 +9,7 @@ pub mod scratch;
|
|||
pub mod sendsync;
|
||||
pub mod unroll_int;
|
||||
pub mod unsized_clone;
|
||||
|
||||
pub mod safe_extern {
|
||||
pub use pm::safe_extern;
|
||||
}
|
||||
18
tests/safe_extern.rs
Normal file
18
tests/safe_extern.rs
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
use uwu::safe_extern::safe_extern;
|
||||
|
||||
#[safe_extern]
|
||||
extern "Rust" {
|
||||
fn add(a: u8, b: u8) -> u8;
|
||||
}
|
||||
|
||||
mod _impl {
|
||||
#[no_mangle]
|
||||
pub(super) fn add(a: u8, b: u8) -> u8 {
|
||||
a + b
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adding() {
|
||||
assert_eq!(add(1, 2), 3);
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue