diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h @@ -28,14 +28,16 @@ class ValueBoundsConstraintSet { public: /// Reify a bound for the given index-typed value or shape dimension size in - /// terms of the owning op's operands. + /// terms of the owning op's operands. LB and EQ bounds are closed, UB bounds + /// are open. static FailureOr reifyBound(OpBuilder &b, Location loc, presburger::IntegerPolyhedron::BoundType type, Value value, int64_t dim = kIndexValue); /// Reify a bound for the given index-typed value or shape dimension size in - /// terms of SSA values for which `stopCondition` is met. + /// terms of SSA values for which `stopCondition` is met. LB and EQ bounds are + /// closed, UB bounds are open. static FailureOr reifyBound(OpBuilder &b, Location loc, presburger::IntegerPolyhedron::BoundType type, Value value, diff --git a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterface.cpp @@ -214,10 +214,6 @@ assertValidValueDim(value, dim); #endif // NDEBUG - // Only EQ bounds are supported at the moment. - assert(type == presburger::IntegerPolyhedron::BoundType::EQ && - "unsupported bound type"); - // Process the backward slice of `value` (i.e., reverse use-def chain) until // `stopCondition` is met. ValueBoundsConstraintSet cstr(std::make_pair(value, dim)); @@ -246,16 +242,32 @@ SmallVector lb(1), ub(1); cstr.cstr.getSliceBounds(pos, 1, b.getContext(), &lb, &ub, /*getClosedUB=*/true); + // Note: There are TODOs in the implementation of `getSliceBounds`. In such a // case, no lower/upper bound can be computed at the moment. - if (lb.empty() || !lb[0] || ub.empty() || !ub[0] || - lb[0].getNumResults() != 1 || ub[0].getNumResults() != 1) + // EQ, UB bounds: upper bound is needed. + if ((type != presburger::IntegerPolyhedron::BoundType::LB) && + (ub.empty() || !ub[0] || ub[0].getNumResults() != 1)) + return failure(); + // EQ, LB bounds: lower bound is needed. + if ((type != presburger::IntegerPolyhedron::BoundType::UB) && + (lb.empty() || !lb[0] || lb[0].getNumResults() != 1)) return failure(); - // Look for same lower and upper bound: EQ bound. - if (ub[0] != lb[0]) + // EQ bound: lower and upper bound must match. + if (type == presburger::IntegerPolyhedron::BoundType::EQ && ub[0] != lb[0]) return failure(); + AffineMap bound; + if (type == presburger::IntegerPolyhedron::BoundType::EQ || + type == presburger::IntegerPolyhedron::BoundType::LB) { + bound = lb[0]; + } else { + // Computed UB is a closed bound. Turn into an open bound. + bound = AffineMap::get(ub[0].getNumDims(), ub[0].getNumSymbols(), + ub[0].getResult(0) + 1); + } + // Gather all SSA values that are used in the computed bound. SmallVector operands; assert(cstr.cstr.getNumDimAndSymbolVars() == cstr.positionToValueDim.size() && @@ -268,10 +280,10 @@ // be included in the generated affine.apply op. bool used = false; if (i < cstr.cstr.getNumDimVars()) { - if (lb[0].isFunctionOfDim(i)) + if (bound.isFunctionOfDim(i)) used = true; } else { - if (lb[0].isFunctionOfSymbol(i - cstr.cstr.getNumDimVars())) + if (bound.isFunctionOfSymbol(i - cstr.cstr.getNumDimVars())) used = true; } @@ -305,20 +317,20 @@ } } - mlir::canonicalizeMapAndOperands(&lb[0], &operands); + mlir::canonicalizeMapAndOperands(&bound, &operands); // Check for special cases where no affine.apply op is needed. - if (lb[0].isSingleConstant()) { + if (bound.isSingleConstant()) { // Bound is a constant: return an IntegerAttr. return static_cast( - b.getIndexAttr(lb[0].getSingleConstantResult())); + b.getIndexAttr(bound.getSingleConstantResult())); } // No affine.apply op is needed if the bound is a single SSA value. - if (auto expr = lb[0].getResult(0).dyn_cast()) + if (auto expr = bound.getResult(0).dyn_cast()) return static_cast(operands[expr.getPosition()]); - if (auto expr = lb[0].getResult(0).dyn_cast()) + if (auto expr = bound.getResult(0).dyn_cast()) return static_cast( operands[expr.getPosition() + cstr.cstr.getNumDimVars() - 1]); // General case: build affine.apply op. return static_cast( - b.create(loc, lb[0], operands).getResult()); + b.create(loc, bound, operands).getResult()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" using namespace mlir; @@ -143,6 +144,49 @@ cstr.addBound(IntegerPolyhedron::BoundType::EQ, value, bound); }; }; + +struct AffineMinOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto minOp = cast(op); + assert(value == minOp.getResult() && "invalid value"); + + // Align affine map results with dims/symbols in the constraint set. + for (AffineExpr expr : minOp.getAffineMap().getResults()) { + SmallVector dimReplacements = llvm::to_vector(llvm::map_range( + minOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); })); + SmallVector symReplacements = llvm::to_vector(llvm::map_range( + minOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); })); + AffineExpr bound = + expr.replaceDimsAndSymbols(dimReplacements, symReplacements); + AffineExpr openBound = bound + 1; + cstr.addBound(IntegerPolyhedron::BoundType::UB, value, openBound); + } + }; +}; + +struct AffineMaxOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto maxOp = cast(op); + assert(value == maxOp.getResult() && "invalid value"); + + // Align affine map results with dims/symbols in the constraint set. + for (AffineExpr expr : maxOp.getAffineMap().getResults()) { + SmallVector dimReplacements = llvm::to_vector(llvm::map_range( + maxOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); })); + SmallVector symReplacements = llvm::to_vector(llvm::map_range( + maxOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); })); + AffineExpr bound = + expr.replaceDimsAndSymbols(dimReplacements, symReplacements); + cstr.addBound(IntegerPolyhedron::BoundType::LB, value, bound); + } + }; +}; } // namespace namespace memref { @@ -313,6 +357,8 @@ registry.addExtension(+[](MLIRContext *ctx, AffineDialect *dialect) { AffineApplyOp::attachInterface(*ctx); + AffineMaxOp::attachInterface(*ctx); + AffineMinOp::attachInterface(*ctx); }); registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { diff --git a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir @@ -152,6 +152,52 @@ // ----- +// CHECK-LABEL: func @affine_max_lb( +// CHECK-SAME: %[[a:.*]]: index +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: return %[[c2]] +func.func @affine_max_lb(%a: index) -> (index) { + // Note: There are two LBs: s0 and 2. FlatAffineValueConstraints always + // returns the constant one at the moment. + %1 = affine.max affine_map<()[s0] -> (s0, 2)>()[%a] + %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index) + return %2 : index +} + +// ----- + +func.func @affine_max_ub(%a: index) -> (index) { + %1 = affine.max affine_map<()[s0] -> (s0, 2)>()[%a] + // expected-error @below{{could not reify bound}} + %2 = "test.reify_bound"(%1) {type = "UB"}: (index) -> (index) + return %2 : index +} + +// ----- + +// CHECK-LABEL: func @affine_min_ub( +// CHECK-SAME: %[[a:.*]]: index +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: return %[[c3]] +func.func @affine_min_ub(%a: index) -> (index) { + // Note: There are two UBs: s0 + 1 and 3. FlatAffineValueConstraints always + // returns the constant one at the moment. + %1 = affine.min affine_map<()[s0] -> (s0, 2)>()[%a] + %2 = "test.reify_bound"(%1) {type = "UB"}: (index) -> (index) + return %2 : index +} + +// ----- + +func.func @affine_min_lb(%a: index) -> (index) { + %1 = affine.min affine_map<()[s0] -> (s0, 2)>()[%a] + // expected-error @below{{could not reify bound}} + %2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index) + return %2 : index +} + +// ----- + // CHECK-LABEL: func @memref_alloc( // CHECK-SAME: %[[sz:.*]]: index // CHECK: %[[c6:.*]] = arith.constant 6 : index diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -224,6 +224,17 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +FailureOr +parseBoundType(std::string type) { + if (type == "EQ") + return presburger::IntegerPolyhedron::BoundType::EQ; + if (type == "LB") + return presburger::IntegerPolyhedron::BoundType::LB; + if (type == "UB") + return presburger::IntegerPolyhedron::BoundType::UB; + return failure(); +} + static LogicalResult reifyValueBounds(func::FuncOp funcOp, bool reifyToFuncArgs) { IRRewriter rewriter(funcOp.getContext()); @@ -242,6 +253,15 @@ op->emitOpError("invalid op"); return WalkResult::skip(); } + std::string boundTypeStr = "EQ"; + if (auto boundTypeAttr = op->getAttrOfType("type")) + boundTypeStr = boundTypeAttr.str(); + auto boundType = parseBoundType(boundTypeStr); + if (failed(boundType)) { + op->emitOpError("invalid op"); + return WalkResult::interrupt(); + } + int64_t dim = value.getType().isIndex() ? ValueBoundsConstraintSet::kIndexValue : op->getAttrOfType("dim").getInt(); @@ -249,9 +269,8 @@ rewriter.setInsertionPointAfterValue(value); FailureOr reified; if (!reifyToFuncArgs) { - reified = ValueBoundsConstraintSet::reifyBound( - rewriter, op->getLoc(), - presburger::IntegerPolyhedron::BoundType::EQ, value, dim); + reified = ValueBoundsConstraintSet::reifyBound(rewriter, op->getLoc(), + *boundType, value, dim); } else { auto stopCondition = [](Value v) { auto bbArg = v.dyn_cast(); @@ -261,9 +280,7 @@ bbArg.getParentBlock()->getParentOp()); }; reified = ValueBoundsConstraintSet::reifyBound( - rewriter, op->getLoc(), - presburger::IntegerPolyhedron::BoundType::EQ, value, dim, - stopCondition); + rewriter, op->getLoc(), *boundType, value, dim, stopCondition); } if (failed(reified)) { op->emitOpError("could not reify bound");