diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -542,12 +542,21 @@ return false; // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its // matching bbArg may. - if (isa(opOperand.getOwner())) + if (auto forOp = dyn_cast(opOperand.getOwner())) { + for (OpOperand &use : + forOp.getRegionIterArgForOpOperand(opOperand).getUses()) + if (bufferizesToMemoryRead(use)) + return true; return false; + } // TiledLoop alone doesn't bufferize to a memory read, one of the uses of its // matching bbArg may. - if (isa(opOperand.getOwner())) + if (auto tiledLoopOp = dyn_cast(opOperand.getOwner())) { + for (OpOperand &use : tiledLoopOp.getTiedBlockArgument(opOperand).getUses()) + if (bufferizesToMemoryRead(use)) + return true; return false; + } // CallOpInterface alone doesn't bufferize to a memory read, one of the uses // of the matching bbArg may. It is the responsibility of the caller to // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be @@ -1685,6 +1694,8 @@ b.create(forOp.getLoc(), operandBuffer, resultBuffer); } BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); + aliasInfo.createAliasInfoEntry(resultBuffer); + aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer); map(bvm, bbArg, resultBuffer); map(bvm, opResult, resultBuffer); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -474,6 +474,61 @@ // ----- +func private @some_use(tensor) -> () + +// CHECK-LABEL: func @scf_for_deps +func @scf_for_deps(%A : tensor {linalg.inplaceable = true}, + %B : tensor {linalg.inplaceable = true}, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + // %r0 must be out of place because one use of %t in the subsequent production + // of %r1 is read. + // CHECK: scf.for + // CHECK-NEXT: scf.yield + // CHECK-NEXT: {__inplace_results_attr__ = ["false"]} + %r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor) { + scf.yield %t : tensor + } + + // %r1 bufferizes inplace fine. + // CHECK: scf.for + // CHECK-NEXT: call + // CHECK-NEXT: scf.yield + // CHECK-NEXT: {__inplace_results_attr__ = ["true"]} + %r1 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor) { + call @some_use(%t) : (tensor) -> () + scf.yield %t : tensor + } + + // %r2 must be out of place because one use of %t in the subsequent production + // of %r3 is read. + // CHECK: linalg.tiled_loop + // CHECK-NEXT: linalg.yield + // CHECK-NEXT: {__inplace_results_attr__ = ["false"]} + %r2 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step) + ins() + outs(%t = %B: tensor) { + linalg.yield %t : tensor + } + + // %r3 bufferizes inplace fine. + // CHECK: linalg.tiled_loop + // CHECK-NEXT: call + // CHECK-NEXT: linalg.yield + // CHECK-NEXT: {__inplace_results_attr__ = ["true"]} + %r3 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step) + ins() + outs(%t = %B: tensor) { + call @some_use(%t) : (tensor) -> () + linalg.yield %t : tensor + } + + return %r1, %r3: tensor, tensor +} + +// ----- + //===----------------------------------------------------------------------===// // Cross function boundary cases. //===----------------------------------------------------------------------===//