diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -31,6 +31,7 @@ struct VectorTypeStorage; struct RankedTensorTypeStorage; struct UnrankedTensorTypeStorage; +struct BaseMemRefTypeStorage; struct MemRefTypeStorage; struct UnrankedMemRefTypeStorage; struct ComplexTypeStorage; @@ -451,6 +452,7 @@ /// Base MemRef for Ranked and Unranked variants class BaseMemRefType : public ShapedType { public: + using ImplType = detail::BaseMemRefTypeStorage; using ShapedType::ShapedType; /// Return true if the specified element type is ok in a memref. @@ -460,6 +462,9 @@ /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Type type); + + /// Returns the memory space in which data referred to by this memref resides. + unsigned getMemorySpace() const; }; //===----------------------------------------------------------------------===// @@ -544,9 +549,6 @@ /// map composition. ArrayRef getAffineMaps() const; - /// Returns the memory space in which data referred to by this memref resides. - unsigned getMemorySpace() const; - // TODO: merge these two special values in a single one used everywhere. // Unfortunately, uses of `-1` have crept deep into the codebase now and are // hard to track. @@ -592,9 +594,6 @@ unsigned memorySpace); ArrayRef getShape() const { return llvm::None; } - - /// Returns the memory space in which data referred to by this memref resides. - unsigned getMemorySpace() const; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -359,6 +359,14 @@ return checkTensorElementType(loc, elementType); } +//===----------------------------------------------------------------------===// +// BaseMemRefType +//===----------------------------------------------------------------------===// + +unsigned BaseMemRefType::getMemorySpace() const { + return static_cast(impl)->memorySpace; +} + //===----------------------------------------------------------------------===// // MemRefType //===----------------------------------------------------------------------===// @@ -449,8 +457,6 @@ return getImpl()->getAffineMaps(); } -unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; } - //===----------------------------------------------------------------------===// // UnrankedMemRefType //===----------------------------------------------------------------------===// @@ -466,10 +472,6 @@ return Base::getChecked(location, elementType, memorySpace); } -unsigned UnrankedMemRefType::getMemorySpace() const { - return getImpl()->memorySpace; -} - LogicalResult UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, unsigned memorySpace) { diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -207,13 +207,25 @@ } }; -struct MemRefTypeStorage : public ShapedTypeStorage { +struct BaseMemRefTypeStorage : public ShapedTypeStorage { + BaseMemRefTypeStorage(Type elementType, unsigned memorySpace) + : ShapedTypeStorage(elementType), memorySpace(memorySpace) {} + + /// The hash key used for uniquing. + using KeyTy = unsigned; + bool operator==(const KeyTy &key) const { return key == memorySpace; } + + /// Memory space in which data referenced by memref resides. + const unsigned memorySpace; +}; + +struct MemRefTypeStorage : public BaseMemRefTypeStorage { MemRefTypeStorage(unsigned shapeSize, Type elementType, const int64_t *shapeElements, const unsigned numAffineMaps, AffineMap const *affineMapList, const unsigned memorySpace) - : ShapedTypeStorage(elementType), shapeElements(shapeElements), - shapeSize(shapeSize), numAffineMaps(numAffineMaps), - affineMapList(affineMapList), memorySpace(memorySpace) {} + : BaseMemRefTypeStorage(elementType, memorySpace), + shapeElements(shapeElements), shapeSize(shapeSize), + numAffineMaps(numAffineMaps), affineMapList(affineMapList) {} /// The hash key used for uniquing. // MemRefs are uniqued based on their shape, element type, affine map @@ -257,16 +269,14 @@ const unsigned numAffineMaps; /// List of affine maps in the memref's layout/index map composition. AffineMap const *affineMapList; - /// Memory space in which data referenced by memref resides. - const unsigned memorySpace; }; /// Unranked MemRef is a MemRef with unknown rank. /// Only element type and memory space are known -struct UnrankedMemRefTypeStorage : public ShapedTypeStorage { +struct UnrankedMemRefTypeStorage : public BaseMemRefTypeStorage { UnrankedMemRefTypeStorage(Type elementTy, const unsigned memorySpace) - : ShapedTypeStorage(elementTy), memorySpace(memorySpace) {} + : BaseMemRefTypeStorage(elementTy, memorySpace) {} /// The hash key used for uniquing. using KeyTy = std::tuple; @@ -282,8 +292,6 @@ return new (allocator.allocate()) UnrankedMemRefTypeStorage(std::get<0>(key), std::get<1>(key)); } - /// Memory space in which data referenced by memref resides. - const unsigned memorySpace; }; /// Complex Type Storage.