Page MenuHomePhabricator

[mlir][amx] Add Intel AMX dialect (architectural-specific vector dialect)
ClosedPublic

Authored by aartbik on Mar 11 2021, 4:35 PM.

Details

Summary

The Intel Advanced Matrix Extensions (AMX) provides a tile matrix
multiply unit (TMUL), a tile control register (TILECFG), and eight
tile registers TMM0 through TMM7 (TILEDATA). This new MLIR dialect
provides a bridge between MLIR concepts like vectors and memrefs
and the lower level LLVM IR details of AMX.

Diff Detail

Event Timeline

aartbik created this revision.Mar 11 2021, 4:35 PM
aartbik requested review of this revision.Mar 11 2021, 4:35 PM

This is a large drop, so I added some background on discourse:

https://llvm.discourse.group/t/intel-amx-vector-dialect/2984/6

Can you split this patch up? Seems like there are several separable components: amx, llvm_amx, the lowering to LLVM, etc.

silvas added a subscriber: silvas.Mar 11 2021, 7:55 PM
silvas added inline comments.
mlir/include/mlir/Dialect/AMX/AMX.td
10

drive by comment: can you provide links to the "source of truth" documentation (insofar as it is available) for folks that want to dig deeper?

nicolasvasilache requested changes to this revision.Mar 12 2021, 12:08 AM

Putting a blocker for double-checking the stride and maybe the error messages.
The rest is good to go and can be improved incrementally (e.g. dropping the LLVMAMX dialect if we can add a type hook).

Nice!

mlir/include/mlir/Dialect/AMX/AMX.td
78

+1 on pointing to official doc so we can dig deeper.
It is unclear to me whether the bytes need to be contiguous in memory of whether there is a way to accept strides and what alignment constraints are required for correctness or eprf.
IIRC you mentioned there is a configuration mechanism for the sizes, does it also support something for striding for memory accesses ?

133

You also want to add some type checking on lhs/rhs via TypeMatchesWith what here is an example from AVX512 for syntax purposes:

