Changeset View
Changeset View
Standalone View
Standalone View
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Show First 20 Lines • Show All 712 Lines • ▼ Show 20 Lines | for (Instruction *Inst : MatrixInsts) { | ||||
if (CallInst *CInst = dyn_cast<CallInst>(Inst)) | if (CallInst *CInst = dyn_cast<CallInst>(Inst)) | ||||
Changed |= VisitCallInst(CInst); | Changed |= VisitCallInst(CInst); | ||||
Value *Op1; | Value *Op1; | ||||
Value *Op2; | Value *Op2; | ||||
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) | if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) | ||||
Changed |= VisitBinaryOperator(BinOp); | Changed |= VisitBinaryOperator(BinOp); | ||||
if (match(Inst, m_Load(m_Value(Op1)))) | if (match(Inst, m_Load(m_Value(Op1)))) | ||||
Changed |= VisitLoad(Inst, Op1, Builder); | Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); | ||||
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) | else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) | ||||
Changed |= VisitStore(Inst, Op1, Op2, Builder); | Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder); | ||||
} | } | ||||
RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func); | RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func); | ||||
RemarkGen.emitRemarks(); | RemarkGen.emitRemarks(); | ||||
for (Instruction *Inst : reverse(ToRemove)) | for (Instruction *Inst : reverse(ToRemove)) | ||||
Inst->eraseFromParent(); | Inst->eraseFromParent(); | ||||
return Changed; | return Changed; | ||||
} | } | ||||
LoadInst *createVectorLoad(Value *ColumnPtr, Type *EltType, | LoadInst *createVectorLoad(Value *ColumnPtr, Type *EltType, bool IsVolatile, | ||||
IRBuilder<> &Builder) { | IRBuilder<> &Builder) { | ||||
return Builder.CreateAlignedLoad( | return Builder.CreateAlignedLoad(ColumnPtr, | ||||
ColumnPtr, Align(DL.getABITypeAlignment(EltType)), "col.load"); | Align(DL.getABITypeAlignment(EltType)), | ||||
IsVolatile, "col.load"); | |||||
} | } | ||||
StoreInst *createVectorStore(Value *ColumnValue, Value *ColumnPtr, | StoreInst *createVectorStore(Value *ColumnValue, Value *ColumnPtr, | ||||
Type *EltType, IRBuilder<> &Builder) { | Type *EltType, bool IsVolatile, | ||||
IRBuilder<> &Builder) { | |||||
return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, | return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, | ||||
DL.getABITypeAlign(EltType)); | DL.getABITypeAlign(EltType), IsVolatile); | ||||
} | } | ||||
/// Turns \p BasePtr into an elementwise pointer to \p EltType. | /// Turns \p BasePtr into an elementwise pointer to \p EltType. | ||||
Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { | Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) { | ||||
unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); | unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); | ||||
Type *EltPtrType = PointerType::get(EltType, AS); | Type *EltPtrType = PointerType::get(EltType, AS); | ||||
return Builder.CreatePointerCast(BasePtr, EltPtrType); | return Builder.CreatePointerCast(BasePtr, EltPtrType); | ||||
} | } | ||||
Show All 19 Lines | bool VisitCallInst(CallInst *Inst) { | ||||
default: | default: | ||||
return false; | return false; | ||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
/// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between | /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between | ||||
/// vectors. | /// vectors. | ||||
MatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride, ShapeInfo Shape, | MatrixTy loadMatrix(Type *Ty, Value *Ptr, Value *Stride, bool IsVolatile, | ||||
IRBuilder<> &Builder) { | ShapeInfo Shape, IRBuilder<> &Builder) { | ||||
auto VType = cast<VectorType>(Ty); | auto VType = cast<VectorType>(Ty); | ||||
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); | Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); | ||||
MatrixTy Result; | MatrixTy Result; | ||||
for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { | for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { | ||||
Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, | Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, | ||||
Shape.getStride(), VType->getElementType(), | Shape.getStride(), VType->getElementType(), | ||||
Builder); | Builder); | ||||
Value *Vector = createVectorLoad(GEP, VType->getElementType(), Builder); | Value *Vector = | ||||
createVectorLoad(GEP, VType->getElementType(), IsVolatile, Builder); | |||||
Result.addVector(Vector); | Result.addVector(Vector); | ||||
} | } | ||||
return Result.addNumLoads(getNumOps(Result.getVectorTy()) * | return Result.addNumLoads(getNumOps(Result.getVectorTy()) * | ||||
Result.getNumVectors()); | Result.getNumVectors()); | ||||
} | } | ||||
/// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, | /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, | ||||
/// starting at \p MatrixPtr[I][J]. | /// starting at \p MatrixPtr[I][J]. | ||||
MatrixTy loadMatrix(Value *MatrixPtr, ShapeInfo MatrixShape, Value *I, | MatrixTy loadMatrix(Value *MatrixPtr, bool IsVolatile, ShapeInfo MatrixShape, | ||||
Value *J, ShapeInfo ResultShape, Type *EltTy, | Value *I, Value *J, ShapeInfo ResultShape, Type *EltTy, | ||||
IRBuilder<> &Builder) { | IRBuilder<> &Builder) { | ||||
Value *Offset = Builder.CreateAdd( | Value *Offset = Builder.CreateAdd( | ||||
Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); | Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); | ||||
unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); | unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); | ||||
Value *EltPtr = | Value *EltPtr = | ||||
Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); | Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); | ||||
Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); | Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); | ||||
auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * | auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * | ||||
ResultShape.NumColumns); | ResultShape.NumColumns); | ||||
Type *TilePtrTy = PointerType::get(TileTy, AS); | Type *TilePtrTy = PointerType::get(TileTy, AS); | ||||
Value *TilePtr = | Value *TilePtr = | ||||
Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); | Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); | ||||
return loadMatrix(TileTy, TilePtr, | return loadMatrix(TileTy, TilePtr, | ||||
Builder.getInt64(MatrixShape.getStride()), ResultShape, | Builder.getInt64(MatrixShape.getStride()), IsVolatile, | ||||
Builder); | ResultShape, Builder); | ||||
} | } | ||||
/// Lower a load instruction with shape information. | /// Lower a load instruction with shape information. | ||||
void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, | void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, bool IsVolatile, | ||||
ShapeInfo Shape) { | ShapeInfo Shape) { | ||||
IRBuilder<> Builder(Inst); | IRBuilder<> Builder(Inst); | ||||
finalizeLowering(Inst, | finalizeLowering( | ||||
loadMatrix(Inst->getType(), Ptr, Stride, Shape, Builder), | Inst, | ||||
loadMatrix(Inst->getType(), Ptr, Stride, IsVolatile, Shape, Builder), | |||||
Builder); | Builder); | ||||
} | } | ||||
/// Lowers llvm.matrix.column.major.load. | /// Lowers llvm.matrix.column.major.load. | ||||
/// | /// | ||||
/// The intrinsic loads a matrix from memory using a stride between columns. | /// The intrinsic loads a matrix from memory using a stride between columns. | ||||
void LowerColumnMajorLoad(CallInst *Inst) { | void LowerColumnMajorLoad(CallInst *Inst) { | ||||
assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && | assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && | ||||
"Intrinsic only supports column-major layout!"); | "Intrinsic only supports column-major layout!"); | ||||
Value *Ptr = Inst->getArgOperand(0); | Value *Ptr = Inst->getArgOperand(0); | ||||
Value *Stride = Inst->getArgOperand(1); | Value *Stride = Inst->getArgOperand(1); | ||||
LowerLoad(Inst, Ptr, Stride, | LowerLoad(Inst, Ptr, Stride, | ||||
cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), | |||||
{Inst->getArgOperand(3), Inst->getArgOperand(4)}); | {Inst->getArgOperand(3), Inst->getArgOperand(4)}); | ||||
} | } | ||||
/// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p | /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p | ||||
/// MatrixPtr[I][J]. | /// MatrixPtr[I][J]. | ||||
void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, | void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, bool IsVolatile, | ||||
ShapeInfo MatrixShape, Value *I, Value *J, Type *EltTy, | ShapeInfo MatrixShape, Value *I, Value *J, Type *EltTy, | ||||
IRBuilder<> &Builder) { | IRBuilder<> &Builder) { | ||||
Value *Offset = Builder.CreateAdd( | Value *Offset = Builder.CreateAdd( | ||||
Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); | Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); | ||||
unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); | unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); | ||||
Value *EltPtr = | Value *EltPtr = | ||||
Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); | Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); | ||||
Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); | Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); | ||||
auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * | auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * | ||||
StoreVal.getNumColumns()); | StoreVal.getNumColumns()); | ||||
Type *TilePtrTy = PointerType::get(TileTy, AS); | Type *TilePtrTy = PointerType::get(TileTy, AS); | ||||
Value *TilePtr = | Value *TilePtr = | ||||
Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); | Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); | ||||
storeMatrix(TileTy, StoreVal, TilePtr, | storeMatrix(TileTy, StoreVal, TilePtr, | ||||
Builder.getInt64(MatrixShape.getStride()), Builder); | Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); | ||||
} | } | ||||
/// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between | /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between | ||||
/// vectors. | /// vectors. | ||||
MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, Value *Stride, | MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, Value *Stride, | ||||
IRBuilder<> &Builder) { | bool IsVolatile, IRBuilder<> &Builder) { | ||||
auto VType = cast<VectorType>(Ty); | auto VType = cast<VectorType>(Ty); | ||||
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); | Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); | ||||
for (auto Vec : enumerate(StoreVal.vectors())) { | for (auto Vec : enumerate(StoreVal.vectors())) { | ||||
Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), | Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), | ||||
Stride, StoreVal.getStride(), | Stride, StoreVal.getStride(), | ||||
VType->getElementType(), Builder); | VType->getElementType(), Builder); | ||||
createVectorStore(Vec.value(), GEP, VType->getElementType(), Builder); | createVectorStore(Vec.value(), GEP, VType->getElementType(), IsVolatile, | ||||
Builder); | |||||
} | } | ||||
return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * | return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * | ||||
StoreVal.getNumVectors()); | StoreVal.getNumVectors()); | ||||
} | } | ||||
/// Lower a store instruction with shape information. | /// Lower a store instruction with shape information. | ||||
void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, | void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, | ||||
ShapeInfo Shape) { | bool IsVolatile, ShapeInfo Shape) { | ||||
IRBuilder<> Builder(Inst); | IRBuilder<> Builder(Inst); | ||||
auto StoreVal = getMatrix(Matrix, Shape, Builder); | auto StoreVal = getMatrix(Matrix, Shape, Builder); | ||||
finalizeLowering( | finalizeLowering(Inst, | ||||
Inst, storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride, Builder), | storeMatrix(Matrix->getType(), StoreVal, Ptr, Stride, | ||||
IsVolatile, Builder), | |||||
Builder); | Builder); | ||||
} | } | ||||
/// Lowers llvm.matrix.column.major.store. | /// Lowers llvm.matrix.column.major.store. | ||||
/// | /// | ||||
/// The intrinsic store a matrix back memory using a stride between columns. | /// The intrinsic store a matrix back memory using a stride between columns. | ||||
void LowerColumnMajorStore(CallInst *Inst) { | void LowerColumnMajorStore(CallInst *Inst) { | ||||
assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && | assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && | ||||
"Intrinsic only supports column-major layout!"); | "Intrinsic only supports column-major layout!"); | ||||
Value *Matrix = Inst->getArgOperand(0); | Value *Matrix = Inst->getArgOperand(0); | ||||
Value *Ptr = Inst->getArgOperand(1); | Value *Ptr = Inst->getArgOperand(1); | ||||
Value *Stride = Inst->getArgOperand(2); | Value *Stride = Inst->getArgOperand(2); | ||||
LowerStore(Inst, Matrix, Ptr, Stride, | LowerStore(Inst, Matrix, Ptr, Stride, | ||||
cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), | |||||
{Inst->getArgOperand(4), Inst->getArgOperand(5)}); | {Inst->getArgOperand(4), Inst->getArgOperand(5)}); | ||||
} | } | ||||
// Set elements I..I+NumElts-1 to Block | // Set elements I..I+NumElts-1 to Block | ||||
Value *insertVector(Value *Col, unsigned I, Value *Block, | Value *insertVector(Value *Col, unsigned I, Value *Block, | ||||
IRBuilder<> &Builder) { | IRBuilder<> &Builder) { | ||||
// First, bring Block to the same size as Col | // First, bring Block to the same size as Col | ||||
▲ Show 20 Lines • Show All 294 Lines • ▼ Show 20 Lines | void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, | ||||
IRBuilder<> Builder(Store); | IRBuilder<> Builder(Store); | ||||
for (unsigned J = 0; J < C; J += TileSize) | for (unsigned J = 0; J < C; J += TileSize) | ||||
for (unsigned I = 0; I < R; I += TileSize) { | for (unsigned I = 0; I < R; I += TileSize) { | ||||
const unsigned TileR = std::min(R - I, unsigned(TileSize)); | const unsigned TileR = std::min(R - I, unsigned(TileSize)); | ||||
const unsigned TileC = std::min(C - J, unsigned(TileSize)); | const unsigned TileC = std::min(C - J, unsigned(TileSize)); | ||||
MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); | MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); | ||||
for (unsigned K = 0; K < M; K += TileSize) { | for (unsigned K = 0; K < M; K += TileSize) { | ||||
const unsigned TileM = std::min(M - K, unsigned(TileSize)); | const unsigned TileM = std::min(M - K, unsigned(TileSize)); | ||||
MatrixTy A = | MatrixTy A = loadMatrix(APtr, LoadOp0->isVolatile(), LShape, | ||||
LuoYuanke: Why we set volatile as false? Is it possible that the original load instruction is volatile… | |||||
Yes indeed! I wanted to share the code ASAP, fixed now (with additional test case) fhahn: Yes indeed! I wanted to share the code ASAP, fixed now (with additional test case) | |||||
loadMatrix(APtr, LShape, Builder.getInt64(I), Builder.getInt64(K), | Builder.getInt64(I), Builder.getInt64(K), | ||||
{TileR, TileM}, EltType, Builder); | {TileR, TileM}, EltType, Builder); | ||||
MatrixTy B = | MatrixTy B = loadMatrix(BPtr, LoadOp1->isVolatile(), RShape, | ||||
loadMatrix(BPtr, RShape, Builder.getInt64(K), Builder.getInt64(J), | Builder.getInt64(K), Builder.getInt64(J), | ||||
{TileM, TileC}, EltType, Builder); | {TileM, TileC}, EltType, Builder); | ||||
emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); | emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); | ||||
} | } | ||||
storeMatrix(Res, CPtr, {R, M}, Builder.getInt64(I), Builder.getInt64(J), | storeMatrix(Res, CPtr, Store->isVolatile(), {R, M}, Builder.getInt64(I), | ||||
EltType, Builder); | Builder.getInt64(J), EltType, Builder); | ||||
} | } | ||||
// Mark eliminated instructions as fused and remove them. | // Mark eliminated instructions as fused and remove them. | ||||
FusedInsts.insert(Store); | FusedInsts.insert(Store); | ||||
FusedInsts.insert(MatMul); | FusedInsts.insert(MatMul); | ||||
Store->eraseFromParent(); | Store->eraseFromParent(); | ||||
MatMul->eraseFromParent(); | MatMul->eraseFromParent(); | ||||
if (LoadOp0->hasNUses(0)) { | if (LoadOp0->hasNUses(0)) { | ||||
▲ Show 20 Lines • Show All 91 Lines • ▼ Show 20 Lines | void LowerTranspose(CallInst *Inst) { | ||||
// account for later simplifications/combines. | // account for later simplifications/combines. | ||||
finalizeLowering( | finalizeLowering( | ||||
Inst, | Inst, | ||||
Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), | Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), | ||||
Builder); | Builder); | ||||
} | } | ||||
/// Lower load instructions, if shape information is available. | /// Lower load instructions, if shape information is available. | ||||
bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) { | bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) { | ||||
auto I = ShapeMap.find(Inst); | auto I = ShapeMap.find(Inst); | ||||
if (I == ShapeMap.end()) | if (I == ShapeMap.end()) | ||||
return false; | return false; | ||||
LowerLoad(Inst, Ptr, Builder.getInt64(I->second.getStride()), I->second); | LowerLoad(Inst, Ptr, Builder.getInt64(I->second.getStride()), | ||||
Inst->isVolatile(), I->second); | |||||
return true; | return true; | ||||
} | } | ||||
bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr, | bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr, | ||||
IRBuilder<> &Builder) { | IRBuilder<> &Builder) { | ||||
auto I = ShapeMap.find(StoredVal); | auto I = ShapeMap.find(StoredVal); | ||||
if (I == ShapeMap.end()) | if (I == ShapeMap.end()) | ||||
return false; | return false; | ||||
LowerStore(Inst, StoredVal, Ptr, Builder.getInt64(I->second.getStride()), | LowerStore(Inst, StoredVal, Ptr, Builder.getInt64(I->second.getStride()), | ||||
I->second); | Inst->isVolatile(), I->second); | ||||
return true; | return true; | ||||
} | } | ||||
/// Lower binary operators, if shape information is available. | /// Lower binary operators, if shape information is available. | ||||
bool VisitBinaryOperator(BinaryOperator *Inst) { | bool VisitBinaryOperator(BinaryOperator *Inst) { | ||||
auto I = ShapeMap.find(Inst); | auto I = ShapeMap.find(Inst); | ||||
if (I == ShapeMap.end()) | if (I == ShapeMap.end()) | ||||
return false; | return false; | ||||
▲ Show 20 Lines • Show All 545 Lines • Show Last 20 Lines |
Why we set volatile as false? Is it possible that the original load instruction is volatile load?
Do we allow fusion if the load and store instruction is volatile?