diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -508,20 +508,6 @@ } }]; - let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &result, " - "ArrayRef resultTypes, ValueRange args, " - "int64_t inputCount, int64_t outputCount, " - "ArrayRef indexingMaps, " - "ArrayRef iteratorTypes", [{ - return build(builder, result, resultTypes, args, - builder.getI64IntegerAttr(inputCount), - builder.getI64IntegerAttr(outputCount), - builder.getAffineMapArrayAttr(indexingMaps), - builder.getStrArrayAttr(iteratorTypes), - /*doc=*/nullptr, /*library_call=*/nullptr); - }]>]; - let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseGenericOp(parser, result); }]; } @@ -637,6 +623,14 @@ future. }]; + let builders = [ + OpBuilder< + "OpBuilder &builder, OperationState &result, ArrayRef resultTypes, " + "ValueRange args, int64_t inputCount, int64_t outputCount, " + "ArrayRef indexingMaps, ArrayRef iteratorTypes, " + "function_ref = nullptr"> + ]; + let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; @@ -763,6 +757,16 @@ future. }]; + let builders = [ + OpBuilder< + "OpBuilder &builder, OperationState &result, ArrayRef resultTypes, " + "ValueRange args, int64_t inputCount, int64_t outputCount, " + "ArrayRef indexingMaps, ArrayRef iteratorTypes, " + "function_ref " + "= nullptr"> + ]; + + let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -70,6 +70,58 @@ // GenericOps //===----------------------------------------------------------------------===// +void GenericOp::build( + OpBuilder &builder, OperationState &result, ArrayRef resultTypes, + ValueRange args, int64_t inputCount, int64_t outputCount, + ArrayRef indexingMaps, ArrayRef iteratorTypes, + function_ref bodyBuild) { + build(builder, result, resultTypes, args, + builder.getI64IntegerAttr(inputCount), + builder.getI64IntegerAttr(outputCount), + builder.getAffineMapArrayAttr(indexingMaps), + builder.getStrArrayAttr(iteratorTypes), + /*doc=*/nullptr, /*library_call=*/nullptr); + if (!bodyBuild) + return; + + SmallVector blockArgTypes; + for (Value arg : args) + blockArgTypes.push_back(arg.getType().cast().getElementType()); + + OpBuilder::InsertionGuard guard(builder); + auto ®ion = *result.regions.front(); + Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes); + bodyBuild(builder, result.location, bodyBlock->getArguments()); +} + +void IndexedGenericOp::build( + OpBuilder &builder, OperationState &result, ArrayRef resultTypes, + ValueRange args, int64_t inputCount, int64_t outputCount, + ArrayRef indexingMaps, ArrayRef iteratorTypes, + function_ref + bodyBuild) { + build(builder, result, resultTypes, args, + builder.getI64IntegerAttr(inputCount), + builder.getI64IntegerAttr(outputCount), + builder.getAffineMapArrayAttr(indexingMaps), + builder.getStrArrayAttr(iteratorTypes), + /*doc=*/nullptr, /*library_call=*/nullptr); + if (!bodyBuild) + return; + + unsigned nLoops = iteratorTypes.size(); + SmallVector blockArgTypes(nLoops, builder.getIndexType()); + for (Value arg : args) + blockArgTypes.push_back(arg.getType().cast().getElementType()); + + OpBuilder::InsertionGuard guard(builder); + auto ®ion = *result.regions.front(); + Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes); + bodyBuild(builder, result.location, + bodyBlock->getArguments().take_front(nLoops), + bodyBlock->getArguments().drop_front(nLoops)); +} + template static void printGenericOp(OpAsmPrinter &p, GenericOpType op) { auto attrNames = op.linalgTraitAttrNames();