diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -9167,10 +9167,10 @@ // need to create an fmul recipe (multiplying the first two operands of // the fmuladd together) to use as the vector operand for the fadd // reduction. - VPInstruction *FMulRecipe = - new VPInstruction(Instruction::FMul, {CurrentLink->getOperand(0), - CurrentLink->getOperand(1)}); - FMulRecipe->setFastMathFlags(CurrentLinkI->getFastMathFlags()); + VPInstruction *FMulRecipe = new VPInstruction( + Instruction::FMul, + {CurrentLink->getOperand(0), CurrentLink->getOperand(1)}, + CurrentLinkI->getFastMathFlags()); LinkVPBB->insert(FMulRecipe, CurrentLink->getIterator()); VecOp = FMulRecipe; } else { diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -844,6 +844,8 @@ char AllowReciprocal : 1; char AllowContract : 1; char ApproxFunc : 1; + + FastMathFlagsTy(const FastMathFlags &FMF); }; OperationType OpType; @@ -878,14 +880,7 @@ GEPFlags.IsInBounds = GEP->isInBounds(); } else if (auto *Op = dyn_cast(&I)) { OpType = OperationType::FPMathOp; - FastMathFlags FMF = Op->getFastMathFlags(); - FMFs.AllowReassoc = FMF.allowReassoc(); - FMFs.NoNaNs = FMF.noNaNs(); - FMFs.NoInfs = FMF.noInfs(); - FMFs.NoSignedZeros = FMF.noSignedZeros(); - FMFs.AllowReciprocal = FMF.allowReciprocal(); - FMFs.AllowContract = FMF.allowContract(); - FMFs.ApproxFunc = FMF.approxFunc(); + FMFs = Op->getFastMathFlags(); } } @@ -895,6 +890,12 @@ : VPRecipeBase(SC, Operands), OpType(OperationType::OverflowingBinOp), WrapFlags(WrapFlags) {} + template + VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, + FastMathFlags FMFs) + : VPRecipeBase(SC, Operands), OpType(OperationType::FPMathOp), + FMFs(FMFs) {} + static inline bool classof(const VPRecipeBase *R) { return R->getVPDefID() == VPRecipeBase::VPInstructionSC || R->getVPDefID() == VPRecipeBase::VPWidenSC || @@ -959,6 +960,9 @@ return GEPFlags.IsInBounds; } + /// Returns true if the recipe has fast-math flags. + bool hasFastMathFlags() const { return OpType == OperationType::FPMathOp; } + FastMathFlags getFastMathFlags() const; bool hasNoUnsignedWrap() const { @@ -1008,7 +1012,6 @@ private: typedef unsigned char OpcodeTy; OpcodeTy Opcode; - FastMathFlags FMF; DebugLoc DL; /// An optional name that can be used for the generated IR instruction. @@ -1020,6 +1023,10 @@ /// one. Value *generateInstruction(VPTransformState &State, unsigned Part); + /// Return true if the VPInstruction is a floating point math operation, i.e. + /// has fast-math flags. + bool isFPMathOp() const; + protected: void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); } @@ -1038,6 +1045,9 @@ : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, WrapFlags), VPValue(this), Opcode(Opcode), DL(DL), Name(Name.str()) {} + VPInstruction(unsigned Opcode, std::initializer_list Operands, + FastMathFlags FMFs, DebugLoc DL = {}, const Twine &Name = ""); + VP_CLASSOF_IMPL(VPDef::VPInstructionSC) VPInstruction *clone() const { @@ -1091,9 +1101,6 @@ } } - /// Set the fast-math flags. - void setFastMathFlags(FastMathFlags FMFNew); - /// Returns true if the recipe only uses the first lane of operand \p Op. bool onlyFirstLaneUsed(const VPValue *Op) const override { assert(is_contained(operands(), Op) && diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -230,6 +230,16 @@ return Res; } +VPInstruction::VPInstruction(unsigned Opcode, + std::initializer_list Operands, + + FastMathFlags FMFs, DebugLoc DL, const Twine &Name) + : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, FMFs), + VPValue(this), Opcode(Opcode), DL(DL), Name(Name.str()) { + // Make sure the VPInstruction is a floating-point operation. + assert(isFPMathOp() && "this op can't take fast-math flags"); +} + Value *VPInstruction::generateInstruction(VPTransformState &State, unsigned Part) { IRBuilderBase &Builder = State.Builder; @@ -373,10 +383,22 @@ } } +bool VPInstruction::isFPMathOp() const { + // Inspired by FPMathOperator::classof. Notable differences are that we don't + // support Call, PHI and Select opcodes here yet. + return Opcode == Instruction::FAdd || Opcode == Instruction::FMul || + Opcode == Instruction::FNeg || Opcode == Instruction::FSub || + Opcode == Instruction::FDiv || Opcode == Instruction::FRem || + Opcode == Instruction::FCmp; +} + void VPInstruction::execute(VPTransformState &State) { assert(!State.Instance && "VPInstruction executing an Instance"); IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder); - State.Builder.setFastMathFlags(FMF); + if (hasFastMathFlags()) { + assert(isFPMathOp() && "Recipe not a FPMathOp but has fast-math flags?); + State.Builder.setFastMathFlags(getFastMathFlags()); + } for (unsigned Part = 0; Part < State.UF; ++Part) { Value *GeneratedValue = generateInstruction(State, Part); if (!hasResult()) @@ -439,7 +461,6 @@ O << Instruction::getOpcodeName(getOpcode()); } - O << FMF; printFlags(O); printOperands(O, SlotTracker); @@ -450,16 +471,6 @@ } #endif -void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) { - // Make sure the VPInstruction is a floating-point operation. - assert((Opcode == Instruction::FAdd || Opcode == Instruction::FMul || - Opcode == Instruction::FNeg || Opcode == Instruction::FSub || - Opcode == Instruction::FDiv || Opcode == Instruction::FRem || - Opcode == Instruction::FCmp) && - "this op can't take fast-math flags"); - FMF = FMFNew; -} - void VPWidenCallRecipe::execute(VPTransformState &State) { assert(State.VF.isVector() && "not widening"); auto &CI = *cast(getUnderlyingInstr()); @@ -576,6 +587,17 @@ } } +VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy( + const FastMathFlags &FMF) { + AllowReassoc = FMF.allowReassoc(); + NoNaNs = FMF.noNaNs(); + NoInfs = FMF.noInfs(); + NoSignedZeros = FMF.noSignedZeros(); + AllowReciprocal = FMF.allowReciprocal(); + AllowContract = FMF.allowContract(); + ApproxFunc = FMF.approxFunc(); +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPRecipeWithIRFlags::printFlags(raw_ostream &O) const { switch (OpType) {