diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -9184,10 +9184,10 @@ // need to create an fmul recipe (multiplying the first two operands of // the fmuladd together) to use as the vector operand for the fadd // reduction. - VPInstruction *FMulRecipe = - new VPInstruction(Instruction::FMul, {CurrentLink->getOperand(0), - CurrentLink->getOperand(1)}); - FMulRecipe->setFastMathFlags(CurrentLinkI->getFastMathFlags()); + VPInstruction *FMulRecipe = VPInstruction::createFMFOperator( + Instruction::FMul, + {CurrentLink->getOperand(0), CurrentLink->getOperand(1)}, + CurrentLinkI->getFastMathFlags()); LinkVPBB->insert(FMulRecipe, CurrentLink->getIterator()); VecOp = FMulRecipe; } else { diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -850,9 +850,13 @@ unsigned char AllFlags; }; +protected: + /// Set the fast-math flags. + void setFastMathFlags(FastMathFlags FMFNew); + public: template - VPRecipeWithIRFlags(const unsigned char SC, iterator_range Operands) + VPRecipeWithIRFlags(const unsigned char SC, IterT Operands) : VPRecipeBase(SC, Operands) { OpType = OperationType::Other; AllFlags = 0; @@ -873,15 +877,7 @@ OpType = OperationType::GEPOp; GEPFlags.IsInBounds = GEP->isInBounds(); } else if (auto *Op = dyn_cast(&I)) { - OpType = OperationType::FPMathOp; - FastMathFlags FMF = Op->getFastMathFlags(); - FMFs.AllowReassoc = FMF.allowReassoc(); - FMFs.NoNaNs = FMF.noNaNs(); - FMFs.NoInfs = FMF.noInfs(); - FMFs.NoSignedZeros = FMF.noSignedZeros(); - FMFs.AllowReciprocal = FMF.allowReciprocal(); - FMFs.AllowContract = FMF.allowContract(); - FMFs.ApproxFunc = FMF.approxFunc(); + setFastMathFlags(Op->getFastMathFlags()); } } @@ -948,7 +944,6 @@ return GEPFlags.IsInBounds; } -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) FastMathFlags getFastMathFlags() const { FastMathFlags Res; Res.setAllowReassoc(FMFs.AllowReassoc); @@ -961,6 +956,7 @@ return Res; } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void printFlags(raw_ostream &O) const; #endif }; @@ -969,7 +965,7 @@ /// While as any Recipe it may generate a sequence of IR instructions when /// executed, these instructions would always form a single-def expression as /// the VPInstruction is also a single def-use vertex. -class VPInstruction : public VPRecipeBase, public VPValue { +class VPInstruction : public VPRecipeWithIRFlags, public VPValue { friend class VPlanSlp; public: @@ -997,7 +993,6 @@ private: typedef unsigned char OpcodeTy; OpcodeTy Opcode; - FastMathFlags FMF; DebugLoc DL; /// An optional name that can be used for the generated IR instruction. @@ -1015,7 +1010,7 @@ public: VPInstruction(unsigned Opcode, ArrayRef Operands, DebugLoc DL, const Twine &Name = "") - : VPRecipeBase(VPDef::VPInstructionSC, Operands), VPValue(this), + : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands), VPValue(this), Opcode(Opcode), DL(DL), Name(Name.str()) {} VPInstruction(unsigned Opcode, std::initializer_list Operands, @@ -1024,6 +1019,11 @@ VP_CLASSOF_IMPL(VPDef::VPInstructionSC) + static VPInstruction *createFMFOperator(unsigned Opcode, + ArrayRef Operands, + FastMathFlags FMF, DebugLoc DL = {}, + const Twine &Name = ""); + VPInstruction *clone() const { SmallVector Operands(operands()); return new VPInstruction(Opcode, Operands, DL, Name); @@ -1075,9 +1075,6 @@ } } - /// Set the fast-math flags. - void setFastMathFlags(FastMathFlags FMFNew); - /// Returns true if the recipe only uses the first lane of operand \p Op. bool onlyFirstLaneUsed(const VPValue *Op) const override { assert(is_contained(operands(), Op) && diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -216,6 +216,21 @@ insertBefore(BB, I); } +VPInstruction *VPInstruction::createFMFOperator(unsigned Opcode, + ArrayRef Operands, + FastMathFlags FMF, DebugLoc DL, + const Twine &Name) { + auto *I = new VPInstruction(Opcode, Operands, DL, Name); + // Make sure the VPInstruction is a floating-point operation. + assert((Opcode == Instruction::FAdd || Opcode == Instruction::FMul || + Opcode == Instruction::FNeg || Opcode == Instruction::FSub || + Opcode == Instruction::FDiv || Opcode == Instruction::FRem || + Opcode == Instruction::FCmp) && + "this op can't take fast-math flags"); + I->setFastMathFlags(FMF); + return I; +} + Value *VPInstruction::generateInstruction(VPTransformState &State, unsigned Part) { IRBuilderBase &Builder = State.Builder; @@ -364,7 +379,7 @@ void VPInstruction::execute(VPTransformState &State) { assert(!State.Instance && "VPInstruction executing an Instance"); IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder); - State.Builder.setFastMathFlags(FMF); + State.Builder.setFastMathFlags(getFastMathFlags()); for (unsigned Part = 0; Part < State.UF; ++Part) { Value *GeneratedValue = generateInstruction(State, Part); if (!hasResult()) @@ -433,7 +448,7 @@ O << Instruction::getOpcodeName(getOpcode()); } - O << FMF; + printFlags(O); for (const VPValue *Operand : operands()) { O << " "; @@ -447,16 +462,6 @@ } #endif -void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) { - // Make sure the VPInstruction is a floating-point operation. - assert((Opcode == Instruction::FAdd || Opcode == Instruction::FMul || - Opcode == Instruction::FNeg || Opcode == Instruction::FSub || - Opcode == Instruction::FDiv || Opcode == Instruction::FRem || - Opcode == Instruction::FCmp) && - "this op can't take fast-math flags"); - FMF = FMFNew; -} - void VPWidenCallRecipe::execute(VPTransformState &State) { assert(State.VF.isVector() && "not widening"); auto &CI = *cast(getUnderlyingInstr()); @@ -573,6 +578,17 @@ } } +void VPRecipeWithIRFlags::setFastMathFlags(FastMathFlags FMFNew) { + OpType = OperationType::FPMathOp; + FMFs.AllowReassoc = FMFNew.allowReassoc(); + FMFs.NoNaNs = FMFNew.noNaNs(); + FMFs.NoInfs = FMFNew.noInfs(); + FMFs.NoSignedZeros = FMFNew.noSignedZeros(); + FMFs.AllowReciprocal = FMFNew.allowReciprocal(); + FMFs.AllowContract = FMFNew.allowContract(); + FMFs.ApproxFunc = FMFNew.approxFunc(); +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const { switch (OpType) {