diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -154,6 +154,7 @@ add_subdirectory(tools/mlir-tblgen) add_subdirectory(tools/mlir-linalg-ods-gen) add_subdirectory(tools/mlir-pdll) +add_subdirectory(tools/mlir-src-sharder) add_subdirectory(include/mlir) add_subdirectory(lib) diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake --- a/mlir/cmake/modules/AddMLIR.cmake +++ b/mlir/cmake/modules/AddMLIR.cmake @@ -9,7 +9,7 @@ tablegen(MLIR ${ARGV}) set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn} PARENT_SCOPE) - + # Get the current set of include paths for this td file. cmake_parse_arguments(ARG "" "" "DEPENDS;EXTRA_INCLUDES" ${ARGN}) get_directory_property(tblgen_includes INCLUDE_DIRECTORIES) @@ -66,7 +66,7 @@ " filepath: \"${LLVM_TARGET_DEFINITIONS_ABSOLUTE}\"\n" " includes: \"${CMAKE_CURRENT_SOURCE_DIR};${tblgen_includes}\"\n" ) - + add_public_tablegen_target(${target}) endfunction() @@ -83,6 +83,22 @@ add_dependencies(mlir-headers MLIR${dialect}IncGen) endfunction() +# Declare sharded dialect operation declarations and definitions +function(add_sharded_ops ops_target shard_count) + set(LLVM_TARGET_DEFINITIONS ${ops_target}.td) + mlir_tablegen(${ops_target}.h.inc -gen-op-decls -op-shard-count=${shard_count}) + mlir_tablegen(${ops_target}.cpp.inc -gen-op-defs -op-shard-count=${shard_count}) + set(LLVM_TARGET_DEFINITIONS ${ops_target}.cpp) + foreach(index RANGE ${shard_count}) + set(SHARDED_SRC ${ops_target}.${index}.cpp) + list(APPEND SHARDED_SRCS ${SHARDED_SRC}) + tablegen(MLIR_SRC_SHARDER ${SHARDED_SRC} -op-shard-index=${index}) + set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${SHARDED_SRC}) + endforeach() + add_public_tablegen_target(MLIR${ops_target}ShardGen) + set(SHARDED_SRCS ${SHARDED_SRCS} PARENT_SCOPE) +endfunction() + # Declare a dialect in the include directory function(add_mlir_interface interface) set(LLVM_TARGET_DEFINITIONS ${interface}.td) diff --git a/mlir/cmake/modules/CMakeLists.txt b/mlir/cmake/modules/CMakeLists.txt --- a/mlir/cmake/modules/CMakeLists.txt +++ b/mlir/cmake/modules/CMakeLists.txt @@ -31,6 +31,7 @@ # Refer to the best host mlir-tbgen, which might be a host-optimized version set(MLIR_CONFIG_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}") set(MLIR_CONFIG_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}") +set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}") configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in @@ -62,6 +63,7 @@ # if we're building with a host-optimized mlir-tblgen (with LLVM_OPTIMIZED_TABLEGEN). set(MLIR_CONFIG_TABLEGEN_EXE mlir-tblgen) set(MLIR_CONFIG_PDLL_TABLEGEN_EXE mlir-pdll) +set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE mlir-src-sharder) configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in diff --git a/mlir/cmake/modules/MLIRConfig.cmake.in b/mlir/cmake/modules/MLIRConfig.cmake.in --- a/mlir/cmake/modules/MLIRConfig.cmake.in +++ b/mlir/cmake/modules/MLIRConfig.cmake.in @@ -10,6 +10,7 @@ set(MLIR_INCLUDE_DIRS "@MLIR_CONFIG_INCLUDE_DIRS@") set(MLIR_TABLEGEN_EXE "@MLIR_CONFIG_TABLEGEN_EXE@") set(MLIR_PDLL_TABLEGEN_EXE "@MLIR_CONFIG_PDLL_TABLEGEN_EXE@") +set(MLIR_SRC_SHARDER_TABLEGEN_EXE "@MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE@") set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@") set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@") diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -98,8 +98,14 @@ /// class StaticVerifierFunctionEmitter { public: + /// Create a constraint uniquer with a unique prefix derived from the record + /// keeper with an optional tag. StaticVerifierFunctionEmitter(raw_ostream &os, - const llvm::RecordKeeper &records); + const llvm::RecordKeeper &records, + StringRef tag = ""); + + /// Collect and unique all the constraints used by operations. + void collectOpConstraints(ArrayRef opDefs); /// Collect and unique all compatible type, attribute, successor, and region /// constraints from the operations in the file and emit them at the top of @@ -107,7 +113,7 @@ /// /// Constraints that do not meet the restriction that they can only reference /// `$_self` and `$_op` are not uniqued. - void emitOpConstraints(ArrayRef opDefs, bool emitDecl); + void emitOpConstraints(ArrayRef opDefs); /// Unique all compatible type and attribute constraints from a pattern file /// and emit them at the top of the generated file. @@ -175,8 +181,6 @@ /// Emit pattern constraints. void emitPatternConstraints(); - /// Collect and unique all the constraints used by operations. - void collectOpConstraints(ArrayRef opDefs); /// Collect and unique all pattern constraints. void collectPatternConstraints(ArrayRef constraints); @@ -222,9 +226,7 @@ } }; template <> struct stringifier { - static std::string apply(const Twine &twine) { - return twine.str(); - } + static std::string apply(const Twine &twine) { return twine.str(); } }; template struct stringifier> { diff --git a/mlir/test/mlir-tblgen/shard-op-defs.td b/mlir/test/mlir-tblgen/shard-op-defs.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/shard-op-defs.td @@ -0,0 +1,33 @@ +// RUN: mlir-tblgen -gen-op-defs -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DEFS +// RUN: mlir-tblgen -gen-op-decls -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DECLS + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; + let cppNamespace = "test"; +} + +class Test_Op traits = []> + : Op; + +def OpA : Test_Op<"a">; +def OpB : Test_Op<"b">; +def OpC : Test_Op<"c">; + +// DECLS: OpA +// DECLS: OpB +// DECLS: OpC +// DECLS: registerTestDialectOperations( +// DECLS: registerTestDialectOperations0( +// DECLS: registerTestDialectOperations1( + +// DEFS-LABEL: GET_OP_DEFS_0 +// DEFS: void test::registerTestDialectOperations( +// DEFS: void test::registerTestDialectOperations0( +// DEFS: OpAAdaptor +// DEFS: OpBAdaptor + +// DEFS-LABEL: GET_OP_DEFS_1 +// DEFS: void test::registerTestDialectOperations1( +// DEFS: OpCAdaptor diff --git a/mlir/tools/mlir-src-sharder/CMakeLists.txt b/mlir/tools/mlir-src-sharder/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-src-sharder/CMakeLists.txt @@ -0,0 +1,14 @@ +set(LLVM_LINK_COMPONENTS Support) +set(LIBS MLIRSupport) + +add_tablegen(mlir-src-sharder MLIR_SRC_SHARDER + mlir-src-sharder.cpp + + DEPENDS + ${LIBS} + ) + +set_target_properties(mlir-src-sharder PROPERTIES FOLDER "Tablegenning") +target_link_libraries(mlir-src-sharder PRIVATE ${LIBS}) + +mlir_check_all_link_libraries(mlir-src-sharder) diff --git a/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp @@ -0,0 +1,114 @@ +//===- mlir-src-sharder.cpp - A tool for sharder generated source files ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/ToolOutputFile.h" + +using namespace mlir; + +/// Create a dependency file for `-d` option. +/// +/// This functionality is generally only for the benefit of the build system, +/// and is modeled after the same option in TableGen. +static LogicalResult createDependencyFile(StringRef outputFilename, + StringRef dependencyFile) { + if (outputFilename == "-") { + llvm::errs() << "error: the option -d must be used together with -o\n"; + return failure(); + } + + std::string errorMessage; + std::unique_ptr outputFile = + openOutputFile(dependencyFile, &errorMessage); + if (!outputFile) { + llvm::errs() << errorMessage << "\n"; + return failure(); + } + + outputFile->os() << outputFilename << ":\n"; + outputFile->keep(); + return success(); +} + +int main(int argc, char **argv) { + // FIXME: This is necessary because we link in TableGen, which defines its + // options as static variables.. some of which overlap with our options. + llvm::cl::ResetCommandLineParser(); + + llvm::cl::opt opShardIndex( + "op-shard-index", llvm::cl::desc("The current shard index")); + llvm::cl::opt inputFilename(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + llvm::cl::opt outputFilename( + "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + llvm::cl::list includeDirs( + "I", llvm::cl::desc("Directory of include files"), + llvm::cl::value_desc("directory"), llvm::cl::Prefix); + llvm::cl::opt dependencyFilename( + "d", llvm::cl::desc("Dependency filename"), + llvm::cl::value_desc("filename"), llvm::cl::init("")); + llvm::cl::opt writeIfChanged( + "write-if-changed", + llvm::cl::desc("Only write to the output file if it changed")); + + llvm::InitLLVM y(argc, argv); + llvm::cl::ParseCommandLineOptions(argc, argv); + + // Open the input file. + std::string errorMessage; + std::unique_ptr inputFile = + openInputFile(inputFilename, &errorMessage); + if (!inputFile) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + + // Write the output to a buffer. + std::string outputStr; + llvm::raw_string_ostream os(outputStr); + os << "#define GET_OP_DEFS_" << opShardIndex << "\n" + << inputFile->getBuffer(); + + // Determine whether we need to write the output file. + bool shouldWriteOutput = true; + if (writeIfChanged) { + // Only update the real output file if there are any differences. This + // prevents recompilation of all the files depending on it if there aren't + // any. + if (auto existingOrErr = + llvm::MemoryBuffer::getFile(outputFilename, /*IsText=*/true)) + if (std::move(existingOrErr.get())->getBuffer() == os.str()) + shouldWriteOutput = false; + } + + // Populate the output file if necessary. + if (shouldWriteOutput) { + std::unique_ptr outputFile = + openOutputFile(outputFilename, &errorMessage); + if (!outputFile) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + outputFile->os() << os.str(); + outputFile->keep(); + } + + // Always write the depfile, even if the main output hasn't changed. If it's + // missing, Ninja considers the output dirty. + if (!dependencyFilename.empty()) + if (failed(createDependencyFile(outputFilename, dependencyFilename))) + return 1; + + return 0; +} diff --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp --- a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp +++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp @@ -25,7 +25,8 @@ /// Generate a unique label based on the current file name to prevent name /// collisions if multiple generated files are included at once. -static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) { +static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records, + StringRef tag) { // Use the input file name when generating a unique name. std::string inputFilename = records.getInputFilename(); @@ -34,7 +35,7 @@ nameRef.consume_back(".td"); // Sanitize any invalid characters. - std::string uniqueName; + std::string uniqueName(tag); for (char c : nameRef) { if (llvm::isAlnum(c) || c == '_') uniqueName.push_back(c); @@ -45,15 +46,11 @@ } StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( - raw_ostream &os, const llvm::RecordKeeper &records) - : os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {} + raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag) + : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {} void StaticVerifierFunctionEmitter::emitOpConstraints( - ArrayRef opDefs, bool emitDecl) { - collectOpConstraints(opDefs); - if (emitDecl) - return; - + ArrayRef opDefs) { NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); emitTypeConstraints(); emitAttrConstraints(); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2956,32 +2956,15 @@ OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os); } -// Emits the opcode enum and op classes. -static void emitOpClasses(const RecordKeeper &recordKeeper, - const std::vector &defs, raw_ostream &os, - bool emitDecl) { - // First emit forward declaration for each class, this allows them to refer - // to each others in traits for example. - if (emitDecl) { - os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n"; - os << "#undef GET_OP_FWD_DEFINES\n"; - for (auto *def : defs) { - Operator op(*def); - NamespaceEmitter emitter(os, op.getCppNamespace()); - os << "class " << op.getCppClassName() << ";\n"; - } - os << "#endif\n\n"; - } - - IfDefScope scope("GET_OP_CLASSES", os); +/// Emit the class declarations or definitions for the given op defs. +static void +emitOpClasses(const RecordKeeper &recordKeeper, + const std::vector &defs, raw_ostream &os, + const StaticVerifierFunctionEmitter &staticVerifierEmitter, + bool emitDecl) { if (defs.empty()) return; - // Generate all of the locally instantiated methods first. - StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper); - os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); - staticVerifierEmitter.emitOpConstraints(defs, emitDecl); - for (auto *def : defs) { Operator op(*def); if (emitDecl) { @@ -3011,34 +2994,147 @@ } } -// Emits a comma-separated list of the ops. -static void emitOpList(const std::vector &defs, raw_ostream &os) { - IfDefScope scope("GET_OP_LIST", os); +/// Emit the declarations for the provided op classes. +static void emitOpClassDecls(const RecordKeeper &recordKeeper, + const std::vector &defs, + raw_ostream &os) { + // First emit forward declaration for each class, this allows them to refer + // to each others in traits for example. + for (auto *def : defs) { + Operator op(*def); + NamespaceEmitter emitter(os, op.getCppNamespace()); + os << "class " << op.getCppClassName() << ";\n"; + } + + // Emit the op class declarations. + IfDefScope scope("GET_OP_CLASSES", os); + if (defs.empty()) + return; + StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper); + staticVerifierEmitter.collectOpConstraints(defs); + emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter, + /*emitDecl=*/true); +} + +/// Emit the definitions for the provided op classes. +static void emitOpClassDefs(const RecordKeeper &recordKeeper, + ArrayRef defs, raw_ostream &os, + StringRef constraintPrefix = "") { + if (defs.empty()) + return; + + // Generate all of the locally instantiated methods first. + StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper, + constraintPrefix); + os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); + staticVerifierEmitter.collectOpConstraints(defs); + staticVerifierEmitter.emitOpConstraints(defs); - interleave( - // TODO: We are constructing the Operator wrapper instance just for - // getting it's qualified class name here. Reduce the overhead by having a - // lightweight version of Operator class just for that purpose. - defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); }, - [&os]() { os << ",\n"; }); + // Emit the classes. + emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter, + /*emitDecl=*/false); } +/// Emit op declarations for all op records. static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Declarations", os); + emitSourceFileHeader("Op Declarations", os); std::vector defs = getRequestedOpDefinitions(recordKeeper); - emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true); + emitOpClassDecls(recordKeeper, defs, os); + + // If we are generating sharded op definitions, emit the sharded op + // registration hooks. + SmallVector, 4> shardedDefs; + shardOpDefinitions(defs, shardedDefs); + if (defs.empty() || shardedDefs.size() <= 1) + return false; + + Dialect dialect = Operator(defs.front()).getDialect(); + NamespaceEmitter ns(os, dialect); + + const char *const opRegistrationHook = + "void register{0}Operations{1}({2}::{0} *dialect);\n"; + os << formatv(opRegistrationHook, dialect.getCppClassName(), "", + dialect.getCppNamespace()); + for (unsigned i = 0; i < shardedDefs.size(); ++i) { + os << formatv(opRegistrationHook, dialect.getCppClassName(), i, + dialect.getCppNamespace()); + } return false; } +/// Generate the dialect op registration hook and the op class definitions for a +/// shard of ops. +static void emitOpDefShard(const RecordKeeper &recordKeeper, + ArrayRef defs, const Dialect &dialect, + unsigned shardIndex, unsigned shardCount, + raw_ostream &os) { + std::string shardGuard = "GET_OP_DEFS_"; + std::string indexStr = std::to_string(shardIndex); + shardGuard += indexStr; + IfDefScope scope(shardGuard, os); + + // Emit the op registration hook in the first shard. + const char *const opRegistrationHook = + "void {0}::register{1}Operations{2}({0}::{1} *dialect) {{\n"; + if (shardIndex == 0) { + os << formatv(opRegistrationHook, dialect.getCppNamespace(), + dialect.getCppClassName(), ""); + for (unsigned i = 0; i < shardCount; ++i) { + os << formatv(" {0}::register{1}Operations{2}(dialect);\n", + dialect.getCppNamespace(), dialect.getCppClassName(), i); + } + os << "}\n"; + } + + // Generate the per-shard op registration hook. + os << formatv(opCommentHeader, dialect.getCppClassName(), + "Op Registration Hook") + << formatv(opRegistrationHook, dialect.getCppNamespace(), + dialect.getCppClassName(), shardIndex); + for (Record *def : defs) { + os << formatv(" ::mlir::RegisteredOperationName::insert<{0}>(*dialect);\n", + Operator(def).getQualCppClassName()); + } + os << "}\n"; + + // Generate the per-shard op definitions. + emitOpClassDefs(recordKeeper, defs, os, indexStr); +} + +/// Emit op definitions for all op records. static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Definitions", os); std::vector defs = getRequestedOpDefinitions(recordKeeper); - emitOpList(defs, os); - emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false); + SmallVector, 4> shardedDefs; + shardOpDefinitions(defs, shardedDefs); + + // If no shard was requested, emit the regular op list and class definitions. + if (shardedDefs.size() == 1) { + { + IfDefScope scope("GET_OP_LIST", os); + interleave( + defs, os, + [&](Record *def) { os << Operator(def).getQualCppClassName(); }, + ",\n"); + } + { + IfDefScope scope("GET_OP_CLASSES", os); + emitOpClassDefs(recordKeeper, defs, os); + } + return false; + } + if (defs.empty()) + return false; + Dialect dialect = Operator(defs.front()).getDialect(); + for (auto &it : llvm::enumerate(shardedDefs)) { + emitOpDefShard(recordKeeper, it.value(), dialect, it.index(), + shardedDefs.size(), os); + } return false; } diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/tools/mlir-tblgen/OpGenHelpers.h --- a/mlir/tools/mlir-tblgen/OpGenHelpers.h +++ b/mlir/tools/mlir-tblgen/OpGenHelpers.h @@ -13,6 +13,7 @@ #ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_ #define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_ +#include "mlir/Support/LLVM.h" #include "llvm/TableGen/Record.h" #include @@ -24,6 +25,10 @@ std::vector getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper); +/// Shard the op defintions into the number of shards set by "op-shard-count". +void shardOpDefinitions(ArrayRef defs, + SmallVectorImpl> &shardedDefs); + } // namespace tblgen } // namespace mlir diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp --- a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp +++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp @@ -30,6 +30,10 @@ "op-exclude-regex", cl::desc("Regex of name of op's to exclude (no filter if empty)"), cl::cat(opDefGenCat)); +static cl::opt opShardCount( + "op-shard-count", + cl::desc("The number of shards into which the op classes will be divided"), + cl::cat(opDefGenCat), cl::init(1)); static std::string getOperationName(const Record &def) { auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name"); @@ -63,3 +67,22 @@ return defs; } + +void mlir::tblgen::shardOpDefinitions( + ArrayRef defs, + SmallVectorImpl> &shardedDefs) { + assert(opShardCount > 0 && "expected a positive shard count"); + if (opShardCount == 1) { + shardedDefs.push_back(defs); + return; + } + + unsigned minShardSize = defs.size() / opShardCount; + unsigned numMissing = defs.size() - minShardSize * opShardCount; + shardedDefs.reserve(opShardCount); + for (unsigned i = 0, start = 0; i < opShardCount; ++i) { + unsigned size = minShardSize + (i < numMissing); + shardedDefs.push_back(defs.slice(start, size)); + start += size; + } +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6566,6 +6566,15 @@ ], ) +cc_binary( + name = "mlir-src-sharder", + srcs = ["tools/mlir-src-sharder/mlir-src-sharder.cpp"], + deps = [ + ":Support", + "//llvm:Support", + ], +) + cc_binary( name = "mlir-linalg-ods-yaml-gen", srcs = [ diff --git a/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl b/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl --- a/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl +++ b/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl @@ -431,3 +431,136 @@ textual_hdrs = [":" + filegroup_name], **kwargs ) + +def _gentbl_shard_impl(ctx): + args = ctx.actions.args() + args.add(ctx.file.src_file) + args.add("-op-shard-index", ctx.attr.index) + args.add("-o", ctx.outputs.out.path) + ctx.actions.run( + outputs = [ctx.outputs.out], + inputs = [ctx.file.src_file], + executable = ctx.executable.sharder, + arguments = [args], + use_default_shell_env = True, + mnemonic = "ShardGenerate", + ) + +gentbl_shard_rule = rule( + _gentbl_shard_impl, + doc = "", + output_to_genfiles = True, + attrs = { + "index": attr.int(mandatory = True, doc = ""), + "sharder": attr.label( + doc = "", + executable = True, + cfg = "exec", + ), + "src_file": attr.label( + doc = "", + allow_single_file = True, + mandatory = True, + ), + "out": attr.output( + doc = "", + mandatory = True, + ), + }, +) + +def gentbl_sharded_ops( + name, + tblgen, + sharder, + td_file, + shard_count, + src_file, + src_out, + hdr_out, + test = False, + includes = [], + strip_include_prefix = None, + deps = []): + """Generate sharded op declarations and definitions. + + This special build rule shards op definitions in a TableGen file and generates multiple copies + of a template source file for including and compiling each shard. The rule defines a filegroup + consisting of the source shards, the generated source file, and the generated header file. + + Args: + name: The name of the filegroup. + tblgen: The binary used to produce the output. + sharder: The source file sharder to use. + td_file: The primary table definitions file. + shard_count: The number of op definition shards to produce. + src_file: The source file template. + src_out: The generated source file. + hdr_out: The generated header file. + test: Whether this is a test target. + includes: See gentbl_rule.includes + deps: See gentbl_rule.deps + strip_include_prefix: Attribute to pass through to cc_library. + """ + cc_lib_name = name + "__gentbl_cc_lib" + gentbl_cc_library( + name = cc_lib_name, + strip_include_prefix = strip_include_prefix, + includes = includes, + tbl_outs = [ + ( + [ + "-gen-op-defs", + "-op-shard-count=" + str(shard_count), + ], + src_out, + ), + ( + [ + "-gen-op-decls", + "-op-shard-count=" + str(shard_count), + ], + hdr_out, + ), + ], + tblgen = tblgen, + td_file = td_file, + test = test, + deps = deps, + ) + all_files = [hdr_out, src_out] + for i in range(0, shard_count): + out_file = "shard_copy_" + str(i) + "_" + src_file + gentbl_shard_rule( + index = i, + name = name + "__src_shard" + str(i), + testonly = test, + out = out_file, + sharder = sharder, + src_file = src_file, + ) + all_files.append(out_file) + native.filegroup(name = name, srcs = all_files) + +def gentbl_sharded_op_defs(name, source_file, shard_count): + """Generates multiple copies of a source file that includes sharded op definitions. + + Args: + name: The name of the rule. + source_file: The source to copy. + shard_count: The number of shards. + + Returns: + A list of the copied filenames to be included in the dialect library. + """ + copies = [] + for i in range(0, shard_count): + out_file = "shard_copy_" + str(i) + "_" + source_file + copies.append(out_file) + native.genrule( + name = name + "_shard_" + str(i), + srcs = [source_file], + outs = [out_file], + cmd = "echo -e \"#define GET_OP_DEFS_" + str(i) + "\n$$(cat $(SRCS))\" > $(OUTS)", + ) + return copies