diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -269,6 +269,8 @@ ( `attach` `(` $attachOperands^ `:` type($attachOperands) `)` )? attr-dict-with-keyword }]; + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -317,6 +319,8 @@ ( `detach` `(` $detachOperands^ `:` type($detachOperands) `)` )? attr-dict-with-keyword }]; + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -515,6 +519,8 @@ ( `device` `(` $deviceOperands^ `:` type($deviceOperands) `)` )? attr-dict-with-keyword }]; + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -8,9 +8,11 @@ #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace acc; @@ -153,6 +155,31 @@ return isa(op) || isa(op); } +namespace { +/// Pattern to remove operation without region that have constant false `ifCond` +/// and remove the condition from the operation if the `ifCond` is a true +/// constant. +template +struct RemoveConstantIfCondition : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Early return if there is no condition. + if (!op.ifCond()) + return success(); + + auto constOp = op.ifCond().template getDefiningOp(); + if (constOp && constOp.getValue().template cast().getInt()) + op.ifCondMutable().erase(0); + else if (constOp) + rewriter.eraseOp(op); + + return success(); + } +}; +} // namespace + //===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// @@ -694,6 +721,11 @@ return getOperand(waitOperands().size() + numOptional + i); } +void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + //===----------------------------------------------------------------------===// // EnterDataOp //===----------------------------------------------------------------------===// @@ -736,6 +768,11 @@ return getOperand(waitOperands().size() + numOptional + i); } +void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + //===----------------------------------------------------------------------===// // InitOp //===----------------------------------------------------------------------===// @@ -802,6 +839,11 @@ numOptional + i); } +void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + //===----------------------------------------------------------------------===// // WaitOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/canonicalize.mlir b/mlir/test/Dialect/OpenACC/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/canonicalize.mlir @@ -0,0 +1,92 @@ +// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s + +func @testenterdataop(%a: memref<10xf32>) -> () { + %ifCond = constant true + acc.enter_data if(%ifCond) create(%a: memref<10xf32>) + return +} + +// CHECK: acc.enter_data create(%{{.*}} : memref<10xf32>) + +// ----- + +func @testenterdataop(%a: memref<10xf32>) -> () { + %ifCond = constant false + acc.enter_data if(%ifCond) create(%a: memref<10xf32>) + return +} + +// CHECK: func @testenterdataop +// CHECK-NOT: acc.enter_data + +// ----- + +func @testexitdataop(%a: memref<10xf32>) -> () { + %ifCond = constant true + acc.exit_data if(%ifCond) delete(%a: memref<10xf32>) + return +} + +// CHECK: acc.exit_data delete(%{{.*}} : memref<10xf32>) + +// ----- + +func @testexitdataop(%a: memref<10xf32>) -> () { + %ifCond = constant false + acc.exit_data if(%ifCond) delete(%a: memref<10xf32>) + return +} + +// CHECK: func @testexitdataop +// CHECK-NOT: acc.exit_data + +// ----- + +func @testupdateop(%a: memref<10xf32>) -> () { + %ifCond = constant true + acc.update if(%ifCond) host(%a: memref<10xf32>) + return +} + +// CHECK: acc.update host(%{{.*}} : memref<10xf32>) + +// ----- + +func @testupdateop(%a: memref<10xf32>) -> () { + %ifCond = constant false + acc.update if(%ifCond) host(%a: memref<10xf32>) + return +} + +// CHECK: func @testupdateop +// CHECK-NOT: acc.update + +// ---- + +func @testenterdataop(%a: memref<10xf32>, %ifCond: i1) -> () { + acc.enter_data if(%ifCond) create(%a: memref<10xf32>) + return +} + +// CHECK: func @testenterdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1) +// CHECK: acc.enter_data if(%{{.*}}) create(%{{.*}} : memref<10xf32>) + +// ----- + +func @testexitdataop(%a: memref<10xf32>, %ifCond: i1) -> () { + acc.exit_data if(%ifCond) delete(%a: memref<10xf32>) + return +} + +// CHECK: func @testexitdataop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1) +// CHECK: acc.exit_data if(%{{.*}}) delete(%{{.*}} : memref<10xf32>) + +// ----- + +func @testupdateop(%a: memref<10xf32>, %ifCond: i1) -> () { + acc.update if(%ifCond) host(%a: memref<10xf32>) + return +} + +// CHECK: func @testupdateop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1) +// CHECK: acc.update if(%{{.*}}) host(%{{.*}} : memref<10xf32>)