diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -156,6 +156,35 @@ StopConditionFn stopCondition = nullptr, bool closedUB = false); + /// Compute a constant bound for the given affine map, where dims and symbols + /// are bound to the given operands. The affine map must have exactly one + /// result. + /// + /// This function traverses the backward slice of the given operands in a + /// worklist-driven manner until `stopCondition` evaluates to "true". The + /// constraint set is populated according to `ValueBoundsOpInterface` for each + /// visited value. (No constraints are added for values for which the stop + /// condition evaluates to "true".) + /// + /// The stop condition is optional: If none is specified, the backward slice + /// is traversed in a breadth-first manner until a constant bound could be + /// computed. + /// + /// By default, lower/equal bounds are closed and upper bounds are open. If + /// `closedUB` is set to "true", upper bounds are also closed. + static FailureOr computeConstantBound( + presburger::BoundType type, AffineMap map, ValueDimList mapOperands, + StopConditionFn stopCondition = nullptr, bool closedUB = false); + + /// Compute whether the given values/dimensions are equal. Return "failure" if + /// equality could not be determined. + /// + /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are + /// index-typed. + static FailureOr areEqual(Value value1, Value value2, + std::optional dim1 = std::nullopt, + std::optional dim2 = std::nullopt); + /// Add a bound for the given index-typed value or shaped value. This function /// returns a builder that adds the bound. BoundBuilder bound(Value value) { return BoundBuilder(*this, value); } @@ -199,13 +228,23 @@ int64_t getPos(Value value, std::optional dim = std::nullopt) const; /// Insert a value/dimension into the constraint set. If `isSymbol` is set to - /// "false", a dimension is added. + /// "false", a dimension is added. The value/dimension is added to the + /// worklist. /// /// Note: There are certain affine restrictions wrt. dimensions. E.g., they /// cannot be multiplied. Furthermore, bounds can only be queried for /// dimensions but not for symbols. int64_t insert(Value value, std::optional dim, bool isSymbol = true); + /// Insert an anonymous column into the constraint set. The column is not + /// bound to any value/dimension. If `isSymbol` is set to "false", a dimension + /// is added. + /// + /// Note: There are certain affine restrictions wrt. dimensions. E.g., they + /// cannot be multiplied. Furthermore, bounds can only be queried for + /// dimensions but not for symbols. + int64_t insert(bool isSymbol = true); + /// Project out the given column in the constraint set. void projectOut(int64_t pos); @@ -213,7 +252,7 @@ void projectOut(function_ref condition); /// Mapping of columns to values/shape dimensions. - SmallVector positionToValueDim; + SmallVector> positionToValueDim; /// Reverse mapping of values/shape dimensions to columns. DenseMap valueDimToPosition; diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -124,12 +124,24 @@ positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim); // Update reverse mapping. for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) - valueDimToPosition[positionToValueDim[i]] = i; + if (positionToValueDim[i].has_value()) + valueDimToPosition[*positionToValueDim[i]] = i; worklist.push(pos); return pos; } +int64_t ValueBoundsConstraintSet::insert(bool isSymbol) { + int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol) + : cstr.appendVar(VarKind::SetDim); + positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt); + // Update reverse mapping. + for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) + if (positionToValueDim[i].has_value()) + valueDimToPosition[*positionToValueDim[i]] = i; + return pos; +} + int64_t ValueBoundsConstraintSet::getPos(Value value, std::optional dim) const { #ifndef NDEBUG @@ -155,7 +167,9 @@ while (!worklist.empty()) { int64_t pos = worklist.front(); worklist.pop(); - ValueDim valueDim = positionToValueDim[pos]; + assert(positionToValueDim[pos].has_value() && + "did not expect std::nullopt on worklist"); + ValueDim valueDim = *positionToValueDim[pos]; Value value = valueDim.first; int64_t dim = valueDim.second; @@ -191,20 +205,24 @@ assert(pos >= 0 && pos < static_cast(positionToValueDim.size()) && "invalid position"); cstr.projectOut(pos); - bool erased = valueDimToPosition.erase(positionToValueDim[pos]); - (void)erased; - assert(erased && "inconsistent reverse mapping"); + if (positionToValueDim[pos].has_value()) { + bool erased = valueDimToPosition.erase(*positionToValueDim[pos]); + (void)erased; + assert(erased && "inconsistent reverse mapping"); + } positionToValueDim.erase(positionToValueDim.begin() + pos); // Update reverse mapping. for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i) - valueDimToPosition[positionToValueDim[i]] = i; + if (positionToValueDim[i].has_value()) + valueDimToPosition[*positionToValueDim[i]] = i; } void ValueBoundsConstraintSet::projectOut( function_ref condition) { int64_t nextPos = 0; while (nextPos < static_cast(positionToValueDim.size())) { - if (condition(positionToValueDim[nextPos])) { + if (positionToValueDim[nextPos].has_value() && + condition(*positionToValueDim[nextPos])) { projectOut(nextPos); // The column was projected out so another column is now at that position. // Do not increase the counter. @@ -332,7 +350,9 @@ replacementSymbols.push_back(b.getAffineSymbolExpr(numSymbols++)); } - ValueBoundsConstraintSet::ValueDim valueDim = cstr.positionToValueDim[i]; + assert(cstr.positionToValueDim[i].has_value() && + "cannot build affine map in terms of anonymous column"); + ValueBoundsConstraintSet::ValueDim valueDim = *cstr.positionToValueDim[i]; Value value = valueDim.first; int64_t dim = valueDim.second; if (dim == ValueBoundsConstraintSet::kIndexValue) { @@ -406,10 +426,35 @@ assertValidValueDim(value, dim); #endif // NDEBUG - // Process the backward slice of `value` (i.e., reverse use-def chain) until - // `stopCondition` is met. - ValueBoundsConstraintSet cstr(value.getContext()); - int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false); + AffineMap map = + AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, + Builder(value.getContext()).getAffineDimExpr(0)); + return computeConstantBound(type, map, {{value, dim}}, stopCondition, + closedUB); +} + +FailureOr ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType type, AffineMap map, ValueDimList operands, + StopConditionFn stopCondition, bool closedUB) { + assert(map.getNumResults() == 1 && "expected affine map with one result"); + ValueBoundsConstraintSet cstr(map.getContext()); + int64_t pos = cstr.insert(/*isSymbol=*/false); + + // Add map and operands to the constraint set. Dimensions are converted to + // symbols. All operands are added to the worklist. + auto mapper = [&](std::pair> v) { + return cstr.getExpr(v.first, v.second); + }; + SmallVector dimReplacements = llvm::to_vector( + llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper)); + SmallVector symReplacements = llvm::to_vector( + llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper)); + cstr.addBound( + presburger::BoundType::EQ, pos, + map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements)); + + // Process the backward slice of `operands` (i.e., reverse use-def chain) + // until `stopCondition` is met. if (stopCondition) { cstr.processWorklist(stopCondition); } else { @@ -428,6 +473,27 @@ return failure(); } +FailureOr +ValueBoundsConstraintSet::areEqual(Value value1, Value value2, + std::optional dim1, + std::optional dim2) { +#ifndef NDEBUG + assertValidValueDim(value1, dim1); + assertValidValueDim(value2, dim2); +#endif // NDEBUG + + // Subtract the two values/dimensions from each other. If the result is 0, + // both are equal. + Builder b(value1.getContext()); + AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, + b.getAffineDimExpr(0) - b.getAffineDimExpr(1)); + FailureOr bound = computeConstantBound( + presburger::BoundType::EQ, map, {{value1, dim1}, {value2, dim2}}); + if (failed(bound)) + return failure(); + return *bound == 0; +} + ValueBoundsConstraintSet::BoundBuilder & ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) { assert(!this->dim.has_value() && "dim was already set"); diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir @@ -156,3 +156,49 @@ %1 = "test.reify_bound"(%0) : (index) -> (index) return %1 : index } + +// ----- + +func.func @dynamic_dims_are_equal(%t: tensor) { + %c0 = arith.constant 0 : index + %dim0 = tensor.dim %t, %c0 : tensor + %dim1 = tensor.dim %t, %c0 : tensor + // expected-remark @below {{equal}} + "test.are_equal"(%dim0, %dim1) : (index, index) -> () + return +} + +// ----- + +func.func @dynamic_dims_are_different(%t: tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %t, %c0 : tensor + %val = arith.addi %dim0, %c1 : index + // expected-remark @below {{different}} + "test.are_equal"(%dim0, %val) : (index, index) -> () + return +} + +// ----- + +func.func @dynamic_dims_are_maybe_equal_1(%t: tensor) { + %c0 = arith.constant 0 : index + %c5 = arith.constant 5 : index + %dim0 = tensor.dim %t, %c0 : tensor + // expected-error @below {{could not determine equality}} + "test.are_equal"(%dim0, %c5) : (index, index) -> () + return +} + +// ----- + +func.func @dynamic_dims_are_maybe_equal_2(%t: tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %t, %c0 : tensor + %dim1 = tensor.dim %t, %c1 : tensor + // expected-error @below {{could not determine equality}} + "test.are_equal"(%dim0, %dim1) : (index, index) -> () + return +} diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -175,10 +175,38 @@ return failure(result.wasInterrupted()); } +/// Look for "test.are_equal" ops and emit errors/remarks. +static LogicalResult testEquality(func::FuncOp funcOp) { + IRRewriter rewriter(funcOp.getContext()); + WalkResult result = funcOp.walk([&](Operation *op) { + // Look for test.are_equal ops. + if (op->getName().getStringRef() == "test.are_equal") { + if (op->getNumOperands() != 2 || !op->getOperand(0).getType().isIndex() || + !op->getOperand(1).getType().isIndex()) { + op->emitOpError("invalid op"); + return WalkResult::skip(); + } + FailureOr equal = ValueBoundsConstraintSet::areEqual( + op->getOperand(0), op->getOperand(1)); + if (failed(equal)) { + op->emitError("could not determine equality"); + } else if (*equal) { + op->emitRemark("equal"); + } else { + op->emitRemark("different"); + } + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + void TestReifyValueBounds::runOnOperation() { if (failed( testReifyValueBounds(getOperation(), reifyToFuncArgs, useArithOps))) signalPassFailure(); + if (failed(testEquality(getOperation()))) + signalPassFailure(); } namespace mlir {