diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -35,6 +35,10 @@ std::unique_ptr createParallelLoopTilingPass(llvm::ArrayRef tileSize = {}); +/// Creates a pass which folds arith ops on induction variable into +/// loop range. +std::unique_ptr createForLoopRangeFoldingPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -45,4 +45,10 @@ let dependentDialects = ["AffineDialect"]; } +def SCFForLoopRangeFolding + : Pass<"for-loop-range-folding"> { + let summary = "Fold add/mul ops into loop range"; + let constructor = "mlir::createForLoopRangeFoldingPass()"; +} + #endif // MLIR_DIALECT_SCF_PASSES diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSCFTransforms Bufferize.cpp + LoopRangeFolding.cpp LoopSpecialization.cpp ParallelLoopFusion.cpp ParallelLoopTiling.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp @@ -0,0 +1,86 @@ +//===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop range folding. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/SCF/Utils.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" + +using namespace mlir; +using namespace mlir::scf; + +namespace { +struct ForLoopRangeFolding + : public SCFForLoopRangeFoldingBase { + void runOnOperation() override; +}; +} // namespace + +void ForLoopRangeFolding::runOnOperation() { + getOperation()->walk([&](ForOp op) { + Value indVar = op.getInductionVar(); + + auto canBeFolded = [&](Value value) { + return op.isDefinedOutsideOfLoop(value) || value == indVar; + }; + + // Fold until a fixed point is reached + while (true) { + + // If the induction variable is used more than once, we can't fold its + // arith ops into the loop range + if (!indVar.hasOneUse()) + break; + + Operation *user = *indVar.getUsers().begin(); + if (!isa(user)) + break; + + if (!llvm::all_of(user->getOperands(), canBeFolded)) + break; + + OpBuilder b(op); + BlockAndValueMapping lbMap; + lbMap.map(indVar, op.lowerBound()); + BlockAndValueMapping ubMap; + ubMap.map(indVar, op.upperBound()); + BlockAndValueMapping stepMap; + stepMap.map(indVar, op.step()); + + if (isa(user)) { + Operation *lbFold = b.clone(*user, lbMap); + Operation *ubFold = b.clone(*user, ubMap); + + op.setLowerBound(lbFold->getResult(0)); + op.setUpperBound(ubFold->getResult(0)); + + } else if (isa(user)) { + Operation *ubFold = b.clone(*user, ubMap); + Operation *stepFold = b.clone(*user, stepMap); + + op.setUpperBound(ubFold->getResult(0)); + op.setStep(stepFold->getResult(0)); + } + + ValueRange wrapIndvar(indVar); + user->replaceAllUsesWith(wrapIndvar); + user->erase(); + } + }); +} + +std::unique_ptr mlir::createForLoopRangeFoldingPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SCF/loop-range.mlir b/mlir/test/Dialect/SCF/loop-range.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/loop-range.mlir @@ -0,0 +1,128 @@ +// RUN: mlir-opt %s -pass-pipeline='func(for-loop-range-folding)' -split-input-file | FileCheck %s + +func @fold_one_loop(%arg0: memref, %arg1: index, %arg2: index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + scf.for %i = %c0 to %arg1 step %c1 { + %0 = addi %arg2, %i : index + %1 = muli %0, %c4 : index + %2 = memref.load %arg0[%1] : memref + %3 = muli %2, %2 : i32 + memref.store %3, %arg0[%1] : memref + } + return +} + +// CHECK-LABEL: func @fold_one_loop +// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}} +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[C4:.*]] = constant 4 : index +// CHECK: %[[I0:.*]] = addi %[[ARG2]], %[[C0]] : index +// CHECK: %[[I1:.*]] = addi %[[ARG2]], %[[ARG1]] : index +// CHECK: %[[I2:.*]] = muli %[[I1]], %[[C4]] : index +// CHECK: %[[I3:.*]] = muli %[[C1]], %[[C4]] : index +// CHECK: scf.for %[[I:.*]] = %[[I0]] to %[[I2]] step %[[I3]] { +// CHECK: %[[I4:.*]] = memref.load %[[ARG0]]{{\[}}%[[I]] +// CHECK: %[[I5:.*]] = muli %[[I4]], %[[I4]] : i32 +// CHECK: memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I]] + +func @fold_one_loop2(%arg0: memref, %arg1: index, %arg2: index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %c10 = constant 10 : index + scf.for %j = %c0 to %c10 step %c1 { + scf.for %i = %c0 to %arg1 step %c1 { + %0 = addi %arg2, %i : index + %1 = muli %0, %c4 : index + %2 = memref.load %arg0[%1] : memref + %3 = muli %2, %2 : i32 + memref.store %3, %arg0[%1] : memref + } + } + return +} + +// CHECK-LABEL: func @fold_one_loop2 +// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}} +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[C4:.*]] = constant 4 : index +// CHECK: %[[C10:.*]] = constant 10 : index +// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C10]] step %[[C1]] { +// CHECK: %[[I0:.*]] = addi %[[ARG2]], %[[C0]] : index +// CHECK: %[[I1:.*]] = addi %[[ARG2]], %[[ARG1]] : index +// CHECK: %[[I2:.*]] = muli %[[I1]], %[[C4]] : index +// CHECK: %[[I3:.*]] = muli %[[C1]], %[[C4]] : index +// CHECK: scf.for %[[I:.*]] = %[[I0]] to %[[I2]] step %[[I3]] { +// CHECK: %[[I4:.*]] = memref.load %[[ARG0]]{{\[}}%[[I]] +// CHECK: %[[I5:.*]] = muli %[[I4]], %[[I4]] : i32 +// CHECK: memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I]] + +func @fold_two_loops(%arg0: memref, %arg1: index, %arg2: index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %c10 = constant 10 : index + scf.for %j = %c0 to %c10 step %c1 { + scf.for %i = %j to %arg1 step %c1 { + %0 = addi %arg2, %i : index + %1 = muli %0, %c4 : index + %2 = memref.load %arg0[%1] : memref + %3 = muli %2, %2 : i32 + memref.store %3, %arg0[%1] : memref + } + } + return +} + +// CHECK-LABEL: func @fold_two_loops +// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}} +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[C4:.*]] = constant 4 : index +// CHECK: %[[C10:.*]] = constant 10 : index +// CHECK: %[[I0:.*]] = addi %[[ARG2]], %[[C0]] : index +// CHECK: %[[I1:.*]] = addi %[[ARG2]], %[[C10]] : index +// CHECK: scf.for %[[J:.*]] = %[[I0]] to %[[I1]] step %[[C1]] { +// CHECK: %[[I1:.*]] = addi %[[ARG2]], %[[ARG1]] : index +// CHECK: %[[I2:.*]] = muli %[[I1]], %[[C4]] : index +// CHECK: %[[I3:.*]] = muli %[[C1]], %[[C4]] : index +// CHECK: scf.for %[[I:.*]] = %[[J]] to %[[I2]] step %[[I3]] { +// CHECK: %[[I4:.*]] = memref.load %[[ARG0]]{{\[}}%[[I]] +// CHECK: %[[I5:.*]] = muli %[[I4]], %[[I4]] : i32 +// CHECK: memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I]] + +// If an instruction's operands are not defined outside the loop, we cannot +// perform the optimization, as is the case with the muli below. (If paired +// with loop invariant code motion we can continue.) +func @fold_only_first_add(%arg0: memref, %arg1: index, %arg2: index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + scf.for %i = %c0 to %arg1 step %c1 { + %0 = addi %arg2, %i : index + %1 = addi %arg2, %c4 : index + %2 = muli %0, %1 : index + %3 = memref.load %arg0[%2] : memref + %4 = muli %3, %3 : i32 + memref.store %4, %arg0[%2] : memref + } + return +} + +// CHECK-LABEL: func @fold_only_first_add +// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}} +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[C4:.*]] = constant 4 : index +// CHECK: %[[I0:.*]] = addi %[[ARG2]], %[[C0]] : index +// CHECK: %[[I1:.*]] = addi %[[ARG2]], %[[ARG1]] : index +// CHECK: scf.for %[[I:.*]] = %[[I0]] to %[[I1]] step %[[C1]] { +// CHECK: %[[I2:.*]] = addi %[[ARG2]], %[[C4]] : index +// CHECK: %[[I3:.*]] = muli %[[I]], %[[I2]] : index +// CHECK: %[[I4:.*]] = memref.load %[[ARG0]]{{\[}}%[[I3]] +// CHECK: %[[I5:.*]] = muli %[[I4]], %[[I4]] : i32 +// CHECK: memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I3]]