diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -14401,6 +14401,118 @@ """""""""" The argument to this intrinsic must be a vector of floating-point values. +Matrix Intrinsics +----------------- + +Operations on matrixes requiring shape information (like number of rows/columns +or the memory layout) can be expressed using the matrix intrinsics. Matrixes are +embedded in a flat vector and the intrinsics take the dimensions as arguments. +Currently column-major layout is assumed. The intrinsics support both integer +and floating point matrixes. + + +'``llvm.matrix.transpose.*``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +:: + + declare vectorty @llvm.matrix.transpose.*(vectorty %In, i32 , i32 ) + +Overview: +""""""""" + +The '``llvm.matrix.transpose.*``' intrinsic treats %In as containing a matrix +with rows and columns and returns the transposed matrix embedded in +the result vector. + +Arguments: +"""""""""" + +The and arguments must be constant integers. The vector argument +%In and the returned vector must have * elements. + +'``llvm.matrix.multiply.*``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +:: + + declare vectorty @llvm.matrix.multiply.*(vectorty %A, vectorty %B, i32 , i32 , i32 ) + +Overview: +""""""""" + +The '``llvm.matrix.multiply.*``' intrinsic treats %A as matrix with rows and columns, %B as +matrix with rows and columns and multiplies them. The result matrix is returned embedded in the +result vector. + +Arguments: +"""""""""" + +The , and arguments must be constant integers. The vector argument %A +must have * elements, %B must have * elements and the returned +vector must have * elements. + + +'``llvm.matrix.columnwise.load.*``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +:: + + declare vectorty @llvm.matrix.columnwise.load.*(ptrty %Ptr, i32 %Stride, i32 , i32 ) + +Overview: +""""""""" + +The '``llvm.matrix.columnwise.load.*``' intrinsic loads a matrix with +rows and columns, using a stride of %Stride between columns. For two +consecutive columns A and B, %Stride refers to the distance between the end of +column A and the beginning of column B. Given the start address of column A, +the start address of column B is computed as A + + %Stride. +The result matrix is returned embedded in the result vector. This allows for +convenient loading of sub matrixes. + + +Arguments: +"""""""""" + +The and arguments must be constant integers. The returned vector +must have * elements. + +'``llvm.matrix.columnwise.store.*``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +:: + + declare void @llvm.matrix.columnwise.store.*(vectorty %In, ptrty %Ptr, i32 %Stride, i32 , i32 ) + +Overview: +""""""""" + +The '``llvm.matrix.columnwise.store.*``' intrinsic stores the matrix with + rows and columns embedded in %In , using a stride of %Stride +between columns. For two consecutive columns A and B, %Stride refers to the +distance between the end of column A and the beginning of column B. Given the +start address of column A, the start address of column B is computed as +A + + %Stride. + +Arguments: +"""""""""" + +The and arguments must be constant integers. The vector argument +%In must have * elements. + Half Precision Floating-Point Intrinsics ---------------------------------------- diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -1232,6 +1232,42 @@ [llvm_anyvector_ty]>; } +//===----- Matrix intrinsics ---------------------------------------------===// + +def int_matrix_transpose : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + llvm_i32_ty, + llvm_i32_ty], + [IntrNoMem, IntrSpeculatable, + IntrWillReturn, ImmArg<1>, ImmArg<2>]>; + +def int_matrix_multiply : Intrinsic<[llvm_anyvector_ty], + [llvm_anyvector_ty, + llvm_anyvector_ty, + llvm_i32_ty, + llvm_i32_ty, + llvm_i32_ty], + [IntrNoMem, IntrSpeculatable, + IntrWillReturn, ImmArg<2>, ImmArg<3>, + ImmArg<4>]>; + +def int_matrix_columnwise_load : Intrinsic<[llvm_anyvector_ty], + [LLVMAnyPointerType>, + llvm_i32_ty, + llvm_i32_ty, + llvm_i32_ty], + [IntrReadMem, IntrWillReturn, + ImmArg<2>, ImmArg<3>]>; + +def int_matrix_columnwise_store : Intrinsic<[], + [llvm_anyvector_ty, + LLVMAnyPointerType>, + llvm_i32_ty, + llvm_i32_ty, + llvm_i32_ty], + [WriteOnly<1>, IntrWillReturn, + ImmArg<3>, ImmArg<4>]>; + //===---------- Intrinsics to control hardware supported loops ----------===// // Specify that the value given is the number of iterations that the next loop diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -255,6 +255,7 @@ void initializeLowerInvokeLegacyPassPass(PassRegistry&); void initializeLowerSwitchPass(PassRegistry&); void initializeLowerTypeTestsPass(PassRegistry&); +void initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry &); void initializeMIRCanonicalizerPass(PassRegistry &); void initializeMIRNamerPass(PassRegistry &); void initializeMIRPrintingPassPass(PassRegistry&); diff --git a/llvm/include/llvm/Transforms/Scalar.h b/llvm/include/llvm/Transforms/Scalar.h --- a/llvm/include/llvm/Transforms/Scalar.h +++ b/llvm/include/llvm/Transforms/Scalar.h @@ -359,6 +359,12 @@ // Pass *createLowerGuardIntrinsicPass(); +//===----------------------------------------------------------------------===// +// +// LowerMatrixIntrinsics - Lower matrix intrinsics to vector operations. +// +Pass *createLowerMatrixIntrinsicsPass(); + //===----------------------------------------------------------------------===// // // LowerWidenableCondition - Lower widenable condition to i1 true. diff --git a/llvm/include/llvm/Transforms/Scalar/LowerMatrixIntrinsics.h b/llvm/include/llvm/Transforms/Scalar/LowerMatrixIntrinsics.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Transforms/Scalar/LowerMatrixIntrinsics.h @@ -0,0 +1,24 @@ +//===- LowerMatrixIntrinsics.h - Lower matrix intrinsics. -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass lowers matrix intrinsics down to vector operations. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_SCALAR_LOWERMATRIXINTRINSICSPASS_H +#define LLVM_TRANSFORMS_SCALAR_LOWERMATRIXINTRINSICSPASS_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { +struct LowerMatrixIntrinsicsPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; +} // namespace llvm + +#endif diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -146,6 +146,7 @@ #include "llvm/Transforms/Scalar/LowerConstantIntrinsics.h" #include "llvm/Transforms/Scalar/LowerExpectIntrinsic.h" #include "llvm/Transforms/Scalar/LowerGuardIntrinsic.h" +#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" #include "llvm/Transforms/Scalar/LowerWidenableCondition.h" #include "llvm/Transforms/Scalar/MakeGuardsExplicit.h" #include "llvm/Transforms/Scalar/MemCpyOptimizer.h" diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -189,6 +189,7 @@ FUNCTION_PASS("lower-expect", LowerExpectIntrinsicPass()) FUNCTION_PASS("lower-guard-intrinsic", LowerGuardIntrinsicPass()) FUNCTION_PASS("lower-constant-intrinsics", LowerConstantIntrinsicsPass()) +FUNCTION_PASS("lower-matrix-intrinsics", LowerMatrixIntrinsicsPass()) FUNCTION_PASS("lower-widenable-condition", LowerWidenableConditionPass()) FUNCTION_PASS("guard-widening", GuardWideningPass()) FUNCTION_PASS("gvn", GVN()) diff --git a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp --- a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -147,6 +147,10 @@ "enable-order-file-instrumentation", cl::init(false), cl::Hidden, cl::desc("Enable order file instrumentation (default = off)")); +static cl::opt + EnableMatrix("enable-matrix", cl::init(false), cl::Hidden, + cl::desc("Enable lowering of the matrix intrinsics")); + PassManagerBuilder::PassManagerBuilder() { OptLevel = 2; SizeLevel = 0; @@ -656,6 +660,14 @@ MPM.add(createFloat2IntPass()); MPM.add(createLowerConstantIntrinsicsPass()); + if (EnableMatrix) { + MPM.add(createLowerMatrixIntrinsicsPass()); + // CSE the pointer arithmetic of the column vectors. This allows alias + // analysis to establish no-aliasing between loads and stores of different + // columns of the same matrix. + MPM.add(createEarlyCSEPass(false)); + } + addExtensionsToPM(EP_VectorizerStart, MPM); // Re-rotate loops in all our loop nests. These may have fallout out of diff --git a/llvm/lib/Transforms/Scalar/CMakeLists.txt b/llvm/lib/Transforms/Scalar/CMakeLists.txt --- a/llvm/lib/Transforms/Scalar/CMakeLists.txt +++ b/llvm/lib/Transforms/Scalar/CMakeLists.txt @@ -47,6 +47,7 @@ LowerConstantIntrinsics.cpp LowerExpectIntrinsic.cpp LowerGuardIntrinsic.cpp + LowerMatrixIntrinsics.cpp LowerWidenableCondition.cpp MakeGuardsExplicit.cpp MemCpyOptimizer.cpp diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -0,0 +1,483 @@ +//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Lower matrix intrinsics to vector operations. +// +// TODO: +// * Implement multiply & add fusion +// * Implement shape propagation +// * Implement optimizations to reduce shufflevector uses by using shape +// information. +// * Add remark, summarizing the available matrix optimization opportunities. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" + +using namespace llvm; + +#define DEBUG_TYPE "lower-matrix-intrinsics" + +namespace { +// Given a \p MatrixPtr for the in-memory representation of a matrix, +// compute the address of the element at index \p Row, \p Col. \p Offset is +// the number of elements to skip to move same row, next column (this the +// number of rows other than accessing a submatrix. +Value *computeEltAddr(Value *MatrixPtr, Value *Row, Value *Col, Type *EltType, + Value *Offset, IRBuilder<> &Builder) { + Type *EltPtrType = PointerType::get( + EltType, cast(MatrixPtr->getType())->getAddressSpace()); + Value *Base = Builder.CreatePointerCast(MatrixPtr, EltPtrType); + + unsigned RowBitWidth = cast(Row->getType())->getBitWidth(); + unsigned ColBitWidth = cast(Col->getType())->getBitWidth(); + unsigned OffsetBitWidth = cast(Offset->getType())->getBitWidth(); + + unsigned WidestBitWidth = + std::max(std::max(RowBitWidth, ColBitWidth), OffsetBitWidth); + + Type *IntegerType = IntegerType::get(Builder.getContext(), WidestBitWidth); + + // If they are not the same width, extend all to be the same width + if (RowBitWidth != WidestBitWidth || ColBitWidth != WidestBitWidth || + OffsetBitWidth != WidestBitWidth) { + Row = Builder.CreateZExt(Row, IntegerType); + Col = Builder.CreateZExt(Col, IntegerType); + Offset = Builder.CreateZExt(Offset, IntegerType); + } + + // i = base + row + column * offset + + // Distance to the desired column + // (column * + offset) + Value *ColumnOffset = Builder.CreateMul(Col, Offset); + + // Compute the final element address offset + // (row + column * offset) + Value *EltIndex = Builder.CreateAdd(Row, ColumnOffset); + if (isa(EltIndex) && cast(EltIndex)->isZero()) + return Base; + return Builder.CreateGEP(EltType, Base, EltIndex); +} + +// Return the address of a column vector (\p EltType x \p NumRows) at index (\p +// Row, \p Col) of \p Base assuming \p Stride elements between two consecutive +// columns. +Value *computeColumnAddr(Value *Base, unsigned Row, unsigned Col, Value *Stride, + VectorType *VType, unsigned NumRows, + IRBuilder<> &Builder) { + Type *EltType = VType->getElementType(); + + // Stride is the number of elements between the end of a column and the + // beginning of the next one. Add the number of rows to it, to get the offset + // between the start of 2 columns. + Value *Offset = Builder.CreateAdd( + Builder.getIntN(cast(Stride->getType())->getBitWidth(), + NumRows), + Stride); + Value *EltPtr = + computeEltAddr(Base, Builder.getInt32(Row), Builder.getInt32(Col), + EltType, Offset, Builder); + + Type *ColumnType = VectorType::get(EltType, NumRows); + Type *ColumnPtrType = PointerType::get( + ColumnType, cast(Base->getType())->getAddressSpace()); + return Builder.CreatePointerCast(EltPtr, ColumnPtrType); +} + +/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. +/// +/// Currently, the lowering for each matrix intrinsic is done as follows: +/// 1. Split the operand vectors containing an embedded matrix into a set of +/// column vectors, based on the shape information from the intrinsic. +/// 2. Apply the transformation described by the intrinsic on the column +/// vectors, which yields a set of column vectors containing result matrix. +/// 3. Embed the columns of the result matrix in a flat vector and replace all +/// uses of the intrinsic result with it. +class LowerMatrixIntrinsics { + Function &Func; + const DataLayout &DL; + const TargetTransformInfo &TTI; + + /// Wrapper class representing a matrix as a set of column vectors. + /// All column vectors must have the same vector type. + class ColumnMatrixTy { + SmallVector Columns; + + public: + ColumnMatrixTy() : Columns() {} + ColumnMatrixTy(ArrayRef Cols) + : Columns(Cols.begin(), Cols.end()) {} + + Value *getColumn(unsigned i) const { return Columns[i]; } + + void setColumn(unsigned i, Value *V) { Columns[i] = V; } + + size_t getNumColumns() const { return Columns.size(); } + + const SmallVectorImpl &getColumnVectors() const { return Columns; } + + SmallVectorImpl &getColumnVectors() { return Columns; } + + void addColumn(Value *V) { Columns.push_back(V); } + + iterator_range::iterator> columns() { + return make_range(Columns.begin(), Columns.end()); + } + + /// Embed the columns of the matrix into a flat vector by concatenating + /// them. + Value *embedInVector(IRBuilder<> &Builder) const { + return Columns.size() == 1 ? Columns[0] + : concatenateVectors(Builder, Columns); + } + }; + + /// The list of instructions that need to be removed after + /// lowering. + SmallVector ToRemove; + + struct ShapeInfo { + unsigned NumRows; + unsigned NumColumns; + + ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) + : NumRows(NumRows), NumColumns(NumColumns) {} + + ShapeInfo(ConstantInt *NumRows, ConstantInt *NumColumns) + : NumRows(NumRows->getZExtValue()), + NumColumns(NumColumns->getZExtValue()) {} + }; + +public: + LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI) + : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {} + + /// Return the set of column vectors that a matrix value is lowered to. + /// + /// We split the flat vector \p MatrixVal containing a matrix with shape \p SI + /// into column vectors. + ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, + IRBuilder<> Builder) { + VectorType *VType = dyn_cast(MatrixVal->getType()); + assert(VType && "MatrixVal must be a vector type"); + assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && + "The vector size must match the number of matrix elements"); + SmallVector SplitVecs; + Value *Undef = UndefValue::get(VType); + + for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); + MaskStart += SI.NumRows) { + Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0); + Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split"); + SplitVecs.push_back(V); + } + + return {SplitVecs}; + } + + // Replace intrinsic calls + bool VisitCallInst(CallInst *Inst) { + if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) + return false; + + switch (Inst->getCalledFunction()->getIntrinsicID()) { + case Intrinsic::matrix_multiply: + LowerMatrixMultiply(Inst); + break; + case Intrinsic::matrix_transpose: + LowerMatrixTranspose(Inst); + break; + case Intrinsic::matrix_columnwise_load: + LowerMatrixLoad(Inst); + break; + case Intrinsic::matrix_columnwise_store: + LowerMatrixStore(Inst); + break; + default: + return false; + } + ToRemove.push_back(Inst); + return true; + } + + bool Visit() { + ReversePostOrderTraversal RPOT(&Func); + bool Changed = false; + for (auto *BB : RPOT) { + for (Instruction &Inst : *BB) { + bool Touched = false; + + if (CallInst *CInst = dyn_cast(&Inst)) + Touched = VisitCallInst(CInst); + Changed |= Touched; + } + } + + std::reverse(ToRemove.begin(), ToRemove.end()); + for (auto *Inst : ToRemove) + Inst->eraseFromParent(); + return Changed; + } + + LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, + IRBuilder<> Builder) { + unsigned Align = DL.getABITypeAlignment(EltType); + return Builder.CreateAlignedLoad(ColumnPtr, Align); + } + + StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, + Type *EltType, IRBuilder<> Builder) { + unsigned Align = DL.getABITypeAlignment(EltType); + return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align); + } + + /// Handles lowering the non-contiguous matrix load + /// + /// The intrinsic loads a matrix from memory using a stride between columns. + void LowerMatrixLoad(CallInst *Inst) { + IRBuilder<> Builder(Inst); + Value *Ptr = Inst->getArgOperand(0); + Value *Stride = Inst->getArgOperand(1); + auto VType = cast(Inst->getType()); + + ShapeInfo Shape(cast(Inst->getArgOperand(2)), + cast(Inst->getArgOperand(3))); + + ColumnMatrixTy Result; + + // Distance between start of one column and the start of the next + for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { + Value *GEP = + computeColumnAddr(Ptr, 0, C, Stride, VType, Shape.NumRows, Builder); + Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); + Result.addColumn(Column); + } + + Inst->replaceAllUsesWith(Result.embedInVector(Builder)); + } + + /// Handles lowering the non-contiguous matrix store. + /// + /// The intrinsic store a matrix back memory using a stride between columns. + void LowerMatrixStore(CallInst *Inst) { + IRBuilder<> Builder(Inst); + Value *Matrix = Inst->getArgOperand(0); + Value *Ptr = Inst->getArgOperand(1); + Value *Stride = Inst->getArgOperand(2); + ShapeInfo Shape(cast(Inst->getArgOperand(3)), + cast(Inst->getArgOperand(4))); + + auto VType = cast(Matrix->getType()); + + auto LM = getMatrix(Matrix, Shape, Builder); + + for (auto C : enumerate(LM.columns())) { + Value *GEP = computeColumnAddr(Ptr, 0, C.index(), Stride, VType, + Shape.NumRows, Builder); + createColumnStore(C.value(), GEP, VType->getElementType(), Builder); + } + } + + /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from + /// the matrix \p Cols represented as a vector of column vectors. + Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J, + unsigned NumElts, IRBuilder<> Builder) { + Value *Col = LM.getColumn(J); + Value *Undef = UndefValue::get(Col->getType()); + Constant *Mask = createSequentialMask(Builder, I, NumElts, 0); + return Builder.CreateShuffleVector(Col, Undef, Mask, "block"); + } + + // Set elements I..I+NumElts-1 to Block + Value *insertVector(Value *Col, unsigned I, Value *Block, + IRBuilder<> Builder) { + + // First, bring Block to the same size as Col + unsigned BlockNumElts = + cast(Block->getType())->getNumElements(); + unsigned NumElts = cast(Col->getType())->getNumElements(); + + Value *ExtendMask = + createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts); + Value *Undef = UndefValue::get(Block->getType()); + Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); + + // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, + // 8, 4, 5, 6 + SmallVector Mask; + unsigned i; + for (i = 0; i < I; i++) + Mask.push_back(Builder.getInt32(i)); + + unsigned VecNumElts = cast(Col->getType())->getNumElements(); + for (; i < I + BlockNumElts; i++) + Mask.push_back(Builder.getInt32(i - I + VecNumElts)); + + for (; i < VecNumElts; i++) + Mask.push_back(Builder.getInt32(i)); + + Value *MaskVal = ConstantVector::get(Mask); + + return Builder.CreateShuffleVector(Col, Block, MaskVal); + } + + Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, + IRBuilder<> &Builder) { + Value *Mul = UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); + if (!Sum) + return Mul; + + return UseFPOp ? Builder.CreateFAdd(Sum, Mul) : Builder.CreateAdd(Sum, Mul); + } + + void LowerMatrixMultiply(CallInst *MatMul) { + IRBuilder<> Builder(MatMul); + auto *EltType = cast(MatMul->getType())->getElementType(); + ShapeInfo LShape(cast(MatMul->getArgOperand(2)), + cast(MatMul->getArgOperand(3))); + ShapeInfo RShape(cast(MatMul->getArgOperand(3)), + cast(MatMul->getArgOperand(4))); + + const ColumnMatrixTy &Lhs = + getMatrix(MatMul->getArgOperand(0), LShape, Builder); + const ColumnMatrixTy &Rhs = + getMatrix(MatMul->getArgOperand(1), RShape, Builder); + + const unsigned R = LShape.NumRows; + const unsigned M = LShape.NumColumns; + const unsigned C = RShape.NumColumns; + assert(M == RShape.NumRows); + + // Initialize the output + ColumnMatrixTy Result; + for (unsigned J = 0; J < C; ++J) + Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); + + const unsigned VF = std::max(TTI.getRegisterBitWidth(true) / + EltType->getPrimitiveSizeInBits(), + 1ULL); + + // Multiply columns from the first operand with scalars from the second + // operand. Then move along the K axes and accumulate the columns. With + // this the adds can be vectorized without reassociation. + for (unsigned J = 0; J < C; ++J) { + unsigned BlockSize = VF; + for (unsigned I = 0; I < R; I += BlockSize) { + // Gradually lower the vectorization factor to cover the remainder. + while (I + BlockSize > R) + BlockSize /= 2; + + Value *Sum = nullptr; + for (unsigned K = 0; K < M; ++K) { + Value *L = extractVector(Lhs, I, K, BlockSize, Builder); + Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); + Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); + Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), + Builder); + } + Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); + } + } + + MatMul->replaceAllUsesWith(Result.embedInVector(Builder)); + } + + void LowerMatrixTranspose(CallInst *Inst) { + ColumnMatrixTy Result; + IRBuilder<> Builder(Inst); + Value *InputVal = Inst->getArgOperand(0); + VectorType *VectorTy = cast(InputVal->getType()); + ShapeInfo ArgShape(cast(Inst->getArgOperand(1)), + cast(Inst->getArgOperand(2))); + ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); + + for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { + // Build a single column vector for this row. First initialize it. + Value *ResultColumn = UndefValue::get( + VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); + + // Go through the elements of this row and insert it into the resulting + // column vector. + for (auto C : enumerate(InputMatrix.columns())) { + Value *Elt = Builder.CreateExtractElement(C.value(), Row); + // We insert at index Column since that is the row index after the + // transpose. + ResultColumn = + Builder.CreateInsertElement(ResultColumn, Elt, C.index()); + } + Result.addColumn(ResultColumn); + } + + Inst->replaceAllUsesWith(Result.embedInVector(Builder)); + } +}; +} // namespace + +PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &TTI = AM.getResult(F); + LowerMatrixIntrinsics LMT(F, TTI); + if (LMT.Visit()) { + PreservedAnalyses PA; + PA.preserveSet(); + return PA; + } + return PreservedAnalyses::all(); +} + +namespace { + +class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { +public: + static char ID; + + LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { + initializeLowerMatrixIntrinsicsLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto *TTI = &getAnalysis().getTTI(F); + LowerMatrixIntrinsics LMT(F, *TTI); + bool C = LMT.Visit(); + return C; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesCFG(); + } +}; +} // namespace + +static const char pass_name[] = "Lower the matrix intrinsics"; +char LowerMatrixIntrinsicsLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, + false, false) +INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, + false, false) + +Pass *llvm::createLowerMatrixIntrinsicsPass() { + return new LowerMatrixIntrinsicsLegacyPass(); +} diff --git a/llvm/lib/Transforms/Scalar/Scalar.cpp b/llvm/lib/Transforms/Scalar/Scalar.cpp --- a/llvm/lib/Transforms/Scalar/Scalar.cpp +++ b/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -82,6 +82,7 @@ initializeLowerConstantIntrinsicsPass(Registry); initializeLowerExpectIntrinsicPass(Registry); initializeLowerGuardIntrinsicLegacyPassPass(Registry); + initializeLowerMatrixIntrinsicsLegacyPassPass(Registry); initializeLowerWidenableConditionLegacyPassPass(Registry); initializeMemCpyOptLegacyPassPass(Registry); initializeMergeICmpsLegacyPassPass(Registry); diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply.ll @@ -0,0 +1,272 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + + +define void @multiply_2x2(<4 x double> * %Ptr.A, <4 x double> * %Ptr.B, <4 x double>* %Ptr.C) { +; CHECK-LABEL: @multiply_2x2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <4 x double>, <4 x double>* [[PTR_A:%.*]], align 16 +; CHECK-NEXT: [[B:%.*]] = load <4 x double>, <4 x double>* [[PTR_B:%.*]], align 16 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x double> [[A]], <4 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x double> [[A]], <4 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <4 x double> [[B]], <4 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <4 x double> [[B]], <4 x double> undef, <2 x i32> +; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x double> undef, double [[TMP0]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = fmul <1 x double> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[BLOCK4:%.*]] = shufflevector <2 x double> [[SPLIT1]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT5:%.*]] = insertelement <1 x double> undef, double [[TMP2]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT6:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT5]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = fmul <1 x double> [[BLOCK4]], [[SPLAT_SPLAT6]] +; CHECK-NEXT: [[TMP4:%.*]] = fadd <1 x double> [[TMP1]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <1 x double> [[TMP4]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x double> undef, <2 x double> [[TMP5]], <2 x i32> +; CHECK-NEXT: [[BLOCK7:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT8:%.*]] = insertelement <1 x double> undef, double [[TMP7]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT9:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT8]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP8:%.*]] = fmul <1 x double> [[BLOCK7]], [[SPLAT_SPLAT9]] +; CHECK-NEXT: [[BLOCK10:%.*]] = shufflevector <2 x double> [[SPLIT1]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT11:%.*]] = insertelement <1 x double> undef, double [[TMP9]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT12:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT11]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP10:%.*]] = fmul <1 x double> [[BLOCK10]], [[SPLAT_SPLAT12]] +; CHECK-NEXT: [[TMP11:%.*]] = fadd <1 x double> [[TMP8]], [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <1 x double> [[TMP11]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x double> [[TMP6]], <2 x double> [[TMP12]], <2 x i32> +; CHECK-NEXT: [[BLOCK13:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT14:%.*]] = insertelement <1 x double> undef, double [[TMP14]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT15:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT14]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP15:%.*]] = fmul <1 x double> [[BLOCK13]], [[SPLAT_SPLAT15]] +; CHECK-NEXT: [[BLOCK16:%.*]] = shufflevector <2 x double> [[SPLIT1]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT17:%.*]] = insertelement <1 x double> undef, double [[TMP16]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT18:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT17]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP17:%.*]] = fmul <1 x double> [[BLOCK16]], [[SPLAT_SPLAT18]] +; CHECK-NEXT: [[TMP18:%.*]] = fadd <1 x double> [[TMP15]], [[TMP17]] +; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <1 x double> [[TMP18]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x double> undef, <2 x double> [[TMP19]], <2 x i32> +; CHECK-NEXT: [[BLOCK19:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT20:%.*]] = insertelement <1 x double> undef, double [[TMP21]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT21:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT20]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP22:%.*]] = fmul <1 x double> [[BLOCK19]], [[SPLAT_SPLAT21]] +; CHECK-NEXT: [[BLOCK22:%.*]] = shufflevector <2 x double> [[SPLIT1]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[TMP23:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT23:%.*]] = insertelement <1 x double> undef, double [[TMP23]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT24:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT23]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP24:%.*]] = fmul <1 x double> [[BLOCK22]], [[SPLAT_SPLAT24]] +; CHECK-NEXT: [[TMP25:%.*]] = fadd <1 x double> [[TMP22]], [[TMP24]] +; CHECK-NEXT: [[TMP26:%.*]] = shufflevector <1 x double> [[TMP25]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP27:%.*]] = shufflevector <2 x double> [[TMP20]], <2 x double> [[TMP26]], <2 x i32> +; CHECK-NEXT: [[TMP28:%.*]] = shufflevector <2 x double> [[TMP13]], <2 x double> [[TMP27]], <4 x i32> +; CHECK-NEXT: store <4 x double> [[TMP28]], <4 x double>* [[PTR_C:%.*]], align 16 +; CHECK-NEXT: ret void +; +entry: + %a = load <4 x double>, <4 x double>* %Ptr.A, align 16 + %b = load <4 x double>, <4 x double>* %Ptr.B, align 16 + %c = call <4 x double> @llvm.matrix.multiply.v4f64.v4f64.v4f64(<4 x double> %a, <4 x double> %b, i32 2, i32 2, i32 2) + store <4 x double> %c, <4 x double>* %Ptr.C, align 16 + ret void +} + +declare <4 x double> @llvm.matrix.multiply.v4f64.v4f64.v4f64(<4 x double>, <4 x double>, i32, i32, i32) + +define void @multiply_1x2(<2 x double> * %Ptr.A, <2 x double> * %Ptr.B, <4 x double>* %Ptr.C) { +; CHECK-LABEL: @multiply_1x2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <2 x double>, <2 x double>* [[PTR_A:%.*]], align 16 +; CHECK-NEXT: [[B:%.*]] = load <2 x double>, <2 x double>* [[PTR_B:%.*]], align 16 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <2 x double> [[A]], <2 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <2 x double> [[B]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <2 x double> [[B]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <1 x double> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x double> undef, double [[TMP0]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = fmul <1 x double> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <1 x double> [[TMP1]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <2 x double> undef, <2 x double> [[TMP2]], <2 x i32> +; CHECK-NEXT: [[BLOCK3:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <1 x double> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT4:%.*]] = insertelement <1 x double> undef, double [[TMP4]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT5:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT4]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = fmul <1 x double> [[BLOCK3]], [[SPLAT_SPLAT5]] +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <1 x double> [[TMP5]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <2 x double> [[TMP3]], <2 x double> [[TMP6]], <2 x i32> +; CHECK-NEXT: [[BLOCK6:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <1 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT7:%.*]] = insertelement <1 x double> undef, double [[TMP8]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT8:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT7]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP9:%.*]] = fmul <1 x double> [[BLOCK6]], [[SPLAT_SPLAT8]] +; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <1 x double> [[TMP9]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <2 x double> undef, <2 x double> [[TMP10]], <2 x i32> +; CHECK-NEXT: [[BLOCK9:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <1 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT10:%.*]] = insertelement <1 x double> undef, double [[TMP12]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT11:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT10]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP13:%.*]] = fmul <1 x double> [[BLOCK9]], [[SPLAT_SPLAT11]] +; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <1 x double> [[TMP13]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP15:%.*]] = shufflevector <2 x double> [[TMP11]], <2 x double> [[TMP14]], <2 x i32> +; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <2 x double> [[TMP7]], <2 x double> [[TMP15]], <4 x i32> +; CHECK-NEXT: store <4 x double> [[TMP16]], <4 x double>* [[PTR_C:%.*]], align 16 +; CHECK-NEXT: ret void +; +entry: + %a = load <2 x double>, <2 x double>* %Ptr.A, align 16 + %b = load <2 x double>, <2 x double>* %Ptr.B, align 16 + %c = call <4 x double> @llvm.matrix.multiply.v4f64.v2f64.v2f64(<2 x double> %a, <2 x double> %b, i32 2, i32 1, i32 2) + store <4 x double> %c, <4 x double>* %Ptr.C, align 16 + ret void +} + +declare <4 x double> @llvm.matrix.multiply.v4f64.v2f64.v2f64(<2 x double>, <2 x double>, i32, i32, i32) + +define void @multiply_i32_2x3(<6 x i32> * %Ptr.A, <6 x i32> * %Ptr.B, <9 x i32>* %Ptr.C) { +; CHECK-LABEL: @multiply_i32_2x3( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <6 x i32>, <6 x i32>* [[PTR_A:%.*]], align 16 +; CHECK-NEXT: [[B:%.*]] = load <6 x i32>, <6 x i32>* [[PTR_B:%.*]], align 16 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <6 x i32> [[A]], <6 x i32> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <6 x i32> [[A]], <6 x i32> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <6 x i32> [[B]], <6 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <6 x i32> [[B]], <6 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <6 x i32> [[B]], <6 x i32> undef, <2 x i32> +; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x i32> undef, i32 [[TMP0]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = mul <1 x i32> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[BLOCK5:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <1 x i32> undef, i32 [[TMP2]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT6]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = mul <1 x i32> [[BLOCK5]], [[SPLAT_SPLAT7]] +; CHECK-NEXT: [[TMP4:%.*]] = add <1 x i32> [[TMP1]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <1 x i32> [[TMP4]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <3 x i32> undef, <3 x i32> [[TMP5]], <3 x i32> +; CHECK-NEXT: [[BLOCK8:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <1 x i32> undef, i32 [[TMP7]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT9]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP8:%.*]] = mul <1 x i32> [[BLOCK8]], [[SPLAT_SPLAT10]] +; CHECK-NEXT: [[BLOCK11:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT12:%.*]] = insertelement <1 x i32> undef, i32 [[TMP9]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT13:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT12]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP10:%.*]] = mul <1 x i32> [[BLOCK11]], [[SPLAT_SPLAT13]] +; CHECK-NEXT: [[TMP11:%.*]] = add <1 x i32> [[TMP8]], [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <1 x i32> [[TMP11]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <3 x i32> [[TMP6]], <3 x i32> [[TMP12]], <3 x i32> +; CHECK-NEXT: [[BLOCK14:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT15:%.*]] = insertelement <1 x i32> undef, i32 [[TMP14]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT16:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT15]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP15:%.*]] = mul <1 x i32> [[BLOCK14]], [[SPLAT_SPLAT16]] +; CHECK-NEXT: [[BLOCK17:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT18:%.*]] = insertelement <1 x i32> undef, i32 [[TMP16]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT19:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT18]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP17:%.*]] = mul <1 x i32> [[BLOCK17]], [[SPLAT_SPLAT19]] +; CHECK-NEXT: [[TMP18:%.*]] = add <1 x i32> [[TMP15]], [[TMP17]] +; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <1 x i32> [[TMP18]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <3 x i32> [[TMP13]], <3 x i32> [[TMP19]], <3 x i32> +; CHECK-NEXT: [[BLOCK20:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP21:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT21:%.*]] = insertelement <1 x i32> undef, i32 [[TMP21]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT22:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT21]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP22:%.*]] = mul <1 x i32> [[BLOCK20]], [[SPLAT_SPLAT22]] +; CHECK-NEXT: [[BLOCK23:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP23:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT24:%.*]] = insertelement <1 x i32> undef, i32 [[TMP23]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT25:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT24]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP24:%.*]] = mul <1 x i32> [[BLOCK23]], [[SPLAT_SPLAT25]] +; CHECK-NEXT: [[TMP25:%.*]] = add <1 x i32> [[TMP22]], [[TMP24]] +; CHECK-NEXT: [[TMP26:%.*]] = shufflevector <1 x i32> [[TMP25]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP27:%.*]] = shufflevector <3 x i32> undef, <3 x i32> [[TMP26]], <3 x i32> +; CHECK-NEXT: [[BLOCK26:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP28:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT27:%.*]] = insertelement <1 x i32> undef, i32 [[TMP28]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT28:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT27]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP29:%.*]] = mul <1 x i32> [[BLOCK26]], [[SPLAT_SPLAT28]] +; CHECK-NEXT: [[BLOCK29:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP30:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT30:%.*]] = insertelement <1 x i32> undef, i32 [[TMP30]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT31:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT30]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP31:%.*]] = mul <1 x i32> [[BLOCK29]], [[SPLAT_SPLAT31]] +; CHECK-NEXT: [[TMP32:%.*]] = add <1 x i32> [[TMP29]], [[TMP31]] +; CHECK-NEXT: [[TMP33:%.*]] = shufflevector <1 x i32> [[TMP32]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP34:%.*]] = shufflevector <3 x i32> [[TMP27]], <3 x i32> [[TMP33]], <3 x i32> +; CHECK-NEXT: [[BLOCK32:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP35:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT33:%.*]] = insertelement <1 x i32> undef, i32 [[TMP35]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT34:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT33]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP36:%.*]] = mul <1 x i32> [[BLOCK32]], [[SPLAT_SPLAT34]] +; CHECK-NEXT: [[BLOCK35:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP37:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT36:%.*]] = insertelement <1 x i32> undef, i32 [[TMP37]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT37:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT36]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP38:%.*]] = mul <1 x i32> [[BLOCK35]], [[SPLAT_SPLAT37]] +; CHECK-NEXT: [[TMP39:%.*]] = add <1 x i32> [[TMP36]], [[TMP38]] +; CHECK-NEXT: [[TMP40:%.*]] = shufflevector <1 x i32> [[TMP39]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP41:%.*]] = shufflevector <3 x i32> [[TMP34]], <3 x i32> [[TMP40]], <3 x i32> +; CHECK-NEXT: [[BLOCK38:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP42:%.*]] = extractelement <2 x i32> [[SPLIT4]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT39:%.*]] = insertelement <1 x i32> undef, i32 [[TMP42]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT40:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT39]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP43:%.*]] = mul <1 x i32> [[BLOCK38]], [[SPLAT_SPLAT40]] +; CHECK-NEXT: [[BLOCK41:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP44:%.*]] = extractelement <2 x i32> [[SPLIT4]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT42:%.*]] = insertelement <1 x i32> undef, i32 [[TMP44]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT43:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT42]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP45:%.*]] = mul <1 x i32> [[BLOCK41]], [[SPLAT_SPLAT43]] +; CHECK-NEXT: [[TMP46:%.*]] = add <1 x i32> [[TMP43]], [[TMP45]] +; CHECK-NEXT: [[TMP47:%.*]] = shufflevector <1 x i32> [[TMP46]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP48:%.*]] = shufflevector <3 x i32> undef, <3 x i32> [[TMP47]], <3 x i32> +; CHECK-NEXT: [[BLOCK44:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP49:%.*]] = extractelement <2 x i32> [[SPLIT4]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT45:%.*]] = insertelement <1 x i32> undef, i32 [[TMP49]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT46:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT45]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP50:%.*]] = mul <1 x i32> [[BLOCK44]], [[SPLAT_SPLAT46]] +; CHECK-NEXT: [[BLOCK47:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP51:%.*]] = extractelement <2 x i32> [[SPLIT4]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT48:%.*]] = insertelement <1 x i32> undef, i32 [[TMP51]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT49:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT48]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP52:%.*]] = mul <1 x i32> [[BLOCK47]], [[SPLAT_SPLAT49]] +; CHECK-NEXT: [[TMP53:%.*]] = add <1 x i32> [[TMP50]], [[TMP52]] +; CHECK-NEXT: [[TMP54:%.*]] = shufflevector <1 x i32> [[TMP53]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP55:%.*]] = shufflevector <3 x i32> [[TMP48]], <3 x i32> [[TMP54]], <3 x i32> +; CHECK-NEXT: [[BLOCK50:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP56:%.*]] = extractelement <2 x i32> [[SPLIT4]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT51:%.*]] = insertelement <1 x i32> undef, i32 [[TMP56]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT52:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT51]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP57:%.*]] = mul <1 x i32> [[BLOCK50]], [[SPLAT_SPLAT52]] +; CHECK-NEXT: [[BLOCK53:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP58:%.*]] = extractelement <2 x i32> [[SPLIT4]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT54:%.*]] = insertelement <1 x i32> undef, i32 [[TMP58]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT55:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT54]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP59:%.*]] = mul <1 x i32> [[BLOCK53]], [[SPLAT_SPLAT55]] +; CHECK-NEXT: [[TMP60:%.*]] = add <1 x i32> [[TMP57]], [[TMP59]] +; CHECK-NEXT: [[TMP61:%.*]] = shufflevector <1 x i32> [[TMP60]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP62:%.*]] = shufflevector <3 x i32> [[TMP55]], <3 x i32> [[TMP61]], <3 x i32> +; CHECK-NEXT: [[TMP63:%.*]] = shufflevector <3 x i32> [[TMP20]], <3 x i32> [[TMP41]], <6 x i32> +; CHECK-NEXT: [[TMP64:%.*]] = shufflevector <3 x i32> [[TMP62]], <3 x i32> undef, <6 x i32> +; CHECK-NEXT: [[TMP65:%.*]] = shufflevector <6 x i32> [[TMP63]], <6 x i32> [[TMP64]], <9 x i32> +; CHECK-NEXT: store <9 x i32> [[TMP65]], <9 x i32>* [[PTR_C:%.*]], align 16 +; CHECK-NEXT: ret void +; +entry: + %a = load <6 x i32>, <6 x i32>* %Ptr.A, align 16 + %b = load <6 x i32>, <6 x i32>* %Ptr.B, align 16 + %c = call <9 x i32> @llvm.matrix.multiply.v6i32.v6i32.v6i32(<6 x i32> %a, <6 x i32> %b, i32 3, i32 2, i32 3) + store <9 x i32> %c, <9 x i32>* %Ptr.C, align 16 + ret void +} + +declare <9 x i32> @llvm.matrix.multiply.v6i32.v6i32.v6i32(<6 x i32>, <6 x i32>, i32, i32, i32) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load.ll @@ -0,0 +1,83 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + +define <9 x double> @strided_load_3x3(<9 x double>* %in, i32 %stride) { +; CHECK-LABEL: @strided_load_3x3( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = add i32 3, [[STRIDE:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <9 x double>* [[IN:%.*]] to double* +; CHECK-NEXT: [[TMP2:%.*]] = mul i32 0, [[TMP0]] +; CHECK-NEXT: [[TMP3:%.*]] = add i32 0, [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr double, double* [[TMP1]], i32 [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = bitcast double* [[TMP4]] to <3 x double>* +; CHECK-NEXT: [[TMP6:%.*]] = load <3 x double>, <3 x double>* [[TMP5]], align 8 +; CHECK-NEXT: [[TMP7:%.*]] = add i32 3, [[STRIDE]] +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <9 x double>* [[IN]] to double* +; CHECK-NEXT: [[TMP9:%.*]] = mul i32 1, [[TMP7]] +; CHECK-NEXT: [[TMP10:%.*]] = add i32 0, [[TMP9]] +; CHECK-NEXT: [[TMP11:%.*]] = getelementptr double, double* [[TMP8]], i32 [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = bitcast double* [[TMP11]] to <3 x double>* +; CHECK-NEXT: [[TMP13:%.*]] = load <3 x double>, <3 x double>* [[TMP12]], align 8 +; CHECK-NEXT: [[TMP14:%.*]] = add i32 3, [[STRIDE]] +; CHECK-NEXT: [[TMP15:%.*]] = bitcast <9 x double>* [[IN]] to double* +; CHECK-NEXT: [[TMP16:%.*]] = mul i32 2, [[TMP14]] +; CHECK-NEXT: [[TMP17:%.*]] = add i32 0, [[TMP16]] +; CHECK-NEXT: [[TMP18:%.*]] = getelementptr double, double* [[TMP15]], i32 [[TMP17]] +; CHECK-NEXT: [[TMP19:%.*]] = bitcast double* [[TMP18]] to <3 x double>* +; CHECK-NEXT: [[TMP20:%.*]] = load <3 x double>, <3 x double>* [[TMP19]], align 8 +; CHECK-NEXT: [[TMP21:%.*]] = shufflevector <3 x double> [[TMP6]], <3 x double> [[TMP13]], <6 x i32> +; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <3 x double> [[TMP20]], <3 x double> undef, <6 x i32> +; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <6 x double> [[TMP21]], <6 x double> [[TMP22]], <9 x i32> +; CHECK-NEXT: ret <9 x double> [[TMP23]] +; +entry: + %load = call <9 x double> @llvm.matrix.columnwise.load(<9 x double>* %in, i32 %stride, i32 3, i32 3) + ret <9 x double> %load +} + +declare <9 x double> @llvm.matrix.columnwise.load(<9 x double>*, i32, i32, i32) + +define <9 x double> @strided_load_9x1(<9 x double>* %in, i32 %stride) { +; CHECK-LABEL: @strided_load_9x1( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = add i32 9, [[STRIDE:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <9 x double>* [[IN:%.*]] to double* +; CHECK-NEXT: [[TMP2:%.*]] = mul i32 0, [[TMP0]] +; CHECK-NEXT: [[TMP3:%.*]] = add i32 0, [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr double, double* [[TMP1]], i32 [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = bitcast double* [[TMP4]] to <9 x double>* +; CHECK-NEXT: [[TMP6:%.*]] = load <9 x double>, <9 x double>* [[TMP5]], align 8 +; CHECK-NEXT: ret <9 x double> [[TMP6]] +; +entry: + %load = call <9 x double> @llvm.matrix.columnwise.load(<9 x double>* %in, i32 %stride, i32 9, i32 1) + ret <9 x double> %load +} + +declare <8 x i64> @llvm.matrix.columnwise.load.v8i64(<8 x i64>*, i32, i32, i32) + +define <8 x i64> @strided_load_i64_4x2(<8 x i64>* %in, i32 %stride) { +; CHECK-LABEL: @strided_load_i64_4x2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = add i32 4, [[STRIDE:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i64>* [[IN:%.*]] to i64* +; CHECK-NEXT: [[TMP2:%.*]] = mul i32 0, [[TMP0]] +; CHECK-NEXT: [[TMP3:%.*]] = add i32 0, [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i64, i64* [[TMP1]], i32 [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = bitcast i64* [[TMP4]] to <4 x i64>* +; CHECK-NEXT: [[TMP6:%.*]] = load <4 x i64>, <4 x i64>* [[TMP5]], align 4 +; CHECK-NEXT: [[TMP7:%.*]] = add i32 4, [[STRIDE]] +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <8 x i64>* [[IN]] to i64* +; CHECK-NEXT: [[TMP9:%.*]] = mul i32 1, [[TMP7]] +; CHECK-NEXT: [[TMP10:%.*]] = add i32 0, [[TMP9]] +; CHECK-NEXT: [[TMP11:%.*]] = getelementptr i64, i64* [[TMP8]], i32 [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = bitcast i64* [[TMP11]] to <4 x i64>* +; CHECK-NEXT: [[TMP13:%.*]] = load <4 x i64>, <4 x i64>* [[TMP12]], align 4 +; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i64> [[TMP6]], <4 x i64> [[TMP13]], <8 x i32> +; CHECK-NEXT: ret <8 x i64> [[TMP14]] +; +entry: + %load = call <8 x i64> @llvm.matrix.columnwise.load.v8i64(<8 x i64>* %in, i32 %stride, i32 4, i32 2) + ret <8 x i64> %load +} diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/strided-store.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-store.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-store.ll @@ -0,0 +1,92 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + +define void @strided_store_double_3x2(<6 x double>* %in.addr, double* %out) { +; CHECK-LABEL: @strided_store_double_3x2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[IN:%.*]] = load <6 x double>, <6 x double>* [[IN_ADDR:%.*]], align 8 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <6 x double> [[IN]], <6 x double> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <6 x double> [[IN]], <6 x double> undef, <3 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = bitcast double* [[OUT:%.*]] to <3 x double>* +; CHECK-NEXT: store <3 x double> [[SPLIT]], <3 x double>* [[TMP0]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr double, double* [[OUT]], i32 5 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[TMP1]] to <3 x double>* +; CHECK-NEXT: store <3 x double> [[SPLIT1]], <3 x double>* [[TMP2]], align 8 +; CHECK-NEXT: ret void +; +entry: + + %in = load <6 x double>, <6 x double>* %in.addr, align 8 + call void @llvm.matrix.columnwise.store(<6 x double> %in, double* %out, i32 2, i32 3, i32 2) + ret void +} + +define void @strided_store_double_3x2_nonconst_stride(<6 x double>* %in.addr, i32 %stride, double* %out) { +; CHECK-LABEL: @strided_store_double_3x2_nonconst_stride( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[IN:%.*]] = load <6 x double>, <6 x double>* [[IN_ADDR:%.*]], align 8 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <6 x double> [[IN]], <6 x double> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <6 x double> [[IN]], <6 x double> undef, <3 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = add i32 3, [[STRIDE:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = mul i32 0, [[TMP0]] +; CHECK-NEXT: [[TMP2:%.*]] = add i32 0, [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr double, double* [[OUT:%.*]], i32 [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast double* [[TMP3]] to <3 x double>* +; CHECK-NEXT: store <3 x double> [[SPLIT]], <3 x double>* [[TMP4]], align 8 +; CHECK-NEXT: [[TMP5:%.*]] = add i32 3, [[STRIDE]] +; CHECK-NEXT: [[TMP6:%.*]] = mul i32 1, [[TMP5]] +; CHECK-NEXT: [[TMP7:%.*]] = add i32 0, [[TMP6]] +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr double, double* [[OUT]], i32 [[TMP7]] +; CHECK-NEXT: [[TMP9:%.*]] = bitcast double* [[TMP8]] to <3 x double>* +; CHECK-NEXT: store <3 x double> [[SPLIT1]], <3 x double>* [[TMP9]], align 8 +; CHECK-NEXT: ret void +; +entry: + + %in = load <6 x double>, <6 x double>* %in.addr, align 8 + call void @llvm.matrix.columnwise.store(<6 x double> %in, double* %out, i32 %stride, i32 3, i32 2) + ret void +} + + +declare void @llvm.matrix.columnwise.store(<6 x double>, double*, i32, i32, i32) + +define void @strided_store_i8_2x3(<10 x i8>* %in.addr, double* %out) { +; CHECK-LABEL: @strided_store_i8_2x3( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[IN:%.*]] = load <10 x i8>, <10 x i8>* [[IN_ADDR:%.*]], align 8 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <10 x i8> [[IN]], <10 x i8> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <10 x i8> [[IN]], <10 x i8> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <10 x i8> [[IN]], <10 x i8> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <10 x i8> [[IN]], <10 x i8> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <10 x i8> [[IN]], <10 x i8> undef, <2 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = bitcast double* [[OUT:%.*]] to i8* +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <2 x i8>* +; CHECK-NEXT: store <2 x i8> [[SPLIT]], <2 x i8>* [[TMP1]], align 1 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[OUT]] to i8* +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, i8* [[TMP2]], i32 4 +; CHECK-NEXT: [[TMP4:%.*]] = bitcast i8* [[TMP3]] to <2 x i8>* +; CHECK-NEXT: store <2 x i8> [[SPLIT1]], <2 x i8>* [[TMP4]], align 1 +; CHECK-NEXT: [[TMP5:%.*]] = bitcast double* [[OUT]] to i8* +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, i8* [[TMP5]], i32 8 +; CHECK-NEXT: [[TMP7:%.*]] = bitcast i8* [[TMP6]] to <2 x i8>* +; CHECK-NEXT: store <2 x i8> [[SPLIT2]], <2 x i8>* [[TMP7]], align 1 +; CHECK-NEXT: [[TMP8:%.*]] = bitcast double* [[OUT]] to i8* +; CHECK-NEXT: [[TMP9:%.*]] = getelementptr i8, i8* [[TMP8]], i32 12 +; CHECK-NEXT: [[TMP10:%.*]] = bitcast i8* [[TMP9]] to <2 x i8>* +; CHECK-NEXT: store <2 x i8> [[SPLIT3]], <2 x i8>* [[TMP10]], align 1 +; CHECK-NEXT: [[TMP11:%.*]] = bitcast double* [[OUT]] to i8* +; CHECK-NEXT: [[TMP12:%.*]] = getelementptr i8, i8* [[TMP11]], i32 16 +; CHECK-NEXT: [[TMP13:%.*]] = bitcast i8* [[TMP12]] to <2 x i8>* +; CHECK-NEXT: store <2 x i8> [[SPLIT4]], <2 x i8>* [[TMP13]], align 1 +; CHECK-NEXT: ret void +; +entry: + + %in = load <10 x i8>, <10 x i8>* %in.addr, align 8 + call void @llvm.matrix.columnwise.store.v10i8(<10 x i8> %in, double* %out, i32 2, i32 2, i32 5) + ret void +} + +declare void @llvm.matrix.columnwise.store.v10i8(<10 x i8>, double*, i32, i32, i32) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose.ll @@ -0,0 +1,129 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + + +define void @transpose(<8 x double>* %Ptr.A, <8 x double>* %Ptr.B) { +; CHECK-LABEL: @transpose( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <8 x double>, <8 x double>* [[PTR_A:%.*]], align 16 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x double> undef, double [[TMP0]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x double> [[TMP1]], double [[TMP2]], i64 1 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> [[TMP3]], double [[TMP4]], i64 2 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 3 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x double> undef, double [[TMP8]], i64 0 +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 1 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x double> [[TMP11]], double [[TMP12]], i64 2 +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 3 +; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> [[TMP15]], <8 x i32> +; CHECK-NEXT: store <8 x double> [[TMP16]], <8 x double>* [[PTR_B:%.*]], align 16 +; CHECK-NEXT: ret void +; +entry: + %a = load <8 x double>, <8 x double> *%Ptr.A, align 16 + %c = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 2, i32 4) + + store <8 x double> %c, <8 x double> *%Ptr.B, align 16 + ret void +} + +declare <8 x double> @llvm.matrix.transpose(<8 x double>, i32, i32) + +define void @transpose_single_column(<8 x double>* %Ptr.A, <8 x double>* %Ptr.B) { +; CHECK-LABEL: @transpose_single_column( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <8 x double>, <8 x double>* [[PTR_A:%.*]], align 16 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <8 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <8 x double> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <1 x double> undef, double [[TMP0]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <8 x double> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <1 x double> undef, double [[TMP2]], i64 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <8 x double> [[SPLIT]], i64 2 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <1 x double> undef, double [[TMP4]], i64 0 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <8 x double> [[SPLIT]], i64 3 +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <1 x double> undef, double [[TMP6]], i64 0 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <8 x double> [[SPLIT]], i64 4 +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <1 x double> undef, double [[TMP8]], i64 0 +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <8 x double> [[SPLIT]], i64 5 +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <1 x double> undef, double [[TMP10]], i64 0 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <8 x double> [[SPLIT]], i64 6 +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <1 x double> undef, double [[TMP12]], i64 0 +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <8 x double> [[SPLIT]], i64 7 +; CHECK-NEXT: [[TMP15:%.*]] = insertelement <1 x double> undef, double [[TMP14]], i64 0 +; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <1 x double> [[TMP1]], <1 x double> [[TMP3]], <2 x i32> +; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <1 x double> [[TMP5]], <1 x double> [[TMP7]], <2 x i32> +; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <1 x double> [[TMP9]], <1 x double> [[TMP11]], <2 x i32> +; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <1 x double> [[TMP13]], <1 x double> [[TMP15]], <2 x i32> +; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x double> [[TMP16]], <2 x double> [[TMP17]], <4 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = shufflevector <2 x double> [[TMP18]], <2 x double> [[TMP19]], <4 x i32> +; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <4 x double> [[TMP20]], <4 x double> [[TMP21]], <8 x i32> +; CHECK-NEXT: store <8 x double> [[TMP22]], <8 x double>* [[PTR_B:%.*]], align 16 +; CHECK-NEXT: ret void +; +entry: + %a = load <8 x double>, <8 x double> *%Ptr.A, align 16 + %c = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 8, i32 1) + + store <8 x double> %c, <8 x double> *%Ptr.B, align 16 + ret void +} + +declare <12 x i16> @llvm.matrix.transpose.v12i16(<12 x i16>, i32, i32) + +define void @transpose_i16_3x4(<12 x i16>* %Ptr.A, <12 x i16>* %Ptr.B) { +; CHECK-LABEL: @transpose_i16_3x4( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <12 x i16>, <12 x i16>* [[PTR_A:%.*]], align 16 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <12 x i16> [[A]], <12 x i16> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <12 x i16> [[A]], <12 x i16> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <12 x i16> [[A]], <12 x i16> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <12 x i16> [[A]], <12 x i16> undef, <3 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <3 x i16> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i16> undef, i16 [[TMP0]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <3 x i16> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x i16> [[TMP1]], i16 [[TMP2]], i64 1 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <3 x i16> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x i16> [[TMP3]], i16 [[TMP4]], i64 2 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <3 x i16> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x i16> [[TMP5]], i16 [[TMP6]], i64 3 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <3 x i16> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x i16> undef, i16 [[TMP8]], i64 0 +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <3 x i16> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x i16> [[TMP9]], i16 [[TMP10]], i64 1 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <3 x i16> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x i16> [[TMP11]], i16 [[TMP12]], i64 2 +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <3 x i16> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x i16> [[TMP13]], i16 [[TMP14]], i64 3 +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <3 x i16> [[SPLIT]], i64 2 +; CHECK-NEXT: [[TMP17:%.*]] = insertelement <4 x i16> undef, i16 [[TMP16]], i64 0 +; CHECK-NEXT: [[TMP18:%.*]] = extractelement <3 x i16> [[SPLIT1]], i64 2 +; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x i16> [[TMP17]], i16 [[TMP18]], i64 1 +; CHECK-NEXT: [[TMP20:%.*]] = extractelement <3 x i16> [[SPLIT2]], i64 2 +; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i16> [[TMP19]], i16 [[TMP20]], i64 2 +; CHECK-NEXT: [[TMP22:%.*]] = extractelement <3 x i16> [[SPLIT3]], i64 2 +; CHECK-NEXT: [[TMP23:%.*]] = insertelement <4 x i16> [[TMP21]], i16 [[TMP22]], i64 3 +; CHECK-NEXT: [[TMP24:%.*]] = shufflevector <4 x i16> [[TMP7]], <4 x i16> [[TMP15]], <8 x i32> +; CHECK-NEXT: [[TMP25:%.*]] = shufflevector <4 x i16> [[TMP23]], <4 x i16> undef, <8 x i32> +; CHECK-NEXT: [[TMP26:%.*]] = shufflevector <8 x i16> [[TMP24]], <8 x i16> [[TMP25]], <12 x i32> +; CHECK-NEXT: store <12 x i16> [[TMP26]], <12 x i16>* [[PTR_B:%.*]], align 16 +; CHECK-NEXT: ret void +; +entry: + %a = load <12 x i16>, <12 x i16> *%Ptr.A, align 16 + %c = call <12 x i16> @llvm.matrix.transpose.v12i16(<12 x i16> %a, i32 3, i32 4) + + store <12 x i16> %c, <12 x i16> *%Ptr.B, align 16 + ret void +}