diff --git a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt --- a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt @@ -8,6 +8,7 @@ set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) mlir_tablegen(SPIRVEnumAvailability.h.inc -gen-spirv-enum-avail-decls) mlir_tablegen(SPIRVEnumAvailability.cpp.inc -gen-spirv-enum-avail-defs) +mlir_tablegen(SPIRVCapabilityImplication.inc -gen-spirv-capability-implication) add_public_tablegen_target(MLIRSPIRVEnumAvailabilityIncGen) set(LLVM_TARGET_DEFINITIONS SPIRVOps.td) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -367,25 +367,21 @@ } def SPV_C_StorageBuffer16BitAccess : I32EnumAttrCase<"StorageBuffer16BitAccess", 4433> { list availability = [ - MinVersion, Extension<[SPV_KHR_16bit_storage]> ]; } def SPV_C_StoragePushConstant16 : I32EnumAttrCase<"StoragePushConstant16", 4435> { list availability = [ - MinVersion, Extension<[SPV_KHR_16bit_storage]> ]; } def SPV_C_StorageInputOutput16 : I32EnumAttrCase<"StorageInputOutput16", 4436> { list availability = [ - MinVersion, Extension<[SPV_KHR_16bit_storage]> ]; } def SPV_C_DeviceGroup : I32EnumAttrCase<"DeviceGroup", 4437> { list availability = [ - MinVersion, Extension<[SPV_KHR_device_group]> ]; } @@ -401,43 +397,36 @@ } def SPV_C_StorageBuffer8BitAccess : I32EnumAttrCase<"StorageBuffer8BitAccess", 4448> { list availability = [ - MinVersion, Extension<[SPV_KHR_8bit_storage]> ]; } def SPV_C_StoragePushConstant8 : I32EnumAttrCase<"StoragePushConstant8", 4450> { list availability = [ - MinVersion, Extension<[SPV_KHR_8bit_storage]> ]; } def SPV_C_DenormPreserve : I32EnumAttrCase<"DenormPreserve", 4464> { list availability = [ - MinVersion, Extension<[SPV_KHR_float_controls]> ]; } def SPV_C_DenormFlushToZero : I32EnumAttrCase<"DenormFlushToZero", 4465> { list availability = [ - MinVersion, Extension<[SPV_KHR_float_controls]> ]; } def SPV_C_SignedZeroInfNanPreserve : I32EnumAttrCase<"SignedZeroInfNanPreserve", 4466> { list availability = [ - MinVersion, Extension<[SPV_KHR_float_controls]> ]; } def SPV_C_RoundingModeRTE : I32EnumAttrCase<"RoundingModeRTE", 4467> { list availability = [ - MinVersion, Extension<[SPV_KHR_float_controls]> ]; } def SPV_C_RoundingModeRTZ : I32EnumAttrCase<"RoundingModeRTZ", 4468> { list availability = [ - MinVersion, Extension<[SPV_KHR_float_controls]> ]; } @@ -595,14 +584,12 @@ def SPV_C_StorageUniform16 : I32EnumAttrCase<"StorageUniform16", 4434> { list implies = [SPV_C_StorageBuffer16BitAccess]; list availability = [ - MinVersion, Extension<[SPV_KHR_16bit_storage]> ]; } def SPV_C_UniformAndStorageBuffer8BitAccess : I32EnumAttrCase<"UniformAndStorageBuffer8BitAccess", 4449> { list implies = [SPV_C_StorageBuffer8BitAccess]; list availability = [ - MinVersion, Extension<[SPV_KHR_8bit_storage]> ]; } @@ -708,21 +695,18 @@ def SPV_C_DrawParameters : I32EnumAttrCase<"DrawParameters", 4427> { list implies = [SPV_C_Shader]; list availability = [ - MinVersion, Extension<[SPV_KHR_shader_draw_parameters]> ]; } def SPV_C_MultiView : I32EnumAttrCase<"MultiView", 4439> { list implies = [SPV_C_Shader]; list availability = [ - MinVersion, Extension<[SPV_KHR_multiview]> ]; } def SPV_C_VariablePointersStorageBuffer : I32EnumAttrCase<"VariablePointersStorageBuffer", 4441> { list implies = [SPV_C_Shader]; list availability = [ - MinVersion, Extension<[SPV_KHR_variable_pointers]> ]; } @@ -807,7 +791,6 @@ def SPV_C_PhysicalStorageBufferAddresses : I32EnumAttrCase<"PhysicalStorageBufferAddresses", 5347> { list implies = [SPV_C_Shader]; list availability = [ - MinVersion, Extension<[SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer]> ]; } @@ -874,7 +857,6 @@ def SPV_C_VariablePointers : I32EnumAttrCase<"VariablePointers", 4442> { list implies = [SPV_C_VariablePointersStorageBuffer]; list availability = [ - MinVersion, Extension<[SPV_KHR_variable_pointers]> ]; } @@ -1041,7 +1023,6 @@ } def SPV_AM_PhysicalStorageBuffer64 : I32EnumAttrCase<"PhysicalStorageBuffer64", 5348> { list availability = [ - MinVersion, Extension<[SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer]>, Capability<[SPV_C_PhysicalStorageBufferAddresses]> ]; @@ -1266,35 +1247,30 @@ } def SPV_BI_BaseVertex : I32EnumAttrCase<"BaseVertex", 4424> { list availability = [ - MinVersion, Extension<[SPV_KHR_shader_draw_parameters]>, Capability<[SPV_C_DrawParameters]> ]; } def SPV_BI_BaseInstance : I32EnumAttrCase<"BaseInstance", 4425> { list availability = [ - MinVersion, Extension<[SPV_KHR_shader_draw_parameters]>, Capability<[SPV_C_DrawParameters]> ]; } def SPV_BI_DrawIndex : I32EnumAttrCase<"DrawIndex", 4426> { list availability = [ - MinVersion, Extension<[SPV_KHR_shader_draw_parameters, SPV_NV_mesh_shader]>, Capability<[SPV_C_DrawParameters, SPV_C_MeshShadingNV]> ]; } def SPV_BI_DeviceIndex : I32EnumAttrCase<"DeviceIndex", 4438> { list availability = [ - MinVersion, Extension<[SPV_KHR_device_group]>, Capability<[SPV_C_DeviceGroup]> ]; } def SPV_BI_ViewIndex : I32EnumAttrCase<"ViewIndex", 4440> { list availability = [ - MinVersion, Extension<[SPV_KHR_multiview]>, Capability<[SPV_C_MultiView]> ]; @@ -1803,13 +1779,11 @@ } def SPV_D_NoSignedWrap : I32EnumAttrCase<"NoSignedWrap", 4469> { list availability = [ - MinVersion, Extension<[SPV_KHR_no_integer_wrap_decoration]> ]; } def SPV_D_NoUnsignedWrap : I32EnumAttrCase<"NoUnsignedWrap", 4470> { list availability = [ - MinVersion, Extension<[SPV_KHR_no_integer_wrap_decoration]> ]; } @@ -1873,14 +1847,12 @@ } def SPV_D_RestrictPointer : I32EnumAttrCase<"RestrictPointer", 5355> { list availability = [ - MinVersion, Extension<[SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer]>, Capability<[SPV_C_PhysicalStorageBufferAddresses]> ]; } def SPV_D_AliasedPointer : I32EnumAttrCase<"AliasedPointer", 5356> { list availability = [ - MinVersion, Extension<[SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer]>, Capability<[SPV_C_PhysicalStorageBufferAddresses]> ]; @@ -2161,35 +2133,30 @@ } def SPV_EM_DenormPreserve : I32EnumAttrCase<"DenormPreserve", 4459> { list availability = [ - MinVersion, Extension<[SPV_KHR_float_controls]>, Capability<[SPV_C_DenormPreserve]> ]; } def SPV_EM_DenormFlushToZero : I32EnumAttrCase<"DenormFlushToZero", 4460> { list availability = [ - MinVersion, Extension<[SPV_KHR_float_controls]>, Capability<[SPV_C_DenormFlushToZero]> ]; } def SPV_EM_SignedZeroInfNanPreserve : I32EnumAttrCase<"SignedZeroInfNanPreserve", 4461> { list availability = [ - MinVersion, Extension<[SPV_KHR_float_controls]>, Capability<[SPV_C_SignedZeroInfNanPreserve]> ]; } def SPV_EM_RoundingModeRTE : I32EnumAttrCase<"RoundingModeRTE", 4462> { list availability = [ - MinVersion, Extension<[SPV_KHR_float_controls]>, Capability<[SPV_C_RoundingModeRTE]> ]; } def SPV_EM_RoundingModeRTZ : I32EnumAttrCase<"RoundingModeRTZ", 4463> { list availability = [ - MinVersion, Extension<[SPV_KHR_float_controls]>, Capability<[SPV_C_RoundingModeRTZ]> ]; @@ -2705,7 +2672,6 @@ } def SPV_MM_Vulkan : I32EnumAttrCase<"Vulkan", 3> { list availability = [ - MinVersion, Extension<[SPV_KHR_vulkan_memory_model]>, Capability<[SPV_C_VulkanMemoryModel]> ]; @@ -2755,7 +2721,6 @@ } def SPV_MS_Volatile : BitEnumAttrCase<"Volatile", 0x8000> { list availability = [ - MinVersion, Extension<[SPV_KHR_vulkan_memory_model]>, Capability<[SPV_C_VulkanMemoryModel]> ]; @@ -2835,7 +2800,6 @@ def SPV_SC_Image : I32EnumAttrCase<"Image", 11>; def SPV_SC_StorageBuffer : I32EnumAttrCase<"StorageBuffer", 12> { list availability = [ - MinVersion, Extension<[SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, Capability<[SPV_C_Shader]> ]; @@ -2878,7 +2842,6 @@ } def SPV_SC_PhysicalStorageBuffer : I32EnumAttrCase<"PhysicalStorageBuffer", 5349> { list availability = [ - MinVersion, Extension<[SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer]>, Capability<[SPV_C_PhysicalStorageBufferAddresses]> ]; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -17,6 +17,8 @@ #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" +#include + // Forward declare enum classes related to op availability. Their definitions // are in the TableGen'erated SPIRVEnums.h.inc and can be referenced by other // declarations in SPIRVEnums.h.inc. @@ -33,10 +35,22 @@ // Pull in all enum type availability query function declarations #include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.h.inc" -#include - namespace mlir { namespace spirv { +/// Returns the implied extensions for the given version. These extensions are +/// incorporated into the current version so they are implicitly declared when +/// targeting the given version. +ArrayRef getImpliedExtensions(Version version); + +/// Returns the directly implied capabilities for the given capability. These +/// capabilities are implicitly declared by the given capability. +ArrayRef getDirectImpliedCapabilities(Capability cap); +/// Returns the recursively implied capabilities for the given capability. These +/// capabilities are implicitly declared by the given capability. Compared to +/// the above function, this function collects implied capabilities recursively: +/// if an implicitly declared capability implicitly declares a third one, the +/// third one will also be returned. +SmallVector getRecursiveImpliedCapabilities(Capability cap); namespace detail { struct ArrayTypeStorage; diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -125,6 +125,7 @@ // StringAttr and IntegerAttr. class EnumAttrCase : public Attribute { public: + explicit EnumAttrCase(const llvm::Record *record); explicit EnumAttrCase(const llvm::DefInit *init); // Returns true if this EnumAttrCase is backed by a StringAttr. diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -247,9 +247,19 @@ givenExtensions.insert( *spirv::symbolizeExtension(extAttr.cast().getValue())); - for (Attribute capAttr : targetEnv.capabilities()) - givenCapabilities.insert( - static_cast(capAttr.cast().getInt())); + // Add extensions implied by the current version. + for (spirv::Extension ext : spirv::getImpliedExtensions(givenVersion)) + givenExtensions.insert(ext); + + for (Attribute capAttr : targetEnv.capabilities()) { + auto cap = + static_cast(capAttr.cast().getInt()); + givenCapabilities.insert(cap); + + // Add capabilities implied by the current capability. + for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap)) + givenCapabilities.insert(c); + } } bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" @@ -26,6 +27,73 @@ #include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.cpp.inc" //===----------------------------------------------------------------------===// +// Availability relationship +//===----------------------------------------------------------------------===// + +ArrayRef spirv::getImpliedExtensions(Version version) { + // Note: the following lists are from "Appendix A: Changes" of the spec. + +#define V_1_3_IMPLIED_EXTS \ + Extension::SPV_KHR_shader_draw_parameters, Extension::SPV_KHR_16bit_storage, \ + Extension::SPV_KHR_device_group, Extension::SPV_KHR_multiview, \ + Extension::SPV_KHR_storage_buffer_storage_class, \ + Extension::SPV_KHR_variable_pointers + +#define V_1_4_IMPLIED_EXTS \ + Extension::SPV_KHR_no_integer_wrap_decoration, \ + Extension::SPV_GOOGLE_decorate_string, \ + Extension::SPV_GOOGLE_hlsl_functionality1, \ + Extension::SPV_KHR_float_controls + +#define V_1_5_IMPLIED_EXTS \ + Extension::SPV_KHR_8bit_storage, Extension::SPV_EXT_descriptor_indexing, \ + Extension::SPV_EXT_shader_viewport_index_layer, \ + Extension::SPV_EXT_physical_storage_buffer, \ + Extension::SPV_KHR_physical_storage_buffer, \ + Extension::SPV_KHR_vulkan_memory_model + + switch (version) { + default: + return {}; + case Version::V_1_3: { + static Extension exts[] = {V_1_3_IMPLIED_EXTS}; + return exts; + } + case Version::V_1_4: { + static Extension exts[] = {V_1_3_IMPLIED_EXTS, V_1_4_IMPLIED_EXTS}; + return exts; + } + case Version::V_1_5: { + static Extension exts[] = {V_1_3_IMPLIED_EXTS, V_1_4_IMPLIED_EXTS, + V_1_5_IMPLIED_EXTS}; + return exts; + } + } + +#undef V_1_5_IMPLIED_EXTS +#undef V_1_4_IMPLIED_EXTS +#undef V_1_3_IMPLIED_EXTS +} + +// Pull in utility function definition for implied capabilities +#include "mlir/Dialect/SPIRV/SPIRVCapabilityImplication.inc" + +SmallVector +spirv::getRecursiveImpliedCapabilities(Capability cap) { + ArrayRef directCaps = getDirectImpliedCapabilities(cap); + llvm::SetVector> allCaps( + directCaps.begin(), directCaps.end()); + + // TODO(antiagainst): This is insufficient; find a better way to handle this + // (e.g., using static lists) if this turns out to be a bottleneck. + for (unsigned i = 0; i < allCaps.size(); ++i) + for (Capability c : getDirectImpliedCapabilities(allCaps[i])) + allCaps.insert(c); + + return allCaps.takeVector(); +} + +//===----------------------------------------------------------------------===// // ArrayType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -137,12 +137,15 @@ return def->getValueAsString("value"); } -tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) - : Attribute(init) { +tblgen::EnumAttrCase::EnumAttrCase(const llvm::Record *record) + : Attribute(record) { assert(isSubClassOf("EnumAttrCaseInfo") && "must be subclass of TableGen 'EnumAttrInfo' class"); } +tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init) + : EnumAttrCase(init->getDef()) {} + bool tblgen::EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); } diff --git a/mlir/test/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/Dialect/SPIRV/TestAvailability.cpp @@ -95,12 +95,24 @@ PatternRewriter &rewriter) const override; }; +struct ConvertToBitReverse : public RewritePattern { + ConvertToBitReverse(MLIRContext *context); + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + struct ConvertToGroupNonUniformBallot : public RewritePattern { ConvertToGroupNonUniformBallot(MLIRContext *context); PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; }; +struct ConvertToModule : public RewritePattern { + ConvertToModule(MLIRContext *context); + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + struct ConvertToSubgroupBallot : public RewritePattern { ConvertToSubgroupBallot(MLIRContext *context); PatternMatchResult matchAndRewrite(Operation *op, @@ -118,7 +130,8 @@ auto target = spirv::SPIRVConversionTarget::get(targetEnv, context); OwningRewritePatternList patterns; - patterns.insert(context); if (failed(applyPartialConversion(fn, *target, patterns))) @@ -146,6 +159,20 @@ return matchSuccess(); } +ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context) + : RewritePattern("test.convert_to_bit_reverse_op", {"spv.BitReverse"}, 1, + context) {} + +PatternMatchResult +ConvertToBitReverse::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + Value predicate = op->getOperand(0); + + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), predicate); + return matchSuccess(); +} + ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot( MLIRContext *context) : RewritePattern("test.convert_to_group_non_uniform_ballot_op", @@ -160,6 +187,18 @@ return matchSuccess(); } +ConvertToModule::ConvertToModule(MLIRContext *context) + : RewritePattern("test.convert_to_module_op", {"spv.module"}, 1, context) {} + +PatternMatchResult +ConvertToModule::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + rewriter.replaceOpWithNewOp( + op, spirv::AddressingModel::PhysicalStorageBuffer64, + spirv::MemoryModel::Vulkan); + return matchSuccess(); +} + ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context) : RewritePattern("test.convert_to_subgroup_ballot_op", {"spv.SubgroupBallotKHR"}, 1, context) {} diff --git a/mlir/test/Dialect/SPIRV/availability.mlir b/mlir/test/Dialect/SPIRV/availability.mlir --- a/mlir/test/Dialect/SPIRV/availability.mlir +++ b/mlir/test/Dialect/SPIRV/availability.mlir @@ -42,7 +42,7 @@ // CHECK-LABEL: module_physical_storage_buffer64_vulkan func @module_physical_storage_buffer64_vulkan() { - // CHECK: spv.module min version: V_1_5 + // CHECK: spv.module min version: V_1_0 // CHECK: spv.module max version: V_1_5 // CHECK: spv.module extensions: [ [SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer] [SPV_KHR_vulkan_memory_model] ] // CHECK: spv.module capabilities: [ [PhysicalStorageBufferAddresses] [VulkanMemoryModel] ] diff --git a/mlir/test/Dialect/SPIRV/target-env.mlir b/mlir/test/Dialect/SPIRV/target-env.mlir --- a/mlir/test/Dialect/SPIRV/target-env.mlir +++ b/mlir/test/Dialect/SPIRV/target-env.mlir @@ -11,16 +11,30 @@ // whose value, if containing AtomicCounterMemory bit, additionally requires // AtomicStorage capability. +// spv.BitReverse is available in all SPIR-V versiosn under Shader capability. + // spv.GroupNonUniformBallot is available starting from SPIR-V 1.3 under // GroupNonUniform capability. // spv.SubgroupBallotKHR is available under in all SPIR-V versions under // SubgroupBallotKHR capability and SPV_KHR_shader_ballot extension. +// The GeometryPointSize capability implies the Geometry capability, which +// implies the Shader capability. + +// PhysicalStorageBuffer64 addressing model is available via extension +// SPV_EXT_physical_storage_buffer or SPV_KHR_physical_storage_buffer; +// both extensions are incorporated into SPIR-V 1.5. + +// Vulkan memory model is available via extension SPV_KHR_vulkan_memory_model, +// which extensions are incorporated into SPIR-V 1.5. + // Enum case symbol (value) map: -// Version: 1.0 (0), 1.1 (1), 1.2 (2), 1.3 (3), 1.4 (4) -// Capability: Kernel (6), AtomicStorage (21), GroupNonUniformBallot (64), -// SubgroupBallotKHR (4423) +// Version: 1.0 (0), 1.1 (1), 1.2 (2), 1.3 (3), 1.4 (4), 1.5 (5) +// Capability: Shader (1), Geometry (2), Kernel (6), AtomicStorage (21), +// GeometryPointSize (24), GroupNonUniformBallot (64), +// SubgroupBallotKHR (4423), VulkanMemoryModel (5345), +// PhysicalStorageBufferAddresses (5347) //===----------------------------------------------------------------------===// // MaxVersion @@ -97,6 +111,24 @@ return %0: vector<4xi32> } +// CHECK-LABEL: @bit_reverse_directly_implied_capability +func @bit_reverse_directly_implied_capability(%operand: i32) -> i32 attributes { + spv.target_env = {version = 0: i32, extensions = [], capabilities = [2: i32]} +} { + // CHECK: spv.BitReverse + %0 = "test.convert_to_bit_reverse_op"(%operand): (i32) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @bit_reverse_recursively_implied_capability +func @bit_reverse_recursively_implied_capability(%operand: i32) -> i32 attributes { + spv.target_env = {version = 0: i32, extensions = [], capabilities = [24: i32]} +} { + // CHECK: spv.BitReverse + %0 = "test.convert_to_bit_reverse_op"(%operand): (i32) -> (i32) + return %0: i32 +} + //===----------------------------------------------------------------------===// // Extension //===----------------------------------------------------------------------===// @@ -118,3 +150,49 @@ %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>) return %0: vector<4xi32> } + +// CHECK-LABEL: @module_suitable_extension1 +func @module_suitable_extension1() attributes { + spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_vulkan_memory_model", "SPV_EXT_physical_storage_buffer"], capabilities = [5345: i32, 5347: i32]} +} { + // CHECK: spv.module "PhysicalStorageBuffer64" "Vulkan" + "test.convert_to_module_op"() : () ->() + return +} + +// CHECK-LABEL: @module_suitable_extension2 +func @module_suitable_extension2() attributes { + spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_vulkan_memory_model", "SPV_KHR_physical_storage_buffer"], capabilities = [5345: i32, 5347: i32]} +} { + // CHECK: spv.module "PhysicalStorageBuffer64" "Vulkan" + "test.convert_to_module_op"() : () -> () + return +} + +// CHECK-LABEL: @module_missing_extension_mm +func @module_missing_extension_mm() attributes { + spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_physical_storage_buffer"], capabilities = [5345: i32, 5347: i32]} +} { + // CHECK: test.convert_to_module_op + "test.convert_to_module_op"() : () -> () + return +} + +// CHECK-LABEL: @module_missing_extension_am +func @module_missing_extension_am() attributes { + spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_vulkan_memory_model"], capabilities = [5345: i32, 5347: i32]} +} { + // CHECK: test.convert_to_module_op + "test.convert_to_module_op"() : () -> () + return +} + +// CHECK-LABEL: @module_implied_extension +func @module_implied_extension() attributes { + // Version 1.5 implies SPV_KHR_vulkan_memory_model and SPV_KHR_physical_storage_buffer. + spv.target_env = {version = 5: i32, extensions = [], capabilities = [5345: i32, 5347: i32]} +} { + // CHECK: spv.module "PhysicalStorageBuffer64" "Vulkan" + "test.convert_to_module_op"() : () -> () + return +} diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Support/STLExtras.h" #include "mlir/Support/StringExtras.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/Format.h" @@ -1283,3 +1284,48 @@ [](const RecordKeeper &records, raw_ostream &os) { return emitAvailabilityImpl(records, os); }); + +//===----------------------------------------------------------------------===// +// SPIR-V Capability Implication AutoGen +//===----------------------------------------------------------------------===// + +static bool emitCapabilityImplication(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("SPIR-V Capability Implication", os); + + EnumAttr enumAttr(recordKeeper.getDef("SPV_CapabilityAttr")); + + os << "ArrayRef " + "spirv::getDirectImpliedCapabilities(Capability cap) {\n" + << " switch (cap) {\n" + << " default: return {};\n"; + for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) { + const Record &def = enumerant.getDef(); + if (!def.getValue("implies")) + continue; + + os << " case Capability::" << enumerant.getSymbol() + << ": {static Capability implies[] = {"; + std::vector impliedCapsDefs = def.getValueAsListOfDefs("implies"); + mlir::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) { + os << "Capability::" << EnumAttrCase(capDef).getSymbol(); + }); + os << "}; return implies; }\n"; + } + os << " }\n"; + os << "}\n"; + + return false; +} + +//===----------------------------------------------------------------------===// +// SPIR-V Capability Implication Hook Registration +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration + genCapabilityImplication("gen-spirv-capability-implication", + "Generate utilty function to return implied " + "capabilities for a given capability", + [](const RecordKeeper &records, raw_ostream &os) { + return emitCapabilityImplication(records, os); + }); diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -266,6 +266,13 @@ exts = enum_case.get('extensions', []) if exts: exts = 'Extension<[{}]>'.format(', '.join(sorted(set(exts)))) + # We need to strip the minimal version requirement if this symbol is + # available via an extension, which means *any* SPIR-V version can support + # it as long as the extension is provided. The grammar's 'version' field + # under such case should be interpreted as this symbol is introduced as + # a core symbol since the given version, rather than a minimal version + # requirement. + min_version = 'MinVersion' if for_op else '' # TODO(antiagainst): delete this once ODS can support dialect-specific content # and we can use omission to mean no requirements. if for_op and not exts: