diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -47,8 +47,8 @@ constexpr const static ::llvm::StringLiteral kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps"; - using RegionBuilderFunType = - llvm::function_ref; + using RegionBuilderFunType = llvm::function_ref< + void(ImplicitLocOpBuilder &b, Block &, ArrayRef)>; RegionBuilderFunType getRegionBuilder(StringRef name) { return namedStructuredOpRegionBuilders.lookup(name); } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1025,7 +1025,7 @@ Returns a null function if this named op does not define a region builder. }], - /*retTy=*/"std::function", + /*retTy=*/"std::function)>", /*methodName=*/"getRegionBuilder", (ins), [{ return ConcreteOp::getRegionBuilder(); }] diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -83,8 +83,10 @@ extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}); } - static void regionBuilder(ImplicitLocOpBuilder &b, Block &block); - static std::function + static void regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs); + static std::function)> getRegionBuilder() { return ®ionBuilder; } @@ -254,7 +256,8 @@ library_call()->str() : "op_has_no_registered_library_name"; } - static std::function + static std::function)> getRegionBuilder() { return nullptr; } diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -38,7 +38,7 @@ Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); b.setInsertionPointToStart(body); - fun(b, *body); + fun(b, *body, op->getAttrs()); } MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -49,7 +49,7 @@ template static void fillStructuredOpRegion( OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, - TypeRange outputTypes, + TypeRange outputTypes, ArrayRef attrs, llvm::function_ref errorHandler = nullptr); /// Generic entry point to create both the region and the block of a LinalgOp. @@ -72,7 +72,8 @@ template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes); + TypeRange inputTypes, TypeRange outputTypes, + ArrayRef attrs); static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, @@ -375,7 +376,8 @@ //===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// -void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { +void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, + ArrayRef attrs) { assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args"); b.create(block.getArgument(0)); } @@ -384,16 +386,16 @@ Value output) { build(builder, result, output.getType().dyn_cast(), value, output); - fillStructuredOpRegion(builder, *result.regions.front(), - TypeRange{value.getType()}, - TypeRange{output.getType()}, {}); + fillStructuredOpRegion( + builder, *result.regions.front(), TypeRange{value.getType()}, + TypeRange{output.getType()}, result.attributes.getAttrs(), {}); } ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType, Type outputType) { OpBuilder opBuilder(parser.getContext()); fillStructuredOpRegion(opBuilder, r, TypeRange{valueType}, - TypeRange{outputType}); + TypeRange{outputType}, {}); return success(); } @@ -1820,7 +1822,7 @@ template static void fillStructuredOpRegion( OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, - TypeRange outputTypes, + TypeRange outputTypes, ArrayRef attrs, llvm::function_ref errorHandler) { assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); @@ -1851,7 +1853,7 @@ opBuilder.setInsertionPointToStart(body); ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); - NamedStructuredOpType::regionBuilder(b, *body); + NamedStructuredOpType::regionBuilder(b, *body, attrs); // indexing_maps is an auto-generated method. @@ -1866,7 +1868,7 @@ TypeRange outputTypes) { Region ®ion = *result.addRegion(); fillStructuredOpRegion( - opBuilder, region, inputTypes, outputTypes, + opBuilder, region, inputTypes, outputTypes, result.attributes.getAttrs(), [&](unsigned expected, unsigned actual) { assert(expected != actual && "incorrect number of arguments"); }); @@ -1929,14 +1931,15 @@ template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes) { + TypeRange inputTypes, TypeRange outputTypes, + ArrayRef attrs) { ParseResult res = success(); OpBuilder opBuilder(parser.getContext()); // Resolve `captures` into `capturedValues` at parse time so we can build the // region with captures. SmallVector capturedValues; fillStructuredOpRegion( - opBuilder, region, inputTypes, outputTypes, + opBuilder, region, inputTypes, outputTypes, attrs, [&](unsigned expected, unsigned actual) { res = parser.emitError( parser.getCurrentLocation(), @@ -1973,7 +1976,8 @@ std::unique_ptr region = std::make_unique(); if (parseNamedStructuredOpRegion( - parser, *region, inputTypes, outputTypes)) + parser, *region, inputTypes, outputTypes, + result.attributes.getAttrs())) return failure(); result.addRegion(std::move(region)); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2555,11 +2555,13 @@ let extraClassDeclaration = [{ bool hasIndexSemantics() { return false; } - static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block) { + static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block, + mlir::ArrayRef attrs) { b.create(block.getArguments().back()); } - static std::function + static std::function)> getRegionBuilder() { return ®ionBuilder; } diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -83,8 +83,8 @@ # ODS-NEXT: TypeRange(inputs), # ODS-NEXT: TypeRange(outputs) -# IMPL-LABEL: void Test1Op::regionBuilder( -# IMPL: ImplicitLocOpBuilder &b, Block &block) +# IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b, +# IMPL-NEXT: Block &block, ArrayRef attrs) # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64"); # IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); @@ -174,7 +174,8 @@ # IMPL: auto attr = op->getAttrOfType("strides") # IMPL: "incorrect element type for index attribute 'strides'" # IMPL: "incorrect shape for index attribute 'strides'" -# IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, Block &block) +# IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, +# IMPL-NEXT: Block &block, ArrayRef attrs) # IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 && # IMPL: yields.push_back(block.getArgument(0)); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -523,8 +523,10 @@ // Auto-generated. ArrayAttr iterator_types(); ArrayAttr indexing_maps(); - static void regionBuilder(ImplicitLocOpBuilder &b, Block &block); - static std::function + static void regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ArrayRef attrs); + static std::function)> getRegionBuilder() {{ return regionBuilder; } @@ -952,7 +954,8 @@ // {1}: Number of args // {2}: Statements static const char structuredOpRegionBuilderFormat[] = R"FMT( -void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{ +void {0}::regionBuilder(ImplicitLocOpBuilder &b, + Block &block, ArrayRef attrs) {{ assert({1} > 0 && block.getNumArguments() == {1} && "{0} regionBuilder expects {1} (>=0) args"); RegionBuilderHelper helper(block.getArgument(0).getContext(), block);