diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -225,7 +225,13 @@ * same context as element type. The type is owned by the context. */ MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet( MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, - MlirAttribute const *affineMaps, unsigned memorySpace); + MlirAffineMap const *affineMaps, unsigned memorySpace); + +/** Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o + * illegal arguments, emitting appropriate diagnostics. */ +MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked( + MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, + MlirAffineMap const *affineMaps, unsigned memorySpace, MlirLocation loc); /** Creates a MemRef type with the given rank, shape, memory space and element * type in the same context as the element type. The type has no affine maps, diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -2535,6 +2535,8 @@ } }; +class PyMemRefLayoutMapList; + /// Ranked MemRef Type subclass - MemRefType. class PyMemRefType : public PyConcreteType { public: @@ -2542,16 +2544,22 @@ static constexpr const char *pyClassName = "MemRefType"; using PyConcreteType::PyConcreteType; + PyMemRefLayoutMapList getLayout(); + static void bindDerived(ClassTy &c) { - // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding - // once the affine map binding is completed. c.def_static( - "get_contiguous_memref", - // TODO: Make the location optional and create a default location. + "get", [](PyType &elementType, std::vector shape, - unsigned memorySpace, DefaultingPyLocation loc) { - MlirType t = mlirMemRefTypeContiguousGetChecked( - elementType, shape.size(), shape.data(), memorySpace, loc); + std::vector layout, unsigned memorySpace, + DefaultingPyLocation loc) { + SmallVector maps; + maps.reserve(layout.size()); + for (PyAffineMap &map : layout) + maps.push_back(map); + + MlirType t = mlirMemRefTypeGetChecked(elementType, shape.size(), + shape.data(), maps.size(), + maps.data(), memorySpace, loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2565,15 +2573,11 @@ } return PyMemRefType(elementType.getContext(), t); }, - py::arg("element_type"), py::arg("shape"), py::arg("memory_space"), + py::arg("element_type"), py::arg("shape"), + py::arg("layout") = py::list(), py::arg("memory_space") = 0, py::arg("loc") = py::none(), "Create a memref type") - .def_property_readonly( - "num_affine_maps", - [](PyMemRefType &self) -> intptr_t { - return mlirMemRefTypeGetNumAffineMaps(self); - }, - "Returns the number of affine layout maps in the given MemRef " - "type.") + .def_property_readonly("layout", &PyMemRefType::getLayout, + "The list of layout maps of the MemRef type.") .def_property_readonly( "memory_space", [](PyMemRefType &self) -> unsigned { @@ -2583,6 +2587,41 @@ } }; +/// A list of affine layout maps in a memref type. Internally, these are stored +/// as consecutive elements, random access is cheap. Both the type and the maps +/// are owned by the context, no need to worry about lifetime extension. +class PyMemRefLayoutMapList + : public Sliceable { +public: + static constexpr const char *pyClassName = "MemRefLayoutMapList"; + + PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length, + step), + memref(type) {} + + intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); } + + PyAffineMap getElement(intptr_t index) { + return PyAffineMap(memref.getContext(), + mlirMemRefTypeGetAffineMap(memref, index)); + } + + PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyMemRefLayoutMapList(memref, startIndex, length, step); + } + +private: + PyMemRefType memref; +}; + +PyMemRefLayoutMapList PyMemRefType::getLayout() { + return PyMemRefLayoutMapList(*this); +} + /// Unranked MemRef Type subclass - UnrankedMemRefType. class PyUnrankedMemRefType : public PyConcreteType { @@ -3631,6 +3670,7 @@ PyRankedTensorType::bind(m); PyUnrankedTensorType::bind(m); PyMemRefType::bind(m); + PyMemRefLayoutMapList::bind(m); PyUnrankedMemRefType::bind(m); PyTupleType::bind(m); PyFunctionType::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -231,6 +231,17 @@ unwrap(elementType), maps, memorySpace)); } +MlirType mlirMemRefTypeGetChecked(MlirType elementType, intptr_t rank, + const int64_t *shape, intptr_t numMaps, + MlirAffineMap const *affineMaps, + unsigned memorySpace, MlirLocation loc) { + SmallVector maps; + (void)unwrapList(numMaps, affineMaps, maps); + return wrap(MemRefType::getChecked( + unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), + unwrap(elementType), maps, memorySpace)); +} + MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, const int64_t *shape, unsigned memorySpace) { diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -326,17 +326,27 @@ f32 = F32Type.get() shape = [2, 3] loc = Location.unknown() - memref = MemRefType.get_contiguous_memref(f32, shape, 2) + memref = MemRefType.get(f32, shape, memory_space=2) # CHECK: memref type: memref<2x3xf32, 2> print("memref type:", memref) # CHECK: number of affine layout maps: 0 - print("number of affine layout maps:", memref.num_affine_maps) + print("number of affine layout maps:", len(memref.layout)) # CHECK: memory space: 2 print("memory space:", memref.memory_space) + layout = AffineMap.get_permutation([1, 0]) + memref_layout = MemRefType.get(f32, shape, [layout]) + # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>> + print("memref type:", memref_layout) + assert len(memref_layout.layout) == 1 + # CHECK: memref layout: (d0, d1) -> (d1, d0) + print("memref layout:", memref_layout.layout[0]) + # CHECK: memory space: 0 + print("memory space:", memref_layout.memory_space) + none = NoneType.get() try: - memref_invalid = MemRefType.get_contiguous_memref(none, shape, 2) + memref_invalid = MemRefType.get(none, shape) except ValueError as e: # CHECK: invalid 'Type(none)' and expected floating point, integer, vector # CHECK: or complex type.