diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h @@ -49,20 +49,9 @@ /// maximally compose chains of AffineApplyOps. FailureOr decompose(RewriterBase &rewriter, AffineApplyOp op); -/// Reify a bound for the given index-typed value or shape dimension size in -/// terms of the owning op's operands. `dim` must be `nullopt` if and only if -/// `value` is index-typed. -/// -/// By default, lower/equal bounds are closed and upper bounds are open. If -/// `closedUB` is set to "true", upper bounds are also closed. -FailureOr reifyValueBound(OpBuilder &b, Location loc, - presburger::BoundType type, Value value, - std::optional dim, - bool closedUB = false); - -/// Reify a bound for the given index-typed value or shape dimension size in -/// terms of SSA values for which `stopCondition` is met. `dim` must be -/// `nullopt` if and only if `value` is index-typed. +/// Reify a bound for the given index-typed value in terms of SSA values for +/// which `stopCondition` is met. If no stop condition is specified, reify in +/// terms of the operands of the owner op. /// /// By default, lower/equal bounds are closed and upper bounds are open. If /// `closedUB` is set to "true", upper bounds are also closed. @@ -77,11 +66,22 @@ /// is an EQ bound for %1. /// * Otherwise, if the owners of %a, %b or %c do not implement the /// ValueBoundsOpInterface, no bound can be computed. -FailureOr -reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, - Value value, std::optional dim, - ValueBoundsConstraintSet::StopConditionFn stopCondition, - bool closedUB = false); +FailureOr reifyIndexValueBound( + OpBuilder &b, Location loc, presburger::BoundType type, Value value, + ValueBoundsConstraintSet::StopConditionFn stopCondition = nullptr, + bool closedUB = false); + +/// Reify a bound for the specified dimension of the given shaped value in terms +/// of SSA values for which `stopCondition` is met. If no stop condition is +/// specified, reify in terms of the operands of the owner op. +/// +/// By default, lower/equal bounds are closed and upper bounds are open. If +/// `closedUB` is set to "true", upper bounds are also closed. +FailureOr reifyShapedValueDimBound( + OpBuilder &b, Location loc, presburger::BoundType type, Value value, + int64_t dim, + ValueBoundsConstraintSet::StopConditionFn stopCondition = nullptr, + bool closedUB = false); } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -15,25 +15,11 @@ using namespace mlir; -FailureOr -mlir::reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, - Value value, std::optional dim, bool closedUB) { - // We are trying to reify a bound for `value`. Construct a stop condition that - // evaluates to "true" for any SSA value expect for `value`. I.e., the bound - // will be computed in terms of any SSA values except for `value`. The first - // such values are operands of the owner of `value`. - auto stopCondition = [&](Value v, std::optional d) { - // Reify in terms of SSA values that are different from `value`. - return v != value; - }; - return reifyValueBound(b, loc, type, value, dim, stopCondition, closedUB); -} - -FailureOr mlir::reifyValueBound( - OpBuilder &b, Location loc, presburger::BoundType type, Value value, - std::optional dim, - function_ref)> stopCondition, - bool closedUB) { +static FailureOr +reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, + Value value, std::optional dim, + function_ref)> stopCondition, + bool closedUB) { // Compute bound. AffineMap boundMap; ValueDimList mapOperands; @@ -85,3 +71,31 @@ return static_cast( b.create(loc, boundMap, operands).getResult()); } + +FailureOr mlir::reifyShapedValueDimBound( + OpBuilder &b, Location loc, presburger::BoundType type, Value value, + int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition, + bool closedUB) { + auto reifyToOperands = [&](Value v, std::optional d) { + // We are trying to reify a bound for `value` in terms of the owning op's + // operands. Construct a stop condition that evaluates to "true" for any SSA + // value except for `value`. I.e., the bound will be computed in terms of + // any SSA values except for `value`. The first such values are operands of + // the owner of `value`. + return v != value; + }; + return reifyValueBound(b, loc, type, value, dim, + stopCondition ? stopCondition : reifyToOperands, + closedUB); +} + +FailureOr mlir::reifyIndexValueBound( + OpBuilder &b, Location loc, presburger::BoundType type, Value value, + ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { + auto reifyToOperands = [&](Value v, std::optional d) { + return v != value; + }; + return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt, + stopCondition ? stopCondition : reifyToOperands, + closedUB); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -462,9 +462,9 @@ // of the enclosing loops. for (auto forOp : packingLoops) { // Compute an upper bound `ubVal` for the upper bound of `forOp`. - FailureOr loopUb = reifyValueBound( + FailureOr loopUb = reifyIndexValueBound( rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(), - /*dim=*/std::nullopt, /*stopCondition=*/ + /*stopCondition=*/ [&](Value v, std::optional d) { if (v == forOp.getUpperBound()) return false; diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -130,8 +130,13 @@ reified = FailureOr(rewriter.getIndexAttr(*reifiedConst)); } else { - reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value, - dim, stopCondition); + if (dim) { + reified = reifyShapedValueDimBound(rewriter, op->getLoc(), *boundType, + value, *dim, stopCondition); + } else { + reified = reifyIndexValueBound(rewriter, op->getLoc(), *boundType, + value, stopCondition); + } } if (failed(reified)) { op->emitOpError("could not reify bound");