From 8952e8082a835279883b0a1f7faee3a7a5ec46bc Mon Sep 17 00:00:00 2001 From: Nilstrieb <48135649+Nilstrieb@users.noreply.github.com> Date: Sat, 6 Nov 2021 21:32:49 +0100 Subject: [PATCH] fix stackoverflow --- fuzz/fuzz_targets/lex_parse.rs | 9 +- src/lib.rs | 1 + src/parse/mod.rs | 217 +++++++++++++++++++++++++++------ src/parse/test.rs | 11 ++ 4 files changed, 201 insertions(+), 37 deletions(-) diff --git a/fuzz/fuzz_targets/lex_parse.rs b/fuzz/fuzz_targets/lex_parse.rs index c5f7755..750fa4b 100644 --- a/fuzz/fuzz_targets/lex_parse.rs +++ b/fuzz/fuzz_targets/lex_parse.rs @@ -1,6 +1,11 @@ #![no_main] use libfuzzer_sys::fuzz_target; -fuzz_target!(|data: &[u8]| { - // fuzzed code goes here +fuzz_target!(|data: String| { + let lexer = script_lang::Lexer::lex(&data); + let tokens = lexer.collect::, _>>(); + + if let Ok(tokens) = tokens { + let _ast = script_lang::parse(tokens); + } }); diff --git a/src/lib.rs b/src/lib.rs index 59e1682..93246cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ mod lex; mod parse; pub use lex::*; +pub use parse::*; pub fn run_program(program: &str) { let lexer = lex::Lexer::lex(program); diff --git a/src/parse/mod.rs b/src/parse/mod.rs index e0a755b..1845abb 100644 --- a/src/parse/mod.rs +++ b/src/parse/mod.rs @@ -9,6 +9,7 @@ use std::iter::Peekable; pub fn parse(tokens: Vec) -> Result { let mut parser = Parser { tokens: tokens.into_iter().peekable(), + depth: 0, inside_fn_depth: 0, inside_loop_depth: 0, }; @@ -19,6 +20,7 @@ pub fn parse(tokens: Vec) -> Result { #[derive(Debug)] struct Parser<'code> { tokens: Peekable>>, + depth: usize, inside_fn_depth: usize, inside_loop_depth: usize, } @@ -38,26 +40,60 @@ macro_rules! parse_bin_op { }}; } +macro_rules! exit_parse { + ($self: ident) => { + $self.depth -= 1; + }; +} + +macro_rules! enter_parse { + ($self: ident) => { + $self.depth += 1; + + if $self.depth > Self::MAX_DEPTH { + let _ = $self.too_nested_error()?; + } + }; +} + impl<'code> Parser<'code> { + const MAX_DEPTH: usize = 100; + fn program(&mut self) -> ParseResult<'code, Program> { Ok(Program(self.statement_list()?)) } - fn statement_list(&mut self) -> ParseResult<'code, Vec> { - let mut stmts = Vec::new(); - loop { - if let Some(TokenType::BraceC) | None = self.peek_kind() { - return Ok(stmts); - } - let stmt = self.statement()?; - stmts.push(stmt); + fn too_nested_error(&mut self) -> ParseResult<'code, ()> { + let next_token = self.next(); + match next_token { + Some(token) => Err(ParseErr::MaxDepth(token.span)), + None => Err(ParseErr::Eof("reached EOF while being nested to deeply")), } } + fn statement_list(&mut self) -> ParseResult<'code, Vec> { + enter_parse!(self); + let mut stmts = Vec::new(); + let return_stmts = loop { + if let Some(TokenType::BraceC) | None = self.peek_kind() { + break Ok(stmts); + } + let stmt = self.statement()?; + stmts.push(stmt); + }; + exit_parse!(self); + return_stmts + } + fn block(&mut self) -> ParseResult<'code, Block> { + enter_parse!(self); + let start_span = self.expect(TokenType::BraceO)?.span; let stmts = self.statement_list()?; let end_span = self.expect(TokenType::BraceC)?.span; + + exit_parse!(self); + Ok(Block { stmts, span: start_span.extend(end_span), @@ -65,7 +101,9 @@ impl<'code> Parser<'code> { } fn statement(&mut self) -> ParseResult<'code, Stmt> { - match *self.peek_kind().ok_or(ParseErr::Eof("statement"))? { + enter_parse!(self); + + let stmt = match *self.peek_kind().ok_or(ParseErr::Eof("statement"))? { TokenType::Let => self.declaration(), TokenType::Fn => self.fn_decl(), TokenType::If => Ok(Stmt::If(self.if_stmt()?)), @@ -78,15 +116,22 @@ impl<'code> Parser<'code> { let stmt = self.assignment()?; Ok(stmt) } - } + }; + exit_parse!(self); + stmt } fn declaration(&mut self) -> ParseResult<'code, Stmt> { + enter_parse!(self); + let keyword_span = self.expect(TokenType::Let)?.span; let name = self.ident()?; self.expect(TokenType::Equal)?; let init = self.expression()?; self.expect(TokenType::Semi)?; + + exit_parse!(self); + Ok(Stmt::Declaration(Declaration { span: keyword_span.extend(init.span()), name, @@ -95,6 +140,8 @@ impl<'code> Parser<'code> { } fn fn_decl(&mut self) -> ParseResult<'code, Stmt> { + enter_parse!(self); + let keyword_span = self.expect(TokenType::Fn)?.span; let name = self.ident()?; let args = self.fn_args()?; @@ -103,6 +150,8 @@ impl<'code> Parser<'code> { let body = self.block()?; self.inside_fn_depth -= 1; + exit_parse!(self); + Ok(Stmt::FnDecl(FnDecl { span: keyword_span.extend(body.span), name, @@ -112,13 +161,20 @@ impl<'code> Parser<'code> { } fn fn_args(&mut self) -> ParseResult<'code, Vec> { + enter_parse!(self); + self.expect(TokenType::ParenO)?; let params = self.parse_list(TokenType::ParenC, Self::ident)?; self.expect(TokenType::ParenC)?; + + exit_parse!(self); + Ok(params) } fn if_stmt(&mut self) -> ParseResult<'code, IfStmt> { + enter_parse!(self); + let keyword_span = self.expect(TokenType::If)?.span; let cond = self.expression()?; let body = self.block()?; @@ -129,6 +185,8 @@ impl<'code> Parser<'code> { None }; + exit_parse!(self); + Ok(IfStmt { span: keyword_span .extend(body.span) @@ -140,9 +198,11 @@ impl<'code> Parser<'code> { } fn else_part(&mut self) -> ParseResult<'code, ElsePart> { + enter_parse!(self); + let keyword_span = self.expect(TokenType::Else)?.span; - if let Some(TokenType::If) = self.peek_kind() { + let else_part = if let Some(TokenType::If) = self.peek_kind() { let else_if_stmt = self.if_stmt()?; let else_span = keyword_span.extend(else_if_stmt.span); Ok(ElsePart::ElseIf(else_if_stmt, else_span)) @@ -150,10 +210,16 @@ impl<'code> Parser<'code> { let block = self.block()?; let else_span = keyword_span.extend(block.span); Ok(ElsePart::Else(block, else_span)) - } + }; + + exit_parse!(self); + + else_part } fn loop_stmt(&mut self) -> ParseResult<'code, Stmt> { + enter_parse!(self); + let keyword_span = self.expect(TokenType::Loop)?.span; self.inside_loop_depth += 1; @@ -161,10 +227,15 @@ impl<'code> Parser<'code> { self.inside_loop_depth -= 1; let loop_span = keyword_span.extend(block.span); + + exit_parse!(self); + Ok(Stmt::Loop(block, keyword_span.extend(loop_span))) } fn while_stmt(&mut self) -> ParseResult<'code, Stmt> { + enter_parse!(self); + let keyword_span = self.expect(TokenType::While)?.span; let cond = self.expression()?; @@ -172,6 +243,8 @@ impl<'code> Parser<'code> { let body = self.block()?; self.inside_loop_depth -= 1; + exit_parse!(self); + Ok(Stmt::While(WhileStmt { span: keyword_span.extend(body.span), cond, @@ -180,9 +253,13 @@ impl<'code> Parser<'code> { } fn break_stmt(&mut self) -> ParseResult<'code, Stmt> { + enter_parse!(self); + let keyword_span = self.expect(TokenType::Break)?.span; let semi_span = self.expect(TokenType::Semi)?.span; + exit_parse!(self); + if self.inside_loop_depth == 0 { Err(ParseErr::BreakOutsideLoop(keyword_span.extend(semi_span))) } else { @@ -191,6 +268,8 @@ impl<'code> Parser<'code> { } fn return_stmt(&mut self) -> ParseResult<'code, Stmt> { + enter_parse!(self); + let keyword_span = self.expect(TokenType::Return)?.span; let expr = if let Some(TokenType::Semi) = self.peek_kind() { @@ -201,6 +280,8 @@ impl<'code> Parser<'code> { let semi_span = self.expect(TokenType::Semi)?.span; + exit_parse!(self); + if self.inside_fn_depth == 0 { Err(ParseErr::ReturnOutsideFunction( keyword_span.extend(semi_span), @@ -211,9 +292,11 @@ impl<'code> Parser<'code> { } fn assignment(&mut self) -> ParseResult<'code, Stmt> { + enter_parse!(self); + let expr = self.expression()?; - if let Some(TokenType::Equal) = self.peek_kind() { + let stmt = if let Some(TokenType::Equal) = self.peek_kind() { let _ = self.expect(TokenType::Equal)?; let init = self.expression()?; let semi_span = self.expect(TokenType::Semi)?.span; @@ -225,32 +308,50 @@ impl<'code> Parser<'code> { } else { let _ = self.expect(TokenType::Semi)?; Ok(Stmt::Expr(expr)) - } + }; + + exit_parse!(self); + stmt } fn expression(&mut self) -> ParseResult<'code, Expr> { - self.logical_or() + enter_parse!(self); + let return_expr = self.logical_or(); + exit_parse!(self); + return_expr } fn logical_or(&mut self) -> ParseResult<'code, Expr> { + enter_parse!(self); + let lhs = self.logical_and()?; - match self.peek_kind() { + let return_expr = match self.peek_kind() { Some(TokenType::Or) => parse_bin_op!(self, lhs, BinaryOpKind::Or, logical_or), _ => Ok(lhs), - } + }; + + exit_parse!(self); + return_expr } fn logical_and(&mut self) -> ParseResult<'code, Expr> { + enter_parse!(self); + let lhs = self.equality()?; - match self.peek_kind() { + let return_expr = match self.peek_kind() { Some(TokenType::And) => parse_bin_op!(self, lhs, BinaryOpKind::And, logical_and), _ => Ok(lhs), - } + }; + + exit_parse!(self); + return_expr } fn equality(&mut self) -> ParseResult<'code, Expr> { + enter_parse!(self); + let lhs = self.comparison()?; - match self.peek_kind() { + let return_expr = match self.peek_kind() { Some(TokenType::BangEqual) => { parse_bin_op!(self, lhs, BinaryOpKind::NotEqual, comparison) } @@ -258,12 +359,16 @@ impl<'code> Parser<'code> { parse_bin_op!(self, lhs, BinaryOpKind::Equal, comparison) } _ => Ok(lhs), - } + }; + exit_parse!(self); + return_expr } fn comparison(&mut self) -> ParseResult<'code, Expr> { + enter_parse!(self); + let lhs = self.term()?; - match self.peek_kind() { + let return_expr = match self.peek_kind() { Some(TokenType::Greater) => parse_bin_op!(self, lhs, BinaryOpKind::Greater, term), Some(TokenType::GreaterEqual) => { parse_bin_op!(self, lhs, BinaryOpKind::GreaterEqual, term) @@ -273,30 +378,42 @@ impl<'code> Parser<'code> { parse_bin_op!(self, lhs, BinaryOpKind::LessEqual, term) } _ => Ok(lhs), - } + }; + exit_parse!(self); + return_expr } fn term(&mut self) -> ParseResult<'code, Expr> { + enter_parse!(self); + let lhs = self.factor()?; - match self.peek_kind() { + let return_expr = match self.peek_kind() { Some(TokenType::Plus) => parse_bin_op!(self, lhs, BinaryOpKind::Add, term), Some(TokenType::Minus) => parse_bin_op!(self, lhs, BinaryOpKind::Sub, term), _ => Ok(lhs), - } + }; + exit_parse!(self); + return_expr } fn factor(&mut self) -> ParseResult<'code, Expr> { + enter_parse!(self); + let lhs = self.unary()?; - match self.peek_kind() { + let return_expr = match self.peek_kind() { Some(TokenType::Asterisk) => parse_bin_op!(self, lhs, BinaryOpKind::Mul, factor), Some(TokenType::Slash) => parse_bin_op!(self, lhs, BinaryOpKind::Div, factor), Some(TokenType::Percent) => parse_bin_op!(self, lhs, BinaryOpKind::Mod, factor), _ => Ok(lhs), - } + }; + exit_parse!(self); + return_expr } fn unary(&mut self) -> ParseResult<'code, Expr> { - match self.peek_kind() { + enter_parse!(self); + + let return_expr = match self.peek_kind() { Some(TokenType::Not) => { let unary_op_span = self.next().unwrap().span; let expr = self.call()?; @@ -316,10 +433,14 @@ impl<'code> Parser<'code> { }))) } _ => self.call(), - } + }; + exit_parse!(self); + return_expr } fn call(&mut self) -> ParseResult<'code, Expr> { + enter_parse!(self); + let mut expr = self.primary()?; loop { @@ -349,12 +470,16 @@ impl<'code> Parser<'code> { } } + exit_parse!(self); + Ok(expr) } fn primary(&mut self) -> ParseResult<'code, Expr> { + enter_parse!(self); + let next = self.next().ok_or(ParseErr::Eof("primary"))?; - match next.kind { + let return_expr = match next.kind { TokenType::String(literal) => Ok(Expr::Literal(Literal::String(literal, next.span))), TokenType::Number(literal) => Ok(Expr::Literal(Literal::Number(literal, next.span))), TokenType::False => Ok(Expr::Literal(Literal::Boolean(false, next.span))), @@ -375,12 +500,16 @@ impl<'code> Parser<'code> { })) } _ => Err(ParseErr::InvalidTokenPrimary(next)), - } + }; + exit_parse!(self); + return_expr } fn ident(&mut self) -> ParseResult<'code, Ident> { + enter_parse!(self); + let Token { kind, span } = self.next().ok_or(ParseErr::Eof("identifier"))?; - match kind { + let return_expr = match kind { TokenType::Ident(name) => { let name_owned = name.to_owned(); Ok(Ident { @@ -394,21 +523,32 @@ impl<'code> Parser<'code> { actual: Token { span, kind }, }) } - } + }; + exit_parse!(self); + return_expr } fn object_literal(&mut self, open_span: Span) -> ParseResult<'code, Expr> { + enter_parse!(self); + let close_span = self.expect(TokenType::BraceC)?.span; + + exit_parse!(self); Ok(Expr::Literal(Literal::Object(open_span.extend(close_span)))) } fn array_literal(&mut self, open_span: Span) -> ParseResult<'code, Expr> { + enter_parse!(self); + let elements = self.parse_list(TokenType::BracketC, Self::expression)?; let closing_bracket = self.expect(TokenType::BracketC)?; - Ok(Expr::Literal(Literal::Array( + + let return_expr = Ok(Expr::Literal(Literal::Array( elements, open_span.extend(closing_bracket.span), - ))) + ))); + exit_parse!(self); + return_expr } fn parse_list( @@ -419,6 +559,8 @@ impl<'code> Parser<'code> { where F: FnMut(&mut Self) -> ParseResult<'code, T>, { + enter_parse!(self); + let mut elements = Vec::new(); if self.peek_kind() == Some(&close) { @@ -443,6 +585,8 @@ impl<'code> Parser<'code> { let expr = parser(self)?; elements.push(expr); } + + exit_parse!(self); Ok(elements) } @@ -481,6 +625,7 @@ impl<'code> Parser<'code> { #[derive(Debug)] pub enum ParseErr<'code> { + MaxDepth(Span), BreakOutsideLoop(Span), ReturnOutsideFunction(Span), MismatchedKind { @@ -504,6 +649,7 @@ impl CompilerError for ParseErr<'_> { ParseErr::Eof(_) => Span::dummy(), ParseErr::BreakOutsideLoop(span) => *span, ParseErr::ReturnOutsideFunction(span) => *span, + ParseErr::MaxDepth(span) => *span, } } @@ -523,6 +669,7 @@ impl CompilerError for ParseErr<'_> { } ParseErr::BreakOutsideLoop(_) => "break used outside of loop".to_string(), ParseErr::ReturnOutsideFunction(_) => "return used outside of function".to_string(), + ParseErr::MaxDepth(_) => "reached maximal nesting depth".to_string(), } } diff --git a/src/parse/test.rs b/src/parse/test.rs index 89afac4..59c75db 100644 --- a/src/parse/test.rs +++ b/src/parse/test.rs @@ -42,6 +42,7 @@ fn empty_block() -> Block { fn parser(tokens: Vec) -> Parser { Parser { tokens: tokens.into_iter().peekable(), + depth: 0, inside_fn_depth: 0, inside_loop_depth: 0, } @@ -444,6 +445,16 @@ mod expr { parser.expression().unwrap() } + #[test] + fn stack_overflow() { + let tokens = std::iter::repeat(BracketO) + .map(token) + .take(100_000) + .collect(); + let expr = parser(tokens).expression(); + assert!(expr.is_err()); + } + #[test] fn number_literal() { test_number_literal(parse_expr);