diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SPIRV/IR/CMakeLists.txt @@ -8,8 +8,8 @@ add_dependencies(mlir-headers MLIRSPIRVEnumsIncGen) set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) -mlir_tablegen(SPIRVEnumAvailability.h.inc -gen-spirv-enum-avail-decls) -mlir_tablegen(SPIRVEnumAvailability.cpp.inc -gen-spirv-enum-avail-defs) +mlir_tablegen(SPIRVEnumAvailability.h.inc -gen-enum-avail-decls) +mlir_tablegen(SPIRVEnumAvailability.cpp.inc -gen-enum-avail-defs) mlir_tablegen(SPIRVCapabilityImplication.inc -gen-spirv-capability-implication) add_public_tablegen_target(MLIRSPIRVEnumAvailabilityIncGen) add_dependencies(mlir-headers MLIRSPIRVEnumAvailabilityIncGen) @@ -17,7 +17,6 @@ set(LLVM_TARGET_DEFINITIONS SPIRVOps.td) mlir_tablegen(SPIRVAvailability.h.inc -gen-avail-interface-decls) mlir_tablegen(SPIRVAvailability.cpp.inc -gen-avail-interface-defs) -mlir_tablegen(SPIRVOpAvailabilityImpl.inc -gen-spirv-avail-impls) add_public_tablegen_target(MLIRSPIRVAvailabilityIncGen) add_dependencies(mlir-headers MLIRSPIRVAvailabilityIncGen) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td deleted file mode 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td +++ /dev/null @@ -1,97 +0,0 @@ -//===- SPIRVAvailability.td - Op Availability Base file ----*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_SPIRV_IR_AVAILABILITY -#define MLIR_DIALECT_SPIRV_IR_AVAILABILITY - -include "mlir/IR/OpBase.td" - -//===----------------------------------------------------------------------===// -// Op availability definitions -//===----------------------------------------------------------------------===// - -// The base class for defining op availability dimensions. -class Availability { - // The following are fields for controlling the generated C++ OpInterface. - - // The namespace for the generated C++ OpInterface subclass. - string interfaceNamespace = ?; - // The name for the generated C++ OpInterface subclass. - string interfaceName = ?; - // The documentation for the generated C++ OpInterface subclass. - string interfaceDescription = ""; - - // The following are fields for controlling the query function signature. - - // The query function's return type in the generated C++ OpInterface subclass. - string queryFnRetType = ?; - // The query function's name in the generated C++ OpInterface subclass. - string queryFnName = ?; - - // The following are fields for controlling the query function implementation. - - // The logic for merging two availability requirements. This is used to derive - // the final availability requirement when, for example, an op has two - // operands and these two operands have different availability requirements. - // - // The code should use `$overall` as the placeholder for the final requirement - // and `$instance` for the current availability requirement instance. - code mergeAction = ?; - // The initializer for the final availability requirement. - string initializer = ?; - // An availability instance's type. - string instanceType = ?; - - // The following are fields for a concrete availability instance. - - // The code for preparing a concrete instance. This should be C++ statements - // and will be generated before the `mergeAction` logic. - code instancePreparation = ""; - // The availability requirement carried by a concrete instance. - string instance = ?; -} - -class MinVersionBase - : Availability { - let interfaceName = name; - - let queryFnRetType = "llvm::Optional<" # scheme.returnType # ">"; - let queryFnName = "getMinVersion"; - - let mergeAction = "{ " - "if ($overall.hasValue()) { " - "$overall = static_cast<" # scheme.returnType # ">(" - "std::max(*$overall, $instance)); " - "} else { $overall = $instance; }}"; - let initializer = "::llvm::None"; - let instanceType = scheme.cppNamespace # "::" # scheme.className; - - let instance = scheme.cppNamespace # "::" # scheme.className # "::" # - min.symbol; -} - -class MaxVersionBase - : Availability { - let interfaceName = name; - - let queryFnRetType = "llvm::Optional<" # scheme.returnType # ">"; - let queryFnName = "getMaxVersion"; - - let mergeAction = "{ " - "if ($overall.hasValue()) { " - "$overall = static_cast<" # scheme.returnType # ">(" - "std::min(*$overall, $instance)); " - "} else { $overall = $instance; }}"; - let initializer = "::llvm::None"; - let instanceType = scheme.cppNamespace # "::" # scheme.className; - - let instance = scheme.cppNamespace # "::" # scheme.className # "::" # - max.symbol; -} - -#endif // MLIR_DIALECT_SPIRV_IR_AVAILABILITY diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -16,7 +16,6 @@ #define MLIR_DIALECT_SPIRV_IR_BASE include "mlir/IR/OpBase.td" -include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.td" //===----------------------------------------------------------------------===// // SPIR-V dialect definitions diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2281,6 +2281,89 @@ dag results = rets; } +//===----------------------------------------------------------------------===// +// Availability definitions +//===----------------------------------------------------------------------===// + +// The base class for defining op availability dimensions. +class Availability { + // The following are fields for controlling the generated C++ OpInterface. + + // The namespace for the generated C++ OpInterface subclass. + string interfaceNamespace = ?; + // The name for the generated C++ OpInterface subclass. + string interfaceName = ?; + // The documentation for the generated C++ OpInterface subclass. + string interfaceDescription = ""; + + // The following are fields for controlling the query function signature. + + // The query function's return type in the generated C++ OpInterface subclass. + string queryFnRetType = ?; + // The query function's name in the generated C++ OpInterface subclass. + string queryFnName = ?; + + // The following are fields for controlling the query function implementation. + + // The logic for merging two availability requirements. This is used to derive + // the final availability requirement when, for example, an op has two + // operands and these two operands have different availability requirements. + // + // The code should use `$overall` as the placeholder for the final requirement + // and `$instance` for the current availability requirement instance. + code mergeAction = ?; + // The initializer for the final availability requirement. + string initializer = ?; + // An availability instance's type. + string instanceType = ?; + + // The following are fields for a concrete availability instance. + + // The code for preparing a concrete instance. This should be C++ statements + // and will be generated before the `mergeAction` logic. + code instancePreparation = ""; + // The availability requirement carried by a concrete instance. + string instance = ?; +} + +class MinVersionBase + : Availability { + let interfaceName = name; + + let queryFnRetType = "llvm::Optional<" # scheme.returnType # ">"; + let queryFnName = "getMinVersion"; + + let mergeAction = "{ " + "if ($overall.hasValue()) { " + "$overall = static_cast<" # scheme.returnType # ">(" + "std::max(*$overall, $instance)); " + "} else { $overall = $instance; }}"; + let initializer = "::llvm::None"; + let instanceType = scheme.cppNamespace # "::" # scheme.className; + + let instance = scheme.cppNamespace # "::" # scheme.className # "::" # + min.symbol; +} + +class MaxVersionBase + : Availability { + let interfaceName = name; + + let queryFnRetType = "llvm::Optional<" # scheme.returnType # ">"; + let queryFnName = "getMaxVersion"; + + let mergeAction = "{ " + "if ($overall.hasValue()) { " + "$overall = static_cast<" # scheme.returnType # ">(" + "std::min(*$overall, $instance)); " + "} else { $overall = $instance; }}"; + let initializer = "::llvm::None"; + let instanceType = scheme.cppNamespace # "::" # scheme.className; + + let instance = scheme.cppNamespace # "::" # scheme.className # "::" # + max.symbol; +} + //===----------------------------------------------------------------------===// // Common value constraints //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Availability.h b/mlir/include/mlir/TableGen/Availability.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/TableGen/Availability.h @@ -0,0 +1,79 @@ +//===- Availability.h - Availability wrapper class --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Availability wrapper to simplify using TableGen Record defining a MLIR +// Availability. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_AVAILABILITY_H_ +#define MLIR_TABLEGEN_AVAILABILITY_H_ + +#include "mlir/Support/LLVM.h" + +namespace llvm { +class Record; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +// Wrapper class with helper methods for accessing availability defined in +// TableGen. +class Availability { +public: + explicit Availability(const llvm::Record *def); + + // Returns the name of the direct TableGen class for this availability + // instance. + StringRef getClass() const; + + // Returns the generated C++ interface's class namespace. + StringRef getInterfaceClassNamespace() const; + + // Returns the generated C++ interface's class name. + StringRef getInterfaceClassName() const; + + // Returns the generated C++ interface's description. + StringRef getInterfaceDescription() const; + + // Returns the name of the query function insided the generated C++ interface. + StringRef getQueryFnName() const; + + // Returns the return type of the query function insided the generated C++ + // interface. + StringRef getQueryFnRetType() const; + + // Returns the code for merging availability requirements. + StringRef getMergeActionCode() const; + + // Returns the initializer expression for initializing the final availability + // requirements. + StringRef getMergeInitializer() const; + + // Returns the C++ type for an availability instance. + StringRef getMergeInstanceType() const; + + // Returns the C++ statements for preparing availability instance. + StringRef getMergeInstancePreparation() const; + + // Returns the concrete availability instance carried in this case. + StringRef getMergeInstance() const; + + // Returns the underlying LLVM TableGen Record. + const llvm::Record *getDef() const { return def; } + +private: + // The TableGen definition of this availability. + const llvm::Record *def; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_AVAILABILITY_H_ diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -3777,10 +3777,3 @@ // TablenGen'erated operation definitions. #define GET_OP_CLASSES #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc" - -namespace mlir { -namespace spirv { -// TableGen'erated operation availability interface implementations. -#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc" -} // namespace spirv -} // namespace mlir diff --git a/mlir/lib/TableGen/Availability.cpp b/mlir/lib/TableGen/Availability.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/TableGen/Availability.cpp @@ -0,0 +1,70 @@ +//===- Availability.cpp - Availability definitions ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/Availability.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +Availability::Availability(const llvm::Record *def) : def(def) { + assert(def->isSubClassOf("Availability") && + "must be subclass of TableGen 'Availability' class"); +} + +StringRef Availability::getClass() const { + SmallVector parentClass; + def->getDirectSuperClasses(parentClass); + if (parentClass.size() != 1) { + PrintFatalError(def->getLoc(), + "expected to only have one direct superclass"); + } + return parentClass.front()->getName(); +} + +StringRef Availability::getInterfaceClassNamespace() const { + return def->getValueAsString("interfaceNamespace"); +} + +StringRef Availability::getInterfaceClassName() const { + return def->getValueAsString("interfaceName"); +} + +StringRef Availability::getInterfaceDescription() const { + return def->getValueAsString("interfaceDescription"); +} + +StringRef Availability::getQueryFnRetType() const { + return def->getValueAsString("queryFnRetType"); +} + +StringRef Availability::getQueryFnName() const { + return def->getValueAsString("queryFnName"); +} + +StringRef Availability::getMergeActionCode() const { + return def->getValueAsString("mergeAction"); +} + +StringRef Availability::getMergeInitializer() const { + return def->getValueAsString("initializer"); +} + +StringRef Availability::getMergeInstanceType() const { + return def->getValueAsString("instanceType"); +} + +StringRef Availability::getMergeInstancePreparation() const { + return def->getValueAsString("instancePreparation"); +} + +StringRef Availability::getMergeInstance() const { + return def->getValueAsString("instance"); +} diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt --- a/mlir/lib/TableGen/CMakeLists.txt +++ b/mlir/lib/TableGen/CMakeLists.txt @@ -12,6 +12,7 @@ Argument.cpp Attribute.cpp AttrOrTypeDef.cpp + Availability.cpp Builder.cpp Constraint.cpp Dialect.cpp diff --git a/mlir/tools/mlir-tblgen/AvailabilityInterfaceGen.cpp b/mlir/tools/mlir-tblgen/AvailabilityInterfaceGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/AvailabilityInterfaceGen.cpp @@ -0,0 +1,182 @@ +//===- AvailabilityInterfacesGen.cpp --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// AvailabilityInterfacesGen generates definitions for availability interfaces. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/LLVM.h" +#include "mlir/TableGen/Availability.h" +#include "mlir/TableGen/GenInfo.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +using namespace mlir; +using llvm::Record; +using llvm::RecordKeeper; +using mlir::tblgen::Availability; + +//===----------------------------------------------------------------------===// +// Availability Interface Declarations AutoGen +//===----------------------------------------------------------------------===// + +static void emitConceptDecl(const Availability &availability, raw_ostream &os) { + os << " class Concept {\n" + << " public:\n" + << " virtual ~Concept() = default;\n" + << " virtual " << availability.getQueryFnRetType() << " " + << availability.getQueryFnName() + << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n" + << " };\n"; +} + +static void emitModelDecl(const Availability &availability, raw_ostream &os) { + for (const char *modelClass : {"Model", "FallbackModel"}) { + os << " template\n"; + os << " class " << modelClass << " : public Concept {\n" + << " public:\n" + << " " << availability.getQueryFnRetType() << " " + << availability.getQueryFnName() + << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n" + << " auto op = llvm::cast(tblgen_opaque_op);\n" + << " (void)op;\n" + // Forward to the method on the concrete operation type. + << " return op." << availability.getQueryFnName() << "();\n" + << " }\n" + << " };\n"; + } + os << " template\n"; + os << " class ExternalModel : public FallbackModel {};\n"; +} + +static void emitInterfaceDecl(const Availability &availability, + raw_ostream &os) { + StringRef interfaceName = availability.getInterfaceClassName(); + std::string interfaceTraitsName = + std::string(formatv("{0}Traits", interfaceName)); + + StringRef cppNamespace = availability.getInterfaceClassNamespace(); + cppNamespace.consume_front("::"); + SmallVector nsSegments; + llvm::SplitString(cppNamespace, nsSegments, "::"); + for (StringRef segment : nsSegments) + os << "namespace " << segment << " {\n"; + + // Emit the traits struct containing the concept and model declarations. + os << "namespace detail {\n" + << "struct " << interfaceTraitsName << " {\n"; + emitConceptDecl(availability, os); + os << '\n'; + emitModelDecl(availability, os); + os << "};\n} // namespace detail\n\n"; + + // Emit the main interface class declaration. + os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n"; + os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n" + "public:\n" + " using OpInterface<{1}, detail::{2}>::OpInterface;\n", + interfaceName, interfaceName, interfaceTraitsName); + + // Emit query function declaration. + os << " " << availability.getQueryFnRetType() << " " + << availability.getQueryFnName() << "();\n"; + os << "};\n\n"; + + for (StringRef segment : llvm::reverse(nsSegments)) + os << "} // namespace " << segment << "\n"; +} + +static bool emitInterfaceDecls(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("Availability Interface Declarations", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("Availability"); + SmallVector handledClasses; + for (const Record *def : defs) { + SmallVector parent; + def->getDirectSuperClasses(parent); + if (parent.size() != 1) { + PrintFatalError(def->getLoc(), + "expected to only have one direct superclass"); + } + if (llvm::is_contained(handledClasses, parent.front())) + continue; + + Availability avail(def); + emitInterfaceDecl(avail, os); + handledClasses.push_back(parent.front()); + } + return false; +} + +//===----------------------------------------------------------------------===// +// Availability Interface Definitions AutoGen +//===----------------------------------------------------------------------===// + +static void emitInterfaceDef(const Availability &availability, + raw_ostream &os) { + os << availability.getQueryFnRetType() << " "; + + StringRef cppNamespace = availability.getInterfaceClassNamespace(); + cppNamespace.consume_front("::"); + if (!cppNamespace.empty()) + os << cppNamespace << "::"; + + StringRef methodName = availability.getQueryFnName(); + os << availability.getInterfaceClassName() << "::" << methodName << "() {\n" + << " return getImpl()->" << methodName << "(getImpl(), getOperation());\n" + << "}\n"; +} + +static bool emitInterfaceDefs(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("Availability Interface Definitions", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("Availability"); + SmallVector handledClasses; + for (const Record *def : defs) { + SmallVector parent; + def->getDirectSuperClasses(parent); + if (parent.size() != 1) { + PrintFatalError(def->getLoc(), + "expected to only have one direct superclass"); + } + if (llvm::is_contained(handledClasses, parent.front())) + continue; + + Availability availability(def); + emitInterfaceDef(availability, os); + handledClasses.push_back(parent.front()); + } + return false; +} + +//===----------------------------------------------------------------------===// +// Availability Interface Hook Registration +//===----------------------------------------------------------------------===// + +// Registers the operation interface generator to mlir-tblgen. +static mlir::GenRegistration + genInterfaceDecls("gen-avail-interface-decls", + "Generate availability interface declarations", + [](const RecordKeeper &records, raw_ostream &os) { + return emitInterfaceDecls(records, os); + }); + +// Registers the operation interface generator to mlir-tblgen. +static mlir::GenRegistration + genInterfaceDefs("gen-avail-interface-defs", + "Generate op interface definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitInterfaceDefs(records, os); + }); diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -6,6 +6,7 @@ add_tablegen(mlir-tblgen MLIR AttrOrTypeDefGen.cpp + AvailabilityInterfaceGen.cpp DialectGen.cpp DirectiveCommonGen.cpp EnumsGen.cpp diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -11,10 +11,12 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/Availability.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TableGen/Error.h" @@ -29,6 +31,7 @@ using llvm::RecordKeeper; using llvm::StringRef; using mlir::tblgen::Attribute; +using mlir::tblgen::Availability; using mlir::tblgen::EnumAttr; using mlir::tblgen::EnumAttrCase; using mlir::tblgen::FmtContext; @@ -42,6 +45,10 @@ return str.str(); } +//===----------------------------------------------------------------------===// +// Enum Utility AutoGen +//===----------------------------------------------------------------------===// + static void emitEnumClass(const Record &enumDef, StringRef enumName, StringRef underlyingType, StringRef description, const std::vector &enumerants, @@ -540,6 +547,10 @@ return false; } +//===----------------------------------------------------------------------===// +// Enum Utility Query Hook Registration +//===----------------------------------------------------------------------===// + // Registers the enum utility generator to mlir-tblgen. static mlir::GenRegistration genEnumDecls("gen-enum-decls", "Generate enum utility declarations", @@ -553,3 +564,202 @@ [](const RecordKeeper &records, raw_ostream &os) { return emitEnumDefs(records, os); }); + +//===----------------------------------------------------------------------===// +// Enum Availability Query AutoGen +//===----------------------------------------------------------------------===// + +// Returns the availability spec of the given `def`. +static std::vector getAvailabilities(const llvm::Record &def) { + std::vector availabilities; + + if (def.getValue("availability")) { + std::vector availDefs = + def.getValueAsListOfDefs("availability"); + availabilities.reserve(availDefs.size()); + for (const llvm::Record *avail : availDefs) + availabilities.emplace_back(avail); + } + + return availabilities; +} + +static void emitAvailabilityQueryForIntEnum(const Record &enumDef, + raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::vector enumerants = enumAttr.getAllCases(); + + // Mapping from availability class name to (enumerant, availability + // specification) pairs. + llvm::StringMap, 1>> + classCaseMap; + + // Place all availability specifications to their corresponding + // availability classes. + for (const EnumAttrCase &enumerant : enumerants) + for (const Availability &avail : getAvailabilities(enumerant.getDef())) + classCaseMap[avail.getClass()].push_back({enumerant, avail}); + + for (const auto &classCasePair : classCaseMap) { + Availability avail = classCasePair.getValue().front().second; + + os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", + avail.getMergeInstanceType(), avail.getQueryFnName(), + enumName); + + os << " switch (value) {\n"; + for (const auto &caseSpecPair : classCasePair.getValue()) { + EnumAttrCase enumerant = caseSpecPair.first; + Availability avail = caseSpecPair.second; + os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName, + enumerant.getSymbol(), avail.getMergeInstancePreparation(), + avail.getMergeInstanceType(), avail.getMergeInstance()); + } + // Only emit default if uncovered cases. + if (classCasePair.getValue().size() < enumAttr.getAllCases().size()) + os << " default: break;\n"; + os << " }\n" + << " return llvm::None;\n" + << "}\n"; + } +} + +static void emitAvailabilityQueryForBitEnum(const Record &enumDef, + raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::string underlyingType = std::string(enumAttr.getUnderlyingType()); + std::vector enumerants = enumAttr.getAllCases(); + + // Mapping from availability class name to (enumerant, availability + // specification) pairs. + llvm::StringMap, 1>> + classCaseMap; + + // Place all availability specifications to their corresponding + // availability classes. + for (const EnumAttrCase &enumerant : enumerants) + for (const Availability &avail : getAvailabilities(enumerant.getDef())) + classCaseMap[avail.getClass()].push_back({enumerant, avail}); + + for (const auto &classCasePair : classCaseMap) { + Availability avail = classCasePair.getValue().front().second; + + os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", + avail.getMergeInstanceType(), avail.getQueryFnName(), + enumName); + + os << formatv( + " assert(::llvm::countPopulation(static_cast<{0}>(value)) <= 1" + " && \"cannot have more than one bit set\");\n", + underlyingType); + + os << " switch (value) {\n"; + for (const auto &caseSpecPair : classCasePair.getValue()) { + EnumAttrCase enumerant = caseSpecPair.first; + Availability avail = caseSpecPair.second; + os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName, + enumerant.getSymbol(), avail.getMergeInstancePreparation(), + avail.getMergeInstanceType(), avail.getMergeInstance()); + } + os << " default: break;\n"; + os << " }\n" + << " return llvm::None;\n" + << "}\n"; + } +} + +static void emitEnumAvailDecl(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + StringRef cppNamespace = enumAttr.getCppNamespace(); + auto enumerants = enumAttr.getAllCases(); + + llvm::SmallVector namespaces; + llvm::SplitString(cppNamespace, namespaces, "::"); + + for (auto ns : namespaces) + os << "namespace " << ns << " {\n"; + + llvm::StringSet<> handledClasses; + + // Place all availability specifications to their corresponding + // availability classes. + for (const EnumAttrCase &enumerant : enumerants) + for (const Availability &avail : getAvailabilities(enumerant.getDef())) { + StringRef className = avail.getClass(); + if (handledClasses.count(className)) + continue; + os << formatv("llvm::Optional<{0}> {1}({2} value);\n", + avail.getMergeInstanceType(), avail.getQueryFnName(), + enumName); + handledClasses.insert(className); + } + + for (auto ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; +} + +static bool emitEnumAvailDecls(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("Enum Availability Declarations", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); + for (const auto *def : defs) + emitEnumAvailDecl(*def, os); + + return false; +} + +static void emitEnumAvailDef(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef cppNamespace = enumAttr.getCppNamespace(); + + llvm::SmallVector namespaces; + llvm::SplitString(cppNamespace, namespaces, "::"); + + for (auto ns : namespaces) + os << "namespace " << ns << " {\n"; + + if (enumAttr.isBitEnum()) { + emitAvailabilityQueryForBitEnum(enumDef, os); + } else { + emitAvailabilityQueryForIntEnum(enumDef, os); + } + + for (auto ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; + os << "\n"; +} + +static bool emitEnumAvailDefs(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("Enum Availability Definitions", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); + for (const auto *def : defs) + emitEnumAvailDef(*def, os); + + return false; +} + +//===----------------------------------------------------------------------===// +// Enum Availability Query Hook Registration +//===----------------------------------------------------------------------===// + +// Registers the enum utility generator to mlir-tblgen. +static mlir::GenRegistration + genEnumAvailDecls("gen-enum-avail-decls", + "Generate enum availability declarations", + [](const RecordKeeper &records, raw_ostream &os) { + return emitEnumAvailDecls(records, os); + }); + +// Registers the enum utility generator to mlir-tblgen. +static mlir::GenRegistration + genEnumAvailDefs("gen-enum-avail-defs", + "Generate enum availability definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitEnumAvailDefs(records, os); + }); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -13,6 +13,7 @@ #include "OpFormatGen.h" #include "OpGenHelpers.h" +#include "mlir/TableGen/Availability.h" #include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" @@ -478,6 +479,9 @@ // Generate the type inference interface methods. void genTypeInterfaceMethods(); + // Generate availability interface methods. + void genAvailabilityInterfaceMethods(); + private: // The TableGen record for this op. // TODO: OpEmitter should not have a Record directly, @@ -641,6 +645,7 @@ genVerifier(); genCanonicalizerDecls(); genFolderDecls(); + genAvailabilityInterfaceMethods(); genTypeInterfaceMethods(); genOpInterfaceMethods(); generateOpFormat(op, opClass); @@ -1987,6 +1992,133 @@ } } +// Returns the availability spec of the given `def`. +static std::vector getAvailabilities(const llvm::Record &def) { + std::vector availabilities; + + if (def.getValue("availability")) { + std::vector availDefs = + def.getValueAsListOfDefs("availability"); + availabilities.reserve(availDefs.size()); + for (const llvm::Record *avail : availDefs) + availabilities.emplace_back(avail); + } + + return availabilities; +} + +void OpEmitter::genAvailabilityInterfaceMethods() { + mlir::tblgen::FmtContext fctx; + fctx.addSubst("overall", "tblgen_overall"); + + std::vector opAvailabilities = getAvailabilities(op.getDef()); + + // First collect all availability classes this op should implement. + // All availability instances keep information for the generated interface and + // the instance's specific requirement. Here we remember a random instance so + // we can get the information regarding the generated interface. + llvm::StringMap availClasses; + for (const Availability &avail : opAvailabilities) + availClasses.try_emplace(avail.getClass(), avail); + for (const NamedAttribute &namedAttr : op.getAttributes()) { + const auto *enumAttr = llvm::dyn_cast(&namedAttr.attr); + if (!enumAttr) + continue; + + for (const EnumAttrCase &enumerant : enumAttr->getAllCases()) + for (const Availability &caseAvail : + getAvailabilities(enumerant.getDef())) + availClasses.try_emplace(caseAvail.getClass(), caseAvail); + } + + // Then generate implementation for each availability class. + for (const auto &availClass : availClasses) { + StringRef availClassName = availClass.getKey(); + Availability avail = availClass.getValue(); + + auto *method = opClass.addMethodAndPrune(avail.getQueryFnRetType(), + avail.getQueryFnName()); + assert(method && "method already registered!"); + auto &body = method->body(); + + // Create the variable for the final requirement and initialize it. + body << formatv(" {0} tblgen_overall = {1};\n", avail.getQueryFnRetType(), + avail.getMergeInitializer()); + + // Update with the op's specific availability spec. + for (const Availability &avail : opAvailabilities) { + if (avail.getClass() == availClassName && + (!avail.getMergeInstancePreparation().empty() || + !avail.getMergeActionCode().empty())) { + body << " {\n " + // Prepare this instance. + << avail.getMergeInstancePreparation() + << "\n " + // Merge this instance. + << std::string( + tgfmt(avail.getMergeActionCode(), + &fctx.addSubst("instance", avail.getMergeInstance()))) + << ";\n }\n"; + } + } + + // Update with enum attributes' specific availability spec. + for (const NamedAttribute &namedAttr : op.getAttributes()) { + const auto *enumAttr = llvm::dyn_cast(&namedAttr.attr); + if (!enumAttr) + continue; + + // (enumerant, availability specification) pairs for this availability + // class. + SmallVector, 1> caseSpecs; + + // Collect all cases' availability specs. + for (const EnumAttrCase &enumerant : enumAttr->getAllCases()) + for (const Availability &caseAvail : + getAvailabilities(enumerant.getDef())) + if (availClassName == caseAvail.getClass()) + caseSpecs.push_back({enumerant, caseAvail}); + + // If this attribute kind does not have any availability spec from any of + // its cases, no more work to do. + if (caseSpecs.empty()) + continue; + + if (enumAttr->isBitEnum()) { + // For BitEnumAttr, we need to iterate over each bit to query its + // availability spec. + body << formatv(" for (unsigned i = 0; " + "i < std::numeric_limits<{0}>::digits; ++i) {{\n", + enumAttr->getUnderlyingType()); + body << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & " + "static_cast<{0}::{1}>(1 << i);\n", + enumAttr->getCppNamespace(), + enumAttr->getEnumClassName(), namedAttr.name); + body << formatv( + " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n", + enumAttr->getUnderlyingType()); + } else { + // For IntEnumAttr, we just need to query the value as a whole. + body << " {\n"; + body << formatv(" auto tblgen_attrVal = this->{0}();\n", + namedAttr.name); + } + body << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n", + enumAttr->getCppNamespace(), avail.getQueryFnName()); + body << " if (tblgen_instance) " + // TODO` here once ODS supports + // dialect-specific contents so that we can use not implementing the + // availability interface as indication of no requirements. + << std::string(tgfmt(caseSpecs.front().second.getMergeActionCode(), + &fctx.addSubst("instance", "*tblgen_instance"))) + << ";\n"; + body << " }\n"; + } + + body << " return tblgen_overall;\n"; + } +} + void OpEmitter::genTypeInterfaceMethods() { if (!op.allResultTypesKnown()) return; diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -35,7 +35,6 @@ using llvm::raw_string_ostream; using llvm::Record; using llvm::RecordKeeper; -using llvm::SmallVector; using llvm::SMLoc; using llvm::StringMap; using llvm::StringRef; @@ -47,469 +46,6 @@ using mlir::tblgen::NamedTypeConstraint; using mlir::tblgen::Operator; -//===----------------------------------------------------------------------===// -// Availability Wrapper Class -//===----------------------------------------------------------------------===// - -namespace { -// Wrapper class with helper methods for accessing availability defined in -// TableGen. -class Availability { -public: - explicit Availability(const Record *def); - - // Returns the name of the direct TableGen class for this availability - // instance. - StringRef getClass() const; - - // Returns the generated C++ interface's class namespace. - StringRef getInterfaceClassNamespace() const; - - // Returns the generated C++ interface's class name. - StringRef getInterfaceClassName() const; - - // Returns the generated C++ interface's description. - StringRef getInterfaceDescription() const; - - // Returns the name of the query function insided the generated C++ interface. - StringRef getQueryFnName() const; - - // Returns the return type of the query function insided the generated C++ - // interface. - StringRef getQueryFnRetType() const; - - // Returns the code for merging availability requirements. - StringRef getMergeActionCode() const; - - // Returns the initializer expression for initializing the final availability - // requirements. - StringRef getMergeInitializer() const; - - // Returns the C++ type for an availability instance. - StringRef getMergeInstanceType() const; - - // Returns the C++ statements for preparing availability instance. - StringRef getMergeInstancePreparation() const; - - // Returns the concrete availability instance carried in this case. - StringRef getMergeInstance() const; - - // Returns the underlying LLVM TableGen Record. - const llvm::Record *getDef() const { return def; } - -private: - // The TableGen definition of this availability. - const llvm::Record *def; -}; -} // namespace - -Availability::Availability(const llvm::Record *def) : def(def) { - assert(def->isSubClassOf("Availability") && - "must be subclass of TableGen 'Availability' class"); -} - -StringRef Availability::getClass() const { - SmallVector parentClass; - def->getDirectSuperClasses(parentClass); - if (parentClass.size() != 1) { - PrintFatalError(def->getLoc(), - "expected to only have one direct superclass"); - } - return parentClass.front()->getName(); -} - -StringRef Availability::getInterfaceClassNamespace() const { - return def->getValueAsString("interfaceNamespace"); -} - -StringRef Availability::getInterfaceClassName() const { - return def->getValueAsString("interfaceName"); -} - -StringRef Availability::getInterfaceDescription() const { - return def->getValueAsString("interfaceDescription"); -} - -StringRef Availability::getQueryFnRetType() const { - return def->getValueAsString("queryFnRetType"); -} - -StringRef Availability::getQueryFnName() const { - return def->getValueAsString("queryFnName"); -} - -StringRef Availability::getMergeActionCode() const { - return def->getValueAsString("mergeAction"); -} - -StringRef Availability::getMergeInitializer() const { - return def->getValueAsString("initializer"); -} - -StringRef Availability::getMergeInstanceType() const { - return def->getValueAsString("instanceType"); -} - -StringRef Availability::getMergeInstancePreparation() const { - return def->getValueAsString("instancePreparation"); -} - -StringRef Availability::getMergeInstance() const { - return def->getValueAsString("instance"); -} - -// Returns the availability spec of the given `def`. -std::vector getAvailabilities(const Record &def) { - std::vector availabilities; - - if (def.getValue("availability")) { - std::vector availDefs = def.getValueAsListOfDefs("availability"); - availabilities.reserve(availDefs.size()); - for (const Record *avail : availDefs) - availabilities.emplace_back(avail); - } - - return availabilities; -} - -//===----------------------------------------------------------------------===// -// Availability Interface Definitions AutoGen -//===----------------------------------------------------------------------===// - -static void emitInterfaceDef(const Availability &availability, - raw_ostream &os) { - - os << availability.getQueryFnRetType() << " "; - - StringRef cppNamespace = availability.getInterfaceClassNamespace(); - cppNamespace.consume_front("::"); - if (!cppNamespace.empty()) - os << cppNamespace << "::"; - - StringRef methodName = availability.getQueryFnName(); - os << availability.getInterfaceClassName() << "::" << methodName << "() {\n" - << " return getImpl()->" << methodName << "(getImpl(), getOperation());\n" - << "}\n"; -} - -static bool emitInterfaceDefs(const RecordKeeper &recordKeeper, - raw_ostream &os) { - llvm::emitSourceFileHeader("Availability Interface Definitions", os); - - auto defs = recordKeeper.getAllDerivedDefinitions("Availability"); - SmallVector handledClasses; - for (const Record *def : defs) { - SmallVector parent; - def->getDirectSuperClasses(parent); - if (parent.size() != 1) { - PrintFatalError(def->getLoc(), - "expected to only have one direct superclass"); - } - if (llvm::is_contained(handledClasses, parent.front())) - continue; - - Availability availability(def); - emitInterfaceDef(availability, os); - handledClasses.push_back(parent.front()); - } - return false; -} - -//===----------------------------------------------------------------------===// -// Availability Interface Declarations AutoGen -//===----------------------------------------------------------------------===// - -static void emitConceptDecl(const Availability &availability, raw_ostream &os) { - os << " class Concept {\n" - << " public:\n" - << " virtual ~Concept() = default;\n" - << " virtual " << availability.getQueryFnRetType() << " " - << availability.getQueryFnName() - << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n" - << " };\n"; -} - -static void emitModelDecl(const Availability &availability, raw_ostream &os) { - for (const char *modelClass : {"Model", "FallbackModel"}) { - os << " template\n"; - os << " class " << modelClass << " : public Concept {\n" - << " public:\n" - << " " << availability.getQueryFnRetType() << " " - << availability.getQueryFnName() - << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n" - << " auto op = llvm::cast(tblgen_opaque_op);\n" - << " (void)op;\n" - // Forward to the method on the concrete operation type. - << " return op." << availability.getQueryFnName() << "();\n" - << " }\n" - << " };\n"; - } - os << " template\n"; - os << " class ExternalModel : public FallbackModel {};\n"; -} - -static void emitInterfaceDecl(const Availability &availability, - raw_ostream &os) { - StringRef interfaceName = availability.getInterfaceClassName(); - std::string interfaceTraitsName = - std::string(formatv("{0}Traits", interfaceName)); - - StringRef cppNamespace = availability.getInterfaceClassNamespace(); - cppNamespace.consume_front("::"); - SmallVector nsSegments; - llvm::SplitString(cppNamespace, nsSegments, "::"); - for (StringRef segment : nsSegments) - os << "namespace " << segment << " {\n"; - - // Emit the traits struct containing the concept and model declarations. - os << "namespace detail {\n" - << "struct " << interfaceTraitsName << " {\n"; - emitConceptDecl(availability, os); - os << '\n'; - emitModelDecl(availability, os); - os << "};\n} // namespace detail\n\n"; - - // Emit the main interface class declaration. - os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n"; - os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n" - "public:\n" - " using OpInterface<{1}, detail::{2}>::OpInterface;\n", - interfaceName, interfaceName, interfaceTraitsName); - - // Emit query function declaration. - os << " " << availability.getQueryFnRetType() << " " - << availability.getQueryFnName() << "();\n"; - os << "};\n\n"; - - for (StringRef segment : llvm::reverse(nsSegments)) - os << "} // namespace " << segment << "\n"; -} - -static bool emitInterfaceDecls(const RecordKeeper &recordKeeper, - raw_ostream &os) { - llvm::emitSourceFileHeader("Availability Interface Declarations", os); - - auto defs = recordKeeper.getAllDerivedDefinitions("Availability"); - SmallVector handledClasses; - for (const Record *def : defs) { - SmallVector parent; - def->getDirectSuperClasses(parent); - if (parent.size() != 1) { - PrintFatalError(def->getLoc(), - "expected to only have one direct superclass"); - } - if (llvm::is_contained(handledClasses, parent.front())) - continue; - - Availability avail(def); - emitInterfaceDecl(avail, os); - handledClasses.push_back(parent.front()); - } - return false; -} - -//===----------------------------------------------------------------------===// -// Availability Interface Hook Registration -//===----------------------------------------------------------------------===// - -// Registers the operation interface generator to mlir-tblgen. -static mlir::GenRegistration - genInterfaceDecls("gen-avail-interface-decls", - "Generate availability interface declarations", - [](const RecordKeeper &records, raw_ostream &os) { - return emitInterfaceDecls(records, os); - }); - -// Registers the operation interface generator to mlir-tblgen. -static mlir::GenRegistration - genInterfaceDefs("gen-avail-interface-defs", - "Generate op interface definitions", - [](const RecordKeeper &records, raw_ostream &os) { - return emitInterfaceDefs(records, os); - }); - -//===----------------------------------------------------------------------===// -// Enum Availability Query AutoGen -//===----------------------------------------------------------------------===// - -static void emitAvailabilityQueryForIntEnum(const Record &enumDef, - raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - std::vector enumerants = enumAttr.getAllCases(); - - // Mapping from availability class name to (enumerant, availability - // specification) pairs. - llvm::StringMap, 1>> - classCaseMap; - - // Place all availability specifications to their corresponding - // availability classes. - for (const EnumAttrCase &enumerant : enumerants) - for (const Availability &avail : getAvailabilities(enumerant.getDef())) - classCaseMap[avail.getClass()].push_back({enumerant, avail}); - - for (const auto &classCasePair : classCaseMap) { - Availability avail = classCasePair.getValue().front().second; - - os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", - avail.getMergeInstanceType(), avail.getQueryFnName(), - enumName); - - os << " switch (value) {\n"; - for (const auto &caseSpecPair : classCasePair.getValue()) { - EnumAttrCase enumerant = caseSpecPair.first; - Availability avail = caseSpecPair.second; - os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName, - enumerant.getSymbol(), avail.getMergeInstancePreparation(), - avail.getMergeInstanceType(), avail.getMergeInstance()); - } - // Only emit default if uncovered cases. - if (classCasePair.getValue().size() < enumAttr.getAllCases().size()) - os << " default: break;\n"; - os << " }\n" - << " return llvm::None;\n" - << "}\n"; - } -} - -static void emitAvailabilityQueryForBitEnum(const Record &enumDef, - raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - std::string underlyingType = std::string(enumAttr.getUnderlyingType()); - std::vector enumerants = enumAttr.getAllCases(); - - // Mapping from availability class name to (enumerant, availability - // specification) pairs. - llvm::StringMap, 1>> - classCaseMap; - - // Place all availability specifications to their corresponding - // availability classes. - for (const EnumAttrCase &enumerant : enumerants) - for (const Availability &avail : getAvailabilities(enumerant.getDef())) - classCaseMap[avail.getClass()].push_back({enumerant, avail}); - - for (const auto &classCasePair : classCaseMap) { - Availability avail = classCasePair.getValue().front().second; - - os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", - avail.getMergeInstanceType(), avail.getQueryFnName(), - enumName); - - os << formatv( - " assert(::llvm::countPopulation(static_cast<{0}>(value)) <= 1" - " && \"cannot have more than one bit set\");\n", - underlyingType); - - os << " switch (value) {\n"; - for (const auto &caseSpecPair : classCasePair.getValue()) { - EnumAttrCase enumerant = caseSpecPair.first; - Availability avail = caseSpecPair.second; - os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName, - enumerant.getSymbol(), avail.getMergeInstancePreparation(), - avail.getMergeInstanceType(), avail.getMergeInstance()); - } - os << " default: break;\n"; - os << " }\n" - << " return llvm::None;\n" - << "}\n"; - } -} - -static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef enumName = enumAttr.getEnumClassName(); - StringRef cppNamespace = enumAttr.getCppNamespace(); - auto enumerants = enumAttr.getAllCases(); - - llvm::SmallVector namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; - - llvm::StringSet<> handledClasses; - - // Place all availability specifications to their corresponding - // availability classes. - for (const EnumAttrCase &enumerant : enumerants) - for (const Availability &avail : getAvailabilities(enumerant.getDef())) { - StringRef className = avail.getClass(); - if (handledClasses.count(className)) - continue; - os << formatv("llvm::Optional<{0}> {1}({2} value);\n", - avail.getMergeInstanceType(), avail.getQueryFnName(), - enumName); - handledClasses.insert(className); - } - - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; -} - -static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { - llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os); - - auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); - for (const auto *def : defs) - emitEnumDecl(*def, os); - - return false; -} - -static void emitEnumDef(const Record &enumDef, raw_ostream &os) { - EnumAttr enumAttr(enumDef); - StringRef cppNamespace = enumAttr.getCppNamespace(); - - llvm::SmallVector namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; - - if (enumAttr.isBitEnum()) { - emitAvailabilityQueryForBitEnum(enumDef, os); - } else { - emitAvailabilityQueryForIntEnum(enumDef, os); - } - - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; - os << "\n"; -} - -static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { - llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os); - - auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); - for (const auto *def : defs) - emitEnumDef(*def, os); - - return false; -} - -//===----------------------------------------------------------------------===// -// Enum Availability Query Hook Registration -//===----------------------------------------------------------------------===// - -// Registers the enum utility generator to mlir-tblgen. -static mlir::GenRegistration - genEnumDecls("gen-spirv-enum-avail-decls", - "Generate SPIR-V enum availability declarations", - [](const RecordKeeper &records, raw_ostream &os) { - return emitEnumDecls(records, os); - }); - -// Registers the enum utility generator to mlir-tblgen. -static mlir::GenRegistration - genEnumDefs("gen-spirv-enum-avail-defs", - "Generate SPIR-V enum availability definitions", - [](const RecordKeeper &records, raw_ostream &os) { - return emitEnumDefs(records, os); - }); - //===----------------------------------------------------------------------===// // Serialization AutoGen //===----------------------------------------------------------------------===// @@ -1245,145 +781,6 @@ return emitAttrUtils(records, os); }); -//===----------------------------------------------------------------------===// -// SPIR-V Availability Impl AutoGen -//===----------------------------------------------------------------------===// - -static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { - mlir::tblgen::FmtContext fctx; - fctx.addSubst("overall", "tblgen_overall"); - - std::vector opAvailabilities = - getAvailabilities(srcOp.getDef()); - - // First collect all availability classes this op should implement. - // All availability instances keep information for the generated interface and - // the instance's specific requirement. Here we remember a random instance so - // we can get the information regarding the generated interface. - llvm::StringMap availClasses; - for (const Availability &avail : opAvailabilities) - availClasses.try_emplace(avail.getClass(), avail); - for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { - const auto *enumAttr = llvm::dyn_cast(&namedAttr.attr); - if (!enumAttr) - continue; - - for (const EnumAttrCase &enumerant : enumAttr->getAllCases()) - for (const Availability &caseAvail : - getAvailabilities(enumerant.getDef())) - availClasses.try_emplace(caseAvail.getClass(), caseAvail); - } - - // Then generate implementation for each availability class. - for (const auto &availClass : availClasses) { - StringRef availClassName = availClass.getKey(); - Availability avail = availClass.getValue(); - - // Generate the implementation method signature. - os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(), - srcOp.getCppClassName(), avail.getQueryFnName()); - - // Create the variable for the final requirement and initialize it. - os << formatv(" {0} tblgen_overall = {1};\n", avail.getQueryFnRetType(), - avail.getMergeInitializer()); - - // Update with the op's specific availability spec. - for (const Availability &avail : opAvailabilities) - if (avail.getClass() == availClassName && - (!avail.getMergeInstancePreparation().empty() || - !avail.getMergeActionCode().empty())) { - os << " {\n " - // Prepare this instance. - << avail.getMergeInstancePreparation() - << "\n " - // Merge this instance. - << std::string( - tgfmt(avail.getMergeActionCode(), - &fctx.addSubst("instance", avail.getMergeInstance()))) - << ";\n }\n"; - } - - // Update with enum attributes' specific availability spec. - for (const NamedAttribute &namedAttr : srcOp.getAttributes()) { - const auto *enumAttr = llvm::dyn_cast(&namedAttr.attr); - if (!enumAttr) - continue; - - // (enumerant, availability specification) pairs for this availability - // class. - SmallVector, 1> caseSpecs; - - // Collect all cases' availability specs. - for (const EnumAttrCase &enumerant : enumAttr->getAllCases()) - for (const Availability &caseAvail : - getAvailabilities(enumerant.getDef())) - if (availClassName == caseAvail.getClass()) - caseSpecs.push_back({enumerant, caseAvail}); - - // If this attribute kind does not have any availability spec from any of - // its cases, no more work to do. - if (caseSpecs.empty()) - continue; - - if (enumAttr->isBitEnum()) { - // For BitEnumAttr, we need to iterate over each bit to query its - // availability spec. - os << formatv(" for (unsigned i = 0; " - "i < std::numeric_limits<{0}>::digits; ++i) {{\n", - enumAttr->getUnderlyingType()); - os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & " - "static_cast<{0}::{1}>(1 << i);\n", - enumAttr->getCppNamespace(), enumAttr->getEnumClassName(), - namedAttr.name); - os << formatv( - " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n", - enumAttr->getUnderlyingType()); - } else { - // For IntEnumAttr, we just need to query the value as a whole. - os << " {\n"; - os << formatv(" auto tblgen_attrVal = this->{0}();\n", - namedAttr.name); - } - os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n", - enumAttr->getCppNamespace(), avail.getQueryFnName()); - os << " if (tblgen_instance) " - // TODO` here once ODS supports - // dialect-specific contents so that we can use not implementing the - // availability interface as indication of no requirements. - << std::string(tgfmt(caseSpecs.front().second.getMergeActionCode(), - &fctx.addSubst("instance", "*tblgen_instance"))) - << ";\n"; - os << " }\n"; - } - - os << " return tblgen_overall;\n"; - os << "}\n"; - } -} - -static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper, - raw_ostream &os) { - llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os); - - auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op"); - for (const auto *def : defs) { - Operator op(def); - emitAvailabilityImpl(op, os); - } - return false; -} - -//===----------------------------------------------------------------------===// -// Op Availability Implementation Hook Registration -//===----------------------------------------------------------------------===// - -static mlir::GenRegistration - genOpAvailabilityImpl("gen-spirv-avail-impls", - "Generate SPIR-V operation utility definitions", - [](const RecordKeeper &records, raw_ostream &os) { - return emitAvailabilityImpl(records, os); - }); - //===----------------------------------------------------------------------===// // SPIR-V Capability Implication AutoGen //===----------------------------------------------------------------------===// diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3473,11 +3473,11 @@ "include/mlir/Dialect/SPIRV/IR/SPIRVEnums.cpp.inc", ), ( - ["-gen-spirv-enum-avail-decls"], + ["-gen-enum-avail-decls"], "include/mlir/Dialect/SPIRV/IR/SPIRVEnumAvailability.h.inc", ), ( - ["-gen-spirv-enum-avail-defs"], + ["-gen-enum-avail-defs"], "include/mlir/Dialect/SPIRV/IR/SPIRVEnumAvailability.cpp.inc", ), ( @@ -3516,10 +3516,6 @@ ["-gen-avail-interface-defs"], "include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc", ), - ( - ["-gen-spirv-avail-impls"], - "include/mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc", - ), ], tblgen = ":mlir-tblgen", td_file = "include/mlir/Dialect/SPIRV/IR/SPIRVOps.td",