diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -625,4 +625,35 @@ let hasVerifier = 1; } +def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> { + let summary = "Generate a wgmma matrix descriptor"; + let description = [{ + This Op builds a wgmma descriptor that is used by wgmma matrix multiply + and accumulate. + + The descriptor specifies the properties of the matrix in shared memory that + is a multiplicand in the matrix multiply and accumulate operation. + + The descriptor is a 64-bit value contained in a register with the following + ``` + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | 0-13 |14-15| 16-29 |30-31| 32-45 |46-48|49-51| 52-61 |62-63| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | 14bits |2bits| 14bits |2bits| 14bits |2bits|3bits| 10bits |2bits| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | BaseAddr| 0 | LeadingDim| 0 | Stride | 0 |Offst| 0 |Swzle| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + ``` + + See for more details: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor + + }]; + let results = (outs I64:$descriptor); + let arguments = (ins Arg:$tensor, + NVGPU_TensorMapDescriptor:$tensorMap); + let assemblyFormat = [{$tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap)}]; + let hasVerifier = 1; +} + #endif // NVGPU diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -934,6 +934,76 @@ return success(); } }; +struct NVGPUGenerateGmmaDescriptorLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + nvgpu::GenerateGmmaDescriptorOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + + nvgpu::TensorMapSwizzleKind swizzleKind = + op.getTensorMap().getType().getSwizzle(); + + unsigned layout = + (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128 + : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64 + : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32 + : 1; + unsigned swizzle = + (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 1 + : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 2 + : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3 + : 0; + + auto ti64 = rewriter.getIntegerType(64); + auto makeConst = [&](uint64_t index) -> Value { + return rewriter.create( + loc, ti64, rewriter.getI64IntegerAttr(index)); + }; + auto shiftLeft = [&](Value value, unsigned shift) -> Value { + return rewriter.create(loc, ti64, value, makeConst(shift)); + }; + auto shiftRight = [&](Value value, unsigned shift) -> Value { + return rewriter.create(loc, ti64, value, makeConst(shift)); + }; + auto insertBit = [&](Value desc, Value val, int startBit) { + return rewriter.create(loc, ti64, desc, + shiftLeft(val, startBit)); + }; + + int ex4LSB = 4; + Value strideDim = makeConst((layout << 3) >> ex4LSB); + int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); + Value leadDim = makeConst((sizeN * layout) >> ex4LSB); + Value baseAddr = getStridedElementPtr( + op->getLoc(), cast(op.getTensor().getType()), + adaptor.getTensor(), {}, rewriter); + Value basePtr = rewriter.create(loc, ti64, baseAddr); + // Just use 14 bits for base address + Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50); + + int startSwizzleBit = 62, startOffsetBit = 49, startStrideBit = 32, + startLeadBit = 16, startBaseAddrBit = 0; + Value dsc = makeConst(0); + // // [62,64) swizzle type + dsc = insertBit(dsc, makeConst(swizzle), startSwizzleBit); + // // [49,52) base_offset + dsc = insertBit(dsc, makeConst(0), startOffsetBit); + // // [32,46) stride + dsc = insertBit(dsc, strideDim, startStrideBit); + // // [16,30) leading dimension + dsc = insertBit(dsc, leadDim, startLeadBit); + // // [0,14) start_address + dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); + + rewriter.replaceOp(op, dsc); + return success(); + } +}; static Value makeI64Const(RewriterBase &rewriter, Operation *op, int32_t index) { @@ -1064,6 +1134,7 @@ NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load + NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -366,6 +366,42 @@ return success(); } +//===----------------------------------------------------------------------===// +// NVGPU_GenerateGmmaDescriptorOp +//===----------------------------------------------------------------------===// + +LogicalResult GenerateGmmaDescriptorOp::verify() { + MemRefType memrefType = getTensor().getType(); + MemRefType tensorMapType = getTensorMap().getType().getTensor(); + + if (memrefType != tensorMapType) + return emitError() << "memref and tensor map type mismatch"; + + if (!memrefType.hasStaticShape() || !tensorMapType.hasStaticShape()) + return emitError() << "supports only static shapes"; + + if (memrefType.getRank() != 2) + return emitError() << "supports only 2d memref is supported for now"; + + if (getTensorMap().getType().getSwizzle() != + TensorMapSwizzleKind::SWIZZLE_128B) { + return emitError() << "supports only " + << stringifyTensorMapSwizzleKind( + TensorMapSwizzleKind::SWIZZLE_128B) + << " is supported for the time being"; + } + + if (getTensorMap().getType().getInterleave() != + TensorMapInterleaveKind::INTERLEAVE_NONE) { + return emitError() << "supports only " + << stringifyTensorMapInterleaveKind( + TensorMapInterleaveKind::INTERLEAVE_NONE) + << " is supported for the time being"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd dialect, type, and op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -631,6 +631,47 @@ } } +!tensorMap = !nvgpu.tensormap.descriptor, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none> +memref.global "private" @dynamicShmem : memref<0xf16,3> +// CHECK-LABEL: func @create_wgmma_descriptor( +func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> i64{ + %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3> + %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3> + // CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3> + // CHECK: %[[Sre:.+]] = memref.reinterpret_cast %[[S0]] to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<0xf16, 3> to memref<128x64xf16, 3> + // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[Sre]] : memref<128x64xf16, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[c64:.+]] = llvm.mlir.constant(64 : i64) : i64 + // CHECK: %[[c1024:.+]] = llvm.mlir.constant(1024 : i64) : i64 + // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[S3:.+]] = llvm.ptrtoint %[[S2]] : !llvm.ptr<3> to i64 + // CHECK: %[[S4:.+]] = llvm.mlir.constant(46 : i64) : i64 + // CHECK: %[[S5:.+]] = llvm.shl %[[S3]], %[[S4]] : i64 + // CHECK: %[[S6:.+]] = llvm.mlir.constant(50 : i64) : i64 + // CHECK: %[[S7:.+]] = llvm.lshr %[[S5]], %[[S6]] : i64 + // CHECK: %[[S8:.+]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[S9:.+]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: %[[S10:.+]] = llvm.mlir.constant(62 : i64) : i64 + // CHECK: %[[S11:.+]] = llvm.shl %[[S9]], %[[S10]] : i64 + // CHECK: %[[S12:.+]] = llvm.or %[[S8]], %[[S11]] : i64 + // CHECK: %[[S13:.+]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[S14:.+]] = llvm.mlir.constant(49 : i64) : i64 + // CHECK: %[[S15:.+]] = llvm.shl %[[S13]], %[[S14]] : i64 + // CHECK: %[[S16:.+]] = llvm.or %[[S12]], %[[S15]] : i64 + // CHECK: %[[S18:.+]] = llvm.mlir.constant(32 : i64) : i64 + // CHECK: %[[S19:.+]] = llvm.shl %[[c64]], %[[S18]] : i64 + // CHECK: %[[S20:.+]] = llvm.or %[[S16]], %[[S19]] : i64 + // CHECK: %[[S22:.+]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[S23:.+]] = llvm.shl %[[c1024]], %[[S22]] : i64 + // CHECK: %[[S24:.+]] = llvm.or %[[S20]], %[[S23]] : i64 + // CHECK: %[[S25:.+]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[S26:.+]] = llvm.shl %[[S7]], %[[S25]] : i64 + // CHECK: %[[S27:.+]] = llvm.or %[[S24]], %[[S26]] : i64 + // CHECK: return %[[S27]] : i64 + %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap + func.return %descA : i64 +} + + transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1