diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1151,7 +1151,7 @@ void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, const MatrixTy &B, bool AllowContraction, IRBuilder<> &Builder, bool IsTiled, - bool IsScalarMatrixTransposed) { + bool IsScalarMatrixTransposed, Instruction *Inst) { const unsigned VF = std::max( TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) .getFixedSize() / @@ -1166,6 +1166,11 @@ Result.isColumnMajor() == A.isColumnMajor() && "operands must agree on matrix layout"); unsigned NumComputeOps = 0; + + if (auto *FPOp = dyn_cast(Inst)) { + Builder.setFastMathFlags(FPOp->getFastMathFlags()); + } + if (A.isColumnMajor()) { // Multiply columns from the first operand with scalars from the second // operand. Then move along the K axes and accumulate the columns. With @@ -1189,8 +1194,7 @@ IsScalarMatrixTransposed ? J : K); Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, - Result.getElementType()->isFloatingPointTy(), - Builder, AllowContraction, NumComputeOps); + IsFP, Builder, AllowContraction, NumComputeOps); } Result.setVector(J, insertVector(Result.getVector(J), I, Sum, Builder)); @@ -1391,7 +1395,8 @@ {TileSize, TileSize}, EltType, Builder); MatrixTy B = loadMatrix(RPtr, {}, false, RShape, TI.CurrentK, TI.CurrentCol, {TileSize, TileSize}, EltType, Builder); - emitMatrixMultiply(TileResult, A, B, AllowContract, Builder, true, false); + emitMatrixMultiply(TileResult, A, B, AllowContract, Builder, true, false, + MatMul); // Store result after the inner loop is done. Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator()); storeMatrix(TileResult, Store->getPointerOperand(), Store->getAlign(), @@ -1453,7 +1458,8 @@ loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), RShape, Builder.getInt64(K), Builder.getInt64(J), {TileM, TileC}, EltType, Builder); - emitMatrixMultiply(Res, A, B, AllowContract, Builder, true, false); + emitMatrixMultiply(Res, A, B, AllowContract, Builder, true, false, + MatMul); } storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, Builder.getInt64(I), Builder.getInt64(J), EltType, @@ -1523,7 +1529,8 @@ bool AllowContract = AllowContractEnabled || (isa(MatMul) && MatMul->hasAllowContract()); - emitMatrixMultiply(Result, MA, MB, AllowContract, Builder, false, true); + emitMatrixMultiply(Result, MA, MB, AllowContract, Builder, false, true, + MatMul); FusedInsts.insert(MatMul); FusedInsts.insert(cast(Transpose)); @@ -1580,7 +1587,8 @@ bool AllowContract = AllowContractEnabled || (isa(MatMul) && MatMul->hasAllowContract()); - emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false, false); + emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false, false, + MatMul); finalizeLowering(MatMul, Result, Builder); } @@ -1675,11 +1683,11 @@ case Instruction::Sub: return Builder.CreateSub(LHS, RHS); case Instruction::FAdd: - return Builder.CreateFAdd(LHS, RHS); + return Builder.CreateFAddFMF(LHS, RHS, Inst); case Instruction::FMul: - return Builder.CreateFMul(LHS, RHS); + return Builder.CreateFMulFMF(LHS, RHS, Inst); case Instruction::FSub: - return Builder.CreateFSub(LHS, RHS); + return Builder.CreateFSubFMF(LHS, RHS, Inst); default: llvm_unreachable("Unsupported binary operator for matrix"); } @@ -1713,7 +1721,7 @@ auto BuildVectorOp = [&Builder, Inst](Value *Op) { switch (Inst->getOpcode()) { case Instruction::FNeg: - return Builder.CreateFNeg(Op); + return Builder.CreateFNegFMF(Op, Inst); default: llvm_unreachable("Unsupported unary operator for matrix"); } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/preserve-existing-fast-math-flags.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/preserve-existing-fast-math-flags.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/preserve-existing-fast-math-flags.ll @@ -0,0 +1,23 @@ +; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s + +; Function Attrs: nofree nounwind uwtable willreturn mustprogress +define <4 x float> @preserve_fmf(<4 x float> %m, float %x, float %y) { +; CHECK-LABEL: @preserve_fmf( +; CHECK: fmul fast <1 x float> +; CHECK: call fast <1 x float> @llvm.fmuladd.v1f32( +; CHECK: fmul fast <1 x float> +; CHECK: call fast <1 x float> @llvm.fmuladd.v1f32( +; CHECK: fmul fast <1 x float> +; CHECK: call fast <1 x float> @llvm.fmuladd.v1f32( +; CHECK: fmul fast <1 x float> +; CHECK: call fast <1 x float> @llvm.fmuladd.v1f32( +; CHECK: fadd fast <2 x float> +; CHECK-NEXT: fadd fast <2 x float> + %i1 = insertelement <4 x float> , float %x, i64 0 + %i2 = insertelement <4 x float> %i1, float %y, i64 3 + %res = tail call fast <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %m, <4 x float> %i1, i32 2, i32 2, i32 2) + %res.2 = fadd fast <4 x float> %res, %m + ret <4 x float> %res +} + +declare <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float>, <4 x float>, i32, i32, i32)