diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -181,8 +181,8 @@ void setColumn(unsigned i, Value *V) { Columns[i] = V; } - size_t getNumColumns() const { return Columns.size(); } - size_t getNumRows() const { + unsigned getNumColumns() const { return Columns.size(); } + unsigned getNumRows() const { assert(Columns.size() > 0 && "Cannot call getNumRows without columns"); return cast(Columns[0]->getType())->getNumElements(); } @@ -634,10 +634,11 @@ return true; } - void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, - ShapeInfo Shape) { - IRBuilder<> Builder(Inst); - auto VType = cast(Inst->getType()); + /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between + /// columns. + ColumnMatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride, + ShapeInfo Shape, IRBuilder<> &Builder) { + auto VType = cast(Ty); Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); ColumnMatrixTy Result; // Distance between start of one column and the start of the next @@ -648,10 +649,41 @@ Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); Result.addColumn(Column); } + return Result.addNumLoads(getNumOps(Result.getColumnTy()) * + Result.getNumColumns()); + } + + /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, + /// starting at \p MatrixPtr[I][J]. + ColumnMatrixTy loadMatrix(Value *MatrixPtr, ShapeInfo MatrixShape, unsigned I, + unsigned J, ShapeInfo ResultShape, Type *EltTy, + IRBuilder<> &Builder) { + + Value *Offset = Builder.CreateAdd( + Builder.CreateMul(Builder.getInt32(J), + Builder.getInt32(MatrixShape.NumRows)), + Builder.getInt32(I)); + + unsigned AS = cast(MatrixPtr->getType())->getAddressSpace(); + Value *EltPtr = + Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); + Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); + Type *TileTy = + VectorType::get(EltTy, ResultShape.NumRows * ResultShape.NumColumns); + Type *TilePtrTy = PointerType::get(TileTy, AS); + Value *TilePtr = + Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); + + return loadMatrix(TileTy, TilePtr, Builder.getInt32(ResultShape.NumRows), + ResultShape, Builder); + } + /// Lower a load instruction with shape information. + void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, + ShapeInfo Shape) { + IRBuilder<> Builder(Inst); finalizeLowering(Inst, - Result.addNumLoads(getNumOps(Result.getColumnTy()) * - Result.getNumColumns()), + loadMatrix(Inst->getType(), Ptr, Stride, Shape, Builder), Builder); } @@ -665,22 +697,54 @@ {Inst->getArgOperand(2), Inst->getArgOperand(3)}); } - void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, - ShapeInfo Shape) { - IRBuilder<> Builder(Inst); - auto VType = cast(Matrix->getType()); + /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p + /// MatrixPtr[I][J]. + void storeMatrix(const ColumnMatrixTy &StoreVal, Value *MatrixPtr, + ShapeInfo MatrixShape, unsigned I, unsigned J, Type *EltTy, + IRBuilder<> &Builder) { + Value *Offset = Builder.CreateAdd( + Builder.CreateMul(Builder.getInt32(J), + Builder.getInt32(MatrixShape.NumRows)), + Builder.getInt32(I)); + + unsigned AS = cast(MatrixPtr->getType())->getAddressSpace(); + Value *EltPtr = + Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); + Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); + Type *TileTy = VectorType::get(EltTy, StoreVal.getNumRows() * + StoreVal.getNumColumns()); + Type *TilePtrTy = PointerType::get(TileTy, AS); + Value *TilePtr = + Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); + + storeMatrix(TileTy, StoreVal, TilePtr, + Builder.getInt32(StoreVal.getNumRows()), Builder); + } + + /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between + /// columns. + ColumnMatrixTy storeMatrix(Type *Ty, ColumnMatrixTy StoreVal, Value *Ptr, + Value *Stride, IRBuilder<> &Builder) { + auto VType = cast(Ty); Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); - auto LM = getMatrix(Matrix, Shape, Builder); - for (auto C : enumerate(LM.columns())) { - Value *GEP = - computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride, - Shape.NumRows, VType->getElementType(), Builder); + for (auto C : enumerate(StoreVal.columns())) { + Value *GEP = computeColumnAddr(EltPtr, Builder.getInt32(C.index()), + Stride, StoreVal.getNumRows(), + VType->getElementType(), Builder); createColumnStore(C.value(), GEP, VType->getElementType(), Builder); } - Inst2ColumnMatrix[Inst] = ColumnMatrixTy().addNumStores( - getNumOps(LM.getColumnTy()) * LM.getNumColumns()); + return ColumnMatrixTy().addNumStores(getNumOps(StoreVal.getColumnTy()) * + StoreVal.getNumColumns()); + } - ToRemove.push_back(Inst); + /// Lower a store instruction with shape information. + void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, + ShapeInfo Shape) { + IRBuilder<> Builder(Inst); + auto StoreVal = getMatrix(Matrix, Shape, Builder); + finalizeLowering( + Inst, storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride, Builder), + Builder); } /// Lowers llvm.matrix.columnwise.store.