diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -108,14 +108,8 @@ let regions = (region AnyRegion:$region); - // TODO: If the parent is a func like op (which would be the case if all other - // ops are from the std dialect), the inliner logic could be readily used to - // inline. let hasCanonicalizer = 1; - // TODO: can fold if it returns a constant. - // TODO: Single block execute_region ops can be readily inlined irrespective - // of which op is a parent. Add a fold for this. let hasFolder = 0; } diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -143,23 +143,94 @@ // // "test.foo"() : () -> () // %x = "test.val"() : () -> i64 -// "test.bar"(%v) : (i64) -> () +// "test.bar"(%x) : (i64) -> () // struct SingleBlockExecuteInliner : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override { - if (op.region().getBlocks().size() != 1) + if (!llvm::hasSingleElement(op.region())) return failure(); replaceOpWithRegion(rewriter, op, op.region()); return success(); } }; +// Inline an ExecuteRegionOp if its parent can contain multiple blocks. +// TODO generalize the conditions for operations which can be inlined into. +// func @func_execute_region_elim() { +// "test.foo"() : () -> () +// %v = scf.execute_region -> i64 { +// %c = "test.cmp"() : () -> i1 +// cond_br %c, ^bb2, ^bb3 +// ^bb2: +// %x = "test.val1"() : () -> i64 +// br ^bb4(%x : i64) +// ^bb3: +// %y = "test.val2"() : () -> i64 +// br ^bb4(%y : i64) +// ^bb4(%z : i64): +// scf.yield %z : i64 +// } +// "test.bar"(%v) : (i64) -> () +// return +// } +// +// becomes +// +// func @func_execute_region_elim() { +// "test.foo"() : () -> () +// %c = "test.cmp"() : () -> i1 +// cond_br %c, ^bb1, ^bb2 +// ^bb1: // pred: ^bb0 +// %x = "test.val1"() : () -> i64 +// br ^bb3(%x : i64) +// ^bb2: // pred: ^bb0 +// %y = "test.val2"() : () -> i64 +// br ^bb3(%y : i64) +// ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2 +// "test.bar"(%z) : (i64) -> () +// return +// } +// +struct MultiBlockExecuteInliner : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (!isa(op->getParentOp())) + return failure(); + + Block *prevBlock = op->getBlock(); + Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator()); + rewriter.setInsertionPointToEnd(prevBlock); + + rewriter.create(op.getLoc(), &op.region().front()); + + for (Block &blk : op.region()) { + if (YieldOp yieldOp = dyn_cast(blk.getTerminator())) { + rewriter.setInsertionPoint(yieldOp); + rewriter.create(yieldOp.getLoc(), postBlock, + yieldOp.results()); + rewriter.eraseOp(yieldOp); + } + } + + rewriter.inlineRegionBefore(op.region(), postBlock); + SmallVector blockArgs; + + for (auto res : op.getResults()) + blockArgs.push_back(postBlock->addArgument(res.getType())); + + rewriter.replaceOp(op, blockArgs); + return success(); + } +}; + void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -948,3 +948,70 @@ // CHECK-NEXT: "test.bar"(%[[VAL]]) : (i64) -> () // CHECK-NEXT: } + +// ----- + +// CHECK-LABEL: func @func_execute_region_elim +func @func_execute_region_elim() { + "test.foo"() : () -> () + %v = scf.execute_region -> i64 { + %c = "test.cmp"() : () -> i1 + cond_br %c, ^bb2, ^bb3 + ^bb2: + %x = "test.val1"() : () -> i64 + br ^bb4(%x : i64) + ^bb3: + %y = "test.val2"() : () -> i64 + br ^bb4(%y : i64) + ^bb4(%z : i64): + scf.yield %z : i64 + } + "test.bar"(%v) : (i64) -> () + return +} + +// CHECK: "test.foo" +// CHECK: %[[cmp:.+]] = "test.cmp" +// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]] +// CHECK: ^[[bb1]]: // pred: ^bb0 +// CHECK: %[[x:.+]] = "test.val1" +// CHECK: br ^[[bb3:.+]](%[[x]] : i64) +// CHECK: ^[[bb2]]: // pred: ^bb0 +// CHECK: %[[y:.+]] = "test.val2" +// CHECK: br ^[[bb3]](%[[y:.+]] : i64) +// CHECK: ^[[bb3]](%[[z:.+]]: i64): +// CHECK: "test.bar"(%[[z]]) +// CHECK: return + + +// ----- + +// CHECK-LABEL: func @func_execute_region_elim2 +func @func_execute_region_elim2() { + "test.foo"() : () -> () + %v = scf.execute_region -> i64 { + %c = "test.cmp"() : () -> i1 + cond_br %c, ^bb2, ^bb3 + ^bb2: + %x = "test.val1"() : () -> i64 + scf.yield %x : i64 + ^bb3: + %y = "test.val2"() : () -> i64 + scf.yield %y : i64 + } + "test.bar"(%v) : (i64) -> () + return +} + +// CHECK: "test.foo" +// CHECK: %[[cmp:.+]] = "test.cmp" +// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]] +// CHECK: ^[[bb1]]: // pred: ^bb0 +// CHECK: %[[x:.+]] = "test.val1" +// CHECK: br ^[[bb3:.+]](%[[x]] : i64) +// CHECK: ^[[bb2]]: // pred: ^bb0 +// CHECK: %[[y:.+]] = "test.val2" +// CHECK: br ^[[bb3]](%[[y:.+]] : i64) +// CHECK: ^[[bb3]](%[[z:.+]]: i64): +// CHECK: "test.bar"(%[[z]]) +// CHECK: return