diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -352,6 +352,21 @@ //===----------------------------------------------------------------------===// // 2.17.1 critical Construct //===----------------------------------------------------------------------===// +def CriticalDeclareOp : OpenMP_Op<"critical.declare", [Symbol]> { + let summary = "declares a named critical section."; + + let description = [{ + Declares a named critical section. + + The name can be used in critical constructs in the dialect. + }]; + + let arguments = (ins SymbolNameAttr:$sym_name); + + let assemblyFormat = "$sym_name attr-dict"; +} + + // TODO: Autogenerate this from OMP.td in llvm/include/Frontend def omp_sync_hint_none: I32EnumAttrCase<"none", 0>; def omp_sync_hint_uncontended: I32EnumAttrCase<"uncontended", 1>; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -976,6 +976,17 @@ (op.hint().getValue() != SyncHintKind::none)) return op.emitOpError() << "must specify a name unless the effect is as if " "hint(none) is specified"; + + if (op.nameAttr()) { + auto symbolRef = op.nameAttr().cast(); + auto decl = + SymbolTable::lookupNearestSymbolFrom(op, symbolRef); + if (!decl) { + return op.emitOpError() << "expected symbol reference " << symbolRef + << " to point to a critical declaration"; + } + } + return success(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -805,14 +805,18 @@ .Case([&](omp::WsLoopOp) { return convertOmpWsLoop(*op, builder, moduleTranslation); }) - .Case( - [](auto op) { - // `yield` and `terminator` can be just omitted. The block structure - // was created in the region that handles their parent operation. - // `reduction.declare` will be used by reductions and is not - // converted directly, skip it. - return success(); - }) + .Case([](auto op) { + // `yield` and `terminator` can be just omitted. The block structure + // was created in the region that handles their parent operation. + // `reduction.declare` will be used by reductions and is not + // converted directly, skip it. + // `critical.declare` is only used to declare names of critical + // sections which will be used by `critical` ops and hence can be + // ignored for lowering. The OpenMP IRBuilder will create unique + // name for critical section names. + return success(); + }) .Default([&](Operation *inst) { return inst->emitError("unsupported OpenMP operation: ") << inst->getName(); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -296,10 +296,20 @@ // ----- -func @omp_critical() -> () { +func @omp_critical1() -> () { // expected-error @below {{must specify a name unless the effect is as if hint(none) is specified}} omp.critical hint(nonspeculative) { omp.terminator } return } + +// ----- + +func @omp_critical2() -> () { + // expected-error @below {{expected symbol reference @excl to point to a critical declaration}} + omp.critical(@excl) hint(speculative) { + omp.terminator + } + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -369,6 +369,10 @@ return } +// CHECK: omp.critical.declare +// CHECK-LABEL: @mutex +omp.critical.declare @mutex + // CHECK-LABEL: omp_critical func @omp_critical() -> () { omp.critical { diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -469,6 +469,8 @@ // ----- +omp.critical.declare @mutex + // CHECK-LABEL: @omp_critical llvm.func @omp_critical(%x : !llvm.ptr, %xval : i32) -> () { // CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_.var{{.*}}, i32 0)