diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -28,6 +28,10 @@ /// better vectorization. std::unique_ptr createForLoopPeelingPass(); +/// Creates a pass that canonicalizes affine.min ops in scf.for loops with +/// known lower and upper bounds. +std::unique_ptr createAffineMinSCFCanonicalizationPass(); + /// Creates a loop fusion pass which fuses parallel loops. std::unique_ptr createParallelLoopFusionPass(); diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -17,6 +17,16 @@ let dependentDialects = ["memref::MemRefDialect"]; } +// Note: Making this a canonicalization pattern would require a dependency +// of the SCF dialect on the Affine dialect or vice versa. +def AffineMinSCFCanonicalization + : FunctionPass<"canonicalize-scf-affine-min"> { + let summary = "Canonicalize affine.min ops in the context of SCF loops with " + "known bounds"; + let constructor = "mlir::createAffineMinSCFCanonicalizationPass()"; + let dependentDialects = ["AffineDialect"]; +} + def SCFForLoopPeeling : FunctionPass<"for-loop-peeling"> { let summary = "Peel `for` loops at their upper bounds."; diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -13,10 +13,12 @@ #ifndef MLIR_DIALECT_SCF_TRANSFORMS_H_ #define MLIR_DIALECT_SCF_TRANSFORMS_H_ +#include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" namespace mlir { +class AffineMinOp; class ConversionTarget; struct LogicalResult; class MLIRContext; @@ -26,6 +28,7 @@ class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; class Operation; +class Value; namespace scf { @@ -34,6 +37,21 @@ class ParallelOp; class ForOp; +/// Try to canonicalize an affine.min operation in the context of `for` loops +/// with a known range. +/// +/// `loopMatcher` is used to retrieve loop bounds and step size for a given +/// iteration variable: If the first parameter is an iteration variable, return +/// lower/upper bounds via the second/third parameter and the step size via the +/// last parameter. The function should return `success` in that case. If the +/// first parameter is not an iteration variable, return `failure`. +/// +/// Note: `loopMatcher` allows this function to be used with any "for loop"-like +/// operation (scf.for, scf.parallel and even ops defined in other dialects). +LogicalResult canonicalizeAffineMinOpInLoop( + AffineMinOp minOp, RewriterBase &rewriter, + function_ref loopMatcher); + /// Fuses all adjacent scf.parallel operations with identical bounds and step /// into one scf.parallel operations. Uses a naive aliasing and dependency /// analysis. @@ -149,6 +167,11 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, const PipeliningOption &options); +/// Populate patterns for canonicalizing operations inside SCF loop bodies. +/// At the moment, only affine.min computations with iteration variables, +/// loop bounds and loop steps are canonicalized. +void populateSCFLoopBodyCanonicalizationPatterns(RewritePatternSet &patterns); + } // namespace scf } // namespace mlir diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -382,6 +382,77 @@ return success(); } +/// Canonicalize AffineMinOp operations in the context of for loops with a known +/// range. Call `canonicalizeAffineMinOp` and add the following constraints to +/// the constraint system (along with the missing dimensions): +/// +/// * iv >= lb +/// * iv < lb + step * ((ub - lb - 1) floorDiv step) + 1 +/// +/// Note: Due to limitations of FlatAffineConstraints, only constant step sizes +/// are currently supported. +LogicalResult mlir::scf::canonicalizeAffineMinOpInLoop( + AffineMinOp minOp, RewriterBase &rewriter, + function_ref loopMatcher) { + FlatAffineValueConstraints constraints; + DenseSet allIvs; + + // Find all iteration variables among `minOp`'s operands add constrain them. + for (Value operand : minOp.operands()) { + // Skip duplicate ivs. + if (llvm::find(allIvs, operand) != allIvs.end()) + continue; + + // If `operand` is an iteration variable: Find corresponding loop + // bounds and step. + Value iv = operand; + Value lb, ub, step; + if (failed(loopMatcher(operand, lb, ub, step))) + continue; + allIvs.insert(iv); + + // FlatAffineConstraints does not support semi-affine expressions. + // Therefore, only constant step values are supported. + auto stepInt = getConstantIntValue(step); + if (!stepInt) + continue; + + unsigned dimIv = constraints.addDimId(iv); + unsigned dimLb = constraints.addDimId(lb); + unsigned dimUb = constraints.addDimId(ub); + + // If loop lower/upper bounds are constant: Add EQ constraint. + Optional lbInt = getConstantIntValue(lb); + Optional ubInt = getConstantIntValue(ub); + if (lbInt) + constraints.addBound(FlatAffineConstraints::EQ, dimLb, *lbInt); + if (ubInt) + constraints.addBound(FlatAffineConstraints::EQ, dimUb, *ubInt); + + // iv >= lb (equiv.: iv - lb >= 0) + SmallVector ineqLb(constraints.getNumCols(), 0); + ineqLb[dimIv] = 1; + ineqLb[dimLb] = -1; + constraints.addInequality(ineqLb); + + // iv < lb + step * ((ub - lb - 1) floorDiv step) + 1 + AffineExpr exprLb = lbInt ? rewriter.getAffineConstantExpr(*lbInt) + : rewriter.getAffineDimExpr(dimLb); + AffineExpr exprUb = ubInt ? rewriter.getAffineConstantExpr(*ubInt) + : rewriter.getAffineDimExpr(dimUb); + AffineExpr ivUb = + exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1).floorDiv(*stepInt))); + auto map = AffineMap::get( + /*dimCount=*/constraints.getNumDimIds(), + /*symbolCount=*/constraints.getNumSymbolIds(), /*result=*/ivUb); + + if (failed(constraints.addBound(FlatAffineConstraints::UB, dimIv, map))) + return failure(); + } + + return canonicalizeAffineMinOp(rewriter, minOp, constraints); +} + static constexpr char kPeeledLoopLabel[] = "__peeled_loop__"; static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; @@ -423,6 +494,39 @@ /// the direct parent. bool skipPartial; }; + +/// Canonicalize AffineMinOp operations in the context of scf.for and +/// scf.parallel loops with a known range. +struct AffineMinSCFCanonicalizationPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineMinOp minOp, + PatternRewriter &rewriter) const override { + auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) { + if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { + lb = forOp.lowerBound(); + ub = forOp.upperBound(); + step = forOp.step(); + return success(); + } + if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { + for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { + if (parOp.getInductionVars()[idx] == iv) { + lb = parOp.lowerBound()[idx]; + ub = parOp.upperBound()[idx]; + step = parOp.step()[idx]; + return success(); + } + } + return failure(); + } + return failure(); + }; + + return scf::canonicalizeAffineMinOpInLoop(minOp, rewriter, loopMatcher); + } +}; } // namespace namespace { @@ -456,8 +560,24 @@ }); } }; + +struct AffineMinSCFCanonicalization + : public AffineMinSCFCanonicalizationBase { + void runOnFunction() override { + FuncOp funcOp = getFunction(); + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) + signalPassFailure(); + } +}; } // namespace +std::unique_ptr mlir::createAffineMinSCFCanonicalizationPass() { + return std::make_unique(); +} + std::unique_ptr mlir::createParallelLoopSpecializationPass() { return std::make_unique(); } @@ -469,3 +589,8 @@ std::unique_ptr mlir::createForLoopPeelingPass() { return std::make_unique(); } + +void mlir::scf::populateSCFLoopBodyCanonicalizationPatterns( + RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} diff --git a/mlir/test/Dialect/SCF/canonicalize-scf-affine-min.mlir b/mlir/test/Dialect/SCF/canonicalize-scf-affine-min.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/canonicalize-scf-affine-min.mlir @@ -0,0 +1,189 @@ +// RUN: mlir-opt %s -canonicalize-scf-affine-min -split-input-file | FileCheck %s + +// Note: This is mostly a copy of test/Dialect/Linalg/fold-affine-min-scf.mlir + +// CHECK-LABEL: func @scf_for_canonicalize_min +// CHECK: %[[C2:.*]] = constant 2 : i64 +// CHECK: scf.for +// CHECK: memref.store %[[C2]], %{{.*}}[] : memref +func @scf_for_canonicalize_min(%A : memref) { + %c0 = constant 0 : index + %c2 = constant 2 : index + %c4 = constant 4 : index + + scf.for %i = %c0 to %c4 step %c2 { + %1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4) + %2 = index_cast %1: index to i64 + memref.store %2, %A[]: memref + } + return +} + +// ----- + +// CHECK-LABEL: func @scf_for_loop_nest_canonicalize_min +// CHECK: %[[C5:.*]] = constant 5 : i64 +// CHECK: scf.for +// CHECK: scf.for +// CHECK: memref.store %[[C5]], %{{.*}}[] : memref +func @scf_for_loop_nest_canonicalize_min(%A : memref) { + %c0 = constant 0 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %c6 = constant 6 : index + + scf.for %i = %c0 to %c4 step %c2 { + scf.for %j = %c0 to %c6 step %c3 { + %1 = affine.min affine_map<(d0, d1, d2, d3)[] -> (5, d1 + d3 - d0 - d2)> (%i, %c4, %j, %c6) + %2 = index_cast %1: index to i64 + memref.store %2, %A[]: memref + } + } + return +} + +// ----- + +// CHECK-LABEL: func @scf_for_not_canonicalizable_1 +// CHECK: scf.for +// CHECK: affine.min +// CHECK: index_cast +func @scf_for_not_canonicalizable_1(%A : memref) { + // This should not canonicalize because: 4 - %i may take the value 1 < 2. + %c1 = constant 1 : index + %c2 = constant 2 : index + %c4 = constant 4 : index + + scf.for %i = %c1 to %c4 step %c2 { + %1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c4] + %2 = index_cast %1: index to i64 + memref.store %2, %A[]: memref + } + return +} + +// ----- + +// CHECK-LABEL: func @scf_for_canonicalize_partly +// CHECK: scf.for +// CHECK: affine.apply +// CHECK: index_cast +func @scf_for_canonicalize_partly(%A : memref) { + // This should canonicalize only partly: 256 - %i <= 256. + %c1 = constant 1 : index + %c16 = constant 16 : index + %c256 = constant 256 : index + + scf.for %i = %c1 to %c256 step %c16 { + %1 = affine.min affine_map<(d0) -> (256, 256 - d0)> (%i) + %2 = index_cast %1: index to i64 + memref.store %2, %A[]: memref + } + return +} + +// ----- + +// CHECK-LABEL: func @scf_for_not_canonicalizable_2 +// CHECK: scf.for +// CHECK: affine.min +// CHECK: index_cast +func @scf_for_not_canonicalizable_2(%A : memref, %step : index) { + // This example should simplify but affine_map is currently missing + // semi-affine canonicalizations: `((s0 * 42 - 1) floordiv s0) * s0` + // should evaluate to 41 * s0. + // Note that this may require positivity assumptions on `s0`. + // Revisit when support is added. + %c0 = constant 0 : index + + %ub = affine.apply affine_map<(d0) -> (42 * d0)> (%step) + scf.for %i = %c0 to %ub step %step { + %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d1 - d2)> (%step, %ub, %i) + %2 = index_cast %1: index to i64 + memref.store %2, %A[]: memref + } + return +} + +// ----- + +// CHECK-LABEL: func @scf_for_not_canonicalizable_3 +// CHECK: scf.for +// CHECK: affine.min +// CHECK: index_cast +func @scf_for_not_canonicalizable_3(%A : memref, %step : index) { + // This example should simplify but affine_map is currently missing + // semi-affine canonicalizations: `-(((s0 * s0 - 1) floordiv s0) * s0)` + // should evaluate to (s0 - 1) * s0. + // Note that this may require positivity assumptions on `s0`. + // Revisit when support is added. + %c0 = constant 0 : index + + %ub2 = affine.apply affine_map<(d0)[s0] -> (s0 * d0)> (%step)[%step] + scf.for %i = %c0 to %ub2 step %step { + %1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub2) + %2 = index_cast %1: index to i64 + memref.store %2, %A[]: memref + } + return +} + +// ----- + +// CHECK-LABEL: func @scf_for_invalid_loop +// CHECK: scf.for +// CHECK: affine.min +// CHECK: index_cast +func @scf_for_invalid_loop(%A : memref, %step : index) { + // This is an invalid loop. It should not be touched by the canonicalization + // pattern. + %c1 = constant 1 : index + %c7 = constant 7 : index + %c256 = constant 256 : index + + scf.for %i = %c256 to %c1 step %c1 { + %1 = affine.min affine_map<(d0)[s0] -> (s0 + d0, 0)> (%i)[%c7] + %2 = index_cast %1: index to i64 + memref.store %2, %A[]: memref + } + return +} + +// ----- + +// CHECK-LABEL: func @scf_parallel_canonicalize_min_1 +// CHECK: %[[C2:.*]] = constant 2 : i64 +// CHECK: scf.parallel +// CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref +func @scf_parallel_canonicalize_min_1(%A : memref) { + %c0 = constant 0 : index + %c2 = constant 2 : index + %c4 = constant 4 : index + + scf.parallel (%i) = (%c0) to (%c4) step (%c2) { + %1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4) + %2 = index_cast %1: index to i64 + memref.store %2, %A[]: memref + } + return +} + +// ----- + +// CHECK-LABEL: func @scf_parallel_canonicalize_min_2 +// CHECK: %[[C2:.*]] = constant 2 : i64 +// CHECK: scf.parallel +// CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref +func @scf_parallel_canonicalize_min_2(%A : memref) { + %c1 = constant 1 : index + %c2 = constant 2 : index + %c7 = constant 7 : index + + scf.parallel (%i) = (%c1) to (%c7) step (%c2) { + %1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c7] + %2 = index_cast %1: index to i64 + memref.store %2, %A[]: memref + } + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1475,7 +1475,10 @@ "lib/Dialect/SCF/Transforms/*.cpp", "lib/Dialect/SCF/Transforms/*.h", ]), - hdrs = ["include/mlir/Dialect/SCF/Passes.h"], + hdrs = [ + "include/mlir/Dialect/SCF/Passes.h", + "include/mlir/Dialect/SCF/Transforms.h", + ], includes = ["include"], deps = [ ":Affine",