diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -84,7 +84,11 @@ /// The stop condition when traversing the backward slice of a shaped value/ /// index-type value. The traversal continues until the stop condition /// evaluates to "true" for a value. - using StopConditionFn = function_ref; + /// + /// The first parameter of the function is the shaped value/index-typed + /// value. The second parameter is the dimension in case of a shaped value. + using StopConditionFn = + function_ref /*dim*/)>; /// Compute a bound for the given index-typed value or shape dimension size. /// The computed bound is stored in `resultMap`. The operands of the bound are @@ -92,16 +96,26 @@ /// or a shaped value and a dimension. /// /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound - /// is computed in terms of values for which `stopCondition` evaluates to - /// "true". To that end, the backward slice (reverse use-def chain) of the - /// given value is visited in a worklist-driven manner and the constraint set - /// is populated according to `ValueBoundsOpInterface` for each visited value. + /// is computed in terms of values/dimensions for which `stopCondition` + /// evaluates to "true". To that end, the backward slice (reverse use-def + /// chain) of the given value is visited in a worklist-driven manner and the + /// constraint set is populated according to `ValueBoundsOpInterface` for each + /// visited value. static LogicalResult computeBound(AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, Value value, std::optional dim, StopConditionFn stopCondition); + /// Compute a bound in terms of the values/dimensions in `dependencies`. The + /// computed bound consists of only constant terms and dependent values (or + /// dimension sizes thereof). + static LogicalResult computeBound(AffineMap &resultMap, + ValueDimList &mapOperands, + presburger::BoundType type, Value value, + std::optional dim, + ValueDimList dependencies); + /// Compute a constant bound for the given index-typed value or shape /// dimension size. /// diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -21,19 +21,19 @@ std::optional dim) { // We are trying to reify a bound for `value`. Construct a stop condition that // evaluates to "true" for any SSA value expect for `value`. I.e., the bound - // will be computed in terms of any SSA values expect for `value`. The first + // will be computed in terms of any SSA values except for `value`. The first // such values are operands of the owner of `value`. - auto stopCondition = [&](Value v) { + auto stopCondition = [&](Value v, std::optional d) { // Reify in terms of SSA values that are different from `value`. return v != value; }; return reifyValueBound(b, loc, type, value, dim, stopCondition); } -FailureOr -mlir::reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, - Value value, std::optional dim, - function_ref stopCondition) { +FailureOr mlir::reifyValueBound( + OpBuilder &b, Location loc, presburger::BoundType type, Value value, + std::optional dim, + function_ref)> stopCondition) { // Compute bound. AffineMap boundMap; ValueDimList mapOperands; diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -12,6 +12,7 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" using namespace mlir; +using presburger::BoundType; namespace mlir { namespace scf { @@ -19,22 +20,97 @@ struct ForOpInterface : public ValueBoundsOpInterface::ExternalModel { + + /// Populate bounds of values/dimensions for iter_args/OpResults. + static void populateIterArgBounds(scf::ForOp forOp, Value value, + std::optional dim, + ValueBoundsConstraintSet &cstr) { + // `value` is an iter_arg or an OpResult. + int64_t iterArgIdx; + if (auto iterArg = value.dyn_cast()) { + iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars(); + } else { + iterArgIdx = value.cast().getResultNumber(); + } + + // An EQ constraint can be added if the yielded value (dimension size) + // equals the corresponding block argument (dimension size). + assert(forOp.getLoopBody().hasOneBlock() && + "multiple blocks not supported"); + Value yieldedValue = + cast(forOp.getLoopBody().front().getTerminator()) + .getOperand(iterArgIdx); + Value iterArg = forOp.getRegionIterArg(iterArgIdx); + Value initArg = forOp.getInitArgs()[iterArgIdx]; + + auto addEqBound = [&]() { + if (dim.has_value()) { + cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim); + } else { + cstr.bound(value) == initArg; + } + }; + + if (yieldedValue == iterArg) { + addEqBound(); + return; + } + + // Compute EQ bound for yielded value. + AffineMap bound; + ValueDimList boundOperands; + LogicalResult status = ValueBoundsConstraintSet::computeBound( + bound, boundOperands, BoundType::EQ, yieldedValue, dim, + [&](Value v, std::optional d) { + // Stop when reaching a block argument of the loop body. + if (auto bbArg = v.dyn_cast()) + return bbArg.getOwner()->getParentOp() == forOp; + // Stop when reaching a value that is defined outside of the loop. It + // is impossible to reach an iter_arg from there. + Operation *op = v.getDefiningOp(); + return forOp.getLoopBody().findAncestorOpInRegion(*op) == nullptr; + }); + if (failed(status)) + return; + if (bound.getNumResults() != 1) + return; + + // Check if computed bound equals the corresponding iter_arg. + Value singleValue = nullptr; + std::optional singleDim = std::nullopt; + if (auto dimExpr = bound.getResult(0).dyn_cast()) { + int64_t idx = dimExpr.getPosition(); + singleValue = boundOperands[idx].first; + singleDim = boundOperands[idx].second; + } else if (auto symExpr = bound.getResult(0).dyn_cast()) { + int64_t idx = symExpr.getPosition() + bound.getNumDims(); + singleValue = boundOperands[idx].first; + singleDim = boundOperands[idx].second; + } + if (singleValue == iterArg && singleDim == dim) + addEqBound(); + } + void populateBoundsForIndexValue(Operation *op, Value value, ValueBoundsConstraintSet &cstr) const { auto forOp = cast(op); - // Only IV is supported at the moment. - if (value != forOp.getInductionVar()) + + if (value == forOp.getInductionVar()) { + // TODO: Take into account step size. + cstr.bound(value) >= forOp.getLowerBound(); + cstr.bound(value) < forOp.getUpperBound(); return; + } - // TODO: Take into account step size. - cstr.bound(value) >= forOp.getLowerBound(); - cstr.bound(value) < forOp.getUpperBound(); + // Handle iter_args and OpResults. + populateIterArgBounds(forOp, value, std::nullopt, cstr); } void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, ValueBoundsConstraintSet &cstr) const { - // iter_arg / return value not supported. - return; + auto forOp = cast(op); + // Handle iter_args and OpResults. + populateIterArgBounds(forOp, value, dim, cstr); } }; diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -164,7 +164,8 @@ } // Do not process any further if the stop condition is met. - if (stopCondition(value)) + auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim); + if (stopCondition(value, maybeDim)) continue; // Query `ValueBoundsOpInterface` for constraints. New items may be added to @@ -213,12 +214,14 @@ Value value, std::optional dim, StopConditionFn stopCondition) { #ifndef NDEBUG assertValidValueDim(value, dim); + assert(!stopCondition(value, dim) && + "stop condition should not be satisfied for starting point"); #endif // NDEBUG Builder b(value.getContext()); mapOperands.clear(); - if (stopCondition(value)) { + if (stopCondition(value, dim)) { // Special case: If the stop condition is satisfied for the input // value/dimension, directly return it. mapOperands.push_back(std::make_pair(value, dim)); @@ -239,7 +242,9 @@ // Do not project out `valueDim`. if (valueDim == p) return false; - return !stopCondition(p.first); + auto maybeDim = + p.second == kIndexValue ? std::nullopt : std::make_optional(p.second); + return !stopCondition(p.first, maybeDim); }); // Compute lower and upper bounds for `valueDim`. @@ -338,6 +343,16 @@ return success(); } +LogicalResult ValueBoundsConstraintSet::computeBound( + AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, + Value value, std::optional dim, ValueDimList dependencies) { + return computeBound(resultMap, mapOperands, type, value, dim, + [&](Value v, std::optional d) { + return llvm::is_contained(dependencies, + std::make_pair(v, d)); + }); +} + FailureOr ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType type, Value value, std::optional dim, StopConditionFn stopCondition) { @@ -354,9 +369,10 @@ } else { // No stop condition specified: Keep adding constraints until a bound could // be computed. - cstr.processWorklist(/*stopCondition=*/[&](Value v) { - return cstr.cstr.getConstantBound64(type, pos).has_value(); - }); + cstr.processWorklist( + /*stopCondition=*/[&](Value v, std::optional dim) { + return cstr.cstr.getConstantBound64(type, pos).has_value(); + }); } // Compute constant bound for `valueDim`. diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir @@ -12,3 +12,95 @@ } return } + +// ----- + +// CHECK-LABEL: func @scf_for_index_result_small( +// CHECK-SAME: %[[i:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index +// CHECK: "test.some_use"(%[[i]]) +// CHECK: "test.some_use"(%[[i]]) +func.func @scf_for_index_result_small(%i: index, %a: index, %b: index, %c: index) { + %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %i) -> index { + %1 = "test.reify_bound"(%arg) {type = "EQ"} : (index) -> (index) + "test.some_use"(%1) : (index) -> () + scf.yield %arg : index + } + %2 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index) + "test.some_use"(%2) : (index) -> () + return +} + +// ----- + +// CHECK-LABEL: func @scf_for_index_result( +// CHECK-SAME: %[[i:.*]]: index, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index +// CHECK: "test.some_use"(%[[i]]) +// CHECK: "test.some_use"(%[[i]]) +func.func @scf_for_index_result(%i: index, %a: index, %b: index, %c: index) { + %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %i) -> index { + %add = arith.addi %arg, %a : index + %sub = arith.subi %add, %a : index + + %1 = "test.reify_bound"(%arg) {type = "EQ"} : (index) -> (index) + "test.some_use"(%1) : (index) -> () + scf.yield %sub : index + } + %2 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index) + "test.some_use"(%2) : (index) -> () + return +} + +// ----- + +// CHECK-LABEL: func @scf_for_tensor_result_small( +// CHECK-SAME: %[[t:.*]]: tensor, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: "test.some_use"(%[[dim]]) +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: "test.some_use"(%[[dim]]) +func.func @scf_for_tensor_result_small(%t: tensor, %a: index, %b: index, %c: index) { + %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %t) -> tensor { + %1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor) -> (index) + "test.some_use"(%1) : (index) -> () + scf.yield %arg : tensor + } + %2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor) -> (index) + "test.some_use"(%2) : (index) -> () + return +} + +// ----- + +// CHECK-LABEL: func @scf_for_tensor_result( +// CHECK-SAME: %[[t:.*]]: tensor, %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: "test.some_use"(%[[dim]]) +// CHECK: %[[dim:.*]] = tensor.dim %[[t]] +// CHECK: "test.some_use"(%[[dim]]) +func.func @scf_for_tensor_result(%t: tensor, %a: index, %b: index, %c: index) { + %cst = arith.constant 5.0 : f32 + %0 = scf.for %iv = %a to %b step %c iter_args(%arg = %t) -> tensor { + %filled = linalg.fill ins(%cst : f32) outs(%arg : tensor) -> tensor + %1 = "test.reify_bound"(%arg) {type = "EQ", dim = 0} : (tensor) -> (index) + "test.some_use"(%1) : (index) -> () + scf.yield %filled : tensor + } + %2 = "test.reify_bound"(%0) {type = "EQ", dim = 0} : (tensor) -> (index) + "test.some_use"(%2) : (index) -> () + return +} + +// ----- + +func.func @scf_for_swapping_yield(%t1: tensor, %t2: tensor, %a: index, %b: index, %c: index) { + %cst = arith.constant 5.0 : f32 + %r1, %r2 = scf.for %iv = %a to %b step %c iter_args(%arg1 = %t1, %arg2 = %t2) -> (tensor, tensor) { + %filled1 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor) -> tensor + %filled2 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor) -> tensor + scf.yield %filled2, %filled1 : tensor, tensor + } + // expected-error @below{{could not reify bound}} + %reify1 = "test.reify_bound"(%r1) {type = "EQ", dim = 0} : (tensor) -> (index) + "test.some_use"(%reify1) : (index) -> () + return +} diff --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt --- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt @@ -25,5 +25,7 @@ MLIRIR MLIRPass MLIRSupport + MLIRMemRefDialect + MLIRTensorDialect MLIRVectorUtils ) diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -9,6 +9,8 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Pass/Pass.h" @@ -33,7 +35,8 @@ TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){}; void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } void runOnOperation() override; @@ -101,13 +104,14 @@ // Prepare stop condition. By default, reify in terms of the op's // operands. No stop condition is used when a constant was requested. - std::function stopCondition = [&](Value v) { - // Reify in terms of SSA values that are different from `value`. - return v != value; - }; + std::function)> stopCondition = + [&](Value v, std::optional d) { + // Reify in terms of SSA values that are different from `value`. + return v != value; + }; if (reifyToFuncArgs) { // Reify in terms of function block arguments. - stopCondition = stopCondition = [](Value v) { + stopCondition = stopCondition = [](Value v, std::optional d) { auto bbArg = v.dyn_cast(); if (!bbArg) return false; diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -558,6 +558,7 @@ "//mlir:Pass", "//mlir:SCFDialect", "//mlir:Support", + "//mlir:TensorDialect", "//mlir:Transforms", "//mlir:ValueBoundsOpInterface", "//mlir:VectorDialect",