diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h --- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h +++ b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h @@ -13,23 +13,100 @@ namespace mlir { -/// Lowering for AllocOp and AllocaOp. -struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern { +/// Lowering for memory allocation ops. +struct AllocationOpLLVMLowering : public ConvertToLLVMPattern { using ConvertToLLVMPattern::createIndexConstant; using ConvertToLLVMPattern::getIndexType; using ConvertToLLVMPattern::getVoidPtrType; - explicit AllocLikeOpLLVMLowering(StringRef opName, - LLVMTypeConverter &converter) + explicit AllocationOpLLVMLowering(StringRef opName, + LLVMTypeConverter &converter) : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {} protected: - // Returns 'input' aligned up to 'alignment'. Computes - // bumped = input + alignement - 1 - // aligned = bumped - bumped % alignment + /// Computes the aligned value for 'input' as follows: + /// bumped = input + alignement - 1 + /// aligned = bumped - bumped % alignment static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment); + static MemRefType getMemRefResultType(Operation *op) { + return op->getResult(0).getType().cast(); + } + + /// Computes the alignment for the given memory allocation op. + template + Value getAlignment(ConversionPatternRewriter &rewriter, Location loc, + OpType op) const { + MemRefType memRefType = op.getType(); + Value alignment; + if (auto alignmentAttr = op.getAlignment()) { + alignment = createIndexConstant(rewriter, loc, *alignmentAttr); + } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { + // In the case where no alignment is specified, we may want to override + // `malloc's` behavior. `malloc` typically aligns at the size of the + // biggest scalar on a target HW. For non-scalars, use the natural + // alignment of the LLVM type given by the LLVM DataLayout. + alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); + } + return alignment; + } + + /// Computes the alignment for aligned_alloc used to allocate the buffer for + /// the memory allocation op. + /// + /// Aligned_alloc requires the allocation size to be a power of two, and the + /// allocation size to be a multiple of the alignment. + template + int64_t alignedAllocationGetAlignment(ConversionPatternRewriter &rewriter, + Location loc, OpType op, + const DataLayout *defaultLayout) const { + if (Optional alignment = op.getAlignment()) + return *alignment; + + // Whenever we don't have alignment set, we will use an alignment + // consistent with the element type; since the allocation size has to be a + // power of two, we will bump to the next power of two if it isn't. + unsigned eltSizeBytes = + getMemRefEltSizeInBytes(op.getType(), op, defaultLayout); + return std::max(kMinAlignedAllocAlignment, + llvm::PowerOf2Ceil(eltSizeBytes)); + } + + /// Allocates a memory buffer using an allocation method that doesn't + /// guarantee alignment. Returns the pointer and its aligned value. + std::tuple + allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc, + Value sizeBytes, Operation *op, + Value alignment) const; + + /// Allocates a memory buffer using an aligned allocation method. + Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter, + Location loc, Value sizeBytes, Operation *op, + const DataLayout *defaultLayout, + int64_t alignment) const; + +private: + /// Computes the byte size for the MemRef element type. + unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op, + const DataLayout *defaultLayout) const; + + /// Returns true if the memref size in bytes is known to be a multiple of + /// factor. + bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op, + const DataLayout *defaultLayout) const; + + /// The minimum alignment to use with aligned_alloc (has to be a power of 2). + static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; +}; + +/// Lowering for AllocOp and AllocaOp. +struct AllocLikeOpLLVMLowering : public AllocationOpLLVMLowering { + explicit AllocLikeOpLLVMLowering(StringRef opName, + LLVMTypeConverter &converter) + : AllocationOpLLVMLowering(opName, converter) {} + +protected: /// Allocates the underlying buffer. Returns the allocated pointer and the /// aligned pointer. virtual std::tuple @@ -37,10 +114,6 @@ Value sizeBytes, Operation *op) const = 0; private: - static MemRefType getMemRefResultType(Operation *op) { - return op->getResult(0).getType().cast(); - } - // An `alloc` is converted into a definition of a memref descriptor value and // a call to `malloc` to allocate the underlying data buffer. The memref // descriptor is of the LLVM structure type where: diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -178,6 +178,99 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// ReallocOp +//===----------------------------------------------------------------------===// + + +def MemRef_ReallocOp : MemRef_Op<"realloc"> { + let summary = "memory reallocation operation"; + let description = [{ + The `realloc` operation changes the size of a memory region. The memory + region is specified by a 1D source memref and the size of the new memory + region is specified by a 1D result memref type and an optional dynamic Value + of `Index` type. The source and the result memref must be in the same memory + space and have the same element type. + + The operation may move the memory region to a new location. In this case, + the content of the memory block is preserved up to the lesser of the new + and old sizes. If the new size if larger, the value of the extended memory + is undefined. This is consistent with the ISO C realloc. + + The operation returns an SSA value for the memref. + + Example: + + ```mlir + %0 = memref.realloc %src : memref<64xf32> to memref<124xf32> + ``` + + The source memref may have a dynamic shape, in which case, the compiler will + generate code to extract its size from the runtime data structure for the + memref. + + ```mlir + %1 = memref.realloc %src : memref to memref<124xf32> + ``` + + If the result memref has a dynamic shape, a result dimension operand is + needed to spefify its dynamic dimension. In the example below, the ssa value + '%d' specifies the unknown dimension of the result memref. + + ```mlir + %2 = memref.realloc %src(%d) : memref to memref + ``` + + An optional `alignment` attribute may be specified to ensure that the + region of memory that will be indexed is aligned at the specified byte + boundary. This is consistent with the fact that memref.alloc supports such + an optional alignment attribute. Note that in ISO C standard, neither alloc + nor realloc supports alignment, though there is aligned_alloc but not + aligned_realloc. + + ```mlir + %3 = memref.ralloc %src {alignment = 8} : memref<64xf32> to memref<124xf32> + ``` + + Referencing the memref through the old SSA value after realloc is undefined + behavior. + + ```mlir + %new = memref.realloc %old : memref<64xf32> to memref<124xf32> + %4 = memref.load %new[%index] // ok + %5 = memref.load %old[%index] // undefined behavior + ``` + }]; + + let arguments = (ins MemRefRankOf<[AnyType], [1]>:$source, + Optional:$dynamicResultSize, + ConfinedAttr, + [IntMinValue<0>]>:$alignment); + + let results = (outs MemRefRankOf<[AnyType], [1]>); + + let builders = [ + OpBuilder<(ins "MemRefType":$resultType, + "Value":$source, + CArg<"Value", "Value()">:$dynamicResultSize), [{ + return build($_builder, $_state, resultType, source, dynamicResultSize, + IntegerAttr()); + }]>]; + + let extraClassDeclaration = [{ + /// The result of a realloc is always a memref. + MemRefType getType() { return getResult().getType().cast(); } + }]; + + let assemblyFormat = [{ + $source (`(` $dynamicResultSize^ `)`)? attr-dict + `:` type($source) `to` type(results) + }]; + + let hasCanonicalizer = 1; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // AllocaOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -7,11 +7,40 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" +#include "mlir/Analysis/DataLayoutAnalysis.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" using namespace mlir; -Value AllocLikeOpLLVMLowering::createAligned( +namespace { +// TODO: Fix the LLVM utilities for looking up functions to take Operation* +// with SymbolTable trait instead of ModuleOp and make similar change here. This +// allows call sites to use getParentWithTrait instead +// of getParentOfType to pass down the operation. +LLVM::LLVMFuncOp getNotalignedAllocFn(LLVMTypeConverter *typeConverter, + ModuleOp module, Type indexType) { + bool useGenericFn = typeConverter->getOptions().useGenericFunctions; + + if (useGenericFn) + return LLVM::lookupOrCreateGenericAllocFn(module, indexType); + + return LLVM::lookupOrCreateMallocFn(module, indexType); +} + +LLVM::LLVMFuncOp getAlignedAllocFn(LLVMTypeConverter *typeConverter, + ModuleOp module, Type indexType) { + bool useGenericFn = typeConverter->getOptions().useGenericFunctions; + + if (useGenericFn) + return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType); + + return LLVM::lookupOrCreateAlignedAllocFn(module, indexType); +} + +} // end namespace + +Value AllocationOpLLVMLowering::createAligned( ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment) { Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); @@ -21,6 +50,88 @@ return rewriter.create(loc, bumped, mod); } +std::tuple AllocationOpLLVMLowering::allocateBufferManuallyAlign( + ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, + Operation *op, Value alignment) const { + if (alignment) { + // Adjust the allocation size to consider alignment. + sizeBytes = rewriter.create(loc, sizeBytes, alignment); + } + + MemRefType memRefType = getMemRefResultType(op); + // Allocate the underlying buffer. + Type elementPtrType = this->getElementPtrType(memRefType); + LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn( + getTypeConverter(), op->getParentOfType(), getIndexType()); + auto results = rewriter.create(loc, allocFuncOp, sizeBytes); + Value allocatedPtr = rewriter.create(loc, elementPtrType, + results.getResult()); + + Value alignedPtr = allocatedPtr; + if (alignment) { + // Compute the aligned pointer. + Value allocatedInt = + rewriter.create(loc, getIndexType(), allocatedPtr); + Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); + alignedPtr = + rewriter.create(loc, elementPtrType, alignmentInt); + } + + return std::make_tuple(allocatedPtr, alignedPtr); +} + +unsigned AllocationOpLLVMLowering::getMemRefEltSizeInBytes( + MemRefType memRefType, Operation *op, + const DataLayout *defaultLayout) const { + const DataLayout *layout = defaultLayout; + if (const DataLayoutAnalysis *analysis = + getTypeConverter()->getDataLayoutAnalysis()) { + layout = &analysis->getAbove(op); + } + Type elementType = memRefType.getElementType(); + if (auto memRefElementType = elementType.dyn_cast()) + return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, + *layout); + if (auto memRefElementType = elementType.dyn_cast()) + return getTypeConverter()->getUnrankedMemRefDescriptorSize( + memRefElementType, *layout); + return layout->getTypeSize(elementType); +} + +bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf( + MemRefType type, uint64_t factor, Operation *op, + const DataLayout *defaultLayout) const { + uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op, defaultLayout); + for (unsigned i = 0, e = type.getRank(); i < e; i++) { + if (ShapedType::isDynamic(type.getDimSize(i))) + continue; + sizeDivisor = sizeDivisor * type.getDimSize(i); + } + return sizeDivisor % factor == 0; +} + +Value AllocationOpLLVMLowering::allocateBufferAutoAlign( + ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, + Operation *op, const DataLayout *defaultLayout, int64_t alignment) const { + Value allocAlignment = createIndexConstant(rewriter, loc, alignment); + + MemRefType memRefType = getMemRefResultType(op); + // Function aligned_alloc requires size to be a multiple of alignment; we pad + // the size to the next multiple if necessary. + if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout)) + sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); + + Type elementPtrType = this->getElementPtrType(memRefType); + LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn( + getTypeConverter(), op->getParentOfType(), getIndexType()); + auto results = rewriter.create( + loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes})); + Value allocatedPtr = rewriter.create(loc, elementPtrType, + results.getResult()); + + return allocatedPtr; +} + LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite( Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -36,63 +36,25 @@ return !ShapedType::isDynamicStrideOrOffset(strideOrOffset); } +LLVM::LLVMFuncOp getFreeFn(LLVMTypeConverter *typeConverter, ModuleOp module) { + bool useGenericFn = typeConverter->getOptions().useGenericFunctions; + + if (useGenericFn) + return LLVM::lookupOrCreateGenericFreeFn(module); + + return LLVM::lookupOrCreateFreeFn(module); +} + struct AllocOpLowering : public AllocLikeOpLLVMLowering { AllocOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), converter) {} - - LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const { - bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions; - - if (useGenericFn) - return LLVM::lookupOrCreateGenericAllocFn(module, getIndexType()); - - return LLVM::lookupOrCreateMallocFn(module, getIndexType()); - } - std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { - // Heap allocations. - memref::AllocOp allocOp = cast(op); - MemRefType memRefType = allocOp.getType(); - - Value alignment; - if (auto alignmentAttr = allocOp.getAlignment()) { - alignment = createIndexConstant(rewriter, loc, *alignmentAttr); - } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { - // In the case where no alignment is specified, we may want to override - // `malloc's` behavior. `malloc` typically aligns at the size of the - // biggest scalar on a target HW. For non-scalars, use the natural - // alignment of the LLVM type given by the LLVM DataLayout. - alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); - } - - if (alignment) { - // Adjust the allocation size to consider alignment. - sizeBytes = rewriter.create(loc, sizeBytes, alignment); - } - - // Allocate the underlying buffer and store a pointer to it in the MemRef - // descriptor. - Type elementPtrType = this->getElementPtrType(memRefType); - auto allocFuncOp = getAllocFn(allocOp->getParentOfType()); - auto results = rewriter.create(loc, allocFuncOp, sizeBytes); - Value allocatedPtr = rewriter.create(loc, elementPtrType, - results.getResult()); - - Value alignedPtr = allocatedPtr; - if (alignment) { - // Compute the aligned type pointer. - Value allocatedInt = - rewriter.create(loc, getIndexType(), allocatedPtr); - Value alignmentInt = - createAligned(rewriter, loc, allocatedInt, alignment); - alignedPtr = - rewriter.create(loc, elementPtrType, alignmentInt); - } - - return std::make_tuple(allocatedPtr, alignedPtr); + return allocateBufferManuallyAlign( + rewriter, loc, sizeBytes, op, + getAlignment(rewriter, loc, cast(op))); } }; @@ -100,90 +62,17 @@ AlignedAllocOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), converter) {} - - /// Returns the memref's element size in bytes using the data layout active at - /// `op`. - // TODO: there are other places where this is used. Expose publicly? - unsigned getMemRefEltSizeInBytes(MemRefType memRefType, Operation *op) const { - const DataLayout *layout = &defaultLayout; - if (const DataLayoutAnalysis *analysis = - getTypeConverter()->getDataLayoutAnalysis()) { - layout = &analysis->getAbove(op); - } - Type elementType = memRefType.getElementType(); - if (auto memRefElementType = elementType.dyn_cast()) - return getTypeConverter()->getMemRefDescriptorSize(memRefElementType, - *layout); - if (auto memRefElementType = elementType.dyn_cast()) - return getTypeConverter()->getUnrankedMemRefDescriptorSize( - memRefElementType, *layout); - return layout->getTypeSize(elementType); - } - - /// Returns true if the memref size in bytes is known to be a multiple of - /// factor assuming the data layout active at `op`. - bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, - Operation *op) const { - uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); - for (unsigned i = 0, e = type.getRank(); i < e; i++) { - if (ShapedType::isDynamic(type.getDimSize(i))) - continue; - sizeDivisor = sizeDivisor * type.getDimSize(i); - } - return sizeDivisor % factor == 0; - } - - /// Returns the alignment to be used for the allocation call itself. - /// aligned_alloc requires the allocation size to be a power of two, and the - /// allocation size to be a multiple of alignment, - int64_t getAllocationAlignment(memref::AllocOp allocOp) const { - if (Optional alignment = allocOp.getAlignment()) - return *alignment; - - // Whenever we don't have alignment set, we will use an alignment - // consistent with the element type; since the allocation size has to be a - // power of two, we will bump to the next power of two if it already isn't. - auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType(), allocOp); - return std::max(kMinAlignedAllocAlignment, - llvm::PowerOf2Ceil(eltSizeBytes)); - } - - LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const { - bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions; - - if (useGenericFn) - return LLVM::lookupOrCreateGenericAlignedAllocFn(module, getIndexType()); - - return LLVM::lookupOrCreateAlignedAllocFn(module, getIndexType()); - } - std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { - // Heap allocations. - memref::AllocOp allocOp = cast(op); - MemRefType memRefType = allocOp.getType(); - int64_t alignment = getAllocationAlignment(allocOp); - Value allocAlignment = createIndexConstant(rewriter, loc, alignment); - - // aligned_alloc requires size to be a multiple of alignment; we will pad - // the size to the next multiple if necessary. - if (!isMemRefSizeMultipleOf(memRefType, alignment, op)) - sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); - - Type elementPtrType = this->getElementPtrType(memRefType); - auto allocFuncOp = getAllocFn(allocOp->getParentOfType()); - auto results = rewriter.create( - loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes})); - Value allocatedPtr = rewriter.create(loc, elementPtrType, - results.getResult()); - - return std::make_tuple(allocatedPtr, allocatedPtr); + Value ptr = allocateBufferAutoAlign( + rewriter, loc, sizeBytes, op, &defaultLayout, + alignedAllocationGetAlignment(rewriter, loc, cast(op), + &defaultLayout)); + return std::make_tuple(ptr, ptr); } - /// The minimum alignment to use with aligned_alloc (has to be a power of 2). - static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; - +private: /// Default layout to use in absence of the corresponding analysis. DataLayout defaultLayout; }; @@ -212,6 +101,160 @@ } }; +/// The base class for lowering realloc op, to support the implementation of +/// realloc via allocation methods that may or may not support alignment. +/// A derived class should provide an implementation of allocateBuffer using +/// the underline allocation methods. +struct ReallocOpLoweringBase : public AllocationOpLLVMLowering { + using OpAdaptor = typename memref::ReallocOp::Adaptor; + + ReallocOpLoweringBase(LLVMTypeConverter &converter) + : AllocationOpLLVMLowering(memref::ReallocOp::getOperationName(), + converter) {} + + /// Allocates the new buffer. Returns the allocated pointer and the + /// aligned pointer. + virtual std::tuple + allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, + Value sizeBytes, memref::ReallocOp op) const = 0; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + return matchAndRewrite(cast(op), + OpAdaptor(operands, op->getAttrDictionary()), + rewriter); + } + + // A `realloc` is converted as follows: + // If new_size > old_size + // 1. allocates a new buffer + // 2. copies the content of the old buffer to the new buffer + // 3. release the old buffer + // 3. updates the buffer pointers in the memref descriptor + // Update the size in the memref descriptor + // Alignment request is handled by allocating `alignment` more bytes than + // requested and shifting the aligned pointer relative to the allocated + // memory. + LogicalResult matchAndRewrite(memref::ReallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + OpBuilder::InsertionGuard guard(rewriter); + Location loc = op.getLoc(); + + auto computeNumElements = + [&](MemRefType type, function_ref getDynamicSize) -> Value { + // Compute number of elements. + int64_t size = type.getShape()[0]; + Value numElements = ((size == ShapedType::kDynamicSize) + ? getDynamicSize() + : createIndexConstant(rewriter, loc, size)); + Type indexType = getIndexType(); + if (numElements.getType() != indexType) + numElements = typeConverter->materializeTargetConversion( + rewriter, loc, indexType, numElements); + return numElements; + }; + + MemRefDescriptor desc(adaptor.getSource()); + Value oldDesc = desc; + + // Split the block right before the current op into two blocks. + Block *currentBlock = rewriter.getInsertionBlock(); + Block *block = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + // Add a block argument by creating an empty block with the argument type + // and then merging the block into the empty block. + Block *endBlock = rewriter.createBlock( + block->getParent(), Region::iterator(block), oldDesc.getType(), loc); + rewriter.mergeBlocks(block, endBlock, {}); + // Add a new block for the true branch of the conditional statement we will + // add. + Block *trueBlock = rewriter.createBlock( + currentBlock->getParent(), std::next(Region::iterator(currentBlock))); + + rewriter.setInsertionPointToEnd(currentBlock); + Value src = op.getSource(); + auto srcType = src.getType().dyn_cast(); + Value srcNumElements = computeNumElements( + srcType, [&]() -> Value { return desc.size(rewriter, loc, 0); }); + auto dstType = op.getType().cast(); + Value dstNumElements = computeNumElements( + dstType, [&]() -> Value { return op.getDynamicResultSize(); }); + Value cond = rewriter.create( + loc, IntegerType::get(rewriter.getContext(), 1), + LLVM::ICmpPredicate::ugt, dstNumElements, srcNumElements); + rewriter.create(loc, cond, trueBlock, ArrayRef(), + endBlock, ValueRange{oldDesc}); + + rewriter.setInsertionPointToStart(trueBlock); + Value sizeInBytes = getSizeInBytes(loc, dstType.getElementType(), rewriter); + // Compute total byte size. + auto dstByteSize = + rewriter.create(loc, dstNumElements, sizeInBytes); + // Allocate a new buffer. + auto [dstRawPtr, dstAlignedPtr] = + allocateBuffer(rewriter, loc, dstByteSize, op); + // Copy the data from the old buffer to the new buffer. + Value srcAlignedPtr = desc.alignedPtr(rewriter, loc); + Value isVolatile = + rewriter.create(loc, rewriter.getBoolAttr(false)); + auto toVoidPtr = [&](Value ptr) -> Value { + return rewriter.create(loc, getVoidPtrType(), ptr); + }; + rewriter.create(loc, toVoidPtr(dstAlignedPtr), + toVoidPtr(srcAlignedPtr), dstByteSize, + isVolatile); + // Deallocate the old buffer. + LLVM::LLVMFuncOp freeFunc = + getFreeFn(getTypeConverter(), op->getParentOfType()); + rewriter.create(loc, freeFunc, + toVoidPtr(desc.allocatedPtr(rewriter, loc))); + // Replace the old buffer addresses in the MemRefDescriptor with the new + // buffer addresses. + desc.setAllocatedPtr(rewriter, loc, dstRawPtr); + desc.setAlignedPtr(rewriter, loc, dstAlignedPtr); + rewriter.create(loc, Value(desc), endBlock); + + rewriter.setInsertionPoint(op); + // Update the memref size. + MemRefDescriptor newDesc(endBlock->getArgument(0)); + newDesc.setSize(rewriter, loc, 0, dstNumElements); + rewriter.replaceOp(op, {newDesc}); + return success(); + } + +private: + using ConvertToLLVMPattern::matchAndRewrite; +}; + +struct ReallocOpLowering : public ReallocOpLoweringBase { + ReallocOpLowering(LLVMTypeConverter &converter) + : ReallocOpLoweringBase(converter) {} + std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, + Location loc, Value sizeBytes, + memref::ReallocOp op) const override { + return allocateBufferManuallyAlign(rewriter, loc, sizeBytes, op, + getAlignment(rewriter, loc, op)); + } +}; + +struct AlignedReallocOpLowering : public ReallocOpLoweringBase { + AlignedReallocOpLowering(LLVMTypeConverter &converter) + : ReallocOpLoweringBase(converter) {} + std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, + Location loc, Value sizeBytes, + memref::ReallocOp op) const override { + Value ptr = allocateBufferAutoAlign( + rewriter, loc, sizeBytes, op, &defaultLayout, + alignedAllocationGetAlignment(rewriter, loc, op, &defaultLayout)); + return std::make_tuple(ptr, ptr); + } + +private: + /// Default layout to use in absence of the corresponding analysis. + DataLayout defaultLayout; +}; + struct AllocaScopeOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -316,20 +359,12 @@ explicit DeallocOpLowering(LLVMTypeConverter &converter) : ConvertOpToLLVMPattern(converter) {} - LLVM::LLVMFuncOp getFreeFn(ModuleOp module) const { - bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions; - - if (useGenericFn) - return LLVM::lookupOrCreateGenericFreeFn(module); - - return LLVM::lookupOrCreateFreeFn(module); - } - LogicalResult matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Insert the `free` declaration if it is not already present. - auto freeFunc = getFreeFn(op->getParentOfType()); + LLVM::LLVMFuncOp freeFunc = + getFreeFn(getTypeConverter(), op->getParentOfType()); MemRefDescriptor memref(adaptor.getMemref()); Value casted = rewriter.create( op.getLoc(), getVoidPtrType(), @@ -2060,9 +2095,11 @@ // clang-format on auto allocLowering = converter.getOptions().allocLowering; if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) - patterns.add(converter); + patterns.add(converter); else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) - patterns.add(converter); + patterns.add( + converter); } namespace { diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -246,6 +246,52 @@ context); } +//===----------------------------------------------------------------------===// +// ReallocOp +//===----------------------------------------------------------------------===// + +LogicalResult ReallocOp::verify() { + auto sourceType = getOperand(0).getType().cast(); + MemRefType resultType = getType(); + + // The source memref should have identity layout (or none). + if (!sourceType.getLayout().isIdentity()) + return emitError("unsupported layout for source memref type ") + << sourceType; + + // The result memref should have identity layout (or none). + if (!resultType.getLayout().isIdentity()) + return emitError("unsupported layout for result memref type ") + << resultType; + + // The source memref and the result memref should be in the same memory space. + if (sourceType.getMemorySpace() != resultType.getMemorySpace()) + return emitError("different memory spaces specified for source memref " + "type ") + << sourceType << " and result memref type " << resultType; + + // The source memref and the result memref should have the same element type. + if (sourceType.getElementType() != resultType.getElementType()) + return emitError("different element types specified for source memref " + "type ") + << sourceType << " and result memref type " << resultType; + + // Verify that we have the dynamic dimension operand when it is needed. + if (resultType.getNumDynamicDims() && !getDynamicResultSize()) + return emitError("missing dimension operand for result type ") + << resultType; + if (!resultType.getNumDynamicDims() && getDynamicResultSize()) + return emitError("unnecessary dimension operand for result type ") + << resultType; + + return success(); +} + +void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + //===----------------------------------------------------------------------===// // AllocaScopeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -626,3 +626,125 @@ return } +// ----- + +// CHECK-LABEL: func.func @realloc_dynamic( +// CHECK-SAME: %[[arg0:.*]]: memref, +// CHECK-SAME: %[[arg1:.*]]: index) -> memref { +func.func @realloc_dynamic(%in: memref, %d: index) -> memref{ +// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] +// CHECK: %[[drc_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0] +// CHECK: %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64 +// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[drc_dim]] : i64 +// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]] +// CHECK: ^bb1: +// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] +// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 +// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] +// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]]) +// CHECK: %[[new_buffer:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr to !llvm.ptr +// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1] +// CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1 +// CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer]] : !llvm.ptr to !llvm.ptr +// CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr to !llvm.ptr +// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]]) +// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] +// CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr to !llvm.ptr +// CHECK: llvm.call @free(%[[old_buffer_unaligned_void]]) +// CHECK: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer]], %[[descriptor]][0] +// CHECK: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer]], %[[descriptor_update1]][1] +// CHECK: llvm.br ^bb2(%[[descriptor_update2]] +// CHECK: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>): +// CHECK: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0] +// CHECK: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]] +// CHECK: return %[[descriptor_update5]] : memref + + %out = memref.realloc %in(%d) : memref to memref + return %out : memref +} + +// ----- + +// CHECK-LABEL: func.func @realloc_dynamic_alignment( +// CHECK-SAME: %[[arg0:.*]]: memref, +// CHECK-SAME: %[[arg1:.*]]: index) -> memref { +// ALIGNED-ALLOC-LABEL: func.func @realloc_dynamic_alignment( +// ALIGNED-ALLOC-SAME: %[[arg0:.*]]: memref, +// ALIGNED-ALLOC-SAME: %[[arg1:.*]]: index) -> memref { +func.func @realloc_dynamic_alignment(%in: memref, %d: index) -> memref{ +// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] +// CHECK: %[[drc_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0] +// CHECK: %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64 +// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[drc_dim]] : i64 +// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]] +// CHECK: ^bb1: +// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] +// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 +// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] +// CHECK: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64 +// CHECK: %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]] +// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]]) +// CHECK: %[[new_buffer_unaligned:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr to !llvm.ptr +// CHECK: %[[new_buffer_int:.*]] = llvm.ptrtoint %[[new_buffer_unaligned]] : !llvm.ptr +// CHECK: %[[const_1:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[alignment_m1:.*]] = llvm.sub %[[alignment]], %[[const_1]] +// CHECK: %[[ptr_alignment_m1:.*]] = llvm.add %[[new_buffer_int]], %[[alignment_m1]] +// CHECK: %[[padding:.*]] = llvm.urem %[[ptr_alignment_m1]], %[[alignment]] +// CHECK: %[[new_buffer_aligned_int:.*]] = llvm.sub %[[ptr_alignment_m1]], %[[padding]] +// CHECK: %[[new_buffer_aligned:.*]] = llvm.inttoptr %[[new_buffer_aligned_int]] : i64 to !llvm.ptr +// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1] +// CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1 +// CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr to !llvm.ptr +// CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr to !llvm.ptr +// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]]) +// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] +// CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr to !llvm.ptr +// CHECK: llvm.call @free(%[[old_buffer_unaligned_void]]) +// CHECK: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer_unaligned]], %[[descriptor]][0] +// CHECK: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer_aligned]], %[[descriptor_update1]][1] +// CHECK: llvm.br ^bb2(%[[descriptor_update2]] +// CHECK: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>): +// CHECK: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0] +// CHECK: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]] +// CHECK: return %[[descriptor_update5]] : memref + +// ALIGNED-ALLOC: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] +// ALIGNED-ALLOC: %[[drc_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0] +// ALIGNED-ALLOC: %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64 +// ALIGNED-ALLOC: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[drc_dim]] : i64 +// ALIGNED-ALLOC: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]] +// ALIGNED-ALLOC: ^bb1: +// ALIGNED-ALLOC: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr +// ALIGNED-ALLOC: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] +// ALIGNED-ALLOC: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 +// ALIGNED-ALLOC: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] +// ALIGNED-ALLOC-DAG: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64 +// ALIGNED-ALLOC-DAG: %[[const_1:.*]] = llvm.mlir.constant(1 : index) : i64 +// ALIGNED-ALLOC: %[[alignment_m1:.*]] = llvm.sub %[[alignment]], %[[const_1]] +// ALIGNED-ALLOC: %[[size_alignment_m1:.*]] = llvm.add %[[dst_size]], %[[alignment_m1]] +// ALIGNED-ALLOC: %[[padding:.*]] = llvm.urem %[[size_alignment_m1]], %[[alignment]] +// ALIGNED-ALLOC: %[[adjust_dst_size:.*]] = llvm.sub %[[size_alignment_m1]], %[[padding]] +// ALIGNED-ALLOC: %[[new_buffer_raw:.*]] = llvm.call @aligned_alloc(%[[alignment]], %[[adjust_dst_size]]) +// ALIGNED-ALLOC: %[[new_buffer_aligned:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr to !llvm.ptr +// ALIGNED-ALLOC: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1] +// ALIGNED-ALLOC: %[[volatile:.*]] = llvm.mlir.constant(false) : i1 +// ALIGNED-ALLOC-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr to !llvm.ptr +// ALIGNED-ALLOC-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr to !llvm.ptr +// ALIGNED-ALLOC: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]]) +// ALIGNED-ALLOC: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] +// ALIGNED-ALLOC: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr to !llvm.ptr +// ALIGNED-ALLOC: llvm.call @free(%[[old_buffer_unaligned_void]]) +// ALIGNED-ALLOC: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer_aligned]], %[[descriptor]][0] +// ALIGNED-ALLOC: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer_aligned]], %[[descriptor_update1]][1] +// ALIGNED-ALLOC: llvm.br ^bb2(%[[descriptor_update2]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) +// ALIGNED-ALLOC: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>): +// ALIGNED-ALLOC: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0] +// ALIGNED-ALLOC: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]] +// ALIGNED-ALLOC: return %[[descriptor_update5]] : memref + + %out = memref.realloc %in(%d) {alignment = 8} : memref to memref + return %out : memref +} + diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -338,3 +338,87 @@ %1 = memref.reshape %arg0(%shape) : (memref, memref<1xindex>) -> memref return %1 : memref } + +// ----- + +// CHECK-LABEL: func.func @realloc_static( +// CHECK-SAME: %[[arg0:.*]]: memref<2xi32>) -> memref<4xi32> { +func.func @realloc_static(%in: memref<2xi32>) -> memref<4xi32>{ +// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : memref<2xi32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[src_dim:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK: %[[dst_dim:.*]] = llvm.mlir.constant(4 : index) : i64 +// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[src_dim]] +// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]] +// CHECK: ^bb1: +// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] +// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 +// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] +// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]]) +// CHECK: %[[new_buffer:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr to !llvm.ptr +// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1] +// CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1 +// CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer]] : !llvm.ptr to !llvm.ptr +// CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr to !llvm.ptr +// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]]) +// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] +// CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr to !llvm.ptr +// CHECK: llvm.call @free(%[[old_buffer_unaligned_void]]) +// CHECK: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer]], %[[descriptor]][0] +// CHECK: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer]], %[[descriptor_update1]][1] +// CHECK: llvm.br ^bb2(%[[descriptor_update2]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) +// CHECK: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>): +// CHECK: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0] +// CHECK: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]] +// CHECK: return %[[descriptor_update5]] : memref<4xi32> + + %out = memref.realloc %in : memref<2xi32> to memref<4xi32> + return %out : memref<4xi32> +} + +// ----- + +// CHECK-LABEL: func.func @realloc_static_alignment( +// CHECK-SAME: %[[arg0:.*]]: memref<2xf32>) -> memref<4xf32> { +func.func @realloc_static_alignment(%in: memref<2xf32>) -> memref<4xf32>{ +// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : memref<2xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK: %[[src_dim:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK: %[[dst_dim:.*]] = llvm.mlir.constant(4 : index) : i64 +// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[src_dim]] : i64 +// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]] +// CHECK: ^bb1: +// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] +// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 +// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] +// CHECK: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64 +// CHECK: %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]] +// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]]) +// CHECK: %[[new_buffer_unaligned:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr to !llvm.ptr +// CHECK: %[[new_buffer_int:.*]] = llvm.ptrtoint %[[new_buffer_unaligned]] : !llvm.ptr +// CHECK: %[[const_1:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: %[[alignment_m1:.*]] = llvm.sub %[[alignment]], %[[const_1]] +// CHECK: %[[ptr_alignment_m1:.*]] = llvm.add %[[new_buffer_int]], %[[alignment_m1]] +// CHECK: %[[padding:.*]] = llvm.urem %[[ptr_alignment_m1]], %[[alignment]] +// CHECK: %[[new_buffer_aligned_int:.*]] = llvm.sub %[[ptr_alignment_m1]], %[[padding]] +// CHECK: %[[new_buffer_aligned:.*]] = llvm.inttoptr %[[new_buffer_aligned_int]] : i64 to !llvm.ptr +// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1] +// CHECK: %[[volatile:.*]] = llvm.mlir.constant(false) : i1 +// CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer_aligned]] : !llvm.ptr to !llvm.ptr +// CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr to !llvm.ptr +// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[dst_size]], %[[volatile]]) +// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] +// CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr to !llvm.ptr +// CHECK: llvm.call @free(%[[old_buffer_unaligned_void]]) +// CHECK: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer_unaligned]], %[[descriptor]][0] +// CHECK: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer_aligned]], %[[descriptor_update1]][1] +// CHECK: llvm.br ^bb2(%[[descriptor_update2]] +// CHECK: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>): +// CHECK: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0] +// CHECK: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]] +// CHECK: return %[[descriptor_update5]] : memref<4xf32> + + + %out = memref.realloc %in {alignment = 8} : memref<2xf32> to memref<4xf32> + return %out : memref<4xf32> +} diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -787,3 +787,16 @@ // CHECK-SAME: %[[ARG1:.+]]: index // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][0, 0] [1, %[[ARG1]]] [1, 1] // CHECK-SAME: memref<8x?xf32> to memref> + +// ---- + +// CHECK-LABEL: func @memref_realloc_dead +// CHECK-SAME: %[[SRC:[0-9a-z]+]]: memref<2xf32> +// CHECK-NOT: memref.realloc +// CHECK: return %[[SRC]] +func.func @memref_realloc_dead(%src : memref<2xf32>, %v : f32) -> memref<2xf32>{ + %0 = memref.realloc %src : memref<2xf32> to memref<4xf32> + %i2 = arith.constant 2 : index + memref.store %v, %0[%i2] : memref<4xf32> + return %src : memref<2xf32> +} diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -992,3 +992,37 @@ } return } + +// ----- + +#map0 = affine_map<(d0) -> (d0 floordiv 8, d0 mod 8)> +func.func @memref_realloc_layout(%src : memref<256xf32, #map0>) -> memref{ + // expected-error@+1 {{unsupported layout}} + %0 = memref.realloc %src : memref<256xf32, #map0> to memref + return %0 : memref +} + +// ----- + +func.func @memref_realloc_sizes_1(%src : memref<2xf32>) -> memref{ + // expected-error@+1 {{missing dimension operand}} + %0 = memref.realloc %src : memref<2xf32> to memref + return %0 : memref +} + +// ----- + +func.func @memref_realloc_sizes_2(%src : memref, %d : index) + -> memref<4xf32>{ + // expected-error@+1 {{unnecessary dimension operand}} + %0 = memref.realloc %src(%d) : memref to memref<4xf32> + return %0 : memref<4xf32> +} + +// ----- + +func.func @memref_realloc_type(%src : memref<256xf32>) -> memref{ + // expected-error@+1 {{different element types}} + %0 = memref.realloc %src : memref<256xf32> to memref + return %0 : memref +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -347,3 +347,30 @@ return %m2: memref> } + +// ----- + +// CHECK-LABEL: func @memref_realloc_ss +func.func @memref_realloc_ss(%src : memref<2xf32>) -> memref<4xf32>{ + %0 = memref.realloc %src : memref<2xf32> to memref<4xf32> + return %0 : memref<4xf32> +} + +// CHECK-LABEL: func @memref_realloc_sd +func.func @memref_realloc_sd(%src : memref<2xf32>, %d : index) -> memref{ + %0 = memref.realloc %src(%d) : memref<2xf32> to memref + return %0 : memref +} + +// CHECK-LABEL: func @memref_realloc_ds +func.func @memref_realloc_ds(%src : memref) -> memref<4xf32>{ + %0 = memref.realloc %src: memref to memref<4xf32> + return %0 : memref<4xf32> +} + +// CHECK-LABEL: func @memref_realloc_dd +func.func @memref_realloc_dd(%src : memref, %d: index) + -> memref{ + %0 = memref.realloc %src(%d) : memref to memref + return %0 : memref +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-realloc.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-realloc.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-realloc.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts |\ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext +// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-memref-to-llvm='use-aligned-alloc=1' -convert-func-to-llvm -arith-expand -reconcile-unrealized-casts |\ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | FileCheck %s + +func.func @entry() { + // Set up memory. + %c0 = arith.constant 0: index + %c1 = arith.constant 1: index + %c8 = arith.constant 8: index + %A = memref.alloc() : memref<8xf32> + scf.for %i = %c0 to %c8 step %c1 { + %i32 = arith.index_cast %i : index to i32 + %fi = arith.sitofp %i32 : i32 to f32 + memref.store %fi, %A[%i] : memref<8xf32> + } + + %d0 = arith.constant -1.0 : f32 + %Av = vector.transfer_read %A[%c0], %d0: memref<8xf32>, vector<8xf32> + vector.print %Av : vector<8xf32> + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 ) + + // Realloc with static sizes. + %B = memref.realloc %A : memref<8xf32> to memref<10xf32> + + %c10 = arith.constant 10: index + scf.for %i = %c8 to %c10 step %c1 { + %i32 = arith.index_cast %i : index to i32 + %fi = arith.sitofp %i32 : i32 to f32 + memref.store %fi, %B[%i] : memref<10xf32> + } + + %Bv = vector.transfer_read %B[%c0], %d0: memref<10xf32>, vector<10xf32> + vector.print %Bv : vector<10xf32> + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ) + + // Realloc with dynamic sizes. + %Bd = memref.cast %B : memref<10xf32> to memref + %c13 = arith.constant 13: index + %Cd = memref.realloc %Bd(%c13) : memref to memref + %C = memref.cast %Cd : memref to memref<13xf32> + + scf.for %i = %c10 to %c13 step %c1 { + %i32 = arith.index_cast %i : index to i32 + %fi = arith.sitofp %i32 : i32 to f32 + memref.store %fi, %C[%i] : memref<13xf32> + } + + %Cv = vector.transfer_read %C[%c0], %d0: memref<13xf32>, vector<13xf32> + vector.print %Cv : vector<13xf32> + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 ) + + memref.dealloc %C : memref<13xf32> + return +}