diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -625,4 +625,35 @@ let hasVerifier = 1; } +def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> { + let summary = "Generate a wgmma matrix descriptor"; + let description = [{ + This Op builds a wgmma descriptor that is used by wgmma matrix multiply + and accumulate. + + The descriptor specifies the properties of the matrix in shared memory that + is a multiplicand in the matrix multiply and accumulate operation. + + The descriptor is a 64-bit value contained in a register with the following + ``` + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | 0-13 |14-15| 16-29 |30-31| 32-45 |46-48|49-51| 52-61 |62-63| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | 14bits |2bits| 14bits |2bits| 14bits |2bits|3bits| 10bits |2bits| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | BaseAddr| 0 | LeadingDim| 0 | Stride | 0 |Offst| 0 |Swzle| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + ``` + + See for more details: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor + + }]; + let results = (outs I64:$descriptor); + let arguments = (ins Arg:$tensor, + NVGPU_TensorMapDescriptor:$tensorMap); + let assemblyFormat = [{$tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap)}]; + let hasVerifier = 1; +} + #endif // NVGPU diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -934,6 +934,79 @@ return success(); } }; +struct NVGPUGenerateGmmaDescriptorLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + nvgpu::GenerateGmmaDescriptorOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + constexpr int startLayoutBit = 62; + constexpr int startOffsetBit = 49; + constexpr int startStrideBit = 32; + constexpr int startLeadingDimBit = 16; + constexpr int startBaseAddrBit = 50; + + Location loc = op->getLoc(); + + nvgpu::TensorMapSwizzleKind swizzleKind = + op.getTensorMap().getType().getSwizzle(); + + unsigned swizzle = + (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128 + : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64 + : (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32 + : 1; + + auto ti64 = rewriter.getIntegerType(64); + auto makeConst = [&](uint64_t index) -> Value { + return rewriter.create( + loc, ti64, rewriter.getI64IntegerAttr(index)); + }; + auto shiftLeft = [&](Value value, unsigned shift) -> Value { + return rewriter.create(loc, ti64, value, makeConst(shift)); + }; + auto shiftRight = [&](Value value, unsigned shift) -> Value { + return rewriter.create(loc, ti64, value, makeConst(shift)); + }; + auto makeOr = [&](Value lhs, Value rhs) -> Value { + return rewriter.create(loc, ti64, lhs, rhs); + }; + auto exclude4LSB = [&](Value value, unsigned startBit) -> Value { + return shiftRight(shiftLeft(value, (startBit - 4)), startBit); + }; + + Value desc = makeConst(0); + // [62,64) layout type + // 6 bits unused, 2 bits [6,8) + desc = makeOr(desc, shiftLeft(makeConst(uint64_t(1)), startLayoutBit)); + // [49,52) base_offset + // 1 bit unused, 3 bits [1,4), 4 bits unused + // Valid only for SWIZZLE_128B and SWIZZLE_64B + desc = makeOr(desc, shiftLeft(makeConst(0), startOffsetBit)); + // [32,46) stride + // 14 bits [0,14), 2 bits unused + Value strideDim = shiftRight(shiftLeft(makeConst(swizzle), 3), 4); + // Exclude 4LSB + desc = makeOr(desc, exclude4LSB(strideDim, startStrideBit)); + // [16,30) leading dimension + // 14 bits [0,14), 2 bits unused + // Not used with swizzling. Exclude 4LSB + desc = makeOr(desc, exclude4LSB(makeConst(1), startLeadingDimBit)); + // [0,14) start_address + // 14 bits [0,14), 2 bits unused + Value basePtr = rewriter.create( + op->getLoc(), adaptor.getTensor(), 1); + Value ptri64 = rewriter.create(loc, ti64, basePtr); + // Exclude 4LSB + Value startAdress = exclude4LSB(ptri64, startBaseAddrBit); + desc = makeOr(desc, startAdress); + + rewriter.replaceOp(op, desc); + return success(); + } +}; static Value makeI64Const(RewriterBase &rewriter, Operation *op, int32_t index) { @@ -1064,6 +1137,7 @@ NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load + NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -366,6 +366,42 @@ return success(); } +//===----------------------------------------------------------------------===// +// NVGPU_GenerateGmmaDescriptorOp +//===----------------------------------------------------------------------===// + +LogicalResult GenerateGmmaDescriptorOp::verify() { + MemRefType memrefType = getTensor().getType(); + MemRefType tensorMapType = getTensorMap().getType().getTensor(); + + if (memrefType != tensorMapType) + return emitError() << "memref and tensor map type mismatch"; + + if (!memrefType.hasStaticShape() || !tensorMapType.hasStaticShape()) + return emitError() << "supports only static shapes"; + + if (memrefType.getRank() != 2) + return emitError() << "supports only 2d memref is supported for now"; + + if (getTensorMap().getType().getSwizzle() != + TensorMapSwizzleKind::SWIZZLE_128B) { + return emitError() << "supports only " + << stringifyTensorMapSwizzleKind( + TensorMapSwizzleKind::SWIZZLE_128B) + << " is supported for the time being"; + } + + if (getTensorMap().getType().getInterleave() != + TensorMapInterleaveKind::INTERLEAVE_NONE) { + return emitError() << "supports only " + << stringifyTensorMapInterleaveKind( + TensorMapInterleaveKind::INTERLEAVE_NONE) + << " is supported for the time being"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd dialect, type, and op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -convert-nvgpu-to-nvvm='use-opaque-pointers=1' | FileCheck %s // RUN: mlir-opt %s -test-transform-dialect-interpreter | FileCheck %s +// RUN: mlir-opt --convert-nvgpu-to-nvvm='use-opaque-pointers=1' --split-input-file -cse -canonicalize %s | FileCheck %s // CHECK-LABEL: @m16n8k16_fp16 func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { @@ -641,4 +642,48 @@ transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter {use_opaque_pointers = true} } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "scf"], partial_conversion} : !transform.any_op + +// ----- + +!tensorMap = !nvgpu.tensormap.descriptor, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none> +memref.global "private" @dynamicShmem : memref<0xf16,3> +func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> i64{ + %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3> + %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3> + // CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3> + // CHECK: %[[R0:.+]] = memref.reinterpret_cast %[[S0]] to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<0xf16, 3> to memref<128x64xf16, 3> + // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[R0]] : memref<128x64xf16, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[PTR:.+]] = llvm.ptrtoint %[[S2]] : !llvm.ptr<3> to i64 + // CHECK: %[[S10:.+]] = llvm.mlir.constant(46 : i64) : i64 + // CHECK: %[[S11:.+]] = llvm.shl %[[PTR]], %[[S10]] : i64 + // CHECK: %[[S12:.+]] = llvm.mlir.constant(50 : i64) : i64 + // CHECK: %[[S13:.+]] = llvm.lshr %[[S11]], %[[S12]] : i64 + // CHECK: %[[S14:.+]] = llvm.mlir.constant(128 : i64) : i64 + // CHECK: %[[S15:.+]] = llvm.mlir.constant(3 : i64) : i64 + // CHECK: %[[S16:.+]] = llvm.shl %[[S14]], %[[S15]] : i64 + // CHECK: %[[S17:.+]] = llvm.mlir.constant(4 : i64) : i64 + // CHECK: %[[S18:.+]] = llvm.lshr %[[S16]], %[[S17]] : i64 + // CHECK: %[[S19:.+]] = llvm.mlir.constant(16384 : i64) : i64 + // CHECK: %[[S20:.+]] = llvm.mlir.constant(4 : i64) : i64 + // CHECK: %[[S21:.+]] = llvm.lshr %[[S19]], %[[S20]] : i64 + // CHECK: %[[S22:.+]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[S23:.+]] = llvm.mlir.constant(3 : i64) : i64 + // CHECK: %[[S24:.+]] = llvm.mlir.constant(62 : i64) : i64 + // CHECK: %[[S25:.+]] = llvm.shl %[[S23]], %[[S24]] : i64 + // CHECK: %[[S26:.+]] = llvm.or %[[S22]], %[[S25]] : i64 + // CHECK: %[[S27:.+]] = llvm.mlir.constant(32 : i64) : i64 + // CHECK: %[[S28:.+]] = llvm.shl %[[S18]], %[[S27]] : i64 + // CHECK: %[[S29:.+]] = llvm.or %[[S26]], %[[S28]] : i64 + // CHECK: %[[S30:.+]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[S31:.+]] = llvm.shl %[[S21]], %[[S30]] : i64 + // CHECK: %[[S32:.+]] = llvm.or %[[S29]], %[[S31]] : i64 + // CHECK: %[[S33:.+]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[S34:.+]] = llvm.mlir.constant(49 : i64) : i64 + // CHECK: %[[S35:.+]] = llvm.shl %[[S33]], %[[S34]] : i64 + // CHECK: %[[S36:.+]] = llvm.or %[[S32]], %[[S35]] : i64 + // CHECK: %[[DESC:.+]] = llvm.or %[[S36]], %[[S13]] : i64 + // CHECK: return %[[DESC]] : i64 + %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap + func.return %descA : i64 }