diff --git a/mlir/test/mlir-tblgen/gen-dialect-doc.td b/mlir/test/mlir-tblgen/gen-dialect-doc.td --- a/mlir/test/mlir-tblgen/gen-dialect-doc.td +++ b/mlir/test/mlir-tblgen/gen-dialect-doc.td @@ -1,4 +1,5 @@ -// RUN: mlir-tblgen -gen-dialect-doc -I %S/../../include %s | FileCheck %s +// RUN: mlir-tblgen -gen-dialect-doc -I %S/../../include -dialect=test %s | FileCheck %s +// RUN: mlir-tblgen -gen-dialect-doc -I %S/../../include -dialect=test_toc %s | FileCheck %s --check-prefix=CHECK_TOC include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -55,6 +56,6 @@ } def BOp : Op; -// CHECK: Dialect with -// CHECK: [TOC] -// CHECK: here. +// CHECK_TOC: Dialect with +// CHECK_TOC: [TOC] +// CHECK_TOC: here. diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "DialectGenUtilities.h" #include "mlir/TableGen/Class.h" #include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" @@ -55,12 +56,10 @@ DialectFilterIterator(records.end(), records.end(), filterFn)}; } -static Optional -findSelectedDialect(ArrayRef dialectDefs) { +Optional tblgen::findDialectToGenerate(ArrayRef dialects) { // Select the dialect to gen for. - if (dialectDefs.size() == 1 && selectedDialect.getNumOccurrences() == 0) { - return Dialect(dialectDefs.front()); - } + if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0) + return dialects.front(); if (selectedDialect.getNumOccurrences() == 0) { llvm::errs() << "when more than 1 dialect is present, one must be selected " @@ -68,15 +67,14 @@ return llvm::None; } - const auto *dialectIt = - llvm::find_if(dialectDefs, [](const llvm::Record *def) { - return Dialect(def).getName() == selectedDialect; - }); - if (dialectIt == dialectDefs.end()) { + const auto *dialectIt = llvm::find_if(dialects, [](const Dialect &dialect) { + return dialect.getName() == selectedDialect; + }); + if (dialectIt == dialects.end()) { llvm::errs() << "selected dialect with '-dialect' does not exist\n"; return llvm::None; } - return Dialect(*dialectIt); + return *dialectIt; } //===----------------------------------------------------------------------===// @@ -235,7 +233,8 @@ if (dialectDefs.empty()) return false; - Optional dialect = findSelectedDialect(dialectDefs); + SmallVector dialects(dialectDefs.begin(), dialectDefs.end()); + Optional dialect = findDialectToGenerate(dialects); if (!dialect) return true; auto attrDefs = recordKeeper.getAllDerivedDefinitions("DialectAttr"); @@ -308,7 +307,8 @@ if (dialectDefs.empty()) return false; - Optional dialect = findSelectedDialect(dialectDefs); + SmallVector dialects(dialectDefs.begin(), dialectDefs.end()); + Optional dialect = findDialectToGenerate(dialects); if (!dialect) return true; emitDialectDef(*dialect, os); diff --git a/mlir/tools/mlir-tblgen/DialectGenUtilities.h b/mlir/tools/mlir-tblgen/DialectGenUtilities.h new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/DialectGenUtilities.h @@ -0,0 +1,24 @@ +//===- DialectGenUtilities.h - Utilities for dialect generation -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRTBLGEN_DIALECTGENUTILITIES_H_ +#define MLIR_TOOLS_MLIRTBLGEN_DIALECTGENUTILITIES_H_ + +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace tblgen { +class Dialect; + +/// Find the dialect selected by the user to generate for. Returns None if no +/// dialect was found, or if more than one potential dialect was found. +Optional findDialectToGenerate(ArrayRef dialects); +} // namespace tblgen +} // namespace mlir + +#endif // MLIR_TOOLS_MLIRTBLGEN_DIALECTGENUTILITIES_H_ diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "DialectGenUtilities.h" #include "DocGenUtilities.h" #include "OpGenHelpers.h" #include "mlir/Support/IndentedOstream.h" @@ -18,6 +19,7 @@ #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" @@ -35,8 +37,6 @@ using mlir::tblgen::Operator; -extern llvm::cl::opt selectedDialect; - // Emit the description by aligning the text to the left per line (e.g., // removing the minimum indentation across the block). // @@ -307,9 +307,6 @@ ArrayRef attrDefs, ArrayRef ops, ArrayRef types, ArrayRef typeDefs, raw_ostream &os) { - if (selectedDialect.getNumOccurrences() && - dialect.getName() != selectedDialect) - return; os << "# '" << dialect.getName() << "' Dialect\n\n"; emitIfNotEmpty(dialect.getSummary(), os); emitIfNotEmpty(dialect.getDescription(), os); @@ -351,7 +348,7 @@ } } -static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { +static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { std::vector opDefs = getRequestedOpDefinitions(recordKeeper); std::vector attrDefs = recordKeeper.getAllDerivedDefinitionsIfDefined("DialectAttr"); @@ -362,7 +359,8 @@ std::vector attrDefDefs = recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef"); - std::set dialectsWithDocs; + llvm::SetVector, std::set> + dialectsWithDocs; llvm::StringMap> dialectAttrs; llvm::StringMap> dialectAttrDefs; @@ -399,13 +397,17 @@ dialectsWithDocs.insert(type.getDialect()); } + Optional dialect = + findDialectToGenerate(dialectsWithDocs.getArrayRef()); + if (!dialect) + return true; + os << "\n"; - for (const Dialect &dialect : dialectsWithDocs) { - StringRef dialectName = dialect.getName(); - emitDialectDoc(dialect, dialectAttrs[dialectName], - dialectAttrDefs[dialectName], dialectOps[dialectName], - dialectTypes[dialectName], dialectTypeDefs[dialectName], os); - } + StringRef dialectName = dialect->getName(); + emitDialectDoc(*dialect, dialectAttrs[dialectName], + dialectAttrDefs[dialectName], dialectOps[dialectName], + dialectTypes[dialectName], dialectTypeDefs[dialectName], os); + return false; } //===----------------------------------------------------------------------===// @@ -437,6 +439,5 @@ static mlir::GenRegistration genRegister("gen-dialect-doc", "Generate dialect documentation", [](const RecordKeeper &records, raw_ostream &os) { - emitDialectDoc(records, os); - return false; + return emitDialectDoc(records, os); });