diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -194,6 +194,31 @@ return mlirStringRefCreate(s.data(), s.size()); } +/// Create a block, using the current location context if no locations are +/// specified. +static MlirBlock createBlock(const py::sequence &pyArgTypes, + const std::optional &pyArgLocs) { + SmallVector argTypes; + argTypes.reserve(pyArgTypes.size()); + for (const auto &pyType : pyArgTypes) + argTypes.push_back(pyType.cast()); + + SmallVector argLocs; + if (pyArgLocs) { + argLocs.reserve(pyArgLocs->size()); + for (const auto &pyLoc : *pyArgLocs) + argLocs.push_back(pyLoc.cast()); + } else if (!argTypes.empty()) { + argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); + } + + if (argTypes.size() != argLocs.size()) + throw py::value_error(("Expected " + Twine(argTypes.size()) + + " locations, got: " + Twine(argLocs.size())) + .str()); + return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); +} + /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } @@ -363,21 +388,10 @@ throw py::index_error("attempt to access out of bounds block"); } - PyBlock appendBlock(const py::args &pyArgTypes) { + PyBlock appendBlock(const py::args &pyArgTypes, + const std::optional &pyArgLocs) { operation->checkValid(); - llvm::SmallVector argTypes; - llvm::SmallVector argLocs; - argTypes.reserve(pyArgTypes.size()); - argLocs.reserve(pyArgTypes.size()); - for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast()); - // TODO: Pass in a proper location here. - argLocs.push_back( - mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); - } - - MlirBlock block = - mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); mlirRegionAppendOwnedBlock(region, block); return PyBlock(operation, block); } @@ -387,7 +401,8 @@ .def("__getitem__", &PyBlockList::dunderGetItem) .def("__iter__", &PyBlockList::dunderIter) .def("__len__", &PyBlockList::dunderLen) - .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); + .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, + py::arg("arg_locs") = std::nullopt); } private: @@ -2978,27 +2993,17 @@ "Returns a forward-optimized sequence of operations.") .def_static( "create_at_start", - [](PyRegion &parent, py::list pyArgTypes) { + [](PyRegion &parent, const py::list &pyArgTypes, + const std::optional &pyArgLocs) { parent.checkValid(); - llvm::SmallVector argTypes; - llvm::SmallVector argLocs; - argTypes.reserve(pyArgTypes.size()); - argLocs.reserve(pyArgTypes.size()); - for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast()); - // TODO: Pass in a proper location here. - argLocs.push_back( - mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); - } - - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), - argLocs.data()); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); mlirRegionInsertOwnedBlock(parent, 0, block); return PyBlock(parent.getParentOperation(), block); }, py::arg("parent"), py::arg("arg_types") = py::list(), + py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block at the beginning of the given " - "region (with given argument types).") + "region (with given argument types and locations).") .def( "append_to", [](PyBlock &self, PyRegion ®ion) { @@ -3010,50 +3015,30 @@ "Append this block to a region, transferring ownership if necessary") .def( "create_before", - [](PyBlock &self, py::args pyArgTypes) { + [](PyBlock &self, const py::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - llvm::SmallVector argTypes; - llvm::SmallVector argLocs; - argTypes.reserve(pyArgTypes.size()); - argLocs.reserve(pyArgTypes.size()); - for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast()); - // TODO: Pass in a proper location here. - argLocs.push_back( - mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); - } - - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), - argLocs.data()); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockBefore(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, + py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block before this block " - "(with given argument types).") + "(with given argument types and locations).") .def( "create_after", - [](PyBlock &self, py::args pyArgTypes) { + [](PyBlock &self, const py::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - llvm::SmallVector argTypes; - llvm::SmallVector argLocs; - argTypes.reserve(pyArgTypes.size()); - argLocs.reserve(pyArgTypes.size()); - for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast()); - - // TODO: Pass in a proper location here. - argLocs.push_back( - mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); - } - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), - argLocs.data()); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockAfter(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, + py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block after this block " - "(with given argument types).") + "(with given argument types and locations).") .def( "__iter__", [](PyBlock &self) { diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -90,7 +90,7 @@ raise IndexError('External function does not have a body') return self.regions[0].blocks[0] - def add_entry_block(self): + def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): """ Add an entry block to the function body using the function signature to infer block arguments. @@ -98,7 +98,7 @@ """ if not self.is_external: raise IndexError('The function already has an entry block!') - self.body.blocks.append(*self.type.inputs) + self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) return self.body.blocks[0] @property diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py --- a/mlir/test/python/ir/blocks.py +++ b/mlir/test/python/ir/blocks.py @@ -18,28 +18,28 @@ # CHECK-LABEL: TEST: testBlockCreation -# CHECK: func @test(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i16) +# CHECK: func @test(%[[ARG0:.*]]: i32 loc("arg0"), %[[ARG1:.*]]: i16 loc("arg1")) # CHECK: cf.br ^bb1(%[[ARG1]] : i16) -# CHECK: ^bb1(%[[PHI0:.*]]: i16): +# CHECK: ^bb1(%[[PHI0:.*]]: i16 loc("middle")): # CHECK: cf.br ^bb2(%[[ARG0]] : i32) -# CHECK: ^bb2(%[[PHI1:.*]]: i32): +# CHECK: ^bb2(%[[PHI1:.*]]: i32 loc("successor")): # CHECK: return @run def testBlockCreation(): with Context() as ctx, Location.unknown(): - module = Module.create() + module = builtin.ModuleOp() with InsertionPoint(module.body): f_type = FunctionType.get( [IntegerType.get_signless(32), IntegerType.get_signless(16)], []) f_op = func.FuncOp("test", f_type) - entry_block = f_op.add_entry_block() + entry_block = f_op.add_entry_block([Location.name("arg0"), Location.name("arg1")]) i32_arg, i16_arg = entry_block.arguments - successor_block = entry_block.create_after(i32_arg.type) + successor_block = entry_block.create_after(i32_arg.type, arg_locs=[Location.name("successor")]) with InsertionPoint(successor_block) as successor_ip: assert successor_ip.block == successor_block func.ReturnOp([]) - middle_block = successor_block.create_before(i16_arg.type) + middle_block = successor_block.create_before(i16_arg.type, arg_locs=[Location.name("middle")]) with InsertionPoint(entry_block) as entry_ip: assert entry_ip.block == entry_block @@ -48,27 +48,57 @@ with InsertionPoint(middle_block) as middle_ip: assert middle_ip.block == middle_block cf.BranchOp([i32_arg], dest=successor_block) - print(module.operation) + module.print(enable_debug_info=True) # Ensure region back references are coherent. assert entry_block.region == middle_block.region == successor_block.region +# CHECK-LABEL: TEST: testBlockCreationArgLocs +@run +def testBlockCreationArgLocs(): + with Context() as ctx: + ctx.allow_unregistered_dialects = True + f32 = F32Type.get() + op = Operation.create("test", regions=1, loc=Location.unknown()) + blocks = op.regions[0].blocks + + with Location.name("default_loc"): + blocks.append(f32) + blocks.append() + # CHECK: ^bb0(%{{.+}}: f32 loc("default_loc")): + # CHECK-NEXT: ^bb1: + op.print(enable_debug_info=True) + + try: + blocks.append(f32) + except RuntimeError as err: + # CHECK: Missing loc: An MLIR function requires a Location but none was provided + print("Missing loc:", err) + + try: + blocks.append(f32, f32, arg_locs=[Location.unknown()]) + except ValueError as err: + # CHECK: Wrong loc count: Expected 2 locations, got: 1 + print("Wrong loc count:", err) + + # CHECK-LABEL: TEST: testFirstBlockCreation -# CHECK: func @test(%{{.*}}: f32) +# CHECK: func @test(%{{.*}}: f32 loc("arg_loc")) # CHECK: return @run def testFirstBlockCreation(): with Context() as ctx, Location.unknown(): - module = Module.create() + module = builtin.ModuleOp() f32 = F32Type.get() with InsertionPoint(module.body): f = func.FuncOp("test", ([f32], [])) - entry_block = Block.create_at_start(f.operation.regions[0], [f32]) + entry_block = Block.create_at_start(f.operation.regions[0], + [f32], [Location.name("arg_loc")]) with InsertionPoint(entry_block): func.ReturnOp([]) - print(module) - assert module.operation.verify() + module.print(enable_debug_info=True) + assert module.verify() assert f.body.blocks[0] == entry_block