Index: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -34,6 +34,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MatrixBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -84,6 +85,9 @@ clEnumValN(MatrixLayoutTy::RowMajor, "row-major", "Use row-major layout"))); +static cl::opt PrintAfterTransposeOpt("matrix-print-after-transpose-opt", + cl::init(false)); + /// Helper function to either return Scope, if it is a subprogram or the /// attached subprogram for a local scope. static DISubprogram *getSubprogram(DIScope *Scope) { @@ -200,11 +204,16 @@ unsigned NumLoads = 0; /// Number of compute operations emitted to generate this matrix. unsigned NumComputeOps = 0; + // Most of the time transposes can be fused with matrix multiplies or can be + // folded away via algebraic simplifications. This is the number of + // transposes that we failed to make "free" via such optimizations. + unsigned NumExposedTransposes = 0; OpInfoTy &operator+=(const OpInfoTy &RHS) { NumStores += RHS.NumStores; NumLoads += RHS.NumLoads; NumComputeOps += RHS.NumComputeOps; + NumExposedTransposes += RHS.NumExposedTransposes; return *this; } }; @@ -309,6 +318,11 @@ return *this; } + MatrixTy &addNumExposedTransposes(unsigned N) { + OpInfo.NumExposedTransposes += N; + return *this; + } + MatrixTy &addNumComputeOps(unsigned N) { OpInfo.NumComputeOps += N; return *this; @@ -384,8 +398,10 @@ /// the result value of the instruction, with the only exceptions being store /// instructions and the matrix_column_major_store intrinsics. For those, the /// shape information indicates that those instructions should be lowered - /// using shape information as well. - DenseMap ShapeMap; + /// using shape information as well. A ValueMap is used so that when + /// sub-passes like optimizeTransposes performs RAUW the map stays + /// up-to-date. + ValueMap ShapeMap; /// List of instructions to remove. While lowering, we are not replacing all /// users of a lowered instruction, if shape information is available and @@ -659,6 +675,109 @@ return NewWorkList; } + /// Try moving transposes in order to fold them away or into multiplies. + void optimizeTransposes() { + // First sink all transposes inside matmuls, hoping that we end up with NN, + // NT or TN variants. + for (BasicBlock &BB: reverse(Func)) { + for (auto II = BB.rbegin(); II != BB.rend();) { + Instruction &I = *II; + // We may remove II. By default continue on the next/prev instruction. + ++II; + // If we were to erase II, move again. + auto eraseFromParent = [&](Value *V) { + auto *Inst = cast(V); + if (Inst->use_empty()) { + if (Inst == &*II) { + ++II; + } + Inst->eraseFromParent(); + } + }; + + // If we're creating a new instruction, continue from there. + Instruction *NewInst = nullptr; + + IRBuilder <> IB(&I); + MatrixBuilder> Builder(IB); + + Value *TA, *TAMA, *TAMB; + ConstantInt *R, *K, *C; + if (match(&I, m_Intrinsic(m_Value(TA)))) { + + // Transpose of a transpose is a nop + Value *TATA; + if (match(TA, + m_Intrinsic(m_Value(TATA)))) { + I.replaceAllUsesWith(TATA); + eraseFromParent(&I); + eraseFromParent(TA); + } + + // (A * B)^t -> B^t * A^t + // RxK KxC CxK KxR + else if (match(TA, m_Intrinsic( + m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R), + m_ConstantInt(K), m_ConstantInt(C)))) { + Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(), + C->getZExtValue(), + TAMB->getName() + "_t"); + // We are being run after shape prop, add shape for newly created + // instructions so that we lower them later. + setShapeInfo(T0, {C, K}); + Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(), + K->getZExtValue(), + TAMA->getName() + "_t"); + setShapeInfo(T1, {K, R}); + NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(), + K->getZExtValue(), + R->getZExtValue(), "mmul"); + setShapeInfo(NewInst, {C, R}); + I.replaceAllUsesWith(NewInst); + eraseFromParent(&I); + eraseFromParent(TA); + } + } + + // If we replaced I with a new instruction, continue from there. + if (NewInst) + II = std::next(BasicBlock::reverse_iterator(NewInst)); + } + } + + // If we have a TT matmul, lift the transpose until we have a non-TT situation. + for (BasicBlock &BB: Func) { + for (BasicBlock::iterator II = BB.begin(); II != BB.end();) { + Instruction *I = &*II; + // We may remove I. + ++II; + Value *A, *B, *AT, *BT; + ConstantInt *R, *K, *C; + if (match(&*I, m_Intrinsic( + m_Value(A), m_Value(B), m_ConstantInt(R), + m_ConstantInt(K), m_ConstantInt(C))) && + match(A, m_Intrinsic(m_Value(AT))) && + match(B, m_Intrinsic(m_Value((BT))))) { + IRBuilder<> IB(&*I); + MatrixBuilder> Builder(IB); + Value *M = Builder.CreateMatrixMultiply( + BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); + setShapeInfo(M, {C, R}); + Value *NewInst = Builder.CreateMatrixTranspose(M, R->getZExtValue(), + C->getZExtValue()); + setShapeInfo(NewInst, {C, R}); + I->replaceAllUsesWith(NewInst); + if (I->use_empty()) + I->eraseFromParent(); + if (A->use_empty()) + cast(A)->eraseFromParent(); + if (B->use_empty()) + cast(B)->eraseFromParent(); + } + } + } + } + bool Visit() { if (EnableShapePropagation) { SmallVector WorkList; @@ -689,6 +808,12 @@ } } + optimizeTransposes(); + if (PrintAfterTransposeOpt) { + dbgs() << "Dump after matrix transpose optimization:\n"; + Func.dump(); + } + bool Changed = false; SmallVector MaybeFusableInsts; SmallVector MatrixInsts; @@ -1490,7 +1615,8 @@ // account for later simplifications/combines. finalizeLowering( Inst, - Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), + Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns) + .addNumExposedTransposes(1), Builder); } @@ -2005,7 +2131,9 @@ Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " << ore::NV("NumLoads", Counts.NumLoads) << " loads, " << ore::NV("NumComputeOps", Counts.NumComputeOps) - << " compute ops"; + << " compute ops, " + << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes) + << " exposed transposes"; if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || SharedCounts.NumComputeOps > 0) { Index: llvm/test/Transforms/LowerMatrixIntrinsics/remarks-inlining.ll =================================================================== --- llvm/test/Transforms/LowerMatrixIntrinsics/remarks-inlining.ll +++ llvm/test/Transforms/LowerMatrixIntrinsics/remarks-inlining.ll @@ -47,50 +47,50 @@ target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" target triple = "aarch64-apple-ios" -; CHECK-LABEL: remark: load.h:41:43: Lowered with 0 stores, 10 loads, 0 compute ops +; CHECK-LABEL: remark: load.h:41:43: Lowered with 0 stores, 10 loads, 0 compute ops, 0 exposed transposes ; CHECK-NEXT: load(addr %A) -; CHECK-LABEL: remark: load.h:41:43: Lowered with 0 stores, 10 loads, 0 compute ops +; CHECK-LABEL: remark: load.h:41:43: Lowered with 0 stores, 10 loads, 0 compute ops, 0 exposed transposes ; CHECK-NEXT: column.major.load.3x5.double(addr %B, 5) -; CHECK-LABEL: remark: load.h:41:11: Lowered with 0 stores, 1 loads, 0 compute ops +; CHECK-LABEL: remark: load.h:41:11: Lowered with 0 stores, 1 loads, 0 compute ops, 0 exposed transposes ; CHECK-NEXT: load(addr %D) -; CHECK-LABEL: remark: assign.h:32:43: Lowered with 0 stores, 10 loads, 0 compute ops +; CHECK-LABEL: remark: assign.h:32:43: Lowered with 0 stores, 10 loads, 0 compute ops, 0 exposed transposes ; CHECK-NEXT: load(addr %A) -; CHECK-LABEL: remark: assign.h:32:43: Lowered with 0 stores, 10 loads, 0 compute ops +; CHECK-LABEL: remark: assign.h:32:43: Lowered with 0 stores, 10 loads, 0 compute ops, 0 exposed transposes ; CHECK-NEXT: column.major.load.3x5.double(addr %B, 5) -; CHECK-LABEL: remark: toplevel.c:410:0: Lowered with 10 stores, 20 loads, 10 compute ops +; CHECK-LABEL: remark: toplevel.c:410:0: Lowered with 10 stores, 20 loads, 10 compute ops, 0 exposed transposes ; CHECK-NEXT: store( ; CHECK-NEXT: fadd( ; CHECK-NEXT: load(addr %A), ; CHECK-NEXT: column.major.load.3x5.double(addr %B, 5)), ; CHECK-NEXT: addr %C) -; CHECK-LABEL: remark: toplevel.c:510:0: Lowered with 1 stores, 1 loads, 8 compute ops +; CHECK-LABEL: remark: toplevel.c:510:0: Lowered with 2 stores, 1 loads, 4 compute ops, 1 exposed transposes ; CHECK-NEXT: store( -; CHECK-NEXT: transpose.1x2.float(transpose.2x1.float(load(addr %D))), +; CHECK-NEXT: transpose.2x1.float(load(addr %D)), ; CHECK-NEXT: addr %D) -; CHECK-LABEL: remark: add.h:66:11: Lowered with 0 stores, 0 loads, 10 compute ops +; CHECK-LABEL: remark: add.h:66:11: Lowered with 0 stores, 0 loads, 10 compute ops, 0 exposed transposes ; CHECK-NEXT: fadd( ; CHECK-NEXT: addr %A, ; CHECK-NEXT: scalar) -; CHECK-LABEL: remark: store.h:10:11: Lowered with 10 stores, 0 loads, 0 compute ops +; CHECK-LABEL: remark: store.h:10:11: Lowered with 10 stores, 0 loads, 0 compute ops, 0 exposed transposes ; CHECK-NEXT: store( ; CHECK-NEXT: scalar, ; CHECK-NEXT: addr %C) -; CHECK-LABEL: remark: store.h:66:11: Lowered with 1 stores, 0 loads, 0 compute ops +; CHECK-LABEL: remark: store.h:66:11: Lowered with 2 stores, 0 loads, 0 compute ops, 0 exposed transposes ; CHECK-NEXT: store( ; CHECK-NEXT: scalar, ; CHECK-NEXT: addr %D) -; CHECK-LABEL: remark: transpose.h:13:11: Lowered with 0 stores, 0 loads, 8 compute ops -; CHECK-NEXT: transpose.1x2.float(transpose.2x1.float(addr %D)) +; CHECK-LABEL: remark: transpose.h:13:11: Lowered with 0 stores, 0 loads, 4 compute ops, 1 exposed transposes +; CHECK-NEXT: transpose.2x1.float(addr %D) define void @toplevel(<15 x double>* %A, double* %B, <15 x double>* %C, <2 x float>* %D) !dbg !16 { entry: @@ -101,8 +101,7 @@ %load = load <2 x float>, <2 x float>* %D, !dbg !104 %t1 = call <2 x float> @llvm.matrix.transpose(<2 x float> %load, i32 2, i32 1), !dbg !106 - %t2 = call <2 x float> @llvm.matrix.transpose(<2 x float> %t1, i32 1, i32 2), !dbg !106 - store <2 x float> %t2, <2 x float>* %D, !dbg !108 + store <2 x float> %t1, <2 x float>* %D, !dbg !108 ret void } Index: llvm/test/Transforms/LowerMatrixIntrinsics/remarks-shared-subtrees.ll =================================================================== --- llvm/test/Transforms/LowerMatrixIntrinsics/remarks-shared-subtrees.ll +++ llvm/test/Transforms/LowerMatrixIntrinsics/remarks-shared-subtrees.ll @@ -17,7 +17,9 @@ ; YAML-NEXT: - NumLoads: '0' ; YAML-NEXT: - String: ' loads, ' ; YAML-NEXT: - NumComputeOps: '0' -; YAML-NEXT: - String: ' compute ops' +; YAML-NEXT: - String: ' compute ops, ' +; YAML-NEXT: - NumExposedTransposes: '0' +; YAML-NEXT: - String: ' exposed transposes' ; YAML-NEXT: - String: ",\nadditionally " ; YAML-NEXT: - NumStores: '0' ; YAML-NEXT: - String: ' stores, ' @@ -45,7 +47,9 @@ ; YAML-NEXT: - NumLoads: '45' ; YAML-NEXT: - String: ' loads, ' ; YAML-NEXT: - NumComputeOps: '120' -; YAML-NEXT: - String: ' compute ops' +; YAML-NEXT: - String: ' compute ops, ' +; YAML-NEXT: - NumExposedTransposes: '0' +; YAML-NEXT: - String: ' exposed transposes' ; YAML-NEXT: - String: ",\nadditionally " ; YAML-NEXT: - NumStores: '0' ; YAML-NEXT: - String: ' stores, ' Index: llvm/test/Transforms/LowerMatrixIntrinsics/transpose-and-multiply-fold.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LowerMatrixIntrinsics/transpose-and-multiply-fold.ll @@ -0,0 +1,168 @@ +; REQUIRES: aarch64-registered-target + +; This test needs to be target specific due to the cost estimate in the output. + +; RUN: opt -lower-matrix-intrinsics -S -o /dev/null -pass-remarks-output=%t < %s && FileCheck --input-file %t %s +; RUN: opt -passes='lower-matrix-intrinsics' -S -o /dev/null -pass-remarks-output=%t < %s && FileCheck --input-file %t %s + +target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" +target triple = "aarch64-apple-ios" + +define void @double_transpose(<9 x double>* %A, <9 x double>* %B) { +; CHECK: Pass: lower-matrix-intrinsics +; CHECK-NEXT: Name: matrix-lowered +; CHECK-NEXT: Function: double_transpose +; CHECK-NEXT: Args: +; CHECK-NEXT: - String: 'Lowered with ' +; CHECK-NEXT: - NumStores: '6' +; CHECK-NEXT: - String: ' stores, ' +; CHECK-NEXT: - NumLoads: '6' +; CHECK-NEXT: - String: ' loads, ' +; CHECK-NEXT: - NumComputeOps: '0' +; CHECK-NEXT: - String: ' compute ops, ' +; CHECK-NEXT: - NumExposedTransposes: '0' +; CHECK-NEXT: - String: ' exposed transposes' +; CHECK-NEXT: - String: | +; CHECK: store( +; CHECK-NEXT: load(addr %A), +; CHECK-NEXT: addr %B) +entry: + %a = load <9 x double>, <9 x double>* %A, align 16 + %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) + %att = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %at, i32 3, i32 3) + store <9 x double> %att, <9 x double>* %B, align 16 + ret void +} + +define void @multiply_3x3x3_ntt(<9 x double>* %A, <9 x double>* %B, <9 x double>* %C, <9 x double>* %R) { +; CHECK: Pass: lower-matrix-intrinsics +; CHECK-NEXT: Name: matrix-lowered +; CHECK-NEXT: Function: multiply_3x3x3_ntt +; CHECK-NEXT: Args: +; CHECK-NEXT: - String: 'Lowered with ' +; CHECK-NEXT: - NumStores: '6' +; CHECK-NEXT: - String: ' stores, ' +; CHECK-NEXT: - NumLoads: '18' +; CHECK-NEXT: - String: ' loads, ' +; CHECK-NEXT: - NumComputeOps: '60' +; CHECK-NEXT: - String: ' compute ops, ' +; CHECK-NEXT: - NumExposedTransposes: '0' +; CHECK-NEXT: - String: ' exposed transposes' +; CHECK-NEXT: - String: | +; CHECK: store( +; CHECK-NEXT: multiply.3x3.3x3.double( +; CHECK-NEXT: load(addr %A), +; CHECK-NEXT: transpose.3x3.double(multiply.3x3.3x3.double( +; CHECK-NEXT: load(addr %C), +; CHECK-NEXT: load(addr %B)))), +; CHECK-NEXT: addr %R) +entry: + %a = load <9 x double>, <9 x double>* %A, align 16 + %b = load <9 x double>, <9 x double>* %B, align 16 + %c = load <9 x double>, <9 x double>* %C, align 16 + %b_t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) + %c_t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3) + %m1 = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %b_t, <9 x double> %c_t, i32 3, i32 3, i32 3) + %m2 = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %a, <9 x double> %m1, i32 3, i32 3, i32 3) + store <9 x double> %m2, <9 x double>* %R, align 16 + ret void +} + +define void @multiply_3x3x3_tt_t(<9 x double>* %A, <9 x double>* %B, <9 x double>* %C) { +; CHECK: Pass: lower-matrix-intrinsics +; CHECK-NEXT: Name: matrix-lowered +; CHECK-NEXT: Function: multiply_3x3x3_tt_t +; CHECK-NEXT: Args: +; CHECK-NEXT: - String: 'Lowered with ' +; CHECK-NEXT: - NumStores: '6' +; CHECK-NEXT: - String: ' stores, ' +; CHECK-NEXT: - NumLoads: '12' +; CHECK-NEXT: - String: ' loads, ' +; CHECK-NEXT: - NumComputeOps: '30' +; CHECK-NEXT: - String: ' compute ops, ' +; CHECK-NEXT: - NumExposedTransposes: '0' +; CHECK-NEXT: - String: ' exposed transposes' +; CHECK-NEXT: - String: | +; CHECK: store( +; CHECK-NEXT: multiply.3x3.3x3.double( +; CHECK-NEXT: load(addr %B), +; CHECK-NEXT: load(addr %A)), +; CHECK-NEXT: addr %C) +entry: + %a = load <9 x double>, <9 x double>* %A, align 16 + %at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3) + %b = load <9 x double>, <9 x double>* %B, align 16 + %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) + %c = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3) + %ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3) + store <9 x double> %ct, <9 x double>* %C, align 16 + ret void +} + +define void @multiply_3x3x3_nt_t(<9 x double>* %A, <9 x double>* %B, <9 x double>* %C) { +; CHECK: Pass: lower-matrix-intrinsics +; CHECK-NEXT: Name: matrix-lowered +; CHECK-NEXT: Function: multiply_3x3x3_nt_t +; CHECK-NEXT: Args: +; CHECK-NEXT: - String: 'Lowered with ' +; CHECK-NEXT: - NumStores: '6' +; CHECK-NEXT: - String: ' stores, ' +; CHECK-NEXT: - NumLoads: '12' +; CHECK-NEXT: - String: ' loads, ' +; CHECK-NEXT: - NumComputeOps: '30' +; CHECK-NEXT: - String: ' compute ops, ' +; CHECK-NEXT: - NumExposedTransposes: '0' +; CHECK-NEXT: - String: ' exposed transposes' +; CHECK-NEXT: - String: | +; CHECK: store( +; CHECK-NEXT: multiply.3x3.3x3.double( +; CHECK-NEXT: load(addr %B), +; CHECK-NEXT: transpose.3x3.double(load(addr %A))), +; CHECK-NEXT: addr %C) +entry: + %a = load <9 x double>, <9 x double>* %A, align 16 + %b = load <9 x double>, <9 x double>* %B, align 16 + %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) + %c = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %a, <9 x double> %bt, i32 3, i32 3, i32 3) + %ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3) + store <9 x double> %ct, <9 x double>* %C, align 16 + ret void +} + +define void @multiply_ntt_t(<9 x double>* %A, <9 x double>* %B, <9 x double>* %C, <9 x double>* %R) { +; CHECK: Pass: lower-matrix-intrinsics +; CHECK-NEXT: Name: matrix-lowered +; CHECK-NEXT: Function: multiply_ntt_t +; CHECK-NEXT: Args: +; CHECK-NEXT: - String: 'Lowered with ' +; CHECK-NEXT: - NumStores: '6' +; CHECK-NEXT: - String: ' stores, ' +; CHECK-NEXT: - NumLoads: '18' +; CHECK-NEXT: - String: ' loads, ' +; CHECK-NEXT: - NumComputeOps: '60' +; CHECK-NEXT: - String: ' compute ops, ' +; CHECK-NEXT: - NumExposedTransposes: '0' +; CHECK-NEXT: - String: ' exposed transposes' +; CHECK-NEXT: - String: | +; CHECK: store( +; CHECK-NEXT: multiply.3x3.3x3.double( +; CHECK-NEXT: multiply.3x3.3x3.double( +; CHECK-NEXT: load(addr %C), +; CHECK-NEXT: load(addr %B)), +; CHECK-NEXT: transpose.3x3.double(load(addr %A))), +; CHECK-NEXT: addr %R) +entry: + %a = load <9 x double>, <9 x double>* %A, align 16 + %b = load <9 x double>, <9 x double>* %B, align 16 + %bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3) + %c = load <9 x double>, <9 x double>* %C, align 16 + %ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3) + %btct = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %bt, <9 x double> %ct, i32 3, i32 3, i32 3) + %abtct= call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %a, <9 x double> %btct, i32 3, i32 3, i32 3) + %abtct_t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %abtct, i32 3, i32 3) + store <9 x double> %abtct_t, <9 x double>* %R, align 16 + ret void +} + +declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32 immarg, i32 immarg, i32 immarg) +declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32 immarg, i32 immarg)