diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -1,4 +1,5 @@ // RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s +// RUN: mlir-tblgen -gen-op-decls -op-regex="test.a_op" -I %S/../../include %s | FileCheck %s --check-prefix=REDUCE include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -195,3 +196,5 @@ // CHECK-LABEL: _BOp declarations // CHECK: class _BOp : public Op<_BOp +// REDUCE-LABEL: NS::AOp declarations +// REDUCE-NOT: NS::BOp declarations diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -21,6 +21,8 @@ #include "mlir/TableGen/SideEffects.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Regex.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -32,6 +34,13 @@ using namespace mlir; using namespace mlir::tblgen; +cl::OptionCategory opDefGenCat("Options for -gen-op-defs and -gen-op-decls"); + +static cl::opt + opFilter("op-regex", + cl::desc("Regex of name of op's to filter (no filter if empty)"), + cl::cat(opDefGenCat)); + static const char *const tblgenNamePrefix = "tblgen_"; static const char *const generatedArgName = "odsArg"; static const char *const builderOpState = "odsState"; @@ -2081,10 +2090,37 @@ [&os]() { os << ",\n"; }); } +static std::string getOperationName(const Record &def) { + auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name"); + auto opName = def.getValueAsString("opName"); + if (prefix.empty()) + return std::string(opName); + return std::string(llvm::formatv("{0}.{1}", prefix, opName)); +} + +static std::vector +getAllDerivedDefinitions(const RecordKeeper &recordKeeper, + StringRef className) { + Record *classDef = recordKeeper.getClass(className); + if (!classDef) + PrintFatalError("ERROR: Couldn't find the `" + className + "' class!\n"); + + llvm::Regex includeRegex(opFilter); + std::vector defs; + for (const auto &def : recordKeeper.getDefs()) { + if (def.second->isSubClassOf(classDef)) { + if (opFilter.empty() || includeRegex.match(getOperationName(*def.second))) + defs.push_back(def.second.get()); + } + } + + return defs; +} + static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Declarations", os); - const auto &defs = recordKeeper.getAllDerivedDefinitions("Op"); + const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op"); emitOpClasses(defs, os, /*emitDecl=*/true); return false; @@ -2093,7 +2129,7 @@ static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Definitions", os); - const auto &defs = recordKeeper.getAllDerivedDefinitions("Op"); + const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op"); emitOpList(defs, os); emitOpClasses(defs, os, /*emitDecl=*/false);