diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -319,7 +319,7 @@ //===----------------------------------------------------------------------===// /// Returns the strides of the MemRef if the layout map is in strided form. -/// MemRefs with layout maps in strided form include: +/// MemRefs with a layout map in strided form include: /// 1. empty or identity layout map, in which case the stride information is /// the canonical form computed from sizes; /// 2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`, @@ -328,14 +328,13 @@ /// A stride specification is a list of integer values that are either static /// or dynamic (encoded with getDynamicStrideOrOffset()). Strides encode the /// distance in the number of elements between successive entries along a -/// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>` -/// specifies a view into a non-contiguous memory region of `42` by `16` `f32` -/// elements in which the distance between two consecutive elements along the -/// outer dimension is `1` and the distance between two consecutive elements -/// along the inner dimension is `64`. +/// particular dimension. /// -/// Returns whether a simple strided form can be extracted from the composition -/// of the layout map. +/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a +/// non-contiguous memory region of `42` by `16` `f32` elements in which the +/// distance between two consecutive elements along the outer dimension is `1` +/// and the distance between two consecutive elements along the inner dimension +/// is `64`. /// /// The convention is that the strides for dimensions d0, .. dn appear in /// order to make indexing intuitive into the result. diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -509,14 +509,16 @@ ##### Strided MemRef - A memref may specify strides as part of its type. A stride specification is - a list of integer values that are either static or `?` (dynamic case). + A memref may specify a strided layout as part of its type. A stride + specification is a list of integer values that are either static or `?` + (dynamic case). Strides encode the distance, in number of elements, in (linear) memory between successive entries along a particular dimension. A stride specification is syntactic sugar for an equivalent strided memref - representation using semi-affine maps. For example, - `memref<42x16xf32, offset: 33, strides: [1, 64]>` specifies a non-contiguous - memory region of `42` by `16` `f32` elements such that: + representation with a *single* semi-affine map. + + For example, `memref<42x16xf32, offset: 33, strides: [1, 64]>` specifies a + non-contiguous memory region of `42` by `16` `f32` elements such that: 1. the minimal size of the enclosing memory region must be `33 + 42 * 1 + 16 * 64 = 1066` elements; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -773,16 +773,13 @@ AffineExpr &offset) { auto affineMaps = t.getAffineMaps(); + if (affineMaps.size() > 1) + return failure(); + if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1) return failure(); - AffineMap m; - if (!affineMaps.empty()) { - m = affineMaps.back(); - for (size_t i = affineMaps.size() - 1; i > 0; --i) - m = m.compose(affineMaps[i - 1]); - assert(!m.isIdentity() && "unexpected identity map"); - } + AffineMap m = affineMaps.empty() ? AffineMap() : affineMaps.back(); auto zero = getAffineConstantExpr(0, t.getContext()); auto one = getAffineConstantExpr(1, t.getContext()); @@ -790,7 +787,7 @@ strides.assign(t.getRank(), zero); // Canonical case for empty map. - if (!m) { + if (!m || m.isIdentity()) { // 0-D corner case, offset is already 0. if (t.getRank() == 0) return success(); diff --git a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir --- a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir +++ b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir @@ -59,9 +59,6 @@ %30 = memref.alloc()[%0] : memref(123)>> // CHECK: MemRefType offset: 123 strides: - %100 = memref.alloc(%0, %0)[%0, %0] : memref(i + j, j, k)>, affine_map<(i, j, k)[M, N]->(M * i + N * j + k + 1)>> - // CHECK: MemRefType offset: 1 strides: ?, ?, 1 - %101 = memref.alloc() : memref<3x4x5xf32, affine_map<(i, j, k)->(i floordiv 4 + j + k)>> // CHECK: MemRefType memref<3x4x5xf32, affine_map<(d0, d1, d2) -> (d0 floordiv 4 + d1 + d2)>> cannot be converted to strided form %102 = memref.alloc() : memref<3x4x5xf32, affine_map<(i, j, k)->(i ceildiv 4 + j + k)>> diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -2,7 +2,6 @@ AttributeTest.cpp DialectTest.cpp InterfaceAttachmentTest.cpp - MemRefTypeTest.cpp OperationSupportTest.cpp ShapedTypeTest.cpp SubElementInterfaceTest.cpp diff --git a/mlir/unittests/IR/MemRefTypeTest.cpp b/mlir/unittests/IR/MemRefTypeTest.cpp deleted file mode 100644 --- a/mlir/unittests/IR/MemRefTypeTest.cpp +++ /dev/null @@ -1,50 +0,0 @@ -//===- MemRefTypeTest.cpp - MemRefType unit tests -------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinTypes.h" -#include "gtest/gtest.h" - -using namespace mlir; -using namespace mlir::detail; - -namespace { - -TEST(MemRefTypeTest, GetStridesAndOffset) { - MLIRContext context; - - SmallVector shape({2, 3, 4}); - Type f32 = FloatType::getF32(&context); - - AffineMap map1 = makeStridedLinearLayoutMap({12, 4, 1}, 5, &context); - MemRefType type1 = MemRefType::get(shape, f32, {map1}); - SmallVector strides1; - int64_t offset1 = -1; - LogicalResult res1 = getStridesAndOffset(type1, strides1, offset1); - ASSERT_TRUE(res1.succeeded()); - ASSERT_EQ(3u, strides1.size()); - EXPECT_EQ(12, strides1[0]); - EXPECT_EQ(4, strides1[1]); - EXPECT_EQ(1, strides1[2]); - ASSERT_EQ(5, offset1); - - AffineMap map2 = AffineMap::getPermutationMap({1, 2, 0}, &context); - AffineMap map3 = makeStridedLinearLayoutMap({8, 2, 1}, 0, &context); - MemRefType type2 = MemRefType::get(shape, f32, {map2, map3}); - SmallVector strides2; - int64_t offset2 = -1; - LogicalResult res2 = getStridesAndOffset(type2, strides2, offset2); - ASSERT_TRUE(res2.succeeded()); - ASSERT_EQ(3u, strides2.size()); - EXPECT_EQ(1, strides2[0]); - EXPECT_EQ(8, strides2[1]); - EXPECT_EQ(2, strides2[2]); - ASSERT_EQ(0, offset2); -} - -} // end namespace