diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -10123,7 +10123,8 @@ or more :ref:`fast-math-flags `. These are optimization hints to enable otherwise unsafe floating-point optimizations. Fast-math-flags are only valid for phis that return a floating-point scalar or vector -type. +type, or an array (nested to any depth) of floating-point scalar or vector +types. Semantics: """""""""" @@ -10172,7 +10173,8 @@ #. The optional ``fast-math flags`` marker indicates that the select has one or more :ref:`fast-math flags `. These are optimization hints to enable otherwise unsafe floating-point optimizations. Fast-math flags are only valid - for selects that return a floating-point scalar or vector type. + for selects that return a floating-point scalar or vector type, or an array + (nested to any depth) of floating-point scalar or vector types. Semantics: """""""""" @@ -10271,7 +10273,8 @@ #. The optional ``fast-math flags`` marker indicates that the call has one or more :ref:`fast-math flags `, which are optimization hints to enable otherwise unsafe floating-point optimizations. Fast-math flags are only valid - for calls that return a floating-point scalar or vector type. + for calls that return a floating-point scalar or vector type, or an array + (nested to any depth) of floating-point scalar or vector types. #. The optional "cconv" marker indicates which :ref:`calling convention ` the call should use. If none is diff --git a/llvm/include/llvm/IR/Operator.h b/llvm/include/llvm/IR/Operator.h --- a/llvm/include/llvm/IR/Operator.h +++ b/llvm/include/llvm/IR/Operator.h @@ -394,8 +394,12 @@ return true; case Instruction::PHI: case Instruction::Select: - case Instruction::Call: - return V->getType()->isFPOrFPVectorTy(); + case Instruction::Call: { + Type *Ty = V->getType(); + while (ArrayType *ArrTy = dyn_cast(Ty)) + Ty = ArrTy->getElementType(); + return Ty->isFPOrFPVectorTy(); + } default: return false; } diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp --- a/llvm/unittests/IR/InstructionsTest.cpp +++ b/llvm/unittests/IR/InstructionsTest.cpp @@ -1046,6 +1046,60 @@ FP->deleteValue(); } +TEST(InstructionsTest, FPCallIsFPMathOperator) { + LLVMContext C; + + Type *ITy = Type::getInt32Ty(C); + FunctionType *IFnTy = FunctionType::get(ITy, {}); + Value *ICallee = Constant::getNullValue(IFnTy->getPointerTo()); + std::unique_ptr ICall(CallInst::Create(IFnTy, ICallee, {}, "")); + EXPECT_FALSE(isa(ICall)); + + Type *VITy = VectorType::get(ITy, 2); + FunctionType *VIFnTy = FunctionType::get(VITy, {}); + Value *VICallee = Constant::getNullValue(VIFnTy->getPointerTo()); + std::unique_ptr VICall(CallInst::Create(VIFnTy, VICallee, {}, "")); + EXPECT_FALSE(isa(VICall)); + + Type *AITy = ArrayType::get(ITy, 2); + FunctionType *AIFnTy = FunctionType::get(AITy, {}); + Value *AICallee = Constant::getNullValue(AIFnTy->getPointerTo()); + std::unique_ptr AICall(CallInst::Create(AIFnTy, AICallee, {}, "")); + EXPECT_FALSE(isa(AICall)); + + Type *FTy = Type::getFloatTy(C); + FunctionType *FFnTy = FunctionType::get(FTy, {}); + Value *FCallee = Constant::getNullValue(FFnTy->getPointerTo()); + std::unique_ptr FCall(CallInst::Create(FFnTy, FCallee, {}, "")); + EXPECT_TRUE(isa(FCall)); + + Type *VFTy = VectorType::get(FTy, 2); + FunctionType *VFFnTy = FunctionType::get(VFTy, {}); + Value *VFCallee = Constant::getNullValue(VFFnTy->getPointerTo()); + std::unique_ptr VFCall(CallInst::Create(VFFnTy, VFCallee, {}, "")); + EXPECT_TRUE(isa(VFCall)); + + Type *AFTy = ArrayType::get(FTy, 2); + FunctionType *AFFnTy = FunctionType::get(AFTy, {}); + Value *AFCallee = Constant::getNullValue(AFFnTy->getPointerTo()); + std::unique_ptr AFCall(CallInst::Create(AFFnTy, AFCallee, {}, "")); + EXPECT_TRUE(isa(AFCall)); + + Type *AVFTy = ArrayType::get(VFTy, 2); + FunctionType *AVFFnTy = FunctionType::get(AVFTy, {}); + Value *AVFCallee = Constant::getNullValue(AVFFnTy->getPointerTo()); + std::unique_ptr AVFCall( + CallInst::Create(AVFFnTy, AVFCallee, {}, "")); + EXPECT_TRUE(isa(AVFCall)); + + Type *AAVFTy = ArrayType::get(AVFTy, 2); + FunctionType *AAVFFnTy = FunctionType::get(AAVFTy, {}); + Value *AAVFCallee = Constant::getNullValue(AAVFFnTy->getPointerTo()); + std::unique_ptr AAVFCall( + CallInst::Create(AAVFFnTy, AAVFCallee, {}, "")); + EXPECT_TRUE(isa(AAVFCall)); +} + TEST(InstructionsTest, FNegInstruction) { LLVMContext Context; Type *FltTy = Type::getFloatTy(Context);