From 31f2304a30b0d096d11b42bfb4119a462e8be1ef Mon Sep 17 00:00:00 2001 From: Nilstrieb <48135649+Nilstrieb@users.noreply.github.com> Date: Sat, 16 Apr 2022 00:37:50 +0200 Subject: [PATCH] simplify window passes --- rust2/src/opts.rs | 108 +++++++++++++++++++++++----------------------- 1 file changed, 55 insertions(+), 53 deletions(-) diff --git a/rust2/src/opts.rs b/rust2/src/opts.rs index fe3c6cb..ca57113 100644 --- a/rust2/src/opts.rs +++ b/rust2/src/opts.rs @@ -160,42 +160,65 @@ fn pass_find_set_null(ir: &mut Ir<'_>) { /// pass that replaces `SetN(n) Add(m)` with `SetN(n + m)` fn pass_set_n(ir: &mut Ir<'_>) { - let stmts = &mut ir.stmts; - for i in 0..stmts.len() { - let a = &mut stmts[i]; - if let StmtKind::Loop(body) = &mut a.kind { - pass_set_n(body); - } - - if i >= stmts.len() - 1 { - break; // we are the last element - } - - let a = &stmts[i]; + two_window_pass(ir, pass_set_n, |a, b| { if let StmtKind::SetN(before) = a.kind() { - let b = &stmts[i + 1]; let new = match b.kind() { StmtKind::Add(n) => StmtKind::SetN(before.wrapping_add(*n)), StmtKind::Sub(n) => StmtKind::SetN(before.wrapping_sub(*n)), _ => { - continue; + return WindowPassAction::None; } }; - let span = a.span.merge(b.span); - stmts.remove(i + 1); - stmts[i] = Stmt::new(new, span); + return WindowPassAction::Merge(new); } - } + WindowPassAction::None + }); } /// pass that replaces `Left(5) Right(3)` with `Left(2)` fn pass_cancel_left_right_add_sub(ir: &mut Ir<'_>) { + two_window_pass(ir, pass_cancel_left_right_add_sub, |a, b| { + match (a.kind(), b.kind()) { + (StmtKind::Right(r), StmtKind::Left(l)) | (StmtKind::Left(l), StmtKind::Right(r)) => { + let new = match r.cmp(l) { + Ordering::Equal => return WindowPassAction::RemoveBoth, + Ordering::Less => StmtKind::Left(l - r), + Ordering::Greater => StmtKind::Right(r - l), + }; + + WindowPassAction::Merge(new) + } + (StmtKind::Add(r), StmtKind::Sub(l)) | (StmtKind::Sub(l), StmtKind::Add(r)) => { + let new = match r.cmp(l) { + Ordering::Equal => return WindowPassAction::RemoveBoth, + Ordering::Less => StmtKind::Sub(l - r), + Ordering::Greater => StmtKind::Add(r - l), + }; + + WindowPassAction::Merge(new) + } + _ => WindowPassAction::None, + } + }) +} + +enum WindowPassAction<'ir> { + None, + Merge(StmtKind<'ir>), + RemoveBoth, +} + +fn two_window_pass<'ir, P, F>(ir: &mut Ir<'ir>, pass_recur: P, action: F) +where + P: Fn(&mut Ir<'ir>), + F: Fn(&Stmt<'ir>, &Stmt<'ir>) -> WindowPassAction<'ir>, +{ let stmts = &mut ir.stmts; let mut i = 0; while i < stmts.len() { let a = &mut stmts[i]; if let StmtKind::Loop(body) = &mut a.kind { - pass_cancel_left_right_add_sub(body); + pass_recur(body); } if i >= stmts.len() - 1 { @@ -205,44 +228,23 @@ fn pass_cancel_left_right_add_sub(ir: &mut Ir<'_>) { let a = &stmts[i]; let b = &stmts[i + 1]; - match (a.kind(), b.kind()) { - (StmtKind::Right(r), StmtKind::Left(l)) | (StmtKind::Left(l), StmtKind::Right(r)) => { - let new = match r.cmp(l) { - Ordering::Equal => { - // remove both - stmts.remove(i + 1); - stmts.remove(i); - continue; - } - Ordering::Less => StmtKind::Left(l - r), - Ordering::Greater => StmtKind::Right(r - l), - }; + let merged_span = a.span.merge(b.span); - let span = a.span.merge(b.span); - stmts.remove(i + 1); - stmts[i] = Stmt::new(new, span); - // don't increment i, maybe the next one matches as well (<><) - } - (StmtKind::Add(r), StmtKind::Sub(l)) | (StmtKind::Sub(l), StmtKind::Add(r)) => { - let new = match r.cmp(l) { - Ordering::Equal => { - // remove both - stmts.remove(i + 1); - stmts.remove(i); - continue; - } - Ordering::Less => StmtKind::Sub(l - r), - Ordering::Greater => StmtKind::Add(r - l), - }; + let result = action(a, b); - let span = a.span.merge(b.span); - stmts.remove(i + 1); - stmts[i] = Stmt::new(new, span); - // don't increment i, maybe the next one matches as well (<><) - } - _ => { + match result { + WindowPassAction::None => { + // only increment i if we haven't removed anything i += 1; } + WindowPassAction::RemoveBoth => { + stmts.remove(i); + stmts.remove(i + 1); + } + WindowPassAction::Merge(new) => { + stmts.remove(i + 1); + stmts[i] = Stmt::new(new, merged_span); + } } } }