diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -456,126 +456,59 @@ auto sym = getAffineSymbolExpr(nSymbols++, context); expr = expr ? expr + d * sym : d * sym; } - return expr; -} - -// Factored out common logic to update `strides` and `seen` for `dim` with value -// `val`. This handles both saturated and unsaturated cases. -static void accumulateStrides(MutableArrayRef strides, - MutableArrayRef seen, unsigned pos, - int64_t val) { - if (!seen[pos]) { - // Newly seen case, sets value - strides[pos] = val; - seen[pos] = true; - return; - } - if (strides[pos] != MemRefType::getDynamicStrideOrOffset()) - // Already seen case accumulates unless they are already saturated. - strides[pos] += val; -} - -// This sums multiple offsets as they are seen. In the particular case of -// accumulating a dynamic offset with either a static of dynamic one, this -// saturates to MemRefType::getDynamicStrideOrOffset(). -static void accumulateOffset(int64_t &offset, bool &seenOffset, int64_t val) { - if (!seenOffset) { - // Newly seen case, sets value - offset = val; - seenOffset = true; - return; - } - if (offset != MemRefType::getDynamicStrideOrOffset()) - // Already seen case accumulates unless they are already saturated. - offset += val; + return simplifyAffineExpr(expr, rank, nSymbols); +} + +// Fallback cases for terminal dim/sym/cst that are not part of a binary op ( +// i.e. single term). Accumulate the AffineExpr into the existing one. +static LogicalResult extractStridesFromTerm(AffineExpr e, + MutableArrayRef strides, + AffineExpr &offset) { + if (auto dim = e.dyn_cast()) + strides[dim.getPosition()] = strides[dim.getPosition()] + 1; + else + offset = offset + e; + return success(); } -/// Takes a single AffineExpr `e` and populates the `strides` and `seen` arrays -/// with the strides values for each dim position and whether a value exists at -/// that position, respectively. +/// Takes a single AffineExpr `e` and populates the `strides` array with the +/// strides expressions for each dim position. /// The convention is that the strides for dimensions d0, .. dn appear in /// order to make indexing intuitive into the result. -static void extractStrides(AffineExpr e, MutableArrayRef strides, - int64_t &offset, MutableArrayRef seen, - bool &seenOffset, bool &failed) { +static LogicalResult extractStrides(AffineExpr e, + MutableArrayRef strides, + AffineExpr &offset) { auto bin = e.dyn_cast(); if (!bin) - return; + return extractStridesFromTerm(e, strides, offset); if (bin.getKind() == AffineExprKind::CeilDiv || bin.getKind() == AffineExprKind::FloorDiv || - bin.getKind() == AffineExprKind::Mod) { - failed = true; - return; - } + bin.getKind() == AffineExprKind::Mod) + return failure(); + if (bin.getKind() == AffineExprKind::Mul) { // LHS may be more complex than just a single dim (e.g. multiple syms and // dims). Bail out for now and revisit when we have evidence this is needed. auto dim = bin.getLHS().dyn_cast(); - if (!dim) { - failed = true; - return; - } - auto cst = bin.getRHS().dyn_cast(); - if (!cst) { - strides[dim.getPosition()] = MemRefType::getDynamicStrideOrOffset(); - seen[dim.getPosition()] = true; - } else { - accumulateStrides(strides, seen, dim.getPosition(), cst.getValue()); - } - return; + if (!dim) + return failure(); + strides[dim.getPosition()] = strides[dim.getPosition()] + bin.getRHS(); + return success(); } + if (bin.getKind() == AffineExprKind::Add) { - for (auto e : {bin.getLHS(), bin.getRHS()}) { - if (auto cst = e.dyn_cast()) { - // Independent constants cumulate. - accumulateOffset(offset, seenOffset, cst.getValue()); - } else if (auto sym = e.dyn_cast()) { - // Independent symbols saturate. - offset = MemRefType::getDynamicStrideOrOffset(); - seenOffset = true; - } else if (auto dim = e.dyn_cast()) { - // Independent symbols cumulate 1. - accumulateStrides(strides, seen, dim.getPosition(), 1); - } - // Sum of binary ops dispatch to the respective exprs. - } - return; + auto res1 = extractStrides(bin.getLHS(), strides, offset); + auto res2 = extractStrides(bin.getRHS(), strides, offset); + return success(succeeded(res1) && succeeded(res2)); } - llvm_unreachable("unexpected binary operation"); -} -// Fallback cases for terminal dim/sym/cst that are not part of a binary op ( -// i.e. single term). -static void extractStridesFromTerm(AffineExpr e, - MutableArrayRef strides, - int64_t &offset, MutableArrayRef seen, - bool &seenOffset) { - if (auto cst = e.dyn_cast()) { - assert(!seenOffset && "unexpected `seen` bit with single term"); - offset = cst.getValue(); - seenOffset = true; - return; - } - if (auto sym = e.dyn_cast()) { - assert(!seenOffset && "unexpected `seen` bit with single term"); - offset = MemRefType::getDynamicStrideOrOffset(); - seenOffset = true; - return; - } - if (auto dim = e.dyn_cast()) { - assert(!seen[dim.getPosition()] && - "unexpected `seen` bit with single term"); - strides[dim.getPosition()] = 1; - seen[dim.getPosition()] = true; - return; - } llvm_unreachable("unexpected binary operation"); } -LogicalResult mlir::getStridesAndOffset(MemRefType t, - SmallVectorImpl &strides, - int64_t &offset) { +static LogicalResult getStridesAndOffset(MemRefType t, + SmallVectorImpl &strides, + AffineExpr &offset) { auto affineMaps = t.getAffineMaps(); // For now strides are only computed on a single affine map with a single // result (i.e. the closed subset of linearization maps that are compatible @@ -583,39 +516,58 @@ // TODO(ntv): support more forms on a per-need basis. if (affineMaps.size() > 1) return failure(); - AffineExpr stridedExpr; - if (affineMaps.empty() || affineMaps[0].isIdentity()) { - if (t.getRank() == 0) { - // Handle 0-D corner case. - offset = 0; - return success(); - } - stridedExpr = makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); - } else if (affineMaps[0].getNumResults() == 1) { - stridedExpr = affineMaps[0].getResult(0); - } - if (!stridedExpr) + if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) return failure(); - bool failed = false; - strides = SmallVector(t.getRank(), 0); - bool seenOffset = false; - SmallVector seen(t.getRank(), false); - if (stridedExpr.isa()) { - stridedExpr.walk([&](AffineExpr e) { - if (!failed) - extractStrides(e, strides, offset, seen, seenOffset, failed); - }); - } else { - extractStridesFromTerm(stridedExpr, strides, offset, seen, seenOffset); + auto zero = getAffineConstantExpr(0, t.getContext()); + offset = zero; + strides.assign(t.getRank(), zero); + + AffineMap m; + if (!affineMaps.empty()) + m = affineMaps.front(); + + // Canonical case for empty/identity map. + if (!m || m.isIdentity()) { + // 0-D corner case, offset is already 0. + if (t.getRank() == 0) + return success(); + auto stridedExpr = + makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); + if (succeeded(extractStrides(stridedExpr, strides, offset))) + return success(); + assert(false && "unexpected failure: extract strides in canonical layout"); } - // Constant offset may not be present in `stridedExpr` which means it is - // implicitly 0. - if (!seenOffset) - offset = 0; + // Non-canonical case requires more work. + auto stridedExpr = + simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); + if (failed(extractStrides(stridedExpr, strides, offset))) { + offset = AffineExpr(); + strides.clear(); + return failure(); + } - if (failed || !llvm::all_of(seen, [](bool b) { return b; })) { + if (m.isIdentity()) + return success(); + + // Simplify results to allow folding to constants and simple checks. + unsigned numDims = m.getNumDims(); + unsigned numSymbols = m.getNumSymbols(); + offset = simplifyAffineExpr(offset, numDims, numSymbols); + for (auto &stride : strides) + stride = simplifyAffineExpr(stride, numDims, numSymbols); + + /// In practice, a strided memref must be internally non-aliasing. Test + /// against 0 as a proxy. + /// TODO(ntv) static cases can have more advanced checks. + /// TODO(ntv) dynamic cases would require a way to compare symbolic + /// expressions and would probably need an affine set context propagated + /// everywhere. + if (llvm::any_of(strides, [](AffineExpr e) { + return e == getAffineConstantExpr(0, e.getContext()); + })) { + offset = AffineExpr(); strides.clear(); return failure(); } @@ -623,6 +575,26 @@ return success(); } +LogicalResult mlir::getStridesAndOffset(MemRefType t, + SmallVectorImpl &strides, + int64_t &offset) { + AffineExpr offsetExpr; + SmallVector strideExprs; + if (!succeeded(::getStridesAndOffset(t, strideExprs, offsetExpr))) + return failure(); + if (auto cst = offsetExpr.dyn_cast()) + offset = cst.getValue(); + else + offset = ShapedType::kDynamicStrideOrOffset; + for (auto e : strideExprs) { + if (auto c = e.dyn_cast()) + strides.push_back(c.getValue()); + else + strides.push_back(ShapedType::kDynamicStrideOrOffset); + } + return success(); +} + //===----------------------------------------------------------------------===// /// ComplexType //===----------------------------------------------------------------------===//