diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -139,11 +139,29 @@ const TargetTransformInfo &TTI; OptimizationRemarkEmitter &ORE; + struct OpInfoTy { + /// Number of stores emitted to generate this matrix. + unsigned NumStores = 0; + /// Number of loads emitted to generate this matrix. + unsigned NumLoads = 0; + /// Number of compute operations emitted to generate this matrix. + unsigned NumComputeOps = 0; + + OpInfoTy &operator+=(const OpInfoTy &RHS) { + NumStores += RHS.NumStores; + NumLoads += RHS.NumLoads; + NumComputeOps += RHS.NumComputeOps; + return *this; + } + }; + /// Wrapper class representing a matrix as a set of column vectors. /// All column vectors must have the same vector type. class ColumnMatrixTy { SmallVector Columns; + OpInfoTy OpInfo; + public: ColumnMatrixTy() : Columns() {} ColumnMatrixTy(ArrayRef Cols) @@ -165,6 +183,10 @@ void addColumn(Value *V) { Columns.push_back(V); } + VectorType *getColumnTy() { + return cast(Columns[0]->getType()); + } + iterator_range::iterator> columns() { return make_range(Columns.begin(), Columns.end()); } @@ -175,6 +197,29 @@ return Columns.size() == 1 ? Columns[0] : concatenateVectors(Builder, Columns); } + + ColumnMatrixTy &addNumLoads(unsigned N) { + OpInfo.NumLoads += N; + return *this; + } + + void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } + + ColumnMatrixTy &addNumStores(unsigned N) { + OpInfo.NumStores += N; + return *this; + } + + ColumnMatrixTy &addNumComputeOps(unsigned N) { + OpInfo.NumComputeOps += N; + return *this; + } + + unsigned getNumStores() const { return OpInfo.NumStores; } + unsigned getNumLoads() const { return OpInfo.NumLoads; } + unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } + + const OpInfoTy &getOpInfo() const { return OpInfo; } }; struct ShapeInfo { @@ -222,6 +267,20 @@ OptimizationRemarkEmitter &ORE) : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), ORE(ORE) {} + unsigned getNumOps(Type *VT) { + assert(isa(VT) && "Expected vector type"); + return getNumOps(VT->getScalarType(), + cast(VT)->getNumElements()); + } + + // + /// Return the estimated number of vector ops required for an operation on + /// \p VT * N. + unsigned getNumOps(Type *ST, unsigned N) { + return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / + double(TTI.getRegisterBitWidth(true))); + } + /// Return the set of column vectors that a matrix value is lowered to. /// /// If we lowered \p MatrixVal, just return the cache result column matrix. @@ -580,7 +639,10 @@ Result.addColumn(Column); } - finalizeLowering(Inst, Result, Builder); + finalizeLowering(Inst, + Result.addNumLoads(getNumOps(Result.getColumnTy()) * + Result.getNumColumns()), + Builder); } /// Lowers llvm.matrix.columnwise.load. @@ -605,7 +667,8 @@ Shape.NumRows, VType->getElementType(), Builder); createColumnStore(C.value(), GEP, VType->getElementType(), Builder); } - Inst2ColumnMatrix[Inst] = ColumnMatrixTy(); + Inst2ColumnMatrix[Inst] = ColumnMatrixTy().addNumStores( + getNumOps(LM.getColumnTy()) * LM.getNumColumns()); ToRemove.push_back(Inst); } @@ -666,8 +729,9 @@ } Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, - IRBuilder<> &Builder, bool AllowContraction) { - + IRBuilder<> &Builder, bool AllowContraction, + unsigned &NumComputeOps) { + NumComputeOps += getNumOps(A->getType()); if (!Sum) return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); @@ -679,10 +743,12 @@ Func.getParent(), Intrinsic::fmuladd, A->getType()); return Builder.CreateCall(FMulAdd, {A, B, Sum}); } + NumComputeOps += getNumOps(A->getType()); Value *Mul = Builder.CreateFMul(A, B); return Builder.CreateFAdd(Sum, Mul); } + NumComputeOps += getNumOps(A->getType()); Value *Mul = Builder.CreateMul(A, B); return Builder.CreateAdd(Sum, Mul); } @@ -736,6 +802,7 @@ bool AllowContract = AllowContractEnabled || (isa(MatMul) && MatMul->hasAllowContract()); + unsigned NumComputeOps = 0; // Multiply columns from the first operand with scalars from the second // operand. Then move along the K axes and accumulate the columns. With // this the adds can be vectorized without reassociation. @@ -752,11 +819,12 @@ Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), - Builder, AllowContract); + Builder, AllowContract, NumComputeOps); } Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); } } + Result.addNumComputeOps(NumComputeOps); finalizeLowering(MatMul, Result, Builder); } @@ -786,7 +854,13 @@ Result.addColumn(ResultColumn); } - finalizeLowering(Inst, Result, Builder); + // TODO: Improve estimate of operations needed for transposes. Currently we + // just count the insertelement/extractelement instructions, but do not + // account for later simplifications/combines. + finalizeLowering( + Inst, + Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), + Builder); } /// Lower load instructions, if shape information is available. @@ -848,7 +922,10 @@ Result.addColumn( BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C))); - finalizeLowering(Inst, Result, Builder); + finalizeLowering(Inst, + Result.addNumComputeOps(getNumOps(Result.getColumnTy()) * + Result.getNumColumns()), + Builder); return true; } @@ -1114,6 +1191,23 @@ return Leaves; } + /// Calculate the number of exclusive and shared op counts for expression + /// starting at \p V. Expressions used multiple times are counted once. + OpInfoTy sumOpInfos(Value *Root, SmallPtrSetImpl &ReusedExprs) { + auto CM = Inst2ColumnMatrix.find(Root); + if (CM == Inst2ColumnMatrix.end()) + return {}; + + // Already counted this expression. Stop. + if (!ReusedExprs.insert(Root).second) + return {}; + + OpInfoTy Count = CM->second.getOpInfo(); + for (Value *Op : cast(Root)->operand_values()) + Count += sumOpInfos(Op, ReusedExprs); + return Count; + } + void emitRemarks() { if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) return; @@ -1123,10 +1217,16 @@ // Generate remarks for each leaf. for (auto *L : Leaves) { + SmallPtrSet ReusedExprs; + auto Counts = sumOpInfos(L, ReusedExprs); OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", cast(L)->getDebugLoc(), cast(L)->getParent()); - Rem << "Lowered matrix expression "; + Rem << "Lowered with "; + Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " + << ore::NV("NumLoads", Counts.NumLoads) << " loads, " + << ore::NV("NumComputeOps", Counts.NumComputeOps) << " compute ops"; + Rem << ("\n" + linearize(L, DL)); ORE.emit(Rem); } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/remarks.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/remarks.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/remarks.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/remarks.ll @@ -3,7 +3,7 @@ target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" target triple = "aarch64-apple-ios" -; CHECK-LABEL: remark: test.h:40:20: Lowered matrix expression +; CHECK-LABEL: remark: test.h:40:20: Lowered with 6 stores, 6 loads, 24 compute ops ; CHECK-NEXT: store( ; CHECK-NEXT: transpose.2x6.double(load(addr %A)), ; CHECK-NEXT: addr %B) @@ -17,7 +17,7 @@ declare <12 x double> @llvm.matrix.transpose.v12f64.v12f64(<12 x double>, i32, i32) -; CHECK-LABEL: remark: test.h:50:20: Lowered matrix expression +; CHECK-LABEL: remark: test.h:50:20: Lowered with 2 stores, 12 loads, 22 compute ops ; CHECK-NEXT: store( ; CHECK-NEXT: multiply.2x6.6x2.double( ; CHECK-NEXT: load(addr %A), @@ -33,7 +33,7 @@ declare <4 x double> @llvm.matrix.multiply(<12 x double>, <12 x double>, i32, i32, i32) -; CHECK-LABEL: remark: test.h:60:20: Lowered matrix expression +; CHECK-LABEL: remark: test.h:60:20: Lowered with 6 stores, 6 loads, 0 compute ops ; CHECK-NEXT: store( ; CHECK-NEXT: columnwise.load.3x3.double(addr %A, 5), ; CHECK-NEXT: addr %B) @@ -45,7 +45,7 @@ declare <9 x double> @llvm.matrix.columnwise.load(<9 x double>*, i32, i32, i32) -; CHECK-LABEL: remark: test.h:70:20: Lowered matrix expression +; CHECK-LABEL: remark: test.h:70:20: Lowered with 6 stores, 6 loads, 0 compute ops ; CHECK-NEXT: columnwise.store.3x3.double( ; CHECK-NEXT: columnwise.load.3x3.double(addr %A, 5), ; CHECK-NEXT: addr %B, @@ -58,7 +58,7 @@ declare void @llvm.matrix.columnwise.store(<9 x double>, <9 x double>*, i32, i32, i32) -; CHECK-LABEL: remark: test.h:80:20: Lowered matrix expression +; CHECK-LABEL: remark: test.h:80:20: Lowered with 6 stores, 6 loads, 12 compute ops ; CHECK-NEXT: columnwise.store.3x3.double( ; CHECK-NEXT: fmul( ; CHECK-NEXT: fadd( @@ -76,7 +76,7 @@ ret void } -; CHECK-LABEL: remark: test.h:90:20: Lowered matrix expression +; CHECK-LABEL: remark: test.h:90:20: Lowered with 6 stores, 6 loads, 12 compute ops ; CHECK-NEXT: columnwise.store.3x3.double( ; CHECK-NEXT: fmul( ; CHECK-NEXT: fadd( @@ -85,7 +85,7 @@ ; CHECK-NEXT: (reused) columnwise.load.3x3.double(addr %A, 5)), ; CHECK-NEXT: addr %B, ; CHECK-NEXT: 10) -; CHECK-NEXT: remark: test.h:90:20: Lowered matrix expression +; CHECK-NEXT: remark: test.h:90:20: Lowered with 2 stores, 12 loads, 22 compute ops ; CHECK-NEXT: store( ; CHECK-NEXT: multiply.2x6.6x2.double( ; CHECK-NEXT: load(addr %C), @@ -106,7 +106,7 @@ ret void } -; CHECK-LABEL: remark: test.h:100:20: Lowered matrix expression +; CHECK-LABEL: remark: test.h:100:20: Lowered with 6 stores, 6 loads, 12 compute ops ; CHECK-NEXT: columnwise.store.3x3.double( ; CHECK-NEXT: fmul( ; CHECK-NEXT: fadd( @@ -124,7 +124,7 @@ ret void } -; CHECK-LABEL: remark: test.h:30:20: Lowered matrix expression +; CHECK-LABEL: remark: test.h:30:20: Lowered with 10 stores, 9 loads, 30 compute ops ; CHECK-NEXT: store( ; CHECK-NEXT: transpose.5x3.double(load(addr %A)), ; CHECK-NEXT: stack addr %s1)