diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -10123,7 +10123,8 @@ or more :ref:`fast-math-flags `. These are optimization hints to enable otherwise unsafe floating-point optimizations. Fast-math-flags are only valid for phis that return a floating-point scalar or vector -type. +type, or an array (nested to any depth) of floating-point scalar or vector +types. Semantics: """""""""" @@ -10172,7 +10173,8 @@ #. The optional ``fast-math flags`` marker indicates that the select has one or more :ref:`fast-math flags `. These are optimization hints to enable otherwise unsafe floating-point optimizations. Fast-math flags are only valid - for selects that return a floating-point scalar or vector type. + for selects that return a floating-point scalar or vector type, or an array + (nested to any depth) of floating-point scalar or vector types. Semantics: """""""""" @@ -10271,7 +10273,8 @@ #. The optional ``fast-math flags`` marker indicates that the call has one or more :ref:`fast-math flags `, which are optimization hints to enable otherwise unsafe floating-point optimizations. Fast-math flags are only valid - for calls that return a floating-point scalar or vector type. + for calls that return a floating-point scalar or vector type, or an array + (nested to any depth) of floating-point scalar or vector types. #. The optional "cconv" marker indicates which :ref:`calling convention ` the call should use. If none is diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h --- a/llvm/include/llvm/IR/Operator.h +++ b/llvm/include/llvm/IR/Operator.h @@ -394,8 +394,12 @@ return true; case Instruction::PHI: case Instruction::Select: - case Instruction::Call: - return V->getType()->isFPOrFPVectorTy(); + case Instruction::Call: { + Type *Ty = V->getType(); + while (ArrayType *ArrTy = dyn_cast(Ty)) + Ty = ArrTy->getElementType(); + return Ty->isFPOrFPVectorTy(); + } default: return false; } diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -5793,7 +5793,7 @@ if (Res != 0) return Res; if (FMF.any()) { - if (!Inst->getType()->isFPOrFPVectorTy()) + if (!isa(Inst)) return Error(Loc, "fast-math-flags specified for select without " "floating-point scalar or vector return type"); Inst->setFastMathFlags(FMF); @@ -5810,7 +5810,7 @@ if (Res != 0) return Res; if (FMF.any()) { - if (!Inst->getType()->isFPOrFPVectorTy()) + if (!isa(Inst)) return Error(Loc, "fast-math-flags specified for phi without " "floating-point scalar or vector return type"); Inst->setFastMathFlags(FMF); @@ -6781,10 +6781,6 @@ ParseOptionalOperandBundles(BundleList, PFS)) return true; - if (FMF.any() && !RetType->isFPOrFPVectorTy()) - return Error(CallLoc, "fast-math-flags specified for call without " - "floating-point scalar or vector return type"); - // If RetType is a non-function pointer type, then this is the short syntax // for the call, which means that RetType is just the return type. Infer the // rest of the function argument types from the arguments that are present. @@ -6847,8 +6843,12 @@ CallInst *CI = CallInst::Create(Ty, Callee, Args, BundleList); CI->setTailCallKind(TCK); CI->setCallingConv(CC); - if (FMF.any()) + if (FMF.any()) { + if (!isa(CI)) + return Error(CallLoc, "fast-math-flags specified for call without " + "floating-point scalar or vector return type"); CI->setFastMathFlags(FMF); + } CI->setAttributes(PAL); ForwardRefAttrGroups[CI] = FwdRefAttrGrps; Inst = CI; diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -4641,10 +4641,9 @@ // There is an optional final record for fast-math-flags if this phi has a // floating-point type. size_t NumArgs = (Record.size() - 1) / 2; - if ((Record.size() - 1) % 2 == 1 && !Ty->isFPOrFPVectorTy()) - return error("Invalid record"); - PHINode *PN = PHINode::Create(Ty, NumArgs); + if ((Record.size() - 1) % 2 == 1 && !isa(PN)) + return error("Invalid record"); InstructionList.push_back(PN); for (unsigned i = 0; i != NumArgs; i++) { diff --git a/llvm/test/Bitcode/compatibility.ll b/llvm/test/Bitcode/compatibility.ll --- a/llvm/test/Bitcode/compatibility.ll +++ b/llvm/test/Bitcode/compatibility.ll @@ -861,6 +861,14 @@ ret void } +define void @fastmathflags_array_select(i1 %cond, [2 x double] %op1, [2 x double] %op2) { + %f.nnan.nsz = select nnan nsz i1 %cond, [2 x double] %op1, [2 x double] %op2 + ; CHECK: %f.nnan.nsz = select nnan nsz i1 %cond, [2 x double] %op1, [2 x double] %op2 + %f.fast = select fast i1 %cond, [2 x double] %op1, [2 x double] %op2 + ; CHECK: %f.fast = select fast i1 %cond, [2 x double] %op1, [2 x double] %op2 + ret void +} + define void @fastmathflags_phi(i1 %cond, float %f1, float %f2, double %d1, double %d2, half %h1, half %h2) { entry: br i1 %cond, label %L1, label %L2 @@ -903,6 +911,27 @@ ret void } +define void @fastmathflags_array_phi(i1 %cond, [4 x float] %f1, [4 x float] %f2, [2 x double] %d1, [2 x double] %d2, [8 x half] %h1, [8 x half] %h2) { +entry: + br i1 %cond, label %L1, label %L2 +L1: + br label %exit +L2: + br label %exit +exit: + %p.nnan = phi nnan [4 x float] [ %f1, %L1 ], [ %f2, %L2 ] + ; CHECK: %p.nnan = phi nnan [4 x float] [ %f1, %L1 ], [ %f2, %L2 ] + %p.ninf = phi ninf [2 x double] [ %d1, %L1 ], [ %d2, %L2 ] + ; CHECK: %p.ninf = phi ninf [2 x double] [ %d1, %L1 ], [ %d2, %L2 ] + %p.contract = phi contract [8 x half] [ %h1, %L1 ], [ %h2, %L2 ] + ; CHECK: %p.contract = phi contract [8 x half] [ %h1, %L1 ], [ %h2, %L2 ] + %p.nsz.reassoc = phi reassoc nsz [4 x float] [ %f1, %L1 ], [ %f2, %L2 ] + ; CHECK: %p.nsz.reassoc = phi reassoc nsz [4 x float] [ %f1, %L1 ], [ %f2, %L2 ] + %p.fast = phi fast [8 x half] [ %h2, %L1 ], [ %h1, %L2 ] + ; CHECK: %p.fast = phi fast [8 x half] [ %h2, %L1 ], [ %h1, %L2 ] + ret void +} + ; Check various fast math flags and floating-point types on calls. declare float @fmf1() diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp --- a/llvm/unittests/IR/InstructionsTest.cpp +++ b/llvm/unittests/IR/InstructionsTest.cpp @@ -1046,6 +1046,60 @@ FP->deleteValue(); } +TEST(InstructionsTest, FPCallIsFPMathOperator) { + LLVMContext C; + + Type *ITy = Type::getInt32Ty(C); + FunctionType *IFnTy = FunctionType::get(ITy, {}); + Value *ICallee = Constant::getNullValue(IFnTy->getPointerTo()); + std::unique_ptr ICall(CallInst::Create(IFnTy, ICallee, {}, "")); + EXPECT_FALSE(isa(ICall)); + + Type *VITy = VectorType::get(ITy, 2); + FunctionType *VIFnTy = FunctionType::get(VITy, {}); + Value *VICallee = Constant::getNullValue(VIFnTy->getPointerTo()); + std::unique_ptr VICall(CallInst::Create(VIFnTy, VICallee, {}, "")); + EXPECT_FALSE(isa(VICall)); + + Type *AITy = ArrayType::get(ITy, 2); + FunctionType *AIFnTy = FunctionType::get(AITy, {}); + Value *AICallee = Constant::getNullValue(AIFnTy->getPointerTo()); + std::unique_ptr AICall(CallInst::Create(AIFnTy, AICallee, {}, "")); + EXPECT_FALSE(isa(AICall)); + + Type *FTy = Type::getFloatTy(C); + FunctionType *FFnTy = FunctionType::get(FTy, {}); + Value *FCallee = Constant::getNullValue(FFnTy->getPointerTo()); + std::unique_ptr FCall(CallInst::Create(FFnTy, FCallee, {}, "")); + EXPECT_TRUE(isa(FCall)); + + Type *VFTy = VectorType::get(FTy, 2); + FunctionType *VFFnTy = FunctionType::get(VFTy, {}); + Value *VFCallee = Constant::getNullValue(VFFnTy->getPointerTo()); + std::unique_ptr VFCall(CallInst::Create(VFFnTy, VFCallee, {}, "")); + EXPECT_TRUE(isa(VFCall)); + + Type *AFTy = ArrayType::get(FTy, 2); + FunctionType *AFFnTy = FunctionType::get(AFTy, {}); + Value *AFCallee = Constant::getNullValue(AFFnTy->getPointerTo()); + std::unique_ptr AFCall(CallInst::Create(AFFnTy, AFCallee, {}, "")); + EXPECT_TRUE(isa(AFCall)); + + Type *AVFTy = ArrayType::get(VFTy, 2); + FunctionType *AVFFnTy = FunctionType::get(AVFTy, {}); + Value *AVFCallee = Constant::getNullValue(AVFFnTy->getPointerTo()); + std::unique_ptr AVFCall( + CallInst::Create(AVFFnTy, AVFCallee, {}, "")); + EXPECT_TRUE(isa(AVFCall)); + + Type *AAVFTy = ArrayType::get(AVFTy, 2); + FunctionType *AAVFFnTy = FunctionType::get(AAVFTy, {}); + Value *AAVFCallee = Constant::getNullValue(AAVFFnTy->getPointerTo()); + std::unique_ptr AAVFCall( + CallInst::Create(AAVFFnTy, AAVFCallee, {}, "")); + EXPECT_TRUE(isa(AAVFCall)); +} + TEST(InstructionsTest, FNegInstruction) { LLVMContext Context; Type *FltTy = Type::getFloatTy(Context);