diff --git a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt --- a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt @@ -1,8 +1,3 @@ -set(LLVM_TARGET_DEFINITIONS SPIRVLowering.td) -mlir_tablegen(SPIRVLowering.h.inc -gen-struct-attr-decls) -mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs) -add_public_tablegen_target(MLIRSPIRVLoweringStructGen) - add_mlir_dialect(SPIRVOps SPIRVOps) set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) @@ -10,6 +5,12 @@ mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRSPIRVEnumsIncGen) +set(LLVM_TARGET_DEFINITIONS SPIRVOps.td) +mlir_tablegen(SPIRVAvailability.h.inc -gen-avail-interface-decls) +mlir_tablegen(SPIRVAvailability.cpp.inc -gen-avail-interface-defs) +mlir_tablegen(SPIRVOpAvailabilityImpl.inc -gen-spirv-avail-impls) +add_public_tablegen_target(MLIRSPIRVAvailabilityIncGen) + set(LLVM_TARGET_DEFINITIONS SPIRVOps.td) mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization) add_public_tablegen_target(MLIRSPIRVSerializationGen) @@ -17,3 +18,8 @@ set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) mlir_tablegen(SPIRVOpUtils.inc -gen-spirv-op-utils) add_public_tablegen_target(MLIRSPIRVOpUtilsGen) + +set(LLVM_TARGET_DEFINITIONS SPIRVLowering.td) +mlir_tablegen(SPIRVLowering.h.inc -gen-struct-attr-decls) +mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRSPIRVLoweringStructGen) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td @@ -120,6 +120,13 @@ ``` }]; + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Kernel]> + ]; + let arguments = (ins SPV_AnyPtr:$pointer, SPV_ScopeAttr:$memory_scope, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -45,6 +45,214 @@ let cppNamespace = "spirv"; } +//===----------------------------------------------------------------------===// +// Op availability definitions +//===----------------------------------------------------------------------===// + +// The base class for defining op availability dimensions. +class Availability { + // The following are fields for controlling the generated C++ OpInterface. + + // The name for the generated C++ OpInterface subclass. + string interfaceName = ?; + // The documentation for the generated C++ OpInterface subclass. + string interfaceDescription = ""; + + // The following are fields for controlling the query function signature. + + // The query function's return type in the generated C++ OpInterface subclass. + string queryFnRetType = ?; + // The query function's name in the generated C++ OpInterface subclass. + string queryFnName = ?; + + // The following are fields for controlling the query function implementation. + + // The logic for merging two availability requirements. This is used to derive + // the final availability requirement when, for example, an op has two + // operands and these two operands have different availability requirements. + // + // The code should use `$overall` as the placeholder for the final requirement + // and `$instance` for the current availability requirement instance. + code mergeAction = ?; + // The initializer for the final availability requirement. + string initializer = ?; + // An availability instance's type. + string instanceType = ?; + + // The following are fields for a concrete availability instance. + + // The availability requirement carried by a concrete instance. + string instance = ?; +} + +class MinVersionBase + : Availability { + let interfaceName = name; + + let queryFnRetType = scheme.returnType; + let queryFnName = "getMinVersion"; + + let mergeAction = "$overall = static_cast<" # scheme.returnType # ">(" + "std::max($overall, $instance))"; + let initializer = "static_cast<" # scheme.returnType # ">(uint32_t(0))"; + let instanceType = scheme.cppNamespace # "::" # scheme.className; + + let instance = scheme.cppNamespace # "::" # scheme.className # "::" # + min.symbol; +} + +class MaxVersionBase + : Availability { + let interfaceName = name; + + let queryFnRetType = scheme.returnType; + let queryFnName = "getMaxVersion"; + + let mergeAction = "$overall = static_cast<" # scheme.returnType # ">(" + "std::min($overall, $instance))"; + let initializer = "static_cast<" # scheme.returnType # ">(~uint32_t(0))"; + let instanceType = scheme.cppNamespace # "::" # scheme.className; + + let instance = scheme.cppNamespace # "::" # scheme.className # "::" # + max.symbol; +} + +//===----------------------------------------------------------------------===// +// SPIR-V availability definitions +//===----------------------------------------------------------------------===// + +def SPV_V_1_0 : I32EnumAttrCase<"V_1_0", 0>; +def SPV_V_1_1 : I32EnumAttrCase<"V_1_1", 1>; +def SPV_V_1_2 : I32EnumAttrCase<"V_1_2", 2>; +def SPV_V_1_3 : I32EnumAttrCase<"V_1_3", 3>; +def SPV_V_1_4 : I32EnumAttrCase<"V_1_4", 4>; +def SPV_V_1_5 : I32EnumAttrCase<"V_1_5", 5>; + +def SPV_VersionAttr : I32EnumAttr<"Version", "valid SPIR-V version", [ + SPV_V_1_0, SPV_V_1_1, SPV_V_1_2, SPV_V_1_3, SPV_V_1_4, SPV_V_1_5]> { + let cppNamespace = "::mlir::spirv"; +} + +class MinVersion : MinVersionBase< + "QueryMinVersionInterface", SPV_VersionAttr, min> { + let interfaceDescription = [{ + Querying interface for minimal required SPIR-V version. + + This interface provides a `getMinVersion()` method to query the minimal + required version for the implementing SPIR-V operation. The returned value + is a `mlir::spirv::Version` enumerant. + }]; +} + +class MaxVersion : MaxVersionBase< + "QueryMaxVersionInterface", SPV_VersionAttr, max> { + let interfaceDescription = [{ + Querying interface for maximal supported SPIR-V version. + + This interface provides a `getMaxVersion()` method to query the maximal + supported version for the implementing SPIR-V operation. The returned value + is a `mlir::spirv::Version` enumerant. + }]; +} + +class Extension extensions> : Availability { + let interfaceName = "QueryExtensionInterface"; + let interfaceDescription = [{ + Querying interface for required SPIR-V extensions. + + This interface provides a `getExtensions()` method to query the required + extensions for the implementing SPIR-V operation. The returned value + is a neted vector whose element is `mlir::spirv::Extension`s. The outer + vector's elements (which are vectors) should be interpreted as conjunction + while the innner vector's elements (which are `mlir::spirv::Extension`s) + should be interpreted as disjunction. For example, given + + ``` + {{Extension::A, Extension::B}, {Extension::C}, {{Extension::D, Extension::E}} + ``` + + The operation instance is available when (`Extension::A` OR `Extension::B`) + AND (`Extension::C`) AND (`Extension::D` OR `Extension::E`) is enabled. + }]; + + // TODO(antiagainst): Using SmallVector> is an anti-pattern. + // Find a better way for this. + let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<" + "::mlir::spirv::Extension, 1>, 1>"; + let queryFnName = "getExtensions"; + + let mergeAction = !if( + !empty(extensions), "", "$overall.emplace_back($instance)"); + let initializer = "{}"; + let instanceType = "::llvm::SmallVector<::mlir::spirv::Extension, 1>"; + + // Compose all capabilities as an C++ initializer list + let instance = "std::initializer_list<::mlir::spirv::Extension>{" # + StrJoin.result # + "}"; +} + +class Capability capabilities> : Availability { + let interfaceName = "QueryCapabilityInterface"; + let interfaceDescription = [{ + Querying interface for required SPIR-V capabilities. + + This interface provides a `getCapabilities()` method to query the required + capabilities for the implementing SPIR-V operation. The returned value + is a neted vector whose element is `mlir::spirv::Capability`s. The outer + vector's elements (which are vectors) should be interpreted as conjunction + while the innner vector's elements (which are `mlir::spirv::Capability`s) + should be interpreted as disjunction. For example, given + + ``` + {{Capability::A, Capability::B}, {Capability::C}, {{Capability::D, Capability::E}} + ``` + + The operation instance is available when (`Capability::A` OR `Capability::B`) + AND (`Capability::C`) AND (`Capability::D` OR `Capability::E`) is enabled. + }]; + + let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<" + "::mlir::spirv::Capability, 1>, 1>"; + let queryFnName = "getCapabilities"; + + let mergeAction = !if( + !empty(capabilities), "", "$overall.emplace_back($instance)"); + let initializer = "{}"; + let instanceType = "::llvm::SmallVector<::mlir::spirv::Capability, 1>"; + + // Compose all capabilities as an C++ initializer list + let instance = "std::initializer_list<::mlir::spirv::Capability>{" # + StrJoin.result # + "}"; +} + +// TODO(antiagainst): the following interfaces definitions are duplicating with +// the above. Remove them once we are able to support dialect-specific contents +// in ODS. +def QueryMinVersionInterface : OpInterface<"QueryMinVersionInterface"> { + let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMinVersion">]; +} +def QueryMaxVersionInterface : OpInterface<"QueryMaxVersionInterface"> { + let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMaxVersion">]; +} +def QueryExtensionInterface : OpInterface<"QueryExtensionInterface"> { + let methods = [InterfaceMethod< + "", + "::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Extension, 1>, 1>", + "getExtensions">]; +} +def QueryCapabilityInterface : OpInterface<"QueryCapabilityInterface"> { + let methods = [InterfaceMethod< + "", + "::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Capability, 1>, 1>", + "getCapabilities">]; +} + //===----------------------------------------------------------------------===// // SPIR-V extension definitions //===----------------------------------------------------------------------===// @@ -1216,7 +1424,22 @@ // Base class for all SPIR-V ops. class SPV_Op traits = []> : - Op { + Op, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ])> { + // Availability specification for this op itself. + list availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[]> + ]; // For each SPIR-V op, the following static functions need to be defined // in SPVOps.cpp: diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td @@ -53,6 +53,13 @@ ``` }]; + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_GroupNonUniformBallot]> + ]; + let arguments = (ins SPV_ScopeAttr:$execution_scope, SPV_Bool:$predicate diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h @@ -21,18 +21,23 @@ namespace spirv { +// TableGen'erated operation interfaces for querying versions, extensions, and +// capabilities. +#include "mlir/Dialect/SPIRV/SPIRVAvailability.h.inc" + +// TablenGen'erated operation declarations. #define GET_OP_CLASSES #include "mlir/Dialect/SPIRV/SPIRVOps.h.inc" -/// Following methods are auto-generated. -/// -/// Get the name used in the Op to refer to an enum value of the given -/// `EnumClass`. -/// template StringRef attributeName(); -/// -/// Get the function that can be used to symbolize an enum value. -/// template -/// Optional (*)(StringRef) symbolizeEnum(); +// TableGen'erated helper functions. +// +// Get the name used in the Op to refer to an enum value of the given +// `EnumClass`. +// template StringRef attributeName(); +// +// Get the function that can be used to symbolize an enum value. +// template +// Optional (*)(StringRef) symbolizeEnum(); #include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc" } // end namespace spirv diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt @@ -15,6 +15,7 @@ ) add_dependencies(MLIRSPIRV + MLIRSPIRVAvailabilityIncGen MLIRSPIRVCanonicalizationIncGen MLIRSPIRVEnumsIncGen MLIRSPIRVLoweringStructGen diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -3063,8 +3063,16 @@ namespace mlir { namespace spirv { +// TableGen'erated operation interfaces for querying versions, extensions, and +// capabilities. +#include "mlir/Dialect/SPIRV/SPIRVAvailability.cpp.inc" + +// TablenGen'erated operation definitions. #define GET_OP_CLASSES #include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc" +// TableGen'erated operation availability interface implementations. +#include "mlir/Dialect/SPIRV/SPIRVOpAvailabilityImpl.inc" + } // namespace spirv } // namespace mlir diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Dialect) add_subdirectory(EDSC) add_subdirectory(mlir-cpu-runner) add_subdirectory(SDBM) diff --git a/mlir/test/Dialect/CMakeLists.txt b/mlir/test/Dialect/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(SPIRV) diff --git a/mlir/test/Dialect/SPIRV/CMakeLists.txt b/mlir/test/Dialect/SPIRV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/CMakeLists.txt @@ -0,0 +1,14 @@ +add_llvm_library(MLIRSPIRVTestPasses + TestAvailability.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV + ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR + ) + +target_link_libraries(MLIRSPIRVTestPasses + MLIRIR + MLIRPass + MLIRSPIRV + MLIRSupport + ) diff --git a/mlir/test/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/Dialect/SPIRV/TestAvailability.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/TestAvailability.cpp @@ -0,0 +1,73 @@ +//===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" +#include "mlir/IR/Function.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// A pass for testing SPIR-V op availability. +struct TestAvailability : public FunctionPass { + void runOnFunction() override; +}; +} // end anonymous namespace + +void TestAvailability::runOnFunction() { + auto f = getFunction(); + llvm::outs() << f.getName() << "\n"; + + Dialect *spvDialect = getContext().getRegisteredDialect("spv"); + + f.getOperation()->walk([&](Operation *op) { + if (op->getDialect() != spvDialect) + return WalkResult::advance(); + + auto &os = llvm::outs(); + + if (auto minVersion = dyn_cast(op)) + os << " min version: " + << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n"; + + if (auto maxVersion = dyn_cast(op)) + os << " max version: " + << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n"; + + if (auto extension = dyn_cast(op)) { + os << " extensions: ["; + for (const auto &exts : extension.getExtensions()) { + os << " ["; + interleaveComma(exts, os, [&](spirv::Extension ext) { + os << spirv::stringifyExtension(ext); + }); + os << "]"; + } + os << " ]\n"; + } + + if (auto capability = dyn_cast(op)) { + os << " capabilities: ["; + for (const auto &caps : capability.getCapabilities()) { + os << " ["; + interleaveComma(caps, os, [&](spirv::Capability cap) { + os << spirv::stringifyCapability(cap); + }); + os << "]"; + } + os << " ]\n"; + } + os.flush(); + + return WalkResult::advance(); + }); +} + +static PassRegistration pass("test-spirv-op-availability", + "Test SPIR-V op availability"); diff --git a/mlir/test/Dialect/SPIRV/availability.mlir b/mlir/test/Dialect/SPIRV/availability.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/availability.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt -disable-pass-threading -test-spirv-op-availability %s | FileCheck %s + +// CHECK-LABEL: iadd +func @iadd(%arg: i32) -> i32 { + // CHECK: min version: V_1_0 + // CHECK: max version: V_1_5 + // CHECK: extensions: [ ] + // CHECK: capabilities: [ ] + %0 = spv.IAdd %arg, %arg: i32 + return %0: i32 +} + +// CHECK: atomic_compare_exchange_weak +func @atomic_compare_exchange_weak(%ptr: !spv.ptr, %value: i32, %comparator: i32) -> i32 { + // CHECK: min version: V_1_0 + // CHECK: max version: V_1_3 + // CHECK: extensions: [ ] + // CHECK: capabilities: [ [Kernel] ] + %0 = spv.AtomicCompareExchangeWeak "Workgroup" "Release" "Acquire" %ptr, %value, %comparator: !spv.ptr + return %0: i32 +} + +// CHECK-LABEL: subgroup_ballot +func @subgroup_ballot(%predicate: i1) -> vector<4xi32> { + // CHECK: min version: V_1_3 + // CHECK: max version: V_1_5 + // CHECK: extensions: [ ] + // CHECK: capabilities: [ [GroupNonUniformBallot] ] + %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32> + return %0: vector<4xi32> +} diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -41,6 +41,7 @@ MLIRROCDLIR MLIRSPIRV MLIRStandardToSPIRVTransforms + MLIRSPIRVTestPasses MLIRSPIRVTransforms MLIRStandardOps MLIRStandardToLLVM diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -13,6 +13,7 @@ #include "mlir/Support/StringExtras.h" #include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/Sequence.h" @@ -43,6 +44,233 @@ using mlir::tblgen::NamedTypeConstraint; using mlir::tblgen::Operator; +//===----------------------------------------------------------------------===// +// Availability Wrapper Class +//===----------------------------------------------------------------------===// + +namespace { +// Wrapper class with helper methods for accessing availability defined in +// TableGen. +class Availability { +public: + explicit Availability(const Record *def); + + // Returns the name of the direct TableGen class for this availability + // instance. + StringRef getClass() const; + + // Returns the generated C++ interface's class name. + StringRef getInterfaceClassName() const; + + // Returns the generated C++ interface's description. + StringRef getInterfaceDescription() const; + + // Returns the name of the query function insided the generated C++ interface. + StringRef getQueryFnName() const; + + // Returns the return type of the query function insided the generated C++ + // interface. + StringRef getQueryFnRetType() const; + + // Returns the code for merging availability requirements. + StringRef getMergeActionCode() const; + + // Returns the initializer expression for initializing the final availability + // requirements. + StringRef getMergeInitializer() const; + + // Returns the C++ type for an availability instance. + StringRef getMergeInstanceType() const; + + // Returns the concrete availability instance carried in this case. + StringRef getMergeInstance() const; + +private: + // The TableGen definition of this availability. + const llvm::Record *def; +}; +} // namespace + +Availability::Availability(const llvm::Record *def) : def(def) { + assert(def->isSubClassOf("Availability") && + "must be subclass of TableGen 'Availability' class"); +} + +StringRef Availability::getClass() const { + SmallVector parentClass; + def->getDirectSuperClasses(parentClass); + if (parentClass.size() != 1) { + PrintFatalError(def->getLoc(), + "expected to only have one direct superclass"); + } + return parentClass.front()->getName(); +} + +StringRef Availability::getInterfaceClassName() const { + return def->getValueAsString("interfaceName"); +} + +StringRef Availability::getInterfaceDescription() const { + return def->getValueAsString("interfaceDescription"); +} + +StringRef Availability::getQueryFnRetType() const { + return def->getValueAsString("queryFnRetType"); +} + +StringRef Availability::getQueryFnName() const { + return def->getValueAsString("queryFnName"); +} + +StringRef Availability::getMergeActionCode() const { + return def->getValueAsString("mergeAction"); +} + +StringRef Availability::getMergeInitializer() const { + return def->getValueAsString("initializer"); +} + +StringRef Availability::getMergeInstanceType() const { + return def->getValueAsString("instanceType"); +} + +StringRef Availability::getMergeInstance() const { + return def->getValueAsString("instance"); +} + +//===----------------------------------------------------------------------===// +// Availability Interface Definitions AutoGen +//===----------------------------------------------------------------------===// + +static void emitInterfaceDef(const Availability &availability, + raw_ostream &os) { + StringRef methodName = availability.getQueryFnName(); + os << availability.getQueryFnRetType() << " " + << availability.getInterfaceClassName() << "::" << methodName << "() {\n" + << " return getImpl()->" << methodName << "(getOperation());\n" + << "}\n"; +} + +static bool emitInterfaceDefs(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("Availability Interface Definitions", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("Availability"); + SmallVector handledClasses; + for (const Record *def : defs) { + SmallVector parent; + def->getDirectSuperClasses(parent); + if (parent.size() != 1) { + PrintFatalError(def->getLoc(), + "expected to only have one direct superclass"); + } + if (llvm::is_contained(handledClasses, parent.front())) + continue; + + Availability availability(def); + emitInterfaceDef(availability, os); + handledClasses.push_back(parent.front()); + } + return false; +} + +//===----------------------------------------------------------------------===// +// Availability Interface Declarations AutoGen +//===----------------------------------------------------------------------===// + +static void emitConceptDecl(const Availability &availability, raw_ostream &os) { + os << " class Concept {\n" + << " public:\n" + << " virtual ~Concept() = default;\n" + << " virtual " << availability.getQueryFnRetType() << " " + << availability.getQueryFnName() << "(Operation *tblgen_opaque_op) = 0;\n" + << " };\n"; +} + +static void emitModelDecl(const Availability &availability, raw_ostream &os) { + os << " template\n"; + os << " class Model : public Concept {\n" + << " public:\n" + << " " << availability.getQueryFnRetType() << " " + << availability.getQueryFnName() + << "(Operation *tblgen_opaque_op) final {\n" + << " auto op = llvm::cast(tblgen_opaque_op);\n" + << " (void)op;\n" + // Forward to the method on the concrete operation type. + << " return op." << availability.getQueryFnName() << "();\n" + << " }\n" + << " };\n"; +} + +static void emitInterfaceDecl(const Availability &availability, + raw_ostream &os) { + StringRef interfaceName = availability.getInterfaceClassName(); + std::string interfaceTraitsName = formatv("{0}Traits", interfaceName); + + // Emit the traits struct containing the concept and model declarations. + os << "namespace detail {\n" + << "struct " << interfaceTraitsName << " {\n"; + emitConceptDecl(availability, os); + os << '\n'; + emitModelDecl(availability, os); + os << "};\n} // end namespace detail\n\n"; + + // Emit the main interface class declaration. + os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n"; + os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n" + "public:\n" + " using OpInterface<{1}, detail::{2}>::OpInterface;\n", + interfaceName, interfaceName, interfaceTraitsName); + + // Emit query function declaration. + os << " " << availability.getQueryFnRetType() << " " + << availability.getQueryFnName() << "();\n"; + os << "};\n\n"; +} + +static bool emitInterfaceDecls(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("Availability Interface Declarations", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("Availability"); + SmallVector handledClasses; + for (const Record *def : defs) { + SmallVector parent; + def->getDirectSuperClasses(parent); + if (parent.size() != 1) { + PrintFatalError(def->getLoc(), + "expected to only have one direct superclass"); + } + if (llvm::is_contained(handledClasses, parent.front())) + continue; + + Availability avail(def); + emitInterfaceDecl(avail, os); + handledClasses.push_back(parent.front()); + } + return false; +} + +//===----------------------------------------------------------------------===// +// Availability Interface Hook Registration +//===----------------------------------------------------------------------===// + +// Registers the operation interface generator to mlir-tblgen. +static mlir::GenRegistration + genInterfaceDecls("gen-avail-interface-decls", + "Generate availability interface declarations", + [](const RecordKeeper &records, raw_ostream &os) { + return emitInterfaceDecls(records, os); + }); + +// Registers the operation interface generator to mlir-tblgen. +static mlir::GenRegistration + genInterfaceDefs("gen-avail-interface-defs", + "Generate op interface definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitInterfaceDefs(records, os); + }); + //===----------------------------------------------------------------------===// // Serialization AutoGen //===----------------------------------------------------------------------===// @@ -650,6 +878,17 @@ return false; } +//===----------------------------------------------------------------------===// +// Serialization Hook Registration +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration genSerialization( + "gen-spirv-serialization", + "Generate SPIR-V (de)serialization utilities and functions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitSerializationFns(records, os); + }); + //===----------------------------------------------------------------------===// // Op Utils AutoGen //===----------------------------------------------------------------------===// @@ -707,19 +946,92 @@ } //===----------------------------------------------------------------------===// -// Hook Registration +// Op Utils Hook Registration //===----------------------------------------------------------------------===// -static mlir::GenRegistration genSerialization( - "gen-spirv-serialization", - "Generate SPIR-V (de)serialization utilities and functions", - [](const RecordKeeper &records, raw_ostream &os) { - return emitSerializationFns(records, os); - }); - static mlir::GenRegistration genOpUtils("gen-spirv-op-utils", "Generate SPIR-V operation utility definitions", [](const RecordKeeper &records, raw_ostream &os) { return emitOpUtils(records, os); }); + +//===----------------------------------------------------------------------===// +// SPIR-V Availability Impl AutoGen +//===----------------------------------------------------------------------===// + +// Returns the availability spec of the given `def`. +std::vector getAvailabilities(const Record &def) { + std::vector availabilities; + if (auto *availListInit = def.getValueAsListInit("availability")) { + availabilities.reserve(availListInit->size()); + for (auto *availInit : *availListInit) + availabilities.emplace_back( + llvm::cast(availInit)->getDef()); + } + return availabilities; +} + +static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { + mlir::tblgen::FmtContext fctx; + fctx.addSubst("overall", "overall"); + + std::vector opAvailabilities = + getAvailabilities(srcOp.getDef()); + + // First collect all availablity classes this op should implement. + // All availablity instances keep information for the generated interface and + // the instance's specific requirement. Here we remember a random instance so + // we can get the information regarding the generated interface. + llvm::StringMap availClasses; + for (const Availability &avail : opAvailabilities) + availClasses.try_emplace(avail.getClass(), avail); + + // Then generate implementation for each availability class. + for (const auto &availClass : availClasses) { + StringRef availClassName = availClass.getKey(); + Availability avail = availClass.getValue(); + + // Generate the implementation method signature. + os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(), + srcOp.getCppClassName(), avail.getQueryFnName()); + + // Create the variable for the final requirement and initialize it. + os << formatv(" {0} overall = {1};\n", avail.getQueryFnRetType(), + avail.getMergeInitializer()); + + // Update with the op's specific availability spec. + for (const Availability &avail : opAvailabilities) + if (avail.getClass() == availClassName) { + os << " " + << tgfmt(avail.getMergeActionCode(), + &fctx.addSubst("instance", avail.getMergeInstance())) + << ";\n"; + } + os << " return overall;\n"; + os << "}\n"; + } +} + +static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op"); + for (const auto *def : defs) { + Operator op(def); + emitAvailabilityImpl(op, os); + } + return false; +} + +//===----------------------------------------------------------------------===// +// Op Availability Implementation Hook Registration +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration + genOpAvailabilityImpl("gen-spirv-avail-impls", + "Generate SPIR-V operation utility definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitAvailabilityImpl(records, os); + });