diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -708,6 +708,50 @@ let verifier = [{ return verifyAtomicUpdateOp(*this); }]; } +def AtomicCaptureOp : OpenMP_Op<"atomic.capture", + [SingleBlockImplicitTerminator<"TerminatorOp">]> { + let summary = "performs an atomic capture"; + let description = [{ + This operation performs an atomic capture. + + `hint` is the value of hint (as used in the hint clause). It is a compile + time constant. As the name suggests, this is just a hint for optimization. + + `memory_order` indicates the memory ordering behavior of the construct. It + can be one of `seq_cst`, `acq_rel`, `release`, `acquire` or `relaxed`. + + The region has the following allowed forms: + + ``` + omp.atomic.capture { + omp.atomic.update ... + omp.atomic.read ... + omp.terminator + } + + omp.atomic.capture { + omp.atomic.read ... + omp.atomic.update ... + omp.terminator + } + + omp.atomic.capture { + omp.atomic.read ... + omp.atomic.write ... + omp.terminator + } + ``` + + }]; + + let arguments = (ins DefaultValuedAttr:$hint, + OptionalAttr:$memory_order); + let regions = (region SizedRegion<1>:$region); + let parser = [{ return parseAtomicCaptureOp(parser, result); }]; + let printer = [{ return printAtomicCaptureOp(p, *this); }]; + let verifier = [{ return verifyAtomicCaptureOp(*this); }]; +} + //===----------------------------------------------------------------------===// // 2.19.5.7 declare reduction Directive //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1539,6 +1539,68 @@ return success(); } +//===----------------------------------------------------------------------===// +// AtomicCaptureOp +//===----------------------------------------------------------------------===// + +/// Parser for AtomicCaptureOp +static LogicalResult parseAtomicCaptureOp(OpAsmParser &parser, + OperationState &result) { + SmallVector clauses = {memoryOrderClause, hintClause}; + SmallVector segments; + if (parseClauses(parser, result, clauses, segments) || + parser.parseRegion(*result.addRegion())) + return failure(); + return success(); +} + +/// Printer for AtomicCaptureOp +static void printAtomicCaptureOp(OpAsmPrinter &p, AtomicCaptureOp op) { + if (op.memory_order()) + p << "memory_order(" << op.memory_order() << ") "; + if (op.hintAttr()) + printSynchronizationHint(p, op, op.hintAttr()); + p.printRegion(op.region()); +} + +/// Verifier for AtomicCaptureOp +static LogicalResult verifyAtomicCaptureOp(AtomicCaptureOp op) { + Block::OpListType &ops = op.region().front().getOperations(); + if (ops.size() != 3) + return emitError(op.getLoc()) + << "expected three operations in omp.atomic.capture region (one " + "terminator, and two atomic ops)"; + auto &firstOp = ops.front(); + auto &secondOp = *ops.getNextNode(firstOp); + auto firstReadStmt = dyn_cast(firstOp); + auto firstUpdateStmt = dyn_cast(firstOp); + auto secondReadStmt = dyn_cast(secondOp); + auto secondUpdateStmt = dyn_cast(secondOp); + auto secondWriteStmt = dyn_cast(secondOp); + + if (!((firstUpdateStmt && secondReadStmt) || + (firstReadStmt && secondUpdateStmt) || + (firstReadStmt && secondWriteStmt))) + return emitError(ops.front().getLoc()) + << "invalid sequence of operations in the capture region"; + if (firstUpdateStmt && secondReadStmt && + firstUpdateStmt.x() != secondReadStmt.x()) + return emitError(firstUpdateStmt.getLoc()) + << "updated variable in omp.atomic.update must be captured in " + "second operation"; + if (firstReadStmt && secondUpdateStmt && + firstReadStmt.x() != secondUpdateStmt.x()) + return emitError(firstReadStmt.getLoc()) + << "captured variable in omp.atomic.read must be updated in second " + "operation"; + if (firstReadStmt && secondWriteStmt && + firstReadStmt.x() != secondWriteStmt.address()) + return emitError(firstReadStmt.getLoc()) + << "captured variable in omp.atomic.read must be updated in " + "second operation"; + return success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -650,6 +650,122 @@ // ----- +func @omp_atomic_capture(%x: memref, %v: memref, %expr: i32) { + // expected-error @below {{expected three operations in omp.atomic.capture region}} + omp.atomic.capture { + omp.atomic.read %v = %x : memref + omp.terminator + } + return +} + +// ----- + +func @omp_atomic_capture(%x: memref, %v: memref, %expr: i32) { + omp.atomic.capture { + // expected-error @below {{invalid sequence of operations in the capture region}} + omp.atomic.read %v = %x : memref + omp.atomic.read %v = %x : memref + omp.terminator + } + return +} + +// ----- + +func @omp_atomic_capture(%x: memref, %v: memref, %expr: i32) { + omp.atomic.capture { + // expected-error @below {{invalid sequence of operations in the capture region}} + omp.atomic.update %x = %x add %expr : memref, i32 + omp.atomic.update %x = %x sub %expr : memref, i32 + omp.terminator + } + return +} + +// ----- + +func @omp_atomic_capture(%x: memref, %v: memref, %expr: i32) { + omp.atomic.capture { + // expected-error @below {{invalid sequence of operations in the capture region}} + omp.atomic.write %x = %expr : memref, i32 + omp.atomic.write %x = %expr : memref, i32 + omp.terminator + } + return +} + +// ----- + +func @omp_atomic_capture(%x: memref, %v: memref, %expr: i32) { + omp.atomic.capture { + // expected-error @below {{invalid sequence of operations in the capture region}} + omp.atomic.write %x = %expr : memref, i32 + omp.atomic.update %x = %x add %expr : memref, i32 + omp.terminator + } + return +} + +// ----- + +func @omp_atomic_capture(%x: memref, %v: memref, %expr: i32) { + omp.atomic.capture { + // expected-error @below {{invalid sequence of operations in the capture region}} + omp.atomic.update %x = %x add %expr : memref, i32 + omp.atomic.write %x = %expr : memref, i32 + omp.terminator + } + return +} + +// ----- + +func @omp_atomic_capture(%x: memref, %v: memref, %expr: i32) { + omp.atomic.capture { + // expected-error @below {{invalid sequence of operations in the capture region}} + omp.atomic.write %x = %expr : memref, i32 + omp.atomic.read %v = %x : memref + omp.terminator + } + return +} + +// ----- + +func @omp_atomic_capture(%x: memref, %y: memref, %v: memref, %expr: i32) { + omp.atomic.capture { + // expected-error @below {{updated variable in omp.atomic.update must be captured in second operation}} + omp.atomic.update %x = %x add %expr : memref, i32 + omp.atomic.read %v = %y : memref + omp.terminator + } +} + +// ----- + +func @omp_atomic_capture(%x: memref, %y: memref, %v: memref, %expr: i32) { + omp.atomic.capture { + // expected-error @below {{captured variable in omp.atomic.read must be updated in second operation}} + omp.atomic.read %v = %y : memref + omp.atomic.update %x = %x add %expr : memref, i32 + omp.terminator + } +} + +// ----- + +func @omp_atomic_capture(%x: memref, %y: memref, %v: memref, %expr: i32) { + omp.atomic.capture { + // expected-error @below {{captured variable in omp.atomic.read must be updated in second operation}} + omp.atomic.read %v = %x : memref + omp.atomic.write %y = %expr : memref, i32 + omp.terminator + } +} + +// ----- + func @omp_sections(%data_var1 : memref, %data_var2 : memref, %data_var3 : memref) -> () { // expected-error @below {{operand used in both private and firstprivate clauses}} omp.sections private(%data_var1 : memref) firstprivate(%data_var1 : memref) { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -584,6 +584,42 @@ return } +// CHECK-LABEL: omp_atomic_capture +// CHECK-SAME: (%[[v:.*]]: memref, %[[x:.*]]: memref, %[[expr:.*]]: i32) +func @omp_atomic_capture(%v: memref, %x: memref, %expr: i32) { + // CHECK: omp.atomic.capture{ + // CHECK-NEXT: omp.atomic.update %[[x]] = %[[expr]] add %[[x]] : memref, i32 + // CHECK-NEXT: omp.atomic.read %[[v]] = %[[x]] : memref + // CHECK-NEXT: omp.terminator + // CHECK-NEXT: } + omp.atomic.capture{ + omp.atomic.update %x = %expr add %x : memref, i32 + omp.atomic.read %v = %x : memref + omp.terminator + } + // CHECK: omp.atomic.capture{ + // CHECK-NEXT: omp.atomic.read %[[v]] = %[[x]] : memref + // CHECK-NEXT: omp.atomic.update %[[x]] = %[[expr]] add %[[x]] : memref, i32 + // CHECK-NEXT: omp.terminator + // CHECK-NEXT: } + omp.atomic.capture{ + omp.atomic.read %v = %x : memref + omp.atomic.update %x = %expr add %x : memref, i32 + omp.terminator + } + // CHECK: omp.atomic.capture{ + // CHECK-NEXT: omp.atomic.read %[[v]] = %[[x]] : memref + // CHECK-NEXT: omp.atomic.write %[[x]] = %[[expr]] : memref, i32 + // CHECK-NEXT: omp.terminator + // CHECK-NEXT: } + omp.atomic.capture{ + omp.atomic.read %v = %x : memref + omp.atomic.write %x = %expr : memref, i32 + omp.terminator + } + return +} + // CHECK-LABEL: omp_sectionsop func @omp_sectionsop(%data_var1 : memref, %data_var2 : memref, %data_var3 : memref, %redn_var : !llvm.ptr) {