diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h @@ -59,10 +59,10 @@ /// terms of SSA values for which `stopCondition` is met. `dim` must be /// `nullopt` if and only if `value` is index-typed. LB and EQ bounds are /// closed, UB bounds are open. -FailureOr -reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, - Value value, std::optional dim, - function_ref stopCondition); +FailureOr reifyValueBound( + OpBuilder &b, Location loc, presburger::BoundType type, Value value, + std::optional dim, + function_ref)> stopCondition); } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -44,15 +44,26 @@ /// or a shaped value and a dimension. /// /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound - /// is computed in terms of values for which `stopCondition` evaluates to - /// "true". To that end, the backward slice (reverse use-def chain) of the - /// given value is visited in a worklist-driven manner and the constraint set - /// is populated according to `ValueBoundsOpInterface` for each visited value. + /// is computed in terms of values/dimensions for which `stopCondition` + /// evaluates to "true". To that end, the backward slice (reverse use-def + /// chain) of the given value is visited in a worklist-driven manner and the + /// constraint set is populated according to `ValueBoundsOpInterface` for each + /// visited value. + /// + /// Note: `stopCondition` should not evaluate to "true" for the input + /// value/dimension. + static LogicalResult + computeBound(AffineMap &resultMap, ValueDimList &mapOperands, + presburger::BoundType type, Value value, + std::optional dim, + function_ref)> stopCondition); + + /// Compute a bound in terms of the values/dimensions in `dependencies`. static LogicalResult computeBound(AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, Value value, std::optional dim, - function_ref stopCondition); + ValueDimList dependencies); /// Compute a constant bound for the given index-typed value or shape /// dimension size. @@ -67,10 +78,11 @@ /// The stop condition is optional: If none is specified, the backward slice /// is traversed in a breadth-first manner until a constant bound could be /// computed. - static FailureOr - computeConstantBound(presburger::BoundType type, Value value, - std::optional dim = std::nullopt, - function_ref stopCondition = nullptr); + static FailureOr computeConstantBound( + presburger::BoundType type, Value value, + std::optional dim = std::nullopt, + function_ref)> stopCondition = + nullptr); /// Bound the given index-typed value by the given expression. void addBound(presburger::BoundType type, Value value, AffineExpr expr); @@ -111,7 +123,8 @@ /// Iteratively process all elements on the worklist until an index-typed /// value or shaped value meets `stopCondition`. Such values are not processed /// any further. - void processWorklist(function_ref stopCondition); + void processWorklist( + function_ref)> stopCondition); /// Bound the given column in the underlying constraint set by the given /// expression. diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -19,17 +19,17 @@ presburger::BoundType type, Value value, std::optional dim) { - auto stopCondition = [&](Value v) { + auto stopCondition = [&](Value v, std::optional d) { // Reify in terms of SSA values that are different from `value`. return v != value; }; return reifyValueBound(b, loc, type, value, dim, stopCondition); } -FailureOr -mlir::reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, - Value value, std::optional dim, - function_ref stopCondition) { +FailureOr mlir::reifyValueBound( + OpBuilder &b, Location loc, presburger::BoundType type, Value value, + std::optional dim, + function_ref)> stopCondition) { // Compute bound. AffineMap boundMap; ValueDimList mapOperands; diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -23,19 +23,109 @@ void populateBoundsForIndexValue(Operation *op, Value value, ValueBoundsConstraintSet &cstr) const { auto forOp = cast(op); - // Only IV is supported at the moment. - if (value != forOp.getInductionVar()) + + if (value == forOp.getInductionVar()) { + // TODO: Take into account step size. + cstr.addBound(BoundType::LB, value, cstr.getExpr(forOp.getLowerBound())); + cstr.addBound(BoundType::UB, value, cstr.getExpr(forOp.getUpperBound())); return; + } + + // `value` is an iter_arg or an OpResult. + int64_t iterArgIdx; + if (auto iterArg = value.dyn_cast()) { + iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars(); + } else { + iterArgIdx = value.cast().getResultNumber(); + } - // TODO: Take into account step size. - cstr.addBound(BoundType::LB, value, cstr.getExpr(forOp.getLowerBound())); - cstr.addBound(BoundType::UB, value, cstr.getExpr(forOp.getUpperBound())); + // An EQ constraint can be added if the yielded value is the same value as + // the corresponding block argument. + assert(forOp.getLoopBody().hasOneBlock() && + "multiple blocks not supported"); + Value yieldedValue = + cast(forOp.getLoopBody().front().getTerminator()) + .getOperand(iterArgIdx); + Value iterArg = forOp.getRegionIterArg(iterArgIdx); + Value initArg = forOp.getInitArgs()[iterArgIdx]; + + // Compute EQ bound for yielded value. + AffineMap bound; + ValueDimList boundOperands; + LogicalResult status = ValueBoundsConstraintSet::computeBound( + bound, boundOperands, BoundType::EQ, yieldedValue, + /*dim=*/std::nullopt, {{iterArg, std::nullopt}}); + if (failed(status)) + return; + if (bound.getNumResults() != 1) + return; + + // Check if computed bound equals the corresponding iter_arg. + Value singleValue = nullptr; + if (auto dimExpr = bound.getResult(0).dyn_cast()) { + int64_t idx = dimExpr.getPosition(); + if (!boundOperands[idx].second.has_value()) + singleValue = boundOperands[idx].first; + } else if (auto symExpr = bound.getResult(0).dyn_cast()) { + int64_t idx = symExpr.getPosition() + bound.getNumDims(); + if (!boundOperands[idx].second.has_value()) + singleValue = boundOperands[idx].first; + } + if (singleValue == iterArg) + cstr.addBound(BoundType::EQ, value, cstr.getExpr(initArg)); } void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, ValueBoundsConstraintSet &cstr) const { - // iter_arg / return value not supported. - return; + auto forOp = cast(op); + + // `value` is an iter_arg or an OpResult. + int64_t iterArgIdx; + if (auto iterArg = value.dyn_cast()) { + iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars(); + } else { + iterArgIdx = value.cast().getResultNumber(); + } + + // An EQ constraint can be added if the dimension size of the yielded value + // equals the dimensions size of the corresponding block argument. + assert(forOp.getLoopBody().hasOneBlock() && + "multiple blocks not supported"); + Value yieldedValue = + cast(forOp.getLoopBody().front().getTerminator()) + .getOperand(iterArgIdx); + Value iterArg = forOp.getRegionIterArg(iterArgIdx); + Value initArg = forOp.getInitArgs()[iterArgIdx]; + + // Compute EQ bound for yielded value. + AffineMap bound; + ValueDimList boundOperands; + LogicalResult status = ValueBoundsConstraintSet::computeBound( + bound, boundOperands, BoundType::EQ, yieldedValue, dim, + {{iterArg, dim}}); + if (failed(status)) + return; + if (bound.getNumResults() != 1) + return; + + // Check if computed bound equals the corresponding iter_arg. + Value singleValue = nullptr; + int64_t singleDim = -1; + if (auto dimExpr = bound.getResult(0).dyn_cast()) { + int64_t idx = dimExpr.getPosition(); + if (boundOperands[idx].second.has_value()) { + singleValue = boundOperands[idx].first; + singleDim = *boundOperands[idx].second; + } + } else if (auto symExpr = bound.getResult(0).dyn_cast()) { + int64_t idx = symExpr.getPosition() + bound.getNumDims(); + if (boundOperands[idx].second.has_value()) { + singleValue = boundOperands[idx].first; + singleDim = *boundOperands[idx].second; + } + } + if (singleValue == iterArg && singleDim == dim) + cstr.addBound(BoundType::EQ, value, dim, cstr.getExpr(initArg, dim)); } }; diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -172,7 +172,7 @@ } void ValueBoundsConstraintSet::processWorklist( - function_ref stopCondition) { + function_ref)> stopCondition) { while (!worklist.empty()) { int64_t pos = worklist.front(); worklist.pop(); @@ -191,7 +191,8 @@ } // Do not process any further if the stop condition is met. - if (stopCondition(value)) + auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim); + if (stopCondition(value, maybeDim)) continue; // Query `ValueBoundsOpInterface` for constraints. New items may be added to @@ -237,9 +238,11 @@ LogicalResult ValueBoundsConstraintSet::computeBound( AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, Value value, std::optional dim, - function_ref stopCondition) { + function_ref)> stopCondition) { #ifndef NDEBUG assertValidValueDim(value, dim); + assert(!stopCondition(value, dim) && + "stop condition should not be satisfied for starting point"); #endif // NDEBUG // Process the backward slice of `value` (i.e., reverse use-def chain) until @@ -254,7 +257,9 @@ // Do not project out `valueDim`. if (valueDim == p) return false; - return !stopCondition(p.first); + auto maybeDim = + p.second == kIndexValue ? std::nullopt : std::make_optional(p.second); + return !stopCondition(p.first, maybeDim); }); // Compute lower and upper bounds for `valueDim`. @@ -355,9 +360,27 @@ return success(); } +LogicalResult ValueBoundsConstraintSet::computeBound( + AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, + Value value, std::optional dim, ValueDimList dependencies) { + if (llvm::is_contained(dependencies, std::make_pair(value, dim))) { + resultMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, + Builder(value.getContext()).getAffineDimExpr(0)); + mapOperands.clear(); + mapOperands.push_back(std::make_pair(value, dim)); + return success(); + } + + return computeBound(resultMap, mapOperands, type, value, dim, + [&](Value v, std::optional d) { + return llvm::is_contained(dependencies, + std::make_pair(v, d)); + }); +} + FailureOr ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType type, Value value, std::optional dim, - function_ref stopCondition) { + function_ref)> stopCondition) { #ifndef NDEBUG assertValidValueDim(value, dim); #endif // NDEBUG @@ -372,9 +395,10 @@ } else { // No stop condition specified: Keep adding constraints until a bound could // be computed - cstr.processWorklist(/*stopCondition=*/[&](Value v) { - return cstr.cstr.getConstantBound64(type, pos).has_value(); - }); + cstr.processWorklist( + /*stopCondition=*/[&](Value v, std::optional dim) { + return cstr.cstr.getConstantBound64(type, pos).has_value(); + }); } // Compute constant bound for `valueDim`. diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir @@ -12,3 +12,80 @@ } return } + +// ----- + +// CHECK-LABEL: func @scf_for_index_result_small( +// CHECK-SAME: %[[i:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index +// CHECK: "test.some_use"(%[[i]]) +// CHECK: "test.some_use"(%[[i]]) +func.func @scf_for_index_result_small(%i: index, %a: index, %b: index, %c: index) { + %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %i) -> index { + %1 = "test.reify_bound"(%arg) {type = "EQ"} : (index) -> (index) + "test.some_use"(%1) : (index) -> () + scf.yield %arg : index + } + %2 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index) + "test.some_use"(%2) : (index) -> () + return +} + +// ----- + +// CHECK-LABEL: func @scf_for_index_result( +// CHECK-SAME: %[[i:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index +// CHECK: "test.some_use"(%[[i]]) +// CHECK: "test.some_use"(%[[i]]) +func.func @scf_for_index_result(%i: index, %a: index, %b: index, %c: index) { + %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %i) -> index { + %add = arith.addi %arg, %a : index + %sub = arith.subi %add, %a : index + + %1 = "test.reify_bound"(%arg) {type = "EQ"} : (index) -> (index) + "test.some_use"(%1) : (index) -> () + scf.yield %sub : index + } + %2 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index) + "test.some_use"(%2) : (index) -> () + return +} + +// ----- + +// CHECK-LABEL: func @scf_for_tensor_result_small( +// CHECK-SAME: %[[t:.*]]: tensor, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: "test.some_use"(%[[dim]]) +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: "test.some_use"(%[[dim]]) +func.func @scf_for_tensor_result_small(%t: tensor, %a: index, %b: index, %c: index) { + %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %t) -> tensor { + %1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor) -> (index) + "test.some_use"(%1) : (index) -> () + scf.yield %arg : tensor + } + %2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor) -> (index) + "test.some_use"(%2) : (index) -> () + return +} + +// ----- + +// CHECK-LABEL: func @scf_for_tensor_result( +// CHECK-SAME: %[[t:.*]]: tensor, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: "test.some_use"(%[[dim]]) +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: "test.some_use"(%[[dim]]) +func.func @scf_for_tensor_result(%t: tensor, %a: index, %b: index, %c: index) { + %cst = arith.constant 5.0 : f32 + %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %t) -> tensor { + %filled = linalg.fill ins(%cst : f32) outs(%arg : tensor) -> tensor + %1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor) -> (index) + "test.some_use"(%1) : (index) -> () + scf.yield %filled : tensor + } + %2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor) -> (index) + "test.some_use"(%2) : (index) -> () + return +} diff --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt --- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt @@ -25,5 +25,7 @@ MLIRIR MLIRPass MLIRSupport + MLIRMemRefDialect + MLIRTensorDialect MLIRVectorUtils ) diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -9,6 +9,8 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Pass/Pass.h" @@ -33,7 +35,8 @@ TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){}; void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnOperation() override; @@ -99,13 +102,14 @@ // Prepare stop condition. By default, reify in terms of the op's // operands. No stop condition is used when a constant was requested. - std::function stopCondition = [&](Value v) { - // Reify in terms of SSA values that are different from `value`. - return v != value; - }; + std::function)> stopCondition = + [&](Value v, std::optional d) { + // Reify in terms of SSA values that are different from `value`. + return v != value; + }; if (reifyToFuncArgs) { // Reify in terms of function block arguments. - stopCondition = stopCondition = [](Value v) { + stopCondition = stopCondition = [](Value v, std::optional d) { auto bbArg = v.dyn_cast(); if (!bbArg) return false; diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -540,6 +540,7 @@ "//mlir:Pass", "//mlir:SCFDialect", "//mlir:Support", + "//mlir:TensorDialect", "//mlir:Transforms", "//mlir:ValueBoundsOpInterface", "//mlir:VectorDialect",