Index: llvm/include/llvm/Transforms/Scalar/OptimizeCompares.h =================================================================== --- /dev/null +++ llvm/include/llvm/Transforms/Scalar/OptimizeCompares.h @@ -0,0 +1,33 @@ +//===- OptimizeCompares.h - Optimize Compares -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the Optimize Compares pass. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_SCALAR_OPTIMIZE_COMPARES_H +#define LLVM_TRANSFORMS_SCALAR_OPTIMIZE_COMPARES_H + +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +class Loop; +class LPMUpdater; + +class OptimizeComparesPass : public PassInfoMixin { +public: + PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, LPMUpdater &U); +}; + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_SCALAR_OPTIMIZE_COMPARES_H Index: llvm/lib/Passes/PassBuilder.cpp =================================================================== --- llvm/lib/Passes/PassBuilder.cpp +++ llvm/lib/Passes/PassBuilder.cpp @@ -199,6 +199,7 @@ #include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h" #include "llvm/Transforms/Scalar/NaryReassociate.h" #include "llvm/Transforms/Scalar/NewGVN.h" +#include "llvm/Transforms/Scalar/OptimizeCompares.h" #include "llvm/Transforms/Scalar/PartiallyInlineLibCalls.h" #include "llvm/Transforms/Scalar/Reassociate.h" #include "llvm/Transforms/Scalar/Reg2Mem.h" Index: llvm/lib/Passes/PassRegistry.def =================================================================== --- llvm/lib/Passes/PassRegistry.def +++ llvm/lib/Passes/PassRegistry.def @@ -513,6 +513,7 @@ LOOP_PASS("loop-instsimplify", LoopInstSimplifyPass()) LOOP_PASS("loop-rotate", LoopRotatePass()) LOOP_PASS("no-op-loop", NoOpLoopPass()) +LOOP_PASS("opt-compares", OptimizeComparesPass()) LOOP_PASS("print", PrintLoopPass(dbgs())) LOOP_PASS("loop-deletion", LoopDeletionPass()) LOOP_PASS("loop-simplifycfg", LoopSimplifyCFGPass()) Index: llvm/lib/Transforms/Scalar/CMakeLists.txt =================================================================== --- llvm/lib/Transforms/Scalar/CMakeLists.txt +++ llvm/lib/Transforms/Scalar/CMakeLists.txt @@ -58,6 +58,7 @@ MergedLoadStoreMotion.cpp NaryReassociate.cpp NewGVN.cpp + OptimizeCompares.cpp PartiallyInlineLibCalls.cpp PlaceSafepoints.cpp Reassociate.cpp Index: llvm/lib/Transforms/Scalar/OptimizeCompares.cpp =================================================================== --- /dev/null +++ llvm/lib/Transforms/Scalar/OptimizeCompares.cpp @@ -0,0 +1,126 @@ +#include "llvm/Transforms/Scalar/OptimizeCompares.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Pass.h" + +#define DEBUG_TYPE "opt-compares" + +using namespace llvm; + +namespace { +struct CompareDesc { + BranchInst *Term; + Value *Invariant; +}; +} + +static bool insertInvariantChecks(Loop &L, ICmpInst::Predicate Pred, + Instruction *Variant, + ArrayRef Compares, + DominatorTree &DT, LoopInfo &LI) { + assert(ICmpInst::isRelational(Pred)); + assert(ICmpInst::isStrictPredicate(Pred)); + if (Compares.size() < 2) + return false; + ICmpInst::Predicate NonStrictPred = ICmpInst::getNonStrictPredicate(Pred); + + const CompareDesc *Previous = nullptr; + DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy); + for (auto &Desc : llvm::reverse(Compares)) { + if (Previous) { + auto *Term = Desc.Term; + BasicBlock *InLoopSucc = Term->getSuccessor(0); + BasicBlock *BB = Term->getParent(); + auto *SplitBB = SplitBlock(BB, Term, &DTU, &LI, /*MSSAU*/ nullptr, + BB->getName() + ".split", /*Before*/ true); + // Channel SplitBB around BB. + auto *SafeBB = BasicBlock::Create( + BB->getContext(), BB->getName() + ".safe", BB->getParent(), BB); + IRBuilder Builder(SplitBB->getTerminator()); + auto *InvariantCond = Builder.CreateICmp( + NonStrictPred, Previous->Invariant, Desc.Invariant, "invariant.cmp"); + Builder.CreateCondBr(InvariantCond, SafeBB, BB); + SplitBB->getTerminator()->eraseFromParent(); + + Builder.SetInsertPoint(SafeBB); + Builder.CreateBr(InLoopSucc); + + SmallVector DTUpdates = { + { DominatorTree::Insert, SplitBB, SafeBB }, + { DominatorTree::Insert, SafeBB, InLoopSucc }, + }; + DTU.applyUpdates(DTUpdates); + L.addBasicBlockToLoop(SafeBB, LI); + } + Previous = &Desc; + } + return true; +} + +static bool optimizeComparisons(Loop &L, DominatorTree &DT, LoopInfo &LI) { + if (!DT.isReachableFromEntry(L.getHeader())) + return false; + auto *Latch = L.getLoopLatch(); + if (!Latch) + return false; + + // Maps (Predicate, Variant) to invariants. + DenseMap > CandidatesULT; + for (auto *DTN = DT.getNode(Latch); L.contains(DTN->getBlock()); + DTN = DTN->getIDom()) { + ICmpInst::Predicate Pred; + Value *LHS = nullptr, *RHS = nullptr; + BasicBlock *IfTrue = nullptr, *IfFalse = nullptr; + using namespace PatternMatch; + auto *BB = DTN->getBlock(); + auto *Term = BB->getTerminator(); + if (!match(Term, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), + m_BasicBlock(IfTrue), m_BasicBlock(IfFalse)))) + continue; + // Only deal with loop exits. + if (L.contains(IfFalse)) + continue; + // Canonicalize to "Variant LHS, invariant RHS". + if (L.isLoopInvariant(LHS)) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(LHS, RHS); + } + // TODO: Support other predicates? + if (Pred != ICmpInst::ICMP_ULT) + continue; + if (L.isLoopInvariant(LHS) || !L.isLoopInvariant(RHS)) + continue; + CompareDesc Desc; + Desc.Term = cast(Term); + Desc.Invariant = RHS; + CandidatesULT[cast(LHS)].push_back(Desc); + } + + bool Changed = false; + for (auto &It : CandidatesULT) + Changed |= insertInvariantChecks(L, ICmpInst::ICMP_ULT, It.first, It.second, + DT, LI); + return Changed; +} + +PreservedAnalyses OptimizeComparesPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + + if (!optimizeComparisons(L, AR.DT, AR.LI)) + return PreservedAnalyses::all(); + +#ifndef NDEBUG + assert(AR.DT.verify()); + AR.LI.verify(AR.DT); +#endif + + auto PA = getLoopPassPreservedAnalyses(); + PA.preserve(); + return PA; +} Index: llvm/test/Transforms/OptimizeCompares/range_check.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/OptimizeCompares/range_check.ll @@ -0,0 +1,69 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -S -passes=opt-compares | FileCheck %s + +define i32 @src(ptr noundef %p, i32 noundef %n, i32 noundef %limit, ptr noundef %arr, ptr noundef %len_p) { +; CHECK-LABEL: @src( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[LEN:%.*]] = load i32, ptr [[LEN_P:%.*]], align 4, !range [[RNG0:![0-9]+]], !noundef !1 +; CHECK-NEXT: br label [[LOOP:%.*]] +; CHECK: loop: +; CHECK-NEXT: [[IV:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[BACKEDGE:%.*]] ] +; CHECK-NEXT: [[EL_PTR:%.*]] = getelementptr i32, ptr [[P:%.*]], i32 [[IV]] +; CHECK-NEXT: [[EL:%.*]] = load i32, ptr [[EL_PTR]], align 4 +; CHECK-NEXT: [[BOUND_CHECK:%.*]] = icmp ult i32 [[EL]], [[LIMIT:%.*]] +; CHECK-NEXT: br i1 [[BOUND_CHECK]], label [[GUARDED_SPLIT:%.*]], label [[BOUND_CHECK_FAILED:%.*]] +; CHECK: guarded.split: +; CHECK-NEXT: [[RANGE_CHECK:%.*]] = icmp ult i32 [[EL]], [[LEN]] +; CHECK-NEXT: [[INVARIANT_CMP:%.*]] = icmp ule i32 [[LIMIT]], [[LEN]] +; CHECK-NEXT: br i1 [[INVARIANT_CMP]], label [[GUARDED_SAFE:%.*]], label [[GUARDED:%.*]] +; CHECK: guarded.safe: +; CHECK-NEXT: br label [[BACKEDGE]] +; CHECK: guarded: +; CHECK-NEXT: br i1 [[RANGE_CHECK]], label [[BACKEDGE]], label [[RANGE_CHECK_FAILED:%.*]] +; CHECK: backedge: +; CHECK-NEXT: [[ARR_PTR:%.*]] = getelementptr i32, ptr [[ARR:%.*]], i32 [[EL]] +; CHECK-NEXT: store i32 [[IV]], ptr [[ARR_PTR]], align 4 +; CHECK-NEXT: [[IV_NEXT]] = add i32 [[IV]], 1 +; CHECK-NEXT: [[LOOP_COND:%.*]] = icmp slt i32 [[IV_NEXT]], [[N:%.*]] +; CHECK-NEXT: br i1 [[LOOP_COND]], label [[LOOP]], label [[EXIT:%.*]] +; CHECK: exit: +; CHECK-NEXT: ret i32 0 +; CHECK: bound_check_failed: +; CHECK-NEXT: ret i32 -1 +; CHECK: range_check_failed: +; CHECK-NEXT: ret i32 -2 +; +entry: + %len = load i32, ptr %len_p, align 4, !range !0, !noundef !1 + br label %loop + +loop: ; preds = %backedge, %entry + %iv = phi i32 [ 0, %entry ], [ %iv.next, %backedge ] + %el.ptr = getelementptr i32, ptr %p, i32 %iv + %el = load i32, ptr %el.ptr, align 4 + %bound_check = icmp ult i32 %el, %limit + br i1 %bound_check, label %guarded, label %bound_check_failed + +guarded: ; preds = %loop + %range_check = icmp ult i32 %el, %len + br i1 %range_check, label %backedge, label %range_check_failed + +backedge: ; preds = %guarded + %arr.ptr = getelementptr i32, ptr %arr, i32 %el + store i32 %iv, ptr %arr.ptr, align 4 + %iv.next = add i32 %iv, 1 + %loop_cond = icmp slt i32 %iv.next, %n + br i1 %loop_cond, label %loop, label %exit + +exit: ; preds = %backedge + ret i32 0 + +bound_check_failed: ; preds = %loop + ret i32 -1 + +range_check_failed: ; preds = %guarded + ret i32 -2 +} + +!0 = !{i32 0, i32 2147483646} +!1 = !{}