This is an archive of the discontinued LLVM Phabricator instance.

[mlir][spirv] Add remaining cooperative matrix instructions.
ClosedPublic

Authored by ThomasRaoux on May 18 2020, 9:33 PM.

Details

Summary

Those are changes on top of D80043 which add cooperative matrix basic infrastructure. This adds support for cooperative matrix support for arithmetic and cast instructions. It also adds cooperative matrix store, muladd and matrixlength instructions which are part of the extension.

Diff Detail

Event Timeline

ThomasRaoux created this revision.May 18 2020, 9:33 PM
Herald added a project: Restricted Project. · View Herald TranscriptMay 18 2020, 9:33 PM
antiagainst requested changes to this revision.May 20 2020, 2:51 PM

Nice! Thanks Thomas!

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
3006 ↗(On Diff #265113)

Nit: can this be ordered after SPV_Integer?

mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
56 ↗(On Diff #265113)

I think we can add

let assemblyFormat = "attr-dict `:` type($result)";

So that we don't need to manually write the parser and printer?

133 ↗(On Diff #265113)

I think we need verification on $pointer. At the moment every pointer will be accepted.

175 ↗(On Diff #265113)

ssa-use

176 ↗(On Diff #265113)

One general rule in MLIR regarding assembly is that it should be parsable on its own. I think we need to give at least two types (for $a and $b) for this? Otherwise by only looking at this op, I'm not able to tell what the type is for $a and $b.

203 ↗(On Diff #265113)

I see we don't do any validation on these cooperative matrix ops. Do you plan to add them later?

Here we should at least additionally check that 1) $a, $b, $c, and $result has consistent dimensions: MxK, KxN, MxN, MxN, and 2) they are at the same scope. Basically going through the description and verify what's written there is checked. :)

204 ↗(On Diff #265113)

Similarly here I think we can do

let assemblyFormat = "$a, $b, $c `:` type($a), type($b), type($c, $result)

(type($c) is kinda redundant given that we can deduce it but I don't care that much. Also it makes the mapping very clear.)

265 ↗(On Diff #265113)

Similarly we need to verify the requirements on $pointer.

This revision now requires changes to proceed.May 20 2020, 2:51 PM
ThomasRaoux marked an inline comment as done.
ThomasRaoux marked 8 inline comments as done.May 20 2020, 6:03 PM
ThomasRaoux added inline comments.
mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
56 ↗(On Diff #265113)

Here type is not the type of $result so this syntax wouldn't work. Note sure if there is a way to get it to pick up type for the argument?

133 ↗(On Diff #265113)

I added some verification for both load and store pointer type however I didn't find an easy way to test the verification as the invalid IR would fail parsing.

176 ↗(On Diff #265113)

My bad I had wrongly assumed the type were matching. I added 2 types and a result type even though the result type could be deducted from the operand types.
I wasn't able to get the assemblyFormat tow work as I could find a way to make the type of c the same as type of result. (type($c, $result) doesn't work) I haven't look very deeply though I can investigate more if you want. Fow now I left the custom parsing/printing functions.

I also added a verify function for muladd and some tests along with it. I was planning to add those later to keep the patch small but it is better to add it now so that I don't forget.

antiagainst accepted this revision.May 21 2020, 5:24 AM
antiagainst marked 2 inline comments as done.

Cool! Sorry just a few more nits.

mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
179 ↗(On Diff #265398)

This need to be updated right?

56 ↗(On Diff #265113)

Ah you are right. Actually the op takes in the id for a type directly. Sorry missed that.

176 ↗(On Diff #265113)

SGTM. The assembly format is a bit opaque; River added support for it so he knows all the nitty gritty. Feel free to ping him later if you have questions there. (But he is OOO ATM I think.)

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
2672 ↗(On Diff #265398)

Nit: ...And... ?

2677 ↗(On Diff #265398)

s/snd/and/
s/coop/cooperative/

For error messages it's better to be clear; so spelling out things completely is preferrable.

2766 ↗(On Diff #265398)

Nit: In general MLIR does not use capitalized variable names. So I'd suggest to s/M/coopMatrix/ or something.

2777 ↗(On Diff #265398)

result and the third ... ?

No need to have the trailing dot here.

2782 ↗(On Diff #265398)

We can do unconditional cast in the above and omit the check here. These are guaranteed by ODS. Normally the constraints in ODS does not need to be repeated here. We just need to check additional constraints using C++ here.

2789 ↗(On Diff #265398)

In general error messages should stay all lower-case in MLIR. I think that's also the case for Clang/etc. So here: "matrix ..."

Besides, it's better to be consistent so either "matrix .. mismatch" or "matrix ... must match" for this and the following two cases.

This revision is now accepted and ready to land.May 21 2020, 5:24 AM
ThomasRaoux marked an inline comment as done.
ThomasRaoux marked 10 inline comments as done.May 21 2020, 10:34 AM
ThomasRaoux added inline comments.
mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
176 ↗(On Diff #265113)

Sounds good. I'll chat with him when I get a chance to see if this can be improved.

This revision was automatically updated to reflect the committed changes.
rriddle added inline comments.May 27 2020, 1:58 PM
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
2738 ↗(On Diff #265566)

nit: Drop trivial braces.

2746 ↗(On Diff #265566)

This looks like it should use the declarative format:

https://mlir.llvm.org/docs/OpDefinitions/#declarative-assembly-format

2756 ↗(On Diff #265566)

Same here and a few others.

ThomasRaoux marked 3 inline comments as done.May 27 2020, 5:58 PM
ThomasRaoux added inline comments.
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
2738 ↗(On Diff #265566)

I'll send a patch to address this.

2746 ↗(On Diff #265566)

River, I wasn't able to get a syntax working for this instruction as the type itself is an argument and is unrelated to the destination. Do yo have any suggestion to make it work?

2756 ↗(On Diff #265566)

Here also I couldn't get the automatic print/parser to work as I have an instruction
r = matmuladd a, b, c : type(a), type(b) -> type(r and c)
I wasn't able to express the fact that c and r must be the same type. Is there a way to make it work?

rriddle added inline comments.May 27 2020, 8:41 PM
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
2746 ↗(On Diff #265566)

If the type of the result is inferrable, you don't have to specify it. Inferable types are either via traits(e.g., SameOperands) or with buildable types(i.e. statically constructible like index, i32, etc.). In this case, the result type is SPV_Int32(i.e. just an alias i32) which is a buildable type. To get the same syntax, all you really need to do it seems is something like:

attr-dict `:` $type

(You would have seen this if you read the section that I sent you : https://mlir.llvm.org/docs/OpDefinitions/#type-inference)

2756 ↗(On Diff #265566)

In https://mlir.llvm.org/docs/OpDefinitions/#type-inference, it lists several traits that are supported for listing type equality constraints. You can do simple equality using AllTypesMatch<["c", "r"]> where r and c are the names of the operands/results/etc. With that types can be inferred for all as long as one can be inferred. I would expect the following to work:

def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<..., [..., AllTypesMatch<"c", "result">]> {
  ...
  let assemblyFormat = "operands `:` type($a), type($b) -> type($c)";
}