Add a new type to SPIRV dialect for cooperative matrix and add new op for cooperative matrix load. This is missing most instructions to support cooperative matrix extension but this is a stop-gap patch to avoid creating big review.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
Could you also add some roundtrip tests to : https://github.com/llvm/llvm-project/blob/master/mlir/test/Dialect/SPIRV/ops.mlir (Actually a separate test file for these ops might also be fine)
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | ||
---|---|---|
3228 | The opcodes in this file are auto-generated using this script : https://github.com/llvm/llvm-project/blob/master/mlir/utils/spirv/define_opcodes.sh . Please use that instead of defining them here. (Sorry if the documentation doesnt point you in the right direction) | |
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp | ||
282 | Nit: Instead of using parseKeyword here, we could make this method : https://github.com/llvm/llvm-project/blob/7ee479a760e0a4402b4eb7fb6168768a44f66945/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp#L147 a utility function and use that here? | |
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp | ||
1237 | I wonder if the automatic deserialization generation code would work here? (AFAIR the same code must already be generated by tblgen, but is being intercepted early here) |
Thanks Mahesh, I added roundtrip tests. Please take another look.
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | ||
---|---|---|
3228 | My bad, it was in the documentation, I had missed this part. Updated using the script. | |
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp | ||
282 | Done. | |
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp | ||
1237 | My understanding is that we get auto-generated code for the ops but here this if for deserialization of the new type. The type is not defined tblgen at all unless I'm missing something. |
Sorry, I should have been more clear. These go through the serialization and deserialization of the ops. So while they are testing the parser and printer as well, they are primarily for testing that the serialization and deserialization works. Typically ops are tested for both serialization/deserialization through the SPIR-V binary, as well as just IR roundtrip. The IR round-trip tests live here : https://github.com/llvm/llvm-project/tree/master/mlir/test/Dialect/SPIRV . Regressions here indicate that the IR parsing/printing failed.
The serialization tests are within the Serialization sub-directory here : https://github.com/llvm/llvm-project/tree/master/mlir/test/Dialect/SPIRV/Serialization. These catch regressions in serialization/deserialization.
So I was asking for a roundtrip IR test.
Rest of the changes look fine to me.
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp | ||
---|---|---|
1237 | Ah, your right. My bad. |
Fantastic!! This is really great, Thomas! Thanks for taking on this! Overall looks good to me; I just have a few nits. Besides, can we add tests for
- The coop mat type itself: https://github.com/llvm/llvm-project/blob/master/mlir/test/Dialect/SPIRV/types.mlir
- (De)serialization for coop mat load op in the directory as pointed out by Mahesh: https://github.com/llvm/llvm-project/tree/master/mlir/test/Dialect/SPIRV/Serialization
mlir/include/mlir/Dialect/SPIRV/ParserUtils.h | ||
---|---|---|
2 | The header explanation needs to be updated. | |
41 | Nit: let's have an empty line before #endif to improve readability. | |
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | ||
2993–2994 | Nit: can we keep this subsection alphabetically sorted? | |
mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td | ||
23 | What about copying the original doc for coopmat load from https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/NV/SPV_NV_cooperative_matrix.asciidoc#3328-memory-instructions here? We have been mostly following that convention for SPIR-V ops defined. | |
mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h | ||
355 | Nit: move this empy line after the function decl? | |
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp | ||
269 | spv.coopmatrix? | |
281 | Can we put the scope ahead of the element type? I think it's more natural that way: element types typically appear together with row/col sizes. (The spec put it in an order like this but we don't need to follow that order.) | |
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | ||
2623 | spv.CooperativeMatrixLoadNV | |
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp | ||
286 | Nit: leave an empty line to separate the methods with fields. | |
315 | We need to push SPV_NV_cooperative_matrix to extensions here. | |
320 | We need to push CooperativeMatrixNV capability to capabilities here. | |
mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir | ||
4 | We can drop the requires ... part. They are incorrect here given we actually need more capabilities and extensions for the following to be used. |
Thanks Lei. added some tests for the type in types.mir. About 2. I already had serialization/deserialization tests, am I missing something there?
mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir | ||
---|---|---|
4 | I get the following error if I remove the requires part. Maybe I missed something? error: module must have 'vce_triple' attribute to be serializeable |
Thanks Thomas! Just two final comments about the type assembly and tests. Feel free to land after addressing it. :)
mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir | ||
---|---|---|
4 | Ah, sorry. You are right. I mistakenly treated this test as one for op parsing/printing round-trip. (Hence also my comment to request tests for serialization.) Then can we fix the required extension and capability? At the moment it's wrong. (We don't have validation for this yet but should have.) We need to have spv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> here. | |
mlir/test/Dialect/SPIRV/types.mlir | ||
338 | Hmm, sorry didn't point this out previously; I feel it's actually better to have the format of spv.coopmatrix<8x8xf32, Workgroup>. This is more consistent with vectors/tensors (vector<3x4xi32>) and pointers (spv.ptr<f32, PushConstant>) that way. How do you think? |
The header explanation needs to be updated.