diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2514,6 +2514,13 @@ code extraClassDefinition = ?; } +// Marker to group ops together for documentation purposes. +class OpDocGroup { + string summary; + string description; + list ops; +} + // The arguments of an op. class Arguments { dag arguments = args; diff --git a/mlir/test/mlir-tblgen/gen-dialect-doc.td b/mlir/test/mlir-tblgen/gen-dialect-doc.td --- a/mlir/test/mlir-tblgen/gen-dialect-doc.td +++ b/mlir/test/mlir-tblgen/gen-dialect-doc.td @@ -14,7 +14,17 @@ }]; let cppNamespace = "NS"; } -def AOp : Op]>; +def AAOp : Op]>; +def ABOp : Op]>; +def ACOp : Op]>; + +def : OpDocGroup { + let summary = "Group of ops"; + let description = "Grouped for some reason."; + let ops = [AAOp, ACOp]; +} + +def ADOp : Op]>; def TestAttr : DialectAttr> { let summary = "attribute summary"; @@ -53,6 +63,10 @@ // CHECK: [TOC] // CHECK-NOT: [TOC] +// CHECK: test.b +// CHECK: test.d +// CHECK: Group of ops +// CHECK: test.a // CHECK: Traits: SingleBlockImplicitTerminator // CHECK: Interfaces: NoMemoryEffect (MemoryEffectOpInterface) // CHECK: Effects: MemoryEffects::Effect{} diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -36,11 +36,12 @@ //===----------------------------------------------------------------------===// // Commandline Options //===----------------------------------------------------------------------===// -static llvm::cl::OptionCategory docCat("Options for -gen-(attrdef|typedef|op|dialect)-doc"); -llvm::cl::opt stripPrefix( - "strip-prefix", - llvm::cl::desc("Strip prefix of the fully qualified names"), - llvm::cl::init("::mlir::"), llvm::cl::cat(docCat)); +static llvm::cl::OptionCategory + docCat("Options for -gen-(attrdef|typedef|op|dialect)-doc"); +llvm::cl::opt + stripPrefix("strip-prefix", + llvm::cl::desc("Strip prefix of the fully qualified names"), + llvm::cl::init("::mlir::"), llvm::cl::cat(docCat)); using namespace llvm; using namespace mlir; @@ -276,7 +277,8 @@ } os << "\nSyntax:\n\n```\n" - << prefix << def.getDialect().getName() << "." << def.getMnemonic() << "<\n"; + << prefix << def.getDialect().getName() << "." << def.getMnemonic() + << "<\n"; for (const auto &it : llvm::enumerate(parameters)) { const AttrOrTypeParameter ¶m = it.value(); os << " " << param.getSyntax(); @@ -334,24 +336,52 @@ // Dialect Documentation //===----------------------------------------------------------------------===// -static void emitDialectDoc(const Dialect &dialect, - ArrayRef attributes, - ArrayRef attrDefs, ArrayRef ops, - ArrayRef types, ArrayRef typeDefs, - raw_ostream &os) { - os << "# '" << dialect.getName() << "' Dialect\n\n"; - emitIfNotEmpty(dialect.getSummary(), os); - emitIfNotEmpty(dialect.getDescription(), os); +struct OpDocGroup { + const Dialect &getDialect() const { return ops.front().getDialect(); } - // Generate a TOC marker except if description already contains one. - llvm::Regex r("^[[:space:]]*\\[TOC\\]$", llvm::Regex::RegexFlags::Newline); - if (!r.match(dialect.getDescription())) - os << "[TOC]\n\n"; + // Returns the summary description of the section. + std::string summary; + + // Returns the description of the section. + StringRef description = ""; + + // Instances inside the section. + std::vector ops; +}; +static void maybeNest(bool nest, llvm::function_ref fn, + raw_ostream &os) { + std::string str; + llvm::raw_string_ostream ss(str); + fn(ss); + for (StringRef x : llvm::split(ss.str(), "\n")) { + if (nest && x.starts_with("#")) + os << "#"; + os << x << "\n"; + } +} + +static void emitBlock(ArrayRef attributes, + ArrayRef attrDefs, ArrayRef ops, + ArrayRef types, ArrayRef typeDefs, + raw_ostream &os) { if (!ops.empty()) { os << "## Operation definition\n\n"; - for (const Operator &op : ops) - emitOpDoc(op, os); + for (const OpDocGroup &grouping : ops) { + bool nested = grouping.ops.size() > 1; + maybeNest( + nested, + [&](raw_ostream &os) { + if (nested) { + os << "## " << grouping.summary << "\n\n"; + os << grouping.description << "\n\n"; + } + for (const Operator &op : grouping.ops) { + emitOpDoc(op, os); + } + }, + os); + } } if (!attributes.empty()) { @@ -380,6 +410,23 @@ } } +static void emitDialectDoc(const Dialect &dialect, + ArrayRef attributes, + ArrayRef attrDefs, ArrayRef ops, + ArrayRef types, ArrayRef typeDefs, + raw_ostream &os) { + os << "# '" << dialect.getName() << "' Dialect\n\n"; + emitIfNotEmpty(dialect.getSummary(), os); + emitIfNotEmpty(dialect.getDescription(), os); + + // Generate a TOC marker except if description already contains one. + llvm::Regex r("^[[:space:]]*\\[TOC\\]$", llvm::Regex::RegexFlags::Newline); + if (!r.match(dialect.getDescription())) + os << "[TOC]\n\n"; + + emitBlock(attributes, attrDefs, ops, types, typeDefs, os); +} + static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { std::vector dialectDefs = recordKeeper.getAllDerivedDefinitionsIfDefined("Dialect"); @@ -388,6 +435,9 @@ if (!dialect) return true; + std::vector opDocGroup = + recordKeeper.getAllDerivedDefinitionsIfDefined("OpDocGroup"); + std::vector opDefs = getRequestedOpDefinitions(recordKeeper); std::vector attrDefs = recordKeeper.getAllDerivedDefinitionsIfDefined("DialectAttr"); @@ -400,26 +450,57 @@ std::vector dialectAttrs; std::vector dialectAttrDefs; - std::vector dialectOps; + std::vector dialectOps; std::vector dialectTypes; std::vector dialectTypeDefs; llvm::SmallDenseSet seen; auto addIfInDialect = [&](llvm::Record *record, const auto &def, auto &vec) { - if (seen.insert(record).second && def.getDialect() == *dialect) + if (seen.insert(record).second && def.getDialect() == *dialect) { vec.push_back(def); + return true; + } + return false; }; + // Treat op groupings first to mark as seen. + for (const auto &sec : opDocGroup) { + OpDocGroup op; + op.summary = sec->getValueAsString("summary"); + op.description = sec->getValueAsString("description"); + for (auto nestedRec : sec->getValueAsListOfDefs("ops")) + addIfInDialect(nestedRec, Operator(nestedRec), op.ops); + if (!op.ops.empty()) + dialectOps.push_back(op); + } + for (Record *def : attrDefDefs) addIfInDialect(def, AttrDef(def), dialectAttrDefs); for (Record *def : attrDefs) addIfInDialect(def, Attribute(def), dialectAttrs); - for (Record *def : opDefs) - addIfInDialect(def, Operator(def), dialectOps); + for (Record *def : opDefs) { + OpDocGroup op; + op.ops.emplace_back(def); + addIfInDialect(def, op, dialectOps); + } for (Record *def : typeDefDefs) addIfInDialect(def, TypeDef(def), dialectTypeDefs); for (Record *def : typeDefs) addIfInDialect(def, Type(def), dialectTypes); + // Sort alphabetically ignorning dialect for ops and section name for + // sections. + // TODO: The sorting order could be revised, currently attempting to sort of + // keep in alphabetical order. + std::sort(dialectOps.begin(), dialectOps.end(), + [](const OpDocGroup &lhs, const OpDocGroup &rhs) { + auto getDesc = [](const OpDocGroup &arg) -> StringRef { + if (!arg.summary.empty()) + return arg.summary; + return arg.ops.front().getDef().getValueAsString("opName"); + }; + return getDesc(lhs).compare_insensitive(getDesc(rhs)) < 0; + }); + os << "\n"; emitDialectDoc(*dialect, dialectAttrs, dialectAttrDefs, dialectOps, dialectTypes, dialectTypeDefs, os);