diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -98,8 +98,14 @@ /// class StaticVerifierFunctionEmitter { public: + /// Create a constraint uniquer with a unique prefix derived from the record + /// keeper with an optional tag. StaticVerifierFunctionEmitter(raw_ostream &os, - const llvm::RecordKeeper &records); + const llvm::RecordKeeper &records, + StringRef tag = ""); + + /// Collect and unique all the constraints used by operations. + void collectOpConstraints(ArrayRef opDefs); /// Collect and unique all compatible type, attribute, successor, and region /// constraints from the operations in the file and emit them at the top of @@ -107,7 +113,7 @@ /// /// Constraints that do not meet the restriction that they can only reference /// `$_self` and `$_op` are not uniqued. - void emitOpConstraints(ArrayRef opDefs, bool emitDecl); + void emitOpConstraints(ArrayRef opDefs); /// Unique all compatible type and attribute constraints from a pattern file /// and emit them at the top of the generated file. @@ -175,8 +181,6 @@ /// Emit pattern constraints. void emitPatternConstraints(); - /// Collect and unique all the constraints used by operations. - void collectOpConstraints(ArrayRef opDefs); /// Collect and unique all pattern constraints. void collectPatternConstraints(ArrayRef constraints); @@ -222,9 +226,7 @@ } }; template <> struct stringifier { - static std::string apply(const Twine &twine) { - return twine.str(); - } + static std::string apply(const Twine &twine) { return twine.str(); } }; template struct stringifier> { diff --git a/mlir/test/mlir-tblgen/shard-op-defs.td b/mlir/test/mlir-tblgen/shard-op-defs.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/shard-op-defs.td @@ -0,0 +1,29 @@ +// RUN: mlir-tblgen -gen-op-defs -op-shard-count=2 -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; + let cppNamespace = "test"; +} + +class Test_Op traits = []> + : Op; + +def OpA : Test_Op<"a">; +def OpB : Test_Op<"b">; +def OpC : Test_Op<"c">; + +// CHECK-LABEL: GET_OP_LIST_0 +// CHECK: test::OpA +// CHECK: test::OpB + +// CHECK-LABEL: GET_OP_CLASSES_0 +// CHECK: OpAAdaptor +// CHECK: OpBAdaptor + +// CHECK-LABEL: GET_OP_LIST_1 +// CHECK: test::OpC + +// CHECK-LABEL: GET_OP_CLASSES_1 +// CHECK: OpCAdaptor diff --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp --- a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp +++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp @@ -25,7 +25,8 @@ /// Generate a unique label based on the current file name to prevent name /// collisions if multiple generated files are included at once. -static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) { +static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records, + StringRef tag) { // Use the input file name when generating a unique name. std::string inputFilename = records.getInputFilename(); @@ -34,7 +35,7 @@ nameRef.consume_back(".td"); // Sanitize any invalid characters. - std::string uniqueName; + std::string uniqueName(tag); for (char c : nameRef) { if (llvm::isAlnum(c) || c == '_') uniqueName.push_back(c); @@ -45,15 +46,11 @@ } StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( - raw_ostream &os, const llvm::RecordKeeper &records) - : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {} + raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag) + : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {} void StaticVerifierFunctionEmitter::emitOpConstraints( - ArrayRef opDefs, bool emitDecl) { - collectOpConstraints(opDefs); - if (emitDecl) - return; - + ArrayRef opDefs) { NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); emitTypeConstraints(); emitAttrConstraints(); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2956,32 +2956,15 @@ OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os); } -// Emits the opcode enum and op classes. -static void emitOpClasses(const RecordKeeper &recordKeeper, - const std::vector &defs, raw_ostream &os, - bool emitDecl) { - // First emit forward declaration for each class, this allows them to refer - // to each others in traits for example. - if (emitDecl) { - os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n"; - os << "#undef GET_OP_FWD_DEFINES\n"; - for (auto *def : defs) { - Operator op(*def); - NamespaceEmitter emitter(os, op.getCppNamespace()); - os << "class " << op.getCppClassName() << ";\n"; - } - os << "#endif\n\n"; - } - - IfDefScope scope("GET_OP_CLASSES", os); +/// Emit the class declarations or definitions for the given op defs. +static void +emitOpClasses(const RecordKeeper &recordKeeper, + const std::vector &defs, raw_ostream &os, + const StaticVerifierFunctionEmitter &staticVerifierEmitter, + bool emitDecl) { if (defs.empty()) return; - // Generate all of the locally instantiated methods first. - StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper); - os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); - staticVerifierEmitter.emitOpConstraints(defs, emitDecl); - for (auto *def : defs) { Operator op(*def); if (emitDecl) { @@ -3011,23 +2994,65 @@ } } -// Emits a comma-separated list of the ops. -static void emitOpList(const std::vector &defs, raw_ostream &os) { - IfDefScope scope("GET_OP_LIST", os); +/// Emit the declarations for the provided op classes. +static void emitOpClassDecls(const RecordKeeper &recordKeeper, + const std::vector &defs, + raw_ostream &os) { + // First emit forward declaration for each class, this allows them to refer + // to each others in traits for example. + os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n"; + os << "#undef GET_OP_FWD_DEFINES\n"; + for (auto *def : defs) { + Operator op(*def); + NamespaceEmitter emitter(os, op.getCppNamespace()); + os << "class " << op.getCppClassName() << ";\n"; + } + os << "#endif\n\n"; + // Emit the op class declarations. + IfDefScope scope("GET_OP_CLASSES", os); + if (defs.empty()) + return; + StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper); + staticVerifierEmitter.collectOpConstraints(defs); + emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter, + /*emitDecl=*/true); +} + +/// Emit the operation list for inclusion inside `Dialect::addOperations`. +static void emitOpList(ArrayRef defs, raw_ostream &os, + StringRef listGuard) { + IfDefScope scope(listGuard, os); interleave( - // TODO: We are constructing the Operator wrapper instance just for - // getting it's qualified class name here. Reduce the overhead by having a - // lightweight version of Operator class just for that purpose. - defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); }, - [&os]() { os << ",\n"; }); + defs, os, [&](Record *def) { os << Operator(def).getQualCppClassName(); }, + ",\n"); +} + +/// Emit the definitions for the provided op classes. +static void emitOpClassDefs(const RecordKeeper &recordKeeper, + ArrayRef defs, raw_ostream &os, + StringRef defGuard, StringRef constraintPrefix) { + IfDefScope scope(defGuard, os); + if (defs.empty()) + return; + + // Generate all of the locally instantiated methods first. + StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper, + constraintPrefix); + os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); + staticVerifierEmitter.collectOpConstraints(defs); + staticVerifierEmitter.emitOpConstraints(defs); + + // Emit the classes. + emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter, + /*emitDecl=*/false); } static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Declarations", os); std::vector defs = getRequestedOpDefinitions(recordKeeper); - emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true); + emitOpClassDecls(recordKeeper, defs, os); return false; } @@ -3036,9 +3061,21 @@ emitSourceFileHeader("Op Definitions", os); std::vector defs = getRequestedOpDefinitions(recordKeeper); - emitOpList(defs, os); - emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false); - + SmallVector, 1> shardedDefs; + shardOpDefinitions(std::move(defs), shardedDefs); + + for (auto &it : llvm::enumerate(shardedDefs)) { + std::string defGuard = "GET_OP_CLASSES"; + std::string listGuard = "GET_OP_LIST"; + std::string indexStr = ("_" + Twine(it.index()).str()); + if (shardedDefs.size() > 1) { + defGuard += indexStr; + listGuard += indexStr; + } + emitOpList(it.value(), os, listGuard); + emitOpClassDefs(recordKeeper, it.value(), os, defGuard, + shardedDefs.size() > 1 ? indexStr : ""); + } return false; } diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/tools/mlir-tblgen/OpGenHelpers.h --- a/mlir/tools/mlir-tblgen/OpGenHelpers.h +++ b/mlir/tools/mlir-tblgen/OpGenHelpers.h @@ -13,6 +13,7 @@ #ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_ #define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_ +#include "mlir/Support/LLVM.h" #include "llvm/TableGen/Record.h" #include @@ -24,6 +25,11 @@ std::vector getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper); +/// Shard the op defintions into the number of shards set by "op-shard-count". +void shardOpDefinitions( + std::vector &&defs, + SmallVectorImpl> &shardedDefs); + } // namespace tblgen } // namespace mlir diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp --- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp +++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "OpGenHelpers.h" +#include "mlir/Support/MathExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Regex.h" @@ -30,6 +31,10 @@ "op-exclude-regex", cl::desc("Regex of name of op's to exclude (no filter if empty)"), cl::cat(opDefGenCat)); +static cl::opt opShardCount( + "op-shard-count", + cl::desc("The number of shards into which the op classes will be divided"), + cl::cat(opDefGenCat), cl::init(1)); static std::string getOperationName(const Record &def) { auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name"); @@ -63,3 +68,22 @@ return defs; } + +void mlir::tblgen::shardOpDefinitions( + std::vector &&defs, + SmallVectorImpl> &shardedDefs) { + assert(opShardCount > 0 && "expected a positive shard count"); + if (opShardCount == 1) { + shardedDefs.push_back(std::move(defs)); + return; + } + + unsigned shardSize = ceilDiv(defs.size(), opShardCount); + shardedDefs.assign(opShardCount, {}); + auto it = defs.begin(); + for (std::vector &opDefs : shardedDefs) { + opDefs.reserve(shardSize); + for (unsigned i = 0; i < shardSize && it != defs.end(); ++i, ++it) + opDefs.push_back(*it); + } +}