diff --git a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -247,6 +248,51 @@ } // namespace } // namespace memref + +namespace arith { +namespace { + +struct AddIOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto addIOp = cast(op); + assert(value == addIOp.getResult() && "invalid value"); + + cstr.addBound(IntegerPolyhedron::BoundType::EQ, value, + cstr.getExpr(addIOp.getLhs()) + + cstr.getExpr(addIOp.getRhs())); + } +}; + +struct SubIOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto subIOp = cast(op); + assert(value == subIOp.getResult() && "invalid value"); + + cstr.addBound(IntegerPolyhedron::BoundType::EQ, value, + cstr.getExpr(subIOp.getLhs()) - + cstr.getExpr(subIOp.getRhs())); + } +}; + +struct MulIOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto mulIOp = cast(op); + assert(value == mulIOp.getResult() && "invalid value"); + + cstr.addBound(IntegerPolyhedron::BoundType::EQ, value, + cstr.getExpr(mulIOp.getLhs()) * + cstr.getExpr(mulIOp.getRhs())); + } +}; + +} // namespace +} // namespace arith } // namespace mlir void mlir::linalg::registerValueBoundsOpInterfaceExternalModels( @@ -280,4 +326,10 @@ memref::RankOp::attachInterface(*ctx); memref::SubViewOp::attachInterface(*ctx); }); + + registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { + arith::AddIOp::attachInterface(*ctx); + arith::SubIOp::attachInterface(*ctx); + arith::MulIOp::attachInterface(*ctx); + }); } diff --git a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir @@ -235,3 +235,45 @@ %1 = "test.reify_bound"(%0) {dim = 0} : (memref>) -> (index) return %1 : index } + +// ----- + +// CHECK: #[[$map]] = affine_map<()[s0] -> (s0 + 5)> +// CHECK-LABEL: func @arith_addi( +// CHECK-SAME: %[[a:.*]]: index +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]]] +// CHECK: return %[[apply]] +func.func @arith_addi(%a: index) -> index { + %0 = arith.constant 5 : index + %1 = arith.addi %0, %a : index + %2 = "test.reify_bound"(%1) : (index) -> (index) + return %2 : index +} + +// ----- + +// CHECK: #[[$map]] = affine_map<()[s0] -> (-s0 + 5)> +// CHECK-LABEL: func @arith_subi( +// CHECK-SAME: %[[a:.*]]: index +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]]] +// CHECK: return %[[apply]] +func.func @arith_subi(%a: index) -> index { + %0 = arith.constant 5 : index + %1 = arith.subi %0, %a : index + %2 = "test.reify_bound"(%1) : (index) -> (index) + return %2 : index +} + +// ----- + +// CHECK: #[[$map]] = affine_map<()[s0] -> (s0 * 5)> +// CHECK-LABEL: func @arith_muli( +// CHECK-SAME: %[[a:.*]]: index +// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[a]]] +// CHECK: return %[[apply]] +func.func @arith_muli(%a: index) -> index { + %0 = arith.constant 5 : index + %1 = arith.muli %0, %a : index + %2 = "test.reify_bound"(%1) : (index) -> (index) + return %2 : index +}