Index: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -80,6 +80,9 @@ clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout"))); +static cl::opt PrintAfterTransposeOpt("matrix-print-after-transpose-opt", + cl::init(false)); + /// Helper function to either return Scope, if it is a subprogram or the /// attached subprogram for a local scope. static DISubprogram *getSubprogram(DIScope *Scope) { @@ -88,6 +91,20 @@ return cast(Scope)->getSubprogram(); } +/// Return true if V is a splat of a value (which is used when multiplying a +/// matrix with a scalar). +static bool isSplat(Value *V) { + if (auto *SV = dyn_cast(V)) + return SV->isZeroEltSplat(); + return false; +} + +/// Match any mul operation (fp or integer). +template +auto m_AnyMul(const LTy &L, const RTy &R) { + return m_CombineOr(m_Mul(L, R), m_FMul(L, R)); +} + namespace { // Given an element pointer \p BasePtr to the start of a (sub) matrix, compute @@ -747,8 +764,8 @@ Value *TA, *TAMA, *TAMB; ConstantInt *R, *K, *C; - if (match(&I, m_Intrinsic(m_Value(TA)))) { - + if (match(&I, m_Intrinsic( + m_Value(TA), m_ConstantInt(R), m_ConstantInt(C)))) { // Transpose of a transpose is a nop Value *TATA; if (match(TA, @@ -757,7 +774,11 @@ EraseFromParent(&I); EraseFromParent(TA); } - + // k^T -> k + else if (isSplat(TA)) { + ReplaceAllUsesWith(I, TA); + EraseFromParent(&I); + } // (A * B)^t -> B^t * A^t // RxK KxC CxK KxR else if (match(TA, m_Intrinsic( @@ -773,6 +794,28 @@ ReplaceAllUsesWith(I, NewInst); EraseFromParent(&I); EraseFromParent(TA); + // Same as above, but with a mul, which occurs when multiplied + // with a scalar. + // (A * k)^t -> A^t * k + // R x C RxC + } else if (match(TA, m_AnyMul(m_Value(TAMA), m_Value(TAMB))) && + (isSplat(TAMA) || isSplat(TAMB))) { + IRBuilder<> LocalBuilder(&I); + // We know that the transposed operand is of shape RxC. + // An when multiplied with a scalar, the shape is preserved. + NewInst = distributeTransposes( + TAMA, {R, C}, TAMB, {R, C}, Builder, + [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) { + bool IsFP = I.getType()->isFPOrFPVectorTy(); + auto *Mul = IsFP ? LocalBuilder.CreateFMul(T0, T1, "mmul") + : LocalBuilder.CreateMul(T0, T1, "mmul"); + auto *Result = cast(Mul); + setShapeInfo(Result, Shape0); + return Result; + }); + ReplaceAllUsesWith(I, NewInst); + EraseFromParent(&I); + EraseFromParent(TA); } } @@ -848,10 +891,10 @@ if (!isMinimal()) { optimizeTransposes(); - LLVM_DEBUG({ + if (PrintAfterTransposeOpt) { dbgs() << "Dump after matrix transpose optimization:\n"; Func.dump(); - }); + } } bool Changed = false; Index: llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll =================================================================== --- llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll +++ llvm/test/Transforms/LowerMatrixIntrinsics/transpose-opts.ll @@ -5,6 +5,7 @@ ; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s ; RUN: opt -passes='lower-matrix-intrinsics' -S -o /dev/null -pass-remarks-output=%t < %s && FileCheck --input-file %t --check-prefix=REMARK %s +; RUN: opt -passes='lower-matrix-intrinsics' -matrix-print-after-transpose-opt -S -o /dev/null %s 2>&1 | FileCheck --check-prefix=AFTER-TRANSPOSE %s target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" target triple = "aarch64-apple-ios" @@ -1130,11 +1131,92 @@ ret <12 x double> %m } +; k * A^T +define void @kat(<9 x double>* %A, double %k, <9 x double>* %C) { +; AFTER-TRANSPOSE-LABEL: @kat( +; AFTER-TRANSPOSE: load +; AFTER-TRANSPOSE: insertelement +; AFTER-TRANSPOSE: shufflevector +; AFTER-TRANSPOSE: llvm.matrix.transpose +; AFTER-TRANSPOSE-NOT: llvm.matrix.transpose +; AFTER-TRANSPOSE: llvm.matrix.multiply +; AFTER-TRANSPOSE: store +entry: + %a = load <9 x double>, <9 x double>* %A + %veck = insertelement <9 x double> poison, double %k, i64 0 + %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer + %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) + %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %at, i32 3, i32 3, i32 3) + store <9 x double> %mul, <9 x double>* %C + ret void +} + +; (k * A)^T -> A^T * k +define void @ka_t(<9 x double>* %A, double %k, <9 x double>* %C) { +; AFTER-TRANSPOSE-LABEL: @ka_t( +; AFTER-TRANSPOSE: load +; AFTER-TRANSPOSE: insertelement +; AFTER-TRANSPOSE: shufflevector +; AFTER-TRANSPOSE: llvm.matrix.transpose +; AFTER-TRANSPOSE-NOT: llvm.matrix.transpose +; AFTER-TRANSPOSE: llvm.matrix.multiply +; AFTER-TRANSPOSE: store +entry: + %a = load <9 x double>, <9 x double>* %A + %veck = insertelement <9 x double> poison, double %k, i64 0 + %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer + %mul = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %a, i32 3, i32 3, i32 3) + %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3) + store <9 x double> %t, <9 x double>* %C + ret void +} + +; (k * A)^T -> A^T * k with fmul +define void @ka_t_fmul(<9 x double>* %A, double %k, <9 x double>* %C) { +; AFTER-TRANSPOSE-LABEL: @ka_t_fmul( +; AFTER-TRANSPOSE: load +; AFTER-TRANSPOSE: insertelement +; AFTER-TRANSPOSE: shufflevector +; AFTER-TRANSPOSE: llvm.matrix.transpose +; AFTER-TRANSPOSE-NOT: llvm.matrix.transpose +; AFTER-TRANSPOSE: fmul +; AFTER-TRANSPOSE: store +entry: + %a = load <9 x double>, <9 x double>* %A + %veck = insertelement <9 x double> poison, double %k, i64 0 + %splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer + %mul = fmul <9 x double> %splat, %a + %t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %mul, i32 3, i32 3) + store <9 x double> %t, <9 x double>* %C + ret void +} + +; (k * A)^T -> A^T * k with mul (non-fp types) +define void @ka_t_mul(<9 x i32>* %A, i32 %k, <9 x i32>* %C) { +; AFTER-TRANSPOSE-LABEL: @ka_t_mul( +; AFTER-TRANSPOSE: load +; AFTER-TRANSPOSE: insertelement +; AFTER-TRANSPOSE: shufflevector +; AFTER-TRANSPOSE: llvm.matrix.transpose +; AFTER-TRANSPOSE-NOT: llvm.matrix.transpose +; AFTER-TRANSPOSE: mul +; AFTER-TRANSPOSE: store +entry: + %a = load <9 x i32>, <9 x i32>* %A + %veck = insertelement <9 x i32> poison, i32 %k, i64 0 + %splat = shufflevector <9 x i32> %veck, <9 x i32> poison, <9 x i32> zeroinitializer + %mul = mul <9 x i32> %splat, %a + %t = call <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32> %mul, i32 3, i32 3) + store <9 x i32> %t, <9 x i32>* %C + ret void +} + declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32) declare <12 x double> @llvm.matrix.multiply.v12f64.v6f64.v8f64(<6 x double>, <8 x double>, i32, i32, i32) declare <8 x double> @llvm.matrix.multiply.v8f64.v6f64.v12f64(<6 x double> %a, <12 x double>, i32, i32, i32) declare <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float>, <4 x float>, i32, i32, i32) declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32, i32) +declare <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32>, i32, i32) declare <6 x double> @llvm.matrix.transpose.v6f64.v6f64(<6 x double>, i32, i32) declare <8 x double> @llvm.matrix.transpose.v8f64.v8f64(<8 x double>, i32, i32) declare <12 x double> @llvm.matrix.transpose.v12f64.v12f64(<12 x double>, i32, i32)