diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -362,9 +362,14 @@ The name can be used in critical constructs in the dialect. }]; - let arguments = (ins SymbolNameAttr:$sym_name); + let arguments = (ins SymbolNameAttr:$sym_name, + DefaultValuedAttr:$hint); + + let assemblyFormat = [{ + $sym_name custom($hint) attr-dict + }]; - let assemblyFormat = "$sym_name attr-dict"; + let verifier = "return verifyCriticalDeclareOp(*this);"; } @@ -375,13 +380,12 @@ block (region) to be executed by only a single thread at a time. }]; - let arguments = (ins OptionalAttr:$name, - DefaultValuedAttr:$hint); + let arguments = (ins OptionalAttr:$name); let regions = (region AnyRegion:$region); let assemblyFormat = [{ - (`(` $name^ `)`)? custom($hint) $region attr-dict + (`(` $name^ `)`)? $region attr-dict }]; let verifier = "return ::verifyCriticalOp(*this);"; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1110,14 +1110,11 @@ // Verifier for critical construct (2.17.1) //===----------------------------------------------------------------------===// -static LogicalResult verifyCriticalOp(CriticalOp op) { +static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) { + return verifySynchronizationHint(op, op.hint()); +} - if (failed(verifySynchronizationHint(op, op.hint()))) { - return failure(); - } - if (!op.name().hasValue() && (op.hint() != 0)) - return op.emitOpError() << "must specify a name unless the effect is as if " - "no hint is specified"; +static LogicalResult verifyCriticalOp(CriticalOp op) { if (op.nameAttr()) { auto symbolRef = op.nameAttr().cast(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -300,8 +300,19 @@ llvm::OpenMPIRBuilder::LocationDescription ompLoc( builder.saveIP(), builder.getCurrentDebugLocation()); llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); - llvm::Constant *hint = llvm::ConstantInt::get( - llvm::Type::getInt32Ty(llvmContext), static_cast(criticalOp.hint())); + llvm::Constant *hint = nullptr; + + // If it has a name, it probably has a hint too. + if (criticalOp.nameAttr()) { + // The verifiers in OpenMP Dialect guarentee that all the pointers are + // non-null + auto symbolRef = criticalOp.nameAttr().cast(); + auto criticalDeclareOp = + SymbolTable::lookupNearestSymbolFrom(criticalOp, + symbolRef); + hint = llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), + static_cast(criticalDeclareOp.hint())); + } builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical( ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint)); return success(); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -296,19 +296,9 @@ // ----- -func @omp_critical1() -> () { - // expected-error @below {{must specify a name unless the effect is as if no hint is specified}} - omp.critical hint(nonspeculative) { - omp.terminator - } - return -} - -// ----- - func @omp_critical2() -> () { // expected-error @below {{expected symbol reference @excl to point to a critical declaration}} - omp.critical(@excl) hint(speculative) { + omp.critical(@excl) { omp.terminator } return @@ -316,32 +306,15 @@ // ----- -omp.critical.declare @mutex -func @omp_critical() -> () { - // expected-error @below {{the hints omp_sync_hint_uncontended and omp_sync_hint_contended cannot be combined}} - omp.critical(@mutex) hint(uncontended, contended) { - omp.terminator - } - return -} +// expected-error @below {{the hints omp_sync_hint_uncontended and omp_sync_hint_contended cannot be combined}} +omp.critical.declare @mutex hint(uncontended, contended) // ----- -omp.critical.declare @mutex -func @omp_critical() -> () { - // expected-error @below {{the hints omp_sync_hint_nonspeculative and omp_sync_hint_speculative cannot be combined}} - omp.critical(@mutex) hint(nonspeculative, speculative) { - omp.terminator - } - return -} +// expected-error @below {{the hints omp_sync_hint_nonspeculative and omp_sync_hint_speculative cannot be combined}} +omp.critical.declare @mutex hint(nonspeculative, speculative) // ----- -omp.critical.declare @mutex -func @omp_critica() -> () { - // expected-error @below {{invalid_hint is not a valid hint}} - omp.critical(@mutex) hint(invalid_hint) { - omp.terminator - } -} +// expected-error @below {{invalid_hint is not a valid hint}} +omp.critical.declare @mutex hint(invalid_hint) diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -369,9 +369,23 @@ return } -// CHECK: omp.critical.declare -// CHECK-LABEL: @mutex -omp.critical.declare @mutex +// CHECK: omp.critical.declare @mutex1 hint(uncontended) +omp.critical.declare @mutex1 hint(uncontended) +// CHECK: omp.critical.declare @mutex2 hint(contended) +omp.critical.declare @mutex2 hint(contended) +// CHECK: omp.critical.declare @mutex3 hint(nonspeculative) +omp.critical.declare @mutex3 hint(nonspeculative) +// CHECK: omp.critical.declare @mutex4 hint(speculative) +omp.critical.declare @mutex4 hint(speculative) +// CHECK: omp.critical.declare @mutex5 hint(uncontended, nonspeculative) +omp.critical.declare @mutex5 hint(uncontended, nonspeculative) +// CHECK: omp.critical.declare @mutex6 hint(contended, nonspeculative) +omp.critical.declare @mutex6 hint(contended, nonspeculative) +// CHECK: omp.critical.declare @mutex7 hint(uncontended, speculative) +omp.critical.declare @mutex7 hint(uncontended, speculative) +// CHECK: omp.critical.declare @mutex8 hint(contended, speculative) +omp.critical.declare @mutex8 hint(contended, speculative) + // CHECK-LABEL: omp_critical func @omp_critical() -> () { @@ -380,36 +394,8 @@ omp.terminator } - // CHECK: omp.critical(@{{.*}}) hint(uncontended) - omp.critical(@mutex) hint(uncontended) { - omp.terminator - } - // CHECK: omp.critical(@{{.*}}) hint(contended) - omp.critical(@mutex) hint(contended) { - omp.terminator - } - // CHECK: omp.critical(@{{.*}}) hint(nonspeculative) - omp.critical(@mutex) hint(nonspeculative) { - omp.terminator - } - // CHECK: omp.critical(@{{.*}}) hint(uncontended, nonspeculative) - omp.critical(@mutex) hint(uncontended, nonspeculative) { - omp.terminator - } - // CHECK: omp.critical(@{{.*}}) hint(contended, nonspeculative) - omp.critical(@mutex) hint(nonspeculative, contended) { - omp.terminator - } - // CHECK: omp.critical(@{{.*}}) hint(speculative) - omp.critical(@mutex) hint(speculative) { - omp.terminator - } - // CHECK: omp.critical(@{{.*}}) hint(uncontended, speculative) - omp.critical(@mutex) hint(uncontended, speculative) { - omp.terminator - } - // CHECK: omp.critical(@{{.*}}) hint(contended, speculative) - omp.critical(@mutex) hint(speculative, contended) { + // CHECK: omp.critical(@{{.*}}) + omp.critical(@mutex1) { omp.terminator } return diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -469,11 +469,11 @@ // ----- -omp.critical.declare @mutex +omp.critical.declare @mutex hint(contended) // CHECK-LABEL: @omp_critical llvm.func @omp_critical(%x : !llvm.ptr, %xval : i32) -> () { - // CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_.var{{.*}}, i32 0) + // CHECK: call void @__kmpc_critical({{.*}}critical_user_.var{{.*}}) // CHECK: br label %omp.critical.region // CHECK: omp.critical.region omp.critical { @@ -486,7 +486,7 @@ // CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_mutex.var{{.*}}, i32 2) // CHECK: br label %omp.critical.region // CHECK: omp.critical.region - omp.critical(@mutex) hint(contended) { + omp.critical(@mutex) { // CHECK: store llvm.store %xval, %x : !llvm.ptr omp.terminator