diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -406,7 +406,7 @@ /// Ranked MemRef Type subclass - MemRefType. class PyMemRefType : public PyConcreteType { public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; static constexpr const char *pyClassName = "MemRefType"; using PyConcreteType::PyConcreteType; diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py --- a/mlir/python/mlir/dialects/_memref_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_ops_ext.py @@ -33,5 +33,5 @@ memref_resolved = _get_op_result_or_value(memref) indices_resolved = [] if indices is None else _get_op_results_or_values( indices) - return_type = memref_resolved.type + return_type = MemRefType(memref_resolved.type).element_type super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip) diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py --- a/mlir/test/python/dialects/memref.py +++ b/mlir/test/python/dialects/memref.py @@ -71,3 +71,4 @@ # CHECK: func @f1(%[[ARG0:.*]]: memref, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] print(module) + assert module.operation.verify()