diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -311,6 +311,10 @@ return results().empty() ? OpBuilder::atBlockTerminator(body, listener) : OpBuilder::atBlockEnd(body, listener); } + Block* thenBlock(); + YieldOp thenYield(); + Block* elseBlock(); + YieldOp elseYield(); }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1266,15 +1266,125 @@ } }; +/// Merge any consecutive scf.if's with the same condition. +/// +/// scf.if %cond { +/// firstCodeTrue();... +/// } else { +/// firstCodeFalse();... +/// } +/// %res = scf.if %cond { +/// secondCodeTrue();... +/// } else { +/// secondCodeFalse();... +/// } +/// +/// becomes +/// %res = scf.if %cmp { +/// firstCodeTrue();... +/// secondCodeTrue();... +/// } else { +/// firstCodeFalse();... +/// secondCodeFalse();... +/// } +struct CombineIfs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp nextIf, + PatternRewriter &rewriter) const override { + Block *parent = nextIf->getBlock(); + if (nextIf == &parent->front()) + return failure(); + + auto prevIf = dyn_cast(nextIf->getPrevNode()); + if (!prevIf) + return failure(); + + if (nextIf.condition() != prevIf.condition()) + return failure(); + + // Don't permit merging if a result of the first if is used + // within the second. + if (llvm::any_of(prevIf->getUsers(), + [&](Operation *user) { return nextIf->isAncestor(user); })) + return failure(); + + SmallVector mergedTypes(prevIf.getResultTypes()); + llvm::append_range(mergedTypes, nextIf.getResultTypes()); + + IfOp combinedIf = rewriter.create( + nextIf.getLoc(), mergedTypes, nextIf.condition(), /*hasElse=*/false); + rewriter.eraseBlock(&combinedIf.thenRegion().back()); + + YieldOp thenYield = prevIf.thenYield(); + YieldOp thenYield2 = nextIf.thenYield(); + + combinedIf.thenRegion().getBlocks().splice( + combinedIf.thenRegion().getBlocks().begin(), + prevIf.thenRegion().getBlocks()); + + rewriter.mergeBlocks(nextIf.thenBlock(), combinedIf.thenBlock()); + rewriter.setInsertionPointToEnd(combinedIf.thenBlock()); + + SmallVector mergedYields(thenYield.getOperands()); + llvm::append_range(mergedYields, thenYield2.getOperands()); + rewriter.create(thenYield2.getLoc(), mergedYields); + rewriter.eraseOp(thenYield); + rewriter.eraseOp(thenYield2); + + combinedIf.elseRegion().getBlocks().splice( + combinedIf.elseRegion().getBlocks().begin(), + prevIf.elseRegion().getBlocks()); + + if (!nextIf.elseRegion().empty()) { + if (combinedIf.elseRegion().empty()) { + combinedIf.elseRegion().getBlocks().splice( + combinedIf.elseRegion().getBlocks().begin(), + nextIf.elseRegion().getBlocks()); + } else { + YieldOp elseYield = combinedIf.elseYield(); + YieldOp elseYield2 = nextIf.elseYield(); + rewriter.mergeBlocks(nextIf.elseBlock(), combinedIf.elseBlock()); + + rewriter.setInsertionPointToEnd(combinedIf.elseBlock()); + + SmallVector mergedElseYields(elseYield.getOperands()); + llvm::append_range(mergedElseYields, elseYield2.getOperands()); + + rewriter.create(elseYield2.getLoc(), mergedElseYields); + rewriter.eraseOp(elseYield); + rewriter.eraseOp(elseYield2); + } + } + + SmallVector prevValues; + SmallVector nextValues; + for (auto pair : llvm::enumerate(combinedIf.getResults())) { + if (pair.index() < prevIf.getNumResults()) + prevValues.push_back(pair.value()); + else + nextValues.push_back(pair.value()); + } + rewriter.replaceOp(prevIf, prevValues); + rewriter.replaceOp(nextIf, nextValues); + return success(); + } +}; + } // namespace void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add(context); + results.add(context); } +Block *IfOp::thenBlock() { return &thenRegion().back(); } +YieldOp IfOp::thenYield() { return cast(&thenBlock()->back()); } +Block *IfOp::elseBlock() { return &elseRegion().back(); } +YieldOp IfOp::elseYield() { return cast(&elseBlock()->back()); } + //===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -724,3 +724,103 @@ // CHECK-NEXT: scf.yield %[[sv2]] : i32 // CHECK-NEXT: } // CHECK-NEXT: return %[[if]], %arg1 : i32, i64 + +// ----- + +// CHECK-LABEL: @combineIfs +func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) { + %res = scf.if %arg0 -> i32 { + %v = "test.firstCodeTrue"() : () -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.firstCodeFalse"() : () -> i32 + scf.yield %v2 : i32 + } + %res2 = scf.if %arg0 -> i32 { + %v = "test.secondCodeTrue"() : () -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.secondCodeFalse"() : () -> i32 + scf.yield %v2 : i32 + } + return %res, %res2 : i32, i32 +} +// CHECK-NEXT: %[[res:.+]]:2 = scf.if %arg0 -> (i32, i32) { +// CHECK-NEXT: %[[tval0:.+]] = "test.firstCodeTrue"() : () -> i32 +// CHECK-NEXT: %[[tval:.+]] = "test.secondCodeTrue"() : () -> i32 +// CHECK-NEXT: scf.yield %[[tval0]], %[[tval]] : i32, i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[fval0:.+]] = "test.firstCodeFalse"() : () -> i32 +// CHECK-NEXT: %[[fval:.+]] = "test.secondCodeFalse"() : () -> i32 +// CHECK-NEXT: scf.yield %[[fval0]], %[[fval]] : i32, i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]]#0, %[[res]]#1 : i32, i32 + + +// CHECK-LABEL: @combineIfs2 +func @combineIfs2(%arg0 : i1, %arg2: i64) -> i32 { + scf.if %arg0 { + "test.firstCodeTrue"() : () -> () + scf.yield + } + %res = scf.if %arg0 -> i32 { + %v = "test.secondCodeTrue"() : () -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.secondCodeFalse"() : () -> i32 + scf.yield %v2 : i32 + } + return %res : i32 +} +// CHECK-NEXT: %[[res:.+]] = scf.if %arg0 -> (i32) { +// CHECK-NEXT: "test.firstCodeTrue"() : () -> () +// CHECK-NEXT: %[[tval:.+]] = "test.secondCodeTrue"() : () -> i32 +// CHECK-NEXT: scf.yield %[[tval]] : i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[fval:.+]] = "test.secondCodeFalse"() : () -> i32 +// CHECK-NEXT: scf.yield %[[fval]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]] : i32 + + +// CHECK-LABEL: @combineIfs3 +func @combineIfs3(%arg0 : i1, %arg2: i64) -> i32 { + %res = scf.if %arg0 -> i32 { + %v = "test.firstCodeTrue"() : () -> i32 + scf.yield %v : i32 + } else { + %v2 = "test.firstCodeFalse"() : () -> i32 + scf.yield %v2 : i32 + } + scf.if %arg0 { + "test.secondCodeTrue"() : () -> () + scf.yield + } + return %res : i32 +} +// CHECK-NEXT: %[[res:.+]] = scf.if %arg0 -> (i32) { +// CHECK-NEXT: %[[tval:.+]] = "test.firstCodeTrue"() : () -> i32 +// CHECK-NEXT: "test.secondCodeTrue"() : () -> () +// CHECK-NEXT: scf.yield %[[tval]] : i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[fval:.+]] = "test.firstCodeFalse"() : () -> i32 +// CHECK-NEXT: scf.yield %[[fval]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]] : i32 + +// CHECK-LABEL: @combineIfs4 +func @combineIfs4(%arg0 : i1, %arg2: i64) { + scf.if %arg0 { + "test.firstCodeTrue"() : () -> () + scf.yield + } + scf.if %arg0 { + "test.secondCodeTrue"() : () -> () + scf.yield + } + return +} +// CHECK-NEXT: scf.if %arg0 { +// CHECK-NEXT: "test.firstCodeTrue"() : () -> () +// CHECK-NEXT: "test.secondCodeTrue"() : () -> () +// CHECK-NEXT: }