Those are changes on top of D80043 which add cooperative matrix basic infrastructure. This adds support for cooperative matrix support for arithmetic and cast instructions. It also adds cooperative matrix store, muladd and matrixlength instructions which are part of the extension.
Details
Diff Detail
Event Timeline
Nice! Thanks Thomas!
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | ||
---|---|---|
3006 ↗ | (On Diff #265113) | Nit: can this be ordered after SPV_Integer? |
mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td | ||
56 ↗ | (On Diff #265113) | I think we can add let assemblyFormat = "attr-dict `:` type($result)"; So that we don't need to manually write the parser and printer? |
133 ↗ | (On Diff #265113) | I think we need verification on $pointer. At the moment every pointer will be accepted. |
175 ↗ | (On Diff #265113) | ssa-use |
176 ↗ | (On Diff #265113) | One general rule in MLIR regarding assembly is that it should be parsable on its own. I think we need to give at least two types (for $a and $b) for this? Otherwise by only looking at this op, I'm not able to tell what the type is for $a and $b. |
203 ↗ | (On Diff #265113) | I see we don't do any validation on these cooperative matrix ops. Do you plan to add them later? Here we should at least additionally check that 1) $a, $b, $c, and $result has consistent dimensions: MxK, KxN, MxN, MxN, and 2) they are at the same scope. Basically going through the description and verify what's written there is checked. :) |
204 ↗ | (On Diff #265113) | Similarly here I think we can do let assemblyFormat = "$a, $b, $c `:` type($a), type($b), type($c, $result) (type($c) is kinda redundant given that we can deduce it but I don't care that much. Also it makes the mapping very clear.) |
265 ↗ | (On Diff #265113) | Similarly we need to verify the requirements on $pointer. |
mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td | ||
---|---|---|
56 ↗ | (On Diff #265113) | Here type is not the type of $result so this syntax wouldn't work. Note sure if there is a way to get it to pick up type for the argument? |
133 ↗ | (On Diff #265113) | I added some verification for both load and store pointer type however I didn't find an easy way to test the verification as the invalid IR would fail parsing. |
176 ↗ | (On Diff #265113) | My bad I had wrongly assumed the type were matching. I added 2 types and a result type even though the result type could be deducted from the operand types. I also added a verify function for muladd and some tests along with it. I was planning to add those later to keep the patch small but it is better to add it now so that I don't forget. |
Cool! Sorry just a few more nits.
mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td | ||
---|---|---|
179 ↗ | (On Diff #265398) | This need to be updated right? |
56 ↗ | (On Diff #265113) | Ah you are right. Actually the op takes in the id for a type directly. Sorry missed that. |
176 ↗ | (On Diff #265113) | SGTM. The assembly format is a bit opaque; River added support for it so he knows all the nitty gritty. Feel free to ping him later if you have questions there. (But he is OOO ATM I think.) |
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | ||
2672 ↗ | (On Diff #265398) | Nit: ...And... ? |
2677 ↗ | (On Diff #265398) | s/snd/and/ For error messages it's better to be clear; so spelling out things completely is preferrable. |
2766 ↗ | (On Diff #265398) | Nit: In general MLIR does not use capitalized variable names. So I'd suggest to s/M/coopMatrix/ or something. |
2777 ↗ | (On Diff #265398) | result and the third ... ? No need to have the trailing dot here. |
2782 ↗ | (On Diff #265398) | We can do unconditional cast in the above and omit the check here. These are guaranteed by ODS. Normally the constraints in ODS does not need to be repeated here. We just need to check additional constraints using C++ here. |
2789 ↗ | (On Diff #265398) | In general error messages should stay all lower-case in MLIR. I think that's also the case for Clang/etc. So here: "matrix ..." Besides, it's better to be consistent so either "matrix .. mismatch" or "matrix ... must match" for this and the following two cases. |
mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td | ||
---|---|---|
176 ↗ | (On Diff #265113) | Sounds good. I'll chat with him when I get a chance to see if this can be improved. |
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | ||
---|---|---|
2738 ↗ | (On Diff #265566) | nit: Drop trivial braces. |
2746 ↗ | (On Diff #265566) | This looks like it should use the declarative format: https://mlir.llvm.org/docs/OpDefinitions/#declarative-assembly-format |
2756 ↗ | (On Diff #265566) | Same here and a few others. |
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | ||
---|---|---|
2738 ↗ | (On Diff #265566) | I'll send a patch to address this. |
2746 ↗ | (On Diff #265566) | River, I wasn't able to get a syntax working for this instruction as the type itself is an argument and is unrelated to the destination. Do yo have any suggestion to make it work? |
2756 ↗ | (On Diff #265566) | Here also I couldn't get the automatic print/parser to work as I have an instruction |
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | ||
---|---|---|
2746 ↗ | (On Diff #265566) | If the type of the result is inferrable, you don't have to specify it. Inferable types are either via traits(e.g., SameOperands) or with buildable types(i.e. statically constructible like index, i32, etc.). In this case, the result type is SPV_Int32(i.e. just an alias i32) which is a buildable type. To get the same syntax, all you really need to do it seems is something like: attr-dict `:` $type (You would have seen this if you read the section that I sent you : https://mlir.llvm.org/docs/OpDefinitions/#type-inference) |
2756 ↗ | (On Diff #265566) | In https://mlir.llvm.org/docs/OpDefinitions/#type-inference, it lists several traits that are supported for listing type equality constraints. You can do simple equality using AllTypesMatch<["c", "r"]> where r and c are the names of the operands/results/etc. With that types can be inferred for all as long as one can be inferred. I would expect the following to work: def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<..., [..., AllTypesMatch<"c", "result">]> { ... let assemblyFormat = "operands `:` type($a), type($b) -> type($c)"; } |