diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -506,12 +506,30 @@ let assemblyFormat = "$mask attr-dict `:` type($mask)"; } +// https://docs.nvidia.com/cuda/parallel-thread-execution/#id62 +def LoadCacheModifierCA : I32EnumAttrCase<"CA", 0, "ca">; +def LoadCacheModifierCG : I32EnumAttrCase<"CG", 1, "cg">; +def LoadCacheModifierCS : I32EnumAttrCase<"CS", 2, "cs">; +def LoadCacheModifierLU : I32EnumAttrCase<"LU", 3, "lu">; +def LoadCacheModifierCV : I32EnumAttrCase<"CV", 4, "cv">; -def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">, +/// Enum attribute of the different kinds. +def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind", + "NVVM load cache modifier kind", + [LoadCacheModifierCA, LoadCacheModifierCG, LoadCacheModifierCS, + LoadCacheModifierLU, LoadCacheModifierCV]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} + +def LoadCacheModifierAttr : EnumAttr; + +def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global", [DeclareOpInterfaceMethods]>, Arguments<(ins LLVM_i8Ptr_shared:$dst, LLVM_i8Ptr_global:$src, I32Attr:$size, - OptionalAttr:$bypass_l1)> { + LoadCacheModifierAttr:$modifier, + Optional:$cpSize)> { string llvmBuilder = [{ llvm::Intrinsic::ID id; switch ($size) { @@ -522,18 +540,40 @@ id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_8; break; case 16: - if(static_cast($bypass_l1)) + if($modifier == NVVM::LoadCacheModifierKind::CG) id = llvm::Intrinsic::nvvm_cp_async_cg_shared_global_16; - else + else if($modifier == NVVM::LoadCacheModifierKind::CA) id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16; + else + llvm_unreachable("unsupported cache modifier"); break; default: llvm_unreachable("unsupported async copy size"); } createIntrinsicCall(builder, id, {$dst, $src}); }]; - let assemblyFormat = "$dst `,` $src `,` $size attr-dict `:` type(operands)"; + let assemblyFormat = "$dst `,` $src `,` $size `,` `cache` `=` $modifier (`,` $cpSize^)? attr-dict `:` type(operands)"; let hasVerifier = 1; + let extraClassDeclaration = [{ + bool hasIntrinsic() { if(getCpSize()) return false; return true; } + + void getAsmValues(RewriterBase &rewriter, + llvm::SmallVectorImpl> &asmValues) { + asmValues.push_back({getDst(), PTXRegisterMod::Read}); + asmValues.push_back({getSrc(), PTXRegisterMod::Read}); + asmValues.push_back({makeConstantI32(rewriter, getSize()), PTXRegisterMod::Read}); + asmValues.push_back({getCpSize(), PTXRegisterMod::Read}); + } + }]; + let extraClassDefinition = [{ + const char* $cppClass::getPtx() { + if(getModifier() == NVVM::LoadCacheModifierKind::CG) + return "cp.async.cg.shared.global [%0], [%1], %2, %3;\n"; + if(getModifier() == NVVM::LoadCacheModifierKind::CA) + return "cp.async.ca.shared.global [%0], [%1], %2, %3;\n"; + llvm_unreachable("unsupported cache modifier"); + } + }]; } def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> { diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -361,51 +361,6 @@ } }; -static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr, - Value dstBytes, Value srcElements, - mlir::MemRefType elementType, - ConversionPatternRewriter &rewriter) { - auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), - LLVM::AsmDialect::AD_ATT); - - const char *cpAsyncCgStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n"; - const char *cpAsyncCaStr = "cp.async.ca.shared.global [$0], [$1], $2, $3;\n"; - const char *asmConstraints = "r,l,n,r"; - - Value c3I32 = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3)); - Value bitwidth = rewriter.create( - loc, rewriter.getI32Type(), - rewriter.getI32IntegerAttr(elementType.getElementTypeBitWidth())); - Value srcElementsI32 = - rewriter.create(loc, rewriter.getI32Type(), srcElements); - Value srcBytes = rewriter.create( - loc, rewriter.create(loc, bitwidth, srcElementsI32), c3I32); - - SmallVector asmVals{dstPtr, srcPtr, dstBytes, srcBytes}; - - // Pick the right asm string based on the dstBytes which is a compile-time - // constant. - auto dstByteConstOp = - dyn_cast(dstBytes.getDefiningOp()); - auto dstByteAttr = dyn_cast(dstByteConstOp.getValue()); - int64_t dstByteVal = dstByteAttr.getValue().getSExtValue(); - - assert((dstByteVal == 4 || dstByteVal == 8 || dstByteVal == 16) && - "cp.async byte copy size must be 4, 8 or 16"); - // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than - // 16 dst bytes. - const char *asmStr = (dstByteVal == 16) ? cpAsyncCgStr : cpAsyncCaStr; - - rewriter.create( - loc, LLVM::LLVMVoidType::get(rewriter.getContext()), - /*operands=*/asmVals, - /*asm_string=*/asmStr, - /*constraints=*/asmConstraints, /*has_side_effects=*/true, - /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, - /*operand_attrs=*/ArrayAttr()); -} - /// Returns the constraints for the sparse MMA inline assembly instruction. static std::string buildMmaSparseAsmConstraintString(unsigned matASize, unsigned matBSize, @@ -620,30 +575,38 @@ int64_t dstElements = adaptor.getDstElements().getZExtValue(); int64_t sizeInBytes = (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; - // bypass L1 is only supported for byte sizes of 16, we drop the hint - // otherwise. - UnitAttr bypassL1 = - sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr(); - - // When the optional SrcElements argument is present, the source (global - // memory) of CpAsyncOp is read only for SrcElements number of elements. The - // rest of the DstElements in the destination (shared memory) are filled - // with zeros. - if (op.getSrcElements()) - emitCpAsyncOpZfillAsm(loc, dstPtr, scrPtr, - rewriter.create( - loc, rewriter.getI32Type(), - rewriter.getI32IntegerAttr(sizeInBytes)), - adaptor.getSrcElements(), srcMemrefType, rewriter); - // When the optional SrcElements argument is *not* present, the regular // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global - // memory) to fill DstElements number of elements in the destination (shared - // memory). - else - rewriter.create(loc, dstPtr, scrPtr, - rewriter.getI32IntegerAttr(sizeInBytes), - bypassL1); + // memory) to fill DstElements number of elements in the destination + // (shared memory). + Value srcBytes = adaptor.getSrcElements(); + if (srcBytes) { + // When the optional SrcElements argument is present, the source (global + // memory) of CpAsyncOp is read only for SrcElements number of elements. + // The rest of the DstElements in the destination (shared memory) are + // filled with zeros. + Value c3I32 = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3)); + Value bitwidth = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth())); + Value srcElementsI32 = + rewriter.create(loc, rewriter.getI32Type(), srcBytes); + srcBytes = rewriter.create( + loc, rewriter.create(loc, bitwidth, srcElementsI32), + c3I32); + } + // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than + // 16 dst bytes. + NVVM::LoadCacheModifierKind cacheModifier = + (op.getBypassL1().value_or(false) && sizeInBytes == 16) + ? NVVM::LoadCacheModifierKind::CG + : NVVM::LoadCacheModifierKind::CA; + + rewriter.create( + loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), + NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier), + srcBytes); // Drop the result token. Value zero = rewriter.create( diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -68,10 +68,13 @@ void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); } LogicalResult CpAsyncOp::verify() { + if (getModifier() != LoadCacheModifierKind::CG && + getModifier() != LoadCacheModifierKind::CA) + return emitError("Only CG and CA cache modifiers are supported."); if (getSize() != 4 && getSize() != 8 && getSize() != 16) return emitError("expected byte size to be either 4, 8 or 16."); - if (getBypassL1() && getSize() != 16) - return emitError("bypass l1 is only support for 16 bytes copy."); + if (getModifier() == LoadCacheModifierKind::CG && getSize() != 16) + return emitError("CG cache modifier is only support for 16 bytes copy."); return success(); } diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -83,6 +83,20 @@ return emitOpError() << "expected " << dstMemref.getRank() << " destination indices, got " << getDstIndices().size(); + if (getBypassL1().has_value()) { + int64_t dstElements = getDstElements().getZExtValue(); + int64_t sizeInBytes = + (dstMemref.getElementTypeBitWidth() * dstElements) / 8; + int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth(); + if (getBypassL1().value() && sizeInBytes != 16) { + return emitOpError() << "bypassL1 does not satify alignment for " + << dstMemref << " with destination element " + << dstElements + << ". Unset bypassL1, or set " + "destination element to " + << req; + } + } return success(); } diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -258,14 +258,14 @@ // CHECK-DAG: %[[FI4:.*]] = llvm.add %[[FI3]], %[[IDX1]] : i64 // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI4]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr<1> - // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 16 + // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 16, cache = ca %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 : memref<128x128xf32> to memref<3x16x128xf32, 3> // CHECK: nvvm.cp.async.commit.group %1 = nvgpu.device_async_create_group %0 // CHECK: nvvm.cp.async.wait.group 1 nvgpu.device_async_wait %1 { numGroups = 1 : i32 } - // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 {bypass_l1} + // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = cg %2 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3> return } @@ -288,7 +288,7 @@ // CHECK-DAG: %[[FI3:.*]] = llvm.add %[[FI2]], %[[IDX1]] : i64 // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr<1> - // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 16 + // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 16, cache = ca %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i], 32 : memref<128x64xi4> to memref<128x128xi4, 3> return %0 : !nvgpu.device.async.token } @@ -296,11 +296,31 @@ // ----- // CHECK-LABEL: @async_cp_zfill_f32_align4( -// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index) +// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index func.func @async_cp_zfill_f32_align4( %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) { - // CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void + // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 + // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64 + // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)> + // CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64 + // CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64 + // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(128 : index) : i64 + // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX1]], %[[S0]] : i64 + // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI1]], %[[LI]] : i64 + // CHECK-DAG: %[[FI2:.*]] = llvm.add %[[FI1]], %[[IDX1]] : i64 + // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI2]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32 + // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[S2:.*]] = llvm.mlir.constant(128 : index) : i64 + // CHECK-DAG: %[[FI2:.*]] = llvm.mul %[[IDX1]], %[[S2]] : i64 + // CHECK-DAG: %[[FI3:.*]] = llvm.add %[[FI2]], %[[IDX1]] : i64 + // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr<1> + // CHECK-DAG: %[[c1:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK-DAG: %[[c2:.*]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK-DAG: %[[c3:.*]] = llvm.trunc %[[SRC1]] : i64 to i32 + // CHECK-DAG: %[[c4:.*]] = llvm.mul %[[c2]], %[[c3]] : i32 + // CHECK-DAG: %[[c5:.*]] = llvm.lshr %[[c4]], %[[c1]] : i32 + // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 16, cache = cg, %[[c5]] %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3> // CHECK: nvvm.cp.async.commit.group %1 = nvgpu.device_async_create_group %0 @@ -316,9 +336,29 @@ // CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index) func.func @async_cp_zfill_f32_align1( %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) { - // CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void - %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 1, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3> + // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 + // CHECK: %[[SRC1:.*]] = builtin.unrealized_conversion_cast %[[SRCELEMENTS]] : index to i64 + // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)> + // CHECK-DAG: %[[S2048:.*]] = llvm.mlir.constant(2048 : index) : i64 + // CHECK-DAG: %[[LI1:.*]] = llvm.mul %[[IDX1]], %[[S2048]] : i64 + // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(128 : index) : i64 + // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX1]], %[[S0]] : i64 + // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI1]], %[[LI]] : i64 + // CHECK-DAG: %[[FI2:.*]] = llvm.add %[[FI1]], %[[IDX1]] : i64 + // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI2]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32 + // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[S2:.*]] = llvm.mlir.constant(128 : index) : i64 + // CHECK-DAG: %[[FI2:.*]] = llvm.mul %[[IDX1]], %[[S2]] : i64 + // CHECK-DAG: %[[FI3:.*]] = llvm.add %[[FI2]], %[[IDX1]] : i64 + // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr<1> + // CHECK-DAG: %[[c1:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK-DAG: %[[c2:.*]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK-DAG: %[[c3:.*]] = llvm.trunc %[[SRC1]] : i64 to i32 + // CHECK-DAG: %[[c4:.*]] = llvm.mul %[[c2]], %[[c3]] : i32 + // CHECK-DAG: %[[c5:.*]] = llvm.lshr %[[c4]], %[[c1]] : i32 + // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 4, cache = ca, %[[c5]] + %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 1, %srcElements : memref<128x128xf32> to memref<3x16x128xf32, 3> // CHECK: nvvm.cp.async.commit.group %1 = nvgpu.device_async_create_group %0 // CHECK: nvvm.cp.async.wait.group 1 diff --git a/mlir/test/Conversion/NVGPUToNVVM/typed-pointers.mlir b/mlir/test/Conversion/NVGPUToNVVM/typed-pointers.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/typed-pointers.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/typed-pointers.mlir @@ -21,14 +21,14 @@ // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI4]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr to !llvm.ptr - // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16 + // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16, cache = ca %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 : memref<128x128xf32> to memref<3x16x128xf32, 3> // CHECK: nvvm.cp.async.commit.group %1 = nvgpu.device_async_create_group %0 // CHECK: nvvm.cp.async.wait.group 1 nvgpu.device_async_wait %1 { numGroups = 1 : i32 } - // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 {bypass_l1} + // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = cg %2 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3> return } @@ -53,7 +53,7 @@ // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr to !llvm.ptr - // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16 + // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16, cache = ca %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i], 32 : memref<128x64xi4> to memref<128x128xi4, 3> return %0 : !nvgpu.device.async.token } diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -2,28 +2,46 @@ // CHECK-LABEL : @init_mbarrier_arrive_expect_tx llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) -> i32{ - //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;", "=r,r,r" %arg0, %arg1 : (!llvm.ptr<3>, i32) -> i32 + //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;", "=r,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i32 %res = nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 -> i32 llvm.return %res : i32 } // CHECK-LABEL : @init_mbarrier_arrive_expect_tx_generic llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32)-> i32 { - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;", "=r,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i32 %res = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i32 llvm.return %res : i32 } // CHECK-LABEL : @init_mbarrier_try_wait.parity.shared llvm.func @init_mbarrier_try_wait.parity.shared(%barrier : !llvm.ptr<3>, %token : i32) -> i32 { - // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,r,r" %arg0, %arg1 : (!llvm.ptr<3>, i32) -> i32 + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i32 %res = nvvm.mbarrier.try_wait.parity.shared %barrier, %token : !llvm.ptr<3>, i32 -> i32 llvm.return %res : i32 } // CHECK-LABEL : @init_mbarrier_try_wait.parity llvm.func @init_mbarrier_try_wait.parity(%barrier : !llvm.ptr, %token : i32) -> i32{ - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32 + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i32 %res = nvvm.mbarrier.try_wait.parity %barrier, %token : !llvm.ptr, i32 -> i32 llvm.return %res : i32 } + +// CHECK-LABEL : @async_cp +func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) { + // CHECK : nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1> + nvvm.cp.async.shared.global %dst, %src, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1> + // CHECK : nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1> + nvvm.cp.async.shared.global %dst, %src, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1> + return +} + +// CHECK-LABEL : @async_cp_zfill +func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) { + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [%0], [%1], %2, %3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void + nvvm.cp.async.shared.global %dst, %src, 16, cache = cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32 + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [%0], [%1], %2, %3;\0A", "r,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32) -> !llvm.void + nvvm.cp.async.shared.global %dst, %src, 4, cache = ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32 + return +} diff --git a/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir b/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir --- a/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid-typed-pointers.mlir @@ -278,15 +278,15 @@ func.func @cp_async(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { // expected-error @below {{expected byte size to be either 4, 8 or 16.}} - nvvm.cp.async.shared.global %arg0, %arg1, 32 : !llvm.ptr, !llvm.ptr + nvvm.cp.async.shared.global %arg0, %arg1, 32, cache = ca : !llvm.ptr, !llvm.ptr return } // ----- func.func @cp_async(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { - // expected-error @below {{bypass l1 is only support for 16 bytes copy.}} - nvvm.cp.async.shared.global %arg0, %arg1, 8 {bypass_l1} : !llvm.ptr, !llvm.ptr + // expected-error @below {{CG cache modifier is only support for 16 bytes copy.}} + nvvm.cp.async.shared.global %arg0, %arg1, 8, cache = cg : !llvm.ptr, !llvm.ptr return } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1281,15 +1281,15 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { // expected-error @below {{expected byte size to be either 4, 8 or 16.}} - nvvm.cp.async.shared.global %arg0, %arg1, 32 : !llvm.ptr<3>, !llvm.ptr<1> + nvvm.cp.async.shared.global %arg0, %arg1, 32, cache = cg : !llvm.ptr<3>, !llvm.ptr<1> return } // ----- func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { - // expected-error @below {{bypass l1 is only support for 16 bytes copy.}} - nvvm.cp.async.shared.global %arg0, %arg1, 8 {bypass_l1} : !llvm.ptr<3>, !llvm.ptr<1> + // expected-error @below {{CG cache modifier is only support for 16 bytes copy.}} + nvvm.cp.async.shared.global %arg0, %arg1, 8, cache = cg : !llvm.ptr<3>, !llvm.ptr<1> return } diff --git a/mlir/test/Dialect/LLVMIR/nvvm-typed-pointers.mlir b/mlir/test/Dialect/LLVMIR/nvvm-typed-pointers.mlir --- a/mlir/test/Dialect/LLVMIR/nvvm-typed-pointers.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm-typed-pointers.mlir @@ -11,10 +11,10 @@ // CHECK-LABEL: @cp_async llvm.func @cp_async(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { -// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 - nvvm.cp.async.shared.global %arg0, %arg1, 16 : !llvm.ptr, !llvm.ptr -// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 {bypass_l1} - nvvm.cp.async.shared.global %arg0, %arg1, 16 {bypass_l1} : !llvm.ptr, !llvm.ptr +// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, cache = ca + nvvm.cp.async.shared.global %arg0, %arg1, 16, cache=ca : !llvm.ptr, !llvm.ptr +// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, cache = cg + nvvm.cp.async.shared.global %arg0, %arg1, 16, cache=cg : !llvm.ptr, !llvm.ptr // CHECK: nvvm.cp.async.commit.group nvvm.cp.async.commit.group // CHECK: nvvm.cp.async.wait.group 0 diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -289,10 +289,10 @@ // CHECK-LABEL: @cp_async llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { -// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 - nvvm.cp.async.shared.global %arg0, %arg1, 16 : !llvm.ptr<3>, !llvm.ptr<1> -// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 {bypass_l1} - nvvm.cp.async.shared.global %arg0, %arg1, 16 {bypass_l1} : !llvm.ptr<3>, !llvm.ptr<1> +// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = ca + nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = ca : !llvm.ptr<3>, !llvm.ptr<1> +// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache = cg + nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = cg : !llvm.ptr<3>, !llvm.ptr<1> // CHECK: nvvm.cp.async.commit.group nvvm.cp.async.commit.group // CHECK: nvvm.cp.async.wait.group 0 diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir --- a/mlir/test/Dialect/NVGPU/invalid.mlir +++ b/mlir/test/Dialect/NVGPU/invalid.mlir @@ -185,3 +185,12 @@ (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> return %d : vector<2x2xf16> } + +// ----- + +func.func @async_cp_zfill_f32_align1( + %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) { + // expected-error @+1 {{'nvgpu.device_async_copy' op bypassL1 does not satify alignment for 'memref<3x16x128xf32, 3>' with destination element 1. Unset bypassL1, or set destination element to 4}} + %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 1, %srcElements {bypassL1} : memref<128x128xf32> to memref<3x16x128xf32, 3> + return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -309,13 +309,13 @@ // CHECK-LABEL: @cp_async llvm.func @cp_async(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}) - nvvm.cp.async.shared.global %arg0, %arg1, 4 : !llvm.ptr, !llvm.ptr + nvvm.cp.async.shared.global %arg0, %arg1, 4, cache = ca : !llvm.ptr, !llvm.ptr // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}) - nvvm.cp.async.shared.global %arg0, %arg1, 8 : !llvm.ptr, !llvm.ptr + nvvm.cp.async.shared.global %arg0, %arg1, 8, cache = ca : !llvm.ptr, !llvm.ptr // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}) - nvvm.cp.async.shared.global %arg0, %arg1, 16 : !llvm.ptr, !llvm.ptr + nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = ca : !llvm.ptr, !llvm.ptr // CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}) - nvvm.cp.async.shared.global %arg0, %arg1, 16 {bypass_l1} : !llvm.ptr, !llvm.ptr + nvvm.cp.async.shared.global %arg0, %arg1, 16, cache = cg : !llvm.ptr, !llvm.ptr // CHECK: call void @llvm.nvvm.cp.async.commit.group() nvvm.cp.async.commit.group // CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0)