diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -60,6 +60,40 @@ /// Otherwise return nullptr. IntegerAttr getSmallestBoundingIndex(Value size); +/// Computes an upper bound for the result `value` of an index computation. +/// Translates AffineMinOps and AffineApplyOps along the use-def chains of the +/// index computation to affine constraints and projects out intermediate +/// values. The method sets `boundMap` to an affine map that given +/// `boundOperands` evaluates to an upper bound for the index computation. +/// +/// Example: +/// ``` +/// %dim0 = dim %tensor, %c0 +/// %dim1 = dim %tensor, %c1 +/// %0 = affine.min affine.map<(d0) -> (40, d0)> (%dim0) +/// %1 = affine.apply affine.map<(d0, d1) -> (d0 + d1)> (%0, %dim1) +/// ``` +/// getUpperBoundForIndex(%1, boundMap, boundOperands) +/// set the output parameters to: +/// - boundMap = affine.map<(d0) -> (d0 + 40)> +/// - boundOperands = [%dim1] +void getUpperBoundForIndex(Value value, AffineMap &boundMap, + SmallVectorImpl &boundOperands); + +/// Returns a constant upper bound for the result `value` of an index +/// computation. Calls `getUpperBoundForIndex` and returns a constant upper +/// bound if the result of `boundMap` is a constant expression and failure +/// otherwise. +/// +/// Example: +/// ``` +/// %0 = affine.min affine.map<(d0) -> (40, d0)> (%d0) +/// %1 = affine.apply affine.map<(d0) -> (d0 + 2)> (%0) +/// ``` +/// getConstantUpperBoundForIndex(%1) returns 42 +/// (boundsMap = affine.map<() -> (42)>) +FailureOr getConstantUpperBoundForIndex(Value value); + /// Create an ExtractSliceOp and, if `source` is defined by an ExtractSliceOp, /// fold it by adding the offsets. /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -171,16 +171,20 @@ staticSizes.reserve(opToPad.getRank(opOperand)); auto shapedOp = cast(sliceOp.getOperation()); for (auto size : shapedOp.getMixedSizes()) { - auto indexAttr = size.is() - ? size.get().dyn_cast() - : linalg::getSmallestBoundingIndex(size.get()); - // SmallestBoundingIndex must exist for all sizes. - // For now return an error if we can't find it. - if (!indexAttr) { + // If the size is an attribute add it directly to `staticSizes`. + if (size.is()) { + staticSizes.push_back( + size.get().dyn_cast().getInt()); + continue; + } + // Otherwise, try to compute a constant upper bound for the size value. + FailureOr upperBound = + getConstantUpperBoundForIndex(size.get()); + if (failed(upperBound)) { LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); return failure(); } - staticSizes.push_back(indexAttr.getInt()); + staticSizes.push_back(upperBound.getValue()); } auto staticTensorType = RankedTensorType::get( staticSizes, getElementTypeOrSelf(opOperand->get())); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -12,7 +12,10 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" @@ -209,6 +212,106 @@ return nullptr; } +void getUpperBoundForIndex(Value value, AffineMap &boundMap, + SmallVectorImpl &boundOperands) { + // Initialize `boundMap` and `boundOperands` to the identity returning + // `value`. This combination is the default result of the method if no + // simplification is possible. + assert(value.getType().isIndex() && "expect value to have index type"); + boundMap = AffineMap::getMultiDimIdentityMap(1, value.getContext()); + boundOperands.assign({value}); + canonicalizeMapAndOperands(&boundMap, &boundOperands); + + // Continue only if there is an affine index computation to simplify. + Operation *definingOp = value.getDefiningOp(); + if (!definingOp || !isa(definingOp)) + return; + + // Get the backward slice containing the affine index computation. + SetVector backwardSlice; + getBackwardSlice(definingOp, &backwardSlice, [](Operation *op) { + return isa(op); + }); + backwardSlice.insert(definingOp); + + // Setup a system of affine constraints that describe the index computation. + FlatAffineValueConstraints constraints; + + // Helper to find or create an identifier for the given value. + auto findOrCreateId = [&](Value value) { + if (!constraints.containsId(value)) { + constraints.appendDimId(value); + return true; + } + unsigned pos; + constraints.findId(value, &pos); + return pos < constraints.getNumDimIds(); + }; + // Helper to get the position for the given value. + auto getPosition = [&](Value value) { + unsigned pos; + bool exists = constraints.findId(value, &pos); + (void)exists; + assert(exists && "expect to find the identifier"); + return pos; + }; + + // Add the affine operations in `backwardSlice` to the constraints. + for (Operation *op : llvm::reverse(backwardSlice)) { + // Add an identifier for all op results and operands. + if (!(llvm::all_of(op->getResults(), findOrCreateId) && + llvm::all_of(op->getOperands(), findOrCreateId))) + return; + // Add AffineApplyOps to the constraints. + if (auto applyOp = dyn_cast(op)) { + AffineValueMap valueMap(applyOp.getAffineMap(), applyOp.getOperands(), + applyOp.getResult()); + if (failed(constraints.composeMap(&valueMap))) + return; + continue; + } + // Add AffineMinOps to the constraints. + auto minOp = cast(op); + AffineMap map = constraints.computeAlignedMap(minOp.getAffineMap(), + minOp.getOperands()); + if (failed(constraints.addBound(FlatAffineConstraints::UB, + getPosition(minOp.getResult()), map))) + return; + } + + // Obtain an upper bound for the affine index computation by projecting out + // all temporary results and expressing the upper bound for `value` in terms + // of the terminals of the index computation. + SmallVector lowerBounds(1), upperBounds(1); + constraints.getSliceBounds(getPosition(value), 1, value.getContext(), + &lowerBounds, &upperBounds); + + // Verify `upperBounds[0]` is valid and has at least one result. + if (!upperBounds[0] || upperBounds[0].getNumResults() == 0) + return; + + // Set `boundMap` and `boundOperands` to the computed upper bound. + boundMap = upperBounds[0]; + constraints.getAllValues(&boundOperands); + erase_value(boundOperands, value); + canonicalizeMapAndOperands(&boundMap, &boundOperands); +} + +FailureOr getConstantUpperBoundForIndex(Value value) { + // Compute an upper bound for `value`. + AffineMap boundMap; + SmallVector boundOperands; + getUpperBoundForIndex(value, boundMap, boundOperands); + + // Return a constant upper bound if `boundMap` has a single constant result + // dimension, otherwise return failure. + if (boundMap.getNumResults() != 1) + return failure(); + if (auto constExpr = boundMap.getResult(0).dyn_cast()) + return constExpr.getValue(); + return failure(); +} + tensor::ExtractSliceOp makeComposedExtractSliceOp( OpBuilder &b, Location loc, Value source, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) {