diff --git a/mlir/examples/python/linalg_matmul.py b/mlir/examples/python/linalg_matmul.py --- a/mlir/examples/python/linalg_matmul.py +++ b/mlir/examples/python/linalg_matmul.py @@ -31,9 +31,9 @@ def build_matmul_buffers_func(func_name, m, k, n, dtype): - lhs_type = MemRefType.get(dtype, [m, k]) - rhs_type = MemRefType.get(dtype, [k, n]) - result_type = MemRefType.get(dtype, [m, n]) + lhs_type = MemRefType.get([m, k], dtype) + rhs_type = MemRefType.get([k, n], dtype) + result_type = MemRefType.get([m, n], dtype) # TODO: There should be a one-liner for this. func_type = FunctionType.get([lhs_type, rhs_type, result_type], []) _, entry = FuncOp(func_name, func_type) @@ -49,8 +49,6 @@ def build_matmul_tensors_func(func_name, m, k, n, dtype): - # TODO: MemRefType and TensorTypes should not have inverted dtype/shapes - # from each other. lhs_type = RankedTensorType.get([m, k], dtype) rhs_type = RankedTensorType.get([k, n], dtype) result_type = RankedTensorType.get([m, n], dtype) diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -2832,7 +2832,7 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyType &elementType, std::vector shape, + [](std::vector shape, PyType &elementType, std::vector layout, unsigned memorySpace, DefaultingPyLocation loc) { SmallVector maps; @@ -2856,7 +2856,7 @@ } return PyMemRefType(elementType.getContext(), t); }, - py::arg("element_type"), py::arg("shape"), + py::arg("shape"), py::arg("element_type"), py::arg("layout") = py::list(), py::arg("memory_space") = 0, py::arg("loc") = py::none(), "Create a memref type") .def_property_readonly("layout", &PyMemRefType::getLayout, diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -326,7 +326,7 @@ f32 = F32Type.get() shape = [2, 3] loc = Location.unknown() - memref = MemRefType.get(f32, shape, memory_space=2) + memref = MemRefType.get(shape, f32, memory_space=2) # CHECK: memref type: memref<2x3xf32, 2> print("memref type:", memref) # CHECK: number of affine layout maps: 0 @@ -335,7 +335,7 @@ print("memory space:", memref.memory_space) layout = AffineMap.get_permutation([1, 0]) - memref_layout = MemRefType.get(f32, shape, [layout]) + memref_layout = MemRefType.get(shape, f32, [layout]) # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>> print("memref type:", memref_layout) assert len(memref_layout.layout) == 1 @@ -346,7 +346,7 @@ none = NoneType.get() try: - memref_invalid = MemRefType.get(none, shape) + memref_invalid = MemRefType.get(shape, none) except ValueError as e: # CHECK: invalid 'Type(none)' and expected floating point, integer, vector # CHECK: or complex type.