diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -92,7 +92,7 @@ MLIRContext *context = &this->getTypeConverter()->getContext(); Type llvmVoidType = LLVM::LLVMVoidType::get(context); - Type llvmPointerType = + LLVM::LLVMPointerType llvmPointerType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType); Type llvmInt8Type = IntegerType::get(context, 8); @@ -807,6 +807,22 @@ return success(); } +static Value bitAndAddrspaceCast(Location loc, + ConversionPatternRewriter &rewriter, + LLVM::LLVMPointerType destinationType, + Value sourcePtr, + LLVMTypeConverter &typeConverter) { + auto sourceTy = sourcePtr.getType().cast(); + if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) + sourcePtr = rewriter.create( + loc, + typeConverter.getPointerType(sourceTy.getElementType(), + destinationType.getAddressSpace()), + sourcePtr); + + return rewriter.create(loc, destinationType, sourcePtr); +} + LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -829,11 +845,13 @@ auto sizeBytes = rewriter.create(loc, getIndexType(), gepPtr); - auto src = rewriter.create( - loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc)); - auto dst = rewriter.create( - loc, llvmPointerType, - MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc)); + auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, + srcDesc.alignedPtr(rewriter, loc), + *getTypeConverter()); + auto dst = bitAndAddrspaceCast( + loc, rewriter, llvmPointerType, + MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc), + *getTypeConverter()); auto stream = adaptor.getAsyncDependencies().front(); memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream}); @@ -866,8 +884,9 @@ auto value = rewriter.create(loc, llvmInt32Type, adaptor.getValue()); - auto dst = rewriter.create( - loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc)); + auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, + dstDesc.alignedPtr(rewriter, loc), + *getTypeConverter()); auto stream = adaptor.getAsyncDependencies().front(); memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream}); diff --git a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir --- a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir @@ -7,8 +7,10 @@ // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate %t0 = gpu.wait async // CHECK: %[[size_bytes:.*]] = llvm.ptrtoint + // CHECK-NOT: llvm.addrspacecast // CHECK: %[[src:.*]] = llvm.bitcast - // CHECK: %[[dst:.*]] = llvm.bitcast + // CHECK: %[[addr_cast:.*]] = llvm.addrspacecast + // CHECK: %[[dst:.*]] = llvm.bitcast %[[addr_cast]] // CHECK: llvm.call @mgpuMemcpy(%[[dst]], %[[src]], %[[size_bytes]], %[[t0]]) %t1 = gpu.memcpy async [%t0] %dst, %src : memref<7xf32, 1>, memref<7xf32> // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]]) diff --git a/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir --- a/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir @@ -8,7 +8,8 @@ %t0 = gpu.wait async // CHECK: %[[size_bytes:.*]] = llvm.mlir.constant // CHECK: %[[value:.*]] = llvm.bitcast - // CHECK: %[[dst:.*]] = llvm.bitcast + // CHECK: %[[addr_cast:.*]] = llvm.addrspacecast + // CHECK: %[[dst:.*]] = llvm.bitcast %[[addr_cast]] // CHECK: llvm.call @mgpuMemset32(%[[dst]], %[[value]], %[[size_bytes]], %[[t0]]) %t1 = gpu.memset async [%t0] %dst, %value : memref<7xf32, 1>, f32 // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]])