diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1266,13 +1266,83 @@ } }; +/// Merge any consecutive scf.if's with the same condition. +/// +/// scf.if %cond { +/// firstCodeTrue();... +/// } else { +/// firstCodeFalse();... +/// } +/// %res = scf.if %cond { +/// secondCodeTrue();... +/// } else { +/// secondCodeFalse();... +/// } +/// +/// becomes +/// %res = scf.if %cmp { +/// firstCodeTrue();... +/// secondCodeTrue();... +/// } else { +/// firstCodeFalse();... +/// secondCodeFalse();... +/// } +struct CombineIfs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp nextIf, + PatternRewriter &rewriter) const override { + Block *parent = nextIf->getBlock(); + if (nextIf == &parent->front()) + return failure(); + + auto prevIf = dyn_cast(nextIf->getPrevNode()); + if (!prevIf) + return failure(); + + if (prevIf.results().size() != 0) + return failure(); + + if (nextIf.condition() != prevIf.condition()) + return failure(); + + rewriter.updateRootInPlace(nextIf, [&]() { + Block &then = prevIf.thenRegion().back(); + assert(isa(&then.back())); + rewriter.eraseOp(&then.back()); + rewriter.mergeBlocks(&*nextIf.thenRegion().begin(), &then); + nextIf.thenRegion().getBlocks().splice( + nextIf.thenRegion().getBlocks().begin(), + prevIf.thenRegion().getBlocks()); + + if (!prevIf.elseRegion().empty()) { + if (nextIf.elseRegion().empty()) { + nextIf.elseRegion().getBlocks().splice( + nextIf.elseRegion().getBlocks().begin(), + prevIf.elseRegion().getBlocks()); + } else { + Block &elseB = prevIf.elseRegion().back(); + assert(isa(&elseB.back())); + rewriter.eraseOp(&elseB.back()); + rewriter.mergeBlocks(&*nextIf.elseRegion().begin(), &elseB); + nextIf.elseRegion().getBlocks().splice( + nextIf.elseRegion().getBlocks().begin(), + prevIf.elseRegion().getBlocks()); + } + } + }); + rewriter.eraseOp(prevIf); + return success(); + } +}; + } // namespace void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -724,3 +724,100 @@ // CHECK-NEXT: scf.yield %[[sv2]] : i32 // CHECK-NEXT: } // CHECK-NEXT: return %[[if]], %arg1 : i32, i64 + +// ----- + +// CHECK-LABEL: @combineIfs +func @combineIfs(%arg0 : i1, %arg2: i64) -> i32 { + scf.if %arg0 { + "test.firstCodeTrue"() : () -> () + scf.yield + } else { + "test.firstCodeFalse"() : () -> () + scf.yield + } + %res = scf.if %arg0 -> i32 { + %v = "test.secondCodeTrue"() : () -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.secondCodeFalse"() : () -> i32 + scf.yield %v2 : i32 + } + return %res : i32 +} +// CHECK-NEXT: %[[res:.+]] = scf.if %arg0 -> (i32) { +// CHECK-NEXT: "test.firstCodeTrue"() : () -> () +// CHECK-NEXT: %[[tval:.+]] = "test.secondCodeTrue"() : () -> i32 +// CHECK-NEXT: scf.yield %[[tval]] : i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: "test.firstCodeFalse"() : () -> () +// CHECK-NEXT: %[[fval:.+]] = "test.secondCodeFalse"() : () -> i32 +// CHECK-NEXT: scf.yield %[[fval]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]] : i32 + + +// CHECK-LABEL: @combineIfs2 +func @combineIfs2(%arg0 : i1, %arg2: i64) -> i32 { + scf.if %arg0 { + "test.firstCodeTrue"() : () -> () + scf.yield + } + %res = scf.if %arg0 -> i32 { + %v = "test.secondCodeTrue"() : () -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.secondCodeFalse"() : () -> i32 + scf.yield %v2 : i32 + } + return %res : i32 +} +// CHECK-NEXT: %[[res:.+]] = scf.if %arg0 -> (i32) { +// CHECK-NEXT: "test.firstCodeTrue"() : () -> () +// CHECK-NEXT: %[[tval:.+]] = "test.secondCodeTrue"() : () -> i32 +// CHECK-NEXT: scf.yield %[[tval]] : i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[fval:.+]] = "test.secondCodeFalse"() : () -> i32 +// CHECK-NEXT: scf.yield %[[fval]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]] : i32 + + +// CHECK-LABEL: @combineIfs3 +func @combineIfs3(%arg0 : i1, %arg2: i64) { + scf.if %arg0 { + "test.firstCodeTrue"() : () -> () + scf.yield + } else { + "test.firstCodeFalse"() : () -> () + scf.yield + } + scf.if %arg0 { + "test.secondCodeTrue"() : () -> () + scf.yield + } + return +} +// CHECK-NEXT: scf.if %arg0 { +// CHECK-NEXT: "test.firstCodeTrue"() : () -> () +// CHECK-NEXT: "test.secondCodeTrue"() : () -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: "test.firstCodeFalse"() : () -> () +// CHECK-NEXT: } + +// CHECK-LABEL: @combineIfs4 +func @combineIfs4(%arg0 : i1, %arg2: i64) { + scf.if %arg0 { + "test.firstCodeTrue"() : () -> () + scf.yield + } + scf.if %arg0 { + "test.secondCodeTrue"() : () -> () + scf.yield + } + return +} +// CHECK-NEXT: scf.if %arg0 { +// CHECK-NEXT: "test.firstCodeTrue"() : () -> () +// CHECK-NEXT: "test.secondCodeTrue"() : () -> () +// CHECK-NEXT: } \ No newline at end of file