diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 7d9f6c38e36a4..95dbd6f6c8f80 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -86,6 +86,7 @@ mod multiple_return_terminators; mod normalize_array_len; mod nrvo; mod prettify; +mod ref_cmp_simplify; mod ref_prop; mod remove_noop_landing_pads; mod remove_storage_markers; @@ -561,6 +562,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &instsimplify::InstSimplify, &simplify::SimplifyLocals::BeforeConstProp, ©_prop::CopyProp, + &ref_cmp_simplify::RefCmpSimplify, &ref_prop::ReferencePropagation, // Perform `SeparateConstSwitch` after SSA-based analyses, as cloning blocks may // destroy the SSA property. It should still happen before const-propagation, so the diff --git a/compiler/rustc_mir_transform/src/ref_cmp_simplify.rs b/compiler/rustc_mir_transform/src/ref_cmp_simplify.rs new file mode 100644 index 0000000000000..2ac61b72c156a --- /dev/null +++ b/compiler/rustc_mir_transform/src/ref_cmp_simplify.rs @@ -0,0 +1,93 @@ +use crate::MirPass; +use rustc_middle::mir::patch::MirPatch; +use rustc_middle::mir::*; +use rustc_middle::ty::TyCtxt; + +pub struct RefCmpSimplify; + +impl<'tcx> MirPass<'tcx> for RefCmpSimplify { + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + self.simplify_ref_cmp(tcx, body) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum MatchState { + Empty, + Deref { src_statement_idx: usize, dst: Local, src: Local }, + CopiedFrom { src_statement_idx: usize, dst: Local, real_src: Local }, + Completed { src_statement_idx: usize, dst: Local, real_src: Local }, +} + +impl RefCmpSimplify { + fn simplify_ref_cmp<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!("body: {:#?}", body); + + let n_bbs = body.basic_blocks.len() as u32; + for bb in 0..n_bbs { + let bb = BasicBlock::from_u32(bb); + let mut max = Local::MAX; + 'repeat: loop { + let mut state = MatchState::Empty; + let bb_data = &body.basic_blocks[bb]; + for (i, stmt) in bb_data.statements.iter().enumerate().rev() { + state = match (state, &stmt.kind) { + ( + MatchState::Empty, + StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Copy(rhs)))), + ) if rhs.has_deref() && lhs.ty(body, tcx).ty.is_primitive() => { + let Some(dst) = lhs.as_local() else { + continue + }; + let Some(src) = rhs.local_or_deref_local() else { + continue; + }; + if max <= dst { + continue; + } + max = dst; + MatchState::Deref { dst, src, src_statement_idx: i } + } + ( + MatchState::Deref { src, dst, src_statement_idx }, + StatementKind::Assign(box (lhs, Rvalue::CopyForDeref(rhs))), + ) if lhs.as_local() == Some(src) && rhs.has_deref() => { + let Some(real_src) = rhs.local_or_deref_local() else{ + continue; + }; + MatchState::CopiedFrom { src_statement_idx, dst, real_src } + } + ( + MatchState::CopiedFrom { src_statement_idx, dst, real_src }, + StatementKind::Assign(box ( + lhs, + Rvalue::Ref(_, BorrowKind::Shared | BorrowKind::Shallow, rhs), + )), + ) if lhs.as_local() == Some(real_src) => { + let Some(real_src) = rhs.as_local() else { + continue; + }; + MatchState::Completed { dst, real_src, src_statement_idx } + } + _ => continue, + }; + if let MatchState::Completed { dst, real_src, src_statement_idx } = state { + let mut patch = MirPatch::new(&body); + let src = Place::from(real_src); + let src = src.project_deeper(&[PlaceElem::Deref], tcx); + let dst = Place::from(dst); + let new_stmt = + StatementKind::Assign(Box::new((dst, Rvalue::Use(Operand::Copy(src))))); + patch.add_statement( + Location { block: bb, statement_index: src_statement_idx + 1 }, + new_stmt, + ); + patch.apply(body); + continue 'repeat; + } + } + break; + } + } + } +} diff --git a/tests/mir-opt/ref_int_cmp.opt1.RefCmpSimplify.diff b/tests/mir-opt/ref_int_cmp.opt1.RefCmpSimplify.diff new file mode 100644 index 0000000000000..285d0d9b65ebe --- /dev/null +++ b/tests/mir-opt/ref_int_cmp.opt1.RefCmpSimplify.diff @@ -0,0 +1,50 @@ +- // MIR for `opt1` before RefCmpSimplify ++ // MIR for `opt1` after RefCmpSimplify + + fn opt1(_1: &u8, _2: &u8) -> bool { + debug x => _1; // in scope 0 at $DIR/ref_int_cmp.rs:+0:13: +0:14 + debug y => _2; // in scope 0 at $DIR/ref_int_cmp.rs:+0:21: +0:22 + let mut _0: bool; // return place in scope 0 at $DIR/ref_int_cmp.rs:+0:32: +0:36 + let mut _3: &&u8; // in scope 0 at $DIR/ref_int_cmp.rs:+1:3: +1:4 + let mut _4: &&u8; // in scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8 + let _5: &u8; // in scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8 + let mut _8: &u8; // in scope 0 at $SRC_DIR/core/src/cmp.rs:LL:COL + let mut _9: &u8; // in scope 0 at $SRC_DIR/core/src/cmp.rs:LL:COL + scope 1 (inlined cmp::impls::::lt) { // at $DIR/ref_int_cmp.rs:5:3: 5:8 + debug self => _3; // in scope 1 at $SRC_DIR/core/src/cmp.rs:LL:COL + debug other => _4; // in scope 1 at $SRC_DIR/core/src/cmp.rs:LL:COL + let mut _6: &u8; // in scope 1 at $SRC_DIR/core/src/cmp.rs:LL:COL + let mut _7: &u8; // in scope 1 at $SRC_DIR/core/src/cmp.rs:LL:COL + scope 2 (inlined cmp::impls::::lt) { // at $SRC_DIR/core/src/cmp.rs:LL:COL + debug self => _6; // in scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL + debug other => _7; // in scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL + let mut _10: u8; // in scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL + let mut _11: u8; // in scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL + } + } + + bb0: { + StorageLive(_3); // scope 0 at $DIR/ref_int_cmp.rs:+1:3: +1:4 + _3 = &_1; // scope 0 at $DIR/ref_int_cmp.rs:+1:3: +1:4 + StorageLive(_4); // scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8 + StorageLive(_5); // scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8 + _5 = _2; // scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8 + _4 = &_5; // scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8 + _6 = deref_copy (*_3); // scope 1 at $SRC_DIR/core/src/cmp.rs:LL:COL + _7 = deref_copy (*_4); // scope 1 at $SRC_DIR/core/src/cmp.rs:LL:COL + StorageLive(_10); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL + _10 = (*_6); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL ++ _10 = (*_1); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL + StorageLive(_11); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL + _11 = (*_7); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL ++ _11 = (*_5); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL + _0 = Lt(move _10, move _11); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL + StorageDead(_11); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL + StorageDead(_10); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL + StorageDead(_4); // scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8 + StorageDead(_3); // scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8 + StorageDead(_5); // scope 0 at $DIR/ref_int_cmp.rs:+2:1: +2:2 + return; // scope 0 at $DIR/ref_int_cmp.rs:+2:2: +2:2 + } + } + diff --git a/tests/mir-opt/ref_int_cmp.rs b/tests/mir-opt/ref_int_cmp.rs new file mode 100644 index 0000000000000..305e09a7b2146 --- /dev/null +++ b/tests/mir-opt/ref_int_cmp.rs @@ -0,0 +1,10 @@ +// compile-flags: -O -Zmir-opt-level=3 + +// EMIT_MIR ref_int_cmp.opt1.RefCmpSimplify.diff +pub fn opt1(x: &u8, y: &u8) -> bool { + x < y +} + +fn main() { + opt1(&1, &2); +}