diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -367,21 +367,6 @@ } -// TODO: Autogenerate this from OMP.td in llvm/include/Frontend -def omp_sync_hint_none: I32EnumAttrCase<"none", 0>; -def omp_sync_hint_uncontended: I32EnumAttrCase<"uncontended", 1>; -def omp_sync_hint_contended: I32EnumAttrCase<"contended", 2>; -def omp_sync_hint_nonspeculative: I32EnumAttrCase<"nonspeculative", 3>; -def omp_sync_hint_speculative: I32EnumAttrCase<"speculative", 4>; - -def SyncHintKind: I32EnumAttr<"SyncHintKind", "OpenMP Sync Hint Kind", - [omp_sync_hint_none, omp_sync_hint_uncontended, omp_sync_hint_contended, - omp_sync_hint_nonspeculative, omp_sync_hint_speculative]> { - let cppNamespace = "::mlir::omp"; - let stringToSymbolFnName = "ConvertToEnum"; - let symbolToStringFnName = "ConvertToString"; -} - def CriticalOp : OpenMP_Op<"critical"> { let summary = "critical construct"; let description = [{ @@ -390,12 +375,12 @@ }]; let arguments = (ins OptionalAttr:$name, - OptionalAttr:$hint); + DefaultValuedAttr:$hint); let regions = (region AnyRegion:$region); let assemblyFormat = [{ - (`(` $name^ `)`)? (`hint` `(` $hint^ `)`)? $region attr-dict + (`(` $name^ `)`)? custom($hint) $region attr-dict }]; let verifier = "return ::verifyCriticalOp(*this);"; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -971,11 +971,109 @@ return success(); } +//===----------------------------------------------------------------------===// +// Parser, printer and verifier for Synchronization Hint (2.17.12) +//===----------------------------------------------------------------------===// + +/// Parses a Synchronization Hint clause. The value of hint is an integer +/// which is a combination of different hints from `omp_sync_hint_t`. +/// +/// hint-clause = `hint` `(` hint-value `)` +static ParseResult parseSynchronizationHint(OpAsmParser &parser, + IntegerAttr &hintAttr) { + if (failed(parser.parseOptionalKeyword("hint"))) { + hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); + return success(); + } + + if (failed(parser.parseLParen())) + return failure(); + StringRef hintKeyword; + int64_t hint = 0; + do { + if (failed(parser.parseKeyword(&hintKeyword))) + return failure(); + if (hintKeyword == "uncontended") + hint |= 1; + else if (hintKeyword == "contended") + hint |= 2; + else if (hintKeyword == "nonspeculative") + hint |= 4; + else if (hintKeyword == "speculative") + hint |= 8; + else + return parser.emitError(parser.getCurrentLocation()) + << hintKeyword << " is not a valid hint"; + } while (succeeded(parser.parseOptionalComma())); + if (failed(parser.parseRParen())) + return failure(); + hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); + return success(); +} + +// Prints a Synchronization Hint clause +static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, + IntegerAttr hintAttr) { + int64_t hint = hintAttr.getInt(); + + if (hint == 0) + return; + + // Helper function to get n-th bit from the right end of `value` + auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; + + bool uncontended = bitn(hint, 0); + bool contended = bitn(hint, 1); + bool nonspeculative = bitn(hint, 2); + bool speculative = bitn(hint, 3); + + SmallVector hints; + if (uncontended) + hints.push_back("uncontended"); + if (contended) + hints.push_back("contended"); + if (nonspeculative) + hints.push_back("nonspeculative"); + if (speculative) + hints.push_back("speculative"); + + p << "hint("; + llvm::interleaveComma(hints, p); + p << ")"; +} + +// Verifies a synchronization hint clause +static LogicalResult verifySynchronizationHint(Operation *op, int32_t hint) { + + // Helper function to get n-th bit from the right end of `value` + auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; + + bool uncontended = bitn(hint, 0); + bool contended = bitn(hint, 1); + bool nonspeculative = bitn(hint, 2); + bool speculative = bitn(hint, 3); + + if (uncontended && contended) + return op->emitOpError() << "the hints omp_sync_hint_uncontended and " + "omp_sync_hint_contended cannot be combined"; + if (nonspeculative && speculative) + return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " + "omp_sync_hint_speculative cannot be combined."; + return success(); +} + +//===----------------------------------------------------------------------===// +// Verifier for critical construct (2.17.1) +//===----------------------------------------------------------------------===// + static LogicalResult verifyCriticalOp(CriticalOp op) { - if (!op.name().hasValue() && op.hint().hasValue() && - (op.hint().getValue() != SyncHintKind::none)) + + if (failed(verifySynchronizationHint(op, op.hint()))) { + return failure(); + } + if (!op.name().hasValue() && (op.hint() != 0)) return op.emitOpError() << "must specify a name unless the effect is as if " - "hint(none) is specified"; + "no hint is specified"; if (op.nameAttr()) { auto symbolRef = op.nameAttr().cast(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -300,14 +300,8 @@ llvm::OpenMPIRBuilder::LocationDescription ompLoc( builder.saveIP(), builder.getCurrentDebugLocation()); llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); - llvm::Constant *hint = nullptr; - if (criticalOp.hint().hasValue()) { - hint = - llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), - static_cast(criticalOp.hint().getValue())); - } else { - hint = llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0); - } + llvm::Constant *hint = llvm::ConstantInt::get( + llvm::Type::getInt32Ty(llvmContext), static_cast(criticalOp.hint())); builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical( ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint)); return success(); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -297,7 +297,7 @@ // ----- func @omp_critical1() -> () { - // expected-error @below {{must specify a name unless the effect is as if hint(none) is specified}} + // expected-error @below {{must specify a name unless the effect is as if no hint is specified}} omp.critical hint(nonspeculative) { omp.terminator } @@ -313,3 +313,35 @@ } return } + +// ----- + +omp.critical.declare @mutex +func @omp_critical() -> () { + // expected-error @below {{the hints omp_sync_hint_uncontended and omp_sync_hint_contended cannot be combined}} + omp.critical(@mutex) hint(uncontended, contended) { + omp.terminator + } + return +} + +// ----- + +omp.critical.declare @mutex +func @omp_critical() -> () { + // expected-error @below {{the hints omp_sync_hint_nonspeculative and omp_sync_hint_speculative cannot be combined}} + omp.critical(@mutex) hint(nonspeculative, speculative) { + omp.terminator + } + return +} + +// ----- + +omp.critical.declare @mutex +func @omp_critica() -> () { + // expected-error @below {{invalid_hint is not a valid hint}} + omp.critical(@mutex) hint(invalid_hint) { + omp.terminator + } +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -375,12 +375,42 @@ // CHECK-LABEL: omp_critical func @omp_critical() -> () { + // CHECK: omp.critical omp.critical { omp.terminator } + // CHECK: omp.critical(@{{.*}}) hint(uncontended) + omp.critical(@mutex) hint(uncontended) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(contended) + omp.critical(@mutex) hint(contended) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(nonspeculative) omp.critical(@mutex) hint(nonspeculative) { omp.terminator } + // CHECK: omp.critical(@{{.*}}) hint(uncontended, nonspeculative) + omp.critical(@mutex) hint(uncontended, nonspeculative) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(contended, nonspeculative) + omp.critical(@mutex) hint(nonspeculative, contended) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(speculative) + omp.critical(@mutex) hint(speculative) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(uncontended, speculative) + omp.critical(@mutex) hint(uncontended, speculative) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint(contended, speculative) + omp.critical(@mutex) hint(speculative, contended) { + omp.terminator + } return }