diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -9170,10 +9170,10 @@ // need to create an fmul recipe (multiplying the first two operands of // the fmuladd together) to use as the vector operand for the fadd // reduction. - VPInstruction *FMulRecipe = - new VPInstruction(Instruction::FMul, {CurrentLink->getOperand(0), - CurrentLink->getOperand(1)}); - FMulRecipe->setFastMathFlags(CurrentLinkI->getFastMathFlags()); + VPInstruction *FMulRecipe = new VPInstruction( + Instruction::FMul, + {CurrentLink->getOperand(0), CurrentLink->getOperand(1)}, + CurrentLinkI->getFastMathFlags()); LinkVPBB->insert(FMulRecipe, CurrentLink->getIterator()); VecOp = FMulRecipe; } else { diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -838,6 +838,8 @@ char AllowReciprocal : 1; char AllowContract : 1; char ApproxFunc : 1; + + FastMathFlagsTy(const FastMathFlags &FMF); }; OperationType OpType; @@ -874,17 +876,16 @@ GEPFlags.IsInBounds = GEP->isInBounds(); } else if (auto *Op = dyn_cast(&I)) { OpType = OperationType::FPMathOp; - FastMathFlags FMF = Op->getFastMathFlags(); - FMFs.AllowReassoc = FMF.allowReassoc(); - FMFs.NoNaNs = FMF.noNaNs(); - FMFs.NoInfs = FMF.noInfs(); - FMFs.NoSignedZeros = FMF.noSignedZeros(); - FMFs.AllowReciprocal = FMF.allowReciprocal(); - FMFs.AllowContract = FMF.allowContract(); - FMFs.ApproxFunc = FMF.approxFunc(); + FMFs = Op->getFastMathFlags(); } } + template + VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, + FastMathFlags FMFs) + : VPRecipeBase(SC, Operands), OpType(OperationType::FPMathOp), + FMFs(FMFs) {} + static inline bool classof(const VPRecipeBase *R) { return R->getVPDefID() == VPRecipeBase::VPWidenSC || R->getVPDefID() == VPRecipeBase::VPWidenGEPSC || @@ -959,7 +960,7 @@ /// While as any Recipe it may generate a sequence of IR instructions when /// executed, these instructions would always form a single-def expression as /// the VPInstruction is also a single def-use vertex. -class VPInstruction : public VPRecipeBase, public VPValue { +class VPInstruction : public VPRecipeWithIRFlags, public VPValue { friend class VPlanSlp; public: @@ -987,7 +988,6 @@ private: typedef unsigned char OpcodeTy; OpcodeTy Opcode; - FastMathFlags FMF; DebugLoc DL; /// An optional name that can be used for the generated IR instruction. @@ -999,19 +999,26 @@ /// one. Value *generateInstruction(VPTransformState &State, unsigned Part); + /// Return true if the VPInstruction is a floating point math operation, i.e. + /// has fast-math flags. + bool isFPMathOp() const; + protected: void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); } public: VPInstruction(unsigned Opcode, ArrayRef Operands, DebugLoc DL, const Twine &Name = "") - : VPRecipeBase(VPDef::VPInstructionSC, Operands), VPValue(this), + : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands), VPValue(this), Opcode(Opcode), DL(DL), Name(Name.str()) {} VPInstruction(unsigned Opcode, std::initializer_list Operands, DebugLoc DL = {}, const Twine &Name = "") : VPInstruction(Opcode, ArrayRef(Operands), DL, Name) {} + VPInstruction(unsigned Opcode, std::initializer_list Operands, + FastMathFlags FMFs, DebugLoc DL = {}, const Twine &Name = ""); + VP_CLASSOF_IMPL(VPDef::VPInstructionSC) VPInstruction *clone() const { @@ -1065,9 +1072,6 @@ } } - /// Set the fast-math flags. - void setFastMathFlags(FastMathFlags FMFNew); - /// Returns true if the recipe only uses the first lane of operand \p Op. bool onlyFirstLaneUsed(const VPValue *Op) const override { assert(is_contained(operands(), Op) && diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -230,6 +230,15 @@ return Res; } +VPInstruction::VPInstruction(unsigned Opcode, + std::initializer_list Operands, + FastMathFlags FMFs, DebugLoc DL, const Twine &Name) + : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, FMFs), + VPValue(this), Opcode(Opcode), DL(DL), Name(Name.str()) { + // Make sure the VPInstruction is a floating-point operation. + assert(isFPMathOp() && "this op can't take fast-math flags"); +} + Value *VPInstruction::generateInstruction(VPTransformState &State, unsigned Part) { IRBuilderBase &Builder = State.Builder; @@ -375,10 +384,18 @@ } } +bool VPInstruction::isFPMathOp() const { + return Opcode == Instruction::FAdd || Opcode == Instruction::FMul || + Opcode == Instruction::FNeg || Opcode == Instruction::FSub || + Opcode == Instruction::FDiv || Opcode == Instruction::FRem || + Opcode == Instruction::FCmp; +} + void VPInstruction::execute(VPTransformState &State) { assert(!State.Instance && "VPInstruction executing an Instance"); IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder); - State.Builder.setFastMathFlags(FMF); + if (isFPMathOp()) + State.Builder.setFastMathFlags(getFastMathFlags()); for (unsigned Part = 0; Part < State.UF; ++Part) { Value *GeneratedValue = generateInstruction(State, Part); if (!hasResult()) @@ -447,7 +464,7 @@ O << Instruction::getOpcodeName(getOpcode()); } - O << FMF; + printFlags(O); for (const VPValue *Operand : operands()) { O << " "; @@ -461,16 +478,6 @@ } #endif -void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) { - // Make sure the VPInstruction is a floating-point operation. - assert((Opcode == Instruction::FAdd || Opcode == Instruction::FMul || - Opcode == Instruction::FNeg || Opcode == Instruction::FSub || - Opcode == Instruction::FDiv || Opcode == Instruction::FRem || - Opcode == Instruction::FCmp) && - "this op can't take fast-math flags"); - FMF = FMFNew; -} - void VPWidenCallRecipe::execute(VPTransformState &State) { assert(State.VF.isVector() && "not widening"); auto &CI = *cast(getUnderlyingInstr()); @@ -587,6 +594,17 @@ } } +VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy( + const FastMathFlags &FMF) { + AllowReassoc = FMF.allowReassoc(); + NoNaNs = FMF.noNaNs(); + NoInfs = FMF.noInfs(); + NoSignedZeros = FMF.noSignedZeros(); + AllowReciprocal = FMF.allowReciprocal(); + AllowContract = FMF.allowContract(); + ApproxFunc = FMF.approxFunc(); +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const { switch (OpType) {