diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -151,6 +151,14 @@ `bypassL1` attribute is hint to the backend and hardware that the copy should by pass the L1 cache, this may be dropped by the backend or hardware. + `dstElements` attribute is the total number of elements written to + destination (shared memory). + `srcElements` argument is the total number of elements read from + source (global memory). + + srcElements` is an optional argument and when present it only reads + srcElements number of elements from the source global memory and zero fills + the rest of the elements in the destination shared memory. In order to do a copy and wait for the result we need the following combination: @@ -183,10 +191,11 @@ Variadic:$dstIndices, Arg:$src, Variadic:$srcIndices, - IndexAttr:$numElements, + IndexAttr:$dstElements, + Optional:$srcElements, OptionalAttr:$bypassL1); let assemblyFormat = [{ - $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $numElements + $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $dstElements (`,` $srcElements^)? attr-dict `:` type($src) `to` type($dst) }]; let hasVerifier = 1; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -354,6 +354,35 @@ } }; +static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr, + Value dstBytes, Value srcElements, + mlir::MemRefType elementType, + ConversionPatternRewriter &rewriter) { + auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), + LLVM::AsmDialect::AD_ATT); + const char *asmStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n"; + const char *asmConstraints = "r,l,n,r"; + + Value c3I32 = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3)); + Value bitwidth = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(elementType.getElementTypeBitWidth())); + Value srcElementsI32 = + rewriter.create(loc, rewriter.getI32Type(), srcElements); + Value srcBytes = rewriter.create( + loc, rewriter.create(loc, bitwidth, srcElementsI32), c3I32); + + SmallVector asmVals{dstPtr, srcPtr, dstBytes, srcBytes}; + + rewriter.create( + loc, LLVM::LLVMVoidType::get(rewriter.getContext()), /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/asmConstraints, /*has_side_effects=*/true, + /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, + /*operand_attrs=*/ArrayAttr()); +} + struct NVGPUAsyncCopyLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -383,15 +412,33 @@ i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace); scrPtr = rewriter.create(loc, srcPointerGlobalType, scrPtr); - int64_t numElements = adaptor.getNumElements().getZExtValue(); + int64_t dstElements = adaptor.getDstElements().getZExtValue(); int64_t sizeInBytes = - (dstMemrefType.getElementTypeBitWidth() * numElements) / 8; + (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; // bypass L1 is only supported for byte sizes of 16, we drop the hint // otherwise. UnitAttr bypassL1 = sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr(); - rewriter.create( - loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassL1); + + // When the optional SrcElements argument is present, the source (global + // memory) of CpAsyncOp is read only for SrcElements number of elements. The + // rest of the DstElements in the destination (shared memory) are filled + // with zeros. + if (op.getSrcElements()) + emitCpAsyncOpZfillAsm(loc, dstPtr, scrPtr, + rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(sizeInBytes)), + adaptor.getSrcElements(), srcMemrefType, rewriter); + + // When the optional SrcElements argument is *not* present, the regular + // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global + // memory) to fill DstElements number of elements in the destination (shared + // memory). + else + rewriter.create(loc, dstPtr, scrPtr, + rewriter.getI32IntegerAttr(sizeInBytes), + bypassL1); // Drop the result token. Value zero = rewriter.create( diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -297,3 +297,19 @@ return %0 : !nvgpu.device.async.token } +// ----- + +// CHECK-LABEL: @async_cp_zfill( +// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index) +func.func @async_cp_zfill( + %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) { + + // CHECK-DAG: lvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES:.*]], %[[SRCBYTES:.*]] : (!llvm.ptr, !llvm.ptr, i32, i32) -> !llvm.void + %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3> + // CHECK: nvvm.cp.async.commit.group + %1 = nvgpu.device_async_create_group %0 + // CHECK: nvvm.cp.async.wait.group 1 + nvgpu.device_async_wait %1 { numGroups = 1 : i32 } + + return +}