diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -572,10 +572,13 @@ `OpBuilder` is a class that takes the parameter list and the optional `build()` method body. They are separated because we need to generate op declaration and -definition into separate files. The parameter list should _include_ `Builder -*builder, OperationState &state`. If the `body` is not provided, _only_ the -builder declaration will be generated; this provides a way to define complicated -builders entirely in C++ files. +definition into separate files. The parameter list should not include `OpBuilder +&builder, OperationState &state` as they will be inserted automatically and the +placeholders `$_builder` and `$_state` used. For legacy/to be deprecated reason +if the `OpBuilder` parameter starts with `OpBuilder` param, then the parameter +is used. If the `body` is not provided, only the builder declaration will be +generated; this provides a way to define complicated builders entirely in C++ +files. For example, for the following op: @@ -595,8 +598,8 @@ ... let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &state, float val = 0.5f", [{ - state.addAttribute("attr", builder.getF32FloatAttr(val)); + OpBuilder<"float val = 0.5f", [{ + $_state.addAttribute("attr", $_builder.getF32FloatAttr(val)); }]> ]; } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1062,14 +1062,14 @@ // result type. So need to provide a builder not requiring result types. let builders = [ OpBuilder< - "OpBuilder &builder, OperationState &state, IntegerAttr count", + "IntegerAttr count", [{ - auto i32Type = builder.getIntegerType(32); - state.addTypes(i32Type); // $output1 + auto i32Type = $_builder.getIntegerType(32); + $_state.addTypes(i32Type); // $output1 SmallVector types(count.getInt(), i32Type); - state.addTypes(types); // $output2 - state.addTypes(types); // $output3 - state.addAttribute("count", count); + $_state.addTypes(types); // $output2 + $_state.addTypes(types); // $output3 + $_state.addAttribute("count", count); }]> ]; } diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -80,7 +80,7 @@ // CHECK: uint32_t attr1(); // CHECK: ::mlir::FloatAttr attr2Attr() // CHECK: ::llvm::Optional< ::llvm::APFloat > attr2(); -// CHECK: static void build(Value val); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value val); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::llvm::ArrayRef<::mlir::Type> s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::llvm::ArrayRef<::mlir::Type> s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions) @@ -256,7 +256,7 @@ // CHECK-LABEL: NS::SkipDefaultBuildersOp declarations // CHECK: class SkipDefaultBuildersOp // CHECK-NOT: static void build(::mlir::Builder -// CHECK: static void build(Value +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value // Check leading underscore in op name // --- diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -47,6 +47,7 @@ static const char *const tblgenNamePrefix = "tblgen_"; static const char *const generatedArgName = "odsArg"; +static const char *const builder = "odsBuilder"; static const char *const builderOpState = "odsState"; // The logic to calculate the actual value range for a declared operand/result @@ -1177,16 +1178,34 @@ if (listInit) { for (Init *init : listInit->getValues()) { Record *builderDef = cast(init)->getDef(); - StringRef params = builderDef->getValueAsString("params"); + StringRef params = builderDef->getValueAsString("params").trim(); + // TODO: Remove this and just generate the builder/state always. + bool skipParamGen = params.startswith("OpBuilder") || + params.startswith("mlir::OpBuilder") || + params.startswith("::mlir::OpBuilder"); StringRef body = builderDef->getValueAsString("body"); bool hasBody = !body.empty(); OpMethod::Property properties = hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration; + std::string paramStr = + skipParamGen ? params.str() + : llvm::formatv("::mlir::OpBuilder &{0}, " + "::mlir::OperationState &{1}, {2}", + builder, builderOpState, params) + .str(); auto *method = - opClass.addMethodAndPrune("void", "build", properties, params); - if (hasBody) - method->body() << body; + opClass.addMethodAndPrune("void", "build", properties, paramStr); + if (hasBody) { + if (skipParamGen) { + method->body() << body; + } else { + FmtContext fctx; + fctx.withBuilder(builder); + fctx.addSubst("_state", builderOpState); + method->body() << tgfmt(body, &fctx); + } + } } } if (op.skipDefaultBuilders()) {