This is an archive of the discontinued LLVM Phabricator instance.

[MLIR][NVGPU] Introduction of wgmma.generate.descriptor Op
ClosedPublic

Authored by guraypp on Aug 8 2023, 4:22 AM.

Details

Summary

This work introduces a new Op, wgmma.generate.descriptor, designed to create a wgmma descriptor for inputs of matrix multiply and accumulate operations using wgmma.mma_async PTX instruction.

The descriptor format specifications can be found in the following link:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor

It's important to note that this op is in its initial phase, and it does come with certain limitations. It only supports 128b swizzling and does not incorporate interleaving. In the future, different calculations will be addressed in separate works, expanding the capabilities of the op.

Diff Detail

Event Timeline

guraypp created this revision.Aug 8 2023, 4:22 AM
Herald added a project: Restricted Project. · View Herald TranscriptAug 8 2023, 4:22 AM
guraypp requested review of this revision.Aug 8 2023, 4:22 AM
qcolombet added inline comments.Aug 10 2023, 12:40 AM
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
646

Maybe add: Where Mod is the swizzling mode.

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
939

Stupid question, do we have to unwrap the descriptor, as opposed to using it directly?

978

That's not enough to get the start address, is it?

The first three fields (start address, leading dim, and offset) are stored as matrix-descriptor-encode(x) = (x & 0x3FFFF) >> 0x4, if I'm not mistaken. (Though the spec is weird because I feel we would lose some information in the encoding process for the offset and leading dim at least.)

Anyhow, this leads back to my other question, should we even unwrap the descriptor.

981

Could you introduce constants for the various shifts and sizes?
E.g., StartAddrSizeInBits = ...
StartAddrBitStartPos = ...
And use that in the shifts and masks etc.

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
654

I found it strange to see a 0xf16.
Could we use f16 instead?

Also I would expect generally speaking we would have something like memref<?xf16> for this kind of global variable, wouldn't we?

qcolombet added inline comments.Aug 10 2023, 1:05 AM
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
939

Scratch that, I'm stupid, we're wrapping the value in here, not unwrapping it, so we need to create that thing.
I thought there was a dedicated nvidia instruction that creates the matrix descriptor, but no :).

guraypp updated this revision to Diff 548935.Aug 10 2023, 2:20 AM
guraypp marked 5 inline comments as done.

address comments

guraypp added a subscriber: ftynse.Aug 10 2023, 2:20 AM
guraypp added inline comments.
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
939

We had chat offline. Here we create descriptor from scratch, so we need to fill the bits.

978

Good catch. I added this to other fields as well. Thanks

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
654

Very good question.

0xf16,3 is needed for dynamic shared memory, NVPTX backend of LLVM implemented in that way.

memref<?xf16,3> IR does not get verified as a global object.

One can think of using memref<8192xf16,3>. However, there is CUDA limitation here. Sized shared memory is generated as static shared memory, and its limit is 48k whereas dynamic shared memory (0xf16) has larger.

I had chat about that with @ftynse. I am planning to model dynamic shared memory in MLIR.

guraypp updated this revision to Diff 548947.Aug 10 2023, 3:10 AM

calculate stride at compile-time

guraypp updated this revision to Diff 548953.Aug 10 2023, 3:31 AM

calculate leading dim also at compile time, and exclude 4 LSB

guraypp updated this revision to Diff 548954.Aug 10 2023, 3:32 AM

use const for LSB bit

guraypp updated this revision to Diff 548974.Aug 10 2023, 4:46 AM

fix the test

qcolombet added inline comments.Aug 10 2023, 9:53 AM
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
940

Stick to swizzle mode or say "layout mode" for the swizzle bits, because it can be difficult to connect both information when the comment explaining the struct of the bits says swizzle and the variable says layoutBit.

944

That should be 0 I believe.

982

I would suggest a higher level API (to hide the left/right shift used for masking) of the from:

val = insertBits(dst, field, start, size)

And chain that together, e.g.:

addressWo4LSB = makeShl(start_address, 4)
desc = insertBits(zero, addressWo4LSB, startAddressPos /*i.e., 0*/, startAddressSize /*i.e., 14*/)
strideDimWo4LSB = makeShl(strideDim, 4)
desc = insertBits(desc, strideDimWo4LSB, startStridePos, strideSize)
...

Where insertBits would produce something like:

insertBits(dst, field, start, size):
   mask = (1 << size) - 1 // computed as a constant at compile time
   masked_field = field & mask
   res = dst | (masked_field << start)
   return res

Note: at this point, you can simply implement exclude4LSB with makeShl(..., 4). (I.e., keep that function, the name helps the understanding, but the implementation becomes simpler.)

qcolombet added inline comments.Aug 10 2023, 9:54 AM
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
654

That's fine for now.
Maybe just add a comment saying that's how dynamic shared memory is currently represented.

guraypp updated this revision to Diff 550842.Aug 16 2023, 12:28 PM
guraypp marked 2 inline comments as done.

address comments

guraypp marked 3 inline comments as done.Aug 16 2023, 12:29 PM
qcolombet added inline comments.Aug 22 2023, 3:12 AM
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
963

We'll also need the size of the field (e.g., 14 for BaseAddr) to mask the bits appropriately.
Technically we may not need that because our values should always be in range (e.g., the masking should be a noop), but I'd rather we emit the proper code sequence and have it being optimized away.

Alternatively, we could emit llvm.assume(val < ((1 << size) - 1)) if you believe the masking is overkill.

guraypp updated this revision to Diff 552300.Aug 22 2023, 3:33 AM

simplify the pattern, return ssa values

guraypp added inline comments.Aug 22 2023, 5:02 AM
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
963

For the baseAddr, I do following below. Is it not enough?

Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
qcolombet accepted this revision.Aug 22 2023, 6:25 AM
qcolombet added inline comments.
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
963

Ah right.
It is enough, though it is not super readable (e.g., I missed it :)).

I would have hidden the left/right shift directly in insertBits by passing a size directly. Anyhow, aside from BaseAddr I see that everything is known at compile time, so let's go with what you have here. The systematic masking is not necessary / overkill at this point.

This revision is now accepted and ready to land.Aug 22 2023, 6:25 AM