This patch handles native mma.sync sizes and enables issuing ldmatrix on largest possible tiles for matrixB. It requires handling vector.extract_strided_slice from vector to ngpu lowering.
Details
Diff Detail
Event Timeline
Looks great, made a few minor comments. Thanks!
mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h | ||
---|---|---|
29 | ||
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp | ||
197 | Maybe just return false where there are currently assert statements. What's going to happen on release builds if this occurs? The purpose of the function is to validate what is expected after all. I know we're kind of in a weird place here in terms of emitting diagnostic information. IMO going forward we should convert these validation functions to return LogicalResult and replace the asserts with return op.emitWarning(msg). Currently if a user is messing with mlir-opt and tries this pass with invalid input IR, they'll get no feedback on why it fails to convert or, in the worst case, a crash. For the useNvGpu parameter, it might be simpler to just omit that and return useNvGpu && extractSlicedSliceSupportsMMAMatrixType(...) in the supportsMMaMatrixType function | |
219 | Can we emit the case for operand A and simplify to the below? if(operandB) return op->getResult() == contractOp->getRhs(); if(operandC) return op->getResult() == contractOp->getAcc(); return false; | |
730 | if(!transferReadOp) return failure(); | |
mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp | ||
50 | Move to VectorToGPU.cpp as a static method? |
mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp | ||
---|---|---|
50 | Can this helper function also be used in getWarpMatrixInfo? |
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp | ||
---|---|---|
219 | Done! thanks for the suggestion. |
Thanks Manish, looks great!
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp | ||
---|---|---|
216 | nit: this comment is a bit confusing has the code could reach here it is just that this is cannot be converted to simt code. I would remove it or rephrase saying we only handle matrixB and matrixC |