diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -393,6 +393,19 @@ }); } +/// Replaces the given op with the contents of the given single-block region, +/// using the operands of the block terminator to replace operation results. +static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, + Region ®ion, ValueRange blockArgs = {}) { + assert(llvm::hasSingleElement(region) && "expected single-region block"); + Block *block = ®ion.front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.mergeBlockBefore(block, op, blockArgs); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); +} + namespace { // Fold away ForOp iter arguments that are also yielded by the op. // These arguments must be defined outside of the ForOp region and can just be @@ -500,11 +513,51 @@ return success(); } }; + +/// Rewriting pattern that erases loops that are known not to iterate and +/// replaces single-iteration loops with their bodies. +struct SimplifyTrivialLoops : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp op, + PatternRewriter &rewriter) const override { + auto lb = op.lowerBound().getDefiningOp(); + auto ub = op.upperBound().getDefiningOp(); + if (!lb || !ub) + return failure(); + + // If the loop is known to have 0 iterations, remove it. + llvm::APInt lbValue = lb.getValue().cast().getValue(); + llvm::APInt ubValue = ub.getValue().cast().getValue(); + if (lbValue.sge(ubValue)) { + rewriter.replaceOp(op, op.getIterOperands()); + return success(); + } + + auto step = op.step().getDefiningOp(); + if (!step) + return failure(); + + // If the loop is known to have 1 iteration, inline its body and remove the + // loop. + llvm::APInt stepValue = lb.getValue().cast().getValue(); + if ((lbValue + stepValue).sge(ubValue)) { + SmallVector blockArgs; + blockArgs.reserve(op.getNumIterOperands() + 1); + blockArgs.push_back(op.lowerBound()); + llvm::append_range(blockArgs, op.getIterOperands()); + replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs); + return success(); + } + + return failure(); + } +}; } // namespace void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -710,11 +763,31 @@ return success(); } }; + +struct RemoveStaticCondition : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp op, + PatternRewriter &rewriter) const override { + auto constant = op.condition().getDefiningOp(); + if (!constant) + return failure(); + + if (constant.getValue().cast().getValue()) + replaceOpWithRegion(rewriter, op, op.thenRegion()); + else if (!op.elseRegion().empty()) + replaceOpWithRegion(rewriter, op, op.elseRegion()); + else + rewriter.eraseOp(op); + + return success(); + } +}; } // namespace void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -56,11 +56,10 @@ // ----- -func @one_unused() -> (index) { +func @one_unused(%cond: i1) -> (index) { %c0 = constant 0 : index %c1 = constant 1 : index - %true = constant true - %0, %1 = scf.if %true -> (index, index) { + %0, %1 = scf.if %cond -> (index, index) { scf.yield %c0, %c1 : index, index } else { scf.yield %c0, %c1 : index, index @@ -70,8 +69,7 @@ // CHECK-LABEL: func @one_unused // CHECK: [[C0:%.*]] = constant 1 : index -// CHECK: [[C1:%.*]] = constant true -// CHECK: [[V0:%.*]] = scf.if [[C1]] -> (index) { +// CHECK: [[V0:%.*]] = scf.if %{{.*}} -> (index) { // CHECK: scf.yield [[C0]] : index // CHECK: } else // CHECK: scf.yield [[C0]] : index @@ -80,12 +78,11 @@ // ----- -func @nested_unused() -> (index) { +func @nested_unused(%cond1: i1, %cond2: i1) -> (index) { %c0 = constant 0 : index %c1 = constant 1 : index - %true = constant true - %0, %1 = scf.if %true -> (index, index) { - %2, %3 = scf.if %true -> (index, index) { + %0, %1 = scf.if %cond1 -> (index, index) { + %2, %3 = scf.if %cond2 -> (index, index) { scf.yield %c0, %c1 : index, index } else { scf.yield %c0, %c1 : index, index @@ -99,9 +96,8 @@ // CHECK-LABEL: func @nested_unused // CHECK: [[C0:%.*]] = constant 1 : index -// CHECK: [[C1:%.*]] = constant true -// CHECK: [[V0:%.*]] = scf.if [[C1]] -> (index) { -// CHECK: [[V1:%.*]] = scf.if [[C1]] -> (index) { +// CHECK: [[V0:%.*]] = scf.if {{.*}} -> (index) { +// CHECK: [[V1:%.*]] = scf.if {{.*}} -> (index) { // CHECK: scf.yield [[C0]] : index // CHECK: } else // CHECK: scf.yield [[C0]] : index @@ -115,11 +111,10 @@ // ----- func private @side_effect() {} -func @all_unused() { +func @all_unused(%cond: i1) { %c0 = constant 0 : index %c1 = constant 1 : index - %true = constant true - %0, %1 = scf.if %true -> (index, index) { + %0, %1 = scf.if %cond -> (index, index) { call @side_effect() : () -> () scf.yield %c0, %c1 : index, index } else { @@ -130,8 +125,7 @@ } // CHECK-LABEL: func @all_unused -// CHECK: [[C1:%.*]] = constant true -// CHECK: scf.if [[C1]] { +// CHECK: scf.if %{{.*}} { // CHECK: call @side_effect() : () -> () // CHECK: } else // CHECK: call @side_effect() : () -> () @@ -172,3 +166,115 @@ // CHECK-NEXT: scf.yield %[[c]] : i32 // CHECK-NEXT: } // CHECK-NEXT: return %[[a]], %[[r1]], %[[b]] : i32, i32, i32 + +// CHECK-LABEL: @replace_true_if +func @replace_true_if() { + %true = constant true + // CHECK-NOT: scf.if + // CHECK: "test.op" + scf.if %true { + "test.op"() : () -> () + scf.yield + } + return +} + +// CHECK-LABEL: @remove_false_if +func @remove_false_if() { + %false = constant false + // CHECK-NOT: scf.if + // CHECK-NOT: "test.op" + scf.if %false { + "test.op"() : () -> () + scf.yield + } + return +} + +// CHECK-LABEL: @replace_true_if_with_values +func @replace_true_if_with_values() { + %true = constant true + // CHECK-NOT: scf.if + // CHECK: %[[VAL:.*]] = "test.op" + %0 = scf.if %true -> (i32) { + %1 = "test.op"() : () -> i32 + scf.yield %1 : i32 + } else { + %2 = "test.other_op"() : () -> i32 + scf.yield %2 : i32 + } + // CHECK: "test.consume"(%[[VAL]]) + "test.consume"(%0) : (i32) -> () + return +} + +// CHECK-LABEL: @replace_false_if_with_values +func @replace_false_if_with_values() { + %false = constant false + // CHECK-NOT: scf.if + // CHECK: %[[VAL:.*]] = "test.other_op" + %0 = scf.if %false -> (i32) { + %1 = "test.op"() : () -> i32 + scf.yield %1 : i32 + } else { + %2 = "test.other_op"() : () -> i32 + scf.yield %2 : i32 + } + // CHECK: "test.consume"(%[[VAL]]) + "test.consume"(%0) : (i32) -> () + return +} + +// CHECK-LABEL: @remove_zero_iteration_loop +func @remove_zero_iteration_loop() { + %c42 = constant 42 : index + %c1 = constant 1 : index + // CHECK: %[[INIT:.*]] = "test.init" + %init = "test.init"() : () -> i32 + // CHECK-NOT: scf.for + %0 = scf.for %i = %c42 to %c1 step %c1 iter_args(%arg = %init) -> (i32) { + %1 = "test.op"(%i, %arg) : (index, i32) -> i32 + scf.yield %1 : i32 + } + // CHECK: "test.consume"(%[[INIT]]) + "test.consume"(%0) : (i32) -> () + return +} + +// CHECK-LABEL: @replace_single_iteration_loop +func @replace_single_iteration_loop() { + // CHECK: %[[LB:.*]] = constant 42 + %c42 = constant 42 : index + %c43 = constant 43 : index + %c1 = constant 1 : index + // CHECK: %[[INIT:.*]] = "test.init" + %init = "test.init"() : () -> i32 + // CHECK-NOT: scf.for + // CHECK: %[[VAL:.*]] = "test.op"(%[[LB]], %[[INIT]]) + %0 = scf.for %i = %c42 to %c43 step %c1 iter_args(%arg = %init) -> (i32) { + %1 = "test.op"(%i, %arg) : (index, i32) -> i32 + scf.yield %1 : i32 + } + // CHECK: "test.consume"(%[[VAL]]) + "test.consume"(%0) : (i32) -> () + return +} + +// CHECK-LABEL: @replace_single_iteration_loop_non_unit_step +func @replace_single_iteration_loop_non_unit_step() { + // CHECK: %[[LB:.*]] = constant 42 + %c42 = constant 42 : index + %c47 = constant 47 : index + %c5 = constant 5 : index + // CHECK: %[[INIT:.*]] = "test.init" + %init = "test.init"() : () -> i32 + // CHECK-NOT: scf.for + // CHECK: %[[VAL:.*]] = "test.op"(%[[LB]], %[[INIT]]) + %0 = scf.for %i = %c42 to %c47 step %c5 iter_args(%arg = %init) -> (i32) { + %1 = "test.op"(%i, %arg) : (index, i32) -> i32 + scf.yield %1 : i32 + } + // CHECK: "test.consume"(%[[VAL]]) + "test.consume"(%0) : (i32) -> () + return +}