diff --git a/mlir/include/mlir/TableGen/ODSDialectHook.h b/mlir/include/mlir/TableGen/ODSDialectHook.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/TableGen/ODSDialectHook.h @@ -0,0 +1,37 @@ +//===- ODSDialectHook.h - Dialect customization hooks into ODS --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines ODS customization hooks for dialects to programmatically +// emit dialect specific contents in ODS C++ code emission. +// +//===----------------------------------------------------------------------===// + +#include + +namespace llvm { +class StringRef; +} + +namespace mlir { +namespace tblgen { +class Operator; +class OpClass; + +// The emission function for dialect specific content. It takes in an Operator +// and updates the OpClass accordingly. +using DialectEmitFunction = + std::function; + +// ODSDialectHookRegistration provides a global initializer that registers a +// dialect specific content emission function. +struct ODSDialectHookRegistration { + ODSDialectHookRegistration(llvm::StringRef dialectName, + DialectEmitFunction emitFn); +}; +} // namespace tblgen +} // namespace mlir diff --git a/mlir/include/mlir/TableGen/OpClass.h b/mlir/include/mlir/TableGen/OpClass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/TableGen/OpClass.h @@ -0,0 +1,167 @@ +//===- OpClass.h - Helper classes for Op C++ code emission ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines several classes for Op C++ code emission. They are only +// expected to be used by MLIR TableGen backends. +// +// We emit the op declaration and definition into separate files: *Ops.h.inc +// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and +// the latter for dialect *Ops.cpp. This way provides a cleaner interface. +// +// In order to do this split, we need to track method signature and +// implementation logic separately. Signature information is used for both +// declaration and definition, while implementation logic is only for +// definition. So we have the following classes for C++ code emission. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_OPCLASS_H_ +#define MLIR_TABLEGEN_OPCLASS_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#include + +namespace mlir { +namespace tblgen { +class FmtObjectBase; + +// Class for holding the signature of an op's method for C++ code emission +class OpMethodSignature { +public: + OpMethodSignature(StringRef retType, StringRef name, StringRef params); + + // Writes the signature as a method declaration to the given `os`. + void writeDeclTo(raw_ostream &os) const; + // Writes the signature as the start of a method definition to the given `os`. + // `namePrefix` is the prefix to be prepended to the method name (typically + // namespaces for qualifying the method definition). + void writeDefTo(raw_ostream &os, StringRef namePrefix) const; + +private: + // Returns true if the given C++ `type` ends with '&' or '*', or is empty. + static bool elideSpaceAfterType(StringRef type); + + std::string returnType; + std::string methodName; + std::string parameters; +}; + +// Class for holding the body of an op's method for C++ code emission +class OpMethodBody { +public: + explicit OpMethodBody(bool declOnly); + + OpMethodBody &operator<<(Twine content); + OpMethodBody &operator<<(int content); + OpMethodBody &operator<<(const FmtObjectBase &content); + + void writeTo(raw_ostream &os) const; + +private: + // Whether this class should record method body. + bool isEffective; + std::string body; +}; + +// Class for holding an op's method for C++ code emission +class OpMethod { +public: + // Properties (qualifiers) of class methods. Bitfield is used here to help + // querying properties. + enum Property { + MP_None = 0x0, + MP_Static = 0x1, // Static method + MP_Constructor = 0x2, // Constructor + MP_Private = 0x4, // Private method + }; + + OpMethod(StringRef retType, StringRef name, StringRef params, + Property property, bool declOnly); + + OpMethodBody &body(); + + // Returns true if this is a static method. + bool isStatic() const; + + // Returns true if this is a private method. + bool isPrivate() const; + + // Writes the method as a declaration to the given `os`. + void writeDeclTo(raw_ostream &os) const; + // Writes the method as a definition to the given `os`. `namePrefix` is the + // prefix to be prepended to the method name (typically namespaces for + // qualifying the method definition). + void writeDefTo(raw_ostream &os, StringRef namePrefix) const; + +private: + Property properties; + // Whether this method only contains a declaration. + bool isDeclOnly; + OpMethodSignature methodSignature; + OpMethodBody methodBody; +}; + +// A class used to emit C++ classes from Tablegen. Contains a list of public +// methods and a list of private fields to be emitted. +class Class { +public: + explicit Class(StringRef name); + + // Creates a new method in this class. + OpMethod &newMethod(StringRef retType, StringRef name, StringRef params = "", + OpMethod::Property = OpMethod::MP_None, + bool declOnly = false); + + OpMethod &newConstructor(StringRef params = "", bool declOnly = false); + + // Creates a new field in this class. + void newField(StringRef type, StringRef name, StringRef defaultValue = ""); + + // Writes this op's class as a declaration to the given `os`. + void writeDeclTo(raw_ostream &os) const; + // Writes the method definitions in this op's class to the given `os`. + void writeDefTo(raw_ostream &os) const; + + // Returns the C++ class name of the op. + StringRef getClassName() const { return className; } + +protected: + std::string className; + SmallVector methods; + SmallVector fields; +}; + +// Class for holding an op for C++ code emission +class OpClass : public Class { +public: + explicit OpClass(StringRef name, StringRef extraClassDeclaration = ""); + + // Sets whether this OpClass should generate the using directive for its + // associate operand adaptor class. + void setHasOperandAdaptorClass(bool has); + + // Adds an op trait. + void addTrait(Twine trait); + + // Writes this op's class as a declaration to the given `os`. Redefines + // Class::writeDeclTo to also emit traits and extra class declarations. + void writeDeclTo(raw_ostream &os) const; + +private: + StringRef extraClassDeclaration; + SmallVector traits; + bool hasOperandAdaptor; +}; + +} // namespace tblgen +} // namespace mlir + +#endif // MLIR_TABLEGEN_OPCLASS_H_ diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -46,6 +46,9 @@ // Returns this op's dialect name. StringRef getDialectName() const; + // Returns the dialect of the op. + const Dialect &getDialect() const { return dialect; } + // Returns the operation name. The name will follow the "." // format if its dialect name is not empty. std::string getOperationName() const; @@ -156,14 +159,8 @@ StringRef getExtraClassDeclaration() const; // Returns the Tablegen definition this operator was constructed from. - // TODO(antiagainst,zinenko): do not expose the TableGen record, this is a - // temporary solution to OpEmitter requiring a Record because Operator does - // not provide enough methods. const llvm::Record &getDef() const; - // Returns the dialect of the op. - const Dialect &getDialect() const { return dialect; } - // Prints the contents in this operator to the given `os`. This is used for // debugging purposes. void print(llvm::raw_ostream &os) const; diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt --- a/mlir/lib/TableGen/CMakeLists.txt +++ b/mlir/lib/TableGen/CMakeLists.txt @@ -5,6 +5,7 @@ Dialect.cpp Format.cpp Operator.cpp + OpClass.cpp OpInterfaces.cpp OpTrait.cpp Pattern.cpp diff --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/OpClass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/TableGen/OpClass.cpp @@ -0,0 +1,235 @@ +//===- OpClass.cpp - Helper classes for Op C++ code emission --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/OpClass.h" + +#include "mlir/TableGen/Format.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// OpMethodSignature definitions +//===----------------------------------------------------------------------===// + +tblgen::OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name, + StringRef params) + : returnType(retType), methodName(name), parameters(params) {} + +void tblgen::OpMethodSignature::writeDeclTo(raw_ostream &os) const { + os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << methodName + << "(" << parameters << ")"; +} + +void tblgen::OpMethodSignature::writeDefTo(raw_ostream &os, + StringRef namePrefix) const { + // We need to remove the default values for parameters in method definition. + // TODO(antiagainst): We are using '=' and ',' as delimiters for parameter + // initializers. This is incorrect for initializer list with more than one + // element. Change to a more robust approach. + auto removeParamDefaultValue = [](StringRef params) { + std::string result; + std::pair parts; + while (!params.empty()) { + parts = params.split("="); + result.append(result.empty() ? "" : ", "); + result.append(parts.first); + params = parts.second.split(",").second; + } + return result; + }; + + os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << namePrefix + << (namePrefix.empty() ? "" : "::") << methodName << "(" + << removeParamDefaultValue(parameters) << ")"; +} + +bool tblgen::OpMethodSignature::elideSpaceAfterType(StringRef type) { + return type.empty() || type.endswith("&") || type.endswith("*"); +} + +//===----------------------------------------------------------------------===// +// OpMethodBody definitions +//===----------------------------------------------------------------------===// + +tblgen::OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {} + +tblgen::OpMethodBody &tblgen::OpMethodBody::operator<<(Twine content) { + if (isEffective) + body.append(content.str()); + return *this; +} + +tblgen::OpMethodBody &tblgen::OpMethodBody::operator<<(int content) { + if (isEffective) + body.append(std::to_string(content)); + return *this; +} + +tblgen::OpMethodBody & +tblgen::OpMethodBody::operator<<(const FmtObjectBase &content) { + if (isEffective) + body.append(content.str()); + return *this; +} + +void tblgen::OpMethodBody::writeTo(raw_ostream &os) const { + auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; }); + os << bodyRef; + if (bodyRef.empty() || bodyRef.back() != '\n') + os << "\n"; +} + +//===----------------------------------------------------------------------===// +// OpMethod definitions +//===----------------------------------------------------------------------===// + +tblgen::OpMethod::OpMethod(StringRef retType, StringRef name, StringRef params, + OpMethod::Property property, bool declOnly) + : properties(property), isDeclOnly(declOnly), + methodSignature(retType, name, params), methodBody(declOnly) {} + +tblgen::OpMethodBody &tblgen::OpMethod::body() { return methodBody; } + +bool tblgen::OpMethod::isStatic() const { return properties & MP_Static; } + +bool tblgen::OpMethod::isPrivate() const { return properties & MP_Private; } + +void tblgen::OpMethod::writeDeclTo(raw_ostream &os) const { + os.indent(2); + if (isStatic()) + os << "static "; + methodSignature.writeDeclTo(os); + os << ";"; +} + +void tblgen::OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const { + if (isDeclOnly) + return; + + methodSignature.writeDefTo(os, namePrefix); + os << " {\n"; + methodBody.writeTo(os); + os << "}"; +} + +//===----------------------------------------------------------------------===// +// Class definitions +//===----------------------------------------------------------------------===// + +tblgen::Class::Class(StringRef name) : className(name) {} + +tblgen::OpMethod &tblgen::Class::newMethod(StringRef retType, StringRef name, + StringRef params, + OpMethod::Property property, + bool declOnly) { + methods.emplace_back(retType, name, params, property, declOnly); + return methods.back(); +} + +tblgen::OpMethod &tblgen::Class::newConstructor(StringRef params, + bool declOnly) { + return newMethod("", getClassName(), params, OpMethod::MP_Constructor, + declOnly); +} + +void tblgen::Class::newField(StringRef type, StringRef name, + StringRef defaultValue) { + std::string varName = formatv("{0} {1}", type, name).str(); + std::string field = defaultValue.empty() + ? varName + : formatv("{0} = {1}", varName, defaultValue).str(); + fields.push_back(std::move(field)); +} + +void tblgen::Class::writeDeclTo(raw_ostream &os) const { + bool hasPrivateMethod = false; + os << "class " << className << " {\n"; + os << "public:\n"; + for (const auto &method : methods) { + if (!method.isPrivate()) { + method.writeDeclTo(os); + os << '\n'; + } else { + hasPrivateMethod = true; + } + } + os << '\n'; + os << "private:\n"; + if (hasPrivateMethod) { + for (const auto &method : methods) { + if (method.isPrivate()) { + method.writeDeclTo(os); + os << '\n'; + } + } + os << '\n'; + } + for (const auto &field : fields) + os.indent(2) << field << ";\n"; + os << "};\n"; +} + +void tblgen::Class::writeDefTo(raw_ostream &os) const { + for (const auto &method : methods) { + method.writeDefTo(os, className); + os << "\n\n"; + } +} + +//===----------------------------------------------------------------------===// +// OpClass definitions +//===----------------------------------------------------------------------===// + +tblgen::OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) + : Class(name), extraClassDeclaration(extraClassDeclaration), + hasOperandAdaptor(true) {} + +void tblgen::OpClass::setHasOperandAdaptorClass(bool has) { + hasOperandAdaptor = has; +} + +// Adds the given trait to this op. +void tblgen::OpClass::addTrait(Twine trait) { traits.push_back(trait.str()); } + +void tblgen::OpClass::writeDeclTo(raw_ostream &os) const { + os << "class " << className << " : public Op<" << className; + for (const auto &trait : traits) + os << ", " << trait; + os << "> {\npublic:\n"; + os << " using Op::Op;\n"; + if (hasOperandAdaptor) + os << " using OperandAdaptor = " << className << "OperandAdaptor;\n"; + + bool hasPrivateMethod = false; + for (const auto &method : methods) { + if (!method.isPrivate()) { + method.writeDeclTo(os); + os << "\n"; + } else { + hasPrivateMethod = true; + } + } + + // TODO: Add line control markers to make errors easier to debug. + if (!extraClassDeclaration.empty()) + os << extraClassDeclaration << "\n"; + + if (hasPrivateMethod) { + os << "\nprivate:\n"; + for (const auto &method : methods) { + if (method.isPrivate()) { + method.writeDeclTo(os); + os << "\n"; + } + } + } + + os << "};\n"; +} diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -14,10 +14,13 @@ #include "mlir/Support/STLExtras.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/ODSDialectHook.h" +#include "mlir/TableGen/OpClass.h" #include "mlir/TableGen/OpInterfaces.h" #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -25,10 +28,35 @@ #define DEBUG_TYPE "mlir-tblgen-opdefgen" -using namespace llvm; using namespace mlir; using namespace mlir::tblgen; +using llvm::CodeInit; +using llvm::DefInit; +using llvm::formatv; +using llvm::Init; +using llvm::ListInit; +using llvm::Record; +using llvm::RecordKeeper; +using llvm::StringInit; + +//===----------------------------------------------------------------------===// +// Dialect hook registration +//===----------------------------------------------------------------------===// + +static llvm::ManagedStatic> dialectHooks; + +ODSDialectHookRegistration::ODSDialectHookRegistration( + StringRef dialectName, DialectEmitFunction emitFn) { + bool inserted = dialectHooks->try_emplace(dialectName, emitFn).second; + assert(inserted && "Multiple ODS hooks for the same dialect!"); + (void)inserted; +} + +//===----------------------------------------------------------------------===// +// Static string definitions +//===----------------------------------------------------------------------===// + static const char *const tblgenNamePrefix = "tblgen_"; static const char *const generatedArgName = "tblgen_arg"; static const char *const builderOpState = "tblgen_state"; @@ -114,6 +142,10 @@ !attr.getConstBuilderTemplate().empty(); } +//===----------------------------------------------------------------------===// +// Op emitter +//===----------------------------------------------------------------------===// + namespace { // Simple RAII helper for defining ifdef-undef-endif scopes. class IfDefScope { @@ -131,346 +163,6 @@ }; } // end anonymous namespace -//===----------------------------------------------------------------------===// -// Classes for C++ code emission -//===----------------------------------------------------------------------===// - -// We emit the op declaration and definition into separate files: *Ops.h.inc -// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and -// the latter for dialect *Ops.cpp. This way provides a cleaner interface. -// -// In order to do this split, we need to track method signature and -// implementation logic separately. Signature information is used for both -// declaration and definition, while implementation logic is only for -// definition. So we have the following classes for C++ code emission. - -namespace { -// Class for holding the signature of an op's method for C++ code emission -class OpMethodSignature { -public: - OpMethodSignature(StringRef retType, StringRef name, StringRef params); - - // Writes the signature as a method declaration to the given `os`. - void writeDeclTo(raw_ostream &os) const; - // Writes the signature as the start of a method definition to the given `os`. - // `namePrefix` is the prefix to be prepended to the method name (typically - // namespaces for qualifying the method definition). - void writeDefTo(raw_ostream &os, StringRef namePrefix) const; - -private: - // Returns true if the given C++ `type` ends with '&' or '*', or is empty. - static bool elideSpaceAfterType(StringRef type); - - std::string returnType; - std::string methodName; - std::string parameters; -}; - -// Class for holding the body of an op's method for C++ code emission -class OpMethodBody { -public: - explicit OpMethodBody(bool declOnly); - - OpMethodBody &operator<<(Twine content); - OpMethodBody &operator<<(int content); - OpMethodBody &operator<<(const FmtObjectBase &content); - - void writeTo(raw_ostream &os) const; - -private: - // Whether this class should record method body. - bool isEffective; - std::string body; -}; - -// Class for holding an op's method for C++ code emission -class OpMethod { -public: - // Properties (qualifiers) of class methods. Bitfield is used here to help - // querying properties. - enum Property { - MP_None = 0x0, - MP_Static = 0x1, // Static method - MP_Constructor = 0x2, // Constructor - MP_Private = 0x4, // Private method - }; - - OpMethod(StringRef retType, StringRef name, StringRef params, - Property property, bool declOnly); - - OpMethodBody &body(); - - // Returns true if this is a static method. - bool isStatic() const; - - // Returns true if this is a private method. - bool isPrivate() const; - - // Writes the method as a declaration to the given `os`. - void writeDeclTo(raw_ostream &os) const; - // Writes the method as a definition to the given `os`. `namePrefix` is the - // prefix to be prepended to the method name (typically namespaces for - // qualifying the method definition). - void writeDefTo(raw_ostream &os, StringRef namePrefix) const; - -private: - Property properties; - // Whether this method only contains a declaration. - bool isDeclOnly; - OpMethodSignature methodSignature; - OpMethodBody methodBody; -}; - -// A class used to emit C++ classes from Tablegen. Contains a list of public -// methods and a list of private fields to be emitted. -class Class { -public: - explicit Class(StringRef name); - - // Creates a new method in this class. - OpMethod &newMethod(StringRef retType, StringRef name, StringRef params = "", - OpMethod::Property = OpMethod::MP_None, - bool declOnly = false); - - OpMethod &newConstructor(StringRef params = "", bool declOnly = false); - - // Creates a new field in this class. - void newField(StringRef type, StringRef name, StringRef defaultValue = ""); - - // Writes this op's class as a declaration to the given `os`. - void writeDeclTo(raw_ostream &os) const; - // Writes the method definitions in this op's class to the given `os`. - void writeDefTo(raw_ostream &os) const; - - // Returns the C++ class name of the op. - StringRef getClassName() const { return className; } - -protected: - std::string className; - SmallVector methods; - SmallVector fields; -}; - -// Class for holding an op for C++ code emission -class OpClass : public Class { -public: - explicit OpClass(StringRef name, StringRef extraClassDeclaration = ""); - - // Sets whether this OpClass should generate the using directive for its - // associate operand adaptor class. - void setHasOperandAdaptorClass(bool has); - - // Adds an op trait. - void addTrait(Twine trait); - - // Writes this op's class as a declaration to the given `os`. Redefines - // Class::writeDeclTo to also emit traits and extra class declarations. - void writeDeclTo(raw_ostream &os) const; - -private: - StringRef extraClassDeclaration; - SmallVector traits; - bool hasOperandAdaptor; -}; -} // end anonymous namespace - -OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name, - StringRef params) - : returnType(retType), methodName(name), parameters(params) {} - -void OpMethodSignature::writeDeclTo(raw_ostream &os) const { - os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << methodName - << "(" << parameters << ")"; -} - -void OpMethodSignature::writeDefTo(raw_ostream &os, - StringRef namePrefix) const { - // We need to remove the default values for parameters in method definition. - // TODO(antiagainst): We are using '=' and ',' as delimiters for parameter - // initializers. This is incorrect for initializer list with more than one - // element. Change to a more robust approach. - auto removeParamDefaultValue = [](StringRef params) { - std::string result; - std::pair parts; - while (!params.empty()) { - parts = params.split("="); - result.append(result.empty() ? "" : ", "); - result.append(parts.first); - params = parts.second.split(",").second; - } - return result; - }; - - os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << namePrefix - << (namePrefix.empty() ? "" : "::") << methodName << "(" - << removeParamDefaultValue(parameters) << ")"; -} - -bool OpMethodSignature::elideSpaceAfterType(StringRef type) { - return type.empty() || type.endswith("&") || type.endswith("*"); -} - -OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {} - -OpMethodBody &OpMethodBody::operator<<(Twine content) { - if (isEffective) - body.append(content.str()); - return *this; -} - -OpMethodBody &OpMethodBody::operator<<(int content) { - if (isEffective) - body.append(std::to_string(content)); - return *this; -} - -OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) { - if (isEffective) - body.append(content.str()); - return *this; -} - -void OpMethodBody::writeTo(raw_ostream &os) const { - auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; }); - os << bodyRef; - if (bodyRef.empty() || bodyRef.back() != '\n') - os << "\n"; -} - -OpMethod::OpMethod(StringRef retType, StringRef name, StringRef params, - OpMethod::Property property, bool declOnly) - : properties(property), isDeclOnly(declOnly), - methodSignature(retType, name, params), methodBody(declOnly) {} - -OpMethodBody &OpMethod::body() { return methodBody; } - -bool OpMethod::isStatic() const { return properties & MP_Static; } - -bool OpMethod::isPrivate() const { return properties & MP_Private; } - -void OpMethod::writeDeclTo(raw_ostream &os) const { - os.indent(2); - if (isStatic()) - os << "static "; - methodSignature.writeDeclTo(os); - os << ";"; -} - -void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const { - if (isDeclOnly) - return; - - methodSignature.writeDefTo(os, namePrefix); - os << " {\n"; - methodBody.writeTo(os); - os << "}"; -} - -Class::Class(StringRef name) : className(name) {} - -OpMethod &Class::newMethod(StringRef retType, StringRef name, StringRef params, - OpMethod::Property property, bool declOnly) { - methods.emplace_back(retType, name, params, property, declOnly); - return methods.back(); -} - -OpMethod &Class::newConstructor(StringRef params, bool declOnly) { - return newMethod("", getClassName(), params, OpMethod::MP_Constructor, - declOnly); -} - -void Class::newField(StringRef type, StringRef name, StringRef defaultValue) { - std::string varName = formatv("{0} {1}", type, name).str(); - std::string field = defaultValue.empty() - ? varName - : formatv("{0} = {1}", varName, defaultValue).str(); - fields.push_back(std::move(field)); -} - -void Class::writeDeclTo(raw_ostream &os) const { - bool hasPrivateMethod = false; - os << "class " << className << " {\n"; - os << "public:\n"; - for (const auto &method : methods) { - if (!method.isPrivate()) { - method.writeDeclTo(os); - os << '\n'; - } else { - hasPrivateMethod = true; - } - } - os << '\n'; - os << "private:\n"; - if (hasPrivateMethod) { - for (const auto &method : methods) { - if (method.isPrivate()) { - method.writeDeclTo(os); - os << '\n'; - } - } - os << '\n'; - } - for (const auto &field : fields) - os.indent(2) << field << ";\n"; - os << "};\n"; -} - -void Class::writeDefTo(raw_ostream &os) const { - for (const auto &method : methods) { - method.writeDefTo(os, className); - os << "\n\n"; - } -} - -OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) - : Class(name), extraClassDeclaration(extraClassDeclaration), - hasOperandAdaptor(true) {} - -void OpClass::setHasOperandAdaptorClass(bool has) { hasOperandAdaptor = has; } - -// Adds the given trait to this op. -void OpClass::addTrait(Twine trait) { traits.push_back(trait.str()); } - -void OpClass::writeDeclTo(raw_ostream &os) const { - os << "class " << className << " : public Op<" << className; - for (const auto &trait : traits) - os << ", " << trait; - os << "> {\npublic:\n"; - os << " using Op::Op;\n"; - if (hasOperandAdaptor) - os << " using OperandAdaptor = " << className << "OperandAdaptor;\n"; - - bool hasPrivateMethod = false; - for (const auto &method : methods) { - if (!method.isPrivate()) { - method.writeDeclTo(os); - os << "\n"; - } else { - hasPrivateMethod = true; - } - } - - // TODO: Add line control markers to make errors easier to debug. - if (!extraClassDeclaration.empty()) - os << extraClassDeclaration << "\n"; - - if (hasPrivateMethod) { - os << "\nprivate:\n"; - for (const auto &method : methods) { - if (method.isPrivate()) { - method.writeDeclTo(os); - os << "\n"; - } - } - } - - os << "};\n"; -} - -//===----------------------------------------------------------------------===// -// Op emitter -//===----------------------------------------------------------------------===// - namespace { // Helper class to emit a record into the given output stream. class OpEmitter { @@ -614,6 +306,7 @@ verifyCtx.withOp("(*this->getOperation())"); genTraits(); + // Generate C++ code for various op methods. The order here determines the // methods in the generated file. genOpAsmInterface(); @@ -629,6 +322,13 @@ genCanonicalizerDecls(); genFolderDecls(); genOpInterfaceMethods(); + + // If a dialect hook is registered for this op's dialect, emit dialect + // specific content. + auto dialectHookIt = dialectHooks->find(op.getDialectName()); + if (dialectHookIt != dialectHooks->end()) { + dialectHookIt->second(op, opClass); + } } void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {