diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td @@ -47,6 +47,9 @@ // The following are fields for a concrete availability instance. + // The code for preparing a concrete instance. This should be C++ statements + // and will be generated before the `mergeAction` logic. + code instancePreparation = ""; // The availability requirement carried by a concrete instance. string instance = ?; } diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -148,23 +148,27 @@ AND (`Extension::C`) AND (`Extension::D` OR `Extension::E`) is enabled. }]; - // TODO(antiagainst): Using SmallVector> is an anti-pattern. + // TODO(antiagainst): Returning SmallVector> is not recommended. // Find a better way for this. - let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<" - "::mlir::spirv::Extension, 1>, 1>"; + let queryFnRetType = "::llvm::SmallVector<::llvm::ArrayRef<" + "::mlir::spirv::Extension>, 1>"; let queryFnName = "getExtensions"; let mergeAction = !if( !empty(extensions), "", "$overall.emplace_back($instance)"); let initializer = "{}"; - let instanceType = "::llvm::SmallVector<::mlir::spirv::Extension, 1>"; + let instanceType = "::llvm::ArrayRef<::mlir::spirv::Extension>"; - // Compose all capabilities as an C++ initializer list - let instance = "std::initializer_list<::mlir::spirv::Extension>{" # - StrJoin.result # - "}"; + // Pack all extensions as a static array and get its reference. + let instancePreparation = !if(!empty(extensions), "", + "static const ::mlir::spirv::Extension exts[] = {" # + StrJoin.result # + "}; " # + // The following manual ArrayRef constructor call is to satisfy GCC 5. + "ArrayRef<::mlir::spirv::Extension> " # + "ref(exts, ::llvm::array_lengthof(exts));"); + let instance = "ref"; } class Capability capabilities> : Availability { @@ -187,21 +191,25 @@ AND (`Capability::C`) AND (`Capability::D` OR `Capability::E`) is enabled. }]; - let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<" - "::mlir::spirv::Capability, 1>, 1>"; + let queryFnRetType = "::llvm::SmallVector<::llvm::ArrayRef<" + "::mlir::spirv::Capability>, 1>"; let queryFnName = "getCapabilities"; let mergeAction = !if( !empty(capabilities), "", "$overall.emplace_back($instance)"); let initializer = "{}"; - let instanceType = "::llvm::SmallVector<::mlir::spirv::Capability, 1>"; + let instanceType = "::llvm::ArrayRef<::mlir::spirv::Capability>"; - // Compose all capabilities as an C++ initializer list - let instance = "std::initializer_list<::mlir::spirv::Capability>{" # - StrJoin.result # - "}"; + // Pack all capabilities as a static array and get its reference. + let instancePreparation = !if(!empty(capabilities), "", + "static const ::mlir::spirv::Capability caps[] = {" # + StrJoin.result # + "}; " # + // The following manual ArrayRef constructor call is to satisfy GCC 5. + "ArrayRef<::mlir::spirv::Capability> " # + "ref(caps, ::llvm::array_lengthof(caps));"); + let instance = "ref"; } // TODO(antiagainst): the following interfaces definitions are duplicating with @@ -216,13 +224,13 @@ def QueryExtensionInterface : OpInterface<"QueryExtensionInterface"> { let methods = [InterfaceMethod< "", - "::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Extension, 1>, 1>", + "::llvm::SmallVector<::llvm::ArrayRef<::mlir::spirv::Extension>, 1>", "getExtensions">]; } def QueryCapabilityInterface : OpInterface<"QueryCapabilityInterface"> { let methods = [InterfaceMethod< "", - "::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Capability, 1>, 1>", + "::llvm::SmallVector<::llvm::ArrayRef<::mlir::spirv::Capability>, 1>", "getCapabilities">]; } diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -85,6 +85,9 @@ // Returns the C++ type for an availability instance. StringRef getMergeInstanceType() const; + // Returns the C++ statements for preparing availability instance. + StringRef getMergeInstancePreparation() const; + // Returns the concrete availability instance carried in this case. StringRef getMergeInstance() const; @@ -137,6 +140,10 @@ return def->getValueAsString("instanceType"); } +StringRef Availability::getMergeInstancePreparation() const { + return def->getValueAsString("instancePreparation"); +} + StringRef Availability::getMergeInstance() const { return def->getValueAsString("instance"); } @@ -321,9 +328,9 @@ for (const auto &caseSpecPair : classCasePair.getValue()) { EnumAttrCase enumerant = caseSpecPair.first; Availability avail = caseSpecPair.second; - os << formatv(" case {0}::{1}: return {2}({3});\n", enumName, - enumerant.getSymbol(), avail.getMergeInstanceType(), - avail.getMergeInstance()); + os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName, + enumerant.getSymbol(), avail.getMergeInstancePreparation(), + avail.getMergeInstanceType(), avail.getMergeInstance()); } // Only emit default if uncovered cases. if (classCasePair.getValue().size() < enumAttr.getAllCases().size()) @@ -368,9 +375,9 @@ for (const auto &caseSpecPair : classCasePair.getValue()) { EnumAttrCase enumerant = caseSpecPair.first; Availability avail = caseSpecPair.second; - os << formatv(" case {0}::{1}: return {2}({3});\n", enumName, - enumerant.getSymbol(), avail.getMergeInstanceType(), - avail.getMergeInstance()); + os << formatv(" case {0}::{1}: { {2} return {3}({4}); }\n", enumName, + enumerant.getSymbol(), avail.getMergeInstancePreparation(), + avail.getMergeInstanceType(), avail.getMergeInstance()); } os << " default: break;\n"; os << " }\n" @@ -1162,7 +1169,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) { mlir::tblgen::FmtContext fctx; - fctx.addSubst("overall", "overall"); + fctx.addSubst("overall", "tblgen_overall"); std::vector opAvailabilities = getAvailabilities(srcOp.getDef()); @@ -1195,17 +1202,23 @@ srcOp.getCppClassName(), avail.getQueryFnName()); // Create the variable for the final requirement and initialize it. - os << formatv(" {0} overall = {1};\n", avail.getQueryFnRetType(), + os << formatv(" {0} tblgen_overall = {1};\n", avail.getQueryFnRetType(), avail.getMergeInitializer()); // Update with the op's specific availability spec. for (const Availability &avail : opAvailabilities) - if (avail.getClass() == availClassName) { - os << " " + if (avail.getClass() == availClassName && + (!avail.getMergeInstancePreparation().empty() || + !avail.getMergeActionCode().empty())) { + os << " {\n " + // Prepare this instance. + << avail.getMergeInstancePreparation() + << "\n " + // Merge this instance. << std::string( tgfmt(avail.getMergeActionCode(), &fctx.addSubst("instance", avail.getMergeInstance()))) - << ";\n"; + << ";\n }\n"; } // Update with enum attributes' specific availability spec. @@ -1236,30 +1249,32 @@ os << formatv(" for (unsigned i = 0; " "i < std::numeric_limits<{0}>::digits; ++i) {{\n", enumAttr->getUnderlyingType()); - os << formatv(" {0}::{1} attrVal = this->{2}() & " + os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & " "static_cast<{0}::{1}>(1 << i);\n", enumAttr->getCppNamespace(), enumAttr->getEnumClassName(), namedAttr.name); - os << formatv(" if (static_cast<{0}>(attrVal) == 0) continue;\n", - enumAttr->getUnderlyingType()); + os << formatv( + " if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n", + enumAttr->getUnderlyingType()); } else { // For IntEnumAttr, we just need to query the value as a whole. os << " {\n"; - os << formatv(" auto attrVal = this->{0}();\n", namedAttr.name); + os << formatv(" auto tblgen_attrVal = this->{0}();\n", + namedAttr.name); } - os << formatv(" auto instance = {0}::{1}(attrVal);\n", + os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n", enumAttr->getCppNamespace(), avail.getQueryFnName()); - os << " if (instance) " + os << " if (tblgen_instance) " // TODO(antiagainst): use `avail.getMergeCode()` here once ODS supports // dialect-specific contents so that we can use not implementing the // availability interface as indication of no requirements. << std::string(tgfmt(caseSpecs.front().second.getMergeActionCode(), - &fctx.addSubst("instance", "*instance"))) + &fctx.addSubst("instance", "*tblgen_instance"))) << ";\n"; os << " }\n"; } - os << " return overall;\n"; + os << " return tblgen_overall;\n"; os << "}\n"; } }