diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -305,6 +306,7 @@ const BasicBlockSectionsProfileReader *BBSectionsProfileReader = nullptr; const TargetLibraryInfo *TLInfo; const LoopInfo *LI; + AssumptionCache *AC; std::unique_ptr BFI; std::unique_ptr BPI; ProfileSummaryInfo *PSI; @@ -386,6 +388,7 @@ AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addRequired(); AU.addUsedIfAvailable(); } @@ -482,6 +485,7 @@ "Optimize for code generation", false, false) INITIALIZE_PASS_DEPENDENCY(BasicBlockSectionsProfileReader) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) @@ -510,6 +514,7 @@ TLInfo = &getAnalysis().getTLI(F); TTI = &getAnalysis().getTTI(F); LI = &getAnalysis().getLoopInfo(); + AC = &getAnalysis().getAssumptionCache(F); BPI.reset(new BranchProbabilityInfo(F, *LI)); BFI.reset(new BlockFrequencyInfo(F, *BPI, *LI)); PSI = &getAnalysis().getPSI(); @@ -7961,6 +7966,8 @@ } static bool optimizeBranch(BranchInst *Branch, const TargetLowering &TLI, + const DataLayout &DL, AssumptionCache *AC, + const DominatorTree *DT, SmallSet &FreshBBs, bool IsHugeFunc) { // Try and convert @@ -7981,7 +7988,10 @@ return false; Value *X = Cmp->getOperand(0); - APInt CmpC = cast(Cmp->getOperand(1))->getValue(); + Value *C = Cmp->getOperand(1); + APInt CmpC = cast(C)->getValue(); + if (CmpC.isZero()) + return false; for (auto *U : X->users()) { Instruction *UI = dyn_cast(U); @@ -8006,12 +8016,36 @@ replaceAllUsesWith(Cmp, NewCmp, FreshBBs, IsHugeFunc); return true; } - if (Cmp->isEquality() && + if ((Cmp->isEquality() || Cmp->isSigned()) && (match(UI, m_Add(m_Specific(X), m_SpecificInt(-CmpC))) || match(UI, m_Sub(m_Specific(X), m_SpecificInt(CmpC))))) { + if (Cmp->isRelational()) { + // A cheap check + if (computeOverflowForSignedSub(X, C, DL, AC, Branch, DT) != + OverflowResult::NeverOverflows) { + // x = X, c = CmpC + // c < 0 c > 0 + // add/sub x - c <= INT_MAX x - c >= INT_MIN + bool IsUpperBound = CmpC.slt(0); + APInt Bound = + (IsUpperBound ? APInt::getSignedMaxValue(CmpC.getBitWidth()) + : APInt::getSignedMinValue(CmpC.getBitWidth())) + + CmpC; + std::optional NoOverflow = isImpliedByDomCondition( + IsUpperBound ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_SGE, X, + ConstantInt::get(UI->getType(), Bound), Branch, DL); + if (!NoOverflow || !NoOverflow.value()) + continue; + } + } + IRBuilder<> Builder(Branch); - if (UI->getParent() != Branch->getParent()) + if (UI->getParent() != Branch->getParent()) { UI->moveBefore(Branch); + // clear nsw/nuw flags + UI->setHasNoUnsignedWrap(false); + UI->setHasNoSignedWrap(false); + } Value *NewCmp = Builder.CreateCmp(Cmp->getPredicate(), UI, ConstantInt::get(UI->getType(), 0)); LLVM_DEBUG(dbgs() << "Converting " << *Cmp << "\n"); @@ -8194,7 +8228,9 @@ case Instruction::ExtractElement: return optimizeExtractElementInst(cast(I)); case Instruction::Br: - return optimizeBranch(cast(I), *TLI, FreshBBs, IsHugeFunc); + return optimizeBranch(cast(I), *TLI, *DL, AC, + &getDT(*I->getParent()->getParent()), FreshBBs, + IsHugeFunc); } return false; diff --git a/llvm/test/CodeGen/RISCV/branch-on-zero.ll b/llvm/test/CodeGen/RISCV/branch-on-zero.ll --- a/llvm/test/CodeGen/RISCV/branch-on-zero.ll +++ b/llvm/test/CodeGen/RISCV/branch-on-zero.ll @@ -176,3 +176,77 @@ while.end: ; preds = %while.body, %entry ret i32 0 } + +define i32 @test_nsw_add_may_overflow(i32 %0) { +; RV32-LABEL: test_nsw_add_may_overflow: +; RV32: # %bb.0: +; RV32-NEXT: li a1, 13 +; RV32-NEXT: bge a0, a1, .LBB4_2 +; RV32-NEXT: # %bb.1: +; RV32-NEXT: li a0, 0 +; RV32-NEXT: ret +; RV32-NEXT: .LBB4_2: +; RV32-NEXT: addi a0, a0, -13 +; RV32-NEXT: ret +; +; RV64-LABEL: test_nsw_add_may_overflow: +; RV64: # %bb.0: +; RV64-NEXT: sext.w a1, a0 +; RV64-NEXT: li a2, 13 +; RV64-NEXT: bge a1, a2, .LBB4_2 +; RV64-NEXT: # %bb.1: +; RV64-NEXT: li a0, 0 +; RV64-NEXT: ret +; RV64-NEXT: .LBB4_2: +; RV64-NEXT: addiw a0, a0, -13 +; RV64-NEXT: ret + %2 = icmp slt i32 %0, 13 + br i1 %2, label %5, label %3 + +3: ; preds = %1 + %4 = add i32 %0, -13 + ret i32 %4 + +5: ; preds = %1 + ret i32 0 +} + +define i32 @test_nsw_add_no_overflow(i32 %0) { +; RV32-LABEL: test_nsw_add_no_overflow: +; RV32: # %bb.0: +; RV32-NEXT: bltz a0, .LBB5_3 +; RV32-NEXT: # %bb.1: # %nonnegative +; RV32-NEXT: addi a0, a0, -13 +; RV32-NEXT: bltz a0, .LBB5_3 +; RV32-NEXT: # %bb.2: +; RV32-NEXT: ret +; RV32-NEXT: .LBB5_3: +; RV32-NEXT: li a0, 0 +; RV32-NEXT: ret +; +; RV64-LABEL: test_nsw_add_no_overflow: +; RV64: # %bb.0: +; RV64-NEXT: sext.w a1, a0 +; RV64-NEXT: bltz a1, .LBB5_3 +; RV64-NEXT: # %bb.1: # %nonnegative +; RV64-NEXT: addiw a0, a0, -13 +; RV64-NEXT: bltz a0, .LBB5_3 +; RV64-NEXT: # %bb.2: +; RV64-NEXT: ret +; RV64-NEXT: .LBB5_3: +; RV64-NEXT: li a0, 0 +; RV64-NEXT: ret + %x = icmp sge i32 %0, 0 + br i1 %x, label %nonnegative, label %5 + +nonnegative: ; preds = %1 + %2 = icmp slt i32 %0, 13 + br i1 %2, label %5, label %3 + +3: ; preds = %nonnegative + %4 = add i32 %0, -13 + ret i32 %4 + +5: ; preds = %nonnegative, %1 + ret i32 0 +}