diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -2255,11 +2255,102 @@ return success(replaced); } }; + +/// Remove WhileOp results that are also unused in 'after' block. +/// +/// %0:2 = scf.while () : () -> (i32, i64) { +/// %condition = "test.condition"() : () -> i1 +/// %v1 = "test.get_some_value"() : () -> i32 +/// %v2 = "test.get_some_value"() : () -> i64 +/// scf.condition(%condition) %v1, %v2 : i32, i64 +/// } do { +/// ^bb0(%arg0: i32, %arg1: i64): +/// "test.use"(%arg0) : (i32) -> () +/// scf.yield +/// } +/// return %0#0 : i32 +/// +/// becomes +/// %0 = scf.while () : () -> (i32) { +/// %condition = "test.condition"() : () -> i1 +/// %v1 = "test.get_some_value"() : () -> i32 +/// %v2 = "test.get_some_value"() : () -> i64 +/// scf.condition(%condition) %v1 : i32 +/// } do { +/// ^bb0(%arg0: i32): +/// "test.use"(%arg0) : (i32) -> () +/// scf.yield +/// } +/// return %0 : i32 +struct WhileUnusedResult : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp op, + PatternRewriter &rewriter) const override { + auto term = op.getConditionOp(); + auto afterArgs = op.getAfterArguments(); + auto termArgs = term.args(); + + // Collect results mapping, new terminator args and new result types. + SmallVector newResultsIndices; + SmallVector newResultTypes; + SmallVector newTermArgs; + bool needUpdate = false; + for (auto it : + llvm::enumerate(llvm::zip(op.getResults(), afterArgs, termArgs))) { + auto i = static_cast(it.index()); + Value result = std::get<0>(it.value()); + Value afterArg = std::get<1>(it.value()); + Value termArg = std::get<2>(it.value()); + if (result.use_empty() && afterArg.use_empty()) { + needUpdate = true; + } else { + newResultsIndices.emplace_back(i); + newTermArgs.emplace_back(termArg); + newResultTypes.emplace_back(result.getType()); + } + } + + if (!needUpdate) + return failure(); + + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(term); + rewriter.replaceOpWithNewOp(term, term.condition(), + newTermArgs); + } + + auto newWhile = + rewriter.create(op.getLoc(), newResultTypes, op.inits()); + + Block &newAfterBlock = *rewriter.createBlock( + &newWhile.after(), /*insertPt*/ {}, newResultTypes); + + // Build new results list and new after block args (unused entries will be + // null). + SmallVector newResults(op.getNumResults()); + SmallVector newAfterBlockArgs(op.getNumResults()); + for (auto it : llvm::enumerate(newResultsIndices)) { + newResults[it.value()] = newWhile.getResult(it.index()); + newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index()); + } + + rewriter.inlineRegionBefore(op.before(), newWhile.before(), + newWhile.before().begin()); + + Block &afterBlock = op.after().front(); + rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs); + + rewriter.replaceOp(op, newResults); + return success(); + } +}; } // namespace void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -782,7 +782,7 @@ // ----- // CHECK-LABEL: @while_cond_true -func @while_cond_true() { +func @while_cond_true() -> i1 { %0 = scf.while () : () -> i1 { %condition = "test.condition"() : () -> i1 scf.condition(%condition) %condition : i1 @@ -791,7 +791,7 @@ "test.use"(%arg0) : (i1) -> () scf.yield } - return + return %0 : i1 } // CHECK-NEXT: %[[true:.+]] = arith.constant true // CHECK-NEXT: %{{.+}} = scf.while : () -> i1 { @@ -805,6 +805,34 @@ // ----- +// CHECK-LABEL: @while_unused_result +func @while_unused_result() -> i32 { + %0:2 = scf.while () : () -> (i32, i64) { + %condition = "test.condition"() : () -> i1 + %v1 = "test.get_some_value"() : () -> i32 + %v2 = "test.get_some_value"() : () -> i64 + scf.condition(%condition) %v1, %v2 : i32, i64 + } do { + ^bb0(%arg0: i32, %arg1: i64): + "test.use"(%arg0) : (i32) -> () + scf.yield + } + return %0#0 : i32 +} +// CHECK-NEXT: %[[res:.*]] = scf.while : () -> i32 { +// CHECK-NEXT: %[[cmp:.*]] = "test.condition"() : () -> i1 +// CHECK-NEXT: %[[val:.*]] = "test.get_some_value"() : () -> i32 +// CHECK-NEXT: %{{.*}} = "test.get_some_value"() : () -> i64 +// CHECK-NEXT: scf.condition(%[[cmp]]) %[[val]] : i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[arg:.*]]: i32): // no predecessors +// CHECK-NEXT: "test.use"(%[[arg]]) : (i32) -> () +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]] : i32 + +// ----- + // CHECK-LABEL: @combineIfs func @combineIfs(%arg0 : i1, %arg2: i64) -> (i32, i32) { %res = scf.if %arg0 -> i32 {