diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2398,6 +2398,15 @@ list decorators = []> : OpVariable; +// Marker to group ops together for documentation purposes. +class OpDocGroup { + // Single line summary of the group of ops. + string summary; + + // Longer description of documentation group. + string description; +} + // Base class for all ops. class Op props = []> { // The dialect of the op. @@ -2415,6 +2424,9 @@ // Additional, longer human-readable description of what the op does. string description = ""; + // Optional. The group of ops this op is part of. + OpDocGroup opDocGroup = ?; + // Dag containing the arguments of the op. Default to 0 arguments. dag arguments = (ins); diff --git a/mlir/test/mlir-tblgen/gen-dialect-doc.td b/mlir/test/mlir-tblgen/gen-dialect-doc.td --- a/mlir/test/mlir-tblgen/gen-dialect-doc.td +++ b/mlir/test/mlir-tblgen/gen-dialect-doc.td @@ -14,7 +14,28 @@ }]; let cppNamespace = "NS"; } -def AOp : Op]>; + +def OpGroupA : OpDocGroup { + let summary = "Group of ops"; + let description = "Grouped for some reason."; +} + +let opDocGroup = OpGroupA in { +def ADOp : Op]>; +def AAOp : Op]>; +} + +def OpGroupB : OpDocGroup { + let summary = "Other group of ops"; + let description = "Grouped for some other reason."; +} + +let opDocGroup = OpGroupB in { +def ACOp : Op]>; +def ABOp : Op]>; +} + +def AEOp : Op]>; def TestAttr : DialectAttr> { let summary = "attribute summary"; @@ -53,6 +74,13 @@ // CHECK: [TOC] // CHECK-NOT: [TOC] +// CHECK: test.e +// CHECK: Group of ops +// CHECK: test.a +// CHECK: test.d +// CHECK: Other group +// CHECK: test.b +// CHECK: test.c // CHECK: Traits: SingleBlockImplicitTerminator // CHECK: Interfaces: NoMemoryEffect (MemoryEffectOpInterface) // CHECK: Effects: MemoryEffects::Effect{} diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -36,11 +36,12 @@ //===----------------------------------------------------------------------===// // Commandline Options //===----------------------------------------------------------------------===// -static llvm::cl::OptionCategory docCat("Options for -gen-(attrdef|typedef|op|dialect)-doc"); -llvm::cl::opt stripPrefix( - "strip-prefix", - llvm::cl::desc("Strip prefix of the fully qualified names"), - llvm::cl::init("::mlir::"), llvm::cl::cat(docCat)); +static llvm::cl::OptionCategory + docCat("Options for -gen-(attrdef|typedef|op|dialect)-doc"); +llvm::cl::opt + stripPrefix("strip-prefix", + llvm::cl::desc("Strip prefix of the fully qualified names"), + llvm::cl::init("::mlir::"), llvm::cl::cat(docCat)); using namespace llvm; using namespace mlir; @@ -276,7 +277,8 @@ } os << "\nSyntax:\n\n```\n" - << prefix << def.getDialect().getName() << "." << def.getMnemonic() << "<\n"; + << prefix << def.getDialect().getName() << "." << def.getMnemonic() + << "<\n"; for (const auto &it : llvm::enumerate(parameters)) { const AttrOrTypeParameter ¶m = it.value(); os << " " << param.getSyntax(); @@ -334,24 +336,53 @@ // Dialect Documentation //===----------------------------------------------------------------------===// -static void emitDialectDoc(const Dialect &dialect, - ArrayRef attributes, - ArrayRef attrDefs, ArrayRef ops, - ArrayRef types, ArrayRef typeDefs, - raw_ostream &os) { - os << "# '" << dialect.getName() << "' Dialect\n\n"; - emitIfNotEmpty(dialect.getSummary(), os); - emitIfNotEmpty(dialect.getDescription(), os); +struct OpDocGroup { + const Dialect &getDialect() const { return ops.front().getDialect(); } - // Generate a TOC marker except if description already contains one. - llvm::Regex r("^[[:space:]]*\\[TOC\\]$", llvm::Regex::RegexFlags::Newline); - if (!r.match(dialect.getDescription())) - os << "[TOC]\n\n"; + // Returns the summary description of the section. + std::string summary = ""; + + // Returns the description of the section. + StringRef description = ""; + + // Instances inside the section. + std::vector ops; +}; + +static void maybeNest(bool nest, llvm::function_ref fn, + raw_ostream &os) { + std::string str; + llvm::raw_string_ostream ss(str); + fn(ss); + for (StringRef x : llvm::split(ss.str(), "\n")) { + if (nest && x.starts_with("#")) + os << "#"; + os << x << "\n"; + } +} +static void emitBlock(ArrayRef attributes, + ArrayRef attrDefs, ArrayRef ops, + ArrayRef types, ArrayRef typeDefs, + raw_ostream &os) { if (!ops.empty()) { os << "## Operation definition\n\n"; - for (const Operator &op : ops) - emitOpDoc(op, os); + for (const OpDocGroup &grouping : ops) { + bool nested = grouping.ops.size() > 1; + maybeNest( + nested, + [&](raw_ostream &os) { + if (nested) { + os << "## " << StringRef(grouping.summary).trim() << "\n\n"; + emitDescription(grouping.description, os); + os << "\n\n"; + } + for (const Operator &op : grouping.ops) { + emitOpDoc(op, os); + } + }, + os); + } } if (!attributes.empty()) { @@ -380,6 +411,23 @@ } } +static void emitDialectDoc(const Dialect &dialect, + ArrayRef attributes, + ArrayRef attrDefs, ArrayRef ops, + ArrayRef types, ArrayRef typeDefs, + raw_ostream &os) { + os << "# '" << dialect.getName() << "' Dialect\n\n"; + emitIfNotEmpty(dialect.getSummary(), os); + emitIfNotEmpty(dialect.getDescription(), os); + + // Generate a TOC marker except if description already contains one. + llvm::Regex r("^[[:space:]]*\\[TOC\\]$", llvm::Regex::RegexFlags::Newline); + if (!r.match(dialect.getDescription())) + os << "[TOC]\n\n"; + + emitBlock(attributes, attrDefs, ops, types, typeDefs, os); +} + static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { std::vector dialectDefs = recordKeeper.getAllDerivedDefinitionsIfDefined("Dialect"); @@ -400,26 +448,62 @@ std::vector dialectAttrs; std::vector dialectAttrDefs; - std::vector dialectOps; + std::vector dialectOps; std::vector dialectTypes; std::vector dialectTypeDefs; + llvm::SmallDenseSet seen; auto addIfInDialect = [&](llvm::Record *record, const auto &def, auto &vec) { - if (seen.insert(record).second && def.getDialect() == *dialect) + if (seen.insert(record).second && def.getDialect() == *dialect) { vec.push_back(def); + return true; + } + return false; }; + SmallDenseMap opDocGroup; + for (Record *def : attrDefDefs) addIfInDialect(def, AttrDef(def), dialectAttrDefs); for (Record *def : attrDefs) addIfInDialect(def, Attribute(def), dialectAttrs); - for (Record *def : opDefs) - addIfInDialect(def, Operator(def), dialectOps); + for (Record *def : opDefs) { + if (Record *group = def->getValueAsOptionalDef("opDocGroup")) { + OpDocGroup &op = opDocGroup[group]; + addIfInDialect(def, Operator(def), op.ops); + } else { + OpDocGroup op; + op.ops.emplace_back(def); + addIfInDialect(def, op, dialectOps); + } + } + for (Record *rec : + recordKeeper.getAllDerivedDefinitionsIfDefined("OpDocGroup")) { + if (opDocGroup[rec].ops.empty()) + continue; + opDocGroup[rec].summary = rec->getValueAsString("summary"); + opDocGroup[rec].description = rec->getValueAsString("description"); + dialectOps.push_back(opDocGroup[rec]); + } for (Record *def : typeDefDefs) addIfInDialect(def, TypeDef(def), dialectTypeDefs); for (Record *def : typeDefs) addIfInDialect(def, Type(def), dialectTypes); + // Sort alphabetically ignorning dialect for ops and section name for + // sections. + // TODO: The sorting order could be revised, currently attempting to sort of + // keep in alphabetical order. + std::sort(dialectOps.begin(), dialectOps.end(), + [](const OpDocGroup &lhs, const OpDocGroup &rhs) { + auto getDesc = [](const OpDocGroup &arg) -> StringRef { + if (!arg.summary.empty()) + return arg.summary; + return arg.ops.front().getDef().getValueAsString("opName"); + }; + return getDesc(lhs).compare_insensitive(getDesc(rhs)) < 0; + }); + os << "\n"; emitDialectDoc(*dialect, dialectAttrs, dialectAttrDefs, dialectOps, dialectTypes, dialectTypeDefs, os);