diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -169,6 +169,30 @@ let assemblyFormat = "`<` struct(params) `>`"; } +def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "wgmma.descriptor", []> { + let summary = "Warpgroup matrix descriptor type"; + let description = [{ + The descriptor specifies the properties of the matrix in shared memory that + is a multiplicand in the matrix multiply and accumulate operation. + + The descriptor is a 64-bit value contained in a register with the following: + ``` + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | 0-13 |14-15| 16-29 |30-31| 32-45 |46-48|49-51| 52-61 |62-63| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | 14bits |2bits| 14bits |2bits| 14bits |2bits|3bits| 10bits |2bits| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | BaseAddr| 0 | LeadingDim| 0 | Stride | 0 |Offst| 0 |Swzle| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + ``` + + [See for more details in PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor) + + }]; + let parameters = (ins "MemRefType":$tensor); + let assemblyFormat = "`<` struct(params) `>`"; +} + //===----------------------------------------------------------------------===// // NVGPU Op Definitions //===----------------------------------------------------------------------===// @@ -628,32 +652,18 @@ def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> { let summary = "Generate a wgmma matrix descriptor"; let description = [{ - This Op builds a wgmma descriptor that is used by wgmma matrix multiply - and accumulate. + This Op builds a `nvgpu.wgmma.descriptor` that is used by warpgroup-level + matrix multiply and accumulate. The descriptor specifies the properties of the matrix in shared memory that is a multiplicand in the matrix multiply and accumulate operation. - - The descriptor is a 64-bit value contained in a register with the following - ``` - +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ - | 0-13 |14-15| 16-29 |30-31| 32-45 |46-48|49-51| 52-61 |62-63| - +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ - | 14bits |2bits| 14bits |2bits| 14bits |2bits|3bits| 10bits |2bits| - +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ - | BaseAddr| 0 | LeadingDim| 0 | Stride | 0 |Offst| 0 |Swzle| - +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ - ``` - - See for more details: - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor - }]; - let results = (outs I64:$descriptor); + let results = (outs NVGPU_WarpgroupMatrixDescriptor:$descriptor); let arguments = (ins Arg:$tensor, NVGPU_TensorMapDescriptor:$tensorMap); - let assemblyFormat = [{$tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap)}]; + let assemblyFormat = [{$tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap) `->` type($descriptor)}]; let hasVerifier = 1; } + #endif // NVGPU diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -417,6 +417,10 @@ converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { return converter.convertType(IntegerType::get(type.getContext(), 64)); }); + converter.addConversion( + [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { + return converter.convertType(IntegerType::get(type.getContext(), 64)); + }); converter.addConversion([&](nvgpu::MBarrierType type) -> Type { return converter.convertType( nvgpu::getMBarrierMemrefType(rewriter.getContext(), type)); diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -634,13 +634,13 @@ !tensorMap = !nvgpu.tensormap.descriptor, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none> memref.global "private" @dynamicShmem : memref<0xf16,3> // CHECK-LABEL: func @create_wgmma_descriptor( -func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> i64{ +func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.descriptor>{ %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3> %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3> // CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3> // CHECK: %[[Sre:.+]] = memref.reinterpret_cast %[[S0]] to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<0xf16, 3> to memref<128x64xf16, 3> // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[Sre]] : memref<128x64xf16, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[c64:.+]] = llvm.mlir.constant(64 : i64) : i64 + // CHECK: %[[c64:.+]] = llvm.mlir.constant(64 : i64) : i64 // CHECK: %[[c1024:.+]] = llvm.mlir.constant(1024 : i64) : i64 // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[S3:.+]] = llvm.ptrtoint %[[S2]] : !llvm.ptr<3> to i64 @@ -659,19 +659,19 @@ // CHECK: %[[S16:.+]] = llvm.or %[[S12]], %[[S15]] : i64 // CHECK: %[[S18:.+]] = llvm.mlir.constant(32 : i64) : i64 // CHECK: %[[S19:.+]] = llvm.shl %[[c64]], %[[S18]] : i64 - // CHECK: %[[S20:.+]] = llvm.or %[[S16]], %[[S19]] : i64 + // CHECK: %[[S20:.+]] = llvm.or %[[S16]], %[[S19]] : i64 // CHECK: %[[S22:.+]] = llvm.mlir.constant(16 : i64) : i64 // CHECK: %[[S23:.+]] = llvm.shl %[[c1024]], %[[S22]] : i64 // CHECK: %[[S24:.+]] = llvm.or %[[S20]], %[[S23]] : i64 // CHECK: %[[S25:.+]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[S26:.+]] = llvm.shl %[[S7]], %[[S25]] : i64 // CHECK: %[[S27:.+]] = llvm.or %[[S24]], %[[S26]] : i64 - // CHECK: return %[[S27]] : i64 - %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap - func.return %descA : i64 + // CHECK: %[[ret:.+]] = builtin.unrealized_conversion_cast %[[S27]] : i64 to !nvgpu.wgmma.descriptor> + // CHECK: return %[[ret]] + %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap -> !nvgpu.wgmma.descriptor> + func.return %descA : !nvgpu.wgmma.descriptor> } - transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 @@ -682,4 +682,4 @@ transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter {use_opaque_pointers = true} } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "scf"], partial_conversion} : !transform.any_op -} +} \ No newline at end of file