diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -45,13 +45,17 @@ VPBasicBlock *BB = nullptr; VPBasicBlock::iterator InsertPt = VPBasicBlock::iterator(); + /// Insert \p VPI in BB at InsertPt if BB is set. + VPInstruction *insertInstruction(VPInstruction *VPI) { + if (BB) + BB->insert(VPI, InsertPt); + return VPI; + } + VPInstruction *createInstruction(unsigned Opcode, ArrayRef Operands, DebugLoc DL, const Twine &Name = "") { - VPInstruction *Instr = new VPInstruction(Opcode, Operands, DL, Name); - if (BB) - BB->insert(Instr, InsertPt); - return Instr; + return insertInstruction(new VPInstruction(Opcode, Operands, DL, Name)); } VPInstruction *createInstruction(unsigned Opcode, @@ -152,6 +156,15 @@ Name); } + /// Create a new ICmp VPInstruction with predicate \p Pred and operands \p A + /// and \p B. + /// TODO: add createFCmp when needed. + VPValue *createICmp(CmpInst::Predicate Pred, VPValue *A, VPValue *B, + DebugLoc DL = {}, const Twine &Name = "") { + return insertInstruction( + new VPInstruction(Instruction::ICmp, Pred, A, B, DL, Name)); + } + //===--------------------------------------------------------------------===// // RAII helpers. //===--------------------------------------------------------------------===// diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -8045,7 +8045,7 @@ nullptr, "active.lane.mask"); } else { VPValue *BTC = Plan.getOrCreateBackedgeTakenCount(); - BlockMask = Builder.createNaryOp(VPInstruction::ICmpULE, {IV, BTC}); + BlockMask = Builder.createICmp(CmpInst::ICMP_ULE, IV, BTC); } BlockMaskCache[Header] = BlockMask; } diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -814,6 +814,7 @@ /// Class to record LLVM IR flag for a recipe along with it. class VPRecipeWithIRFlags : public VPRecipeBase { enum class OperationType : unsigned char { + Cmp, OverflowingBinOp, PossiblyExactOp, GEPOp, @@ -851,11 +852,12 @@ OperationType OpType; union { + CmpInst::Predicate CmpPredicate; WrapFlagsTy WrapFlags; ExactFlagsTy ExactFlags; GEPFlagsTy GEPFlags; FastMathFlagsTy FMFs; - unsigned char AllFlags; + unsigned AllFlags; }; public: @@ -869,7 +871,10 @@ template VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, Instruction &I) : VPRecipeWithIRFlags(SC, Operands) { - if (auto *Op = dyn_cast(&I)) { + if (auto *Op = dyn_cast(&I)) { + OpType = OperationType::Cmp; + CmpPredicate = Op->getPredicate(); + } else if (auto *Op = dyn_cast(&I)) { OpType = OperationType::OverflowingBinOp; WrapFlags = {Op->hasNoUnsignedWrap(), Op->hasNoSignedWrap()}; } else if (auto *Op = dyn_cast(&I)) { @@ -896,6 +901,12 @@ : VPRecipeBase(SC, Operands), OpType(OperationType::FPMathOp), FMFs(FMFs) {} + template + VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, + CmpInst::Predicate Pred) + : VPRecipeBase(SC, Operands), OpType(OperationType::Cmp), + CmpPredicate(Pred) {} + static inline bool classof(const VPRecipeBase *R) { return R->getVPDefID() == VPRecipeBase::VPInstructionSC || R->getVPDefID() == VPRecipeBase::VPWidenSC || @@ -922,6 +933,7 @@ FMFs.NoNaNs = false; FMFs.NoInfs = false; break; + case OperationType::Cmp: case OperationType::Other: break; } @@ -949,6 +961,7 @@ I->setHasAllowContract(FMFs.AllowContract); I->setHasApproxFunc(FMFs.ApproxFunc); break; + case OperationType::Cmp: case OperationType::Other: break; } @@ -977,6 +990,12 @@ return WrapFlags.HasNSW; } + CmpInst::Predicate getPredicate() const { + assert(OpType == OperationType::Cmp && + "recipe doesn't have a compare predicate"); + return CmpPredicate; + } + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void printFlags(raw_ostream &O) const; #endif @@ -996,7 +1015,6 @@ Instruction::OtherOpsEnd + 1, // Combines the incoming and previous // values of a first-order recurrence. Not, - ICmpULE, SLPLoad, SLPStore, ActiveLaneMask, @@ -1050,6 +1068,9 @@ VPInstruction(unsigned Opcode, std::initializer_list Operands, FastMathFlags FMFs, DebugLoc DL = {}, const Twine &Name = ""); + VPInstruction(unsigned Opcode, CmpInst::Predicate Pred, VPValue *A, + VPValue *B, DebugLoc DL = {}, const Twine &Name = ""); + VP_CLASSOF_IMPL(VPDef::VPInstructionSC) unsigned getOpcode() const { return Opcode; } diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -116,8 +116,8 @@ return false; case VPInstructionSC: switch (cast(this)->getOpcode()) { + case Instruction::ICmp: case VPInstruction::Not: - case VPInstruction::ICmpULE: case VPInstruction::CalculateTripCountMinusVF: case VPInstruction::CanonicalIVIncrement: case VPInstruction::CanonicalIVIncrementForPart: @@ -255,6 +255,15 @@ assert(isFPMathOp() && "this op can't take fast-math flags"); } +VPInstruction::VPInstruction(unsigned Opcode, CmpInst::Predicate Pred, + VPValue *A, VPValue *B, DebugLoc DL, + const Twine &Name) + : VPRecipeWithIRFlags(VPDef::VPInstructionSC, ArrayRef({A, B}), + Pred), + VPValue(this), Opcode(Opcode), DL(DL), Name(Name.str()) { + assert(Opcode == Instruction::ICmp || Opcode == Instruction::FCmp); +} + Value *VPInstruction::generateInstruction(VPTransformState &State, unsigned Part) { IRBuilderBase &Builder = State.Builder; @@ -271,10 +280,11 @@ Value *A = State.get(getOperand(0), Part); return Builder.CreateNot(A, Name); } - case VPInstruction::ICmpULE: { + case Instruction::FCmp: + case Instruction::ICmp: { Value *A = State.get(getOperand(0), Part); Value *B = State.get(getOperand(1), Part); - return Builder.CreateICmpULE(A, B, Name); + return Builder.CreateCmp(getPredicate(), A, B, Name); } case Instruction::Select: { Value *Cond = State.get(getOperand(0), Part); @@ -444,9 +454,6 @@ case VPInstruction::Not: O << "not"; break; - case VPInstruction::ICmpULE: - O << "icmp ule"; - break; case VPInstruction::SLPLoad: O << "combined load"; break; @@ -618,6 +625,9 @@ #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const { switch (OpType) { + case OperationType::Cmp: + O << " " << CmpInst::getPredicateName(getPredicate()); + break; case OperationType::PossiblyExactOp: if (ExactFlags.IsExact) O << " exact"; @@ -741,8 +751,6 @@ const Instruction *UI = getUnderlyingInstr(); O << " = " << UI->getOpcodeName(); printFlags(O); - if (auto *Cmp = dyn_cast(UI)) - O << Cmp->getPredicate() << " "; printOperands(O, SlotTracker); } #endif diff --git a/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll b/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll --- a/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll +++ b/llvm/test/Transforms/LoopVectorize/vplan-sink-scalars-and-merge.ll @@ -906,7 +906,7 @@ ; CHECK-NEXT: pred.store.if: ; CHECK-NEXT: REPLICATE ir<%l1> = load ir<%src> ; CHECK-NEXT: REPLICATE ir<%l2> = trunc ir<%l1> -; CHECK-NEXT: REPLICATE ir<%cmp> = icmp ir<%l1>, ir<0> +; CHECK-NEXT: REPLICATE ir<%cmp> = icmp eq ir<%l1>, ir<0> ; CHECK-NEXT: REPLICATE ir<%sel> = select ir<%cmp>, ir<5>, ir<%l2> ; CHECK-NEXT: REPLICATE store ir<%sel>, ir<%dst> ; CHECK-NEXT: Successor(s): pred.store.continue