diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -367,21 +367,6 @@ } -// TODO: Autogenerate this from OMP.td in llvm/include/Frontend -def omp_sync_hint_none: I32EnumAttrCase<"none", 0>; -def omp_sync_hint_uncontended: I32EnumAttrCase<"uncontended", 1>; -def omp_sync_hint_contended: I32EnumAttrCase<"contended", 2>; -def omp_sync_hint_nonspeculative: I32EnumAttrCase<"nonspeculative", 3>; -def omp_sync_hint_speculative: I32EnumAttrCase<"speculative", 4>; - -def SyncHintKind: I32EnumAttr<"SyncHintKind", "OpenMP Sync Hint Kind", - [omp_sync_hint_none, omp_sync_hint_uncontended, omp_sync_hint_contended, - omp_sync_hint_nonspeculative, omp_sync_hint_speculative]> { - let cppNamespace = "::mlir::omp"; - let stringToSymbolFnName = "ConvertToEnum"; - let symbolToStringFnName = "ConvertToString"; -} - def CriticalOp : OpenMP_Op<"critical"> { let summary = "critical construct"; let description = [{ @@ -390,12 +375,13 @@ }]; let arguments = (ins OptionalAttr:$name, - OptionalAttr:$hint); + UnitAttr:$has_hint, + DefaultValuedAttr:$hint); let regions = (region AnyRegion:$region); let assemblyFormat = [{ - (`(` $name^ `)`)? (`hint` `(` $hint^ `)`)? $region attr-dict + (`(` $name^ `)`)? (`hint` $has_hint^ custom($hint))? $region attr-dict }]; let verifier = "return ::verifyCriticalOp(*this);"; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -971,9 +971,82 @@ return success(); } +static ParseResult parseSynchronizationHint(OpAsmParser &parser, + IntegerAttr &hintAttr) { + if (failed(parser.parseLParen())) + return failure(); + StringRef hintKeyword; + int64_t hint = 0; + do { + if (failed(parser.parseKeyword(&hintKeyword))) + return failure(); + if (hintKeyword == "uncontended") + hint |= 1; + if (hintKeyword == "contended") + hint |= 2; + if (hintKeyword == "nonspeculative") + hint |= 4; + if (hintKeyword == "speculative") + hint |= 8; + } while (succeeded(parser.parseOptionalComma())); + if (failed(parser.parseRParen())) + return failure(); + hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); + return success(); +} + +static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, + IntegerAttr hintAttr) { + int32_t hint = hintAttr.getInt(); + + // Helper function to get n-th bit from the right end of `value` + auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; + + bool uncontended = bitn(hint, 0); + bool contended = bitn(hint, 1); + bool nonspeculative = bitn(hint, 2); + bool speculative = bitn(hint, 3); + + SmallVector hints; + if (uncontended) + hints.push_back("uncontended"); + if (contended) + hints.push_back("contended"); + if (nonspeculative) + hints.push_back("nonspeculative"); + if (speculative) + hints.push_back("speculative"); + + p << "("; + llvm::interleaveComma(hints, p); + p << ")"; +} + +static LogicalResult verifySynchronizationHint(Operation *op, int32_t hint) { + + // Helper function to get n-th bit from the right end of `value` + auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; + + bool uncontended = bitn(hint, 0); + bool contended = bitn(hint, 1); + bool nonspeculative = bitn(hint, 2); + bool speculative = bitn(hint, 3); + + if (uncontended && contended) + return op->emitOpError() << "the hints omp_sync_hint_uncontended and " + "omp_sync_hint_contended cannot be combined"; + if (nonspeculative && speculative) + return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " + "omp_sync_hint_speculative cannot be combined."; + return success(); +} + static LogicalResult verifyCriticalOp(CriticalOp op) { - if (!op.name().hasValue() && op.hint().hasValue() && - (op.hint().getValue() != SyncHintKind::none)) + + if (failed(verifySynchronizationHint(op, op.hint()))) { + return failure(); + } + if (!op.name().hasValue() && (op.hint() != 0)) return op.emitOpError() << "must specify a name unless the effect is as if " "hint(none) is specified"; diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -300,14 +300,8 @@ llvm::OpenMPIRBuilder::LocationDescription ompLoc( builder.saveIP(), builder.getCurrentDebugLocation()); llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); - llvm::Constant *hint = nullptr; - if (criticalOp.hint().hasValue()) { - hint = - llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), - static_cast(criticalOp.hint().getValue())); - } else { - hint = llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0); - } + llvm::Constant *hint = llvm::ConstantInt::get( + llvm::Type::getInt32Ty(llvmContext), static_cast(criticalOp.hint())); builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical( ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint)); return success(); diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -313,3 +313,23 @@ } return } + +// ----- + +func @omp_critical() -> () { + // expected-error @below {{the hints omp_sync_hint_uncontended and omp_sync_hint_contended cannot be combined}} + omp.critical() hint(uncontended, contended) { + omp.terminator + } + return +} + +// ----- + +func @omp_critical() -> () { + // expected-error @below {{the hints omp_sync_hint_nonspeculative and omp_sync_hint_speculative cannot be combined}} + omp.critical() hint(nonspeculative, speculative) { + omp.terminator + } + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -375,11 +375,41 @@ // CHECK-LABEL: omp_critical func @omp_critical() -> () { + // CHECK: omp.critical omp.critical { omp.terminator } - omp.critical(@mutex) hint(nonspeculative) { + // CHECK: omp.critical(@{{.*}}) hint (uncontended) + omp.critical(@mutex) hint (uncontended) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint (contended) + omp.critical(@mutex) hint (contended) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint (nonspeculative) + omp.critical(@mutex) hint (nonspeculative) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint (uncontended, nonspeculative) + omp.critical(@mutex) hint (uncontended, nonspeculative) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint (contended, nonspeculative) + omp.critical(@mutex) hint (nonspeculative, contended) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint (speculative) + omp.critical(@mutex) hint (speculative) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint (uncontended, speculative) + omp.critical(@mutex) hint (uncontended, speculative) { + omp.terminator + } + // CHECK: omp.critical(@{{.*}}) hint (contended, speculative) + omp.critical(@mutex) hint (speculative, contended) { omp.terminator } return