diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1112,7 +1112,8 @@ } def AtomicUpdateOp : OpenMP_Op<"atomic.update", - [SingleBlockImplicitTerminator<"YieldOp">]> { + [SingleBlockImplicitTerminator<"YieldOp">, + RecursiveSideEffects]> { let summary = "performs an atomic update"; @@ -1145,7 +1146,9 @@ operations. }]; - let arguments = (ins OpenMP_PointerLikeType:$x, + let arguments = (ins Arg:$x, DefaultValuedAttr:$hint_val, OptionalAttr:$memory_order_val); let regions = (region SizedRegion<1>:$region); @@ -1156,10 +1159,19 @@ }]; let hasVerifier = 1; let hasRegionVerifier = 1; + let hasCanonicalizeMethod = 1; let extraClassDeclaration = [{ Operation* getFirstOp() { return &getRegion().front().getOperations().front(); } + + /// Returns true if the new value is same as old value and the operation is + /// a no-op, false otherwise. + bool isNoOp(); + + /// Returns the new value if the operation is equivalent to just a write + /// operation. Otherwise, returns nullptr. + Value getWriteOpVal(); }]; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -836,6 +836,34 @@ // Verifier for AtomicUpdateOp //===----------------------------------------------------------------------===// +bool AtomicUpdateOp::isNoOp() { + YieldOp yieldOp = dyn_cast(getFirstOp()); + return (yieldOp && + yieldOp.results().front() == getRegion().front().getArgument(0)); +} + +Value AtomicUpdateOp::getWriteOpVal() { + YieldOp yieldOp = dyn_cast(getFirstOp()); + if (yieldOp && + yieldOp.results().front() != getRegion().front().getArgument(0)) + return yieldOp.results().front(); + return nullptr; +} + +LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op, + PatternRewriter &rewriter) { + if (op.isNoOp()) { + rewriter.eraseOp(op); + return success(); + } + if (Value writeVal = op.getWriteOpVal()) { + rewriter.replaceOpWithNewOp( + op, op.x(), writeVal, op.hint_valAttr(), op.memory_order_valAttr()); + return success(); + } + return failure(); +} + LogicalResult AtomicUpdateOp::verify() { if (auto mo = memory_order_val()) { if (*mo == ClauseMemoryOrderKind::Acq_rel || @@ -845,6 +873,9 @@ } } + if (region().getNumArguments() != 1) + return emitError("the region must accept exactly one argument"); + if (x().getType().cast().getElementType() != region().getArgument(0).getType()) { return emitError("the type of the operand must be a pointer type whose " @@ -855,12 +886,6 @@ } LogicalResult AtomicUpdateOp::verifyRegions() { - if (region().getNumArguments() != 1) - return emitError("the region must accept exactly one argument"); - - if (region().front().getOperations().size() < 2) - return emitError() << "the update region must have at least two operations " - "(binop and terminator)"; YieldOp yieldOp = *region().getOps().begin(); diff --git a/mlir/test/Dialect/OpenMP/canonicalize.mlir b/mlir/test/Dialect/OpenMP/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/OpenMP/canonicalize.mlir @@ -0,0 +1,74 @@ +// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s + +func.func @update_no_op(%x : memref) { + omp.atomic.update %x : memref { + ^bb0(%xval : i32): + omp.yield(%xval : i32) + } + return +} + +// CHECK-LABEL: func.func @update_no_op +// CHECK-NOT: omp.atomic.update + +// ----- + +func.func @update_write_op(%x : memref, %value: i32) { + omp.atomic.update %x : memref { + ^bb0(%xval : i32): + omp.yield(%value : i32) + } + return +} + +// CHECK-LABEL: func.func @update_write_op +// CHECK-SAME: (%[[X:.+]]: memref, %[[VALUE:.+]]: i32) +// CHECK: omp.atomic.write %[[X]] = %[[VALUE]] : memref, i32 +// CHECK-NOT: omp.atomic.update + +// ----- + +func.func @update_normal(%x : memref, %value: i32) { + omp.atomic.update %x : memref { + ^bb0(%xval : i32): + %newval = arith.addi %xval, %value : i32 + omp.yield(%newval : i32) + } + return +} + +// CHECK-LABEL: func.func @update_normal +// CHECK: omp.atomic.update +// CHECK: arith.addi +// CHECK: omp.yield + +// ----- + +func.func @update_unnecessary_computations(%x: memref) { + %c0 = arith.constant 0 : i32 + omp.atomic.update %x : memref { + ^bb0(%xval: i32): + %newval = arith.addi %xval, %c0 : i32 + omp.yield(%newval: i32) + } + return +} + +// CHECK-LABEL: func.func @update_unnecessary_computations +// CHECK-NOT: omp.atomic.update + +// ----- + +func.func @update_unnecessary_computations(%x: memref) { + %c0 = arith.constant 0 : i32 + omp.atomic.update %x : memref { + ^bb0(%xval: i32): + %newval = arith.muli %xval, %c0 : i32 + omp.yield(%newval: i32) + } + return +} + +// CHECK-LABEL: func.func @update_unnecessary_computations +// CHECK-NOT: omp.atomic.update +// CHECK: omp.atomic.write diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -705,17 +705,6 @@ // ----- -func.func @omp_atomic_update9(%x: memref, %expr: i32) { - // expected-error @below {{the update region must have at least two operations (binop and terminator)}} - omp.atomic.update %x : memref { - ^bb0(%xval: i32): - omp.yield (%xval : i32) - } - return -} - -// ----- - func.func @omp_atomic_update(%x: memref, %expr: i32) { // expected-error @below {{the hints omp_sync_hint_uncontended and omp_sync_hint_contended cannot be combined}} omp.atomic.update hint(uncontended, contended) %x : memref { diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -757,6 +757,25 @@ omp.yield(%newval : i1) } + // CHECK: omp.atomic.update %[[X]] : memref { + // CHECK-NEXT: (%[[XVAL:.*]]: i32): + // CHECK-NEXT: omp.yield(%[[XVAL]] : i32) + // CHECK-NEXT: } + omp.atomic.update %x : memref { + ^bb0(%xval:i32): + omp.yield(%xval:i32) + } + + // CHECK: omp.atomic.update %[[X]] : memref { + // CHECK-NEXT: (%[[XVAL:.*]]: i32): + // CHECK-NEXT: omp.yield(%{{.+}} : i32) + // CHECK-NEXT: } + %const = arith.constant 42 : i32 + omp.atomic.update %x : memref { + ^bb0(%xval:i32): + omp.yield(%const:i32) + } + // CHECK: omp.atomic.update hint(none) %[[X]] : memref // CHECK-NEXT: (%[[XVAL:.*]]: i32): // CHECK-NEXT: %[[NEWVAL:.*]] = llvm.add %[[XVAL]], %[[EXPR]] : i32