diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -109,6 +109,8 @@ /// times. void getCapabilities(CapabilityArrayRefVector &capabilities, Optional storage = llvm::None); + + Optional getTypeNumBytes(); }; // SPIR-V scalar type: bool type, integer type, floating point type. @@ -127,6 +129,8 @@ Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional storage = llvm::None); + + Optional getTypeNumBytes(); }; // SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType. @@ -153,6 +157,8 @@ Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional storage = llvm::None); + + Optional getTypeNumBytes(); }; // SPIR-V array type @@ -181,6 +187,8 @@ Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional storage = llvm::None); + + Optional getTypeNumBytes(); }; // SPIR-V image type @@ -242,6 +250,8 @@ Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, Optional storage = llvm::None); + + Optional getTypeNumBytes(); }; // SPIR-V run-time array type diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -151,6 +151,14 @@ getElementType().cast().getCapabilities(capabilities, storage); } +Optional ArrayType::getTypeNumBytes() { + auto elementType = getElementType().cast(); + Optional size = elementType.getTypeNumBytes(); + if (!size) + return llvm::None; + return (*size + getArrayStride()) * getNumElements(); +} + //===----------------------------------------------------------------------===// // CompositeType //===----------------------------------------------------------------------===// @@ -281,6 +289,20 @@ } } +Optional CompositeType::getTypeNumBytes() { + switch (getKind()) { + case spirv::TypeKind::Array: + return cast().getTypeNumBytes(); + case StandardTypes::Vector: + return cast() + .getElementType() + .cast() + .getTypeNumBytes(); + default: + return llvm::None; + } +} + //===----------------------------------------------------------------------===// // CooperativeMatrixType //===----------------------------------------------------------------------===// @@ -616,6 +638,11 @@ capabilities.push_back(*scCaps); } +Optional PointerType::getTypeNumBytes() { + // Memory is byte-addressable so we let the size of the pointer to be 8 bytes. + return 8; +} + //===----------------------------------------------------------------------===// // RuntimeArrayType //===----------------------------------------------------------------------===// @@ -812,6 +839,19 @@ #undef WIDTH_CASE } +Optional ScalarType::getTypeNumBytes() { + auto bitWidth = getIntOrFloatBitWidth(); + // According to the SPIR-V spec: + // "There is no physical size or bit pattern defined for values with boolean + // type. If they are stored (in conjunction with OpVariable), they can only + // be used with logical addressing operations, not physical, and only with + // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, + // Private, Function, Input, and Output." + if (bitWidth == 1) + return llvm::None; + return bitWidth / 8; +} + //===----------------------------------------------------------------------===// // SPIRVType //===----------------------------------------------------------------------===// @@ -867,6 +907,18 @@ } } +Optional SPIRVType::getTypeNumBytes() { + if (auto scalarType = dyn_cast()) { + return scalarType.getTypeNumBytes(); + } else if (auto compositeType = dyn_cast()) { + return compositeType.getTypeNumBytes(); + } else if (auto ptrType = dyn_cast()) { + return ptrType.getTypeNumBytes(); + } else { + return llvm::None; + } +} + //===----------------------------------------------------------------------===// // StructType //===----------------------------------------------------------------------===//