diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -741,19 +741,26 @@ matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Operation *funcOp = op->getParentOp(); - Operation *mOp = funcOp->getParentOp(); MemRefType barrierType = createMBarrierMemrefType(rewriter, op.getBarrier().getType()); - memref::GlobalOp global; - if (auto moduleOp = dyn_cast(mOp)) - global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); - else if (auto moduleOp = dyn_cast(mOp)) - global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); - + std::optional global = std::nullopt; + Operation *mOp = funcOp->getParentOp(); + while (mOp != nullptr && !global.has_value()) { + if (auto moduleOp = dyn_cast(mOp)) + global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); + else if (auto moduleOp = dyn_cast(mOp)) + global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); + mOp = mOp->getParentOp(); + }; + if (!global.has_value()) { + return op->emitError() << "Expected parent op to be either " + "gpu::GPUModuleOp or ModuleOp " + "to genere mbarrier object."; + } rewriter.setInsertionPoint(op); rewriter.replaceOpWithNewOp(op, barrierType, - global.getName()); + global->getName()); return success(); } }; diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -603,4 +603,19 @@ nvgpu.mbarrier.try_wait.parity %barrier, %phase, %ticks : !barrierType func.return +} +// ----- + +!barrierType = !nvgpu.mbarrier.barrier> +module @find_parent{ + func.func @main() { + %c1 = arith.constant 1 : index + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) { + // CHECK: memref.get_global @__mbarrier : memref<1xi64, 3> + %barrier = nvgpu.mbarrier.create -> !barrierType + gpu.terminator + } + func.return + } } \ No newline at end of file