diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -673,26 +673,23 @@ SmallVectorImpl &strides, AffineExpr &offset) { auto affineMaps = t.getAffineMaps(); - // For now strides are only computed on a single affine map with a single - // result (i.e. the closed subset of linearization maps that are compatible - // with striding semantics). - // TODO: support more forms on a per-need basis. - if (affineMaps.size() > 1) - return failure(); - if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1) - return failure(); - auto zero = getAffineConstantExpr(0, t.getContext()); - auto one = getAffineConstantExpr(1, t.getContext()); - offset = zero; - strides.assign(t.getRank(), zero); + if (!affineMaps.empty() && affineMaps.back().getNumResults() != 1) + return failure(); AffineMap m; if (!affineMaps.empty()) { - m = affineMaps.front(); + m = affineMaps.back(); + for (size_t i = affineMaps.size() - 1; i > 0; --i) + m = m.compose(affineMaps[i - 1]); assert(!m.isIdentity() && "unexpected identity map"); } + auto zero = getAffineConstantExpr(0, t.getContext()); + auto one = getAffineConstantExpr(1, t.getContext()); + offset = zero; + strides.assign(t.getRank(), zero); + // Canonical case for empty map. if (!m) { // 0-D corner case, offset is already 0. diff --git a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir --- a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir +++ b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir @@ -60,7 +60,8 @@ // CHECK: MemRefType offset: 123 strides: %100 = memref.alloc(%0, %0)[%0, %0] : memref(i + j, j, k)>, affine_map<(i, j, k)[M, N]->(M * i + N * j + k + 1)>> -// CHECK: MemRefType memref (d0 + d1, d1, d2)>, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2 + 1)>> cannot be converted to strided form + // CHECK: MemRefType offset: 1 strides: ?, ?, 1 + %101 = memref.alloc() : memref<3x4x5xf32, affine_map<(i, j, k)->(i floordiv 4 + j + k)>> // CHECK: MemRefType memref<3x4x5xf32, affine_map<(d0, d1, d2) -> (d0 floordiv 4 + d1 + d2)>> cannot be converted to strided form %102 = memref.alloc() : memref<3x4x5xf32, affine_map<(i, j, k)->(i ceildiv 4 + j + k)>> diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_unittest(MLIRIRTests AttributeTest.cpp DialectTest.cpp + MemRefTypeTest.cpp OperationSupportTest.cpp ShapedTypeTest.cpp ) diff --git a/mlir/unittests/IR/MemRefTypeTest.cpp b/mlir/unittests/IR/MemRefTypeTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/IR/MemRefTypeTest.cpp @@ -0,0 +1,50 @@ +//===- MemRefTypeTest.cpp - MemRefType unit tests -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypes.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::detail; + +namespace { + +TEST(MemRefTypeTest, GetStridesAndOffset) { + MLIRContext context; + + SmallVector shape({2, 3, 4}); + Type f32 = FloatType::getF32(&context); + + AffineMap map1 = makeStridedLinearLayoutMap({12, 4, 1}, 5, &context); + MemRefType type1 = MemRefType::get(shape, f32, {map1}); + SmallVector strides1; + int64_t offset1 = -1; + LogicalResult res1 = getStridesAndOffset(type1, strides1, offset1); + ASSERT_TRUE(res1.succeeded()); + ASSERT_EQ(3, strides1.size()); + EXPECT_EQ(12, strides1[0]); + EXPECT_EQ(4, strides1[1]); + EXPECT_EQ(1, strides1[2]); + ASSERT_EQ(5, offset1); + + AffineMap map2 = AffineMap::getPermutationMap({1, 2, 0}, &context); + AffineMap map3 = makeStridedLinearLayoutMap({8, 2, 1}, 0, &context); + MemRefType type2 = MemRefType::get(shape, f32, {map2, map3}); + SmallVector strides2; + int64_t offset2 = -1; + LogicalResult res2 = getStridesAndOffset(type2, strides2, offset2); + ASSERT_TRUE(res2.succeeded()); + ASSERT_EQ(3, strides2.size()); + EXPECT_EQ(1, strides2[0]); + EXPECT_EQ(8, strides2[1]); + EXPECT_EQ(2, strides2[2]); + ASSERT_EQ(0, offset2); +} + +} // end namespace