diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -456,126 +456,73 @@ auto sym = getAffineSymbolExpr(nSymbols++, context); expr = expr ? expr + d * sym : d * sym; } - return expr; -} - -// Factored out common logic to update `strides` and `seen` for `dim` with value -// `val`. This handles both saturated and unsaturated cases. -static void accumulateStrides(MutableArrayRef strides, - MutableArrayRef seen, unsigned pos, - int64_t val) { - if (!seen[pos]) { - // Newly seen case, sets value - strides[pos] = val; - seen[pos] = true; - return; - } - if (strides[pos] != MemRefType::getDynamicStrideOrOffset()) - // Already seen case accumulates unless they are already saturated. - strides[pos] += val; -} - -// This sums multiple offsets as they are seen. In the particular case of -// accumulating a dynamic offset with either a static of dynamic one, this -// saturates to MemRefType::getDynamicStrideOrOffset(). -static void accumulateOffset(int64_t &offset, bool &seenOffset, int64_t val) { - if (!seenOffset) { - // Newly seen case, sets value - offset = val; - seenOffset = true; - return; - } - if (offset != MemRefType::getDynamicStrideOrOffset()) - // Already seen case accumulates unless they are already saturated. - offset += val; + return simplifyAffineExpr(expr, rank, nSymbols); } -/// Takes a single AffineExpr `e` and populates the `strides` and `seen` arrays -/// with the strides values for each dim position and whether a value exists at -/// that position, respectively. +// Fallback cases for terminal dim/sym/cst that are not part of a binary op ( +// i.e. single term). Accumulate the AffineExpr into the existing one. +static void extractStridesFromTerm(AffineExpr e, + AffineExpr multiplicativeFactor, + MutableArrayRef strides, + AffineExpr &offset) { + if (auto dim = e.dyn_cast()) + strides[dim.getPosition()] = + strides[dim.getPosition()] + multiplicativeFactor; + else + offset = offset + e * multiplicativeFactor; +} + +/// Takes a single AffineExpr `e` and populates the `strides` array with the +/// strides expressions for each dim position. /// The convention is that the strides for dimensions d0, .. dn appear in /// order to make indexing intuitive into the result. -static void extractStrides(AffineExpr e, MutableArrayRef strides, - int64_t &offset, MutableArrayRef seen, - bool &seenOffset, bool &failed) { +static LogicalResult extractStrides(AffineExpr e, + AffineExpr multiplicativeFactor, + MutableArrayRef strides, + AffineExpr &offset) { auto bin = e.dyn_cast(); - if (!bin) - return; + if (!bin) { + extractStridesFromTerm(e, multiplicativeFactor, strides, offset); + return success(); + } if (bin.getKind() == AffineExprKind::CeilDiv || bin.getKind() == AffineExprKind::FloorDiv || - bin.getKind() == AffineExprKind::Mod) { - failed = true; - return; - } + bin.getKind() == AffineExprKind::Mod) + return failure(); + if (bin.getKind() == AffineExprKind::Mul) { - // LHS may be more complex than just a single dim (e.g. multiple syms and - // dims). Bail out for now and revisit when we have evidence this is needed. auto dim = bin.getLHS().dyn_cast(); - if (!dim) { - failed = true; - return; - } - auto cst = bin.getRHS().dyn_cast(); - if (!cst) { - strides[dim.getPosition()] = MemRefType::getDynamicStrideOrOffset(); - seen[dim.getPosition()] = true; - } else { - accumulateStrides(strides, seen, dim.getPosition(), cst.getValue()); + if (dim) { + strides[dim.getPosition()] = + strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; + return success(); } - return; + // LHS and RHS may both contain complex expressions of dims. Try one path + // and if it fails try the other. This is guaranteed to succeed because + // only one path may have a `dim`, otherwise this is not an AffineExpr in + // the first place. + if (bin.getLHS().isSymbolicOrConstant()) + return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), + strides, offset); + return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), + strides, offset); } + if (bin.getKind() == AffineExprKind::Add) { - for (auto e : {bin.getLHS(), bin.getRHS()}) { - if (auto cst = e.dyn_cast()) { - // Independent constants cumulate. - accumulateOffset(offset, seenOffset, cst.getValue()); - } else if (auto sym = e.dyn_cast()) { - // Independent symbols saturate. - offset = MemRefType::getDynamicStrideOrOffset(); - seenOffset = true; - } else if (auto dim = e.dyn_cast()) { - // Independent symbols cumulate 1. - accumulateStrides(strides, seen, dim.getPosition(), 1); - } - // Sum of binary ops dispatch to the respective exprs. - } - return; + auto res1 = + extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); + auto res2 = + extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); + return success(succeeded(res1) && succeeded(res2)); } - llvm_unreachable("unexpected binary operation"); -} -// Fallback cases for terminal dim/sym/cst that are not part of a binary op ( -// i.e. single term). -static void extractStridesFromTerm(AffineExpr e, - MutableArrayRef strides, - int64_t &offset, MutableArrayRef seen, - bool &seenOffset) { - if (auto cst = e.dyn_cast()) { - assert(!seenOffset && "unexpected `seen` bit with single term"); - offset = cst.getValue(); - seenOffset = true; - return; - } - if (auto sym = e.dyn_cast()) { - assert(!seenOffset && "unexpected `seen` bit with single term"); - offset = MemRefType::getDynamicStrideOrOffset(); - seenOffset = true; - return; - } - if (auto dim = e.dyn_cast()) { - assert(!seen[dim.getPosition()] && - "unexpected `seen` bit with single term"); - strides[dim.getPosition()] = 1; - seen[dim.getPosition()] = true; - return; - } llvm_unreachable("unexpected binary operation"); } -LogicalResult mlir::getStridesAndOffset(MemRefType t, - SmallVectorImpl &strides, - int64_t &offset) { +static LogicalResult getStridesAndOffset(MemRefType t, + SmallVectorImpl &strides, + AffineExpr &offset) { auto affineMaps = t.getAffineMaps(); // For now strides are only computed on a single affine map with a single // result (i.e. the closed subset of linearization maps that are compatible @@ -583,39 +530,58 @@ // TODO(ntv): support more forms on a per-need basis. if (affineMaps.size() > 1) return failure(); - AffineExpr stridedExpr; - if (affineMaps.empty() || affineMaps[0].isIdentity()) { - if (t.getRank() == 0) { - // Handle 0-D corner case. - offset = 0; - return success(); - } - stridedExpr = makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); - } else if (affineMaps[0].getNumResults() == 1) { - stridedExpr = affineMaps[0].getResult(0); - } - if (!stridedExpr) + if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) return failure(); - bool failed = false; - strides = SmallVector(t.getRank(), 0); - bool seenOffset = false; - SmallVector seen(t.getRank(), false); - if (stridedExpr.isa()) { - stridedExpr.walk([&](AffineExpr e) { - if (!failed) - extractStrides(e, strides, offset, seen, seenOffset, failed); - }); - } else { - extractStridesFromTerm(stridedExpr, strides, offset, seen, seenOffset); + auto zero = getAffineConstantExpr(0, t.getContext()); + auto one = getAffineConstantExpr(1, t.getContext()); + offset = zero; + strides.assign(t.getRank(), zero); + + AffineMap m; + if (!affineMaps.empty()) { + m = affineMaps.front(); + assert(!m.isIdentity() && "unexpected identity map"); } - // Constant offset may not be present in `stridedExpr` which means it is - // implicitly 0. - if (!seenOffset) - offset = 0; + // Canonical case for empty map. + if (!m) { + // 0-D corner case, offset is already 0. + if (t.getRank() == 0) + return success(); + auto stridedExpr = + makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); + if (succeeded(extractStrides(stridedExpr, one, strides, offset))) + return success(); + assert(false && "unexpected failure: extract strides in canonical layout"); + } + + // Non-canonical case requires more work. + auto stridedExpr = + simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); + if (failed(extractStrides(stridedExpr, one, strides, offset))) { + offset = AffineExpr(); + strides.clear(); + return failure(); + } - if (failed || !llvm::all_of(seen, [](bool b) { return b; })) { + // Simplify results to allow folding to constants and simple checks. + unsigned numDims = m.getNumDims(); + unsigned numSymbols = m.getNumSymbols(); + offset = simplifyAffineExpr(offset, numDims, numSymbols); + for (auto &stride : strides) + stride = simplifyAffineExpr(stride, numDims, numSymbols); + + /// In practice, a strided memref must be internally non-aliasing. Test + /// against 0 as a proxy. + /// TODO(ntv) static cases can have more advanced checks. + /// TODO(ntv) dynamic cases would require a way to compare symbolic + /// expressions and would probably need an affine set context propagated + /// everywhere. + if (llvm::any_of(strides, [](AffineExpr e) { + return e == getAffineConstantExpr(0, e.getContext()); + })) { + offset = AffineExpr(); strides.clear(); return failure(); } @@ -623,6 +589,26 @@ return success(); } +LogicalResult mlir::getStridesAndOffset(MemRefType t, + SmallVectorImpl &strides, + int64_t &offset) { + AffineExpr offsetExpr; + SmallVector strideExprs; + if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) + return failure(); + if (auto cst = offsetExpr.dyn_cast()) + offset = cst.getValue(); + else + offset = ShapedType::kDynamicStrideOrOffset; + for (auto e : strideExprs) { + if (auto c = e.dyn_cast()) + strides.push_back(c.getValue()); + else + strides.push_back(ShapedType::kDynamicStrideOrOffset); + } + return success(); +} + //===----------------------------------------------------------------------===// /// ComplexType //===----------------------------------------------------------------------===// diff --git a/mlir/test/AffineOps/memref-stride-calculation.mlir b/mlir/test/AffineOps/memref-stride-calculation.mlir --- a/mlir/test/AffineOps/memref-stride-calculation.mlir +++ b/mlir/test/AffineOps/memref-stride-calculation.mlir @@ -67,5 +67,15 @@ // CHECK: MemRefType memref<3x4x5xf32, (d0, d1, d2) -> (d0 ceildiv 4 + d1 + d2)> cannot be converted to strided form %103 = alloc() : memref<3x4x5xf32, (i, j, k)->(i mod 4 + j + k)> // CHECK: MemRefType memref<3x4x5xf32, (d0, d1, d2) -> (d0 mod 4 + d1 + d2)> cannot be converted to strided form + + %200 = alloc()[%0, %0, %0] : memref<3x4x5xf32, (i, j, k)[M, N, K]->(M * i + N * i + N * j + K * k - (M + N - 20)* i)> + // CHECK: MemRefType offset: 0 strides: 20, ?, ? + %201 = alloc()[%0, %0, %0] : memref<3x4x5xf32, (i, j, k)[M, N, K]->(M * i + N * i + N * K * j + K * K * k - (M + N - 20) * (i + 1))> + // CHECK: MemRefType offset: ? strides: 20, ?, ? + %202 = alloc()[%0, %0, %0] : memref<3x4x5xf32, (i, j, k)[M, N, K]->(M * (i + 1) + j + k - M)> + // CHECK: MemRefType offset: 0 strides: ?, 1, 1 + %203 = alloc()[%0, %0, %0] : memref<3x4x5xf32, (i, j, k)[M, N, K]->(M + M * (i + N * (j + K * k)))> + // CHECK: MemRefType offset: ? strides: ?, ?, ? + return }