diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -108,14 +108,9 @@ let regions = (region AnyRegion:$region); - // TODO: If the parent is a func like op (which would be the case if all other - // ops are from the std dialect), the inliner logic could be readily used to - // inline. let hasCanonicalizer = 1; - // TODO: can fold if it returns a constant. - // TODO: Single block execute_region ops can be readily inlined irrespective - // of which op is a parent. Add a fold for this. + // TODO: Change single block inlining to a folder. let hasFolder = 0; } diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -143,7 +143,7 @@ // // "test.foo"() : () -> () // %x = "test.val"() : () -> i64 -// "test.bar"(%v) : (i64) -> () +// "test.bar"(%x) : (i64) -> () // struct SingleBlockExecuteInliner : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -157,9 +157,71 @@ } }; + +// Inline an ExecuteRegionOp if its parent can contain multiple blocks. +// TODO generalize the conditions for operations which can be inlined into. +// func @func_execute_region_elim() { +// "test.foo"() : () -> () +// %v = scf.execute_region -> i64 { +// %c = "test.cmp"() : () -> i1 +// cond_br %c, ^bb2, ^bb3 +// ^bb2: +// %x = "test.val1"() : () -> i64 +// br ^bb4(%x : i64) +// ^bb3: +// %y = "test.val2"() : () -> i64 +// br ^bb4(%y : i64) +// ^bb4(%z : i64): +// scf.yield %z : i64 +// } +// "test.bar"(%v) : (i64) -> () +// return +// } +// +// becomes +// +// func @func_execute_region_elim() { +// "test.foo"() : () -> () +// %c = "test.cmp"() : () -> i1 +// cond_br %c, ^bb1, ^bb2 +// ^bb1: // pred: ^bb0 +// %x = "test.val1"() : () -> i64 +// br ^bb3(%x : i64) +// ^bb2: // pred: ^bb0 +// %y = "test.val2"() : () -> i64 +// br ^bb3(%y : i64) +// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2 +// "test.bar"(%z) : (i64) -> () +// return +// } +// +struct MultiBlockExecuteInliner : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (!isa(op->getParentOp())) + return failure(); + + Block *prevBlock = op->getBlock(); + Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator()); + rewriter.setInsertionPointToEnd(prevBlock); + + rewriter.create(op.getLoc(), &op.region().front()); + + YieldOp yieldOp = cast(op.region().back().getTerminator()); + rewriter.setInsertionPoint(yieldOp); + rewriter.create(yieldOp.getLoc(), postBlock); + rewriter.inlineRegionBefore(op.region(), postBlock); + rewriter.replaceOp(op, yieldOp.results()); + rewriter.eraseOp(yieldOp); + return success(); + } +}; + void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -948,3 +948,38 @@ // CHECK-NEXT: "test.bar"(%[[VAL]]) : (i64) -> () // CHECK-NEXT: } + +// ----- + +// CHECK-LABEL: func @func_execute_region_elim +func @func_execute_region_elim() { + "test.foo"() : () -> () + %v = scf.execute_region -> i64 { + %c = "test.cmp"() : () -> i1 + cond_br %c, ^bb2, ^bb3 + ^bb2: + %x = "test.val1"() : () -> i64 + br ^bb4(%x : i64) + ^bb3: + %y = "test.val2"() : () -> i64 + br ^bb4(%y : i64) + ^bb4(%z : i64): + scf.yield %z : i64 + } + "test.bar"(%v) : (i64) -> () + return +} + +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: %[[cmp:.+]] = "test.cmp"() : () -> i1 +// CHECK-NEXT: cond_br %[[cmp]], ^bb1, ^bb2 +// CHECK-NEXT: ^bb1: // pred: ^bb0 +// CHECK-NEXT: %[[x:.+]] = "test.val1"() : () -> i64 +// CHECK-NEXT: br ^bb3(%[[x]] : i64) +// CHECK-NEXT: ^bb2: // pred: ^bb0 +// CHECK-NEXT: %[[y:.+]] = "test.val2"() : () -> i64 +// CHECK-NEXT: br ^bb3(%[[y:.+]] : i64) +// CHECK-NEXT: ^bb3(%[[z:.+]]: i64): // 2 preds: ^bb1, ^bb2 +// CHECK-NEXT: "test.bar"(%[[z]]) : (i64) -> () +// CHECK-NEXT: return +// CHECK-NEXT: } \ No newline at end of file