diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -94,12 +94,12 @@ "LLVM pointer to " # pointee.summary>; // Type constraints accepting LLVM pointer type to integer of a specific width. -class LLVM_IntPtrBase : Type< +class LLVM_IntPtrBase : Type< LLVM_PointerTo>.predicate, "LLVM pointer to " # I.summary>, BuildableType<"::mlir::LLVM::LLVMPointerType::get(" "::mlir::IntegerType::get($_builder.getContext(), " - # width #"))">; + # width #"), "# addressSpace #")">; def LLVM_i8Ptr : LLVM_IntPtrBase<8>; diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -16,6 +16,9 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +def LLVM_i8Ptr_global : LLVM_IntPtrBase<8, 1>; +def LLVM_i8Ptr_shared : LLVM_IntPtrBase<8, 3>; + //===----------------------------------------------------------------------===// // NVVM dialect definitions //===----------------------------------------------------------------------===// @@ -157,6 +160,56 @@ let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }]; } + +def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">, + Arguments<(ins LLVM_i8Ptr_shared:$dst, + LLVM_i8Ptr_global:$src, + I32Attr:$size)> { + string llvmBuilder = [{ + llvm::Intrinsic::ID id; + switch ($size) { + case 4: + id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_4; + break; + case 8: + id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_8; + break; + case 16: + id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16; + break; + default: + llvm_unreachable("unsupported async copy size"); + } + createIntrinsicCall(builder, id, {$dst, $src}); + }]; + let verifier = [{ + if (size() != 4 && size() != 8 && size() != 16) + return emitError("expected byte size to be either 4, 8 or 16."); + return success(); + }]; + let assemblyFormat = "$dst `,` $src `,` $size attr-dict"; +} + +def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> { + string llvmBuilder = [{ + createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_commit_group); + }]; + let assemblyFormat = "attr-dict"; +} + +def NVVM_CpAsyncWaitGroupOp : NVVM_Op<"cp.async.wait.group">, + Arguments<(ins I32Attr:$n)> { + string llvmBuilder = [{ + createIntrinsicCall( + builder, + llvm::Intrinsic::nvvm_cp_async_wait_group, + llvm::ConstantInt::get( + llvm::Type::getInt32Ty(moduleTranslation.getLLVMContext()), + $n)); + }]; + let assemblyFormat = "$n attr-dict"; +} + def NVVM_MmaOp : NVVM_Op<"mma.sync">, Results<(outs LLVM_Type:$res)>, diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1226,3 +1226,11 @@ llvm.bitcast %arg0 : vector<2x3xf32> to vector<2x3xi32> return } + +// ----- + +func @cp_async(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { + // expected-error @below {{expected byte size to be either 4, 8 or 16.}} + nvvm.cp.async.shared.global %arg0, %arg1, 32 + return +} diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -95,6 +95,15 @@ llvm.return %r : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } +llvm.func @cp_async(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { +// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 + nvvm.cp.async.shared.global %arg0, %arg1, 16 +// CHECK: nvvm.cp.async.commit.group + nvvm.cp.async.commit.group +// CHECK: nvvm.cp.async.wait.group 0 + nvvm.cp.async.wait.group 0 + llvm.return +} // ----- diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -162,6 +162,20 @@ llvm.return } +llvm.func @cp_async(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { +// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}}) + nvvm.cp.async.shared.global %arg0, %arg1, 4 +// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}}) + nvvm.cp.async.shared.global %arg0, %arg1, 8 +// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}}) + nvvm.cp.async.shared.global %arg0, %arg1, 16 +// CHECK: call void @llvm.nvvm.cp.async.commit.group() + nvvm.cp.async.commit.group +// CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0) + nvvm.cp.async.wait.group 0 + llvm.return +} + // This function has the "kernel" attribute attached and should appear in the // NVVM annotations after conversion. llvm.func @kernel_func() attributes {nvvm.kernel} {