diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -567,17 +567,16 @@ #### Custom builder methods However, if the above cases cannot satisfy all needs, you can define additional -convenience build methods with `OpBuilder`. - -`OpBuilder` is a class that takes the parameter list and the optional `build()` -method body. They are separated because we need to generate op declaration and -definition into separate files. The parameter list should not include `OpBuilder -&builder, OperationState &state` as they will be inserted automatically and the -placeholders `$_builder` and `$_state` used. For legacy/to be deprecated reason -if the `OpBuilder` parameter starts with `OpBuilder` param, then the parameter -is used. If the `body` is not provided, only the builder declaration will be -generated; this provides a way to define complicated builders entirely in C++ -files. +convenience build methods in the `builders` field. + +`OpBuilderDAG` is a class that takes the parameter list as a TableGen `dag` as +well as the optional `build()` method body. Trailing parameters (like in C++) +can have default values defined using the `CArg` class, which takes two string +arguments: the type and the default value. The parameter list should not include +`OpBuilder &, OperationState &` as they will be inserted automatically and the +placeholders `$_builder` and `$_state` can be used to refer to these objects. If +the `body` is not provided, only the builder declaration will be generated; this +provides a way to define complicated builders entirely in C++ files. For example, for the following op: @@ -597,7 +596,7 @@ ... let builders = [ - OpBuilder<"float val = 0.5f", [{ + OpBuilderDAG<(ins CArg<"float", "0.5f">:$val), [{ $_state.addAttribute("attr", $_builder.getF32FloatAttr(val)); }]> ]; @@ -612,6 +611,15 @@ } ``` +A version of the builder with no default parameter values looks similarly to +other C method declarations: `OpBuilderDAG<(ins "float":$val)>`. + +**Deprecated:** `OpBuilder` class allows one to specify the custom builder +signature as a raw string, without separating parameters into different `dag` +arguments. It also supports leading parameters of `OpBuilder &` and +`OperationState &` types, which will be used instead of the autogenerated ones +if present. + ### Custom parser and printer methods Functions to parse and print the operation's custom assembly form. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1803,6 +1803,13 @@ // Marker used to identify the argument list for an op or interface method. def ins; +// This class represents a typed argument with optional default value for C +// function signatures, e.g. builders or methods. +class CArg { + string type = ty; + string defaultValue = value; +} + // OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in // C++. The purpose to wrap around C++ symbol string with this class is to make // interfaces specified for ops in TableGen less alien and more integrated. @@ -1922,6 +1929,15 @@ // Marker used to identify the successor list for an op. def successor; +// Base class for custom builders. This is a transient class that will go away +// when the transition to the DAG form of builder declaration is complete. +// Should not be used directly. +class OpBuilderBase { + string params = sp; + dag dagParams = dp; + code body = b; +} + // Class for defining a custom builder. // // TableGen generates several generic builders for each op by default (see @@ -1931,23 +1947,27 @@ // The signature of the builder is always // // ```c++ -// static void build(OpBuilder &builder, OperationState &state, +// static void build(::mlir::OpBuilder &builder, mlir::OperationState &state, // ...) { // ... // } // ``` // -// To define a custom builder, the parameter list (*including* the `Builder -// *builder, OperationState &state` part) and body should be passed in -// as separate template arguments to this class. This is because we generate -// op declaration and definition into separate files. If an empty string is -// passed in for `body`, then *only* the builder declaration will be -// generated; this provides a way to define complicated builders entirely -// in C++. -class OpBuilder { - string params = p; - code body = b; -} +// To define a custom builder, the parameter list (*NOT including* the +// `OpBuilder &builder, OperationState &state` part) and body should be passed +// in as separate template arguments to this class. The parameter list is a +// TableGen DAG with `ins` operation with named arguments, which has either: +// - string initializers ("Type":$name) to represent a typed parameter, or +// - CArg-typed initializers (CArg<"Type", "default">:$name) to represent a +// typed parameter that has a default value. +// +// If an empty string is passed in for `body`, then *only* the builder +// declaration will be generated; this provides a way to define complicated +// builders entirely in C++. +class OpBuilderDAG : OpBuilderBase<"<<>>", p, b>; + +// Deprecated version of OpBuilder that takes the builder signature as string. +class OpBuilder : OpBuilderBase; // A base decorator class that may optionally be added to OpVariables. class OpVariableDecorator; @@ -2024,7 +2044,7 @@ // ValueRange operands, // ArrayRef attributes); // ``` - list builders = ?; + list builders = ?; // Avoid generating default build functions. Custom builders must be // provided. diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -33,7 +33,9 @@ AnyRegion:$someRegion, VariadicRegion:$someRegions ); - let builders = [OpBuilder<"Value val">]; + let builders = [OpBuilderDAG<(ins "Value":$val)>, + OpBuilderDAG<(ins CArg<"int", "0">:$integer)>, + OpBuilder<"double deprecatedForm">]; let parser = [{ foo }]; let printer = [{ bar }]; let verifier = [{ baz }]; @@ -81,6 +83,8 @@ // CHECK: ::mlir::FloatAttr attr2Attr() // CHECK: ::llvm::Optional< ::llvm::APFloat > attr2(); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value val); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, int integer = 0); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, double deprecatedForm); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions) @@ -250,7 +254,7 @@ def NS_SkipDefaultBuildersOp : NS_Op<"skip_default_builders", []> { let skipDefaultBuilders = 1; - let builders = [OpBuilder<"Value val">]; + let builders = [OpBuilderDAG<(ins "Value":$val)>]; } // CHECK-LABEL: NS::SkipDefaultBuildersOp declarations diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1153,6 +1153,85 @@ body << " }\n"; } +/// Returns a signature of the builder as defined by a dag-typed initializer. +/// Updates the context `fctx` to enable replacement of $_builder and $_state +/// in the body. Reports errors at `loc`. +static llvm::Optional +builderSignatureFromDAG(const DagInit *init, ArrayRef loc, + FmtContext &fctx) { + auto *defInit = dyn_cast(init->getOperator()); + if (!defInit || !defInit->getDef()->getName().equals("ins")) { + PrintFatalError(loc, "expected 'ins'"); + return llvm::None; + } + + // Inject builder and state arguments. + llvm::SmallVector arguments; + arguments.reserve(init->getNumArgs() + 2); + arguments.push_back(llvm::formatv("::mlir::OpBuilder &{0}", builder).str()); + arguments.push_back( + llvm::formatv("::mlir::OperationState &{0}", builderOpState).str()); + + // Accept either a StringInit or a DefInit with two string values as dag + // arguments. The former corresponds to the type, the latter to the type and + // the default value. Similarly to C++, once an argument with a default value + // is detected, the following arguments must have default values as well. + bool seenDefaultValue = false; + for (unsigned i = 0, e = init->getNumArgs(); i < e; ++i) { + // If no name is provided, generate one. + StringInit *argName = init->getArgName(i); + std::string name = + argName ? argName->getValue().str() : "odsArg" + std::to_string(i); + + Init *argInit = init->getArg(i); + StringRef type, defaultValue; + if (StringInit *strType = dyn_cast(argInit)) { + type = strType->getValue(); + } else { + const Record *typeAndDefaultValue = cast(argInit)->getDef(); + type = typeAndDefaultValue->getValueAsString("type"); + defaultValue = typeAndDefaultValue->getValueAsString("defaultValue"); + if (!defaultValue.empty()) { + seenDefaultValue = true; + } else if (seenDefaultValue) { + PrintFatalError(typeAndDefaultValue->getLoc(), + "expected an argument with default value after other " + "arguments with default values"); + return llvm::None; + } + } + std::string defaultValueString = + defaultValue.empty() ? "" : llvm::formatv(" = {0}", defaultValue).str(); + arguments.push_back( + llvm::formatv("{0} {1}{2}", type, name, defaultValueString).str()); + } + + fctx.withBuilder(builder); + fctx.addSubst("_state", builderOpState); + + return llvm::join(arguments, ", "); +} + +// Returns a signature fo the builder as defined by a string initializer, +// optioally injecting the builder and state arguments. +// TODO: to be removed after the transition is complete. +static llvm::Optional +builderSignatureFromString(StringRef params, FmtContext &fctx) { + bool skipParamGen = params.startswith("OpBuilder") || + params.startswith("mlir::OpBuilder") || + params.startswith("::mlir::OpBuilder"); + if (skipParamGen) + return params.str(); + + fctx.withBuilder(builder); + fctx.addSubst("_state", builderOpState); + return llvm::formatv("::mlir::OpBuilder &{0}, " + "::mlir::OperationState &{1}{2}{3}", + builder, builderOpState, params.empty() ? "" : ", ", + params) + .str(); +} + void OpEmitter::genBuilder() { // Handle custom builders if provided. // TODO: Create wrapper class for OpBuilder to hide the native @@ -1163,34 +1242,24 @@ for (Init *init : listInit->getValues()) { Record *builderDef = cast(init)->getDef(); StringRef params = builderDef->getValueAsString("params").trim(); - // TODO: Remove this and just generate the builder/state always. - bool skipParamGen = params.startswith("OpBuilder") || - params.startswith("mlir::OpBuilder") || - params.startswith("::mlir::OpBuilder"); + FmtContext fctx; + llvm::Optional paramStr = + (params == "<<>>") + ? builderSignatureFromDAG( + builderDef->getValueAsDag("dagParams"), + builderDef->getLoc(), fctx) + : builderSignatureFromString(params, fctx); + if (!paramStr.hasValue()) + return; + StringRef body = builderDef->getValueAsString("body"); bool hasBody = !body.empty(); - OpMethod::Property properties = hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration; - std::string paramStr = - skipParamGen ? params.str() - : llvm::formatv("::mlir::OpBuilder &{0}, " - "::mlir::OperationState &{1}{2}{3}", - builder, builderOpState, - params.empty() ? "" : ", ", params) - .str(); auto *method = - opClass.addMethodAndPrune("void", "build", properties, paramStr); - if (hasBody) { - if (skipParamGen) { - method->body() << body; - } else { - FmtContext fctx; - fctx.withBuilder(builder); - fctx.addSubst("_state", builderOpState); - method->body() << tgfmt(body, &fctx); - } - } + opClass.addMethodAndPrune("void", "build", properties, *paramStr); + if (hasBody) + method->body() << tgfmt(body, &fctx); } } if (op.skipDefaultBuilders()) {