diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -470,6 +470,10 @@ Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const; + /// Computes total number of elements for the given shape. + Value getNumElements(Location loc, ArrayRef shape, + ConversionPatternRewriter &rewriter) const; + /// Computes total size in bytes of to store the given shape. Value getCumulativeSizeInBytes(Location loc, Type elementType, ArrayRef shape, diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -979,18 +979,23 @@ return rewriter.create(loc, getIndexType(), gep); } -Value ConvertToLLVMPattern::getCumulativeSizeInBytes( - Location loc, Type elementType, ArrayRef shape, +Value ConvertToLLVMPattern::getNumElements( + Location loc, ArrayRef shape, ConversionPatternRewriter &rewriter) const { // Compute the total number of memref elements. - Value cumulativeSizeInBytes = + Value numElements = shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); for (unsigned i = 1, e = shape.size(); i < e; ++i) - cumulativeSizeInBytes = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, shape[i]}); - auto elementSize = this->getSizeInBytes(loc, elementType, rewriter); - return rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, elementSize}); + numElements = rewriter.create(loc, numElements, shape[i]); + return numElements; +} + +Value ConvertToLLVMPattern::getCumulativeSizeInBytes( + Location loc, Type elementType, ArrayRef shape, + ConversionPatternRewriter &rewriter) const { + Value elementSize = this->getSizeInBytes(loc, elementType, rewriter); + Value numElements = this->getNumElements(loc, shape, rewriter); + return rewriter.create(loc, numElements, elementSize); } /// Creates and populates the memref descriptor struct given all its fields.