diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -10134,7 +10134,8 @@ or more :ref:`fast-math-flags `. These are optimization hints to enable otherwise unsafe floating-point optimizations. Fast-math-flags are only valid for phis that return a floating-point scalar or vector -type. +type, or an array (nested to any depth) of floating-point scalar or vector +types. Semantics: """""""""" @@ -10183,7 +10184,8 @@ #. The optional ``fast-math flags`` marker indicates that the select has one or more :ref:`fast-math flags `. These are optimization hints to enable otherwise unsafe floating-point optimizations. Fast-math flags are only valid - for selects that return a floating-point scalar or vector type. + for selects that return a floating-point scalar or vector type, or an array + (nested to any depth) of floating-point scalar or vector types. Semantics: """""""""" @@ -10282,7 +10284,8 @@ #. The optional ``fast-math flags`` marker indicates that the call has one or more :ref:`fast-math flags `, which are optimization hints to enable otherwise unsafe floating-point optimizations. Fast-math flags are only valid - for calls that return a floating-point scalar or vector type. + for calls that return a floating-point scalar or vector type, or an array + (nested to any depth) of floating-point scalar or vector types. #. The optional "cconv" marker indicates which :ref:`calling convention ` the call should use. If none is diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h --- a/llvm/include/llvm/IR/Operator.h +++ b/llvm/include/llvm/IR/Operator.h @@ -394,8 +394,12 @@ return true; case Instruction::PHI: case Instruction::Select: - case Instruction::Call: - return V->getType()->isFPOrFPVectorTy(); + case Instruction::Call: { + Type *Ty = V->getType(); + while (ArrayType *ArrTy = dyn_cast(Ty)) + Ty = ArrTy->getElementType(); + return Ty->isFPOrFPVectorTy(); + } default: return false; } diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -5799,7 +5799,7 @@ if (Res != 0) return Res; if (FMF.any()) { - if (!Inst->getType()->isFPOrFPVectorTy()) + if (!isa(Inst)) return Error(Loc, "fast-math-flags specified for select without " "floating-point scalar or vector return type"); Inst->setFastMathFlags(FMF); @@ -5816,7 +5816,7 @@ if (Res != 0) return Res; if (FMF.any()) { - if (!Inst->getType()->isFPOrFPVectorTy()) + if (!isa(Inst)) return Error(Loc, "fast-math-flags specified for phi without " "floating-point scalar or vector return type"); Inst->setFastMathFlags(FMF); @@ -6787,10 +6787,6 @@ ParseOptionalOperandBundles(BundleList, PFS)) return true; - if (FMF.any() && !RetType->isFPOrFPVectorTy()) - return Error(CallLoc, "fast-math-flags specified for call without " - "floating-point scalar or vector return type"); - // If RetType is a non-function pointer type, then this is the short syntax // for the call, which means that RetType is just the return type. Infer the // rest of the function argument types from the arguments that are present. @@ -6853,8 +6849,12 @@ CallInst *CI = CallInst::Create(Ty, Callee, Args, BundleList); CI->setTailCallKind(TCK); CI->setCallingConv(CC); - if (FMF.any()) + if (FMF.any()) { + if (!isa(CI)) + return Error(CallLoc, "fast-math-flags specified for call without " + "floating-point scalar or vector return type"); CI->setFastMathFlags(FMF); + } CI->setAttributes(PAL); ForwardRefAttrGroups[CI] = FwdRefAttrGrps; Inst = CI; diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -4641,10 +4641,9 @@ // There is an optional final record for fast-math-flags if this phi has a // floating-point type. size_t NumArgs = (Record.size() - 1) / 2; - if ((Record.size() - 1) % 2 == 1 && !Ty->isFPOrFPVectorTy()) - return error("Invalid record"); - PHINode *PN = PHINode::Create(Ty, NumArgs); + if ((Record.size() - 1) % 2 == 1 && !isa(PN)) + return error("Invalid record"); InstructionList.push_back(PN); for (unsigned i = 0; i != NumArgs; i++) { diff --git a/llvm/test/Bitcode/compatibility.ll b/llvm/test/Bitcode/compatibility.ll --- a/llvm/test/Bitcode/compatibility.ll +++ b/llvm/test/Bitcode/compatibility.ll @@ -861,6 +861,14 @@ ret void } +define void @fastmathflags_array_select(i1 %cond, [2 x double] %op1, [2 x double] %op2) { + %f.nnan.nsz = select nnan nsz i1 %cond, [2 x double] %op1, [2 x double] %op2 + ; CHECK: %f.nnan.nsz = select nnan nsz i1 %cond, [2 x double] %op1, [2 x double] %op2 + %f.fast = select fast i1 %cond, [2 x double] %op1, [2 x double] %op2 + ; CHECK: %f.fast = select fast i1 %cond, [2 x double] %op1, [2 x double] %op2 + ret void +} + define void @fastmathflags_phi(i1 %cond, float %f1, float %f2, double %d1, double %d2, half %h1, half %h2) { entry: br i1 %cond, label %L1, label %L2 @@ -903,24 +911,65 @@ ret void } +define void @fastmathflags_array_phi(i1 %cond, [4 x float] %f1, [4 x float] %f2, [2 x double] %d1, [2 x double] %d2, [8 x half] %h1, [8 x half] %h2) { +entry: + br i1 %cond, label %L1, label %L2 +L1: + br label %exit +L2: + br label %exit +exit: + %p.nnan = phi nnan [4 x float] [ %f1, %L1 ], [ %f2, %L2 ] + ; CHECK: %p.nnan = phi nnan [4 x float] [ %f1, %L1 ], [ %f2, %L2 ] + %p.ninf = phi ninf [2 x double] [ %d1, %L1 ], [ %d2, %L2 ] + ; CHECK: %p.ninf = phi ninf [2 x double] [ %d1, %L1 ], [ %d2, %L2 ] + %p.contract = phi contract [8 x half] [ %h1, %L1 ], [ %h2, %L2 ] + ; CHECK: %p.contract = phi contract [8 x half] [ %h1, %L1 ], [ %h2, %L2 ] + %p.nsz.reassoc = phi reassoc nsz [4 x float] [ %f1, %L1 ], [ %f2, %L2 ] + ; CHECK: %p.nsz.reassoc = phi reassoc nsz [4 x float] [ %f1, %L1 ], [ %f2, %L2 ] + %p.fast = phi fast [8 x half] [ %h2, %L1 ], [ %h1, %L2 ] + ; CHECK: %p.fast = phi fast [8 x half] [ %h2, %L1 ], [ %h1, %L2 ] + ret void +} + ; Check various fast math flags and floating-point types on calls. -declare float @fmf1() -declare double @fmf2() -declare <4 x double> @fmf3() +declare float @fmf_f32() +declare double @fmf_f64() +declare <4 x double> @fmf_v4f64() ; CHECK-LABEL: fastMathFlagsForCalls( define void @fastMathFlagsForCalls(float %f, double %d1, <4 x double> %d2) { - %call.fast = call fast float @fmf1() - ; CHECK: %call.fast = call fast float @fmf1() + %call.fast = call fast float @fmf_f32() + ; CHECK: %call.fast = call fast float @fmf_f32() + + ; Throw in some other attributes to make sure those stay in the right places. + + %call.nsz.arcp = notail call nsz arcp double @fmf_f64() + ; CHECK: %call.nsz.arcp = notail call nsz arcp double @fmf_f64() + + %call.nnan.ninf = tail call nnan ninf fastcc <4 x double> @fmf_v4f64() + ; CHECK: %call.nnan.ninf = tail call nnan ninf fastcc <4 x double> @fmf_v4f64() + + ret void +} + +declare [2 x float] @fmf_a2f32() +declare [2 x double] @fmf_a2f64() +declare [2 x <4 x double>] @fmf_a2v4f64() + +; CHECK-LABEL: fastMathFlagsForArrayCalls( +define void @fastMathFlagsForArrayCalls([2 x float] %f, [2 x double] %d1, [2 x <4 x double>] %d2) { + %call.fast = call fast [2 x float] @fmf_a2f32() + ; CHECK: %call.fast = call fast [2 x float] @fmf_a2f32() ; Throw in some other attributes to make sure those stay in the right places. - %call.nsz.arcp = notail call nsz arcp double @fmf2() - ; CHECK: %call.nsz.arcp = notail call nsz arcp double @fmf2() + %call.nsz.arcp = notail call nsz arcp [2 x double] @fmf_a2f64() + ; CHECK: %call.nsz.arcp = notail call nsz arcp [2 x double] @fmf_a2f64() - %call.nnan.ninf = tail call nnan ninf fastcc <4 x double> @fmf3() - ; CHECK: %call.nnan.ninf = tail call nnan ninf fastcc <4 x double> @fmf3() + %call.nnan.ninf = tail call nnan ninf fastcc [2 x <4 x double>] @fmf_a2v4f64() + ; CHECK: %call.nnan.ninf = tail call nnan ninf fastcc [2 x <4 x double>] @fmf_a2v4f64() ret void } diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp --- a/llvm/unittests/IR/InstructionsTest.cpp +++ b/llvm/unittests/IR/InstructionsTest.cpp @@ -1046,6 +1046,60 @@ FP->deleteValue(); } +TEST(InstructionsTest, FPCallIsFPMathOperator) { + LLVMContext C; + + Type *ITy = Type::getInt32Ty(C); + FunctionType *IFnTy = FunctionType::get(ITy, {}); + Value *ICallee = Constant::getNullValue(IFnTy->getPointerTo()); + std::unique_ptr ICall(CallInst::Create(IFnTy, ICallee, {}, "")); + EXPECT_FALSE(isa(ICall)); + + Type *VITy = VectorType::get(ITy, 2); + FunctionType *VIFnTy = FunctionType::get(VITy, {}); + Value *VICallee = Constant::getNullValue(VIFnTy->getPointerTo()); + std::unique_ptr VICall(CallInst::Create(VIFnTy, VICallee, {}, "")); + EXPECT_FALSE(isa(VICall)); + + Type *AITy = ArrayType::get(ITy, 2); + FunctionType *AIFnTy = FunctionType::get(AITy, {}); + Value *AICallee = Constant::getNullValue(AIFnTy->getPointerTo()); + std::unique_ptr AICall(CallInst::Create(AIFnTy, AICallee, {}, "")); + EXPECT_FALSE(isa(AICall)); + + Type *FTy = Type::getFloatTy(C); + FunctionType *FFnTy = FunctionType::get(FTy, {}); + Value *FCallee = Constant::getNullValue(FFnTy->getPointerTo()); + std::unique_ptr FCall(CallInst::Create(FFnTy, FCallee, {}, "")); + EXPECT_TRUE(isa(FCall)); + + Type *VFTy = VectorType::get(FTy, 2); + FunctionType *VFFnTy = FunctionType::get(VFTy, {}); + Value *VFCallee = Constant::getNullValue(VFFnTy->getPointerTo()); + std::unique_ptr VFCall(CallInst::Create(VFFnTy, VFCallee, {}, "")); + EXPECT_TRUE(isa(VFCall)); + + Type *AFTy = ArrayType::get(FTy, 2); + FunctionType *AFFnTy = FunctionType::get(AFTy, {}); + Value *AFCallee = Constant::getNullValue(AFFnTy->getPointerTo()); + std::unique_ptr AFCall(CallInst::Create(AFFnTy, AFCallee, {}, "")); + EXPECT_TRUE(isa(AFCall)); + + Type *AVFTy = ArrayType::get(VFTy, 2); + FunctionType *AVFFnTy = FunctionType::get(AVFTy, {}); + Value *AVFCallee = Constant::getNullValue(AVFFnTy->getPointerTo()); + std::unique_ptr AVFCall( + CallInst::Create(AVFFnTy, AVFCallee, {}, "")); + EXPECT_TRUE(isa(AVFCall)); + + Type *AAVFTy = ArrayType::get(AVFTy, 2); + FunctionType *AAVFFnTy = FunctionType::get(AAVFTy, {}); + Value *AAVFCallee = Constant::getNullValue(AAVFFnTy->getPointerTo()); + std::unique_ptr AAVFCall( + CallInst::Create(AAVFFnTy, AAVFCallee, {}, "")); + EXPECT_TRUE(isa(AAVFCall)); +} + TEST(InstructionsTest, FNegInstruction) { LLVMContext Context; Type *FltTy = Type::getFloatTy(Context);