diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -2255,11 +2255,71 @@ return success(replaced); } }; + +struct WhileUnusedResult : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp op, + PatternRewriter &rewriter) const override { + auto term = op.getConditionOp(); + auto afterArgs = op.getAfterArguments(); + auto termArgs = term.args(); + + SmallVector argsToRemove; + SmallVector resultsInd; + SmallVector newResultTypes; + SmallVector newTermArgs; + for (auto it : + llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) { + auto i = static_cast(it.index()); + Value result = std::get<0>(it.value()); + Value afterArg = std::get<1>(it.value()); + Value termArg = std::get<2>(it.value()); + if (result.use_empty() && afterArg.use_empty()) { + argsToRemove.emplace_back(i); + } else { + resultsInd.emplace_back(i); + newTermArgs.emplace_back(termArg); + newResultTypes.emplace_back(result.getType()); + } + } + + if (argsToRemove.empty()) + return failure(); + + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(term); + rewriter.replaceOpWithNewOp(term, term.condition(), + newTermArgs); + } + + auto &afterBlock = op.after().front(); + afterBlock.eraseArguments(argsToRemove); + + auto newWhile = + rewriter.create(op.getLoc(), newResultTypes, op.inits()); + + SmallVector newResults(op.getNumResults()); + for (auto it : llvm::enumerate(resultsInd)) + newResults[it.value()] = newWhile.getResult(it.index()); + ; + + rewriter.inlineRegionBefore(op.before(), newWhile.before(), + newWhile.before().begin()); + + rewriter.inlineRegionBefore(op.after(), newWhile.after(), + newWhile.after().begin()); + + rewriter.replaceOp(op, newResults); + return success(); + } +}; } // namespace void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -782,7 +782,7 @@ // ----- // CHECK-LABEL: @while_cond_true -func @while_cond_true() { +func @while_cond_true() -> i1 { %0 = scf.while () : () -> i1 { %condition = "test.condition"() : () -> i1 scf.condition(%condition) %condition : i1 @@ -791,7 +791,7 @@ "test.use"(%arg0) : (i1) -> () scf.yield } - return + return %0 : i1 } // CHECK-NEXT: %[[true:.+]] = arith.constant true // CHECK-NEXT: %{{.+}} = scf.while : () -> i1 { @@ -805,6 +805,34 @@ // ----- +// CHECK-LABEL: @while_unused_result +func @while_unused_result() -> i32 { + %0:2 = scf.while () : () -> (i32, i64) { + %condition = "test.condition"() : () -> i1 + %v1 = "test.get_some_value"() : () -> i32 + %v2 = "test.get_some_value"() : () -> i64 + scf.condition(%condition) %v1, %v2 : i32, i64 + } do { + ^bb0(%arg0: i32, %arg1: i64): + "test.use"(%arg0) : (i32) -> () + scf.yield + } + return %0#0 : i32 +} +// CHECK-NEXT: %[[res:.+]] = scf.while : () -> i32 { +// CHECK-NEXT: %[[cmp:.+]] = "test.condition"() : () -> i1 +// CHECK-NEXT: %[[val:.+]] = "test.get_some_value"() : () -> i32 +// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i64 +// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg0: i32): // no predecessors +// CHECK-NEXT: "test.use"(%arg0) : (i32) -> () +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]] : i32 + +// ----- + // CHECK-LABEL: @combineIfs func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) { %res = scf.if %arg0 -> i32 {