diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -457,4 +457,47 @@ let assemblyFormat = "$barrier `,` $count attr-dict `:` type($barrier) `->` type($token)"; } +def NVGPU_MBarrierArriveExpectTxOp : NVGPU_Op<"mbarrier.arrive.expect_tx", []> { + let summary = "Performs expect_tx operation on the `nvgpu.mbarrier.arrive`"; + let description = [{ + A thread executing the Op performs an expect-tx operation on the mbarrier + object at the location specified by the address operand $barrier. The + expect-tx operation, with an $txcount argument, increases the tx-count of + an mbarrier object by the value specified by $txcount. This makes the + current phase of the mbarrier object to expect and track the completion of + additional asynchronous transactions. + + The `$txCount` specifies the number of element to the expect-tx operation. + + Example: + ```mlir + %token = mbarrier.arrive.expect_tx %barrier, %txcount : !nvgpu.mbarrier.barrier> -> !nvgpu.mbarrier.token + ``` + }]; + let arguments = (ins NVGPU_MBarrier:$barrier, + Index:$txcount, + Optional:$state); + let results = (outs NVGPU_MBarrierToken:$token); + let assemblyFormat = "$barrier `,` $txcount (`,` $state^)? attr-dict `:` type($barrier) `->` type($token)"; +} + +def NVGPU_MBarrierTryWaitParityOp : NVGPU_Op<"mbarrier.try_wait.parity", []> { + let summary = "Waits for the `nvgpu.mbarrier` to complete its current phase."; + let description = [{ + Checks whether the mbarrier object has completed the phase. It is is a + potentially blocking instruction which tests for the completion of the + phase. Suspended thread resumes execution when the specified phase completes + OR before the phase completes following a system-dependent time limit. + + Example: + ```mlir + %token = mbarrier.try_wait.parity %barrier, %token : !nvgpu.mbarrier.barrier>, !nvgpu.mbarrier.token : i32 + ``` + + }]; + let arguments = (ins NVGPU_MBarrier:$barrier, NVGPU_MBarrierToken:$token); + let results = (outs I1:$waitComplete); + let assemblyFormat = "$barrier `,` $token attr-dict `:` type($barrier) `,` type($token)"; +} + #endif // NVGPU diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -878,12 +878,69 @@ } }; +struct NVGPUMBarrierArriveExpectTxLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + nvgpu::MBarrierArriveExpectTxOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(), + op.getBarrier(), adaptor.getBarrier()); + Type tokenType = getTypeConverter()->convertType( + nvgpu::MBarrierTokenType::get(op->getContext())); + Value txcount = adaptor.getTxcount(); + + if (!adaptor.getTxcount().getType().isInteger(32)) { + txcount = rewriter.create(op->getLoc(), + rewriter.getI32Type(), txcount); + } + if (isMbarrierShared(op.getBarrier().getType())) { + rewriter.replaceOpWithNewOp( + op, tokenType, barrier, txcount); + } else { + rewriter.replaceOpWithNewOp( + op, tokenType, barrier, txcount); + } + + return success(); + } +}; + +struct NVGPUMBarrierTryWaitParityLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + nvgpu::MBarrierTryWaitParityOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value barrier = getMbarrierPtr(rewriter, *getTypeConverter(), + op.getBarrier(), adaptor.getBarrier()); + Type tokenType = getTypeConverter()->convertType( + nvgpu::MBarrierTokenType::get(op->getContext())); + if (isMbarrierShared(op.getBarrier().getType())) { + rewriter.replaceOpWithNewOp( + op, tokenType, barrier, adaptor.getToken()); + } else { + rewriter.replaceOpWithNewOp( + op, tokenType, barrier, adaptor.getToken()); + } + + return success(); + } +}; + } // namespace void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add> +!tokenType = !nvgpu.mbarrier.token + +// CHECK-LABEL: func @mbarrier_txcount +func.func @mbarrier_txcount() { + %num_threads = arith.constant 128 : index + + // CHECK: %[[barMemref:.+]] = memref.get_global @__mbarrier : memref<1xi64, 3> + %barrier = nvgpu.mbarrier.create -> !barrierType + + // CHECK: %[[barStr:.+]] = builtin.unrealized_conversion_cast %[[barMemref]] : memref<1xi64, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[barPtr:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: nvvm.mbarrier.init.shared %[[barPtr]] + nvgpu.mbarrier.init %barrier, %num_threads : !barrierType + + %c0 = arith.constant 0 : index + %tidxreg = nvvm.read.ptx.sreg.tid.x : i32 + %tidx = arith.index_cast %tidxreg : i32 to index + %cnd = arith.cmpi eq, %tidx, %c0 : index + + %token = scf.if %cnd -> !tokenType{ + %txcount = arith.constant 256 : index + // Here once can do TMA request + + // CHECK: %[[barPtr2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[token:.+]] = nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]] + %wtoken = nvgpu.mbarrier.arrive.expect_tx %barrier, %txcount : !barrierType -> !tokenType + scf.yield %wtoken : !tokenType + } else { + %txcount = arith.constant 0 : index + // CHECK: %[[barPtr2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[token:.+]] = nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]] + %wtoken = nvgpu.mbarrier.arrive.expect_tx %barrier, %txcount : !barrierType -> !tokenType + scf.yield %wtoken : !tokenType + } + + + scf.while () : () -> () { + // CHECK: %[[barPtr3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]] + %1 = nvgpu.mbarrier.try_wait.parity %barrier, %token : !barrierType, !tokenType + scf.condition(%1) + } do { + ^bb0(): + scf.yield + } + + func.return } \ No newline at end of file