diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -181,6 +181,10 @@ void setColumn(unsigned i, Value *V) { Columns[i] = V; } + Type *getElementType() { + return cast(Columns[0]->getType())->getElementType(); + } + unsigned getNumColumns() const { return Columns.size(); } unsigned getNumRows() const { assert(Columns.size() > 0 && "Cannot call getNumRows without columns"); @@ -848,6 +852,49 @@ } } + /// Compute Res += A * B for tile-sized matrices with left-associating + /// addition. + void emitChainedMatrixMultiply(ColumnMatrixTy &Result, + const ColumnMatrixTy &A, + const ColumnMatrixTy &B, bool AllowContraction, + IRBuilder<> &Builder, bool isTiled) { + const unsigned VF = std::max( + TTI.getRegisterBitWidth(true) / + Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), + 1U); + unsigned R = Result.getNumRows(); + unsigned C = Result.getNumColumns(); + unsigned M = A.getNumColumns(); + + for (unsigned J = 0; J < C; ++J) { + unsigned BlockSize = VF; + + // If Result is zero, we don't need to accumulate in the K==0 iteration. + bool isSumZero = isa(Result.getColumn(J)); + + unsigned NumOps = 0; + for (unsigned I = 0; I < R; I += BlockSize) { + // Gradually lower the vectorization factor to cover the remainder. + while (I + BlockSize > R) + BlockSize /= 2; + + Value *Sum = + isTiled ? extractVector(Result, I, J, BlockSize, Builder) : nullptr; + for (unsigned K = 0; K < M; ++K) { + Value *L = extractVector(A, I, K, BlockSize, Builder); + Value *RH = Builder.CreateExtractElement(B.getColumn(J), K); + Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); + Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, + Result.getElementType()->isFloatingPointTy(), + Builder, AllowContraction, NumOps); + } + Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); + } + + Result.addNumComputeOps(NumOps); + } + } + /// Lowers llvm.matrix.multiply. void LowerMultiply(CallInst *MatMul) { IRBuilder<> Builder(MatMul); @@ -870,35 +917,9 @@ for (unsigned J = 0; J < C; ++J) Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); - const unsigned VF = std::max(TTI.getRegisterBitWidth(true) / - EltType->getPrimitiveSizeInBits(), - uint64_t(1)); - bool AllowContract = AllowContractEnabled || (isa(MatMul) && MatMul->hasAllowContract()); - unsigned NumComputeOps = 0; - // Multiply columns from the first operand with scalars from the second - // operand. Then move along the K axes and accumulate the columns. With - // this the adds can be vectorized without reassociation. - for (unsigned J = 0; J < C; ++J) { - unsigned BlockSize = VF; - for (unsigned I = 0; I < R; I += BlockSize) { - // Gradually lower the vectorization factor to cover the remainder. - while (I + BlockSize > R) - BlockSize /= 2; - - Value *Sum = nullptr; - for (unsigned K = 0; K < M; ++K) { - Value *L = extractVector(Lhs, I, K, BlockSize, Builder); - Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); - Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); - Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), - Builder, AllowContract, NumComputeOps); - } - Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); - } - } - Result.addNumComputeOps(NumComputeOps); + emitChainedMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false); finalizeLowering(MatMul, Result, Builder); }