diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -237,10 +237,13 @@ /// Adds a lower or an upper bound for the identifier at the specified /// position with constraints being drawn from the specified bound map and /// operands. If `eq` is true, add a single equality equal to the bound map's - /// first result expr. + /// first result expr. By default, the bound map is fully composed with the + /// operands before adding any bounds. This allows for bounds being expressed + /// in terms of values that are used by the operands. LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap, ValueRange operands, bool eq, - bool lower = true); + bool lower = true, + bool composeMapAndOperands = true); /// Returns the bound for the identifier at `pos` from the inequality at /// `ineqPos` as a 1-d affine value map (affine map + operands). The returned diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1941,10 +1941,9 @@ } } -LogicalResult -FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, - ValueRange boundOperands, bool eq, - bool lower) { +LogicalResult FlatAffineConstraints::addLowerOrUpperBound( + unsigned pos, AffineMap boundMap, ValueRange boundOperands, bool eq, + bool lower, bool composeMapAndOperands) { assert(pos < getNumDimAndSymbolIds() && "invalid position"); // Equality follows the logic of lower bound except that we add an equality // instead of an inequality. @@ -1956,7 +1955,8 @@ // transitively get to terminal symbols or loop IVs. auto map = boundMap; SmallVector operands(boundOperands.begin(), boundOperands.end()); - fullyComposeAffineMapAndOperands(&map, &operands); + if (composeMapAndOperands) + fullyComposeAffineMapAndOperands(&map, &operands); map = simplifyAffineMap(map); canonicalizeMapAndOperands(&map, &operands); for (auto operand : operands)