diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -187,13 +187,23 @@ /// Finds all associated dealloc nodes for the alloc nodes using alias /// information. - DeallocSetT findAssociatedDeallocs(AllocOp alloc) const { + DeallocSetT findAssociatedDeallocs(OpResult allocResult) const { DeallocSetT result; - auto possibleValues = aliases.resolve(alloc); + auto possibleValues = aliases.resolve(allocResult); for (Value alias : possibleValues) - for (Operation *user : alias.getUsers()) { - if (isa(user)) - result.insert(user); + for (Operation *op : alias.getUsers()) { + // Check for an existing memory effect interface. + auto effectInstance = dyn_cast(op); + if (!effectInstance) + continue; + // Check whether the associated value will be freed using the current + // operation. + SmallVector effects; + effectInstance.getEffectsOnValue(alias, effects); + if (llvm::any_of(effects, [=](MemoryEffects::EffectInstance &it) { + return isa(it.getEffect()); + })) + result.insert(op); } return result; } @@ -328,8 +338,6 @@ /// The actual buffer placement pass that moves alloc and dealloc nodes into /// the right positions. It uses the algorithm described at the top of the file. -// TODO: create a templated version that allows to match dialect-specific -// alloc/dealloc nodes and to insert dialect-specific dealloc node. struct BufferPlacementPass : mlir::PassWrapper { void runOnFunction() override { @@ -337,42 +345,58 @@ auto &analysis = getAnalysis(); // Compute an initial placement of all nodes. - llvm::SmallDenseMap placements; - getFunction().walk([&](AllocOp alloc) { - placements[alloc] = analysis.computeAllocAndDeallocPositions( - alloc.getOperation()->getResult(0)); - return WalkResult::advance(); + llvm::SmallVector, 16> + placements; + getFunction().walk([&](MemoryEffectOpInterface op) { + // Try to find a single allocation result. + SmallVector effects; + op.getEffects(effects); + + SmallVector allocateResultEffects; + llvm::copy_if(effects, std::back_inserter(allocateResultEffects), + [=](MemoryEffects::EffectInstance &it) { + Value value = it.getValue(); + return isa(it.getEffect()) && + value && value.isa(); + }); + // If there is one result only, we will be able to move the allocation and + // (possibly existing) deallocation ops. + if (allocateResultEffects.size() == 1) { + // Insert allocation result. + auto allocResult = allocateResultEffects[0].getValue().cast(); + placements.emplace_back( + allocResult, analysis.computeAllocAndDeallocPositions(allocResult)); + } }); - // Move alloc (and dealloc - if any) nodes into the right places - // and insert dealloc nodes if necessary. - getFunction().walk([&](AllocOp alloc) { + // Move alloc (and dealloc - if any) nodes into the right places and insert + // dealloc nodes if necessary. + for (auto &entry : placements) { // Find already associated dealloc nodes. + OpResult alloc = entry.first; auto deallocs = analysis.findAssociatedDeallocs(alloc); if (deallocs.size() > 1) { emitError(alloc.getLoc(), - "Not supported number of associated dealloc operations"); - return WalkResult::interrupt(); + "not supported number of associated dealloc operations"); + return; } // Move alloc node to the right place. - BufferPlacementPositions &positions = placements[alloc]; - Operation *allocOperation = alloc.getOperation(); + BufferPlacementPositions &positions = entry.second; + Operation *allocOperation = alloc.getOwner(); allocOperation->moveBefore(positions.getAllocPosition()); // If there is an existing dealloc, move it to the right place. + Operation *nextOp = positions.getDeallocPosition()->getNextNode(); + assert(nextOp && "Invalid Dealloc operation position"); if (deallocs.size()) { - Operation *nextOp = positions.getDeallocPosition()->getNextNode(); - assert(nextOp && "Invalid Dealloc operation position"); (*deallocs.begin())->moveBefore(nextOp); } else { // If there is no dealloc node, insert one in the right place. - OpBuilder builder(alloc); - builder.setInsertionPointAfter(positions.getDeallocPosition()); + OpBuilder builder(nextOp); builder.create(allocOperation->getLoc(), alloc); } - return WalkResult::advance(); - }); + } }; };