diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -36,57 +36,194 @@ using namespace mlir; using namespace mlir::linalg; -/// Forward declarations. - -/// Generic entry point to create the block for the region of a LinalgOp. -/// This is used by both named structured ops created by ods-gen and by manually -/// defined C++ ops. -/// This is used by both builders and parsers. -/// This function creates the block in the region with arguments corresponding -/// to the elemental types of `inputTypes` and `outputTypes`. The latter are -/// asserted to be of ShapedType. -template -static void fillStructuredOpRegion( - OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, - TypeRange outputTypes, ArrayRef attrs, - llvm::function_ref errorHandler = nullptr); - -/// Generic entry point to create both the region and the block of a LinalgOp. -template -static void -createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, - TypeRange inputTypes, TypeRange outputTypes); - -/// Common parsing and printing used for both named structured ops created by -/// ods-gen and by manually defined C++ ops. Does not handle regions. +//===----------------------------------------------------------------------===// +// Support for named Linalg ops defined in ods-gen. +//===----------------------------------------------------------------------===// + +using RegionBuilderFn = llvm::function_ref)>; + +/// Fills the region of a structured operation using the provided +/// `regionBuilder`. The method is used by both named structured ops created by +/// ods-gen and by manually defined C++ ops. It is called by both builders and +/// parsers and creates a block with arguments corresponding to the elemental +/// types of `inputTypes` and `outputTypes`. All output types are asserted to be +/// ShapedType. +static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, + TypeRange inputTypes, TypeRange outputTypes, + ArrayRef attrs, + RegionBuilderFn regionBuilder) { + assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); + + // TODO: atm all operands go through getElementTypeOrSelf, + // reconsider when we have evidence we need to. + SmallVector argTypes; + SmallVector argLocs; + for (auto containers : {inputTypes, outputTypes}) { + for (auto t : containers) { + argTypes.push_back(getElementTypeOrSelf(t)); + + // TODO: Pass in a proper location here. + argLocs.push_back(opBuilder.getUnknownLoc()); + } + } + + // RAII. + OpBuilder::InsertionGuard guard(opBuilder); + Block *body = + opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); + + opBuilder.setInsertionPointToStart(body); + ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); + regionBuilder(b, *body, attrs); + + // indexing_maps is an auto-generated method. + + // iterator_types is an auto-generated method. +} + +/// Create the region and fill the block of a structured operation given +/// `inputTypes` and `outputTypes` as well as a `regionBuilder`. +void createAndFillStructuredOpRegion(OpBuilder &opBuilder, + OperationState &result, + TypeRange inputTypes, + TypeRange outputTypes, + RegionBuilderFn regionBuilder) { + Region ®ion = *result.addRegion(); + fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, + result.attributes.getAttrs(), regionBuilder); +} + +/// Common parsing used for both named structured ops created by ods-gen and by +/// manually defined C++ ops. Does not handle regions. static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, - SmallVectorImpl &outputTypes); -template -static void printCommonStructuredOpParts(OpAsmPrinter &p, - NamedStructuredOpType op); + SmallVectorImpl &outputTypes) { + SMLoc inputsOperandsLoc, outputsOperandsLoc; + SmallVector inputsOperands, + outputsOperands; -/// Specific parsing and printing for named structured ops created by ods-gen. -template -static ParseResult -parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes, - ArrayRef attrs); + parser.parseOptionalAttrDict(result.attributes); + + if (succeeded(parser.parseOptionalKeyword("ins"))) { + if (parser.parseLParen()) + return failure(); + + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands) || + parser.parseColonTypeList(inputTypes) || parser.parseRParen()) + return failure(); + } + + if (succeeded(parser.parseOptionalKeyword("outs"))) { + outputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || + parser.parseColonTypeList(outputTypes) || parser.parseRParen()) + return failure(); + } + + if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, + result.operands) || + parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, + result.operands)) + return failure(); + + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getI32VectorAttr( + {static_cast(inputsOperands.size()), + static_cast(outputsOperands.size())})); + return success(); +} + +static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, + ValueRange outputs) { + if (!inputs.empty()) + p << " ins(" << inputs << " : " << inputs.getTypes() << ")"; + if (!outputs.empty()) + p << " outs(" << outputs << " : " << outputs.getTypes() << ")"; +} + +//===----------------------------------------------------------------------===// +// Specific parsing and printing for named structured ops created by ods-gen. +//===----------------------------------------------------------------------===// + +static ParseResult parseNamedStructuredOpRegion( + OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, + TypeRange inputTypes, TypeRange outputTypes, ArrayRef attrs, + RegionBuilderFn regionBuilder) { + if (numRegionArgs != inputTypes.size() + outputTypes.size()) { + return parser.emitError( + parser.getCurrentLocation(), + llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " + "region expects {0} args, got {1}", + numRegionArgs, inputTypes.size() + outputTypes.size())); + } + + OpBuilder opBuilder(parser.getContext()); + fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs, + regionBuilder); + return success(); +} static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, - SmallVectorImpl &resultTypes); + SmallVectorImpl &resultTypes) { + if (parser.parseOptionalArrowTypeList(resultTypes)) + return failure(); + return success(); +} -template static ParseResult parseNamedStructuredOp(OpAsmParser &parser, - OperationState &result); + OperationState &result, + unsigned numRegionArgs, + RegionBuilderFn regionBuilder) { + // TODO: Enable when ods-gen supports captures. + SmallVector inputTypes, outputTypes; + if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) + return failure(); + + // TODO: consider merging results parsing into region parsing. + // Need to wait for declarative assembly resolution to decide. + SmallVector outputTensorsTypes; + if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) + return failure(); + result.addTypes(outputTensorsTypes); + + std::unique_ptr region = std::make_unique(); + if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes, + outputTypes, result.attributes.getAttrs(), + regionBuilder)) + return failure(); + result.addRegion(std::move(region)); + + return success(); +} static void printNamedStructuredOpResults(OpAsmPrinter &p, - TypeRange resultTypes); + TypeRange resultTypes) { + if (resultTypes.empty()) + return; + p.printOptionalArrowTypeList(resultTypes); +} + +static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, + ValueRange inputs, ValueRange outputs) { + p.printOptionalAttrDict( + op->getAttrs(), + /*elidedAttrs=*/{"operand_segment_sizes", + // See generated code in mlir-linalg-yaml-gen.cpp + "linalg.memoized_indexing_maps"}); -template -static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); + // Printing is shared with generic ops, except for the region and + // attributes. + printCommonStructuredOpParts(p, inputs, outputs); + + // Results printing. + printNamedStructuredOpResults(p, op->getResultTypes()); + + // Region is elided. +} /// This is a common class used for patterns of the form /// ``` @@ -590,7 +727,7 @@ } // Printing is shared with named ops, except for the region and attributes - printCommonStructuredOpParts(p, *this); + printCommonStructuredOpParts(p, inputs(), outputs()); genericAttrNames.push_back("operand_segment_sizes"); genericAttrNamesSet.insert(genericAttrNames.back()); @@ -682,12 +819,7 @@ outputBuffers); } -template -static LogicalResult verifyGenericOp(GenericOpType op) { - return success(); -} - -LogicalResult GenericOp::verify() { return verifyGenericOp(*this); } +LogicalResult GenericOp::verify() { return success(); } namespace { // Deduplicate redundant args of a linalg generic op. @@ -1365,213 +1497,6 @@ return ss.str(); } -//===----------------------------------------------------------------------===// -// Support for named Linalg ops defined in ods-gen. -//===----------------------------------------------------------------------===// - -/// Generic entry point to create the block for the region of a LinalgOp. -/// This is used by both named structured ops created by ods-gen and by manually -/// defined C++ ops. -/// This is used by both builders and parsers. -/// This function creates the block in the region with arguments corresponding -/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted -/// to be ShapedType. -template -static void fillStructuredOpRegion( - OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, - TypeRange outputTypes, ArrayRef attrs, - llvm::function_ref errorHandler) { - assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); - - // TODO: atm all operands go through getElementTypeOrSelf, - // reconsider when we have evidence we need to. - SmallVector argTypes; - SmallVector argLocs; - for (auto containers : {inputTypes, outputTypes}) { - for (auto t : containers) { - argTypes.push_back(getElementTypeOrSelf(t)); - - // TODO: Pass in a proper location here. - argLocs.push_back(opBuilder.getUnknownLoc()); - } - } - - // RAII. - OpBuilder::InsertionGuard guard(opBuilder); - Block *body = - opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); - unsigned actual = body->getNumArguments(); - unsigned expected = NamedStructuredOpType::getNumRegionArgs(); - if (expected != actual) { - if (errorHandler) - errorHandler(expected, actual); - return; - } - - opBuilder.setInsertionPointToStart(body); - ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); - NamedStructuredOpType::regionBuilder(b, *body, attrs); - - // indexing_maps is an auto-generated method. - - // iterator_types is an auto-generated method. -} - -/// Generic entry point to create both the region and the block of a LinalgOp. -template -void createAndFillStructuredOpRegion(OpBuilder &opBuilder, - OperationState &result, - TypeRange inputTypes, - TypeRange outputTypes) { - Region ®ion = *result.addRegion(); - fillStructuredOpRegion( - opBuilder, region, inputTypes, outputTypes, result.attributes.getAttrs(), - [&](unsigned expected, unsigned actual) { - assert(expected != actual && "incorrect number of arguments"); - }); -} - -/// Common parsing used for both named structured ops created by ods-gen and by -/// manually defined C++ ops. Does not handle regions. -static ParseResult -parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, - SmallVectorImpl &inputTypes, - SmallVectorImpl &outputTypes) { - SMLoc inputsOperandsLoc, outputsOperandsLoc; - SmallVector inputsOperands, - outputsOperands; - - parser.parseOptionalAttrDict(result.attributes); - - if (succeeded(parser.parseOptionalKeyword("ins"))) { - if (parser.parseLParen()) - return failure(); - - inputsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperandList(inputsOperands) || - parser.parseColonTypeList(inputTypes) || parser.parseRParen()) - return failure(); - } - - if (succeeded(parser.parseOptionalKeyword("outs"))) { - outputsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || - parser.parseColonTypeList(outputTypes) || parser.parseRParen()) - return failure(); - } - - if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, - result.operands) || - parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, - result.operands)) - return failure(); - - result.addAttribute("operand_segment_sizes", - parser.getBuilder().getI32VectorAttr( - {static_cast(inputsOperands.size()), - static_cast(outputsOperands.size())})); - return success(); -} - -template -static void printCommonStructuredOpParts(OpAsmPrinter &p, - NamedStructuredOpType op) { - if (!op.inputs().empty()) - p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; - if (!op.outputs().empty()) - p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; -} - -//===----------------------------------------------------------------------===// -// Specific parsing and printing for named structured ops created by ods-gen. -//===----------------------------------------------------------------------===// - -template -static ParseResult -parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes, - ArrayRef attrs) { - ParseResult res = success(); - OpBuilder opBuilder(parser.getContext()); - // Resolve `captures` into `capturedValues` at parse time so we can build the - // region with captures. - SmallVector capturedValues; - fillStructuredOpRegion( - opBuilder, region, inputTypes, outputTypes, attrs, - [&](unsigned expected, unsigned actual) { - res = parser.emitError( - parser.getCurrentLocation(), - llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " - "region expects {0} args, got {1}", - expected, actual)); - region.front().dump(); - }); - return res; -} - -static ParseResult -parseNamedStructuredOpResults(OpAsmParser &parser, - SmallVectorImpl &resultTypes) { - if (parser.parseOptionalArrowTypeList(resultTypes)) - return failure(); - return success(); -} - -template -static ParseResult parseNamedStructuredOp(OpAsmParser &parser, - OperationState &result) { - // TODO: Enable when ods-gen supports captures. - SmallVector inputTypes, outputTypes; - if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) - return failure(); - - // TODO: consider merging results parsing into region parsing. - // Need to wait for declarative assembly resolution to decide. - SmallVector outputTensorsTypes; - if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) - return failure(); - result.addTypes(outputTensorsTypes); - - std::unique_ptr region = std::make_unique(); - if (parseNamedStructuredOpRegion( - parser, *region, inputTypes, outputTypes, - result.attributes.getAttrs())) - return failure(); - result.addRegion(std::move(region)); - - return success(); -} - -static void printNamedStructuredOpResults(OpAsmPrinter &p, - TypeRange resultTypes) { - if (resultTypes.empty()) - return; - p.printOptionalArrowTypeList(resultTypes); -} - -template -static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { - p.printOptionalAttrDict( - op->getAttrs(), - /*elidedAttrs=*/{"operand_segment_sizes", - // See generated code in mlir-linalg-yaml-gen.cpp - "linalg.memoized_indexing_maps"}); - - // Printing is shared with generic ops, except for the region and - // attributes. - printCommonStructuredOpParts(p, op); - - // Results printing. - printNamedStructuredOpResults(p, op.result_tensors().getTypes()); - - // Region is elided. -} - -template -static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { - return verifyGenericOp(op); -} - //===----------------------------------------------------------------------===// // Canonicalizers and Folders. //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -91,11 +91,13 @@ # ODS-NEXT: $_builder.getI32VectorAttr({ # ODS-NEXT: static_cast(inputs.size()), # ODS-NEXT: static_cast(outputs.size())})); -# ODS-NEXT: createAndFillStructuredOpRegion( +# ODS-NEXT: createAndFillStructuredOpRegion( # ODS-NEXT: $_builder, # ODS-NEXT: $_state, # ODS-NEXT: TypeRange(inputs), -# ODS-NEXT: TypeRange(outputs) +# ODS-NEXT: TypeRange(outputs), +# ODS-NEXT: Test1Op::getRegionBuilder() + # IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b, # IMPL-NEXT: Block &block, ArrayRef attrs) diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -537,11 +537,12 @@ static_cast(inputs.size()), static_cast(outputs.size())})); $_state.addAttributes(attributes); - createAndFillStructuredOpRegion<{0}>( + createAndFillStructuredOpRegion( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)); + TypeRange(outputs), + {0}::getRegionBuilder()); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, @@ -557,11 +558,12 @@ $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); - createAndFillStructuredOpRegion<{0}>( + createAndFillStructuredOpRegion( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)); + TypeRange(outputs), + {0}::getRegionBuilder()); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, @@ -618,11 +620,12 @@ $_builder.getI32VectorAttr({{ static_cast(inputs.size()), static_cast(outputs.size())})); - createAndFillStructuredOpRegion<{0}>( + createAndFillStructuredOpRegion( $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)); + TypeRange(outputs), + {0}::getRegionBuilder()); }]> )FMT"; @@ -705,10 +708,11 @@ // {0}: Class name static const char structuredOpParserFormat[] = R"FMT( ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{ - return ::parseNamedStructuredOp<{0}>(parser, result); + return ::parseNamedStructuredOp(parser, result, + {0}::getNumRegionArgs(), {0}::getRegionBuilder()); } void {0}::print(OpAsmPrinter &p) {{ - ::printNamedStructuredOp(p, *this); + ::printNamedStructuredOp(p, getOperation(), inputs(), outputs()); } )FMT";