diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1371,18 +1371,44 @@ } }; +/// Pattern to remove an empty else branch. +struct RemoveEmptyElseBranch : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp ifOp, + PatternRewriter &rewriter) const override { + // Cannot remove else region when there are operation results. + if (ifOp.getNumResults()) + return failure(); + Block *elseBlock = ifOp.elseBlock(); + if (!elseBlock || !llvm::hasSingleElement(*elseBlock)) + return failure(); + auto newIfOp = rewriter.cloneWithoutRegions(ifOp); + rewriter.inlineRegionBefore(ifOp.thenRegion(), newIfOp.thenRegion(), + newIfOp.thenRegion().begin()); + rewriter.eraseOp(ifOp); + return success(); + } +}; + } // namespace void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add(context); } Block *IfOp::thenBlock() { return &thenRegion().back(); } YieldOp IfOp::thenYield() { return cast(&thenBlock()->back()); } -Block *IfOp::elseBlock() { return &elseRegion().back(); } +Block *IfOp::elseBlock() { + Region &r = elseRegion(); + if (r.empty()) + return nullptr; + return &r.back(); +} YieldOp IfOp::elseYield() { return cast(&elseBlock()->back()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -250,6 +250,20 @@ // CHECK-NOT: scf.if // CHECK: return +// ---- + +func @empty_else(%cond: i1, %v : memref) { + scf.if %cond { + memref.store %cond, %v[] : memref + } else { + } + return +} + +// CHECK-LABEL: func @empty_else +// CHECK: scf.if +// CHECK-NOT: else + // ----- func @to_select1(%cond: i1) -> index { @@ -475,9 +489,9 @@ // CHECK-LABEL: @replace_single_iteration_loop_2 func @replace_single_iteration_loop_2() { // CHECK: %[[LB:.*]] = constant 5 - %c5 = constant 5 : index - %c6 = constant 6 : index - %c11 = constant 11 : index + %c5 = constant 5 : index + %c6 = constant 6 : index + %c11 = constant 11 : index // CHECK: %[[INIT:.*]] = "test.init" %init = "test.init"() : () -> i32 // CHECK-NOT: scf.for