diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -18,11 +18,9 @@ #endif /// Apply the special region builder for the builtin named Linalg op. -/// The list of `capture` MlirValue is passed as-is to the region builder. /// Assert that `op` is a builtin named Linalg op. MLIR_CAPI_EXPORTED void -mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op, - intptr_t n, MlirValue const *mlirCaptures); +mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op); MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -49,7 +49,7 @@ kInplaceableAttrName = "linalg.inplaceable"; using RegionBuilderFunType = - llvm::function_ref; + llvm::function_ref; RegionBuilderFunType getRegionBuilder(StringRef name) { return namedStructuredOpRegionBuilders.lookup(name); } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -901,7 +901,7 @@ Returns a null function if this named op does not define a region builder. }], - /*retTy=*/"std::function", + /*retTy=*/"std::function", /*methodName=*/"getRegionBuilder", (ins), [{ return ConcreteOp::getRegionBuilder(); }] diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -153,10 +153,8 @@ Value getSource() { return input();} Value getTarget() { return output(); } - static void regionBuilder( - ImplicitLocOpBuilder &b, Block &block, ValueRange captures); - static std::function< - void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)> + static void regionBuilder(ImplicitLocOpBuilder &b, Block &block); + static std::function getRegionBuilder() { return ®ionBuilder; } @@ -200,10 +198,8 @@ extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)}); } - static void regionBuilder( - ImplicitLocOpBuilder &b, Block &block, ValueRange captures); - static std::function< - void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)> + static void regionBuilder(ImplicitLocOpBuilder &b, Block &block); + static std::function getRegionBuilder() { return ®ionBuilder; } @@ -291,8 +287,7 @@ return padding().getValue().getValue({i, 1}); } - static std::function< - void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)> + static std::function getRegionBuilder() { return nullptr; } @@ -533,8 +528,7 @@ library_call()->str() : "op_has_no_registered_library_name"; } - static std::function< - void(ImplicitLocOpBuilder &b, Block &block, ValueRange captures)> + static std::function getRegionBuilder() { return nullptr; } diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -21,15 +21,10 @@ void mlir::python::populateDialectLinalgSubmodule(py::module m) { m.def( "fill_builtin_region", - [](PyDialectDescriptor &dialect, PyOperation &op, py::list captures) { - llvm::SmallVector mlirOperands; - mlirOperands.reserve(captures.size()); - for (auto v : captures) - mlirOperands.push_back(py::cast(v)->get()); - mlirLinalgFillBuiltinNamedOpRegion( - dialect.get(), op.get(), mlirOperands.size(), mlirOperands.data()); + [](PyDialectDescriptor &dialect, PyOperation &op) { + mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get()); }, - py::arg("dialect"), py::arg("op"), py::arg("captures") = py::list(), + py::arg("dialect"), py::arg("op"), "Fill the region for `op`, which is assumed to be a builtin named Linalg " "op."); } diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -16,13 +16,8 @@ /// Apply the special region builder for the builtin named Linalg op. /// Assert that `op` is a builtin named Linalg op. void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, - MlirOperation mlirOp, intptr_t n, - MlirValue const *mlirCaptures) { + MlirOperation mlirOp) { Operation *op = unwrap(mlirOp); - SmallVector captures; - captures.reserve(n); - for (unsigned idx = 0; idx < n; ++idx) - captures.push_back(unwrap(mlirCaptures[idx])); LinalgDialect::RegionBuilderFunType fun = static_cast(unwrap(linalgDialect)) @@ -41,7 +36,7 @@ Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes); b.setInsertionPointToStart(body); - fun(b, *body, captures); + fun(b, *body); } MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -43,20 +43,19 @@ /// defined C++ ops. /// This is used by both builders and parsers. /// This function creates the block in the region with arguments corresponding -/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted -/// to be ShapedType. +/// to the elemental types of `inputTypes` and `outputTypes`. The latter are +/// asserted to be of ShapedType. template static void fillStructuredOpRegion( OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, - TypeRange outputTypes, ValueRange captures = {}, + TypeRange outputTypes, std::function errorHandler = nullptr); /// Generic entry point to create both the region and the block of a LinalgOp. template static void createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, - TypeRange inputTypes, TypeRange outputTypes, - ValueRange captures = {}); + TypeRange inputTypes, TypeRange outputTypes); /// Common parsing and printing used for both named structured ops created by /// ods-gen and by manually defined C++ ops. Does not handle regions. @@ -72,17 +71,15 @@ template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes, - ArrayRef captures = {}); + TypeRange inputTypes, TypeRange outputTypes); static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl &resultTypes); template -static ParseResult -parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, - ArrayRef captures = {}); +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result); static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes); @@ -323,8 +320,7 @@ //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// -void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, - ValueRange captures) { +void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args"); b.create(block.getArgument(0)); } @@ -403,8 +399,7 @@ //===----------------------------------------------------------------------===// // FillOp //===----------------------------------------------------------------------===// -void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, - ValueRange captures) { +void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args"); b.create(block.getArgument(0)); } @@ -2799,7 +2794,6 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, - ValueRange captures, std::function errorHandler) { assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); @@ -2823,7 +2817,7 @@ opBuilder.setInsertionPointToStart(body); ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); - NamedStructuredOpType::regionBuilder(b, *body, captures); + NamedStructuredOpType::regionBuilder(b, *body); // indexing_maps is an auto-generated method. @@ -2835,11 +2829,10 @@ void createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes, - TypeRange outputTypes, - ValueRange captures) { + TypeRange outputTypes) { Region ®ion = *result.addRegion(); fillStructuredOpRegion( - opBuilder, region, inputTypes, outputTypes, captures, + opBuilder, region, inputTypes, outputTypes, [&](unsigned expected, unsigned actual) { assert(expected != actual && "incorrect number of arguments"); }); @@ -2902,15 +2895,14 @@ template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes, - ArrayRef captures) { + TypeRange inputTypes, TypeRange outputTypes) { ParseResult res = success(); OpBuilder opBuilder(parser.getBuilder().getContext()); // Resolve `captures` into `capturedValues` at parse time so we can build the // region with captures. SmallVector capturedValues; fillStructuredOpRegion( - opBuilder, region, inputTypes, outputTypes, capturedValues, + opBuilder, region, inputTypes, outputTypes, [&](unsigned expected, unsigned actual) { res = parser.emitError( parser.getCurrentLocation(), @@ -2931,11 +2923,9 @@ } template -static ParseResult -parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, - ArrayRef captures) { +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result) { // TODO: Enable when ods-gen supports captures. - assert(captures.empty() && "unexpected captures for named structured ops"); SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); @@ -2949,7 +2939,7 @@ std::unique_ptr region = std::make_unique(); if (parseNamedStructuredOpRegion( - parser, *region, inputTypes, outputTypes, captures)) + parser, *region, inputTypes, outputTypes)) return failure(); result.addRegion(std::move(region)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -63,8 +63,7 @@ iterators, [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { ImplicitLocOpBuilder b(loc, bodyBuilder); - regionBuilder(b, *bodyBuilder.getBlock(), - /*captures=*/{}); + regionBuilder(b, *bodyBuilder.getBlock()); }); } diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -33,7 +33,7 @@ ip=ip) OpView.__init__(self, op) linalgDialect = Context.current.get_dialect_descriptor("linalg") - fill_builtin_region(linalgDialect, self.operation, []) + fill_builtin_region(linalgDialect, self.operation) # TODO: self.result is None. When len(results) == 1 we expect it to be # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug # in the generator of _linalg_ops_gen.py where we have: diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -24,7 +24,7 @@ // IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 }); // // IMPL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b, -// IMPL: Block &block, ValueRange captures) { +// IMPL: Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = b.create([[a]], [[b]]); // IMPL: Value [[e:.*]] = b.create([[c]], [[d]]); @@ -49,7 +49,7 @@ // IMPL: AffineMap::get(3, 3, {d0, d1}, context) // // IMPL: Test2Op::regionBuilder(ImplicitLocOpBuilder &b, -// IMPL: Block &block, ValueRange captures) { +// IMPL: Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = b.create([[a]], [[b]]); // IMPL: Value [[e:.*]] = b.create([[c]], [[d]]); @@ -74,7 +74,7 @@ // IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context) // // IMPL: Test3Op::regionBuilder(ImplicitLocOpBuilder &b, -// IMPL: Block &block, ValueRange captures) { +// IMPL: Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = b.create([[a]], [[b]]); // IMPL: Value [[e:.*]] = b.create([[c]], [[d]]); @@ -182,7 +182,7 @@ // Test output arg order. // IMPL-LABEL: void Test8Op::regionBuilder(ImplicitLocOpBuilder &b, -// IMPL: Block &block, ValueRange captures) { +// IMPL: Block &block) { // IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); // IMPL: Value [[d:.*]] = b.create([[a]], [[b]]); // IMPL: Value [[e:.*]] = b.create([[d]], [[c]]); @@ -199,7 +199,7 @@ // IMPL: auto map1 = AffineMap::get(2, 2, {d1}, context); // IMPL: auto map2 = AffineMap::get(2, 2, {d0}, context); // IMPL-LABEL: void Test9Op::regionBuilder(ImplicitLocOpBuilder &b, -// IMPL: Block &block, ValueRange captures) { +// IMPL: Block &block) { // IMPL: Value [[a:.*]](args[0]), [[c:.*]](args[2]); ods_def: def test9(A: f32(M, K), B: f32(K)) -> (C: f32(M)) diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -76,7 +76,7 @@ # ODS-NEXT: TypeRange(outputs) # IMPL-LABEL: void Test1Op::regionBuilder( -# IMPL: ImplicitLocOpBuilder &b, Block &block, ValueRange captures) +# IMPL: ImplicitLocOpBuilder &b, Block &block) # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64"); # IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]]); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); @@ -163,8 +163,7 @@ # IMPL: auto attr = op->getAttrOfType("strides") # IMPL: "missing indexing map required attribute 'strides'" -# IMPL: void Test2Op::regionBuilder( -# IMPL-NEXT: ImplicitLocOpBuilder &b, Block &block, ValueRange captures) +# IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, Block &block) # IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 && # IMPL: yields.push_back(block.getArgument(0)); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1923,7 +1923,7 @@ $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)/*, TODO: support captures*/); + TypeRange(outputs)); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, @@ -1941,7 +1941,7 @@ $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)/*, TODO: support captures*/); + TypeRange(outputs)); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, @@ -1956,7 +1956,7 @@ ]; let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; let parser = [{{ - return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/); + return ::parseNamedStructuredOp<{0}>(parser, result); }]; let hasFolder = 1; @@ -1964,10 +1964,9 @@ // Auto-generated. ArrayAttr iterator_types(); ArrayAttr indexing_maps(); - static void regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ValueRange captures); - static std::function getRegionBuilder() {{ + static void regionBuilder(ImplicitLocOpBuilder &b, Block &block); + static std::function + getRegionBuilder() {{ return regionBuilder; } @@ -2035,7 +2034,7 @@ $_builder, $_state, TypeRange(inputs), - TypeRange(outputs)/*, TODO: support captures*/); + TypeRange(outputs)); {2} }]> )FMT"; @@ -2354,8 +2353,7 @@ }; const char *regionBuilderFmt = R"FMT( - void {0}::regionBuilder(ImplicitLocOpBuilder &b, - Block &block, ValueRange captures) { + void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { auto args = block.getArguments(); Value {1}; {2} diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -511,10 +511,8 @@ // Auto-generated. ArrayAttr iterator_types(); ArrayAttr indexing_maps(); - static void regionBuilder( - ImplicitLocOpBuilder &b, Block &block, ValueRange captures); - static std::function< - void(ImplicitLocOpBuilder &b, Block &, ValueRange)> + static void regionBuilder(ImplicitLocOpBuilder &b, Block &block); + static std::function getRegionBuilder() {{ return regionBuilder; } @@ -883,8 +881,7 @@ // {1}: Number of args // {2}: Statements static const char structuredOpRegionBuilderFormat[] = R"FMT( -void {0}::regionBuilder( - ImplicitLocOpBuilder &b, Block &block, ValueRange captures) {{ +void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{ assert({1} > 0 && block.getNumArguments() == {1} && "{0} regionBuilder expects {1} (>=0) args"); RegionBuilderHelper helper(block.getArgument(0).getContext(), block);