diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -22,6 +22,7 @@ #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -3738,7 +3739,8 @@ } }; -struct WhileUnusedArg : public OpRewritePattern { +/// Remove unused init/yield args. +struct WhileRemoveUnusedArgs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(WhileOp op, @@ -3746,42 +3748,52 @@ if (!llvm::any_of(op.getBeforeArguments(), [](Value arg) { return arg.use_empty(); })) - return failure(); + return rewriter.notifyMatchFailure(op, "No args to remove"); YieldOp yield = op.getYieldOp(); // Collect results mapping, new terminator args and new result types. SmallVector newYields; SmallVector newInits; - llvm::BitVector argsToErase(op.getBeforeArguments().size()); - for (const auto &it : llvm::enumerate(llvm::zip( - op.getBeforeArguments(), yield.getOperands(), op.getInits()))) { - Value beforeArg = std::get<0>(it.value()); - Value yieldValue = std::get<1>(it.value()); - Value initValue = std::get<2>(it.value()); + llvm::BitVector argsToErase; + + size_t argsCount = op.getBeforeArguments().size(); + newYields.reserve(argsCount); + newInits.reserve(argsCount); + argsToErase.reserve(argsCount); + for (auto &&[beforeArg, yieldValue, initValue] : llvm::zip( + op.getBeforeArguments(), yield.getOperands(), op.getInits())) { if (beforeArg.use_empty()) { - argsToErase.set(it.index()); + argsToErase.push_back(true); } else { + argsToErase.push_back(false); newYields.emplace_back(yieldValue); newInits.emplace_back(initValue); } } - if (argsToErase.none()) - return failure(); + Block &beforeBlock = op.getBefore().front(); + Block &afterBlock = op.getAfter().front(); - rewriter.startRootUpdate(op); - op.getBefore().front().eraseArguments(argsToErase); - rewriter.finalizeRootUpdate(op); + beforeBlock.eraseArguments(argsToErase); - WhileOp replacement = - rewriter.create(op.getLoc(), op.getResultTypes(), newInits); - replacement.getBefore().takeBody(op.getBefore()); - replacement.getAfter().takeBody(op.getAfter()); - rewriter.replaceOp(op, replacement.getResults()); + Location loc = op.getLoc(); + auto newWhileOp = + rewriter.create(loc, op.getResultTypes(), newInits, + /*beforeBody*/ nullptr, /*afterBody*/ nullptr); + Block &newBeforeBlock = newWhileOp.getBefore().front(); + Block &newAfterBlock = newWhileOp.getAfter().front(); + OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(yield); rewriter.replaceOpWithNewOp(yield, newYields); + + rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, + newBeforeBlock.getArguments()); + rewriter.mergeBlocks(&afterBlock, &newAfterBlock, + newAfterBlock.getArguments()); + + rewriter.replaceOp(op, newWhileOp.getResults()); return success(); } }; @@ -3792,14 +3804,21 @@ LogicalResult matchAndRewrite(WhileOp op, PatternRewriter &rewriter) const override { - Block &beforeBlock = op.getBefore().front(); - Block &afterBlock = op.getAfter().front(); - - auto condOp = cast(beforeBlock.getTerminator()); + ConditionOp condOp = op.getConditionOp(); ValueRange condOpArgs = condOp.getArgs(); + + llvm::SmallPtrSet argsSet; + for (Value arg : condOpArgs) + argsSet.insert(arg); + + if (argsSet.size() == condOpArgs.size()) + return rewriter.notifyMatchFailure(op, "No results to remove"); + llvm::SmallDenseMap argsMap; SmallVector newArgs; - for (auto arg : condOpArgs) { + argsMap.reserve(condOpArgs.size()); + newArgs.reserve(condOpArgs.size()); + for (Value arg : condOpArgs) { if (!argsMap.count(arg)) { auto pos = static_cast(argsMap.size()); argsMap.insert({arg, pos}); @@ -3807,9 +3826,6 @@ } } - if (argsMap.size() == condOpArgs.size()) - return rewriter.notifyMatchFailure(op, "No results to remove"); - ValueRange argsRange(newArgs); Location loc = op.getLoc(); @@ -3834,64 +3850,13 @@ rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), argsRange); - rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, - newBeforeBlock.getArguments()); - rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping); - rewriter.replaceOp(op, resultsMapping); - return success(); - } -}; - -/// Remove unused init/yield args. -struct WhileRemoveUnusedArgs : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp op, - PatternRewriter &rewriter) const override { Block &beforeBlock = op.getBefore().front(); Block &afterBlock = op.getAfter().front(); - auto yield = cast(afterBlock.getTerminator()); - - llvm::BitVector argsToRemove; - SmallVector newInits; - SmallVector newYieldArgs; - - bool changed = false; - for (auto &&[arg, init, yieldArg] : llvm::zip( - beforeBlock.getArguments(), op.getInits(), yield.getResults())) { - bool empty = arg.use_empty(); - argsToRemove.push_back(empty); - if (empty) { - changed = true; - continue; - } - - newInits.emplace_back(init); - newYieldArgs.emplace_back(yieldArg); - } - - if (!changed) - return rewriter.notifyMatchFailure(op, "No args to remove"); - - beforeBlock.eraseArguments(argsToRemove); - - Location loc = op.getLoc(); - auto newWhileOp = - rewriter.create(loc, op->getResultTypes(), newInits, - /*beforeBody*/ nullptr, /*afterBody*/ nullptr); - Block &newBeforeBlock = newWhileOp.getBefore().front(); - Block &newAfterBlock = newWhileOp.getAfter().front(); - - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(yield); - rewriter.replaceOpWithNewOp(yield, newYieldArgs); - rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlock.getArguments()); - rewriter.mergeBlocks(&afterBlock, &newAfterBlock, - newAfterBlock.getArguments()); - rewriter.replaceOp(op, newWhileOp.getResults()); + rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping); + rewriter.replaceOp(op, resultsMapping); return success(); } }; diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1019,30 +1019,6 @@ // ----- -// CHECK-LABEL: @while_unused_arg -func.func @while_unused_arg(%x : i32, %y : f64) -> i32 { - %0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) { - %condition = "test.condition"(%arg1) : (i32) -> i1 - scf.condition(%condition) %arg1 : i32 - } do { - ^bb0(%arg1: i32): - %next = "test.use"(%arg1) : (i32) -> (i32) - scf.yield %next, %y : i32, f64 - } - return %0 : i32 -} -// CHECK-NEXT: %[[res:.*]] = scf.while (%[[arg2:.+]] = %{{.*}}) : (i32) -> i32 { -// CHECK-NEXT: %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1 -// CHECK-NEXT: scf.condition(%[[cmp]]) %[[arg2]] : i32 -// CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[post:.+]]: i32): -// CHECK-NEXT: %[[next:.+]] = "test.use"(%[[post]]) : (i32) -> i32 -// CHECK-NEXT: scf.yield %[[next]] : i32 -// CHECK-NEXT: } -// CHECK-NEXT: return %[[res]] : i32 - -// ----- - // CHECK-LABEL: @invariant_loop_args_in_same_order // CHECK-SAME: (%[[FUNC_ARG0:.*]]: tensor) func.func @invariant_loop_args_in_same_order(%f_arg0: tensor) -> (tensor, tensor, tensor, tensor, tensor) { @@ -1221,10 +1197,36 @@ // CHECK: } // CHECK: return %[[RES]], %[[RES]] : i32, i32 + +// ----- + +// CHECK-LABEL: @while_unused_arg1 +func.func @while_unused_arg1(%x : i32, %y : f64) -> i32 { + %0 = scf.while (%arg1 = %x, %arg2 = %y) : (i32, f64) -> (i32) { + %condition = "test.condition"(%arg1) : (i32) -> i1 + scf.condition(%condition) %arg1 : i32 + } do { + ^bb0(%arg1: i32): + %next = "test.use"(%arg1) : (i32) -> (i32) + scf.yield %next, %y : i32, f64 + } + return %0 : i32 +} +// CHECK-NEXT: %[[res:.*]] = scf.while (%[[arg2:.*]] = %{{.*}}) : (i32) -> i32 { +// CHECK-NEXT: %[[cmp:.*]] = "test.condition"(%[[arg2]]) : (i32) -> i1 +// CHECK-NEXT: scf.condition(%[[cmp]]) %[[arg2]] : i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[post:.*]]: i32): +// CHECK-NEXT: %[[next:.*]] = "test.use"(%[[post]]) : (i32) -> i32 +// CHECK-NEXT: scf.yield %[[next]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[res]] : i32 + + // ----- -// CHECK-LABEL: @while_unused_arg -func.func @while_unused_arg(%val0: i32) -> i32 { +// CHECK-LABEL: @while_unused_arg2 +func.func @while_unused_arg2(%val0: i32) -> i32 { %0 = scf.while (%val1 = %val0) : (i32) -> i32 { %val = "test.val"() : () -> i32 %condition = "test.condition"() : () -> i1