diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -144,6 +144,26 @@ ]; } +def OMP_MEMORY_ORDER_SeqCst : ClauseVal<"seq_cst", 1, 1> {} +def OMP_MEMORY_ORDER_AcqRel : ClauseVal<"acq_rel", 2, 1> {} +def OMP_MEMORY_ORDER_Acquire : ClauseVal<"acquire", 3, 1> {} +def OMP_MEMORY_ORDER_Release : ClauseVal<"release", 4, 1> {} +def OMP_MEMORY_ORDER_Relaxed : ClauseVal<"relaxed", 5, 1> {} +def OMP_MEMORY_ORDER_Default : ClauseVal<"default", 6, 0> { + let isDefault = 1; +} +def OMPC_MemoryOrder : Clause<"memory_order"> { + let enumClauseValue = "MemoryOrderKind"; + let allowedClauseValues = [ + OMP_MEMORY_ORDER_SeqCst, + OMP_MEMORY_ORDER_AcqRel, + OMP_MEMORY_ORDER_Acquire, + OMP_MEMORY_ORDER_Release, + OMP_MEMORY_ORDER_Relaxed, + OMP_MEMORY_ORDER_Default + ]; +} + def OMPC_Ordered : Clause<"ordered"> { let clangClass = "OMPOrderedClause"; let flangClass = "ScalarIntConstantExpr"; diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -415,6 +415,37 @@ let assemblyFormat = "attr-dict"; } +//===----------------------------------------------------------------------===// +// 2.17.7 atomic construct +//===----------------------------------------------------------------------===// + +// In the OpenMP Specification, atomic construct has an `atomic-clause` which +// can take the values `read`, `write`, `update` and `capture`. These four +// kinds of atomic constructs are fundamentally independent and are handled +// separately while lowering. Having four separate operations (one for each +// value of the clause) here decomposes handling of this construct into a +// two-step process. + +def AtomicReadOp : OpenMP_Op<"atomic.read"> { + let arguments = (ins OpenMP_PointerLikeType:$address, + DefaultValuedAttr:$hint, + OptionalAttr:$memory_order); + let results = (outs AnyType); + let parser = [{ return parseAtomicReadOp(parser, result); }]; + let printer = [{ return printAtomicReadOp(p, *this); }]; + let verifier = [{ return verifyAtomicReadOp(*this); }]; +} + +def AtomicWriteOp : OpenMP_Op<"atomic.write"> { + let arguments = (ins OpenMP_PointerLikeType:$address, + AnyType:$value, + DefaultValuedAttr:$hint, + OptionalAttr:$memory_order); + let parser = [{ return parseAtomicWriteOp(parser, result); }]; + let printer = [{ return printAtomicWriteOp(p, *this); }]; + let verifier = [{ return verifyAtomicWriteOp(*this); }]; +} + //===----------------------------------------------------------------------===// // 2.19.5.7 declare reduction Directive //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -370,6 +370,98 @@ return success(); } +//===----------------------------------------------------------------------===// +// Parser, printer and verifier for Synchronization Hint (2.17.12) +//===----------------------------------------------------------------------===// + +/// Parses a Synchronization Hint clause. The value of hint is an integer +/// which is a combination of different hints from `omp_sync_hint_t`. +/// +/// hint-clause = `hint` `(` hint-value `)` +static ParseResult parseSynchronizationHint(OpAsmParser &parser, + IntegerAttr &hintAttr, + bool parseKeyword = true) { + if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) { + hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); + return success(); + } + + if (failed(parser.parseLParen())) + return failure(); + StringRef hintKeyword; + int64_t hint = 0; + do { + if (failed(parser.parseKeyword(&hintKeyword))) + return failure(); + if (hintKeyword == "uncontended") + hint |= 1; + else if (hintKeyword == "contended") + hint |= 2; + else if (hintKeyword == "nonspeculative") + hint |= 4; + else if (hintKeyword == "speculative") + hint |= 8; + else + return parser.emitError(parser.getCurrentLocation()) + << hintKeyword << " is not a valid hint"; + } while (succeeded(parser.parseOptionalComma())); + if (failed(parser.parseRParen())) + return failure(); + hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); + return success(); +} + +/// Prints a Synchronization Hint clause +static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, + IntegerAttr hintAttr) { + int64_t hint = hintAttr.getInt(); + + if (hint == 0) + return; + + // Helper function to get n-th bit from the right end of `value` + auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; + + bool uncontended = bitn(hint, 0); + bool contended = bitn(hint, 1); + bool nonspeculative = bitn(hint, 2); + bool speculative = bitn(hint, 3); + + SmallVector hints; + if (uncontended) + hints.push_back("uncontended"); + if (contended) + hints.push_back("contended"); + if (nonspeculative) + hints.push_back("nonspeculative"); + if (speculative) + hints.push_back("speculative"); + + p << "hint("; + llvm::interleaveComma(hints, p); + p << ") "; +} + +/// Verifies a synchronization hint clause +static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { + + // Helper function to get n-th bit from the right end of `value` + auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; + + bool uncontended = bitn(hint, 0); + bool contended = bitn(hint, 1); + bool nonspeculative = bitn(hint, 2); + bool speculative = bitn(hint, 3); + + if (uncontended && contended) + return op->emitOpError() << "the hints omp_sync_hint_uncontended and " + "omp_sync_hint_contended cannot be combined"; + if (nonspeculative && speculative) + return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " + "omp_sync_hint_speculative cannot be combined."; + return success(); +} + enum ClauseType { ifClause, numThreadsClause, @@ -389,6 +481,8 @@ orderClause, orderedClause, inclusiveClause, + memoryOrderClause, + hintClause, COUNT }; @@ -616,6 +710,19 @@ return failure(); auto attr = UnitAttr::get(parser.getBuilder().getContext()); result.addAttribute("inclusive", attr); + } else if (clauseKeyword == "memory_order") { + StringRef memoryOrder; + if (checkAllowed(memoryOrderClause) || parser.parseLParen() || + parser.parseKeyword(&memoryOrder) || parser.parseRParen()) + return failure(); + result.addAttribute("memory_order", + parser.getBuilder().getStringAttr(memoryOrder)); + } else if (clauseKeyword == "hint") { + IntegerAttr hint; + if (checkAllowed(hintClause) || + parseSynchronizationHint(parser, hint, false)) + return failure(); + result.addAttribute("hint", hint); } else { return parser.emitError(parser.getNameLoc()) << clauseKeyword << " is not a valid clause"; @@ -1015,97 +1122,6 @@ return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); } -//===----------------------------------------------------------------------===// -// Parser, printer and verifier for Synchronization Hint (2.17.12) -//===----------------------------------------------------------------------===// - -/// Parses a Synchronization Hint clause. The value of hint is an integer -/// which is a combination of different hints from `omp_sync_hint_t`. -/// -/// hint-clause = `hint` `(` hint-value `)` -static ParseResult parseSynchronizationHint(OpAsmParser &parser, - IntegerAttr &hintAttr) { - if (failed(parser.parseOptionalKeyword("hint"))) { - hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); - return success(); - } - - if (failed(parser.parseLParen())) - return failure(); - StringRef hintKeyword; - int64_t hint = 0; - do { - if (failed(parser.parseKeyword(&hintKeyword))) - return failure(); - if (hintKeyword == "uncontended") - hint |= 1; - else if (hintKeyword == "contended") - hint |= 2; - else if (hintKeyword == "nonspeculative") - hint |= 4; - else if (hintKeyword == "speculative") - hint |= 8; - else - return parser.emitError(parser.getCurrentLocation()) - << hintKeyword << " is not a valid hint"; - } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRParen())) - return failure(); - hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); - return success(); -} - -/// Prints a Synchronization Hint clause -static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, - IntegerAttr hintAttr) { - int64_t hint = hintAttr.getInt(); - - if (hint == 0) - return; - - // Helper function to get n-th bit from the right end of `value` - auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; - - bool uncontended = bitn(hint, 0); - bool contended = bitn(hint, 1); - bool nonspeculative = bitn(hint, 2); - bool speculative = bitn(hint, 3); - - SmallVector hints; - if (uncontended) - hints.push_back("uncontended"); - if (contended) - hints.push_back("contended"); - if (nonspeculative) - hints.push_back("nonspeculative"); - if (speculative) - hints.push_back("speculative"); - - p << "hint("; - llvm::interleaveComma(hints, p); - p << ")"; -} - -/// Verifies a synchronization hint clause -static LogicalResult verifySynchronizationHint(Operation *op, int32_t hint) { - - // Helper function to get n-th bit from the right end of `value` - auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; - - bool uncontended = bitn(hint, 0); - bool contended = bitn(hint, 1); - bool nonspeculative = bitn(hint, 2); - bool speculative = bitn(hint, 3); - - if (uncontended && contended) - return op->emitOpError() << "the hints omp_sync_hint_uncontended and " - "omp_sync_hint_contended cannot be combined"; - if (nonspeculative && speculative) - return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " - "omp_sync_hint_speculative cannot be combined."; - return success(); -} - //===----------------------------------------------------------------------===// // Verifier for critical construct (2.17.1) //===----------------------------------------------------------------------===// @@ -1132,5 +1148,105 @@ return success(); } +//===----------------------------------------------------------------------===// +// AtomicReadOp +//===----------------------------------------------------------------------===// + +/// Parser for AtomicReadOp +/// +/// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type +/// address ::= operand `:` type +static ParseResult parseAtomicReadOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType address; + Type addressType; + SmallVector clauses = {memoryOrderClause, hintClause}; + SmallVector segments; + + if (parseClauses(parser, result, clauses, segments) || + parser.parseOperand(address) || parser.parseColonType(addressType) || + parser.resolveOperand(address, addressType, result.operands)) + return failure(); + + SmallVector resultType; + if (parser.parseArrowTypeList(resultType)) + return failure(); + result.addTypes(resultType); + return success(); +} + +/// Printer for AtomicReadOp +static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) { + p << " "; + if (op.memory_order()) + p << "memory_order(" << op.memory_order().getValue() << ") "; + if (op.hintAttr()) + printSynchronizationHint(p << " ", op, op.hintAttr()); + p << op.address() << " : " << op.address().getType() << " "; + p << " -> " << op.getType(); + return; +} + +/// Verifier for AtomicReadOp +static LogicalResult verifyAtomicReadOp(AtomicReadOp op) { + if (op.memory_order()) { + StringRef memOrder = op.memory_order().getValue(); + if (memOrder.equals("acq_rel") || memOrder.equals("release")) + return op.emitError( + "memory-order must not be acq_rel or release for atomic reads"); + } + return verifySynchronizationHint(op, op.hint()); +} + +//===----------------------------------------------------------------------===// +// AtomicWriteOp +//===----------------------------------------------------------------------===// + +/// Parser for AtomicWriteOp +/// +/// operation ::= `omp.atomic.write` atomic-clause-list operands +/// operands ::= address `,` value +/// address ::= operand `:` type +/// value ::= operand `:` type +static ParseResult parseAtomicWriteOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType address, value; + Type addrType, valueType; + SmallVector clauses = {memoryOrderClause, hintClause}; + SmallVector segments; + + if (parseClauses(parser, result, clauses, segments) || + parser.parseOperand(address) || parser.parseColonType(addrType) || + parser.resolveOperand(address, addrType, result.operands) || + parser.parseComma() || parser.parseOperand(value) || + parser.parseColonType(valueType) || + parser.resolveOperand(value, valueType, result.operands)) + return failure(); + return success(); +} + +/// Printer for AtomicWriteOp +static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) { + p << " "; + if (op.memory_order()) + p << "memory_order(" << op.memory_order() << ") "; + if (op.hintAttr()) + printSynchronizationHint(p, op, op.hintAttr()); + p << op.address() << " : " << op.address().getType() << ", " << op.value() + << " : " << op.value().getType(); + return; +} + +/// Verifier for AtomicWriteOp +static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) { + if (op.memory_order()) { + StringRef memoryOrder = op.memory_order().getValue(); + if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire")) + return op.emitError( + "memory-order must not be acq_rel or acquire for atomic writes"); + } + return verifySynchronizationHint(op, op.hint()); +} + #define GET_OP_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -345,3 +345,99 @@ omp.terminator } } + +// ----- + +func @omp_atomic_read1(%addr : memref) { + // expected-error @below {{the hints omp_sync_hint_nonspeculative and omp_sync_hint_speculative cannot be combined.}} + %1 = omp.atomic.read hint(speculative, nonspeculative) %addr : memref -> i32 + return +} + +// ----- + +func @omp_atomic_read2(%addr : memref) { + // expected-error @below {{attribute 'memory_order' failed to satisfy constraint: MemoryOrderKind Clause}} + %1 = omp.atomic.read memory_order(xyz) %addr : memref -> i32 + return +} + +// ----- + +func @omp_atomic_read3(%addr : memref) { + // expected-error @below {{memory-order must not be acq_rel or release for atomic reads}} + %1 = omp.atomic.read memory_order(acq_rel) %addr : memref -> i32 + return +} + +// ----- + +func @omp_atomic_read4(%addr : memref) { + // expected-error @below {{memory-order must not be acq_rel or release for atomic reads}} + %1 = omp.atomic.read memory_order(release) %addr : memref -> i32 + return +} + +// ----- + +func @omp_atomic_read5(%addr : memref) { + // expected-error @below {{at most one memory_order clause can appear on the omp.atomic.read operation}} + %1 = omp.atomic.read memory_order(acquire) memory_order(relaxed) %addr : memref -> i32 + return +} + +// ----- + +func @omp_atomic_read6(%addr : memref) { + // expected-error @below {{at most one hint clause can appear on the omp.atomic.read operation}} + %1 = omp.atomic.read hint(speculative) hint(contended) %addr : memref -> i32 + return +} + +// ----- + +func @omp_atomic_write1(%addr : memref, %val : i32) { + // expected-error @below {{the hints omp_sync_hint_uncontended and omp_sync_hint_contended cannot be combined}} + omp.atomic.write hint(contended, uncontended) %addr : memref, %val : i32 + return +} + +// ----- + +func @omp_atomic_write2(%addr : memref, %val : i32) { + // expected-error @below {{memory-order must not be acq_rel or acquire for atomic writes}} + omp.atomic.write memory_order(acq_rel) %addr : memref, %val : i32 + return +} + +// ----- + +func @omp_atomic_write3(%addr : memref, %val : i32) { + // expected-error @below {{memory-order must not be acq_rel or acquire for atomic writes}} + omp.atomic.write memory_order(acquire) %addr : memref, %val : i32 + return +} + +// ----- + +func @omp_atomic_write4(%addr : memref, %val : i32) { + // expected-error @below {{at most one memory_order clause can appear on the omp.atomic.write operation}} + omp.atomic.write memory_order(release) memory_order(seq_cst) %addr : memref, %val : i32 + return +} + +// ----- + +func @omp_atomic_write5(%addr : memref, %val : i32) { + // expected-error @below {{at most one hint clause can appear on the omp.atomic.write operation}} + omp.atomic.write hint(contended) hint(speculative) %addr : memref, %val : i32 + return +} + +// ----- + +func @omp_atomic_write6(%addr : memref, %val : i32) { + // expected-error @below {{attribute 'memory_order' failed to satisfy constraint: MemoryOrderKind Clause}} + omp.atomic.write memory_order(xyz) %addr : memref, %val : i32 + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -414,3 +414,35 @@ } return } + +// CHECK-LABEL: omp_atomic_read +func @omp_atomic_read(%addr : memref) { + // CHECK: %{{.*}} = omp.atomic.read %{{.*}} : memref -> i32 + %1 = omp.atomic.read %addr : memref -> i32 + // CHECK: %{{.*}} = omp.atomic.read memory_order(seq_cst) %{{.*}} : memref -> i32 + %2 = omp.atomic.read memory_order(seq_cst) %addr : memref -> i32 + // CHECK: %{{.*}} = omp.atomic.read memory_order(acquire) %{{.*}} : memref -> i32 + %5 = omp.atomic.read memory_order(acquire) %addr : memref -> i32 + // CHECK: %{{.*}} = omp.atomic.read memory_order(relaxed) %{{.*}} : memref -> i32 + %6 = omp.atomic.read memory_order(relaxed) %addr : memref -> i32 + // CHECK: %{{.*}} = omp.atomic.read hint(contended, nonspeculative) %{{.*}} : memref -> i32 + %7 = omp.atomic.read hint(nonspeculative, contended) %addr : memref -> i32 + // CHECK: %{{.*}} = omp.atomic.read memory_order(seq_cst) hint(contended, speculative) %{{.*}} : memref -> i32 + %8 = omp.atomic.read hint(speculative, contended) memory_order(seq_cst) %addr : memref -> i32 + return +} + +// CHECK-LABEL: omp_atomic_write +func @omp_atomic_write(%addr : memref, %val : i32) { + // CHECK: omp.atomic.write %{{.*}} : memref, %{{.*}} : i32 + omp.atomic.write %addr : memref, %val : i32 + // CHECK: omp.atomic.write memory_order(seq_cst) %{{.*}} : memref, %{{.*}} : i32 + omp.atomic.write memory_order(seq_cst) %addr : memref, %val : i32 + // CHECK: omp.atomic.write memory_order(release) %{{.*}} : memref, %{{.*}} : i32 + omp.atomic.write memory_order(release) %addr : memref, %val : i32 + // CHECK: omp.atomic.write memory_order(relaxed) %{{.*}} : memref, %{{.*}} : i32 + omp.atomic.write memory_order(relaxed) %addr : memref, %val : i32 + // CHECK: omp.atomic.write hint(uncontended, speculative) %{{.*}} : memref, %{{.*}} : i32 + omp.atomic.write hint(speculative, uncontended) %addr : memref, %val : i32 + return +}