diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -403,6 +403,15 @@ /// Map from instructions to their produced column matrix. MapVector Inst2ColumnMatrix; +private: + static Optional getFastMathFlags(Instruction *Inst) { + if (isa(*Inst)) { + return Inst->getFastMathFlags(); + } + + return None; + } + public: LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, @@ -1149,9 +1158,9 @@ /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate /// operand is transposed. void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, - const MatrixTy &B, bool AllowContraction, - IRBuilder<> &Builder, bool IsTiled, - bool IsScalarMatrixTransposed) { + const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled, + bool IsScalarMatrixTransposed, + Optional FMFOpt) { const unsigned VF = std::max( TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) .getFixedSize() / @@ -1166,6 +1175,14 @@ Result.isColumnMajor() == A.isColumnMajor() && "operands must agree on matrix layout"); unsigned NumComputeOps = 0; + bool AllowContraction = AllowContractEnabled; + + if (FMFOpt.hasValue()) { + FastMathFlags &FMF = FMFOpt.getValue(); + Builder.setFastMathFlags(FMF); + AllowContraction |= FMF.allowContract(); + } + if (A.isColumnMajor()) { // Multiply columns from the first operand with scalars from the second // operand. Then move along the K axes and accumulate the columns. With @@ -1189,8 +1206,7 @@ IsScalarMatrixTransposed ? J : K); Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, - Result.getElementType()->isFloatingPointTy(), - Builder, AllowContraction, NumComputeOps); + IsFP, Builder, AllowContraction, NumComputeOps); } Result.setVector(J, insertVector(Result.getVector(J), I, Sum, Builder)); @@ -1354,8 +1370,7 @@ } void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape, - Value *RPtr, ShapeInfo RShape, StoreInst *Store, - bool AllowContract) { + Value *RPtr, ShapeInfo RShape, StoreInst *Store) { auto *EltType = cast(MatMul->getType())->getElementType(); // Create the main tiling loop nest. @@ -1391,7 +1406,8 @@ {TileSize, TileSize}, EltType, Builder); MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol, {TileSize, TileSize}, EltType, Builder); - emitMatrixMultiply(TileResult, A, B, AllowContract, Builder, true, false); + emitMatrixMultiply(TileResult, A, B, Builder, true, false, + getFastMathFlags(MatMul)); // Store result after the inner loop is done. Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator()); storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), @@ -1430,11 +1446,8 @@ Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); Value *CPtr = Store->getPointerOperand(); - bool AllowContract = AllowContractEnabled || (isa(MatMul) && - MatMul->hasAllowContract()); if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0)) - createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store, - AllowContract); + createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store); else { IRBuilder<> Builder(Store); for (unsigned J = 0; J < C; J += TileSize) @@ -1453,7 +1466,8 @@ loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), RShape, Builder.getInt64(K), Builder.getInt64(J), {TileM, TileC}, EltType, Builder); - emitMatrixMultiply(Res, A, B, AllowContract, Builder, true, false); + emitMatrixMultiply(Res, A, B, Builder, true, false, + getFastMathFlags(MatMul)); } storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, Builder.getInt64(I), Builder.getInt64(J), EltType, @@ -1520,10 +1534,8 @@ // Initialize the output MatrixTy Result(R, C, EltType); - bool AllowContract = - AllowContractEnabled || - (isa(MatMul) && MatMul->hasAllowContract()); - emitMatrixMultiply(Result, MA, MB, AllowContract, Builder, false, true); + emitMatrixMultiply(Result, MA, MB, Builder, false, true, + getFastMathFlags(MatMul)); FusedInsts.insert(MatMul); FusedInsts.insert(cast(Transpose)); @@ -1578,9 +1590,8 @@ assert(Lhs.getElementType() == Result.getElementType() && "Matrix multiply result element type does not match arguments."); - bool AllowContract = AllowContractEnabled || (isa(MatMul) && - MatMul->hasAllowContract()); - emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false, false); + emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false, + getFastMathFlags(MatMul)); finalizeLowering(MatMul, Result, Builder); } @@ -1675,11 +1686,11 @@ case Instruction::Sub: return Builder.CreateSub(LHS, RHS); case Instruction::FAdd: - return Builder.CreateFAdd(LHS, RHS); + return Builder.CreateFAddFMF(LHS, RHS, Inst); case Instruction::FMul: - return Builder.CreateFMul(LHS, RHS); + return Builder.CreateFMulFMF(LHS, RHS, Inst); case Instruction::FSub: - return Builder.CreateFSub(LHS, RHS); + return Builder.CreateFSubFMF(LHS, RHS, Inst); default: llvm_unreachable("Unsupported binary operator for matrix"); } @@ -1713,7 +1724,7 @@ auto BuildVectorOp = [&Builder, Inst](Value *Op) { switch (Inst->getOpcode()) { case Instruction::FNeg: - return Builder.CreateFNeg(Op); + return Builder.CreateFNegFMF(Op, Inst); default: llvm_unreachable("Unsupported unary operator for matrix"); } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-double-contraction-fmf.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-double-contraction-fmf.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-double-contraction-fmf.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-double-contraction-fmf.ll @@ -14,48 +14,48 @@ ; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x double> poison, double [[TMP0]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP1:%.*]] = fmul <1 x double> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul contract <1 x double> [[BLOCK]], [[SPLAT_SPLAT]] ; CHECK-NEXT: [[BLOCK4:%.*]] = shufflevector <2 x double> [[SPLIT1]], <2 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT5:%.*]] = insertelement <1 x double> poison, double [[TMP2]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT6:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT5]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP3:%.*]] = call <1 x double> @llvm.fmuladd.v1f64(<1 x double> [[BLOCK4]], <1 x double> [[SPLAT_SPLAT6]], <1 x double> [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = call contract <1 x double> @llvm.fmuladd.v1f64(<1 x double> [[BLOCK4]], <1 x double> [[SPLAT_SPLAT6]], <1 x double> [[TMP1]]) ; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <1 x double> [[TMP3]], <1 x double> poison, <2 x i32> ; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <2 x double> undef, <2 x double> [[TMP4]], <2 x i32> ; CHECK-NEXT: [[BLOCK7:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> poison, <1 x i32> ; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT8:%.*]] = insertelement <1 x double> poison, double [[TMP6]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT9:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT8]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP7:%.*]] = fmul <1 x double> [[BLOCK7]], [[SPLAT_SPLAT9]] +; CHECK-NEXT: [[TMP7:%.*]] = fmul contract <1 x double> [[BLOCK7]], [[SPLAT_SPLAT9]] ; CHECK-NEXT: [[BLOCK10:%.*]] = shufflevector <2 x double> [[SPLIT1]], <2 x double> poison, <1 x i32> ; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT11:%.*]] = insertelement <1 x double> poison, double [[TMP8]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT12:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT11]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP9:%.*]] = call <1 x double> @llvm.fmuladd.v1f64(<1 x double> [[BLOCK10]], <1 x double> [[SPLAT_SPLAT12]], <1 x double> [[TMP7]]) +; CHECK-NEXT: [[TMP9:%.*]] = call contract <1 x double> @llvm.fmuladd.v1f64(<1 x double> [[BLOCK10]], <1 x double> [[SPLAT_SPLAT12]], <1 x double> [[TMP7]]) ; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <1 x double> [[TMP9]], <1 x double> poison, <2 x i32> ; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <2 x double> [[TMP5]], <2 x double> [[TMP10]], <2 x i32> ; CHECK-NEXT: [[BLOCK13:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT14:%.*]] = insertelement <1 x double> poison, double [[TMP12]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT15:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT14]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP13:%.*]] = fmul <1 x double> [[BLOCK13]], [[SPLAT_SPLAT15]] +; CHECK-NEXT: [[TMP13:%.*]] = fmul contract <1 x double> [[BLOCK13]], [[SPLAT_SPLAT15]] ; CHECK-NEXT: [[BLOCK16:%.*]] = shufflevector <2 x double> [[SPLIT1]], <2 x double> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT17:%.*]] = insertelement <1 x double> poison, double [[TMP14]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT18:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT17]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP15:%.*]] = call <1 x double> @llvm.fmuladd.v1f64(<1 x double> [[BLOCK16]], <1 x double> [[SPLAT_SPLAT18]], <1 x double> [[TMP13]]) +; CHECK-NEXT: [[TMP15:%.*]] = call contract <1 x double> @llvm.fmuladd.v1f64(<1 x double> [[BLOCK16]], <1 x double> [[SPLAT_SPLAT18]], <1 x double> [[TMP13]]) ; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <1 x double> [[TMP15]], <1 x double> poison, <2 x i32> ; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <2 x double> undef, <2 x double> [[TMP16]], <2 x i32> ; CHECK-NEXT: [[BLOCK19:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> poison, <1 x i32> ; CHECK-NEXT: [[TMP18:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT20:%.*]] = insertelement <1 x double> poison, double [[TMP18]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT21:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT20]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP19:%.*]] = fmul <1 x double> [[BLOCK19]], [[SPLAT_SPLAT21]] +; CHECK-NEXT: [[TMP19:%.*]] = fmul contract <1 x double> [[BLOCK19]], [[SPLAT_SPLAT21]] ; CHECK-NEXT: [[BLOCK22:%.*]] = shufflevector <2 x double> [[SPLIT1]], <2 x double> poison, <1 x i32> ; CHECK-NEXT: [[TMP20:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT23:%.*]] = insertelement <1 x double> poison, double [[TMP20]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT24:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT23]], <1 x double> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP21:%.*]] = call <1 x double> @llvm.fmuladd.v1f64(<1 x double> [[BLOCK22]], <1 x double> [[SPLAT_SPLAT24]], <1 x double> [[TMP19]]) +; CHECK-NEXT: [[TMP21:%.*]] = call contract <1 x double> @llvm.fmuladd.v1f64(<1 x double> [[BLOCK22]], <1 x double> [[SPLAT_SPLAT24]], <1 x double> [[TMP19]]) ; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <1 x double> [[TMP21]], <1 x double> poison, <2 x i32> ; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <2 x double> [[TMP17]], <2 x double> [[TMP22]], <2 x i32> ; CHECK-NEXT: [[TMP24:%.*]] = shufflevector <2 x double> [[TMP11]], <2 x double> [[TMP23]], <4 x i32> diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-float-contraction-fmf.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-float-contraction-fmf.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-float-contraction-fmf.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-float-contraction-fmf.ll @@ -14,48 +14,48 @@ ; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x float> [[SPLIT2]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x float> poison, float [[TMP0]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP1:%.*]] = fmul <1 x float> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul contract <1 x float> [[BLOCK]], [[SPLAT_SPLAT]] ; CHECK-NEXT: [[BLOCK4:%.*]] = shufflevector <2 x float> [[SPLIT1]], <2 x float> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x float> [[SPLIT2]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT5:%.*]] = insertelement <1 x float> poison, float [[TMP2]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT6:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT5]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP3:%.*]] = call <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK4]], <1 x float> [[SPLAT_SPLAT6]], <1 x float> [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = call contract <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK4]], <1 x float> [[SPLAT_SPLAT6]], <1 x float> [[TMP1]]) ; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <1 x float> [[TMP3]], <1 x float> poison, <2 x i32> ; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <2 x float> undef, <2 x float> [[TMP4]], <2 x i32> ; CHECK-NEXT: [[BLOCK7:%.*]] = shufflevector <2 x float> [[SPLIT]], <2 x float> poison, <1 x i32> ; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x float> [[SPLIT2]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT8:%.*]] = insertelement <1 x float> poison, float [[TMP6]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT9:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT8]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP7:%.*]] = fmul <1 x float> [[BLOCK7]], [[SPLAT_SPLAT9]] +; CHECK-NEXT: [[TMP7:%.*]] = fmul contract <1 x float> [[BLOCK7]], [[SPLAT_SPLAT9]] ; CHECK-NEXT: [[BLOCK10:%.*]] = shufflevector <2 x float> [[SPLIT1]], <2 x float> poison, <1 x i32> ; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x float> [[SPLIT2]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT11:%.*]] = insertelement <1 x float> poison, float [[TMP8]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT12:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT11]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP9:%.*]] = call <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK10]], <1 x float> [[SPLAT_SPLAT12]], <1 x float> [[TMP7]]) +; CHECK-NEXT: [[TMP9:%.*]] = call contract <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK10]], <1 x float> [[SPLAT_SPLAT12]], <1 x float> [[TMP7]]) ; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <1 x float> [[TMP9]], <1 x float> poison, <2 x i32> ; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <2 x float> [[TMP5]], <2 x float> [[TMP10]], <2 x i32> ; CHECK-NEXT: [[BLOCK13:%.*]] = shufflevector <2 x float> [[SPLIT]], <2 x float> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x float> [[SPLIT3]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT14:%.*]] = insertelement <1 x float> poison, float [[TMP12]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT15:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT14]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP13:%.*]] = fmul <1 x float> [[BLOCK13]], [[SPLAT_SPLAT15]] +; CHECK-NEXT: [[TMP13:%.*]] = fmul contract <1 x float> [[BLOCK13]], [[SPLAT_SPLAT15]] ; CHECK-NEXT: [[BLOCK16:%.*]] = shufflevector <2 x float> [[SPLIT1]], <2 x float> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x float> [[SPLIT3]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT17:%.*]] = insertelement <1 x float> poison, float [[TMP14]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT18:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT17]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP15:%.*]] = call <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK16]], <1 x float> [[SPLAT_SPLAT18]], <1 x float> [[TMP13]]) +; CHECK-NEXT: [[TMP15:%.*]] = call contract <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK16]], <1 x float> [[SPLAT_SPLAT18]], <1 x float> [[TMP13]]) ; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <1 x float> [[TMP15]], <1 x float> poison, <2 x i32> ; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <2 x float> undef, <2 x float> [[TMP16]], <2 x i32> ; CHECK-NEXT: [[BLOCK19:%.*]] = shufflevector <2 x float> [[SPLIT]], <2 x float> poison, <1 x i32> ; CHECK-NEXT: [[TMP18:%.*]] = extractelement <2 x float> [[SPLIT3]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT20:%.*]] = insertelement <1 x float> poison, float [[TMP18]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT21:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT20]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP19:%.*]] = fmul <1 x float> [[BLOCK19]], [[SPLAT_SPLAT21]] +; CHECK-NEXT: [[TMP19:%.*]] = fmul contract <1 x float> [[BLOCK19]], [[SPLAT_SPLAT21]] ; CHECK-NEXT: [[BLOCK22:%.*]] = shufflevector <2 x float> [[SPLIT1]], <2 x float> poison, <1 x i32> ; CHECK-NEXT: [[TMP20:%.*]] = extractelement <2 x float> [[SPLIT3]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT23:%.*]] = insertelement <1 x float> poison, float [[TMP20]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT24:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT23]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP21:%.*]] = call <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK22]], <1 x float> [[SPLAT_SPLAT24]], <1 x float> [[TMP19]]) +; CHECK-NEXT: [[TMP21:%.*]] = call contract <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK22]], <1 x float> [[SPLAT_SPLAT24]], <1 x float> [[TMP19]]) ; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <1 x float> [[TMP21]], <1 x float> poison, <2 x i32> ; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <2 x float> [[TMP17]], <2 x float> [[TMP22]], <2 x i32> ; CHECK-NEXT: [[TMP24:%.*]] = shufflevector <2 x float> [[TMP11]], <2 x float> [[TMP23]], <4 x i32> diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-float-contraction.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-float-contraction.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-float-contraction.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-float-contraction.ll @@ -14,48 +14,48 @@ ; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x float> [[SPLIT2]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x float> poison, float [[TMP0]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP1:%.*]] = fmul <1 x float> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul contract <1 x float> [[BLOCK]], [[SPLAT_SPLAT]] ; CHECK-NEXT: [[BLOCK4:%.*]] = shufflevector <2 x float> [[SPLIT1]], <2 x float> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x float> [[SPLIT2]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT5:%.*]] = insertelement <1 x float> poison, float [[TMP2]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT6:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT5]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP3:%.*]] = call <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK4]], <1 x float> [[SPLAT_SPLAT6]], <1 x float> [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = call contract <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK4]], <1 x float> [[SPLAT_SPLAT6]], <1 x float> [[TMP1]]) ; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <1 x float> [[TMP3]], <1 x float> poison, <2 x i32> ; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <2 x float> undef, <2 x float> [[TMP4]], <2 x i32> ; CHECK-NEXT: [[BLOCK7:%.*]] = shufflevector <2 x float> [[SPLIT]], <2 x float> poison, <1 x i32> ; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x float> [[SPLIT2]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT8:%.*]] = insertelement <1 x float> poison, float [[TMP6]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT9:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT8]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP7:%.*]] = fmul <1 x float> [[BLOCK7]], [[SPLAT_SPLAT9]] +; CHECK-NEXT: [[TMP7:%.*]] = fmul contract <1 x float> [[BLOCK7]], [[SPLAT_SPLAT9]] ; CHECK-NEXT: [[BLOCK10:%.*]] = shufflevector <2 x float> [[SPLIT1]], <2 x float> poison, <1 x i32> ; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x float> [[SPLIT2]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT11:%.*]] = insertelement <1 x float> poison, float [[TMP8]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT12:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT11]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP9:%.*]] = call <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK10]], <1 x float> [[SPLAT_SPLAT12]], <1 x float> [[TMP7]]) +; CHECK-NEXT: [[TMP9:%.*]] = call contract <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK10]], <1 x float> [[SPLAT_SPLAT12]], <1 x float> [[TMP7]]) ; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <1 x float> [[TMP9]], <1 x float> poison, <2 x i32> ; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <2 x float> [[TMP5]], <2 x float> [[TMP10]], <2 x i32> ; CHECK-NEXT: [[BLOCK13:%.*]] = shufflevector <2 x float> [[SPLIT]], <2 x float> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x float> [[SPLIT3]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT14:%.*]] = insertelement <1 x float> poison, float [[TMP12]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT15:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT14]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP13:%.*]] = fmul <1 x float> [[BLOCK13]], [[SPLAT_SPLAT15]] +; CHECK-NEXT: [[TMP13:%.*]] = fmul contract <1 x float> [[BLOCK13]], [[SPLAT_SPLAT15]] ; CHECK-NEXT: [[BLOCK16:%.*]] = shufflevector <2 x float> [[SPLIT1]], <2 x float> poison, <1 x i32> zeroinitializer ; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x float> [[SPLIT3]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT17:%.*]] = insertelement <1 x float> poison, float [[TMP14]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT18:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT17]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP15:%.*]] = call <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK16]], <1 x float> [[SPLAT_SPLAT18]], <1 x float> [[TMP13]]) +; CHECK-NEXT: [[TMP15:%.*]] = call contract <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK16]], <1 x float> [[SPLAT_SPLAT18]], <1 x float> [[TMP13]]) ; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <1 x float> [[TMP15]], <1 x float> poison, <2 x i32> ; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <2 x float> undef, <2 x float> [[TMP16]], <2 x i32> ; CHECK-NEXT: [[BLOCK19:%.*]] = shufflevector <2 x float> [[SPLIT]], <2 x float> poison, <1 x i32> ; CHECK-NEXT: [[TMP18:%.*]] = extractelement <2 x float> [[SPLIT3]], i64 0 ; CHECK-NEXT: [[SPLAT_SPLATINSERT20:%.*]] = insertelement <1 x float> poison, float [[TMP18]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT21:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT20]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP19:%.*]] = fmul <1 x float> [[BLOCK19]], [[SPLAT_SPLAT21]] +; CHECK-NEXT: [[TMP19:%.*]] = fmul contract <1 x float> [[BLOCK19]], [[SPLAT_SPLAT21]] ; CHECK-NEXT: [[BLOCK22:%.*]] = shufflevector <2 x float> [[SPLIT1]], <2 x float> poison, <1 x i32> ; CHECK-NEXT: [[TMP20:%.*]] = extractelement <2 x float> [[SPLIT3]], i64 1 ; CHECK-NEXT: [[SPLAT_SPLATINSERT23:%.*]] = insertelement <1 x float> poison, float [[TMP20]], i32 0 ; CHECK-NEXT: [[SPLAT_SPLAT24:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT23]], <1 x float> poison, <1 x i32> zeroinitializer -; CHECK-NEXT: [[TMP21:%.*]] = call <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK22]], <1 x float> [[SPLAT_SPLAT24]], <1 x float> [[TMP19]]) +; CHECK-NEXT: [[TMP21:%.*]] = call contract <1 x float> @llvm.fmuladd.v1f32(<1 x float> [[BLOCK22]], <1 x float> [[SPLAT_SPLAT24]], <1 x float> [[TMP19]]) ; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <1 x float> [[TMP21]], <1 x float> poison, <2 x i32> ; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <2 x float> [[TMP17]], <2 x float> [[TMP22]], <2 x i32> ; CHECK-NEXT: [[TMP24:%.*]] = shufflevector <2 x float> [[TMP11]], <2 x float> [[TMP23]], <4 x i32> diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/preserve-existing-fast-math-flags.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/preserve-existing-fast-math-flags.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/preserve-existing-fast-math-flags.ll @@ -0,0 +1,44 @@ +; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s + +; Function Attrs: nofree nounwind uwtable willreturn mustprogress +define <4 x float> @preserve_fmf(<4 x float> %m, <4 x float> %n, float %x, float %y) { +; CHECK-LABEL: @preserve_fmf( +; CHECK: fmul fast <1 x float> +; CHECK: call fast <1 x float> @llvm.fmuladd.v1f32( +; CHECK: fmul fast <1 x float> +; CHECK: call fast <1 x float> @llvm.fmuladd.v1f32( +; CHECK: fmul fast <1 x float> +; CHECK: call fast <1 x float> @llvm.fmuladd.v1f32( +; CHECK: fmul fast <1 x float> +; CHECK: call fast <1 x float> @llvm.fmuladd.v1f32( +; CHECK-COUNT-2: fadd fast <2 x float> +; CHECK-COUNT-2: fsub reassoc <2 x float> +; CHECK-COUNT-2: fneg contract <2 x float> +; CHECK-COUNT-2: fmul reassoc contract <2 x float> +; CHECK-COUNT-2: fmul reassoc <1 x float> +; CHECK: fadd reassoc <1 x float> +; CHECK-COUNT-2: fmul reassoc <1 x float> +; CHECK: fadd reassoc <1 x float> +; CHECK-COUNT-2: fmul reassoc <1 x float> +; CHECK: fadd reassoc <1 x float> +; CHECK-COUNT-2: fmul reassoc <1 x float> +; CHECK: fadd reassoc <1 x float> +; CHECK: fmul reassoc contract <1 x float> +; CHECK: call reassoc contract <1 x float> @llvm.fmuladd.v1f32( +; CHECK: fmul reassoc contract <1 x float> +; CHECK: call reassoc contract <1 x float> @llvm.fmuladd.v1f32( +; CHECK: fmul reassoc contract <1 x float> +; CHECK: call reassoc contract <1 x float> @llvm.fmuladd.v1f32( +; CHECK: fmul reassoc contract <1 x float> +; CHECK: call reassoc contract <1 x float> @llvm.fmuladd.v1f32( + %res = tail call fast <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %m, <4 x float> %n, i32 2, i32 2, i32 2) + %res.2 = fadd fast <4 x float> %res, %m + %res.3 = fsub reassoc <4 x float> %res.2, %n + %res.4 = fneg contract <4 x float> %res.3 + %res.5 = fmul reassoc contract <4 x float> %res.3, %res.4 + %res.6 = tail call reassoc <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %res.4, <4 x float> %res.5, i32 2, i32 2, i32 2) + %res.7 = tail call contract reassoc <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %res.5, <4 x float> %res.6, i32 2, i32 2, i32 2) + ret <4 x float> %res.7 +} + +declare <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float>, <4 x float>, i32, i32, i32)