diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -103,9 +103,12 @@ *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType()); auto elementPtrType = getTypeConverter()->getPointerType(elementType, addrSpace); + auto sizeOfElement = getSizeInBytes(loc, elementType, rewriter); + auto numElements = + rewriter.create(loc, sizeBytes, sizeOfElement); auto allocatedElementPtr = rewriter.create( - loc, elementPtrType, elementType, sizeBytes, + loc, elementPtrType, elementType, numElements, allocaOp.getAlignment().value_or(0)); return std::make_tuple(allocatedElementPtr, allocatedElementPtr);