Index: mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td =================================================================== --- mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -151,6 +151,14 @@ `bypassL1` attribute is hint to the backend and hardware that the copy should by pass the L1 cache, this may be dropped by the backend or hardware. + `dstElements` attribute is the total number of elements written to + destination (shared memory). + `srcElements` argument is the total number of elements read from + source (global memory). + + srcElements` is an optional argument and when present it only reads + srcElements number of elements from the source global memory and zero fills + the rest of the elements in the destination shared memory. In order to do a copy and wait for the result we need the following combination: @@ -183,10 +191,11 @@ Variadic:$dstIndices, Arg:$src, Variadic:$srcIndices, - IndexAttr:$numElements, + IndexAttr:$dstElements, + Optional:$srcElements, OptionalAttr:$bypassL1); let assemblyFormat = [{ - $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $numElements + $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $dstElements (`,` $srcElements^)? attr-dict `:` type($src) `to` type($dst) }]; let hasVerifier = 1; Index: mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp =================================================================== --- mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -348,6 +348,36 @@ } }; +static void emitAsyncCopyZfillInlineAsm(Location loc, Value dstPtr, + Value srcPtr, Value dstBytes, + Value srcElements, + mlir::MemRefType elementType, + ConversionPatternRewriter &rewriter) { + + auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), + LLVM::AsmDialect::AD_ATT); + const char *asmStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n"; + const char *asmConstraints = "r,l,n,r"; + + Value c3I32 = + rewriter.create(loc, rewriter.getI32Type(), 3); + Value bitwidth = rewriter.create( + loc, rewriter.getI32Type(), elementType.getElementTypeBitWidth()); + Value srcElementsI32 = + rewriter.create(loc, rewriter.getI32Type(), srcElements); + Value srcBytes = rewriter.create( + loc, rewriter.create(loc, bitwidth, srcElementsI32), c3I32); + + SmallVector asmVals{srcPtr, dstPtr, dstBytes, srcBytes}; + + rewriter.create( + loc, LLVM::LLVMVoidType::get(rewriter.getContext()), /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/asmConstraints, /*has_side_effects=*/true, + /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, + /*operand_attrs=*/ArrayAttr()); +} + struct NVGPUAsyncCopyLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -356,6 +386,7 @@ LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); auto dstMemrefType = op.getDst().getType().cast(); Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(), @@ -377,15 +408,27 @@ i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace); scrPtr = rewriter.create(loc, srcPointerGlobalType, scrPtr); - int64_t numElements = adaptor.getNumElements().getZExtValue(); + int64_t dstElements = adaptor.getDstElements().getZExtValue(); int64_t sizeInBytes = - (dstMemrefType.getElementTypeBitWidth() * numElements) / 8; + (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; // bypass L1 is only supported for byte sizes of 16, we drop the hint // otherwise. UnitAttr bypassL1 = sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr(); - rewriter.create( - loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1); + + // cp_async_zfill when the optional SrcElement is present. + if (op.getSrcElements()) + emitAsyncCopyZfillInlineAsm(loc, dstPtr, scrPtr, + rewriter.create( + loc, rewriter.getI32Type(), sizeInBytes), + adaptor.getSrcElements(), srcMemrefType, + rewriter); + + // cp_async when the optional SrcElement is not present. + else + rewriter.create(loc, dstPtr, scrPtr, + rewriter.getI32IntegerAttr(sizeInBytes), + bypassL1); // Drop the result token. Value zero = rewriter.create( Index: mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir =================================================================== --- mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -297,3 +297,19 @@ return %0 : !nvgpu.device.async.token } +// ----- + +// CHECK-LABEL: @async_cp_zfill( +// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index) +func.func @async_cp_zfill( + %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) { + + // CHECK-DAG: lvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[SRCPTR:.*]], %[[DSTPTR:.*]], %[[DSTBYTES:.*]], %[[SRCBYTES:.*]] : (!llvm.ptr, !llvm.ptr, i32, i32) -> !llvm.void + %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3> + // CHECK: nvvm.cp.async.commit.group + %1 = nvgpu.device_async_create_group %0 + // CHECK: nvvm.cp.async.wait.group 1 + nvgpu.device_async_wait %1 { numGroups = 1 : i32 } + + return +}