diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -50,6 +50,23 @@ return rewriter.create(loc, bumped, mod); } +static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, + Location loc, Value allocatedPtr, + MemRefType memRefType, Type elementPtrType, + LLVMTypeConverter &typeConverter) { + auto allocatedPtrTy = allocatedPtr.getType().cast(); + if (allocatedPtrTy.getAddressSpace() != memRefType.getMemorySpaceAsInt()) + allocatedPtr = rewriter.create( + loc, + LLVM::LLVMPointerType::get(allocatedPtrTy.getElementType(), + memRefType.getMemorySpaceAsInt()), + allocatedPtr); + + allocatedPtr = + rewriter.create(loc, elementPtrType, allocatedPtr); + return allocatedPtr; +} + std::tuple AllocationOpLLVMLowering::allocateBufferManuallyAlign( ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op, Value alignment) const { @@ -64,8 +81,10 @@ LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn( getTypeConverter(), op->getParentOfType(), getIndexType()); auto results = rewriter.create(loc, allocFuncOp, sizeBytes); - Value allocatedPtr = rewriter.create(loc, elementPtrType, - results.getResult()); + + Value allocatedPtr = + castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, + elementPtrType, *getTypeConverter()); Value alignedPtr = allocatedPtr; if (alignment) { @@ -126,10 +145,9 @@ getTypeConverter(), op->getParentOfType(), getIndexType()); auto results = rewriter.create( loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes})); - Value allocatedPtr = rewriter.create(loc, elementPtrType, - results.getResult()); - return allocatedPtr; + return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, + elementPtrType, *getTypeConverter()); } LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite( diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -182,6 +182,11 @@ // CHECK-LABEL: func @address_space( func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) { + // CHECK: %[[MEMORY:.*]] = llvm.call @malloc(%{{.*}}) + // CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[MEMORY]] : !llvm.ptr to !llvm.ptr + // CHECK: %[[BCAST:.*]] = llvm.bitcast %[[CAST]] + // CHECK: llvm.insertvalue %[[BCAST]], %{{[[:alnum:]]+}}[0] + // CHECK: llvm.insertvalue %[[BCAST]], %{{[[:alnum:]]+}}[1] %0 = memref.alloc() : memref<32xf32, affine_map<(d0) -> (d0)>, 5> %1 = arith.constant 7 : index // CHECK: llvm.load %{{.*}} : !llvm.ptr