diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -784,6 +784,7 @@ }]; let hasVerifier = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -292,6 +292,46 @@ return success(); } }; + +/// Replaces the given op with the contents of the given single-block region, +/// using the operands of the block terminator to replace operation results. +static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, + Region ®ion, ValueRange blockArgs = {}) { + assert(llvm::hasSingleElement(region) && "expected single-region block"); + Block *block = ®ion.front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.inlineBlockBefore(block, op, blockArgs); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); +} + +/// Pattern to remove operation with region that have constant false `ifCond` +/// and remove the condition from the operation if the `ifCond` is constant +/// true. +template +struct RemoveConstantIfConditionWithRegion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Early return if there is no condition. + Value ifCond = op.getIfCond(); + if (!ifCond) + return failure(); + + IntegerAttr constAttr; + if (!matchPattern(ifCond, m_Constant(&constAttr))) + return failure(); + if (constAttr.getInt()) + rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); + else + replaceOpWithRegion(rewriter, op, op.getRegion()); + + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -386,6 +426,11 @@ return success(); } +void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + //===----------------------------------------------------------------------===// // LoopOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenACC/canonicalize.mlir b/mlir/test/Dialect/OpenACC/canonicalize.mlir --- a/mlir/test/Dialect/OpenACC/canonicalize.mlir +++ b/mlir/test/Dialect/OpenACC/canonicalize.mlir @@ -105,3 +105,40 @@ // CHECK: func @testupdateop(%{{.*}}: memref, [[IFCOND:%.*]]: i1) // CHECK: acc.update if(%{{.*}}) dataOperands(%{{.*}} : memref) + +// ----- + +func.func @testhostdataop(%a: memref, %ifCond: i1) -> () { + %0 = acc.use_device varPtr(%a : memref) -> memref + %false = arith.constant false + acc.host_data dataOperands(%0 : memref) if(%false) { + acc.loop { + acc.yield + } + acc.loop { + acc.yield + } + acc.terminator + } + return +} + +// CHECK-LABEL: func.func @testhostdataop +// CHECK-NOT: acc.host_data +// CHECK: acc.loop +// CHECK: acc.yield +// CHECK: acc.loop +// CHECK: acc.yield + +// ----- + +func.func @testhostdataop(%a: memref, %ifCond: i1) -> () { + %0 = acc.use_device varPtr(%a : memref) -> memref + %true = arith.constant true + acc.host_data dataOperands(%0 : memref) if(%true) { + } + return +} + +// CHECK-LABEL: func.func @testhostdataop +// CHECK: acc.host_data dataOperands(%{{.*}} : memref) {