diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -127,7 +127,7 @@ using Base::Base; /// Signedness semantics. - enum SignednessSemantics { + enum SignednessSemantics : uint32_t { Signless, /// No signedness semantics Signed, /// Signed integer Unsigned, /// Unsigned integer diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -79,16 +79,9 @@ return *abstractType; } - /// Get the subclass data. - unsigned getSubclassData() const { return subclassData; } - - /// Set the subclass data. - void setSubclassData(unsigned val) { subclassData = val; } - protected: /// This constructor is used by derived classes as part of the TypeUniquer. - TypeStorage(unsigned subclassData = 0) - : abstractType(nullptr), subclassData(subclassData) {} + TypeStorage() : abstractType(nullptr) {} private: /// Set the abstract type for this storage instance. This is used by the @@ -99,9 +92,6 @@ /// The abstract description for this type. const AbstractType *abstractType; - - /// Space for subclasses to store data. - unsigned subclassData; }; /// Default storage type for types that require no additional initialization or diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -191,9 +191,6 @@ friend ::llvm::hash_code hash_value(Type arg); - unsigned getSubclassData() const; - void setSubclassData(unsigned val); - /// Methods for supporting PointerLikeTypeTraits. const void *getAsOpaquePointer() const { return static_cast(impl); @@ -263,7 +260,7 @@ MLIRContext *context); // Input types. - unsigned getNumInputs() const { return getSubclassData(); } + unsigned getNumInputs() const; Type getInput(unsigned i) const { return getInputs()[i]; } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -108,14 +108,15 @@ } bool operator==(const KeyTy &key) const { - return key == KeyTy(elementType, getSubclassData(), stride); + return key == KeyTy(elementType, elementCount, stride); } ArrayTypeStorage(const KeyTy &key) - : TypeStorage(std::get<1>(key)), elementType(std::get<0>(key)), + : elementType(std::get<0>(key)), elementCount(std::get<1>(key)), stride(std::get<2>(key)) {} Type elementType; + unsigned elementCount; unsigned stride; }; @@ -132,9 +133,7 @@ elementCount, stride); } -unsigned ArrayType::getNumElements() const { - return getImpl()->getSubclassData(); -} +unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; } Type ArrayType::getElementType() const { return getImpl()->elementType; } @@ -321,19 +320,17 @@ } bool operator==(const KeyTy &key) const { - return key == KeyTy(elementType, getScope(), rows, columns); + return key == KeyTy(elementType, scope, rows, columns); } CooperativeMatrixTypeStorage(const KeyTy &key) - : TypeStorage(static_cast(std::get<1>(key))), - elementType(std::get<0>(key)), rows(std::get<2>(key)), - columns(std::get<3>(key)) {} - - Scope getScope() const { return static_cast(getSubclassData()); } + : elementType(std::get<0>(key)), rows(std::get<2>(key)), + columns(std::get<3>(key)), scope(std::get<1>(key)) {} Type elementType; unsigned rows; unsigned columns; + Scope scope; }; CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType, @@ -347,9 +344,7 @@ return getImpl()->elementType; } -Scope CooperativeMatrixNVType::getScope() const { - return getImpl()->getScope(); -} +Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; } unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; } @@ -412,20 +407,6 @@ } struct spirv::detail::ImageTypeStorage : public TypeStorage { -private: - /// Define a bit-field struct to pack the enum values - union EnumPack { - struct { - unsigned dimEncoding : getNumBits(); - unsigned depthInfoEncoding : getNumBits(); - unsigned arrayedInfoEncoding : getNumBits(); - unsigned samplingInfoEncoding : getNumBits(); - unsigned samplerUseInfoEncoding : getNumBits(); - unsigned formatEncoding : getNumBits(); - } data; - unsigned storage; - }; - public: using KeyTy = std::tuple; @@ -436,95 +417,23 @@ } bool operator==(const KeyTy &key) const { - return key == KeyTy(elementType, getDim(), getDepthInfo(), getArrayedInfo(), - getSamplingInfo(), getSamplerUseInfo(), - getImageFormat()); - } - - Dim getDim() const { - EnumPack v; - v.storage = getSubclassData(); - return static_cast(v.data.dimEncoding); - } - void setDim(Dim dim) { - EnumPack v; - v.storage = getSubclassData(); - v.data.dimEncoding = static_cast(dim); - setSubclassData(v.storage); - } - - ImageDepthInfo getDepthInfo() const { - EnumPack v; - v.storage = getSubclassData(); - return static_cast(v.data.depthInfoEncoding); - } - void setDepthInfo(ImageDepthInfo depthInfo) { - EnumPack v; - v.storage = getSubclassData(); - v.data.depthInfoEncoding = static_cast(depthInfo); - setSubclassData(v.storage); - } - - ImageArrayedInfo getArrayedInfo() const { - EnumPack v; - v.storage = getSubclassData(); - return static_cast(v.data.arrayedInfoEncoding); - } - void setArrayedInfo(ImageArrayedInfo arrayedInfo) { - EnumPack v; - v.storage = getSubclassData(); - v.data.arrayedInfoEncoding = static_cast(arrayedInfo); - setSubclassData(v.storage); + return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo, + samplerUseInfo, format); } - ImageSamplingInfo getSamplingInfo() const { - EnumPack v; - v.storage = getSubclassData(); - return static_cast(v.data.samplingInfoEncoding); - } - void setSamplingInfo(ImageSamplingInfo samplingInfo) { - EnumPack v; - v.storage = getSubclassData(); - v.data.samplingInfoEncoding = static_cast(samplingInfo); - setSubclassData(v.storage); - } - - ImageSamplerUseInfo getSamplerUseInfo() const { - EnumPack v; - v.storage = getSubclassData(); - return static_cast(v.data.samplerUseInfoEncoding); - } - void setSamplerUseInfo(ImageSamplerUseInfo samplerUseInfo) { - EnumPack v; - v.storage = getSubclassData(); - v.data.samplerUseInfoEncoding = static_cast(samplerUseInfo); - setSubclassData(v.storage); - } - - ImageFormat getImageFormat() const { - EnumPack v; - v.storage = getSubclassData(); - return static_cast(v.data.formatEncoding); - } - void setImageFormat(ImageFormat format) { - EnumPack v; - v.storage = getSubclassData(); - v.data.formatEncoding = static_cast(format); - setSubclassData(v.storage); - } - - ImageTypeStorage(const KeyTy &key) : elementType(std::get<0>(key)) { - static_assert(sizeof(EnumPack) <= sizeof(getSubclassData()), - "EnumPack size greater than subClassData type size"); - setDim(std::get<1>(key)); - setDepthInfo(std::get<2>(key)); - setArrayedInfo(std::get<3>(key)); - setSamplingInfo(std::get<4>(key)); - setSamplerUseInfo(std::get<5>(key)); - setImageFormat(std::get<6>(key)); - } + ImageTypeStorage(const KeyTy &key) + : elementType(std::get<0>(key)), dim(std::get<1>(key)), + depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)), + samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)), + format(std::get<6>(key)) {} Type elementType; + Dim dim : getNumBits(); + ImageDepthInfo depthInfo : getNumBits(); + ImageArrayedInfo arrayedInfo : getNumBits(); + ImageSamplingInfo samplingInfo : getNumBits(); + ImageSamplerUseInfo samplerUseInfo : getNumBits(); + ImageFormat format : getNumBits(); }; ImageType @@ -536,27 +445,23 @@ Type ImageType::getElementType() const { return getImpl()->elementType; } -Dim ImageType::getDim() const { return getImpl()->getDim(); } +Dim ImageType::getDim() const { return getImpl()->dim; } -ImageDepthInfo ImageType::getDepthInfo() const { - return getImpl()->getDepthInfo(); -} +ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; } ImageArrayedInfo ImageType::getArrayedInfo() const { - return getImpl()->getArrayedInfo(); + return getImpl()->arrayedInfo; } ImageSamplingInfo ImageType::getSamplingInfo() const { - return getImpl()->getSamplingInfo(); + return getImpl()->samplingInfo; } ImageSamplerUseInfo ImageType::getSamplerUseInfo() const { - return getImpl()->getSamplerUseInfo(); + return getImpl()->samplerUseInfo; } -ImageFormat ImageType::getImageFormat() const { - return getImpl()->getImageFormat(); -} +ImageFormat ImageType::getImageFormat() const { return getImpl()->format; } void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &, Optional) { @@ -588,18 +493,14 @@ } bool operator==(const KeyTy &key) const { - return key == KeyTy(pointeeType, getStorageClass()); + return key == KeyTy(pointeeType, storageClass); } PointerTypeStorage(const KeyTy &key) - : TypeStorage(static_cast(key.second)), pointeeType(key.first) { - } - - StorageClass getStorageClass() const { - return static_cast(getSubclassData()); - } + : pointeeType(key.first), storageClass(key.second) {} Type pointeeType; + StorageClass storageClass; }; PointerType PointerType::get(Type pointeeType, StorageClass storageClass) { @@ -610,7 +511,7 @@ Type PointerType::getPointeeType() const { return getImpl()->pointeeType; } StorageClass PointerType::getStorageClass() const { - return getImpl()->getStorageClass(); + return getImpl()->storageClass; } void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, @@ -650,13 +551,14 @@ } bool operator==(const KeyTy &key) const { - return key == KeyTy(elementType, getSubclassData()); + return key == KeyTy(elementType, stride); } RuntimeArrayTypeStorage(const KeyTy &key) - : TypeStorage(key.second), elementType(key.first) {} + : elementType(key.first), stride(key.second) {} Type elementType; + unsigned stride; }; RuntimeArrayType RuntimeArrayType::get(Type elementType) { @@ -671,9 +573,7 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; } -unsigned RuntimeArrayType::getArrayStride() const { - return getImpl()->getSubclassData(); -} +unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; } void RuntimeArrayType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, @@ -917,8 +817,8 @@ unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, StructType::MemberDecorationInfo const *memberDecorationsInfo) - : TypeStorage(numMembers), memberTypes(memberTypes), - offsetInfo(layoutInfo), numMemberDecorations(numMemberDecorations), + : memberTypes(memberTypes), offsetInfo(layoutInfo), + numMembers(numMembers), numMemberDecorations(numMemberDecorations), memberDecorationsInfo(memberDecorationsInfo) {} using KeyTy = std::tuple, ArrayRef, @@ -960,12 +860,12 @@ } ArrayRef getMemberTypes() const { - return ArrayRef(memberTypes, getSubclassData()); + return ArrayRef(memberTypes, numMembers); } ArrayRef getOffsetInfo() const { if (offsetInfo) { - return ArrayRef(offsetInfo, getSubclassData()); + return ArrayRef(offsetInfo, numMembers); } return {}; } @@ -980,6 +880,7 @@ Type const *memberTypes; StructType::OffsetInfo const *offsetInfo; + unsigned numMembers; unsigned numMemberDecorations; StructType::MemberDecorationInfo const *memberDecorationsInfo; }; @@ -1003,9 +904,7 @@ ArrayRef()); } -unsigned StructType::getNumElements() const { - return getImpl()->getSubclassData(); -} +unsigned StructType::getNumElements() const { return getImpl()->numMembers; } Type StructType::getElementType(unsigned index) const { assert(getNumElements() > index && "member index out of range"); diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -130,10 +130,10 @@ return success(); } -unsigned IntegerType::getWidth() const { return getImpl()->getWidth(); } +unsigned IntegerType::getWidth() const { return getImpl()->width; } IntegerType::SignednessSemantics IntegerType::getSignedness() const { - return getImpl()->getSignedness(); + return getImpl()->signedness; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -54,17 +54,17 @@ struct IntegerTypeStorage : public TypeStorage { IntegerTypeStorage(unsigned width, IntegerType::SignednessSemantics signedness) - : TypeStorage(packKeyBits(width, signedness)) {} + : width(width), signedness(signedness) {} /// The hash key used for uniquing. using KeyTy = std::pair; static llvm::hash_code hashKey(const KeyTy &key) { - return llvm::hash_value(packKeyBits(key.first, key.second)); + return llvm::hash_value(key); } bool operator==(const KeyTy &key) const { - return getSubclassData() == packKeyBits(key.first, key.second); + return KeyTy(width, signedness) == key; } static IntegerTypeStorage *construct(TypeStorageAllocator &allocator, @@ -73,35 +73,15 @@ IntegerTypeStorage(key.first, key.second); } - struct KeyBits { - unsigned width : 30; - unsigned signedness : 2; - }; - - /// Pack the given `width` and `signedness` as a key. - static unsigned packKeyBits(unsigned width, - IntegerType::SignednessSemantics signedness) { - KeyBits bits{width, static_cast(signedness)}; - return llvm::bit_cast(bits); - } - - static KeyBits unpackKeyBits(unsigned bits) { - return llvm::bit_cast(bits); - } - - unsigned getWidth() { return unpackKeyBits(getSubclassData()).width; } - - IntegerType::SignednessSemantics getSignedness() { - return static_cast( - unpackKeyBits(getSubclassData()).signedness); - } + unsigned width : 30; + IntegerType::SignednessSemantics signedness : 2; }; /// Function Type Storage and Uniquing. struct FunctionTypeStorage : public TypeStorage { FunctionTypeStorage(unsigned numInputs, unsigned numResults, Type const *inputsAndResults) - : TypeStorage(numInputs), numResults(numResults), + : numInputs(numInputs), numResults(numResults), inputsAndResults(inputsAndResults) {} /// The hash key used for uniquing. @@ -128,20 +108,20 @@ } ArrayRef getInputs() const { - return ArrayRef(inputsAndResults, getSubclassData()); + return ArrayRef(inputsAndResults, numInputs); } ArrayRef getResults() const { - return ArrayRef(inputsAndResults + getSubclassData(), numResults); + return ArrayRef(inputsAndResults + numInputs, numResults); } + unsigned numInputs; unsigned numResults; Type const *inputsAndResults; }; /// Shaped Type Storage. struct ShapedTypeStorage : public TypeStorage { - ShapedTypeStorage(Type elementTy, unsigned subclassData = 0) - : TypeStorage(subclassData), elementType(elementTy) {} + ShapedTypeStorage(Type elementTy) : elementType(elementTy) {} /// The hash key used for uniquing. using KeyTy = Type; @@ -154,7 +134,8 @@ struct VectorTypeStorage : public ShapedTypeStorage { VectorTypeStorage(unsigned shapeSize, Type elementTy, const int64_t *shapeElements) - : ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {} + : ShapedTypeStorage(elementTy), shapeElements(shapeElements), + shapeSize(shapeSize) {} /// The hash key used for uniquing. using KeyTy = std::pair, Type>; @@ -174,16 +155,18 @@ } ArrayRef getShape() const { - return ArrayRef(shapeElements, getSubclassData()); + return ArrayRef(shapeElements, shapeSize); } const int64_t *shapeElements; + unsigned shapeSize; }; struct RankedTensorTypeStorage : public ShapedTypeStorage { RankedTensorTypeStorage(unsigned shapeSize, Type elementTy, const int64_t *shapeElements) - : ShapedTypeStorage(elementTy, shapeSize), shapeElements(shapeElements) {} + : ShapedTypeStorage(elementTy), shapeElements(shapeElements), + shapeSize(shapeSize) {} /// The hash key used for uniquing. using KeyTy = std::pair, Type>; @@ -203,10 +186,11 @@ } ArrayRef getShape() const { - return ArrayRef(shapeElements, getSubclassData()); + return ArrayRef(shapeElements, shapeSize); } const int64_t *shapeElements; + unsigned shapeSize; }; struct UnrankedTensorTypeStorage : public ShapedTypeStorage { @@ -225,9 +209,9 @@ MemRefTypeStorage(unsigned shapeSize, Type elementType, const int64_t *shapeElements, const unsigned numAffineMaps, AffineMap const *affineMapList, const unsigned memorySpace) - : ShapedTypeStorage(elementType, shapeSize), shapeElements(shapeElements), - numAffineMaps(numAffineMaps), affineMapList(affineMapList), - memorySpace(memorySpace) {} + : ShapedTypeStorage(elementType), shapeElements(shapeElements), + shapeSize(shapeSize), numAffineMaps(numAffineMaps), + affineMapList(affineMapList), memorySpace(memorySpace) {} /// The hash key used for uniquing. // MemRefs are uniqued based on their shape, element type, affine map @@ -256,7 +240,7 @@ } ArrayRef getShape() const { - return ArrayRef(shapeElements, getSubclassData()); + return ArrayRef(shapeElements, shapeSize); } ArrayRef getAffineMaps() const { @@ -265,6 +249,8 @@ /// An array of integers which stores the shape dimension sizes. const int64_t *shapeElements; + /// The number of shape elements. + unsigned shapeSize; /// The number of affine maps in the 'affineMapList' array. const unsigned numAffineMaps; /// List of affine maps in the memref's layout/index map composition. @@ -322,7 +308,7 @@ public llvm::TrailingObjects { using KeyTy = ArrayRef; - TupleTypeStorage(unsigned numTypes) : TypeStorage(numTypes) {} + TupleTypeStorage(unsigned numTypes) : numElements(numTypes) {} /// Construction. static TupleTypeStorage *construct(TypeStorageAllocator &allocator, @@ -341,12 +327,15 @@ bool operator==(const KeyTy &key) const { return key == getTypes(); } /// Return the number of held types. - unsigned size() const { return getSubclassData(); } + unsigned size() const { return numElements; } /// Return the held types. ArrayRef getTypes() const { return {getTrailingObjects(), size()}; } + + /// The number of tuple elements. + unsigned numElements; }; } // namespace detail diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -27,9 +27,6 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); } -unsigned Type::getSubclassData() const { return impl->getSubclassData(); } -void Type::setSubclassData(unsigned val) { impl->setSubclassData(val); } - //===----------------------------------------------------------------------===// // FunctionType //===----------------------------------------------------------------------===// @@ -39,6 +36,8 @@ return Base::get(context, Type::Kind::Function, inputs, results); } +unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } + ArrayRef FunctionType::getInputs() const { return getImpl()->getInputs(); }