diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -88,6 +88,11 @@ static constexpr int64_t kDynamicStrideOrOffset = std::numeric_limits::min(); + /// Return clone of this type with new shape and element type. + ShapedType clone(ArrayRef shape, Type elementType); + ShapedType clone(ArrayRef shape); + ShapedType clone(Type elementType); + /// Return the element type. Type getElementType() const; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -197,6 +197,75 @@ constexpr int64_t ShapedType::kDynamicSize; constexpr int64_t ShapedType::kDynamicStrideOrOffset; +ShapedType ShapedType::clone(ArrayRef shape, Type elementType) { + if (auto other = dyn_cast()) { + MemRefType::Builder b(other); + b.setShape(shape); + b.setElementType(elementType); + return b; + } + + if (auto other = dyn_cast()) { + MemRefType::Builder b(shape, elementType); + b.setMemorySpace(other.getMemorySpace()); + return b; + } + + if (isa()) + return RankedTensorType::get(shape, elementType); + + if (isa()) + return VectorType::get(shape, elementType); + + llvm_unreachable("Unhandled ShapedType clone case"); +} + +ShapedType ShapedType::clone(ArrayRef shape) { + if (auto other = dyn_cast()) { + MemRefType::Builder b(other); + b.setShape(shape); + return b; + } + + if (auto other = dyn_cast()) { + MemRefType::Builder b(shape, other.getElementType()); + b.setShape(shape); + b.setMemorySpace(other.getMemorySpace()); + return b; + } + + if (isa()) + return RankedTensorType::get(shape, getElementType()); + + if (isa()) + return VectorType::get(shape, getElementType()); + + llvm_unreachable("Unhandled ShapedType clone case"); +} + +ShapedType ShapedType::clone(Type elementType) { + if (auto other = dyn_cast()) { + MemRefType::Builder b(other); + b.setElementType(elementType); + return b; + } + + if (auto other = dyn_cast()) { + return UnrankedMemRefType::get(elementType, other.getMemorySpace()); + } + + if (isa()) { + if (hasRank()) + return RankedTensorType::get(getShape(), elementType); + return UnrankedTensorType::get(elementType); + } + + if (isa()) { + return VectorType::get(getShape(), elementType); + } + llvm_unreachable("Unhandled ShapedType clone hit"); +} + Type ShapedType::getElementType() const { return static_cast(impl)->elementType; } diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -2,6 +2,7 @@ AttributeTest.cpp DialectTest.cpp OperationSupportTest.cpp + ShapedTypeTest.cpp ) target_link_libraries(MLIRIRTests PRIVATE diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/IR/ShapedTypeTest.cpp @@ -0,0 +1,129 @@ +//===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectInterface.h" +#include "llvm/ADT/SmallVector.h" +#include "gtest/gtest.h" +#include + +using namespace mlir; +using namespace mlir::detail; + +namespace { +TEST(ShapedTypeTest, CloneMemref) { + MLIRContext context; + + Type i32 = IntegerType::get(&context, 32); + Type f32 = FloatType::getF32(&context); + int memSpace = 7; + Type memrefOriginalType = i32; + llvm::SmallVector memrefOriginalShape({10, 20}); + AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context); + + ShapedType memrefType = + MemRefType::Builder(memrefOriginalShape, memrefOriginalType) + .setMemorySpace(memSpace) + .setAffineMaps(map); + // Update shape. + llvm::SmallVector memrefNewShape({30, 40}); + ASSERT_NE(memrefOriginalShape, memrefNewShape); + ASSERT_EQ(memrefType.clone(memrefNewShape), + (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType) + .setMemorySpace(memSpace) + .setAffineMaps(map)); + // Update type. + Type memrefNewType = f32; + ASSERT_NE(memrefOriginalType, memrefNewType); + ASSERT_EQ(memrefType.clone(memrefNewType), + (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType) + .setMemorySpace(memSpace) + .setAffineMaps(map)); + // Update both. + ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType), + (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) + .setMemorySpace(memSpace) + .setAffineMaps(map)); + + // Test unranked memref cloning. + ShapedType unrankedTensorType = + UnrankedMemRefType::get(memrefOriginalType, memSpace); + ASSERT_EQ(unrankedTensorType.clone(memrefNewShape), + (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType) + .setMemorySpace(memSpace)); + ASSERT_EQ(unrankedTensorType.clone(memrefNewType), + UnrankedMemRefType::get(memrefNewType, memSpace)); + ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType), + (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType) + .setMemorySpace(memSpace)); +} + +TEST(ShapedTypeTest, CloneTensor) { + MLIRContext context; + + Type i32 = IntegerType::get(&context, 32); + Type f32 = FloatType::getF32(&context); + + Type tensorOriginalType = i32; + llvm::SmallVector tensorOriginalShape({10, 20}); + + // Test ranked tensor cloning. + ShapedType tensorType = + RankedTensorType::get(tensorOriginalShape, tensorOriginalType); + // Update shape. + llvm::SmallVector tensorNewShape({30, 40}); + ASSERT_NE(tensorOriginalShape, tensorNewShape); + ASSERT_EQ(tensorType.clone(tensorNewShape), + RankedTensorType::get(tensorNewShape, tensorOriginalType)); + // Update type. + Type tensorNewType = f32; + ASSERT_NE(tensorOriginalType, tensorNewType); + ASSERT_EQ(tensorType.clone(tensorNewType), + RankedTensorType::get(tensorOriginalShape, tensorNewType)); + // Update both. + ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType), + RankedTensorType::get(tensorNewShape, tensorNewType)); + + // Test unranked tensor cloning. + ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType); + ASSERT_EQ(unrankedTensorType.clone(tensorNewShape), + RankedTensorType::get(tensorNewShape, tensorOriginalType)); + ASSERT_EQ(unrankedTensorType.clone(tensorNewType), + UnrankedTensorType::get(tensorNewType)); + ASSERT_EQ(unrankedTensorType.clone(tensorNewShape), + RankedTensorType::get(tensorNewShape, tensorOriginalType)); +} + +TEST(ShapedTypeTest, CloneVector) { + MLIRContext context; + + Type i32 = IntegerType::get(&context, 32); + Type f32 = FloatType::getF32(&context); + + Type vectorOriginalType = i32; + llvm::SmallVector vectorOriginalShape({10, 20}); + ShapedType vectorType = + VectorType::get(vectorOriginalShape, vectorOriginalType); + // Update shape. + llvm::SmallVector vectorNewShape({30, 40}); + ASSERT_NE(vectorOriginalShape, vectorNewShape); + ASSERT_EQ(vectorType.clone(vectorNewShape), + VectorType::get(vectorNewShape, vectorOriginalType)); + // Update type. + Type vectorNewType = f32; + ASSERT_NE(vectorOriginalType, vectorNewType); + ASSERT_EQ(vectorType.clone(vectorNewType), + VectorType::get(vectorOriginalShape, vectorNewType)); + // Update both. + ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType), + VectorType::get(vectorNewShape, vectorNewType)); +} + +} // end namespace