diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -249,6 +249,23 @@ let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)"; } +def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_i64ptr_any:$addr, LLVM_Type:$token)> { + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_test_wait, {$addr, $token}); + }]; + let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)"; +} + +def NVVM_MBarrierTestWaitSharedOp : NVVM_Op<"mbarrier.test.wait.shared">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_i64ptr_shared:$addr, LLVM_Type:$token)> { + string llvmBuilder = [{ + $res = createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_test_wait_shared, {$addr, $token}); + }]; + let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)"; +} //===----------------------------------------------------------------------===// // NVVM synchronization op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -394,3 +394,16 @@ %0 = nvvm.mbarrier.arrive.nocomplete.shared %barrier, %count : !llvm.ptr<3>, i32 -> i64 llvm.return } + +llvm.func private @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 { + // CHECK: nvvm.mbarrier.test.wait %{{.*}} + %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1 + llvm.return %isComplete : i1 +} + +llvm.func private @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) { + %count = nvvm.read.ptx.sreg.ntid.x : i32 + // CHECK: nvvm.mbarrier.test.wait.shared %{{.*}} + %isComplete = nvvm.mbarrier.test.wait.shared %barrier, %token : !llvm.ptr<3>, i64 -> i1 + llvm.return +}