diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1650,6 +1650,9 @@ .Case([&](UnrankedMemRefType memrefTy) { os << "memref<*x"; printType(memrefTy.getElementType()); + // Only print the memory space if it is the non-default one. + if (memrefTy.getMemorySpace()) + os << ", " << memrefTy.getMemorySpace(); os << '>'; }) .Case([&](ComplexType complexTy) { diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -703,6 +703,11 @@ return } +// Check that unranked memrefs with non-default memory space roundtrip +// properly. +// CHECK-LABEL: @unranked_memref_roundtrip(memref<*xf32, 4>) +func @unranked_memref_roundtrip(memref<*xf32, 4>) + // CHECK-LABEL: func @memref_view(%arg0 func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<2048xi8> diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1076,7 +1076,7 @@ // incompatible memory space func @invalid_memref_cast() { %0 = alloc() : memref<2x5xf32, 0> - // expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xf32>' are cast incompatible}} + // expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xf32, 1>' are cast incompatible}} %1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xf32, 1> return }