diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1019,6 +1019,26 @@ unsigned C = Result.getNumColumns(); unsigned M = A.getNumColumns(); + // Dot product + if (R == 1 && C == 1) { + auto *VecA = A.embedInVector(Builder); + auto *VecB = B.embedInVector(Builder); + auto *Mul = Builder.CreateFMul(VecA, VecB); + Function *Reduce = Intrinsic::getDeclaration( + Func.getParent(), Intrinsic::vector_reduce_fadd, VecA->getType()); + auto *Res = Builder.CreateCall( + Reduce, + {ConstantFP::get(cast(VecA->getType())->getElementType(), + 0.0), + Mul}); + FastMathFlags FMF; + FMF.setFast(); + cast(Res)->setFastMathFlags(FMF); + Result.setVector(0, Builder.CreateInsertElement(Result.getVector(0), Res, + uint64_t(0))); + return; + } + bool IsFP = Result.getElementType()->isFloatingPointTy(); assert(A.isColumnMajor() == B.isColumnMajor() && Result.isColumnMajor() == A.isColumnMajor() && diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-dot-float.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-dot-float.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-dot-float.ll @@ -0,0 +1,73 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s + +define <1 x float> @dot_3x_float(<3 x float> %a, <3 x float> %b) { +; CHECK-LABEL: @dot_3x_float( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <3 x float> [[A:%.*]], <3 x float> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <3 x float> [[A]], <3 x float> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <3 x float> [[A]], <3 x float> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <3 x float> [[B:%.*]], <3 x float> poison, <3 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = shufflevector <1 x float> [[SPLIT]], <1 x float> [[SPLIT1]], <2 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <1 x float> [[SPLIT2]], <1 x float> poison, <2 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x float> [[TMP0]], <2 x float> [[TMP1]], <3 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = fmul <3 x float> [[TMP2]], [[SPLIT3]] +; CHECK-NEXT: [[TMP4:%.*]] = call fast float @llvm.vector.reduce.fadd.v3f32(float 0.000000e+00, <3 x float> [[TMP3]]) +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <1 x float> undef, float [[TMP4]], i64 0 +; CHECK-NEXT: ret <1 x float> [[TMP5]] +; +entry: + %c = call <1 x float> @llvm.matrix.multiply.v1f32.v3f32.v3f32(<3 x float> %a, <3 x float> %b, i32 1, i32 3, i32 1) + ret <1 x float> %c +} + +declare <1 x float> @llvm.matrix.multiply.v1f32.v3f32.v3f32(<3 x float>, <3 x float>, i32, i32, i32) + +define <1 x float> @dot_4x_float(<4 x float> %a, <4 x float> %b) { +; CHECK-LABEL: @dot_4x_float( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x float> [[A:%.*]], <4 x float> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x float> [[A]], <4 x float> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <4 x float> [[A]], <4 x float> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <4 x float> [[A]], <4 x float> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <4 x float> [[B:%.*]], <4 x float> poison, <4 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = shufflevector <1 x float> [[SPLIT]], <1 x float> [[SPLIT1]], <2 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <1 x float> [[SPLIT2]], <1 x float> [[SPLIT3]], <2 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x float> [[TMP0]], <2 x float> [[TMP1]], <4 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = fmul <4 x float> [[TMP2]], [[SPLIT4]] +; CHECK-NEXT: [[TMP4:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[TMP3]]) +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <1 x float> undef, float [[TMP4]], i64 0 +; CHECK-NEXT: ret <1 x float> [[TMP5]] +; +entry: + %c = call <1 x float> @llvm.matrix.multiply.v1f32.v4f32.v4f32(<4 x float> %a, <4 x float> %b, i32 1, i32 4, i32 1) + ret <1 x float> %c +} + +declare <1 x float> @llvm.matrix.multiply.v1f32.v4f32.v4f32(<4 x float>, <4 x float>, i32, i32, i32) + +define <1 x float> @dot_5x_float(<5 x float> %a, <5 x float> %b) { +; CHECK-LABEL: @dot_5x_float( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <5 x float> [[A:%.*]], <5 x float> poison, <1 x i32> zeroinitializer +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <5 x float> [[A]], <5 x float> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <5 x float> [[A]], <5 x float> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <5 x float> [[A]], <5 x float> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <5 x float> [[A]], <5 x float> poison, <1 x i32> +; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <5 x float> [[B:%.*]], <5 x float> poison, <5 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = shufflevector <1 x float> [[SPLIT]], <1 x float> [[SPLIT1]], <2 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <1 x float> [[SPLIT2]], <1 x float> [[SPLIT3]], <2 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x float> [[TMP0]], <2 x float> [[TMP1]], <4 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <1 x float> [[SPLIT4]], <1 x float> poison, <4 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <4 x float> [[TMP2]], <4 x float> [[TMP3]], <5 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = fmul <5 x float> [[TMP4]], [[SPLIT5]] +; CHECK-NEXT: [[TMP6:%.*]] = call fast float @llvm.vector.reduce.fadd.v5f32(float 0.000000e+00, <5 x float> [[TMP5]]) +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <1 x float> undef, float [[TMP6]], i64 0 +; CHECK-NEXT: ret <1 x float> [[TMP7]] +; +entry: + %c = call <1 x float> @llvm.matrix.multiply.v1f32.v5f32.v5f32(<5 x float> %a, <5 x float> %b, i32 1, i32 5, i32 1) + ret <1 x float> %c +} + +declare <1 x float> @llvm.matrix.multiply.v1f32.v5f32.v5f32(<5 x float>, <5 x float>, i32, i32, i32)