diff --git a/rust2/src/codegen.rs b/rust2/src/codegen.rs index da3c71d..9be18e7 100644 --- a/rust2/src/codegen.rs +++ b/rust2/src/codegen.rs @@ -13,6 +13,9 @@ //! //! technically, the `JumpIfNotZero` would be an unconditional Jmp to the `JmpIfZero`, but that's //! a needless indirection. +//! +//! this module must not produce out of bounds jumps and always put the `End` instruction at the +//! end use crate::opts::{Ir, Stmt as IrStmt, StmtKind}; use crate::parse::Span; @@ -34,8 +37,18 @@ pub enum Stmt { #[derive(Debug, Clone)] pub struct Code<'c> { - pub stmts: Vec, - pub debug: Vec, + stmts: Vec, + debug: Vec, +} + +impl Code<'_> { + pub fn stmts(&self) -> &[Stmt] { + &self.stmts + } + + pub fn debug(&self) -> &[Span] { + &self.debug + } } pub fn generate<'c>(alloc: &'c Bump, ir: &Ir<'_>) -> Code<'c> { @@ -70,7 +83,7 @@ fn ir_to_stmt<'c>(code: &mut Code<'c>, ir_stmt: &IrStmt<'_>) { StmtKind::SetNull => Stmt::SetNull, StmtKind::Loop(instr) => { let skip_jmp_idx = code.stmts.len(); - code.stmts.push(Stmt::JmpIfZero(usize::MAX)); // placeholder + code.stmts.push(Stmt::JmpIfZero(0)); // placeholder code.debug.push(ir_stmt.span); // compile the loop body now diff --git a/rust2/src/codegen_interpreter.rs b/rust2/src/codegen_interpreter.rs index 67c4eb9..10cfa15 100644 --- a/rust2/src/codegen_interpreter.rs +++ b/rust2/src/codegen_interpreter.rs @@ -30,13 +30,22 @@ where mem: [Wrapping(0u8); MEM_SIZE], }; - interpreter.execute(); + // SAFETY: `Code` can only be produced by the `crate::codegen` module, which is trusted to not + // produce out of bounds jumps and put the `End` at the end + unsafe { + interpreter.execute(); + } } impl<'c, W: Write, R: Read> Interpreter<'c, W, R> { - fn execute(&mut self) { + unsafe fn execute(&mut self) { + let stmts = self.code.stmts(); loop { - let instr = self.code.stmts[self.ip]; + // SAFETY: If the code ends with an `End` and there are no out of bounds jumps, + // `self.ip` will never be out of bounds + // Removing this bounds check speeds up execution by about 40% + debug_assert!(self.ip < stmts.len()); + let instr = unsafe { *stmts.get_unchecked(self.ip) }; self.ip += 1; match instr { Stmt::Add(n) => { diff --git a/rust2/src/lib.rs b/rust2/src/lib.rs index dd2f7f7..90d4e08 100644 --- a/rust2/src/lib.rs +++ b/rust2/src/lib.rs @@ -1,4 +1,5 @@ #![feature(allocator_api, let_else)] +#![deny(unsafe_op_in_unsafe_fn)] #![warn(rust_2018_idioms)] #![allow(dead_code)]