diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -372,53 +372,59 @@ } def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx", - [DeclareOpInterfaceMethods]>, - Results<(outs LLVM_Type:$res)>, + [DeclareOpInterfaceMethods]>, Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> { - let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)"; + let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)"; let extraClassDefinition = [{ - std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 %0, [%1], %2;"); } + std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); } }]; } def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared", - [DeclareOpInterfaceMethods]>, - Results<(outs LLVM_Type:$res)>, + [DeclareOpInterfaceMethods]>, Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> { - let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)"; + let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands)"; let extraClassDefinition = [{ - std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;"); } + std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); } }]; } def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity", - [DeclareOpInterfaceMethods]>, - Results<(outs LLVM_Type:$res)>, - Arguments<(ins LLVM_i64ptr_any:$addr, LLVM_Type:$token)> { - let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)"; + [DeclareOpInterfaceMethods]>, + Arguments<(ins LLVM_i64ptr_any:$addr, I32:$phase, I32:$ticks)> { + let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)"; let extraClassDefinition = [{ std::string $cppClass::getPtx() { - return std::string("{\n\t" - ".reg .pred P1; \n\t" - "mbarrier.try_wait.parity.b64 P1, [%1], %2; \n\t" - "selp.b32 %0, 1, 0, P1; \n\t" - "}"); + return std::string( + "{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra.uni DONE; \n\t" + "bra.uni LAB_WAIT; \n\t" + "DONE: \n\t" + "}" + ); } }]; } def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared", - [DeclareOpInterfaceMethods]>, - Results<(outs LLVM_Type:$res)>, - Arguments<(ins LLVM_i64ptr_shared:$addr, LLVM_Type:$token)> { - let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)"; + [DeclareOpInterfaceMethods]>, + Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$phase, I32:$ticks)> { + let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)"; let extraClassDefinition = [{ std::string $cppClass::getPtx() { - return std::string("{\n\t" - ".reg .pred P1; \n\t" - "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t" - "selp.b32 %0, 1, 0, P1; \n\t" - "}"); + return std::string( + "{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra.uni DONE; \n\t" + "bra.uni LAB_WAIT; \n\t" + "DONE: \n\t" + "}" + ); } }]; } diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -469,4 +469,44 @@ let assemblyFormat = "$barrier `,` $count attr-dict `:` type($barrier) `->` type($token)"; } +def NVGPU_MBarrierArriveExpectTxOp : NVGPU_Op<"mbarrier.arrive.expect_tx", []> { + let summary = "Performs expect_tx operation on the `nvgpu.mbarrier.arrive`"; + let description = [{ + A thread executing the Op performs an expect-tx operation on the mbarrier + object at the location specified by the address operand $barrier. The + expect-tx operation, with an $txcount argument, increases the tx-count of + an mbarrier object by the value specified by $txcount. This makes the + current phase of the mbarrier object to expect and track the completion of + additional asynchronous transactions. + + The `$txCount` specifies the number of element to the expect-tx operation. + + Example: + ```mlir + nvgpu.mbarrier.arrive.expect_tx %barrier, %ic0 : !nvgpu.mbarrier.barrier> + ``` + }]; + let arguments = (ins NVGPU_MBarrier:$barrier, + Index:$txcount); + let assemblyFormat = "$barrier `,` $txcount attr-dict `:` type($barrier)"; +} + +def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> { + let summary = "Waits for the `nvgpu.mbarrier` to complete its current phase."; + let description = [{ + Checks whether the mbarrier object has completed the phase. It is is a + potentially blocking instruction which tests for the completion of the + phase. Suspended thread resumes execution when the specified phase completes + OR before the phase completes following a system-dependent time limit. + + Example: + ```mlir + nvgpu.mbarrier.try_wait.parity %barrier, %phase, %ticks : !nvgpu.mbarrier.barrier> + ``` + + }]; + let arguments = (ins NVGPU_MBarrier:$barrier, Index:$phase, Index:$ticks); + let assemblyFormat = "$barrier `,` $phase `,` $ticks attr-dict `:` type($barrier)"; +} + #endif // NVGPU diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -25,6 +25,17 @@ using namespace mlir; +/// GPU has 32 bit registers, this function truncates values when larger width +/// is not needed. +static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc, + Value value) { + Type type = value.getType(); + assert(llvm::isa(type) && "expected an integer Value"); + if (type.getIntOrFloatBitWidth() <= 32) + return value; + return rewriter.create(loc, rewriter.getI32Type(), value); +} + /// Returns the type for the intrinsic given the vectorResultType of the /// `gpu.mma.sync` operation. static Type inferIntrinsicResultType(Type vectorResultType) { @@ -850,6 +861,55 @@ } }; +struct NVGPUMBarrierArriveExpectTxLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + nvgpu::MBarrierArriveExpectTxOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(), + op.getBarrier(), adaptor.getBarrier()); + Value txcount = truncToI32(rewriter, op->getLoc(), adaptor.getTxcount()); + + if (isMbarrierShared(op.getBarrier().getType())) { + rewriter.replaceOpWithNewOp( + op, barrier, txcount); + return success(); + } + + rewriter.replaceOpWithNewOp(op, barrier, + txcount); + return success(); + } +}; + +struct NVGPUMBarrierTryWaitParityLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + nvgpu::MBarrierTryWaitParityOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(), + op.getBarrier(), adaptor.getBarrier()); + Value ticks = truncToI32(rewriter, op->getLoc(), adaptor.getTicks()); + Value phase = truncToI32(rewriter, op->getLoc(), adaptor.getPhase()); + + if (isMbarrierShared(op.getBarrier().getType())) { + rewriter.replaceOpWithNewOp( + op, barrier, phase, ticks); + return success(); + } + + rewriter.replaceOpWithNewOp(op, barrier, + phase, ticks); + return success(); + } +}; + } // namespace void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, @@ -859,7 +919,9 @@ NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive NVGPUMBarrierArriveNoCompleteLowering, // nvgpu.mbarrier.arrive.no_complete - NVGPUMBarrierTestWaitLowering, // nvgpu.try_wait_parity + NVGPUMBarrierTestWaitLowering, // nvgpu.mbarrier.test_wait_parity + NVGPUMBarrierTryWaitParityLowering, // nvgpu.mbarrier.try_wait_parity + NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -558,3 +558,49 @@ func.return } + + +// ----- +!barrierType = !nvgpu.mbarrier.barrier> +!tokenType = !nvgpu.mbarrier.token + +// CHECK-LABEL: func @mbarrier_txcount +func.func @mbarrier_txcount() { + %num_threads = arith.constant 128 : index + + // CHECK: %[[barMemref:.+]] = memref.get_global @__mbarrier : memref<1xi64, 3> + %barrier = nvgpu.mbarrier.create -> !barrierType + + // CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[barPtr:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: nvvm.mbarrier.init.shared %[[barPtr]] + nvgpu.mbarrier.init %barrier, %num_threads : !barrierType + + %c0 = arith.constant 0 : index + %tidxreg = nvvm.read.ptx.sreg.tid.x : i32 + %tidx = arith.index_cast %tidxreg : i32 to index + %cnd = arith.cmpi eq, %tidx, %c0 : index + + scf.if %cnd { + %txcount = arith.constant 256 : index + // CHECK: %[[barPtr2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]] + nvgpu.mbarrier.arrive.expect_tx %barrier, %txcount : !barrierType + scf.yield + } else { + %txcount = arith.constant 0 : index + // CHECK: %[[barPtr2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]] + nvgpu.mbarrier.arrive.expect_tx %barrier, %txcount : !barrierType + scf.yield + } + + + %phase = arith.constant 0 : index + %ticks = arith.constant 10000000 : index + // CHECK: %[[barPtr3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]] + nvgpu.mbarrier.try_wait.parity %barrier, %phase, %ticks : !barrierType + + func.return +} \ No newline at end of file diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -1,31 +1,31 @@ // RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file %s | FileCheck %s // CHECK-LABEL : @init_mbarrier_arrive_expect_tx -llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) -> i64 { - //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 $0, [$1], $2;", "=l,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i64 - %res = nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 -> i64 - llvm.return %res : i64 +llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) { + //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r" + nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 + llvm.return } // CHECK-LABEL : @init_mbarrier_arrive_expect_tx_generic -llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32)-> i64 { - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 $0, [$1], $2;", "=l,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i64 - %res = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i64 - llvm.return %res : i64 +llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32) { + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 _, [$0], $1;", "l,r" + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 + llvm.return } // CHECK-LABEL : @init_mbarrier_try_wait.parity.shared -llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %token : i32) -> i32 { - // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [$1], $2; \0A\09selp.b32 $0, 1, 0, P1; \0A\09}", "=r,r,r" %{{.*}}, %{{.*}} : (!llvm.ptr<3>, i32) -> i32 - %res = nvvm.mbarrier.try_wait.parity.shared %barrier, %token : !llvm.ptr<3>, i32 -> i32 - llvm.return %res : i32 +llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32, %phase : i32) { + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09LAB_WAIT: \0A\09mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, $2; \0A\09@P1 bra.uni DONE; \0A\09bra.uni LAB_WAIT; \0A\09DONE: \0A\09}", "r,r,r" + nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32 + llvm.return } // CHECK-LABEL : @init_mbarrier_try_wait.parity -llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %token : i32) -> i32{ - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [$1], $2; \0A\09selp.b32 $0, 1, 0, P1; \0A\09}", "=r,l,r" %{{.*}}, %{{.*}} : (!llvm.ptr, i32) -> i32 - %res = nvvm.mbarrier.try_wait.parity %barrier, %token : !llvm.ptr, i32 -> i32 - llvm.return %res : i32 +llvm.func @init_mbarrier_try_wait(%barrier : !llvm.ptr, %ticks : i32, %phase : i32){ + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09LAB_WAIT: \0A\09mbarrier.try_wait.parity.b64 P1, [$0], $1, $2; \0A\09@P1 bra.uni DONE; \0A\09bra.uni LAB_WAIT; \0A\09DONE: \0A\09}", "r,r,r" + nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr, i32, i32 + llvm.return } // CHECK-LABEL : @async_cp