diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -111,7 +111,7 @@ // TODO: If the parent is a func like op (which would be the case if all other // ops are from the std dialect), the inliner logic could be readily used to // inline. - let hasCanonicalizer = 0; + let hasCanonicalizer = 1; // TODO: can fold if it returns a constant. // TODO: Single block execute_region ops can be readily inlined irrespective diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -73,6 +73,19 @@ // ExecuteRegionOp //===----------------------------------------------------------------------===// +/// Replaces the given op with the contents of the given single-block region, +/// using the operands of the block terminator to replace operation results. +static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, + Region ®ion, ValueRange blockArgs = {}) { + assert(llvm::hasSingleElement(region) && "expected single-region block"); + Block *block = ®ion.front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.mergeBlockBefore(block, op, blockArgs); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); +} + /// /// (ssa-id `=`)? `execute_region` `->` function-result-type `{` /// block+ @@ -118,6 +131,37 @@ return success(); } +// Inline an ExecuteRegionOp if it only contains one op. +// "test.foo"() : () -> () +// %v = scf.execute_region -> i64 { +// %x = "test.val"() : () -> i64 +// scf.yield %x : i64 +// } +// "test.bar"(%v) : (i64) -> () +// +// becomes +// +// "test.foo"() : () -> () +// %x = "test.val"() : () -> i64 +// "test.bar"(%v) : (i64) -> () +// +struct SingleBlockExecuteInliner : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (op.region().getBlocks().size() != 1) + return failure(); + replaceOpWithRegion(rewriter, op, op.region()); + return success(); + } +}; + +void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // ForOp //===----------------------------------------------------------------------===// @@ -444,19 +488,6 @@ }); } -/// Replaces the given op with the contents of the given single-block region, -/// using the operands of the block terminator to replace operation results. -static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, - Region ®ion, ValueRange blockArgs = {}) { - assert(llvm::hasSingleElement(region) && "expected single-region block"); - Block *block = ®ion.front(); - Operation *terminator = block->getTerminator(); - ValueRange results = terminator->getOperands(); - rewriter.mergeBlockBefore(block, op, blockArgs); - rewriter.replaceOp(op, results); - rewriter.eraseOp(terminator); -} - namespace { // Fold away ForOp iter arguments when: // 1) The op yields the iter arguments. diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -921,9 +921,30 @@ } "test.bar"(%v) : (i64) -> () // CHECK: %[[C2:.*]] = constant 2 : i64 - // CHECK: scf.execute_region -> i64 { - // CHECK-NEXT: scf.yield %[[C2]] : i64 - // CHECK-NEXT: } + // CHECK: "test.foo" + // CHECK-NEXT: "test.bar"(%[[C2]]) : (i64) -> () } return } + +// ----- + +// CHECK-LABEL: func @execute_region_elim +func @execute_region_elim() { + affine.for %i = 0 to 100 { + "test.foo"() : () -> () + %v = scf.execute_region -> i64 { + %x = "test.val"() : () -> i64 + scf.yield %x : i64 + } + "test.bar"(%v) : (i64) -> () + } + return +} + +// CHECK-NEXT: affine.for %arg0 = 0 to 100 { +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: %[[VAL:.*]] = "test.val"() : () -> i64 +// CHECK-NEXT: "test.bar"(%[[VAL]]) : (i64) -> () +// CHECK-NEXT: } +