diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -13,13 +13,21 @@ #ifndef MLIR_TABLEGEN_CODEGENHELPERS_H #define MLIR_TABLEGEN_CODEGENHELPERS_H +#include "mlir/Support/IndentedOstream.h" #include "mlir/TableGen/Dialect.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +namespace llvm { +class RecordKeeper; +} // namespace llvm + namespace mlir { namespace tblgen { +class Constraint; + // Simple RAII helper for defining ifdef-undef-endif scopes. class IfDefScope { public: @@ -62,6 +70,82 @@ SmallVector namespaces; }; +/// This class deduplicates shared operation verification code by emitting +/// static functions alongside the op definitions. These methods are local to +/// the definition file, and are invoked within the operation verify methods. +/// An example is shown below: +/// +/// static LogicalResult localVerify(...) +/// +/// LogicalResult OpA::verify(...) { +/// if (failed(localVerify(...))) +/// return failure(); +/// ... +/// } +/// +/// LogicalResult OpB::verify(...) { +/// if (failed(localVerify(...))) +/// return failure(); +/// ... +/// } +/// +class StaticVerifierFunctionEmitter { +public: + StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records, + raw_ostream &os); + + /// Emit the static verifier functions for `llvm::Record`s. The + /// `signatureFormat` describes the required arguments and it must have a + /// placeholder for function name. + /// Example, + /// const char *typeVerifierSignature = + /// "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type" + /// " type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)"; + /// + /// `errorHandlerFormat` describes the error message to return. It may have a + /// placeholder for the summary of Constraint and bring more information for + /// the error message. + /// Example, + /// const char *typeVerifierErrorHandler = + /// " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << " + /// "\" must be {0}, but got \" << type"; + /// + /// `typeArgName` is used to identify the argument that needs to check its + /// type. The constraint template will replace `$_self` with it. + void emitFunctionsFor(StringRef signatureFormat, StringRef errorHandlerFormat, + StringRef typeArgName, ArrayRef opDefs, + bool emitDecl); + + /// Get the name of the local function used for the given type constraint. + /// These functions are used for operand and result constraints and have the + /// form: + /// LogicalResult(Operation *op, Type type, StringRef valueKind, + /// unsigned valueGroupStartIndex); + StringRef getTypeConstraintFn(const Constraint &constraint) const; + +private: + /// Returns a unique name to use when generating local methods. + static std::string getUniqueName(const llvm::RecordKeeper &records); + + /// Emit local methods for the type constraints used within the provided op + /// definitions. + void emitTypeConstraintMethods(StringRef signatureFormat, + StringRef errorHandlerFormat, + StringRef typeArgName, + ArrayRef opDefs, + bool emitDecl); + + raw_indented_ostream os; + + /// A unique label for the file currently being generated. This is used to + /// ensure that the local functions have a unique name. + std::string uniqueOutputLabel; + + /// A set of functions implementing type constraints, used for operand and + /// result verification. + llvm::DenseMap localTypeConstraints; +}; + } // namespace tblgen } // namespace mlir diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -6,6 +6,7 @@ add_tablegen(mlir-tblgen MLIR AttrOrTypeDefGen.cpp + CodeGenHelpers.cpp DialectGen.cpp DirectiveCommonGen.cpp EnumsGen.cpp diff --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp @@ -0,0 +1,139 @@ +//===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// OpDefinitionsGen uses the description of operations to generate C++ +// definitions for ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/CodeGenHelpers.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/Operator.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Path.h" +#include "llvm/TableGen/Record.h" + +using namespace llvm; +using namespace mlir; +using namespace mlir::tblgen; + +StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( + const llvm::RecordKeeper &records, raw_ostream &os) + : os(os), uniqueOutputLabel(getUniqueName(records)) {} + +void StaticVerifierFunctionEmitter::emitFunctionsFor( + StringRef signatureFormat, StringRef errorHandlerFormat, + StringRef typeArgName, ArrayRef opDefs, bool emitDecl) { + llvm::Optional namespaceEmitter; + if (!emitDecl) + namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace()); + + emitTypeConstraintMethods(signatureFormat, errorHandlerFormat, typeArgName, + opDefs, emitDecl); +} + +StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn( + const Constraint &constraint) const { + auto it = localTypeConstraints.find(constraint.getAsOpaquePointer()); + assert(it != localTypeConstraints.end() && "expected valid constraint fn"); + return it->second; +} + +std::string StaticVerifierFunctionEmitter::getUniqueName( + const llvm::RecordKeeper &records) { + // Use the input file name when generating a unique name. + std::string inputFilename = records.getInputFilename(); + + // Drop all but the base filename. + StringRef nameRef = llvm::sys::path::filename(inputFilename); + nameRef.consume_back(".td"); + + // Sanitize any invalid characters. + std::string uniqueName; + for (char c : nameRef) { + if (llvm::isAlnum(c) || c == '_') + uniqueName.push_back(c); + else + uniqueName.append(llvm::utohexstr((unsigned char)c)); + } + return uniqueName; +} + +void StaticVerifierFunctionEmitter::emitTypeConstraintMethods( + StringRef signatureFormat, StringRef errorHandlerFormat, + StringRef typeArgName, ArrayRef opDefs, bool emitDecl) { + // Collect a set of all of the used type constraints within the operation + // definitions. + llvm::SetVector typeConstraints; + for (Record *def : opDefs) { + Operator op(*def); + for (NamedTypeConstraint &operand : op.getOperands()) + if (operand.hasPredicate()) + typeConstraints.insert(operand.constraint.getAsOpaquePointer()); + for (NamedTypeConstraint &result : op.getResults()) + if (result.hasPredicate()) + typeConstraints.insert(result.constraint.getAsOpaquePointer()); + } + + // Record the mapping from predicate to constraint. If two constraints has the + // same predicate and constraint summary, they can share the same verification + // function. + llvm::DenseMap predToConstraint; + FmtContext fctx; + for (auto it : llvm::enumerate(typeConstraints)) { + std::string name; + Constraint constraint = Constraint::getFromOpaquePointer(it.value()); + Pred pred = constraint.getPredicate(); + auto iter = predToConstraint.find(pred); + if (iter != predToConstraint.end()) { + do { + Constraint built = Constraint::getFromOpaquePointer(iter->second); + // We may have the different constraints but have the same predicate, + // for example, ConstraintA and Variadic, note that + // Variadic<> doesn't introduce new predicate. In this case, we can + // share the same predicate function if they also have consistent + // summary, otherwise we may report the wrong message while verification + // fails. + if (constraint.getSummary() == built.getSummary()) { + name = getTypeConstraintFn(built).str(); + break; + } + ++iter; + } while (iter != predToConstraint.end() && iter->first == pred); + } + + if (!name.empty()) { + localTypeConstraints.try_emplace(it.value(), name); + continue; + } + + // Generate an obscure and unique name for this type constraint. + name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel + + Twine(it.index())) + .str(); + predToConstraint.insert( + std::make_pair(constraint.getPredicate(), it.value())); + localTypeConstraints.try_emplace(it.value(), name); + + // Only generate the methods if we are generating definitions. + if (emitDecl) + continue; + + os << formatv(signatureFormat.data(), name) << " {\n"; + os.indent() << "if (!(" + << tgfmt(constraint.getConditionTemplate(), + &fctx.withSelf(typeArgName)) + << ")) {\n"; + os.indent() << "return " + << formatv(errorHandlerFormat.data(), constraint.getSummary()) + << ";\n"; + os.unindent() << "}\nreturn ::mlir::success();\n"; + os.unindent() << "}\n\n"; + } +} diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -24,7 +24,6 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/Support/Path.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -101,6 +100,14 @@ std::next({0}, valueRange.first + valueRange.second)}; )"; +const char *typeVerifierSignature = + "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type " + "type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)"; + +const char *typeVerifierErrorHandler = + " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << \" must " + "be {0}, but got \" << type"; + static const char *const opCommentHeader = R"( //===----------------------------------------------------------------------===// // {0} {1} @@ -108,175 +115,6 @@ )"; -//===----------------------------------------------------------------------===// -// StaticVerifierFunctionEmitter -//===----------------------------------------------------------------------===// - -namespace { -/// This class deduplicates shared operation verification code by emitting -/// static functions alongside the op definitions. These methods are local to -/// the definition file, and are invoked within the operation verify methods. -/// An example is shown below: -/// -/// static LogicalResult localVerify(...) -/// -/// LogicalResult OpA::verify(...) { -/// if (failed(localVerify(...))) -/// return failure(); -/// ... -/// } -/// -/// LogicalResult OpB::verify(...) { -/// if (failed(localVerify(...))) -/// return failure(); -/// ... -/// } -/// -class StaticVerifierFunctionEmitter { -public: - StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records, - ArrayRef opDefs, - raw_ostream &os, bool emitDecl); - - /// Get the name of the local function used for the given type constraint. - /// These functions are used for operand and result constraints and have the - /// form: - /// LogicalResult(Operation *op, Type type, StringRef valueKind, - /// unsigned valueGroupStartIndex); - StringRef getTypeConstraintFn(const Constraint &constraint) const { - auto it = localTypeConstraints.find(constraint.getAsOpaquePointer()); - assert(it != localTypeConstraints.end() && "expected valid constraint fn"); - return it->second; - } - -private: - /// Returns a unique name to use when generating local methods. - static std::string getUniqueName(const llvm::RecordKeeper &records); - - /// Emit local methods for the type constraints used within the provided op - /// definitions. - void emitTypeConstraintMethods(ArrayRef opDefs, - raw_ostream &os, bool emitDecl); - - /// A unique label for the file currently being generated. This is used to - /// ensure that the local functions have a unique name. - std::string uniqueOutputLabel; - - /// A set of functions implementing type constraints, used for operand and - /// result verification. - llvm::DenseMap localTypeConstraints; -}; -} // namespace - -StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( - const llvm::RecordKeeper &records, ArrayRef opDefs, - raw_ostream &os, bool emitDecl) - : uniqueOutputLabel(getUniqueName(records)) { - llvm::Optional namespaceEmitter; - if (!emitDecl) { - os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); - namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace()); - } - - emitTypeConstraintMethods(opDefs, os, emitDecl); -} - -std::string StaticVerifierFunctionEmitter::getUniqueName( - const llvm::RecordKeeper &records) { - // Use the input file name when generating a unique name. - std::string inputFilename = records.getInputFilename(); - - // Drop all but the base filename. - StringRef nameRef = llvm::sys::path::filename(inputFilename); - nameRef.consume_back(".td"); - - // Sanitize any invalid characters. - std::string uniqueName; - for (char c : nameRef) { - if (llvm::isAlnum(c) || c == '_') - uniqueName.push_back(c); - else - uniqueName.append(llvm::utohexstr((unsigned char)c)); - } - return uniqueName; -} - -void StaticVerifierFunctionEmitter::emitTypeConstraintMethods( - ArrayRef opDefs, raw_ostream &os, bool emitDecl) { - // Collect a set of all of the used type constraints within the operation - // definitions. - llvm::SetVector typeConstraints; - for (Record *def : opDefs) { - Operator op(*def); - for (NamedTypeConstraint &operand : op.getOperands()) - if (operand.hasPredicate()) - typeConstraints.insert(operand.constraint.getAsOpaquePointer()); - for (NamedTypeConstraint &result : op.getResults()) - if (result.hasPredicate()) - typeConstraints.insert(result.constraint.getAsOpaquePointer()); - } - - // Record the mapping from predicate to constraint. If two constraints has the - // same predicate and constraint summary, they can share the same verification - // function. - llvm::DenseMap predToConstraint; - FmtContext fctx; - for (auto it : llvm::enumerate(typeConstraints)) { - std::string name; - Constraint constraint = Constraint::getFromOpaquePointer(it.value()); - Pred pred = constraint.getPredicate(); - auto iter = predToConstraint.find(pred); - if (iter != predToConstraint.end()) { - do { - Constraint built = Constraint::getFromOpaquePointer(iter->second); - // We may have the different constraints but have the same predicate, - // for example, ConstraintA and Variadic, note that - // Variadic<> doesn't introduce new predicate. In this case, we can - // share the same predicate function if they also have consistent - // summary, otherwise we may report the wrong message while verification - // fails. - if (constraint.getSummary() == built.getSummary()) { - name = getTypeConstraintFn(built).str(); - break; - } - ++iter; - } while (iter != predToConstraint.end() && iter->first == pred); - } - - if (!name.empty()) { - localTypeConstraints.try_emplace(it.value(), name); - continue; - } - - // Generate an obscure and unique name for this type constraint. - name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel + - Twine(it.index())) - .str(); - predToConstraint.insert( - std::make_pair(constraint.getPredicate(), it.value())); - localTypeConstraints.try_emplace(it.value(), name); - - // Only generate the methods if we are generating definitions. - if (emitDecl) - continue; - - os << "static ::mlir::LogicalResult " << name - << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef " - "valueKind, unsigned valueGroupStartIndex) {\n"; - - os << " if (!(" - << tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type")) - << ")) {\n" - << formatv( - " return op->emitOpError(valueKind) << \" #\" << " - "valueGroupStartIndex << \" must be {0}, but got \" << type;\n", - constraint.getSummary()) - << " }\n" - << " return ::mlir::success();\n" - << "}\n\n"; - } -} - //===----------------------------------------------------------------------===// // Utility structs and functions //===----------------------------------------------------------------------===// @@ -2560,8 +2398,12 @@ return; // Generate all of the locally instantiated methods first. - StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os, - emitDecl); + StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, os); + os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); + staticVerifierEmitter.emitFunctionsFor( + typeVerifierSignature, typeVerifierErrorHandler, /*typeArgName=*/"type", + defs, emitDecl); + for (auto *def : defs) { Operator op(*def); if (emitDecl) {