diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -25,6 +25,31 @@ // Type Converter //===----------------------------------------------------------------------===// +struct SPIRVConversionOptions { + /// Whether to emulate non-32-bit scalar types with 32-bit scalar types if + /// no native support. + /// + /// Non-32-bit scalar types require special hardware support that may not + /// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar + /// types require special capabilities or extensions. This option controls + /// whether to use 32-bit types to emulate, if a scalar type of a certain + /// bitwidth is not supported in the target environment. This requires the + /// runtime to also feed in data with a matched bitwidth and layout for + /// interface types. The runtime can do that by inspecting the SPIR-V + /// module. + /// + /// If the original scalar type has less than 32-bit, a multiple of its + /// values will be packed into one 32-bit value to be memory efficient. + bool emulateNon32BitScalarTypes{true}; + + /// Use 64-bit integers to convert index types. + bool use64bitIndex{false}; + + /// The number of bits to store a boolean value. It is eight bits by + /// default. + unsigned boolNumBits{8}; +}; + /// Type conversion from builtin types to SPIR-V types for shader interface. /// /// For memref types, this converter additionally performs type wrapping to @@ -32,39 +57,8 @@ /// pointers to structs. class SPIRVTypeConverter : public TypeConverter { public: - struct Options { - /// Whether to emulate non-32-bit scalar types with 32-bit scalar types if - /// no native support. - /// - /// Non-32-bit scalar types require special hardware support that may not - /// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar - /// types require special capabilities or extensions. This option controls - /// whether to use 32-bit types to emulate, if a scalar type of a certain - /// bitwidth is not supported in the target environment. This requires the - /// runtime to also feed in data with a matched bitwidth and layout for - /// interface types. The runtime can do that by inspecting the SPIR-V - /// module. - /// - /// If the original scalar type has less than 32-bit, a multiple of its - /// values will be packed into one 32-bit value to be memory efficient. - bool emulateNon32BitScalarTypes{true}; - - /// Use 64-bit integers to convert index types. - bool use64bitIndex{false}; - - /// The number of bits to store a boolean value. It is eight bits by - /// default. - unsigned boolNumBits{8}; - - // Note: we need this instead of inline initializers because of - // https://bugs.llvm.org/show_bug.cgi?id=36684 - Options() - - {} - }; - explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, - Options options = {}); + const SPIRVConversionOptions &options = {}); /// Gets the SPIR-V correspondence for the standard index type. Type getIndexType() const; @@ -72,11 +66,11 @@ const spirv::TargetEnv &getTargetEnv() const { return targetEnv; } /// Returns the options controlling the SPIR-V type converter. - const Options &getOptions() const { return options; } + const SPIRVConversionOptions &getOptions() const { return options; } private: spirv::TargetEnv targetEnv; - Options options; + SPIRVConversionOptions options; MLIRContext *getContext() const; }; diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -959,7 +959,7 @@ auto targetAttr = spirv::lookupTargetEnvOrDefault(op); auto target = SPIRVConversionTarget::get(targetAttr); - SPIRVTypeConverter::Options options; + SPIRVConversionOptions options; options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; SPIRVTypeConverter typeConverter(targetAttr, options); diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp --- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp @@ -40,7 +40,7 @@ std::unique_ptr target = SPIRVConversionTarget::get(targetAttr); - SPIRVTypeConverter::Options options; + SPIRVConversionOptions options; options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; SPIRVTypeConverter typeConverter(targetAttr, options); diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp --- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp @@ -39,7 +39,7 @@ std::unique_ptr target = SPIRVConversionTarget::get(targetAttr); - SPIRVTypeConverter::Options options; + SPIRVConversionOptions options; options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; SPIRVTypeConverter typeConverter(targetAttr, options); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp @@ -39,7 +39,7 @@ std::unique_ptr target = SPIRVConversionTarget::get(targetAttr); - SPIRVTypeConverter::Options options; + SPIRVConversionOptions options; options.boolNumBits = this->boolNumBits; SPIRVTypeConverter typeConverter(targetAttr, options); diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp @@ -37,7 +37,7 @@ std::unique_ptr target = SPIRVConversionTarget::get(targetAttr); - SPIRVTypeConverter::Options options; + SPIRVConversionOptions options; options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; SPIRVTypeConverter typeConverter(targetAttr, options); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -124,8 +124,8 @@ // TODO: This is a utility function that should probably be exposed by the // SPIR-V dialect. Keeping it local till the use case arises. -static Optional -getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) { +static Optional getTypeNumBytes(const SPIRVConversionOptions &options, + Type type) { if (type.isa()) { auto bitWidth = type.getIntOrFloatBitWidth(); // According to the SPIR-V spec: @@ -199,7 +199,7 @@ /// Converts a scalar `type` to a suitable type under the given `targetEnv`. static Type convertScalarType(const spirv::TargetEnv &targetEnv, - const SPIRVTypeConverter::Options &options, + const SPIRVConversionOptions &options, spirv::ScalarType type, Optional storageClass = {}) { // Get extension and capability requirements for the given type. @@ -232,7 +232,7 @@ /// Converts a vector `type` to a suitable type under the given `targetEnv`. static Type convertVectorType(const spirv::TargetEnv &targetEnv, - const SPIRVTypeConverter::Options &options, + const SPIRVConversionOptions &options, VectorType type, Optional storageClass = {}) { if (type.getRank() <= 1 && type.getNumElements() == 1) @@ -271,7 +271,7 @@ /// constant values and use OpCompositeExtract and OpCompositeInsert to /// manipulate, like what we do for vectors. static Type convertTensorType(const spirv::TargetEnv &targetEnv, - const SPIRVTypeConverter::Options &options, + const SPIRVConversionOptions &options, TensorType type) { // TODO: Handle dynamic shapes. if (!type.hasStaticShape()) { @@ -310,7 +310,7 @@ } static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, - const SPIRVTypeConverter::Options &options, + const SPIRVConversionOptions &options, MemRefType type, spirv::StorageClass storageClass) { unsigned numBoolBits = options.boolNumBits; @@ -349,7 +349,7 @@ } static Type convertMemrefType(const spirv::TargetEnv &targetEnv, - const SPIRVTypeConverter::Options &options, + const SPIRVConversionOptions &options, MemRefType type) { auto attr = type.getMemorySpace().dyn_cast_or_null(); if (!attr) { @@ -414,7 +414,7 @@ } SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, - Options options) + const SPIRVConversionOptions &options) : targetEnv(targetAttr), options(options) { // Add conversions. The order matters here: later ones will be tried earlier.