diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1594,6 +1594,52 @@ } }; +/// Convert nested ifs into arith.andi + signle if. +/// +/// scf.if %arg0 { +/// scf.if %arg1 { +/// ... +/// scf.yield +/// } +/// scf.yield +/// } +/// becomes +/// +/// %0 = arith.andi %arg0, %arg1 +/// scf.if %0 { +/// ... +/// scf.yield +/// } +struct CombineNestedIfs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp op, + PatternRewriter &rewriter) const override { + // Both if ops must not yield results and have only then block. + if (op->getNumResults() != 0 || op.elseBlock()) + return failure(); + + auto nestedOps = op.thenBlock()->without_terminator(); + // Nested IfOp must be the only op in block. + if (!llvm::hasSingleElement(nestedOps)) + return failure(); + + auto nestedIf = dyn_cast(*nestedOps.begin()); + if (!nestedIf || nestedIf->getNumResults() != 0 || nestedIf.elseBlock()) + return failure(); + + Location loc = op.getLoc(); + Value newCondition = rewriter.create(loc, op.condition(), + nestedIf.condition()); + auto newIf = rewriter.create(loc, newCondition); + Block *newIfBlock = newIf.thenBlock(); + rewriter.eraseOp(newIfBlock->getTerminator()); + rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock); + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -1601,7 +1647,7 @@ results .add(context); + RemoveEmptyElseBranch, CombineNestedIfs>(context); } Block *IfOp::thenBlock() { return &thenRegion().back(); } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -429,6 +429,25 @@ // ----- +// CHECK-LABEL: @merge_nested_if +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) +func @merge_nested_if(%arg0: i1, %arg1: i1) { +// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]] +// CHECK: scf.if %[[COND]] { +// CHECK-NEXT: "test.op"() + scf.if %arg0 { + scf.if %arg1 { + "test.op"() : () -> () + scf.yield + } + scf.yield + } + return +} + + +// ----- + // CHECK-LABEL: @remove_zero_iteration_loop func @remove_zero_iteration_loop() { %c42 = arith.constant 42 : index