diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -456,38 +456,25 @@ auto sym = getAffineSymbolExpr(nSymbols++, context); expr = expr ? expr + d * sym : d * sym; } - return expr; -} - -// Factored out common logic to update `strides` and `seen` for `dim` with value -// `val`. This handles both saturated and unsaturated cases. -static void accumulateStrides(MutableArrayRef strides, - MutableArrayRef seen, unsigned pos, - int64_t val) { - if (!seen[pos]) { - // Newly seen case, sets value - strides[pos] = val; - seen[pos] = true; - return; - } - if (strides[pos] != MemRefType::getDynamicStrideOrOffset()) - // Already seen case accumulates unless they are already saturated. - strides[pos] += val; -} - -// This sums multiple offsets as they are seen. In the particular case of -// accumulating a dynamic offset with either a static of dynamic one, this -// saturates to MemRefType::getDynamicStrideOrOffset(). -static void accumulateOffset(int64_t &offset, bool &seenOffset, int64_t val) { - if (!seenOffset) { - // Newly seen case, sets value - offset = val; - seenOffset = true; + return simplifyAffineExpr(expr, rank, nSymbols); +} + +// Fallback cases for terminal dim/sym/cst that are not part of a binary op ( +// i.e. single term). +static void extractStridesFromTerm(AffineExpr e, + MutableArrayRef strides, + AffineExpr &offset, + MutableArrayRef seen, + bool &seenOffset) { + if (auto dim = e.dyn_cast()) { + assert(!seen[dim.getPosition()] && + "unexpected `seen` bit with single term"); + strides[dim.getPosition()] = strides[dim.getPosition()] + 1; + seen[dim.getPosition()] = true; return; } - if (offset != MemRefType::getDynamicStrideOrOffset()) - // Already seen case accumulates unless they are already saturated. - offset += val; + offset = offset + e; + seenOffset = true; } /// Takes a single AffineExpr `e` and populates the `strides` and `seen` arrays @@ -495,12 +482,14 @@ /// that position, respectively. /// The convention is that the strides for dimensions d0, .. dn appear in /// order to make indexing intuitive into the result. -static void extractStrides(AffineExpr e, MutableArrayRef strides, - int64_t &offset, MutableArrayRef seen, +static void extractStrides(AffineExpr e, MutableArrayRef strides, + AffineExpr &offset, MutableArrayRef seen, bool &seenOffset, bool &failed) { auto bin = e.dyn_cast(); - if (!bin) + if (!bin) { + extractStridesFromTerm(e, strides, offset, seen, seenOffset); return; + } if (bin.getKind() == AffineExprKind::CeilDiv || bin.getKind() == AffineExprKind::FloorDiv || @@ -516,66 +505,21 @@ failed = true; return; } - auto cst = bin.getRHS().dyn_cast(); - if (!cst) { - strides[dim.getPosition()] = MemRefType::getDynamicStrideOrOffset(); - seen[dim.getPosition()] = true; - } else { - accumulateStrides(strides, seen, dim.getPosition(), cst.getValue()); - } + strides[dim.getPosition()] = strides[dim.getPosition()] + bin.getRHS(); + seen[dim.getPosition()] = true; return; } if (bin.getKind() == AffineExprKind::Add) { - for (auto e : {bin.getLHS(), bin.getRHS()}) { - if (auto cst = e.dyn_cast()) { - // Independent constants cumulate. - accumulateOffset(offset, seenOffset, cst.getValue()); - } else if (auto sym = e.dyn_cast()) { - // Independent symbols saturate. - offset = MemRefType::getDynamicStrideOrOffset(); - seenOffset = true; - } else if (auto dim = e.dyn_cast()) { - // Independent symbols cumulate 1. - accumulateStrides(strides, seen, dim.getPosition(), 1); - } - // Sum of binary ops dispatch to the respective exprs. - } + extractStrides(bin.getLHS(), strides, offset, seen, seenOffset, failed); + extractStrides(bin.getRHS(), strides, offset, seen, seenOffset, failed); return; } llvm_unreachable("unexpected binary operation"); } -// Fallback cases for terminal dim/sym/cst that are not part of a binary op ( -// i.e. single term). -static void extractStridesFromTerm(AffineExpr e, - MutableArrayRef strides, - int64_t &offset, MutableArrayRef seen, - bool &seenOffset) { - if (auto cst = e.dyn_cast()) { - assert(!seenOffset && "unexpected `seen` bit with single term"); - offset = cst.getValue(); - seenOffset = true; - return; - } - if (auto sym = e.dyn_cast()) { - assert(!seenOffset && "unexpected `seen` bit with single term"); - offset = MemRefType::getDynamicStrideOrOffset(); - seenOffset = true; - return; - } - if (auto dim = e.dyn_cast()) { - assert(!seen[dim.getPosition()] && - "unexpected `seen` bit with single term"); - strides[dim.getPosition()] = 1; - seen[dim.getPosition()] = true; - return; - } - llvm_unreachable("unexpected binary operation"); -} - -LogicalResult mlir::getStridesAndOffset(MemRefType t, - SmallVectorImpl &strides, - int64_t &offset) { +static LogicalResult getStridesAndOffset(MemRefType t, + SmallVectorImpl &strides, + AffineExpr &offset) { auto affineMaps = t.getAffineMaps(); // For now strides are only computed on a single affine map with a single // result (i.e. the closed subset of linearization maps that are compatible @@ -583,37 +527,29 @@ // TODO(ntv): support more forms on a per-need basis. if (affineMaps.size() > 1) return failure(); + if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) + return failure(); + + auto zero = getAffineConstantExpr(0, t.getContext()); + offset = zero; + strides.assign(t.getRank(), zero); + AffineExpr stridedExpr; - if (affineMaps.empty() || affineMaps[0].isIdentity()) { - if (t.getRank() == 0) { - // Handle 0-D corner case. - offset = 0; + if (affineMaps.empty()) { + // 0-D corner case, offset is already 0. + if (t.getRank() == 0) return success(); - } stridedExpr = makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); - } else if (affineMaps[0].getNumResults() == 1) { - stridedExpr = affineMaps[0].getResult(0); + } else { + assert(affineMaps[0].getNumResults() == 1); + stridedExpr = affineMaps.front().getResult(0); } - if (!stridedExpr) - return failure(); + assert(stridedExpr); bool failed = false; - strides = SmallVector(t.getRank(), 0); bool seenOffset = false; SmallVector seen(t.getRank(), false); - if (stridedExpr.isa()) { - stridedExpr.walk([&](AffineExpr e) { - if (!failed) - extractStrides(e, strides, offset, seen, seenOffset, failed); - }); - } else { - extractStridesFromTerm(stridedExpr, strides, offset, seen, seenOffset); - } - - // Constant offset may not be present in `stridedExpr` which means it is - // implicitly 0. - if (!seenOffset) - offset = 0; + extractStrides(stridedExpr, strides, offset, seen, seenOffset, failed); if (failed || !llvm::all_of(seen, [](bool b) { return b; })) { strides.clear(); @@ -623,6 +559,25 @@ return success(); } +LogicalResult mlir::getStridesAndOffset(MemRefType t, + SmallVectorImpl &strides, + int64_t &offset) { + AffineExpr offsetExpr; + SmallVector strideExprs; + if (!succeeded(::getStridesAndOffset(t, strideExprs, offsetExpr))) + return failure(); + if (auto cst = offsetExpr.dyn_cast()) + offset = cst.getValue(); + else + offset = ShapedType::kDynamicStrideOrOffset; + for (auto e : strideExprs) { + if (auto c = e.dyn_cast()) + strides.push_back(c.getValue()); + else + strides.push_back(ShapedType::kDynamicStrideOrOffset); + } + return success(); +} //===----------------------------------------------------------------------===// /// ComplexType //===----------------------------------------------------------------------===//