diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -412,6 +412,7 @@ LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1); +protected: /// Returns the LLVM dialect. LLVM::LLVMDialect &getDialect() const; @@ -419,6 +420,10 @@ /// defined by the used type converter. LLVM::LLVMType getIndexType() const; + /// Gets the MLIR type wrapping the LLVM integer type whose bit width + /// corresponds to that of a LLVM pointer type. + LLVM::LLVMType getIntPtrType(unsigned addressSpace = 0) const; + /// Gets the MLIR type wrapping the LLVM void type. LLVM::LLVMType getVoidType() const; @@ -470,6 +475,13 @@ ArrayRef shape, ConversionPatternRewriter &rewriter) const; + /// Creates and populates the memref descriptor struct given all its fields. + MemRefDescriptor + createMemRefDescriptor(Location loc, MemRefType memRefType, + Value allocatedPtr, Value alignedPtr, uint64_t offset, + ArrayRef strides, ArrayRef sizes, + ConversionPatternRewriter &rewriter) const; + protected: /// Reference to the type converter, with potential extensions. LLVMTypeConverter &typeConverter; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -876,6 +876,13 @@ return typeConverter.getIndexType(); } +LLVM::LLVMType +ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { + return LLVM::LLVMType::getIntNTy( + &typeConverter.getContext(), + typeConverter.getPointerBitwidth(addressSpace)); +} + LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const { return LLVM::LLVMType::getVoidTy(&typeConverter.getContext()); } @@ -977,19 +984,70 @@ } Value ConvertToLLVMPattern::getCumulativeSizeInBytes( - Location loc, Type elementType, ArrayRef sizes, + Location loc, Type elementType, ArrayRef shape, ConversionPatternRewriter &rewriter) const { // Compute the total number of memref elements. Value cumulativeSizeInBytes = - sizes.empty() ? createIndexConstant(rewriter, loc, 1) : sizes.front(); - for (unsigned i = 1, e = sizes.size(); i < e; ++i) + shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); + for (unsigned i = 1, e = shape.size(); i < e; ++i) cumulativeSizeInBytes = rewriter.create( - loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, sizes[i]}); + loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, shape[i]}); auto elementSize = this->getSizeInBytes(loc, elementType, rewriter); return rewriter.create( loc, getIndexType(), ArrayRef{cumulativeSizeInBytes, elementSize}); } +/// Creates and populates the memref descriptor struct given all its fields. +MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( + Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, + uint64_t offset, ArrayRef strides, ArrayRef sizes, + ConversionPatternRewriter &rewriter) const { + auto structType = typeConverter.convertType(memRefType); + auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); + + // Field 1: Allocated pointer, used for malloc/free. + memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); + + // Field 2: Actual aligned pointer to payload. + memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); + + // Field 3: Offset in aligned pointer. + memRefDescriptor.setOffset(rewriter, loc, + this->createIndexConstant(rewriter, loc, offset)); + + if (memRefType.getRank() == 0) + // No size/stride descriptor in memref, return the descriptor value. + return memRefDescriptor; + + // Fields 4 and 5: sizes and strides of the strided MemRef. + // Store all sizes in the descriptor. Only dynamic sizes are passed in as + // operands to AllocOp. + Value runningStride = nullptr; + // Iterate strides in reverse order, compute runningStride and strideValues. + auto nStrides = strides.size(); + SmallVector strideValues(nStrides, nullptr); + for (unsigned i = 0; i < nStrides; ++i) { + int64_t index = nStrides - 1 - i; + if (strides[index] == MemRefType::getDynamicStrideOrOffset()) + // Identity layout map is enforced in the match function, so we compute: + // `runningStride *= sizes[index + 1]` + runningStride = runningStride + ? rewriter.create(loc, runningStride, + sizes[index + 1]) + : this->createIndexConstant(rewriter, loc, 1); + else + runningStride = this->createIndexConstant(rewriter, loc, strides[index]); + strideValues[index] = runningStride; + } + // Fill size and stride descriptors in memref. + for (auto indexedSize : llvm::enumerate(sizes)) { + int64_t index = indexedSize.index(); + memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value()); + memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]); + } + return memRefDescriptor; +} + /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. @@ -1714,251 +1772,84 @@ } /// Lowering for AllocOp and AllocaOp. -template -struct AllocLikeOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::createIndexConstant; - using ConvertOpToLLVMPattern::getIndexType; - using ConvertOpToLLVMPattern::typeConverter; - using ConvertOpToLLVMPattern::getVoidPtrType; +struct AllocLikeOpLowering : public ConvertToLLVMPattern { + using ConvertToLLVMPattern::createIndexConstant; + using ConvertToLLVMPattern::getIndexType; + using ConvertToLLVMPattern::getVoidPtrType; + using ConvertToLLVMPattern::typeConverter; + + explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter) + : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {} + +protected: + // Returns 'input' aligned up to 'alignment. Computes + // bumped = input + alignement - 1 + // aligned = bumped - bumped % alignment + static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, + Value input, Value alignment) { + Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); + Value bump = rewriter.create(loc, alignment, one); + Value bumped = rewriter.create(loc, input, bump); + Value mod = rewriter.create(loc, bumped, alignment); + return rewriter.create(loc, bumped, mod); + } + + // Creates a call to an allocation function with params and casts the + // resulting void pointer to ptrType. + Value createAllocCall(Location loc, StringRef name, Type ptrType, + ArrayRef params, ModuleOp module, + ConversionPatternRewriter &rewriter) const { + SmallVector paramTypes; + auto allocFuncOp = module.lookupSymbol(name); + if (!allocFuncOp) { + for (const Value ¶m : params) + paramTypes.push_back(param.getType().cast()); + auto allocFuncType = + LLVM::LLVMType::getFunctionTy(getVoidPtrType(), paramTypes, + /*isVarArg=*/false); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + allocFuncOp = rewriter.create(rewriter.getUnknownLoc(), + name, allocFuncType); + } + auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp); + auto allocatedPtr = rewriter + .create(loc, getVoidPtrType(), + allocFuncSymbol, params) + .getResult(0); + return rewriter.create(loc, ptrType, allocatedPtr); + } - explicit AllocLikeOpLowering(LLVMTypeConverter &converter) - : ConvertOpToLLVMPattern(converter) {} + /// Allocates the underlying buffer. Returns the allocated pointer and the + /// aligned pointer. + virtual std::tuple + allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, + Value cumulativeSize, Operation *op) const = 0; + +private: + static MemRefType getMemRefResultType(Operation *op) { + return op->getResult(0).getType().cast(); + } LogicalResult match(Operation *op) const override { - MemRefType memRefType = cast(op).getType(); + MemRefType memRefType = getMemRefResultType(op); if (isSupportedMemRefType(memRefType)) return success(); int64_t offset; SmallVector strides; - auto successStrides = getStridesAndOffset(memRefType, strides, offset); - if (failed(successStrides)) + if (failed(getStridesAndOffset(memRefType, strides, offset))) return failure(); // Dynamic strides are ok if they can be deduced from dynamic sizes (which - // is guaranteed when succeeded(successStrides)). Dynamic offset however can - // never be alloc'ed. + // is guaranteed when getStridesAndOffset succeeded. Dynamic offset however + // can never be alloc'ed. if (offset == MemRefType::getDynamicStrideOrOffset()) return failure(); return success(); } - // Returns bump = (alignment - (input % alignment))% alignment, which is the - // increment necessary to align `input` to `alignment` boundary. - // TODO: this can be made more efficient by just using a single addition - // and two bit shifts: (ptr + align - 1)/align, align is always power of 2. - Value createBumpToAlign(Location loc, OpBuilder b, Value input, - Value alignment) const { - Value modAlign = b.create(loc, input, alignment); - Value diff = b.create(loc, alignment, modAlign); - Value shift = b.create(loc, diff, alignment); - return shift; - } - - /// Creates and populates the memref descriptor struct given all its fields. - /// This method also performs any post allocation alignment needed for heap - /// allocations when `accessAlignment` is non null. This is used with - /// allocators that do not support alignment. - MemRefDescriptor createMemRefDescriptor( - Location loc, ConversionPatternRewriter &rewriter, MemRefType memRefType, - Value allocatedTypePtr, Value allocatedBytePtr, Value accessAlignment, - uint64_t offset, ArrayRef strides, ArrayRef sizes) const { - auto elementPtrType = this->getElementPtrType(memRefType); - auto structType = typeConverter.convertType(memRefType); - auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); - - // Field 1: Allocated pointer, used for malloc/free. - memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedTypePtr); - - // Field 2: Actual aligned pointer to payload. - Value alignedBytePtr = allocatedTypePtr; - if (accessAlignment) { - // offset = (align - (ptr % align))% align - Value intVal = rewriter.create( - loc, this->getIndexType(), allocatedBytePtr); - Value offset = createBumpToAlign(loc, rewriter, intVal, accessAlignment); - Value aligned = rewriter.create( - loc, allocatedBytePtr.getType(), allocatedBytePtr, offset); - alignedBytePtr = rewriter.create( - loc, elementPtrType, ArrayRef(aligned)); - } - memRefDescriptor.setAlignedPtr(rewriter, loc, alignedBytePtr); - - // Field 3: Offset in aligned pointer. - memRefDescriptor.setOffset(rewriter, loc, - createIndexConstant(rewriter, loc, offset)); - - if (memRefType.getRank() == 0) - // No size/stride descriptor in memref, return the descriptor value. - return memRefDescriptor; - - // Fields 4 and 5: sizes and strides of the strided MemRef. - // Store all sizes in the descriptor. Only dynamic sizes are passed in as - // operands to AllocOp. - Value runningStride = nullptr; - // Iterate strides in reverse order, compute runningStride and strideValues. - auto nStrides = strides.size(); - SmallVector strideValues(nStrides, nullptr); - for (unsigned i = 0; i < nStrides; ++i) { - int64_t index = nStrides - 1 - i; - if (strides[index] == MemRefType::getDynamicStrideOrOffset()) - // Identity layout map is enforced in the match function, so we compute: - // `runningStride *= sizes[index + 1]` - runningStride = runningStride - ? rewriter.create(loc, runningStride, - sizes[index + 1]) - : createIndexConstant(rewriter, loc, 1); - else - runningStride = createIndexConstant(rewriter, loc, strides[index]); - strideValues[index] = runningStride; - } - // Fill size and stride descriptors in memref. - for (auto indexedSize : llvm::enumerate(sizes)) { - int64_t index = indexedSize.index(); - memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value()); - memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]); - } - return memRefDescriptor; - } - - /// Returns the memref's element size in bytes. - // TODO: there are other places where this is used. Expose publicly? - static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { - auto elementType = memRefType.getElementType(); - - unsigned sizeInBits; - if (elementType.isIntOrFloat()) { - sizeInBits = elementType.getIntOrFloatBitWidth(); - } else { - auto vectorType = elementType.cast(); - sizeInBits = - vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); - } - return llvm::divideCeil(sizeInBits, 8); - } - - /// Returns the alignment to be used for the allocation call itself. - /// aligned_alloc requires the allocation size to be a power of two, and the - /// allocation size to be a multiple of alignment, - Optional getAllocationAlignment(AllocOp allocOp) const { - // No alignment can be used for the 'malloc' call itself. - if (!typeConverter.getOptions().useAlignedAlloc) - return None; - - if (Optional alignment = allocOp.alignment()) - return *alignment; - - // Whenever we don't have alignment set, we will use an alignment - // consistent with the element type; since the allocation size has to be a - // power of two, we will bump to the next power of two if it already isn't. - auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType()); - return std::max(kMinAlignedAllocAlignment, - llvm::PowerOf2Ceil(eltSizeBytes)); - } - - /// Returns true if the memref size in bytes is known to be a multiple of - /// factor. - static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) { - uint64_t sizeDivisor = getMemRefEltSizeInBytes(type); - for (unsigned i = 0, e = type.getRank(); i < e; i++) { - if (type.isDynamic(type.getDimSize(i))) - continue; - sizeDivisor = sizeDivisor * type.getDimSize(i); - } - return sizeDivisor % factor == 0; - } - - /// Allocates the underlying buffer using the right call. `allocatedBytePtr` - /// is set to null for stack allocations. `accessAlignment` is set if - /// alignment is needed post allocation (for eg. in conjunction with malloc). - Value allocateBuffer(Location loc, Value cumulativeSize, Operation *op, - MemRefType memRefType, Value one, Value &accessAlignment, - Value &allocatedBytePtr, - ConversionPatternRewriter &rewriter) const { - auto elementPtrType = this->getElementPtrType(memRefType); - - // With alloca, one gets a pointer to the element type right away. - // For stack allocations. - if (auto allocaOp = dyn_cast(op)) { - allocatedBytePtr = nullptr; - accessAlignment = nullptr; - return rewriter.create( - loc, elementPtrType, cumulativeSize, - allocaOp.alignment() ? *allocaOp.alignment() : 0); - } - - // Heap allocations. - AllocOp allocOp = cast(op); - - Optional allocationAlignment = getAllocationAlignment(allocOp); - // Whether to use std lib function aligned_alloc that supports alignment. - bool useAlignedAlloc = allocationAlignment.hasValue(); - - // Insert the malloc/aligned_alloc declaration if it is not already present. - const auto *allocFuncName = useAlignedAlloc ? "aligned_alloc" : "malloc"; - auto module = allocOp.getParentOfType(); - auto allocFunc = module.lookupSymbol(allocFuncName); - if (!allocFunc) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart( - op->getParentOfType().getBody()); - SmallVector callArgTypes = {getIndexType()}; - // aligned_alloc(size_t alignment, size_t size) - if (useAlignedAlloc) - callArgTypes.push_back(getIndexType()); - allocFunc = rewriter.create( - rewriter.getUnknownLoc(), allocFuncName, - LLVM::LLVMType::getFunctionTy(getVoidPtrType(), callArgTypes, - /*isVarArg=*/false)); - } - - // Allocate the underlying buffer and store a pointer to it in the MemRef - // descriptor. - SmallVector callArgs; - if (useAlignedAlloc) { - // Use aligned_alloc. - assert(allocationAlignment && "allocation alignment should be present"); - auto alignedAllocAlignmentValue = rewriter.create( - loc, typeConverter.convertType(rewriter.getIntegerType(64)), - rewriter.getI64IntegerAttr(allocationAlignment.getValue())); - // aligned_alloc requires size to be a multiple of alignment; we will pad - // the size to the next multiple if necessary. - if (!isMemRefSizeMultipleOf(memRefType, allocationAlignment.getValue())) { - Value bump = createBumpToAlign(loc, rewriter, cumulativeSize, - alignedAllocAlignmentValue); - cumulativeSize = - rewriter.create(loc, cumulativeSize, bump); - } - callArgs = {alignedAllocAlignmentValue, cumulativeSize}; - } else { - // Adjust the allocation size to consider alignment. - if (Optional alignment = allocOp.alignment()) { - accessAlignment = createIndexConstant(rewriter, loc, *alignment); - } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { - // In the case where no alignment is specified, we may want to override - // `malloc's` behavior. `malloc` typically aligns at the size of the - // biggest scalar on a target HW. For non-scalars, use the natural - // alignment of the LLVM type given by the LLVM DataLayout. - accessAlignment = - this->getSizeInBytes(loc, memRefType.getElementType(), rewriter); - } - if (accessAlignment) - cumulativeSize = - rewriter.create(loc, cumulativeSize, accessAlignment); - callArgs.push_back(cumulativeSize); - } - auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFunc); - allocatedBytePtr = rewriter - .create(loc, getVoidPtrType(), - allocFuncSymbol, callArgs) - .getResult(0); - // For heap allocations, the allocated pointer is a cast of the byte pointer - // to the type pointer. - return rewriter.create(loc, elementPtrType, - allocatedBytePtr); - } - // An `alloc` is converted into a definition of a memref descriptor value and // a call to `malloc` to allocate the underlying data buffer. The memref // descriptor is of the LLVM structure type where: @@ -1976,7 +1867,7 @@ // an llvm.alloca to allocate the underlying data buffer. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - MemRefType memRefType = cast(op).getType(); + MemRefType memRefType = getMemRefResultType(op); auto loc = op->getLoc(); // Get actual sizes of the memref as values: static sizes are constant @@ -1987,17 +1878,12 @@ Value cumulativeSize = this->getCumulativeSizeInBytes( loc, memRefType.getElementType(), sizes, rewriter); + // Allocate the underlying buffer. - // Value holding the alignment that has to be performed post allocation - // (in conjunction with allocators that do not support alignment, eg. - // malloc); nullptr if no such adjustment needs to be performed. - Value accessAlignment; - // Byte pointer to the allocated buffer. - Value allocatedBytePtr; - Value allocatedTypePtr = - allocateBuffer(loc, cumulativeSize, op, memRefType, - createIndexConstant(rewriter, loc, 1), accessAlignment, - allocatedBytePtr, rewriter); + Value allocatedPtr; + Value alignedPtr; + std::tie(allocatedPtr, alignedPtr) = + this->allocateBuffer(rewriter, loc, cumulativeSize, op); int64_t offset; SmallVector strides; @@ -2014,25 +1900,163 @@ "unexpected number of strides"); // Create the MemRef descriptor. - auto memRefDescriptor = createMemRefDescriptor( - loc, rewriter, memRefType, allocatedTypePtr, allocatedBytePtr, - accessAlignment, offset, strides, sizes); + auto memRefDescriptor = + this->createMemRefDescriptor(loc, memRefType, allocatedPtr, alignedPtr, + offset, strides, sizes, rewriter); // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); } +}; -protected: - /// The minimum alignment to use with aligned_alloc (has to be a power of 2). - uint64_t kMinAlignedAllocAlignment = 16UL; +struct AllocOpLowering : public AllocLikeOpLowering { + AllocOpLowering(LLVMTypeConverter &converter) + : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {} + + std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, + Location loc, Value cumulativeSize, + Operation *op) const override { + // Heap allocations. + AllocOp allocOp = cast(op); + MemRefType memRefType = allocOp.getType(); + + Value alignment; + if (auto alignmentAttr = allocOp.alignment()) { + alignment = createIndexConstant(rewriter, loc, *alignmentAttr); + } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { + // In the case where no alignment is specified, we may want to override + // `malloc's` behavior. `malloc` typically aligns at the size of the + // biggest scalar on a target HW. For non-scalars, use the natural + // alignment of the LLVM type given by the LLVM DataLayout. + alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); + } + + if (alignment) { + // Adjust the allocation size to consider alignment. + cumulativeSize = + rewriter.create(loc, cumulativeSize, alignment); + } + + // Allocate the underlying buffer and store a pointer to it in the MemRef + // descriptor. + Type elementPtrType = this->getElementPtrType(memRefType); + Value allocatedPtr = + createAllocCall(loc, "malloc", elementPtrType, {cumulativeSize}, + allocOp.getParentOfType(), rewriter); + + Value alignedPtr = allocatedPtr; + if (alignment) { + auto intPtrType = getIntPtrType(memRefType.getMemorySpace()); + // Compute the aligned type pointer. + Value allocatedInt = + rewriter.create(loc, intPtrType, allocatedPtr); + Value alignmentInt = + createAligned(rewriter, loc, allocatedInt, alignment); + alignedPtr = + rewriter.create(loc, elementPtrType, alignmentInt); + } + + return std::make_tuple(allocatedPtr, alignedPtr); + } }; -struct AllocOpLowering : public AllocLikeOpLowering { - explicit AllocOpLowering(LLVMTypeConverter &converter) - : AllocLikeOpLowering(converter) {} +struct AlignedAllocOpLowering : public AllocLikeOpLowering { + AlignedAllocOpLowering(LLVMTypeConverter &converter) + : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {} + + /// Returns the memref's element size in bytes. + // TODO: there are other places where this is used. Expose publicly? + static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { + auto elementType = memRefType.getElementType(); + + unsigned sizeInBits; + if (elementType.isIntOrFloat()) { + sizeInBits = elementType.getIntOrFloatBitWidth(); + } else { + auto vectorType = elementType.cast(); + sizeInBits = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + } + return llvm::divideCeil(sizeInBits, 8); + } + + /// Returns true if the memref size in bytes is known to be a multiple of + /// factor. + static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) { + uint64_t sizeDivisor = getMemRefEltSizeInBytes(type); + for (unsigned i = 0, e = type.getRank(); i < e; i++) { + if (type.isDynamic(type.getDimSize(i))) + continue; + sizeDivisor = sizeDivisor * type.getDimSize(i); + } + return sizeDivisor % factor == 0; + } + + /// Returns the alignment to be used for the allocation call itself. + /// aligned_alloc requires the allocation size to be a power of two, and the + /// allocation size to be a multiple of alignment, + int64_t getAllocationAlignment(AllocOp allocOp) const { + if (Optional alignment = allocOp.alignment()) + return *alignment; + + // Whenever we don't have alignment set, we will use an alignment + // consistent with the element type; since the allocation size has to be a + // power of two, we will bump to the next power of two if it already isn't. + auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType()); + return std::max(kMinAlignedAllocAlignment, + llvm::PowerOf2Ceil(eltSizeBytes)); + } + + std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, + Location loc, Value cumulativeSize, + Operation *op) const override { + // Heap allocations. + AllocOp allocOp = cast(op); + MemRefType memRefType = allocOp.getType(); + int64_t alignment = getAllocationAlignment(allocOp); + Value allocAlignment = createIndexConstant(rewriter, loc, alignment); + + // aligned_alloc requires size to be a multiple of alignment; we will pad + // the size to the next multiple if necessary. + if (!isMemRefSizeMultipleOf(memRefType, alignment)) + cumulativeSize = + createAligned(rewriter, loc, cumulativeSize, allocAlignment); + + Type elementPtrType = this->getElementPtrType(memRefType); + Value allocatedPtr = createAllocCall( + loc, "aligned_alloc", elementPtrType, {allocAlignment, cumulativeSize}, + allocOp.getParentOfType(), rewriter); + + return std::make_tuple(allocatedPtr, allocatedPtr); + } + + /// The minimum alignment to use with aligned_alloc (has to be a power of 2). + static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; }; -using AllocaOpLowering = AllocLikeOpLowering; +struct AllocaOpLowering : public AllocLikeOpLowering { + AllocaOpLowering(LLVMTypeConverter &converter) + : AllocLikeOpLowering(AllocaOp::getOperationName(), converter) {} + + /// Allocates the underlying buffer using the right call. `allocatedBytePtr` + /// is set to null for stack allocations. `accessAlignment` is set if + /// alignment is needed post allocation (for eg. in conjunction with malloc). + std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, + Location loc, Value cumulativeSize, + Operation *op) const override { + + // With alloca, one gets a pointer to the element type right away. + // For stack allocations. + auto allocaOp = cast(op); + auto elementPtrType = this->getElementPtrType(allocaOp.getType()); + + auto allocatedElementPtr = rewriter.create( + loc, elementPtrType, cumulativeSize, + allocaOp.alignment() ? *allocaOp.alignment() : 0); + + return std::make_tuple(allocatedElementPtr, allocatedElementPtr); + } +}; /// Copies the shaped descriptor part to (if `toDynamic` is set) or from /// (otherwise) the dynamically allocated memory for any operands that were @@ -3153,12 +3177,13 @@ // This relies on LLVM's CSE optimization (potentially after SROA), since // after CSE all memref.alignedPtr instances get de-duplicated into the same // pointer SSA value. - Value zero = - createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), 0); - Value mask = createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), + auto intPtrType = + getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); + Value zero = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0); + Value mask = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, alignment - 1); Value ptrValue = - rewriter.create(op->getLoc(), getIndexType(), ptr); + rewriter.create(op->getLoc(), intPtrType, ptr); rewriter.create( op->getLoc(), rewriter.create( @@ -3429,9 +3454,12 @@ RankOpLowering, StoreOpLowering, SubViewOpLowering, - ViewOpLowering, - AllocOpLowering>(converter); + ViewOpLowering>(converter); // clang-format on + if (converter.getOptions().useAlignedAlloc) + patterns.insert(converter); + else + patterns.insert(converter); } void mlir::populateStdToLLVMFuncOpConversionPattern( diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -36,7 +36,6 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> @@ -77,7 +76,6 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -107,7 +105,6 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[sz_bytes]] x !llvm.float : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -153,8 +150,7 @@ // ALIGNED-ALLOC-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // ALIGNED-ALLOC-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 -// ALIGNED-ALLOC-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// ALIGNED-ALLOC-NEXT: %[[alignment:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64 +// ALIGNED-ALLOC-NEXT: %[[alignment:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[allocated:.*]] = llvm.call @aligned_alloc(%[[alignment]], %[[bytes]]) : (!llvm.i64, !llvm.i64) -> !llvm.ptr // ALIGNED-ALLOC-NEXT: llvm.bitcast %[[allocated]] : !llvm.ptr to !llvm.ptr %0 = alloc() {alignment = 32} : memref<32x18xf32> @@ -164,26 +160,27 @@ %1 = alloc() {alignment = 64} : memref<4096xf32> // Alignment is to element type boundaries (minimum 16 bytes). - // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64 + // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c32]] %2 = alloc() : memref<4096xvector<8xf32>> // The minimum alignment is 16 bytes unless explicitly specified. - // ALIGNED-ALLOC: %[[c16:.*]] = llvm.mlir.constant(16 : i64) : !llvm.i64 + // ALIGNED-ALLOC: %[[c16:.*]] = llvm.mlir.constant(16 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c16]], %3 = alloc() : memref<4096xvector<2xf32>> - // ALIGNED-ALLOC: %[[c8:.*]] = llvm.mlir.constant(8 : i64) : !llvm.i64 + // ALIGNED-ALLOC: %[[c8:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c8]], %4 = alloc() {alignment = 8} : memref<1024xvector<4xf32>> // Bump the memref allocation size if its size is not a multiple of alignment. - // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64 - // ALIGNED-ALLOC-NEXT: llvm.urem + // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 + // ALIGNED-ALLOC-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.sub + // ALIGNED-ALLOC-NEXT: llvm.add // ALIGNED-ALLOC-NEXT: llvm.urem - // ALIGNED-ALLOC-NEXT: %[[SIZE_ALIGNED:.*]] = llvm.add + // ALIGNED-ALLOC-NEXT: %[[SIZE_ALIGNED:.*]] = llvm.sub // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c32]], %[[SIZE_ALIGNED]]) %5 = alloc() {alignment = 32} : memref<100xf32> // Bump alignment to the next power of two if it isn't. - // ALIGNED-ALLOC: %[[c128:.*]] = llvm.mlir.constant(128 : i64) : !llvm.i64 + // ALIGNED-ALLOC: %[[c128:.*]] = llvm.mlir.constant(128 : index) : !llvm.i64 // ALIGNED-ALLOC: llvm.call @aligned_alloc(%[[c128]] %6 = alloc(%N) : memref> return %0 : memref<32x18xf32> diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -76,7 +76,6 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> @@ -91,7 +90,6 @@ // BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // BAREPTR-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 -// BAREPTR-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: llvm.call @malloc(%{{.*}}) : (!llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr // BAREPTR-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> @@ -130,19 +128,19 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 // CHECK-NEXT: %[[allocsize:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[ptr]] : !llvm.ptr to !llvm.i64 +// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK-NEXT: %[[bump:.*]] = llvm.sub %[[alignment]], %[[one_1]] : !llvm.i64 +// CHECK-NEXT: %[[bumped:.*]] = llvm.add %[[allocatedAsInt]], %[[bump]] : !llvm.i64 +// CHECK-NEXT: %[[mod:.*]] = llvm.urem %[[bumped]], %[[alignment]] : !llvm.i64 +// CHECK-NEXT: %[[aligned:.*]] = llvm.sub %[[bumped]], %[[mod]] : !llvm.i64 +// CHECK-NEXT: %[[alignedBitCast:.*]] = llvm.inttoptr %[[aligned]] : !llvm.i64 to !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[allocated]] : !llvm.ptr to !llvm.i64 -// CHECK-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64 -// CHECK-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64 -// CHECK-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64 -// CHECK-NEXT: %[[aligned:.*]] = llvm.getelementptr %[[allocated]][%[[alignAdj3]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr -// CHECK-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -153,19 +151,19 @@ // BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // BAREPTR-NEXT: llvm.mul %{{.*}}, %[[sizeof]] : !llvm.i64 -// BAREPTR-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 // BAREPTR-NEXT: %[[allocsize:.*]] = llvm.add {{.*}}, %[[alignment]] : !llvm.i64 // BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr +// BAREPTR-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[ptr]] : !llvm.ptr to !llvm.i64 +// BAREPTR-NEXT: %[[one_2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// BAREPTR-NEXT: %[[bump:.*]] = llvm.sub %[[alignment]], %[[one_2]] : !llvm.i64 +// BAREPTR-NEXT: %[[bumped:.*]] = llvm.add %[[allocatedAsInt]], %[[bump]] : !llvm.i64 +// BAREPTR-NEXT: %[[mod:.*]] = llvm.urem %[[bumped]], %[[alignment]] : !llvm.i64 +// BAREPTR-NEXT: %[[aligned:.*]] = llvm.sub %[[bumped]], %[[mod]] : !llvm.i64 +// BAREPTR-NEXT: %[[alignedBitCast:.*]] = llvm.inttoptr %[[aligned]] : !llvm.i64 to !llvm.ptr // BAREPTR-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// BAREPTR-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[allocated]] : !llvm.ptr to !llvm.i64 -// BAREPTR-NEXT: %[[alignAdj1:.*]] = llvm.urem %[[allocatedAsInt]], %[[alignment]] : !llvm.i64 -// BAREPTR-NEXT: %[[alignAdj2:.*]] = llvm.sub %[[alignment]], %[[alignAdj1]] : !llvm.i64 -// BAREPTR-NEXT: %[[alignAdj3:.*]] = llvm.urem %[[alignAdj2]], %[[alignment]] : !llvm.i64 -// BAREPTR-NEXT: %[[aligned:.*]] = llvm.getelementptr %[[allocated]][%[[alignAdj3]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr -// BAREPTR-NEXT: %[[alignedBitCast:.*]] = llvm.bitcast %[[aligned]] : !llvm.ptr to !llvm.ptr // BAREPTR-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // BAREPTR-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -186,7 +184,6 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm.ptr to !llvm.ptr @@ -198,7 +195,6 @@ // BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // BAREPTR-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 -// BAREPTR-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[bytes]]) : (!llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: llvm.bitcast %[[allocated]] : !llvm.ptr to !llvm.ptr %0 = alloc() : memref<32x18xf32> @@ -217,7 +213,6 @@ // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 -// CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[bytes]] x !llvm.float : (!llvm.i64) -> !llvm.ptr %0 = alloca() : memref<32x18xf32>