diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2255,6 +2255,10 @@ // Mapping of PyAttribute. //---------------------------------------------------------------------------- py::class_(m, "Attribute") + // Delegate to the PyAttribute copy constructor, which will also lifetime + // extend the backing context which owns the MlirAttribute. + .def(py::init(), py::arg("cast_from_type"), + "Casts the passed attribute to the generic Attribute") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) @@ -2358,6 +2362,10 @@ // Mapping of PyType. //---------------------------------------------------------------------------- py::class_(m, "Type") + // Delegate to the PyType copy constructor, which will also lifetime + // extend the backing context which owns the MlirType. + .def(py::init(), py::arg("cast_from_type"), + "Casts the passed type to the generic Type") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) .def_static( diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -8,9 +8,11 @@ f() gc.collect() assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testParsePrint +@run def testParsePrint(): with Context() as ctx: t = Attribute.parse('"hello"') @@ -22,12 +24,11 @@ # CHECK: Attribute("hello") print(repr(t)) -run(testParsePrint) - # CHECK-LABEL: TEST: testParseError # TODO: Hook the diagnostic manager to capture a more meaningful error # message. +@run def testParseError(): with Context(): try: @@ -38,10 +39,9 @@ else: print("Exception not produced") -run(testParseError) - # CHECK-LABEL: TEST: testAttrEq +@run def testAttrEq(): with Context(): a1 = Attribute.parse('"attr1"') @@ -56,10 +56,19 @@ # CHECK: a1 == None: False print("a1 == None:", a1 == None) -run(testAttrEq) + +# CHECK-LABEL: TEST: testAttrCast +@run +def testAttrCast(): + with Context(): + a1 = Attribute.parse('"attr1"') + a2 = Attribute(a1) + # CHECK: a1 == a2: True + print("a1 == a2:", a1 == a2) # CHECK-LABEL: TEST: testAttrEqDoesNotRaise +@run def testAttrEqDoesNotRaise(): with Context(): a1 = Attribute.parse('"attr1"') @@ -71,10 +80,9 @@ # CHECK: True print(a1 != None) -run(testAttrEqDoesNotRaise) - # CHECK-LABEL: TEST: testAttrCapsule +@run def testAttrCapsule(): with Context() as ctx: a1 = Attribute.parse('"attr1"') @@ -85,10 +93,9 @@ assert a2 == a1 assert a2.context is ctx -run(testAttrCapsule) - # CHECK-LABEL: TEST: testStandardAttrCasts +@run def testStandardAttrCasts(): with Context(): a1 = Attribute.parse('"attr1"') @@ -104,10 +111,9 @@ else: print("Exception not produced") -run(testStandardAttrCasts) - # CHECK-LABEL: TEST: testAffineMapAttr +@run def testAffineMapAttr(): with Context() as ctx: d0 = AffineDimExpr.get(0) @@ -122,10 +128,9 @@ attr_parsed = Attribute.parse(str(attr_built)) assert attr_built == attr_parsed -run(testAffineMapAttr) - # CHECK-LABEL: TEST: testFloatAttr +@run def testFloatAttr(): with Context(), Location.unknown(): fattr = FloatAttr(Attribute.parse("42.0 : f32")) @@ -149,10 +154,9 @@ else: print("Exception not produced") -run(testFloatAttr) - # CHECK-LABEL: TEST: testIntegerAttr +@run def testIntegerAttr(): with Context() as ctx: iattr = IntegerAttr(Attribute.parse("42")) @@ -166,10 +170,9 @@ print("default_get:", IntegerAttr.get( IntegerType.get_signless(32), 42)) -run(testIntegerAttr) - # CHECK-LABEL: TEST: testBoolAttr +@run def testBoolAttr(): with Context() as ctx: battr = BoolAttr(Attribute.parse("true")) @@ -180,10 +183,9 @@ # CHECK: default_get: true print("default_get:", BoolAttr.get(True)) -run(testBoolAttr) - # CHECK-LABEL: TEST: testFlatSymbolRefAttr +@run def testFlatSymbolRefAttr(): with Context() as ctx: sattr = FlatSymbolRefAttr(Attribute.parse('@symbol')) @@ -194,10 +196,9 @@ # CHECK: default_get: @foobar print("default_get:", FlatSymbolRefAttr.get("foobar")) -run(testFlatSymbolRefAttr) - # CHECK-LABEL: TEST: testStringAttr +@run def testStringAttr(): with Context() as ctx: sattr = StringAttr(Attribute.parse('"stringattr"')) @@ -211,10 +212,9 @@ print("typed_get:", StringAttr.get_typed( IntegerType.get_signless(32), "12345")) -run(testStringAttr) - # CHECK-LABEL: TEST: testNamedAttr +@run def testNamedAttr(): with Context(): a = Attribute.parse('"stringattr"') @@ -226,10 +226,9 @@ # CHECK: named: NamedAttribute(foobar="stringattr") print("named:", named) -run(testNamedAttr) - # CHECK-LABEL: TEST: testDenseIntAttr +@run def testDenseIntAttr(): with Context(): raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>") @@ -263,10 +262,8 @@ print(ShapedType(a.type).element_type) -run(testDenseIntAttr) - - # CHECK-LABEL: TEST: testDenseFPAttr +@run def testDenseFPAttr(): with Context(): raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>") @@ -286,10 +283,8 @@ print(ShapedType(a.type).element_type) -run(testDenseFPAttr) - - # CHECK-LABEL: TEST: testDictAttr +@run def testDictAttr(): with Context(): dict_attr = { @@ -327,10 +322,8 @@ assert False, "expected IndexError on accessing an out-of-bounds attribute" - -run(testDictAttr) - # CHECK-LABEL: TEST: testTypeAttr +@run def testTypeAttr(): with Context(): raw = Attribute.parse("vector<4xf32>") @@ -341,10 +334,8 @@ print(ShapedType(type_attr.value).element_type) -run(testTypeAttr) - - # CHECK-LABEL: TEST: testArrayAttr +@run def testArrayAttr(): with Context(): raw = Attribute.parse("[42, true, vector<4xf32>]") @@ -391,5 +382,4 @@ except RuntimeError as e: # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute print("Error: ", e) -run(testArrayAttr) diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -8,9 +8,11 @@ f() gc.collect() assert Context._get_live_count() == 0 + return f # CHECK-LABEL: TEST: testParsePrint +@run def testParsePrint(): ctx = Context() t = Type.parse("i32", ctx) @@ -22,12 +24,11 @@ # CHECK: Type(i32) print(repr(t)) -run(testParsePrint) - # CHECK-LABEL: TEST: testParseError # TODO: Hook the diagnostic manager to capture a more meaningful error # message. +@run def testParseError(): ctx = Context() try: @@ -38,10 +39,9 @@ else: print("Exception not produced") -run(testParseError) - # CHECK-LABEL: TEST: testTypeEq +@run def testTypeEq(): ctx = Context() t1 = Type.parse("i32", ctx) @@ -56,10 +56,19 @@ # CHECK: t1 == None: False print("t1 == None:", t1 == None) -run(testTypeEq) + +# CHECK-LABEL: TEST: testTypeCast +@run +def testTypeCast(): + ctx = Context() + t1 = Type.parse("i32", ctx) + t2 = Type(t1) + # CHECK: t1 == t2: True + print("t1 == t2:", t1 == t2) # CHECK-LABEL: TEST: testTypeIsInstance +@run def testTypeIsInstance(): ctx = Context() t1 = Type.parse("i32", ctx) @@ -71,10 +80,9 @@ # CHECK: True print(F32Type.isinstance(t2)) -run(testTypeIsInstance) - # CHECK-LABEL: TEST: testTypeEqDoesNotRaise +@run def testTypeEqDoesNotRaise(): ctx = Context() t1 = Type.parse("i32", ctx) @@ -86,10 +94,9 @@ # CHECK: True print(t1 != None) -run(testTypeEqDoesNotRaise) - # CHECK-LABEL: TEST: testTypeCapsule +@run def testTypeCapsule(): with Context() as ctx: t1 = Type.parse("i32", ctx) @@ -100,10 +107,9 @@ assert t2 == t1 assert t2.context is ctx -run(testTypeCapsule) - # CHECK-LABEL: TEST: testStandardTypeCasts +@run def testStandardTypeCasts(): ctx = Context() t1 = Type.parse("i32", ctx) @@ -119,10 +125,9 @@ else: print("Exception not produced") -run(testStandardTypeCasts) - # CHECK-LABEL: TEST: testIntegerType +@run def testIntegerType(): with Context() as ctx: i32 = IntegerType(Type.parse("i32")) @@ -158,17 +163,16 @@ # CHECK: unsigned: ui64 print("unsigned:", IntegerType.get_unsigned(64)) -run(testIntegerType) - # CHECK-LABEL: TEST: testIndexType +@run def testIndexType(): with Context() as ctx: # CHECK: index type: index print("index type:", IndexType.get()) -run(testIndexType) # CHECK-LABEL: TEST: testFloatType +@run def testFloatType(): with Context(): # CHECK: float: bf16 @@ -180,17 +184,17 @@ # CHECK: float: f64 print("float:", F64Type.get()) -run(testFloatType) # CHECK-LABEL: TEST: testNoneType +@run def testNoneType(): with Context(): # CHECK: none type: none print("none type:", NoneType.get()) -run(testNoneType) # CHECK-LABEL: TEST: testComplexType +@run def testComplexType(): with Context() as ctx: complex_i32 = ComplexType(Type.parse("complex")) @@ -210,13 +214,12 @@ else: print("Exception not produced") -run(testComplexType) - # CHECK-LABEL: TEST: testConcreteShapedType # Shaped type is not a kind of builtin types, it is the base class for vectors, # memrefs and tensors, so this test case uses an instance of vector to test the # shaped type. The class hierarchy is preserved on the python side. +@run def testConcreteShapedType(): with Context() as ctx: vector = VectorType(Type.parse("vector<2x3xf32>")) @@ -239,20 +242,20 @@ # CHECK: isinstance(ShapedType): True print("isinstance(ShapedType):", isinstance(vector, ShapedType)) -run(testConcreteShapedType) # CHECK-LABEL: TEST: testAbstractShapedType # Tests that ShapedType operates as an abstract base class of a concrete # shaped type (using vector as an example). +@run def testAbstractShapedType(): ctx = Context() vector = ShapedType(Type.parse("vector<2x3xf32>", ctx)) # CHECK: element type: f32 print("element type:", vector.element_type) -run(testAbstractShapedType) # CHECK-LABEL: TEST: testVectorType +@run def testVectorType(): with Context(), Location.unknown(): f32 = F32Type.get() @@ -269,9 +272,9 @@ else: print("Exception not produced") -run(testVectorType) # CHECK-LABEL: TEST: testRankedTensorType +@run def testRankedTensorType(): with Context(), Location.unknown(): f32 = F32Type.get() @@ -291,9 +294,9 @@ else: print("Exception not produced") -run(testRankedTensorType) # CHECK-LABEL: TEST: testUnrankedTensorType +@run def testUnrankedTensorType(): with Context(), Location.unknown(): f32 = F32Type.get() @@ -333,9 +336,9 @@ else: print("Exception not produced") -run(testUnrankedTensorType) # CHECK-LABEL: TEST: testMemRefType +@run def testMemRefType(): with Context(), Location.unknown(): f32 = F32Type.get() @@ -369,9 +372,9 @@ else: print("Exception not produced") -run(testMemRefType) # CHECK-LABEL: TEST: testUnrankedMemRefType +@run def testUnrankedMemRefType(): with Context(), Location.unknown(): f32 = F32Type.get() @@ -411,9 +414,9 @@ else: print("Exception not produced") -run(testUnrankedMemRefType) # CHECK-LABEL: TEST: testTupleType +@run def testTupleType(): with Context() as ctx: i32 = IntegerType(Type.parse("i32")) @@ -428,10 +431,9 @@ # CHECK: pos-th type in the tuple type: f32 print("pos-th type in the tuple type:", tuple_type.get_type(1)) -run(testTupleType) - # CHECK-LABEL: TEST: testFunctionType +@run def testFunctionType(): with Context() as ctx: input_types = [IntegerType.get_signless(32), @@ -442,6 +444,3 @@ print("INPUTS:", func.inputs) # CHECK: RESULTS: [Type(index)] print("RESULTS:", func.results) - - -run(testFunctionType)