diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -96,6 +96,18 @@ /// Returns true if the result of this operation is a symbol for all its /// uses in `region`. bool isValidSymbol(Region *region); + + /// Returns all dimension operands. + ValueRange getDimOperands() { + return OperandRange{getOperands().begin(), + getOperands().begin() + getMap().getNumDims()}; + } + + /// Returns all symbol operands. + ValueRange getSymbolOperands() { + return OperandRange{getOperands().begin() + getMap().getNumDims(), + getOperands().end()}; + } }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -117,6 +118,31 @@ } // namespace } // namespace tensor + +namespace { +struct AffineApplyOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto applyOp = cast(op); + assert(value == applyOp.getResult() && "invalid value"); + assert(applyOp.getAffineMap().getNumResults() == 1 && + "expected single result"); + + // Align affine map result with dims/symbols in the constraint set. + AffineExpr expr = applyOp.getAffineMap().getResult(0); + SmallVector dimReplacements = llvm::to_vector(llvm::map_range( + applyOp.getDimOperands(), [&](Value v) { return cstr.getExpr(v); })); + SmallVector symReplacements = llvm::to_vector(llvm::map_range( + applyOp.getSymbolOperands(), [&](Value v) { return cstr.getExpr(v); })); + AffineExpr bound = + expr.replaceDimsAndSymbols(dimReplacements, symReplacements); + cstr.addBound(IntegerPolyhedron::BoundType::EQ, value, bound); + }; +}; +} // namespace + } // namespace mlir void mlir::linalg::registerValueBoundsOpInterfaceExternalModels( @@ -134,4 +160,8 @@ tensor::PadOp::attachInterface(*ctx); tensor::RankOp::attachInterface(*ctx); }); + + registry.addExtension(+[](MLIRContext *ctx, AffineDialect *dialect) { + AffineApplyOp::attachInterface(*ctx); + }); } diff --git a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir @@ -135,3 +135,17 @@ %1 = "test.reify_bound"(%0) : (index) -> (index) return %1 : index } + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-LABEL: func @affine_apply( +// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]] +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]] +// CHECL: return %[[apply]] +func.func @affine_apply(%a: index, %b: index) -> index { + %0 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%a, %b] + %1 = "test.reify_bound"(%0) : (index) -> (index) + return %1 : index +}