diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -35,9 +35,10 @@ std::unique_ptr createParallelLoopTilingPass(llvm::ArrayRef tileSize = {}); +/// Creates a pass which folds arith ops on induction variable into +/// loop range. std::unique_ptr createForLoopRangeFoldingPass(); - //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -46,8 +46,8 @@ } def SCFForLoopRangeFolding - : Pass<"for-loop-range-folding"> { - let summary = ""; + : FunctionPass<"for-loop-range-folding"> { + let summary = "Fold add/mul ops into loop range"; let constructor = "mlir::createForLoopRangeFoldingPass()"; } diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp @@ -1,3 +1,15 @@ +//===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop range folding. +// +//===----------------------------------------------------------------------===// + #include "PassDetail.h" #include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/SCF/SCF.h" @@ -6,18 +18,17 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" - using namespace mlir; using namespace mlir::scf; namespace { struct ForLoopRangeFolding : public SCFForLoopRangeFoldingBase { - void runOnOperation() override; + void runOnFunction() override; }; -} +} // namespace -LogicalResult foldRanges(ForOp op) { +static LogicalResult foldRanges(ForOp op) { // Fold until a fixed point is reached Value indVar = op.getInductionVar(); @@ -32,59 +43,46 @@ if (!indVar.hasOneUse()) break; - Operation *use = indVar.getUses().begin().getUser(); - if (!isa(use)) + Operation *user = indVar.getUses().begin().getUser(); + if (!isa(user)) break; - if (!llvm::all_of(use->getOperands(), canBeFolded)) + if (!llvm::all_of(user->getOperands(), canBeFolded)) break; OpBuilder b(op); - BlockAndValueMapping lbMap; lbMap.map(indVar, op.lowerBound()); - BlockAndValueMapping ubMap; ubMap.map(indVar, op.upperBound()); - BlockAndValueMapping stepMap; stepMap.map(indVar, op.step()); - - if (auto addOp = dyn_cast(use)) { - auto lbFold = b.create( - op.getLoc(), - lbMap.lookupOrDefault(addOp.getOperand(0)), - lbMap.lookupOrDefault(addOp.getOperand(1))); - - auto ubFold = b.create( - op.getLoc(), - ubMap.lookupOrDefault(addOp.getOperand(0)), - ubMap.lookupOrDefault(addOp.getOperand(1))); - - op.setLowerBound(lbFold); - op.setUpperBound(ubFold); - - addOp.replaceAllUsesWith(indVar); - addOp.erase(); - - } else if (auto mulOp = dyn_cast(use)) { - auto ubFold = b.create( - op.getLoc(), - ubMap.lookupOrDefault(mulOp.getOperand(0)), - ubMap.lookupOrDefault(mulOp.getOperand(1))); - - auto stepFold = b.create( - op.getLoc(), - stepMap.lookupOrDefault(mulOp.getOperand(0)), - stepMap.lookupOrDefault(mulOp.getOperand(1))); - - op.setUpperBound(ubFold); - op.setStep(stepFold); - - mulOp.replaceAllUsesWith(indVar); - mulOp.erase(); - } + BlockAndValueMapping lbMap; + lbMap.map(indVar, op.lowerBound()); + BlockAndValueMapping ubMap; + ubMap.map(indVar, op.upperBound()); + BlockAndValueMapping stepMap; + stepMap.map(indVar, op.step()); + + if (isa(user)) { + auto lbFold = b.clone(*user, lbMap); + auto ubFold = b.clone(*user, ubMap); + + op.setLowerBound(lbFold->getResult(0)); + op.setUpperBound(ubFold->getResult(0)); + + } else if (isa(user)) { + auto ubFold = b.clone(*user, ubMap); + auto stepFold = b.clone(*user, stepMap); + + op.setUpperBound(ubFold->getResult(0)); + op.setStep(stepFold->getResult(0)); + } + + ValueRange wrapIndvar(indVar); + user->replaceAllUsesWith(wrapIndvar); + user->erase(); } return success(); } -void ForLoopRangeFolding::runOnOperation() { - getOperation()->walk([&](ForOp forOp) { +void ForLoopRangeFolding::runOnFunction() { + getFunction().getOperation()->walk([&](ForOp forOp) { if (failed(foldRanges(forOp))) signalPassFailure(); }); diff --git a/mlir/test/Dialect/SCF/loop-range.mlir b/mlir/test/Dialect/SCF/loop-range.mlir --- a/mlir/test/Dialect/SCF/loop-range.mlir +++ b/mlir/test/Dialect/SCF/loop-range.mlir @@ -15,14 +15,14 @@ } // CHECK-LABEL: func @fold_one_loop -// CHECK-SAME: ([[ARG0:%.*]]: {{.*}}, [[ARG1:%.*]]: {{.*}}, [[ARG2:%.*]]: {{.*}} -// CHECK: [[C4:%.*]] = constant 4 : index -// CHECK: [[I0:%.*]] = addi [[ARG2]], [[ARG1]] : index -// CHECK: [[I1:%.*]] = muli [[I0]], [[C4]] : index -// CHECK: scf.for [[I:%.*]] = [[ARG2]] to [[I1]] step [[C4]] { -// CHECK: [[I2:%.*]] = memref.load [[ARG0]]{{\[}}[[I]] -// CHECK: [[I3:%.*]] = muli [[I2]], [[I2]] : i32 -// CHECK: memref.store [[I3]], [[ARG0]]{{\[}}[[I]] +// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}} +// CHECK: %[[C4:.*]] = constant 4 : index +// CHECK: %[[I0:.*]] = addi %[[ARG2]], %[[ARG1]] : index +// CHECK: %[[I1:.*]] = muli %[[I0]], %[[C4]] : index +// CHECK: scf.for %[[I:.*]] = %[[ARG2]] to %[[I1]] step %[[C4]] { +// CHECK: %[[I2:.*]] = memref.load %[[ARG0]]{{\[}}%[[I]] +// CHECK: %[[I3:.*]] = muli %[[I2]], %[[I2]] : i32 +// CHECK: memref.store %[[I3]], %[[ARG0]]{{\[}}%[[I]] func @fold_one_loop2(%arg0: memref, %arg1: index, %arg2: index) { %c0 = constant 0 : index @@ -42,14 +42,14 @@ } // CHECK-LABEL: func @fold_one_loop2 -// CHECK-SAME: ([[ARG0:%.*]]: {{.*}}, [[ARG1:%.*]]: {{.*}}, [[ARG2:%.*]]: {{.*}} -// CHECK: [[C4:%.*]] = constant 4 : index -// CHECK: [[I0:%.*]] = addi [[ARG2]], [[ARG1]] : index -// CHECK: [[I1:%.*]] = muli [[I0]], [[C4]] : index -// CHECK: scf.for [[I:%.*]] = [[ARG2]] to [[I1]] step [[C4]] { -// CHECK: [[I2:%.*]] = memref.load [[ARG0]]{{\[}}[[I]] -// CHECK: [[I3:%.*]] = muli [[I2]], [[I2]] : i32 -// CHECK: memref.store [[I3]], [[ARG0]]{{\[}}[[I]] +// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}} +// CHECK: %[[C4:.*]] = constant 4 : index +// CHECK: %[[I0:.*]] = addi %[[ARG2]], %[[ARG1]] : index +// CHECK: %[[I1:.*]] = muli %[[I0]], %[[C4]] : index +// CHECK: scf.for %[[I:.*]] = %[[ARG2]] to %[[I1]] step %[[C4]] { +// CHECK: %[[I2:.*]] = memref.load %[[ARG0]]{{\[}}%[[I]] +// CHECK: %[[I3:.*]] = muli %[[I2]], %[[I2]] : i32 +// CHECK: memref.store %[[I3]], %[[ARG0]]{{\[}}%[[I]] func @fold_two_loops(%arg0: memref, %arg1: index, %arg2: index) { %c0 = constant 0 : index @@ -69,15 +69,15 @@ } // CHECK-LABEL: func @fold_two_loops -// CHECK-SAME: ([[ARG0:%.*]]: {{.*}}, [[ARG1:%.*]]: {{.*}}, [[ARG2:%.*]]: {{.*}} -// CHECK: [[C10:%.*]] = constant 10 : index -// CHECK: [[C4:%.*]] = constant 4 : index -// CHECK: [[C1:%.*]] = constant 1 : index -// CHECK: [[I0:%.*]] = addi [[ARG2]], [[C10]] : index -// CHECK: scf.for [[J:%.*]] = [[ARG2]] to [[I0]] step [[C1]] { -// CHECK: [[I1:%.*]] = addi [[ARG2]], [[ARG1]] : index -// CHECK: [[I2:%.*]] = muli [[I1]], [[C4]] : index -// CHECK: scf.for [[I:%.*]] = [[J]] to [[I2]] step [[C4]] { -// CHECK: [[I3:%.*]] = memref.load [[ARG0]]{{\[}}[[I]] -// CHECK: [[I4:%.*]] = muli [[I3]], [[I3]] : i32 -// CHECK: memref.store [[I4]], [[ARG0]]{{\[}}[[I]] \ No newline at end of file +// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}} +// CHECK: %[[C10:.*]] = constant 10 : index +// CHECK: %[[C4:.*]] = constant 4 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[I0:.*]] = addi %[[ARG2]], %[[C10]] : index +// CHECK: scf.for %[[J:.*]] = %[[ARG2]] to %[[I0]] step %[[C1]] { +// CHECK: %[[I1:.*]] = addi %[[ARG2]], %[[ARG1]] : index +// CHECK: %[[I2:.*]] = muli %[[I1]], %[[C4]] : index +// CHECK: scf.for %[[I:.*]] = %[[J]] to %[[I2]] step %[[C4]] { +// CHECK: %[[I3:.*]] = memref.load %[[ARG0]]{{\[}}%[[I]] +// CHECK: %[[I4:.*]] = muli %[[I3]], %[[I3]] : i32 +// CHECK: memref.store %[[I4]], %[[ARG0]]{{\[}}%[[I]]