diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -100,6 +100,17 @@ def NVGPU_MBarrierToken : NVGPU_Type<"MBarrierToken", "mbarrier.token", []> { } +// https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-map +def NVGPU_TmaDescriptor : NVGPU_Type<"TmaDescriptor", "tma.descriptor", []> { + let summary = "TMA descriptor"; + let parameters = (ins "MemRefType":$tensor); + let description = [{ + `nvgpu.tma.descriptor` is a type that represents a TMA descriptor. It is + 128-byte object either in constant space or kernel paramater. + }]; + let assemblyFormat = "`<` struct(params) `>`"; +} + //===----------------------------------------------------------------------===// // NVGPU Op Definitions //===----------------------------------------------------------------------===// @@ -469,4 +480,21 @@ let assemblyFormat = "$barrier `,` $count attr-dict `:` type($barrier) `->` type($token)"; } + +def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", []> { + let summary = "TMA asynchronous load"; + let description = [{ + Load the give descriptor from global memory to shared memory asynchronously. + }]; + let arguments = (ins Arg:$dst, + NVGPU_MBarrier:$barrier, + NVGPU_TmaDescriptor:$tmaDescriptor, + Variadic:$coordinates); + let assemblyFormat = [{ + $tmaDescriptor `[` $coordinates `]` `,` $barrier `to` $dst + attr-dict `:` type($tmaDescriptor) `,` type($barrier) `,` type($dst) + }]; + let hasVerifier = 1; +} + #endif // NVGPU diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -402,6 +402,9 @@ converter.addConversion([&](nvgpu::MBarrierType type) -> Type { return converter.convertType(createMBarrierMemrefType(rewriter, type)); }); + converter.addConversion([&](nvgpu::TmaDescriptorType type) -> Type { + return converter.getPointerType(type.getTensor().getElementType()); + }); populateNVGPUToNVVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); @@ -850,6 +853,29 @@ } }; +struct TmaAsyncLoadOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto dest = rewriter.create(op->getLoc(), + adaptor.getDst(), 1); + Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(), + op.getBarrier(), adaptor.getBarrier()); + + SmallVector coords = adaptor.getCoordinates(); + for (auto [index, value] : llvm::enumerate(coords)) { + if (!value.getType().isInteger(32)) { + coords[index] = rewriter.create( + op->getLoc(), rewriter.getI32Type(), value); + } + } + rewriter.replaceOpWithNewOp( + op, dest, adaptor.getTmaDescriptor(), barrier, coords); + return success(); + } +}; } // namespace void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, @@ -860,6 +886,7 @@ NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete NVGPUMBarrierTestWaitLowering, // nvgpu.try_wait_parity + TmaAsyncLoadOpLowering, // nvgpu.tma.async.load MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -317,6 +317,31 @@ return success(); } +//===----------------------------------------------------------------------===// +// NVGPU_TmaAsyncLoadOp +//===----------------------------------------------------------------------===// + +LogicalResult TmaAsyncLoadOp::verify() { + // Destination memref + auto dstMemref = llvm::cast(getDst().getType()); + if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) { + return emitError() + << "The operation stores data to shared memory, but " + "the destination memref does not have a memory space of " + << NVGPUDialect::kSharedMemoryAddressSpace; + } + if (getCoordinates().size() > 5) { + return emitError() << "Maximum 5 coordinates are supported."; + } + if (getCoordinates().size() != size_t(dstMemref.getRank())) { + return emitError() << "Destionation memref rank is " + << size_t(dstMemref.getRank()) << " but there are " + << getCoordinates().size() + << " coordinates. They must match."; + } + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd dialect, type, and op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -558,3 +558,17 @@ func.return } + +// ----- + +// CHECK-LABEL: func @async_tma_load + +func.func @async_tma_load(%tmaDescriptor: !nvgpu.tma.descriptor>, + %buffer: memref<128x128xf32, 3>, + %mbarrier : !nvgpu.mbarrier.barrier>) { + %crd0 = arith.constant 256 : index + %crd1 = arith.constant 256 : index + // CHECK : nvvm.cp.async.bulk.tensor.shared.cluster.global %{{.*}}, %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}] + nvgpu.tma.async.load %tmaDescriptor[%crd0, %crd1], %mbarrier to %buffer : !nvgpu.tma.descriptor>, !nvgpu.mbarrier.barrier>, memref<128x128xf32, 3> + func.return +}