diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -612,8 +612,22 @@ Blocks can be created within a given region and inserted before or after another block of the same region using `create_before()`, `create_after()` methods of -the `Block` class. They are not expected to exist outside of regions (unlike in -C++ that supports detached blocks). +the `Block` class, or the `create_at_start()` static method of the same class. +They are not expected to exist outside of regions (unlike in C++ that supports +detached blocks). + +```python +from mlir.ir import Block, Context, Operation + +with Context(): + op = Operation.create("generic.op", regions=1) + + # Create the first block in the region. + entry_block = Block.create_at_start(op.regions[0]) + + # Create further blocks. + other_block = entry_block.create_after() +``` Blocks can be used to create `InsertionPoint`s, which can point to the beginning or the end of the block, or just before its terminator. It is common for diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -99,6 +99,9 @@ static void bind(py::module &m) { auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); cls.def(py::init()); + cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool { + return DerivedTy::isaFunction(otherAffineExpr); + }); DerivedTy::bindDerived(cls); } diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1548,6 +1548,9 @@ static void bind(py::module &m) { auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); cls.def(py::init(), py::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyValue &otherValue) -> bool { + return DerivedTy::isaFunction(otherValue); + }); DerivedTy::bindDerived(cls); } @@ -2248,6 +2251,12 @@ return PyBlockList(self.getParentOperation(), self.get()); }, "Returns a forward-optimized sequence of blocks.") + .def_property_readonly( + "owner", + [](PyRegion &self) { + return self.getParentOperation()->createOpView(); + }, + "Returns the operation owning this region.") .def( "__iter__", [](PyRegion &self) { @@ -2291,6 +2300,23 @@ return PyOperationList(self.getParentOperation(), self.get()); }, "Returns a forward-optimized sequence of operations.") + .def_static( + "create_at_start", + [](PyRegion &parent, py::list pyArgTypes) { + parent.checkValid(); + llvm::SmallVector argTypes; + argTypes.reserve(pyArgTypes.size()); + for (auto &pyArg : pyArgTypes) { + argTypes.push_back(pyArg.cast()); + } + + MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); + mlirRegionInsertOwnedBlock(parent, 0, block); + return PyBlock(parent.getParentOperation(), block); + }, + py::arg("parent"), py::arg("pyArgTypes") = py::list(), + "Creates and returns a new Block at the beginning of the given " + "region (with given argument types).") .def( "create_before", [](PyBlock &self, py::args pyArgTypes) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -533,6 +533,7 @@ : parentOperation(std::move(parentOperation)), region(region) { assert(!mlirRegionIsNull(region) && "python region cannot be null"); } + operator MlirRegion() const { return region; } MlirRegion get() { return region; } PyOperationRef &getParentOperation() { return parentOperation; } @@ -681,6 +682,9 @@ auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), pybind11::module_local()); cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool { + return DerivedTy::isaFunction(otherAttr); + }); DerivedTy::bindDerived(cls); } @@ -764,6 +768,7 @@ public: PyValue(PyOperationRef parentOperation, MlirValue value) : parentOperation(parentOperation), value(value) {} + operator MlirValue() const { return value; } MlirValue get() { return value; } PyOperationRef &getParentOperation() { return parentOperation; } diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py --- a/mlir/test/python/ir/affine_expr.py +++ b/mlir/test/python/ir/affine_expr.py @@ -8,9 +8,11 @@ f() gc.collect() assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testAffineExprCapsule +@run def testAffineExprCapsule(): with Context() as ctx: affine_expr = AffineExpr.get_constant(42) @@ -24,10 +26,9 @@ assert affine_expr == affine_expr_2 assert affine_expr_2.context == ctx -run(testAffineExprCapsule) - # CHECK-LABEL: TEST: testAffineExprEq +@run def testAffineExprEq(): with Context(): a1 = AffineExpr.get_constant(42) @@ -44,10 +45,9 @@ # CHECK: False print(a1 == "foo") -run(testAffineExprEq) - # CHECK-LABEL: TEST: testAffineExprContext +@run def testAffineExprContext(): with Context(): a1 = AffineExpr.get_constant(42) @@ -61,6 +61,7 @@ # CHECK-LABEL: TEST: testAffineExprConstant +@run def testAffineExprConstant(): with Context(): a1 = AffineExpr.get_constant(42) @@ -77,10 +78,9 @@ assert a1 == a2 -run(testAffineExprConstant) - # CHECK-LABEL: TEST: testAffineExprDim +@run def testAffineExprDim(): with Context(): d1 = AffineExpr.get_dim(1) @@ -100,10 +100,9 @@ assert d1 == d11 assert d1 != d2 -run(testAffineExprDim) - # CHECK-LABEL: TEST: testAffineExprSymbol +@run def testAffineExprSymbol(): with Context(): s1 = AffineExpr.get_symbol(1) @@ -123,10 +122,9 @@ assert s1 == s11 assert s1 != s2 -run(testAffineExprSymbol) - # CHECK-LABEL: TEST: testAffineAddExpr +@run def testAffineAddExpr(): with Context(): d1 = AffineDimExpr.get(1) @@ -143,10 +141,9 @@ assert d12.lhs == d1 assert d12.rhs == d2 -run(testAffineAddExpr) - # CHECK-LABEL: TEST: testAffineMulExpr +@run def testAffineMulExpr(): with Context(): d1 = AffineDimExpr.get(1) @@ -163,10 +160,9 @@ assert expr.lhs == d1 assert expr.rhs == c2 -run(testAffineMulExpr) - # CHECK-LABEL: TEST: testAffineModExpr +@run def testAffineModExpr(): with Context(): d1 = AffineDimExpr.get(1) @@ -183,10 +179,9 @@ assert expr.lhs == d1 assert expr.rhs == c2 -run(testAffineModExpr) - # CHECK-LABEL: TEST: testAffineFloorDivExpr +@run def testAffineFloorDivExpr(): with Context(): d1 = AffineDimExpr.get(1) @@ -198,10 +193,9 @@ assert expr.lhs == d1 assert expr.rhs == c2 -run(testAffineFloorDivExpr) - # CHECK-LABEL: TEST: testAffineCeilDivExpr +@run def testAffineCeilDivExpr(): with Context(): d1 = AffineDimExpr.get(1) @@ -213,10 +207,9 @@ assert expr.lhs == d1 assert expr.rhs == c2 -run(testAffineCeilDivExpr) - # CHECK-LABEL: TEST: testAffineExprSub +@run def testAffineExprSub(): with Context(): d1 = AffineDimExpr.get(1) @@ -232,9 +225,8 @@ # CHECK: -1 print(rhs.rhs) -run(testAffineExprSub) - - +# CHECK-LABEL: TEST: testClassHierarchy +@run def testClassHierarchy(): with Context(): d1 = AffineDimExpr.get(1) @@ -272,4 +264,28 @@ # CHECK: Cannot cast affine expression to AffineBinaryExpr print(e) -run(testClassHierarchy) +# CHECK-LABEL: TEST: testIsInstance +@run +def testIsInstance(): + with Context(): + d1 = AffineDimExpr.get(1) + c2 = AffineConstantExpr.get(2) + add = AffineAddExpr.get(d1, c2) + mul = AffineMulExpr.get(d1, c2) + + # CHECK: True + print(AffineDimExpr.isinstance(d1)) + # CHECK: False + print(AffineConstantExpr.isinstance(d1)) + # CHECK: True + print(AffineConstantExpr.isinstance(c2)) + # CHECK: False + print(AffineMulExpr.isinstance(c2)) + # CHECK: True + print(AffineAddExpr.isinstance(add)) + # CHECK: False + print(AffineMulExpr.isinstance(add)) + # CHECK: True + print(AffineMulExpr.isinstance(mul)) + # CHECK: False + print(AffineAddExpr.isinstance(mul)) diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -89,6 +89,18 @@ print("a1 == a2:", a1 == a2) +# CHECK-LABEL: TEST: testAttrIsInstance +@run +def testAttrIsInstance(): + with Context(): + a1 = Attribute.parse("42") + a2 = Attribute.parse("[42]") + assert IntegerAttr.isinstance(a1) + assert not IntegerAttr.isinstance(a2) + assert not ArrayAttr.isinstance(a1) + assert ArrayAttr.isinstance(a2) + + # CHECK-LABEL: TEST: testAttrEqDoesNotRaise @run def testAttrEqDoesNotRaise(): diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py --- a/mlir/test/python/ir/blocks.py +++ b/mlir/test/python/ir/blocks.py @@ -51,3 +51,22 @@ print(module.operation) # Ensure region back references are coherent. assert entry_block.region == middle_block.region == successor_block.region + + +# CHECK-LABEL: TEST: testFirstBlockCreation +# CHECK: func @test(%{{.*}}: f32) +# CHECK: return +@run +def testFirstBlockCreation(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + func = builtin.FuncOp("test", ([f32], [])) + entry_block = Block.create_at_start(func.operation.regions[0], [f32]) + with InsertionPoint(entry_block): + std.ReturnOp([]) + + print(module) + assert module.operation.verify() + assert func.body.blocks[0] == entry_block diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -11,10 +11,12 @@ f() gc.collect() assert Context._get_live_count() == 0 + return f # Verify iterator based traversal of the op/region/block hierarchy. # CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators +@run def testTraverseOpRegionBlockIterators(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -69,11 +71,9 @@ walk_operations("", op) -run(testTraverseOpRegionBlockIterators) - - # Verify index based traversal of the op/region/block hierarchy. # CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices +@run def testTraverseOpRegionBlockIndices(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -111,10 +111,30 @@ walk_operations("", module.operation) -run(testTraverseOpRegionBlockIndices) +# CHECK-LABEL: TEST: testBlockAndRegionOwners +@run +def testBlockAndRegionOwners(): + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" + builtin.module { + builtin.func @f() { + std.return + } + } + """, ctx) + + assert module.operation.regions[0].owner == module.operation + assert module.operation.regions[0].blocks[0].owner == module.operation + + func = module.body.operations[0] + assert func.operation.regions[0].owner == func + assert func.operation.regions[0].blocks[0].owner == func # CHECK-LABEL: TEST: testBlockArgumentList +@run def testBlockArgumentList(): with Context() as ctx: module = Module.parse( @@ -158,10 +178,8 @@ print("Type: ", t) -run(testBlockArgumentList) - - # CHECK-LABEL: TEST: testOperationOperands +@run def testOperationOperands(): with Context() as ctx: ctx.allow_unregistered_dialects = True @@ -181,10 +199,10 @@ print(f"Operand {i}, type {operand.type}") -run(testOperationOperands) # CHECK-LABEL: TEST: testOperationOperandsSlice +@run def testOperationOperandsSlice(): with Context() as ctx: ctx.allow_unregistered_dialects = True @@ -239,10 +257,10 @@ print(operand) -run(testOperationOperandsSlice) # CHECK-LABEL: TEST: testOperationOperandsSet +@run def testOperationOperandsSet(): with Context() as ctx, Location.unknown(ctx): ctx.allow_unregistered_dialects = True @@ -271,10 +289,10 @@ print(consumer.operands[0]) -run(testOperationOperandsSet) # CHECK-LABEL: TEST: testDetachedOperation +@run def testDetachedOperation(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -295,10 +313,8 @@ # TODO: Check successors once enough infra exists to do it properly. -run(testDetachedOperation) - - # CHECK-LABEL: TEST: testOperationInsertionPoint +@run def testOperationInsertionPoint(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -335,10 +351,8 @@ assert False, "expected insert of attached op to raise" -run(testOperationInsertionPoint) - - # CHECK-LABEL: TEST: testOperationWithRegion +@run def testOperationWithRegion(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -377,10 +391,8 @@ print(module) -run(testOperationWithRegion) - - # CHECK-LABEL: TEST: testOperationResultList +@run def testOperationResultList(): ctx = Context() module = Module.parse( @@ -407,10 +419,10 @@ print(f"Result type {t}") -run(testOperationResultList) # CHECK-LABEL: TEST: testOperationResultListSlice +@run def testOperationResultListSlice(): with Context() as ctx: ctx.allow_unregistered_dialects = True @@ -458,10 +470,10 @@ print(f"Result {res.result_number}, type {res.type}") -run(testOperationResultListSlice) # CHECK-LABEL: TEST: testOperationAttributes +@run def testOperationAttributes(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -506,10 +518,10 @@ assert False, "expected IndexError on accessing an out-of-bounds attribute" -run(testOperationAttributes) # CHECK-LABEL: TEST: testOperationPrint +@run def testOperationPrint(): ctx = Context() module = Module.parse( @@ -553,10 +565,10 @@ use_local_scope=True) -run(testOperationPrint) # CHECK-LABEL: TEST: testKnownOpView +@run def testKnownOpView(): with Context(), Location.unknown(): Context.current.allow_unregistered_dialects = True @@ -586,10 +598,8 @@ print(repr(custom)) -run(testKnownOpView) - - # CHECK-LABEL: TEST: testSingleResultProperty +@run def testSingleResultProperty(): with Context(), Location.unknown(): Context.current.allow_unregistered_dialects = True @@ -620,10 +630,8 @@ print(module.body.operations[2]) -run(testSingleResultProperty) - - # CHECK-LABEL: TEST: testPrintInvalidOperation +@run def testPrintInvalidOperation(): ctx = Context() with Location.unknown(ctx): @@ -639,10 +647,8 @@ print(f".verify = {module.operation.verify()}") -run(testPrintInvalidOperation) - - # CHECK-LABEL: TEST: testCreateWithInvalidAttributes +@run def testCreateWithInvalidAttributes(): ctx = Context() with Location.unknown(ctx): @@ -670,10 +676,8 @@ print(e) -run(testCreateWithInvalidAttributes) - - # CHECK-LABEL: TEST: testOperationName +@run def testOperationName(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -691,10 +695,8 @@ print(op.operation.name) -run(testOperationName) - - # CHECK-LABEL: TEST: testCapsuleConversions +@run def testCapsuleConversions(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -706,10 +708,8 @@ assert m2 is m -run(testCapsuleConversions) - - # CHECK-LABEL: TEST: testOperationErase +@run def testOperationErase(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -728,6 +728,3 @@ # Ensure we can create another operation Operation.create("custom.op2") - - -run(testOperationErase) diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -9,9 +9,11 @@ f() gc.collect() assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testCapsuleConversions +@run def testCapsuleConversions(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -24,10 +26,8 @@ assert value2 == value -run(testCapsuleConversions) - - # CHECK-LABEL: TEST: testOpResultOwner +@run def testOpResultOwner(): ctx = Context() ctx.allow_unregistered_dialects = True @@ -37,4 +37,21 @@ assert op.result.owner == op -run(testOpResultOwner) +# CHECK-LABEL: TEST: testValueIsInstance +@run +def testValueIsInstance(): + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" + func @foo(%arg0: f32) { + %0 = "some_dialect.some_op"() : () -> f64 + return + }""", ctx) + func = module.body.operations[0] + assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0]) + assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0]) + + op = func.regions[0].blocks[0].operations[0] + assert not BlockArgument.isinstance(op.results[0]) + assert OpResult.isinstance(op.results[0])