diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -536,20 +536,6 @@ } }]; - let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &result, " - "ArrayRef resultTypes, ValueRange args, " - "int64_t inputCount, int64_t outputCount, " - "ArrayRef indexingMaps, " - "ArrayRef iteratorTypes", [{ - return build(builder, result, resultTypes, args, - builder.getI64IntegerAttr(inputCount), - builder.getI64IntegerAttr(outputCount), - builder.getAffineMapArrayAttr(indexingMaps), - builder.getStrArrayAttr(iteratorTypes), - /*doc=*/nullptr, /*library_call=*/nullptr); - }]>]; - let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseGenericOp(parser, result); }]; } @@ -665,6 +651,14 @@ future. }]; + let builders = [ + OpBuilder< + "OpBuilder &builder, OperationState &result, ArrayRef resultTypes, " + "ValueRange args, int64_t inputCount, int64_t outputCount, " + "ArrayRef indexingMaps, ArrayRef iteratorTypes, " + "function_ref = nullptr"> + ]; + let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; @@ -791,6 +785,16 @@ future. }]; + let builders = [ + OpBuilder< + "OpBuilder &builder, OperationState &result, ArrayRef resultTypes, " + "ValueRange args, int64_t inputCount, int64_t outputCount, " + "ArrayRef indexingMaps, ArrayRef iteratorTypes, " + "function_ref " + "= nullptr"> + ]; + + let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -70,6 +70,62 @@ // GenericOps //===----------------------------------------------------------------------===// +void GenericOp::build( + OpBuilder &builder, OperationState &result, ArrayRef resultTypes, + ValueRange args, int64_t inputCount, int64_t outputCount, + ArrayRef indexingMaps, ArrayRef iteratorTypes, + function_ref bodyBuild) { + build(builder, result, resultTypes, args, + builder.getI64IntegerAttr(inputCount), + builder.getI64IntegerAttr(outputCount), + builder.getAffineMapArrayAttr(indexingMaps), + builder.getStrArrayAttr(iteratorTypes), + /*doc=*/nullptr, /*library_call=*/nullptr); + if (bodyBuild) { + Region &bodyRegion = *result.regions.front(); + bodyRegion.push_back(new Block); + Block &bodyBlock = bodyRegion.front(); + + for (Value arg : args) + bodyBlock.addArgument(arg.getType().cast().getElementType()); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + bodyBuild(builder, result.location, bodyBlock.getArguments()); + } +} + +void IndexedGenericOp::build( + OpBuilder &builder, OperationState &result, ArrayRef resultTypes, + ValueRange args, int64_t inputCount, int64_t outputCount, + ArrayRef indexingMaps, ArrayRef iteratorTypes, + function_ref + bodyBuild) { + build(builder, result, resultTypes, args, + builder.getI64IntegerAttr(inputCount), + builder.getI64IntegerAttr(outputCount), + builder.getAffineMapArrayAttr(indexingMaps), + builder.getStrArrayAttr(iteratorTypes), + /*doc=*/nullptr, /*library_call=*/nullptr); + if (bodyBuild) { + Region &bodyRegion = *result.regions.front(); + bodyRegion.push_back(new Block); + Block &bodyBlock = bodyRegion.front(); + + unsigned nLoops = iteratorTypes.size(); + for (unsigned i = 0; i < nLoops; ++i) + bodyBlock.addArgument(builder.getIndexType()); + for (Value arg : args) + bodyBlock.addArgument(arg.getType().cast().getElementType()); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + bodyBuild(builder, result.location, + bodyBlock.getArguments().take_front(nLoops), + bodyBlock.getArguments().drop_front(nLoops)); + } +} + template static void printGenericOp(OpAsmPrinter &p, GenericOpType op) { auto attrNames = op.linalgTraitAttrNames();