diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h --- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -52,6 +52,9 @@ // New instructions receive a name to identify them with the current pass. const char *IVName; + /// Indicates whether LCSSA phis should be created for inserted values. + bool PreserveLCSSA; + // InsertedExpressions caches Values for reuse, so must track RAUW. DenseMap, TrackingVH> InsertedExpressions; @@ -146,9 +149,10 @@ public: /// Construct a SCEVExpander in "canonical" mode. explicit SCEVExpander(ScalarEvolution &se, const DataLayout &DL, - const char *name) - : SE(se), DL(DL), IVName(name), IVIncInsertLoop(nullptr), - IVIncInsertPos(nullptr), CanonicalMode(true), LSRMode(false), + const char *name, bool PreserveLCSSA = true) + : SE(se), DL(DL), IVName(name), PreserveLCSSA(PreserveLCSSA), + IVIncInsertLoop(nullptr), IVIncInsertPos(nullptr), CanonicalMode(true), + LSRMode(false), Builder(se.getContext(), TargetFolder(DL), IRBuilderCallbackInserter( [this](Instruction *I) { rememberInstruction(I); })) { @@ -223,14 +227,18 @@ const TargetTransformInfo *TTI = nullptr); /// Insert code to directly compute the specified SCEV expression into the - /// program. The inserted code is inserted into the specified block. - Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I); + /// program. The code is inserted into the specified block. + Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I) { + return expandCodeForImpl(SH, Ty, I, true); + } /// Insert code to directly compute the specified SCEV expression into the - /// program. The inserted code is inserted into the SCEVExpander's current + /// program. The code is inserted into the SCEVExpander's current /// insertion point. If a type is specified, the result will be expanded to /// have that type, with a cast if necessary. - Value *expandCodeFor(const SCEV *SH, Type *Ty = nullptr); + Value *expandCodeFor(const SCEV *SH, Type *Ty = nullptr) { + return expandCodeForImpl(SH, Ty, true); + } /// Generates a code sequence that evaluates this predicate. The inserted /// instructions will be at position \p Loc. The result will be of type i1 @@ -338,6 +346,20 @@ private: LLVMContext &getContext() const { return SE.getContext(); } + /// Insert code to directly compute the specified SCEV expression into the + /// program. The code is inserted into the SCEVExpander's current + /// insertion point. If a type is specified, the result will be expanded to + /// have that type, with a cast if necessary. If \p Root is true, this + /// indicates that \p SH is the top-level expression to expand passed from + /// an external client call. + Value *expandCodeForImpl(const SCEV *SH, Type *Ty, bool Root); + + /// Insert code to directly compute the specified SCEV expression into the + /// program. The code is inserted into the specified block. If \p + /// Root is true, this indicates that \p SH is the top-level expression to + /// expand passed from an external client call. + Value *expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *I, bool Root); + /// Recursive helper function for isHighCostExpansion. bool isHighCostExpansionHelper(const SCEV *S, Loop *L, const Instruction &At, int &BudgetRemaining, @@ -419,6 +441,11 @@ Instruction *Pos, PHINode *LoopPhi); void fixupInsertPoints(Instruction *I); + + /// If required, create LCSSA PHIs for \p Users' operand \p OpIdx. If new + /// LCSSA PHIs have been created, return the LCSSA PHI available at \p User. + /// If no PHIs have been created, return the unchanged operand \p OpIdx. + Value *fixupLCSSAFormFor(Instruction *User, unsigned OpIdx); }; } // namespace llvm diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -5514,8 +5514,8 @@ // we can remove them after we are done working. SmallVector DeadInsts; - SCEVExpander Rewriter(SE, L->getHeader()->getModule()->getDataLayout(), - "lsr"); + SCEVExpander Rewriter(SE, L->getHeader()->getModule()->getDataLayout(), "lsr", + false); #ifndef NDEBUG Rewriter.setDebugType(DEBUG_TYPE); #endif @@ -5780,7 +5780,7 @@ if (EnablePhiElim && L->isLoopSimplifyForm()) { SmallVector DeadInsts; const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); - SCEVExpander Rewriter(SE, DL, "lsr"); + SCEVExpander Rewriter(SE, DL, "lsr", false); #ifndef NDEBUG Rewriter.setDebugType(DEBUG_TYPE); #endif diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -27,6 +27,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -461,9 +462,10 @@ // we didn't find any operands that could be factored, tentatively // assume that element zero was selected (since the zero offset // would obviously be folded away). - Value *Scaled = ScaledOps.empty() ? - Constant::getNullValue(Ty) : - expandCodeFor(SE.getAddExpr(ScaledOps), Ty); + Value *Scaled = + ScaledOps.empty() + ? Constant::getNullValue(Ty) + : expandCodeForImpl(SE.getAddExpr(ScaledOps), Ty, false); GepIndices.push_back(Scaled); // Collect struct field index operands. @@ -522,7 +524,7 @@ SE.DT.dominates(cast(V), &*Builder.GetInsertPoint())); // Expand the operands for a plain byte offset. - Value *Idx = expandCodeFor(SE.getAddExpr(Ops), Ty); + Value *Idx = expandCodeForImpl(SE.getAddExpr(Ops), Ty, false); // Fold a GEP with constant operands. if (Constant *CLHS = dyn_cast(V)) @@ -743,14 +745,14 @@ Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, expand(Op)); } else if (Op->isNonConstantNegative()) { // Instead of doing a negate and add, just do a subtract. - Value *W = expandCodeFor(SE.getNegativeSCEV(Op), Ty); + Value *W = expandCodeForImpl(SE.getNegativeSCEV(Op), Ty, false); Sum = InsertNoopCastOfTo(Sum, Ty); Sum = InsertBinop(Instruction::Sub, Sum, W, SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true); ++I; } else { // A simple add. - Value *W = expandCodeFor(Op, Ty); + Value *W = expandCodeForImpl(Op, Ty, false); Sum = InsertNoopCastOfTo(Sum, Ty); // Canonicalize a constant to the RHS. if (isa(Sum)) std::swap(Sum, W); @@ -802,7 +804,7 @@ // Calculate powers with exponents 1, 2, 4, 8 etc. and include those of them // that are needed into the result. - Value *P = expandCodeFor(I->second, Ty); + Value *P = expandCodeForImpl(I->second, Ty, false); Value *Result = nullptr; if (Exponent & 1) Result = P; @@ -861,7 +863,7 @@ Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); - Value *LHS = expandCodeFor(S->getLHS(), Ty); + Value *LHS = expandCodeForImpl(S->getLHS(), Ty, false); if (const SCEVConstant *SC = dyn_cast(S->getRHS())) { const APInt &RHS = SC->getAPInt(); if (RHS.isPowerOf2()) @@ -870,7 +872,7 @@ SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true); } - Value *RHS = expandCodeFor(S->getRHS(), Ty); + Value *RHS = expandCodeForImpl(S->getRHS(), Ty, false); return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap, /*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS())); } @@ -1265,8 +1267,9 @@ // Expand code for the start value into the loop preheader. assert(L->getLoopPreheader() && "Can't expand add recurrences without a loop preheader!"); - Value *StartV = expandCodeFor(Normalized->getStart(), ExpandTy, - L->getLoopPreheader()->getTerminator()); + Value *StartV = + expandCodeForImpl(Normalized->getStart(), ExpandTy, + L->getLoopPreheader()->getTerminator(), false); // StartV must have been be inserted into L's preheader to dominate the new // phi. @@ -1284,8 +1287,8 @@ if (useSubtract) Step = SE.getNegativeSCEV(Step); // Expand the step somewhere that dominates the loop header. - Value *StepV = expandCodeFor(Step, IntTy, - &*L->getHeader()->getFirstInsertionPt()); + Value *StepV = expandCodeForImpl( + Step, IntTy, &*L->getHeader()->getFirstInsertionPt(), false); // The no-wrap behavior proved by IsIncrement(NUW|NSW) is only applicable if // we actually do emit an addition. It does not apply if we emit a @@ -1430,8 +1433,8 @@ { // Expand the step somewhere that dominates the loop header. SCEVInsertPointGuard Guard(Builder, this); - StepV = expandCodeFor(Step, IntTy, - &*L->getHeader()->getFirstInsertionPt()); + StepV = expandCodeForImpl( + Step, IntTy, &*L->getHeader()->getFirstInsertionPt(), false); } Result = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract); } @@ -1450,8 +1453,8 @@ // Invert the result. if (InvertStep) - Result = Builder.CreateSub(expandCodeFor(Normalized->getStart(), TruncTy), - Result); + Result = Builder.CreateSub( + expandCodeForImpl(Normalized->getStart(), TruncTy, false), Result); } // Re-apply any non-loop-dominating scale. @@ -1459,22 +1462,22 @@ assert(S->isAffine() && "Can't linearly scale non-affine recurrences."); Result = InsertNoopCastOfTo(Result, IntTy); Result = Builder.CreateMul(Result, - expandCodeFor(PostLoopScale, IntTy)); + expandCodeForImpl(PostLoopScale, IntTy, false)); } // Re-apply any non-loop-dominating offset. if (PostLoopOffset) { if (PointerType *PTy = dyn_cast(ExpandTy)) { if (Result->getType()->isIntegerTy()) { - Value *Base = expandCodeFor(PostLoopOffset, ExpandTy); + Value *Base = expandCodeForImpl(PostLoopOffset, ExpandTy, false); Result = expandAddToGEP(SE.getUnknown(Result), PTy, IntTy, Base); } else { Result = expandAddToGEP(PostLoopOffset, PTy, IntTy, Result); } } else { Result = InsertNoopCastOfTo(Result, IntTy); - Result = Builder.CreateAdd(Result, - expandCodeFor(PostLoopOffset, IntTy)); + Result = Builder.CreateAdd( + Result, expandCodeForImpl(PostLoopOffset, IntTy, false)); } } @@ -1516,8 +1519,8 @@ S->getNoWrapFlags(SCEV::FlagNW))); BasicBlock::iterator NewInsertPt = findInsertPointAfter(cast(V), Builder.GetInsertBlock()); - V = expandCodeFor(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr, - &*NewInsertPt); + V = expandCodeForImpl(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr, + &*NewInsertPt, false); return V; } @@ -1632,22 +1635,25 @@ Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); - Value *V = expandCodeFor(S->getOperand(), - SE.getEffectiveSCEVType(S->getOperand()->getType())); + Value *V = expandCodeForImpl( + S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()), + false); return Builder.CreateTrunc(V, Ty); } Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); - Value *V = expandCodeFor(S->getOperand(), - SE.getEffectiveSCEVType(S->getOperand()->getType())); + Value *V = expandCodeForImpl( + S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()), + false); return Builder.CreateZExt(V, Ty); } Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); - Value *V = expandCodeFor(S->getOperand(), - SE.getEffectiveSCEVType(S->getOperand()->getType())); + Value *V = expandCodeForImpl( + S->getOperand(), SE.getEffectiveSCEVType(S->getOperand()->getType()), + false); return Builder.CreateSExt(V, Ty); } @@ -1662,7 +1668,7 @@ Ty = SE.getEffectiveSCEVType(Ty); LHS = InsertNoopCastOfTo(LHS, Ty); } - Value *RHS = expandCodeFor(S->getOperand(i), Ty); + Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false); Value *ICmp = Builder.CreateICmpSGT(LHS, RHS); Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smax"); LHS = Sel; @@ -1685,7 +1691,7 @@ Ty = SE.getEffectiveSCEVType(Ty); LHS = InsertNoopCastOfTo(LHS, Ty); } - Value *RHS = expandCodeFor(S->getOperand(i), Ty); + Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false); Value *ICmp = Builder.CreateICmpUGT(LHS, RHS); Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umax"); LHS = Sel; @@ -1708,7 +1714,7 @@ Ty = SE.getEffectiveSCEVType(Ty); LHS = InsertNoopCastOfTo(LHS, Ty); } - Value *RHS = expandCodeFor(S->getOperand(i), Ty); + Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false); Value *ICmp = Builder.CreateICmpSLT(LHS, RHS); Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smin"); LHS = Sel; @@ -1731,7 +1737,7 @@ Ty = SE.getEffectiveSCEVType(Ty); LHS = InsertNoopCastOfTo(LHS, Ty); } - Value *RHS = expandCodeFor(S->getOperand(i), Ty); + Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false); Value *ICmp = Builder.CreateICmpULT(LHS, RHS); Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umin"); LHS = Sel; @@ -1743,15 +1749,43 @@ return LHS; } -Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty, - Instruction *IP) { +Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, + Instruction *IP, bool Root) { setInsertPoint(IP); - return expandCodeFor(SH, Ty); + Value *V = expandCodeForImpl(SH, Ty, Root); + return V; } -Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) { +Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, bool Root) { // Expand the code for this SCEV. Value *V = expand(SH); + + if (PreserveLCSSA) { + if (auto *Inst = dyn_cast(V)) { + // Create a temporary instruction to at the current insertion point, so we + // can hand it off to the helper to create LCSSA PHIs if required for the + // new use. + // FIXME: Ideally formLCSSAForInstructions (used in fixupLCSSAFormFor) + // would accept a insertion point and return an LCSSA phi for that + // insertion point, so there is no need to insert & remove the temporary + // instruction. + Instruction *Tmp; + if (Inst->getType()->isIntegerTy()) + Tmp = cast(Builder.CreateAdd(Inst, Inst)); + else { + assert(Inst->getType()->isPointerTy()); + Tmp = cast(Builder.CreateGEP(Inst, Builder.getInt32(1))); + } + V = fixupLCSSAFormFor(Tmp, 0); + + // Clean up temporary instruction. + InsertedValues.erase(Tmp); + InsertedPostIncValues.erase(Tmp); + Tmp->eraseFromParent(); + } + } + + InsertedExpressions[std::make_pair(SH, &*Builder.GetInsertPoint())] = V; if (Ty) { assert(SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(SH->getType()) && "non-trivial casts should be done with the SCEVs directly!"); @@ -1891,10 +1925,28 @@ } void SCEVExpander::rememberInstruction(Value *I) { - if (!PostIncLoops.empty()) - InsertedPostIncValues.insert(I); - else - InsertedValues.insert(I); + auto DoInsert = [this](Value *V) { + if (!PostIncLoops.empty()) + InsertedPostIncValues.insert(V); + else + InsertedValues.insert(V); + }; + DoInsert(I); + + if (!PreserveLCSSA) + return; + + if (auto *Inst = dyn_cast(I)) { + // A new instruction has been added, which might introduce new uses outside + // a defining loop. Fix LCSSA from for each operand of the new instruction, + // if required. + for (unsigned OpIdx = 0, OpEnd = Inst->getNumOperands(); OpIdx != OpEnd; + OpIdx++) { + auto *V = fixupLCSSAFormFor(Inst, OpIdx); + if (V != I) + DoInsert(V); + } + } } /// getOrInsertCanonicalInductionVariable - This method returns the @@ -1913,9 +1965,8 @@ // Emit code for it. SCEVInsertPointGuard Guard(Builder, this); - PHINode *V = - cast(expandCodeFor(H, nullptr, - &*L->getHeader()->getFirstInsertionPt())); + PHINode *V = cast(expandCodeForImpl( + H, nullptr, &*L->getHeader()->getFirstInsertionPt(), false)); return V; } @@ -2315,8 +2366,10 @@ Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred, Instruction *IP) { - Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP); - Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP); + Value *Expr0 = + expandCodeForImpl(Pred->getLHS(), Pred->getLHS()->getType(), IP, false); + Value *Expr1 = + expandCodeForImpl(Pred->getRHS(), Pred->getRHS()->getType(), IP, false); Builder.SetInsertPoint(IP); auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check"); @@ -2348,15 +2401,16 @@ IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits); Builder.SetInsertPoint(Loc); - Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc); + Value *TripCountVal = expandCodeForImpl(ExitCount, CountTy, Loc, false); IntegerType *Ty = IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(ARTy)); Type *ARExpandTy = DL.isNonIntegralPointerType(ARTy) ? ARTy : Ty; - Value *StepValue = expandCodeFor(Step, Ty, Loc); - Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc); - Value *StartValue = expandCodeFor(Start, ARExpandTy, Loc); + Value *StepValue = expandCodeForImpl(Step, Ty, Loc, false); + Value *NegStepValue = + expandCodeForImpl(SE.getNegativeSCEV(Step), Ty, Loc, false); + Value *StartValue = expandCodeForImpl(Start, ARExpandTy, Loc, false); ConstantInt *Zero = ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits)); @@ -2459,6 +2513,25 @@ return Check; } +Value *SCEVExpander::fixupLCSSAFormFor(Instruction *User, unsigned OpIdx) { + assert(PreserveLCSSA); + SmallVector ToUpdate; + + auto *OpV = User->getOperand(OpIdx); + auto *OpI = dyn_cast(OpV); + if (!OpI) + return OpV; + + Loop *DefLoop = SE.LI.getLoopFor(OpI->getParent()); + Loop *UseLoop = SE.LI.getLoopFor(User->getParent()); + if (!DefLoop || UseLoop == DefLoop || DefLoop->contains(UseLoop)) + return OpV; + + ToUpdate.push_back(OpI); + formLCSSAForInstructions(ToUpdate, SE.DT, SE.LI, &SE); + return User->getOperand(OpIdx); +} + namespace { // Search for a SCEV subexpression that is not safe to expand. Any expression // that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely diff --git a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp --- a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp +++ b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp @@ -265,7 +265,7 @@ Phi->addIncoming(Add, L); Builder.SetInsertPoint(Post); - Builder.CreateRetVoid(); + Instruction *Ret = Builder.CreateRetVoid(); ScalarEvolution SE = buildSE(*F); const SCEV *S = SE.getSCEV(Phi); @@ -276,6 +276,11 @@ EXPECT_FALSE(isSafeToExpandAt(AR, LPh->getTerminator(), SE)); EXPECT_TRUE(isSafeToExpandAt(AR, L->getTerminator(), SE)); EXPECT_TRUE(isSafeToExpandAt(AR, Post->getTerminator(), SE)); + + EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT)); + SCEVExpander Exp(SE, M.getDataLayout(), "expander"); + Exp.expandCodeFor(SE.getSCEV(Add), nullptr, Ret); + EXPECT_TRUE(LI->getLoopFor(L)->isLCSSAForm(*DT)); } // Check that SCEV expander does not use the nuw instruction