diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -736,12 +736,51 @@ SPIR-V compilation should also take into consideration of the execution environment, so we generate SPIR-V modules valid for the target environment. -This is conveyed by the `spv.target_env` attribute. It is a triple of +This is conveyed by the `spv.target_env` attribute. It should be of +`#spv.target_env` attribute kind, which is defined as: -* `version`: a 32-bit integer indicating the target SPIR-V version. -* `extensions`: a string array attribute containing allowed extensions. -* `capabilities`: a 32-bit integer array attribute containing allowed - capabilities. +``` +spirv-version ::= `V_1_0` | `V_1_1` | ... +spirv-extension ::= `SPV_KHR_16bit_storage` | `SPV_EXT_physical_storage_buffer` | ... +spirv-capability ::= `Shader` | `Kernel` | `GroupNonUniform` | ... + +spirv-extension-list ::= `[` (spirv-extension-elements)? `]` +spirv-extension-elements ::= spirv-extension (`,` spirv-extension)* + +spirv-capability-list ::= `[` (spirv-capability-elements)? `]` +spirv-capability-elements ::= spirv-capability (`,` spirv-capability)* + +spirv-resource-limits ::= dictionary-attribute + +spirv-target-env-attribute ::= `#` `spv.target_env` `<` + spirv-version `,` + spirv-extensions-list `,` + spirv-capability-list `,` + spirv-resource-limits `>` +``` + +The attribute has a few fields: + +* The target SPIR-V version. +* A list of SPIR-V extensions for the target. +* A list of SPIR-V capabilities for the target. +* A dictionary of target resource limits (see the + [Vulkan spec][VulkanResourceLimits] for explanation): + * `max_compute_workgroup_invocations` + * `max_compute_workgroup_size` + +For example, + +``` +module attributes { +spv.target_env = #spv.target_env< + V_1_3, [SPV_KHR_8bit_storage], [Shader, GroupNonUniform] + { + max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32> + }> +} { ... } +``` Dialect conversion framework will utilize the information in `spv.target_env` to properly filter out patterns and ops not available in the target execution @@ -1219,3 +1258,4 @@ [CustomTypeAttrTutorial]: ../DefiningAttributesAndTypes/ [VulkanSpirv]: https://renderdoc.org/vkspec_chunked/chap40.html#spirvenv [VulkanShaderInterface]: https://renderdoc.org/vkspec_chunked/chap14.html#interfaces-resources +[VulkanResourceLimits]: https://renderdoc.org/vkspec_chunked/chap36.html#limits diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h @@ -26,25 +26,35 @@ static StringRef getDialectNamespace() { return "spv"; } + //===--------------------------------------------------------------------===// + // Type + //===--------------------------------------------------------------------===// + /// Checks if the given `type` is valid in SPIR-V dialect. static bool isValidType(Type type); /// Checks if the given `scalar type` is valid in SPIR-V dialect. static bool isValidScalarType(Type type); - /// Returns the attribute name to use when specifying decorations on results - /// of operations. - static std::string getAttributeName(Decoration decoration); - /// Parses a type registered to this dialect. Type parseType(DialectAsmParser &parser) const override; /// Prints a type registered to this dialect. void printType(Type type, DialectAsmPrinter &os) const override; - /// Provides a hook for materializing a constant to this dialect. - Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, - Location loc) override; + //===--------------------------------------------------------------------===// + // Attribute + //===--------------------------------------------------------------------===// + + /// Returns the attribute name to use when specifying decorations on results + /// of operations. + static std::string getAttributeName(Decoration decoration); + + /// Parses an attribute registered to this dialect. + Attribute parseAttribute(DialectAsmParser &parser, Type type) const override; + + /// Prints an attribute registered to this dialect. + void printAttribute(Attribute, DialectAsmPrinter &printer) const override; /// Provides a hook for verifying SPIR-V dialect attributes attached to the /// given op. @@ -62,6 +72,14 @@ LogicalResult verifyRegionResultAttribute(Operation *op, unsigned regionIndex, unsigned resultIndex, NamedAttribute attribute) override; + + //===--------------------------------------------------------------------===// + // Constant + //===--------------------------------------------------------------------===// + + /// Provides a hook for materializing a constant to this dialect. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; }; } // end namespace spirv diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h @@ -25,7 +25,76 @@ #include "mlir/Dialect/SPIRV/TargetAndABI.h.inc" namespace spirv { +enum class Capability : uint32_t; +enum class Extension; enum class StorageClass : uint32_t; +enum class Version : uint32_t; + +namespace detail { +struct TargetEnvAttributeStorage; +} // namespace detail + +/// SPIR-V dialect-specific attribute kinds. +// TODO(antiagainst): move to a more suitable place if we have more attributes. +namespace AttrKind { +enum Kind { + TargetEnv = Attribute::FIRST_SPIRV_ATTR, +}; +} // namespace AttrKind + +/// An attribute that specifies the target version, allowed extensions and +/// capabilities, and resource limits. These information describles a SPIR-V +/// target environment. +class TargetEnvAttr + : public Attribute::AttrBase { +public: + using Base::Base; + + /// Gets a TargetEnvAttr instance. + static TargetEnvAttr get(IntegerAttr version, ArrayAttr extensions, + ArrayAttr capabilities, DictionaryAttr limits); + + /// Returns the attribute kind's name (without the 'spv.' prefix). + static StringRef getKindName(); + + /// Returns the target version. + Version getVersion(); + + struct ext_iterator final + : public llvm::mapped_iterator { + explicit ext_iterator(ArrayAttr::iterator it); + }; + using ext_range = llvm::iterator_range; + + /// Returns the target extensions. + ext_range getExtensions(); + /// Returns the target extensions as a string array attribute. + ArrayAttr getExtensionsAttr(); + + struct cap_iterator final + : public llvm::mapped_iterator { + explicit cap_iterator(ArrayAttr::iterator it); + }; + using cap_range = llvm::iterator_range; + + /// Returns the target capabilities. + cap_range getCapabilities(); + /// Returns the target capabilities as an integer array attribute. + ArrayAttr getCapabilitiesAttr(); + + /// Returns the target resource limits. + DictionaryAttr getResourceLimits(); + + static bool kindof(unsigned kind) { return kind == AttrKind::TargetEnv; } + + static LogicalResult + verifyConstructionInvariants(Optional loc, MLIRContext *context, + IntegerAttr version, ArrayAttr extensions, + ArrayAttr capabilities, DictionaryAttr limits); +}; /// Returns the attribute name for specifying argument ABI information. StringRef getInterfaceVarABIAttrName(); diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td @@ -59,13 +59,4 @@ StructFieldAttr<"max_compute_workgroup_size", I32ElementsAttr> ]>; -// For the generated SPIR-V module, this attribute specifies the target version, -// allowed extensions and capabilities, and resource limits. -def SPV_TargetEnvAttr : StructAttr<"TargetEnvAttr", SPV_Dialect, [ - StructFieldAttr<"version", SPV_VersionAttr>, - StructFieldAttr<"extensions", SPV_ExtensionArrayAttr>, - StructFieldAttr<"capabilities", SPV_CapabilityArrayAttr>, - StructFieldAttr<"limits", SPV_ResourceLimitsAttr> -]>; - #endif // SPIRV_TARGET_AND_ABI diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -35,7 +35,6 @@ loc, xType, invocation, builder->getI32ArrayAttr({dim})); } - //===----------------------------------------------------------------------===// // Reduction (single workgroup) //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -117,6 +117,8 @@ : Dialect(getDialectNamespace(), context) { addTypes(); + addAttributes(); + // Add SPIR-V ops. addOperations< #define GET_OP_LIST @@ -629,6 +631,175 @@ } //===----------------------------------------------------------------------===// +// Attribute Parsing +//===----------------------------------------------------------------------===// + +/// Parses a comma-separated list of keywords, invokes `processKeyword` on each +/// of the parsed keyword, and returns failure if any error occurs. +static ParseResult parseKeywordList( + DialectAsmParser &parser, + function_ref processKeyword) { + if (parser.parseLSquare()) + return failure(); + + // Special case for empty list. + if (succeeded(parser.parseOptionalRSquare())) + return success(); + + // Keep parsing the keyword and an optional comma following it. If the comma + // is successfully parsed, then we have more keywords to parse. + do { + auto loc = parser.getCurrentLocation(); + StringRef keyword; + if (parser.parseKeyword(&keyword) || failed(processKeyword(loc, keyword))) + return failure(); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRSquare()) + return failure(); + + return success(); +} + +/// Parses a spirv::TargetEnvAttr. +static Attribute parseTargetAttr(DialectAsmParser &parser) { + if (parser.parseLess()) + return {}; + + Builder &builder = parser.getBuilder(); + + IntegerAttr versionAttr; + { + auto loc = parser.getCurrentLocation(); + StringRef version; + if (parser.parseKeyword(&version) || parser.parseComma()) + return {}; + + if (auto versionSymbol = spirv::symbolizeVersion(version)) { + versionAttr = + builder.getI32IntegerAttr(static_cast(*versionSymbol)); + } else { + parser.emitError(loc, "unknown version: ") << version; + return {}; + } + } + + ArrayAttr extensionsAttr; + { + SmallVector extensions; + llvm::SMLoc errorloc; + StringRef errorKeyword; + + auto processExtension = [&](llvm::SMLoc loc, StringRef extension) { + if (spirv::symbolizeExtension(extension)) { + extensions.push_back(builder.getStringAttr(extension)); + return success(); + } + return errorloc = loc, errorKeyword = extension, failure(); + }; + if (parseKeywordList(parser, processExtension) || parser.parseComma()) { + if (!errorKeyword.empty()) + parser.emitError(errorloc, "unknown extension: ") << errorKeyword; + return {}; + } + + extensionsAttr = builder.getArrayAttr(extensions); + } + + ArrayAttr capabilitiesAttr; + { + SmallVector capabilities; + llvm::SMLoc errorloc; + StringRef errorKeyword; + + auto processCapability = [&](llvm::SMLoc loc, StringRef capability) { + if (auto capSymbol = spirv::symbolizeCapability(capability)) { + capabilities.push_back( + builder.getI32IntegerAttr(static_cast(*capSymbol))); + return success(); + } + return errorloc = loc, errorKeyword = capability, failure(); + }; + if (parseKeywordList(parser, processCapability) || parser.parseComma()) { + if (!errorKeyword.empty()) + parser.emitError(errorloc, "unknown capability: ") << errorKeyword; + return {}; + } + + capabilitiesAttr = builder.getArrayAttr(capabilities); + } + + DictionaryAttr limitsAttr; + { + auto loc = parser.getCurrentLocation(); + if (parser.parseAttribute(limitsAttr)) + return {}; + + if (!limitsAttr.isa()) { + parser.emitError( + loc, + "limits must be a dictionary attribute containing two 32-bit integer " + "attributes 'max_compute_workgroup_invocations' and " + "'max_compute_workgroup_size'"); + return {}; + } + } + + if (parser.parseGreater()) + return {}; + + return spirv::TargetEnvAttr::get(versionAttr, extensionsAttr, + capabilitiesAttr, limitsAttr); +} + +Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser, + Type type) const { + // SPIR-V attributes are dictionaries so they do not have type. + if (type) { + parser.emitError(parser.getNameLoc(), "unexpected type"); + return {}; + } + + // Parse the kind keyword first. + StringRef attrKind; + if (parser.parseKeyword(&attrKind)) + return {}; + + if (attrKind == spirv::TargetEnvAttr::getKindName()) + return parseTargetAttr(parser); + + parser.emitError(parser.getNameLoc(), "unknown SPIR-V attriubte kind: ") + << attrKind; + return {}; +} + +//===----------------------------------------------------------------------===// +// Attribute Printing +//===----------------------------------------------------------------------===// + +static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) { + auto &os = printer.getStream(); + printer << spirv::TargetEnvAttr::getKindName() << "<" + << spirv::stringifyVersion(targetEnv.getVersion()) << ", ["; + interleaveComma(targetEnv.getExtensionsAttr(), os, [&](Attribute attr) { + os << attr.cast().getValue(); + }); + printer << "], ["; + interleaveComma(targetEnv.getCapabilities(), os, [&](spirv::Capability cap) { + os << spirv::stringifyCapability(cap); + }); + printer << "], " << targetEnv.getResourceLimits() << ">"; +} + +void SPIRVDialect::printAttribute(Attribute attr, + DialectAsmPrinter &printer) const { + if (auto targetEnv = attr.dyn_cast()) + print(targetEnv, printer); + else + llvm_unreachable("unhandled SPIR-V attribute kind"); +} + +//===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -362,19 +362,15 @@ spirv::SPIRVConversionTarget::SPIRVConversionTarget( spirv::TargetEnvAttr targetEnv, MLIRContext *context) - : ConversionTarget(*context), - givenVersion(static_cast(targetEnv.version().getInt())) { - for (Attribute extAttr : targetEnv.extensions()) - givenExtensions.insert( - *spirv::symbolizeExtension(extAttr.cast().getValue())); + : ConversionTarget(*context), givenVersion(targetEnv.getVersion()) { + for (spirv::Extension ext : targetEnv.getExtensions()) + givenExtensions.insert(ext); // Add extensions implied by the current version. for (spirv::Extension ext : spirv::getImpliedExtensions(givenVersion)) givenExtensions.insert(ext); - for (Attribute capAttr : targetEnv.capabilities()) { - auto cap = - static_cast(capAttr.cast().getInt()); + for (spirv::Capability cap : targetEnv.getCapabilities()) { givenCapabilities.insert(cap); // Add capabilities implied by the current capability. diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -16,6 +16,119 @@ namespace mlir { #include "mlir/Dialect/SPIRV/TargetAndABI.cpp.inc" + +namespace spirv { +namespace detail { +struct TargetEnvAttributeStorage : public AttributeStorage { + using KeyTy = std::tuple; + + TargetEnvAttributeStorage(Attribute version, Attribute extensions, + Attribute capabilities, Attribute limits) + : version(version), extensions(extensions), capabilities(capabilities), + limits(limits) {} + + bool operator==(const KeyTy &key) const { + return std::get<0>(key) == version && std::get<1>(key) == extensions && + std::get<2>(key) == capabilities && std::get<3>(key) == limits; + } + + static TargetEnvAttributeStorage * + construct(AttributeStorageAllocator &allocator, const KeyTy &key) { + return new (allocator.allocate()) + TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key), + std::get<2>(key), std::get<3>(key)); + } + + Attribute version; + Attribute extensions; + Attribute capabilities; + Attribute limits; +}; +} // namespace detail +} // namespace spirv +} // namespace mlir + +spirv::TargetEnvAttr spirv::TargetEnvAttr::get(IntegerAttr version, + ArrayAttr extensions, + ArrayAttr capabilities, + DictionaryAttr limits) { + assert(version && extensions && capabilities && limits); + MLIRContext *context = version.getContext(); + return Base::get(context, spirv::AttrKind::TargetEnv, version, extensions, + capabilities, limits); +} + +StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; } + +spirv::Version spirv::TargetEnvAttr::getVersion() { + return static_cast( + getImpl()->version.cast().getValue().getZExtValue()); +} + +spirv::TargetEnvAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it) + : llvm::mapped_iterator( + it, [](Attribute attr) { + return *symbolizeExtension(attr.cast().getValue()); + }) {} + +spirv::TargetEnvAttr::ext_range spirv::TargetEnvAttr::getExtensions() { + auto range = getExtensionsAttr().getValue(); + return {ext_iterator(range.begin()), ext_iterator(range.end())}; +} + +ArrayAttr spirv::TargetEnvAttr::getExtensionsAttr() { + return getImpl()->extensions.cast(); +} + +spirv::TargetEnvAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it) + : llvm::mapped_iterator( + it, [](Attribute attr) { + return *symbolizeCapability( + attr.cast().getValue().getZExtValue()); + }) {} + +spirv::TargetEnvAttr::cap_range spirv::TargetEnvAttr::getCapabilities() { + auto range = getCapabilitiesAttr().getValue(); + return {cap_iterator(range.begin()), cap_iterator(range.end())}; +} + +ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() { + return getImpl()->capabilities.cast(); +} + +DictionaryAttr spirv::TargetEnvAttr::getResourceLimits() { + return getImpl()->limits.cast(); +} + +LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants( + Optional loc, MLIRContext *context, IntegerAttr version, + ArrayAttr extensions, ArrayAttr capabilities, DictionaryAttr limits) { + if (!version.getType().isInteger(32)) + return emitOptionalError(loc, "expected 32-bit integer for version"); + + if (!llvm::all_of(extensions.getValue(), [](Attribute attr) { + if (auto strAttr = attr.dyn_cast()) + if (spirv::symbolizeExtension(strAttr.getValue())) + return true; + return false; + })) + return emitOptionalError(loc, "unknown extension in extension list"); + + if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) { + if (auto intAttr = attr.dyn_cast()) + if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue())) + return true; + return false; + })) + return emitOptionalError(loc, "unknown capability in capability list"); + + if (!limits.isa()) + return emitOptionalError(loc, + "expected spirv::ResourceLimitsAttr for limits"); + + return success(); } StringRef spirv::getInterfaceVarABIAttrName() { @@ -65,7 +178,7 @@ builder.getI32ArrayAttr({}), builder.getI32ArrayAttr( {static_cast(spirv::Capability::Shader)}), - spirv::getDefaultResourceLimits(context), context); + spirv::getDefaultResourceLimits(context)); } spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) { diff --git a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir --- a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir +++ b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir @@ -15,15 +15,12 @@ } module attributes { - spv.target_env = { - version = 3 : i32, - extensions = [], - capabilities = [1: i32, 63: i32], // Shader, GroupNonUniformArithmetic - limits = { + spv.target_env = #spv.target_env< + V_1_3, [], [Shader, GroupNonUniformArithmetic], + { max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32> - } - } + }> } { // CHECK: spv.globalVariable @@ -80,15 +77,12 @@ } module attributes { - spv.target_env = { - version = 3 : i32, - extensions = [], - capabilities = [1: i32, 63: i32], // Shader, GroupNonUniformArithmetic - limits = { + spv.target_env = #spv.target_env< + V_1_3, [], [Shader, GroupNonUniformArithmetic], + { max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32> - } - } + }> } { func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) { // expected-error @+1 {{failed to legalize operation 'linalg.generic'}} @@ -116,15 +110,12 @@ } module attributes { - spv.target_env = { - version = 3 : i32, - extensions = [], - capabilities = [1: i32, 63: i32], // Shader, GroupNonUniformArithmetic - limits = { + spv.target_env = #spv.target_env< + V_1_3, [], [Shader, GroupNonUniformArithmetic], + { max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32> - } - } + }> } { func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes { spv.entry_point_abi = {local_size = dense<[32, 1, 1]>: vector<3xi32>} @@ -154,15 +145,12 @@ } module attributes { - spv.target_env = { - version = 3 : i32, - extensions = [], - capabilities = [1: i32, 63: i32], // Shader, GroupNonUniformArithmetic - limits = { + spv.target_env = #spv.target_env< + V_1_3, [], [Shader, GroupNonUniformArithmetic], + { max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32> - } - } + }> } { func @single_workgroup_reduction(%input: memref<16x8xi32>, %output: memref<16xi32>) attributes { spv.entry_point_abi = {local_size = dense<[16, 8, 1]>: vector<3xi32>} diff --git a/mlir/test/Dialect/SPIRV/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/target-and-abi.mlir --- a/mlir/test/Dialect/SPIRV/target-and-abi.mlir +++ b/mlir/test/Dialect/SPIRV/target-and-abi.mlir @@ -106,50 +106,87 @@ // spv.target_env //===----------------------------------------------------------------------===// -// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', one 32-bit integer array attribute 'capabilities', and one dictionary attribute 'limits'}} func @target_env_wrong_type() attributes { - spv.target_env = 64 + // expected-error @+1 {{expected valid keyword}} + spv.target_env = #spv.target_env<64> } { return } // ----- -// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', one 32-bit integer array attribute 'capabilities', and one dictionary attribute 'limits'}} func @target_env_missing_fields() attributes { - spv.target_env = {version = 0: i32} + // expected-error @+1 {{expected ','}} + spv.target_env = #spv.target_env +} { return } + +// ----- + +func @target_env_wrong_version() attributes { + // expected-error @+1 {{unknown version: V_x_y}} + spv.target_env = #spv.target_env } { return } // ----- -// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', one 32-bit integer array attribute 'capabilities', and one dictionary attribute 'limits'}} func @target_env_wrong_extension_type() attributes { - spv.target_env = {version = 0: i32, extensions = [32: i32], capabilities = [1: i32]} + // expected-error @+1 {{expected valid keyword}} + spv.target_env = #spv.target_env } { return } // ----- -// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', one 32-bit integer array attribute 'capabilities', and one dictionary attribute 'limits'}} func @target_env_wrong_extension() attributes { - spv.target_env = {version = 0: i32, extensions = ["SPV_Something"], capabilities = [1: i32]} + // expected-error @+1 {{unknown extension: SPV_Something}} + spv.target_env = #spv.target_env +} { return } + +// ----- + +func @target_env_wrong_capability() attributes { + // expected-error @+1 {{unknown capability: Something}} + spv.target_env = #spv.target_env +} { return } + +// ----- + +func @target_env_missing_limits() attributes { + spv.target_env = #spv.target_env< + V_1_0, [SPV_KHR_storage_buffer_storage_class], [Shader], + // expected-error @+1 {{limits must be a dictionary attribute containing two 32-bit integer attributes 'max_compute_workgroup_invocations' and 'max_compute_workgroup_size'}} + {max_compute_workgroup_size = dense<[128, 64, 64]> : vector<3xi32>}> +} { return } + +// ----- + +func @target_env_wrong_limits() attributes { + spv.target_env = #spv.target_env< + V_1_0, [SPV_KHR_storage_buffer_storage_class], [Shader], + // expected-error @+1 {{limits must be a dictionary attribute containing two 32-bit integer attributes 'max_compute_workgroup_invocations' and 'max_compute_workgroup_size'}} + {max_compute_workgroup_invocations = 128 : i64, max_compute_workgroup_size = dense<[128, 64, 64]> : vector<3xi32>}> } { return } // ----- func @target_env() attributes { - // CHECK: spv.target_env = {capabilities = [1 : i32], extensions = ["SPV_KHR_storage_buffer_storage_class"], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 64, 64]> : vector<3xi32>}, version = 0 : i32} - spv.target_env = { - version = 0: i32, - extensions = ["SPV_KHR_storage_buffer_storage_class"], - capabilities = [1: i32], - limits = { + + // CHECK: spv.target_env = #spv.target_env : vector<3xi32>}> + spv.target_env = #spv.target_env< + V_1_0, [SPV_KHR_storage_buffer_storage_class], [Shader], + { max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 64, 64]> : vector<3xi32> - } - } + }> } { return } // ----- -// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', one 32-bit integer array attribute 'capabilities', and one dictionary attribute 'limits'}} func @target_env_extra_fields() attributes { - spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_storage_buffer_storage_class"], capabilities = [1: i32], extra = 32} + // expected-error @+6 {{expected '>'}} + spv.target_env = #spv.target_env< + V_1_0, [SPV_KHR_storage_buffer_storage_class], [Shader], + { + max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 64, 64]> : vector<3xi32> + }, + more_stuff + > } { return } diff --git a/mlir/test/Dialect/SPIRV/target-env.mlir b/mlir/test/Dialect/SPIRV/target-env.mlir --- a/mlir/test/Dialect/SPIRV/target-env.mlir +++ b/mlir/test/Dialect/SPIRV/target-env.mlir @@ -29,20 +29,13 @@ // Vulkan memory model is available via extension SPV_KHR_vulkan_memory_model, // which extensions are incorporated into SPIR-V 1.5. -// Enum case symbol (value) map: -// Version: 1.0 (0), 1.1 (1), 1.2 (2), 1.3 (3), 1.4 (4), 1.5 (5) -// Capability: Shader (1), Geometry (2), Kernel (6), AtomicStorage (21), -// GeometryPointSize (24), GroupNonUniformBallot (64), -// SubgroupBallotKHR (4423), VulkanMemoryModel (5345), -// PhysicalStorageBufferAddresses (5347) - //===----------------------------------------------------------------------===// // MaxVersion //===----------------------------------------------------------------------===// // CHECK-LABEL: @cmp_exchange_weak_suitable_version_capabilities func @cmp_exchange_weak_suitable_version_capabilities(%ptr: !spv.ptr, %value: i32, %comparator: i32) -> i32 attributes { - spv.target_env = {version = 1: i32, extensions = [], capabilities = [6: i32, 21: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: spv.AtomicCompareExchangeWeak "Workgroup" "AcquireRelease|AtomicCounterMemory" "Acquire" %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr, i32, i32) -> (i32) @@ -51,7 +44,7 @@ // CHECK-LABEL: @cmp_exchange_weak_unsupported_version func @cmp_exchange_weak_unsupported_version(%ptr: !spv.ptr, %value: i32, %comparator: i32) -> i32 attributes { - spv.target_env = {version = 4: i32, extensions = [], capabilities = [6: i32, 21: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: test.convert_to_atomic_compare_exchange_weak_op %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr, i32, i32) -> (i32) @@ -64,7 +57,7 @@ // CHECK-LABEL: @group_non_uniform_ballot_suitable_version func @group_non_uniform_ballot_suitable_version(%predicate: i1) -> vector<4xi32> attributes { - spv.target_env = {version = 4: i32, extensions = [], capabilities = [64: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: spv.GroupNonUniformBallot "Workgroup" %0 = "test.convert_to_group_non_uniform_ballot_op"(%predicate): (i1) -> (vector<4xi32>) @@ -73,7 +66,7 @@ // CHECK-LABEL: @group_non_uniform_ballot_unsupported_version func @group_non_uniform_ballot_unsupported_version(%predicate: i1) -> vector<4xi32> attributes { - spv.target_env = {version = 1: i32, extensions = [], capabilities = [64: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: test.convert_to_group_non_uniform_ballot_op %0 = "test.convert_to_group_non_uniform_ballot_op"(%predicate): (i1) -> (vector<4xi32>) @@ -86,7 +79,7 @@ // CHECK-LABEL: @cmp_exchange_weak_missing_capability_kernel func @cmp_exchange_weak_missing_capability_kernel(%ptr: !spv.ptr, %value: i32, %comparator: i32) -> i32 attributes { - spv.target_env = {version = 3: i32, extensions = [], capabilities = [21: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: test.convert_to_atomic_compare_exchange_weak_op %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr, i32, i32) -> (i32) @@ -95,7 +88,7 @@ // CHECK-LABEL: @cmp_exchange_weak_missing_capability_atomic_storage func @cmp_exchange_weak_missing_capability_atomic_storage(%ptr: !spv.ptr, %value: i32, %comparator: i32) -> i32 attributes { - spv.target_env = {version = 3: i32, extensions = [], capabilities = [6: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: test.convert_to_atomic_compare_exchange_weak_op %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr, i32, i32) -> (i32) @@ -104,7 +97,7 @@ // CHECK-LABEL: @subgroup_ballot_missing_capability func @subgroup_ballot_missing_capability(%predicate: i1) -> vector<4xi32> attributes { - spv.target_env = {version = 4: i32, extensions = ["SPV_KHR_shader_ballot"], capabilities = [], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: test.convert_to_subgroup_ballot_op %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>) @@ -113,7 +106,7 @@ // CHECK-LABEL: @bit_reverse_directly_implied_capability func @bit_reverse_directly_implied_capability(%operand: i32) -> i32 attributes { - spv.target_env = {version = 0: i32, extensions = [], capabilities = [2: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: spv.BitReverse %0 = "test.convert_to_bit_reverse_op"(%operand): (i32) -> (i32) @@ -122,7 +115,7 @@ // CHECK-LABEL: @bit_reverse_recursively_implied_capability func @bit_reverse_recursively_implied_capability(%operand: i32) -> i32 attributes { - spv.target_env = {version = 0: i32, extensions = [], capabilities = [24: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: spv.BitReverse %0 = "test.convert_to_bit_reverse_op"(%operand): (i32) -> (i32) @@ -135,7 +128,7 @@ // CHECK-LABEL: @subgroup_ballot_suitable_extension func @subgroup_ballot_suitable_extension(%predicate: i1) -> vector<4xi32> attributes { - spv.target_env = {version = 4: i32, extensions = ["SPV_KHR_shader_ballot"], capabilities = [4423: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: spv.SubgroupBallotKHR %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>) @@ -144,7 +137,7 @@ // CHECK-LABEL: @subgroup_ballot_missing_extension func @subgroup_ballot_missing_extension(%predicate: i1) -> vector<4xi32> attributes { - spv.target_env = {version = 4: i32, extensions = [], capabilities = [4423: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: test.convert_to_subgroup_ballot_op %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>) @@ -153,7 +146,7 @@ // CHECK-LABEL: @module_suitable_extension1 func @module_suitable_extension1() attributes { - spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_vulkan_memory_model", "SPV_EXT_physical_storage_buffer"], capabilities = [5345: i32, 5347: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: spv.module "PhysicalStorageBuffer64" "Vulkan" "test.convert_to_module_op"() : () ->() @@ -162,7 +155,7 @@ // CHECK-LABEL: @module_suitable_extension2 func @module_suitable_extension2() attributes { - spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_vulkan_memory_model", "SPV_KHR_physical_storage_buffer"], capabilities = [5345: i32, 5347: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: spv.module "PhysicalStorageBuffer64" "Vulkan" "test.convert_to_module_op"() : () -> () @@ -171,7 +164,7 @@ // CHECK-LABEL: @module_missing_extension_mm func @module_missing_extension_mm() attributes { - spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_physical_storage_buffer"], capabilities = [5345: i32, 5347: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: test.convert_to_module_op "test.convert_to_module_op"() : () -> () @@ -180,7 +173,7 @@ // CHECK-LABEL: @module_missing_extension_am func @module_missing_extension_am() attributes { - spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_vulkan_memory_model"], capabilities = [5345: i32, 5347: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: test.convert_to_module_op "test.convert_to_module_op"() : () -> () @@ -190,7 +183,7 @@ // CHECK-LABEL: @module_implied_extension func @module_implied_extension() attributes { // Version 1.5 implies SPV_KHR_vulkan_memory_model and SPV_KHR_physical_storage_buffer. - spv.target_env = {version = 5: i32, extensions = [], capabilities = [5345: i32, 5347: i32], limits = {max_compute_workgroup_invocations = 128 : i32, max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}} + spv.target_env = #spv.target_env : vector<3xi32>}> } { // CHECK: spv.module "PhysicalStorageBuffer64" "Vulkan" "test.convert_to_module_op"() : () -> ()