diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -109,12 +109,19 @@ /// Set of tensors that are known to bufferize to writable memory. llvm::DenseSet bufferizeToWritableMemory; - /// Auxiliary structure to store all the values a given value aliases with. - /// These are the conservative cases that can further decompose into - /// "equivalent" buffer relationships. + /// Auxiliary structure to store all the values a given value may alias with. + /// Alias information is "may be" conservative: In the presence of branches, a + /// value may alias with one of multiple other values. The concrete aliasing + /// value may not even be known at compile time. All such values are + /// considered to be aliases. llvm::EquivalenceClasses aliasInfo; - /// Auxiliary structure to store all the equivalent buffer classes. + /// Auxiliary structure to store all the equivalent buffer classes. Equivalent + /// buffer information is "must be" conservative: Only if two values are + /// guaranteed to be equivalent at runtime, they said to be equivalent. It is + /// possible that, in the presence of branches, it cannot be determined + /// statically if two values are equivalent. In that case, the values are + /// considered to be not equivalent. llvm::EquivalenceClasses equivalentInfo; }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -545,8 +545,6 @@ ifOp.elseYield().results(), ifOp.results())) { aliasInfo.unionSets(std::get<0>(it), std::get<1>(it)); aliasInfo.unionSets(std::get<0>(it), std::get<2>(it)); - equivalentInfo.unionSets(std::get<0>(it), std::get<1>(it)); - equivalentInfo.unionSets(std::get<0>(it), std::get<2>(it)); } } }); @@ -1319,6 +1317,9 @@ assert(operandBuffer && "operand buffer not found"); // Make sure that all OpOperands are the same buffer. If this is not the case, // we would have to materialize a memref value. + // TODO: Should be looking for checking for "equivalent buffers" instead of + // operator== here, but equivalent buffers for scf.if yield values are not + // set up yet. if (!llvm::all_of(aliasingOperands, [&](OpOperand *o) { return lookup(bvm, o->get()) == operandBuffer; })) { diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -35,6 +35,22 @@ // ----- +func @scf_if_not_equivalent( + %cond: i1, %t1: tensor {linalg.inplaceable = true}, + %idx: index) -> tensor { + // expected-error @+1 {{result buffer is ambiguous}} + %r = scf.if %cond -> (tensor) { + scf.yield %t1 : tensor + } else { + // This buffer aliases, but is not equivalent. + %t2 = tensor.extract_slice %t1 [%idx] [%idx] [1] : tensor to tensor + scf.yield %t2 : tensor + } + return %r : tensor +} + +// ----- + // expected-error @-3 {{expected callgraph to be free of circular dependencies}} func @foo() { diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -887,4 +887,3 @@ } return %r : tensor } -