diff --git a/mlir/include/mlir/Dialect/MemRef/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/MemRef/IR/CMakeLists.txt @@ -1,2 +1,2 @@ add_mlir_dialect(MemRefOps memref) -add_mlir_doc(MemRefOps MemRefOps Dialects/ -gen-dialect-doc) +add_mlir_doc(MemRefOps MemRefOps Dialects/ -gen-dialect-doc -dialect=memref) diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -360,6 +360,13 @@ } static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { + std::vector dialectDefs = + recordKeeper.getAllDerivedDefinitionsIfDefined("Dialect"); + SmallVector dialects(dialectDefs.begin(), dialectDefs.end()); + Optional dialect = findDialectToGenerate(dialects); + if (!dialect) + return true; + std::vector opDefs = getRequestedOpDefinitions(recordKeeper); std::vector attrDefs = recordKeeper.getAllDerivedDefinitionsIfDefined("DialectAttr"); @@ -370,61 +377,31 @@ std::vector attrDefDefs = recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef"); - llvm::SetVector, std::set> - dialectsWithDocs; - - llvm::StringMap> dialectAttrs; - llvm::StringMap> dialectAttrDefs; - llvm::StringMap> dialectOps; - llvm::StringMap> dialectTypes; - llvm::StringMap> dialectTypeDefs; + std::vector dialectAttrs; + std::vector dialectAttrDefs; + std::vector dialectOps; + std::vector dialectTypes; + std::vector dialectTypeDefs; llvm::SmallDenseSet seen; - for (Record *attrDef : attrDefDefs) { - AttrDef attr(attrDef); - dialectAttrDefs[attr.getDialect().getName()].push_back(attr); - dialectsWithDocs.insert(attr.getDialect()); - seen.insert(attrDef); - } - for (Record *attrDef : attrDefs) { - if (seen.count(attrDef)) - continue; - Attribute attr(attrDef); - if (const Dialect &dialect = attr.getDialect()) { - dialectAttrs[dialect.getName()].push_back(attr); - dialectsWithDocs.insert(dialect); - } - } - for (Record *opDef : opDefs) { - Operator op(opDef); - dialectOps[op.getDialect().getName()].push_back(op); - dialectsWithDocs.insert(op.getDialect()); - } - for (Record *typeDef : typeDefDefs) { - TypeDef type(typeDef); - dialectTypeDefs[type.getDialect().getName()].push_back(type); - dialectsWithDocs.insert(type.getDialect()); - seen.insert(typeDef); - } - for (Record *typeDef : typeDefs) { - if (seen.count(typeDef)) - continue; - Type type(typeDef); - if (const Dialect &dialect = type.getDialect()) { - dialectTypes[dialect.getName()].push_back(type); - dialectsWithDocs.insert(dialect); - } - } - - Optional dialect = - findDialectToGenerate(dialectsWithDocs.getArrayRef()); - if (!dialect) - return true; + auto addIfInDialect = [&](llvm::Record *record, const auto &def, auto &vec) { + if (seen.insert(record).second && def.getDialect() == *dialect) + vec.push_back(def); + }; + + for (Record *def : attrDefDefs) + addIfInDialect(def, AttrDef(def), dialectAttrDefs); + for (Record *def : attrDefs) + addIfInDialect(def, Attribute(def), dialectAttrs); + for (Record *def : opDefs) + addIfInDialect(def, Operator(def), dialectOps); + for (Record *def : typeDefDefs) + addIfInDialect(def, TypeDef(def), dialectTypeDefs); + for (Record *def : typeDefs) + addIfInDialect(def, Type(def), dialectTypes); os << "\n"; - StringRef dialectName = dialect->getName(); - emitDialectDoc(*dialect, dialectAttrs[dialectName], - dialectAttrDefs[dialectName], dialectOps[dialectName], - dialectTypes[dialectName], dialectTypeDefs[dialectName], os); + emitDialectDoc(*dialect, dialectAttrs, dialectAttrDefs, dialectOps, + dialectTypes, dialectTypeDefs, os); return false; }