diff --git a/mlir/include/mlir/TableGen/Builder.h b/mlir/include/mlir/TableGen/Builder.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/TableGen/Builder.h @@ -0,0 +1,85 @@ +//===- Builder.h - Builder classes ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Builder wrapper to simplify using TableGen Record for building +// operations/types/etc. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_BUILDER_H_ +#define MLIR_TABLEGEN_BUILDER_H_ + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +namespace llvm { +class Init; +class Record; +class SMLoc; +} // end namespace llvm + +namespace mlir { +namespace tblgen { + +/// Wrapper class with helper methods for accessing Builders defined in +/// TableGen. +class Builder { +public: + /// This class represents a single parameter to a builder method. + class Parameter { + public: + /// Return a string containing the C++ type of this parameter. + StringRef getCppType() const; + + /// Return an optional string containing the name of this parameter. If + /// None, no name was specified for this parameter by the user. + Optional getName() const { return name; } + + /// Return an optional string containing the default value to use for this + /// parameter. + Optional getDefaultValue() const; + + private: + Parameter(Optional name, const llvm::Init *def) + : name(name), def(def) {} + + /// The optional name of the parameter. + Optional name; + + /// The tablegen definition of the parameter. This is either a StringInit, + /// or a CArg DefInit. + const llvm::Init *def; + + // Allow access to the constructor. + friend Builder; + }; + + /// Construct a builder from the given Record instance. + Builder(const llvm::Record *record, ArrayRef loc); + + /// Return a list of parameters used in this build method. + ArrayRef getParameters() const { return parameters; } + + /// Return an optional string containing the body of the builder. + Optional getBody() const; + +protected: + /// The TableGen definition of this builder. + const llvm::Record *def; + +private: + /// A collection of parameters to the builder. + SmallVector parameters; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_BUILDER_H_ diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -16,6 +16,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/TableGen/Argument.h" #include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/Builder.h" #include "mlir/TableGen/Dialect.h" #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Region.h" @@ -287,6 +288,9 @@ // Returns the OperandOrAttribute corresponding to the index. OperandOrAttribute getArgToOperandOrAttribute(int index) const; + // Returns the builders of this operation. + ArrayRef getBuilders() const { return builders; } + private: // Populates the vectors containing operands, attributes, results and traits. void populateOpStructure(); @@ -332,6 +336,9 @@ // Map from argument to attribute or operand number. SmallVector attrOrOperandMapping; + // The builders of this operator. + SmallVector builders; + // The number of native attributes stored in the leading positions of // `attributes`. int numNativeAttributes; diff --git a/mlir/lib/TableGen/Builder.cpp b/mlir/lib/TableGen/Builder.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/TableGen/Builder.cpp @@ -0,0 +1,74 @@ +//===- Builder.cpp - Builder definitions ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/Builder.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// Builder::Parameter +//===----------------------------------------------------------------------===// + +/// Return a string containing the C++ type of this parameter. +StringRef Builder::Parameter::getCppType() const { + if (const auto *stringInit = dyn_cast(def)) + return stringInit->getValue(); + const llvm::Record *record = cast(def)->getDef(); + return record->getValueAsString("type"); +} + +/// Return an optional string containing the default value to use for this +/// parameter. +Optional Builder::Parameter::getDefaultValue() const { + if (isa(def)) + return llvm::None; + const llvm::Record *record = cast(def)->getDef(); + Optional value = record->getValueAsOptionalString("defaultValue"); + return value && !value->empty() ? value : llvm::None; +} + +//===----------------------------------------------------------------------===// +// Builder +//===----------------------------------------------------------------------===// + +Builder::Builder(const llvm::Record *record, ArrayRef loc) + : def(record) { + // Initialize the parameters of the builder. + const llvm::DagInit *dag = def->getValueAsDag("dagParams"); + auto *defInit = dyn_cast(dag->getOperator()); + if (!defInit || !defInit->getDef()->getName().equals("ins")) + PrintFatalError(def->getLoc(), "expected 'ins' in builders"); + + bool seenDefaultValue = false; + for (unsigned i = 0, e = dag->getNumArgs(); i < e; ++i) { + const llvm::StringInit *paramName = dag->getArgName(i); + const llvm::Init *paramValue = dag->getArg(i); + Parameter param(paramName ? paramName->getValue() : Optional(), + paramValue); + + // Similarly to C++, once an argument with a default value is detected, the + // following arguments must have default values as well. + if (param.getDefaultValue()) { + seenDefaultValue = true; + } else if (seenDefaultValue) { + PrintFatalError(loc, + "expected an argument with default value after other " + "arguments with default values"); + } + parameters.emplace_back(param); + } +} + +/// Return an optional string containing the body of the builder. +Optional Builder::getBody() const { + Optional body = def->getValueAsOptionalString("body"); + return body && !body->empty() ? body : llvm::None; +} diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt --- a/mlir/lib/TableGen/CMakeLists.txt +++ b/mlir/lib/TableGen/CMakeLists.txt @@ -11,6 +11,7 @@ llvm_add_library(MLIRTableGen STATIC Argument.cpp Attribute.cpp + Builder.cpp Constraint.cpp Dialect.cpp Format.cpp diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -521,6 +521,18 @@ regions.push_back({name, region}); } + // Populate the builders. + auto *builderList = + dyn_cast_or_null(def.getValueInit("builders")); + if (builderList && !builderList->empty()) { + for (llvm::Init *init : builderList->getValues()) + builders.emplace_back(cast(init)->getDef(), def.getLoc()); + } else if (skipDefaultBuilders()) { + PrintFatalError( + def.getLoc(), + "default builders are skipped and no custom builders provided"); + } + LLVM_DEBUG(print(llvm::dbgs())); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -48,7 +48,7 @@ static const char *const tblgenNamePrefix = "tblgen_"; static const char *const generatedArgName = "odsArg"; -static const char *const builder = "odsBuilder"; +static const char *const odsBuilder = "odsBuilder"; static const char *const builderOpState = "odsState"; // The logic to calculate the actual value range for a declared operand/result @@ -1326,54 +1326,31 @@ body << " }\n"; } -/// Returns a signature of the builder as defined by a dag-typed initializer. -/// Updates the context `fctx` to enable replacement of $_builder and $_state -/// in the body. Reports errors at `loc`. -static std::string builderSignatureFromDAG(const DagInit *init, - ArrayRef loc) { - auto *defInit = dyn_cast(init->getOperator()); - if (!defInit || !defInit->getDef()->getName().equals("ins")) - PrintFatalError(loc, "expected 'ins' in builders"); +/// Returns a signature of the builder. Updates the context `fctx` to enable +/// replacement of $_builder and $_state in the body. +static std::string getBuilderSignature(const Builder &builder) { + ArrayRef params(builder.getParameters()); // Inject builder and state arguments. llvm::SmallVector arguments; - arguments.reserve(init->getNumArgs() + 2); - arguments.push_back(llvm::formatv("::mlir::OpBuilder &{0}", builder).str()); + arguments.reserve(params.size() + 2); + arguments.push_back( + llvm::formatv("::mlir::OpBuilder &{0}", odsBuilder).str()); arguments.push_back( llvm::formatv("::mlir::OperationState &{0}", builderOpState).str()); - // Accept either a StringInit or a DefInit with two string values as dag - // arguments. The former corresponds to the type, the latter to the type and - // the default value. Similarly to C++, once an argument with a default value - // is detected, the following arguments must have default values as well. - bool seenDefaultValue = false; - for (unsigned i = 0, e = init->getNumArgs(); i < e; ++i) { + for (unsigned i = 0, e = params.size(); i < e; ++i) { // If no name is provided, generate one. - StringInit *argName = init->getArgName(i); + Optional paramName = params[i].getName(); std::string name = - argName ? argName->getValue().str() : "odsArg" + std::to_string(i); + paramName ? paramName->str() : "odsArg" + std::to_string(i); - Init *argInit = init->getArg(i); - StringRef type; std::string defaultValue; - if (StringInit *strType = dyn_cast(argInit)) { - type = strType->getValue(); - } else { - const Record *typeAndDefaultValue = cast(argInit)->getDef(); - type = typeAndDefaultValue->getValueAsString("type"); - StringRef defaultValueRef = - typeAndDefaultValue->getValueAsString("defaultValue"); - if (!defaultValueRef.empty()) { - seenDefaultValue = true; - defaultValue = llvm::formatv(" = {0}", defaultValueRef).str(); - } - } - if (seenDefaultValue && defaultValue.empty()) - PrintFatalError(loc, - "expected an argument with default value after other " - "arguments with default values"); + if (Optional defaultParamValue = params[i].getDefaultValue()) + defaultValue = llvm::formatv(" = {0}", *defaultParamValue).str(); arguments.push_back( - llvm::formatv("{0} {1}{2}", type, name, defaultValue).str()); + llvm::formatv("{0} {1}{2}", params[i].getCppType(), name, defaultValue) + .str()); } return llvm::join(arguments, ", "); @@ -1381,41 +1358,26 @@ void OpEmitter::genBuilder() { // Handle custom builders if provided. - // TODO: Create wrapper class for OpBuilder to hide the native - // TableGen API calls here. - { - auto *listInit = dyn_cast_or_null(def.getValueInit("builders")); - if (listInit) { - for (Init *init : listInit->getValues()) { - Record *builderDef = cast(init)->getDef(); - std::string paramStr = builderSignatureFromDAG( - builderDef->getValueAsDag("dagParams"), op.getLoc()); - - StringRef body = builderDef->getValueAsString("body"); - bool hasBody = !body.empty(); - OpMethod::Property properties = - hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration; - auto *method = - opClass.addMethodAndPrune("void", "build", properties, paramStr); + for (const Builder &builder : op.getBuilders()) { + std::string paramStr = getBuilderSignature(builder); - FmtContext fctx; - fctx.withBuilder(builder); - fctx.addSubst("_state", builderOpState); - if (hasBody) - method->body() << tgfmt(body, &fctx); - } - } - if (op.skipDefaultBuilders()) { - if (!listInit || listInit->empty()) - PrintFatalError( - op.getLoc(), - "default builders are skipped and no custom builders provided"); - return; - } + Optional body = builder.getBody(); + OpMethod::Property properties = + body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration; + auto *method = + opClass.addMethodAndPrune("void", "build", properties, paramStr); + + FmtContext fctx; + fctx.withBuilder(odsBuilder); + fctx.addSubst("_state", builderOpState); + if (body) + method->body() << tgfmt(*body, &fctx); } // Generate default builders that requires all result type, operands, and // attributes as parameters. + if (op.skipDefaultBuilders()) + return; // We generate three classes of builders here: // 1. one having a stand-alone parameter for each operand / attribute, and