diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -237,10 +237,13 @@ /// Adds a lower or an upper bound for the identifier at the specified /// position with constraints being drawn from the specified bound map and /// operands. If `eq` is true, add a single equality equal to the bound map's - /// first result expr. + /// first result expr. By default, the bound map is fully composed with the + /// operands before adding any bounds. This allows for bounds being expressed + /// in terms of values that are used by the operands. LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap, ValueRange operands, bool eq, - bool lower = true); + bool lower = true, + bool composeMapAndOperands = true); /// Returns the bound for the identifier at `pos` from the inequality at /// `ineqPos` as a 1-d affine value map (affine map + operands). The returned @@ -334,7 +337,9 @@ /// symbols or loop IVs. The identifier is added to the end of the existing /// dims or symbols. Additional information on the identifier is extracted /// from the IR and added to the constraint system. - void addInductionVarOrTerminalSymbol(Value id); + /// Note: If `allowNonTerminal`, any symbol (incl. potentially non-terminal + /// ones) is allowed. + void addInductionVarOrTerminalSymbol(Value id, bool allowNonTerminal = false); /// Composes the affine value map with this FlatAffineConstrains, adding the /// results of the map as dimensions at the front [0, vMap->getNumResults()) diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -560,12 +560,13 @@ } } -void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) { +void FlatAffineConstraints::addInductionVarOrTerminalSymbol( + Value id, bool allowNonTerminal) { if (containsId(id)) return; // Caller is expected to fully compose map/operands if necessary. - assert((isTopLevelValue(id) || isForInductionVar(id)) && + assert((allowNonTerminal || isTopLevelValue(id) || isForInductionVar(id)) && "non-terminal symbol / loop IV expected"); // Outer loop IVs could be used in forOp's bounds. if (auto loop = getForInductionVarOwner(id)) { @@ -1941,10 +1942,9 @@ } } -LogicalResult -FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, - ValueRange boundOperands, bool eq, - bool lower) { +LogicalResult FlatAffineConstraints::addLowerOrUpperBound( + unsigned pos, AffineMap boundMap, ValueRange boundOperands, bool eq, + bool lower, bool composeMapAndOperands) { assert(pos < getNumDimAndSymbolIds() && "invalid position"); // Equality follows the logic of lower bound except that we add an equality // instead of an inequality. @@ -1956,11 +1956,13 @@ // transitively get to terminal symbols or loop IVs. auto map = boundMap; SmallVector operands(boundOperands.begin(), boundOperands.end()); - fullyComposeAffineMapAndOperands(&map, &operands); + if (composeMapAndOperands) + fullyComposeAffineMapAndOperands(&map, &operands); map = simplifyAffineMap(map); canonicalizeMapAndOperands(&map, &operands); for (auto operand : operands) - addInductionVarOrTerminalSymbol(operand); + addInductionVarOrTerminalSymbol( + operand, /*allowNonTerminal=*/!composeMapAndOperands); FlatAffineConstraints localVarCst; std::vector> flatExprs;