def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect,
  AllTypesMatch<["src", "a", "dst"]>,
  TypesMatchWith<"imm has the same number of bits as elements in dst",
                 "dst", "imm",
                 "IntegerType::get($_self.getContext(), "
                 "($_self.cast<VectorType>().getShape()[0]))">]> {
...

Edit: ah ok I see it appears in the C++ part. Feel free to ignore this comment and leave it as is in C++ or lift some of that into TypesMatchWith.

mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td
16 ↗(On Diff #330103)

Atm this is required because 2-D vector are not native in LLVM-IR, correct?

Since the AMX vector type is a 2-D native type, is it reasonable to extend the LLVM type definition with target-specific hooks that would allow this dialect to disappear (@ftynse )?

mlir/lib/Conversion/AMXToLLVM/ConvertAMXToLLVM.cpp
31 ↗(On Diff #330103)

returning a pair or a struct with 2 values would be more idiomatic IMO

35 ↗(On Diff #330103)

Add an assert here that the IntOrFloatBitWidth is a power of 2 plz.
While I don't expect we'll ever be able to see i33 here, the bug would be so nasty to debug that a line of defense makes sense to me.

52 ↗(On Diff #330103)

The dynamic stride, if needed, should already be available to you in the descriptor.
Is MemRefDescriptorView what we use these days to get this information @ftynse ?

I haven't followed whether the refactorings still guarantee that static constants are visible when we pass function boundaries, I'd double check it by eyeballing the LLVMIR that gets emitted.

In any case this is not just the last size, strides have a life of their own.
They happen to be related to sizes in a particular case only.

102 ↗(On Diff #330103)

Ah ok I see now that the stride is opaque to MLIR and part of the AMX intrinsic.
Can you just document that in the op definition please ?

102 ↗(On Diff #330103)

Some extra checks are needed.
If the most minor stride obtained by calling the following on the memref type:

LogicalResult getStridesAndOffset(MemRefType t,
                                  SmallVectorImpl<int64_t> &strides,
                                  int64_t &offset);

is not 1, then we should fail the conversion.

114 ↗(On Diff #330103)

Same comments as above.

148 ↗(On Diff #330103)

If you returned, you could also more naturally assert the m's agree.

mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
47

I'd go for nicer error messages here: the ops expect a certain vector layout so spelling the error a bit more would be useful.

54

if you moved these emitError inside the function performing the verification, we could have nicer error messages.

mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir
37

Haha I love these types of test.
In polyhedral land, one would skew the the tile and cut it at the boundaries: seeing the pattern is correct is a must.

This revision now requires changes to proceed.Mar 12 2021, 12:08 AM
ftynse added inline comments.Mar 12 2021, 2:39 AM
mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
26–28

It would be nice if we could reconsider this trick. It was introduced to make sure the type system change between built-in vectors and llvm vectors was smooth, but the type system difference is (almost) gone. It feels like we only need some casting/packing between nD and 1D vectors to make vector-to-llvm conversion separate from "ISA dialect"-to-llvm conversion. Not for this commit though.

mlir/include/mlir/Dialect/AMX/AMX.td
2

Nit: this should match the file name

27–28

This should also match the filename, I suppose

136

Nit: putting quotes or backticks, e.g., an "m x k" tile with a "k x n" around variables would make it more readable

151

Nit: something went wrong with whitespace

mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td
1 ↗(On Diff #330103)

I am going in the inverse direction and removing the X/LLVM_X separation between dialects. It is the legacy of there being two completely disjoint type systems. Only ArmSVE is still there just because I haven't had time. So I would appreciate if this didn't introduce another pair of dialects. I think all of these operations can live in the "main" AMX dialect, the patterns can be an in-dialect conversion given a list of "lower-level" ops. AVX512 already follows this model and can serve as an example. https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/AVX512/AVX512.td This will make it easy to "group" high-level ops and low-level ops so that they can share the description, at least in the comments. We can also think about infrastructural support for defining pairs/groups of ops when that becomes necessary (as a matter of fact, I have prototyped it, but it was more code and complexity than just keeping separate op definitions).

16 ↗(On Diff #330103)

I'm not sure I understand what you suggest here @nicolasvasilache. A dialect is merely a collection of ops. Absolutely nothing prevents these ops from living in the "main" AMX dialect. The conversion configuration will be slightly longer, but that's pretty much it.

These ops match LLVM IR intrinsics 1-1 and the conversion from "main" AVX ops to these is non-trivial. I am quite favorable to keeping these ops and the conversion, rather than somehow extending the translation or llvm types.

LLVM_Type is actually LLVM_AnyCompatibleType, we shouldn't be using this anymore, more specific type constraints have been available for several months. If you want operations to accept different types, there's always AnyTypeOf<[]>.

28 ↗(On Diff #330103)

"MLIR LLVM Dialect type system" no longer exists

49 ↗(On Diff #330103)

LLVM_Type accepts any type potentially usable in the LLVM dialect, could we put tighter type constraints? This should be easy if we have these ops next to the higher-level op definition :)

61 ↗(On Diff #330103)

Could we have at least a comment about the semantics of this op, tdpbssd isn't very intuitive. This should be easy if this op lives next to the higher-level op that has a detailed description :)

mlir/lib/Conversion/AMXToLLVM/ConvertAMXToLLVM.cpp
52 ↗(On Diff #330103)

It should be possible to just call memRefDescriptor.stride(position) here. There's no embedded folding though, it always reads the descriptor. I'm fine adding the folding there if desirable.

https://github.com/llvm/llvm-project/blob/cfe8f8e0f010077f5942bce88a2fd331b90ccea7/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h#L304

69 ↗(On Diff #330103)

Nit: context is available in both ptr and loc, no need to pass the type converter to access it.

mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
2

AMXDialect.cpp

mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td
16 ↗(On Diff #330103)

A dialect is merely a collection of ops. Absolutely nothing prevents these ops from living in the "main" AMX dialect.

Ah yes I continually oversubscribe on dialect goes away => we must be use the same type which def. not true.
My brain gets some time getting rewired :)

mlir/lib/Conversion/AMXToLLVM/ConvertAMXToLLVM.cpp
52 ↗(On Diff #330103)

For the folding part I am just asking about the time the descriptor LLVM struct is created and filled (i.e. function boundary and alloc).
When I wrote that the constants were properly propagated at construction time and I don't the refactoring changed it (otherwise things would have broken in noticeable ways).

As long as this is still tru, LLVM should be able to canonicalize / fold away for us.
Still would be nice to confirm looking at LLVM IR post -O3; if for some reason it does not we may want to work a little harder on our end.

ftynse added inline comments.Mar 12 2021, 3:49 AM
mlir/lib/Conversion/AMXToLLVM/ConvertAMXToLLVM.cpp
52 ↗(On Diff #330103)

It doesn't sound like something could have been broken by refactorings, they were mostly moving code around. As long as there is a dialect conversion somewhere, it will try folding any operation before converting it. If we need more, we can always start adding canonializations on the LLVM dialect.

Thanks for working on this, Aart! I think the progressive lowering approach that you are taking here is very on point! I’m not working with AMX but it would be great if you could add me as a reviewer to the related code reviews. It would be very educational for me since this approach is also applicable to similar internal problems we have.

aartbik added inline comments.Mar 12 2021, 9:56 AM
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td
1 ↗(On Diff #330103)

I was of course very aware of the direction you were taken and in fact started out this AMX dialect following this approach. But I got stuck on the fact that the 2-d types needs to be LLVM IR types.

So I am uncertain on how you see this work without the LLVM IR dialect. Could you please sketch your vision in a bit more detail here? Also note that the AMX lowering uses quite a few non-trivial lowerings which work really well at the moment (e.g. getStridedElementPtr), which I am unsure would work without the intermediate.

Thanks for working on this, Aart! I think the progressive lowering approach that you are taking here is very on point! I’m not working with AMX but it would be great if you could add me as a reviewer to the related code reviews. It would be very educational for me since this approach is also applicable to similar internal problems we have.

Thanks Diego! Absolutely, I will start adding you to vector related stuff.

aartbik added inline comments.Mar 12 2021, 11:58 AM
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td
1 ↗(On Diff #330103)

I forgot to quote in my reply above, so tagging you here explicitly @ftynse

mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td
1 ↗(On Diff #330103)

AFAIU @ftynse suggests to just move the ops and drop the extra dialect.
The same conversions would still exist but would within the dialect.
This is independent from the fact that we can automatically convert without worrying about the type.

Am I missing something ?

aartbik added inline comments.Mar 12 2021, 1:19 PM
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td
1 ↗(On Diff #330103)

Just move the ops and drop the extra dialect.

Sorry for being slow.

I don't see how to easily lower

%4 = amx.tilemulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32>

to

call x86_amx @llvm.x86.tdpbf16ps.internal(i16 2, i16 8, i16 8, x86_amx %51, x86_amx %46, x86_amx %50

without going through

%52 = "llvm_amx.tdpbf16ps"(%50, %49, %51, %47, %34, %44) : (i16, i16, i16, !llvm.array<2 x vector<2xf32>>, !llvm.array<2 x vector<4xbf16>>, !llvm.array<2 x vector<4xbf16>>)

first

aartbik added inline comments.Mar 12 2021, 1:49 PM
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td
1 ↗(On Diff #330103)

Ah, wait, I guess I see what you are getting at.

(1) Add the LLVM IR part into AMX dialect
(2) make the conversion a legalization where half the AMX ops are valid and half are invalid

I suppose that would work yes. But unlike the previous ARM case where ops where literal 1:1 mappings with no type changes, this feels like a very subjective aesthetic. I find the separate dialect more intuitive for this case.

But I will do as you requested....

aartbik updated this revision to Diff 330416.Mar 12 2021, 7:42 PM

merged LLVM IR AMX dialect with AMX dialect (other comments still to be addressed....)

bondhugula added inline comments.
mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
191

Drop commented out code?

197

Drop commented out code?

nicolasvasilache accepted this revision.Mar 13 2021, 5:04 AM
nicolasvasilache added inline comments.
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td
1 ↗(On Diff #330103)

Right, I was essentially in the same mental model, see @ftynse's comment about type being orthogonal to dialects.

Re: automatic 1-1 patterns, we also have a bit of precedent now in the arm dialect: https://reviews.llvm.org/D98198
There the 1-1 aspect additionally involves 2-d -> 1d flattening considerations.

This revision is now accepted and ready to land.Mar 13 2021, 5:04 AM

Accepted conditioned on addressing the rest, thanks Aart!

aartbik marked 30 inline comments as done.Mon, Mar 15, 3:32 PM
aartbik added inline comments.
mlir/include/mlir/Dialect/AMX/AMX.td
10

I added a link (but with the caveat that Intel urls are notorious for changing all the times).

78

The row elements are contiguous, the column starting points are defined by a stride, hardcoded in the instructions. I added a comment to the intrinsics that have that stride.

mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td
28 ↗(On Diff #330103)

removed the full dialect, so including this comment ;-)

61 ↗(On Diff #330103)

Added comment, made type more precise.

mlir/lib/Conversion/AMXToLLVM/ConvertAMXToLLVM.cpp
35 ↗(On Diff #330103)

Added (note that we also have type restrictions on the op already, but for future extension this nevrer hurts of course).

aartbik updated this revision to Diff 330829.Mon, Mar 15, 3:48 PM
aartbik marked 5 inline comments as done.

better error messages, more doc on ops, new asserts, stride checks

aartbik updated this revision to Diff 330832.Mon, Mar 15, 4:02 PM

removed commented out code

This revision was landed with ongoing or failed builds.Mon, Mar 15, 5:59 PM
This revision was automatically updated to reflect the committed changes.