diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -165,66 +165,85 @@ } }; - /// Wrapper class representing a matrix as a set of column vectors. - /// All column vectors must have the same vector type. - class ColumnMatrixTy { - SmallVector Columns; + /// Wrapper class representing a matrix as a set of vectors, either in row or + /// column major layout. All vectors must have the same vector type. + class MatrixTy { + SmallVector Vectors; OpInfoTy OpInfo; + bool IsColumnMajor = true; + public: - ColumnMatrixTy() : Columns() {} - ColumnMatrixTy(ArrayRef Cols) - : Columns(Cols.begin(), Cols.end()) {} + MatrixTy() : Vectors() {} + MatrixTy(ArrayRef Vectors) + : Vectors(Vectors.begin(), Vectors.end()) {} + + Value *getVector(unsigned i) const { return Vectors[i]; } + Value *getColumn(unsigned i) const { + assert(isColumnMajor() && "only supported for column-major matrixes"); + return Vectors[i]; + } - Value *getColumn(unsigned i) const { return Columns[i]; } + void setColumn(unsigned i, Value *V) { Vectors[i] = V; } - void setColumn(unsigned i, Value *V) { Columns[i] = V; } + Type *getElementType() { return getVectorTy()->getElementType(); } - Type *getElementType() { - return cast(Columns[0]->getType())->getElementType(); + unsigned getNumColumns() const { + if (isColumnMajor()) + return Vectors.size(); + else { + assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); + return cast(Vectors[0]->getType())->getNumElements(); + } } - - unsigned getNumColumns() const { return Columns.size(); } unsigned getNumRows() const { - assert(Columns.size() > 0 && "Cannot call getNumRows without columns"); - return cast(Columns[0]->getType())->getNumElements(); + if (isColumnMajor()) { + assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); + return cast(Vectors[0]->getType())->getNumElements(); + } else + return Vectors.size(); } - const SmallVectorImpl &getColumnVectors() const { return Columns; } + const SmallVectorImpl &getColumnVectors() const { return Vectors; } - SmallVectorImpl &getColumnVectors() { return Columns; } + SmallVectorImpl &getColumnVectors() { return Vectors; } - void addColumn(Value *V) { Columns.push_back(V); } + void addColumn(Value *V) { Vectors.push_back(V); } VectorType *getColumnTy() { - return cast(Columns[0]->getType()); + assert(isColumnMajor() && "only supported for column-major matrixes"); + return getVectorTy(); + } + + VectorType *getVectorTy() { + return cast(Vectors[0]->getType()); } iterator_range::iterator> columns() { - return make_range(Columns.begin(), Columns.end()); + return make_range(Vectors.begin(), Vectors.end()); } /// Embed the columns of the matrix into a flat vector by concatenating /// them. Value *embedInVector(IRBuilder<> &Builder) const { - return Columns.size() == 1 ? Columns[0] - : concatenateVectors(Builder, Columns); + return Vectors.size() == 1 ? Vectors[0] + : concatenateVectors(Builder, Vectors); } - ColumnMatrixTy &addNumLoads(unsigned N) { + MatrixTy &addNumLoads(unsigned N) { OpInfo.NumLoads += N; return *this; } void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } - ColumnMatrixTy &addNumStores(unsigned N) { + MatrixTy &addNumStores(unsigned N) { OpInfo.NumStores += N; return *this; } - ColumnMatrixTy &addNumComputeOps(unsigned N) { + MatrixTy &addNumComputeOps(unsigned N) { OpInfo.NumComputeOps += N; return *this; } @@ -234,6 +253,8 @@ unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } const OpInfoTy &getOpInfo() const { return OpInfo; } + + bool isColumnMajor() const { return IsColumnMajor; } }; struct ShapeInfo { @@ -274,7 +295,7 @@ SmallVector ToRemove; /// Map from instructions to their produced column matrix. - MapVector Inst2ColumnMatrix; + MapVector Inst2ColumnMatrix; public: LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, @@ -300,8 +321,8 @@ /// If we lowered \p MatrixVal, just return the cache result column matrix. /// Otherwie split the flat vector \p MatrixVal containing a matrix with /// shape \p SI into column vectors. - ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, - IRBuilder<> &Builder) { + MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, + IRBuilder<> &Builder) { VectorType *VType = dyn_cast(MatrixVal->getType()); assert(VType && "MatrixVal must be a vector type"); assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && @@ -313,7 +334,7 @@ // vector and split it later. auto Found = Inst2ColumnMatrix.find(MatrixVal); if (Found != Inst2ColumnMatrix.end()) { - ColumnMatrixTy &M = Found->second; + MatrixTy &M = Found->second; // Return the found matrix, if its shape matches the requested shape // information if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) @@ -640,11 +661,11 @@ /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between /// columns. - ColumnMatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride, - ShapeInfo Shape, IRBuilder<> &Builder) { + MatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride, ShapeInfo Shape, + IRBuilder<> &Builder) { auto VType = cast(Ty); Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); - ColumnMatrixTy Result; + MatrixTy Result; // Distance between start of one column and the start of the next for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { Value *GEP = @@ -659,9 +680,9 @@ /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, /// starting at \p MatrixPtr[I][J]. - ColumnMatrixTy loadMatrix(Value *MatrixPtr, ShapeInfo MatrixShape, unsigned I, - unsigned J, ShapeInfo ResultShape, Type *EltTy, - IRBuilder<> &Builder) { + MatrixTy loadMatrix(Value *MatrixPtr, ShapeInfo MatrixShape, unsigned I, + unsigned J, ShapeInfo ResultShape, Type *EltTy, + IRBuilder<> &Builder) { Value *Offset = Builder.CreateAdd( Builder.CreateMul(Builder.getInt32(J), @@ -703,7 +724,7 @@ /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p /// MatrixPtr[I][J]. - void storeMatrix(const ColumnMatrixTy &StoreVal, Value *MatrixPtr, + void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, ShapeInfo MatrixShape, unsigned I, unsigned J, Type *EltTy, IRBuilder<> &Builder) { Value *Offset = Builder.CreateAdd( @@ -727,8 +748,8 @@ /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between /// columns. - ColumnMatrixTy storeMatrix(Type *Ty, ColumnMatrixTy StoreVal, Value *Ptr, - Value *Stride, IRBuilder<> &Builder) { + MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, Value *Stride, + IRBuilder<> &Builder) { auto VType = cast(Ty); Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); for (auto C : enumerate(StoreVal.columns())) { @@ -737,8 +758,8 @@ VType->getElementType(), Builder); createColumnStore(C.value(), GEP, VType->getElementType(), Builder); } - return ColumnMatrixTy().addNumStores(getNumOps(StoreVal.getColumnTy()) * - StoreVal.getNumColumns()); + return MatrixTy().addNumStores(getNumOps(StoreVal.getColumnTy()) * + StoreVal.getNumColumns()); } /// Lower a store instruction with shape information. @@ -764,7 +785,7 @@ /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from /// the matrix \p LM represented as a vector of column vectors. - Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J, + Value *extractVector(const MatrixTy &LM, unsigned I, unsigned J, unsigned NumElts, IRBuilder<> &Builder) { Value *Col = LM.getColumn(J); Value *Undef = UndefValue::get(Col->getType()); @@ -836,7 +857,7 @@ /// cached value when they are lowered. For other users, \p Matrix is /// flattened and the uses are updated to use it. Also marks \p Inst for /// deletion. - void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix, + void finalizeLowering(Instruction *Inst, MatrixTy Matrix, IRBuilder<> &Builder) { Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); @@ -854,9 +875,8 @@ /// Compute Res += A * B for tile-sized matrices with left-associating /// addition. - void emitChainedMatrixMultiply(ColumnMatrixTy &Result, - const ColumnMatrixTy &A, - const ColumnMatrixTy &B, bool AllowContraction, + void emitChainedMatrixMultiply(MatrixTy &Result, const MatrixTy &A, + const MatrixTy &B, bool AllowContraction, IRBuilder<> &Builder, bool isTiled) { const unsigned VF = std::max( TTI.getRegisterBitWidth(true) / @@ -902,17 +922,15 @@ ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); - const ColumnMatrixTy &Lhs = - getMatrix(MatMul->getArgOperand(0), LShape, Builder); - const ColumnMatrixTy &Rhs = - getMatrix(MatMul->getArgOperand(1), RShape, Builder); + const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); + const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); const unsigned R = LShape.NumRows; const unsigned C = RShape.NumColumns; assert(LShape.NumColumns == RShape.NumRows); // Initialize the output - ColumnMatrixTy Result; + MatrixTy Result; for (unsigned J = 0; J < C; ++J) Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); @@ -924,12 +942,12 @@ /// Lowers llvm.matrix.transpose. void LowerTranspose(CallInst *Inst) { - ColumnMatrixTy Result; + MatrixTy Result; IRBuilder<> Builder(Inst); Value *InputVal = Inst->getArgOperand(0); VectorType *VectorTy = cast(InputVal->getType()); ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); - ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); + MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { // Build a single column vector for this row. First initialize it. @@ -989,11 +1007,11 @@ IRBuilder<> Builder(Inst); ShapeInfo &Shape = I->second; - ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); - ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); + MatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); + MatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); // Add each column and store the result back into the opmapping - ColumnMatrixTy Result; + MatrixTy Result; auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) { switch (Inst->getOpcode()) { case Instruction::Add: @@ -1035,7 +1053,7 @@ /// Mapping from instructions to column matrixes. It is used to identify /// matrix instructions. - const MapVector &Inst2ColumnMatrix; + const MapVector &Inst2ColumnMatrix; /// Mapping from values to the leaves of all expressions that the value is /// part of. @@ -1052,7 +1070,7 @@ SmallPtrSet ReusedExprs; ExprLinearizer(const DataLayout &DL, - const MapVector &Inst2ColumnMatrix, + const MapVector &Inst2ColumnMatrix, const DenseMap> &Shared, const SmallSetVector &ExprsInSubprogram, Value *Leaf) @@ -1296,12 +1314,12 @@ /// that multiple leaves can share sub-expressions. Shared subexpressions /// are explicitly marked as shared(). struct RemarkGenerator { - const MapVector &Inst2ColumnMatrix; + const MapVector &Inst2ColumnMatrix; OptimizationRemarkEmitter &ORE; Function &Func; const DataLayout &DL; - RemarkGenerator(const MapVector &Inst2ColumnMatrix, + RemarkGenerator(const MapVector &Inst2ColumnMatrix, OptimizationRemarkEmitter &ORE, Function &Func) : Inst2ColumnMatrix(Inst2ColumnMatrix), ORE(ORE), Func(Func), DL(Func.getParent()->getDataLayout()) {}