diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -167,60 +167,81 @@ /// Wrapper class representing a matrix as a set of column vectors. /// All column vectors must have the same vector type. - class ColumnMatrixTy { - SmallVector Columns; + class MatrixTy { + SmallVector Vectors; OpInfoTy OpInfo; - public: - ColumnMatrixTy() : Columns() {} - ColumnMatrixTy(ArrayRef Cols) - : Columns(Cols.begin(), Cols.end()) {} + bool IsColumnMajor = true; - Value *getColumn(unsigned i) const { return Columns[i]; } + public: + MatrixTy() : Vectors() {} + MatrixTy(ArrayRef Vectors) + : Vectors(Vectors.begin(), Vectors.end()) {} + + Value *getVector(unsigned i) const { return Vectors[i]; } + Value *getColumn(unsigned i) const { + assert(isColumnMajor() && "only supported for column-major matrixes"); + return Vectors[i]; + } - void setColumn(unsigned i, Value *V) { Columns[i] = V; } + void setColumn(unsigned i, Value *V) { Vectors[i] = V; } - size_t getNumColumns() const { return Columns.size(); } + size_t getNumColumns() const { + if (isColumnMajor()) + return Vectors.size(); + else { + assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); + return cast(Vectors[0]->getType())->getNumElements(); + } + } size_t getNumRows() const { - assert(Columns.size() > 0 && "Cannot call getNumRows without columns"); - return cast(Columns[0]->getType())->getNumElements(); + if (isColumnMajor()) { + assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); + return cast(Vectors[0]->getType())->getNumElements(); + } else + return Vectors.size(); } - const SmallVectorImpl &getColumnVectors() const { return Columns; } + const SmallVectorImpl &getColumnVectors() const { return Vectors; } - SmallVectorImpl &getColumnVectors() { return Columns; } + SmallVectorImpl &getColumnVectors() { return Vectors; } - void addColumn(Value *V) { Columns.push_back(V); } + void addColumn(Value *V) { Vectors.push_back(V); } VectorType *getColumnTy() { - return cast(Columns[0]->getType()); + assert(isColumnMajor() && "only supported for column-major matrixes"); + return getVectorTy(); + } + + VectorType *getVectorTy() { + return cast(Vectors[0]->getType()); } iterator_range::iterator> columns() { - return make_range(Columns.begin(), Columns.end()); + return make_range(Vectors.begin(), Vectors.end()); } /// Embed the columns of the matrix into a flat vector by concatenating /// them. Value *embedInVector(IRBuilder<> &Builder) const { - return Columns.size() == 1 ? Columns[0] - : concatenateVectors(Builder, Columns); + return Vectors.size() == 1 ? Vectors[0] + : concatenateVectors(Builder, Vectors); } - ColumnMatrixTy &addNumLoads(unsigned N) { + MatrixTy &addNumLoads(unsigned N) { OpInfo.NumLoads += N; return *this; } void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } - ColumnMatrixTy &addNumStores(unsigned N) { + MatrixTy &addNumStores(unsigned N) { OpInfo.NumStores += N; return *this; } - ColumnMatrixTy &addNumComputeOps(unsigned N) { + MatrixTy &addNumComputeOps(unsigned N) { OpInfo.NumComputeOps += N; return *this; } @@ -230,6 +251,8 @@ unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } const OpInfoTy &getOpInfo() const { return OpInfo; } + + bool isColumnMajor() const { return IsColumnMajor; } }; struct ShapeInfo { @@ -270,7 +293,7 @@ SmallVector ToRemove; /// Map from instructions to their produced column matrix. - MapVector Inst2ColumnMatrix; + MapVector Inst2ColumnMatrix; public: LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, @@ -296,7 +319,7 @@ /// If we lowered \p MatrixVal, just return the cache result column matrix. /// Otherwie split the flat vector \p MatrixVal containing a matrix with /// shape \p SI into column vectors. - ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, + MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, IRBuilder<> &Builder) { VectorType *VType = dyn_cast(MatrixVal->getType()); assert(VType && "MatrixVal must be a vector type"); @@ -309,7 +332,7 @@ // vector and split it later. auto Found = Inst2ColumnMatrix.find(MatrixVal); if (Found != Inst2ColumnMatrix.end()) { - ColumnMatrixTy &M = Found->second; + MatrixTy &M = Found->second; // Return the found matrix, if its shape matches the requested shape // information if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) @@ -639,7 +662,7 @@ IRBuilder<> Builder(Inst); auto VType = cast(Inst->getType()); Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); - ColumnMatrixTy Result; + MatrixTy Result; // Distance between start of one column and the start of the next for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { Value *GEP = @@ -677,7 +700,7 @@ Shape.NumRows, VType->getElementType(), Builder); createColumnStore(C.value(), GEP, VType->getElementType(), Builder); } - Inst2ColumnMatrix[Inst] = ColumnMatrixTy().addNumStores( + Inst2ColumnMatrix[Inst] = MatrixTy().addNumStores( getNumOps(LM.getColumnTy()) * LM.getNumColumns()); ToRemove.push_back(Inst); @@ -696,7 +719,7 @@ /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from /// the matrix \p LM represented as a vector of column vectors. - Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J, + Value *extractVector(const MatrixTy &LM, unsigned I, unsigned J, unsigned NumElts, IRBuilder<> &Builder) { Value *Col = LM.getColumn(J); Value *Undef = UndefValue::get(Col->getType()); @@ -768,7 +791,7 @@ /// cached value when they are lowered. For other users, \p Matrix is /// flattened and the uses are updated to use it. Also marks \p Inst for /// deletion. - void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix, + void finalizeLowering(Instruction *Inst, MatrixTy Matrix, IRBuilder<> &Builder) { Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); @@ -791,9 +814,9 @@ ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); - const ColumnMatrixTy &Lhs = + const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); - const ColumnMatrixTy &Rhs = + const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder); const unsigned R = LShape.NumRows; @@ -802,7 +825,7 @@ assert(M == RShape.NumRows); // Initialize the output - ColumnMatrixTy Result; + MatrixTy Result; for (unsigned J = 0; J < C; ++J) Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); @@ -840,12 +863,12 @@ /// Lowers llvm.matrix.transpose. void LowerTranspose(CallInst *Inst) { - ColumnMatrixTy Result; + MatrixTy Result; IRBuilder<> Builder(Inst); Value *InputVal = Inst->getArgOperand(0); VectorType *VectorTy = cast(InputVal->getType()); ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); - ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); + MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { // Build a single column vector for this row. First initialize it. @@ -905,11 +928,11 @@ IRBuilder<> Builder(Inst); ShapeInfo &Shape = I->second; - ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); - ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); + MatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); + MatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); // Add each column and store the result back into the opmapping - ColumnMatrixTy Result; + MatrixTy Result; auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) { switch (Inst->getOpcode()) { case Instruction::Add: @@ -951,7 +974,7 @@ /// Mapping from instructions to column matrixes. It is used to identify /// matrix instructions. - const MapVector &Inst2ColumnMatrix; + const MapVector &Inst2ColumnMatrix; /// Mapping from values to the leaves of all expressions that the value is /// part of. @@ -968,7 +991,7 @@ SmallPtrSet ReusedExprs; ExprLinearizer(const DataLayout &DL, - const MapVector &Inst2ColumnMatrix, + const MapVector &Inst2ColumnMatrix, const DenseMap> &Shared, const SmallSetVector &ExprsInSubprogram, Value *Leaf) @@ -1212,12 +1235,12 @@ /// that multiple leaves can share sub-expressions. Shared subexpressions /// are explicitly marked as shared(). struct RemarkGenerator { - const MapVector &Inst2ColumnMatrix; + const MapVector &Inst2ColumnMatrix; OptimizationRemarkEmitter &ORE; Function &Func; const DataLayout &DL; - RemarkGenerator(const MapVector &Inst2ColumnMatrix, + RemarkGenerator(const MapVector &Inst2ColumnMatrix, OptimizationRemarkEmitter &ORE, Function &Func) : Inst2ColumnMatrix(Inst2ColumnMatrix), ORE(ORE), Func(Func), DL(Func.getParent()->getDataLayout()) {}