diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -230,9 +230,59 @@ } }; +// Move any code at the start of an execute region before the region. +// For example: +// scf.execute_region{ +// op1() +// ... +// } +// becomes +// op1() +// scf.execute_region{ +// ... +// } +struct ExecuteRegionCodeMotion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + // Move any operation at the start of the region + // and guaranteed to execute outside the op. + bool changed = false; + while (op.getRegion().front().getOperations().size() > 1) { + Operation *start = &op.getRegion().front().front(); + rewriter.startRootUpdate(op); + rewriter.startRootUpdate(start); + start->moveBefore(op); + rewriter.finalizeRootUpdate(start); + rewriter.finalizeRootUpdate(op); + changed = true; + } + + // Now we can assume there is only 1 opeartion in the first block + // its terminator. If that is an unconditional branch, whose successor only + // has this block as a predecessor, simply start at the successor block. + Block *start = &op.getRegion().front(); + if (auto br = dyn_cast(start->front())) { + Block *dst = br.getDest(); + if (dst->getSinglePredecessor() && dst->getNumArguments() == 0) { + rewriter.startRootUpdate(op); + + dst->moveBefore(start); + + rewriter.eraseBlock(start); + rewriter.finalizeRootUpdate(op); + changed = true; + } + } + return success(changed); + } +}; + void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1069,3 +1069,39 @@ // CHECK: ^[[bb3]](%[[z:.+]]: i64): // CHECK: "test.bar"(%[[z]]) // CHECK: return + +// ----- + +// CHECK-LABEL: func @execute_region_code_motion +func @execute_region_code_motion() { + affine.for %i = 0 to 100 { + scf.execute_region { + "test.foo"() : () -> () + br ^bb1 + ^bb1: + %c = "test.cmp"() : () -> i1 + cond_br %c, ^bb2, ^bb3 + ^bb2: + %x = "test.val1"() : () -> i64 + scf.yield + ^bb3: + %y = "test.val2"() : () -> i64 + scf.yield + } + } + return +} + +// CHECK-NEXT: affine.for %arg0 = 0 to 100 { +// CHECK-NEXT: "test.foo"() : () -> () +// CHECK-NEXT: %0 = "test.cmp"() : () -> i1 +// CHECK-NEXT: scf.execute_region { +// CHECK-NEXT: cond_br %0, ^bb1, ^bb2 +// CHECK-NEXT: ^bb1: // pred: ^bb0 +// CHECK-NEXT: %1 = "test.val1"() : () -> i64 +// CHECK-NEXT: scf.yield +// CHECK-NEXT: ^bb2: // pred: ^bb0 +// CHECK-NEXT: %2 = "test.val2"() : () -> i64 +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: }