diff --git a/llvm/include/llvm/Analysis/IVDescriptors.h b/llvm/include/llvm/Analysis/IVDescriptors.h --- a/llvm/include/llvm/Analysis/IVDescriptors.h +++ b/llvm/include/llvm/Analysis/IVDescriptors.h @@ -33,25 +33,29 @@ /// These are the kinds of recurrences that we support. enum class RecurKind { - None, ///< Not a recurrence. - Add, ///< Sum of integers. - Mul, ///< Product of integers. - Or, ///< Bitwise or logical OR of integers. - And, ///< Bitwise or logical AND of integers. - Xor, ///< Bitwise or logical XOR of integers. - SMin, ///< Signed integer min implemented in terms of select(cmp()). - SMax, ///< Signed integer max implemented in terms of select(cmp()). - UMin, ///< Unisgned integer min implemented in terms of select(cmp()). - UMax, ///< Unsigned integer max implemented in terms of select(cmp()). - FAdd, ///< Sum of floats. - FMul, ///< Product of floats. - FMin, ///< FP min implemented in terms of select(cmp()). - FMax, ///< FP max implemented in terms of select(cmp()). - FMulAdd, ///< Fused multiply-add of floats (a * b + c). - SelectICmp, ///< Integer select(icmp(),x,y) where one of (x,y) is loop - ///< invariant - SelectFCmp ///< Integer select(fcmp(),x,y) where one of (x,y) is loop - ///< invariant + None, ///< Not a recurrence. + Add, ///< Sum of integers. + Mul, ///< Product of integers. + Or, ///< Bitwise or logical OR of integers. + And, ///< Bitwise or logical AND of integers. + Xor, ///< Bitwise or logical XOR of integers. + SMin, ///< Signed integer min implemented in terms of select(cmp()). + SMax, ///< Signed integer max implemented in terms of select(cmp()). + UMin, ///< Unisgned integer min implemented in terms of select(cmp()). + UMax, ///< Unsigned integer max implemented in terms of select(cmp()). + FAdd, ///< Sum of floats. + FMul, ///< Product of floats. + FMin, ///< FP min implemented in terms of select(cmp()). + FMax, ///< FP max implemented in terms of select(cmp()). + FMulAdd, ///< Fused multiply-add of floats (a * b + c). + SelectICmp, ///< Integer select(icmp(),x,y) where one of (x,y) is loop + ///< invariant + SelectFCmp, ///< Integer select(fcmp(),x,y) where one of (x,y) is loop + ///< + SelectMinIdx, ///< Select of the index with minimal value, needs to be + ///< combined with SelectMinIdxMinVal + SelectMinIdxMinVal, + }; /// The RecurrenceDescriptor is used to identify recurrences variables in a @@ -123,7 +127,7 @@ /// the returned struct. static InstDesc isRecurrenceInstr(Loop *L, PHINode *Phi, Instruction *I, RecurKind Kind, InstDesc &Prev, - FastMathFlags FuncFMF); + FastMathFlags FuncFMF, ScalarEvolution *SE); /// Returns true if instruction I has multiple uses in Insts static bool hasMultipleUsesOf(Instruction *I, @@ -221,7 +225,8 @@ /// Returns true if the recurrence kind is an integer min/max kind. static bool isIntMinMaxRecurrenceKind(RecurKind Kind) { return Kind == RecurKind::UMin || Kind == RecurKind::UMax || - Kind == RecurKind::SMin || Kind == RecurKind::SMax; + Kind == RecurKind::SMin || Kind == RecurKind::SMax || + Kind == RecurKind::SelectMinIdx; } /// Returns true if the recurrence kind is a floating-point min/max kind. diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp --- a/llvm/lib/Analysis/IVDescriptors.cpp +++ b/llvm/lib/Analysis/IVDescriptors.cpp @@ -375,7 +375,7 @@ // type-promoted). if (Cur != Start) { ReduxDesc = - isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF); + isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF, SE); ExactFPMathInst = ExactFPMathInst == nullptr ? ReduxDesc.getExactFPMathInst() : ExactFPMathInst; @@ -762,10 +762,9 @@ return InstDesc(true, SI); } -RecurrenceDescriptor::InstDesc -RecurrenceDescriptor::isRecurrenceInstr(Loop *L, PHINode *OrigPhi, - Instruction *I, RecurKind Kind, - InstDesc &Prev, FastMathFlags FuncFMF) { +RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr( + Loop *L, PHINode *OrigPhi, Instruction *I, RecurKind Kind, InstDesc &Prev, + FastMathFlags FuncFMF, ScalarEvolution *SE) { assert(Prev.getRecKind() == RecurKind::None || Prev.getRecKind() == Kind); switch (I->getOpcode()) { default: @@ -828,6 +827,99 @@ return false; } +/// Check if \p Phi is a phi node selecting the index of a minimum value in the +/// loop. It requires a second phi node that selects a minimum value in the +/// loop. \p Phi must select the current IV, if the corresponding value of the +/// current iteration is less than the current minimum value and \p Phi +/// otherwise. +static Instruction *isMinIdxReductionMinVal(PHINode *Phi, Loop *TheLoop, + ScalarEvolution &SE) { + auto IsUMin = [Phi, TheLoop](Value *UMinU) -> bool { + if (!match(UMinU, + m_Intrinsic(m_Specific(Phi), m_Value())) && + !match(UMinU, m_Intrinsic(m_Value(), m_Specific(Phi)))) + return false; + + if (Phi->getIncomingValueForBlock(TheLoop->getLoopLatch()) != UMinU) + return false; + return true; + }; + + if (!Phi->hasNUses(2)) + return nullptr; + + Value *U1 = *Phi->user_begin(); + Value *U2 = *std::next(Phi->user_begin()); + if (IsUMin(U1)) + return cast(U1); + + if (IsUMin(U2)) + return cast(U2); + return nullptr; +} + +static Instruction *isMinIdxReductionMinIdx(PHINode *IdxPhi, Loop *TheLoop, + ScalarEvolution &SE) { + if (!IdxPhi->hasNUses(1)) + return nullptr; + + auto *IdxSel = dyn_cast(*IdxPhi->user_begin()); + if (!IdxSel) + return nullptr; + + if (!IdxSel->getType()->isIntegerTy()) + return nullptr; + + auto *CI = dyn_cast(IdxSel->getOperand(0)); + if (!CI) + return nullptr; + + Value *CmpOp0 = CI->getOperand(0); + Value *CmpOp1 = CI->getOperand(1); + + PHINode *MinPhi = nullptr; + Value *CurVal = nullptr; + Value *UMin = nullptr; + if (auto *P = dyn_cast(CmpOp0)) { + if ((UMin = isMinIdxReductionMinVal(P, TheLoop, SE))) { + MinPhi = P; + CurVal = CmpOp1; + } + } else if (auto *P = dyn_cast(CmpOp1)) { + if ((UMin = isMinIdxReductionMinVal(P, TheLoop, SE))) { + MinPhi = P; + CurVal = CmpOp0; + } + } else + return nullptr; + + if (!UMin || (!match(UMin, m_Intrinsic( + m_Specific(MinPhi), m_Specific(CurVal))) && + !match(UMin, m_Intrinsic(m_Specific(CurVal), + m_Specific(MinPhi))))) + return nullptr; + + if (!((CI->getPredicate() == CmpInst::ICMP_UGT && + CI->getOperand(0) == MinPhi) || + (CI->getPredicate() == CmpInst::ICMP_ULT && + CI->getOperand(1) == MinPhi)) || + !CI->hasOneUse()) + return nullptr; + + Value *U = IdxSel->getOperand(1); + if (auto *T = dyn_cast(U)) + U = T->getOperand(0); + InductionDescriptor ID; + if (!isa(U) || + !InductionDescriptor::isInductionPHI(cast(U), TheLoop, &SE, ID)) + return nullptr; + + if (!ID.getStep()->isOne() || ID.getInductionOpcode() != Instruction::Add) + return nullptr; + + return IdxSel; +} + bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, RecurrenceDescriptor &RedDes, DemandedBits *DB, AssumptionCache *AC, @@ -923,6 +1015,28 @@ LLVM_DEBUG(dbgs() << "Found an FMulAdd reduction PHI." << *Phi << "\n"); return true; } + + if (Instruction *ExitI = isMinIdxReductionMinIdx(Phi, TheLoop, *SE)) { + FastMathFlags FMF; + SmallPtrSet Foo; + RecurrenceDescriptor RD( + Phi->getIncomingValueForBlock(TheLoop->getLoopPreheader()), ExitI, + nullptr, RecurKind::SelectMinIdx, FMF, nullptr, Phi->getType(), false, + false, Foo, -1U); + RedDes = RD; + return true; + } + if (Instruction *ExitI = isMinIdxReductionMinVal(Phi, TheLoop, *SE)) { + FastMathFlags FMF; + SmallPtrSet Foo; + RecurrenceDescriptor RD( + Phi->getIncomingValueForBlock(TheLoop->getLoopPreheader()), ExitI, + nullptr, RecurKind::SelectMinIdxMinVal, FMF, nullptr, Phi->getType(), + false, false, Foo, -1U); + RedDes = RD; + return true; + } + // Not a reduction of known type. return false; } @@ -1107,7 +1221,10 @@ return ConstantFP::get(Tp, 0.0L); return ConstantFP::get(Tp, -0.0L); case RecurKind::UMin: + case RecurKind::SelectMinIdxMinVal: return ConstantInt::get(Tp, -1); + case RecurKind::SelectMinIdx: + llvm_unreachable("should not be called"); case RecurKind::UMax: return ConstantInt::get(Tp, 0); case RecurKind::SMin: @@ -1155,6 +1272,8 @@ case RecurKind::UMax: case RecurKind::UMin: case RecurKind::SelectICmp: + case RecurKind::SelectMinIdxMinVal: + case RecurKind::SelectMinIdx: return Instruction::ICmp; case RecurKind::FMax: case RecurKind::FMin: diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -884,6 +884,8 @@ switch (RK) { default: llvm_unreachable("Unknown min/max recurrence kind"); + case RecurKind::SelectMinIdx: + case RecurKind::SelectMinIdxMinVal: case RecurKind::UMin: return CmpInst::ICMP_ULT; case RecurKind::UMax: @@ -1050,7 +1052,10 @@ case RecurKind::UMax: return Builder.CreateIntMaxReduce(Src, false); case RecurKind::UMin: + case RecurKind::SelectMinIdxMinVal: return Builder.CreateIntMinReduce(Src, false); + case RecurKind::SelectMinIdx: + llvm_unreachable("Unhandled opcode"); case RecurKind::FMax: return Builder.CreateFPMaxReduce(Src); case RecurKind::FMin: diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -912,7 +912,15 @@ if (PrimaryInduction && WidestIndTy != PrimaryInduction->getType()) PrimaryInduction = nullptr; - return true; + unsigned SelectIdx1Cnt = 0; + unsigned SelectIdx2Cnt = 0; + for (auto &KV : Reductions) { + SelectIdx1Cnt += KV.second.getRecurrenceKind() == RecurKind::SelectMinIdx; + SelectIdx2Cnt += + KV.second.getRecurrenceKind() == RecurKind::SelectMinIdxMinVal; + } + return (SelectIdx1Cnt == 0 && SelectIdx2Cnt == 0) || + (SelectIdx1Cnt == 1 && SelectIdx2Cnt == 1); } bool LoopVectorizationLegality::canVectorizeMemory() { diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -3962,6 +3962,7 @@ // Reduce all of the unrolled parts into a single vector. Value *ReducedPartRdx = State.get(LoopExitInstDef, 0); + unsigned Op = RecurrenceDescriptor::getOpcode(RK); // The middle block terminator has already been assigned a DebugLoc here (the @@ -3993,7 +3994,62 @@ // Create the reduction after the loop. Note that inloop reductions create the // target reduction in the loop using a Reduction recipe. - if (VF.isVector() && !PhiR->isInLoop()) { + if (RK == RecurKind::SelectMinIdx) { + const RecurrenceDescriptor *RD2 = nullptr; + VPReductionPHIRecipe *MinP = nullptr; + for (VPRecipeBase &P : PhiR->getParent() + ->getPlan() + ->getVectorLoopRegion() + ->getEntryBasicBlock() + ->phis()) { + auto *RedP = dyn_cast(&P); + if (!RedP) + continue; + const RecurrenceDescriptor &RdxDesc = RedP->getRecurrenceDescriptor(); + if (RdxDesc.getRecurrenceKind() == RecurKind::SelectMinIdxMinVal) { + MinP = RedP; + RD2 = &RdxDesc; + } + } + + VPValue *MinExitDef = MinP->getBackedgeValue(); + VPValue *IndexExitDef = PhiR->getBackedgeValue(); + auto *AllIndices = State.get(IndexExitDef, 0); + auto *AllMins = State.get(MinExitDef, 0); + + if (VF.isScalar()) { + AllIndices = Builder.CreateInsertElement( + PoisonValue::get(FixedVectorType::get(AllIndices->getType(), 1)), + AllIndices, uint64_t(0)); + AllMins = Builder.CreateInsertElement( + PoisonValue::get(FixedVectorType::get(AllMins->getType(), 1)), + AllMins, uint64_t(0)); + } + for (unsigned Part = 1; Part < UF; ++Part) { + Value *IndexExitD = State.get(IndexExitDef, Part); + Value *MinExitD = State.get(MinExitDef, Part); + if (VF.isScalar()) { + IndexExitD = Builder.CreateInsertElement( + PoisonValue::get(FixedVectorType::get(IndexExitD->getType(), 1)), + IndexExitD, uint64_t(0)); + MinExitD = Builder.CreateInsertElement( + PoisonValue::get(FixedVectorType::get(MinExitD->getType(), 1)), + MinExitD, uint64_t(0)); + } + AllIndices = concatenateVectors(Builder, {AllIndices, IndexExitD}); + AllMins = concatenateVectors(Builder, {AllMins, MinExitD}); + } + + Value *MinOfMins = Builder.CreateIntMinReduce(AllMins, false); + Value *Mask = Builder.CreateICmpEQ( + AllMins, Builder.CreateVectorSplat(State.VF * UF, MinOfMins)); + Value *Sel = Builder.CreateSelect( + Mask, AllIndices, + Builder.CreateVectorSplat( + State.VF * UF, + ConstantInt::get(AllIndices->getType()->getScalarType(), -1))); + ReducedPartRdx = Builder.CreateIntMinReduce(Sel, false); + } else if (VF.isVector() && !PhiR->isInLoop()) { ReducedPartRdx = createTargetReduction(Builder, TTI, RdxDesc, ReducedPartRdx, OrigPhi); // If the reduction can be performed in a smaller type, we need to extend @@ -8619,6 +8675,7 @@ Legal->getReductionVars().find(Phi)->second; assert(RdxDesc.getRecurrenceStartValue() == Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader())); + PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi), CM.useOrderedReductions(RdxDesc)); diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -549,6 +549,8 @@ Value *Cond = InvarCond ? InvarCond : State.get(getOperand(0), Part); Value *Op0 = State.get(getOperand(1), Part); Value *Op1 = State.get(getOperand(2), Part); + if (!Op1->getType()->isVectorTy()) + Op1 = State.Builder.CreateVectorSplat(State.VF, Op1); Value *Sel = State.Builder.CreateSelect(Cond, Op0, Op1); State.set(this, Sel, Part); State.addMetadata(Sel, &I); @@ -640,6 +642,9 @@ Builder.setFastMathFlags(Cmp->getFastMathFlags()); C = Builder.CreateFCmp(Cmp->getPredicate(), A, B); } else { + if (!B->getType()->isVectorTy()) { + B = Builder.CreateVectorSplat(State.VF, B); + } C = Builder.CreateICmp(Cmp->getPredicate(), A, B); } State.set(this, C, Part); diff --git a/llvm/test/Transforms/LoopVectorize/select-min-index.ll b/llvm/test/Transforms/LoopVectorize/select-min-index.ll --- a/llvm/test/Transforms/LoopVectorize/select-min-index.ll +++ b/llvm/test/Transforms/LoopVectorize/select-min-index.ll @@ -5,7 +5,7 @@ define i64 @test_vectorize_select_umin_idx(ptr %src) { ; CHECK-LABEL: @test_vectorize_select_umin_idx( -; CHECK-NOT: vector.body: +; CHECK: vector.body: ; entry: br label %loop @@ -28,9 +28,35 @@ ret i64 %res } +define i64 @test_vectorize_select_umin_idx_phi_order_flipped(ptr %src) { +; CHECK-LABEL: @test_vectorize_select_umin_idx_phi_order_flipped( +; CHECK: vector.body: +; +entry: + br label %loop + +loop: + %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] + %min.val = phi i64 [ 0, %entry ], [ %min.val.next, %loop ] + %min.idx = phi i64 [ 0, %entry ], [ %min.idx.next, %loop ] + %gep = getelementptr i64, ptr %src, i64 %iv + %l = load i64, ptr %gep + %cmp = icmp ugt i64 %min.val, %l + %min.val.next = tail call i64 @llvm.umin.i64(i64 %min.val, i64 %l) + %min.idx.next = select i1 %cmp, i64 %iv, i64 %min.idx + %iv.next = add nuw nsw i64 %iv, 1 + %exitcond.not = icmp eq i64 %iv.next, 0 + br i1 %exitcond.not, label %exit, label %loop + +exit: + %res = phi i64 [ %min.idx.next, %loop ] + ret i64 %res +} + + define i64 @test_vectorize_select_umin_idx_min_ops_switched(ptr %src) { ; CHECK-LABEL: @test_vectorize_select_umin_idx_min_ops_switched( -; CHECK-NOT: vector.body: +; CHECK: vector.body: ; entry: br label %loop @@ -105,7 +131,7 @@ define i32 @test_vectorize_select_umin_idx_with_trunc() { ; CHECK-LABEL: @test_vectorize_select_umin_idx_with_trunc( -; CHECK-NOT: vector.body: +; CHECK: vector.body: ; entry: br label %loop