diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -349,6 +349,43 @@ let assemblyFormat = "$region attr-dict"; } +//===----------------------------------------------------------------------===// +// 2.17.1 critical Construct +//===----------------------------------------------------------------------===// +// TODO: Autogenerate this from OMP.td in llvm/include/Frontend +def omp_sync_hint_none: I32EnumAttrCase<"none", 0>; +def omp_sync_hint_uncontended: I32EnumAttrCase<"uncontended", 1>; +def omp_sync_hint_contended: I32EnumAttrCase<"contended", 2>; +def omp_sync_hint_nonspeculative: I32EnumAttrCase<"nonspeculative", 3>; +def omp_sync_hint_speculative: I32EnumAttrCase<"speculative", 4>; + +def SyncHintKind: I32EnumAttr<"SyncHintKind", "OpenMP Sync Hint Kind", + [omp_sync_hint_none, omp_sync_hint_uncontended, omp_sync_hint_contended, + omp_sync_hint_nonspeculative, omp_sync_hint_speculative]> { + let cppNamespace = "::mlir::omp"; + let stringToSymbolFnName = "ConvertToEnum"; + let symbolToStringFnName = "ConvertToString"; +} + +def CriticalOp : OpenMP_Op<"critical"> { + let summary = "critical construct"; + let description = [{ + The critical construct imposes a restriction on the associated structured + block (region) to be executed by only a single thread at a time. + }]; + + let arguments = (ins OptionalAttr:$name, + OptionalAttr:$hint); + + let regions = (region AnyRegion:$region); + + let assemblyFormat = [{ + (`(` $name^ `)`)? (`hint` `(` $hint^ `)`)? $region attr-dict + }]; + + let verifier = "return ::verifyCriticalOp(*this);"; +} + //===----------------------------------------------------------------------===// // 2.17.2 barrier Construct //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -974,5 +974,13 @@ return success(); } +static LogicalResult verifyCriticalOp(CriticalOp op) { + if (!op.name().hasValue() && op.hint().hasValue() && + (op.hint().getValue() != SyncHintKind::none)) + return op.emitOpError() << "must specify a name unless the effect is as if " + "hint(none) is specified"; + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -204,6 +204,45 @@ return success(); } +/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder. +static LogicalResult +convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + auto criticalOp = cast(opInst); + // TODO: support error propagation in OpenMPIRBuilder and use it instead of + // relying on captured variables. + LogicalResult bodyGenStatus = success(); + + auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, + llvm::BasicBlock &continuationBlock) { + // CriticalOp has only one region associated with it. + auto ®ion = cast(opInst).getRegion(); + convertOmpOpRegions(region, "omp.critical.region", *codeGenIP.getBlock(), + continuationBlock, builder, moduleTranslation, + bodyGenStatus); + }; + + // TODO: Perform finalization actions for variables. This has to be + // called for variables which have destructors/finalizers. + auto finiCB = [&](InsertPointTy codeGenIP) {}; + + llvm::OpenMPIRBuilder::LocationDescription ompLoc( + builder.saveIP(), builder.getCurrentDebugLocation()); + llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); + llvm::Constant *hint = nullptr; + if (criticalOp.hint().hasValue()) { + hint = + llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), + static_cast(criticalOp.hint().getValue())); + } else { + hint = llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0); + } + builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical( + ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint)); + return success(); +} + /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder. static LogicalResult convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, @@ -365,6 +404,9 @@ .Case([&](omp::MasterOp) { return convertOmpMaster(*op, builder, moduleTranslation); }) + .Case([&](omp::CriticalOp) { + return convertOmpCritical(*op, builder, moduleTranslation); + }) .Case([&](omp::WsLoopOp) { return convertOmpWsLoop(*op, builder, moduleTranslation); }) diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -293,3 +293,13 @@ } return } + +// ----- + +func @omp_critical() -> () { + // expected-error @below {{must specify a name unless the effect is as if hint(none) is specified}} + omp.critical hint(nonspeculative) { + omp.terminator + } + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -369,3 +369,14 @@ return } +// CHECK-LABEL: omp_critical +func @omp_critical() -> () { + omp.critical { + omp.terminator + } + + omp.critical(@mutex) hint(nonspeculative) { + omp.terminator + } + return +} diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -466,3 +466,28 @@ } llvm.return } + +// CHECK-LABEL: @omp_critical +llvm.func @omp_critical(%x : !llvm.ptr, %xval : i32) -> () { + // CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_.var{{.*}}, i32 0) + // CHECK: br label %omp.critical.region + // CHECK: omp.critical.region + omp.critical { + // CHECK: store + llvm.store %xval, %x : !llvm.ptr + omp.terminator + } + // CHECK: call void @__kmpc_end_critical({{.*}}critical_user_.var{{.*}}) + + // CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_mutex.var{{.*}}, i32 2) + // CHECK: br label %omp.critical.region + // CHECK: omp.critical.region + omp.critical(@mutex) hint(contended) { + // CHECK: store + llvm.store %xval, %x : !llvm.ptr + omp.terminator + } + // CHECK: call void @__kmpc_end_critical({{.*}}critical_user_mutex.var{{.*}}) + + llvm.return +}