diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -35,6 +35,9 @@ std::unique_ptr createParallelLoopTilingPass(llvm::ArrayRef tileSize = {}); +std::unique_ptr createForLoopRangeFoldingPass(); + + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -45,4 +45,10 @@ let dependentDialects = ["AffineDialect"]; } +def SCFForLoopRangeFolding + : Pass<"for-loop-range-folding"> { + let summary = ""; + let constructor = "mlir::createForLoopRangeFoldingPass()"; +} + #endif // MLIR_DIALECT_SCF_PASSES diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSCFTransforms Bufferize.cpp + LoopRangeFolding.cpp LoopSpecialization.cpp ParallelLoopFusion.cpp ParallelLoopTiling.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp @@ -0,0 +1,95 @@ +#include "PassDetail.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/SCF/Utils.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" + + +using namespace mlir; +using namespace mlir::scf; + +namespace { +struct ForLoopRangeFolding + : public SCFForLoopRangeFoldingBase { + void runOnOperation() override; +}; +} + +LogicalResult foldRanges(ForOp op) { + // Fold until a fixed point is reached + Value indVar = op.getInductionVar(); + + auto canBeFolded = [&](Value value) { + return op.isDefinedOutsideOfLoop(value) || value == indVar; + }; + + while (true) { + + // If the induction variable is used more than once, we can't fold its arith + // ops into the loop range + if (!indVar.hasOneUse()) + break; + + Operation *use = indVar.getUses().begin().getUser(); + if (!isa(use)) + break; + + if (!llvm::all_of(use->getOperands(), canBeFolded)) + break; + + OpBuilder b(op); + BlockAndValueMapping lbMap; lbMap.map(indVar, op.lowerBound()); + BlockAndValueMapping ubMap; ubMap.map(indVar, op.upperBound()); + BlockAndValueMapping stepMap; stepMap.map(indVar, op.step()); + + if (auto addOp = dyn_cast(use)) { + auto lbFold = b.create( + op.getLoc(), + lbMap.lookupOrDefault(addOp.getOperand(0)), + lbMap.lookupOrDefault(addOp.getOperand(1))); + + auto ubFold = b.create( + op.getLoc(), + ubMap.lookupOrDefault(addOp.getOperand(0)), + ubMap.lookupOrDefault(addOp.getOperand(1))); + + op.setLowerBound(lbFold); + op.setUpperBound(ubFold); + + addOp.replaceAllUsesWith(indVar); + addOp.erase(); + + } else if (auto mulOp = dyn_cast(use)) { + auto ubFold = b.create( + op.getLoc(), + ubMap.lookupOrDefault(mulOp.getOperand(0)), + ubMap.lookupOrDefault(mulOp.getOperand(1))); + + auto stepFold = b.create( + op.getLoc(), + stepMap.lookupOrDefault(mulOp.getOperand(0)), + stepMap.lookupOrDefault(mulOp.getOperand(1))); + + op.setUpperBound(ubFold); + op.setStep(stepFold); + + mulOp.replaceAllUsesWith(indVar); + mulOp.erase(); + } + } + + return success(); +} + +void ForLoopRangeFolding::runOnOperation() { + getOperation()->walk([&](ForOp forOp) { + if (failed(foldRanges(forOp))) + signalPassFailure(); + }); +} + +std::unique_ptr mlir::createForLoopRangeFoldingPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SCF/loop-range.mlir b/mlir/test/Dialect/SCF/loop-range.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/loop-range.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt %s -pass-pipeline='func(for-loop-range-folding,canonicalize)' -split-input-file | FileCheck %s + +func @fold_one_loop(%arg0: memref, %arg1: index, %arg2: index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + scf.for %i = %c0 to %arg1 step %c1 { + %0 = addi %arg2, %i : index + %1 = muli %0, %c4 : index + %2 = memref.load %arg0[%1] : memref + %3 = muli %2, %2 : i32 + memref.store %3, %arg0[%1] : memref + } + return +} + +// CHECK-LABEL: func @fold_one_loop +// CHECK-SAME: ([[ARG0:%.*]]: {{.*}}, [[ARG1:%.*]]: {{.*}}, [[ARG2:%.*]]: {{.*}} +// CHECK: [[C4:%.*]] = constant 4 : index +// CHECK: [[I0:%.*]] = addi [[ARG2]], [[ARG1]] : index +// CHECK: [[I1:%.*]] = muli [[I0]], [[C4]] : index +// CHECK: scf.for [[I:%.*]] = [[ARG2]] to [[I1]] step [[C4]] { +// CHECK: [[I2:%.*]] = memref.load [[ARG0]]{{\[}}[[I]] +// CHECK: [[I3:%.*]] = muli [[I2]], [[I2]] : i32 +// CHECK: memref.store [[I3]], [[ARG0]]{{\[}}[[I]] + +func @fold_one_loop2(%arg0: memref, %arg1: index, %arg2: index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %c10 = constant 10 : index + scf.for %j = %c0 to %c10 step %c1 { + scf.for %i = %c0 to %arg1 step %c1 { + %0 = addi %arg2, %i : index + %1 = muli %0, %c4 : index + %2 = memref.load %arg0[%1] : memref + %3 = muli %2, %2 : i32 + memref.store %3, %arg0[%1] : memref + } + } + return +} + +// CHECK-LABEL: func @fold_one_loop2 +// CHECK-SAME: ([[ARG0:%.*]]: {{.*}}, [[ARG1:%.*]]: {{.*}}, [[ARG2:%.*]]: {{.*}} +// CHECK: [[C4:%.*]] = constant 4 : index +// CHECK: [[I0:%.*]] = addi [[ARG2]], [[ARG1]] : index +// CHECK: [[I1:%.*]] = muli [[I0]], [[C4]] : index +// CHECK: scf.for [[I:%.*]] = [[ARG2]] to [[I1]] step [[C4]] { +// CHECK: [[I2:%.*]] = memref.load [[ARG0]]{{\[}}[[I]] +// CHECK: [[I3:%.*]] = muli [[I2]], [[I2]] : i32 +// CHECK: memref.store [[I3]], [[ARG0]]{{\[}}[[I]] + +func @fold_two_loops(%arg0: memref, %arg1: index, %arg2: index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %c10 = constant 10 : index + scf.for %j = %c0 to %c10 step %c1 { + scf.for %i = %j to %arg1 step %c1 { + %0 = addi %arg2, %i : index + %1 = muli %0, %c4 : index + %2 = memref.load %arg0[%1] : memref + %3 = muli %2, %2 : i32 + memref.store %3, %arg0[%1] : memref + } + } + return +} + +// CHECK-LABEL: func @fold_two_loops +// CHECK-SAME: ([[ARG0:%.*]]: {{.*}}, [[ARG1:%.*]]: {{.*}}, [[ARG2:%.*]]: {{.*}} +// CHECK: [[C10:%.*]] = constant 10 : index +// CHECK: [[C4:%.*]] = constant 4 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[I0:%.*]] = addi [[ARG2]], [[C10]] : index +// CHECK: scf.for [[J:%.*]] = [[ARG2]] to [[I0]] step [[C1]] { +// CHECK: [[I1:%.*]] = addi [[ARG2]], [[ARG1]] : index +// CHECK: [[I2:%.*]] = muli [[I1]], [[C4]] : index +// CHECK: scf.for [[I:%.*]] = [[J]] to [[I2]] step [[C4]] { +// CHECK: [[I3:%.*]] = memref.load [[ARG0]]{{\[}}[[I]] +// CHECK: [[I4:%.*]] = muli [[I3]], [[I3]] : i32 +// CHECK: memref.store [[I4]], [[ARG0]]{{\[}}[[I]] \ No newline at end of file