diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SCF_PASSES_H_ #define MLIR_DIALECT_SCF_PASSES_H_ +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -55,6 +56,11 @@ // Creates a pass which lowers for loops into while loops. std::unique_ptr createForToWhileLoopPass(); +namespace scf { +// Creates a pass which converts scf.while loops into scf.for ops conditionally. +std::unique_ptr createWhileToForLoopPass(); +} // namespace scf + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -114,4 +114,35 @@ }]; } +def SCFWhileToForLoop : FunctionPass<"scf-while-to-for"> { + let summary = "Convert SCF while loops to SCF for loops"; + let constructor = "mlir::scf::createWhileToForLoopPass()"; + let description = [{ + This pass transforms scf.while operations to scf.for. + + ```mlir + // Input: + %0 = scf.while (%i = %c0) : (index) -> index { + %1 = arith.cmpi slt, %i, %arg1 : index + scf.condition(%1) %i : index + } do { + ^bb0(%i: index): // no predecessors + %1 = arith.addi %i, %c1 : index + %2 = arith.addi %arg2, %arg2 : i32 + memref.store %2, %arg0[%i] : memref + scf.yield %1 : index + } + + // Output: + %0 = scf.for %i = %c0 to %arg1 step %c1 (%i = %c0) { + %1 = arith.addi %i, %c1 : index + %2 = arith.addi %arg2, %arg2 : i32 + memref.store %2, %arg0[%i] : memref + scf.yield %1 : index + } + ``` + + }]; +} + #endif // MLIR_DIALECT_SCF_PASSES diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ ParallelLoopTiling.cpp StructuralTypeConversions.cpp Utils.cpp + WhileToFor.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF diff --git a/mlir/lib/Dialect/SCF/Transforms/WhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/WhileToFor.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/WhileToFor.cpp @@ -0,0 +1,282 @@ +//===- WhileToFor.cpp - scf.while to scf.for loop conversion --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transforms scf.while into scf.for. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/Passes.h" + +#include "PassDetail.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using scf::ForOp; +using scf::WhileOp; + +namespace { + +/// This pass transforms scf.while operations to scf.for. +/// Input: +/// %0 = scf.while (%i = %c0) : (index) -> index { +/// %1 = arith.cmpi slt, %i, %arg1 : index +/// scf.condition(%1) %i : index +/// } do { +/// ^bb0(%i: index): // no predecessors +/// %1 = arith.addi %i, %c1 : index +/// %2 = arith.addi %arg2, %arg2 : i32 +/// memref.store %2, %arg0[%i] : memref +/// scf.yield %1 : index +/// } +/// +/// Output: +/// %0 = scf.for %i = %c0 to %arg1 step %c1 (%i = %c0) { +/// %1 = arith.addi %i, %c1 : index +/// %2 = arith.addi %arg2, %arg2 : i32 +/// memref.store %2, %arg0[%i] : memref +/// scf.yield %1 : index +/// } +struct ConvertWhileToFor : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + /// Replace uses of first argument of `proxyIV` with the first argument of + /// `sourceBlock` in the loop body so that the canonicalization pass is more + /// effective. If `proxyIv` is a rankedTensorType, we need to convert the + /// first argument of `sourceBlock` to a similar form. + static void substituteProxyIv(Block *sourceBlock, Value proxyIv, + PatternRewriter &rewriter) { + assert(!sourceBlock->getArguments().empty() && + "sourceBlock should have atleast one argument."); + if (!proxyIv.getType().isa()) { + proxyIv.replaceAllUsesWith(sourceBlock->getArgument(0)); + return; + } + Type eltType = proxyIv.getType().cast().getElementType(); + rewriter.setInsertionPoint(sourceBlock, sourceBlock->begin()); + Value iv = sourceBlock->getArgument(0); + Location loc = iv.getLoc(); + Value castediv = rewriter.create(loc, eltType, iv); + Value ivTensorType = rewriter.create(loc, castediv); + SmallVector values; + ivTensorType = rewriter.create( + iv.getLoc(), RankedTensorType::get({}, eltType), ivTensorType, + rewriter.getArrayAttr(values)); + proxyIv.replaceAllUsesWith(ivTensorType); + } + + /// Move the ops of `sourceBlock` into `destinationBlock`, keeping the + /// later's block arguments' type as `blockArgumentsType`. `ivIndex` is the + /// index of the block argument of `sourceBlock`, which represents the + /// induction variable. + static void moveBlock(Block *sourceBlock, Block *destinationBlock, + TypeRange blockArgumentsType, unsigned ivIndex, + PatternRewriter &rewriter) { + // If `destination_block` isn't empty, erase its terminator to ensure that + // it never contains two terminator-like ops after merging. + if (!destinationBlock->empty()) + rewriter.eraseOp(destinationBlock->getTerminator()); + + // Add an additional block argument representing the loop induction + // variable, so that the number of block arguments in `sourceBlock` matches + // that of `destinationBlock`. + sourceBlock->insertArgument(sourceBlock->args_begin(), + blockArgumentsType.front()); + Value proxyIv = sourceBlock->getArgument(ivIndex + 1); + substituteProxyIv(sourceBlock, proxyIv, rewriter); + rewriter.mergeBlocks(sourceBlock, destinationBlock, + destinationBlock->getArguments()); + sourceBlock->getArguments().drop_front(); + } + + /// Extract scalar value from `input`, and cast it to index type. + static Value extractScalarValueFromTensor(Value input, + PatternRewriter &rewriter) { + assert(input.getType().cast().getRank() == 0 && + "0 d tensor expected"); + input = rewriter.create(input.getLoc(), input, + /*indices=*/llvm::None); + input = rewriter.create(input.getLoc(), + rewriter.getIndexType(), input); + return input; + } + + /// Checks whether the block arguments of the block in `whileBeforeRegion` and + /// the arguments of `scf.condition` op (the terminator of this block) are + /// same or not. + static bool isWhileStyleLoop(Region *whileBeforeRegion, + scf::ConditionOp condOp, unsigned ivIndex) { + auto hasTwoUses = [](Value v) { + // If it is neither empty, nor single use, and its end is + // std::next(..use_begin(), 2), it has exactly two uses. + return !v.use_empty() && !v.hasOneUse() && + (std::next(v.use_begin(), 2) == v.use_end()); + }; + for (auto en : llvm::enumerate(whileBeforeRegion->front().getArguments())) { + // In order to check whether the value of the `whileBeforeRegion` block + // arguments are passed as arguments to `condOp` without changing values, + // we ensure that there are no other uses of the arguments other than the + // in `condOp`, except for the argument which is the induction varibale, + // which will appear in a arith::AddIOp. + // TODO: The condition is conservative, and can be relaxed. + if (!(en.value() == condOp.getArgs()[en.index()] && + ((en.index() != ivIndex && en.value().hasOneUse()) || + (en.index() == ivIndex && hasTwoUses(en.value()))))) + return false; + } + return true; + } + + /// Checks whether the block arguments of the block in `whileAfterRegion` and + /// the arguments of `scf.yield` op (the terminator if this block) are same or + /// not. + static bool isDoWhileStyleLoop(Region *whileAfterRegion) { + for (auto en : llvm::enumerate(whileAfterRegion->front().getArguments())) { + // In order to check whether the value of the `whileAfterRegion` block + // arguments are passed as arguments to terminating scf::YieldOp without + // changing values, we ensure that there are no other uses of the + // arguments other than the in `yieldOp`. + // TODO: The condition is conservative, and can be relaxed. + if (!(en.value() == whileAfterRegion->front().getTerminator()->getOperand( + en.index()) && + en.value().hasOneUse())) + return false; + } + return true; + } + + LogicalResult matchAndRewrite(WhileOp whileOp, + PatternRewriter &rewriter) const override { + Region *whileAfterRegion = &whileOp.getAfter(); + Region *whileBeforeRegion = &whileOp.getBefore(); + Value upperBound, lowerBound, stepSize; + SmallVector initArgs = whileOp.getInits(); + + // We perform the conversion only for loop with single induction variable. + scf::ConditionOp condOp = whileOp.getConditionOp(); + Value loopTerminatingCondition = condOp.getCondition(); + // If `loopTerminatingCondition` is a block argument, we do not convert it + // to scf.for. + if (loopTerminatingCondition.isa()) + return failure(); + Operation *defOp = loopTerminatingCondition.getDefiningOp(); + // If `defOp` is not arith::CmpIOp, we perform the conversion only if it is + // tensor::ExtractOp (extracting scalar boolean value from 0-d tensor), + // which in turn is defined by arith::CmpIOp. + if (!isa_and_nonnull(defOp)) { + auto tensorExtractOp = dyn_cast(defOp); + if (!tensorExtractOp) + return failure(); + defOp = tensorExtractOp.tensor().getDefiningOp(); + if (!isa_and_nonnull(defOp)) + return failure(); + } + auto cmpOp = dyn_cast_or_null(defOp); + // Keeps track of the block argument representing the induction variable + unsigned ivIndex; + // Extract the upper bound value from the loop terminating condition op. We + // perform the conversion only when `cmpOp` has less than condition, i.e the + // induction variable must be less than the upper bound for the control to + // enter loop body. + // TODO: Make the checks less conservative. + if (cmpOp && cmpOp.getPredicate() == arith::CmpIPredicate::slt) { + ivIndex = cmpOp.getLhs().cast().getArgNumber(); + upperBound = cmpOp.getRhs(); + } else { + return failure(); + } + + // Check whether the scf.while is a while-do style or do-while style op. For + // while style loop, we check that the value of the iter arguments are not + // changed before passing to conditionOp (terminator of the while.before + // block). For do-while style loop, we check whether the value of the block + // arguments of the after block of `whileOp` are not changed before passing + // to the yieldOp (terminator of the after block). + bool iswhileStyleLoop = false; + if (isWhileStyleLoop(whileBeforeRegion, condOp, ivIndex)) + iswhileStyleLoop = true; + else if (isDoWhileStyleLoop(whileAfterRegion)) + iswhileStyleLoop = false; + else + return failure(); + + // Extract the value `stepSize` from the from the AddIOp incrementing the + // induction variable. + Value incrementedInductionVar = + iswhileStyleLoop + ? whileAfterRegion->front().getTerminator()->getOperand(ivIndex) + : condOp.getArgs()[ivIndex]; + auto addOp = incrementedInductionVar.getDefiningOp(); + if (!addOp) + return failure(); + stepSize = addOp.getRhs(); + + SmallVector initArgsFor; + SmallVector initArgsForTypes; + + // Push back an extra argument type same as that of the induction + // variable for the forOp region. + initArgsForTypes.push_back(rewriter.getIndexType()); + + for (auto en : llvm::enumerate(initArgs)) { + initArgsForTypes.push_back(en.value().getType()); + initArgsFor.push_back(en.value()); + } + + // Extract the value of the lower bound of the loop from `initArgs`. + lowerBound = initArgs[ivIndex]; + + // If `lowerBound`, `upperBound` and `stepSize` are ranked tensors, extract + // scalar values from them, and cast them to index type. + if (lowerBound.getType().isa()) + lowerBound = extractScalarValueFromTensor(lowerBound, rewriter); + if (upperBound.getType().isa()) + upperBound = extractScalarValueFromTensor(upperBound, rewriter); + if (stepSize.getType().isa()) + stepSize = extractScalarValueFromTensor(stepSize, rewriter); + + // Create the forOp which should replace `whileOp`. + auto forOp = rewriter.create(whileOp.getLoc(), lowerBound, + upperBound, stepSize, initArgsFor); + + // Move contents of `whileAfterRegion`/`whileBeforeRegion` to + // `forOp.region()` depending on the value of `isWhileStyleLoop`. + if (iswhileStyleLoop) { + moveBlock(&whileAfterRegion->front(), &forOp.getRegion().front(), + initArgsForTypes, ivIndex, rewriter); + } else { + rewriter.setInsertionPoint(condOp); + rewriter.create(forOp.getLoc(), condOp.getArgs()); + rewriter.eraseOp(condOp); + moveBlock(&whileBeforeRegion->front(), &forOp.getRegion().front(), + initArgsForTypes, ivIndex, rewriter); + } + + // Replace whileOp with forOp. + rewriter.replaceOp(whileOp, forOp.getResults()); + return success(); + } +}; + +struct WhileToForLoop : public SCFWhileToForLoopBase { + void runOnFunction() override { + FuncOp funcOp = getFunction(); + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } +}; +} // namespace + +std::unique_ptr mlir::scf::createWhileToForLoopPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SCF/while-loop-to-for-loop.mlir b/mlir/test/Dialect/SCF/while-loop-to-for-loop.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/while-loop-to-for-loop.mlir @@ -0,0 +1,167 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.func(scf-while-to-for)' -allow-unregistered-dialect -split-input-file | FileCheck %s + +// Tests scf.while to scf.for conversion for a simple scf.while loop. +// CHECK-LABEL: func @single_loop +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i32) +func @single_loop(%arg0: memref, %arg1: index, %arg2: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = scf.while (%arg3 = %c0) : (index) -> index { + %1 = arith.cmpi slt, %arg3, %arg1 : index + scf.condition(%1) %arg3 : index + } do { + ^bb0(%arg3: index): // no predecessors + %1 = arith.addi %arg3, %c1 : index + %2 = arith.addi %arg2, %arg2 : i32 + memref.store %2, %arg0[%arg3] : memref + scf.yield %1 : index + } + return +} +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : index +// CHECK-NEXT: %{{.*}} = scf.for %[[IV:.*]] = %[[ZERO]] to %[[ARG1]] step %[[ONE]] iter_args(%[[ARG4:.*]] = %[[ZERO]]) -> (index) { +// CHECK-NEXT: %[[TMP0:.*]] = arith.addi %[[IV]], %[[ONE]] : index +// CHECK-NEXT: %[[TMP1:.*]] = arith.addi %[[ARG2]], %[[ARG2]] : i32 +// CHECK-NEXT: memref.store %[[TMP1]], %[[ARG0]][%[[IV]]] : memref +// CHECK-NEXT: scf.yield %[[TMP0]] : index +// CHECK-NEXT: } +// CHECK-NEXT: return + +// ----- + +// Tests scf.while to scf.for conversion for nested simple scf.while loop. +// CHECK-LABEL: func @nested_loop +// CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: i32) +func @nested_loop(%arg0: memref, %arg1: index, %arg2: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = scf.while (%arg3 = %c0) : (index) -> index { + %1 = arith.cmpi slt, %arg3, %arg1 : index + scf.condition(%1) %arg3 : index + } do { + ^bb0(%arg3: index): // no predecessors + %1 = arith.addi %arg3, %c1 : index + %2 = scf.while (%arg4 = %c0) : (index) -> index { + %3 = arith.cmpi slt, %arg4, %arg1 : index + scf.condition(%3) %arg4 : index + } do { + ^bb0(%arg4: index): // no predecessors + %3 = arith.addi %arg4, %c1 : index + %4 = arith.addi %arg2, %arg2 : i32 + memref.store %4, %arg0[%arg3] : memref + memref.store %4, %arg0[%arg4] : memref + scf.yield %3 : index + } + scf.yield %1 : index + } + return +} +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[ONE:.*]] = arith.constant 1 : index +// CHECK-NEXT: %{{.*}} = scf.for %[[IV0:.*]] = %[[ZERO]] to %[[ARG1]] step %[[ONE]] iter_args(%[[ARG4:.*]] = %[[ZERO]]) -> (index) { +// CHECK-NEXT: %[[TMP0:.*]] = arith.addi %[[IV0]], %[[ONE]] : index +// CHECK-NEXT: %{{.*}} = scf.for %[[IV1:.*]] = %[[ZERO]] to %[[ARG1]] step %[[ONE]] iter_args(%[[ARG6:.*]] = %[[ZERO]]) -> (index) { +// CHECK-NEXT: %[[TMP1:.*]] = arith.addi %[[IV1]], %[[ONE]] : index +// CHECK-NEXT: %[[VAL:.*]] = arith.addi %[[ARG2]], %[[ARG2]] : i32 +// CHECK-NEXT: memref.store %[[VAL]], %[[ARG0]][%[[IV0]]] : memref +// CHECK-NEXT: memref.store %[[VAL]], %[[ARG0]][%[[IV1]]] : memref +// CHECK-NEXT: scf.yield %[[TMP1]] : index +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[TMP0]] : index +// CHECK-NEXT: } +// CHECK-NEXT: return + +// ----- + +// Tests scf.while to scf.for conversion for a while-do style scf.while loop with iter arguments other than induction variable. +// CHECK-LABEL: func @for_iter_args +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor) +func @for_iter_args(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %cst = arith.constant dense<0.000000e+00> : tensor + %0:3 = scf.while (%arg3 = %arg0, %arg4 = %cst, %arg5 = %cst) : (tensor, tensor, tensor) -> (tensor, tensor, tensor) { + %1 = arith.cmpi slt, %arg3, %arg1 : tensor + %2 = tensor.extract %1[] : tensor + scf.condition(%2) %arg3, %arg4, %arg5 : tensor, tensor, tensor + } do { + ^bb0(%arg3: tensor, %arg4: tensor, %arg5: tensor): // no predecessors + %1 = arith.addi %arg3, %arg2 : tensor + %2 = arith.addf %arg4, %arg5 : tensor + scf.yield %1, %2, %2 : tensor, tensor, tensor + } + return %0#2 : tensor + } +// CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: %[[ARG0I32:.*]] = tensor.extract %[[ARG0]][] : tensor +// CHECK-NEXT: %[[ARG0INDEX:.*]] = arith.index_cast %[[ARG0I32]] : i32 to index +// CHECK-NEXT: %[[ARG1I32:.*]] = tensor.extract %[[ARG1]][] : tensor +// CHECK-NEXT: %[[ARG1INDEX:.*]] = arith.index_cast %[[ARG1I32]] : i32 to index +// CHECK-NEXT: %[[ARG2I32:.*]] = tensor.extract %[[ARG2]][] : tensor +// CHECK-NEXT: %[[ARG2INDEX:.*]] = arith.index_cast %[[ARG2I32]] : i32 to index +// CHECK-NEXT: %[[RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[ARG0INDEX]] to %[[ARG1INDEX]] step %[[ARG2INDEX]] iter_args(%[[ARG4:.*]] = %[[ARG0]], %[[ARG5:.*]] = %[[ZERO]], %[[ARG6:.*]] = %[[ZERO]]) -> (tensor, tensor, tensor) { +// CHECK-NEXT: %[[IVI32:.*]] = arith.index_cast %[[IV]] : index to i32 +// CHECK-NEXT: %[[IVTENSOR:.*]] = tensor.from_elements %[[IVI32]] : tensor<1xi32> +// CHECK-NEXT: %[[IV0DTENSOR:.*]] = tensor.collapse_shape %[[IVTENSOR]] [] : tensor<1xi32> into tensor +// CHECK-NEXT: %[[TMP0:.*]] = arith.addi %[[IV0DTENSOR]], %[[ARG2]] : tensor +// CHECK-NEXT: %[[TMP1:.*]] = arith.addf %[[ARG5]], %[[ARG6]] : tensor +// CHECK-NEXT: scf.yield %[[TMP0]], %[[TMP1]], %[[TMP1]] : tensor, tensor, tensor +// CHECK-NEXT: } +// CHECK-NEXT: return %[[RESULT]]#2 : tensor + +// ----- + +// Tests scf.while to scf.for conversion for a do-while style scf.while loop with iter arguments other than induction variable. +// CHECK-LABEL: func @for_iter_args_do_while +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor) +func @for_iter_args_do_while(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %cst = arith.constant dense<0.000000e+00> : tensor + %0:3 = scf.while (%arg3 = %arg0, %arg4 = %cst, %arg5 = %cst) : (tensor, tensor, tensor) -> (tensor, tensor, tensor) { + %1 = arith.cmpi slt, %arg3, %arg1 : tensor + %2 = tensor.extract %1[] : tensor + %3 = arith.addi %arg3, %arg2 : tensor + %4 = arith.addf %arg4, %arg5 : tensor + scf.condition(%2) %3, %4, %4 : tensor, tensor, tensor + } do { + ^bb0(%arg3: tensor, %arg4: tensor, %arg5: tensor): // no predecessors + scf.yield %arg3, %arg4, %arg5 : tensor, tensor, tensor + } + return %0#2 : tensor + } +// CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: %[[ARG0I32:.*]] = tensor.extract %[[ARG0]][] : tensor +// CHECK-NEXT: %[[ARG0INDEX:.*]] = arith.index_cast %[[ARG0I32]] : i32 to index +// CHECK-NEXT: %[[ARG1I32:.*]] = tensor.extract %[[ARG1]][] : tensor +// CHECK-NEXT: %[[ARG1INDEX:.*]] = arith.index_cast %[[ARG1I32]] : i32 to index +// CHECK-NEXT: %[[ARG2I32:.*]] = tensor.extract %[[ARG2]][] : tensor +// CHECK-NEXT: %[[ARG2INDEX:.*]] = arith.index_cast %[[ARG2I32]] : i32 to index +// CHECK-NEXT: %[[RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[ARG0INDEX]] to %[[ARG1INDEX]] step %[[ARG2INDEX]] iter_args(%[[ARG4:.*]] = %[[ARG0]], %[[ARG5:.*]] = %[[ZERO]], %[[ARG6:.*]] = %[[ZERO]]) -> (tensor, tensor, tensor) { +// CHECK-NEXT: %[[IVI32:.*]] = arith.index_cast %[[IV]] : index to i32 +// CHECK-NEXT: %[[IVTENSOR:.*]] = tensor.from_elements %[[IVI32]] : tensor<1xi32> +// CHECK-NEXT: %[[IV0DTENSOR:.*]] = tensor.collapse_shape %[[IVTENSOR]] [] : tensor<1xi32> into tensor +// CHECK-NEXT: %[[TMP0:.*]] = arith.addi %[[IV0DTENSOR]], %[[ARG2]] : tensor +// CHECK-NEXT: %[[TMP1:.*]] = arith.addf %[[ARG5]], %[[ARG6]] : tensor +// CHECK-NEXT: scf.yield %[[TMP0]], %[[TMP1]], %[[TMP1]] : tensor, tensor, tensor +// CHECK-NEXT: } +// CHECK-NEXT: return %[[RESULT]]#2 : tensor + +// ----- + +// Do not convert scf.while to scf.for, as the induction variable value may get modified in the before block. +// CHECK-LABEL: func @single_loop_test2 +func @single_loop_test2(%arg0: memref, %arg1: index, %arg2: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = scf.while (%arg3 = %c0) : (index) -> index { + "foo"(%arg3) : (index) -> () + %1 = arith.cmpi slt, %arg3, %arg1 : index + scf.condition(%1) %arg3 : index + } do { + ^bb0(%arg3: index): // no predecessors + %1 = arith.addi %arg3, %c1 : index + %2 = arith.addi %arg2, %arg2 : i32 + memref.store %2, %arg0[%arg3] : memref + scf.yield %1 : index + } + return +} +// CHECK: scf.while