diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -914,8 +914,9 @@ LogicalResult matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto dest = rewriter.create(op->getLoc(), - adaptor.getDst(), 1); + auto srcMemrefType = cast(op.getDst().getType()); + Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType, + adaptor.getDst(), {}, rewriter); Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(), op.getBarrier(), adaptor.getBarrier()); diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -647,3 +647,35 @@ %tensorMap1d = nvgpu.tma.create.descriptor %devicePtr1d_unranked box[%crd1] : memref<*xf32> -> !tensorMap1d func.return } + +// ----- + +!lhsTensorMap = !nvgpu.tensormap.descriptor, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> +!rhsTensorMap = !nvgpu.tensormap.descriptor, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> + +!barrierType = !nvgpu.mbarrier.barrier> + +!shmemlhs = memref<128x64xf16,3> +!shmemrhs = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3> + +module @mymodule { + // Dynamic Shared memory + memref.global "private" @dynamicShmem : memref<0xf16,3> + + func.func @async_tma_load(%lhsTensorMap: !lhsTensorMap, %rhsTensorMap: !rhsTensorMap, %mbarrier: !barrierType) { + %c0 = arith.constant 0 : index + %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3> + %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to !shmemlhs + %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2,64,128], strides: [8192,128,1] : memref<0xf16, 3> to memref<2x64x128xf16,3> + %rhsShmem3 = memref.subview %rhsShmem2[1,0,0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16,3> to memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3> + %rhsShmem = memref.subview %rhsShmem3[0,0,0][1, 64, 128][1, 1, 1] : memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3> to !shmemrhs + // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global + nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs + // CHECK: %[[desc:.+]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[c8192:.+]] = llvm.mlir.constant(8192 : index) : i64 + // CHECK: %[[shmemOfset:.+]] = llvm.getelementptr %[[desc]][%[[c8192]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16 + // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[shmemOfset]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}] : !llvm.ptr<3>, !llvm.ptr, !llvm.ptr<3>, i32, i32 + nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier to %rhsShmem : !rhsTensorMap, !barrierType -> !shmemrhs + return + } +}