diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -82,16 +82,35 @@ // iterator_types is an auto-generated method. } -/// Create the region and fill the block of a structured operation given -/// `inputTypes` and `outputTypes` as well as a `regionBuilder`. -void createAndFillStructuredOpRegion(OpBuilder &opBuilder, - OperationState &result, - TypeRange inputTypes, - TypeRange outputTypes, - RegionBuilderFn regionBuilder) { - Region ®ion = *result.addRegion(); - fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, - result.attributes.getAttrs(), regionBuilder); +/// Creates a structured operation given `inputs`, `outputs`, and `attributes`. +/// The result types are derived automatically if `resultTensorTypes` is none. +/// The body of the operation is filled using `regionBuilder`. All ods-gen +/// created structured operations use the method to implement their builders. +static void buildStructuredOp(OpBuilder &b, OperationState &state, + llvm::Optional resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef attributes, + RegionBuilderFn regionBuilder) { + // Derive the result types if needed. + SmallVector derivedResultTypes = + resultTensorTypes.getValueOr(TypeRange()); + if (!resultTensorTypes.hasValue()) + copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes), + [](Type type) { return type.isa(); }); + + state.addOperands(inputs); + state.addOperands(outputs); + state.addTypes(derivedResultTypes); + state.addAttributes(attributes); + state.addAttribute( + "operand_segment_sizes", + b.getI32VectorAttr({static_cast(inputs.size()), + static_cast(outputs.size())})); + + // Create and fill the region of the structured operation. + Region ®ion = *state.addRegion(); + fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs), + state.attributes.getAttrs(), regionBuilder); } /// Common parsing used for both named structured ops created by ods-gen and by diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -81,22 +81,8 @@ # ODS-NEXT: "ValueRange":$outputs, "Attribute":$cast, # ODS-NEXT: CArg<"ArrayRef", "{}">:$attributes), -# ODS: $_state.addOperands(inputs); -# ODS-NEXT: $_state.addOperands(outputs); -# ODS-NEXT: $_state.addTypes(resultTensorTypes); -# ODS-NEXT: $_state.addAttribute("cast", cast) -# ODS-NEXT: $_state.addAttributes(attributes); -# ODS-NEXT: $_state.addAttribute( -# ODS-NEXT: "operand_segment_sizes", -# ODS-NEXT: $_builder.getI32VectorAttr({ -# ODS-NEXT: static_cast(inputs.size()), -# ODS-NEXT: static_cast(outputs.size())})); -# ODS-NEXT: createAndFillStructuredOpRegion( -# ODS-NEXT: $_builder, -# ODS-NEXT: $_state, -# ODS-NEXT: TypeRange(inputs), -# ODS-NEXT: TypeRange(outputs), -# ODS-NEXT: Test1Op::getRegionBuilder() +# ODS: buildStructuredOp($_builder, $_state, resultTensorTypes, +# ODS-NEXT: attributes, Test1Op::getRegionBuilder()) # IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b, diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -524,46 +524,16 @@ (ins "ValueRange":$inputs, "ValueRange":$outputs, CArg<"ArrayRef", "{{}">:$attributes), [{{ - $_state.addOperands(inputs); - $_state.addOperands(outputs); - SmallVector resultTensorTypes; - copy_if(outputs.getTypes(), - std::back_inserter(resultTensorTypes), - [](Type type) {{ return type.isa(); }); - $_state.addTypes(resultTensorTypes); - $_state.addAttribute( - "operand_segment_sizes", - $_builder.getI32VectorAttr({{ - static_cast(inputs.size()), - static_cast(outputs.size())})); - $_state.addAttributes(attributes); - createAndFillStructuredOpRegion( - $_builder, - $_state, - TypeRange(inputs), - TypeRange(outputs), - {0}::getRegionBuilder()); + buildStructuredOp($_builder, $_state, llvm::None, inputs, outputs, + attributes, {0}::getRegionBuilder()); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, CArg<"ArrayRef", "{{}">:$attributes), [{{ - $_state.addOperands(inputs); - $_state.addOperands(outputs); - $_state.addTypes(resultTensorTypes); - $_state.addAttributes(attributes); - $_state.addAttribute( - "operand_segment_sizes", - $_builder.getI32VectorAttr({{ - static_cast(inputs.size()), - static_cast(outputs.size())})); - createAndFillStructuredOpRegion( - $_builder, - $_state, - TypeRange(inputs), - TypeRange(outputs), - {0}::getRegionBuilder()); + buildStructuredOp($_builder, $_state, resultTensorTypes, + inputs, outputs, attributes, {0}::getRegionBuilder()); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, @@ -610,22 +580,9 @@ "ValueRange":$outputs, {1}, CArg<"ArrayRef", "{{}">:$attributes), [{{ - $_state.addOperands(inputs); - $_state.addOperands(outputs); - $_state.addTypes(resultTensorTypes); {2} - $_state.addAttributes(attributes); - $_state.addAttribute( - "operand_segment_sizes", - $_builder.getI32VectorAttr({{ - static_cast(inputs.size()), - static_cast(outputs.size())})); - createAndFillStructuredOpRegion( - $_builder, - $_state, - TypeRange(inputs), - TypeRange(outputs), - {0}::getRegionBuilder()); + buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs, + attributes, {0}::getRegionBuilder()); }]> )FMT";