diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -625,4 +625,35 @@ let hasVerifier = 1; } +def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> { + let summary = "Generate a wgmma matrix descriptor"; + let description = [{ + This Op builds a wgmma descriptor that is used by wgmma matrix multiply + and accumulate. + + The descriptor specifies the properties of the matrix in shared memory that + is a multiplicand in the matrix multiply and accumulate operation. + + The descriptor is a 64-bit value contained in a register with the following + ``` + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | 0-13 |14-15| 16-29 |30-31| 32-45 |46-48|49-51| 52-61 |62-63| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | 14bits |2bits| 14bits |2bits| 14bits |2bits|3bits| 10bits |2bits| + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + | BaseAddr| 0 | LeadingDim| 0 | Stride | 0 |Offst| 0 | Mod | + +---------+-----+-----------+-----+-----------+-----+-----+-----------+-----+ + ``` + + See for more details: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor + + }]; + let results = (outs I64:$descriptor); + let arguments = (ins Arg:$tensor, + NVGPU_TensorMapDescriptor:$tensorMap); + let assemblyFormat = [{$tensor `,` $tensorMap attr-dict `:` type($tensor) `,` type($tensorMap)}]; + let hasVerifier = 1; +} + #endif // NVGPU diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -929,6 +929,73 @@ return success(); } }; +struct NVGPUGenerateGmmaDescriptorLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + nvgpu::GenerateGmmaDescriptorOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + nvgpu::TensorMapSwizzleKind shmemLayout = + op.getTensorMap().getType().getSwizzle(); + + MemRefType memrefType = op.getTensor().getType(); + int64_t sizeH = memrefType.getDimSize(0); + // Note: currently only 128b swizzling is supported. Also not sure what to + // put without swizzling + unsigned swizzleSize = + (shmemLayout == nvgpu::TensorMapSwizzleKind::SWIZZLE_128B) ? 128 + : (shmemLayout == nvgpu::TensorMapSwizzleKind::SWIZZLE_64B) ? 64 + : (shmemLayout == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 32 + : 1; + + SmallVector extractedValue = getTypeConverter()->promoteOperands( + op->getLoc(), {op.getTensor()}, {adaptor.getTensor()}, rewriter); + if (extractedValue.empty()) + return failure(); + + auto ti64 = rewriter.getIntegerType(64); + auto makeConst = [&](int32_t index) -> Value { + return rewriter.create( + loc, ti64, rewriter.getI64IntegerAttr(index)); + }; + auto shiftLeft = [&](Value value, unsigned shift) -> Value { + return rewriter.create(loc, ti64, value, makeConst(shift)); + }; + auto shiftRight = [&](Value value, unsigned shift) -> Value { + return rewriter.create(loc, ti64, value, makeConst(shift)); + }; + auto makeOr = [&](Value lhs, Value rhs) -> Value { + return rewriter.create(loc, ti64, lhs, rhs); + }; + + // Trim bits and get only 14 bits + Value basePtr = extractedValue.front(); + Value ptri64 = rewriter.create(loc, ti64, basePtr); + Value startAdress = shiftRight(shiftLeft(ptri64, 46), 50); + // Note: not sure this calculation is always correct. + Value strideDim = shiftRight(shiftLeft(makeConst(swizzleSize), 3), 4); + Value leadingDim = shiftRight(makeConst(swizzleSize * sizeH), 4); + + Value descZero = makeConst(0); + // [62,64) layout type + Value desc1 = makeOr(descZero, shiftLeft(makeConst(int(shmemLayout)), 62)); + // [32,46) stride + Value desc2 = makeOr(desc1, shiftLeft(strideDim, 32)); + // [16,30) leading dimension + Value desc3 = makeOr(desc2, shiftLeft(leadingDim, 16)); + // [49,52) base_offset + Value desc4 = makeOr(desc3, shiftLeft(makeConst(0), 49)); + // [0,14) start_address + Value desc = makeOr(desc4, startAdress); + + rewriter.replaceOp(op, desc); + return success(); + } +}; static Value makeI64Const(RewriterBase &rewriter, Operation *op, int32_t index) { @@ -1059,6 +1126,7 @@ NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load + NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -366,6 +366,42 @@ return success(); } +//===----------------------------------------------------------------------===// +// NVGPU_GenerateGmmaDescriptorOp +//===----------------------------------------------------------------------===// + +LogicalResult GenerateGmmaDescriptorOp::verify() { + MemRefType memrefType = getTensor().getType(); + MemRefType tensorMapType = getTensorMap().getType().getTensor(); + + if (memrefType != tensorMapType) + return emitError() << "memref and tensor map type mismatch"; + + if (!memrefType.hasStaticShape() || !tensorMapType.hasStaticShape()) + return emitError() << "supports only static shapes"; + + if (memrefType.getRank() != 2) + return emitError() << "supports only 2d memref is supported for now"; + + if (getTensorMap().getType().getSwizzle() != + TensorMapSwizzleKind::SWIZZLE_128B) { + return emitError() << "supports only " + << stringifyTensorMapSwizzleKind( + TensorMapSwizzleKind::SWIZZLE_128B) + << " is supported for the time being"; + } + + if (getTensorMap().getType().getInterleave() != + TensorMapInterleaveKind::INTERLEAVE_NONE) { + return emitError() << "supports only " + << stringifyTensorMapInterleaveKind( + TensorMapInterleaveKind::INTERLEAVE_NONE) + << " is supported for the time being"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd dialect, type, and op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --convert-nvgpu-to-nvvm='use-opaque-pointers=1' --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-nvgpu-to-nvvm='use-opaque-pointers=1' --split-input-file -cse -canonicalize %s | FileCheck %s // CHECK-LABEL: @m16n8k16_fp16 func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { @@ -647,3 +647,48 @@ %tensorMap1d = nvgpu.tma.create.descriptor %devicePtr1d_unranked box[%crd1] : memref<*xf32> -> !tensorMap1d func.return } + +// ----- + +!tensorMap = !nvgpu.tensormap.descriptor, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none> +memref.global "private" @dynamicShmem : memref<0xf16,3> +func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> i64{ + %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3> + %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3> + // CHECK: %[[S0:.+]] = memref.get_global @dynamicShmem : memref<0xf16, 3> + // CHECK: %[[R0:.+]] = memref.reinterpret_cast %[[S0]] to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<0xf16, 3> to memref<128x64xf16, 3> + // CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[R0]] : memref<128x64xf16, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[PTR:.+]] = llvm.ptrtoint %[[S2]] : !llvm.ptr<3> to i64 + // CHECK: %[[S10:.+]] = llvm.mlir.constant(46 : i64) : i64 + // CHECK: %[[S11:.+]] = llvm.shl %[[PTR]], %[[S10]] : i64 + // CHECK: %[[S12:.+]] = llvm.mlir.constant(50 : i64) : i64 + // CHECK: %[[S13:.+]] = llvm.lshr %[[S11]], %[[S12]] : i64 + // CHECK: %[[S14:.+]] = llvm.mlir.constant(128 : i64) : i64 + // CHECK: %[[S15:.+]] = llvm.mlir.constant(3 : i64) : i64 + // CHECK: %[[S16:.+]] = llvm.shl %[[S14]], %[[S15]] : i64 + // CHECK: %[[S17:.+]] = llvm.mlir.constant(4 : i64) : i64 + // CHECK: %[[S18:.+]] = llvm.lshr %[[S16]], %[[S17]] : i64 + // CHECK: %[[S19:.+]] = llvm.mlir.constant(16384 : i64) : i64 + // CHECK: %[[S20:.+]] = llvm.mlir.constant(4 : i64) : i64 + // CHECK: %[[S21:.+]] = llvm.lshr %[[S19]], %[[S20]] : i64 + // CHECK: %[[S22:.+]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[S23:.+]] = llvm.mlir.constant(3 : i64) : i64 + // CHECK: %[[S24:.+]] = llvm.mlir.constant(62 : i64) : i64 + // CHECK: %[[S25:.+]] = llvm.shl %[[S23]], %[[S24]] : i64 + // CHECK: %[[S26:.+]] = llvm.or %[[S22]], %[[S25]] : i64 + // CHECK: %[[S27:.+]] = llvm.mlir.constant(32 : i64) : i64 + // CHECK: %[[S28:.+]] = llvm.shl %[[S18]], %[[S27]] : i64 + // CHECK: %[[S29:.+]] = llvm.or %[[S26]], %[[S28]] : i64 + // CHECK: %[[S30:.+]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[S31:.+]] = llvm.shl %[[S21]], %[[S30]] : i64 + // CHECK: %[[S32:.+]] = llvm.or %[[S29]], %[[S31]] : i64 + // CHECK: %[[S33:.+]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[S34:.+]] = llvm.mlir.constant(49 : i64) : i64 + // CHECK: %[[S35:.+]] = llvm.shl %[[S33]], %[[S34]] : i64 + // CHECK: %[[S36:.+]] = llvm.or %[[S32]], %[[S35]] : i64 + // CHECK: %[[DESC:.+]] = llvm.or %[[S36]], %[[S13]] : i64 + // CHECK: return %[[DESC]] : i64 + %descA = nvgpu.wgmma.generate.descriptor %lhsShmem, %tensorMap : memref<128x64xf16,3>, !tensorMap + func.return %descA : i64 +}