diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -44,9 +44,10 @@ /// by an scf.if for the last (partial) iteration (if any). This transformation /// is called "loop peeling". /// -/// Other patterns can simplify/canonicalize operations in the body of the loop -/// and the scf.if. This is beneficial for a wide range of transformations such -/// as vectorization or loop tiling. +/// This transformation is beneficial for a wide range of transformations such +/// as vectorization or loop tiling: It enables additional canonicalizations +/// inside the peeled loop body such as rewriting masked loads into unmaked +/// loads. /// /// E.g., assuming a lower bound of 0 (for illustration purposes): /// ``` @@ -65,11 +66,22 @@ /// } /// ``` /// -/// This function rewrites the given scf.for loop in-place and creates a new -/// scf.if operation (returned via `ifOp`) for the last iteration. +/// After loop peeling, this function tries to simplify/canonicalize affine.min +/// operations in the body of the loop and the scf.if, taking advantage of the +/// fact that every iteration of the peeled loop is a "full" iteration. This +/// canonicalization is expected to enable further canonicalization +/// opportunities through other patterns. /// -/// TODO: Simplify affine.min ops inside the new loop/if statement. -LogicalResult peelForLoop(RewriterBase &b, ForOp forOp, scf::IfOp &ifOp); +/// The return value indicates whether the loop was rewritten or not. Loops are +/// not rewritten if: +/// * Loop step size is 1 or +/// * Loop bounds and step size are static, and step already divides the +/// iteration space evenly. +/// +/// Note: This function rewrites the given scf.for loop in-place and creates a +/// new scf.if operation for the last iteration. It replaces all uses of the +/// unpeeled loop with the results of the newly generated scf.if. +LogicalResult peelAndCanonicalizeForLoop(RewriterBase &rewriter, ForOp forOp); /// Tile a parallel loop of the form /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Analysis/AffineStructures.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/SCF/SCF.h" @@ -97,8 +98,16 @@ /// Rewrite a for loop with bounds/step that potentially do not divide evenly /// into a for loop where the step divides the iteration space evenly, followed /// by an scf.if for the last (partial) iteration (if any). -LogicalResult mlir::scf::peelForLoop(RewriterBase &b, ForOp forOp, - scf::IfOp &ifOp) { +/// +/// This function rewrites the given scf.for loop in-place and creates a new +/// scf.if operation for the last iteration. It replaces all uses of the +/// unpeeled loop with the results of the newly generated scf.if. +/// +/// The newly generated scf.if operation is returned via `ifOp`. The boundary +/// at which the loop is split (new upper bound) is returned via `splitBound`. +/// The return value indicates whether the loop was rewritten or not. +static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp, scf::IfOp &ifOp, + Value &splitBound) { RewriterBase::InsertionGuard guard(b); auto lbInt = getConstantIntValue(forOp.lowerBound()); auto ubInt = getConstantIntValue(forOp.upperBound()); @@ -117,7 +126,7 @@ // New upper bound: %ub - (%ub - %lb) mod %step auto modMap = AffineMap::get(3, 0, {dim1 - ((dim1 - dim0) % dim2)}); b.setInsertionPoint(forOp); - Value splitBound = b.createOrFold( + splitBound = b.createOrFold( loc, modMap, ValueRange{forOp.lowerBound(), forOp.upperBound(), forOp.step()}); @@ -153,6 +162,267 @@ return success(); } +/// Bound an identifier `pos` in a given FlatAffineConstraints with constraints +/// drawn from an affine map. Before adding the constraint, the dimensions/ +/// symbols of the affine map are aligned with the constraint set. `operands` +/// are the SSA Value operands used with the affine map. `dims`/`syms` are the +/// SSA Values assoicated with the FlatAffineConstraint's dimension/symbol +/// columns. Note that this function adds a new symbol column to the constraint +/// set for each dimension/symbol that exists in the affine map but not in the +/// constraint set. The new symbols are returned via `syms`. +static LogicalResult alignAndAddBound(FlatAffineConstraints &constraints, + FlatAffineConstraints::BoundType type, + unsigned pos, AffineMap map, + ValueRange operands, ValueRange dims, + SmallVector &syms) { + SmallVector newSyms; + AffineMap alignedMap = + alignAffineMapWithValues(map, operands, dims, syms, &newSyms); + for (unsigned i = 0; i < newSyms.size() - syms.size(); ++i) + constraints.addSymbolId(constraints.getNumSymbolIds()); + std::swap(syms, newSyms); + return constraints.addBound(type, pos, alignedMap); +} + +/// This function tries to canonicalize affine.min operations by proving that +/// its value is bounded by the same lower and upper bound. In that case, the +/// operation can be replaced with the bound. +/// +/// Bounds are computed by FlatAffineConstraints. Invariants required for +/// finding/proving bounds should be supplied via `constraintBuilder` by adding +/// constraints to the provided FlatAffineConstraints. Only the dimensions in +/// `dims` and constants can be used when adding constraints. Adding new +/// dimensions/symbols is not allowed. (However, local columns may be added.) +/// +/// 1. Set up a constraint system with the dimensions passed as `dims`. +/// 2. Call the builder, which may add new constraints. +/// 3. Add dimensions for `minOp` and `minOpUb` (upper bound of `minOp`). +/// 4. Compute an upper bound of `minOp` and bind it to `minOpUb`. SSA values +/// that are used in `minOp` but are not part of `dims`, are added as extra +/// symbols to the constraint set. +/// 5. For each result of `minOp`: Add result as a dimension `r_i`. Prove that +/// r_i >= minOpUb. If this is the case, ub(minOp) == lb(minOp) and `minOp` +/// can be replaced with that bound. +/// +/// In summary, the following constraints are added throughout this function: +/// +/// dims | minOp | minOpUb | r_i | extra syms... | const | eq/ineq +/// ------+-------+---------+-----+---------------+-------+------------------- +/// (various eq./ineq. constraining `dims`, added by `constraintBuilder`) +/// ... | 0 | 0 | 0 | 0 | ... | ... +/// ------+-------+---------+-----+---------------+-------+------------------- +/// (various ineq. constraining `minOp` in terms of `minOp` operands (`dims` +/// and extra `minOp` operands "extra syms" that are not in `dims`)). +/// ... | -1 | 0 | 0 | ... | ... | >= 0 +/// ------+-------+---------+-----+---------------+-------+------------------- +/// (set `minOpUb` to `minOp` upper bound in terms of `dims` and "extra syms") +/// ... | 0 | -1 | 0 | ... | ... | = 0 +/// ------+-------+---------+-----+---------------+-------+------------------- +/// (for each `minOp` map result r_i: copy previous constraints, set r_i to +/// corresponding map result, prove r_i >= minOpUb via contradiction) +/// ... | 0 | 0 | -1 | ... | ... | = 0 +/// 0 | 0 | 1 | -1 | 0 | -1 | >= 0 +/// +static LogicalResult canonicalizeAffineMinOp( + RewriterBase &rewriter, AffineMinOp minOp, ValueRange dims, + function_ref constraintsBuilder) { + RewriterBase::InsertionGuard guard(rewriter); + AffineMap minOpMap = minOp.getAffineMap(); + unsigned numResults = minOpMap.getNumResults(); + FlatAffineConstraints constraints; + // Keep track of SSA values of dims (if any), so that affine maps can be + // aligned with the dims in `constraints`. + SmallVector dimValues, symValues; + + /// Add an SSA value as a dimension to the constraint system. If the SSA value + /// is a constant, set the dimension to the constant value. + auto addDim = [&](Value value = {}) { + unsigned dimId = constraints.getNumDimIds(); + constraints.addDimId(dimId); + if (auto constInt = getConstantIntValue(value)) + constraints.addBound(FlatAffineConstraints::EQ, dimId, *constInt); + dimValues.push_back(value); + return dimId; + }; + + // Set up constraint system and call builder. + for (Value value : dims) + addDim(value); + if (failed(constraintsBuilder(constraints))) + return failure(); + + // Add extra dimensions. + unsigned dimMinOp = addDim(); // `minOp` + unsigned dimMinOpUb = addDim(); // `minOp` upper bound + unsigned resultDimStart = constraints.getNumDimIds(); + for (unsigned i = 0; i < numResults; ++i) + addDim(); + + // Add an inequality for each result expr_i of minOpMap: minOp <= expr_i + if (failed(alignAndAddBound(constraints, FlatAffineConstraints::UB, dimMinOp, + minOpMap, minOp.operands(), dimValues, + symValues))) + return failure(); + + // Try to compute an upper bound for minOp, expressed in terms of the other + // `dims` and extra symbols. + SmallVector minOpValLb(1), minOpValUb(1); + constraints.getSliceBounds(dimMinOp, 1, minOp.getContext(), &minOpValLb, + &minOpValUb); + // TODO: `getSliceBounds` may return multiple bounds at the moment. This is + // a TODO of `getSliceBounds` and not handled here. + if (!minOpValUb[0] || minOpValUb[0].getNumResults() != 1) + return failure(); // No or multiple upper bounds found. + + // Add an equality: dimMinOpUb = minOpValUb[0] + // Add back dimension for minOp. (Was removed by `getSliceBounds`.) + AffineMap alignedUbMap = minOpValUb[0].shiftDims(/*shift=*/1, + /*offset=*/dimMinOp); + if (failed(constraints.addBound(FlatAffineConstraints::EQ, dimMinOpUb, + alignedUbMap))) + return failure(); + + // If the constraint system is empty, there is an inconsistency. (E.g., this + // can happen if loop lb > ub.) + if (constraints.isEmpty()) + return failure(); + + // Prove that each result of minOpMap has a lower bound that is equal to (or + // greater than) the upper bound of minOp (`kDimMinOpUb`). In that case, + // minOp can be replaced with the bound. I.e., prove that for each result + // expr_i (represented by dimension r_i): + // + // r_i >= minOpUb + // + // To prove this inequality, add its negation to the constraint set and prove + // that the constraint set is empty. + unsigned symValuesSize = symValues.size(); + (void)symValuesSize; + for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) { + assert(!constraints.isEmpty() && "constraint set should not be empty yet"); + + // Add an equality: r_i = expr_i + // Note: These equalities could have been added earlier and used to express + // minOp <= expr_i. However, then we run the risk that `getSliceBounds` + // computes minOpUb in terms of r_i dims, which is not desired. + if (failed(alignAndAddBound(constraints, FlatAffineConstraints::EQ, i, + minOpMap.getSubMap({i - resultDimStart}), + minOp.operands(), dimValues, symValues))) + return failure(); + assert(symValues.size() == symValuesSize && "no new symbols expected"); + + // Add inequality: r_i < minOpUb (equiv.: minOpUb - r_i - 1 >= 0) + SmallVector ineq(constraints.getNumCols(), 0); + ineq[dimMinOpUb] = 1; + ineq[i] = -1; + ineq[constraints.getNumCols() - 1] = -1; + constraints.addInequality(ineq); + if (!constraints.isEmpty()) + return failure(); + + // Check each result one-by-one, so remove r_i < minOpUb ineq. again. + constraints.removeInequality(constraints.getNumInequalities() - 1); + } + + // Lower and upper bound of `minOp` are equal. Replace `minOp` with its upper + // bound. `dimValues` and `symValues` may have "empty" Values. These must be + // filtered out from the list of Values and from the affine map. + SmallVector newOperands; + SmallVector dimReplacements, symReplacements; + unsigned numDims = 0, numSyms = 0; + for (Value val : dimValues) { + if (val) { + newOperands.push_back(val); + dimReplacements.push_back(rewriter.getAffineDimExpr(numDims++)); + } else { + dimReplacements.push_back(rewriter.getAffineDimExpr(numDims)); + } + } + for (Value val : symValues) { + if (val) { + newOperands.push_back(val); + symReplacements.push_back(rewriter.getAffineSymbolExpr(numSyms++)); + } else { + symReplacements.push_back(rewriter.getAffineSymbolExpr(numSyms)); + } + } + AffineMap newMap = alignedUbMap.replaceDimsAndSymbols( + dimReplacements, symReplacements, numDims, numSyms); + rewriter.setInsertionPoint(minOp); + rewriter.replaceOpWithNewOp(minOp, newMap, newOperands); + return success(); +} + +/// Try to simplify an affine.min operation `minOp` after loop peeling. This +/// function detects affine.min operations such as (ub is the previous upper +/// bound of the unpeeled loop): +/// ``` +/// #map = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)> +/// %r = affine.min #affine.min #map(%iv)[%step, %ub] +/// ``` +/// and rewrites them into (in the case the peeled loop): +/// ``` +/// %r = %step +/// ``` +/// affine.min operations inside the generated scf.if operation are rewritten in +/// a similar way. +/// +/// This function builds up a set of constraints, capable of proving that: +/// * Inside the peeled loop: min(step, ub - iv) == step +/// * Inside the scf.if operation: min(step, ub - iv) == ub - iv +/// +/// Note: `ub` is the previous upper bound of the loop (before peeling). +/// `insideLoop` must be true for affine.min ops inside the loop and false for +/// affine.min ops inside the scf.for op. +static LogicalResult rewritePeeledAffineOp(RewriterBase &rewriter, + AffineMinOp minOp, Value iv, + Value ub, Value step, + bool insideLoop) { + auto constraintsBuilder = [&](FlatAffineConstraints &constraints) { + // Add loop peeling invariant. This is the main piece of knowledge that + // enables AffineMinOp simplification. + if (insideLoop) { + // ub - iv >= step (equiv.: -iv + ub - step + 0 >= 0) + // Intuitively: Inside the peeled loop, every iteration is a "full" + // iteration, i.e., step divides the iteration space `ub - lb` evenly. + constraints.addInequality({-1, 1, -1, 0}); + } else { + // ub - iv < step (equiv.: iv + -ub + step - 1 >= 0) + // Intuitively: `iv` is the split bound here, i.e., the iteration variable + // value of the very last iteration (in the unpeeled loop). At that point, + // there are less than `step` elements remaining. (Otherwise, the peeled + // loop would run for at least one more iteration.) + constraints.addInequality({1, -1, 1, -1}); + } + return success(); + }; + + return canonicalizeAffineMinOp(rewriter, minOp, + /*dims=*/ValueRange{iv, ub, step}, + constraintsBuilder); +} + +LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter, + ForOp forOp) { + Value ub = forOp.upperBound(); + scf::IfOp ifOp; + Value splitBound; + if (failed(peelForLoop(rewriter, forOp, ifOp, splitBound))) + return failure(); + + // Rewrite affine.min ops. + forOp.walk([&](AffineMinOp minOp) { + (void)rewritePeeledAffineOp(rewriter, minOp, forOp.getInductionVar(), ub, + forOp.step(), /*insideLoop=*/true); + }); + ifOp.walk([&](AffineMinOp minOp) { + (void)rewritePeeledAffineOp(rewriter, minOp, splitBound, ub, forOp.step(), + /*insideLoop=*/false); + }); + + return success(); +} + static constexpr char kPeeledLoopLabel[] = "__peeled_loop__"; namespace { @@ -163,15 +433,12 @@ PatternRewriter &rewriter) const override { if (forOp->hasAttr(kPeeledLoopLabel)) return failure(); - - scf::IfOp ifOp; - if (failed(peelForLoop(rewriter, forOp, ifOp))) + if (failed(peelAndCanonicalizeForLoop(rewriter, forOp))) return failure(); // Apply label, so that the same loop is not rewritten a second time. rewriter.updateRootInPlace(forOp, [&]() { forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); }); - return success(); } }; diff --git a/mlir/test/Dialect/SCF/for-loop-peeling.mlir b/mlir/test/Dialect/SCF/for-loop-peeling.mlir --- a/mlir/test/Dialect/SCF/for-loop-peeling.mlir +++ b/mlir/test/Dialect/SCF/for-loop-peeling.mlir @@ -1,22 +1,20 @@ // RUN: mlir-opt %s -for-loop-peeling -canonicalize -split-input-file | FileCheck %s // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (s1 - s0) mod s2)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<()[s0, s1, s2] -> (s0, s2 - (s2 - (s2 - s1) mod s0))> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (s1 - (s1 - s0) mod s2))> // CHECK: func @fully_dynamic_bounds( // CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index // CHECK: %[[C0_I32:.*]] = constant 0 : i32 // CHECK: %[[NEW_UB:.*]] = affine.apply #[[MAP0]]()[%[[LB]], %[[UB]], %[[STEP]]] // CHECK: %[[LOOP:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[NEW_UB]] // CHECK-SAME: step %[[STEP]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) { -// CHECK: %[[MINOP:.*]] = affine.min #[[MAP1]](%[[IV]])[%[[STEP]], %[[UB]]] -// CHECK: %[[CAST:.*]] = index_cast %[[MINOP]] : index to i32 +// CHECK: %[[CAST:.*]] = index_cast %[[STEP]] : index to i32 // CHECK: %[[ADD:.*]] = addi %[[ACC]], %[[CAST]] : i32 // CHECK: scf.yield %[[ADD]] // CHECK: } // CHECK: %[[HAS_MORE:.*]] = cmpi slt, %[[NEW_UB]], %[[UB]] // CHECK: %[[RESULT:.*]] = scf.if %[[HAS_MORE]] -> (i32) { -// CHECK: %[[REM:.*]] = affine.min #[[MAP2]]()[%[[STEP]], %[[LB]], %[[UB]]] +// CHECK: %[[REM:.*]] = affine.apply #[[MAP1]]()[%[[LB]], %[[UB]], %[[STEP]]] // CHECK: %[[CAST2:.*]] = index_cast %[[REM]] // CHECK: %[[ADD2:.*]] = addi %[[LOOP]], %[[CAST2]] // CHECK: scf.yield %[[ADD2]] @@ -38,18 +36,16 @@ // ----- -// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0) -> (4, -d0 + 17)> // CHECK: func @fully_static_bounds( // CHECK-DAG: %[[C0_I32:.*]] = constant 0 : i32 // CHECK-DAG: %[[C1_I32:.*]] = constant 1 : i32 +// CHECK-DAG: %[[C4_I32:.*]] = constant 4 : i32 // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[C4:.*]] = constant 4 : index // CHECK-DAG: %[[C16:.*]] = constant 16 : index // CHECK: %[[LOOP:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C16]] // CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) { -// CHECK: %[[MINOP:.*]] = affine.min #[[MAP]](%[[IV]]) -// CHECK: %[[CAST:.*]] = index_cast %[[MINOP]] : index to i32 -// CHECK: %[[ADD:.*]] = addi %[[ACC]], %[[CAST]] : i32 +// CHECK: %[[ADD:.*]] = addi %[[ACC]], %[[C4_I32]] : i32 // CHECK: scf.yield %[[ADD]] // CHECK: } // CHECK: %[[RESULT:.*]] = addi %[[LOOP]], %[[C1_I32]] : i32 @@ -73,24 +69,22 @@ // ----- // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> ((s0 floordiv 4) * 4)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (4, s0 mod 4)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 mod 4)> // CHECK: func @dynamic_upper_bound( // CHECK-SAME: %[[UB:.*]]: index // CHECK-DAG: %[[C0_I32:.*]] = constant 0 : i32 +// CHECK-DAG: %[[C4_I32:.*]] = constant 4 : i32 // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[C4:.*]] = constant 4 : index // CHECK: %[[NEW_UB:.*]] = affine.apply #[[MAP0]]()[%[[UB]]] // CHECK: %[[LOOP:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[NEW_UB]] // CHECK-SAME: step %[[C4]] iter_args(%[[ACC:.*]] = %[[C0_I32]]) -> (i32) { -// CHECK: %[[MINOP:.*]] = affine.min #[[MAP1]](%[[IV]])[%[[UB]]] -// CHECK: %[[CAST:.*]] = index_cast %[[MINOP]] : index to i32 -// CHECK: %[[ADD:.*]] = addi %[[ACC]], %[[CAST]] : i32 +// CHECK: %[[ADD:.*]] = addi %[[ACC]], %[[C4_I32]] : i32 // CHECK: scf.yield %[[ADD]] // CHECK: } // CHECK: %[[HAS_MORE:.*]] = cmpi slt, %[[NEW_UB]], %[[UB]] // CHECK: %[[RESULT:.*]] = scf.if %[[HAS_MORE]] -> (i32) { -// CHECK: %[[REM:.*]] = affine.min #[[MAP2]]()[%[[UB]]] +// CHECK: %[[REM:.*]] = affine.apply #[[MAP1]]()[%[[UB]]] // CHECK: %[[CAST2:.*]] = index_cast %[[REM]] // CHECK: %[[ADD2:.*]] = addi %[[LOOP]], %[[CAST2]] // CHECK: scf.yield %[[ADD2]] @@ -116,23 +110,21 @@ // ----- // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> ((s0 floordiv 4) * 4)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (4, s0 mod 4)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 mod 4)> // CHECK: func @no_loop_results( // CHECK-SAME: %[[UB:.*]]: index, %[[MEMREF:.*]]: memref +// CHECK-DAG: %[[C4_I32:.*]] = constant 4 : i32 // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[C4:.*]] = constant 4 : index // CHECK: %[[NEW_UB:.*]] = affine.apply #[[MAP0]]()[%[[UB]]] // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[NEW_UB]] step %[[C4]] { -// CHECK: %[[MINOP:.*]] = affine.min #[[MAP1]](%[[IV]])[%[[UB]]] // CHECK: %[[LOAD:.*]] = memref.load %[[MEMREF]][] -// CHECK: %[[CAST:.*]] = index_cast %[[MINOP]] : index to i32 -// CHECK: %[[ADD:.*]] = addi %[[LOAD]], %[[CAST]] : i32 +// CHECK: %[[ADD:.*]] = addi %[[LOAD]], %[[C4_I32]] : i32 // CHECK: memref.store %[[ADD]], %[[MEMREF]] // CHECK: } // CHECK: %[[HAS_MORE:.*]] = cmpi slt, %[[NEW_UB]], %[[UB]] // CHECK: scf.if %[[HAS_MORE]] { -// CHECK: %[[REM:.*]] = affine.min #[[MAP2]]()[%[[UB]]] +// CHECK: %[[REM:.*]] = affine.apply #[[MAP1]]()[%[[UB]]] // CHECK: %[[LOAD2:.*]] = memref.load %[[MEMREF]][] // CHECK: %[[CAST2:.*]] = index_cast %[[REM]] // CHECK: %[[ADD2:.*]] = addi %[[LOAD2]], %[[CAST2]] @@ -153,3 +145,75 @@ } return } + +// ----- + +// Test rewriting of affine.min ops. Make sure that more general cases than +// the ones above are successfully rewritten. Also make sure that the pattern +// does not rewrite affine.min ops that should not be rewritten. + +// CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 1)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1 - 1)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0)[s0, s1, s2] -> (s0, -d0 + s1, s2)> +// CHECK-DAG: #[[MAP4:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (s1 - (s1 - s0) mod s2))> +// CHECK-DAG: #[[MAP5:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (s1 - (s1 - s0) mod s2) + 1)> +// CHECK-DAG: #[[MAP6:.*]] = affine_map<()[s0, s1, s2] -> (s1 - (s1 - (s1 - s0) mod s2) - 1)> +// CHECK-DAG: #[[MAP7:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0, s2 - (s2 - (s2 - s1) mod s0), s3)> +// CHECK: func @test_affine_min_rewrite( +// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, +// CHECK-SAME: %[[MEMREF:.*]]: memref, %[[SOME_VAL:.*]]: index +// CHECK: scf.for %[[IV:.*]] = %[[LB]] to %{{.*}} step %[[STEP]] { +// CHECK: %[[RES2:.*]] = affine.apply #[[MAP1]]()[%[[STEP]]] +// CHECK: %[[RES3:.*]] = affine.min #[[MAP2]](%[[IV]])[%[[STEP]], %[[UB]]] +// CHECK: %[[RES4:.*]] = affine.min #map3(%[[IV]])[%[[STEP]], %[[UB]], %[[SOME_VAL]]] +// CHECK: memref.store %[[STEP]] +// CHECK: memref.store %[[STEP]] +// CHECK: memref.store %[[RES2]] +// CHECK: memref.store %[[RES3]] +// CHECK: memref.store %[[RES4]] +// CHECK: } +// CHECK: scf.if {{.*}} { +// CHECK: %[[RES_IF_0:.*]] = affine.apply #[[MAP4]]()[%[[LB]], %[[UB]], %[[STEP]]] +// CHECK: %[[RES_IF_1:.*]] = affine.apply #[[MAP5]]()[%[[LB]], %[[UB]], %[[STEP]]] +// CHECK: %[[RES_IF_2:.*]] = affine.apply #[[MAP5]]()[%[[LB]], %[[UB]], %[[STEP]]] +// CHECK: %[[RES_IF_3:.*]] = affine.apply #[[MAP6]]()[%[[LB]], %[[UB]], %[[STEP]]] +// CHECK: %[[RES_IF_4:.*]] = affine.min #[[MAP7]]()[%[[STEP]], %[[LB]], %[[UB]], %[[SOME_VAL]]] +// CHECK: memref.store %[[RES_IF_0]] +// CHECK: memref.store %[[RES_IF_1]] +// CHECK: memref.store %[[RES_IF_2]] +// CHECK: memref.store %[[RES_IF_3]] +// CHECK: memref.store %[[RES_IF_4]] +#map0 = affine_map<(d0, d1)[s0] -> (s0, d0 - d1)> +#map1 = affine_map<(d0, d1)[s0] -> (d0 - d1 + 1, s0)> +#map2 = affine_map<(d0, d1)[s0] -> (s0 + 1, d0 - d1 + 1)> +#map3 = affine_map<(d0, d1)[s0] -> (s0, d0 - d1 - 1)> +#map4 = affine_map<(d0, d1, d2)[s0] -> (s0, d0 - d1, d2)> +func @test_affine_min_rewrite(%lb : index, %ub: index, + %step: index, %d : memref, + %some_val: index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + scf.for %iv = %lb to %ub step %step { + // Most common case: Rewrite min(%ub - %iv, %step) to %step. + %m0 = affine.min #map0(%ub, %iv)[%step] + // Increase %ub - %iv a little bit, pattern should still apply. + %m1 = affine.min #map1(%ub, %iv)[%step] + // Rewrite min(%ub - %iv + 1, %step + 1) to %step + 1. + %m2 = affine.min #map2(%ub, %iv)[%step] + // min(%ub - %iv - 1, %step) cannot be simplified because %ub - %iv - 1 + // can be smaller than %step. (Can be simplified in if-statement.) + %m3 = affine.min #map3(%ub, %iv)[%step] + // min(%ub - %iv, %step, %some_val) cannot be simplified because the range + // of %some_val is unknown. + %m4 = affine.min #map4(%ub, %iv, %some_val)[%step] + memref.store %m0, %d[%c0] : memref + memref.store %m1, %d[%c1] : memref + memref.store %m2, %d[%c2] : memref + memref.store %m3, %d[%c3] : memref + memref.store %m4, %d[%c4] : memref + } + return +}