diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -2329,6 +2329,128 @@ } }; +/// Remove loop invariant arguments of scf.while. An argument is considered loop +/// invariant if the iteration argument value is the same as the corresponding +/// one being yielded (at the same position) in both before/after block of +/// scf.while. For the arguments which are removed, their uses inside scf.while +/// and their corresponding scf.while's result are replaced with their +/// corresponding initial value. +struct RemoveLoopInvariantArgs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp op, + PatternRewriter &rewriter) const override { + Block &beforeBlock = op.getBefore().front(); + Block &afterBlock = op.getAfter().front(); + Block::BlockArgListType beforeBlockArgs = beforeBlock.getArguments(); + Block::BlockArgListType afterBlockArgs = op.getAfterArguments(); + ConditionOp condOp = op.getConditionOp(); + OperandRange condOpArgs = condOp.getArgs(); + Operation *yieldOp = afterBlock.getTerminator(); + ValueRange yieldOpArgs = yieldOp->getOperands(); + + SmallVector newResultsIndices; + SmallVector newResultTypes; + SmallVector newCondOpArgs, newInitArgs, newYieldOpArgs; + // `canSimplify` is set true if there is at least one loop invariant + // argument, thus requiring us to create a new SCF.WhileOp. + bool canSimplify = false; + for (auto it : llvm::enumerate(llvm::zip(op.getOperands(), op.getResults(), + beforeBlockArgs, afterBlockArgs, + condOpArgs, yieldOpArgs))) { + auto index = static_cast(it.index()); + Value initVal, result, beforeBlockArg, afterBlockArg, condOpArg, + yieldOpArg; + std::tie(initVal, result, beforeBlockArg, afterBlockArg, condOpArg, + yieldOpArg) = it.value(); + + // Check for both before/after block that the iteration argument value is + // the same as the corresponding one being yielded (at the same position + // 'index'). + bool beforeBlockInvariant = false; + for (Operation *user : beforeBlockArg.getUsers()) { + if (isa(user) && + user->getOperand(index + 1) == beforeBlockArg) { + beforeBlockInvariant = true; + break; + } + } + bool afterBlockInvariant = false; + for (Operation *user : afterBlockArg.getUsers()) { + if (user->getBlock() != &afterBlock) + continue; + if (isa(user) && user->getOperand(index) == afterBlockArg) { + afterBlockInvariant = true; + break; + } + } + + // We ensure that the argument at `index` remains unchanged both in the + // before/after blocks and then replace its uses within the while op as + // well as the result of the while op with the initial value. + if (beforeBlockInvariant && afterBlockInvariant) { + afterBlockArg.replaceAllUsesWith(initVal); + beforeBlockArg.replaceAllUsesWith(initVal); + result.replaceAllUsesWith(initVal); + canSimplify = true; + } else { + newResultsIndices.emplace_back(index); + newCondOpArgs.emplace_back(condOpArg); + newResultTypes.emplace_back(result.getType()); + newInitArgs.emplace_back(initVal); + newYieldOpArgs.emplace_back(yieldOpArg); + } + } + + if (!canSimplify) + return failure(); + + { + // Creating a new block to set the insertion guard and create terminator + // ops (ConditionOp and YieldOp) for before/after block respectively. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(condOp); + rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), + newCondOpArgs); + + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp(yieldOp, newYieldOpArgs); + } + + // Creating a new scf.while with new initial operands and result types. + auto newWhile = + rewriter.create(op.getLoc(), newResultTypes, newInitArgs); + + // Create before/after blocks in the before/after region of the new + // scf.while. + Block &newBeforeBlock = *rewriter.createBlock( + &newWhile.getBefore(), /*insertPt*/ {}, newResultTypes); + Block &newAfterBlock = *rewriter.createBlock( + &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes); + + // Replace results of old scf.while with that of new scf.while. Also, + // build new before/after block args (unused entries will be null). + SmallVector newAfterBlockArgs(op.getNumResults()); + SmallVector newBeforeBlockArgs(op.getNumResults()); + for (auto it : llvm::enumerate(newResultsIndices)) { + op.getResult(it.value()) + .replaceAllUsesWith(newWhile.getResult(it.index())); + newBeforeBlockArgs[it.value()] = newBeforeBlock.getArgument(it.index()); + newAfterBlockArgs[it.value()] = newAfterBlock.getArgument(it.index()); + } + + // Move the operations of before/after blocks of the old scf.while into the + // before/after blocks of new scf.while. Here `newBeforeBlocksArgs`/ + // `newAfterBlockArgs` are used to replace the block arguments of + // before/after blocks of scf.while after merging. + rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs); + rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs); + + rewriter.eraseOp(op); + return success(); + } +}; + /// Remove WhileOp results that are also unused in 'after' block. /// /// %0:2 = scf.while () : () -> (i32, i64) { @@ -2423,7 +2545,9 @@ void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results + .insert( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -823,6 +823,43 @@ // ----- +// CHECK-LABEL: @invariant_loop_args +func @invariant_loop_args() -> (tensor, tensor, tensor, tensor, tensor) { + %cst_0 = arith.constant dense<0> : tensor + %cst_1 = arith.constant dense<1> : tensor + %cst_42 = arith.constant dense<42> : tensor + + // %arg1 and %arg4 are invariant loop args. + %0:5 = scf.while (%arg0 = %cst_0, %arg1 = %cst_1, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor) { + %1 = arith.cmpi slt, %arg0, %cst_42 : tensor + %2 = tensor.extract %1[] : tensor + scf.condition(%2) %arg0, %arg1, %arg2, %arg3, %arg4 : tensor, tensor, tensor, tensor, tensor + } do { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): // no predecessors + // %arg1 here will get replaced by %cst_1 + %1 = arith.addi %arg0, %arg1 : tensor + %2 = arith.addi %arg2, %arg3 : tensor + scf.yield %1, %arg1, %2, %2, %arg4 : tensor, tensor, tensor, tensor, tensor + } + return %0#0, %0#1, %0#2, %0#3, %0#4 : tensor, tensor, tensor, tensor, tensor +} +// CHECK-DAG: %[[CST42:.*]] = arith.constant dense<42> : tensor +// CHECK-DAG: %[[ONE:.*]] = arith.constant dense<1> : tensor +// CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0> : tensor +// CHECK-NEXT: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]]) : (tensor, tensor, tensor) -> (tensor, tensor, tensor) { +// CHECK-NEXT: arith.cmpi slt, %[[ARG0]], %{{.*}} : tensor +// CHECK-NEXT: tensor.extract %{{.*}}[] : tensor +// CHECK-NEXT: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]], %[[ARG3]] : tensor, tensor, tensor +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[ARG0:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor): // no predecessors +// CHECK-NEXT: %[[VAL0:.*]] = arith.addi %[[ARG0]], %[[ONE]] : tensor +// CHECK-NEXT: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG3]] : tensor +// CHECK-NEXT: scf.yield %[[VAL0]], %[[VAL1]], %[[VAL1]] : tensor, tensor, tensor +// CHECK-NEXT: } +// CHECK-NEXT: return %[[WHILE]]#0, %[[ONE]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]] : tensor, tensor, tensor, tensor, tensor + +// ----- + // CHECK-LABEL: @while_unused_result func @while_unused_result() -> i32 { %0:2 = scf.while () : () -> (i32, i64) {