diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -744,14 +744,13 @@ matchAndRewrite(nvgpu::MBarrierCreateOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Operation *funcOp = op->getParentOp(); - Operation *mOp = funcOp->getParentOp(); MemRefType barrierType = createMBarrierMemrefType(rewriter, op.getBarrier().getType()); memref::GlobalOp global; - if (auto moduleOp = dyn_cast(mOp)) + if (auto moduleOp = funcOp->getParentOfType()) global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); - else if (auto moduleOp = dyn_cast(mOp)) + else if (auto moduleOp = funcOp->getParentOfType()) global = generateGlobalBarrier(rewriter, funcOp, moduleOp, barrierType); rewriter.setInsertionPoint(op); diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -635,3 +635,18 @@ func.return } +// ----- + +!barrierType = !nvgpu.mbarrier.barrier> +module @find_parent{ + func.func @main() { + %c1 = arith.constant 1 : index + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) { + // CHECK: memref.get_global @__mbarrier : memref<1xi64, 3> + %barrier = nvgpu.mbarrier.create -> !barrierType + gpu.terminator + } + func.return + } +}