diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -28,10 +28,11 @@ let name = "nvgpu"; let cppNamespace = "::mlir::nvgpu"; let description = [{ - This `NVGPU` dialect provides a bridge between the target agnostic GPU and - Vector dialects and the lower level LLVM IR based NVVM dialect. This allow - representing PTX specific operations while using MLIR high level concepts - like memref and 2-D vector. + The `NVGPU` dialect provides a bridge between higher-level target-agnostic + dialects (GPU and Vector) and the lower-level target-specific dialect + (LLVM IR based NVVM dialect) for NVIDIA GPUs. This allow representing PTX + specific operations while using MLIR high level dialects such as Memref + and Vector for memory and target-specific register operands, respectively. }]; let useDefaultTypePrinterParser = 1; @@ -70,20 +71,20 @@ PredOpTrait<"srcMemref and res have same element type", TCresVTEtIsSameAsOp<0, 0>>]> { let description = [{ - The `nvgpu.ldmatrix` op represents loading a matrix fragment from - memory. The load source and result type must be compatible with lowering - to the `nvvm.ldmatrix` instruction. This op is meant to represent - the distributed version of a `vector.transfer_read` as an intermediate - step between lowering from `vector.transfer_read` to `nvvm.ldmatrix`. - - This operation is meant to follow the semantic of described here: - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix - - Example: - ```mlir - %0 = nvgpu.ldmatrix %sm[%c0, %c0] {numTiles = 4 : i32, transpose = false} : - memref -> vector<4x2xf16> - ``` + The `nvgpu.ldmatrix` op represents loading a matrix fragment from + memory to registers. The source and result type must be compatible + with lowering to the `nvvm.ldmatrix` instruction. This op represents + the distributed version of a `vector.transfer_read` as an intermediate + step between lowering from `vector.transfer_read` to `nvvm.ldmatrix`. + + This operation is meant to follow the semantic of described here: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix + + Example: + ```mlir + %0 = nvgpu.ldmatrix %sm[%c0, %c0] {numTiles = 4 : i32, transpose = false} : + memref -> vector<4x2xf16> + ``` }]; let arguments = (ins Arg:$srcMemref, @@ -102,25 +103,24 @@ PredOpTrait<"matrixA and matrixB have same element type", TCopVTEtIsSameAs<0, 1>>]> { let description = [{ - The `nvgpu.mma.sync` op represents the distributed form of a collective - matrix-multiply-and-accumulate (mma) operation that is compatible with - `nvvm.mma.sync`. The operands and results are fragments of the full matrix - operands. The full shape of the distributed mma operation is given by the - `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`. - - This operation is meant to be lowered to the `nvvm.mma.sync` instruction, and - is an intermediate point between lowering from `vector.contract` to - `nvvm.mma.sync`. + The `nvgpu.mma.sync` op represents the warp-level matrix-multiply-and- + accumulate (mma) operation that is compatible with `nvvm.mma.sync`. + The operands and results vector sizes are thread-level onwership to + the warp-level mma operation shape. `mmaShape` attribute holds the + warp-level matrix-multiply shape. + + The `nvgpu.mma.sync` op serves as an intermediate point between lowering from + `vector.contract` to `nvvm.mma.sync`. - This operation is meant to follow the semantic of described here: - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma + This operation is meant to follow the semantic of described here: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma - Example: + Example: - ```mlir - nvgpu.mma.sync (%a, %b, %c) : - (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> - ``` + ```mlir + %res = nvgpu.mma.sync (%matrixA, %matrixB, %matrixC) {mmaShape = [16, 8, 16]} : + (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32> + ``` }]; let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB, @@ -152,25 +152,25 @@ let summary = "device-side asynchronous copy"; let description = [{ The `nvgpu.device_async_copy` op initiates an asynchronous copy operation of - `$size` elements from source to the destination without blocking the thread. - The destination has to be in shared memory. + elements from source (global memory) to the destination (shared memory) + without blocking the thread. The async copy is added to a group. - This is memory access will be pending to be added to a group. - - This op is meant to be used with `gpu.device_async_create_group` and - `gpu.device_async_wait` to synchronize copies as explained in those ops + This op is meant to be used with `nvgpu.device_async_create_group` and + `nvgpu.device_async_wait` to synchronize copies as explained in those ops descriptions. - `bypassL1` attribute is hint to the backend and hardware that - the copy should by pass the L1 cache, this may be dropped by the backend or - hardware. + + `bypassL1` attribute is hint to the hardware to bypass the L1 cache during + async copy, this hint may be ignored by the hardware. + `dstElements` attribute is the total number of elements written to destination (shared memory). + `srcElements` argument is the total number of elements read from source (global memory). - srcElements` is an optional argument and when present it only reads - srcElements number of elements from the source global memory and zero fills - the rest of the elements in the destination shared memory. + `srcElements` is an optional argument and when present the op only reads + `srcElements` number of elements from the source (global memory) and zero fills + the rest of the elements in the destination (shared memory). In order to do a copy and wait for the result we need the following combination: @@ -216,21 +216,21 @@ def NVGPU_DeviceAsyncCreateGroupOp : NVGPU_Op<"device_async_create_group", []> { let summary = "device side asynchronous create group operation"; let description = [{ - The `nvgpu.device_async_create_group` op creates a group of memory accesses - containing all the pending `device_async_copy` operations associated with - argument tokens. Each token can only be part of one group. + The `nvgpu.device_async_create_group` op creates a group of memory accesses + containing all the pending `device_async_copy` operations associated with + argument tokens. Each token can only be part of one group. - It returns a token that can be use to wait until the group fully completes. + It returns a token that can be use to wait until the group fully completes. - This is meant to be used with `nvgpu.device_async_wait` to synchronize copies - as explained in those ops descriptions. + This is meant to be used with `nvgpu.device_async_wait` to synchronize copies + as explained in those ops descriptions. - Groups are executed in the order they are created. + Groups are executed in the order they are created. - Example: + Example: - ```mlir - %0 = nvgpu.device_async_create_group + ```mlir + %0 = nvgpu.device_async_create_group ``` }]; let results = (outs NVGPU_DeviceAsyncToken:$asyncToken); @@ -243,16 +243,17 @@ def NVGPU_DeviceAsyncWaitOp : NVGPU_Op<"device_async_wait", []> { let summary = "Wait for async gpu ops to complete."; let description = [{ - The `nvgpu.device_async_wait` op will block the execution thread until the group - associated with the source token is fully completed. + The `nvgpu.device_async_wait` op will block the execution thread until the group + associated with the source token is fully completed. The optional `$numGroup` attribute gives a lower bound of the number of groups uncompleted when the wait can unblock the thread. - Example: - ```mlir - nvgpu.device_async_wait %0 - ``` + Example: + + ```mlir + nvgpu.device_async_wait %0 + ``` }]; let arguments = (ins NVGPU_DeviceAsyncToken:$asyncDependencies, OptionalAttr:$numGroups);