diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h @@ -27,6 +27,7 @@ namespace mlir { class AffineForOp; +class AffineValueMap; class Block; class Location; struct MemRefAccess; @@ -384,6 +385,13 @@ ArrayRef ops, SmallVectorImpl *surroundingLoops = nullptr); +/// Try to simplify the given affine.min or affine.max op to an affine map with +/// a single result and operands, taking into account the specified constraint +/// set. Return failure if no simplified version could be found. +FailureOr +simplifyConstrainedMinMaxOp(Operation *op, + FlatAffineValueConstraints constraints); + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_ANALYSIS_UTILS_H diff --git a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h @@ -39,70 +39,23 @@ using LoopMatcherFn = function_ref; -/// Try to canonicalize an min/max operations in the context of for `loops` with -/// a known range. +/// Try to canonicalize the given affine.min/max operation in the context of +/// for `loops` with a known range. /// -/// `map` is the body of the min/max operation and `operands` are the SSA values -/// that the dimensions and symbols are bound to; dimensions are listed first. -/// If `isMin`, the operation is a min operation; otherwise, a max operation. /// `loopMatcher` is used to retrieve loop bounds and the step size for a given /// iteration variable. /// /// Note: `loopMatcher` allows this function to be used with any "for loop"-like /// operation (scf.for, scf.parallel and even ops defined in other dialects). LogicalResult canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, Operation *op, - AffineMap map, ValueRange operands, - bool isMin, LoopMatcherFn loopMatcher); + LoopMatcherFn loopMatcher); -/// Attempt to canonicalize min/max operations by proving that their value is -/// bounded by the same lower and upper bound. In such cases, the operation can -/// be folded away. -/// -/// Bounds are computed by FlatAffineValueConstraints. Invariants required for -/// finding/proving bounds should be supplied via `constraints`. -/// -/// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`). -/// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in -/// case of `!isMin`) and bind it to `opBound`. SSA values that are used in -/// `op` but are not part of `constraints`, are added as extra symbols. -/// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that: -/// * If `isMin`: r_i >= opBound -/// * If `isMax`: r_i <= opBound -/// If this is the case, ub(op) == lb(op). -/// 4. Replace `op` with `opBound`. -/// -/// In summary, the following constraints are added throughout this function. -/// Note: `invar` are dimensions added by the caller to express the invariants. -/// (Showing only the case where `isMin`.) -/// -/// invar | op | opBound | r_i | extra syms... | const | eq/ineq -/// ------+-------+---------+-----+---------------+-------+------------------- -/// (various eq./ineq. constraining `invar`, added by the caller) -/// ... | 0 | 0 | 0 | 0 | ... | ... -/// ------+-------+---------+-----+---------------+-------+------------------- -/// (various ineq. constraining `op` in terms of `op` operands (`invar` and -/// extra `op` operands "extra syms" that are not in `invar`)). -/// ... | -1 | 0 | 0 | ... | ... | >= 0 -/// ------+-------+---------+-----+---------------+-------+------------------- -/// (set `opBound` to `op` upper bound in terms of `invar` and "extra syms") -/// ... | 0 | -1 | 0 | ... | ... | = 0 -/// ------+-------+---------+-----+---------------+-------+------------------- -/// (for each `op` map result r_i: set r_i to corresponding map result, -/// prove that r_i >= minOpUb via contradiction) -/// ... | 0 | 0 | -1 | ... | ... | = 0 -/// 0 | 0 | 1 | -1 | 0 | -1 | >= 0 -/// -FailureOr -canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map, - ValueRange operands, bool isMin, - FlatAffineValueConstraints constraints); - -/// Try to simplify a min/max operation `op` after loop peeling. This function -/// can simplify min/max operations such as (ub is the previous upper bound of -/// the unpeeled loop): +/// Try to simplify the given affine.min/max operation `op` after loop peeling. +/// This function can simplify min/max operations such as (ub is the previous +/// upper bound of the unpeeled loop): /// ``` /// #map = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)> -/// %r = affine.min #affine.min #map(%iv)[%step, %ub] +/// %r = affine.min #map(%iv)[%step, %ub] /// ``` /// and rewrites them into (in the case the peeled loop): /// ``` @@ -111,8 +64,7 @@ /// min/max operations inside the partial iteration are rewritten in a similar /// way. LogicalResult rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op, - AffineMap map, ValueRange operands, - bool isMin, Value iv, Value ub, Value step, + Value iv, Value ub, Value step, bool insideLoop); } // namespace scf diff --git a/mlir/lib/Dialect/Affine/Analysis/CMakeLists.txt b/mlir/lib/Dialect/Affine/Analysis/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/Analysis/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Analysis/CMakeLists.txt @@ -13,6 +13,7 @@ MLIRAnalysis MLIRCallInterfaces MLIRControlFlowInterfaces + MLIRDialectUtils MLIRInferTypeOpInterface MLIRSideEffectInterfaces MLIRPresburger diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/IntegerSet.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" @@ -1362,3 +1363,184 @@ assert(simplifiedSet && "guaranteed to succeed while roundtripping"); return simplifiedSet; } + +static void unpackOptionalValues(ArrayRef> source, + SmallVector &target) { + target = llvm::to_vector<4>(llvm::map_range(source, [](Optional val) { + return val.has_value() ? *val : Value(); + })); +} + +/// Bound an identifier `pos` in a given FlatAffineValueConstraints with +/// constraints drawn from an affine map. Before adding the constraint, the +/// dimensions/symbols of the affine map are aligned with `constraints`. +/// `operands` are the SSA Value operands used with the affine map. +/// Note: This function adds a new symbol column to the `constraints` for each +/// dimension/symbol that exists in the affine map but not in `constraints`. +static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints, + IntegerPolyhedron::BoundType type, + unsigned pos, AffineMap map, + ValueRange operands) { + SmallVector dims, syms, newSyms; + unpackOptionalValues(constraints.getMaybeValues(VarKind::SetDim), dims); + unpackOptionalValues(constraints.getMaybeValues(VarKind::Symbol), syms); + + AffineMap alignedMap = + alignAffineMapWithValues(map, operands, dims, syms, &newSyms); + for (unsigned i = syms.size(); i < newSyms.size(); ++i) + constraints.appendSymbolVar(newSyms[i]); + return constraints.addBound(type, pos, alignedMap); +} + +/// Add `val` to each result of `map`. +static AffineMap addConstToResults(AffineMap map, int64_t val) { + SmallVector newResults; + for (AffineExpr r : map.getResults()) + newResults.push_back(r + val); + return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults, + map.getContext()); +} + +// Attempt to simplify the given min/max operation by proving that its value is +// bounded by the same lower and upper bound. +// +// Bounds are computed by FlatAffineValueConstraints. Invariants required for +// finding/proving bounds should be supplied via `constraints`. +// +// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`). +// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in +// case of `!isMin`) and bind it to `opBound`. SSA values that are used in +// `op` but are not part of `constraints`, are added as extra symbols. +// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that: +// * If `isMin`: r_i >= opBound +// * If `isMax`: r_i <= opBound +// If this is the case, ub(op) == lb(op). +// 4. Replace `op` with `opBound`. +// +// In summary, the following constraints are added throughout this function. +// Note: `invar` are dimensions added by the caller to express the invariants. +// (Showing only the case where `isMin`.) +// +// invar | op | opBound | r_i | extra syms... | const | eq/ineq +// ------+-------+---------+-----+---------------+-------+------------------- +// (various eq./ineq. constraining `invar`, added by the caller) +// ... | 0 | 0 | 0 | 0 | ... | ... +// ------+-------+---------+-----+---------------+-------+------------------- +// (various ineq. constraining `op` in terms of `op` operands (`invar` and +// extra `op` operands "extra syms" that are not in `invar`)). +// ... | -1 | 0 | 0 | ... | ... | >= 0 +// ------+-------+---------+-----+---------------+-------+------------------- +// (set `opBound` to `op` upper bound in terms of `invar` and "extra syms") +// ... | 0 | -1 | 0 | ... | ... | = 0 +// ------+-------+---------+-----+---------------+-------+------------------- +// (for each `op` map result r_i: set r_i to corresponding map result, +// prove that r_i >= minOpUb via contradiction) +// ... | 0 | 0 | -1 | ... | ... | = 0 +// 0 | 0 | 1 | -1 | 0 | -1 | >= 0 +// +FailureOr +mlir::simplifyConstrainedMinMaxOp(Operation *op, + FlatAffineValueConstraints constraints) { + bool isMin = isa(op); + assert((isMin || isa(op)) && "expect AffineMin/MaxOp"); + MLIRContext *ctx = op->getContext(); + Builder builder(ctx); + AffineMap map = + isMin ? cast(op).getMap() : cast(op).getMap(); + ValueRange operands = op->getOperands(); + unsigned numResults = map.getNumResults(); + + // Add a few extra dimensions. + unsigned dimOp = constraints.appendDimVar(); // `op` + unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound + unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults); + + // Add an inequality for each result expr_i of map: + // isMin: op <= expr_i, !isMin: op >= expr_i + auto boundType = isMin ? IntegerPolyhedron::UB : IntegerPolyhedron::LB; + // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.) + AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map; + if (failed( + alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands))) + return failure(); + + // Try to compute a lower/upper bound for op, expressed in terms of the other + // `dims` and extra symbols. + SmallVector opLb(1), opUb(1); + constraints.getSliceBounds(dimOp, 1, ctx, &opLb, &opUb); + AffineMap sliceBound = isMin ? opUb[0] : opLb[0]; + // TODO: `getSliceBounds` may return multiple bounds at the moment. This is + // a TODO of `getSliceBounds` and not handled here. + if (!sliceBound || sliceBound.getNumResults() != 1) + return failure(); // No or multiple bounds found. + // Recover the inclusive UB in the case of an `affine.min`. + AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound; + + // Add an equality: Set dimOpBound to computed bound. + // Add back dimension for op. (Was removed by `getSliceBounds`.) + AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp); + if (failed(constraints.addBound(IntegerPolyhedron::EQ, dimOpBound, + alignedBoundMap))) + return failure(); + + // If the constraint system is empty, there is an inconsistency. (E.g., this + // can happen if loop lb > ub.) + if (constraints.isEmpty()) + return failure(); + + // In the case of `isMin` (`!isMin` is inversed): + // Prove that each result of `map` has a lower bound that is equal to (or + // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op` + // can be replaced with the bound. I.e., prove that for each result + // expr_i (represented by dimension r_i): + // + // r_i >= opBound + // + // To prove this inequality, add its negation to the constraint set and prove + // that the constraint set is empty. + for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) { + FlatAffineValueConstraints newConstr(constraints); + + // Add an equality: r_i = expr_i + // Note: These equalities could have been added earlier and used to express + // minOp <= expr_i. However, then we run the risk that `getSliceBounds` + // computes minOpUb in terms of r_i dims, which is not desired. + if (failed(alignAndAddBound(newConstr, IntegerPolyhedron::EQ, i, + map.getSubMap({i - resultDimStart}), operands))) + return failure(); + + // If `isMin`: Add inequality: r_i < opBound + // equiv.: opBound - r_i - 1 >= 0 + // If `!isMin`: Add inequality: r_i > opBound + // equiv.: -opBound + r_i - 1 >= 0 + SmallVector ineq(newConstr.getNumCols(), 0); + ineq[dimOpBound] = isMin ? 1 : -1; + ineq[i] = isMin ? -1 : 1; + ineq[newConstr.getNumCols() - 1] = -1; + newConstr.addInequality(ineq); + if (!newConstr.isEmpty()) + return failure(); + } + + // Lower and upper bound of `op` are equal. Replace `minOp` with its bound. + AffineMap newMap = alignedBoundMap; + SmallVector newOperands; + unpackOptionalValues(constraints.getMaybeValues(), newOperands); + // If dims/symbols have known constant values, use those in order to simplify + // the affine map further. + for (int64_t i = 0, e = constraints.getNumVars(); i < e; ++i) { + // Skip unused operands and operands that are already constants. + if (!newOperands[i] || getConstantIntValue(newOperands[i])) + continue; + if (auto bound = constraints.getConstantBound64(IntegerPolyhedron::EQ, i)) { + AffineExpr expr = + i < newMap.getNumDims() + ? builder.getAffineDimExpr(i) + : builder.getAffineSymbolExpr(i - newMap.getNumDims()); + newMap = newMap.replace(expr, builder.getAffineConstantExpr(*bound), + newMap.getNumDims(), newMap.getNumSymbols()); + } + } + mlir::canonicalizeMapAndOperands(&newMap, &newOperands); + return AffineValueMap(newMap, newOperands); +} diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -152,7 +152,7 @@ /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for /// and scf.parallel loops with a known range. -template +template struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -192,8 +192,7 @@ return failure(); }; - return scf::canonicalizeMinMaxOpInLoop( - rewriter, op, op.getAffineMap(), op.getOperands(), IsMin, loopMatcher); + return scf::canonicalizeMinMaxOpInLoop(rewriter, op, loopMatcher); } }; @@ -214,8 +213,8 @@ RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); patterns - .add, - AffineOpSCFCanonicalizationPattern, + .add, + AffineOpSCFCanonicalizationPattern, DimOfIterArgFolder, DimOfIterArgFolder, DimOfLoopResultFolder, DimOfLoopResultFolder>(ctx); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -154,7 +154,6 @@ return success(); } -template static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp, ForOp partialIteration, Value previousUb) { @@ -164,18 +163,20 @@ "expected same step in main and partial loop"); Value step = forOp.getStep(); - forOp.walk([&](OpTy affineOp) { - AffineMap map = affineOp.getAffineMap(); - (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, - affineOp.getOperands(), IsMin, mainIv, - previousUb, step, + forOp.walk([&](Operation *affineOp) { + if (!isa(affineOp)) + return WalkResult::advance(); + (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, mainIv, previousUb, + step, /*insideLoop=*/true); + return WalkResult::advance(); }); - partialIteration.walk([&](OpTy affineOp) { - AffineMap map = affineOp.getAffineMap(); - (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, - affineOp.getOperands(), IsMin, partialIv, - previousUb, step, /*insideLoop=*/false); + partialIteration.walk([&](Operation *affineOp) { + if (!isa(affineOp)) + return WalkResult::advance(); + (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, partialIv, previousUb, + step, /*insideLoop=*/false); + return WalkResult::advance(); }); } @@ -188,10 +189,7 @@ return failure(); // Rewrite affine.min and affine.max ops. - rewriteAffineOpAfterPeeling( - rewriter, forOp, partialIteration, previousUb); - rewriteAffineOpAfterPeeling( - rewriter, forOp, partialIteration, previousUb); + rewriteAffineOpAfterPeeling(rewriter, forOp, partialIteration, previousUb); return success(); } diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp --- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp @@ -12,7 +12,9 @@ #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" @@ -25,139 +27,17 @@ using namespace mlir; using namespace presburger; -static void unpackOptionalValues(ArrayRef> source, - SmallVector &target) { - target = llvm::to_vector<4>(llvm::map_range(source, [](Optional val) { - return val.has_value() ? *val : Value(); - })); -} - -/// Bound an identifier `pos` in a given FlatAffineValueConstraints with -/// constraints drawn from an affine map. Before adding the constraint, the -/// dimensions/symbols of the affine map are aligned with `constraints`. -/// `operands` are the SSA Value operands used with the affine map. -/// Note: This function adds a new symbol column to the `constraints` for each -/// dimension/symbol that exists in the affine map but not in `constraints`. -static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints, - IntegerPolyhedron::BoundType type, - unsigned pos, AffineMap map, - ValueRange operands) { - SmallVector dims, syms, newSyms; - unpackOptionalValues(constraints.getMaybeValues(VarKind::SetDim), dims); - unpackOptionalValues(constraints.getMaybeValues(VarKind::Symbol), syms); - - AffineMap alignedMap = - alignAffineMapWithValues(map, operands, dims, syms, &newSyms); - for (unsigned i = syms.size(); i < newSyms.size(); ++i) - constraints.appendSymbolVar(newSyms[i]); - return constraints.addBound(type, pos, alignedMap); -} - -/// Add `val` to each result of `map`. -static AffineMap addConstToResults(AffineMap map, int64_t val) { - SmallVector newResults; - for (AffineExpr r : map.getResults()) - newResults.push_back(r + val); - return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults, - map.getContext()); -} - -FailureOr -scf::canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map, - ValueRange operands, bool isMin, - FlatAffineValueConstraints constraints) { +static FailureOr +canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, + FlatAffineValueConstraints constraints) { RewriterBase::InsertionGuard guard(rewriter); - unsigned numResults = map.getNumResults(); - - // Add a few extra dimensions. - unsigned dimOp = constraints.appendDimVar(); // `op` - unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound - unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults); - - // Add an inequality for each result expr_i of map: - // isMin: op <= expr_i, !isMin: op >= expr_i - auto boundType = isMin ? IntegerPolyhedron::UB : IntegerPolyhedron::LB; - // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.) - AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map; - if (failed( - alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands))) - return failure(); - - // Try to compute a lower/upper bound for op, expressed in terms of the other - // `dims` and extra symbols. - SmallVector opLb(1), opUb(1); - constraints.getSliceBounds(dimOp, 1, rewriter.getContext(), &opLb, &opUb); - AffineMap sliceBound = isMin ? opUb[0] : opLb[0]; - // TODO: `getSliceBounds` may return multiple bounds at the moment. This is - // a TODO of `getSliceBounds` and not handled here. - if (!sliceBound || sliceBound.getNumResults() != 1) - return failure(); // No or multiple bounds found. - // Recover the inclusive UB in the case of an `affine.min`. - AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound; - - // Add an equality: Set dimOpBound to computed bound. - // Add back dimension for op. (Was removed by `getSliceBounds`.) - AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp); - if (failed(constraints.addBound(IntegerPolyhedron::EQ, dimOpBound, - alignedBoundMap))) - return failure(); - - // If the constraint system is empty, there is an inconsistency. (E.g., this - // can happen if loop lb > ub.) - if (constraints.isEmpty()) - return failure(); - - // In the case of `isMin` (`!isMin` is inversed): - // Prove that each result of `map` has a lower bound that is equal to (or - // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op` - // can be replaced with the bound. I.e., prove that for each result - // expr_i (represented by dimension r_i): - // - // r_i >= opBound - // - // To prove this inequality, add its negation to the constraint set and prove - // that the constraint set is empty. - for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) { - FlatAffineValueConstraints newConstr(constraints); - - // Add an equality: r_i = expr_i - // Note: These equalities could have been added earlier and used to express - // minOp <= expr_i. However, then we run the risk that `getSliceBounds` - // computes minOpUb in terms of r_i dims, which is not desired. - if (failed(alignAndAddBound(newConstr, IntegerPolyhedron::EQ, i, - map.getSubMap({i - resultDimStart}), operands))) - return failure(); - - // If `isMin`: Add inequality: r_i < opBound - // equiv.: opBound - r_i - 1 >= 0 - // If `!isMin`: Add inequality: r_i > opBound - // equiv.: -opBound + r_i - 1 >= 0 - SmallVector ineq(newConstr.getNumCols(), 0); - ineq[dimOpBound] = isMin ? 1 : -1; - ineq[i] = isMin ? -1 : 1; - ineq[newConstr.getNumCols() - 1] = -1; - newConstr.addInequality(ineq); - if (!newConstr.isEmpty()) - return failure(); - } - - // Lower and upper bound of `op` are equal. Replace `minOp` with its bound. - AffineMap newMap = alignedBoundMap; - SmallVector newOperands; - unpackOptionalValues(constraints.getMaybeValues(), newOperands); - // If dims/symbols have known constant values, use those in order to simplify - // the affine map further. - for (int64_t i = 0, e = constraints.getNumVars(); i < e; ++i) { - // Skip unused operands and operands that are already constants. - if (!newOperands[i] || getConstantIntValue(newOperands[i])) - continue; - if (auto bound = constraints.getConstantBound64(IntegerPolyhedron::EQ, i)) - newOperands[i] = - rewriter.create(op->getLoc(), *bound); - } - mlir::canonicalizeMapAndOperands(&newMap, &newOperands); rewriter.setInsertionPoint(op); - return rewriter.replaceOpWithNewOp(op, newMap, newOperands); + FailureOr simplified = + mlir::simplifyConstrainedMinMaxOp(op, constraints); + if (failed(simplified)) + return failure(); + return rewriter.replaceOpWithNewOp( + op, simplified->getAffineMap(), simplified->getOperands()); } static LogicalResult @@ -231,14 +111,13 @@ /// Note: Due to limitations of IntegerPolyhedron, only constant step sizes /// are currently supported. LogicalResult scf::canonicalizeMinMaxOpInLoop(RewriterBase &rewriter, - Operation *op, AffineMap map, - ValueRange operands, bool isMin, + Operation *op, LoopMatcherFn loopMatcher) { FlatAffineValueConstraints constraints; DenseSet allIvs; // Find all iteration variables among `minOp`'s operands add constrain them. - for (Value operand : operands) { + for (Value operand : op->getOperands()) { // Skip duplicate ivs. if (llvm::is_contained(allIvs, operand)) continue; @@ -256,12 +135,12 @@ return failure(); } - return canonicalizeMinMaxOp(rewriter, op, map, operands, isMin, constraints); + return canonicalizeMinMaxOp(rewriter, op, constraints); } -/// Try to simplify a min/max operation `op` after loop peeling. This function -/// can simplify min/max operations such as (ub is the previous upper bound of -/// the unpeeled loop): +/// Try to simplify the given affine.min/max operation `op` after loop peeling. +/// This function can simplify min/max operations such as (ub is the previous +/// upper bound of the unpeeled loop): /// ``` /// #map = affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)> /// %r = affine.min #affine.min #map(%iv)[%step, %ub] @@ -285,9 +164,8 @@ /// affine.min ops inside the partial iteration. For an explanation of the other /// parameters, see comment of `canonicalizeMinMaxOpInLoop`. LogicalResult scf::rewritePeeledMinMaxOp(RewriterBase &rewriter, Operation *op, - AffineMap map, ValueRange operands, - bool isMin, Value iv, Value ub, - Value step, bool insideLoop) { + Value iv, Value ub, Value step, + bool insideLoop) { FlatAffineValueConstraints constraints; constraints.appendDimVar({iv, ub, step}); if (auto constUb = getConstantIntValue(ub)) @@ -311,5 +189,5 @@ constraints.addInequality({1, -1, 1, -1}); } - return canonicalizeMinMaxOp(rewriter, op, map, operands, isMin, constraints); + return canonicalizeMinMaxOp(rewriter, op, constraints); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2641,6 +2641,7 @@ ":AffineDialect", ":Analysis", ":ArithDialect", + ":DialectUtils", ":FuncDialect", ":IR", ":SideEffectInterfaces",