diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -109,6 +109,11 @@ /// times. void getCapabilities(CapabilityArrayRefVector &capabilities, Optional storage = llvm::None); + + /// Returns the size in bytes for each type. If no size can be calculated, + /// returns `llvm::None`. Note that if the type has explicit layout, it is + /// also taken into account in calculation. + Optional getSizeInBytes(); }; // SPIR-V scalar type: bool type, integer type, floating point type. @@ -127,6 +132,8 @@ Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional storage = llvm::None); + + Optional getSizeInBytes(); }; // SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType. @@ -153,6 +160,8 @@ Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional storage = llvm::None); + + Optional getSizeInBytes(); }; // SPIR-V array type @@ -181,6 +190,8 @@ Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional storage = llvm::None); + + Optional getSizeInBytes(); }; // SPIR-V image type diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -151,6 +151,16 @@ getElementType().cast().getCapabilities(capabilities, storage); } +Optional ArrayType::getSizeInBytes() { + auto elementType = getElementType().cast(); + Optional size = elementType.getSizeInBytes(); + if (!size) + return llvm::None; + // Since array type may have an explicit stride declaration (in bytes), we + // need to also include it in the calculation. + return (*size + getArrayStride()) * getNumElements(); +} + //===----------------------------------------------------------------------===// // CompositeType //===----------------------------------------------------------------------===// @@ -281,6 +291,24 @@ } } +Optional CompositeType::getSizeInBytes() { + switch (getKind()) { + case spirv::TypeKind::Array: + return cast().getSizeInBytes(); + case spirv::TypeKind::Struct: + return cast().getSizeInBytes(); + case StandardTypes::Vector: { + auto elementSize = + cast().getElementType().cast().getSizeInBytes(); + if (!elementSize) + return llvm::None; + return *elementSize * cast().getNumElements(); + } + default: + return llvm::None; + } +} + //===----------------------------------------------------------------------===// // CooperativeMatrixType //===----------------------------------------------------------------------===// @@ -806,6 +834,19 @@ #undef WIDTH_CASE } +Optional ScalarType::getSizeInBytes() { + auto bitWidth = getIntOrFloatBitWidth(); + // According to the SPIR-V spec: + // "There is no physical size or bit pattern defined for values with boolean + // type. If they are stored (in conjunction with OpVariable), they can only + // be used with logical addressing operations, not physical, and only with + // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, + // Private, Function, Input, and Output." + if (bitWidth == 1) + return llvm::None; + return bitWidth / 8; +} + //===----------------------------------------------------------------------===// // SPIRVType //===----------------------------------------------------------------------===// @@ -861,6 +902,14 @@ } } +Optional SPIRVType::getSizeInBytes() { + if (auto scalarType = dyn_cast()) + return scalarType.getSizeInBytes(); + if (auto compositeType = dyn_cast()) + return compositeType.getSizeInBytes(); + return llvm::None; +} + //===----------------------------------------------------------------------===// // StructType //===----------------------------------------------------------------------===//