diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h @@ -50,14 +50,15 @@ /// Reify a bound for the given index-typed value or shape dimension size in /// terms of the owning op's operands. `dim` must be `nullopt` if and only if -/// `value` is index-typed. +/// `value` is index-typed. LB and EQ bounds are closed, UB bounds are open. FailureOr reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, Value value, std::optional dim); /// Reify a bound for the given index-typed value or shape dimension size in /// terms of SSA values for which `stopCondition` is met. `dim` must be -/// `nullopt` if and only if `value` is index-typed. +/// `nullopt` if and only if `value` is index-typed. LB and EQ bounds are +/// closed, UB bounds are open. FailureOr reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, Value value, std::optional dim, diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -39,6 +39,49 @@ }; }; +struct AffineMinOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto minOp = cast(op); + assert(value == minOp.getResult() && "invalid value"); + + // Align affine map results with dims/symbols in the constraint set. + for (AffineExpr expr : minOp.getAffineMap().getResults()) { + SmallVector dimReplacements = llvm::to_vector(llvm::map_range( + minOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); })); + SmallVector symReplacements = llvm::to_vector(llvm::map_range( + minOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); })); + AffineExpr bound = + expr.replaceDimsAndSymbols(dimReplacements, symReplacements); + AffineExpr openBound = bound + 1; + cstr.addBound(BoundType::UB, value, openBound); + } + }; +}; + +struct AffineMaxOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto maxOp = cast(op); + assert(value == maxOp.getResult() && "invalid value"); + + // Align affine map results with dims/symbols in the constraint set. + for (AffineExpr expr : maxOp.getAffineMap().getResults()) { + SmallVector dimReplacements = llvm::to_vector(llvm::map_range( + maxOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); })); + SmallVector symReplacements = llvm::to_vector(llvm::map_range( + maxOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); })); + AffineExpr bound = + expr.replaceDimsAndSymbols(dimReplacements, symReplacements); + cstr.addBound(BoundType::LB, value, bound); + } + }; +}; + } // namespace } // namespace mlir @@ -46,5 +89,7 @@ DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, AffineDialect *dialect) { AffineApplyOp::attachInterface(*ctx); + AffineMaxOp::attachInterface(*ctx); + AffineMinOp::attachInterface(*ctx); }); } diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -222,9 +222,6 @@ AffineMap &resultMap, ValueDimList &mapOperands, Location loc, presburger::BoundType type, Value value, std::optional dim, function_ref stopCondition) { - // Only EQ bounds are supported at the moment. - assert(type == BoundType::EQ && "unsupported bound type"); - // Process the backward slice of `value` (i.e., reverse use-def chain) until // `stopCondition` is met. ValueBoundsConstraintSet cstr( @@ -254,16 +251,32 @@ SmallVector lb(1), ub(1); cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub, /*getClosedUB=*/true); + // Note: There are TODOs in the implementation of `getSliceBounds`. In such a // case, no lower/upper bound can be computed at the moment. - if (lb.empty() || !lb[0] || ub.empty() || !ub[0] || - lb[0].getNumResults() != 1 || ub[0].getNumResults() != 1) + // EQ, UB bounds: upper bound is needed. + if ((type != BoundType::LB) && + (ub.empty() || !ub[0] || ub[0].getNumResults() != 1)) + return failure(); + // EQ, LB bounds: lower bound is needed. + if ((type != BoundType::UB) && + (lb.empty() || !lb[0] || lb[0].getNumResults() != 1)) return failure(); - // Look for same lower and upper bound: EQ bound. - if (ub[0] != lb[0]) + // EQ bound: lower and upper bound must match. + if (type == BoundType::EQ && ub[0] != lb[0]) return failure(); + AffineMap bound; + if (type == BoundType::EQ || + type == BoundType::LB) { + bound = lb[0]; + } else { + // Computed UB is a closed bound. Turn into an open bound. + bound = AffineMap::get(ub[0].getNumDims(), ub[0].getNumSymbols(), + ub[0].getResult(0) + 1); + } + // Gather all SSA values that are used in the computed bound. mapOperands.clear(); assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() && @@ -280,10 +293,10 @@ bool used = false; bool isDim = i < cstr.cstr.getNumDimVars(); if (isDim) { - if (lb[0].isFunctionOfDim(i)) + if (bound.isFunctionOfDim(i)) used = true; } else { - if (lb[0].isFunctionOfSymbol(i - cstr.cstr.getNumDimVars())) + if (bound.isFunctionOfSymbol(i - cstr.cstr.getNumDimVars())) used = true; } @@ -319,7 +332,7 @@ mapOperands.push_back(std::make_pair(value, dim)); } - resultMap = lb[0].replaceDimsAndSymbols(replacementDims, replacementSymbols, + resultMap = bound.replaceDimsAndSymbols(replacementDims, replacementSymbols, numDims, numSymbols); return success(); } diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir @@ -12,3 +12,49 @@ %1 = "test.reify_bound"(%0) : (index) -> (index) return %1 : index } + +// ----- + +// CHECK-LABEL: func @affine_max_lb( +// CHECK-SAME: %[[a:.*]]: index +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: return %[[c2]] +func.func @affine_max_lb(%a: index) -> (index) { + // Note: There are two LBs: s0 and 2. FlatAffineValueConstraints always + // returns the constant one at the moment. + %1 = affine.max affine_map<()[s0] -> (s0, 2)>()[%a] + %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index) + return %2 : index +} + +// ----- + +func.func @affine_max_ub(%a: index) -> (index) { + %1 = affine.max affine_map<()[s0] -> (s0, 2)>()[%a] + // expected-error @below{{could not reify bound}} + %2 = "test.reify_bound"(%1) {type = "UB"}: (index) -> (index) + return %2 : index +} + +// ----- + +// CHECK-LABEL: func @affine_min_ub( +// CHECK-SAME: %[[a:.*]]: index +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: return %[[c3]] +func.func @affine_min_ub(%a: index) -> (index) { + // Note: There are two UBs: s0 + 1 and 3. FlatAffineValueConstraints always + // returns the constant one at the moment. + %1 = affine.min affine_map<()[s0] -> (s0, 2)>()[%a] + %2 = "test.reify_bound"(%1) {type = "UB"}: (index) -> (index) + return %2 : index +} + +// ----- + +func.func @affine_min_lb(%a: index) -> (index) { + %1 = affine.min affine_map<()[s0] -> (s0, 2)>()[%a] + // expected-error @below{{could not reify bound}} + %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index) + return %2 : index +} diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -16,6 +16,7 @@ #define PASS_NAME "test-affine-reify-value-bounds" using namespace mlir; +using mlir::presburger::BoundType; namespace { @@ -45,6 +46,16 @@ } // namespace +FailureOr parseBoundType(std::string type) { + if (type == "EQ") + return BoundType::EQ; + if (type == "LB") + return BoundType::LB; + if (type == "UB") + return BoundType::UB; + return failure(); +} + /// Look for "test.reify_bound" ops in the input and replace their results with /// the reified values. static LogicalResult testReifyValueBounds(func::FuncOp funcOp, @@ -67,6 +78,17 @@ return WalkResult::skip(); } + // Get bound type. + std::string boundTypeStr = "EQ"; + if (auto boundTypeAttr = op->getAttrOfType("type")) + boundTypeStr = boundTypeAttr.str(); + auto boundType = parseBoundType(boundTypeStr); + if (failed(boundType)) { + op->emitOpError("invalid op"); + return WalkResult::interrupt(); + } + + // Get shape dimension (if any). auto dim = value.getType().isIndex() ? std::nullopt : std::make_optional( @@ -77,8 +99,8 @@ FailureOr reified; if (!reifyToFuncArgs) { // Reify in terms of the op's operands. - reified = reifyValueBound(rewriter, op->getLoc(), - presburger::BoundType::EQ, value, dim); + reified = + reifyValueBound(rewriter, op->getLoc(), *boundType, value, dim); } else { // Reify in terms of function block arguments. auto stopCondition = [](Value v) { @@ -88,9 +110,8 @@ return isa( bbArg.getParentBlock()->getParentOp()); }; - reified = - reifyValueBound(rewriter, op->getLoc(), presburger::BoundType::EQ, - value, dim, stopCondition); + reified = reifyValueBound(rewriter, op->getLoc(), *boundType, value, + dim, stopCondition); } if (failed(reified)) { op->emitOpError("could not reify bound");