diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -367,21 +367,6 @@ } -// TODO: Autogenerate this from OMP.td in llvm/include/Frontend -def omp_sync_hint_none: I32EnumAttrCase<"none", 0>; -def omp_sync_hint_uncontended: I32EnumAttrCase<"uncontended", 1>; -def omp_sync_hint_contended: I32EnumAttrCase<"contended", 2>; -def omp_sync_hint_nonspeculative: I32EnumAttrCase<"nonspeculative", 3>; -def omp_sync_hint_speculative: I32EnumAttrCase<"speculative", 4>; - -def SyncHintKind: I32EnumAttr<"SyncHintKind", "OpenMP Sync Hint Kind", - [omp_sync_hint_none, omp_sync_hint_uncontended, omp_sync_hint_contended, - omp_sync_hint_nonspeculative, omp_sync_hint_speculative]> { - let cppNamespace = "::mlir::omp"; - let stringToSymbolFnName = "ConvertToEnum"; - let symbolToStringFnName = "ConvertToString"; -} - def CriticalOp : OpenMP_Op<"critical"> { let summary = "critical construct"; let description = [{ @@ -390,12 +375,13 @@ }]; let arguments = (ins OptionalAttr:$name, - OptionalAttr:$hint); + UnitAttr:$has_hint, + DefaultValuedAttr:$hint); let regions = (region AnyRegion:$region); let assemblyFormat = [{ - (`(` $name^ `)`)? (`hint` `(` $hint^ `)`)? $region attr-dict + (`(` $name^ `)`)? ( `hint` $has_hint^ `(` $hint `)` )? $region attr-dict }]; let verifier = "return ::verifyCriticalOp(*this);"; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -971,9 +971,30 @@ return success(); } +static LogicalResult verifySynchronizationHint(Operation *op, int32_t hint) { + + auto bitn = [](int n, int idx) -> bool { return n & (1 << idx); }; + + bool uncontended = bitn(hint, 0); + bool contended = bitn(hint, 1); + bool nonspeculative = bitn(hint, 2); + bool speculative = bitn(hint, 3); + + if (uncontended && contended) + return op->emitOpError() << "the hints omp_sync_hint_uncontended and " + "omp_sync_hint_contended cannot be combined"; + if (nonspeculative && speculative) + return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " + "omp_sync_hint_speculative cannot be combined."; + return success(); +} + static LogicalResult verifyCriticalOp(CriticalOp op) { - if (!op.name().hasValue() && op.hint().hasValue() && - (op.hint().getValue() != SyncHintKind::none)) + + if (failed(verifySynchronizationHint(op, op.hint()))) { + return failure(); + } + if (!op.name().hasValue() && (op.hint() != 0)) return op.emitOpError() << "must specify a name unless the effect is as if " "hint(none) is specified"; diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -300,14 +300,8 @@ llvm::OpenMPIRBuilder::LocationDescription ompLoc( builder.saveIP(), builder.getCurrentDebugLocation()); llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); - llvm::Constant *hint = nullptr; - if (criticalOp.hint().hasValue()) { - hint = - llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), - static_cast(criticalOp.hint().getValue())); - } else { - hint = llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0); - } + llvm::Constant *hint = llvm::ConstantInt::get( + llvm::Type::getInt32Ty(llvmContext), static_cast(criticalOp.hint())); builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical( ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint)); return success(); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -298,7 +298,7 @@ func @omp_critical1() -> () { // expected-error @below {{must specify a name unless the effect is as if hint(none) is specified}} - omp.critical hint(nonspeculative) { + omp.critical hint(4) { omp.terminator } return @@ -308,7 +308,27 @@ func @omp_critical2() -> () { // expected-error @below {{expected symbol reference @excl to point to a critical declaration}} - omp.critical(@excl) hint(speculative) { + omp.critical(@excl) hint(8) { + omp.terminator + } + return +} + +// ----- + +func @omp_critical() -> () { + // expected-error @below {{the hints omp_sync_hint_uncontended and omp_sync_hint_contended cannot be combined}} + omp.critical() hint(3) { + omp.terminator + } + return +} + +// ----- + +func @omp_critical() -> () { + // expected-error @below {{the hints omp_sync_hint_nonspeculative and omp_sync_hint_speculative cannot be combined}} + omp.critical() hint(12) { omp.terminator } return diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -375,11 +375,45 @@ // CHECK-LABEL: omp_critical func @omp_critical() -> () { + // CHECK: omp.critical omp.critical { omp.terminator } - omp.critical(@mutex) hint(nonspeculative) { + // CHECK: omp.critical(@{{.*}}) hint(4) + omp.critical(@mutex) hint(4) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(1) + omp.critical(@mutex) hint(1) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(2) + omp.critical(@mutex) hint(2) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(4) + omp.critical(@mutex) hint(4) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(5) + omp.critical(@mutex) hint(5) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(6) + omp.critical(@mutex) hint(6) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(8) + omp.critical(@mutex) hint(8) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(9) + omp.critical(@mutex) hint(9) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(10) + omp.critical(@mutex) hint(10) { omp.terminator } return diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -486,7 +486,7 @@ // CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_mutex.var{{.*}}, i32 2) // CHECK: br label %omp.critical.region // CHECK: omp.critical.region - omp.critical(@mutex) hint(contended) { + omp.critical(@mutex) hint(2) { // CHECK: store llvm.store %xval, %x : !llvm.ptr omp.terminator