Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -8059,9 +8059,9 @@ assert(!State.Instance && "Reduction being replicated."); for (unsigned Part = 0; Part < State.UF; ++Part) { RecurrenceDescriptor::RecurrenceKind Kind = RdxDesc->getRecurrenceKind(); - Value *NewVecOp = State.get(VecOp, Part); - if (CondOp) { - Value *NewCond = State.get(CondOp, Part); + Value *NewVecOp = State.get(getVecOp(), Part); + if (getCondOp()) { + Value *NewCond = State.get(getCondOp(), Part); VectorType *VecTy = cast(NewVecOp->getType()); Constant *Iden = RecurrenceDescriptor::getRecurrenceIdentity( Kind, RdxDesc->getMinMaxRecurrenceKind(), VecTy->getElementType()); @@ -8072,7 +8072,7 @@ } Value *NewRed = createTargetReduction(State.Builder, *RdxDesc, NewVecOp, UseReductionIntrinsic, NoNaN); - Value *PrevInChain = State.get(ChainOp, Part); + Value *PrevInChain = State.get(getChainOp(), Part); Value *NextInChain; if (Kind == RecurrenceDescriptor::RK_IntegerMinMax || Kind == RecurrenceDescriptor::RK_FloatMinMax) { @@ -8081,9 +8081,10 @@ NewRed, PrevInChain); } else { NextInChain = State.Builder.CreateBinOp( - (Instruction::BinaryOps)I->getOpcode(), NewRed, PrevInChain); + (Instruction::BinaryOps)getUnderlyingInstr()->getOpcode(), NewRed, + PrevInChain); } - State.ValueMap.setVectorValue(I, Part, NextInChain); + State.ValueMap.setVectorValue(getUnderlyingInstr(), Part, NextInChain); } } Index: llvm/lib/Transforms/Vectorize/VPlan.h =================================================================== --- llvm/lib/Transforms/Vectorize/VPlan.h +++ llvm/lib/Transforms/Vectorize/VPlan.h @@ -1063,17 +1063,10 @@ /// A recipe to represent inloop reduction operations, performing a reduction on /// a vector operand into a scalar value, and adding the result to a chain. -class VPReductionRecipe : public VPRecipeBase { +/// The Operands are {ChainOp, VecOp, [Condition]}. +class VPReductionRecipe : public VPRecipeBase, public VPValue, public VPUser { /// The recurrence decriptor for the reduction in question. RecurrenceDescriptor *RdxDesc; - /// The original instruction being converted to a reduction. - Instruction *I; - /// The VPValue of the vector value to be reduced. - VPValue *VecOp; - /// The VPValue of the scalar Chain being accumulated. - VPValue *ChainOp; - /// The VPValue of the condition for the block. - VPValue *CondOp; /// Fast math flags to use for the resulting reduction operation. bool NoNaN; /// Flag for whether to use reduction intrinsics vs shuffle expansions. @@ -1083,13 +1076,19 @@ VPReductionRecipe(RecurrenceDescriptor *R, Instruction *I, VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp, bool NoNaN, bool UseReductionIntrinsic) - : VPRecipeBase(VPReductionSC), RdxDesc(R), I(I), VecOp(VecOp), - ChainOp(ChainOp), CondOp(CondOp), NoNaN(NoNaN), - UseReductionIntrinsic(UseReductionIntrinsic) {} + : VPRecipeBase(VPRecipeBase::VPReductionSC), + VPValue(VPValue::VPReductionSC, I), VPUser({ChainOp, VecOp}), + RdxDesc(R), NoNaN(NoNaN), UseReductionIntrinsic(UseReductionIntrinsic) { + if (CondOp) + addOperand(CondOp); + } ~VPReductionRecipe() override = default; /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPValue *V) { + return V->getVPValueID() == VPValue::VPReductionSC; + } static inline bool classof(const VPRecipeBase *V) { return V->getVPRecipeID() == VPRecipeBase::VPReductionSC; } @@ -1100,6 +1099,15 @@ /// Print the recipe. void print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const override; + + /// The VPValue of the scalar Chain being accumulated. + VPValue *getChainOp() const { return getOperand(0); } + /// The VPValue of the vector value to be reduced. + VPValue *getVecOp() const { return getOperand(1); } + /// The VPValue of the condition for the block. + VPValue *getCondOp() const { + return getNumOperands() > 2 ? getOperand(2) : nullptr; + } }; /// VPReplicateRecipe replicates a given instruction producing multiple scalar Index: llvm/lib/Transforms/Vectorize/VPlan.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/VPlan.cpp +++ llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -22,6 +22,7 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/IVDescriptors.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -104,6 +105,8 @@ VPValue *VPRecipeBase::toVPValue() { if (auto *V = dyn_cast(this)) return V; + if (auto *V = dyn_cast(this)) + return V; if (auto *V = dyn_cast(this)) return V; return nullptr; @@ -112,6 +115,8 @@ const VPValue *VPRecipeBase::toVPValue() const { if (auto *V = dyn_cast(this)) return V; + if (auto *V = dyn_cast(this)) + return V; if (auto *V = dyn_cast(this)) return V; return nullptr; @@ -894,13 +899,16 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { - O << "\"REDUCE of" << *I << " as "; - ChainOp->printAsOperand(O, SlotTracker); - O << " + reduce("; - VecOp->printAsOperand(O, SlotTracker); - if (CondOp) { + O << "\"REDUCE "; + printAsOperand(O, SlotTracker); + O << " = "; + getChainOp()->printAsOperand(O, SlotTracker); + O << " + reduce." << Instruction::getOpcodeName(RdxDesc->getRecurrenceBinOp()) + << " ("; + getVecOp()->printAsOperand(O, SlotTracker); + if (getCondOp()) { O << ", "; - CondOp->printAsOperand(O, SlotTracker); + getCondOp()->printAsOperand(O, SlotTracker); } O << ")"; } Index: llvm/lib/Transforms/Vectorize/VPlanValue.h =================================================================== --- llvm/lib/Transforms/Vectorize/VPlanValue.h +++ llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -78,7 +78,7 @@ /// are actually instantiated. Values of this enumeration are kept in the /// SubclassID field of the VPValue objects. They are used for concrete /// type identification. - enum { VPValueSC, VPInstructionSC, VPMemoryInstructionSC }; + enum { VPValueSC, VPInstructionSC, VPMemoryInstructionSC, VPReductionSC }; VPValue(Value *UV = nullptr) : VPValue(VPValueSC, UV) {} VPValue(const VPValue &) = delete;