diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -15,6 +15,7 @@ OpDefinitionsGen.cpp OpDocGen.cpp OpFormatGen.cpp + OpGenHelpers.cpp OpInterfacesGen.cpp OpPythonBindingGen.cpp PassCAPIGen.cpp diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "OpFormatGen.h" +#include "OpGenHelpers.h" #include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" @@ -22,8 +23,7 @@ #include "mlir/TableGen/SideEffects.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Regex.h" +#include "llvm/Support/Path.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -35,17 +35,6 @@ using namespace mlir; using namespace mlir::tblgen; -cl::OptionCategory opDefGenCat("Options for -gen-op-defs and -gen-op-decls"); - -static cl::opt opIncFilter( - "op-include-regex", - cl::desc("Regex of name of op's to include (no filter if empty)"), - cl::cat(opDefGenCat)); -static cl::opt opExcFilter( - "op-exclude-regex", - cl::desc("Regex of name of op's to exclude (no filter if empty)"), - cl::cat(opDefGenCat)); - static const char *const tblgenNamePrefix = "tblgen_"; static const char *const generatedArgName = "odsArg"; static const char *const odsBuilder = "odsBuilder"; @@ -2472,44 +2461,10 @@ [&os]() { os << ",\n"; }); } -static std::string getOperationName(const Record &def) { - auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name"); - auto opName = def.getValueAsString("opName"); - if (prefix.empty()) - return std::string(opName); - return std::string(llvm::formatv("{0}.{1}", prefix, opName)); -} - -static std::vector -getAllDerivedDefinitions(const RecordKeeper &recordKeeper, - StringRef className) { - Record *classDef = recordKeeper.getClass(className); - if (!classDef) - PrintFatalError("ERROR: Couldn't find the `" + className + "' class!\n"); - - llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter); - std::vector defs; - for (const auto &def : recordKeeper.getDefs()) { - if (!def.second->isSubClassOf(classDef)) - continue; - // Include if no include filter or include filter matches. - if (!opIncFilter.empty() && - !includeRegex.match(getOperationName(*def.second))) - continue; - // Unless there is an exclude filter and it matches. - if (!opExcFilter.empty() && - excludeRegex.match(getOperationName(*def.second))) - continue; - defs.push_back(def.second.get()); - } - - return defs; -} - static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Declarations", os); - const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op"); + std::vector defs = getRequestedOpDefinitions(recordKeeper); emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true); return false; @@ -2518,7 +2473,7 @@ static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Definitions", os); - const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op"); + std::vector defs = getRequestedOpDefinitions(recordKeeper); emitOpList(defs, os); emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false); diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "DocGenUtilities.h" +#include "OpGenHelpers.h" #include "mlir/Support/IndentedOstream.h" #include "mlir/TableGen/AttrOrTypeDef.h" #include "mlir/TableGen/GenInfo.h" @@ -141,7 +142,7 @@ } static void emitOpDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { - auto opDefs = recordKeeper.getAllDerivedDefinitions("Op"); + auto opDefs = getRequestedOpDefinitions(recordKeeper); os << "\n"; for (const llvm::Record *opDef : opDefs) @@ -269,7 +270,7 @@ } static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { - std::vector opDefs = recordKeeper.getAllDerivedDefinitions("Op"); + std::vector opDefs = getRequestedOpDefinitions(recordKeeper); std::vector typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType"); std::vector typeDefDefs = diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.h b/mlir/tools/mlir-tblgen/OpGenHelpers.h new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/OpGenHelpers.h @@ -0,0 +1,30 @@ +//===- OpGenHelpers.h - MLIR operation generator helpers --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines helpers used in the op generators. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_ +#define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_ + +#include "llvm/TableGen/Record.h" +#include + +namespace mlir { +namespace tblgen { + +/// Returns all the op definitions filtered by the user. The filtering is via +/// command-line option "op-include-regex" and "op-exclude-regex". +std::vector +getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper); + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_ diff --git a/mlir/tools/mlir-tblgen/OpGenHelpers.cpp b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/OpGenHelpers.cpp @@ -0,0 +1,65 @@ +//===- OpGenHelpers.cpp - MLIR operation generator helpers ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines helpers used in the op generators. +// +//===----------------------------------------------------------------------===// + +#include "OpGenHelpers.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Regex.h" +#include "llvm/TableGen/Error.h" + +using namespace llvm; +using namespace mlir; +using namespace mlir::tblgen; + +cl::OptionCategory opDefGenCat("Options for op definition generators"); + +static cl::opt opIncFilter( + "op-include-regex", + cl::desc("Regex of name of op's to include (no filter if empty)"), + cl::cat(opDefGenCat)); +static cl::opt opExcFilter( + "op-exclude-regex", + cl::desc("Regex of name of op's to exclude (no filter if empty)"), + cl::cat(opDefGenCat)); + +static std::string getOperationName(const Record &def) { + auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name"); + auto opName = def.getValueAsString("opName"); + if (prefix.empty()) + return std::string(opName); + return std::string(llvm::formatv("{0}.{1}", prefix, opName)); +} + +std::vector +mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) { + Record *classDef = recordKeeper.getClass("Op"); + if (!classDef) + PrintFatalError("ERROR: Couldn't find the 'Op' class!\n"); + + llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter); + std::vector defs; + for (const auto &def : recordKeeper.getDefs()) { + if (!def.second->isSubClassOf(classDef)) + continue; + // Include if no include filter or include filter matches. + if (!opIncFilter.empty() && + !includeRegex.match(getOperationName(*def.second))) + continue; + // Unless there is an exclude filter and it matches. + if (!opExcFilter.empty() && + excludeRegex.match(getOperationName(*def.second))) + continue; + defs.push_back(def.second.get()); + } + + return defs; +}