diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1266,13 +1266,136 @@ } }; +/// Merge any consecutive scf.if's with the same condition. +/// +/// scf.if %cond { +/// firstCodeTrue();... +/// } else { +/// firstCodeFalse();... +/// } +/// %res = scf.if %cond { +/// secondCodeTrue();... +/// } else { +/// secondCodeFalse();... +/// } +/// +/// becomes +/// %res = scf.if %cmp { +/// firstCodeTrue();... +/// secondCodeTrue();... +/// } else { +/// firstCodeFalse();... +/// secondCodeFalse();... +/// } +struct CombineIfs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp nextIf, + PatternRewriter &rewriter) const override { + Block *parent = nextIf->getBlock(); + if (nextIf == &parent->front()) + return failure(); + + auto prevIf = dyn_cast(nextIf->getPrevNode()); + if (!prevIf) + return failure(); + + if (nextIf.condition() != prevIf.condition()) + return failure(); + + // Don't permit merging if a result of the first if is used + // within the second. + for (auto res : prevIf.results()) { + for (auto &use : res.getUses()) { + if (nextIf->isAncestor(use.getOwner())) + return failure(); + } + } + + SmallVector mergedTypes(prevIf.getResultTypes()); + for (auto nextType : nextIf.getResultTypes()) { + mergedTypes.push_back(nextType); + } + + auto combinedIf = rewriter.create( + nextIf.getLoc(), mergedTypes, nextIf.condition(), /*hasElse*/ false); + rewriter.eraseBlock(&combinedIf.thenRegion().front()); + + combinedIf.thenRegion().getBlocks().splice( + combinedIf.thenRegion().getBlocks().begin(), + nextIf.thenRegion().getBlocks()); + Block &then = prevIf.thenRegion().back(); + auto thenYield = cast(&then.back()); + rewriter.mergeBlocks(&*combinedIf.thenRegion().begin(), &then); + combinedIf.thenRegion().getBlocks().splice( + combinedIf.thenRegion().getBlocks().begin(), + prevIf.thenRegion().getBlocks()); + auto thenYield2 = cast(&combinedIf.thenRegion().back().back()); + rewriter.setInsertionPoint(&combinedIf.thenRegion().back(), + combinedIf.thenRegion().back().end()); + + SmallVector mergedYields(thenYield.getOperands()); + for (auto nextYield : thenYield2.getOperands()) { + mergedYields.push_back(nextYield); + } + rewriter.create(thenYield2.getLoc(), mergedYields); + rewriter.eraseOp(thenYield); + rewriter.eraseOp(thenYield2); + + combinedIf.elseRegion().getBlocks().splice( + combinedIf.elseRegion().getBlocks().begin(), + nextIf.elseRegion().getBlocks()); + + if (!prevIf.elseRegion().empty()) { + if (combinedIf.elseRegion().empty()) { + combinedIf.elseRegion().getBlocks().splice( + combinedIf.elseRegion().getBlocks().begin(), + prevIf.elseRegion().getBlocks()); + } else { + Block &elseB = prevIf.elseRegion().back(); + auto elseYield = cast(&elseB.back()); + rewriter.mergeBlocks(&*combinedIf.elseRegion().begin(), &elseB); + combinedIf.elseRegion().getBlocks().splice( + combinedIf.elseRegion().getBlocks().begin(), + prevIf.elseRegion().getBlocks()); + auto elseYield2 = cast(&combinedIf.elseRegion().back().back()); + rewriter.setInsertionPoint(&combinedIf.elseRegion().back(), + combinedIf.elseRegion().back().end()); + + SmallVector mergedElseYields(elseYield.getOperands()); + for (auto nextYield : elseYield2.getOperands()) { + mergedElseYields.push_back(nextYield); + } + + rewriter.create(elseYield2.getLoc(), mergedElseYields); + rewriter.eraseOp(elseYield); + rewriter.eraseOp(elseYield2); + } + } + + SmallVector prevValues; + SmallVector nextValues; + size_t i = 0; + for (auto val : combinedIf.getResults()) { + if (i < prevIf.getNumResults()) + prevValues.push_back(val); + else + nextValues.push_back(val); + i++; + } + rewriter.replaceOp(prevIf, prevValues); + rewriter.replaceOp(nextIf, nextValues); + return success(); + } +}; + } // namespace void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -724,3 +724,103 @@ // CHECK-NEXT: scf.yield %[[sv2]] : i32 // CHECK-NEXT: } // CHECK-NEXT: return %[[if]], %arg1 : i32, i64 + +// ----- + +// CHECK-LABEL: @combineIfs +func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) { + %res = scf.if %arg0 -> i32 { + %v = "test.firstCodeTrue"() : () -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.firstCodeFalse"() : () -> i32 + scf.yield %v2 : i32 + } + %res2 = scf.if %arg0 -> i32 { + %v = "test.secondCodeTrue"() : () -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.secondCodeFalse"() : () -> i32 + scf.yield %v2 : i32 + } + return %res, %res2 : i32, i32 +} +// CHECK-NEXT: %[[res:.+]]:2 = scf.if %arg0 -> (i32, i32) { +// CHECK-NEXT: %[[tval0:.+]] = "test.firstCodeTrue"() : () -> i32 +// CHECK-NEXT: %[[tval:.+]] = "test.secondCodeTrue"() : () -> i32 +// CHECK-NEXT: scf.yield %[[tval0]], %[[tval]] : i32, i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[fval0:.+]] = "test.firstCodeFalse"() : () -> i32 +// CHECK-NEXT: %[[fval:.+]] = "test.secondCodeFalse"() : () -> i32 +// CHECK-NEXT: scf.yield %[[fval0]], %[[fval]] : i32, i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]]#0, %[[res]]#1 : i32, i32 + + +// CHECK-LABEL: @combineIfs2 +func @combineIfs2(%arg0 : i1, %arg2: i64) -> i32 { + scf.if %arg0 { + "test.firstCodeTrue"() : () -> () + scf.yield + } + %res = scf.if %arg0 -> i32 { + %v = "test.secondCodeTrue"() : () -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.secondCodeFalse"() : () -> i32 + scf.yield %v2 : i32 + } + return %res : i32 +} +// CHECK-NEXT: %[[res:.+]] = scf.if %arg0 -> (i32) { +// CHECK-NEXT: "test.firstCodeTrue"() : () -> () +// CHECK-NEXT: %[[tval:.+]] = "test.secondCodeTrue"() : () -> i32 +// CHECK-NEXT: scf.yield %[[tval]] : i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[fval:.+]] = "test.secondCodeFalse"() : () -> i32 +// CHECK-NEXT: scf.yield %[[fval]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]] : i32 + + +// CHECK-LABEL: @combineIfs3 +func @combineIfs3(%arg0 : i1, %arg2: i64) -> i32 { + %res = scf.if %arg0 -> i32 { + %v = "test.firstCodeTrue"() : () -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.firstCodeFalse"() : () -> i32 + scf.yield %v2 : i32 + } + scf.if %arg0 { + "test.secondCodeTrue"() : () -> () + scf.yield + } + return %res : i32 +} +// CHECK-NEXT: %[[res:.+]] = scf.if %arg0 -> (i32) { +// CHECK-NEXT: %[[tval:.+]] = "test.firstCodeTrue"() : () -> i32 +// CHECK-NEXT: "test.secondCodeTrue"() : () -> () +// CHECK-NEXT: scf.yield %[[tval]] : i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[fval:.+]] = "test.firstCodeFalse"() : () -> i32 +// CHECK-NEXT: scf.yield %[[fval]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]] : i32 + +// CHECK-LABEL: @combineIfs4 +func @combineIfs4(%arg0 : i1, %arg2: i64) { + scf.if %arg0 { + "test.firstCodeTrue"() : () -> () + scf.yield + } + scf.if %arg0 { + "test.secondCodeTrue"() : () -> () + scf.yield + } + return +} +// CHECK-NEXT: scf.if %arg0 { +// CHECK-NEXT: "test.firstCodeTrue"() : () -> () +// CHECK-NEXT: "test.secondCodeTrue"() : () -> () +// CHECK-NEXT: }