Index: include/llvm/IR/FPState.h =================================================================== --- /dev/null +++ include/llvm/IR/FPState.h @@ -0,0 +1,47 @@ +#ifndef LLVM_FPSTATE_H +#define LLVM_FPSTATE_H +namespace llvm { + +class MDNode; +class IRBuilderBase; +class FPState { +public: + enum FPModelKind { + FPM_Off, + FPM_Precise, + FPM_Strict, + FPM_Fast + }; + + enum FPModelExceptKind { + FPME_Off, + FPME_Except, + FPME_NoExcept + }; + + enum FPSpeculationKind { + FPS_Off, + FPS_Fast, + FPS_Strict, + FPS_Safe + }; + + enum ConstrainedExceptKind {CE_Off, CE_Strict, CE_Ignore, CE_MayTrap}; + enum ConstrainedRoundingKind {CR_Off, CR_Dynamic, CR_ToNearest, CR_Downward, + CR_Upward, CR_ToZero}; + +private: + IRBuilderBase &Builder; + FPModelKind FPM; + FPModelExceptKind FPME; + FPSpeculationKind FPS; +public: + FPState(IRBuilderBase &B); + // function to update builder. + void updateBuilder(FPModelKind fpModel, + FPModelExceptKind fpModelExcept, + FPSpeculationKind fpSpeculation); +}; +} // end namespace llvm + +#endif // LLVM_FPSTATE_H Index: include/llvm/IR/IRBuilder.h =================================================================== --- include/llvm/IR/IRBuilder.h +++ include/llvm/IR/IRBuilder.h @@ -26,6 +26,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/FPState.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstrTypes.h" @@ -87,7 +88,6 @@ /// Common base class shared among various IRBuilders. class IRBuilderBase { DebugLoc CurDbgLocation; - protected: BasicBlock *BB; BasicBlock::iterator InsertPt; @@ -96,12 +96,18 @@ MDNode *DefaultFPMathTag; FastMathFlags FMF; + bool IsFPConstrained; + FPState::ConstrainedExceptKind DefaultConstrainedExcept; + FPState::ConstrainedRoundingKind DefaultConstrainedRounding; + ArrayRef DefaultOperandBundles; public: IRBuilderBase(LLVMContext &context, MDNode *FPMathTag = nullptr, ArrayRef OpBundles = None) : Context(context), DefaultFPMathTag(FPMathTag), + IsFPConstrained(false), DefaultConstrainedExcept(FPState::CE_Off), + DefaultConstrainedRounding(FPState::CR_Off), DefaultOperandBundles(OpBundles) { ClearInsertionPoint(); } @@ -218,6 +224,55 @@ /// Set the fast-math flags to be used with generated fp-math operators void setFastMathFlags(FastMathFlags NewFMF) { FMF = NewFMF; } + /// Enable/Disable use of constrained floating point math + void setIsConstrainedFP(bool IsCon) { IsFPConstrained = IsCon; } + + /// Disable use of constrained floating point math + void clearIsConstrainedFP() { setIsConstrainedFP(false); } + + /// Set the exception handling to be used with constrained floating point + void setDefaultConstrainedExcept(FPState::ConstrainedExceptKind NewExcept) { + DefaultConstrainedExcept = NewExcept; + } + + /// Set the rounding mode handling to be used with constrained floating point + void + setDefaultConstrainedRounding(FPState::ConstrainedRoundingKind NewRounding) { + DefaultConstrainedRounding = NewRounding; + } + + /// Get the exception handling used with constrained floating point + MDNode *getDefaultConstrainedExcept() { + switch (DefaultConstrainedExcept) { + case FPState::CE_Off: + case FPState::CE_Strict: + return MDNode::get(Context, MDString::get(Context, "fpexcept.strict")); + case FPState::CE_Ignore: + return MDNode::get(Context, MDString::get(Context, "fpexcept.ignore")); + case FPState::CE_MayTrap: + return MDNode::get(Context, MDString::get(Context, "fpexcept.maytrap")); + default: llvm_unreachable("Invalid constrained except kind"); + } + } + + /// Get the rounding mode handling used with constrained floating point + MDNode *getDefaultConstrainedRounding() { + switch (DefaultConstrainedRounding) { + case FPState::CR_Off: + case FPState::CR_Dynamic: + return MDNode::get(Context, MDString::get(Context, "round.dynamic")); + case FPState::CR_ToNearest: + return MDNode::get(Context, MDString::get(Context, "round.tonearest")); + case FPState::CR_Downward: + return MDNode::get(Context, MDString::get(Context, "round.downward")); + case FPState::CR_Upward: + return MDNode::get(Context, MDString::get(Context, "round.upward")); + case FPState::CR_ToZero: + return MDNode::get(Context, MDString::get(Context, "round.tozero")); + default: llvm_unreachable("Invalid constrained rounding kind"); + } + } + //===--------------------------------------------------------------------===// // RAII helpers. //===--------------------------------------------------------------------===// @@ -1045,6 +1100,28 @@ return (LC && RC) ? Insert(Folder.CreateBinOp(Opc, LC, RC), Name) : nullptr; } + Value *getConstrainedRounding(MDNode *RoundingMD) { + MDString *Rounding; + + if (!RoundingMD) + RoundingMD = getDefaultConstrainedRounding(); + + Rounding = cast(RoundingMD->getOperand(0)); + + return MetadataAsValue::get(Context, Rounding); + } + + Value *getConstrainedExcept(MDNode *ExceptMD) { + MDString *Except; + + if (!ExceptMD) + ExceptMD = getDefaultConstrainedExcept(); + + Except = cast(ExceptMD->getOperand(0)); + + return MetadataAsValue::get(Context, Except); + } + public: Value *CreateAdd(Value *LHS, Value *RHS, const Twine &Name = "", bool HasNUW = false, bool HasNSW = false) { @@ -1247,6 +1324,9 @@ Value *CreateFAdd(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { + if (IsFPConstrained) + return CreateConstrainedFAdd(L, R, nullptr, Name, nullptr, nullptr); + if (Value *V = foldConstant(Instruction::FAdd, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), FPMD, FMF); return Insert(I, Name); @@ -1256,14 +1336,34 @@ /// default FMF. Value *CreateFAddFMF(Value *L, Value *R, Instruction *FMFSource, const Twine &Name = "") { + if (IsFPConstrained) + return CreateConstrainedFAdd(L, R, FMFSource, Name, nullptr, nullptr); + if (Value *V = foldConstant(Instruction::FAdd, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), nullptr, FMFSource->getFastMathFlags()); return Insert(I, Name); } + CallInst *CreateConstrainedFAdd(Value *L, Value *R, + Instruction *FMFSource = nullptr, const Twine &Name = "", + MDNode *RoundingMD = nullptr, + MDNode *ExceptMD = nullptr) { + + Value *Rounding = getConstrainedRounding(RoundingMD); + Value *Except = getConstrainedExcept(ExceptMD); + + return CreateIntrinsic(Intrinsic::experimental_constrained_fadd, + { L->getType() }, + { L, R, Rounding, Except }, + FMFSource, Name); + } + Value *CreateFSub(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { + if (IsFPConstrained) + return CreateConstrainedFSub(L, R, nullptr, Name, nullptr, nullptr); + if (Value *V = foldConstant(Instruction::FSub, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), FPMD, FMF); return Insert(I, Name); @@ -1273,14 +1373,35 @@ /// default FMF. Value *CreateFSubFMF(Value *L, Value *R, Instruction *FMFSource, const Twine &Name = "") { + if (IsFPConstrained) + return CreateConstrainedFSub(L, R, FMFSource, Name, nullptr, nullptr); + if (Value *V = foldConstant(Instruction::FSub, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), nullptr, FMFSource->getFastMathFlags()); return Insert(I, Name); } + CallInst *CreateConstrainedFSub(Value *L, Value *R, + Instruction *FMFSource = nullptr, + const Twine &Name = "", + MDNode *RoundingMD = nullptr, + MDNode *ExceptMD = nullptr) { + + Value *Rounding = getConstrainedRounding(RoundingMD); + Value *Except = getConstrainedExcept(ExceptMD); + + return CreateIntrinsic(Intrinsic::experimental_constrained_fsub, + { L->getType() }, + { L, R, Rounding, Except }, + FMFSource, Name); + } + Value *CreateFMul(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { + if (IsFPConstrained) + return CreateConstrainedFMul(L, R, nullptr, Name, nullptr, nullptr); + if (Value *V = foldConstant(Instruction::FMul, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), FPMD, FMF); return Insert(I, Name); @@ -1290,14 +1411,35 @@ /// default FMF. Value *CreateFMulFMF(Value *L, Value *R, Instruction *FMFSource, const Twine &Name = "") { + if (IsFPConstrained) + return CreateConstrainedFMul(L, R, FMFSource, Name, nullptr, nullptr); + if (Value *V = foldConstant(Instruction::FMul, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), nullptr, FMFSource->getFastMathFlags()); return Insert(I, Name); } + CallInst *CreateConstrainedFMul(Value *L, Value *R, + Instruction *FMFSource = nullptr, + const Twine &Name = "", + MDNode *RoundingMD = nullptr, + MDNode *ExceptMD = nullptr) { + + Value *Rounding = getConstrainedRounding(RoundingMD); + Value *Except = getConstrainedExcept(ExceptMD); + + return CreateIntrinsic(Intrinsic::experimental_constrained_fmul, + { L->getType() }, + { L, R, Rounding, Except }, + FMFSource, Name); + } + Value *CreateFDiv(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { + if (IsFPConstrained) + return CreateConstrainedFDiv(L, R, nullptr, Name, nullptr, nullptr); + if (Value *V = foldConstant(Instruction::FDiv, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), FPMD, FMF); return Insert(I, Name); @@ -1307,14 +1449,35 @@ /// default FMF. Value *CreateFDivFMF(Value *L, Value *R, Instruction *FMFSource, const Twine &Name = "") { + if (IsFPConstrained) + return CreateConstrainedFDiv(L, R, FMFSource, Name, nullptr, nullptr); + if (Value *V = foldConstant(Instruction::FDiv, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), nullptr, FMFSource->getFastMathFlags()); return Insert(I, Name); } + CallInst *CreateConstrainedFDiv(Value *L, Value *R, + Instruction *FMFSource = nullptr, + const Twine &Name = "", + MDNode *RoundingMD = nullptr, + MDNode *ExceptMD = nullptr) { + + Value *Rounding = getConstrainedRounding(RoundingMD); + Value *Except = getConstrainedExcept(ExceptMD); + + return CreateIntrinsic(Intrinsic::experimental_constrained_fdiv, + { L->getType() }, + { L, R, Rounding, Except }, + FMFSource, Name); + } + Value *CreateFRem(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { + if (IsFPConstrained) + return CreateConstrainedFRem(L, R, nullptr, Name, nullptr, nullptr); + if (Value *V = foldConstant(Instruction::FRem, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), FPMD, FMF); return Insert(I, Name); @@ -1324,12 +1487,30 @@ /// default FMF. Value *CreateFRemFMF(Value *L, Value *R, Instruction *FMFSource, const Twine &Name = "") { + if (IsFPConstrained) + return CreateConstrainedFRem(L, R, FMFSource, Name, nullptr, nullptr); + if (Value *V = foldConstant(Instruction::FRem, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), nullptr, FMFSource->getFastMathFlags()); return Insert(I, Name); } + CallInst *CreateConstrainedFRem(Value *L, Value *R, + Instruction *FMFSource = nullptr, + const Twine &Name = "", + MDNode *RoundingMD = nullptr, + MDNode *ExceptMD = nullptr) { + + Value *Rounding = getConstrainedRounding(RoundingMD); + Value *Except = getConstrainedExcept(ExceptMD); + + return CreateIntrinsic(Intrinsic::experimental_constrained_frem, + { L->getType() }, + { L, R, Rounding, Except }, + FMFSource, Name); + } + Value *CreateBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS, const Twine &Name = "", MDNode *FPMathTag = nullptr) { Index: lib/IR/CMakeLists.txt =================================================================== --- lib/IR/CMakeLists.txt +++ lib/IR/CMakeLists.txt @@ -22,6 +22,7 @@ DiagnosticInfo.cpp DiagnosticPrinter.cpp Dominators.cpp + FPState.cpp Function.cpp GVMaterializer.cpp Globals.cpp Index: lib/IR/FPState.cpp =================================================================== --- /dev/null +++ lib/IR/FPState.cpp @@ -0,0 +1,79 @@ +#include "llvm/IR/FPState.h" +#include "llvm/IR/IRBuilder.h" + +namespace llvm { +FPState::FPState(IRBuilderBase &B) : Builder(B), FPM(FPM_Off), FPME(FPME_Off), + FPS(FPS_Off) {} + +void FPState::updateBuilder(FPModelKind fpModel, + FPModelExceptKind fpModelExcept, + FPSpeculationKind fpSpeculation) { + // Save the new settings in the state variables. + FPM = fpModel; + FPME = fpModelExcept; + FPS = fpSpeculation; + + // Translate the compiler options into + // the 3 settings that are transmitted to the IR Builder + bool isConstrained = false; + ConstrainedRoundingKind ConstrainedRoundingMD = CR_Off; + ConstrainedExceptKind ConstrainedExceptMD = CE_Off; + + switch (fpModel) { + case FPM_Off: + case FPM_Precise: + case FPM_Fast: + break; + case FPM_Strict: + isConstrained = true; + ConstrainedRoundingMD = CR_Dynamic; + ConstrainedExceptMD = CE_Strict; + break; + default: + llvm_unreachable("Unsupported FP Model"); + } + + switch (fpModelExcept) { + case FPME_Off: + break; + case FPME_Except: + isConstrained = true; + ConstrainedExceptMD = CE_Strict; + break; + case FPME_NoExcept: + isConstrained = true; + ConstrainedExceptMD = CE_Ignore; + break; + default: + llvm_unreachable("Unsupported FP Except Model"); + } + + switch (fpSpeculation) { + case FPS_Off: + break; + case FPS_Fast: + isConstrained = true; + ConstrainedExceptMD = CE_Ignore; + break; + case FPS_Strict: + isConstrained = true; + ConstrainedExceptMD = CE_Strict; + break; + case FPS_Safe: + isConstrained = true; + ConstrainedExceptMD = CE_MayTrap; + break; + default: + llvm_unreachable("Unsupported FP Speculation"); + } + + if (isConstrained && ConstrainedRoundingMD == CR_Off) + ConstrainedRoundingMD = CR_ToNearest; + + Builder.setIsConstrainedFP(isConstrained); + if (isConstrained) { + Builder.setDefaultConstrainedRounding(ConstrainedRoundingMD); + Builder.setDefaultConstrainedExcept(ConstrainedExceptMD); + } +} +} // end namespace llvm Index: unittests/IR/IRBuilderTest.cpp =================================================================== --- unittests/IR/IRBuilderTest.cpp +++ unittests/IR/IRBuilderTest.cpp @@ -10,6 +10,7 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/FPState.h" #include "llvm/IR/Function.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" @@ -122,6 +123,159 @@ EXPECT_FALSE(II->hasNoNaNs()); } +TEST_F(IRBuilderTest, ConstrainedFP) { + IRBuilder<> Builder(BB); + Value *V; + CallInst *Call; + IntrinsicInst *II; + + V = Builder.CreateLoad(GV); + + Call = Builder.CreateConstrainedFAdd(V, V); + II = cast(Call); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_fadd); + + Call = Builder.CreateConstrainedFSub(V, V); + II = cast(Call); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_fsub); + + Call = Builder.CreateConstrainedFMul(V, V); + II = cast(Call); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_fmul); + + Call = Builder.CreateConstrainedFDiv(V, V); + II = cast(Call); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_fdiv); + + Call = Builder.CreateConstrainedFRem(V, V); + II = cast(Call); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_frem); + + // Now see if we get constrained intrinsics instead of non-constrained + // instructions. + Builder.setIsConstrainedFP(true); + + V = Builder.CreateFAdd(V, V); + ASSERT_TRUE(isa(V)); + II = cast(V); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_fadd); + + V = Builder.CreateFSub(V, V); + ASSERT_TRUE(isa(V)); + II = cast(V); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_fsub); + + V = Builder.CreateFMul(V, V); + ASSERT_TRUE(isa(V)); + II = cast(V); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_fmul); + + V = Builder.CreateFDiv(V, V); + ASSERT_TRUE(isa(V)); + II = cast(V); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_fdiv); + + V = Builder.CreateFRem(V, V); + ASSERT_TRUE(isa(V)); + II = cast(V); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_frem); + + // Verify the codepaths for setting and overriding the default metadata. + MDNode *ExceptStr = MDNode::get(Builder.getContext(), + MDString::get(Builder.getContext(), + "fpexcept.strict")); + MDNode *RoundDyn = MDNode::get(Builder.getContext(), + MDString::get(Builder.getContext(), + "round.dynamic")); + + V = Builder.CreateFAdd(V, V); + ASSERT_TRUE(isa(V)); + auto *CII = cast(V); + ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebStrict); + ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmDynamic); + + Builder.setDefaultConstrainedExcept(FPState::CE_Ignore); + Builder.setDefaultConstrainedRounding(FPState::CR_Upward); + V = Builder.CreateFAdd(V, V); + CII = cast(V); + ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebIgnore); + ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmUpward); + + // Now override the defaults. + Call = Builder.CreateConstrainedFAdd(V, V, nullptr, "", RoundDyn, ExceptStr); + CII = cast(Call); + ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebStrict); + ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmDynamic); + + // Use FPState to update the builder settings + FPState fpState(Builder); + + fpState.updateBuilder(FPState::FPM_Off, FPState::FPME_Off, FPState::FPS_Off); + V = Builder.CreateFAdd(V, V); + ASSERT_TRUE(!isa(V)); + + fpState.updateBuilder(FPState::FPM_Precise, FPState::FPME_Off, + FPState::FPS_Off); + V = Builder.CreateFAdd(V, V); + ASSERT_TRUE(!isa(V)); + + fpState.updateBuilder(FPState::FPM_Strict, + FPState::FPME_Off, FPState::FPS_Off); + V = Builder.CreateFAdd(V, V); + CII = cast(V); + ASSERT_TRUE(isa(V)); + ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebStrict); + ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmDynamic); + + fpState.updateBuilder(FPState::FPM_Off, + FPState::FPME_Except, FPState::FPS_Off); + V = Builder.CreateFAdd(V, V); + CII = cast(V); + ASSERT_TRUE(isa(V)); + ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebStrict); + ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmToNearest); + + fpState.updateBuilder(FPState::FPM_Off, + FPState::FPME_NoExcept, FPState::FPS_Off); + V = Builder.CreateFAdd(V, V); + CII = cast(V); + ASSERT_TRUE(isa(V)); + ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebIgnore); + ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmToNearest); + + fpState.updateBuilder(FPState::FPM_Fast, + FPState::FPME_Off, FPState::FPS_Off); + V = Builder.CreateFAdd(V, V); + ASSERT_TRUE(!isa(V)); + + fpState.updateBuilder(FPState::FPM_Off, + FPState::FPME_Off, FPState::FPS_Fast); + V = Builder.CreateFAdd(V, V); + CII = cast(V); + ASSERT_TRUE(isa(V)); + ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebIgnore); + ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmToNearest); + + fpState.updateBuilder(FPState::FPM_Off, + FPState::FPME_Off, FPState::FPS_Strict); + V = Builder.CreateFAdd(V, V); + CII = cast(V); + ASSERT_TRUE(isa(V)); + ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebStrict); + ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmToNearest); + + fpState.updateBuilder(FPState::FPM_Off, + FPState::FPME_Off, FPState::FPS_Safe); + V = Builder.CreateFAdd(V, V); + CII = cast(V); + ASSERT_TRUE(isa(V)); + ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebMayTrap); + ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmToNearest); + + Builder.CreateRetVoid(); + EXPECT_FALSE(verifyModule(*M)); +} + TEST_F(IRBuilderTest, Lifetime) { IRBuilder<> Builder(BB); AllocaInst *Var1 = Builder.CreateAlloca(Builder.getInt8Ty());