diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -686,6 +686,8 @@ let extraClassDeclaration = [{ OperandRange getSuccessorEntryOperands(unsigned index); ConditionOp getConditionOp(); + YieldOp getYieldOp(); + Block::BlockArgListType getBeforeArguments(); Block::BlockArgListType getAfterArguments(); }]; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -2171,6 +2171,14 @@ return cast(getBefore().front().getTerminator()); } +YieldOp WhileOp::getYieldOp() { + return cast(getAfter().front().getTerminator()); +} + +Block::BlockArgListType WhileOp::getBeforeArguments() { + return getBefore().front().getArguments(); +} + Block::BlockArgListType WhileOp::getAfterArguments() { return getAfter().front().getArguments(); } @@ -2508,11 +2516,60 @@ return success(changed); } }; + +struct WhileUnusedArg : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp op, + PatternRewriter &rewriter) const override { + + if (!llvm::any_of(op.getBeforeArguments(), + [](Value arg) { return arg.use_empty(); })) + return failure(); + + YieldOp yield = op.getYieldOp(); + + // Collect results mapping, new terminator args and new result types. + SmallVector newYields; + SmallVector newInits; + SmallVector argsToErase; + for (const auto &it : llvm::enumerate(llvm::zip( + op.getBeforeArguments(), yield.getOperands(), op.getInits()))) { + Value beforeArg = std::get<0>(it.value()); + Value yieldValue = std::get<1>(it.value()); + Value initValue = std::get<2>(it.value()); + if (beforeArg.use_empty()) { + argsToErase.push_back(it.index()); + } else { + newYields.emplace_back(yieldValue); + newInits.emplace_back(initValue); + } + } + + if (argsToErase.size() == 0) + return failure(); + + rewriter.startRootUpdate(op); + op.getBefore().front().eraseArguments(argsToErase); + rewriter.finalizeRootUpdate(op); + + WhileOp replacement = + rewriter.create(op.getLoc(), op.getResultTypes(), newInits); + replacement.getBefore().takeBody(op.getBefore()); + replacement.getAfter().takeBody(op.getAfter()); + rewriter.replaceOp(op, replacement.getResults()); + + rewriter.setInsertionPoint(yield); + rewriter.replaceOpWithNewOp(yield, newYields); + return success(); + } +}; } // namespace void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -846,6 +846,30 @@ // ----- +// CHECK-LABEL: @while_unused_arg +func @while_unused_arg(%x : i32, %y : f64) -> i32 { + %0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) { + %condition = "test.condition"(%arg1) : (i32) -> i1 + scf.condition(%condition) %arg1 : i32 + } do { + ^bb0(%arg1: i32): + %next = "test.use"(%arg1) : (i32) -> (i32) + scf.yield %next, %y : i32, f64 + } + return %0 : i32 +} +// CHECK-NEXT: %[[res:.*]] = scf.while (%[[arg2:.+]] = %{{.*}}) : (i32) -> i32 { +// CHECK-NEXT: %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1 +// CHECK-NEXT: scf.condition(%[[cmp]]) %[[arg2]] : i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[post:.+]]: i32): // no predecessors +// CHECK-NEXT: %[[next:.+]] = "test.use"(%[[post]]) : (i32) -> i32 +// CHECK-NEXT: scf.yield %[[next]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]] : i32 + +// ----- + // CHECK-LABEL: @while_unused_result func @while_unused_result() -> i32 { %0:2 = scf.while () : () -> (i32, i64) {