diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -612,6 +612,31 @@ [&](Operation *op) { return getInplaceableOpResult(opOperand); }); } +// Predeclaration of function. +static bool bufferizesToMemoryRead(OpOperand &opOperand); + +/// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its +/// matching bbArg may. +static bool bufferizesToMemoryRead(scf::ForOp forOp, OpOperand &opOperand) { + SmallVector workingSet; + for (OpOperand &use : forOp.getRegionIterArgForOpOperand(opOperand).getUses()) + workingSet.push_back(&use); + + while (!workingSet.empty()) { + OpOperand *uMaybeReading = workingSet.pop_back_val(); + // Skip over all ExtractSliceOps. These do not read by themselves but just + // add a new alias. + if (auto extractSliceOp = + dyn_cast(uMaybeReading->getOwner())) + for (OpOperand &use : extractSliceOp.result().getUses()) + workingSet.push_back(&use); + if (bufferizesToMemoryRead(*uMaybeReading)) + return true; + } + + return false; +} + /// Return true if `opOperand` bufferizes to a memory read. static bool bufferizesToMemoryRead(OpOperand &opOperand) { // Unknown op that returns a tensor. The inplace analysis does not support @@ -622,15 +647,8 @@ // may. if (isa(opOperand.getOwner())) return false; - // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its - // matching bbArg may. - if (auto forOp = dyn_cast(opOperand.getOwner())) { - for (OpOperand &use : - forOp.getRegionIterArgForOpOperand(opOperand).getUses()) - if (bufferizesToMemoryRead(use)) - return true; - return false; - } + if (auto forOp = dyn_cast(opOperand.getOwner())) + return bufferizesToMemoryRead(forOp, opOperand); // TiledLoop alone doesn't bufferize to a memory read, one of the uses of its // matching bbArg may. if (auto tiledLoopOp = dyn_cast(opOperand.getOwner())) { diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -912,3 +912,104 @@ return %15 : tensor<62x90xf32> } + +// ----- + +#accesses = [ + affine_map<(i) -> (i)> +] +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel"] +} + +// CHECK-LABEL: func @reading_scf_for +func @reading_scf_for(%t1: tensor {linalg.inplaceable = true}, + %s: index, %v: vector<5xf32>) -> (tensor, vector<5xf32>) { + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + + // Write to %t1. + // CHECK: vector.transfer_write + // CHECK-SAME: __inplace_results_attr__ = ["false"] + %t3 = vector.transfer_write %v, %t1[%s] : vector<5xf32>, tensor + + // Read the old value of %t1 inside the loop via an alias. + // CHECK: scf.for + %r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor, vector<5xf32>) { + // CHECK: tensor.extract_slice + // CHECK-SAME: __inplace_results_attr__ = ["true"] + %e = tensor.extract_slice %t2[%s][%s][1] : tensor to tensor + + // Read from %t1 via alias %e. + %v2 = vector.transfer_read %e[%s], %cst : tensor, vector<5xf32> + scf.yield %e, %v2 : tensor, vector<5xf32> + } + // CHECK: __inplace_results_attr__ = ["true", "none"] + + // Use %t3 in some way without reading it, so that it does not get DCE'd. + // CHECK: linalg.generic + // CHECK-SAME: __inplace_results_attr__ = ["true"] + %o = linalg.generic #trait outs (%t3 : tensor) { + ^bb(%0: f32) : + linalg.yield %cst : f32 + } -> (tensor) + + return %o, %v3 : tensor, vector<5xf32> +} + +// ----- + +#accesses = [ + affine_map<(i) -> (i)> +] +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel"] +} + +// CHECK-LABEL: func @non_reading_scf_for +func @non_reading_scf_for(%t1: tensor {linalg.inplaceable = true}, + %s: index, %v: vector<5xf32>) -> (tensor, vector<5xf32>) { + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + + // Write to %t1. + // CHECK: vector.transfer_write + // CHECK-SAME: __inplace_results_attr__ = ["true"] + %t3 = vector.transfer_write %v, %t1[%s] : vector<5xf32>, tensor + + // This loop does not read from %t1. It only writes to it. + // CHECK: scf.for + %r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor, vector<5xf32>) { + // CHECK: tensor.extract_slice + // CHECK-SAME: __inplace_results_attr__ = ["true"] + %e = tensor.extract_slice %t2[%s][%s][1] : tensor to tensor + + // Write to %t1 via alias. (Overwrite %t3.) + // CHECK: linalg.generic + // CHECK-SAME: __inplace_results_attr__ = ["true"] + %o2 = linalg.generic #trait outs (%e : tensor) { + ^bb(%0: f32) : + linalg.yield %cst : f32 + } -> (tensor) + + // Read overwritten value. This is not a read of %t1. + %v2 = vector.transfer_read %o2[%s], %cst : tensor, vector<5xf32> + scf.yield %o2, %v2 : tensor, vector<5xf32> + } + + // Use %t3 in some way without reading it, so that it does not get DCE'd. + // CHECK: linalg.generic + // CHECK-SAME: __inplace_results_attr__ = ["true"] + %o = linalg.generic #trait outs (%t3 : tensor) { + ^bb(%0: f32) : + linalg.yield %cst : f32 + } -> (tensor) + + return %o, %v3 : tensor, vector<5xf32> +}