diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -21,8 +21,14 @@ class ModuleOp; namespace func { +namespace arith { +class ArithDialect; +} // namespace arith class FuncDialect; } // namespace func +namespace scf { +class SCFDialect; +} // namespace scf namespace tensor { class TensorDialect; } // namespace tensor @@ -67,6 +73,10 @@ /// easier to reason about operations. std::unique_ptr createExpandStridedMetadataPass(); +/// Creates an operation pass to expand `memref.realloc` operations into their +/// components. +std::unique_ptr createExpandReallocPass(bool emitDeallocs = true); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -202,5 +202,48 @@ "affine::AffineDialect", "memref::MemRefDialect" ]; } + +def ExpandRealloc : Pass<"expand-realloc"> { + let summary = "Expand memref.realloc operations into its components"; + let description = [{ + The `memref.realloc` operation performs a conditional allocation and copy to + increase the size of a buffer if necessary. This pass converts a `realloc` + operation into this sequence of simpler operations such that other passes + at a later stage in the compilation pipeline do not have to consider the + `realloc` operation anymore (e.g., the buffer deallocation pass and the + conversion pass to LLVM). + + Example of an expansion: + ```mlir + %realloc = memref.realloc %alloc (%size) : memref to memref + ``` + is expanded to + ```mlir + %c0 = arith.constant 0 : index + %dim = memref.dim %alloc, %c0 : memref + %is_old_smaller = arith.cmpi ult, %dim, %arg1 + %realloc = scf.if %is_old_smaller -> (memref) { + %new_alloc = memref.alloc(%size) : memref + %subview = memref.subview %new_alloc[0] [%dim] [1] + memref.copy %alloc, %subview + memref.dealloc %alloc + scf.yield %alloc_0 : memref + } else { + %reinterpret_cast = memref.reinterpret_cast %alloc to + offset: [0], sizes: [%size], strides: [1] + scf.yield %reinterpret_cast : memref + } + ``` + }]; + let options = [ + Option<"emitDeallocs", "emit-deallocs", "bool", /*default=*/"true", + "Emit deallocation operations for the original MemRef">, + ]; + let constructor = "mlir::memref::createExpandReallocPass()"; + let dependentDialects = [ + "arith::ArithDialect", "scf::SCFDialect", "memref::MemRefDialect" + ]; +} + #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -63,6 +63,10 @@ /// `memref.extract_strided_metadata` of its source. void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns); +/// Appends patterns for expanding `memref.realloc` operations. +void populateExpandReallocPatterns(RewritePatternSet &patterns, + bool emitDeallocs = true); + /// Appends patterns for emulating wide integer memref operations with ops over /// narrower integer types. void populateMemRefWideIntEmulationPatterns( diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -116,165 +116,6 @@ } }; -/// The base class for lowering realloc op, to support the implementation of -/// realloc via allocation methods that may or may not support alignment. -/// A derived class should provide an implementation of allocateBuffer using -/// the underline allocation methods. -struct ReallocOpLoweringBase : public AllocationOpLLVMLowering { - using OpAdaptor = typename memref::ReallocOp::Adaptor; - - ReallocOpLoweringBase(const LLVMTypeConverter &converter) - : AllocationOpLLVMLowering(memref::ReallocOp::getOperationName(), - converter) {} - - /// Allocates the new buffer. Returns the allocated pointer and the - /// aligned pointer. - virtual std::tuple - allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, - Value sizeBytes, memref::ReallocOp op) const = 0; - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto reallocOp = cast(op); - return matchAndRewrite(reallocOp, OpAdaptor(operands, reallocOp), rewriter); - } - - // A `realloc` is converted as follows: - // If new_size > old_size - // 1. allocates a new buffer - // 2. copies the content of the old buffer to the new buffer - // 3. release the old buffer - // 3. updates the buffer pointers in the memref descriptor - // Update the size in the memref descriptor - // Alignment request is handled by allocating `alignment` more bytes than - // requested and shifting the aligned pointer relative to the allocated - // memory. - LogicalResult matchAndRewrite(memref::ReallocOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - OpBuilder::InsertionGuard guard(rewriter); - Location loc = op.getLoc(); - - auto computeNumElements = - [&](MemRefType type, function_ref getDynamicSize) -> Value { - // Compute number of elements. - Type indexType = ConvertToLLVMPattern::getIndexType(); - Value numElements = - type.isDynamicDim(0) - ? getDynamicSize() - : createIndexAttrConstant(rewriter, loc, indexType, - type.getDimSize(0)); - if (numElements.getType() != indexType) - numElements = typeConverter->materializeTargetConversion( - rewriter, loc, indexType, numElements); - return numElements; - }; - - MemRefDescriptor desc(adaptor.getSource()); - Value oldDesc = desc; - - // Split the block right before the current op into two blocks. - Block *currentBlock = rewriter.getInsertionBlock(); - Block *block = - rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); - // Add a block argument by creating an empty block with the argument type - // and then merging the block into the empty block. - Block *endBlock = rewriter.createBlock( - block->getParent(), Region::iterator(block), oldDesc.getType(), loc); - rewriter.mergeBlocks(block, endBlock, {}); - // Add a new block for the true branch of the conditional statement we will - // add. - Block *trueBlock = rewriter.createBlock( - currentBlock->getParent(), std::next(Region::iterator(currentBlock))); - - rewriter.setInsertionPointToEnd(currentBlock); - Value src = op.getSource(); - auto srcType = dyn_cast(src.getType()); - Value srcNumElements = computeNumElements( - srcType, [&]() -> Value { return desc.size(rewriter, loc, 0); }); - auto dstType = cast(op.getType()); - Value dstNumElements = computeNumElements( - dstType, [&]() -> Value { return op.getDynamicResultSize(); }); - Value cond = rewriter.create( - loc, IntegerType::get(rewriter.getContext(), 1), - LLVM::ICmpPredicate::ugt, dstNumElements, srcNumElements); - rewriter.create(loc, cond, trueBlock, ArrayRef(), - endBlock, ValueRange{oldDesc}); - - rewriter.setInsertionPointToStart(trueBlock); - Value sizeInBytes = getSizeInBytes(loc, dstType.getElementType(), rewriter); - // Compute total byte size. - auto dstByteSize = - rewriter.create(loc, dstNumElements, sizeInBytes); - // Since the src and dst memref are guarantee to have the same - // element type by the verifier, it is safe here to reuse the - // type size computed from dst memref. - auto srcByteSize = - rewriter.create(loc, srcNumElements, sizeInBytes); - // Allocate a new buffer. - auto [dstRawPtr, dstAlignedPtr] = - allocateBuffer(rewriter, loc, dstByteSize, op); - // Copy the data from the old buffer to the new buffer. - Value srcAlignedPtr = desc.alignedPtr(rewriter, loc); - auto toVoidPtr = [&](Value ptr) -> Value { - if (getTypeConverter()->useOpaquePointers()) - return ptr; - return rewriter.create(loc, getVoidPtrType(), ptr); - }; - rewriter.create(loc, toVoidPtr(dstAlignedPtr), - toVoidPtr(srcAlignedPtr), srcByteSize, - /*isVolatile=*/false); - // Deallocate the old buffer. - LLVM::LLVMFuncOp freeFunc = - getFreeFn(getTypeConverter(), op->getParentOfType()); - rewriter.create(loc, freeFunc, - toVoidPtr(desc.allocatedPtr(rewriter, loc))); - // Replace the old buffer addresses in the MemRefDescriptor with the new - // buffer addresses. - desc.setAllocatedPtr(rewriter, loc, dstRawPtr); - desc.setAlignedPtr(rewriter, loc, dstAlignedPtr); - rewriter.create(loc, Value(desc), endBlock); - - rewriter.setInsertionPoint(op); - // Update the memref size. - MemRefDescriptor newDesc(endBlock->getArgument(0)); - newDesc.setSize(rewriter, loc, 0, dstNumElements); - rewriter.replaceOp(op, {newDesc}); - return success(); - } - -private: - using ConvertToLLVMPattern::matchAndRewrite; -}; - -struct ReallocOpLowering : public ReallocOpLoweringBase { - ReallocOpLowering(const LLVMTypeConverter &converter) - : ReallocOpLoweringBase(converter) {} - std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, - Location loc, Value sizeBytes, - memref::ReallocOp op) const override { - return allocateBufferManuallyAlign(rewriter, loc, sizeBytes, op, - getAlignment(rewriter, loc, op)); - } -}; - -struct AlignedReallocOpLowering : public ReallocOpLoweringBase { - AlignedReallocOpLowering(const LLVMTypeConverter &converter) - : ReallocOpLoweringBase(converter) {} - std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, - Location loc, Value sizeBytes, - memref::ReallocOp op) const override { - Value ptr = allocateBufferAutoAlign( - rewriter, loc, sizeBytes, op, &defaultLayout, - alignedAllocationGetAlignment(rewriter, loc, op, &defaultLayout)); - return std::make_tuple(ptr, ptr); - } - -private: - /// Default layout to use in absence of the corresponding analysis. - DataLayout defaultLayout; -}; - struct AllocaScopeOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -1899,11 +1740,9 @@ // clang-format on auto allocLowering = converter.getOptions().allocLowering; if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) - patterns.add(converter); + patterns.add(converter); else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) - patterns.add( - converter); + patterns.add(converter); } namespace { diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp ComposeSubView.cpp ExpandOps.cpp + ExpandRealloc.cpp ExpandStridedMetadata.cpp EmulateWideInt.cpp EmulateNarrowType.cpp diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp @@ -0,0 +1,175 @@ +//===- ExpandRealloc.cpp - Expand memref.realloc ops into it's components -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace memref { +#define GEN_PASS_DEF_EXPANDREALLOC +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" +} // namespace memref +} // namespace mlir + +using namespace mlir; + +namespace { + +/// The `realloc` operation performs a conditional allocation and copy to +/// increase the size of a buffer if necessary. This pattern converts the +/// `realloc` operation into this sequence of simpler operations. + +/// Example of an expansion: +/// ```mlir +/// %realloc = memref.realloc %alloc (%size) : memref to memref +/// ``` +/// is expanded to +/// ```mlir +/// %c0 = arith.constant 0 : index +/// %dim = memref.dim %alloc, %c0 : memref +/// %is_old_smaller = arith.cmpi ult, %dim, %arg1 +/// %realloc = scf.if %is_old_smaller -> (memref) { +/// %new_alloc = memref.alloc(%size) : memref +/// %subview = memref.subview %new_alloc[0] [%dim] [1] +/// memref.copy %alloc, %subview +/// memref.dealloc %alloc +/// scf.yield %alloc_0 : memref +/// } else { +/// %reinterpret_cast = memref.reinterpret_cast %alloc to +/// offset: [0], sizes: [%size], strides: [1] +/// scf.yield %reinterpret_cast : memref +/// } +/// ``` +struct ExpandReallocOpPattern : public OpRewritePattern { + ExpandReallocOpPattern(MLIRContext *ctx, bool emitDeallocs) + : OpRewritePattern(ctx), emitDeallocs(emitDeallocs) {} + + LogicalResult matchAndRewrite(memref::ReallocOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + assert(op.getType().getRank() == 1 && + "result MemRef must have exactly one rank"); + assert(op.getSource().getType().getRank() == 1 && + "source MemRef must have exactly one rank"); + assert(op.getType().getLayout().isIdentity() && + "result MemRef must have identity layout (or none)"); + assert(op.getSource().getType().getLayout().isIdentity() && + "source MemRef must have identity layout (or none)"); + + // Get the size of the original buffer. + int64_t inputSize = + op.getSource().getType().cast().getDimSize(0); + OpFoldResult currSize = rewriter.getIndexAttr(inputSize); + if (ShapedType::isDynamic(inputSize)) { + Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc, + rewriter.getIndexAttr(0)); + currSize = rewriter.create(loc, op.getSource(), dimZero) + .getResult(); + } + + // Get the requested size that the new buffer should have. + int64_t outputSize = + op.getResult().getType().cast().getDimSize(0); + OpFoldResult targetSize = ShapedType::isDynamic(outputSize) + ? OpFoldResult{op.getDynamicResultSize()} + : rewriter.getIndexAttr(outputSize); + + // Only allocate a new buffer and copy over the values in the old buffer if + // the old buffer is smaller than the requested size. + Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize); + Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize); + Value cond = rewriter.create(loc, arith::CmpIPredicate::ult, + lhs, rhs); + auto ifOp = rewriter.create( + loc, cond, + [&](OpBuilder &builder, Location loc) { + // Allocate the new buffer. If it is a dynamic memref we need to pass + // an additional operand for the size at runtime, otherwise the static + // size is encoded in the result type. + SmallVector dynamicSizeOperands; + if (op.getDynamicResultSize()) + dynamicSizeOperands.push_back(op.getDynamicResultSize()); + + Value newAlloc = builder.create( + loc, op.getResult().getType(), dynamicSizeOperands, + op.getAlignmentAttr()); + + // Take a subview of the new (bigger) buffer such that we can copy the + // old values over (the copy operation requires both operands to have + // the same shape). + Value subview = builder.create( + loc, newAlloc, ArrayRef{rewriter.getIndexAttr(0)}, + ArrayRef{currSize}, + ArrayRef{rewriter.getIndexAttr(1)}); + builder.create(loc, op.getSource(), subview); + + // Insert the deallocation of the old buffer only if requested + // (enabled by default). + if (emitDeallocs) + builder.create(loc, op.getSource()); + + builder.create(loc, newAlloc); + }, + [&](OpBuilder &builder, Location loc) { + // We need to reinterpret-cast here because either the input or output + // type might be static, which means we need to cast from static to + // dynamic or vice-versa. If both are static and the original buffer + // is already bigger than the requested size, the cast represents a + // subview operation. + Value casted = builder.create( + loc, op.getResult().getType().cast(), op.getSource(), + rewriter.getIndexAttr(0), ArrayRef{targetSize}, + ArrayRef{rewriter.getIndexAttr(1)}); + builder.create(loc, casted); + }); + + rewriter.replaceOp(op, ifOp.getResult(0)); + return success(); + } + +private: + const bool emitDeallocs; +}; + +struct ExpandReallocPass + : public memref::impl::ExpandReallocBase { + ExpandReallocPass(bool emitDeallocs) + : memref::impl::ExpandReallocBase() { + this->emitDeallocs.setValue(emitDeallocs); + } + void runOnOperation() override { + MLIRContext &ctx = getContext(); + + RewritePatternSet patterns(&ctx); + memref::populateExpandReallocPatterns(patterns, emitDeallocs.getValue()); + ConversionTarget target(ctx); + + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +void mlir::memref::populateExpandReallocPatterns(RewritePatternSet &patterns, + bool emitDeallocs) { + patterns.add(patterns.getContext(), emitDeallocs); +} + +std::unique_ptr mlir::memref::createExpandReallocPass(bool emitDeallocs) { + return std::make_unique(emitDeallocs); +} diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -59,6 +59,7 @@ // it to this pipeline. pm.addNestedPass(createConvertLinalgToLoopsPass()); pm.addNestedPass(createConvertVectorToSCFPass()); + pm.addNestedPass(memref::createExpandReallocPass()); pm.addNestedPass(createConvertSCFToCFPass()); pm.addPass(memref::createExpandStridedMetadataPass()); pm.addPass(createLowerAffinePass()); diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -644,114 +644,3 @@ memref.cast %0 : memref<1 x memref<* x f32>> to memref<* x memref<* x f32>> return } - -// ----- - -// CHECK-LABEL: func.func @realloc_dynamic( -// CHECK-SAME: %[[arg0:.*]]: memref, -// CHECK-SAME: %[[arg1:.*]]: index) -> memref { -func.func @realloc_dynamic(%in: memref, %d: index) -> memref{ -// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] -// CHECK: %[[src_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0] -// CHECK: %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64 -// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[src_dim]] : i64 -// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]] -// CHECK: ^bb1: -// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] -// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 -// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] -// CHECK: %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]] -// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]]) -// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1] -// CHECK: "llvm.intr.memcpy"(%[[new_buffer_raw]], %[[old_buffer_aligned]], %[[src_size]]) <{isVolatile = false}> -// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] -// CHECK: llvm.call @free(%[[old_buffer_unaligned]]) -// CHECK: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer_raw]], %[[descriptor]][0] -// CHECK: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer_raw]], %[[descriptor_update1]][1] -// CHECK: llvm.br ^bb2(%[[descriptor_update2]] -// CHECK: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>): -// CHECK: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0] -// CHECK: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]] -// CHECK: return %[[descriptor_update5]] : memref - - %out = memref.realloc %in(%d) : memref to memref - return %out : memref -} - -// ----- - -// CHECK-LABEL: func.func @realloc_dynamic_alignment( -// CHECK-SAME: %[[arg0:.*]]: memref, -// CHECK-SAME: %[[arg1:.*]]: index) -> memref { -// ALIGNED-ALLOC-LABEL: func.func @realloc_dynamic_alignment( -// ALIGNED-ALLOC-SAME: %[[arg0:.*]]: memref, -// ALIGNED-ALLOC-SAME: %[[arg1:.*]]: index) -> memref { -func.func @realloc_dynamic_alignment(%in: memref, %d: index) -> memref{ -// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] -// CHECK: %[[drc_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0] -// CHECK: %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64 -// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[drc_dim]] : i64 -// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]] -// CHECK: ^bb1: -// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] -// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 -// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] -// CHECK: %[[src_size:.*]] = llvm.mul %[[drc_dim]], %[[dst_es]] -// CHECK: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64 -// CHECK: %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]] -// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]]) -// CHECK: %[[new_buffer_int:.*]] = llvm.ptrtoint %[[new_buffer_raw]] : !llvm.ptr -// CHECK: %[[const_1:.*]] = llvm.mlir.constant(1 : index) : i64 -// CHECK: %[[alignment_m1:.*]] = llvm.sub %[[alignment]], %[[const_1]] -// CHECK: %[[ptr_alignment_m1:.*]] = llvm.add %[[new_buffer_int]], %[[alignment_m1]] -// CHECK: %[[padding:.*]] = llvm.urem %[[ptr_alignment_m1]], %[[alignment]] -// CHECK: %[[new_buffer_aligned_int:.*]] = llvm.sub %[[ptr_alignment_m1]], %[[padding]] -// CHECK: %[[new_buffer_aligned:.*]] = llvm.inttoptr %[[new_buffer_aligned_int]] : i64 to !llvm.ptr -// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1] -// CHECK: "llvm.intr.memcpy"(%[[new_buffer_aligned]], %[[old_buffer_aligned]], %[[src_size]]) <{isVolatile = false}> -// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] -// CHECK: llvm.call @free(%[[old_buffer_unaligned]]) -// CHECK: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer_raw]], %[[descriptor]][0] -// CHECK: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer_aligned]], %[[descriptor_update1]][1] -// CHECK: llvm.br ^bb2(%[[descriptor_update2]] -// CHECK: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>): -// CHECK: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0] -// CHECK: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]] -// CHECK: return %[[descriptor_update5]] : memref - -// ALIGNED-ALLOC: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] -// ALIGNED-ALLOC: %[[drc_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0] -// ALIGNED-ALLOC: %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64 -// ALIGNED-ALLOC: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[drc_dim]] : i64 -// ALIGNED-ALLOC: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]] -// ALIGNED-ALLOC: ^bb1: -// ALIGNED-ALLOC: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr -// ALIGNED-ALLOC: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] -// ALIGNED-ALLOC: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 -// ALIGNED-ALLOC: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] -// ALIGNED-ALLOC: %[[src_size:.*]] = llvm.mul %[[drc_dim]], %[[dst_es]] -// ALIGNED-ALLOC-DAG: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64 -// ALIGNED-ALLOC-DAG: %[[const_1:.*]] = llvm.mlir.constant(1 : index) : i64 -// ALIGNED-ALLOC: %[[alignment_m1:.*]] = llvm.sub %[[alignment]], %[[const_1]] -// ALIGNED-ALLOC: %[[size_alignment_m1:.*]] = llvm.add %[[dst_size]], %[[alignment_m1]] -// ALIGNED-ALLOC: %[[padding:.*]] = llvm.urem %[[size_alignment_m1]], %[[alignment]] -// ALIGNED-ALLOC: %[[adjust_dst_size:.*]] = llvm.sub %[[size_alignment_m1]], %[[padding]] -// ALIGNED-ALLOC: %[[new_buffer_raw:.*]] = llvm.call @aligned_alloc(%[[alignment]], %[[adjust_dst_size]]) -// ALIGNED-ALLOC: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1] -// ALIGNED-ALLOC: "llvm.intr.memcpy"(%[[new_buffer_raw]], %[[old_buffer_aligned]], %[[src_size]]) <{isVolatile = false}> -// ALIGNED-ALLOC: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] -// ALIGNED-ALLOC: llvm.call @free(%[[old_buffer_unaligned]]) -// ALIGNED-ALLOC: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer_raw]], %[[descriptor]][0] -// ALIGNED-ALLOC: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer_raw]], %[[descriptor_update1]][1] -// ALIGNED-ALLOC: llvm.br ^bb2(%[[descriptor_update2]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -// ALIGNED-ALLOC: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>): -// ALIGNED-ALLOC: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0] -// ALIGNED-ALLOC: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]] -// ALIGNED-ALLOC: return %[[descriptor_update5]] : memref - - %out = memref.realloc %in(%d) {alignment = 8} : memref to memref - return %out : memref -} - diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir @@ -345,84 +345,6 @@ return %1 : memref } -// ----- - -// CHECK-LABEL: func.func @realloc_static( -// CHECK-SAME: %[[arg0:.*]]: memref<2xi32>) -> memref<4xi32> { -func.func @realloc_static(%in: memref<2xi32>) -> memref<4xi32>{ -// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : memref<2xi32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK: %[[src_dim:.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK: %[[dst_dim:.*]] = llvm.mlir.constant(4 : index) : i64 -// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[src_dim]] -// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]] -// CHECK: ^bb1: -// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] -// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 -// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] -// CHECK: %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]] -// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]]) -// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1] -// CHECK: "llvm.intr.memcpy"(%[[new_buffer_raw]], %[[old_buffer_aligned]], %[[src_size]]) <{isVolatile = false}> -// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] -// CHECK: llvm.call @free(%[[old_buffer_unaligned]]) -// CHECK: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer_raw]], %[[descriptor]][0] -// CHECK: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer_raw]], %[[descriptor_update1]][1] -// CHECK: llvm.br ^bb2(%[[descriptor_update2]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -// CHECK: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>): -// CHECK: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0] -// CHECK: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]] -// CHECK: return %[[descriptor_update5]] : memref<4xi32> - - %out = memref.realloc %in : memref<2xi32> to memref<4xi32> - return %out : memref<4xi32> -} - -// ----- - -// CHECK-LABEL: func.func @realloc_static_alignment( -// CHECK-SAME: %[[arg0:.*]]: memref<2xf32>) -> memref<4xf32> { -func.func @realloc_static_alignment(%in: memref<2xf32>) -> memref<4xf32>{ -// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : memref<2xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK: %[[src_dim:.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK: %[[dst_dim:.*]] = llvm.mlir.constant(4 : index) : i64 -// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[src_dim]] : i64 -// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]] -// CHECK: ^bb1: -// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] -// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 -// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] -// CHECK: %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]] -// CHECK: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : i64 -// CHECK: %[[adjust_dst_size:.*]] = llvm.add %[[dst_size]], %[[alignment]] -// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[adjust_dst_size]]) -// CHECK: %[[new_buffer_int:.*]] = llvm.ptrtoint %[[new_buffer_raw]] : !llvm.ptr -// CHECK: %[[const_1:.*]] = llvm.mlir.constant(1 : index) : i64 -// CHECK: %[[alignment_m1:.*]] = llvm.sub %[[alignment]], %[[const_1]] -// CHECK: %[[ptr_alignment_m1:.*]] = llvm.add %[[new_buffer_int]], %[[alignment_m1]] -// CHECK: %[[padding:.*]] = llvm.urem %[[ptr_alignment_m1]], %[[alignment]] -// CHECK: %[[new_buffer_aligned_int:.*]] = llvm.sub %[[ptr_alignment_m1]], %[[padding]] -// CHECK: %[[new_buffer_aligned:.*]] = llvm.inttoptr %[[new_buffer_aligned_int]] : i64 to !llvm.ptr -// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1] -// CHECK: "llvm.intr.memcpy"(%[[new_buffer_aligned]], %[[old_buffer_aligned]], %[[src_size]]) <{isVolatile = false}> -// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] -// CHECK: llvm.call @free(%[[old_buffer_unaligned]]) -// CHECK: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer_raw]], %[[descriptor]][0] -// CHECK: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer_aligned]], %[[descriptor_update1]][1] -// CHECK: llvm.br ^bb2(%[[descriptor_update2]] -// CHECK: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>): -// CHECK: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0] -// CHECK: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]] -// CHECK: return %[[descriptor_update5]] : memref<4xf32> - - - %out = memref.realloc %in {alignment = 8} : memref<2xf32> to memref<4xf32> - return %out : memref<4xf32> -} - -// ----- - // CHECK-LABEL: @memref_memory_space_cast func.func @memref_memory_space_cast(%input : memref) -> memref { %cast = memref.memory_space_cast %input : memref to memref diff --git a/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir b/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir --- a/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir @@ -373,42 +373,3 @@ // CHECK: ^bb3: // CHECK: return - -// ----- - -// CHECK-LABEL: func.func @realloc_dynamic( -// CHECK-SAME: %[[arg0:.*]]: memref, -// CHECK-SAME: %[[arg1:.*]]: index) -> memref { -func.func @realloc_dynamic(%in: memref, %d: index) -> memref{ -// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %[[arg0]] -// CHECK: %[[src_dim:.*]] = llvm.extractvalue %[[descriptor]][3, 0] -// CHECK: %[[dst_dim:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : index to i64 -// CHECK: %[[cond:.*]] = llvm.icmp "ugt" %[[dst_dim]], %[[src_dim]] : i64 -// CHECK: llvm.cond_br %[[cond]], ^bb1, ^bb2(%[[descriptor]] -// CHECK: ^bb1: -// CHECK: %[[dst_null:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK: %[[dst_gep:.*]] = llvm.getelementptr %[[dst_null]][1] -// CHECK: %[[dst_es:.*]] = llvm.ptrtoint %[[dst_gep]] : !llvm.ptr to i64 -// CHECK: %[[dst_size:.*]] = llvm.mul %[[dst_dim]], %[[dst_es]] -// CHECK: %[[src_size:.*]] = llvm.mul %[[src_dim]], %[[dst_es]] -// CHECK: %[[new_buffer_raw:.*]] = llvm.call @malloc(%[[dst_size]]) -// CHECK: %[[new_buffer:.*]] = llvm.bitcast %[[new_buffer_raw]] : !llvm.ptr to !llvm.ptr -// CHECK: %[[old_buffer_aligned:.*]] = llvm.extractvalue %[[descriptor]][1] -// CHECK-DAG: %[[new_buffer_void:.*]] = llvm.bitcast %[[new_buffer]] : !llvm.ptr to !llvm.ptr -// CHECK-DAG: %[[old_buffer_void:.*]] = llvm.bitcast %[[old_buffer_aligned]] : !llvm.ptr to !llvm.ptr -// CHECK: "llvm.intr.memcpy"(%[[new_buffer_void]], %[[old_buffer_void]], %[[src_size]]) <{isVolatile = false}> -// CHECK: %[[old_buffer_unaligned:.*]] = llvm.extractvalue %[[descriptor]][0] -// CHECK: %[[old_buffer_unaligned_void:.*]] = llvm.bitcast %[[old_buffer_unaligned]] : !llvm.ptr to !llvm.ptr -// CHECK: llvm.call @free(%[[old_buffer_unaligned_void]]) -// CHECK: %[[descriptor_update1:.*]] = llvm.insertvalue %[[new_buffer]], %[[descriptor]][0] -// CHECK: %[[descriptor_update2:.*]] = llvm.insertvalue %[[new_buffer]], %[[descriptor_update1]][1] -// CHECK: llvm.br ^bb2(%[[descriptor_update2]] -// CHECK: ^bb2(%[[descriptor_update3:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>): -// CHECK: %[[descriptor_update4:.*]] = llvm.insertvalue %[[dst_dim]], %[[descriptor_update3]][3, 0] -// CHECK: %[[descriptor_update5:.*]] = builtin.unrealized_conversion_cast %[[descriptor_update4]] -// CHECK: return %[[descriptor_update5]] : memref - - %out = memref.realloc %in(%d) : memref to memref - return %out : memref -} - diff --git a/mlir/test/Dialect/MemRef/expand-realloc.mlir b/mlir/test/Dialect/MemRef/expand-realloc.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/expand-realloc.mlir @@ -0,0 +1,107 @@ +// RUN: mlir-opt --expand-realloc %s --split-input-file --verify-diagnostics | FileCheck %s +// RUN: mlir-opt --expand-realloc=emit-deallocs=false %s --split-input-file --verify-diagnostics | FileCheck %s --check-prefix=NODEALLOC + +func.func @reallow_lowering_example(%init_size: index, %new_size: index) -> memref { + %alloc = memref.alloc(%init_size) : memref + %realloc = memref.realloc %alloc (%new_size) {alignment = 8}: memref to memref + return %realloc : memref +} + +// CHECK-LABEL: func @reallow_lowering_example +// CHECK-SAME: ([[INIT_SIZE:%.+]]: index, [[NEW_SIZE:%.+]]: index) +// CHECK-NEXT: [[OLD_ALLOC:%.+]] = memref.alloc([[INIT_SIZE]]) : memref +// CHECK-NEXT: [[C0:%.+]] = arith.constant 0 +// CHECK-NEXT: [[CURR_SIZE:%.+]] = memref.dim [[OLD_ALLOC]], [[C0]] +// CHECK-NEXT: [[COND:%.+]] = arith.cmpi ult, [[CURR_SIZE]], [[NEW_SIZE]] +// CHECK-NEXT: [[REALLOC:%.+]] = scf.if [[COND]] +// CHECK-NEXT: [[NEW_ALLOC:%.+]] = memref.alloc([[NEW_SIZE]]) {alignment = 8 : i64} : memref +// CHECK-NEXT: [[SUBVIEW:%.+]] = memref.subview [[NEW_ALLOC]][0] [[[CURR_SIZE]]] [1] +// CHECK-NEXT: memref.copy [[OLD_ALLOC]], [[SUBVIEW]] +// CHECK-NEXT: memref.dealloc [[OLD_ALLOC]] +// CHECK-NEXT: scf.yield [[NEW_ALLOC]] +// CHECK: [[CAST:%.+]] = memref.reinterpret_cast [[OLD_ALLOC]] to offset: [0], sizes: [[[NEW_SIZE]]], strides: [1] : memref to memref +// CHECK-NEXT: scf.yield [[CAST]] +// CHECK: return [[REALLOC]] + +// NODEALLOC-LABEL: func @reallow_lowering_example +// NODEALLOC-NOT: dealloc + +// ----- + +func.func @reallow_lowering_example() -> memref<4xf32> { + %alloc = memref.alloc() : memref<2xf32> + %realloc = memref.realloc %alloc {alignment = 8}: memref<2xf32> to memref<4xf32> + return %realloc : memref<4xf32> +} + +// CHECK-LABEL: func @reallow_lowering_example +// CHECK-NEXT: [[OLD_ALLOC:%.+]] = memref.alloc() : memref<2xf32> +// CHECK-NEXT: [[CURR_SIZE:%.+]] = arith.constant 2 +// CHECK-NEXT: [[NEW_SIZE:%.+]] = arith.constant 4 +// CHECK-NEXT: [[COND:%.+]] = arith.cmpi ult, [[CURR_SIZE]], [[NEW_SIZE]] +// CHECK-NEXT: [[REALLOC:%.+]] = scf.if [[COND]] +// CHECK-NEXT: [[NEW_ALLOC:%.+]] = memref.alloc() {alignment = 8 : i64} : memref<4xf32> +// CHECK-NEXT: [[SUBVIEW:%.+]] = memref.subview [[NEW_ALLOC]][0] [2] [1] +// CHECK-NEXT: memref.copy [[OLD_ALLOC]], [[SUBVIEW]] +// CHECK-NEXT: memref.dealloc [[OLD_ALLOC]] +// CHECK-NEXT: scf.yield [[NEW_ALLOC]] +// CHECK: [[CAST:%.+]] = memref.reinterpret_cast [[OLD_ALLOC]] to offset: [0], sizes: [4], strides: [1] : memref<2xf32> to memref<4xf32> +// CHECK-NEXT: scf.yield [[CAST]] +// CHECK: return [[REALLOC]] + +// NODEALLOC-LABEL: func @reallow_lowering_example +// NODEALLOC-NOT: dealloc + +// ----- + +func.func @reallow_lowering_example(%init_size: index) -> memref<4xf32> { + %alloc = memref.alloc(%init_size) : memref + %realloc = memref.realloc %alloc : memref to memref<4xf32> + return %realloc : memref<4xf32> +} + +// CHECK-LABEL: func @reallow_lowering_example +// CHECK-SAME: ([[INIT_SIZE:%.+]]: index) +// CHECK-NEXT: [[OLD_ALLOC:%.+]] = memref.alloc([[INIT_SIZE]]) : memref +// CHECK-NEXT: [[C0:%.+]] = arith.constant 0 +// CHECK-NEXT: [[CURR_SIZE:%.+]] = memref.dim [[OLD_ALLOC]], [[C0]] +// CHECK-NEXT: [[NEW_SIZE:%.+]] = arith.constant 4 +// CHECK-NEXT: [[COND:%.+]] = arith.cmpi ult, [[CURR_SIZE]], [[NEW_SIZE]] +// CHECK-NEXT: [[REALLOC:%.+]] = scf.if [[COND]] +// CHECK-NEXT: [[NEW_ALLOC:%.+]] = memref.alloc() : memref<4xf32> +// CHECK-NEXT: [[SUBVIEW:%.+]] = memref.subview [[NEW_ALLOC]][0] [[[CURR_SIZE]]] [1] +// CHECK-NEXT: memref.copy [[OLD_ALLOC]], [[SUBVIEW]] +// CHECK-NEXT: memref.dealloc [[OLD_ALLOC]] +// CHECK-NEXT: scf.yield [[NEW_ALLOC]] +// CHECK: [[CAST:%.+]] = memref.reinterpret_cast [[OLD_ALLOC]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xf32> +// CHECK-NEXT: scf.yield [[CAST]] +// CHECK: return [[REALLOC]] + +// NODEALLOC-LABEL: func @reallow_lowering_example +// NODEALLOC-NOT: dealloc + +// ----- + +func.func @reallow_lowering_example(%new_size: index) -> memref { + %alloc = memref.alloc() : memref<2xf32> + %realloc = memref.realloc %alloc(%new_size) : memref<2xf32> to memref + return %realloc : memref +} + +// CHECK-LABEL: func @reallow_lowering_example +// CHECK-SAME: ([[NEW_SIZE:%.+]]: index) +// CHECK-NEXT: [[OLD_ALLOC:%.+]] = memref.alloc() : memref<2xf32> +// CHECK-NEXT: [[CURR_SIZE:%.+]] = arith.constant 2 +// CHECK-NEXT: [[COND:%.+]] = arith.cmpi ult, [[CURR_SIZE]], [[NEW_SIZE]] +// CHECK-NEXT: [[REALLOC:%.+]] = scf.if [[COND]] +// CHECK-NEXT: [[NEW_ALLOC:%.+]] = memref.alloc([[NEW_SIZE]]) : memref +// CHECK-NEXT: [[SUBVIEW:%.+]] = memref.subview [[NEW_ALLOC]][0] [2] [1] +// CHECK-NEXT: memref.copy [[OLD_ALLOC]], [[SUBVIEW]] +// CHECK-NEXT: memref.dealloc [[OLD_ALLOC]] +// CHECK-NEXT: scf.yield [[NEW_ALLOC]] +// CHECK: [[CAST:%.+]] = memref.reinterpret_cast [[OLD_ALLOC]] to offset: [0], sizes: [[[NEW_SIZE]]], strides: [1] : memref<2xf32> to memref +// CHECK-NEXT: scf.yield [[CAST]] +// CHECK: return [[REALLOC]] + +// NODEALLOC-LABEL: func @reallow_lowering_example +// NODEALLOC-NOT: dealloc diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-realloc.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-realloc.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-realloc.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-realloc.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt %s -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts |\ +// RUN: mlir-opt %s -convert-vector-to-scf -expand-realloc -expand-strided-metadata -convert-scf-to-cf -convert-vector-to-llvm -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts |\ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_c_runner_utils -// RUN: mlir-opt %s -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm -finalize-memref-to-llvm='use-aligned-alloc=1' -convert-func-to-llvm -arith-expand -reconcile-unrealized-casts |\ +// RUN: mlir-opt %s -convert-vector-to-scf -expand-realloc -expand-strided-metadata -convert-scf-to-cf -convert-vector-to-llvm -finalize-memref-to-llvm='use-aligned-alloc=1' -convert-func-to-llvm -arith-expand -reconcile-unrealized-casts |\ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_c_runner_utils | FileCheck %s diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -11612,6 +11612,7 @@ ":NVGPUDialect", ":Pass", ":RuntimeVerifiableOpInterface", + ":SCFDialect", ":Support", ":TensorDialect", ":Transforms",