diff --git a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" using namespace mlir; @@ -336,6 +337,35 @@ } // namespace } // namespace arith + +namespace scf { +namespace { + +struct ForOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto forOp = cast(op); + // Only IV is supported at the moment. + if (value != forOp.getInductionVar()) + return; + + // TODO: Take into account step size. + cstr.addBound(IntegerPolyhedron::BoundType::LB, value, + cstr.getExpr(forOp.getLowerBound())); + cstr.addBound(IntegerPolyhedron::BoundType::UB, value, + cstr.getExpr(forOp.getUpperBound())); + } + + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + // iter_arg / return value not supported. + return; + } +}; + +} // namespace +} // namespace scf } // namespace mlir void mlir::linalg::registerValueBoundsOpInterfaceExternalModels( @@ -377,4 +407,8 @@ arith::SubIOp::attachInterface(*ctx); arith::MulIOp::attachInterface(*ctx); }); + + registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { + scf::ForOp::attachInterface(*ctx); + }); } diff --git a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir @@ -323,3 +323,17 @@ %2 = "test.reify_bound"(%1) : (index) -> (index) return %2 : index } + +// ----- + +// CHECK-LABEL: func @scf_for( +// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index +// CHECK: "test.some_use"(%[[a]], %[[b]]) +func.func @scf_for(%a: index, %b: index, %c: index) { + scf.for %iv = %a to %b step %c { + %0 = "test.reify_bound"(%iv) {type = "LB"} : (index) -> (index) + %1 = "test.reify_bound"(%iv) {type = "UB"} : (index) -> (index) + "test.some_use"(%0, %1) : (index, index) -> () + } + return +}