diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -3785,13 +3785,129 @@ return success(); } }; + +/// Remove duplicated ConditionOp args. +struct WhileRemoveDuplicatedResults : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp op, + PatternRewriter &rewriter) const override { + Block &beforeBlock = op.getBefore().front(); + Block &afterBlock = op.getAfter().front(); + + auto condOp = cast(beforeBlock.getTerminator()); + ValueRange condOpArgs = condOp.getArgs(); + llvm::SmallDenseMap argsMap; + SmallVector newArgs; + for (auto arg : condOpArgs) { + if (!argsMap.count(arg)) { + auto pos = static_cast(argsMap.size()); + argsMap.insert({arg, pos}); + newArgs.emplace_back(arg); + } + } + + if (argsMap.size() == condOpArgs.size()) + return rewriter.notifyMatchFailure(op, "No results to remove"); + + ValueRange argsRange(newArgs); + auto emptyBuilder = [](OpBuilder &, Location, ValueRange) { + // Nothing + }; + + Location loc = op.getLoc(); + auto newWhileOp = rewriter.create( + loc, argsRange.getTypes(), op.getInits(), emptyBuilder, emptyBuilder); + Block &newBeforeBlock = newWhileOp.getBefore().front(); + Block &newAfterBlock = newWhileOp.getAfter().front(); + + SmallVector afterArgsMapping; + SmallVector resultsMapping; + for (auto &&[i, arg] : llvm::enumerate(condOpArgs)) { + auto it = argsMap.find(arg); + assert(it != argsMap.end()); + auto pos = it->second; + afterArgsMapping.emplace_back(newAfterBlock.getArgument(pos)); + resultsMapping.emplace_back(newWhileOp->getResult(pos)); + } + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(condOp); + rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), + argsRange); + + rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, + newBeforeBlock.getArguments()); + rewriter.mergeBlocks(&afterBlock, &newAfterBlock, afterArgsMapping); + rewriter.replaceOp(op, resultsMapping); + return success(); + } +}; + +/// Remove unused init/yield args. +struct WhileRemoveUnusedArgs : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp op, + PatternRewriter &rewriter) const override { + Block &beforeBlock = op.getBefore().front(); + Block &afterBlock = op.getAfter().front(); + + auto yield = cast(afterBlock.getTerminator()); + + llvm::BitVector argsToRemove; + SmallVector newInits; + SmallVector newYieldArgs; + + bool changed = false; + for (auto &&[arg, init, yieldArg] : llvm::zip( + beforeBlock.getArguments(), op.getInits(), yield.getResults())) { + bool empty = arg.use_empty(); + argsToRemove.push_back(empty); + if (empty) { + changed = true; + continue; + } + + newInits.emplace_back(init); + newYieldArgs.emplace_back(yieldArg); + } + + if (!changed) + return rewriter.notifyMatchFailure(op, "No args to remove"); + + beforeBlock.eraseArguments(argsToRemove); + + auto emptyBuilder = [](OpBuilder &, Location, ValueRange) { + // Nothing + }; + + Location loc = op.getLoc(); + auto newWhileOp = rewriter.create( + loc, op->getResultTypes(), newInits, emptyBuilder, emptyBuilder); + Block &newBeforeBlock = newWhileOp.getBefore().front(); + Block &newAfterBlock = newWhileOp.getAfter().front(); + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(yield); + rewriter.replaceOpWithNewOp(yield, newYieldArgs); + + rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, + newBeforeBlock.getArguments()); + rewriter.mergeBlocks(&afterBlock, &newAfterBlock, + newAfterBlock.getArguments()); + rewriter.replaceOp(op, newWhileOp.getResults()); + return success(); + } +}; } // namespace void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + WhileCmpCond, WhileUnusedResult, WhileRemoveDuplicatedResults, + WhileRemoveUnusedArgs>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1195,6 +1195,60 @@ // CHECK-NEXT: scf.yield // CHECK-NEXT: } +// ----- + +// CHECK-LABEL: @while_duplicated_res +func.func @while_duplicated_res() -> (i32, i32) { + %0:2 = scf.while () : () -> (i32, i32) { + %val = "test.val"() : () -> i32 + %condition = "test.condition"() : () -> i1 + scf.condition(%condition) %val, %val : i32, i32 + } do { + ^bb0(%val2: i32, %val3: i32): + "test.use"(%val2, %val3) : (i32, i32) -> () + scf.yield + } + return %0#0, %0#1: i32, i32 +} +// CHECK: %[[RES:.*]] = scf.while : () -> i32 { +// CHECK: %[[VAL:.*]] = "test.val"() : () -> i32 +// CHECK: %[[COND:.*]] = "test.condition"() : () -> i1 +// CHECK: scf.condition(%[[COND]]) %[[VAL]] : i32 +// CHECK: } do { +// CHECK: ^bb0(%[[ARG:.*]]: i32): +// CHECK: "test.use"(%[[ARG]], %[[ARG]]) : (i32, i32) -> () +// CHECK: scf.yield +// CHECK: } +// CHECK: return %[[RES]], %[[RES]] : i32, i32 + +// ----- + +// CHECK-LABEL: @while_unused_arg +func.func @while_unused_arg(%val0: i32) -> i32 { + %0 = scf.while (%val1 = %val0) : (i32) -> i32 { + %val = "test.val"() : () -> i32 + %condition = "test.condition"() : () -> i1 + scf.condition(%condition) %val: i32 + } do { + ^bb0(%val2: i32): + "test.use"(%val2) : (i32) -> () + %val1 = "test.val1"() : () -> i32 + scf.yield %val1 : i32 + } + return %0 : i32 +} +// CHECK: %[[RES:.*]] = scf.while : () -> i32 { +// CHECK: %[[VAL:.*]] = "test.val"() : () -> i32 +// CHECK: %[[COND:.*]] = "test.condition"() : () -> i1 +// CHECK: scf.condition(%[[COND]]) %[[VAL]] : i32 +// CHECK: } do { +// CHECK: ^bb0(%[[ARG:.*]]: i32): +// CHECK: "test.use"(%[[ARG]]) : (i32) -> () +// CHECK: scf.yield +// CHECK: } +// CHECK: return %[[RES]] : i32 + + // ----- // CHECK-LABEL: @combineIfs diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -518,6 +518,8 @@ %arg2: index) { scf.while (%arg3 = %arg1) : (tensor<5xi1>) -> () { %0 = tensor.extract %arg0[%arg2] : tensor<5xi1> + %1 = tensor.extract %arg3[%arg2] : tensor<5xi1> + "dummy.use"(%1) : (i1) -> () scf.condition(%0) } do { %0 = "dummy.some_op"() : () -> index