diff --git a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt --- a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt @@ -5,6 +5,11 @@ mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRSPIRVEnumsIncGen) +set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) +mlir_tablegen(SPIRVEnumAvailability.h.inc -gen-spirv-enum-avail-decls) +mlir_tablegen(SPIRVEnumAvailability.cpp.inc -gen-spirv-enum-avail-defs) +add_public_tablegen_target(MLIRSPIRVEnumAvailabilityIncGen) + set(LLVM_TARGET_DEFINITIONS SPIRVOps.td) mlir_tablegen(SPIRVAvailability.h.inc -gen-avail-interface-decls) mlir_tablegen(SPIRVAvailability.cpp.inc -gen-avail-interface-defs) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -467,7 +467,13 @@ def SPV_AM_Logical : I32EnumAttrCase<"Logical", 0>; def SPV_AM_Physical32 : I32EnumAttrCase<"Physical32", 1>; def SPV_AM_Physical64 : I32EnumAttrCase<"Physical64", 2>; -def SPV_AM_PhysicalStorageBuffer64 : I32EnumAttrCase<"PhysicalStorageBuffer64", 5348>; +def SPV_AM_PhysicalStorageBuffer64 : I32EnumAttrCase<"PhysicalStorageBuffer64", 5348> { + list availability = [ + MinVersion, + Extension<[SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer]>, + Capability<[SPV_C_PhysicalStorageBufferAddresses]> + ]; +} def SPV_AddressingModelAttr : I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [ @@ -944,7 +950,12 @@ def SPV_MM_Simple : I32EnumAttrCase<"Simple", 0>; def SPV_MM_GLSL450 : I32EnumAttrCase<"GLSL450", 1>; def SPV_MM_OpenCL : I32EnumAttrCase<"OpenCL", 2>; -def SPV_MM_Vulkan : I32EnumAttrCase<"Vulkan", 3>; +def SPV_MM_Vulkan : I32EnumAttrCase<"Vulkan", 3> { + list availability = [ + MinVersion, + Capability<[SPV_C_VulkanMemoryModel]> + ]; +} def SPV_MemoryModelAttr : I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -17,8 +17,21 @@ #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" +// Forward declare enum classes related to op availability. Their definitions +// are in the TableGen'erated SPIRVEnums.h.inc and can be referenced by other +// dclarations in SPIRVEnums.h.inc. +namespace mlir { +namespace spirv { +enum class Version : uint32_t; +enum class Extension; +enum class Capability : uint32_t; +} // namespace spirv +} // namespace mlir + // Pull in all enum type definitions and utility function declarations #include "mlir/Dialect/SPIRV/SPIRVEnums.h.inc" +// Pull in all enum type availability query function declarations +#include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.h.inc" #include diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -135,6 +135,9 @@ // Returns the value of this enum attribute case. int64_t getValue() const; + + // Returns the TableGen definition this EnumAttrCase was constructed from. + const llvm::Record &getDef() const; }; // Wrapper class providing helper methods for accessing enum attributes defined @@ -146,6 +149,8 @@ explicit EnumAttr(const llvm::Record &record); explicit EnumAttr(const llvm::DefInit *init); + static bool classof(const Attribute *attr); + // Returns true if this is a bit enum attribute. bool isBitEnum() const; diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt @@ -18,6 +18,7 @@ add_dependencies(MLIRSPIRV MLIRSPIRVAvailabilityIncGen MLIRSPIRVCanonicalizationIncGen + MLIRSPIRVEnumAvailabilityIncGen MLIRSPIRVEnumsIncGen MLIRSPIRVOpsIncGen MLIRSPIRVOpUtilsGen diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -22,6 +22,8 @@ // Pull in all enum utility function definitions #include "mlir/Dialect/SPIRV/SPIRVEnums.cpp.inc" +// Pull in all enum type availability query function definitions +#include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.cpp.inc" //===----------------------------------------------------------------------===// // ArrayType diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -155,6 +155,8 @@ return def->getValueAsInt("value"); } +const llvm::Record &tblgen::EnumAttrCase::getDef() const { return *def; } + tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) { assert(isSubClassOf("EnumAttrInfo") && "must be subclass of TableGen 'EnumAttr' class"); @@ -165,6 +167,10 @@ tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {} +bool tblgen::EnumAttr::classof(const Attribute *attr) { + return attr->isSubClassOf("EnumAttrInfo"); +} + bool tblgen::EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); } StringRef tblgen::EnumAttr::getEnumClassName() const { diff --git a/mlir/test/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/Dialect/SPIRV/TestAvailability.cpp @@ -30,18 +30,19 @@ if (op->getDialect() != spvDialect) return WalkResult::advance(); + auto opName = op->getName(); auto &os = llvm::outs(); if (auto minVersion = dyn_cast(op)) - os << " min version: " + os << opName << " min version: " << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n"; if (auto maxVersion = dyn_cast(op)) - os << " max version: " + os << opName << " max version: " << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n"; if (auto extension = dyn_cast(op)) { - os << " extensions: ["; + os << opName << " extensions: ["; for (const auto &exts : extension.getExtensions()) { os << " ["; interleaveComma(exts, os, [&](spirv::Extension ext) { @@ -53,7 +54,7 @@ } if (auto capability = dyn_cast(op)) { - os << " capabilities: ["; + os << opName << " capabilities: ["; for (const auto &caps : capability.getCapabilities()) { os << " ["; interleaveComma(caps, os, [&](spirv::Capability cap) { diff --git a/mlir/test/Dialect/SPIRV/availability.mlir b/mlir/test/Dialect/SPIRV/availability.mlir --- a/mlir/test/Dialect/SPIRV/availability.mlir +++ b/mlir/test/Dialect/SPIRV/availability.mlir @@ -29,3 +29,23 @@ %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32> return %0: vector<4xi32> } + +// CHECK-LABEL: module_logical_glsl450 +func @module_logical_glsl450() { + // CHECK: spv.module min version: V_1_0 + // CHECK: spv.module max version: V_1_5 + // CHECK: spv.module extensions: [ ] + // CHECK: spv.module capabilities: [ ] + spv.module "Logical" "GLSL450" { } + return +} + +// CHECK-LABEL: module_physical_storage_buffer64_vulkan +func @module_physical_storage_buffer64_vulkan() { + // CHECK: spv.module min version: V_1_5 + // CHECK: spv.module max version: V_1_5 + // CHECK: spv.module extensions: [ [SPV_EXT_physical_storage_buffer, SPV_KHR_physical_storage_buffer] ] + // CHECK: spv.module capabilities: [ [PhysicalStorageBufferAddresses] [VulkanMemoryModel] ] + spv.module "PhysicalStorageBuffer64" "Vulkan" { } + return +} diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TableGen/Error.h" @@ -40,6 +41,7 @@ using llvm::Twine; using mlir::tblgen::Attribute; using mlir::tblgen::EnumAttr; +using mlir::tblgen::EnumAttrCase; using mlir::tblgen::NamedAttribute; using mlir::tblgen::NamedTypeConstraint; using mlir::tblgen::Operator; @@ -138,6 +140,20 @@ return def->getValueAsString("instance"); } +// Returns the availability spec of the given `def`. +std::vector getAvailabilities(const Record &def) { + std::vector availabilities; + + if (def.getValue("availability")) { + std::vector availDefs = def.getValueAsListOfDefs("availability"); + availabilities.reserve(availDefs.size()); + for (const Record *avail : availDefs) + availabilities.emplace_back(avail); + } + + return availabilities; +} + //===----------------------------------------------------------------------===// // Availability Interface Definitions AutoGen //===----------------------------------------------------------------------===// @@ -272,6 +288,186 @@ }); //===----------------------------------------------------------------------===// +// Enum Availability Query AutoGen +//===----------------------------------------------------------------------===// + +static void emitAvailabilityQueryForIntEnum(const Record &enumDef, + raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::vector enumerants = enumAttr.getAllCases(); + + // Mapping from availability class name to (enumerant, availablity + // specification) pairs. + llvm::StringMap, 1>> + classCaseMap; + + // Place all availablity specifications to their corresponding + // availablility classes. + for (const EnumAttrCase &enumerant : enumerants) + for (const Availability &avail : getAvailabilities(enumerant.getDef())) + classCaseMap[avail.getClass()].push_back({enumerant, avail}); + + for (const auto &classCasePair : classCaseMap) { + Availability avail = classCasePair.getValue().front().second; + + os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", + avail.getMergeInstanceType(), avail.getQueryFnName(), + enumName); + + os << " switch (value) {\n"; + for (const auto &caseSpecPair : classCasePair.getValue()) { + EnumAttrCase enumerant = caseSpecPair.first; + Availability avail = caseSpecPair.second; + os << formatv(" case {0}::{1}: return {2}({3});\n", enumName, + enumerant.getSymbol(), avail.getMergeInstanceType(), + avail.getMergeInstance()); + } + os << " default: break;\n"; + os << " }\n" + << " return llvm::None;\n" + << "}\n"; + } +} + +static void emitAvailabilityQueryForBitEnum(const Record &enumDef, + raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::string underlyingType = enumAttr.getUnderlyingType(); + std::vector enumerants = enumAttr.getAllCases(); + + // Mapping from availability class name to (enumerant, availablity + // specification) pairs. + llvm::StringMap, 1>> + classCaseMap; + + // Place all availablity specifications to their corresponding + // availablility classes. + for (const EnumAttrCase &enumerant : enumerants) + for (const Availability &avail : getAvailabilities(enumerant.getDef())) + classCaseMap[avail.getClass()].push_back({enumerant, avail}); + + for (const auto &classCasePair : classCaseMap) { + Availability avail = classCasePair.getValue().front().second; + + os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", + avail.getMergeInstanceType(), avail.getQueryFnName(), + enumName); + + os << formatv( + " assert(::llvm::countPopulation(static_cast<{0}>(value)) <= 1" + " && \"cannot have more than one bit set\");\n", + underlyingType); + + os << " switch (value) {\n"; + for (const auto &caseSpecPair : classCasePair.getValue()) { + EnumAttrCase enumerant = caseSpecPair.first; + Availability avail = caseSpecPair.second; + os << formatv(" case {0}::{1}: return {2}({3});\n", enumName, + enumerant.getSymbol(), avail.getMergeInstanceType(), + avail.getMergeInstance()); + } + os << " default: break;\n"; + os << " }\n" + << " return llvm::None;\n" + << "}\n"; + } +} + +static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + StringRef cppNamespace = enumAttr.getCppNamespace(); + auto enumerants = enumAttr.getAllCases(); + + llvm::SmallVector namespaces; + llvm::SplitString(cppNamespace, namespaces, "::"); + + for (auto ns : namespaces) + os << "namespace " << ns << " {\n"; + + llvm::StringSet<> handledClasses; + + // Place all availablity specifications to their corresponding + // availablility classes. + for (const EnumAttrCase &enumerant : enumerants) + for (const Availability &avail : getAvailabilities(enumerant.getDef())) { + StringRef className = avail.getClass(); + if (handledClasses.count(className)) + continue; + os << formatv("llvm::Optional<{0}> {1}({2} value);\n", + avail.getMergeInstanceType(), avail.getQueryFnName(), + enumName); + handledClasses.insert(className); + } + + for (auto ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; +} + +static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { + llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); + for (const auto *def : defs) + emitEnumDecl(*def, os); + + return false; +} + +static void emitEnumDef(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef cppNamespace = enumAttr.getCppNamespace(); + + llvm::SmallVector namespaces; + llvm::SplitString(cppNamespace, namespaces, "::"); + + for (auto ns : namespaces) + os << "namespace " << ns << " {\n"; + + if (enumAttr.isBitEnum()) { + emitAvailabilityQueryForBitEnum(enumDef, os); + } else { + emitAvailabilityQueryForIntEnum(enumDef, os); + } + + for (auto ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; + os << "\n"; +} + +static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { + llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); + for (const auto *def : defs) + emitEnumDef(*def, os); + + return false; +} + +//===----------------------------------------------------------------------===// +// Enum Availability Query Hook Registration +//===----------------------------------------------------------------------===// + +// Registers the enum utility generator to mlir-tblgen. +static mlir::GenRegistration + genEnumDecls("gen-spirv-enum-avail-decls", + "Generate SPIR-V enum availability declarations", + [](const RecordKeeper &records, raw_ostream &os) { + return emitEnumDecls(records, os); + }); + +// Registers the enum utility generator to mlir-tblgen. +static mlir::GenRegistration + genEnumDefs("gen-spirv-enum-avail-defs", + "Generate SPIR-V enum availability definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitEnumDefs(records, os); + }); + +//===----------------------------------------------------------------------===// // Serialization AutoGen //===----------------------------------------------------------------------===// @@ -960,18 +1156,6 @@ // SPIR-V Availability Impl AutoGen //===----------------------------------------------------------------------===// -// Returns the availability spec of the given `def`. -std::vector getAvailabilities(const Record &def) { - std::vector availabilities; - if (auto *availListInit = def.getValueAsListInit("availability")) { - availabilities.reserve(availListInit->size()); - for (auto *availInit : *availListInit) - availabilities.emplace_back( - llvm::cast(availInit)->getDef()); - } - return availabilities; -} - static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { mlir::tblgen::FmtContext fctx; fctx.addSubst("overall", "overall"); @@ -986,6 +1170,16 @@ llvm::StringMap availClasses; for (const Availability &avail : opAvailabilities) availClasses.try_emplace(avail.getClass(), avail); + for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { + const auto *enumAttr = llvm::dyn_cast(&namedAttr.attr); + if (!enumAttr) + continue; + + for (const EnumAttrCase &enumerant : enumAttr->getAllCases()) + for (const Availability &caseAvail : + getAvailabilities(enumerant.getDef())) + availClasses.try_emplace(caseAvail.getClass(), caseAvail); + } // Then generate implementation for each availability class. for (const auto &availClass : availClasses) { @@ -1008,6 +1202,58 @@ &fctx.addSubst("instance", avail.getMergeInstance())) << ";\n"; } + + // Update with enum attributes' specific availability spec. + for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { + const auto *enumAttr = llvm::dyn_cast(&namedAttr.attr); + if (!enumAttr) + continue; + + // (enumerant, availablity specification) pairs for this availability + // class. + SmallVector, 1> caseSpecs; + + // Collect all cases' availability specs. + for (const EnumAttrCase &enumerant : enumAttr->getAllCases()) + for (const Availability &caseAvail : + getAvailabilities(enumerant.getDef())) + if (availClassName == caseAvail.getClass()) + caseSpecs.push_back({enumerant, caseAvail}); + + // If this attribute kind does not have any availablity spec from any of + // its cases, no more work to do. + if (caseSpecs.empty()) + continue; + + if (enumAttr->isBitEnum()) { + // For BitEnumAttr, we need to iterate over each bit to query its + // availability spec. + os << formatv(" for (unsigned i = 0; " + "i < std::numeric_limits<{0}>::digits; ++i) {{\n", + enumAttr->getUnderlyingType()); + os << formatv(" {0}::{1} attrVal = this->{2}() & " + "static_cast<{0}::{1}>(1 << i);\n", + enumAttr->getCppNamespace(), enumAttr->getEnumClassName(), + namedAttr.name); + os << formatv(" if (static_cast<{0}>(attrVal) == 0) continue;\n", + enumAttr->getUnderlyingType()); + } else { + // For IntEnumAttr, we just need to query the value as a whole. + os << " {\n"; + os << formatv(" auto attrVal = this->{0}();\n", namedAttr.name); + } + os << formatv(" auto instance = {0}::{1}(attrVal);\n", + enumAttr->getCppNamespace(), avail.getQueryFnName()); + os << " if (instance) " + // TODO(antiagainst): use `avail.getMergeCode()` here once ODS supports + // dialect-specific contents so that we can use not implementing the + // availability interface as indication of no requirements. + << tgfmt(caseSpecs.front().second.getMergeActionCode(), + &fctx.addSubst("instance", "*instance")) + << ";\n"; + os << " }\n"; + } + os << " return overall;\n"; os << "}\n"; }