diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -80,6 +80,9 @@ clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout"))); +static cl::opt PrintAfterTransposeOpt("matrix-print-after-transpose-opt", + cl::init(false)); + /// Helper function to either return Scope, if it is a subprogram or the /// attached subprogram for a local scope. static DISubprogram *getSubprogram(DIScope *Scope) { @@ -88,6 +91,20 @@ return cast(Scope)->getSubprogram(); } +/// Return true if V is a splat of a value (which is used when multiplying a +/// matrix with a scalar). +static bool isSplat(Value *V) { + if (auto *SV = dyn_cast(V)) + return SV->isZeroEltSplat(); + return false; +} + +/// Match any mul operation (fp or integer). +template +auto m_AnyMul(const LTy &L, const RTy &R) { + return m_CombineOr(m_Mul(L, R), m_FMul(L, R)); +} + namespace { // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute @@ -747,8 +764,8 @@ Value *TA, *TAMA, *TAMB; ConstantInt *R, *K, *C; - if (match(&I, m_Intrinsic(m_Value(TA)))) { - + if (match(&I, m_Intrinsic( + m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) { // Transpose of a transpose is a nop Value *TATA; if (match(TA, @@ -757,7 +774,11 @@ EraseFromParent(&I); EraseFromParent(TA); } - + // k^T -> k + else if (isSplat(TA)) { + ReplaceAllUsesWith(I, TA); + EraseFromParent(&I); + } // (A * B)^t -> B^t * A^t // RxK KxC CxK KxR else if (match(TA, m_Intrinsic( @@ -773,6 +794,28 @@ ReplaceAllUsesWith(I, NewInst); EraseFromParent(&I); EraseFromParent(TA); + // Same as above, but with a mul, which occurs when multiplied + // with a scalar. + // (A * k)^t -> A^t * k + // R x C RxC + } else if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) && + (isSplat(TAMA) || isSplat(TAMB))) { + IRBuilder<> LocalBuilder(&I); + // We know that the transposed operand is of shape RxC. + // An when multiplied with a scalar, the shape is preserved. + NewInst = distributeTransposes( + TAMA, {R, C}, TAMB, {R, C}, Builder, + [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { + bool IsFP = I.getType()->isFPOrFPVectorTy(); + auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul") + : LocalBuilder.CreateMul(T0, T1, "mmul"); + auto *Result = cast(Mul); + setShapeInfo(Result, Shape0); + return Result; + }); + ReplaceAllUsesWith(I, NewInst); + EraseFromParent(&I); + EraseFromParent(TA); } } @@ -848,10 +891,10 @@ if (!isMinimal()) { optimizeTransposes(); - LLVM_DEBUG({ + if (PrintAfterTransposeOpt) { dbgs() << "Dump after matrix transpose optimization:\n"; Func.dump(); - }); + } } bool Changed = false; diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/after-transpose-opts.ll @@ -0,0 +1,99 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; REQUIRES: aarch64-registered-target + +; RUN: opt -passes='lower-matrix-intrinsics' -matrix-print-after-transpose-opt -disable-output %s 2>&1 | FileCheck %s + +target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" +target triple = "aarch64-apple-ios" + +; k * A^T +define void @kat(<9 x double>* %Aptr, double %k, <9 x double>* %C) { +; CHECK-LABEL: @kat( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128 +; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0 +; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer +; CHECK-NEXT: [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3) +; CHECK-NEXT: [[MUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[SPLAT]], <9 x double> [[AT]], i32 3, i32 3, i32 3) +; CHECK-NEXT: store <9 x double> [[MUL]], <9 x double>* [[C:%.*]], align 128 +; CHECK-NEXT: ret void +; +entry: + %a = load <9 x double>, <9 x double>* %Aptr + %veck = insertelement <9 x double> poison, double %k, i64 0 + %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer + %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) + %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %at, i32 3, i32 3, i32 3) + store <9 x double> %mul, <9 x double>* %C + ret void +} + +; (k * A)^T -> A^T * k +define void @ka_t(<9 x double>* %Aptr, double %k, <9 x double>* %C) { +; CHECK-LABEL: @ka_t( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128 +; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0 +; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer +; CHECK-NEXT: [[A_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3) +; CHECK-NEXT: [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[A_T]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3) +; CHECK-NEXT: store <9 x double> [[MMUL]], <9 x double>* [[C:%.*]], align 128 +; CHECK-NEXT: ret void +; +entry: + %a = load <9 x double>, <9 x double>* %Aptr + %veck = insertelement <9 x double> poison, double %k, i64 0 + %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer + %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %a, i32 3, i32 3, i32 3) + %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3) + store <9 x double> %t, <9 x double>* %C + ret void +} + +; (k * A)^T -> A^T * k with fmul +define void @ka_t_fmul(<9 x double>* %Aptr, double %k, <9 x double>* %C) { +; CHECK-LABEL: @ka_t_fmul( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128 +; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0 +; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer +; CHECK-NEXT: [[A_T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3) +; CHECK-NEXT: [[MMUL:%.*]] = fmul <9 x double> [[SPLAT]], [[A_T]] +; CHECK-NEXT: store <9 x double> [[MMUL]], <9 x double>* [[C:%.*]], align 128 +; CHECK-NEXT: ret void +; +entry: + %a = load <9 x double>, <9 x double>* %Aptr + %veck = insertelement <9 x double> poison, double %k, i64 0 + %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer + %mul = fmul <9 x double> %splat, %a + %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3) + store <9 x double> %t, <9 x double>* %C + ret void +} + +; (k * A)^T -> A^T * k with mul (non-fp types) +define void @ka_t_mul(<9 x i32>* %Aptr, i32 %k, <9 x i32>* %C) { +; CHECK-LABEL: @ka_t_mul( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <9 x i32>, <9 x i32>* [[APTR:%.*]], align 64 +; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x i32> poison, i32 [[K:%.*]], i64 0 +; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x i32> [[VECK]], <9 x i32> poison, <9 x i32> zeroinitializer +; CHECK-NEXT: [[A_T:%.*]] = call <9 x i32> @llvm.matrix.transpose.v9i32(<9 x i32> [[A]], i32 3, i32 3) +; CHECK-NEXT: [[MMUL:%.*]] = mul <9 x i32> [[SPLAT]], [[A_T]] +; CHECK-NEXT: store <9 x i32> [[MMUL]], <9 x i32>* [[C:%.*]], align 64 +; CHECK-NEXT: ret void +; +entry: + %a = load <9 x i32>, <9 x i32>* %Aptr + %veck = insertelement <9 x i32> poison, i32 %k, i64 0 + %splat = shufflevector <9 x i32> %veck, <9 x i32> poison, <9 x i32> zeroinitializer + %mul = mul <9 x i32> %splat, %a + %t = call <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32> %mul, i32 3, i32 3) + store <9 x i32> %t, <9 x i32>* %C + ret void +} + +declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32) +declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32, i32) +declare <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32>, i32, i32)