diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -16,6 +16,8 @@ #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "llvm/ADT/SetVector.h" +#include + namespace mlir { using ValueDimList = SmallVector>>; @@ -100,6 +102,24 @@ std::optional dim, StopConditionFn stopCondition); + /// Compute a constant bound for the given index-typed value or shape + /// dimension size. + /// + /// `dim` must be `nullopt` if and only if `value` is index-typed. This + /// function traverses the backward slice of the given value in a + /// worklist-driven manner until `stopCondition` evaluates to "true". The + /// constraint set is populated according to `ValueBoundsOpInterface` for each + /// visited value. (No constraints are added for values for which the stop + /// condition evaluates to "true".) + /// + /// The stop condition is optional: If none is specified, the backward slice + /// is traversed in a breadth-first manner until a constant bound could be + /// computed. + static FailureOr + computeConstantBound(presburger::BoundType type, Value value, + std::optional dim = std::nullopt, + StopConditionFn stopCondition = nullptr); + /// Add a bound for the given index-typed value or shaped value. This function /// returns a builder that adds the bound. BoundBuilder bound(Value value) { return BoundBuilder(*this, value); } @@ -162,7 +182,7 @@ DenseMap valueDimToPosition; /// Worklist of values/shape dimensions that have not been processed yet. - SetVector worklist; + std::queue worklist; /// Constraint system of equalities and inequalities. FlatLinearConstraints cstr; diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -121,7 +121,7 @@ for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) valueDimToPosition[positionToValueDim[i]] = i; - worklist.insert(pos); + worklist.push(pos); return pos; } @@ -148,7 +148,8 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) { while (!worklist.empty()) { - int64_t pos = worklist.pop_back_val(); + int64_t pos = worklist.front(); + worklist.pop(); ValueDim valueDim = positionToValueDim[pos]; Value value = valueDim.first; int64_t dim = valueDim.second; @@ -337,6 +338,33 @@ return success(); } +FailureOr ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType type, Value value, std::optional dim, + StopConditionFn stopCondition) { +#ifndef NDEBUG + assertValidValueDim(value, dim); +#endif // NDEBUG + + // Process the backward slice of `value` (i.e., reverse use-def chain) until + // `stopCondition` is met. + ValueBoundsConstraintSet cstr(value, dim); + int64_t pos = cstr.getPos(value, dim); + if (stopCondition) { + cstr.processWorklist(stopCondition); + } else { + // No stop condition specified: Keep adding constraints until a bound could + // be computed. + cstr.processWorklist(/*stopCondition=*/[&](Value v) { + return cstr.cstr.getConstantBound64(type, pos).has_value(); + }); + } + + // Compute constant bound for `valueDim`. + if (auto bound = cstr.cstr.getConstantBound64(type, pos)) + return type == BoundType::UB ? *bound + 1 : *bound; + return failure(); +} + ValueBoundsConstraintSet::BoundBuilder & ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) { assert(!this->dim.has_value() && "dim was already set"); diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir --- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir @@ -26,6 +26,8 @@ // CHECK-LABEL: func @reify_slice_bound( // CHECK: %[[c5:.*]] = arith.constant 5 : index // CHECK: "test.some_use"(%[[c5]]) +// CHECK: %[[c5:.*]] = arith.constant 5 : index +// CHECK: "test.some_use"(%[[c5]]) func.func @reify_slice_bound(%t: tensor, %idx: index, %ub: index, %f: f32) { %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index @@ -33,8 +35,12 @@ %sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub] %slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor to tensor<1x?xi32> %filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32> + %bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index) "test.some_use"(%bound) : (index) -> () + + %bound_const = "test.reify_constant_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index) + "test.some_use"(%bound_const) : (index) -> () } return } @@ -77,6 +83,11 @@ %lb1_ub = "test.reify_bound"(%lb1) {type = "UB"} : (index) -> (index) "test.some_use"(%lb1_ub) : (index) -> () + // CHECK: %[[c129:.*]] = arith.constant 129 : index + // CHECK: "test.some_use"(%[[c129]]) + %lb1_ub_const = "test.reify_constant_bound"(%lb1) {type = "UB"} : (index) -> (index) + "test.some_use"(%lb1_ub_const) : (index) -> () + scf.for %iv1 = %lb1 to %ub1 step %c32 { // CHECK: %[[c32:.*]] = arith.constant 32 : index // CHECK: "test.some_use"(%[[c32]]) @@ -94,6 +105,11 @@ // CHECK: "test.some_use"(%[[c32]]) %matmul_ub = "test.reify_bound"(%matmul) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index) "test.some_use"(%matmul_ub) : (index) -> () + + // CHECK: %[[c32:.*]] = arith.constant 32 : index + // CHECK: "test.some_use"(%[[c32]]) + %matmul_ub_const = "test.reify_constant_bound"(%matmul) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index) + "test.some_use"(%matmul_ub_const) : (index) -> () } } } diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir @@ -80,6 +80,27 @@ // ----- +func.func @extract_slice_dynamic_constant(%t: tensor, %sz: index) -> index { + %0 = tensor.extract_slice %t[2][%sz][1] : tensor to tensor + // expected-error @below{{could not reify bound}} + %1 = "test.reify_constant_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @extract_slice_static_constant( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[c5:.*]] = arith.constant 5 : index +// CHECK: return %[[c5]] +func.func @extract_slice_static_constant(%t: tensor) -> index { + %0 = tensor.extract_slice %t[2][5][1] : tensor to tensor<5xf32> + %1 = "test.reify_constant_bound"(%0) {dim = 0} : (tensor<5xf32>) -> (index) + return %1 : index +} + +// ----- + // CHECK-LABEL: func @extract_slice_rank_reduce( // CHECK-SAME: %[[t:.*]]: tensor, %[[sz:.*]]: index // CHECK: return %[[sz]] diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -63,7 +63,8 @@ IRRewriter rewriter(funcOp.getContext()); WalkResult result = funcOp.walk([&](Operation *op) { // Look for test.reify_bound ops. - if (op->getName().getStringRef() == "test.reify_bound") { + if (op->getName().getStringRef() == "test.reify_bound" || + op->getName().getStringRef() == "test.reify_constant_bound") { if (op->getNumOperands() != 1 || op->getNumResults() != 1 || !op->getResultTypes()[0].isIndex()) { op->emitOpError("invalid op"); @@ -94,22 +95,37 @@ : std::make_optional( op->getAttrOfType("dim").getInt()); - // Reify value bound. - rewriter.setInsertionPointAfter(op); - FailureOr reified; - if (!reifyToFuncArgs) { - // Reify in terms of the op's operands. - reified = - reifyValueBound(rewriter, op->getLoc(), *boundType, value, dim); - } else { + // Check if a constant was requested. + bool constant = + op->getName().getStringRef() == "test.reify_constant_bound"; + + // Prepare stop condition. By default, reify in terms of the op's + // operands. No stop condition is used when a constant was requested. + std::function stopCondition = [&](Value v) { + // Reify in terms of SSA values that are different from `value`. + return v != value; + }; + if (reifyToFuncArgs) { // Reify in terms of function block arguments. - auto stopCondition = [](Value v) { + stopCondition = stopCondition = [](Value v) { auto bbArg = v.dyn_cast(); if (!bbArg) return false; return isa( bbArg.getParentBlock()->getParentOp()); }; + } + + // Reify value bound + rewriter.setInsertionPointAfter(op); + FailureOr reified = failure(); + if (constant) { + auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound( + *boundType, value, dim, /*stopCondition=*/nullptr); + if (succeeded(reifiedConst)) + reified = + FailureOr(rewriter.getIndexAttr(*reifiedConst)); + } else { reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value, dim, stopCondition); }