diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -440,6 +440,16 @@ ValueRange indices, ConversionPatternRewriter &rewriter, llvm::Module &module) const; + /// Returns the type of a pointer to an element of the memref. + Type getElementPtrType(MemRefType type) const; + + /// Determines sizes to be used in the memref descriptor. + void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, + ArrayRef dyn_sizes, + ConversionPatternRewriter &rewriter, + SmallVectorImpl &sizes, + Value &cumulativeSizeInBytes, Value &one) const; + protected: /// Reference to the type converter, with potential extensions. LLVMTypeConverter &typeConverter; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -925,6 +925,50 @@ offset, rewriter); } +Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { + auto elementType = type.getElementType(); + auto structElementType = typeConverter.convertType(elementType); + return structElementType.cast().getPointerTo( + type.getMemorySpace()); +} + +void ConvertToLLVMPattern::getMemRefDescriptorSizes( + Location loc, MemRefType memRefType, ArrayRef dyn_sizes, + ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, + Value &cumulativeSizeInBytes, Value &one) const { + sizes.reserve(memRefType.getRank()); + unsigned i = 0; + for (int64_t s : memRefType.getShape()) + sizes.push_back(s == -1 ? dyn_sizes[i++] + : createIndexConstant(rewriter, loc, s)); + if (sizes.empty()) + sizes.push_back(createIndexConstant(rewriter, loc, 1)); + + // Compute the total number of memref elements. + cumulativeSizeInBytes = sizes.front(); + for (unsigned i = 1, e = sizes.size(); i < e; ++i) + cumulativeSizeInBytes = rewriter.create( + loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, sizes[i]}); + + // Compute the size of an individual element. This emits the MLIR equivalent + // of the following sizeof(...) implementation in LLVM IR: + // %0 = getelementptr %elementType* null, %indexType 1 + // %1 = ptrtoint %elementType* %0 to %indexType + // which is a common pattern of getting the size of a type in bytes. + auto elementType = memRefType.getElementType(); + auto convertedPtrType = typeConverter.convertType(elementType) + .template cast() + .getPointerTo(); + auto nullPtr = rewriter.create(loc, convertedPtrType); + one = createIndexConstant(rewriter, loc, 1); + auto gep = rewriter.create(loc, convertedPtrType, + ArrayRef{nullPtr, one}); + auto elementSize = + rewriter.create(loc, getIndexType(), gep); + cumulativeSizeInBytes = rewriter.create( + loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, elementSize}); +} + /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. @@ -1698,7 +1742,7 @@ Location loc, ConversionPatternRewriter &rewriter, MemRefType memRefType, Value allocatedTypePtr, Value allocatedBytePtr, Value accessAlignment, uint64_t offset, ArrayRef strides, ArrayRef sizes) const { - auto elementPtrType = getElementPtrType(memRefType); + auto elementPtrType = this->getElementPtrType(memRefType); auto structType = typeConverter.convertType(memRefType); auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); @@ -1756,52 +1800,6 @@ return memRefDescriptor; } - /// Determines sizes to be used in the memref descriptor. - void getSizes(Location loc, MemRefType memRefType, ArrayRef operands, - ConversionPatternRewriter &rewriter, - SmallVectorImpl &sizes, Value &cumulativeSize, - Value &one) const { - sizes.reserve(memRefType.getRank()); - unsigned i = 0; - for (int64_t s : memRefType.getShape()) - sizes.push_back(s == -1 ? operands[i++] - : createIndexConstant(rewriter, loc, s)); - if (sizes.empty()) - sizes.push_back(createIndexConstant(rewriter, loc, 1)); - - // Compute the total number of memref elements. - cumulativeSize = sizes.front(); - for (unsigned i = 1, e = sizes.size(); i < e; ++i) - cumulativeSize = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSize, sizes[i]}); - - // Compute the size of an individual element. This emits the MLIR equivalent - // of the following sizeof(...) implementation in LLVM IR: - // %0 = getelementptr %elementType* null, %indexType 1 - // %1 = ptrtoint %elementType* %0 to %indexType - // which is a common pattern of getting the size of a type in bytes. - auto elementType = memRefType.getElementType(); - auto convertedPtrType = typeConverter.convertType(elementType) - .template cast() - .getPointerTo(); - auto nullPtr = rewriter.create(loc, convertedPtrType); - one = createIndexConstant(rewriter, loc, 1); - auto gep = rewriter.create(loc, convertedPtrType, - ArrayRef{nullPtr, one}); - auto elementSize = - rewriter.create(loc, getIndexType(), gep); - cumulativeSize = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSize, elementSize}); - } - - /// Returns the type of a pointer to an element of the memref. - Type getElementPtrType(MemRefType memRefType) const { - auto elementType = memRefType.getElementType(); - auto structElementType = typeConverter.convertType(elementType); - return structElementType.template cast().getPointerTo( - memRefType.getMemorySpace()); - } - /// Returns the memref's element size in bytes. // TODO: there are other places where this is used. Expose publicly? static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { @@ -1856,7 +1854,7 @@ MemRefType memRefType, Value one, Value &accessAlignment, Value &allocatedBytePtr, ConversionPatternRewriter &rewriter) const { - auto elementPtrType = getElementPtrType(memRefType); + auto elementPtrType = this->getElementPtrType(memRefType); // With alloca, one gets a pointer to the element type right away. // For stack allocations. @@ -1960,7 +1958,8 @@ // zero-dimensional memref, assume a scalar (size 1). SmallVector sizes; Value cumulativeSize, one; - getSizes(loc, memRefType, operands, rewriter, sizes, cumulativeSize, one); + this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes, + cumulativeSize, one); // Allocate the underlying buffer. // Value holding the alignment that has to be performed post allocation