diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -48,6 +48,9 @@ /// loop range. std::unique_ptr createForLoopRangeFoldingPass(); +// Creates a pass which lowers for loops into while loops. +std::unique_ptr createForToWhileLoopPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -62,4 +62,10 @@ let constructor = "mlir::createForLoopRangeFoldingPass()"; } +def SCFForToWhileLoop + : FunctionPass<"scf-for-to-while"> { + let summary = "Convert SCF for loops to SCF while loops"; + let constructor = "mlir::createForToWhileLoopPass()"; +} + #endif // MLIR_DIALECT_SCF_PASSES diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSCFTransforms Bufferize.cpp + ForToWhile.cpp LoopPipelining.cpp LoopRangeFolding.cpp LoopSpecialization.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -0,0 +1,116 @@ +//===- LoopSpecialization.cpp - scf.parallel/SCR.for specialization -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transforms SCF.ForOps into SCF.WhileOps. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using scf::ForOp; +using scf::WhileOp; + +namespace { + +struct ForLoopLoweringPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const override { + // Generate type signature for the loop-carried values. The induction + // variable is placed first, followed by the forOp.iterArgs. + SmallVector lcvTypes; + lcvTypes.push_back(forOp.getInductionVar().getType()); + llvm::transform(forOp.initArgs(), std::back_inserter(lcvTypes), + [&](auto v) { return v.getType(); }); + + // Build scf.WhileOp + rewriter.setInsertionPoint(forOp); + SmallVector initArgs; + initArgs.push_back(forOp.lowerBound()); + llvm::append_range(initArgs, forOp.initArgs()); + auto whileOp = rewriter.create(forOp.getLoc(), lcvTypes, initArgs, + forOp->getAttrs()); + + // 'before' region contains the loop condition and forwarding of iteration + // arguments to the 'after' region. + auto beforeBlock = rewriter.createBlock( + &whileOp.before(), whileOp.before().begin(), lcvTypes, {}); + rewriter.setInsertionPointToStart(&whileOp.before().front()); + auto cmpOp = rewriter.create(whileOp.getLoc(), CmpIPredicate::slt, + beforeBlock->getArgument(0), + forOp.upperBound()); + rewriter.create(whileOp.getLoc(), cmpOp.getResult(), + beforeBlock->getArguments()); + + // Inline for-loop body into an executeRegion operation in the "after" + // region. The return type of the execRegionOp does not contain the + // iv - yields in the source for-loop contain only iterArgs. + auto afterBlock = rewriter.createBlock( + &whileOp.after(), whileOp.after().begin(), lcvTypes, {}); + auto execRegionOp = rewriter.create( + whileOp.getLoc(), ArrayRef(lcvTypes).slice(1)); + rewriter.inlineRegionBefore(forOp.getBodyRegion(), execRegionOp.getRegion(), + execRegionOp.getRegion().begin()); + + // Rewrite uses of the inlined for-loop entry block arguments to use the + // loop-carried arguments of the 'after' block. + auto execRegionEntryBlock = &execRegionOp.getRegion().front(); + SmallVector argsToErase; + for (auto arg : enumerate(execRegionEntryBlock->getArguments())) { + arg.value().replaceAllUsesWith(afterBlock->getArgument(arg.index())); + argsToErase.push_back(arg.index()); + } + execRegionEntryBlock->eraseArguments(argsToErase); + + // Add induction variable incrementation + rewriter.setInsertionPointToEnd(afterBlock); + auto ivIncOp = rewriter.create( + whileOp.getLoc(), afterBlock->getArgument(0), forOp.step()); + + // Add yield terminator, forwarding the loop induction variable + + // ExecuteRegionOp results from the loop body + SmallVector yieldArgs; + yieldArgs.push_back(ivIncOp->getResult(0)); + llvm::append_range(yieldArgs, execRegionOp.getResults()); + rewriter.create(whileOp.getLoc(), yieldArgs); + + // We cannot do a direct replacement of the forOp since the while op returns + // one more value (the induction variable escapes the loop through being + // carried in the set of iterargs). Instead, rewrite uses of the forOp + // results, and erase the it. + for (auto arg : llvm::enumerate(forOp.getResults())) { + arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1)); + } + rewriter.eraseOp(forOp); + + return success(); + } +}; + +struct ForToWhileLoop : public SCFForToWhileLoopBase { + void runOnFunction() override { + FuncOp funcOp = getFunction(); + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } +}; +} // namespace + +std::unique_ptr mlir::createForToWhileLoopPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir b/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir @@ -0,0 +1,111 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.func(scf-for-to-while)' -split-input-file | FileCheck %s +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// CHECK-LABEL: builtin.func @single_loop( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: index, +// CHECK-SAME: %[[VAL_2:.*]]: i32) { +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index { +// CHECK: %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index +// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_8:.*]]: index): +// CHECK: scf.execute_region { +// CHECK: %[[VAL_9:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32 +// CHECK: memref.store %[[VAL_9]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[VAL_10:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index +// CHECK: scf.yield %[[VAL_10]] : index +// CHECK: } +// CHECK: return +// CHECK: } +func @single_loop(%arg0: memref, %arg1: index, %arg2: i32) { + %c0 = constant 0 : index + %c1 = constant 1 : index + scf.for %i = %c0 to %arg1 step %c1 { + %0 = addi %arg2, %arg2 : i32 + memref.store %0, %arg0[%i] : memref + } + return +} + +// ----- + +// CHECK-LABEL: builtin.func @nested_loop( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: index, +// CHECK-SAME: %[[VAL_2:.*]]: i32) { +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index { +// CHECK: %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index +// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_8:.*]]: index): +// CHECK: scf.execute_region { +// CHECK: %[[VAL_9:.*]] = scf.while (%[[VAL_10:.*]] = %[[VAL_3]]) : (index) -> index { +// CHECK: %[[VAL_11:.*]] = cmpi slt, %[[VAL_10]], %[[VAL_1]] : index +// CHECK: scf.condition(%[[VAL_11]]) %[[VAL_10]] : index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_12:.*]]: index): +// CHECK: scf.execute_region { +// CHECK: %[[VAL_13:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32 +// CHECK: memref.store %[[VAL_13]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref +// CHECK: memref.store %[[VAL_13]], %[[VAL_0]]{{\[}}%[[VAL_12]]] : memref +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[VAL_14:.*]] = addi %[[VAL_12]], %[[VAL_4]] : index +// CHECK: scf.yield %[[VAL_14]] : index +// CHECK: } +// CHECK: scf.yield +// CHECK: } +// CHECK: %[[VAL_15:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index +// CHECK: scf.yield %[[VAL_15]] : index +// CHECK: } +// CHECK: return +// CHECK: } +func @nested_loop(%arg0: memref, %arg1: index, %arg2: i32) { + %c0 = constant 0 : index + %c1 = constant 1 : index + scf.for %i = %c0 to %arg1 step %c1 { + scf.for %j = %c0 to %arg1 step %c1 { + %0 = addi %arg2, %arg2 : i32 + memref.store %0, %arg0[%i] : memref + memref.store %0, %arg0[%j] : memref + } + } + return +} + +// ----- + +// CHECK-LABEL: builtin.func @for_iter_args( +// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, +// CHECK-SAME: %[[VAL_2:.*]]: index) -> f32 { +// CHECK: %[[VAL_3:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[VAL_4:.*]]:3 = scf.while (%[[VAL_5:.*]] = %[[VAL_0]], %[[VAL_6:.*]] = %[[VAL_3]], %[[VAL_7:.*]] = %[[VAL_3]]) : (index, f32, f32) -> (index, f32, f32) { +// CHECK: %[[VAL_8:.*]] = cmpi slt, %[[VAL_5]], %[[VAL_1]] : index +// CHECK: scf.condition(%[[VAL_8]]) %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : index, f32, f32 +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32): +// CHECK: %[[VAL_12:.*]]:2 = scf.execute_region -> (f32, f32) { +// CHECK: %[[VAL_13:.*]] = addf %[[VAL_10]], %[[VAL_11]] : f32 +// CHECK: scf.yield %[[VAL_13]], %[[VAL_13]] : f32, f32 +// CHECK: } +// CHECK: %[[VAL_14:.*]] = addi %[[VAL_9]], %[[VAL_2]] : index +// CHECK: scf.yield %[[VAL_14]], %[[VAL_15:.*]]#0, %[[VAL_15]]#1 : index, f32, f32 +// CHECK: } +// CHECK: return %[[VAL_16:.*]]#2 : f32 +// CHECK: } + +func @for_iter_args(%arg0 : index, %arg1: index, %arg2: index) -> f32 { + %s0 = constant 0.0 : f32 + %result:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iarg0 = %s0, %iarg1 = %s0) -> (f32, f32) { + %sn = addf %iarg0, %iarg1 : f32 + scf.yield %sn, %sn : f32, f32 + } + return %result#1 : f32 +}