diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -34,12 +34,14 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/MatrixUtils.h" using namespace llvm; using namespace PatternMatch; @@ -1178,6 +1180,106 @@ return Res; } + /// Try to use AArch64's udot instruction to compute 4x4 matrix multiplies. + /// <4 x i32> @llvm.aarch64.udot(<4 x i32> %acc, <16 x i8> %a, <16 x i8> %b) + /// works as follows: The first element of the result vector is equal to the + /// dot product of the first 4 elements of %a and %b added to the first + /// element of %acc. The second element is the dot product of the second 4 + /// elements of %a and %b added to the second element of %acc and so on. + bool tryToEmitDotTiling(CallInst *MatMul, LoadInst *LoadOp0, + LoadInst *LoadOp1, StoreInst *Store, + SmallPtrSetImpl &FusedInsts) { + auto *EltType = cast(LoadOp0->getType())->getElementType(); + ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); + ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); + + // Only process multiplies with i8 matrixes with shapes that fit the tile + // size perfectly. + if (!EltType->isIntegerTy() || EltType->getScalarSizeInBits() != 8 || + LShape.NumRows % 4 != 0 || LShape.NumColumns % 4 != 0 || + RShape.NumColumns % 4 != 0) + return false; + + // Create the main tiling loop nest. + TileInfo TI(LShape.NumRows, LShape.NumColumns, RShape.NumColumns, 4); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + Instruction *InsertI = cast(MatMul); + BasicBlock *Start = InsertI->getParent(); + BasicBlock *End = SplitBlock(InsertI->getParent(), InsertI, &DT, nullptr, + nullptr, "continue"); + IRBuilder<> Builder(MatMul); + BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, Builder, DTU); + Builder.SetInsertPoint(InnerBody->getTerminator()); + + Builder.SetInsertPoint(TI.RowLoopHeader->getTerminator()); + Type *DotResultColTy = + VectorType::get(IntegerType::get(Builder.getContext(), 32), 4); + Type *DotArgTy = + VectorType::get(IntegerType::get(Builder.getContext(), 8), 16); + + // Insert in the inner loop header. + Builder.SetInsertPoint(TI.InnerLoopHeader->getTerminator()); + // Create PHI nodes for the result columns to accumulate across iterations. + SmallVector ColumnPhis; + for (unsigned I = 0; I < 4; I++) { + ColumnPhis.push_back(Builder.CreatePHI(DotResultColTy, 2)); + ColumnPhis[I]->addIncoming(ConstantAggregateZero::get(DotResultColTy), + TI.RowLoopHeader->getSingleSuccessor()); + } + + // Insert in the inner loop body, which computes Res += Load(K, CurrentRow) + // * Load(K, CurrentColumn) + Builder.SetInsertPoint(InnerBody->getTerminator()); + // Load 4x4 tiles of the operands. + MatrixTy A = loadMatrix(LoadOp0->getPointerOperand(), LShape, TI.CurrentK, + TI.CurrentRow, {4, 4}, EltType, Builder); + MatrixTy B = loadMatrix(LoadOp1->getPointerOperand(), RShape, TI.CurrentK, + TI.CurrentCol, {4, 4}, EltType, Builder); + Value *FlatA = A.embedInVector(Builder); + SmallVector Results; + for (unsigned i = 0; i < 4; i++) { + // Compute column i of result using dot product by multiplying all rows of + // A with the i-th column of B. + Value *BB = concatenateVectors(Builder, {B.getColumn(i), B.getColumn(i), + B.getColumn(i), B.getColumn(i)}); + Value *Result = Builder.CreateIntrinsic(Intrinsic::aarch64_neon_udot, + {DotResultColTy, DotArgTy}, + {ColumnPhis[i], FlatA, BB}); + ColumnPhis[i]->addIncoming(Result, TI.InnerLoopLatch); + Results.push_back(Result); + } + + // Store result after the inner loop is done. + Builder.SetInsertPoint(TI.RowLoopLatch->getTerminator()); + MatrixTy Result; + Result.addVector(Builder.CreateTrunc(Results[0], A.getColumnTy())); + Result.addVector(Builder.CreateTrunc(Results[1], A.getColumnTy())); + Result.addVector(Builder.CreateTrunc(Results[2], A.getColumnTy())); + Result.addVector(Builder.CreateTrunc(Results[3], A.getColumnTy())); + storeMatrix(Result, Store->getPointerOperand(), + {LShape.NumRows, RShape.NumColumns}, TI.CurrentRow, + TI.CurrentCol, EltType, Builder); + + // Mark eliminated instructions as fused and remove them. + FusedInsts.insert(Store); + auto *Trans = cast(MatMul->getArgOperand(0)); + FusedInsts.insert(MatMul); + FusedInsts.insert(Trans); + Store->eraseFromParent(); + MatMul->eraseFromParent(); + Trans->eraseFromParent(); + if (LoadOp0->hasNUses(0)) { + FusedInsts.insert(LoadOp0); + LoadOp0->eraseFromParent(); + } + if (LoadOp1->hasNUses(0)) { + FusedInsts.insert(LoadOp1); + LoadOp1->eraseFromParent(); + } + + return true; + } + void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, StoreInst *Store, SmallPtrSetImpl &FusedInsts) { @@ -1249,6 +1351,21 @@ MatrixLayout != MatrixLayoutTy::ColumnMajor) return; + Value *MatrixA; + Value *MatrixB; + if (match(MatMul, + m_Intrinsic( + m_Intrinsic(m_Value(MatrixA)), + m_Value(MatrixB), m_Value(), m_Value(), m_Value()))) { + + auto *LoadOp0 = dyn_cast(MatrixA); + auto *LoadOp1 = dyn_cast(MatrixB); + auto *Store = dyn_cast(*MatMul->user_begin()); + if (LoadOp0 && LoadOp1 && Store && + tryToEmitDotTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts)) + return; + } + auto *LoadOp0 = dyn_cast(MatMul->getOperand(0)); auto *LoadOp1 = dyn_cast(MatMul->getOperand(1)); auto *Store = dyn_cast(*MatMul->user_begin()); @@ -1877,7 +1994,7 @@ AU.addRequired(); AU.addPreserved(); AU.addRequired(); - AU.addPreserved(); + // AU.addPreserved(); } }; } // namespace diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/aarch64-udot-4x4.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/aarch64-udot-4x4.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/aarch64-udot-4x4.ll @@ -0,0 +1,130 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -S %s | FileCheck %s + +target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" +target triple = "aarch64-apple-ios" + +define void @multiply(<16 x i8> * %A, <16 x i8> * %B, <16 x i8>* %C) { +; CHECK-LABEL: @multiply( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[COLS_HEADER:%.*]] +; CHECK: cols.header: +; CHECK-NEXT: [[COLS_IV:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[COLS_STEP:%.*]], [[COLS_LATCH:%.*]] ] +; CHECK-NEXT: br label [[COLS_BODY:%.*]] +; CHECK: cols.body: +; CHECK-NEXT: br label [[ROWS_HEADER:%.*]] +; CHECK: rows.header: +; CHECK-NEXT: [[ROWS_IV:%.*]] = phi i32 [ 0, [[COLS_BODY]] ], [ [[ROWS_STEP:%.*]], [[ROWS_LATCH:%.*]] ] +; CHECK-NEXT: br label [[ROWS_BODY:%.*]] +; CHECK: rows.body: +; CHECK-NEXT: br label [[INNER_HEADER:%.*]] +; CHECK: inner.header: +; CHECK-NEXT: [[INNER_IV:%.*]] = phi i32 [ 0, [[ROWS_BODY]] ], [ [[INNER_STEP:%.*]], [[INNER_LATCH:%.*]] ] +; CHECK-NEXT: [[TMP0:%.*]] = phi <4 x i32> [ zeroinitializer, [[ROWS_BODY]] ], [ [[TMP20:%.*]], [[INNER_LATCH]] ] +; CHECK-NEXT: [[TMP1:%.*]] = phi <4 x i32> [ zeroinitializer, [[ROWS_BODY]] ], [ [[TMP24:%.*]], [[INNER_LATCH]] ] +; CHECK-NEXT: [[TMP2:%.*]] = phi <4 x i32> [ zeroinitializer, [[ROWS_BODY]] ], [ [[TMP28:%.*]], [[INNER_LATCH]] ] +; CHECK-NEXT: [[TMP3:%.*]] = phi <4 x i32> [ zeroinitializer, [[ROWS_BODY]] ], [ [[TMP32:%.*]], [[INNER_LATCH]] ] +; CHECK-NEXT: br label [[INNER_BODY:%.*]] +; CHECK: inner.body: +; CHECK-NEXT: [[TMP4:%.*]] = mul i32 [[ROWS_IV]], 4 +; CHECK-NEXT: [[TMP5:%.*]] = add i32 [[TMP4]], [[INNER_IV]] +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <16 x i8>* [[A:%.*]] to i8* +; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, i8* [[TMP6]], i32 [[TMP5]] +; CHECK-NEXT: [[COL_CAST:%.*]] = bitcast i8* [[TMP7]] to <16 x i8>* +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <16 x i8>* [[COL_CAST]] to i8* +; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast i8* [[TMP8]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST]], align 1 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i8, i8* [[TMP8]], i32 4 +; CHECK-NEXT: [[VEC_CAST1:%.*]] = bitcast i8* [[VEC_GEP]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST1]], align 1 +; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i8, i8* [[TMP8]], i32 8 +; CHECK-NEXT: [[VEC_CAST4:%.*]] = bitcast i8* [[VEC_GEP3]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST4]], align 1 +; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr i8, i8* [[TMP8]], i32 12 +; CHECK-NEXT: [[VEC_CAST7:%.*]] = bitcast i8* [[VEC_GEP6]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST7]], align 1 +; CHECK-NEXT: [[TMP9:%.*]] = mul i32 [[COLS_IV]], 4 +; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[TMP9]], [[INNER_IV]] +; CHECK-NEXT: [[TMP11:%.*]] = bitcast <16 x i8>* [[B:%.*]] to i8* +; CHECK-NEXT: [[TMP12:%.*]] = getelementptr i8, i8* [[TMP11]], i32 [[TMP10]] +; CHECK-NEXT: [[COL_CAST9:%.*]] = bitcast i8* [[TMP12]] to <16 x i8>* +; CHECK-NEXT: [[TMP13:%.*]] = bitcast <16 x i8>* [[COL_CAST9]] to i8* +; CHECK-NEXT: [[VEC_CAST10:%.*]] = bitcast i8* [[TMP13]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD11:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST10]], align 1 +; CHECK-NEXT: [[VEC_GEP12:%.*]] = getelementptr i8, i8* [[TMP13]], i32 4 +; CHECK-NEXT: [[VEC_CAST13:%.*]] = bitcast i8* [[VEC_GEP12]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD14:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST13]], align 1 +; CHECK-NEXT: [[VEC_GEP15:%.*]] = getelementptr i8, i8* [[TMP13]], i32 8 +; CHECK-NEXT: [[VEC_CAST16:%.*]] = bitcast i8* [[VEC_GEP15]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD17:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST16]], align 1 +; CHECK-NEXT: [[VEC_GEP18:%.*]] = getelementptr i8, i8* [[TMP13]], i32 12 +; CHECK-NEXT: [[VEC_CAST19:%.*]] = bitcast i8* [[VEC_GEP18]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD20:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST19]], align 1 +; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i8> [[COL_LOAD]], <4 x i8> [[COL_LOAD2]], <8 x i32> +; CHECK-NEXT: [[TMP15:%.*]] = shufflevector <4 x i8> [[COL_LOAD5]], <4 x i8> [[COL_LOAD8]], <8 x i32> +; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <8 x i8> [[TMP14]], <8 x i8> [[TMP15]], <16 x i32> +; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <4 x i8> [[COL_LOAD11]], <4 x i8> [[COL_LOAD11]], <8 x i32> +; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <4 x i8> [[COL_LOAD11]], <4 x i8> [[COL_LOAD11]], <8 x i32> +; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <8 x i8> [[TMP17]], <8 x i8> [[TMP18]], <16 x i32> +; CHECK-NEXT: [[TMP20]] = call <4 x i32> @llvm.aarch64.neon.udot.v4i32.v16i8(<4 x i32> [[TMP0]], <16 x i8> [[TMP16]], <16 x i8> [[TMP19]]) +; CHECK-NEXT: [[TMP21:%.*]] = shufflevector <4 x i8> [[COL_LOAD14]], <4 x i8> [[COL_LOAD14]], <8 x i32> +; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <4 x i8> [[COL_LOAD14]], <4 x i8> [[COL_LOAD14]], <8 x i32> +; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <8 x i8> [[TMP21]], <8 x i8> [[TMP22]], <16 x i32> +; CHECK-NEXT: [[TMP24]] = call <4 x i32> @llvm.aarch64.neon.udot.v4i32.v16i8(<4 x i32> [[TMP1]], <16 x i8> [[TMP16]], <16 x i8> [[TMP23]]) +; CHECK-NEXT: [[TMP25:%.*]] = shufflevector <4 x i8> [[COL_LOAD17]], <4 x i8> [[COL_LOAD17]], <8 x i32> +; CHECK-NEXT: [[TMP26:%.*]] = shufflevector <4 x i8> [[COL_LOAD17]], <4 x i8> [[COL_LOAD17]], <8 x i32> +; CHECK-NEXT: [[TMP27:%.*]] = shufflevector <8 x i8> [[TMP25]], <8 x i8> [[TMP26]], <16 x i32> +; CHECK-NEXT: [[TMP28]] = call <4 x i32> @llvm.aarch64.neon.udot.v4i32.v16i8(<4 x i32> [[TMP2]], <16 x i8> [[TMP16]], <16 x i8> [[TMP27]]) +; CHECK-NEXT: [[TMP29:%.*]] = shufflevector <4 x i8> [[COL_LOAD20]], <4 x i8> [[COL_LOAD20]], <8 x i32> +; CHECK-NEXT: [[TMP30:%.*]] = shufflevector <4 x i8> [[COL_LOAD20]], <4 x i8> [[COL_LOAD20]], <8 x i32> +; CHECK-NEXT: [[TMP31:%.*]] = shufflevector <8 x i8> [[TMP29]], <8 x i8> [[TMP30]], <16 x i32> +; CHECK-NEXT: [[TMP32]] = call <4 x i32> @llvm.aarch64.neon.udot.v4i32.v16i8(<4 x i32> [[TMP3]], <16 x i8> [[TMP16]], <16 x i8> [[TMP31]]) +; CHECK-NEXT: br label [[INNER_LATCH]] +; CHECK: inner.latch: +; CHECK-NEXT: [[INNER_STEP]] = add i32 [[INNER_IV]], 4 +; CHECK-NEXT: [[INNER_COND:%.*]] = icmp ne i32 [[INNER_STEP]], 4 +; CHECK-NEXT: br i1 [[INNER_COND]], label [[INNER_HEADER]], label [[ROWS_LATCH]] +; CHECK: rows.latch: +; CHECK-NEXT: [[ROWS_STEP]] = add i32 [[ROWS_IV]], 4 +; CHECK-NEXT: [[ROWS_COND:%.*]] = icmp ne i32 [[ROWS_STEP]], 4 +; CHECK-NEXT: [[TMP33:%.*]] = trunc <4 x i32> [[TMP20]] to <4 x i8> +; CHECK-NEXT: [[TMP34:%.*]] = trunc <4 x i32> [[TMP24]] to <4 x i8> +; CHECK-NEXT: [[TMP35:%.*]] = trunc <4 x i32> [[TMP28]] to <4 x i8> +; CHECK-NEXT: [[TMP36:%.*]] = trunc <4 x i32> [[TMP32]] to <4 x i8> +; CHECK-NEXT: [[TMP37:%.*]] = mul i32 [[COLS_IV]], 4 +; CHECK-NEXT: [[TMP38:%.*]] = add i32 [[TMP37]], [[ROWS_IV]] +; CHECK-NEXT: [[TMP39:%.*]] = bitcast <16 x i8>* [[C:%.*]] to i8* +; CHECK-NEXT: [[TMP40:%.*]] = getelementptr i8, i8* [[TMP39]], i32 [[TMP38]] +; CHECK-NEXT: [[COL_CAST21:%.*]] = bitcast i8* [[TMP40]] to <16 x i8>* +; CHECK-NEXT: [[TMP41:%.*]] = bitcast <16 x i8>* [[COL_CAST21]] to i8* +; CHECK-NEXT: [[VEC_CAST22:%.*]] = bitcast i8* [[TMP41]] to <4 x i8>* +; CHECK-NEXT: store <4 x i8> [[TMP33]], <4 x i8>* [[VEC_CAST22]], align 1 +; CHECK-NEXT: [[VEC_GEP23:%.*]] = getelementptr i8, i8* [[TMP41]], i32 4 +; CHECK-NEXT: [[VEC_CAST24:%.*]] = bitcast i8* [[VEC_GEP23]] to <4 x i8>* +; CHECK-NEXT: store <4 x i8> [[TMP34]], <4 x i8>* [[VEC_CAST24]], align 1 +; CHECK-NEXT: [[VEC_GEP25:%.*]] = getelementptr i8, i8* [[TMP41]], i32 8 +; CHECK-NEXT: [[VEC_CAST26:%.*]] = bitcast i8* [[VEC_GEP25]] to <4 x i8>* +; CHECK-NEXT: store <4 x i8> [[TMP35]], <4 x i8>* [[VEC_CAST26]], align 1 +; CHECK-NEXT: [[VEC_GEP27:%.*]] = getelementptr i8, i8* [[TMP41]], i32 12 +; CHECK-NEXT: [[VEC_CAST28:%.*]] = bitcast i8* [[VEC_GEP27]] to <4 x i8>* +; CHECK-NEXT: store <4 x i8> [[TMP36]], <4 x i8>* [[VEC_CAST28]], align 1 +; CHECK-NEXT: br i1 [[ROWS_COND]], label [[ROWS_HEADER]], label [[COLS_LATCH]] +; CHECK: cols.latch: +; CHECK-NEXT: [[COLS_STEP]] = add i32 [[COLS_IV]], 4 +; CHECK-NEXT: [[COLS_COND:%.*]] = icmp ne i32 [[COLS_STEP]], 4 +; CHECK-NEXT: br i1 [[COLS_COND]], label [[COLS_HEADER]], label [[CONTINUE:%.*]] +; CHECK: continue: +; CHECK-NEXT: ret void +; +entry: + %a = load <16 x i8>, <16 x i8>* %A, align 16 + %b = load <16 x i8>, <16 x i8>* %B, align 16 + + %a.trans = call <16 x i8> @llvm.matrix.transpose.v16i8(<16 x i8> %a, i32 4, i32 4) + %c = call <16 x i8> @llvm.matrix.multiply.v16i8.v16i8.v16i8(<16 x i8> %a.trans, <16 x i8> %b, i32 4, i32 4, i32 4) + store <16 x i8> %c, <16 x i8>* %C, align 16 + ret void +} + +declare <16 x i8> @llvm.matrix.multiply.v16i8.v16i8.v16i8(<16 x i8>, <16 x i8>, i32, i32, i32) +declare <16 x i8> @llvm.matrix.transpose.v16i8(<16 x i8>, i32, i32) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/aarch64-udot-8x8.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/aarch64-udot-8x8.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/aarch64-udot-8x8.ll @@ -0,0 +1,130 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -S %s | FileCheck %s + +target datalayout = "e-m:o-i64:64-f80:128-n8:64:32:64-S128" +target triple = "aarch64-apple-ios" + +define void @multiply(<64 x i8> * %A, <64 x i8> * %B, <64 x i8>* %C) { +; CHECK-LABEL: @multiply( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[COLS_HEADER:%.*]] +; CHECK: cols.header: +; CHECK-NEXT: [[COLS_IV:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[COLS_STEP:%.*]], [[COLS_LATCH:%.*]] ] +; CHECK-NEXT: br label [[COLS_BODY:%.*]] +; CHECK: cols.body: +; CHECK-NEXT: br label [[ROWS_HEADER:%.*]] +; CHECK: rows.header: +; CHECK-NEXT: [[ROWS_IV:%.*]] = phi i32 [ 0, [[COLS_BODY]] ], [ [[ROWS_STEP:%.*]], [[ROWS_LATCH:%.*]] ] +; CHECK-NEXT: br label [[ROWS_BODY:%.*]] +; CHECK: rows.body: +; CHECK-NEXT: br label [[INNER_HEADER:%.*]] +; CHECK: inner.header: +; CHECK-NEXT: [[INNER_IV:%.*]] = phi i32 [ 0, [[ROWS_BODY]] ], [ [[INNER_STEP:%.*]], [[INNER_LATCH:%.*]] ] +; CHECK-NEXT: [[TMP0:%.*]] = phi <4 x i32> [ zeroinitializer, [[ROWS_BODY]] ], [ [[TMP20:%.*]], [[INNER_LATCH]] ] +; CHECK-NEXT: [[TMP1:%.*]] = phi <4 x i32> [ zeroinitializer, [[ROWS_BODY]] ], [ [[TMP24:%.*]], [[INNER_LATCH]] ] +; CHECK-NEXT: [[TMP2:%.*]] = phi <4 x i32> [ zeroinitializer, [[ROWS_BODY]] ], [ [[TMP28:%.*]], [[INNER_LATCH]] ] +; CHECK-NEXT: [[TMP3:%.*]] = phi <4 x i32> [ zeroinitializer, [[ROWS_BODY]] ], [ [[TMP32:%.*]], [[INNER_LATCH]] ] +; CHECK-NEXT: br label [[INNER_BODY:%.*]] +; CHECK: inner.body: +; CHECK-NEXT: [[TMP4:%.*]] = mul i32 [[ROWS_IV]], 8 +; CHECK-NEXT: [[TMP5:%.*]] = add i32 [[TMP4]], [[INNER_IV]] +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <64 x i8>* [[A:%.*]] to i8* +; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, i8* [[TMP6]], i32 [[TMP5]] +; CHECK-NEXT: [[COL_CAST:%.*]] = bitcast i8* [[TMP7]] to <16 x i8>* +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <16 x i8>* [[COL_CAST]] to i8* +; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast i8* [[TMP8]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST]], align 1 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i8, i8* [[TMP8]], i32 8 +; CHECK-NEXT: [[VEC_CAST1:%.*]] = bitcast i8* [[VEC_GEP]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST1]], align 1 +; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i8, i8* [[TMP8]], i32 16 +; CHECK-NEXT: [[VEC_CAST4:%.*]] = bitcast i8* [[VEC_GEP3]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST4]], align 1 +; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr i8, i8* [[TMP8]], i32 24 +; CHECK-NEXT: [[VEC_CAST7:%.*]] = bitcast i8* [[VEC_GEP6]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST7]], align 1 +; CHECK-NEXT: [[TMP9:%.*]] = mul i32 [[COLS_IV]], 8 +; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[TMP9]], [[INNER_IV]] +; CHECK-NEXT: [[TMP11:%.*]] = bitcast <64 x i8>* [[B:%.*]] to i8* +; CHECK-NEXT: [[TMP12:%.*]] = getelementptr i8, i8* [[TMP11]], i32 [[TMP10]] +; CHECK-NEXT: [[COL_CAST9:%.*]] = bitcast i8* [[TMP12]] to <16 x i8>* +; CHECK-NEXT: [[TMP13:%.*]] = bitcast <16 x i8>* [[COL_CAST9]] to i8* +; CHECK-NEXT: [[VEC_CAST10:%.*]] = bitcast i8* [[TMP13]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD11:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST10]], align 1 +; CHECK-NEXT: [[VEC_GEP12:%.*]] = getelementptr i8, i8* [[TMP13]], i32 8 +; CHECK-NEXT: [[VEC_CAST13:%.*]] = bitcast i8* [[VEC_GEP12]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD14:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST13]], align 1 +; CHECK-NEXT: [[VEC_GEP15:%.*]] = getelementptr i8, i8* [[TMP13]], i32 16 +; CHECK-NEXT: [[VEC_CAST16:%.*]] = bitcast i8* [[VEC_GEP15]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD17:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST16]], align 1 +; CHECK-NEXT: [[VEC_GEP18:%.*]] = getelementptr i8, i8* [[TMP13]], i32 24 +; CHECK-NEXT: [[VEC_CAST19:%.*]] = bitcast i8* [[VEC_GEP18]] to <4 x i8>* +; CHECK-NEXT: [[COL_LOAD20:%.*]] = load <4 x i8>, <4 x i8>* [[VEC_CAST19]], align 1 +; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i8> [[COL_LOAD]], <4 x i8> [[COL_LOAD2]], <8 x i32> +; CHECK-NEXT: [[TMP15:%.*]] = shufflevector <4 x i8> [[COL_LOAD5]], <4 x i8> [[COL_LOAD8]], <8 x i32> +; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <8 x i8> [[TMP14]], <8 x i8> [[TMP15]], <16 x i32> +; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <4 x i8> [[COL_LOAD11]], <4 x i8> [[COL_LOAD11]], <8 x i32> +; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <4 x i8> [[COL_LOAD11]], <4 x i8> [[COL_LOAD11]], <8 x i32> +; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <8 x i8> [[TMP17]], <8 x i8> [[TMP18]], <16 x i32> +; CHECK-NEXT: [[TMP20]] = call <4 x i32> @llvm.aarch64.neon.udot.v4i32.v16i8(<4 x i32> [[TMP0]], <16 x i8> [[TMP16]], <16 x i8> [[TMP19]]) +; CHECK-NEXT: [[TMP21:%.*]] = shufflevector <4 x i8> [[COL_LOAD14]], <4 x i8> [[COL_LOAD14]], <8 x i32> +; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <4 x i8> [[COL_LOAD14]], <4 x i8> [[COL_LOAD14]], <8 x i32> +; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <8 x i8> [[TMP21]], <8 x i8> [[TMP22]], <16 x i32> +; CHECK-NEXT: [[TMP24]] = call <4 x i32> @llvm.aarch64.neon.udot.v4i32.v16i8(<4 x i32> [[TMP1]], <16 x i8> [[TMP16]], <16 x i8> [[TMP23]]) +; CHECK-NEXT: [[TMP25:%.*]] = shufflevector <4 x i8> [[COL_LOAD17]], <4 x i8> [[COL_LOAD17]], <8 x i32> +; CHECK-NEXT: [[TMP26:%.*]] = shufflevector <4 x i8> [[COL_LOAD17]], <4 x i8> [[COL_LOAD17]], <8 x i32> +; CHECK-NEXT: [[TMP27:%.*]] = shufflevector <8 x i8> [[TMP25]], <8 x i8> [[TMP26]], <16 x i32> +; CHECK-NEXT: [[TMP28]] = call <4 x i32> @llvm.aarch64.neon.udot.v4i32.v16i8(<4 x i32> [[TMP2]], <16 x i8> [[TMP16]], <16 x i8> [[TMP27]]) +; CHECK-NEXT: [[TMP29:%.*]] = shufflevector <4 x i8> [[COL_LOAD20]], <4 x i8> [[COL_LOAD20]], <8 x i32> +; CHECK-NEXT: [[TMP30:%.*]] = shufflevector <4 x i8> [[COL_LOAD20]], <4 x i8> [[COL_LOAD20]], <8 x i32> +; CHECK-NEXT: [[TMP31:%.*]] = shufflevector <8 x i8> [[TMP29]], <8 x i8> [[TMP30]], <16 x i32> +; CHECK-NEXT: [[TMP32]] = call <4 x i32> @llvm.aarch64.neon.udot.v4i32.v16i8(<4 x i32> [[TMP3]], <16 x i8> [[TMP16]], <16 x i8> [[TMP31]]) +; CHECK-NEXT: br label [[INNER_LATCH]] +; CHECK: inner.latch: +; CHECK-NEXT: [[INNER_STEP]] = add i32 [[INNER_IV]], 4 +; CHECK-NEXT: [[INNER_COND:%.*]] = icmp ne i32 [[INNER_STEP]], 8 +; CHECK-NEXT: br i1 [[INNER_COND]], label [[INNER_HEADER]], label [[ROWS_LATCH]] +; CHECK: rows.latch: +; CHECK-NEXT: [[ROWS_STEP]] = add i32 [[ROWS_IV]], 4 +; CHECK-NEXT: [[ROWS_COND:%.*]] = icmp ne i32 [[ROWS_STEP]], 8 +; CHECK-NEXT: [[TMP33:%.*]] = trunc <4 x i32> [[TMP20]] to <4 x i8> +; CHECK-NEXT: [[TMP34:%.*]] = trunc <4 x i32> [[TMP24]] to <4 x i8> +; CHECK-NEXT: [[TMP35:%.*]] = trunc <4 x i32> [[TMP28]] to <4 x i8> +; CHECK-NEXT: [[TMP36:%.*]] = trunc <4 x i32> [[TMP32]] to <4 x i8> +; CHECK-NEXT: [[TMP37:%.*]] = mul i32 [[COLS_IV]], 8 +; CHECK-NEXT: [[TMP38:%.*]] = add i32 [[TMP37]], [[ROWS_IV]] +; CHECK-NEXT: [[TMP39:%.*]] = bitcast <64 x i8>* [[C:%.*]] to i8* +; CHECK-NEXT: [[TMP40:%.*]] = getelementptr i8, i8* [[TMP39]], i32 [[TMP38]] +; CHECK-NEXT: [[COL_CAST21:%.*]] = bitcast i8* [[TMP40]] to <16 x i8>* +; CHECK-NEXT: [[TMP41:%.*]] = bitcast <16 x i8>* [[COL_CAST21]] to i8* +; CHECK-NEXT: [[VEC_CAST22:%.*]] = bitcast i8* [[TMP41]] to <4 x i8>* +; CHECK-NEXT: store <4 x i8> [[TMP33]], <4 x i8>* [[VEC_CAST22]], align 1 +; CHECK-NEXT: [[VEC_GEP23:%.*]] = getelementptr i8, i8* [[TMP41]], i32 8 +; CHECK-NEXT: [[VEC_CAST24:%.*]] = bitcast i8* [[VEC_GEP23]] to <4 x i8>* +; CHECK-NEXT: store <4 x i8> [[TMP34]], <4 x i8>* [[VEC_CAST24]], align 1 +; CHECK-NEXT: [[VEC_GEP25:%.*]] = getelementptr i8, i8* [[TMP41]], i32 16 +; CHECK-NEXT: [[VEC_CAST26:%.*]] = bitcast i8* [[VEC_GEP25]] to <4 x i8>* +; CHECK-NEXT: store <4 x i8> [[TMP35]], <4 x i8>* [[VEC_CAST26]], align 1 +; CHECK-NEXT: [[VEC_GEP27:%.*]] = getelementptr i8, i8* [[TMP41]], i32 24 +; CHECK-NEXT: [[VEC_CAST28:%.*]] = bitcast i8* [[VEC_GEP27]] to <4 x i8>* +; CHECK-NEXT: store <4 x i8> [[TMP36]], <4 x i8>* [[VEC_CAST28]], align 1 +; CHECK-NEXT: br i1 [[ROWS_COND]], label [[ROWS_HEADER]], label [[COLS_LATCH]] +; CHECK: cols.latch: +; CHECK-NEXT: [[COLS_STEP]] = add i32 [[COLS_IV]], 4 +; CHECK-NEXT: [[COLS_COND:%.*]] = icmp ne i32 [[COLS_STEP]], 8 +; CHECK-NEXT: br i1 [[COLS_COND]], label [[COLS_HEADER]], label [[CONTINUE:%.*]] +; CHECK: continue: +; CHECK-NEXT: ret void +; +entry: + %a = load <64 x i8>, <64 x i8>* %A, align 64 + %b = load <64 x i8>, <64 x i8>* %B, align 64 + + %a.trans = call <64 x i8> @llvm.matrix.transpose.v64i8(<64 x i8> %a, i32 8, i32 8) + %c = call <64 x i8> @llvm.matrix.multiply.v64i8.v64i8.v64i8(<64 x i8> %a.trans, <64 x i8> %b, i32 8, i32 8, i32 8) + store <64 x i8> %c, <64 x i8>* %C, align 64 + ret void +} + +declare <64 x i8> @llvm.matrix.multiply.v64i8.v64i8.v64i8(<64 x i8>, <64 x i8>, i32, i32, i32) +declare <64 x i8> @llvm.matrix.transpose.v64i8(<64 x i8>, i32, i32)