diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -2329,6 +2329,242 @@ } }; +/// Remove loop invariant arguments from `before` block of scf.while. +/// An argument is considered loop invariant if the iteration argument value is +/// the same as the corresponding one being yielded by the `after` block of +/// scf.while. For the arguments which are removed, their uses inside scf.while +/// are replaced with their corresponding initial value. +/// +/// Eg: +/// INPUT :- +/// %res = scf.while <...> iter_args(%arg0 = %a, %arg1 = %b, ..., %argN = %N) +/// { +/// ... +/// scf.condition(%cond) %arg1, %arg0, %arg2, %arg0, ..., %argK +/// } do { +/// ^bb0(%arg0, %arg1, %arg2, ..., %argK): +/// ... +/// scf.yield %arg3, %arg1, ..., %argN +/// } +/// +/// OUTPUT :- +/// %res = scf.while <...> iter_args(%arg1 = %b, ..., %argN = %N) { +/// ... +/// scf.condition(%cond) %arg1, %a, %arg2, %a, ..., %argK +/// } do { +/// ^bb0(%arg0, %arg1, %arg2, ..., %argK): +/// ... +/// scf.yield %arg1, ..., %argN +/// } +/// +/// EXPLANATION: +/// %arg0 in scf.condition was used at position 1 and 3. These will then +/// become %arg1 and %arg3 respectively at the `after` block. We need any of +/// %arg1 or %arg3 to be yielded at position 0 for us to claim %arg0 to be a +/// loop invariant. +struct RemoveLoopInvariantArgsFromBeforeBlock + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp op, + PatternRewriter &rewriter) const override { + Block &beforeBlock = op.getBefore().front(); + Block &afterBlock = op.getAfter().front(); + Block::BlockArgListType beforeBlockArgs = beforeBlock.getArguments(); + ConditionOp condOp = op.getConditionOp(); + OperandRange condOpArgs = condOp.getArgs(); + Operation *yieldOp = afterBlock.getTerminator(); + ValueRange yieldOpArgs = yieldOp->getOperands(); + + SmallVector newInitArgs, newYieldOpArgs; + SmallVector newBeforeBlockType; + DenseMap beforeBlockInitValMap; + bool canSimplify = false; + for (auto it : llvm::enumerate( + llvm::zip(op.getOperands(), beforeBlockArgs, yieldOpArgs))) { + auto index = static_cast(it.index()); + Value initVal, beforeBlockArg, yieldOpArg; + std::tie(initVal, beforeBlockArg, yieldOpArg) = it.value(); + + SmallVector condOpIndices; + // Fetch the operand positions in scf.condition where index-th block + // argument was used. + for (auto condOpIt : llvm::enumerate(condOpArgs)) { + if (condOpIt.value() == beforeBlockArg) + condOpIndices.push_back(condOpIt.index()); + } + + unsigned removeArg = false; + // If any of the `after` block argument from the previously computed + // positions are used at index-th operand of scf.yield, we declare + // index-th block argument of `before` block as loop invariant. + for (unsigned indexCondOp : condOpIndices) { + if (afterBlock.getArgument(indexCondOp) == yieldOpArg) { + beforeBlockInitValMap.insert({index, initVal}); + canSimplify = true; + removeArg = true; + break; + } + } + + if (removeArg) + continue; + + newInitArgs.emplace_back(initVal); + newYieldOpArgs.emplace_back(yieldOpArg); + newBeforeBlockType.emplace_back(yieldOpArg.getType()); + } + + if (!canSimplify) + return failure(); + + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp(yieldOp, newYieldOpArgs); + } + + auto newWhile = + rewriter.create(op.getLoc(), op.getResultTypes(), newInitArgs); + + Block &newBeforeBlock = *rewriter.createBlock( + &newWhile.getBefore(), /*insertPt*/ {}, newBeforeBlockType); + + SmallVector newBeforeBlockArgs(beforeBlock.getNumArguments()); + for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) { + Value blockArg; + // If the index 'i' argument was a loop invariant we fetch it's initial + // value from `beforeBlockInitValMap`. + if (beforeBlockInitValMap.count(i) != 0) + blockArg = beforeBlockInitValMap[i]; + else + blockArg = newBeforeBlock.getArgument(j++); + newBeforeBlockArgs[i] = blockArg; + } + + rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs); + rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(), + newWhile.getAfter().begin()); + + rewriter.replaceOp(op, newWhile.getResults()); + return success(); + } +}; + +/// Remove loop invariant value from result of scf.while. +/// A value is considered loop invariant if the final value yielded by +/// scf.condition is defined outside of the `before` block. We remove the +/// corresponding argument in `after` block and replace the use with the value. +/// We also replace the use of the corresponding result of scf.while with the +/// value. +/// +/// Eg: +/// INPUT :- +/// %res = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) { +/// ... +/// scf.condition(%cond) %arg0, %a, %a, %arg1, ..., %argK +/// } do { +/// ^bb0(%arg0, %arg1, %arg2, ..., %argK): +/// ... +/// some_func(%arg1) +/// ... +/// scf.yield %arg0, %arg2, ..., %argN +/// } +/// +/// OUTPUT :- +/// %res = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) { +/// ... +/// scf.condition(%cond) %arg0, %arg1, ..., %argK +/// } do { +/// ^bb0(%arg0, %arg3, ..., %argK): +/// ... +/// some_func(%a) +/// ... +/// scf.yield %arg0, %a, ..., %argN +/// } +/// +/// EXPLANATION: +/// In scf.condition at position 1 we have a value which is not defined within +/// the `before` block of scf.while. The corresponding argument's, %arg1 in +/// `after` block, and the result can be replaced with %a. +struct RemoveLoopInvariantValueYielded : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp op, + PatternRewriter &rewriter) const override { + Block &beforeBlock = op.getBefore().front(); + Block &afterBlock = op.getAfter().front(); + Block::BlockArgListType afterBlockArgs = op.getAfterArguments(); + ConditionOp condOp = op.getConditionOp(); + OperandRange condOpArgs = condOp.getArgs(); + + SmallVector newCondOpArgs; + SmallVector newAfterBlockType; + DenseMap condOpInitValMap; + unsigned canSimplify = false; + for (auto it : llvm::enumerate(llvm::zip(condOpArgs, afterBlockArgs))) { + auto index = static_cast(it.index()); + Value condOpArg, afterBlockArg; + std::tie(condOpArg, afterBlockArg) = it.value(); + // Those values not defined within `before` block will be considered as + // loop invariant values. We map the corresponding `index` with their + // value. + if (condOpArg.getParentBlock() != &beforeBlock) { + condOpInitValMap.insert({index, condOpArg}); + canSimplify = true; + } else { + newCondOpArgs.emplace_back(condOpArg); + newAfterBlockType.emplace_back(condOpArg.getType()); + } + } + + if (!canSimplify) + return failure(); + + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(condOp); + rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), + newCondOpArgs); + } + + auto newWhile = rewriter.create(op.getLoc(), newAfterBlockType, + op.getOperands()); + + Block &newAfterBlock = *rewriter.createBlock( + &newWhile.getAfter(), /*insertPt*/ {}, newAfterBlockType); + + // Since a new scf.condition op was created, we need to fetch the new + // `after` block arguments which will be used while replacing operations of + // previous scf.while's `after` blocks. We'd also be fetching new result + // values too. + SmallVector newAfterBlockArgs(afterBlock.getNumArguments()); + SmallVector newWhileResults(afterBlock.getNumArguments()); + for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) { + Value afterBlockArg, result; + // If index 'i' argument was loop invariant we fetch it's value from the + // `condOpInitMap` map. + if (condOpInitValMap.count(i) != 0) { + afterBlockArg = condOpInitValMap[i]; + result = afterBlockArg; + } else { + afterBlockArg = newAfterBlock.getArgument(j); + result = newWhile.getResult(j); + j++; + } + newAfterBlockArgs[i] = afterBlockArg; + newWhileResults[i] = result; + } + + rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs); + rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), + newWhile.getBefore().begin()); + + rewriter.replaceOp(op, newWhileResults); + return success(); + } +}; + /// Remove WhileOp results that are also unused in 'after' block. /// /// %0:2 = scf.while () : () -> (i32, i64) { @@ -2423,7 +2659,9 @@ void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -823,6 +823,74 @@ // ----- +// CHECK-LABEL: @invariant_loop_args_in_same_order +// CHECK-SAME: (%[[FUNC_ARG0:.*]]: tensor) +func @invariant_loop_args_in_same_order(%f_arg0: tensor) -> (tensor, tensor, tensor, tensor, tensor) { + %cst_0 = arith.constant dense<0> : tensor + %cst_1 = arith.constant dense<1> : tensor + %cst_42 = arith.constant dense<42> : tensor + + %0:5 = scf.while (%arg0 = %cst_0, %arg1 = %f_arg0, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor) { + %1 = arith.cmpi slt, %arg0, %cst_42 : tensor + %2 = tensor.extract %1[] : tensor + scf.condition(%2) %arg0, %arg1, %arg2, %arg3, %arg4 : tensor, tensor, tensor, tensor, tensor + } do { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): // no predecessors + // %arg1 here will get replaced by %cst_1 + %1 = arith.addi %arg0, %arg1 : tensor + %2 = arith.addi %arg2, %arg3 : tensor + scf.yield %1, %arg1, %2, %2, %arg4 : tensor, tensor, tensor, tensor, tensor + } + return %0#0, %0#1, %0#2, %0#3, %0#4 : tensor, tensor, tensor, tensor, tensor +} +// CHECK-DAG: %[[CST42:.*]] = arith.constant dense<42> : tensor +// CHECK-DAG: %[[ONE:.*]] = arith.constant dense<1> : tensor +// CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0> : tensor +// CHECK-NEXT: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]]) : (tensor, tensor, tensor) -> (tensor, tensor, tensor) { +// CHECK-NEXT: arith.cmpi slt, %[[ARG0]], %{{.*}} : tensor +// CHECK-NEXT: tensor.extract %{{.*}}[] : tensor +// CHECK-NEXT: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]], %[[ARG3]] : tensor, tensor, tensor +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[ARG0:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor): // no predecessors +// CHECK-NEXT: %[[VAL0:.*]] = arith.addi %[[ARG0]], %[[FUNC_ARG0]] : tensor +// CHECK-NEXT: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG3]] : tensor +// CHECK-NEXT: scf.yield %[[VAL0]], %[[VAL1]], %[[VAL1]] : tensor, tensor, tensor +// CHECK-NEXT: } +// CHECK-NEXT: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]] : tensor, tensor, tensor, tensor, tensor + +// CHECK-LABEL: @while_loop_invariant_argument_different_order +func @while_loop_invariant_argument_different_order() -> (tensor, tensor, tensor, tensor, tensor, tensor) { + %cst_0 = arith.constant dense<0> : tensor + %cst_1 = arith.constant dense<1> : tensor + %cst_42 = arith.constant dense<42> : tensor + + %0:6 = scf.while (%arg0 = %cst_0, %arg1 = %cst_1, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) { + %1 = arith.cmpi slt, %arg0, %cst_42 : tensor + %2 = tensor.extract %1[] : tensor + scf.condition(%2) %arg1, %arg0, %arg2, %arg0, %arg3, %arg4 : tensor, tensor, tensor, tensor, tensor, tensor + } do { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor): // no predecessors + %1 = arith.addi %arg0, %cst_1 : tensor + %2 = arith.addi %arg2, %arg3 : tensor + scf.yield %arg3, %arg1, %2, %2, %arg4 : tensor, tensor, tensor, tensor, tensor + } + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor, tensor, tensor, tensor, tensor, tensor +} +// CHECK-DAG: %[[CST42:.*]] = arith.constant dense<42> : tensor +// CHECK-DAG: %[[ONE:.*]] = arith.constant dense<1> : tensor +// CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0> : tensor +// CHECK-NEXT: %[[WHILE:.*]]:3 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG2:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]]) : (tensor, tensor, tensor) -> (tensor, tensor, tensor) { +// CHECK-NEXT: arith.cmpi slt, %[[ZERO]], %[[CST42]] : tensor +// CHECK-NEXT: tensor.extract %{{.*}}[] : tensor +// CHECK-NEXT: scf.condition(%{{.*}}) %[[ARG1]], %[[ARG2]], %[[ARG4]] : tensor, tensor, tensor +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%{{.*}}: tensor, %[[ARG1:.*]]: tensor, %{{.*}}: tensor): // no predecessors +// CHECK-NEXT: scf.yield %[[ZERO]], %[[ONE]], %[[ARG1]] : tensor, tensor, tensor +// CHECK-NEXT: } +// CHECK-NEXT: return %[[WHILE]]#0, %[[ZERO]], %[[ONE]], %[[ZERO]], %[[WHILE]]#1, %[[WHILE]]#2 : tensor, tensor, tensor, tensor, tensor, tensor + +// ----- + // CHECK-LABEL: @while_unused_result func @while_unused_result() -> i32 { %0:2 = scf.while () : () -> (i32, i64) {