diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -52,6 +52,9 @@ /// loop range. std::unique_ptr createForLoopRangeFoldingPass(); +// Creates a pass which lowers for loops into while loops. +std::unique_ptr createForToWhileLoopPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -78,4 +78,39 @@ let constructor = "mlir::createForLoopRangeFoldingPass()"; } +def SCFForToWhileLoop + : FunctionPass<"scf-for-to-while"> { + let summary = "Convert SCF for loops to SCF while loops"; + let constructor = "mlir::createForToWhileLoopPass()"; + let description = [{ + This pass transforms SCF.ForOp operations to SCF.WhileOp. The For loop + condition is placed in the 'before' region of the while operation, and the + induction variable incrementation and loop body in the 'after' region. + The loop carried values of the while op are the induction variable (IV) of + the for-loop + any iter_args specified for the for-loop. + Any 'yield' ops in the for-loop are rewritten to additionally yield the + (incremented) induction variable. + + ```mlir + # Before: + scf.for %i = %c0 to %arg1 step %c1 { + %0 = addi %arg2, %arg2 : i32 + memref.store %0, %arg0[%i] : memref + } + + # After: + %0 = scf.while (%i = %c0) : (index) -> index { + %1 = cmpi slt, %i, %arg1 : index + scf.condition(%1) %i : index + } do { + ^bb0(%i: index): // no predecessors + %1 = addi %i, %c1 : index + %2 = addi %arg2, %arg2 : i32 + memref.store %2, %arg0[%i] : memref + scf.yield %1 : index + } + ``` + }]; +} + #endif // MLIR_DIALECT_SCF_PASSES diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSCFTransforms Bufferize.cpp + ForToWhile.cpp LoopCanonicalization.cpp LoopPipelining.cpp LoopRangeFolding.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -0,0 +1,110 @@ +//===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transforms SCF.ForOp's into SCF.WhileOp's. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace llvm; +using namespace mlir; +using scf::ForOp; +using scf::WhileOp; + +namespace { + +struct ForLoopLoweringPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const override { + // Generate type signature for the loop-carried values. The induction + // variable is placed first, followed by the forOp.iterArgs. + SmallVector lcvTypes; + lcvTypes.push_back(forOp.getInductionVar().getType()); + llvm::transform(forOp.initArgs(), std::back_inserter(lcvTypes), + [&](auto v) { return v.getType(); }); + + // Build scf.WhileOp + SmallVector initArgs; + initArgs.push_back(forOp.lowerBound()); + llvm::append_range(initArgs, forOp.initArgs()); + auto whileOp = rewriter.create(forOp.getLoc(), lcvTypes, initArgs, + forOp->getAttrs()); + + // 'before' region contains the loop condition and forwarding of iteration + // arguments to the 'after' region. + auto *beforeBlock = rewriter.createBlock( + &whileOp.before(), whileOp.before().begin(), lcvTypes, {}); + rewriter.setInsertionPointToStart(&whileOp.before().front()); + auto cmpOp = rewriter.create(whileOp.getLoc(), CmpIPredicate::slt, + beforeBlock->getArgument(0), + forOp.upperBound()); + rewriter.create(whileOp.getLoc(), cmpOp.getResult(), + beforeBlock->getArguments()); + + // Inline for-loop body into an executeRegion operation in the "after" + // region. The return type of the execRegionOp does not contain the + // iv - yields in the source for-loop contain only iterArgs. + auto *afterBlock = rewriter.createBlock( + &whileOp.after(), whileOp.after().begin(), lcvTypes, {}); + + // Add induction variable incrementation + rewriter.setInsertionPointToEnd(afterBlock); + auto ivIncOp = rewriter.create( + whileOp.getLoc(), afterBlock->getArgument(0), forOp.step()); + + // Rewrite uses of the for-loop block arguments to the new while-loop + // "after" arguments + for (auto barg : enumerate(forOp.getBody(0)->getArguments())) + barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index())); + + // Inline for-loop body operations into 'after' region. + for (auto &arg : llvm::make_early_inc_range(*forOp.getBody())) + arg.moveBefore(afterBlock, afterBlock->end()); + + // Add incremented IV to yield operations + for (auto yieldOp : afterBlock->getOps()) { + SmallVector yieldOperands = yieldOp.getOperands(); + yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult()); + yieldOp->setOperands(yieldOperands); + } + + // We cannot do a direct replacement of the forOp since the while op returns + // an extra value (the induction variable escapes the loop through being + // carried in the set of iterargs). Instead, rewrite uses of the forOp + // results. + for (auto arg : llvm::enumerate(forOp.getResults())) + arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1)); + + rewriter.eraseOp(forOp); + return success(); + } +}; + +struct ForToWhileLoop : public SCFForToWhileLoopBase { + void runOnFunction() override { + FuncOp funcOp = getFunction(); + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } +}; +} // namespace + +std::unique_ptr mlir::createForToWhileLoopPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir b/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir @@ -0,0 +1,148 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.func(scf-for-to-while)' -split-input-file | FileCheck %s +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// CHECK-LABEL: builtin.func @single_loop( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: index, +// CHECK-SAME: %[[VAL_2:.*]]: i32) { +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index { +// CHECK: %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index +// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_8:.*]]: index): +// CHECK: %[[VAL_9:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index +// CHECK: %[[VAL_10:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32 +// CHECK: memref.store %[[VAL_10]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref +// CHECK: scf.yield %[[VAL_9]] : index +// CHECK: } +// CHECK: return +// CHECK: } +func @single_loop(%arg0: memref, %arg1: index, %arg2: i32) { + %c0 = constant 0 : index + %c1 = constant 1 : index + scf.for %i = %c0 to %arg1 step %c1 { + %0 = addi %arg2, %arg2 : i32 + memref.store %0, %arg0[%i] : memref + } + return +} + +// ----- + +// CHECK-LABEL: builtin.func @nested_loop( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: index, +// CHECK-SAME: %[[VAL_2:.*]]: i32) { +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index { +// CHECK: %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index +// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_8:.*]]: index): +// CHECK: %[[VAL_9:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index +// CHECK: %[[VAL_10:.*]] = scf.while (%[[VAL_11:.*]] = %[[VAL_3]]) : (index) -> index { +// CHECK: %[[VAL_12:.*]] = cmpi slt, %[[VAL_11]], %[[VAL_1]] : index +// CHECK: scf.condition(%[[VAL_12]]) %[[VAL_11]] : index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_13:.*]]: index): +// CHECK: %[[VAL_14:.*]] = addi %[[VAL_13]], %[[VAL_4]] : index +// CHECK: %[[VAL_15:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32 +// CHECK: memref.store %[[VAL_15]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref +// CHECK: memref.store %[[VAL_15]], %[[VAL_0]]{{\[}}%[[VAL_13]]] : memref +// CHECK: scf.yield %[[VAL_14]] : index +// CHECK: } +// CHECK: scf.yield %[[VAL_9]] : index +// CHECK: } +// CHECK: return +// CHECK: } +func @nested_loop(%arg0: memref, %arg1: index, %arg2: i32) { + %c0 = constant 0 : index + %c1 = constant 1 : index + scf.for %i = %c0 to %arg1 step %c1 { + scf.for %j = %c0 to %arg1 step %c1 { + %0 = addi %arg2, %arg2 : i32 + memref.store %0, %arg0[%i] : memref + memref.store %0, %arg0[%j] : memref + } + } + return +} + +// ----- + +// CHECK-LABEL: builtin.func @for_iter_args( +// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, +// CHECK-SAME: %[[VAL_2:.*]]: index) -> f32 { +// CHECK: %[[VAL_3:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[VAL_4:.*]]:3 = scf.while (%[[VAL_5:.*]] = %[[VAL_0]], %[[VAL_6:.*]] = %[[VAL_3]], %[[VAL_7:.*]] = %[[VAL_3]]) : (index, f32, f32) -> (index, f32, f32) { +// CHECK: %[[VAL_8:.*]] = cmpi slt, %[[VAL_5]], %[[VAL_1]] : index +// CHECK: scf.condition(%[[VAL_8]]) %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : index, f32, f32 +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32): +// CHECK: %[[VAL_12:.*]] = addi %[[VAL_9]], %[[VAL_2]] : index +// CHECK: %[[VAL_13:.*]] = addf %[[VAL_10]], %[[VAL_11]] : f32 +// CHECK: scf.yield %[[VAL_12]], %[[VAL_13]], %[[VAL_13]] : index, f32, f32 +// CHECK: } +// CHECK: return %[[VAL_14:.*]]#2 : f32 +// CHECK: } +func @for_iter_args(%arg0 : index, %arg1: index, %arg2: index) -> f32 { + %s0 = constant 0.0 : f32 + %result:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iarg0 = %s0, %iarg1 = %s0) -> (f32, f32) { + %sn = addf %iarg0, %iarg1 : f32 + scf.yield %sn, %sn : f32, f32 + } + return %result#1 : f32 +} + +// ----- + +// CHECK-LABEL: builtin.func @exec_region_multiple_yields( +// CHECK-SAME: %[[VAL_0:.*]]: i32, +// CHECK-SAME: %[[VAL_1:.*]]: index, +// CHECK-SAME: %[[VAL_2:.*]]: i32) -> i32 { +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]]:2 = scf.while (%[[VAL_6:.*]] = %[[VAL_3]], %[[VAL_7:.*]] = %[[VAL_0]]) : (index, i32) -> (index, i32) { +// CHECK: %[[VAL_8:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index +// CHECK: scf.condition(%[[VAL_8]]) %[[VAL_6]], %[[VAL_7]] : index, i32 +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: i32): +// CHECK: %[[VAL_11:.*]] = addi %[[VAL_9]], %[[VAL_4]] : index +// CHECK: %[[VAL_12:.*]] = scf.execute_region -> i32 { +// CHECK: %[[VAL_13:.*]] = cmpi slt, %[[VAL_9]], %[[VAL_4]] : index +// CHECK: cond_br %[[VAL_13]], ^bb1, ^bb2 +// CHECK: ^bb1: +// CHECK: %[[VAL_14:.*]] = subi %[[VAL_10]], %[[VAL_0]] : i32 +// CHECK: scf.yield %[[VAL_14]] : i32 +// CHECK: ^bb2: +// CHECK: %[[VAL_15:.*]] = muli %[[VAL_10]], %[[VAL_2]] : i32 +// CHECK: scf.yield %[[VAL_15]] : i32 +// CHECK: } +// CHECK: scf.yield %[[VAL_11]], %[[VAL_16:.*]] : index, i32 +// CHECK: } +// CHECK: return %[[VAL_17:.*]]#1 : i32 +// CHECK: } +func @exec_region_multiple_yields(%arg0: i32, %arg1: index, %arg2: i32) -> i32 { + %c1_i32 = constant 1 : i32 + %c2_i32 = constant 2 : i32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %c5 = constant 5 : index + %0 = scf.for %i = %c0 to %arg1 step %c1 iter_args(%iarg0 = %arg0) -> i32 { + %2 = scf.execute_region -> i32 { + %1 = cmpi slt, %i, %c1 : index + cond_br %1, ^bb1, ^bb2 + ^bb1: + %2 = subi %iarg0, %arg0 : i32 + scf.yield %2 : i32 + ^bb2: + %3 = muli %iarg0, %arg2 : i32 + scf.yield %3 : i32 + } + scf.yield %2 : i32 + } + return %0 : i32 +}