Support IR that is generated by the vector-to-scf lowering of 2D vector transfers with a mask. Only 2D transfers that were fully unrolled are supported at the moment.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp | ||
---|---|---|
63 | can this return a failure rather than crash ? I see the original impl has an assert higher up, can we fail here and assert at the same place rather than sink the assertion? | |
83 | then should this return failure ? | |
90 | let's propagate the error up and fail at the caller. | |
106 | seems like we could refactor these preconditions into getMaskOp so that the transform could fail gracefully with proper error messages rather than crash in a few places | |
244 | Seems like we could refactor this to precompute the masks and if all of them are valid then only perform rewrites. This way the transform could fail more gracefully, with tested readable error messages. |
Thanks for working on this. It looks good.
Only 2D transfers that were fully unrolled are supported at the moment.
Out of curiosity - why do we have this restriction? unrolling isn't always beneficial.
What I mean by this is that only certain IR is is supported. E.g., only IR where the mask is a vector.extract(...) from a 2D mask. That's the IR that is generated by vector-to-scf. This transformation here does not unroll. If unrolling is not beneficial, we do not have to unroll. (But then we cannot use copy_async.)
This code is mostly copied from IREE and refactored a bit. To keep it simple, that implementation only supports 2D masks. But it could be extended to higher dimensions if needed. We just didn't need it so far, so it is not implemented.
mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp | ||
---|---|---|
63 | This is already handled by this piece of code: if (cast<VectorType>(vectorVal.getType()).getRank() != 1) return; This assert here is just to make sure that createAsyncGroups and this helper function stay in sync. It should never be triggered. | |
90 | Same, this is already handled by this code, the assertion will never fail unless someone changes the code: if (cast<VectorType>(vectorVal.getType()).getRank() != 1) return; | |
106 | This also already handled by a check in getMaskOp: if (extractOp.getPosition().size() == 1 && extractOp.getSourceVectorType().getRank() == 2) Note getMaskOp is called twice. The first time when looking for "eligible" ops. If getMaskOp returns failure at that point, we don't make it buildNumReadElements. All the assertions here are just to highlight what's supported and what is not. But they cannot fail at the moment. | |
244 | That is happening during the call to getMaskOp further up in this function: // Look for compatible mask and padding. If mask/padding is not supported, the op is skipped. (no error) |
can this return a failure rather than crash ?
I see the original impl has an assert higher up, can we fail here and assert at the same place rather than sink the assertion?