diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp @@ -8,7 +8,6 @@ #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/SetOperations.h" @@ -58,74 +57,89 @@ this->dependencies[value].insert(dep); }; - // Add additional dependencies created by view changes to the alias list. - op->walk([&](ViewLikeOpInterface viewInterface) { - dependencies[viewInterface.getViewSource()].insert( - viewInterface->getResult(0)); - }); + op->walk([&](Operation *op) { + // TODO: We should have an op interface instead of a hard-coded list of + // interfaces/ops. - // Query all branch interfaces to link block argument dependencies. - op->walk([&](BranchOpInterface branchInterface) { - Block *parentBlock = branchInterface->getBlock(); - for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end(); - it != e; ++it) { - // Query the branch op interface to get the successor operands. - auto successorOperands = - branchInterface.getSuccessorOperands(it.getIndex()); - // Build the actual mapping of values to their immediate dependencies. - registerDependencies(successorOperands.getForwardedOperands(), - (*it)->getArguments().drop_front( - successorOperands.getProducedOperandCount())); + // Add additional dependencies created by view changes to the alias list. + if (auto viewInterface = dyn_cast(op)) { + dependencies[viewInterface.getViewSource()].insert( + viewInterface->getResult(0)); + return WalkResult::advance(); } - }); - // Query the RegionBranchOpInterface to find potential successor regions. - op->walk([&](RegionBranchOpInterface regionInterface) { - // Extract all entry regions and wire all initial entry successor inputs. - SmallVector entrySuccessors; - regionInterface.getSuccessorRegions(/*index=*/std::nullopt, - entrySuccessors); - for (RegionSuccessor &entrySuccessor : entrySuccessors) { - // Wire the entry region's successor arguments with the initial - // successor inputs. - assert(entrySuccessor.getSuccessor() && - "Invalid entry region without an attached successor region"); - registerDependencies( - regionInterface.getSuccessorEntryOperands( - entrySuccessor.getSuccessor()->getRegionNumber()), - entrySuccessor.getSuccessorInputs()); + if (auto branchInterface = dyn_cast(op)) { + // Query all branch interfaces to link block argument dependencies. + Block *parentBlock = branchInterface->getBlock(); + for (auto it = parentBlock->succ_begin(), e = parentBlock->succ_end(); + it != e; ++it) { + // Query the branch op interface to get the successor operands. + auto successorOperands = + branchInterface.getSuccessorOperands(it.getIndex()); + // Build the actual mapping of values to their immediate dependencies. + registerDependencies(successorOperands.getForwardedOperands(), + (*it)->getArguments().drop_front( + successorOperands.getProducedOperandCount())); + } + return WalkResult::advance(); } - // Wire flow between regions and from region exits. - for (Region ®ion : regionInterface->getRegions()) { - // Iterate over all successor region entries that are reachable from the - // current region. - SmallVector successorRegions; - regionInterface.getSuccessorRegions(region.getRegionNumber(), - successorRegions); - for (RegionSuccessor &successorRegion : successorRegions) { - // Determine the current region index (if any). - std::optional regionIndex; - Region *regionSuccessor = successorRegion.getSuccessor(); - if (regionSuccessor) - regionIndex = regionSuccessor->getRegionNumber(); - // Iterate over all immediate terminator operations and wire the - // successor inputs with the successor operands of each terminator. - for (Block &block : region) { - auto successorOperands = getRegionBranchSuccessorOperands( - block.getTerminator(), regionIndex); - if (successorOperands) { - registerDependencies(*successorOperands, - successorRegion.getSuccessorInputs()); + if (auto regionInterface = dyn_cast(op)) { + // Query the RegionBranchOpInterface to find potential successor regions. + // Extract all entry regions and wire all initial entry successor inputs. + SmallVector entrySuccessors; + regionInterface.getSuccessorRegions(/*index=*/std::nullopt, + entrySuccessors); + for (RegionSuccessor &entrySuccessor : entrySuccessors) { + // Wire the entry region's successor arguments with the initial + // successor inputs. + assert(entrySuccessor.getSuccessor() && + "Invalid entry region without an attached successor region"); + registerDependencies( + regionInterface.getSuccessorEntryOperands( + entrySuccessor.getSuccessor()->getRegionNumber()), + entrySuccessor.getSuccessorInputs()); + } + + // Wire flow between regions and from region exits. + for (Region ®ion : regionInterface->getRegions()) { + // Iterate over all successor region entries that are reachable from the + // current region. + SmallVector successorRegions; + regionInterface.getSuccessorRegions(region.getRegionNumber(), + successorRegions); + for (RegionSuccessor &successorRegion : successorRegions) { + // Determine the current region index (if any). + std::optional regionIndex; + Region *regionSuccessor = successorRegion.getSuccessor(); + if (regionSuccessor) + regionIndex = regionSuccessor->getRegionNumber(); + // Iterate over all immediate terminator operations and wire the + // successor inputs with the successor operands of each terminator. + for (Block &block : region) { + auto successorOperands = getRegionBranchSuccessorOperands( + block.getTerminator(), regionIndex); + if (successorOperands) { + registerDependencies(*successorOperands, + successorRegion.getSuccessorInputs()); + } } } } + + return WalkResult::advance(); } - }); - // TODO: This should be an interface. - op->walk([&](arith::SelectOp selectOp) { - registerDependencies({selectOp.getOperand(1)}, {selectOp.getResult()}); - registerDependencies({selectOp.getOperand(2)}, {selectOp.getResult()}); + // Unknown op: Assume that all operands alias with all results. + for (Value operand : op->getOperands()) { + if (!isa(operand.getType())) + continue; + for (Value result : op->getResults()) { + if (!isa(result.getType())) + continue; + registerDependencies({operand}, {result}); + } + } + return WalkResult::advance(); }); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir @@ -1317,6 +1317,27 @@ // ----- +func.func @f(%arg0: memref) -> memref { + return %arg0 : memref +} + +// CHECK-LABEL: func @function_call +// CHECK: memref.alloc +// CHECK: memref.alloc +// CHECK: call +// CHECK: test.copy +// CHECK: memref.dealloc +// CHECK: memref.dealloc +func.func @function_call() { + %alloc = memref.alloc() : memref + %alloc2 = memref.alloc() : memref + %ret = call @f(%alloc) : (memref) -> memref + test.copy(%ret, %alloc2) : (memref, memref) + return +} + +// ----- + // Memref allocated in `then` region and passed back to the parent if op. #set = affine_set<() : (0 >= 0)> // CHECK-LABEL: func @test_affine_if_1