diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -208,15 +208,12 @@ return success(); } -// Casts a strided element pointer to a vector pointer. The vector pointer -// would always be on address space 0, therefore addrspacecast shall be -// used when source/dst memrefs are not on address space 0. +// Casts a strided element pointer to a vector pointer. The vector pointer +// will be in the same address space as the incoming memref type. static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr, MemRefType memRefType, Type vt) { - auto pType = LLVM::LLVMPointerType::get(vt); - if (memRefType.getMemorySpace() == 0) - return rewriter.create(loc, pType, ptr); - return rewriter.create(loc, pType, ptr); + auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpace()); + return rewriter.create(loc, pType, ptr); } static LogicalResult diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1146,18 +1146,18 @@ // 1. Check address space for GEP is correct. // CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} : // CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr -// CHECK: %[[vecPtr:.*]] = llvm.addrspacecast %[[gep]] : -// CHECK-SAME: !llvm.ptr to !llvm.ptr> +// CHECK: %[[vecPtr:.*]] = llvm.bitcast %[[gep]] : +// CHECK-SAME: !llvm.ptr to !llvm.ptr, 3> // // 2. Check address space of the memref is correct. // CHECK: %[[c0:.*]] = constant 0 : index // CHECK: %[[DIM:.*]] = dim %{{.*}}, %[[c0]] : memref // -// 3. Check address apce for GEP is correct. +// 3. Check address space for GEP is correct. // CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} : // CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr -// CHECK: %[[vecPtr_b:.*]] = llvm.addrspacecast %[[gep_b]] : -// CHECK-SAME: !llvm.ptr to !llvm.ptr> +// CHECK: %[[vecPtr_b:.*]] = llvm.bitcast %[[gep_b]] : +// CHECK-SAME: !llvm.ptr to !llvm.ptr, 3> // -----