diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -625,4 +625,35 @@ let hasVerifier = 1; } +def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> { + let summary = "Generate a wgmma matrix descriptor"; + let description = [{ + This Op builds a wgmma descriptor that is used by wgmma matrix multiply + and accumulate. + + The descriptor specifies the properties of the matrix in shared memory that + is a multiplicand in the matrix multiply and accumulate operation. + + The descriptor is a 64-bit value contained in a register with the following + ``` + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | 0-13 |14-15| 16-29 |30-31| 32-45 |46-48|49-51| 52-61 |62-63| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | 14bits |2bits| 14bits |2bits| 14bits |2bits|3bits| 10bits |2bits| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | BaseAddr| 0 | LeadingDim| 0 | Stride | 0 |Offst| 0 |Swzle| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + ``` + + See for more details: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor + + }]; + let results = (outs I64:$descriptor); + let arguments = (ins Arg:$tensor, + NVGPU_TensorMapDescriptor:$tensorMap); + let assemblyFormat = [{$tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap)}]; + let hasVerifier = 1; +} + #endif // NVGPU diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -934,6 +934,78 @@ return success(); } }; +struct NVGPUGenerateGmmaDescriptorLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + nvgpu::GenerateGmmaDescriptorOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + constexpr int startLayoutBit = 62; + constexpr int startOffsetBit = 49; + constexpr int startStrideBit = 32; + constexpr int startLeadingDimBit = 16; + constexpr int startBaseAddrBit = 50; + constexpr int excludeLSBit = 4; + + Location loc = op->getLoc(); + + nvgpu::TensorMapSwizzleKind swizzleKind = + op.getTensorMap().getType().getSwizzle(); + + unsigned swizzle = + (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128 + : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64 + : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32 + : 1; + + auto ti64 = rewriter.getIntegerType(64); + auto makeConst = [&](uint64_t index) -> Value { + return rewriter.create( + loc, ti64, rewriter.getI64IntegerAttr(index)); + }; + auto shiftLeft = [&](Value value, unsigned shift) -> Value { + return rewriter.create(loc, ti64, value, makeConst(shift)); + }; + auto shiftRight = [&](Value value, unsigned shift) -> Value { + return rewriter.create(loc, ti64, value, makeConst(shift)); + }; + auto makeOr = [&](Value lhs, Value rhs) -> Value { + return rewriter.create(loc, ti64, lhs, rhs); + }; + + Value desc = makeConst(0); + // [62,64) layout type + // 6 bits unused, 2 bits [6,8) + desc = makeOr(desc, shiftLeft(makeConst(uint64_t(1)), startLayoutBit)); + // [49,52) base_offset + // 1 bit unused, 3 bits [1,4), 4 bits unused + // Valid only for SWIZZLE_128B and SWIZZLE_64B + desc = makeOr(desc, shiftLeft(makeConst(0), startOffsetBit)); + // [32,46) stride + // 14 bits [0,14), 2 bits unused (Exclude 4LSB) + Value strideDim = makeConst((swizzle << 3) >> excludeLSBit); + desc = makeOr(desc, shiftLeft(strideDim, startStrideBit)); + // [16,30) leading dimension + // 14 bits [0,14), 2 bits unused (Exclude 4LSB) + int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); + Value leadDim = makeConst((sizeN * swizzle) >> excludeLSBit); + desc = makeOr(desc, shiftLeft(leadDim, startLeadingDimBit)); + // [0,14) start_address + // 14 bits [0,14), 2 bits unused + Value basePtr = rewriter.create( + op->getLoc(), adaptor.getTensor(), 1); + Value ptri64 = rewriter.create(loc, ti64, basePtr); + // Exclude 4LSB + Value startAdress = shiftRight( + shiftLeft(ptri64, (startBaseAddrBit - excludeLSBit)), startBaseAddrBit); + desc = makeOr(desc, startAdress); + + rewriter.replaceOp(op, desc); + return success(); + } +}; static Value makeI64Const(RewriterBase &rewriter, Operation *op, int32_t index) { @@ -1064,6 +1136,7 @@ NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load + NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -366,6 +366,42 @@ return success(); } +//===----------------------------------------------------------------------===// +// NVGPU_GenerateGmmaDescriptorOp +//===----------------------------------------------------------------------===// + +LogicalResult GenerateGmmaDescriptorOp::verify() { + MemRefType memrefType = getTensor().getType(); + MemRefType tensorMapType = getTensorMap().getType().getTensor(); + + if (memrefType != tensorMapType) + return emitError() << "memref and tensor map type mismatch"; + + if (!memrefType.hasStaticShape() || !tensorMapType.hasStaticShape()) + return emitError() << "supports only static shapes"; + + if (memrefType.getRank() != 2) + return emitError() << "supports only 2d memref is supported for now"; + + if (getTensorMap().getType().getSwizzle() != + TensorMapSwizzleKind::SWIZZLE_128B) { + return emitError() << "supports only " + << stringifyTensorMapSwizzleKind( + TensorMapSwizzleKind::SWIZZLE_128B) + << " is supported for the time being"; + } + + if (getTensorMap().getType().getInterleave() != + TensorMapInterleaveKind::INTERLEAVE_NONE) { + return emitError() << "supports only " + << stringifyTensorMapInterleaveKind( + TensorMapInterleaveKind::INTERLEAVE_NONE) + << " is supported for the time being"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd dialect, type, and op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -631,6 +631,45 @@ } } +!tensorMap = !nvgpu.tensormap.descriptor, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none> +memref.global "private" @dynamicShmem : memref<0xf16,3> +// CHECK-LABEL: func @create_wgmma_descriptor( +func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> i64{ + %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3> + %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3> + // CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3> + // CHECK: %[[Sreinterpret_cast:.+]] = memref.reinterpret_cast %[[S0]] to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<0xf16, 3> to memref<128x64xf16, 3> + // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %reinterpret_cast : memref<128x64xf16, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[S2:.+]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[S3:.+]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: %[[S4:.+]] = llvm.mlir.constant(62 : i64) : i64 + // CHECK: %[[S5:.+]] = llvm.shl %[[S3]], %[[S4]] : i64 + // CHECK: %[[S6:.+]] = llvm.or %[[S2]], %[[S5]] : i64 + // CHECK: %[[S7:.+]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[S8:.+]] = llvm.mlir.constant(49 : i64) : i64 + // CHECK: %[[S9:.+]] = llvm.shl %[[S7]], %[[S8]] : i64 + // CHECK: %[[S10:.+]] = llvm.or %[[S6]], %[[S9]] : i64 + // CHECK: %[[S11:.+]] = llvm.mlir.constant(64 : i64) : i64 + // CHECK: %[[S12:.+]] = llvm.mlir.constant(32 : i64) : i64 + // CHECK: %[[S13:.+]] = llvm.shl %[[S11]], %[[S12]] : i64 + // CHECK: %[[S14:.+]] = llvm.or %[[S10]], %[[S13]] : i64 + // CHECK: %[[S15:.+]] = llvm.mlir.constant(1024 : i64) : i64 + // CHECK: %[[S16:.+]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[S17:.+]] = llvm.shl %[[S15]], %[[S16]] : i64 + // CHECK: %[[S18:.+]] = llvm.or %[[S14]], %[[S17]] : i64 + // CHECK: %[[S19:.+]] = llvm.extractvalue %1[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[S20:.+]] = llvm.ptrtoint %[[S19]] : !llvm.ptr<3> to i64 + // CHECK: %[[S21:.+]] = llvm.mlir.constant(46 : i64) : i64 + // CHECK: %[[S22:.+]] = llvm.shl %[[S20]], %[[S21]] : i64 + // CHECK: %[[S23:.+]] = llvm.mlir.constant(50 : i64) : i64 + // CHECK: %[[S24:.+]] = llvm.lshr %[[S22]], %[[S23]] : i64 + // CHECK: %[[S25:.+]] = llvm.or %[[S18]], %[[S24]] : i64 + // CHECK: return %[[S25]] : i64 + %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap + func.return %descA : i64 +} + + transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1