diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -241,6 +241,7 @@ void initializeLoopInfoWrapperPassPass(PassRegistry&); void initializeLoopInstSimplifyLegacyPassPass(PassRegistry&); void initializeLoopInterchangeLegacyPassPass(PassRegistry &); +void initializeLoopIntWrapPredicationLegacyPassPass(PassRegistry &); void initializeLoopFlattenLegacyPassPass(PassRegistry&); void initializeLoopLoadEliminationPass(PassRegistry&); void initializeLoopPassPass(PassRegistry&); diff --git a/llvm/include/llvm/Transforms/Scalar.h b/llvm/include/llvm/Transforms/Scalar.h --- a/llvm/include/llvm/Transforms/Scalar.h +++ b/llvm/include/llvm/Transforms/Scalar.h @@ -149,6 +149,13 @@ // Pass *createLoopPredicationPass(); +//===----------------------------------------------------------------------===// +// +// LoopIntWrapPredication - This pass does predication on overflowing binary +// operators inside loops. +// +Pass *createLoopIntWrapPredicationPass(); + //===----------------------------------------------------------------------===// // // LoopInterchange - This pass interchanges loops to provide a more diff --git a/llvm/include/llvm/Transforms/Scalar/LoopIntWrapPredication.h b/llvm/include/llvm/Transforms/Scalar/LoopIntWrapPredication.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Transforms/Scalar/LoopIntWrapPredication.h @@ -0,0 +1,22 @@ +#ifndef LLVM_TRANSFORMS_SCALAR_LOOPINTWRAPPREDICATION_H +#define LLVM_TRANSFORMS_SCALAR_LOOPINTWRAPPREDICATION_H + +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +class Loop; +class LPMUpdater; + +/// Performs Loop Integer Wrapping Predication Pass. +class LoopIntWrapPredicationPass + : public PassInfoMixin { +public: + PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, LPMUpdater &U); +}; + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_SCALAR_LOOPINTWRAPPREDICATION_H diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -174,6 +174,7 @@ #include "llvm/Transforms/Scalar/LoopFuse.h" #include "llvm/Transforms/Scalar/LoopIdiomRecognize.h" #include "llvm/Transforms/Scalar/LoopInstSimplify.h" +#include "llvm/Transforms/Scalar/LoopIntWrapPredication.h" #include "llvm/Transforms/Scalar/LoopInterchange.h" #include "llvm/Transforms/Scalar/LoopLoadElimination.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -93,6 +93,7 @@ #include "llvm/Transforms/Scalar/LoopFlatten.h" #include "llvm/Transforms/Scalar/LoopIdiomRecognize.h" #include "llvm/Transforms/Scalar/LoopInstSimplify.h" +#include "llvm/Transforms/Scalar/LoopIntWrapPredication.h" #include "llvm/Transforms/Scalar/LoopInterchange.h" #include "llvm/Transforms/Scalar/LoopLoadElimination.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" @@ -206,6 +207,7 @@ extern cl::opt EnableIROutliner; extern cl::opt EnableOrderFileInstrumentation; extern cl::opt EnableCHR; +extern cl::opt EnableLoopIntWrapPredication; extern cl::opt EnableLoopInterchange; extern cl::opt EnableUnrollAndJam; extern cl::opt EnableLoopFlatten; @@ -487,6 +489,8 @@ // TODO: Investigate promotion cap for O1. LPM1.addPass(LICMPass(PTO.LicmMssaOptCap, PTO.LicmMssaNoAccForPromotionCap, /*AllowSpeculation=*/true)); + if (EnableLoopIntWrapPredication) + LPM1.addPass(LoopIntWrapPredicationPass()); LPM1.addPass( SimpleLoopUnswitchPass(/* NonTrivial */ Level == OptimizationLevel::O3 && EnableO3NonTrivialUnswitching)); diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -526,6 +526,7 @@ LOOP_PASS("loop-bound-split", LoopBoundSplitPass()) LOOP_PASS("loop-reroll", LoopRerollPass()) LOOP_PASS("loop-versioning-licm", LoopVersioningLICMPass()) +LOOP_PASS("loop-int-wrap-predication", LoopIntWrapPredicationPass()) #undef LOOP_PASS #ifndef LOOP_PASS_WITH_PARAMS diff --git a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp --- a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -75,6 +75,10 @@ clEnumValN(::CFLAAType::Both, "both", "Enable both variants of CFL-AA"))); +cl::opt EnableLoopIntWrapPredication( + "enable-loop-int-wrap-predication", cl::init(false), cl::Hidden, + cl::desc("Enable Loop Integer Wrapping Predication Pass")); + cl::opt EnableLoopInterchange( "enable-loopinterchange", cl::init(false), cl::Hidden, cl::desc("Enable the experimental LoopInterchange Pass")); diff --git a/llvm/lib/Transforms/Scalar/CMakeLists.txt b/llvm/lib/Transforms/Scalar/CMakeLists.txt --- a/llvm/lib/Transforms/Scalar/CMakeLists.txt +++ b/llvm/lib/Transforms/Scalar/CMakeLists.txt @@ -34,6 +34,7 @@ LoopFuse.cpp LoopIdiomRecognize.cpp LoopInstSimplify.cpp + LoopIntWrapPredication.cpp LoopInterchange.cpp LoopFlatten.cpp LoopLoadElimination.cpp diff --git a/llvm/lib/Transforms/Scalar/LoopIntWrapPredication.cpp b/llvm/lib/Transforms/Scalar/LoopIntWrapPredication.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Transforms/Scalar/LoopIntWrapPredication.cpp @@ -0,0 +1,340 @@ +#include "llvm/Transforms/Scalar/LoopIntWrapPredication.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/PassManager.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" + +using namespace llvm; + +#define DEBUG_TYPE "loop-int-wrap-predication" + +static const char *LoopIntWrapPredicationMetadata = + "llvm.loop.int_wrap_predication.disable"; + +static cl::opt + OverflowProbability("overflow-probability", cl::Hidden, cl::init(0.01), + cl::desc("Weight of branch with overflow condition")); + +STATISTIC(NumPredicatedChains, "Number of predicated chains"); + +// Arithmetic chain - chain of arithmetic instructions (add/sub/mul/shl) that +// starts from induction variable and finishes on zext; this zext can be +// eliminated by widening of induction variable in case when instructions in +// this chain will not overflow. +class ArithmeticChain { +public: + ArithmeticChain(const Loop &L, const Use &U); + ArithmeticChain(const ArithmeticChain &) = delete; + ArithmeticChain(ArithmeticChain &&) = default; + ArithmeticChain &operator=(const ArithmeticChain &) = delete; + ArithmeticChain &operator=(ArithmeticChain &&) = default; + + Value *generateOverflowCheck(ScalarEvolution &SE, Instruction *Loc) const; + void doPredication(Loop &L, Value *Cond, DominatorTree &DT, LoopInfo &LI, + MemorySSAUpdater *MSSAU) const; + + bool empty() const { return Chain.empty() || !ZExt; } + + void print(raw_ostream &OS) const { + OS << "Chain:\n"; + for (auto *Inst : Chain) + OS << *Inst << '\n'; + OS << "ZExt:\n" << *ZExt << '\n'; + } + + void dump() const { print(dbgs()); } + +private: + SmallVector Chain; + ZExtInst *ZExt = nullptr; +}; + +static Instruction *getChainInst(const Loop &L, const Use &U) { + auto *Inst = dyn_cast(U.getUser()); + if (!Inst) + return nullptr; + unsigned Opc = Inst->getOpcode(); + if (Opc != Instruction::Add && Opc != Instruction::Sub && + Opc != Instruction::Mul) + return nullptr; + assert(Inst->getNumOperands() == 2); + const Value *OtherOperand = Inst->getOperand(1 - U.getOperandNo()); + if (!L.isLoopInvariant(OtherOperand)) + return nullptr; + return Inst; +} + +ArithmeticChain::ArithmeticChain(const Loop &L, const Use &U) { + const Use *CurUse = &U; + while (Instruction *Inst = getChainInst(L, *CurUse)) { + if (!Inst->hasOneUse()) + break; + Chain.push_back(Inst); + CurUse = &*Inst->use_begin(); + } + ZExt = dyn_cast(CurUse->getUser()); +} + +struct AffineExpr { + const SCEV *Start; + const SCEV *TripCount; +}; + +// For the reccurence chain AR, that represents an expression inside some number +// of nested loops, determine if it can be flattened to the form: +// Start + i * RequiredStep, where i = 0 .. TripCount-1 +// and return Start and TripCount parameters if it is possible. +static Optional getFlattenedAffineExpr(ScalarEvolution &SE, + const SCEVAddRecExpr *AR, + const SCEV *RequiredStep) { + if (!AR->isAffine()) + return None; + auto *Start = AR->getStart(); + auto *Step = AR->getStepRecurrence(SE); + auto *L = AR->getLoop(); + auto *TripCount = SE.getBackedgeTakenCount(L); + if (isa(TripCount)) + return None; + // In rotated loop form number of header executions is one more than number of + // taken back edges. + if (L->isRotatedForm()) + TripCount = SE.getTripCountFromExitCount(TripCount, false); + auto *StartAR = dyn_cast(Start); + if (!StartAR) { + if (Step == RequiredStep) + return AffineExpr{Start, TripCount}; + return None; + } + // Linear case - parent loop should have Stride = TripCount * Step. + if (Step == RequiredStep) { + if (auto Res = + getFlattenedAffineExpr(SE, StartAR, SE.getMulExpr(TripCount, Step))) + return AffineExpr{Res->Start, SE.getMulExpr(TripCount, Res->TripCount)}; + return None; + } + // Possibly transposed case - parent loop should have required stride and its + // TripCount * RequiredStride = Stride of current loop. + if (auto Res = getFlattenedAffineExpr(SE, StartAR, RequiredStep)) + if (SE.getMulExpr(Res->TripCount, RequiredStep) == Step) + return AffineExpr{Res->Start, SE.getMulExpr(TripCount, Res->TripCount)}; + return None; +} + +Value *ArithmeticChain::generateOverflowCheck(ScalarEvolution &SE, + Instruction *Loc) const { + assert(!Chain.empty()); + auto *AR = dyn_cast(SE.getSCEV(Chain.back())); + if (!AR || AR->hasNoUnsignedWrap()) + return nullptr; + LLVM_DEBUG(dbgs() << "SCEV: " << *AR << '\n'); + auto *Ty = AR->getType(); + // For the simplicity of overflow check, process the most common case with + // Step = 1. + auto Res = getFlattenedAffineExpr(SE, AR, SE.getOne(Ty)); + if (!Res) + return nullptr; + LLVM_DEBUG(dbgs() << "Start: " << *Res->Start + << " TripCount: " << *Res->TripCount << '\n'); + // Generate overflow checking code. + SmallVector OverflowChecks; + SCEVExpander Expander(SE, SE.getDataLayout(), "overflowcheck"); + Expander.setInsertPoint(Loc); + IRBuilder<> Builder(Loc); + // Check overall TripCount overflow. + Value *MulRes = nullptr; + if (auto *Mul = dyn_cast(Res->TripCount)) { + MulRes = Expander.expandCodeFor(Mul->getOperand(0)); + for (unsigned i = 1; i < Mul->getNumOperands(); ++i) { + CallInst *MulOverflow = Builder.CreateCall( + Intrinsic::getDeclaration(Loc->getModule(), + Intrinsic::umul_with_overflow, Ty), + {MulRes, Expander.expandCodeFor(Mul->getOperand(i))}); + MulRes = Builder.CreateExtractValue(MulOverflow, 0); + OverflowChecks.push_back(Builder.CreateExtractValue(MulOverflow, 1)); + } + } else + MulRes = Expander.expandCodeFor(Res->TripCount); + // Check overflow of Start + TripCount. + CallInst *AddOverflow = Builder.CreateCall( + Intrinsic::getDeclaration(Loc->getModule(), Intrinsic::uadd_with_overflow, + Ty), + {Expander.expandCodeFor(Res->Start), MulRes}); + OverflowChecks.push_back(Builder.CreateExtractValue(AddOverflow, 1)); + return Builder.CreateOr(OverflowChecks); +} + +void ArithmeticChain::doPredication(Loop &L, Value *Cond, DominatorTree &DT, + LoopInfo &LI, + MemorySSAUpdater *MSSAU) const { + // Create if-then-else CFG. + BasicBlock *HeadBB = ZExt->getParent(); + BasicBlock *TailBB = SplitBlock(HeadBB, ZExt, &DT, &LI, MSSAU); + LLVMContext &C = HeadBB->getContext(); + BasicBlock *ThenBB = BasicBlock::Create(C, ZExt->getName() + ".then", + HeadBB->getParent(), TailBB); + BasicBlock *ElseBB = BasicBlock::Create(C, ZExt->getName() + ".else", + HeadBB->getParent(), TailBB); + BranchInst *NewBr = BranchInst::Create(ThenBB, ElseBB, Cond); + uint32_t OverflowWeight = UINT32_MAX * OverflowProbability; + uint32_t NoOverflowWeight = UINT32_MAX * (1.0 - OverflowProbability); + // Set branch weight metadata with overflow probability, so the following + // passes will consider it as unlikely event. + NewBr->setMetadata( + LLVMContext::MD_prof, + MDBuilder(C).createBranchWeights(OverflowWeight, NoOverflowWeight)); + ReplaceInstWithInst(HeadBB->getTerminator(), NewBr); + if (MSSAU) + MSSAU->applyUpdates(DominatorTree::UpdateType{DT.Delete, HeadBB, TailBB}, + DT); + for (BasicBlock *BranchBB : {ThenBB, ElseBB}) { + BranchInst *Br = BranchInst::Create(TailBB, BranchBB); + Br->setDebugLoc(ZExt->getDebugLoc()); + L.addBasicBlockToLoop(BranchBB, LI); + DT.addNewBlock(BranchBB, HeadBB); + if (MSSAU) + MSSAU->applyUpdates( + {DominatorTree::UpdateType{DT.Insert, HeadBB, BranchBB}, + DominatorTree::UpdateType{DT.Insert, BranchBB, TailBB}}, + DT); + } + // Create versions with no integer wrapping and with possible wrapping. + Instruction *ThenRes = nullptr; + Instruction *ElseRes = nullptr; + for (auto *Inst : Chain) { + Instruction *PrevThenRes = ThenRes; + Instruction *PrevElseRes = ElseRes; + // Then branch (possible wrap version). + ThenRes = Inst; + ThenRes->moveBefore(ThenBB->getTerminator()); + // Then branch (nowrap version). + ElseRes = Inst->clone(); + ElseRes->setName(Inst->getName() + ".nowrap"); + ElseRes->replaceUsesOfWith(PrevThenRes, PrevElseRes); + ElseRes->setHasNoUnsignedWrap(true); + ElseRes->insertBefore(ElseBB->getTerminator()); + } + // Insert Phi node. + auto *Phi = IRBuilder<>(ZExt).CreatePHI(ZExt->getSrcTy(), 2); + Phi->addIncoming(ThenRes, ThenBB); + Phi->addIncoming(ElseRes, ElseBB); + ZExt->setOperand(0, Phi); +} + +static bool LoopIntWrapPredication(Loop &L, DominatorTree &DT, LoopInfo &LI, + ScalarEvolution &SE, MemorySSA *MSSA) { + // Skip loop if it was already proceeded. + if (findStringMetadataForLoop(&L, LoopIntWrapPredicationMetadata)) + return false; + auto *IndVar = L.getInductionVariable(SE); + if (!IndVar) + return false; + LLVM_DEBUG(dbgs() << "Processing loop: " << L.getName() << '\n' + << "Induction variable: " << IndVar->getName() << '\n'); + BasicBlock *Preheader = L.getLoopPreheader(); + assert(Preheader); + Instruction *Loc = Preheader->getTerminator(); + assert(Loc); + SmallVector Chains; + SmallVector OverflowChecks; + // Collect arithmetic chains from IndVar uses. + for (auto &U : IndVar->uses()) { + ArithmeticChain Chain(L, U); + if (Chain.empty()) + continue; + LLVM_DEBUG(Chain.dump()); + // Try to generate overflow check for this chain. + Value *OverflowCheck = Chain.generateOverflowCheck(SE, Loc); + if (!OverflowCheck) + continue; + Chains.push_back(std::move(Chain)); + OverflowChecks.push_back(OverflowCheck); + } + if (Chains.empty()) + return false; + Value *Cond = IRBuilder<>(Loc).CreateOr(OverflowChecks); + Optional MSSAU; + if (MSSA) + MSSAU = MemorySSAUpdater(MSSA); + // Predicate all chains with one common overflow condition. + for (auto &Chain : Chains) { + Chain.doPredication(L, Cond, DT, LI, MSSAU ? MSSAU.getPointer() : nullptr); + ++NumPredicatedChains; + } + // Mark loop to not proceed it again. + addStringMetadataToLoop(&L, LoopIntWrapPredicationMetadata); + if (MSSA && VerifyMemorySSA) + MSSA->verifyMemorySSA(); + return true; +} + +PreservedAnalyses +LoopIntWrapPredicationPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + if (!LoopIntWrapPredication(L, AR.DT, AR.LI, AR.SE, AR.MSSA)) + return PreservedAnalyses::all(); + auto PA = getLoopPassPreservedAnalyses(); + if (AR.MSSA) + PA.preserve(); + return PA; +} + +namespace { + +class LoopIntWrapPredicationLegacyPass : public LoopPass { +public: + static char ID; + + LoopIntWrapPredicationLegacyPass() : LoopPass(ID) { + initializeLoopIntWrapPredicationLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + + auto &DT = getAnalysis().getDomTree(); + auto &LI = getAnalysis().getLoopInfo(); + auto &SE = getAnalysis().getSE(); + auto *MSSAWP = getAnalysisIfAvailable(); + return LoopIntWrapPredication(*L, DT, LI, SE, + MSSAWP ? &MSSAWP->getMSSA() : nullptr); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.addRequired(); + getLoopAnalysisUsage(AU); + } +}; + +} // end anonymous namespace + +char LoopIntWrapPredicationLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(LoopIntWrapPredicationLegacyPass, + "loop-int-wrap-predication", + "Predicate overflowing binary operators", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(LoopIntWrapPredicationLegacyPass, + "loop-int-wrap-predication", + "Predicate overflowing binary operators", false, false) + +Pass *llvm::createLoopIntWrapPredicationPass() { + return new LoopIntWrapPredicationLegacyPass(); +} diff --git a/llvm/lib/Transforms/Scalar/Scalar.cpp b/llvm/lib/Transforms/Scalar/Scalar.cpp --- a/llvm/lib/Transforms/Scalar/Scalar.cpp +++ b/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -66,6 +66,7 @@ initializeLoopAccessLegacyAnalysisPass(Registry); initializeLoopInstSimplifyLegacyPassPass(Registry); initializeLoopInterchangeLegacyPassPass(Registry); + initializeLoopIntWrapPredicationLegacyPassPass(Registry); initializeLoopFlattenLegacyPassPass(Registry); initializeLoopPredicationLegacyPassPass(Registry); initializeLoopRotateLegacyPassPass(Registry); diff --git a/llvm/test/Transforms/LoopIntWrapPredication/2d-array-linear.ll b/llvm/test/Transforms/LoopIntWrapPredication/2d-array-linear.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopIntWrapPredication/2d-array-linear.ll @@ -0,0 +1,92 @@ +; RUN: opt -passes=loop-int-wrap-predication -S < %s | FileCheck %s + +declare i32 @f() + +define void @foo(i32 %N1, i32 %N2, ptr %C) { +; CHECK-LABEL: @foo( +; CHECK: for.body3.lr.ph.us: ; preds = %for.cond1.preheader.us +; CHECK-NEXT: [[MUL:%.*]] = mul i32 %i.015.us, %N1 +; CHECK-NEXT: [[UMUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %N1, i32 %N2) +; CHECK-NEXT: [[UMULRES:%.*]] = extractvalue { i32, i1 } [[UMUL]], 0 +; CHECK-NEXT: [[OVERFLOW1:%.*]] = extractvalue { i32, i1 } [[UMUL]], 1 +; CHECK-NEXT: [[UADD:%.*]] = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 0, i32 [[UMULRES]]) +; CHECK-NEXT: [[OVERFLOW2:%.*]] = extractvalue { i32, i1 } [[UADD]], 1 +; CHECK-NEXT: [[OVERFLOW:%.*]] = or i1 [[OVERFLOW1]], [[OVERFLOW2]] +; CHECK-NEXT: br label %for.body3.us +; CHECK: for.body3.us: +; CHECK-NEXT: [[INDVAR:%.*]] = phi i32 [ 0, %for.body3.lr.ph.us ], [ %inc.us, %for.body3.us.split ] +; CHECK-NEXT: [[CALL:%.*]] = tail call signext i32 @f() +; CHECK-NEXT: br i1 [[OVERFLOW]], label %idxprom.us.then, label %idxprom.us.else +; CHECK-DAG: idxprom.us.then: +; CHECK-NEXT: [[ADD:%.*]] = add i32 [[INDVAR]], [[MUL]] +; CHECK-NEXT: br label %for.body3.us.split +; CHECK-DAG: idxprom.us.else: +; CHECK-NEXT: [[ADD_NOWRAP:%.*]] = add nuw i32 [[INDVAR]], [[MUL]] +; CHECK-NEXT: br label %for.body3.us.split +; CHECK: for.body3.us.split: +; CHECK-NEXT: [[PHI:%.*]] = phi i32 [ [[ADD]], %idxprom.us.then ], [ [[ADD_NOWRAP]], %idxprom.us.else ] +; CHECK-NEXT: [[IDX:%.*]] = zext i32 [[PHI]] to i64 +; CHECK-NEXT: [[PTR:%.*]] = getelementptr inbounds i32, ptr %C, i64 [[IDX]] +; CHECK-NEXT: store i32 [[CALL]], ptr [[PTR]], align 4 +entry: + %cmp14 = icmp ult i32 0, %N2 + br i1 %cmp14, label %for.cond1.preheader.lr.ph, label %for.end6 + +for.cond1.preheader.lr.ph: ; preds = %entry + %cmp212 = icmp ult i32 0, %N1 + br i1 %cmp212, label %for.cond1.preheader.lr.ph.split.us, label %for.cond1.preheader.lr.ph.split + +for.cond1.preheader.lr.ph.split.us: ; preds = %for.cond1.preheader.lr.ph + br label %for.cond1.preheader.us + +for.cond1.preheader.us: ; preds = %for.inc4.us, %for.cond1.preheader.lr.ph.split.us + %i.015.us = phi i32 [ 0, %for.cond1.preheader.lr.ph.split.us ], [ %inc5.us, %for.inc4.us ] + br label %for.body3.lr.ph.us + +for.body3.lr.ph.us: ; preds = %for.cond1.preheader.us + %mul.us = mul i32 %i.015.us, %N1 + br label %for.body3.us + +for.body3.us: ; preds = %for.body3.lr.ph.us, %for.body3.us + %j.013.us = phi i32 [ 0, %for.body3.lr.ph.us ], [ %inc.us, %for.body3.us ] + %call.us = tail call signext i32 @f() + %add.us = add i32 %j.013.us, %mul.us + %idxprom.us = zext i32 %add.us to i64 + %arrayidx.us = getelementptr inbounds i32, ptr %C, i64 %idxprom.us + store i32 %call.us, ptr %arrayidx.us, align 4 + %inc.us = add nuw i32 %j.013.us, 1 + %cmp2.us = icmp ult i32 %inc.us, %N1 + br i1 %cmp2.us, label %for.body3.us, label %for.cond1.for.inc4_crit_edge.us + +for.cond1.for.inc4_crit_edge.us: ; preds = %for.body3.us + br label %for.inc4.us + +for.inc4.us: ; preds = %for.cond1.for.inc4_crit_edge.us + %inc5.us = add i32 %i.015.us, 1 + %cmp.us = icmp ult i32 %inc5.us, %N2 + br i1 %cmp.us, label %for.cond1.preheader.us, label %for.cond.for.end6_crit_edge.split.us + +for.cond.for.end6_crit_edge.split.us: ; preds = %for.inc4.us + br label %for.cond.for.end6_crit_edge + +for.cond1.preheader.lr.ph.split: ; preds = %for.cond1.preheader.lr.ph + br label %for.cond1.preheader + +for.cond1.preheader: ; preds = %for.cond1.preheader.lr.ph.split, %for.inc4 + %i.015 = phi i32 [ 0, %for.cond1.preheader.lr.ph.split ], [ %inc5, %for.inc4 ] + br label %for.inc4 + +for.inc4: ; preds = %for.cond1.preheader + %inc5 = add i32 %i.015, 1 + %cmp = icmp ult i32 %inc5, %N2 + br i1 %cmp, label %for.cond1.preheader, label %for.cond.for.end6_crit_edge.split + +for.cond.for.end6_crit_edge.split: ; preds = %for.inc4 + br label %for.cond.for.end6_crit_edge + +for.cond.for.end6_crit_edge: ; preds = %for.cond.for.end6_crit_edge.split.us, %for.cond.for.end6_crit_edge.split + br label %for.end6 + +for.end6: ; preds = %for.cond.for.end6_crit_edge, %entry + ret void +} diff --git a/llvm/test/Transforms/LoopIntWrapPredication/2d-array-transposed.ll b/llvm/test/Transforms/LoopIntWrapPredication/2d-array-transposed.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopIntWrapPredication/2d-array-transposed.ll @@ -0,0 +1,93 @@ +; RUN: opt -passes=loop-int-wrap-predication -S < %s | FileCheck %s + +declare i32 @f() + +define void @foo(i32 %N1, i32 %N2, ptr %C) { +; CHECK-LABEL: @foo( +; CHECK: for.body3.lr.ph.us: ; preds = %for.cond1.preheader.us +; CHECK-NEXT: [[UMUL:%.*]] = call { i32, i1 } @llvm.umul.with.overflow.i32(i32 %N1, i32 %N2) +; CHECK-NEXT: [[UMULRES:%.*]] = extractvalue { i32, i1 } [[UMUL]], 0 +; CHECK-NEXT: [[OVERFLOW1:%.*]] = extractvalue { i32, i1 } [[UMUL]], 1 +; CHECK-NEXT: [[UADD:%.*]] = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 0, i32 [[UMULRES]]) +; CHECK-NEXT: [[OVERFLOW2:%.*]] = extractvalue { i32, i1 } [[UADD]], 1 +; CHECK-NEXT: [[OVERFLOW:%.*]] = or i1 [[OVERFLOW1]], [[OVERFLOW2]] +; CHECK-NEXT: br label %for.body3.us +; CHECK: for.body3.us: +; CHECK-NEXT: [[INDVAR:%.*]] = phi i32 [ 0, %for.body3.lr.ph.us ], [ %inc.us, %for.body3.us.split ] +; CHECK-NEXT: [[CALL:%.*]] = tail call signext i32 @f() +; CHECK-NEXT: br i1 [[OVERFLOW]], label %idxprom.us.then, label %idxprom.us.else +; CHECK-DAG: idxprom.us.then: +; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[INDVAR]], %N2 +; CHECK-NEXT: [[ADD:%.*]] = add i32 %i.015.us, [[MUL]] +; CHECK-NEXT: br label %for.body3.us.split +; CHECK-DAG: idxprom.us.else: +; CHECK-NEXT: [[MUL_NOWRAP:%.*]] = mul nuw i32 [[INDVAR]], %N2 +; CHECK-NEXT: [[ADD_NOWRAP:%.*]] = add nuw i32 %i.015.us, [[MUL_NOWRAP]] +; CHECK-NEXT: br label %for.body3.us.split +; CHECK: for.body3.us.split: +; CHECK-NEXT: [[PHI:%.*]] = phi i32 [ [[ADD]], %idxprom.us.then ], [ [[ADD_NOWRAP]], %idxprom.us.else ] +; CHECK-NEXT: [[IDX:%.*]] = zext i32 [[PHI]] to i64 +; CHECK-NEXT: [[PTR:%.*]] = getelementptr inbounds i32, ptr %C, i64 [[IDX]] +; CHECK-NEXT: store i32 [[CALL]], ptr [[PTR]], align 4 +entry: + %cmp14 = icmp ult i32 0, %N2 + br i1 %cmp14, label %for.cond1.preheader.lr.ph, label %for.end6 + +for.cond1.preheader.lr.ph: ; preds = %entry + %cmp212 = icmp ult i32 0, %N1 + br i1 %cmp212, label %for.cond1.preheader.lr.ph.split.us, label %for.cond1.preheader.lr.ph.split + +for.cond1.preheader.lr.ph.split.us: ; preds = %for.cond1.preheader.lr.ph + br label %for.cond1.preheader.us + +for.cond1.preheader.us: ; preds = %for.inc4.us, %for.cond1.preheader.lr.ph.split.us + %i.015.us = phi i32 [ 0, %for.cond1.preheader.lr.ph.split.us ], [ %inc5.us, %for.inc4.us ] + br label %for.body3.lr.ph.us + +for.body3.lr.ph.us: ; preds = %for.cond1.preheader.us + br label %for.body3.us + +for.body3.us: ; preds = %for.body3.lr.ph.us, %for.body3.us + %j.013.us = phi i32 [ 0, %for.body3.lr.ph.us ], [ %inc.us, %for.body3.us ] + %call.us = tail call signext i32 @f() + %mul.us = mul i32 %j.013.us, %N2 + %add.us = add i32 %i.015.us, %mul.us + %idxprom.us = zext i32 %add.us to i64 + %arrayidx.us = getelementptr inbounds i32, ptr %C, i64 %idxprom.us + store i32 %call.us, ptr %arrayidx.us, align 4 + %inc.us = add nuw i32 %j.013.us, 1 + %cmp2.us = icmp ult i32 %inc.us, %N1 + br i1 %cmp2.us, label %for.body3.us, label %for.cond1.for.inc4_crit_edge.us + +for.cond1.for.inc4_crit_edge.us: ; preds = %for.body3.us + br label %for.inc4.us + +for.inc4.us: ; preds = %for.cond1.for.inc4_crit_edge.us + %inc5.us = add i32 %i.015.us, 1 + %cmp.us = icmp ult i32 %inc5.us, %N2 + br i1 %cmp.us, label %for.cond1.preheader.us, label %for.cond.for.end6_crit_edge.split.us + +for.cond.for.end6_crit_edge.split.us: ; preds = %for.inc4.us + br label %for.cond.for.end6_crit_edge + +for.cond1.preheader.lr.ph.split: ; preds = %for.cond1.preheader.lr.ph + br label %for.cond1.preheader + +for.cond1.preheader: ; preds = %for.cond1.preheader.lr.ph.split, %for.inc4 + %i.015 = phi i32 [ 0, %for.cond1.preheader.lr.ph.split ], [ %inc5, %for.inc4 ] + br label %for.inc4 + +for.inc4: ; preds = %for.cond1.preheader + %inc5 = add i32 %i.015, 1 + %cmp = icmp ult i32 %inc5, %N2 + br i1 %cmp, label %for.cond1.preheader, label %for.cond.for.end6_crit_edge.split + +for.cond.for.end6_crit_edge.split: ; preds = %for.inc4 + br label %for.cond.for.end6_crit_edge + +for.cond.for.end6_crit_edge: ; preds = %for.cond.for.end6_crit_edge.split.us, %for.cond.for.end6_crit_edge.split + br label %for.end6 + +for.end6: ; preds = %for.cond.for.end6_crit_edge, %entry + ret void +} diff --git a/llvm/test/Transforms/LoopIntWrapPredication/basic.ll b/llvm/test/Transforms/LoopIntWrapPredication/basic.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopIntWrapPredication/basic.ll @@ -0,0 +1,49 @@ +; RUN: opt -passes=loop-int-wrap-predication -S < %s | FileCheck %s + +declare i32 @f() + +define void @foo(i32 %offset, i32 %N, ptr %C) { +; CHECK-LABEL: @foo( +; CHECK: for.body.lr.ph: +; CHECK-NEXT: [[UADD:%.*]] = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 %offset, i32 %N) +; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i32, i1 } [[UADD]], 1 +; CHECK-NEXT: br label %for.body +; CHECK: for.body: +; CHECK-NEXT: [[INDVAR:%.*]] = phi i32 [ 0, %for.body.lr.ph ], [ %inc, %for.body.split ] +; CHECK-NEXT: [[CALL:%.*]] = tail call signext i32 @f() +; CHECK-NEXT: br i1 [[OVERFLOW]], label %idxprom.then, label %idxprom.else +; CHECK-DAG: idxprom.then: +; CHECK-NEXT: [[ADD:%.*]] = add i32 [[INDVAR]], %offset +; CHECK-NEXT: br label %for.body.split +; CHECK-DAG: idxprom.else: +; CHECK-NEXT: [[ADD_NOWRAP:%.*]] = add nuw i32 [[INDVAR]], %offset +; CHECK-NEXT: br label %for.body.split +; CHECK: for.body.split: +; CHECK-NEXT: [[PHI:%.*]] = phi i32 [ [[ADD]], %idxprom.then ], [ [[ADD_NOWRAP]], %idxprom.else ] +; CHECK-NEXT: [[IDX:%.*]] = zext i32 [[PHI]] to i64 +; CHECK-NEXT: [[PTR:%.*]] = getelementptr inbounds i32, ptr %C, i64 [[IDX]] +; CHECK-NEXT: store i32 [[CALL]], ptr [[PTR]], align 4 +entry: + %cmp3 = icmp ult i32 0, %N + br i1 %cmp3, label %for.body.lr.ph, label %for.cond.cleanup + +for.body.lr.ph: ; preds = %entry + br label %for.body + +for.body: ; preds = %for.body.lr.ph, %for.body + %i.04 = phi i32 [ 0, %for.body.lr.ph ], [ %inc, %for.body ] + %call = tail call signext i32 @f() + %add = add i32 %i.04, %offset + %idxprom = zext i32 %add to i64 + %arrayidx = getelementptr inbounds i32, ptr %C, i64 %idxprom + store i32 %call, ptr %arrayidx, align 4 + %inc = add nuw i32 %i.04, 1 + %cmp = icmp ult i32 %inc, %N + br i1 %cmp, label %for.body, label %for.cond.for.cond.cleanup_crit_edge + +for.cond.for.cond.cleanup_crit_edge: ; preds = %for.body + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %entry, %for.cond.for.cond.cleanup_crit_edge + ret void +} diff --git a/llvm/test/Transforms/LoopIntWrapPredication/non-invariant-trip-count.ll b/llvm/test/Transforms/LoopIntWrapPredication/non-invariant-trip-count.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopIntWrapPredication/non-invariant-trip-count.ll @@ -0,0 +1,51 @@ +; RUN: opt -passes=loop-int-wrap-predication -S < %s | FileCheck %s + +; COM: Linear memory access with variable number of steps in inner loop: +; COM: unsigned offset = 0; +; COM: for (unsigned i = 0; i < N; ++i) { +; COM: offset += i; +; COM: for (unsigned j = 0; j < i + 1; ++j) +; COM: C[offset + j] = f(); +; COM: } + +declare i32 @f() + +define void @foo(i32 %N, ptr %C) { +; COM: not applied, because trip count of inner loop is SCEVCouldNotCompute +; CHECK-LABEL: @foo( +; CHECK: for.body4: +; CHECK-NEXT: [[INDVAR:%.*]] = phi i32 [ 0, %for.body ], [ %inc, %for.body4 ] +; CHECK-NEXT: [[CALL:%.*]] = tail call signext i32 @f() +; CHECK-NEXT: [[ADD:%.*]] = add i32 [[INDVAR]], %add +; CHECK-NEXT: [[IDX:%.*]] = zext i32 [[ADD]] to i64 +; CHECK-NEXT: [[PTR:%.*]] = getelementptr inbounds i32, ptr %C, i64 [[IDX]] +; CHECK-NEXT: store i32 [[CALL]], ptr [[PTR]], align 4 +entry: + %cmp16 = icmp ult i32 0, %N + br i1 %cmp16, label %for.body, label %for.end8 + +for.cond.loopexit: ; preds = %for.body4 + %add2.le = add nuw i32 %i.017, 1 + %cmp = icmp ult i32 %add2.le, %N + br i1 %cmp, label %for.body, label %for.end8 + +for.body: ; preds = %entry, %for.cond.loopexit + %offset.018 = phi i32 [ %add, %for.cond.loopexit ], [ 0, %entry ] + %i.017 = phi i32 [ %add2.le, %for.cond.loopexit ], [ 0, %entry ] + %add = add i32 %offset.018, %i.017 + br label %for.body4 + +for.body4: ; preds = %for.body, %for.body4 + %j.015 = phi i32 [ 0, %for.body ], [ %inc, %for.body4 ] + %call = tail call signext i32 @f() + %add5 = add i32 %j.015, %add + %idxprom = zext i32 %add5 to i64 + %arrayidx = getelementptr inbounds i32, ptr %C, i64 %idxprom + store i32 %call, ptr %arrayidx, align 4 + %inc = add nuw i32 %j.015, 1 + %cmp3.not = icmp ugt i32 %inc, %i.017 + br i1 %cmp3.not, label %for.cond.loopexit, label %for.body4 + +for.end8: ; preds = %for.cond.loopexit, %entry + ret void +}