diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -14408,6 +14408,101 @@ """""""""" The argument to this intrinsic must be a vector of floating-point values. +Matrix Intrinsics +----------------- + +Operations on matrixes requiring shape information (like number of rows/columns +or the memory layout) can be expressed using the matrix intrinsics. Matrixes are +embedded in a flat vector and the intrinsics take the dimensions as arguments. +Currently column-major layout is assumed. The intrinsics support both integer +and floating point matrixes. + + +'``llvm.matrix.transpose.*``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +:: + + declare vectorty @llvm.matrix.transpose.*(vectorty %in, i32 , i32 ) + +Overview: +""""""""" + +The '``llvm.matrix.transpose.*``' intrinsic treats %in as containing a matrix +with rows and columns and returns the transposed matrix embedded in +the result vector. + +Arguments: +"""""""""" + +The and arguments must be constant integers. + +'``llvm.matrix.multiply.*``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +:: + + declare vectorty @llvm.matrix.multiply.*(vectorty %A, vectorty %B, i32 , i32 , i32 ) + +Overview: +""""""""" + +The '``llvm.matrix.multiply.*``' intrinsic treats %A as matrix with rows and columns, %B as +matrix with rows and columns and multiply them. The result matrix is returned embedded in the +result vector. + +Arguments: +"""""""""" + +The , and arguments must be constant integers. + +'``llvm.matrix.columnwise.load.*``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +:: + + declare vectorty @llvm.matrix.columnwise.load.*(ptrty %Ptr, i32 %Stride, i32 , i32 ) + +Overview: +""""""""" + +The '``llvm.matrix.columnwise.load.*``' intrinsic loads a matrix with +rows and columns, using a stride of %Stride between columns. The result +matrix is returned embedded in the result vector. This allows for convenient +loading of sub matrixes. + +'``llvm.matrix.columnwise.store.*``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +:: + + declare void @llvm.matrix.columnwise.store.*(vectorty %In, ptrty %Ptr, i32 %Stride, i32 , i32 ) + +Overview: +""""""""" + +The '``llvm.matrix.columnwise.store.*``' intrinsic stores the matrix with + rows and columns embedded in %In , using a stride of %Stride +between columns. This allows for convenient storing of sub matrixes. + + +Arguments: +"""""""""" + +The and arguments must be constant integers. + Half Precision Floating-Point Intrinsics ---------------------------------------- diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -1232,6 +1232,37 @@ [llvm_anyvector_ty]>; } +//===----- Matrix intrinsics ---------------------------------------------===// + +def int_matrix_transpose : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, + llvm_i32_ty, + llvm_i32_ty], + [IntrNoMem, IntrSpeculatable, IntrWillReturn]>; + +def int_matrix_multiply : Intrinsic<[llvm_anyvector_ty], + [llvm_anyvector_ty, + llvm_anyvector_ty, + llvm_i32_ty, + llvm_i32_ty, + llvm_i32_ty], + [IntrNoMem, IntrSpeculatable, IntrWillReturn]>; + +def int_matrix_columnwise_load : Intrinsic<[llvm_anyvector_ty], + [LLVMAnyPointerType>, + llvm_i32_ty, + llvm_i32_ty, + llvm_i32_ty], + [IntrReadMem]>; + +def int_matrix_columnwise_store : Intrinsic<[], + [llvm_anyvector_ty, + LLVMAnyPointerType>, + llvm_i32_ty, + llvm_i32_ty, + llvm_i32_ty], + [WriteOnly<1>]>; + //===---------- Intrinsics to control hardware supported loops ----------===// // Specify that the value given is the number of iterations that the next loop diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -255,6 +255,7 @@ void initializeLowerInvokeLegacyPassPass(PassRegistry&); void initializeLowerSwitchPass(PassRegistry&); void initializeLowerTypeTestsPass(PassRegistry&); +void initializeLowerMatrixIntrinsicsLegacyPassPass(PassRegistry &); void initializeMIRCanonicalizerPass(PassRegistry &); void initializeMIRNamerPass(PassRegistry &); void initializeMIRPrintingPassPass(PassRegistry&); diff --git a/llvm/include/llvm/Transforms/Scalar.h b/llvm/include/llvm/Transforms/Scalar.h --- a/llvm/include/llvm/Transforms/Scalar.h +++ b/llvm/include/llvm/Transforms/Scalar.h @@ -359,6 +359,8 @@ // Pass *createLowerGuardIntrinsicPass(); +Pass *createLowerMatrixIntrinsicsPass(); + //===----------------------------------------------------------------------===// // // LowerWidenableCondition - Lower widenable condition to i1 true. diff --git a/llvm/include/llvm/Transforms/Scalar/LowerMatrixIntrinsics.h b/llvm/include/llvm/Transforms/Scalar/LowerMatrixIntrinsics.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Transforms/Scalar/LowerMatrixIntrinsics.h @@ -0,0 +1,24 @@ +//===- LowerMatrixIntrinsics.h - Lower matrix intrinsics. -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass lowers matrix intrinsics down to vector operations. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_SCALAR_LOWERMATRIXINTRINSICSPASS_H +#define LLVM_TRANSFORMS_SCALAR_LOWERMATRIXINTRINSICSPASS_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { +struct LowerMatrixIntrinsicsPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; +} // namespace llvm + +#endif diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -146,6 +146,7 @@ #include "llvm/Transforms/Scalar/LowerConstantIntrinsics.h" #include "llvm/Transforms/Scalar/LowerExpectIntrinsic.h" #include "llvm/Transforms/Scalar/LowerGuardIntrinsic.h" +#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" #include "llvm/Transforms/Scalar/LowerWidenableCondition.h" #include "llvm/Transforms/Scalar/MakeGuardsExplicit.h" #include "llvm/Transforms/Scalar/MemCpyOptimizer.h" diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -189,6 +189,7 @@ FUNCTION_PASS("lower-expect", LowerExpectIntrinsicPass()) FUNCTION_PASS("lower-guard-intrinsic", LowerGuardIntrinsicPass()) FUNCTION_PASS("lower-constant-intrinsics", LowerConstantIntrinsicsPass()) +FUNCTION_PASS("lower-matrix-intrinsics", LowerMatrixIntrinsicsPass()) FUNCTION_PASS("lower-widenable-condition", LowerWidenableConditionPass()) FUNCTION_PASS("guard-widening", GuardWideningPass()) FUNCTION_PASS("gvn", GVN()) diff --git a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp --- a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -147,6 +147,10 @@ "enable-order-file-instrumentation", cl::init(false), cl::Hidden, cl::desc("Enable order file instrumentation (default = off)")); +static cl::opt + EnableMatrix("enable-matrix", cl::init(false), cl::Hidden, + cl::desc("Enable lowering of the matrix intrinsics")); + PassManagerBuilder::PassManagerBuilder() { OptLevel = 2; SizeLevel = 0; @@ -656,6 +660,14 @@ MPM.add(createFloat2IntPass()); MPM.add(createLowerConstantIntrinsicsPass()); + if (EnableMatrix) { + MPM.add(createLowerMatrixIntrinsicsPass()); + // CSE the pointer arithmetic of the column vectors. This allows alias + // analysis to establish no-aliasing between loads and stores of different + // columns of the same matrix. + MPM.add(createEarlyCSEPass(false)); + } + addExtensionsToPM(EP_VectorizerStart, MPM); // Re-rotate loops in all our loop nests. These may have fallout out of diff --git a/llvm/lib/Transforms/Scalar/CMakeLists.txt b/llvm/lib/Transforms/Scalar/CMakeLists.txt --- a/llvm/lib/Transforms/Scalar/CMakeLists.txt +++ b/llvm/lib/Transforms/Scalar/CMakeLists.txt @@ -47,6 +47,7 @@ LowerConstantIntrinsics.cpp LowerExpectIntrinsic.cpp LowerGuardIntrinsic.cpp + LowerMatrixIntrinsics.cpp LowerWidenableCondition.cpp MakeGuardsExplicit.cpp MemCpyOptimizer.cpp diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -0,0 +1,516 @@ +//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Lower matrix intrinsics to vector operations. +// +// TODO: +// * Implement multiply & add fusion +// * Implement shape propagation +// * Implement optimizations to reduce shufflevector uses by using shape +// information. +// * Add remark, summarizing the available matrix optimization opportunities. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" + +using namespace llvm; + +#define DEBUG_TYPE "lower-matrix-intrinsics" + +// Given a \p MatrixPtr for the in-memory representation of a matrix, +// compute the address of the element at index \p Row, \p Col. \p Stride is +// the number of elements to skip to move same row, next column (this the +// number of rows other than accessing a submatrix. +static Value *computeEltAddr(Value *MatrixPtr, Value *Row, Value *Col, + Type *EltType, Value *Stride, + IRBuilder<> &Builder) { + Type *EltPtrType = PointerType::get( + EltType, cast(MatrixPtr->getType())->getAddressSpace()); + Value *Base = Builder.CreatePointerCast(MatrixPtr, EltPtrType); + + unsigned RowBitWidth = cast(Row->getType())->getBitWidth(); + unsigned ColBitWidth = cast(Col->getType())->getBitWidth(); + unsigned StrideBitWidth = cast(Stride->getType())->getBitWidth(); + + unsigned WidestBitWidth = + std::max(std::max(RowBitWidth, ColBitWidth), StrideBitWidth); + + Type *IntegerType = IntegerType::get(Builder.getContext(), WidestBitWidth); + + // If they are not the same width, extend all to be the same width + if (RowBitWidth != WidestBitWidth || ColBitWidth != WidestBitWidth || + StrideBitWidth != WidestBitWidth) { + Row = Builder.CreateZExt(Row, IntegerType); + Col = Builder.CreateZExt(Col, IntegerType); + Stride = Builder.CreateZExt(Stride, IntegerType); + } + + // i = base + row + column * stride + + // Distance to the desired column + // (column * + stride) + Value *ColumnOffset = Builder.CreateMul(Col, Stride); + + // Compute the final element address offset + // (row + column * stride) + Value *EltIndex = Builder.CreateAdd(Row, ColumnOffset); + if (isa(EltIndex) && cast(EltIndex)->isZero()) + return Base; + return Builder.CreateGEP(EltType, Base, EltIndex); +} + +namespace { +/// LowerMatrixIntrinsics class contains the methods used to lower +/// instructions and intrinsic calls on matrices. +class LowerMatrixIntrinsics { + /// Maximum "expected" number of columns in all matrix, to improve performance + static constexpr unsigned DefaultVectorSize = 16; + + Function &Func; + const DataLayout &DL; + const TargetTransformInfo &TTI; + + typedef SmallVector LoweredMatrixVec; + class MatrixTy { + + LoweredMatrixVec Columns; + + public: + MatrixTy() : Columns() {} + MatrixTy(ArrayRef Cols) : Columns(Cols.begin(), Cols.end()) {} + + Value *getColumn(unsigned i) const { return Columns[i]; } + + void setColumn(unsigned i, Value *V) { Columns[i] = V; } + + size_t getNumColumns() const { return Columns.size(); } + + const LoweredMatrixVec &getColumnVectors() const { return Columns; } + + LoweredMatrixVec &getColumnVectors() { return Columns; } + + void addColumn(Value *V) { Columns.push_back(V); } + + iterator_range columns() { + return make_range(Columns.begin(), Columns.end()); + } + + Value *embeddInVector(IRBuilder<> &Builder) const { + return Columns.size() == 1 ? Columns[0] + : concatenateVectors(Builder, Columns); + } + }; + + /// The list of instructions that need to be removed after + /// lowering. + SmallVector ToRemove; + + struct ShapeInfo { + unsigned NumRows; + unsigned NumColumns; + + ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) + : NumRows(NumRows), NumColumns(NumColumns) {} + + ShapeInfo(ConstantInt *NumRows, ConstantInt *NumColumns) + : NumRows(NumRows->getZExtValue()), + NumColumns(NumColumns->getZExtValue()) {} + + operator bool() const { return NumRows != 0; } + bool operator==(const ShapeInfo &other) { + return NumRows == other.NumRows && NumColumns == other.NumColumns; + } + bool operator!=(const ShapeInfo &other) { return !(*this == other); } + }; + +public: + LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI) + : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {} + + /// Return the set of column vectors that a matrix value is lowered to. + /// + /// When this value is defined as a matrix value via an instruction return the + /// lowered definitions for the column vectors. When this is a constant or a + /// value cast from another type, emit code to split the value into column + /// vectors. + MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, + IRBuilder<> Builder) { + auto *VType = dyn_cast(MatrixVal->getType()); + assert(VType && "MatrixVal must be a matrix type"); + if (auto *C = dyn_cast(MatrixVal)) { + return splitToColumnVectors(MatrixVal, VType->getElementType(), SI, + Builder); + } + + return splitVector(MatrixVal, SI.NumRows, Builder); + } + + /// Split a matrix into its column vectors. + /// + /// This works by casting it to a vector spanning the entire matrix and then + /// splitting this into column-sized vectors. This is supposed to be used + /// with values that have no lowered definitions for the column vectors. + MatrixTy splitToColumnVectors(Value *MatrixVal, Type *EltTy, + const ShapeInfo &SI, IRBuilder<> Builder) { + + auto *Vec = Builder.CreateBitCast( + MatrixVal, VectorType::get(EltTy, SI.NumRows * SI.NumColumns), + MatrixVal->getName()); + return splitVector(Vec, SI.NumRows, Builder); + } + + // Split up \p Vec into vectors of \p SplitSize. + MatrixTy splitVector(Value *Vec, unsigned SplitSize, IRBuilder<> &Builder) { + SmallVector SplitVecs; + Value *Undef = UndefValue::get(Vec->getType()); + unsigned UnsplitSize = cast(Vec->getType())->getNumElements(); + + for (unsigned MaskStart = 0; MaskStart < UnsplitSize; + MaskStart += SplitSize) { + Constant *Mask = createSequentialMask(Builder, MaskStart, SplitSize, 0); + Value *V = Builder.CreateShuffleVector(Vec, Undef, Mask, "split"); + SplitVecs.push_back(V); + } + + return {SplitVecs}; + } + + // Replace intrinsic calls + bool VisitCallInst(CallInst *Inst) { + if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic()) + return false; + + switch (Inst->getCalledFunction()->getIntrinsicID()) { + case Intrinsic::matrix_multiply: + LowerMatrixMultiply(Inst); + break; + case Intrinsic::matrix_transpose: + LowerMatrixTranspose(Inst); + break; + case Intrinsic::matrix_columnwise_load: + LowerMatrixLoad(Inst); + break; + case Intrinsic::matrix_columnwise_store: + LowerMatrixStore(Inst); + break; + default: + return false; + } + ToRemove.push_back(Inst); + return true; + } + + bool Visit() { + ReversePostOrderTraversal RPOT(&Func); + bool Changed = false; + for (auto *BB : RPOT) { + for (Instruction &Inst : *BB) { + bool Touched = false; + + if (CallInst *CInst = dyn_cast(&Inst)) + Touched = VisitCallInst(CInst); + Changed |= Touched; + } + } + + std::reverse(ToRemove.begin(), ToRemove.end()); + for (auto *Inst : ToRemove) + Inst->eraseFromParent(); + return Changed; + } + + /// Return the address of a column vector (\p EltType x \p Rows) at index (\p + /// Row, \p Col) of \p Base with original column size of \p Stride elements. + Value *computeColumnAddr(Value *Base, unsigned Row, unsigned Col, + Value *Stride, Type *EltType, unsigned Rows, + IRBuilder<> &Builder) { + Value *EltPtr = + computeEltAddr(Base, Builder.getInt32(Row), Builder.getInt32(Col), + EltType, Stride, Builder); + Type *ColumnType = VectorType::get(EltType, Rows); + Type *ColumnPtrType = PointerType::get( + ColumnType, cast(Base->getType())->getAddressSpace()); + return Builder.CreatePointerCast(EltPtr, ColumnPtrType); + } + + Value *computeColumnAddr(Value *Base, unsigned Col, Value *Skip, + VectorType *VType, ShapeInfo Shape, + IRBuilder<> &Builder) { + Value *Rows = Builder.getIntN( + cast(Skip->getType())->getBitWidth(), Shape.NumRows); + Value *Stride = Builder.CreateAdd(Rows, Skip); + return computeColumnAddr(Base, 0, Col, Stride, VType->getElementType(), + Shape.NumRows, Builder); + } + + Value *computeColumnAddr(Value *Base, unsigned Col, VectorType *VType, + ShapeInfo Shape, IRBuilder<> &Builder) { + return computeColumnAddr(Base, Col, Builder.getInt32(0), VType, Shape, + Builder); + } + + LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, + IRBuilder<> Builder) { + unsigned Align = DL.getABITypeAlignment(EltType); + return Builder.CreateAlignedLoad(ColumnPtr, Align); + } + + StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, + Type *EltType, IRBuilder<> Builder) { + unsigned Align = DL.getABITypeAlignment(EltType); + return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align); + } + + /// Handles lowering the non-contiguous matrix load + /// + /// The intrinsic loads a matrix from an array with an offset from the + /// initial element and a stride between columns + /// + /// See VisitLoadInst for handling actual load instructions + void LowerMatrixLoad(CallInst *Inst) { + IRBuilder<> Builder(Inst); + Value *Ptr = Inst->getArgOperand(0); + Value *Stride = Inst->getArgOperand(1); + auto VType = cast(Inst->getType()); + + ShapeInfo Shape(cast(Inst->getArgOperand(2)), + cast(Inst->getArgOperand(3))); + + MatrixTy Result; + + // Distance between start of one column and the start of the next + for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { + Value *GEP = computeColumnAddr(Ptr, C, Stride, VType, Shape, Builder); + Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); + Result.addColumn(Column); + } + + Inst->replaceAllUsesWith(Result.embeddInVector(Builder)); + } + + /// Handles lowering the non-contiguous matrix store + /// + /// The intrinsic store a matrix back to an array with an offset from + /// the initial element pointer and a stride between columns. + /// + /// See VisitStoreInst for handling actual store instructions + void LowerMatrixStore(CallInst *Inst) { + IRBuilder<> Builder(Inst); + Value *Matrix = Inst->getArgOperand(0); + Value *Ptr = Inst->getArgOperand(1); + Value *Stride = Inst->getArgOperand(2); + ShapeInfo Shape(cast(Inst->getArgOperand(3)), + cast(Inst->getArgOperand(4))); + + auto VType = cast(Matrix->getType()); + + auto LM = getMatrix(Matrix, Shape, Builder); + + for (auto C : enumerate(LM.columns())) { + Value *GEP = + computeColumnAddr(Ptr, C.index(), Stride, VType, Shape, Builder); + createColumnStore(C.value(), GEP, VType->getElementType(), Builder); + } + } + + /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from + /// the matrix \p Cols represented as a vector of column vectors. + Value *extractVector(const MatrixTy &LM, unsigned I, unsigned J, + unsigned NumElts, IRBuilder<> Builder) { + Value *Col = LM.getColumn(J); + Value *Undef = UndefValue::get(Col->getType()); + Constant *Mask = createSequentialMask(Builder, I, NumElts, 0); + return Builder.CreateShuffleVector(Col, Undef, Mask, "block"); + } + + // Set elements I..I+NumElts-1 to Block + Value *insertVector(Value *Col, unsigned I, Value *Block, + IRBuilder<> Builder) { + + // First, bring Block to the same size as Col + unsigned BlockNumElts = + cast(Block->getType())->getNumElements(); + unsigned NumElts = cast(Col->getType())->getNumElements(); + + Value *ExtendMask = + createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts); + Value *Undef = UndefValue::get(Block->getType()); + Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); + + // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7, + // 8, 4, 5, 6 + SmallVector Mask; + unsigned i; + for (i = 0; i < I; i++) + Mask.push_back(Builder.getInt32(i)); + + unsigned VecNumElts = cast(Col->getType())->getNumElements(); + for (; i < I + BlockNumElts; i++) + Mask.push_back(Builder.getInt32(i - I + VecNumElts)); + + for (; i < VecNumElts; i++) + Mask.push_back(Builder.getInt32(i)); + + Value *MaskVal = ConstantVector::get(Mask); + + return Builder.CreateShuffleVector(Col, Block, MaskVal); + } + + Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, + IRBuilder<> &Builder) { + Value *Mul = UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); + if (!Sum) + return Mul; + + return UseFPOp ? Builder.CreateFAdd(Sum, Mul) : Builder.CreateAdd(Sum, Mul); + } + + void LowerMatrixMultiply(CallInst *MatMul) { + IRBuilder<> Builder(MatMul); + auto *EltType = cast(MatMul->getType())->getElementType(); + ShapeInfo LShape(cast(MatMul->getArgOperand(2)), + cast(MatMul->getArgOperand(3))); + ShapeInfo RShape(cast(MatMul->getArgOperand(3)), + cast(MatMul->getArgOperand(4))); + + const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); + const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); + + const unsigned R = LShape.NumRows; + const unsigned M = LShape.NumColumns; + const unsigned C = RShape.NumColumns; + assert(M == RShape.NumRows); + + // Initialize the output + MatrixTy Result; + for (unsigned J = 0; J < C; ++J) + Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); + + const unsigned VF = std::max(TTI.getRegisterBitWidth(true) / + EltType->getPrimitiveSizeInBits(), + 1ULL); + + // Multiply columns from the first operand with scalars from the second + // operand. Then move along the K axes and accumulate the columns. With + // this the adds can be vectorized without reassociation. + for (unsigned J = 0; J < C; ++J) { + unsigned BlockSize = VF; + for (unsigned I = 0; I < R; I += BlockSize) { + // Gradually lower the vectorization factor to cover the remainder. + while (I + BlockSize > R) + BlockSize /= 2; + + Value *Sum = nullptr; + for (unsigned K = 0; K < M; ++K) { + Value *L = extractVector(Lhs, I, K, BlockSize, Builder); + Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); + Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); + Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), + Builder); + } + Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); + } + } + + MatMul->replaceAllUsesWith(Result.embeddInVector(Builder)); + } + + void LowerMatrixTranspose(CallInst *Inst) { + MatrixTy Result; + IRBuilder<> Builder(Inst); + Value *InputVal = Inst->getArgOperand(0); + VectorType *VectorTy = cast(InputVal->getType()); + ShapeInfo ArgShape(cast(Inst->getArgOperand(1)), + cast(Inst->getArgOperand(2))); + MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); + + for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { + // Build a single column vector for this row. First initialize it. + Value *ResultColumn = UndefValue::get( + VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); + + // Go through the elements of this row and insert it into the resulting + // column vector. + for (auto C : enumerate(InputMatrix.columns())) { + Value *Elt = Builder.CreateExtractElement(C.value(), Row); + // We insert at index Column since that is the row index after the + // transpose. + ResultColumn = + Builder.CreateInsertElement(ResultColumn, Elt, C.index()); + } + Result.addColumn(ResultColumn); + } + + Inst->replaceAllUsesWith(Result.embeddInVector(Builder)); + } +}; +} // namespace + +PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &TTI = AM.getResult(F); + LowerMatrixIntrinsics LMT(F, TTI); + if (LMT.Visit()) { + PreservedAnalyses PA; + PA.preserveSet(); + return PA; + } + return PreservedAnalyses::all(); +} + +namespace { + +class LowerMatrixIntrinsicsLegacyPass : public FunctionPass { +public: + static char ID; + + LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) { + initializeLowerMatrixIntrinsicsLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto *TTI = &getAnalysis().getTTI(F); + LowerMatrixIntrinsics LMT(F, *TTI); + bool C = LMT.Visit(); + return C; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesCFG(); + } +}; +} // namespace + +static const char pass_name[] = "Lower operations on the matrix type"; +char LowerMatrixIntrinsicsLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, + false, false) +INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name, + false, false) + +Pass *llvm::createLowerMatrixIntrinsicsPass() { + return new LowerMatrixIntrinsicsLegacyPass(); +} diff --git a/llvm/lib/Transforms/Scalar/Scalar.cpp b/llvm/lib/Transforms/Scalar/Scalar.cpp --- a/llvm/lib/Transforms/Scalar/Scalar.cpp +++ b/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -82,6 +82,7 @@ initializeLowerConstantIntrinsicsPass(Registry); initializeLowerExpectIntrinsicPass(Registry); initializeLowerGuardIntrinsicLegacyPassPass(Registry); + initializeLowerMatrixIntrinsicsLegacyPassPass(Registry); initializeLowerWidenableConditionLegacyPassPass(Registry); initializeMemCpyOptLegacyPassPass(Registry); initializeMergeICmpsLegacyPassPass(Registry); diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply.ll @@ -0,0 +1,272 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + + +define void @multiply_2x2(<4 x double> * %Ptr.A, <4 x double> * %Ptr.B, <4 x double>* %Ptr.C) { +; CHECK-LABEL: @multiply_2x2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <4 x double>, <4 x double>* [[PTR_A:%.*]], align 16 +; CHECK-NEXT: [[B:%.*]] = load <4 x double>, <4 x double>* [[PTR_B:%.*]], align 16 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x double> [[A]], <4 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x double> [[A]], <4 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <4 x double> [[B]], <4 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <4 x double> [[B]], <4 x double> undef, <2 x i32> +; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x double> undef, double [[TMP0]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = fmul <1 x double> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[BLOCK4:%.*]] = shufflevector <2 x double> [[SPLIT1]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT5:%.*]] = insertelement <1 x double> undef, double [[TMP2]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT6:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT5]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = fmul <1 x double> [[BLOCK4]], [[SPLAT_SPLAT6]] +; CHECK-NEXT: [[TMP4:%.*]] = fadd <1 x double> [[TMP1]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <1 x double> [[TMP4]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x double> undef, <2 x double> [[TMP5]], <2 x i32> +; CHECK-NEXT: [[BLOCK7:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT8:%.*]] = insertelement <1 x double> undef, double [[TMP7]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT9:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT8]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP8:%.*]] = fmul <1 x double> [[BLOCK7]], [[SPLAT_SPLAT9]] +; CHECK-NEXT: [[BLOCK10:%.*]] = shufflevector <2 x double> [[SPLIT1]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT11:%.*]] = insertelement <1 x double> undef, double [[TMP9]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT12:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT11]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP10:%.*]] = fmul <1 x double> [[BLOCK10]], [[SPLAT_SPLAT12]] +; CHECK-NEXT: [[TMP11:%.*]] = fadd <1 x double> [[TMP8]], [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <1 x double> [[TMP11]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x double> [[TMP6]], <2 x double> [[TMP12]], <2 x i32> +; CHECK-NEXT: [[BLOCK13:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT14:%.*]] = insertelement <1 x double> undef, double [[TMP14]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT15:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT14]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP15:%.*]] = fmul <1 x double> [[BLOCK13]], [[SPLAT_SPLAT15]] +; CHECK-NEXT: [[BLOCK16:%.*]] = shufflevector <2 x double> [[SPLIT1]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT17:%.*]] = insertelement <1 x double> undef, double [[TMP16]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT18:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT17]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP17:%.*]] = fmul <1 x double> [[BLOCK16]], [[SPLAT_SPLAT18]] +; CHECK-NEXT: [[TMP18:%.*]] = fadd <1 x double> [[TMP15]], [[TMP17]] +; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <1 x double> [[TMP18]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x double> undef, <2 x double> [[TMP19]], <2 x i32> +; CHECK-NEXT: [[BLOCK19:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT20:%.*]] = insertelement <1 x double> undef, double [[TMP21]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT21:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT20]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP22:%.*]] = fmul <1 x double> [[BLOCK19]], [[SPLAT_SPLAT21]] +; CHECK-NEXT: [[BLOCK22:%.*]] = shufflevector <2 x double> [[SPLIT1]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[TMP23:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT23:%.*]] = insertelement <1 x double> undef, double [[TMP23]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT24:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT23]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP24:%.*]] = fmul <1 x double> [[BLOCK22]], [[SPLAT_SPLAT24]] +; CHECK-NEXT: [[TMP25:%.*]] = fadd <1 x double> [[TMP22]], [[TMP24]] +; CHECK-NEXT: [[TMP26:%.*]] = shufflevector <1 x double> [[TMP25]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP27:%.*]] = shufflevector <2 x double> [[TMP20]], <2 x double> [[TMP26]], <2 x i32> +; CHECK-NEXT: [[TMP28:%.*]] = shufflevector <2 x double> [[TMP13]], <2 x double> [[TMP27]], <4 x i32> +; CHECK-NEXT: store <4 x double> [[TMP28]], <4 x double>* [[PTR_C:%.*]], align 16 +; CHECK-NEXT: ret void +; +entry: + %a = load <4 x double>, <4 x double>* %Ptr.A, align 16 + %b = load <4 x double>, <4 x double>* %Ptr.B, align 16 + %c = call <4 x double> @llvm.matrix.multiply.v4f64.v4f64.v4f64(<4 x double> %a, <4 x double> %b, i32 2, i32 2, i32 2) + store <4 x double> %c, <4 x double>* %Ptr.C, align 16 + ret void +} + +declare <4 x double> @llvm.matrix.multiply.v4f64.v4f64.v4f64(<4 x double>, <4 x double>, i32, i32, i32) + +define void @multiply_1x2(<2 x double> * %Ptr.A, <2 x double> * %Ptr.B, <4 x double>* %Ptr.C) { +; CHECK-LABEL: @multiply_1x2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <2 x double>, <2 x double>* [[PTR_A:%.*]], align 16 +; CHECK-NEXT: [[B:%.*]] = load <2 x double>, <2 x double>* [[PTR_B:%.*]], align 16 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <2 x double> [[A]], <2 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <2 x double> [[B]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <2 x double> [[B]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <1 x double> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x double> undef, double [[TMP0]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = fmul <1 x double> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <1 x double> [[TMP1]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <2 x double> undef, <2 x double> [[TMP2]], <2 x i32> +; CHECK-NEXT: [[BLOCK3:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <1 x double> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT4:%.*]] = insertelement <1 x double> undef, double [[TMP4]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT5:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT4]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = fmul <1 x double> [[BLOCK3]], [[SPLAT_SPLAT5]] +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <1 x double> [[TMP5]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <2 x double> [[TMP3]], <2 x double> [[TMP6]], <2 x i32> +; CHECK-NEXT: [[BLOCK6:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <1 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT7:%.*]] = insertelement <1 x double> undef, double [[TMP8]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT8:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT7]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP9:%.*]] = fmul <1 x double> [[BLOCK6]], [[SPLAT_SPLAT8]] +; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <1 x double> [[TMP9]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <2 x double> undef, <2 x double> [[TMP10]], <2 x i32> +; CHECK-NEXT: [[BLOCK9:%.*]] = shufflevector <2 x double> [[SPLIT]], <2 x double> undef, <1 x i32> +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <1 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT10:%.*]] = insertelement <1 x double> undef, double [[TMP12]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT11:%.*]] = shufflevector <1 x double> [[SPLAT_SPLATINSERT10]], <1 x double> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP13:%.*]] = fmul <1 x double> [[BLOCK9]], [[SPLAT_SPLAT11]] +; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <1 x double> [[TMP13]], <1 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP15:%.*]] = shufflevector <2 x double> [[TMP11]], <2 x double> [[TMP14]], <2 x i32> +; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <2 x double> [[TMP7]], <2 x double> [[TMP15]], <4 x i32> +; CHECK-NEXT: store <4 x double> [[TMP16]], <4 x double>* [[PTR_C:%.*]], align 16 +; CHECK-NEXT: ret void +; +entry: + %a = load <2 x double>, <2 x double>* %Ptr.A, align 16 + %b = load <2 x double>, <2 x double>* %Ptr.B, align 16 + %c = call <4 x double> @llvm.matrix.multiply.v4f64.v2f64.v2f64(<2 x double> %a, <2 x double> %b, i32 2, i32 1, i32 2) + store <4 x double> %c, <4 x double>* %Ptr.C, align 16 + ret void +} + +declare <4 x double> @llvm.matrix.multiply.v4f64.v2f64.v2f64(<2 x double>, <2 x double>, i32, i32, i32) + +define void @multiply_i32_2x3(<6 x i32> * %Ptr.A, <6 x i32> * %Ptr.B, <9 x i32>* %Ptr.C) { +; CHECK-LABEL: @multiply_i32_2x3( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <6 x i32>, <6 x i32>* [[PTR_A:%.*]], align 16 +; CHECK-NEXT: [[B:%.*]] = load <6 x i32>, <6 x i32>* [[PTR_B:%.*]], align 16 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <6 x i32> [[A]], <6 x i32> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <6 x i32> [[A]], <6 x i32> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <6 x i32> [[B]], <6 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <6 x i32> [[B]], <6 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <6 x i32> [[B]], <6 x i32> undef, <2 x i32> +; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x i32> undef, i32 [[TMP0]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = mul <1 x i32> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[BLOCK5:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <1 x i32> undef, i32 [[TMP2]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT6]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = mul <1 x i32> [[BLOCK5]], [[SPLAT_SPLAT7]] +; CHECK-NEXT: [[TMP4:%.*]] = add <1 x i32> [[TMP1]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <1 x i32> [[TMP4]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <3 x i32> undef, <3 x i32> [[TMP5]], <3 x i32> +; CHECK-NEXT: [[BLOCK8:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <1 x i32> undef, i32 [[TMP7]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT9]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP8:%.*]] = mul <1 x i32> [[BLOCK8]], [[SPLAT_SPLAT10]] +; CHECK-NEXT: [[BLOCK11:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT12:%.*]] = insertelement <1 x i32> undef, i32 [[TMP9]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT13:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT12]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP10:%.*]] = mul <1 x i32> [[BLOCK11]], [[SPLAT_SPLAT13]] +; CHECK-NEXT: [[TMP11:%.*]] = add <1 x i32> [[TMP8]], [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <1 x i32> [[TMP11]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <3 x i32> [[TMP6]], <3 x i32> [[TMP12]], <3 x i32> +; CHECK-NEXT: [[BLOCK14:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT15:%.*]] = insertelement <1 x i32> undef, i32 [[TMP14]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT16:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT15]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP15:%.*]] = mul <1 x i32> [[BLOCK14]], [[SPLAT_SPLAT16]] +; CHECK-NEXT: [[BLOCK17:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT18:%.*]] = insertelement <1 x i32> undef, i32 [[TMP16]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT19:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT18]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP17:%.*]] = mul <1 x i32> [[BLOCK17]], [[SPLAT_SPLAT19]] +; CHECK-NEXT: [[TMP18:%.*]] = add <1 x i32> [[TMP15]], [[TMP17]] +; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <1 x i32> [[TMP18]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <3 x i32> [[TMP13]], <3 x i32> [[TMP19]], <3 x i32> +; CHECK-NEXT: [[BLOCK20:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP21:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT21:%.*]] = insertelement <1 x i32> undef, i32 [[TMP21]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT22:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT21]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP22:%.*]] = mul <1 x i32> [[BLOCK20]], [[SPLAT_SPLAT22]] +; CHECK-NEXT: [[BLOCK23:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP23:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT24:%.*]] = insertelement <1 x i32> undef, i32 [[TMP23]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT25:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT24]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP24:%.*]] = mul <1 x i32> [[BLOCK23]], [[SPLAT_SPLAT25]] +; CHECK-NEXT: [[TMP25:%.*]] = add <1 x i32> [[TMP22]], [[TMP24]] +; CHECK-NEXT: [[TMP26:%.*]] = shufflevector <1 x i32> [[TMP25]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP27:%.*]] = shufflevector <3 x i32> undef, <3 x i32> [[TMP26]], <3 x i32> +; CHECK-NEXT: [[BLOCK26:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP28:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT27:%.*]] = insertelement <1 x i32> undef, i32 [[TMP28]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT28:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT27]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP29:%.*]] = mul <1 x i32> [[BLOCK26]], [[SPLAT_SPLAT28]] +; CHECK-NEXT: [[BLOCK29:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP30:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT30:%.*]] = insertelement <1 x i32> undef, i32 [[TMP30]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT31:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT30]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP31:%.*]] = mul <1 x i32> [[BLOCK29]], [[SPLAT_SPLAT31]] +; CHECK-NEXT: [[TMP32:%.*]] = add <1 x i32> [[TMP29]], [[TMP31]] +; CHECK-NEXT: [[TMP33:%.*]] = shufflevector <1 x i32> [[TMP32]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP34:%.*]] = shufflevector <3 x i32> [[TMP27]], <3 x i32> [[TMP33]], <3 x i32> +; CHECK-NEXT: [[BLOCK32:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP35:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT33:%.*]] = insertelement <1 x i32> undef, i32 [[TMP35]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT34:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT33]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP36:%.*]] = mul <1 x i32> [[BLOCK32]], [[SPLAT_SPLAT34]] +; CHECK-NEXT: [[BLOCK35:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP37:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT36:%.*]] = insertelement <1 x i32> undef, i32 [[TMP37]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT37:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT36]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP38:%.*]] = mul <1 x i32> [[BLOCK35]], [[SPLAT_SPLAT37]] +; CHECK-NEXT: [[TMP39:%.*]] = add <1 x i32> [[TMP36]], [[TMP38]] +; CHECK-NEXT: [[TMP40:%.*]] = shufflevector <1 x i32> [[TMP39]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP41:%.*]] = shufflevector <3 x i32> [[TMP34]], <3 x i32> [[TMP40]], <3 x i32> +; CHECK-NEXT: [[BLOCK38:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP42:%.*]] = extractelement <2 x i32> [[SPLIT4]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT39:%.*]] = insertelement <1 x i32> undef, i32 [[TMP42]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT40:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT39]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP43:%.*]] = mul <1 x i32> [[BLOCK38]], [[SPLAT_SPLAT40]] +; CHECK-NEXT: [[BLOCK41:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP44:%.*]] = extractelement <2 x i32> [[SPLIT4]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT42:%.*]] = insertelement <1 x i32> undef, i32 [[TMP44]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT43:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT42]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP45:%.*]] = mul <1 x i32> [[BLOCK41]], [[SPLAT_SPLAT43]] +; CHECK-NEXT: [[TMP46:%.*]] = add <1 x i32> [[TMP43]], [[TMP45]] +; CHECK-NEXT: [[TMP47:%.*]] = shufflevector <1 x i32> [[TMP46]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP48:%.*]] = shufflevector <3 x i32> undef, <3 x i32> [[TMP47]], <3 x i32> +; CHECK-NEXT: [[BLOCK44:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP49:%.*]] = extractelement <2 x i32> [[SPLIT4]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT45:%.*]] = insertelement <1 x i32> undef, i32 [[TMP49]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT46:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT45]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP50:%.*]] = mul <1 x i32> [[BLOCK44]], [[SPLAT_SPLAT46]] +; CHECK-NEXT: [[BLOCK47:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP51:%.*]] = extractelement <2 x i32> [[SPLIT4]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT48:%.*]] = insertelement <1 x i32> undef, i32 [[TMP51]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT49:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT48]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP52:%.*]] = mul <1 x i32> [[BLOCK47]], [[SPLAT_SPLAT49]] +; CHECK-NEXT: [[TMP53:%.*]] = add <1 x i32> [[TMP50]], [[TMP52]] +; CHECK-NEXT: [[TMP54:%.*]] = shufflevector <1 x i32> [[TMP53]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP55:%.*]] = shufflevector <3 x i32> [[TMP48]], <3 x i32> [[TMP54]], <3 x i32> +; CHECK-NEXT: [[BLOCK50:%.*]] = shufflevector <3 x i32> [[SPLIT]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP56:%.*]] = extractelement <2 x i32> [[SPLIT4]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT51:%.*]] = insertelement <1 x i32> undef, i32 [[TMP56]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT52:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT51]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP57:%.*]] = mul <1 x i32> [[BLOCK50]], [[SPLAT_SPLAT52]] +; CHECK-NEXT: [[BLOCK53:%.*]] = shufflevector <3 x i32> [[SPLIT1]], <3 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP58:%.*]] = extractelement <2 x i32> [[SPLIT4]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT54:%.*]] = insertelement <1 x i32> undef, i32 [[TMP58]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT55:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT54]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP59:%.*]] = mul <1 x i32> [[BLOCK53]], [[SPLAT_SPLAT55]] +; CHECK-NEXT: [[TMP60:%.*]] = add <1 x i32> [[TMP57]], [[TMP59]] +; CHECK-NEXT: [[TMP61:%.*]] = shufflevector <1 x i32> [[TMP60]], <1 x i32> undef, <3 x i32> +; CHECK-NEXT: [[TMP62:%.*]] = shufflevector <3 x i32> [[TMP55]], <3 x i32> [[TMP61]], <3 x i32> +; CHECK-NEXT: [[TMP63:%.*]] = shufflevector <3 x i32> [[TMP20]], <3 x i32> [[TMP41]], <6 x i32> +; CHECK-NEXT: [[TMP64:%.*]] = shufflevector <3 x i32> [[TMP62]], <3 x i32> undef, <6 x i32> +; CHECK-NEXT: [[TMP65:%.*]] = shufflevector <6 x i32> [[TMP63]], <6 x i32> [[TMP64]], <9 x i32> +; CHECK-NEXT: store <9 x i32> [[TMP65]], <9 x i32>* [[PTR_C:%.*]], align 16 +; CHECK-NEXT: ret void +; +entry: + %a = load <6 x i32>, <6 x i32>* %Ptr.A, align 16 + %b = load <6 x i32>, <6 x i32>* %Ptr.B, align 16 + %c = call <9 x i32> @llvm.matrix.multiply.v6i32.v6i32.v6i32(<6 x i32> %a, <6 x i32> %b, i32 3, i32 2, i32 3) + store <9 x i32> %c, <9 x i32>* %Ptr.C, align 16 + ret void +} + +declare <9 x i32> @llvm.matrix.multiply.v6i32.v6i32.v6i32(<6 x i32>, <6 x i32>, i32, i32, i32) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-load.ll @@ -0,0 +1,83 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + +define <9 x double> @strided_load_3x3(<9 x double>* %in, i32 %stride) { +; CHECK-LABEL: @strided_load_3x3( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = add i32 3, [[STRIDE:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <9 x double>* [[IN:%.*]] to double* +; CHECK-NEXT: [[TMP2:%.*]] = mul i32 0, [[TMP0]] +; CHECK-NEXT: [[TMP3:%.*]] = add i32 0, [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr double, double* [[TMP1]], i32 [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = bitcast double* [[TMP4]] to <3 x double>* +; CHECK-NEXT: [[TMP6:%.*]] = load <3 x double>, <3 x double>* [[TMP5]], align 8 +; CHECK-NEXT: [[TMP7:%.*]] = add i32 3, [[STRIDE]] +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <9 x double>* [[IN]] to double* +; CHECK-NEXT: [[TMP9:%.*]] = mul i32 1, [[TMP7]] +; CHECK-NEXT: [[TMP10:%.*]] = add i32 0, [[TMP9]] +; CHECK-NEXT: [[TMP11:%.*]] = getelementptr double, double* [[TMP8]], i32 [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = bitcast double* [[TMP11]] to <3 x double>* +; CHECK-NEXT: [[TMP13:%.*]] = load <3 x double>, <3 x double>* [[TMP12]], align 8 +; CHECK-NEXT: [[TMP14:%.*]] = add i32 3, [[STRIDE]] +; CHECK-NEXT: [[TMP15:%.*]] = bitcast <9 x double>* [[IN]] to double* +; CHECK-NEXT: [[TMP16:%.*]] = mul i32 2, [[TMP14]] +; CHECK-NEXT: [[TMP17:%.*]] = add i32 0, [[TMP16]] +; CHECK-NEXT: [[TMP18:%.*]] = getelementptr double, double* [[TMP15]], i32 [[TMP17]] +; CHECK-NEXT: [[TMP19:%.*]] = bitcast double* [[TMP18]] to <3 x double>* +; CHECK-NEXT: [[TMP20:%.*]] = load <3 x double>, <3 x double>* [[TMP19]], align 8 +; CHECK-NEXT: [[TMP21:%.*]] = shufflevector <3 x double> [[TMP6]], <3 x double> [[TMP13]], <6 x i32> +; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <3 x double> [[TMP20]], <3 x double> undef, <6 x i32> +; CHECK-NEXT: [[TMP23:%.*]] = shufflevector <6 x double> [[TMP21]], <6 x double> [[TMP22]], <9 x i32> +; CHECK-NEXT: ret <9 x double> [[TMP23]] +; +entry: + %load = call <9 x double> @llvm.matrix.columnwise.load(<9 x double>* %in, i32 %stride, i32 3, i32 3) + ret <9 x double> %load +} + +declare <9 x double> @llvm.matrix.columnwise.load(<9 x double>*, i32, i32, i32) + +define <9 x double> @strided_load_9x1(<9 x double>* %in, i32 %stride) { +; CHECK-LABEL: @strided_load_9x1( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = add i32 9, [[STRIDE:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <9 x double>* [[IN:%.*]] to double* +; CHECK-NEXT: [[TMP2:%.*]] = mul i32 0, [[TMP0]] +; CHECK-NEXT: [[TMP3:%.*]] = add i32 0, [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr double, double* [[TMP1]], i32 [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = bitcast double* [[TMP4]] to <9 x double>* +; CHECK-NEXT: [[TMP6:%.*]] = load <9 x double>, <9 x double>* [[TMP5]], align 8 +; CHECK-NEXT: ret <9 x double> [[TMP6]] +; +entry: + %load = call <9 x double> @llvm.matrix.columnwise.load(<9 x double>* %in, i32 %stride, i32 9, i32 1) + ret <9 x double> %load +} + +declare <8 x i64> @llvm.matrix.columnwise.load.v8i64(<8 x i64>*, i32, i32, i32) + +define <8 x i64> @strided_load_i64_4x2(<8 x i64>* %in, i32 %stride) { +; CHECK-LABEL: @strided_load_i64_4x2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = add i32 4, [[STRIDE:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i64>* [[IN:%.*]] to i64* +; CHECK-NEXT: [[TMP2:%.*]] = mul i32 0, [[TMP0]] +; CHECK-NEXT: [[TMP3:%.*]] = add i32 0, [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i64, i64* [[TMP1]], i32 [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = bitcast i64* [[TMP4]] to <4 x i64>* +; CHECK-NEXT: [[TMP6:%.*]] = load <4 x i64>, <4 x i64>* [[TMP5]], align 4 +; CHECK-NEXT: [[TMP7:%.*]] = add i32 4, [[STRIDE]] +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <8 x i64>* [[IN]] to i64* +; CHECK-NEXT: [[TMP9:%.*]] = mul i32 1, [[TMP7]] +; CHECK-NEXT: [[TMP10:%.*]] = add i32 0, [[TMP9]] +; CHECK-NEXT: [[TMP11:%.*]] = getelementptr i64, i64* [[TMP8]], i32 [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = bitcast i64* [[TMP11]] to <4 x i64>* +; CHECK-NEXT: [[TMP13:%.*]] = load <4 x i64>, <4 x i64>* [[TMP12]], align 4 +; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i64> [[TMP6]], <4 x i64> [[TMP13]], <8 x i32> +; CHECK-NEXT: ret <8 x i64> [[TMP14]] +; +entry: + %load = call <8 x i64> @llvm.matrix.columnwise.load.v8i64(<8 x i64>* %in, i32 %stride, i32 4, i32 2) + ret <8 x i64> %load +} diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/strided-store.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-store.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/strided-store.ll @@ -0,0 +1,87 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + +define void @strided_store_double_3x2(<6 x double>* %in.addr, double* %out) { +; CHECK-LABEL: @strided_store_double_3x2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[IN:%.*]] = load <6 x double>, <6 x double>* [[IN_ADDR:%.*]], align 8 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <6 x double> [[IN]], <6 x double> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <6 x double> [[IN]], <6 x double> undef, <3 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = bitcast double* [[OUT:%.*]] to <3 x double>* +; CHECK-NEXT: store <3 x double> [[SPLIT]], <3 x double>* [[TMP0]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr double, double* [[OUT]], i32 5 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[TMP1]] to <3 x double>* +; CHECK-NEXT: store <3 x double> [[SPLIT1]], <3 x double>* [[TMP2]], align 8 +; CHECK-NEXT: ret void +; +entry: + + %in = load <6 x double>, <6 x double>* %in.addr, align 8 + call void @llvm.matrix.columnwise.store(<6 x double> %in, double* %out, i32 2, i32 3, i32 2) + ret void +} + +define void @strided_store_double_3x2_nonconst_stride(<6 x double>* %in.addr, i32 %stride, double* %out) { +; CHECK-LABEL: @strided_store_double_3x2_nonconst_stride( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[IN:%.*]] = load <6 x double>, <6 x double>* [[IN_ADDR:%.*]], align 8 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <6 x double> [[IN]], <6 x double> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <6 x double> [[IN]], <6 x double> undef, <3 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = add i32 3, [[STRIDE:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = mul i32 0, [[TMP0]] +; CHECK-NEXT: [[TMP2:%.*]] = add i32 0, [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr double, double* [[OUT:%.*]], i32 [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast double* [[TMP3]] to <3 x double>* +; CHECK-NEXT: store <3 x double> [[SPLIT]], <3 x double>* [[TMP4]], align 8 +; CHECK-NEXT: [[TMP5:%.*]] = add i32 3, [[STRIDE]] +; CHECK-NEXT: [[TMP6:%.*]] = mul i32 1, [[TMP5]] +; CHECK-NEXT: [[TMP7:%.*]] = add i32 0, [[TMP6]] +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr double, double* [[OUT]], i32 [[TMP7]] +; CHECK-NEXT: [[TMP9:%.*]] = bitcast double* [[TMP8]] to <3 x double>* +; CHECK-NEXT: store <3 x double> [[SPLIT1]], <3 x double>* [[TMP9]], align 8 +; CHECK-NEXT: ret void +; +entry: + + %in = load <6 x double>, <6 x double>* %in.addr, align 8 + call void @llvm.matrix.columnwise.store(<6 x double> %in, double* %out, i32 %stride, i32 3, i32 2) + ret void +} + + +declare void @llvm.matrix.columnwise.store(<6 x double>, double*, i32, i32, i32) + +define void @strided_store_i8_2x3(<10 x i8>* %in.addr, double* %out) { +; CHECK-LABEL: @strided_store_i8_2x3( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[IN:%.*]] = load <10 x i8>, <10 x i8>* [[IN_ADDR:%.*]], align 8 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <10 x i8> [[IN]], <10 x i8> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <10 x i8> [[IN]], <10 x i8> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <10 x i8> [[IN]], <10 x i8> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <10 x i8> [[IN]], <10 x i8> undef, <3 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = bitcast double* [[OUT:%.*]] to i8* +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <3 x i8>* +; CHECK-NEXT: store <3 x i8> [[SPLIT]], <3 x i8>* [[TMP1]], align 1 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[OUT]] to i8* +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, i8* [[TMP2]], i32 5 +; CHECK-NEXT: [[TMP4:%.*]] = bitcast i8* [[TMP3]] to <3 x i8>* +; CHECK-NEXT: store <3 x i8> [[SPLIT1]], <3 x i8>* [[TMP4]], align 1 +; CHECK-NEXT: [[TMP5:%.*]] = bitcast double* [[OUT]] to i8* +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, i8* [[TMP5]], i32 10 +; CHECK-NEXT: [[TMP7:%.*]] = bitcast i8* [[TMP6]] to <3 x i8>* +; CHECK-NEXT: store <3 x i8> [[SPLIT2]], <3 x i8>* [[TMP7]], align 1 +; CHECK-NEXT: [[TMP8:%.*]] = bitcast double* [[OUT]] to i8* +; CHECK-NEXT: [[TMP9:%.*]] = getelementptr i8, i8* [[TMP8]], i32 15 +; CHECK-NEXT: [[TMP10:%.*]] = bitcast i8* [[TMP9]] to <3 x i8>* +; CHECK-NEXT: store <3 x i8> [[SPLIT3]], <3 x i8>* [[TMP10]], align 1 +; CHECK-NEXT: ret void +; +entry: + + %in = load <10 x i8>, <10 x i8>* %in.addr, align 8 + call void @llvm.matrix.columnwise.store.v10i8(<10 x i8> %in, double* %out, i32 2, i32 3, i32 2) + ret void +} + +declare void @llvm.matrix.columnwise.store.v10i8(<10 x i8>, double*, i32, i32, i32) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/transpose.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/transpose.ll @@ -0,0 +1,129 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + + +define void @transpose(<8 x double>* %Ptr.A, <8 x double>* %Ptr.B) { +; CHECK-LABEL: @transpose( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <8 x double>, <8 x double>* [[PTR_A:%.*]], align 16 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x double> undef, double [[TMP0]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x double> [[TMP1]], double [[TMP2]], i64 1 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> [[TMP3]], double [[TMP4]], i64 2 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 3 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x double> undef, double [[TMP8]], i64 0 +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 1 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x double> [[TMP11]], double [[TMP12]], i64 2 +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 3 +; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> [[TMP15]], <8 x i32> +; CHECK-NEXT: store <8 x double> [[TMP16]], <8 x double>* [[PTR_B:%.*]], align 16 +; CHECK-NEXT: ret void +; +entry: + %a = load <8 x double>, <8 x double> *%Ptr.A, align 16 + %c = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 2, i32 4) + + store <8 x double> %c, <8 x double> *%Ptr.B, align 16 + ret void +} + +declare <8 x double> @llvm.matrix.transpose(<8 x double>, i32, i32) + +define void @transpose_single_column(<8 x double>* %Ptr.A, <8 x double>* %Ptr.B) { +; CHECK-LABEL: @transpose_single_column( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <8 x double>, <8 x double>* [[PTR_A:%.*]], align 16 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <8 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <8 x double> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <1 x double> undef, double [[TMP0]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <8 x double> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <1 x double> undef, double [[TMP2]], i64 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <8 x double> [[SPLIT]], i64 2 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <1 x double> undef, double [[TMP4]], i64 0 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <8 x double> [[SPLIT]], i64 3 +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <1 x double> undef, double [[TMP6]], i64 0 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <8 x double> [[SPLIT]], i64 4 +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <1 x double> undef, double [[TMP8]], i64 0 +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <8 x double> [[SPLIT]], i64 5 +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <1 x double> undef, double [[TMP10]], i64 0 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <8 x double> [[SPLIT]], i64 6 +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <1 x double> undef, double [[TMP12]], i64 0 +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <8 x double> [[SPLIT]], i64 7 +; CHECK-NEXT: [[TMP15:%.*]] = insertelement <1 x double> undef, double [[TMP14]], i64 0 +; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <1 x double> [[TMP1]], <1 x double> [[TMP3]], <2 x i32> +; CHECK-NEXT: [[TMP17:%.*]] = shufflevector <1 x double> [[TMP5]], <1 x double> [[TMP7]], <2 x i32> +; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <1 x double> [[TMP9]], <1 x double> [[TMP11]], <2 x i32> +; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <1 x double> [[TMP13]], <1 x double> [[TMP15]], <2 x i32> +; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x double> [[TMP16]], <2 x double> [[TMP17]], <4 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = shufflevector <2 x double> [[TMP18]], <2 x double> [[TMP19]], <4 x i32> +; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <4 x double> [[TMP20]], <4 x double> [[TMP21]], <8 x i32> +; CHECK-NEXT: store <8 x double> [[TMP22]], <8 x double>* [[PTR_B:%.*]], align 16 +; CHECK-NEXT: ret void +; +entry: + %a = load <8 x double>, <8 x double> *%Ptr.A, align 16 + %c = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 8, i32 1) + + store <8 x double> %c, <8 x double> *%Ptr.B, align 16 + ret void +} + +declare <12 x i16> @llvm.matrix.transpose.v12i16(<12 x i16>, i32, i32) + +define void @transpose_i16_3x4(<12 x i16>* %Ptr.A, <12 x i16>* %Ptr.B) { +; CHECK-LABEL: @transpose_i16_3x4( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A:%.*]] = load <12 x i16>, <12 x i16>* [[PTR_A:%.*]], align 16 +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <12 x i16> [[A]], <12 x i16> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <12 x i16> [[A]], <12 x i16> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <12 x i16> [[A]], <12 x i16> undef, <3 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <12 x i16> [[A]], <12 x i16> undef, <3 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <3 x i16> [[SPLIT]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x i16> undef, i16 [[TMP0]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <3 x i16> [[SPLIT1]], i64 0 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x i16> [[TMP1]], i16 [[TMP2]], i64 1 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <3 x i16> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x i16> [[TMP3]], i16 [[TMP4]], i64 2 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <3 x i16> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x i16> [[TMP5]], i16 [[TMP6]], i64 3 +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <3 x i16> [[SPLIT]], i64 1 +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x i16> undef, i16 [[TMP8]], i64 0 +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <3 x i16> [[SPLIT1]], i64 1 +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x i16> [[TMP9]], i16 [[TMP10]], i64 1 +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <3 x i16> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x i16> [[TMP11]], i16 [[TMP12]], i64 2 +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <3 x i16> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x i16> [[TMP13]], i16 [[TMP14]], i64 3 +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <3 x i16> [[SPLIT]], i64 2 +; CHECK-NEXT: [[TMP17:%.*]] = insertelement <4 x i16> undef, i16 [[TMP16]], i64 0 +; CHECK-NEXT: [[TMP18:%.*]] = extractelement <3 x i16> [[SPLIT1]], i64 2 +; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x i16> [[TMP17]], i16 [[TMP18]], i64 1 +; CHECK-NEXT: [[TMP20:%.*]] = extractelement <3 x i16> [[SPLIT2]], i64 2 +; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x i16> [[TMP19]], i16 [[TMP20]], i64 2 +; CHECK-NEXT: [[TMP22:%.*]] = extractelement <3 x i16> [[SPLIT3]], i64 2 +; CHECK-NEXT: [[TMP23:%.*]] = insertelement <4 x i16> [[TMP21]], i16 [[TMP22]], i64 3 +; CHECK-NEXT: [[TMP24:%.*]] = shufflevector <4 x i16> [[TMP7]], <4 x i16> [[TMP15]], <8 x i32> +; CHECK-NEXT: [[TMP25:%.*]] = shufflevector <4 x i16> [[TMP23]], <4 x i16> undef, <8 x i32> +; CHECK-NEXT: [[TMP26:%.*]] = shufflevector <8 x i16> [[TMP24]], <8 x i16> [[TMP25]], <12 x i32> +; CHECK-NEXT: store <12 x i16> [[TMP26]], <12 x i16>* [[PTR_B:%.*]], align 16 +; CHECK-NEXT: ret void +; +entry: + %a = load <12 x i16>, <12 x i16> *%Ptr.A, align 16 + %c = call <12 x i16> @llvm.matrix.transpose.v12i16(<12 x i16> %a, i32 3, i32 4) + + store <12 x i16> %c, <12 x i16> *%Ptr.B, align 16 + ret void +}