diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2006,8 +2006,8 @@ // // ```c++ // static void build(OpBuilder &, OperationState &odsState, - // ArrayRef resultTypes, - // ArrayRef operands, + // TypeRange resultTypes, + // ValueRange operands, // ArrayRef attributes); // ``` list builders = ?; diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -81,9 +81,9 @@ // CHECK: ::mlir::FloatAttr attr2Attr() // CHECK: ::llvm::Optional< ::llvm::APFloat > attr2(); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value val); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::llvm::ArrayRef<::mlir::Type> s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::llvm::ArrayRef<::mlir::Type> s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions) +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions) // CHECK: static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); // CHECK: void print(::mlir::OpAsmPrinter &p); // CHECK: ::mlir::LogicalResult verify(); @@ -180,8 +180,8 @@ // CHECK_LABEL: class NS_HCollectiveParamsOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}) +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}) // Check suppression of "separate arg, separate result" build method for an op // with single variadic arg and single variadic result (since it will be @@ -192,8 +192,8 @@ } // CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op : -// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // Check suppression of "separate arg, collective result" build method for an op // with single variadic arg and non variadic result (since it will be @@ -204,8 +204,8 @@ } // CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op : -// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // Check suppression of "separate arg, collective result" build method for an op // with single variadic arg and > 1 variadic result (since it will be @@ -217,9 +217,9 @@ let results = (outs Variadic:$b, Variadic:$c); } // CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op : -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::llvm::ArrayRef<::mlir::Type> c, ::mlir::ValueRange a); -// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> b, ::mlir::ValueRange a); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::TypeRange c, ::mlir::ValueRange a); +// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // Check default value of `attributes` for the `genUseOperandAsResultTypeCollectiveParamBuilder` builder def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperandsAndResultType]> { @@ -228,8 +228,8 @@ } // CHECK_LABEL: class NS_IOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); @@ -241,8 +241,8 @@ // CHECK_LABEL: class NS_JOp : // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::Value a, ::mlir::Value b); -// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // Check that default builders can be suppressed. diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -14,7 +14,7 @@ } // CHECK-LABEL: void OpA::build -// CHECK: ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange operands +// CHECK: ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands // CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types"); // CHECK-NEXT: odsState.addTypes(resultTypes); @@ -39,7 +39,7 @@ // CHECK-NEXT: odsState.addTypes(resultType1) // CHECK-NEXT: odsState.addTypes(z) -// CHECK: void OpC::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llvm::ArrayRef<::mlir::Type> resultTypes) { +// CHECK: void OpC::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes) { // CHECK-NEXT: assert(resultTypes.size() == 3u && "mismatched number of results"); // CHECK-NEXT: odsState.addTypes(resultTypes); @@ -67,7 +67,7 @@ } // CHECK-LABEL: void OpF::build -// CHECK-SAME: ::llvm::ArrayRef<::mlir::Type> x +// CHECK-SAME: ::mlir::TypeRange x // CHECK-NOT: assert // CHECK: odsState.addTypes(x); @@ -78,12 +78,12 @@ // CHECK-LABEL: OpG definitions -// CHECK: void OpG::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type x, ::llvm::ArrayRef<::mlir::Type> y) +// CHECK: void OpG::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type x, ::mlir::TypeRange y) // CHECK-NEXT: odsState.addTypes(x); // CHECK-NEXT: odsState.addTypes(y); // CHECK: void OpG::build -// CHECK: ::llvm::ArrayRef<::mlir::Type> resultTypes +// CHECK: ::mlir::TypeRange resultTypes // CHECK: assert(resultTypes.size() >= 1u && "mismatched number of return types"); // CHECK-NEXT: odsState.addTypes(resultTypes); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1455,8 +1455,9 @@ let results = (outs Variadic:$result_tensors); let regions = (region AnyRegion:$region); + let skipDefaultBuilders = 1; let builders = [ OpBuilder< - "OpBuilder &b, OperationState &result," + "OpBuilder &b, OperationState &result, " "ValueRange inputs, ValueRange outputBuffers", [{{ result.addOperands(inputs); @@ -1493,6 +1494,14 @@ TypeRange(outputBuffers), TypeRange(initTensors), resultTensorTypes); + }]>, OpBuilder< + "OpBuilder &b, OperationState &result, TypeRange resultTensorTypes," + "ValueRange operands, ArrayRef attributes = {{}", + [{{ + result.addOperands(operands); + result.addAttributes(attributes); + result.addTypes(resultTensorTypes); + (void)result.addRegion(); }]> ]; let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1252,7 +1252,7 @@ SmallVector paramList; paramList.emplace_back("::mlir::OpBuilder &", ""); paramList.emplace_back("::mlir::OperationState &", builderOpState); - paramList.emplace_back("::llvm::ArrayRef<::mlir::Type>", "resultTypes"); + paramList.emplace_back("::mlir::TypeRange", "resultTypes"); paramList.emplace_back("::mlir::ValueRange", "operands"); // Provide default value for `attributes` when its the last parameter StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; @@ -1322,8 +1322,8 @@ if (resultName.empty()) resultName = std::string(formatv("resultType{0}", i)); - StringRef type = result.isVariadic() ? "::llvm::ArrayRef<::mlir::Type>" - : "::mlir::Type"; + StringRef type = + result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type"; OpMethodParameter::Property properties = OpMethodParameter::PP_None; if (result.isOptional()) properties = OpMethodParameter::PP_Optional; @@ -1333,7 +1333,7 @@ } } break; case TypeParamKind::Collective: { - paramList.emplace_back("::llvm::ArrayRef<::mlir::Type>", "resultTypes"); + paramList.emplace_back("::mlir::TypeRange", "resultTypes"); resultTypeNames.push_back("resultTypes"); } break; }