diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -693,7 +693,59 @@ OptionalAttr:$memory_order); let parser = [{ return parseAtomicUpdateOp(parser, result); }]; let printer = [{ return printAtomicUpdateOp(p, *this); }]; - let verifier = [{ return verifyAtomicUpdateOp(*this); }]; + let verifier = [{ return verifyAtomicUpdateOrCaptureOp(*this); }]; +} + +def AtomicCaptureOp : OpenMP_Op<"atomic.capture"> { + + let description = "performs an atomic capture"; + + let summary = [{ + This operation performs an atomic capture. + + An atomic capture is one of the following forms: + %[[X]] = %[[X]] binop %[[expr]], %[[V]] = %[[X]] + %[[X]] = %[[expr]] binop %[[X]], %[[V]] = %[[X]] + %[[V]] = %[[X]], %[[X]] = %[[X]] binop %[[expr]] + %[[V]] = %[[X]], %[[X]] = %[[expr]] binop %[[X]] + %[[V]] = %[[X]], %[[X]] = %[[expr]] + + The operands `x`, `v` and `expr` exactly the same as the operands `x` and + `expr` in the OpenMP Standard. The operand `x` is the address of the + variable that is being captured and updated. `x` is atomically + read/written. The evaluation of `expr` need not be atomic w.r.t the read or + write of the location designated by `x`. In general, type(x) must + dereference to type(expr). + + The attribute `binop` is the binary operation being performed atomically. + It is an optional attribute - it is absent in the fifth form of this + operation. + + The attribute `is_x_binop_expr` is + - true when the update expression is of the form `x binop expr` on RHS + - false when the update expression is of the form `expr binop x` on RHS + + The attribute `is_postfix_update` is + - true when the capture is before update + - false when the capture is after update + + `hint` is the value of hint (as used in the hint clause). It is a compile + time constant. As the name suggests, this is just a hint for optimization. + + `memory_order` indicates the memory ordering behavior of the construct. It + can be one of `seq_cst`, `acq_rel`, `release`, `acquire` or `relaxed`. + }]; + let arguments = (ins OpenMP_PointerLikeType:$x, + OpenMP_PointerLikeType:$v, + AnyType:$expr, + OptionalAttr:$binop, + UnitAttr:$is_x_binop_expr, + UnitAttr:$is_postfix_update, + DefaultValuedAttr:$hint, + OptionalAttr:$memory_order); + let parser = [{ return parseAtomicCaptureOp(parser, result); }]; + let printer = [{ return printAtomicCaptureOp(p, *this); }]; + let verifier = [{ return verifyAtomicUpdateOrCaptureOp(*this); }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -53,6 +53,10 @@ MemRefType::attachInterface>(*getContext()); } +bool equal(OpAsmParser::OperandType &a, OpAsmParser::OperandType &b) { + return a.name == b.name && a.number == b.number; +} + //===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// @@ -1491,16 +1495,162 @@ p << ": " << op.x().getType() << ", " << op.expr().getType(); } -/// Verifier for AtomicUpdateOp -static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) { - if (op.memory_order()) { - StringRef memoryOrder = op.memory_order().getValue(); +/// Verifier for AtomicUpdateOp and AtomicCaptureOp +static LogicalResult verifyAtomicUpdateOrCaptureOp(Operation *op) { + if (op->hasAttrOfType("memory_order")) { + StringRef memoryOrder = + op->getAttrOfType("memory_order").getValue(); if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire")) - return op.emitError( - "memory-order must not be acq_rel or acquire for atomic updates"); + return op->emitError("memory-order must not be acq_rel or acquire for " + "atomic update/capture"); } return success(); } +//===----------------------------------------------------------------------===// +// AtomicCaptureOp +//===----------------------------------------------------------------------===// + +/// Parse an atomic capture op +/// operation ::= `omp.atomic.capture` expr, expr : types +/// expr ::= operand `=` operand | operand `=` operand binop operand +/// types ::= captured_variable_type `,` capturing_variable_type `,` expr_type +static ParseResult parseAtomicCaptureOp(OpAsmParser &parser, + OperationState &result) { + SmallVector clauses = {hintClause, memoryOrderClause}; + SmallVector segments; + // Parse expressions %a = %b [binop|,] + OpAsmParser::OperandType x, v, expr; + StringRef binOp; + bool isPostfixUpdate = false; + bool isXLhsInRhsPart = false; + bool hasBinop = false; + + OpAsmParser::OperandType a, b; + if (parser.parseOperand(a) || parser.parseEqual() || parser.parseOperand(b)) + return failure(); + + // The next operand can either be a comma or a binary operation. + if (succeeded(parser.parseOptionalComma())) { + // %v = %x, %x = ... + v = a; + isPostfixUpdate = true; + OpAsmParser::OperandType lhs, rhs; + if (parser.parseOperand(x) || parser.parseEqual() || + parser.parseOperand(lhs)) + return failure(); + if (!equal(x, b)) + return parser.emitError(parser.getCurrentLocation()) + << "captured variable not found in LHS of second operation"; + if (succeeded(parser.parseOptionalKeyword(&binOp))) { + hasBinop = true; + if (parser.parseOperand(rhs)) + return failure(); + if (equal(lhs, x)) { + isXLhsInRhsPart = true; + expr = rhs; + } else if (equal(rhs, x)) { + isXLhsInRhsPart = false; + expr = lhs; + } else { + return parser.emitError(parser.getCurrentLocation()) + << "captured variable not found in RHS of second operation"; + } + } else { + // %v = %x, %x = %expr + hasBinop = false; + expr = lhs; + if (equal(x, lhs)) + return parser.emitError(parser.getCurrentLocation()) + << "self assignment without update to captured variable in " + "second operation"; + } + } else if (succeeded(parser.parseOptionalKeyword(&binOp))) { + // %x = ..., %v = %x + isPostfixUpdate = false; + hasBinop = true; + x = a; + OpAsmParser::OperandType lhs, rhs; + lhs = b; + if (parser.parseOperand(rhs) || parser.parseComma() || + parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(b)) + return failure(); + if (equal(lhs, x)) { + isXLhsInRhsPart = true; + expr = rhs; + } else if (equal(rhs, x)) { + isXLhsInRhsPart = false; + expr = lhs; + } else { + return parser.emitError(parser.getCurrentLocation()) + << "captured variable not found in RHS of first operation"; + } + if (!equal(x, b)) { + return parser.emitError(parser.getCurrentLocation()) + << "updated variable not captured in second operation"; + } + } else { + return parser.emitError(parser.getCurrentLocation()) + << "expected comma or valid binary operation"; + } + + if (isPostfixUpdate) + result.addAttribute("is_postfix_update", + UnitAttr::get(parser.getContext())); + if (hasBinop) { + result.addAttribute( + "binop", parser.getBuilder().getI64IntegerAttr( + (int64_t)AtomicBinOpKindToEnum(binOp.upper()).getValue())); + } + if (isXLhsInRhsPart) + result.addAttribute("is_x_binop_expr", UnitAttr::get(parser.getContext())); + + SmallVector types; + if (parseClauses(parser, result, clauses, segments) || + parser.parseColonTypeList(types)) + return failure(); + if (types.size() != 3) + return parser.emitError(parser.getNameLoc()) << "expected three types"; + if (parser.resolveOperand(x, types[0], result.operands) || + parser.resolveOperand(v, types[1], result.operands) || + parser.resolveOperand(expr, types[2], result.operands)) + return failure(); + return success(); +} + +/// Print an atomic capture op +static void printAtomicCaptureOp(OpAsmPrinter &p, AtomicCaptureOp op) { + p << " "; + if (op.is_postfix_update()) { + p << op.v() << " = " << op.x() << ", " << op.x() << " = "; + if (op.binop()) { + StringRef binop = AtomicBinOpKindToString(op.binop().getValue()).lower(); + if (op.is_x_binop_expr()) + p << op.x() << " " << binop << " " << op.expr() << " "; + else + p << op.expr() << " " << binop << " " << op.x() << " "; + } else { + p << op.expr() << " "; + } + } else { + p << op.x() << " = "; + if (op.binop()) { + StringRef binop = AtomicBinOpKindToString(op.binop().getValue()).lower(); + if (op.is_x_binop_expr()) + p << op.x() << " " << binop << " " << op.expr(); + else { + p << op.expr() << " " << binop << " " << op.x(); + } + } + p << ", " << op.v() << " = " << op.x() << " "; + } + if (op.memory_order()) + p << "memory_order(" << op.memory_order() << ") "; + if (op.hintAttr()) + printSynchronizationHint(p, op, op.hintAttr()); + p << ": " << op.x().getType() << ", " << op.v().getType() << ", " + << op.expr().getType(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -618,7 +618,7 @@ // ----- func @omp_atomic_update3(%x: memref, %expr: i32) { - // expected-error @below {{memory-order must not be acq_rel or acquire for atomic updates}} + // expected-error @below {{memory-order must not be acq_rel or acquire for atomic update/capture}} omp.atomic.update %x = %x add %expr memory_order(acq_rel) : memref, i32 return } @@ -626,7 +626,7 @@ // ----- func @omp_atomic_update4(%x: memref, %expr: i32) { - // expected-error @below {{memory-order must not be acq_rel or acquire for atomic updates}} + // expected-error @below {{memory-order must not be acq_rel or acquire for atomic update/capture}} omp.atomic.update %x = %x add %expr memory_order(acquire) : memref, i32 return } @@ -642,6 +642,86 @@ // ----- +func @omp_atomic_capture1(%x: memref, %expr: i32, %v: memref) { + // expected-error @below {{op operand #0 must be OpenMP-compatible variable type, but got 'i32'}} + omp.atomic.capture %expr = %expr add %x, %v = %expr : i32, memref, memref + return +} + +// ----- + +func @omp_atomic_capture2(%x: memref, %y: memref, %expr: i32, %v: memref) { + // expected-error @below {{captured variable not found in LHS of second operation}} + omp.atomic.capture %v = %x, %y = %y add %expr : memref, memref, i32 + return +} + +// ----- + +func @omp_atomic_capture3(%x: memref, %y: memref, %expr: i32, %v: memref) { + // expected-error @below {{captured variable not found in RHS of second operation}} + omp.atomic.capture %v = %x, %x = %y add %expr : memref, memref, i32 + return +} + +// ----- + +func @omp_atomic_capture4(%x: memref, %y: memref, %expr: i32, %v: memref) { + // expected-error @below {{self assignment without update to captured variable in second operation}} + omp.atomic.capture %v = %x, %x = %x : memref, memref, i32 + return +} + +// ----- + +func @omp_atomic_capture5(%x: memref, %y: memref, %expr: i32, %v: memref) { + // expected-error @below {{captured variable not found in RHS of first operation}} + omp.atomic.capture %x = %y add %expr, %v = %x : memref, memref, i32 + return +} + +// ----- + +func @omp_atomic_capture6(%x: memref, %expr: i32, %v: memref, %y: memref) { + // expected-error @below {{updated variable not captured in second operation}} + omp.atomic.capture %x = %x add %expr, %v = %y : memref, memref, i32 + return +} + +// ----- + +func @omp_atomic_capture7(%x: memref, %v: memref) { + // expected-error @below {{expected comma or valid binary operation}} + omp.atomic.capture %v = %x : memref, memref, i32 + return +} + +// ----- + +func @omp_atomic_capture8(%x: memref, %expr: i32, %v: memref) { + // expected-error @below {{expected three types}} + omp.atomic.capture %x = %x add %expr, %v = %x memory_order(acquire) : memref, i32 + return +} + +// ----- + +func @omp_atomic_capture9(%x: memref, %expr: i32, %v: memref) { + // expected-error @below {{memory-order must not be acq_rel or acquire for atomic update/capture}} + omp.atomic.capture %x = %x add %expr, %v = %x memory_order(acq_rel) : memref, memref, i32 + return +} + +// ----- + +func @omp_atomic_capture10(%x: memref, %expr: i32, %v: memref) { + // expected-error @below {{memory-order must not be acq_rel or acquire for atomic update/capture}} + omp.atomic.capture %x = %x add %expr, %v = %x memory_order(acquire) : memref, memref, i32 + return +} + +// ----- + func @omp_sections(%data_var1 : memref, %data_var2 : memref, %data_var3 : memref) -> () { // expected-error @below {{operand used in both private and firstprivate clauses}} omp.sections private(%data_var1 : memref) firstprivate(%data_var1 : memref) { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -582,6 +582,74 @@ omp.atomic.update %xBool = %exprBool neqv %xBool : memref, i1 // CHECK: omp.atomic.update %[[X]] = %[[EXPR]] add %[[X]] memory_order(seq_cst) hint(speculative) : memref, i32 omp.atomic.update %x = %expr add %x hint(speculative) memory_order(seq_cst) : memref, i32 + return +} + +// CHECK-LABEL: @omp_atomic_capture +// CHECK-SAME: (%[[X:.*]]: memref, %[[V:.*]]: memref, %[[E:.*]]: i32) +func @omp_atomic_capture(%x : memref, %v : memref, %e : i32) { + + // C/C++ expressions: + // v = ++x; + // v = --x; + // v = x binop= expr; + // v = x = x binop expr; + // x binop= expr; v = x; + // x = x binop expr; v = x; + // ++x; v = x; + // x++; v = x; + // --x; v = x; + // x--; v = x; + // CHECK: omp.atomic.capture %[[X]] = %[[X]] add %[[E]], %[[V]] = %[[X]] : memref, memref, i32 + omp.atomic.capture %x = %x add %e, %v = %x : memref, memref, i32 + + // C/C++ expressions: + // v = x = expr binop x; + // x = expr binop x; v = x; + // CHECK: omp.atomic.capture %[[X]] = %[[E]] add %[[X]], %[[V]] = %[[X]] : memref, memref, i32 + omp.atomic.capture %x = %e add %x, %v = %x : memref, memref, i32 + + // C/C++ expressions: + // v = x++; + // v = x--; + // v = x; x binop= expr; + // v = x; x = x binop expr; + // v = x; x++; + // v = x; ++x; + // v = x; x--; + // v = x; --x; + // CHECK: omp.atomic.capture %[[V]] = %[[X]], %[[X]] = %[[X]] sub %[[E]] : memref, memref, i32 + omp.atomic.capture %v = %x, %x = %x sub %e : memref, memref, i32 + + // C/C++ expressions: + // v = x; x = expr binop x; + // CHECK: omp.atomic.capture %[[V]] = %[[X]], %[[X]] = %[[E]] sub %[[X]] : memref, memref, i32 + omp.atomic.capture %v = %x, %x = %e sub %x : memref, memref, i32 + + // C/C++ expressions: + // v = x; x = expr; + // CHECK: omp.atomic.capture %[[V]] = %[[X]], %[[X]] = %[[E]] : memref, memref, i32 + omp.atomic.capture %v = %x, %x = %e : memref, memref, i32 + + // Fortran expressions: + // update-statement; capture-statement + // CHECK: omp.atomic.capture %[[X]] = %[[X]] add %[[E]], %[[V]] = %[[X]] : memref, memref, i32 + // CHECK: omp.atomic.capture %[[X]] = %[[E]] add %[[X]], %[[V]] = %[[X]] : memref, memref, i32 + omp.atomic.capture %x = %x add %e, %v = %x : memref, memref, i32 + omp.atomic.capture %x = %e add %x, %v = %x : memref, memref, i32 + + // Fortran expressions: + // capture-statement; update-statement + // CHECK: omp.atomic.capture %[[V]] = %[[X]], %[[X]] = %[[X]] sub %[[E]] : memref, memref, i32 + // CHECK: omp.atomic.capture %[[V]] = %[[X]], %[[X]] = %[[E]] sub %[[X]] : memref, memref, i32 + omp.atomic.capture %v = %x, %x = %x sub %e : memref, memref, i32 + omp.atomic.capture %v = %x, %x = %e sub %x : memref, memref, i32 + + // Fortran expressions: + // capture-statement; write-statement + // CHECK: omp.atomic.capture %[[V]] = %[[X]], %[[X]] = %[[E]] : memref, memref, i32 + omp.atomic.capture %v = %x, %x = %e : memref, memref, i32 + return }