diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1463,7 +1463,8 @@ //===----------------------------------------------------------------------===// static MemRefType inferVectorTypeCastResultType(MemRefType t) { - return MemRefType::get({}, VectorType::get(t.getShape(), t.getElementType())); + return MemRefType::get({}, VectorType::get(t.getShape(), t.getElementType()), + {}, t.getMemorySpace()); } void TypeCastOp::build(OpBuilder &builder, OperationState &result, diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -418,6 +418,21 @@ // CHECK: llvm.mlir.constant(0 : index // CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> +func @vector_type_cast_non_zero_addrspace(%arg0: memref<8x8x8xf32, 3>) -> memref, 3> { + %0 = vector.type_cast %arg0: memref<8x8x8xf32, 3> to memref, 3> + return %0 : memref, 3> +} +// CHECK-LABEL: llvm.func @vector_type_cast_non_zero_addrspace +// CHECK: llvm.mlir.undef : !llvm<"{ [8 x [8 x <8 x float>]] addrspace(3)*, [8 x [8 x <8 x float>]] addrspace(3)*, i64 }"> +// CHECK: %[[allocated:.*]] = llvm.extractvalue {{.*}}[0] : !llvm<"{ float addrspace(3)*, float addrspace(3)*, i64, [3 x i64], [3 x i64] }"> +// CHECK: %[[allocatedBit:.*]] = llvm.bitcast %[[allocated]] : !llvm<"float addrspace(3)*"> to !llvm<"[8 x [8 x <8 x float>]] addrspace(3)*"> +// CHECK: llvm.insertvalue %[[allocatedBit]], {{.*}}[0] : !llvm<"{ [8 x [8 x <8 x float>]] addrspace(3)*, [8 x [8 x <8 x float>]] addrspace(3)*, i64 }"> +// CHECK: %[[aligned:.*]] = llvm.extractvalue {{.*}}[1] : !llvm<"{ float addrspace(3)*, float addrspace(3)*, i64, [3 x i64], [3 x i64] }"> +// CHECK: %[[alignedBit:.*]] = llvm.bitcast %[[aligned]] : !llvm<"float addrspace(3)*"> to !llvm<"[8 x [8 x <8 x float>]] addrspace(3)*"> +// CHECK: llvm.insertvalue %[[alignedBit]], {{.*}}[1] : !llvm<"{ [8 x [8 x <8 x float>]] addrspace(3)*, [8 x [8 x <8 x float>]] addrspace(3)*, i64 }"> +// CHECK: llvm.mlir.constant(0 : index +// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]] addrspace(3)*, [8 x [8 x <8 x float>]] addrspace(3)*, i64 }"> + func @vector_print_scalar_i32(%arg0: i32) { vector.print %arg0 : i32 return