diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -620,18 +620,22 @@ "ValueRange":$outputs, "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, "StringRef":$doc, "StringRef":$libraryCall, - CArg<"function_ref", "nullptr">)>, + CArg<"function_ref", "nullptr">, + CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers, "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, "StringRef":$doc, "StringRef":$libraryCall, - CArg<"function_ref", "nullptr">)>, + CArg<"function_ref", "nullptr">, + CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, - CArg<"function_ref", "nullptr">)>, + CArg<"function_ref", "nullptr">, + CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers, "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, - CArg<"function_ref", "nullptr">)> + CArg<"function_ref", "nullptr">, + CArg<"ArrayRef", "{}">:$attributes)> ]; let extraClassDeclaration = structuredOpsBaseDecls # [{ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -502,13 +502,15 @@ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, - function_ref bodyBuild) { + function_ref bodyBuild, + ArrayRef attributes) { build(builder, result, resultTensorTypes, inputs, outputs, builder.getAffineMapArrayAttr(indexingMaps), builder.getStrArrayAttr(iteratorTypes), doc.empty() ? StringAttr() : builder.getStringAttr(doc), libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall)); + result.addAttributes(attributes); if (!bodyBuild) return; @@ -527,30 +529,33 @@ OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, - function_ref bodyBuild) { + function_ref bodyBuild, + ArrayRef attributes) { build(builder, result, TypeRange{}, inputs, outputs, indexingMaps, - iteratorTypes, doc, libraryCall, bodyBuild); + iteratorTypes, doc, libraryCall, bodyBuild, attributes); } void GenericOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, - function_ref bodyBuild) { + function_ref bodyBuild, + ArrayRef attributes) { build(builder, result, inputs, outputs, indexingMaps, iteratorTypes, /*doc=*/"", - /*libraryCall=*/"", bodyBuild); + /*libraryCall=*/"", bodyBuild, attributes); } void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, ArrayRef iteratorTypes, - function_ref bodyBuild) { + function_ref bodyBuild, + ArrayRef attributes) { build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, iteratorTypes, /*doc=*/"", - /*libraryCall=*/"", bodyBuild); + /*libraryCall=*/"", bodyBuild, attributes); } static void print(OpAsmPrinter &p, GenericOp op) { diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -169,7 +169,8 @@ // ODS-LABEL: def Test7Op // ODS: OpBuilder< // ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, -// ODS: "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b) +// ODS: "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b, +// ODS: CArg<"ArrayRef", "{}">:$attributes) // ODS: $_state.addAttribute("attr_a", attr_a); // ODS: $_state.addAttribute("attr_b", attr_b); // diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1910,7 +1910,8 @@ let skipDefaultBuilders = 1; let builders = [ OpBuilder< - (ins "ValueRange":$inputs, "ValueRange":$outputs), + (ins "ValueRange":$inputs, "ValueRange":$outputs, + CArg<"ArrayRef", "{{}">:$attributes), [{{ $_state.addOperands(inputs); $_state.addOperands(outputs); @@ -1919,6 +1920,7 @@ $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); + $_state.addAttributes(attributes); createAndFillStructuredOpRegion<{0}>( $_builder, $_state, @@ -1927,7 +1929,8 @@ }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, - "ValueRange":$outputs), + "ValueRange":$outputs, + CArg<"ArrayRef", "{{}">:$attributes), [{{ $_state.addOperands(inputs); $_state.addOperands(outputs); @@ -1937,6 +1940,7 @@ $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); + $_state.addAttributes(attributes); createAndFillStructuredOpRegion<{0}>( $_builder, $_state, @@ -2020,7 +2024,8 @@ const char *builderFmt = R"FMT( , OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, - "ValueRange":$outputs, {1}), + "ValueRange":$outputs, {1}, + CArg<"ArrayRef", "{{}">:$attributes), [{{ $_state.addOperands(inputs); $_state.addOperands(outputs); @@ -2030,6 +2035,7 @@ $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); + $_state.addAttributes(attributes); createAndFillStructuredOpRegion<{0}>( $_builder, $_state, diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -457,7 +457,8 @@ let skipDefaultBuilders = 1; let builders = [ OpBuilder< - (ins "ValueRange":$inputs, "ValueRange":$outputs), + (ins "ValueRange":$inputs, "ValueRange":$outputs, + CArg<"ArrayRef", "{{}">:$attributes), [{{ $_state.addOperands(inputs); $_state.addOperands(outputs); @@ -471,6 +472,7 @@ $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); + $_state.addAttributes(attributes); createAndFillStructuredOpRegion<{0}>( $_builder, $_state, @@ -539,7 +541,8 @@ static const char structuredOpBuilderFormat[] = R"FMT( , OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, - "ValueRange":$outputs, {1}), + "ValueRange":$outputs, {1}, + CArg<"ArrayRef", "{{}">:$attributes), [{{ $_state.addOperands(inputs); $_state.addOperands(outputs); @@ -555,6 +558,7 @@ TypeRange(inputs), TypeRange(outputs)); {2} + $_state.addAttributes(attributes); }]> )FMT";