diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -619,8 +619,22 @@ Blocks can be created within a given region and inserted before or after another block of the same region using `create_before()`, `create_after()` methods of -the `Block` class. They are not expected to exist outside of regions (unlike in -C++ that supports detached blocks). +the `Block` class, or the `create_at_start()` static method of the same class. +They are not expected to exist outside of regions (unlike in C++ that supports +detached blocks). + +```python +from mlir.ir import Block, Context, Operation + +with Context(): + op = Operation.create("generic.op", regions=1) + + # Create the first block in the region. + entry_block = Block.create_at_start(op.regions[0]) + + # Create further blocks. + other_block = entry_block.create_after() +``` Blocks can be used to create `InsertionPoint`s, which can point to the beginning or the end of the block, or just before its terminator. It is common for diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -99,6 +99,9 @@ static void bind(py::module &m) { auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); cls.def(py::init()); + cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool { + return DerivedTy::isaFunction(otherAffineExpr); + }); DerivedTy::bindDerived(cls); } diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1541,6 +1541,9 @@ static void bind(py::module &m) { auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); cls.def(py::init(), py::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyValue &otherValue) -> bool { + return DerivedTy::isaFunction(otherValue); + }); DerivedTy::bindDerived(cls); } @@ -2213,6 +2216,12 @@ return PyBlockList(self.getParentOperation(), self.get()); }, "Returns a forward-optimized sequence of blocks.") + .def_property_readonly( + "owner", + [](PyRegion &self) { + return self.getParentOperation()->createOpView(); + }, + "Returns the operation owning this region.") .def( "__iter__", [](PyRegion &self) { @@ -2256,6 +2265,23 @@ return PyOperationList(self.getParentOperation(), self.get()); }, "Returns a forward-optimized sequence of operations.") + .def_static( + "create_at_start", + [](PyRegion &parent, py::list pyArgTypes) { + parent.checkValid(); + llvm::SmallVector argTypes; + argTypes.reserve(pyArgTypes.size()); + for (auto &pyArg : pyArgTypes) { + argTypes.push_back(pyArg.cast()); + } + + MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); + mlirRegionInsertOwnedBlock(parent, 0, block); + return PyBlock(parent.getParentOperation(), block); + }, + py::arg("parent"), py::arg("pyArgTypes") = py::list(), + "Creates and returns a new Block at the beginning of the given " + "region (with given argument types).") .def( "create_before", [](PyBlock &self, py::args pyArgTypes) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -533,6 +533,7 @@ : parentOperation(std::move(parentOperation)), region(region) { assert(!mlirRegionIsNull(region) && "python region cannot be null"); } + operator MlirRegion() const { return region; } MlirRegion get() { return region; } PyOperationRef &getParentOperation() { return parentOperation; } @@ -681,6 +682,9 @@ auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), pybind11::module_local()); cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool { + return DerivedTy::isaFunction(otherAttr); + }); DerivedTy::bindDerived(cls); } @@ -764,6 +768,7 @@ public: PyValue(PyOperationRef parentOperation, MlirValue value) : parentOperation(parentOperation), value(value) {} + operator MlirValue() const { return value; } MlirValue get() { return value; } PyOperationRef &getParentOperation() { return parentOperation; } diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -89,6 +89,18 @@ print("a1 == a2:", a1 == a2) +# CHECK-LABEL: TEST: testAttrIsInstance +@run +def testAttrIsInstance(): + with Context(): + a1 = Attribute.parse("42") + a2 = Attribute.parse("[42]") + assert IntegerAttr.isinstance(a1) + assert not IntegerAttr.isinstance(a2) + assert not ArrayAttr.isinstance(a1) + assert ArrayAttr.isinstance(a2) + + # CHECK-LABEL: TEST: testAttrEqDoesNotRaise @run def testAttrEqDoesNotRaise(): diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py --- a/mlir/test/python/ir/blocks.py +++ b/mlir/test/python/ir/blocks.py @@ -51,3 +51,22 @@ print(module.operation) # Ensure region back references are coherent. assert entry_block.region == middle_block.region == successor_block.region + + +# CHECK-LABEL: TEST: testFirstBlockCreation +# CHECK: func @test(%{{.*}}: f32) +# CHECK: return +@run +def testFirstBlockCreation(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + func = builtin.FuncOp("test", ([f32], [])) + entry_block = Block.create_at_start(func.operation.regions[0], [f32]) + with InsertionPoint(entry_block): + std.ReturnOp([]) + + print(module) + assert module.operation.verify() + assert func.body.blocks[0] == entry_block diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -10,10 +10,12 @@ f() gc.collect() assert Context._get_live_count() == 0 + return f # Verify iterator based traversal of the op/region/block hierarchy. # CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators +@run def testTraverseOpRegionBlockIterators(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -66,11 +68,10 @@ # CHECK: OP 1: return walk_operations("", op) -run(testTraverseOpRegionBlockIterators) - # Verify index based traversal of the op/region/block hierarchy. # CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices +@run def testTraverseOpRegionBlockIndices(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -106,10 +107,31 @@ # CHECK: OP 1: parent builtin.func walk_operations("", module.operation) -run(testTraverseOpRegionBlockIndices) + +# CHECK-LABEL: TEST: testBlockAndRegionOwners +@run +def testBlockAndRegionOwners(): + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" + builtin.module { + builtin.func @f() { + std.return + } + } + """, ctx) + + assert module.operation.regions[0].owner == module.operation + assert module.operation.regions[0].blocks[0].owner == module.operation + + func = module.body.operations[0] + assert func.operation.regions[0].owner == func + assert func.operation.regions[0].blocks[0].owner == func # CHECK-LABEL: TEST: testBlockArgumentList +@run def testBlockArgumentList(): with Context() as ctx: module = Module.parse(r""" @@ -152,10 +174,8 @@ print("Type: ", t) -run(testBlockArgumentList) - - # CHECK-LABEL: TEST: testOperationOperands +@run def testOperationOperands(): with Context() as ctx: ctx.allow_unregistered_dialects = True @@ -175,10 +195,10 @@ print(f"Operand {i}, type {operand.type}") -run(testOperationOperands) # CHECK-LABEL: TEST: testOperationOperandsSlice +@run def testOperationOperandsSlice(): with Context() as ctx: ctx.allow_unregistered_dialects = True @@ -233,10 +253,10 @@ print(operand) -run(testOperationOperandsSlice) # CHECK-LABEL: TEST: testOperationOperandsSet +@run def testOperationOperandsSet(): with Context() as ctx, Location.unknown(ctx): ctx.allow_unregistered_dialects = True @@ -265,10 +285,10 @@ print(consumer.operands[0]) -run(testOperationOperandsSet) # CHECK-LABEL: TEST: testDetachedOperation +@run def testDetachedOperation(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -285,10 +305,10 @@ # TODO: Check successors once enough infra exists to do it properly. -run(testDetachedOperation) # CHECK-LABEL: TEST: testOperationInsertionPoint +@run def testOperationInsertionPoint(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -323,10 +343,10 @@ else: assert False, "expected insert of attached op to raise" -run(testOperationInsertionPoint) # CHECK-LABEL: TEST: testOperationWithRegion +@run def testOperationWithRegion(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -364,10 +384,10 @@ # CHECK: %0 = "custom.addi" print(module) -run(testOperationWithRegion) # CHECK-LABEL: TEST: testOperationResultList +@run def testOperationResultList(): ctx = Context() module = Module.parse(r""" @@ -393,10 +413,10 @@ print(f"Result type {t}") -run(testOperationResultList) # CHECK-LABEL: TEST: testOperationResultListSlice +@run def testOperationResultListSlice(): with Context() as ctx: ctx.allow_unregistered_dialects = True @@ -444,10 +464,10 @@ print(f"Result {res.result_number}, type {res.type}") -run(testOperationResultListSlice) # CHECK-LABEL: TEST: testOperationAttributes +@run def testOperationAttributes(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -491,10 +511,10 @@ assert False, "expected IndexError on accessing an out-of-bounds attribute" -run(testOperationAttributes) # CHECK-LABEL: TEST: testOperationPrint +@run def testOperationPrint(): ctx = Context() module = Module.parse(r""" @@ -532,10 +552,10 @@ module.operation.print(large_elements_limit=2, enable_debug_info=True, pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True) -run(testOperationPrint) # CHECK-LABEL: TEST: testKnownOpView +@run def testKnownOpView(): with Context(), Location.unknown(): Context.current.allow_unregistered_dialects = True @@ -564,10 +584,10 @@ # CHECK: OpView object print(repr(custom)) -run(testKnownOpView) # CHECK-LABEL: TEST: testSingleResultProperty +@run def testSingleResultProperty(): with Context(), Location.unknown(): Context.current.allow_unregistered_dialects = True @@ -597,9 +617,9 @@ # CHECK: %1 = "custom.one_result"() : () -> f32 print(module.body.operations[2]) -run(testSingleResultProperty) # CHECK-LABEL: TEST: testPrintInvalidOperation +@run def testPrintInvalidOperation(): ctx = Context() with Location.unknown(ctx): @@ -613,10 +633,10 @@ print(module) # CHECK: .verify = False print(f".verify = {module.operation.verify()}") -run(testPrintInvalidOperation) # CHECK-LABEL: TEST: testCreateWithInvalidAttributes +@run def testCreateWithInvalidAttributes(): ctx = Context() with Location.unknown(ctx): @@ -642,10 +662,10 @@ except Exception as e: # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module" print(e) -run(testCreateWithInvalidAttributes) # CHECK-LABEL: TEST: testOperationName +@run def testOperationName(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -661,9 +681,9 @@ for op in module.body.operations: print(op.operation.name) -run(testOperationName) # CHECK-LABEL: TEST: testCapsuleConversions +@run def testCapsuleConversions(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -674,9 +694,9 @@ m2 = Operation._CAPICreate(m_capsule) assert m2 is m -run(testCapsuleConversions) # CHECK-LABEL: TEST: testOperationErase +@run def testOperationErase(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -695,5 +715,3 @@ # Ensure we can create another operation Operation.create("custom.op2") - -run(testOperationErase) diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -9,9 +9,11 @@ f() gc.collect() assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testCapsuleConversions +@run def testCapsuleConversions(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -24,10 +26,8 @@ assert value2 == value -run(testCapsuleConversions) - - # CHECK-LABEL: TEST: testOpResultOwner +@run def testOpResultOwner(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -37,4 +37,21 @@ assert op.result.owner == op -run(testOpResultOwner) +# CHECK-LABEL: TEST: testValueIsInstance +@run +def testValueIsInstance(): + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" + func @foo(%arg0: f32) { + %0 = "some_dialect.some_op"() : () -> f64 + return + }""", ctx) + func = module.body.operations[0] + assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0]) + assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0]) + + op = func.regions[0].blocks[0].operations[0] + assert not BlockArgument.isinstance(op.results[0]) + assert OpResult.isinstance(op.results[0])