The Intel Advanced Matrix Extensions (AMX) provides a tile matrix
multiply unit (TMUL), a tile control register (TILECFG), and eight
tile registers TMM0 through TMM7 (TILEDATA). This new MLIR dialect
provides a bridge between MLIR concepts like vectors and memrefs
and the lower level LLVM IR details of AMX.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Unit Tests
Event Timeline
This is a large drop, so I added some background on discourse:
https://llvm.discourse.group/t/intel-amx-vector-dialect/2984/6
Can you split this patch up? Seems like there are several separable components: amx, llvm_amx, the lowering to LLVM, etc.
mlir/include/mlir/Dialect/AMX/AMX.td | ||
---|---|---|
10 | drive by comment: can you provide links to the "source of truth" documentation (insofar as it is available) for folks that want to dig deeper? |
Putting a blocker for double-checking the stride and maybe the error messages.
The rest is good to go and can be improved incrementally (e.g. dropping the LLVMAMX dialect if we can add a type hook).
Nice!
mlir/include/mlir/Dialect/AMX/AMX.td | ||
---|---|---|
78 | +1 on pointing to official doc so we can dig deeper. | |
133 | You also want to add some type checking on lhs/rhs via TypeMatchesWith what here is an example from AVX512 for syntax purposes: def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect, AllTypesMatch<["src", "a", "dst"]>, TypesMatchWith<"imm has the same number of bits as elements in dst", "dst", "imm", "IntegerType::get($_self.getContext(), " "($_self.cast<VectorType>().getShape()[0]))">]> { ... Edit: ah ok I see it appears in the C++ part. Feel free to ignore this comment and leave it as is in C++ or lift some of that into TypesMatchWith. | |
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td | ||
16 ↗ | (On Diff #330103) | Atm this is required because 2-D vector are not native in LLVM-IR, correct? Since the AMX vector type is a 2-D native type, is it reasonable to extend the LLVM type definition with target-specific hooks that would allow this dialect to disappear (@ftynse )? |
mlir/lib/Conversion/AMXToLLVM/ConvertAMXToLLVM.cpp | ||
31 ↗ | (On Diff #330103) | returning a pair or a struct with 2 values would be more idiomatic IMO |
35 ↗ | (On Diff #330103) | Add an assert here that the IntOrFloatBitWidth is a power of 2 plz. |
52 ↗ | (On Diff #330103) | The dynamic stride, if needed, should already be available to you in the descriptor. I haven't followed whether the refactorings still guarantee that static constants are visible when we pass function boundaries, I'd double check it by eyeballing the LLVMIR that gets emitted. In any case this is not just the last size, strides have a life of their own. |
102 ↗ | (On Diff #330103) | Ah ok I see now that the stride is opaque to MLIR and part of the AMX intrinsic. |
102 ↗ | (On Diff #330103) | Some extra checks are needed. LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl<int64_t> &strides, int64_t &offset); is not 1, then we should fail the conversion. |
114 ↗ | (On Diff #330103) | Same comments as above. |
148 ↗ | (On Diff #330103) | If you returned, you could also more naturally assert the m's agree. |
mlir/lib/Dialect/AMX/IR/AMXDialect.cpp | ||
47 | I'd go for nicer error messages here: the ops expect a certain vector layout so spelling the error a bit more would be useful. | |
54 | if you moved these emitError inside the function performing the verification, we could have nicer error messages. | |
mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir | ||
37 | Haha I love these types of test. |
mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h | ||
---|---|---|
26–28 | It would be nice if we could reconsider this trick. It was introduced to make sure the type system change between built-in vectors and llvm vectors was smooth, but the type system difference is (almost) gone. It feels like we only need some casting/packing between nD and 1D vectors to make vector-to-llvm conversion separate from "ISA dialect"-to-llvm conversion. Not for this commit though. | |
mlir/include/mlir/Dialect/AMX/AMX.td | ||
2 | Nit: this should match the file name | |
27–28 | This should also match the filename, I suppose | |
136 | Nit: putting quotes or backticks, e.g., an "m x k" tile with a "k x n" around variables would make it more readable | |
151 | Nit: something went wrong with whitespace | |
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td | ||
1 ↗ | (On Diff #330103) | I am going in the inverse direction and removing the X/LLVM_X separation between dialects. It is the legacy of there being two completely disjoint type systems. Only ArmSVE is still there just because I haven't had time. So I would appreciate if this didn't introduce another pair of dialects. I think all of these operations can live in the "main" AMX dialect, the patterns can be an in-dialect conversion given a list of "lower-level" ops. AVX512 already follows this model and can serve as an example. https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/AVX512/AVX512.td This will make it easy to "group" high-level ops and low-level ops so that they can share the description, at least in the comments. We can also think about infrastructural support for defining pairs/groups of ops when that becomes necessary (as a matter of fact, I have prototyped it, but it was more code and complexity than just keeping separate op definitions). |
16 ↗ | (On Diff #330103) | I'm not sure I understand what you suggest here @nicolasvasilache. A dialect is merely a collection of ops. Absolutely nothing prevents these ops from living in the "main" AMX dialect. The conversion configuration will be slightly longer, but that's pretty much it. These ops match LLVM IR intrinsics 1-1 and the conversion from "main" AVX ops to these is non-trivial. I am quite favorable to keeping these ops and the conversion, rather than somehow extending the translation or llvm types. LLVM_Type is actually LLVM_AnyCompatibleType, we shouldn't be using this anymore, more specific type constraints have been available for several months. If you want operations to accept different types, there's always AnyTypeOf<[]>. |
28 ↗ | (On Diff #330103) | "MLIR LLVM Dialect type system" no longer exists |
49 ↗ | (On Diff #330103) | LLVM_Type accepts any type potentially usable in the LLVM dialect, could we put tighter type constraints? This should be easy if we have these ops next to the higher-level op definition :) |
61 ↗ | (On Diff #330103) | Could we have at least a comment about the semantics of this op, tdpbssd isn't very intuitive. This should be easy if this op lives next to the higher-level op that has a detailed description :) |
mlir/lib/Conversion/AMXToLLVM/ConvertAMXToLLVM.cpp | ||
52 ↗ | (On Diff #330103) | It should be possible to just call memRefDescriptor.stride(position) here. There's no embedded folding though, it always reads the descriptor. I'm fine adding the folding there if desirable. |
69 ↗ | (On Diff #330103) | Nit: context is available in both ptr and loc, no need to pass the type converter to access it. |
mlir/lib/Dialect/AMX/IR/AMXDialect.cpp | ||
2 | AMXDialect.cpp |
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td | ||
---|---|---|
16 ↗ | (On Diff #330103) | A dialect is merely a collection of ops. Absolutely nothing prevents these ops from living in the "main" AMX dialect. Ah yes I continually oversubscribe on dialect goes away => we must be use the same type which def. not true. |
mlir/lib/Conversion/AMXToLLVM/ConvertAMXToLLVM.cpp | ||
52 ↗ | (On Diff #330103) | For the folding part I am just asking about the time the descriptor LLVM struct is created and filled (i.e. function boundary and alloc). As long as this is still tru, LLVM should be able to canonicalize / fold away for us. |
mlir/lib/Conversion/AMXToLLVM/ConvertAMXToLLVM.cpp | ||
---|---|---|
52 ↗ | (On Diff #330103) | It doesn't sound like something could have been broken by refactorings, they were mostly moving code around. As long as there is a dialect conversion somewhere, it will try folding any operation before converting it. If we need more, we can always start adding canonializations on the LLVM dialect. |
Thanks for working on this, Aart! I think the progressive lowering approach that you are taking here is very on point! I’m not working with AMX but it would be great if you could add me as a reviewer to the related code reviews. It would be very educational for me since this approach is also applicable to similar internal problems we have.
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td | ||
---|---|---|
1 ↗ | (On Diff #330103) | I was of course very aware of the direction you were taken and in fact started out this AMX dialect following this approach. But I got stuck on the fact that the 2-d types needs to be LLVM IR types. So I am uncertain on how you see this work without the LLVM IR dialect. Could you please sketch your vision in a bit more detail here? Also note that the AMX lowering uses quite a few non-trivial lowerings which work really well at the moment (e.g. getStridedElementPtr), which I am unsure would work without the intermediate. |
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td | ||
---|---|---|
1 ↗ | (On Diff #330103) | AFAIU @ftynse suggests to just move the ops and drop the extra dialect. Am I missing something ? |
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td | ||
---|---|---|
1 ↗ | (On Diff #330103) |
Sorry for being slow. I don't see how to easily lower %4 = amx.tilemulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32> to call x86_amx @llvm.x86.tdpbf16ps.internal(i16 2, i16 8, i16 8, x86_amx %51, x86_amx %46, x86_amx %50 without going through %52 = "llvm_amx.tdpbf16ps"(%50, %49, %51, %47, %34, %44) : (i16, i16, i16, !llvm.array<2 x vector<2xf32>>, !llvm.array<2 x vector<4xbf16>>, !llvm.array<2 x vector<4xbf16>>) first |
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td | ||
---|---|---|
1 ↗ | (On Diff #330103) | Ah, wait, I guess I see what you are getting at. (1) Add the LLVM IR part into AMX dialect I suppose that would work yes. But unlike the previous ARM case where ops where literal 1:1 mappings with no type changes, this feels like a very subjective aesthetic. I find the separate dialect more intuitive for this case. But I will do as you requested.... |
merged LLVM IR AMX dialect with AMX dialect (other comments still to be addressed....)
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td | ||
---|---|---|
1 ↗ | (On Diff #330103) | Right, I was essentially in the same mental model, see @ftynse's comment about type being orthogonal to dialects. Re: automatic 1-1 patterns, we also have a bit of precedent now in the arm dialect: https://reviews.llvm.org/D98198 |
mlir/include/mlir/Dialect/AMX/AMX.td | ||
---|---|---|
10 | I added a link (but with the caveat that Intel urls are notorious for changing all the times). | |
78 | The row elements are contiguous, the column starting points are defined by a stride, hardcoded in the instructions. I added a comment to the intrinsics that have that stride. | |
mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td | ||
28 ↗ | (On Diff #330103) | removed the full dialect, so including this comment ;-) |
61 ↗ | (On Diff #330103) | Added comment, made type more precise. |
mlir/lib/Conversion/AMXToLLVM/ConvertAMXToLLVM.cpp | ||
35 ↗ | (On Diff #330103) | Added (note that we also have type restrictions on the op already, but for future extension this nevrer hurts of course). |
It would be nice if we could reconsider this trick. It was introduced to make sure the type system change between built-in vectors and llvm vectors was smooth, but the type system difference is (almost) gone. It feels like we only need some casting/packing between nD and 1D vectors to make vector-to-llvm conversion separate from "ISA dialect"-to-llvm conversion. Not for this commit though.