diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -267,6 +267,12 @@ let results = (outs Variadic:$result); let regions = (region SizedRegion<1>:$mapper); + let builders = [ + OpBuilder<(ins "ValueRange":$inputs, "Value":$init, + "function_ref", + CArg<"ArrayRef", "{}">:$attributes)> + ]; + let extraClassDeclaration = structuredOpsBaseDecls # [{ // Implement functions necessary for LinalgStructuredInterface. SmallVector getIteratorTypesArray(); @@ -341,6 +347,13 @@ let results = (outs Variadic); let regions = (region SizedRegion<1>:$combiner); + let builders = [ + OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$inits, + "ArrayRef":$dimensions, + "function_ref", + CArg<"ArrayRef", "{}">:$attributes)> + ]; + let extraClassDeclaration = structuredOpsBaseDecls # [{ // Declare functions necessary for LinalgStructuredInterface. SmallVector getIteratorTypesArray(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -649,6 +649,26 @@ // GenericOp //===----------------------------------------------------------------------===// +static void createGenericRegion( + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange outputs, + function_ref bodyBuild) { + SmallVector blockArgTypes; + SmallVector blockArgLocs; + for (ValueRange container : {inputs, outputs}) { + for (Value v : container) { + blockArgTypes.push_back(getElementTypeOrSelf(v)); + blockArgLocs.push_back(v.getLoc()); + } + } + + OpBuilder::InsertionGuard guard(builder); + auto ®ion = *result.regions.front(); + Block *bodyBlock = + builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); + bodyBuild(builder, result.location, bodyBlock->getArguments()); +} + void GenericOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { for (Value v : getRegionInputArgs()) @@ -666,23 +686,8 @@ build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, iteratorTypes, doc, libraryCall); result.addAttributes(attributes); - if (!bodyBuild) - return; - - SmallVector blockArgTypes; - SmallVector blockArgLocs; - for (ValueRange container : {inputs, outputs}) { - for (Value v : container) { - blockArgTypes.push_back(getElementTypeOrSelf(v)); - blockArgLocs.push_back(v.getLoc()); - } - } - - OpBuilder::InsertionGuard guard(builder); - auto ®ion = *result.regions.front(); - Block *bodyBlock = - builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); - bodyBuild(builder, result.location, bodyBlock->getArguments()); + if (bodyBuild) + createGenericRegion(builder, result, inputs, outputs, bodyBuild); } void GenericOp::build( @@ -1317,6 +1322,22 @@ setNameFn(getResults().front(), "mapped"); } +void MapOp::build( + OpBuilder &builder, OperationState &result, ValueRange inputs, Value init, + function_ref bodyBuild, + ArrayRef attributes) { + build(builder, result, TypeRange{}, inputs, init); + result.addAttributes(attributes); + + // Add output types for `RankedTensorType` output arguments. + Type initType = init.getType(); + if (initType.isa()) + result.addTypes(initType); + + if (bodyBuild) + createGenericRegion(builder, result, inputs, /*outputs=*/{}, bodyBuild); +} + ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { if (parseDstStyleOp(parser, result)) return failure(); @@ -1420,6 +1441,25 @@ setNameFn(getResults().front(), "reduced"); } +void ReduceOp::build( + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange inits, ArrayRef dimensions, + function_ref bodyBuild, + ArrayRef attributes) { + build(builder, result, TypeRange{}, inputs, inits, dimensions); + result.addAttributes(attributes); + + // Add output types for `RankedTensorType` output arguments. + for (Value init : inits) { + Type initType = init.getType(); + if (initType.isa()) + result.addTypes(initType); + } + + if (bodyBuild) + createGenericRegion(builder, result, inputs, inits, bodyBuild); +} + SmallVector ReduceOp::getIteratorTypesArray() { int64_t inputRank = getInputs()[0].getType().cast().getRank(); SmallVector iteratorTypes(inputRank, @@ -1597,45 +1637,32 @@ }; } -void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder, - ::mlir::OperationState &odsState) { - Region *region = odsState.addRegion(); - - SmallVector argTypes; - SmallVector argLocs; - for (auto t : odsState.operands) { - argTypes.push_back(getElementTypeOrSelf(t)); - argLocs.push_back(opBuilder.getUnknownLoc()); - } - - // RAII. - OpBuilder::InsertionGuard guard(opBuilder); - Block *body = - opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs); - - ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); - getRegionBuilder()(b, *body, odsState.attributes.getAttrs()); -} - -void TransposeOp::build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, Value input, - Value init, DenseI64ArrayAttr permutation, +void TransposeOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &result, Value input, Value init, + DenseI64ArrayAttr permutation, ArrayRef attributes) { - odsState.addOperands(input); - odsState.addOperands(init); - odsState.addAttribute(getPermutationAttrName(odsState.name), permutation); - odsState.addAttributes(attributes); - odsState.addTypes(init.getType()); + result.addOperands(input); + result.addOperands(init); + result.addAttribute(getPermutationAttrName(result.name), permutation); + result.addAttributes(attributes); + + // Add output types for `RankedTensorType` output arguments. + Type initType = init.getType(); + if (initType.isa()) + result.addTypes(initType); - createRegion(odsBuilder, odsState); + createGenericRegion(builder, result, input, init, + [&](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }); } -void TransposeOp::build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, Value input, - Value init, ArrayRef permutation, +void TransposeOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &result, Value input, Value init, + ArrayRef permutation, ArrayRef attributes) { - build(odsBuilder, odsState, input, init, - odsBuilder.getDenseI64ArrayAttr(permutation), attributes); + build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation), + attributes); } ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { @@ -1645,8 +1672,12 @@ }))) return failure(); - OpBuilder opBuilder(parser.getContext()); - createRegion(opBuilder, result); + OpBuilder builder(parser.getContext()); + createGenericRegion(builder, result, /*inputs=*/result.operands, + /*outputs=*/{}, + [&](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }); return success(); }