diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -270,12 +270,11 @@ PyArrayAttributeIterator &dunderIter() { return *this; } - PyAttribute dunderNext() { + MlirAttribute dunderNext() { // TODO: Throw is an inefficient way to stop iteration. if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) throw py::stop_iteration(); - return PyAttribute(attr.getContext(), - mlirArrayAttrGetElement(attr.get(), nextIndex++)); + return mlirArrayAttrGetElement(attr.get(), nextIndex++); } static void bind(py::module &m) { @@ -290,8 +289,8 @@ int nextIndex = 0; }; - PyAttribute getItem(intptr_t i) { - return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); + MlirAttribute getItem(intptr_t i) { + return mlirArrayAttrGetElement(*this, i); } static void bindDerived(ClassTy &c) { @@ -843,13 +842,11 @@ return mlirDenseElementsAttrIsSplat(self); }) .def("get_splat_value", - [](PyDenseElementsAttribute &self) -> PyAttribute { - if (!mlirDenseElementsAttrIsSplat(self)) { + [](PyDenseElementsAttribute &self) { + if (!mlirDenseElementsAttrIsSplat(self)) throw py::value_error( "get_splat_value called on a non-splat attribute"); - } - return PyAttribute(self.getContext(), - mlirDenseElementsAttrGetSplatValue(self)); + return mlirDenseElementsAttrGetSplatValue(self); }) .def_buffer(&PyDenseElementsAttribute::accessBuffer); } @@ -1018,10 +1015,9 @@ c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { MlirAttribute attr = mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); - if (mlirAttributeIsNull(attr)) { + if (mlirAttributeIsNull(attr)) throw py::key_error("attempt to access a non-existent attribute"); - } - return PyAttribute(self.getContext(), attr); + return attr; }); c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { if (index < 0 || index >= self.dunderLen()) { diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1908,19 +1908,17 @@ erase(py::cast(operation)); } -PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { +MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { operation->checkValid(); symbol.getOperation().checkValid(); MlirAttribute symbolAttr = mlirOperationGetAttributeByName( symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); if (mlirAttributeIsNull(symbolAttr)) throw py::value_error("Expected operation to have a symbol name."); - return PyAttribute( - symbol.getOperation().getContext(), - mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); + return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); } -PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { +MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { // Op must already be a symbol. PyOperation &operation = symbol.getOperation(); operation.checkValid(); @@ -1929,7 +1927,7 @@ mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingNameAttr)) throw py::value_error("Expected operation to have a symbol name."); - return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); + return existingNameAttr; } void PySymbolTable::setSymbolName(PyOperationBase &symbol, @@ -1947,7 +1945,7 @@ mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); } -PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { +MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { PyOperation &operation = symbol.getOperation(); operation.checkValid(); MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); @@ -1955,7 +1953,7 @@ mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingVisAttr)) throw py::value_error("Expected operation to have a symbol visibility."); - return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); + return existingVisAttr; } void PySymbolTable::setVisibility(PyOperationBase &symbol, @@ -2287,13 +2285,13 @@ PyOpAttributeMap(PyOperationRef operation) : operation(std::move(operation)) {} - PyAttribute dunderGetItemNamed(const std::string &name) { + MlirAttribute dunderGetItemNamed(const std::string &name) { MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { throw py::key_error("attempt to access a non-existent attribute"); } - return PyAttribute(operation->getContext(), attr); + return attr; } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -1174,14 +1174,14 @@ /// Inserts the given operation into the symbol table. The operation must have /// the symbol trait. - PyAttribute insert(PyOperationBase &symbol); + MlirAttribute insert(PyOperationBase &symbol); /// Gets and sets the name of a symbol op. - static PyAttribute getSymbolName(PyOperationBase &symbol); + static MlirAttribute getSymbolName(PyOperationBase &symbol); static void setSymbolName(PyOperationBase &symbol, const std::string &name); /// Gets and sets the visibility of a symbol op. - static PyAttribute getVisibility(PyOperationBase &symbol); + static MlirAttribute getVisibility(PyOperationBase &symbol); static void setVisibility(PyOperationBase &symbol, const std::string &visibility); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -505,11 +505,12 @@ py::arg("encoding") = py::none(), py::arg("loc") = py::none(), "Create a ranked tensor type"); c.def_property_readonly( - "encoding", [](PyRankedTensorType &self) -> std::optional { + "encoding", + [](PyRankedTensorType &self) -> std::optional { MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); if (mlirAttributeIsNull(encoding)) return std::nullopt; - return PyAttribute(self.getContext(), encoding); + return encoding; }); } }; @@ -570,9 +571,8 @@ py::arg("loc") = py::none(), "Create a memref type") .def_property_readonly( "layout", - [](PyMemRefType &self) -> PyAttribute { - MlirAttribute layout = mlirMemRefTypeGetLayout(self); - return PyAttribute(self.getContext(), layout); + [](PyMemRefType &self) -> MlirAttribute { + return mlirMemRefTypeGetLayout(self); }, "The layout of the MemRef type.") .def_property_readonly( @@ -584,9 +584,11 @@ "The layout of the MemRef type as an affine map.") .def_property_readonly( "memory_space", - [](PyMemRefType &self) -> PyAttribute { + [](PyMemRefType &self) -> std::optional { MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); + if (mlirAttributeIsNull(a)) + return std::nullopt; + return a; }, "Returns the memory space of the given MemRef type."); } @@ -622,9 +624,11 @@ py::arg("loc") = py::none(), "Create a unranked memref type") .def_property_readonly( "memory_space", - [](PyUnrankedMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); + [](PyUnrankedMemRefType &self) -> std::optional { + MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); + if (mlirAttributeIsNull(a)) + return std::nullopt; + return a; }, "Returns the memory space of the given Unranked MemRef type."); } diff --git a/mlir/test/python/ir/array_attributes.py b/mlir/test/python/ir/array_attributes.py --- a/mlir/test/python/ir/array_attributes.py +++ b/mlir/test/python/ir/array_attributes.py @@ -47,7 +47,11 @@ print(attr) # CHECK: is_splat: True print("is_splat:", attr.is_splat) - assert attr.get_splat_value() == element + + # CHECK: splat_value: IntegerAttr(555 : i32) + splat_value = attr.get_splat_value() + print("splat_value:", repr(splat_value)) + assert splat_value == element # CHECK-LABEL: TEST: testGetDenseElementsSplatFloat diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -441,11 +441,11 @@ assert len(a) == 2 - # CHECK: 42 : i32 - print(a["integerattr"]) + # CHECK: integerattr: IntegerAttr(42 : i32) + print("integerattr:", repr(a["integerattr"])) - # CHECK: "string" - print(a["stringattr"]) + # CHECK: stringattr: StringAttr("string") + print("stringattr:", repr(a["stringattr"])) # CHECK: True print("stringattr" in a) @@ -488,14 +488,14 @@ @run def testArrayAttr(): with Context(): - raw = Attribute.parse("[42, true, vector<4xf32>]") - # CHECK: attr: [42, true, vector<4xf32>] - print("raw attr:", raw) - # CHECK: - 42 - # CHECK: - true - # CHECK: - vector<4xf32> - for attr in ArrayAttr(raw): - print("- ", attr) + arr = Attribute.parse("[42, true, vector<4xf32>]") + # CHECK: arr: [42, true, vector<4xf32>] + print("arr:", arr) + # CHECK: - IntegerAttr(42 : i64) + # CHECK: - BoolAttr(true) + # CHECK: - TypeAttr(vector<4xf32>) + for attr in arr: + print("- ", repr(attr)) with Context(): intAttr = Attribute.parse("42") @@ -504,18 +504,18 @@ raw = ArrayAttr.get([vecAttr, boolAttr, intAttr]) # CHECK: attr: [vector<4xf32>, true, 42] print("raw attr:", raw) - # CHECK: - vector<4xf32> - # CHECK: - true - # CHECK: - 42 - arr = ArrayAttr(raw) + # CHECK: - TypeAttr(vector<4xf32>) + # CHECK: - BoolAttr(true + # CHECK: - IntegerAttr(42 : i64) + arr = raw for attr in arr: - print("- ", attr) - # CHECK: attr[0]: vector<4xf32> - print("attr[0]:", arr[0]) - # CHECK: attr[1]: true - print("attr[1]:", arr[1]) - # CHECK: attr[2]: 42 - print("attr[2]:", arr[2]) + print("- ", repr(attr)) + # CHECK: attr[0]: TypeAttr(vector<4xf32>) + print("attr[0]:", repr(arr[0])) + # CHECK: attr[1]: BoolAttr(true) + print("attr[1]:", repr(arr[1])) + # CHECK: attr[2]: IntegerAttr(42 : i64) + print("attr[2]:", repr(arr[2])) try: print("attr[3]:", arr[3]) except IndexError as e: diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -329,11 +329,13 @@ else: print("Exception not produced") + tensor = RankedTensorType.get(shape, f32, StringAttr.get("encoding")) + assert tensor.shape == shape + assert tensor.encoding.value == "encoding" + # Encoding should be None. assert RankedTensorType.get(shape, f32).encoding is None - tensor = RankedTensorType.get(shape, f32) - assert tensor.shape == shape # CHECK-LABEL: TEST: testUnrankedTensorType @@ -388,12 +390,12 @@ memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) # CHECK: memref type: memref<2x3xf32, 2> print("memref type:", memref_f32) - # CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)> - print("memref layout:", memref_f32.layout) + # CHECK: memref layout: AffineMapAttr(affine_map<(d0, d1) -> (d0, d1)>) + print("memref layout:", repr(memref_f32.layout)) # CHECK: memref affine map: (d0, d1) -> (d0, d1) print("memref affine map:", memref_f32.affine_map) - # CHECK: memory space: 2 - print("memory space:", memref_f32.memory_space) + # CHECK: memory space: IntegerAttr(2 : i64) + print("memory space:", repr(memref_f32.memory_space)) layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0])) memref_layout = MemRefType.get(shape, f32, layout=layout) @@ -403,7 +405,7 @@ print("memref layout:", memref_layout.layout) # CHECK: memref affine map: (d0, d1) -> (d1, d0) print("memref affine map:", memref_layout.affine_map) - # CHECK: memory space: <> + # CHECK: memory space: None print("memory space:", memref_layout.memory_space) none = NoneType.get() @@ -428,6 +430,8 @@ unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2")) # CHECK: unranked memref type: memref<*xf32, 2> print("unranked memref type:", unranked_memref) + # CHECK: memory space: IntegerAttr(2 : i64) + print("memory space:", repr(unranked_memref.memory_space)) try: invalid_rank = unranked_memref.rank except ValueError as e: diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -532,9 +532,9 @@ ) op = module.body.operations[0] assert len(op.attributes) == 3 - iattr = IntegerAttr(op.attributes["some.attribute"]) - fattr = FloatAttr(op.attributes["other.attribute"]) - sattr = StringAttr(op.attributes["dependent"]) + iattr = op.attributes["some.attribute"] + fattr = op.attributes["other.attribute"] + sattr = op.attributes["dependent"] # CHECK: Attribute type i8, value 1 print(f"Attribute type {iattr.type}, value {iattr.value}") # CHECK: Attribute type f64, value 3.0 diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py --- a/mlir/test/python/ir/symbol_table.py +++ b/mlir/test/python/ir/symbol_table.py @@ -75,6 +75,7 @@ updated_name = symbol_table.insert(foo2) assert foo2.name.value != "foo" assert foo2.name == updated_name + assert isinstance(updated_name, StringAttr) # CHECK: module # CHECK: func private @foo() @@ -112,10 +113,10 @@ # CHECK: call @bam() # CHECK: func private @bam print(m) - # CHECK: Foo symbol: "foo" - # CHECK: Bar symbol: "bam" - print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}") - print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}") + # CHECK: Foo symbol: StringAttr("foo") + # CHECK: Bar symbol: StringAttr("bam") + print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}") + print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}") # CHECK-LABEL: testSymbolTableVisibility @@ -130,8 +131,8 @@ """ ) foo = m.operation.regions[0].blocks[0].operations[0] - # CHECK: Existing visibility: "private" - print(f"Existing visibility: {SymbolTable.get_visibility(foo)}") + # CHECK: Existing visibility: StringAttr("private") + print(f"Existing visibility: {repr(SymbolTable.get_visibility(foo))}") SymbolTable.set_visibility(foo, "public") # CHECK: func public @foo print(m)