diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -187,13 +187,23 @@ /// Finds all associated dealloc nodes for the alloc nodes using alias /// information. - DeallocSetT findAssociatedDeallocs(AllocOp alloc) const { + DeallocSetT findAssociatedDeallocs(OpResult allocResult) const { DeallocSetT result; - auto possibleValues = aliases.resolve(alloc); + auto possibleValues = aliases.resolve(allocResult); for (Value alias : possibleValues) - for (Operation *user : alias.getUsers()) { - if (isa(user)) - result.insert(user); + for (OpOperand user : alias.getUsers()) { + // Check for an existing memory effect interface + auto effectInstance = + dyn_cast(user.getOwner()); + if (!effectInstance) + continue; + // Check whether our value will be freed + SmallVector effects; + effectInstance.getEffectsOnValue(alias, effects); + if (llvm::any_of(effects, [=](MemoryEffects::EffectInstance &it) { + return isa(it.getEffect()); + })) + result.insert(user.getOwner()); } return result; } @@ -328,8 +338,6 @@ /// The actual buffer placement pass that moves alloc and dealloc nodes into /// the right positions. It uses the algorithm described at the top of the file. -// TODO: create a templated version that allows to match dialect-specific -// alloc/dealloc nodes and to insert dialect-specific dealloc node. struct BufferPlacementPass : mlir::PassWrapper { void runOnFunction() override { @@ -337,27 +345,43 @@ auto &analysis = getAnalysis(); // Compute an initial placement of all nodes. - llvm::SmallDenseMap placements; - getFunction().walk([&](AllocOp alloc) { - placements[alloc] = analysis.computeAllocAndDeallocPositions( - alloc.getOperation()->getResult(0)); + llvm::SmallVector, 16> + placements; + getFunction().walk([&](MemoryEffectOpInterface op) { + // Try to find an allocation result. + SmallVector effects; + op.getEffects(effects); + auto allocateEffectInstance = + llvm::find_if(effects, [=](MemoryEffects::EffectInstance &it) { + Value value = it.getValue(); + return isa(it.getEffect()) && value && + value.isa(); + }); + if (allocateEffectInstance == effects.end()) + return WalkResult::advance(); + + // Insert allocation result. + auto allocResult = allocateEffectInstance->getValue().cast(); + placements.emplace_back( + allocResult, analysis.computeAllocAndDeallocPositions(allocResult)); return WalkResult::advance(); }); - // Move alloc (and dealloc - if any) nodes into the right places - // and insert dealloc nodes if necessary. - getFunction().walk([&](AllocOp alloc) { + // Move alloc (and dealloc - if any) nodes into the right places and insert + // dealloc nodes if necessary. + for (auto &entry : placements) { // Find already associated dealloc nodes. + OpResult alloc = entry.first; auto deallocs = analysis.findAssociatedDeallocs(alloc); if (deallocs.size() > 1) { emitError(alloc.getLoc(), "Not supported number of associated dealloc operations"); - return WalkResult::interrupt(); + return; } // Move alloc node to the right place. - BufferPlacementPositions &positions = placements[alloc]; - Operation *allocOperation = alloc.getOperation(); + BufferPlacementPositions &positions = entry.second; + Operation *allocOperation = alloc.getOwner(); allocOperation->moveBefore(positions.getAllocPosition()); // If there is an existing dealloc, move it to the right place. @@ -367,12 +391,11 @@ (*deallocs.begin())->moveBefore(nextOp); } else { // If there is no dealloc node, insert one in the right place. - OpBuilder builder(alloc); + OpBuilder builder(alloc.getOwner()); builder.setInsertionPointAfter(positions.getDeallocPosition()); builder.create(allocOperation->getLoc(), alloc); } - return WalkResult::advance(); - }); + } }; };