diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -48,6 +48,9 @@ /// loop range. std::unique_ptr createForLoopRangeFoldingPass(); +// Creates a pass which lowers for loops into while loops. +std::unique_ptr createForToWhileLoopPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -62,4 +62,10 @@ let constructor = "mlir::createForLoopRangeFoldingPass()"; } +def SCFForToWhileLoop + : FunctionPass<"scf-for-to-while"> { + let summary = "Convert SCF for loops to SCF while loops"; + let constructor = "mlir::createForToWhileLoopPass()"; +} + #endif // MLIR_DIALECT_SCF_PASSES diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSCFTransforms Bufferize.cpp + ForToWhile.cpp LoopPipelining.cpp LoopRangeFolding.cpp LoopSpecialization.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -0,0 +1,117 @@ +//===- LoopSpecialization.cpp - scf.parallel/SCR.for specialization -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transforms SCF.ForOps into SCF.WhileOps. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using scf::ForOp; +using scf::WhileOp; + +namespace { + +struct ForLoopLoweringPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const override { + // Generate type signature for the loop-carried values. The induction + // variable is placed first, followed by the forOp.iterArgs. + SmallVector lcvTypes; + lcvTypes.push_back(forOp.getInductionVar().getType()); + llvm::transform(forOp.initArgs(), std::back_inserter(lcvTypes), + [&](auto v) { return v.getType(); }); + + // Build scf.WhileOp + rewriter.setInsertionPoint(forOp); + SmallVector initArgs; + initArgs.push_back(forOp.lowerBound()); + llvm::append_range(initArgs, forOp.initArgs()); + auto whileOp = rewriter.create(forOp.getLoc(), lcvTypes, initArgs, + forOp->getAttrs()); + + // 'before' region contains the loop condition and forwarding of iteration + // arguments to the 'after' region. + auto *beforeBlock = rewriter.createBlock( + &whileOp.before(), whileOp.before().begin(), lcvTypes, {}); + rewriter.setInsertionPointToStart(&whileOp.before().front()); + auto cmpOp = rewriter.create(whileOp.getLoc(), CmpIPredicate::slt, + beforeBlock->getArgument(0), + forOp.upperBound()); + rewriter.create(whileOp.getLoc(), cmpOp.getResult(), + beforeBlock->getArguments()); + + // Inline for-loop body into an executeRegion operation in the "after" + // region. The return type of the execRegionOp does not contain the + // iv - yields in the source for-loop contain only iterArgs. + auto *afterBlock = rewriter.createBlock( + &whileOp.after(), whileOp.after().begin(), lcvTypes, {}); + + // Add induction variable incrementation + rewriter.setInsertionPointToEnd(afterBlock); + auto ivIncOp = rewriter.create( + whileOp.getLoc(), afterBlock->getArgument(0), forOp.step()); + + auto execRegionOp = rewriter.create( + whileOp.getLoc(), ArrayRef(lcvTypes).slice(1)); + rewriter.inlineRegionBefore(forOp.getBodyRegion(), execRegionOp.getRegion(), + execRegionOp.getRegion().begin()); + + // Rewrite uses of the inlined for-loop entry block arguments to use the + // loop-carried arguments of the 'after' block. + auto *execRegionEntryBlock = &execRegionOp.getRegion().front(); + SmallVector argsToErase; + for (auto arg : enumerate(execRegionEntryBlock->getArguments())) { + arg.value().replaceAllUsesWith(afterBlock->getArgument(arg.index())); + argsToErase.push_back(arg.index()); + } + execRegionEntryBlock->eraseArguments(argsToErase); + + // Add yield terminator, forwarding the loop induction variable + + // ExecuteRegionOp results from the loop body + SmallVector yieldArgs; + yieldArgs.push_back(ivIncOp->getResult(0)); + llvm::append_range(yieldArgs, execRegionOp.getResults()); + rewriter.create(whileOp.getLoc(), yieldArgs); + + // We cannot do a direct replacement of the forOp since the while op returns + // one more value (the induction variable escapes the loop through being + // carried in the set of iterargs). Instead, rewrite uses of the forOp + // results, and erase the it. + for (auto arg : llvm::enumerate(forOp.getResults())) { + arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1)); + } + rewriter.eraseOp(forOp); + + return success(); + } +}; + +struct ForToWhileLoop : public SCFForToWhileLoopBase { + void runOnFunction() override { + FuncOp funcOp = getFunction(); + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } +}; +} // namespace + +std::unique_ptr mlir::createForToWhileLoopPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir b/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir @@ -0,0 +1,102 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.func(scf-for-to-while)' -split-input-file | FileCheck %s +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// CHECK: builtin.func @single_loop(%arg0: memref, %arg1: index, %arg2: i32) { +// CHECK-NEXT: %c0 = constant 0 : index +// CHECK-NEXT: %c1 = constant 1 : index +// CHECK-NEXT: %0 = scf.while (%arg3 = %c0) : (index) -> index { +// CHECK-NEXT: %1 = cmpi slt, %arg3, %arg1 : index +// CHECK-NEXT: scf.condition(%1) %arg3 : index +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg3: index): // no predecessors +// CHECK-NEXT: %1 = addi %arg3, %c1 : index +// CHECK-NEXT: scf.execute_region { +// CHECK-NEXT: %2 = addi %arg2, %arg2 : i32 +// CHECK-NEXT: memref.store %2, %arg0[%arg3] : memref +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %1 : index +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } +func @single_loop(%arg0: memref, %arg1: index, %arg2: i32) { + %c0 = constant 0 : index + %c1 = constant 1 : index + scf.for %i = %c0 to %arg1 step %c1 { + %0 = addi %arg2, %arg2 : i32 + memref.store %0, %arg0[%i] : memref + } + return +} + +// ----- + +// CHECK: builtin.func @nested_loop(%arg0: memref, %arg1: index, %arg2: i32) { +// CHECK-NEXT: %c0 = constant 0 : index +// CHECK-NEXT: %c1 = constant 1 : index +// CHECK-NEXT: %0 = scf.while (%arg3 = %c0) : (index) -> index { +// CHECK-NEXT: %1 = cmpi slt, %arg3, %arg1 : index +// CHECK-NEXT: scf.condition(%1) %arg3 : index +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg3: index): // no predecessors +// CHECK-NEXT: %1 = addi %arg3, %c1 : index +// CHECK-NEXT: scf.execute_region { +// CHECK-NEXT: %2 = scf.while (%arg4 = %c0) : (index) -> index { +// CHECK-NEXT: %3 = cmpi slt, %arg4, %arg1 : index +// CHECK-NEXT: scf.condition(%3) %arg4 : index +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg4: index): // no predecessors +// CHECK-NEXT: %3 = addi %arg4, %c1 : index +// CHECK-NEXT: scf.execute_region { +// CHECK-NEXT: %4 = addi %arg2, %arg2 : i32 +// CHECK-NEXT: memref.store %4, %arg0[%arg3] : memref +// CHECK-NEXT: memref.store %4, %arg0[%arg4] : memref +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %3 : index +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %1 : index +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } +func @nested_loop(%arg0: memref, %arg1: index, %arg2: i32) { + %c0 = constant 0 : index + %c1 = constant 1 : index + scf.for %i = %c0 to %arg1 step %c1 { + scf.for %j = %c0 to %arg1 step %c1 { + %0 = addi %arg2, %arg2 : i32 + memref.store %0, %arg0[%i] : memref + memref.store %0, %arg0[%j] : memref + } + } + return +} + +// ----- + +// CHECK: builtin.func @for_iter_args(%arg0: index, %arg1: index, %arg2: index) -> f32 { +// CHECK-NEXT: %cst = constant 0.000000e+00 : f32 +// CHECK-NEXT: %0:3 = scf.while (%arg3 = %arg0, %arg4 = %cst, %arg5 = %cst) : (index, f32, f32) -> (index, f32, f32) { +// CHECK-NEXT: %1 = cmpi slt, %arg3, %arg1 : index +// CHECK-NEXT: scf.condition(%1) %arg3, %arg4, %arg5 : index, f32, f32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg3: index, %arg4: f32, %arg5: f32): // no predecessors +// CHECK-NEXT: %1 = addi %arg3, %arg2 : index +// CHECK-NEXT: %2:2 = scf.execute_region -> (f32, f32) { +// CHECK-NEXT: %3 = addf %arg4, %arg5 : f32 +// CHECK-NEXT: scf.yield %3, %3 : f32, f32 +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %1, %2#0, %2#1 : index, f32, f32 +// CHECK-NEXT: } +// CHECK-NEXT: return %0#2 : f32 +// CHECK-NEXT: } +func @for_iter_args(%arg0 : index, %arg1: index, %arg2: index) -> f32 { + %s0 = constant 0.0 : f32 + %result:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iarg0 = %s0, %iarg1 = %s0) -> (f32, f32) { + %sn = addf %iarg0, %iarg1 : f32 + scf.yield %sn, %sn : f32, f32 + } + return %result#1 : f32 +}