diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -382,14 +382,27 @@ // Insert the `free` declaration if it is not already present. LLVM::LLVMFuncOp freeFunc = getFreeFn(getTypeConverter(), op->getParentOfType()); - MemRefDescriptor memref(adaptor.getMemref()); - Value allocatedPtr = memref.allocatedPtr(rewriter, op.getLoc()); - Value casted = allocatedPtr; + Value allocatedPtr; + if (auto unrankedTy = + llvm::dyn_cast(op.getMemref().getType())) { + Type elementType = unrankedTy.getElementType(); + Type llvmElementTy = getTypeConverter()->convertType(elementType); + LLVM::LLVMPointerType elementPtrTy = getTypeConverter()->getPointerType( + llvmElementTy, unrankedTy.getMemorySpaceAsInt()); + allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( + rewriter, op.getLoc(), + UnrankedMemRefDescriptor(adaptor.getMemref()) + .memRefDescPtr(rewriter, op.getLoc()), + elementPtrTy); + } else { + allocatedPtr = MemRefDescriptor(adaptor.getMemref()) + .allocatedPtr(rewriter, op.getLoc()); + } if (!getTypeConverter()->useOpaquePointers()) - casted = rewriter.create(op.getLoc(), getVoidPtrType(), - allocatedPtr); + allocatedPtr = rewriter.create( + op.getLoc(), getVoidPtrType(), allocatedPtr); - rewriter.replaceOpWithNewOp(op, freeFunc, casted); + rewriter.replaceOpWithNewOp(op, freeFunc, allocatedPtr); return success(); } }; diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -42,6 +42,17 @@ // ----- +// CHECK-LABEL: func @unranked_dealloc +func.func @unranked_dealloc(%arg0: memref<*xf32>) { +// CHECK: %[[memref:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i64, ptr)> +// CHECK: %[[ptr:.*]] = llvm.load %[[memref]] +// CHECK-NEXT: llvm.call @free(%[[ptr]]) + memref.dealloc %arg0 : memref<*xf32> + return +} + +// ----- + // CHECK-LABEL: func @dynamic_alloc( // CHECK: %[[Marg:.*]]: index, %[[Narg:.*]]: index) func.func @dynamic_alloc(%arg0: index, %arg1: index) -> memref {