diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -340,6 +340,19 @@ return r; } +/// Return `true` if the given tensor value is a memory write. Most values are +/// tensor writes, but ops that define a tensor SSA value without specifying its +/// contents (e.g., init_tensor) are not. +static bool isMemoryWrite(Value value, const AnalysisState &state) { + auto opResult = value.dyn_cast(); + if (!opResult) + return true; + auto bufferizableOp = state.getOptions().dynCastBufferizableOp(value); + if (!bufferizableOp) + return true; + return bufferizableOp.isMemoryWrite(opResult, state); +} + /// Annotate IR with details about the detected RaW conflict. static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, Value lastWrite) { @@ -386,10 +399,11 @@ AnalysisState &state, const BufferizationAliasInfo &aliasInfo) { const BufferizationOptions &options = state.getOptions(); - // Gather all written aliases. + // Gather all written aliases. Skip over aliases that are not actual writes. SmallVector writtenAliases; for (OpOperand *uWrite : usesWrite) - writtenAliases.push_back(uWrite->get()); + if (isMemoryWrite(uWrite->get(), state)) + writtenAliases.push_back(uWrite->get()); // Find the inner-most enclosing repetitive region of each alias. If this is // the same region for every alias, save it in `repetitiveRegionOfWrites`. Optional repetitiveRegionOfWrites = @@ -451,9 +465,14 @@ // Note: iter_args of loops are not aliases of their respective block // arguments, so op domanice can be used when analyzing ops that operate // on them. + // + // Note: If `writtenAliases` is empty, there are no memory writes outside + // of the repetitive region of conflictingWritingOp, which means that all + // relevant aliases are inside the same repetitive region. bool canUseOpDominance = + writtenAliases.empty() || repetitiveRegionOfWrites == - getEnclosingRepetitiveRegion(conflictingWritingOp); + getEnclosingRepetitiveRegion(conflictingWritingOp); // No conflict if the readingOp dominates conflictingWritingOp, i.e., the // write is not visible when reading. diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir @@ -1243,3 +1243,62 @@ return %r0 : tensor } + +// ----- + +// CHECK-LABEL: func @write_to_same_init_tensor_in_place( +func.func @write_to_same_init_tensor_in_place( + %A : tensor {linalg.inplaceable = true}, + %lb : index, %ub : index, %step : index, %sz: index, %sz2: index) + -> (tensor) +{ + %B = linalg.init_tensor [%sz2] : tensor + + // CHECK: scf.for {{.*}} { + %r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor) { + %i2 = arith.index_cast %i : index to i32 + %i3 = arith.sitofp %i2 : i32 to f32 + // %B is written multiple times inside a loop, but it is an init_tensor. + // CHECK: tensor.insert + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "none"]} + %B2 = tensor.insert %i3 into %B[%i] : tensor + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]} + %A2 = tensor.insert_slice %B2 into %t[%i][%sz][1] : tensor into tensor + scf.yield %A2 : tensor + } + // CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "true"]} + + return %r0 : tensor +} + +// ----- + +// CHECK-LABEL: func @write_to_same_init_tensor_out_of_place( +func.func @write_to_same_init_tensor_out_of_place( + %A : tensor {linalg.inplaceable = true}, + %lb : index, %ub : index, %step : index, %sz: index, %sz2: index, %f: f32) + -> (tensor) +{ + %B = linalg.init_tensor [%sz2] : tensor + %C = tensor.insert %f into %B[%lb] : tensor + + // CHECK: scf.for {{.*}} { + %r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor) { + %i2 = arith.index_cast %i : index to i32 + %i3 = arith.sitofp %i2 : i32 to f32 + // %C is written multiple times inside a loop. Even though %C aliases with + // an init_tensor, out-of-bounds bufferization is necessary because there is + // another alias (%C) outside of the loop. + // CHECK: tensor.insert + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"]} + %B2 = tensor.insert %i3 into %C[%i] : tensor + // CHECK: tensor.insert_slice + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]} + %A2 = tensor.insert_slice %B2 into %t[%i][%sz][1] : tensor into tensor + scf.yield %A2 : tensor + } + // CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "true"]} + + return %r0 : tensor +}