diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -2343,6 +2343,297 @@ } }; +/// Remove loop invariant arguments from `before` block of scf.while. +/// A before block argument is considered loop invariant if :- +/// 1. i-th yield operand is equal to the i-th while operand. +/// 2. i-th yield operand is k-th after block argument which is (k+1)-th +/// condition operand AND this (k+1)-th condition operand is equal to i-th +/// iter argument/while operand. +/// For the arguments which are removed, their uses inside scf.while +/// are replaced with their corresponding initial value. +/// +/// Eg: +/// INPUT :- +/// %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b, +/// ..., %argN_before = %N) +/// { +/// ... +/// scf.condition(%cond) %arg1_before, %arg0_before, +/// %arg2_before, %arg0_before, ... +/// } do { +/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, +/// ..., %argK_after): +/// ... +/// scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN +/// } +/// +/// OUTPUT :- +/// %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before = +/// %N) +/// { +/// ... +/// scf.condition(%cond) %b, %a, %arg2_before, %a, ... +/// } do { +/// ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2, +/// ..., %argK_after): +/// ... +/// scf.yield %arg1_after, ..., %argN +/// } +/// +/// EXPLANATION: +/// We iterate over each yield operand. +/// 1. 0-th yield operand %arg0_after_2 is 4-th condition operand +/// %arg0_before, which in turn is the 0-th iter argument. So we +/// remove 0-th before block argument and yield operand, and replace +/// all uses of the 0-th before block argument with its initial value +/// %a. +/// 2. 1-th yield operand %b is equal to the 1-th iter arg's initial +/// value. So we remove this operand and the corresponding before +/// block argument and replace all uses of 1-th before block argument +/// with %b. +struct RemoveLoopInvariantArgsFromBeforeBlock + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp op, + PatternRewriter &rewriter) const override { + Block &afterBlock = op.getAfter().front(); + Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments(); + ConditionOp condOp = op.getConditionOp(); + OperandRange condOpArgs = condOp.getArgs(); + Operation *yieldOp = afterBlock.getTerminator(); + ValueRange yieldOpArgs = yieldOp->getOperands(); + + bool canSimplify = false; + for (auto it : llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { + auto index = static_cast(it.index()); + Value initVal, yieldOpArg; + std::tie(initVal, yieldOpArg) = it.value(); + // If i-th yield operand is equal to the i-th operand of the scf.while, + // the i-th before block argument is a loop invariant. + if (yieldOpArg == initVal) { + canSimplify = true; + break; + } + // If the i-th yield operand is k-th after block argument, then we check + // if the (k+1)-th condition op operand is equal to either the i-th before + // block argument or the initial value of i-th before block argument. If + // the comparison results `true`, i-th before block argument is a loop + // invariant. + auto yieldOpBlockArg = yieldOpArg.dyn_cast(); + if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { + Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; + if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { + canSimplify = true; + break; + } + } + } + + if (!canSimplify) + return failure(); + + SmallVector newInitArgs, newYieldOpArgs; + DenseMap beforeBlockInitValMap; + SmallVector newBeforeBlockArgLocs; + for (auto it : llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) { + auto index = static_cast(it.index()); + Value initVal, yieldOpArg; + std::tie(initVal, yieldOpArg) = it.value(); + + // If i-th yield operand is equal to the i-th operand of the scf.while, + // the i-th before block argument is a loop invariant. + if (yieldOpArg == initVal) { + beforeBlockInitValMap.insert({index, initVal}); + continue; + } else { + // If the i-th yield operand is k-th after block argument, then we check + // if the (k+1)-th condition op operand is equal to either the i-th + // before block argument or the initial value of i-th before block + // argument. If the comparison results `true`, i-th before block + // argument is a loop invariant. + auto yieldOpBlockArg = yieldOpArg.dyn_cast(); + if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) { + Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()]; + if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) { + beforeBlockInitValMap.insert({index, initVal}); + continue; + } + } + } + newInitArgs.emplace_back(initVal); + newYieldOpArgs.emplace_back(yieldOpArg); + newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc()); + } + + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp(yieldOp, newYieldOpArgs); + } + + auto newWhile = + rewriter.create(op.getLoc(), op.getResultTypes(), newInitArgs); + + Block &newBeforeBlock = *rewriter.createBlock( + &newWhile.getBefore(), /*insertPt*/ {}, + ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs); + + Block &beforeBlock = op.getBefore().front(); + SmallVector newBeforeBlockArgs(beforeBlock.getNumArguments()); + // For each i-th before block argument we find it's replacement value as :- + // 1. If i-th before block argument is a loop invariant, we fetch it's + // initial value from `beforeBlockInitValMap` by querying for key `i`. + // 2. Else we fetch j-th new before block argument as the replacement + // value of i-th before block argument. + for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) { + // If the index 'i' argument was a loop invariant we fetch it's initial + // value from `beforeBlockInitValMap`. + if (beforeBlockInitValMap.count(i) != 0) + newBeforeBlockArgs[i] = beforeBlockInitValMap[i]; + else + newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++); + } + + rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs); + rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(), + newWhile.getAfter().begin()); + + rewriter.replaceOp(op, newWhile.getResults()); + return success(); + } +}; + +/// Remove loop invariant value from result (condition op) of scf.while. +/// A value is considered loop invariant if the final value yielded by +/// scf.condition is defined outside of the `before` block. We remove the +/// corresponding argument in `after` block and replace the use with the value. +/// We also replace the use of the corresponding result of scf.while with the +/// value. +/// +/// Eg: +/// INPUT :- +/// %res_input:K = scf.while <...> iter_args(%arg0_before = , ..., +/// %argN_before = %N) { +/// ... +/// scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ... +/// } do { +/// ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after): +/// ... +/// some_func(%arg1_after) +/// ... +/// scf.yield %arg0_after, %arg2_after, ..., %argN_after +/// } +/// +/// OUTPUT :- +/// %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) { +/// ... +/// scf.condition(%cond) %arg0, %arg1, ..., %argM +/// } do { +/// ^bb0(%arg0, %arg3, ..., %argM): +/// ... +/// some_func(%a) +/// ... +/// scf.yield %arg0, %b, ..., %argN +/// } +/// +/// EXPLANATION: +/// 1. The 1-th and 2-th operand of scf.condition are defined outside the +/// before block of scf.while, so they get removed. +/// 2. %res_input#1's uses are replaced by %a and %res_input#2's uses are +/// replaced by %b. +/// 3. The corresponding after block argument %arg1_after's uses are +/// replaced by %a and %arg2_after's uses are replaced by %b. +struct RemoveLoopInvariantValueYielded : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp op, + PatternRewriter &rewriter) const override { + Block &beforeBlock = op.getBefore().front(); + ConditionOp condOp = op.getConditionOp(); + OperandRange condOpArgs = condOp.getArgs(); + + bool canSimplify = false; + for (Value condOpArg : condOpArgs) { + // Those values not defined within `before` block will be considered as + // loop invariant values. We map the corresponding `index` with their + // value. + if (condOpArg.getParentBlock() != &beforeBlock) { + canSimplify = true; + break; + } + } + + if (!canSimplify) + return failure(); + + Block::BlockArgListType afterBlockArgs = op.getAfterArguments(); + + SmallVector newCondOpArgs; + SmallVector newAfterBlockType; + DenseMap condOpInitValMap; + SmallVector newAfterBlockArgLocs; + for (auto it : llvm::enumerate(condOpArgs)) { + auto index = static_cast(it.index()); + Value condOpArg = it.value(); + // Those values not defined within `before` block will be considered as + // loop invariant values. We map the corresponding `index` with their + // value. + if (condOpArg.getParentBlock() != &beforeBlock) { + condOpInitValMap.insert({index, condOpArg}); + } else { + newCondOpArgs.emplace_back(condOpArg); + newAfterBlockType.emplace_back(condOpArg.getType()); + newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc()); + } + } + + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(condOp); + rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), + newCondOpArgs); + } + + auto newWhile = rewriter.create(op.getLoc(), newAfterBlockType, + op.getOperands()); + + Block &newAfterBlock = + *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {}, + newAfterBlockType, newAfterBlockArgLocs); + + Block &afterBlock = op.getAfter().front(); + // Since a new scf.condition op was created, we need to fetch the new + // `after` block arguments which will be used while replacing operations of + // previous scf.while's `after` blocks. We'd also be fetching new result + // values too. + SmallVector newAfterBlockArgs(afterBlock.getNumArguments()); + SmallVector newWhileResults(afterBlock.getNumArguments()); + for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) { + Value afterBlockArg, result; + // If index 'i' argument was loop invariant we fetch it's value from the + // `condOpInitMap` map. + if (condOpInitValMap.count(i) != 0) { + afterBlockArg = condOpInitValMap[i]; + result = afterBlockArg; + } else { + afterBlockArg = newAfterBlock.getArgument(j); + result = newWhile.getResult(j); + j++; + } + newAfterBlockArgs[i] = afterBlockArg; + newWhileResults[i] = result; + } + + rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs); + rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(), + newWhile.getBefore().begin()); + + rewriter.replaceOp(op, newWhileResults); + return success(); + } +}; + /// Remove WhileOp results that are also unused in 'after' block. /// /// %0:2 = scf.while () : () -> (i32, i64) { @@ -2552,8 +2843,9 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -870,6 +870,74 @@ // ----- +// CHECK-LABEL: @invariant_loop_args_in_same_order +// CHECK-SAME: (%[[FUNC_ARG0:.*]]: tensor) +func @invariant_loop_args_in_same_order(%f_arg0: tensor) -> (tensor, tensor, tensor, tensor, tensor) { + %cst_0 = arith.constant dense<0> : tensor + %cst_1 = arith.constant dense<1> : tensor + %cst_42 = arith.constant dense<42> : tensor + + %0:5 = scf.while (%arg0 = %cst_0, %arg1 = %f_arg0, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor) { + %1 = arith.cmpi slt, %arg0, %cst_42 : tensor + %2 = tensor.extract %1[] : tensor + scf.condition(%2) %arg0, %arg1, %arg2, %arg3, %arg4 : tensor, tensor, tensor, tensor, tensor + } do { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): // no predecessors + // %arg1 here will get replaced by %cst_1 + %1 = arith.addi %arg0, %arg1 : tensor + %2 = arith.addi %arg2, %arg3 : tensor + scf.yield %1, %arg1, %2, %2, %arg4 : tensor, tensor, tensor, tensor, tensor + } + return %0#0, %0#1, %0#2, %0#3, %0#4 : tensor, tensor, tensor, tensor, tensor +} +// CHECK: %[[CST42:.*]] = arith.constant dense<42> +// CHECK: %[[ONE:.*]] = arith.constant dense<1> +// CHECK: %[[ZERO:.*]] = arith.constant dense<0> +// CHECK: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]]) +// CHECK: arith.cmpi slt, %[[ARG0]], %{{.*}} +// CHECK: tensor.extract %{{.*}}[] +// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]], %[[ARG3]] +// CHECK: } do { +// CHECK: ^{{.*}}(%[[ARG0:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor): +// CHECK: %[[VAL0:.*]] = arith.addi %[[ARG0]], %[[FUNC_ARG0]] +// CHECK: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG3]] +// CHECK: scf.yield %[[VAL0]], %[[VAL1]], %[[VAL1]] +// CHECK: } +// CHECK: return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]] + +// CHECK-LABEL: @while_loop_invariant_argument_different_order +func @while_loop_invariant_argument_different_order() -> (tensor, tensor, tensor, tensor, tensor, tensor) { + %cst_0 = arith.constant dense<0> : tensor + %cst_1 = arith.constant dense<1> : tensor + %cst_42 = arith.constant dense<42> : tensor + + %0:6 = scf.while (%arg0 = %cst_0, %arg1 = %cst_1, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) { + %1 = arith.cmpi slt, %arg0, %cst_42 : tensor + %2 = tensor.extract %1[] : tensor + scf.condition(%2) %arg1, %arg0, %arg2, %arg0, %arg3, %arg4 : tensor, tensor, tensor, tensor, tensor, tensor + } do { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor): // no predecessors + %1 = arith.addi %arg0, %cst_1 : tensor + %2 = arith.addi %arg2, %arg3 : tensor + scf.yield %arg3, %arg1, %2, %2, %arg4 : tensor, tensor, tensor, tensor, tensor + } + return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor, tensor, tensor, tensor, tensor, tensor +} +// CHECK: %[[CST42:.*]] = arith.constant dense<42> +// CHECK: %[[ONE:.*]] = arith.constant dense<1> +// CHECK: %[[ZERO:.*]] = arith.constant dense<0> +// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]]) +// CHECK: arith.cmpi slt, %[[ZERO]], %[[CST42]] +// CHECK: tensor.extract %{{.*}}[] +// CHECK: scf.condition(%{{.*}}) %[[ARG1]], %[[ARG4]] +// CHECK: } do { +// CHECK: ^{{.*}}(%{{.*}}: tensor, %{{.*}}: tensor): +// CHECK: scf.yield %[[ZERO]], %[[ONE]] +// CHECK: } +// CHECK: return %[[WHILE]]#0, %[[ZERO]], %[[ONE]], %[[ZERO]], %[[ONE]], %[[WHILE]]#1 + +// ----- + // CHECK-LABEL: @while_unused_result func @while_unused_result() -> i32 { %0:2 = scf.while () : () -> (i32, i64) {