diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2012,10 +2012,17 @@ // ``` list builders = ?; - // Avoid generating default build functions. Custom builders must be - // provided. + // Avoid generating default build functions. Custom builders must be + // provided. Implies skipDefaultSeparateParamBuilders = 1 and + // skipDefaultCollectiveParamBuilders = 1. bit skipDefaultBuilders = 0; + // Avoid generating default separate param builders. + bit skipDefaultSeparateParamBuilders = 0; + + // Avoid generating default collective param builders. + bit skipDefaultCollectiveParamBuilders = 0; + // Custom parser. code parser = ?; diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -97,6 +97,10 @@ // Returns true if default builders should not be generated. bool skipDefaultBuilders() const; + // Returns true if default separate param builders should not be generated. + bool skipDefaultSeparateParamBuilders() const; + // Returns true if default collective param builders should not be generated. + bool skipDefaultCollectiveParamBuilders() const; // Op result iterators. value_iterator result_begin(); diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -94,6 +94,16 @@ return def.getValueAsBit("skipDefaultBuilders"); } +bool Operator::skipDefaultSeparateParamBuilders() const { + return def.getValueAsBit("skipDefaultSeparateParamBuilders") || + skipDefaultBuilders(); +} + +bool Operator::skipDefaultCollectiveParamBuilders() const { + return def.getValueAsBit("skipDefaultCollectiveParamBuilders") || + skipDefaultBuilders(); +} + auto Operator::result_begin() -> value_iterator { return results.begin(); } auto Operator::result_end() -> value_iterator { return results.end(); } diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -81,9 +81,9 @@ // CHECK: ::mlir::FloatAttr attr2Attr() // CHECK: ::llvm::Optional< ::llvm::APFloat > attr2(); // CHECK: static void build(Value val); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::llvm::ArrayRef<::mlir::Type> s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::llvm::ArrayRef<::mlir::Type> s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions) +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions) // CHECK: static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); // CHECK: void print(::mlir::OpAsmPrinter &p); // CHECK: ::mlir::LogicalResult verify(); @@ -180,8 +180,8 @@ // CHECK_LABEL: class NS_HCollectiveParamsOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}) +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}) // Check suppression of "separate arg, separate result" build method for an op // with single variadic arg and single variadic result (since it will be @@ -192,8 +192,8 @@ } // CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op : -// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // Check suppression of "separate arg, collective result" build method for an op // with single variadic arg and non variadic result (since it will be @@ -204,8 +204,8 @@ } // CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op : -// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // Check suppression of "separate arg, collective result" build method for an op // with single variadic arg and > 1 variadic result (since it will be @@ -217,9 +217,9 @@ let results = (outs Variadic:$b, Variadic:$c); } // CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op : -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::llvm::ArrayRef<::mlir::Type> c, ::mlir::ValueRange a); -// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::TypeRange c, ::mlir::ValueRange a); +// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // Check default value of `attributes` for the `genUseOperandAsResultTypeCollectiveParamBuilder` builder def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperandsAndResultType]> { @@ -228,8 +228,8 @@ } // CHECK_LABEL: class NS_IOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); @@ -241,8 +241,8 @@ // CHECK_LABEL: class NS_JOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // Check that default builders can be suppressed. diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -14,7 +14,7 @@ } // CHECK-LABEL: void OpA::build -// CHECK: ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands +// CHECK: ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands // CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types"); // CHECK-NEXT: odsState.addTypes(resultTypes); @@ -39,7 +39,7 @@ // CHECK-NEXT: odsState.addTypes(resultType1) // CHECK-NEXT: odsState.addTypes(z) -// CHECK: void OpC::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes) { +// CHECK: void OpC::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes) { // CHECK-NEXT: assert(resultTypes.size() == 3u && "mismatched number of results"); // CHECK-NEXT: odsState.addTypes(resultTypes); @@ -67,7 +67,7 @@ } // CHECK-LABEL: void OpF::build -// CHECK-SAME: ::llvm::ArrayRef<::mlir::Type> x +// CHECK-SAME: ::mlir::TypeRange x // CHECK-NOT: assert // CHECK: odsState.addTypes(x); @@ -78,12 +78,12 @@ // CHECK-LABEL: OpG definitions -// CHECK: void OpG::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type x, ::llvm::ArrayRef<::mlir::Type> y) +// CHECK: void OpG::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type x, ::mlir::TypeRange y) // CHECK-NEXT: odsState.addTypes(x); // CHECK-NEXT: odsState.addTypes(y); // CHECK: void OpG::build -// CHECK: ::llvm::ArrayRef<::mlir::Type> resultTypes +// CHECK: ::mlir::TypeRange resultTypes // CHECK: assert(resultTypes.size() >= 1u && "mismatched number of return types"); // CHECK-NEXT: odsState.addTypes(resultTypes); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1455,8 +1455,9 @@ let results = (outs Variadic:$output_tensors); let regions = (region AnyRegion:$region); + let skipDefaultSeparateParamBuilders = 1; let builders = [ OpBuilder< - "OpBuilder &b, OperationState &result," + "OpBuilder &b, OperationState &result, " "ValueRange inputs, ValueRange outputBuffers", [{{ result.addOperands(inputs); @@ -1474,7 +1475,7 @@ TypeRange(), TypeRange()); }]>, OpBuilder< - "OpBuilder &b, OperationState &result, TypeRange resultTensorTypes," + "OpBuilder &b, OperationState &result, TypeRange resultTensorTypes, " "ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors", [{{ result.addOperands(inputs); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1172,9 +1172,11 @@ // Handle custom builders if provided. // TODO: Create wrapper class for OpBuilder to hide the native // TableGen API calls here. + bool hasBuilder = false; { auto *listInit = dyn_cast_or_null(def.getValueInit("builders")); - if (listInit) { + if (listInit && !listInit->empty()) { + hasBuilder = true; for (Init *init : listInit->getValues()) { Record *builderDef = cast(init)->getDef(); StringRef params = builderDef->getValueAsString("params"); @@ -1189,13 +1191,6 @@ method->body() << body; } } - if (op.skipDefaultBuilders()) { - if (!listInit || listInit->empty()) - PrintFatalError( - op.getLoc(), - "default builders are skipped and no custom builders provided"); - return; - } } // Generate default builders that requires all result type, operands, and @@ -1203,21 +1198,46 @@ // We generate three classes of builders here: // 1. one having a stand-alone parameter for each operand / attribute, and - genSeparateArgParamBuilder(); + if (!op.skipDefaultSeparateParamBuilders()) { + genSeparateArgParamBuilder(); + hasBuilder = true; + } + // 2. one having an aggregated parameter for all result types / operands / // attributes, and - genCollectiveParamBuilder(); + if (!op.skipDefaultCollectiveParamBuilders()) { + genCollectiveParamBuilder(); + hasBuilder = true; + } + // 3. one having a stand-alone parameter for each operand and attribute, // use the first operand or attribute's type as all result types // to facilitate different call patterns. if (op.getNumVariableLengthResults() == 0) { if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { - genUseOperandAsResultTypeSeparateParamBuilder(); - genUseOperandAsResultTypeCollectiveParamBuilder(); + if (!op.skipDefaultSeparateParamBuilders()) { + genUseOperandAsResultTypeSeparateParamBuilder(); + hasBuilder = true; + } + + if (!op.skipDefaultCollectiveParamBuilders()) { + genUseOperandAsResultTypeCollectiveParamBuilder(); + hasBuilder = true; + } } - if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType")) + if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && + !op.skipDefaultBuilders()) { genUseAttrAsResultTypeBuilder(); + hasBuilder = true; + } } + + if (hasBuilder) + return; + + PrintFatalError( + op.getLoc(), + "default builders are skipped and no custom builders provided"); } void OpEmitter::genCollectiveParamBuilder() { @@ -1232,7 +1252,7 @@ SmallVector paramList; paramList.emplace_back("::mlir::OpBuilder &", ""); paramList.emplace_back("::mlir::OperationState &", builderOpState); - paramList.emplace_back("::llvm::ArrayRef<::mlir::Type>", "resultTypes"); + paramList.emplace_back("::mlir::TypeRange", "resultTypes"); paramList.emplace_back("::mlir::ValueRange", "operands"); // Provide default value for `attributes` when its the last parameter StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; @@ -1302,8 +1322,8 @@ if (resultName.empty()) resultName = std::string(formatv("resultType{0}", i)); - StringRef type = result.isVariadic() ? "::llvm::ArrayRef<::mlir::Type>" - : "::mlir::Type"; + StringRef type = + result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type"; OpMethodParameter::Property properties = OpMethodParameter::PP_None; if (result.isOptional()) properties = OpMethodParameter::PP_Optional; @@ -1313,7 +1333,7 @@ } } break; case TypeParamKind::Collective: { - paramList.emplace_back("::llvm::ArrayRef<::mlir::Type>", "resultTypes"); + paramList.emplace_back("::mlir::TypeRange", "resultTypes"); resultTypeNames.push_back("resultTypes"); } break; }