diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2164,9 +2164,8 @@ if (!dimsToProject.test(pos)) projectedShape.push_back(shape[pos]); - AffineMap map = inferredType.getLayout().getAffineMap(); - if (!map.isIdentity()) - map = getProjectedMap(map, dimsToProject); + AffineMap map = + getProjectedMap(inferredType.getLayout().getAffineMap(), dimsToProject); inferredType = MemRefType::get(projectedShape, inferredType.getElementType(), map, inferredType.getMemorySpace()); diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -7,6 +7,8 @@ MLIRDialect) add_subdirectory(Affine) +add_subdirectory(MemRef) + add_subdirectory(Quant) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRMemRefTests + InferShapeTest.cpp +) +target_link_libraries(MLIRMemRefTests + PRIVATE + MLIRMemRefDialect + ) diff --git a/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp b/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/MemRef/InferShapeTest.cpp @@ -0,0 +1,60 @@ +//===- InferShapeTest.cpp - unit tests for shape inference ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::memref; + +// Source memref has identity layout. +TEST(InferShapeTest, inferRankReducedShapeIdentity) { + MLIRContext ctx; + OpBuilder b(&ctx); + auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType()); + auto reducedType = SubViewOp::inferRankReducedResultType( + /*resultRank=*/1, sourceMemref, {2, 3}, {1, 2}, {1, 1}); + AffineExpr dim0; + bindDims(&ctx, dim0); + auto expectedType = + MemRefType::get({2}, b.getIndexType(), AffineMap::get(1, 0, dim0 + 13)); + EXPECT_EQ(reducedType, expectedType); +} + +// Source memref has non-identity layout. +TEST(InferShapeTest, inferRankReducedShapeNonIdentity) { + MLIRContext ctx; + OpBuilder b(&ctx); + AffineExpr dim0, dim1; + bindDims(&ctx, dim0, dim1); + auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(), + AffineMap::get(2, 0, 1000 * dim0 + dim1)); + auto reducedType = SubViewOp::inferRankReducedResultType( + /*resultRank=*/1, sourceMemref, {2, 3}, {1, 2}, {1, 1}); + auto expectedType = + MemRefType::get({2}, b.getIndexType(), AffineMap::get(1, 0, dim0 + 2003)); + EXPECT_EQ(reducedType, expectedType); +} + +TEST(InferShapeTest, inferRankReducedShapeToScalar) { + MLIRContext ctx; + OpBuilder b(&ctx); + AffineExpr dim0, dim1; + bindDims(&ctx, dim0, dim1); + auto sourceMemref = MemRefType::get({10, 5}, b.getIndexType(), + AffineMap::get(2, 0, 1000 * dim0 + dim1)); + auto reducedType = SubViewOp::inferRankReducedResultType( + /*resultRank=*/0, sourceMemref, {2, 3}, {1, 1}, {1, 1}); + auto expectedType = + MemRefType::get({}, b.getIndexType(), + AffineMap::get(0, 0, b.getAffineConstantExpr(2003))); + EXPECT_EQ(reducedType, expectedType); +} diff --git a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel @@ -105,6 +105,20 @@ ], ) +cc_test( + name = "memref_tests", + size = "small", + srcs = glob([ + "Dialect/MemRef/*.cpp", + "Dialect/MemRef/*.h", + ]), + deps = [ + "//llvm:TestingSupport", + "//llvm:gtest_main", + "//mlir:MemRefDialect", + ], +) + cc_test( name = "quantops_tests", size = "small",