diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h @@ -26,6 +26,7 @@ class RewritePatternSet; class RewriterBase; class Value; +class ValueBoundsConstraintSet; namespace presburger { enum class BoundType; @@ -71,7 +72,9 @@ FailureOr reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, Value value, std::optional dim, - function_ref)> stopCondition, + function_ref, + const ValueBoundsConstraintSet &)> + stopCondition, bool closedUB = false); /// Reify an already computed bound with Affine dialect ops. diff --git a/mlir/include/mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h @@ -9,10 +9,19 @@ #ifndef MLIR_DIALECT_ARITH_IR_VALUEBOUNDSOPINTERFACEIMPL_H #define MLIR_DIALECT_ARITH_IR_VALUEBOUNDSOPINTERFACEIMPL_H +#include + namespace mlir { class DialectRegistry; +class Value; +class ValueBoundsConstraintSet; namespace arith { +/// Populate bounds for an arith.select-like op. +void populateSelectLikeBounds(Value value, std::optional dim, + Value condition, Value trueValue, + Value falseValue, ValueBoundsConstraintSet &cstr); + void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace arith } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h @@ -18,6 +18,7 @@ class OpFoldResult; class Type; class Value; +class ValueBoundsConstraintSet; namespace presburger { enum class BoundType; @@ -54,7 +55,9 @@ FailureOr reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, Value value, std::optional dim, - function_ref)> stopCondition, + function_ref, + const ValueBoundsConstraintSet &cstr)> + stopCondition, bool closedUB = false, Type resultType = {}); } // namespace arith diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -20,6 +20,8 @@ namespace mlir { +class ValueBoundsConstraintSet; + using ValueDimList = SmallVector>>; /// A helper class to be used with `ValueBoundsOpInterface`. This class stores a @@ -43,7 +45,8 @@ /// The stop condition when traversing the backward slice of a shaped value/ /// integer-type value. The traversal continues until the stop condition /// evaluates to "true" for a value. - using StopConditionFn = function_ref)>; + using StopConditionFn = function_ref, const ValueBoundsConstraintSet &cstr)>; /// Compute a bound for the given integer-typed value or shape dimension size. /// The computed bound is stored in `resultMap`. The operands of the bound are @@ -133,6 +136,11 @@ /// Return an expression that represents a constant. AffineExpr getExpr(int64_t constant); + AffineExpr getAlignedExpr(AffineExpr expr, int64_t numDims, + int64_t numSymbols, ValueDimList valueDims); + + StopConditionFn getStopCondition() const; + protected: /// Dimension identifier to indicate a value is integer-typed. static constexpr int64_t kIntegerValue = -1; @@ -180,6 +188,8 @@ /// Builder for constructing affine expressions. Builder builder; + + StopConditionFn stopCondition; }; } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -18,18 +18,21 @@ FailureOr mlir::reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, Value value, std::optional dim, bool closedUB) { - auto stopCondition = [&](Value v, std::optional d) { + auto stopCondition = [&](Value v, std::optional d, + const ValueBoundsConstraintSet &cstr) { // Reify in terms of SSA values that are different from `value`. return v != value; }; return reifyValueBound(b, loc, type, value, dim, stopCondition, closedUB); } -FailureOr mlir::reifyValueBound( - OpBuilder &b, Location loc, presburger::BoundType type, Value value, - std::optional dim, - function_ref)> stopCondition, - bool closedUB) { +FailureOr +mlir::reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, + Value value, std::optional dim, + function_ref, + const ValueBoundsConstraintSet &)> + stopCondition, + bool closedUB) { // Compute bound. AffineMap boundMap; ValueDimList mapOperands; diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp @@ -14,6 +14,135 @@ using namespace mlir; using presburger::BoundType; +/// Try to align the first map with the second one. I.e., replace dims/symbols +/// in such a way that both have the same ValueDims at the same dim/symbol +/// positions. +static FailureOr computeAlignedExpr(Builder &b, AffineMap m1, + ValueDimList v1, AffineMap m2, + ValueDimList v2) { + if (v1.size() != v2.size()) + return failure(); + SmallVector dimReplacements, symbolReplacements; + for (int i = 0; i < m1.getNumDims(); ++i) { + auto it = llvm::find(v2, v1[i]); + if (!it) + return failure(); + int64_t pos = std::distance(v2.begin(), it); + if (pos < m2.getNumDims()) { + dimReplacements.push_back(b.getAffineDimExpr(pos)); + } else { + dimReplacements.push_back(b.getAffineSymbolExpr(pos - m2.getNumDims())); + } + } + for (int i = m1.getNumDims(); i < m1.getNumDims() + m1.getNumSymbols(); ++i) { + auto it = llvm::find(v2, v1[i]); + if (!it) + return failure(); + int64_t pos = std::distance(v2.begin(), it); + if (pos < m2.getNumDims()) { + symbolReplacements.push_back(b.getAffineDimExpr(pos)); + } else { + symbolReplacements.push_back( + b.getAffineSymbolExpr(pos - m2.getNumDims())); + } + } + + SmallVector results = + llvm::to_vector(llvm::map_range(m1.getResults(), [&](AffineExpr e) { + return e.replaceDimsAndSymbols(dimReplacements, symbolReplacements); + })); + return AffineMap::get(m2.getNumDims(), m2.getNumSymbols(), results, + b.getContext()); +} + +/// Return "true" if the two given bounds are equivalent. +static bool isEquivalentBound(AffineMap m1, ValueDimList v1, AffineMap m2, + ValueDimList v2) { + Builder b(m1.getContext()); + + // Only maps with one result are supported at the moment. + if (m1.getNumResults() != 1 || m2.getNumResults() != 1) + return false; + + // Align the first bound with the second one. + auto alignedMap = computeAlignedExpr(b, m1, v1, m2, v2); + if (failed(alignedMap)) + return false; + + // Compute "bound1 - bounds" and simplify. Both bounds are equivalent if this + // expression simplifies to "0". + AffineExpr sub = alignedMap->getResult(0) - m2.getResult(0); + AffineExpr simplified = + simplifyAffineExpr(sub, m2.getNumDims(), m2.getNumSymbols()); + return simplified == b.getAffineConstantExpr(0); +} + +void mlir::arith::populateSelectLikeBounds(Value value, + std::optional dim, + Value condition, Value trueValue, + Value falseValue, + ValueBoundsConstraintSet &cstr) { + auto addBound = [&](BoundType type, AffineExpr expr) { + if (dim.has_value()) { + cstr.addBound(type, value, *dim, expr); + } else { + cstr.addBound(type, value, expr); + } + }; + + // Case 1: Constant condition. + FailureOr cond = ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType::EQ, condition); + if (succeeded(cond)) { + if (*cond == 0) { + addBound(BoundType::EQ, cstr.getExpr(falseValue, dim)); + return; + } + assert(*cond == 1 && "expected 'true'"); + addBound(BoundType::EQ, cstr.getExpr(trueValue, dim)); + return; + } + + // Case 2: Both values have a constant bound. + auto findConstantBound = [&](BoundType type) { + auto trueBound = + ValueBoundsConstraintSet::computeConstantBound(type, trueValue, dim); + if (succeeded(trueBound)) { + auto falseBound = + ValueBoundsConstraintSet::computeConstantBound(type, falseValue, dim); + if (succeeded(falseBound)) { + // LB: Take the smaller bound. + // UB: Take the larger bound. + addBound(type, cstr.getExpr(type == BoundType::LB + ? std::min(*trueBound, *falseBound) + : std::max(*trueBound, *falseBound))); + return true; + } + } + return false; + }; + if (findConstantBound(BoundType::LB) && findConstantBound(BoundType::UB)) + return; + + // Case 3: Both values have the same bound. + AffineMap mapFalse, mapTrue; + ValueDimList mapFalseOperands, mapTrueOperands; + LogicalResult status = ValueBoundsConstraintSet::computeBound( + mapFalse, mapFalseOperands, BoundType::EQ, falseValue, dim, + cstr.getStopCondition()); + if (failed(status) || mapFalse.getNumResults() != 1) + return; + status = ValueBoundsConstraintSet::computeBound(mapTrue, mapTrueOperands, + BoundType::EQ, trueValue, dim, + mapFalseOperands); + if (failed(status) || mapTrue.getNumResults() != 1) + return; + if (isEquivalentBound(mapFalse, mapFalseOperands, mapTrue, mapTrueOperands)) + addBound(BoundType::EQ, + cstr.getAlignedExpr(mapFalse.getResult(0), mapFalse.getNumDims(), + mapFalse.getNumSymbols(), mapFalseOperands)); +} + namespace mlir { namespace arith { namespace { @@ -70,6 +199,28 @@ } }; +struct SelectOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto selectOp = cast(op); + assert(value == selectOp.getResult() && "invalid value"); + populateSelectLikeBounds(value, /*dim=*/std::nullopt, + selectOp.getCondition(), selectOp.getTrueValue(), + selectOp.getFalseValue(), cstr); + } + + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto selectOp = cast(op); + assert(value == selectOp.getResult() && "invalid value"); + populateSelectLikeBounds(value, dim, selectOp.getCondition(), + selectOp.getTrueValue(), selectOp.getFalseValue(), + cstr); + } +}; + } // namespace } // namespace arith } // namespace mlir @@ -79,6 +230,7 @@ registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { arith::AddIOp::attachInterface(*ctx); arith::ConstantOp::attachInterface(*ctx); + arith::SelectOp::attachInterface(*ctx); arith::SubIOp::attachInterface(*ctx); arith::MulIOp::attachInterface(*ctx); }); diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -66,7 +66,8 @@ FailureOr mlir::arith::reifyValueBound( OpBuilder &b, Location loc, presburger::BoundType type, Value value, std::optional dim, bool closedUB, Type resultType) { - auto stopCondition = [&](Value v, std::optional d) { + auto stopCondition = [&](Value v, std::optional d, + const ValueBoundsConstraintSet &cstr) { // Reify in terms of SSA values that are different from `value`. return v != value; }; @@ -77,8 +78,8 @@ FailureOr mlir::arith::reifyValueBound( OpBuilder &b, Location loc, presburger::BoundType type, Value value, std::optional dim, - function_ref)> stopCondition, - bool closedUB, Type resultType) { + ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB, + Type resultType) { // Default result type if not specified: same as value. if (!resultType) { resultType = dim.has_value() ? b.getIndexType() : value.getType(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -382,7 +382,8 @@ FailureOr loopUb = reifyValueBound( rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(), /*dim=*/std::nullopt, /*stopCondition=*/ - [&](Value v, std::optional d) { + [&](Value v, std::optional d, + const ValueBoundsConstraintSet &cstr) { if (v == forOp.getUpperBound()) return false; // Compute a bound that is independent of any affine op results. diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt @@ -11,6 +11,7 @@ LINK_LIBS PUBLIC MLIRArithDialect + MLIRArithValueBoundsOpInterfaceImpl MLIRBufferizationDialect MLIRControlFlowDialect MLIRIR diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" @@ -61,7 +62,8 @@ ValueDimList boundOperands; LogicalResult status = ValueBoundsConstraintSet::computeBound( bound, boundOperands, BoundType::EQ, yieldedValue, dim, - [&](Value v, std::optional d) { + [&](Value v, std::optional d, + const ValueBoundsConstraintSet &cstr) { // Stop when reaching a block argument of the loop body. if (auto bbArg = v.dyn_cast()) return bbArg.getOwner()->getParentOp() == forOp; @@ -114,6 +116,33 @@ } }; +struct IfOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIntegerValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto ifOp = cast(op); + Value condition = ifOp.getCondition(); + Value thenValue = + ifOp.thenYield().getOperand(value.cast().getResultNumber()); + Value elseValue = + ifOp.elseYield().getOperand(value.cast().getResultNumber()); + arith::populateSelectLikeBounds(value, /*dim=*/std::nullopt, condition, + thenValue, elseValue, cstr); + } + + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto ifOp = cast(op); + Value condition = ifOp.getCondition(); + Value thenValue = + ifOp.thenYield().getOperand(value.cast().getResultNumber()); + Value elseValue = + ifOp.elseYield().getOperand(value.cast().getResultNumber()); + arith::populateSelectLikeBounds(value, dim, condition, thenValue, elseValue, + cstr); + } +}; + } // namespace } // namespace scf } // namespace mlir @@ -122,5 +151,6 @@ DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { scf::ForOp::attachInterface(*ctx); + scf::IfOp::attachInterface(*ctx); }); } diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/ScopeExit.h" using namespace mlir; using presburger::BoundType; @@ -46,6 +47,7 @@ if (value.getType().isIntOrIndex()) { assert(!dim.has_value() && "invalid dim value"); } else if (auto shapedType = value.getType().dyn_cast()) { + assert(dim.has_value() && "invalid dim value"); assert(*dim >= 0 && "invalid dim value"); if (shapedType.hasRank()) assert(*dim < shapedType.getRank() && "invalid dim value"); @@ -139,6 +141,25 @@ return builder.getAffineConstantExpr(constant); } +AffineExpr ValueBoundsConstraintSet::getAlignedExpr(AffineExpr expr, + int64_t numDims, + int64_t numSymbols, + ValueDimList valueDims) { + SmallVector dimReplacements, symbolReplacements; + for (int i = 0; i < numDims; ++i) + dimReplacements.push_back(getExpr(valueDims[i].first, valueDims[i].second)); + for (int i = numDims; i < numDims + numSymbols; ++i) + symbolReplacements.push_back( + getExpr(valueDims[i].first, valueDims[i].second)); + return expr.replaceDimsAndSymbols(dimReplacements, symbolReplacements); +} + +ValueBoundsConstraintSet::StopConditionFn +ValueBoundsConstraintSet::getStopCondition() const { + assert(stopCondition && "stop condition not set"); + return stopCondition; +} + int64_t ValueBoundsConstraintSet::insert(ValueDim valueDim, bool isSymbol) { assert((valueDimToPosition.find(valueDim) == valueDimToPosition.end()) && "already mapped"); @@ -172,6 +193,10 @@ } void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) { + auto resetStopCondition = + llvm::make_scope_exit([&]() { this->stopCondition = nullptr; }); + this->stopCondition = stopCondition; + while (!worklist.empty()) { int64_t pos = worklist.front(); worklist.pop(); @@ -192,7 +217,7 @@ // Do not process any further if the stop condition is met. auto maybeDim = dim == kIntegerValue ? std::nullopt : std::make_optional(dim); - if (stopCondition(value, maybeDim)) + if (stopCondition(value, maybeDim, *this)) continue; // Query `ValueBoundsOpInterface` for constraints. New items may be added to @@ -241,15 +266,17 @@ bool closedUB) { #ifndef NDEBUG assertValidValueDim(value, dim); - assert(!stopCondition(value, dim) && - "stop condition should not be satisfied for starting point"); #endif // NDEBUG int64_t ubAdjustment = closedUB ? 0 : 1; Builder b(value.getContext()); mapOperands.clear(); - if (stopCondition(value, dim)) { + // Process the backward slice of `value` (i.e., reverse use-def chain) until + // `stopCondition` is met. + ValueDim valueDim = std::make_pair(value, dim.value_or(kIntegerValue)); + ValueBoundsConstraintSet cstr(valueDim); + if (stopCondition(value, dim, cstr)) { // Special case: If the stop condition is satisfied for the input // value/dimension, directly return it. mapOperands.push_back(std::make_pair(value, dim)); @@ -260,11 +287,6 @@ b.getAffineDimExpr(0)); return success(); } - - // Process the backward slice of `value` (i.e., reverse use-def chain) until - // `stopCondition` is met. - ValueDim valueDim = std::make_pair(value, dim.value_or(kIntegerValue)); - ValueBoundsConstraintSet cstr(valueDim); cstr.processWorklist(stopCondition); // Project out all variables (apart from `valueDim`) that do not match the @@ -275,7 +297,7 @@ return false; auto maybeDim = p.second == kIntegerValue ? std::nullopt : std::make_optional(p.second); - return !stopCondition(p.first, maybeDim); + return !stopCondition(p.first, maybeDim, cstr); }); // Compute lower and upper bounds for `valueDim`. @@ -380,7 +402,8 @@ bool closedUB) { return computeBound( resultMap, mapOperands, type, value, dim, - [&](Value v, std::optional d) { + [&](Value v, std::optional d, + const ValueBoundsConstraintSet &cstr) { return llvm::is_contained(dependencies, std::make_pair(v, d)); }, closedUB); @@ -416,7 +439,8 @@ // Reify bounds in terms of any independent values. return computeBound( resultMap, mapOperands, type, value, dim, - [&](Value v, std::optional d) { return isIndependent(v); }, + [&](Value v, std::optional d, + const ValueBoundsConstraintSet &cstr) { return isIndependent(v); }, closedUB); } @@ -438,7 +462,8 @@ // No stop condition specified: Keep adding constraints until a bound could // be computed cstr.processWorklist( - /*stopCondition=*/[&](Value v, std::optional dim) { + /*stopCondition=*/[&](Value v, std::optional dim, + const ValueBoundsConstraintSet &cstr) { return cstr.cstr.getConstantBound64(type, pos).has_value(); }); } diff --git a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Arith/value-bounds-op-interface-impl.mlir @@ -76,3 +76,66 @@ %0 = "test.reify_bound"(%c5) : (index) -> (index) return %0 : index } + +// ----- + +// CHECK-LABEL: func @arith_select_const( +// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index +// CHECK: return %[[b]] +func.func @arith_select_const(%a: index, %b: index) -> index { + %c0 = arith.constant 0 : i1 + %0 = arith.select %c0, %a, %b : index + %1 = "test.reify_bound"(%0) : (index) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @arith_select_same_bound( +// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: i1 +// CHECK: return %[[a]] +func.func @arith_select_same_bound(%a: index, %b: index, %c: i1) -> index { + %0 = arith.select %c, %a, %a : index + %1 = "test.reify_bound"(%0) : (index) -> (index) + return %1 : index +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)> +// CHECK-LABEL: func @arith_select_same_bound_2( +// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: i1 +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[b]], %[[a]]] +// CHECK: return %[[apply]] +func.func @arith_select_same_bound_2(%a: index, %b: index, %c: i1) -> index { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = arith.addi %a, %b : index + %1 = arith.addi %0, %a : index + %2 = arith.muli %c2, %a : index + %3 = arith.addi %2, %b : index + + %selected = arith.select %c, %1, %3 : index + %bound = "test.reify_bound"(%selected) {reify_to_func_args} + : (index) -> (index) + return %bound : index +} + +// ----- + +// CHECK-LABEL: func @arith_select_merge_const_bounds( +// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[c9:.*]] = arith.constant 9 : index +// CHECK: return %[[c5]], %[[c9]] +func.func @arith_select_merge_const_bounds(%a: tensor<5xf32>, %b: tensor<8xf32>, + %c: i1) -> (index, index) { + %0 = tensor.cast %a : tensor<5xf32> to tensor + %1 = tensor.cast %b : tensor<8xf32> to tensor + + %selected = arith.select %c, %0, %1 : tensor + %lb = "test.reify_bound"(%selected) {reify_to_func_args, type = "LB", dim = 0} + : (tensor) -> (index) + %ub = "test.reify_bound"(%selected) {reify_to_func_args, type = "UB", dim = 0} + : (tensor) -> (index) + return %lb, %ub : index, index +} diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -108,14 +108,17 @@ // Prepare stop condition. By default, reify in terms of the op's // operands. No stop condition is used when a constant was requested. - std::function)> stopCondition = - [&](Value v, std::optional d) { + std::function, + const ValueBoundsConstraintSet &cstr)> + stopCondition = [&](Value v, std::optional d, + const ValueBoundsConstraintSet &cstr) { // Reify in terms of SSA values that are different from `value`. return v != value; }; - if (reifyToFuncArgs) { + if (op->hasAttr("reify_to_func_args") || reifyToFuncArgs) { // Reify in terms of function block arguments. - stopCondition = stopCondition = [](Value v, std::optional d) { + stopCondition = [](Value v, std::optional d, + const ValueBoundsConstraintSet &cstr) { auto bbArg = v.dyn_cast(); if (!bbArg) return false; diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2905,6 +2905,7 @@ deps = [ ":ArithDialect", ":ArithUtils", + ":ArithValueBoundsOpInterfaceImpl", ":BufferizationDialect", ":ControlFlowDialect", ":ControlFlowInterfaces",