diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -34,12 +34,14 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/MatrixUtils.h" using namespace llvm; using namespace PatternMatch; @@ -1178,6 +1180,101 @@ return Res; } + /// Try to use AArch64's udot instruction to compute 4x4 matrix multiplies. + /// <4 x i32> @llvm.aarch64.udot(<4 x i32> %acc, <16 x i8> %a, <16 x i8> %b) works as follows: + /// The first element of the result vector is equal to the dot product of the first 4 elements of %a and %b added to the first element of %acc. + /// The second element is the dot product of the second 4 elements of %a and %b added to the second element of %acc and so on. + bool tryToEmitDotTiling(CallInst *MatMul, LoadInst *LoadOp0, + LoadInst *LoadOp1, StoreInst *Store, + SmallPtrSetImpl &FusedInsts) { + auto *EltType = cast(LoadOp0->getType())->getElementType(); + ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); + ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); + + // Only process multiplies with i8 matrixes with shapes that fit the tile size perfectly. + if (!EltType->isI8Ty() || LShape.NumRows % 4 != 0 || + LShape.NumColumns % 4 != 0 || RShape.NumColumns % 4 != 0) + return false; + + // Create the main tiling loop nest. + TileInfo TI(LShape.NumRows, LShape.NumColumns, RShape.NumColumns, 4); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + Instruction *InsertI = cast(MatMul); + BasicBlock *Start = InsertI->getParent(); + BasicBlock *End = SplitBlock(InsertI->getParent(), InsertI, &DT, nullptr, + nullptr, "continue"); + IRBuilder<> Builder(MatMul); + BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU); + Builder.SetInsertPoint(InnerBody->getTerminator()); + + Builder.SetInsertPoint(TI.RowLoopHeader->getTerminator()); + Type *DotResultColTy = + VectorType::get(IntegerType::get(Builder.getContext(), 32), 4); + Type *DotArgTy = + VectorType::get(IntegerType::get(Builder.getContext(), 8), 16); + + // Insert in the inner loop header. + Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator()); + // Create PHI nodes for the result columns to accumulate across iterations. + SmallVector ColumnPhis; + for (unsigned I = 0; I < 4; I++) { + ColumnPhis.push_back(Builder.CreatePHI(DotResultColTy, 2)); + ColumnPhis[I]->addIncoming(ConstantAggregateZero::get(DotResultColTy), + TI.RowLoopHeader->getSingleSuccessor()); + } + + // Insert in the inner loop body, which computes Res += Load(K, CurrentRow) * Load(K, CurrentColumn) + Builder.SetInsertPoint(InnerBody->getTerminator()); + // Load 4x4 tiles of the operands. + MatrixTy A = loadMatrix(LoadOp0->getPointerOperand(), LShape, TI.CurrentK, + TI.CurrentRow, {4, 4}, EltType, Builder); + MatrixTy B = loadMatrix(LoadOp1->getPointerOperand(), RShape, TI.CurrentK, + TI.CurrentCol, {4, 4}, EltType, Builder); + Value *FlatA = A.embedInVector(Builder); + SmallVector Results; + for (unsigned i = 0; i < 4; i++) { + // Compute column i of result using dot product by multiplying all rows of + // A with the i-th column of B. + Value *BB = concatenateVectors(Builder, {B.getColumn(i), B.getColumn(i), + B.getColumn(i), B.getColumn(i)}); + Value *Result = Builder.CreateIntrinsic(Intrinsic::aarch64_neon_udot, + {DotResultColTy, DotArgTy}, + {ColumnPhis[i], FlatA, BB}); + ColumnPhis[i]->addIncoming(Result, TI.InnerLoopLatch); + Results.push_back(Result); + } + + // Store result after the inner loop is done. + Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator()); + MatrixTy Result; + Result.addVector(Builder.CreateTrunc(Results[0], A.getColumnTy())); + Result.addVector(Builder.CreateTrunc(Results[1], A.getColumnTy())); + Result.addVector(Builder.CreateTrunc(Results[2], A.getColumnTy())); + Result.addVector(Builder.CreateTrunc(Results[3], A.getColumnTy())); + storeMatrix(Result, Store->getPointerOperand(), + {LShape.NumRows, RShape.NumColumns}, TI.CurrentRow, + TI.CurrentCol, EltType, Builder); + + // Mark eliminated instructions as fused and remove them. + FusedInsts.insert(Store); + auto *Trans = cast(MatMul->getArgOperand(0)); + FusedInsts.insert(MatMul); + FusedInsts.insert(Trans); + Store->eraseFromParent(); + MatMul->eraseFromParent(); + Trans->eraseFromParent(); + if (LoadOp0->hasNUses(0)) { + FusedInsts.insert(LoadOp0); + LoadOp0->eraseFromParent(); + } + if (LoadOp1->hasNUses(0)) { + FusedInsts.insert(LoadOp1); + LoadOp1->eraseFromParent(); + } + + return true; + } + void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, StoreInst *Store, SmallPtrSetImpl &FusedInsts) { @@ -1249,6 +1346,21 @@ MatrixLayout != MatrixLayoutTy::ColumnMajor) return; + Value *MatrixA; + Value *MatrixB; + if (match(MatMul, + m_Intrinsic( + m_Intrinsic(m_Value(MatrixA)), + m_Value(MatrixB), m_Value(), m_Value(), m_Value()))) { + + auto *LoadOp0 = dyn_cast(MatrixA); + auto *LoadOp1 = dyn_cast(MatrixB); + auto *Store = dyn_cast(*MatMul->user_begin()); + if (LoadOp0 && LoadOp1 && Store && + tryToEmitDotTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts)) + return; + } + auto *LoadOp0 = dyn_cast(MatMul->getOperand(0)); auto *LoadOp1 = dyn_cast(MatMul->getOperand(1)); auto *Store = dyn_cast(*MatMul->user_begin()); @@ -1877,7 +1989,7 @@ AU.addRequired(); AU.addPreserved(); AU.addRequired(); - AU.addPreserved(); + // AU.addPreserved(); } }; } // namespace