diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1596,14 +1596,60 @@ } }; +/// Convert nested `if`s into `arith.andi` + single `if`. +/// +/// scf.if %arg0 { +/// scf.if %arg1 { +/// ... +/// scf.yield +/// } +/// scf.yield +/// } +/// becomes +/// +/// %0 = arith.andi %arg0, %arg1 +/// scf.if %0 { +/// ... +/// scf.yield +/// } +struct CombineNestedIfs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp op, + PatternRewriter &rewriter) const override { + // Both `if` ops must not yield results and have only `then` block. + if (op->getNumResults() != 0 || op.elseBlock()) + return failure(); + + auto nestedOps = op.thenBlock()->without_terminator(); + // Nested `if` must be the only op in block. + if (!llvm::hasSingleElement(nestedOps)) + return failure(); + + auto nestedIf = dyn_cast(*nestedOps.begin()); + if (!nestedIf || nestedIf->getNumResults() != 0 || nestedIf.elseBlock()) + return failure(); + + Location loc = op.getLoc(); + Value newCondition = rewriter.create(loc, op.condition(), + nestedIf.condition()); + auto newIf = rewriter.create(loc, newCondition); + Block *newIfBlock = newIf.thenBlock(); + rewriter.eraseOp(newIfBlock->getTerminator()); + rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock); + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add(context); + results.add(context); } Block *IfOp::thenBlock() { return &getThenRegion().back(); } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -429,6 +429,24 @@ // ----- +// CHECK-LABEL: @merge_nested_if +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) +func @merge_nested_if(%arg0: i1, %arg1: i1) { +// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]] +// CHECK: scf.if %[[COND]] { +// CHECK-NEXT: "test.op"() + scf.if %arg0 { + scf.if %arg1 { + "test.op"() : () -> () + scf.yield + } + scf.yield + } + return +} + +// ----- + // CHECK-LABEL: @remove_zero_iteration_loop func @remove_zero_iteration_loop() { %c42 = arith.constant 42 : index