diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -562,12 +562,19 @@ //===----------------------------------------------------------------------===// OpFoldResult ToTensorOp::fold(FoldAdaptor) { - if (auto toMemref = getMemref().getDefiningOp()) - // Approximate alias analysis by conservatively folding only when no there - // is no interleaved operation. + if (auto toMemref = getMemref().getDefiningOp()) { + // Approximate alias analysis by conservatively folding only when know there + // is no interleaved user of the result memref. + bool canResultMemrefMutate = + llvm::any_of(toMemref->getUsers(), [&](Operation *op) { + if (op->getBlock() == this->getOperation()->getBlock()) + return op->isBeforeInBlock(*this); + return true; + }); if (toMemref->getBlock() == this->getOperation()->getBlock() && - toMemref->getNextNode() == this->getOperation()) + !canResultMemrefMutate) return toMemref.getTensor(); + } return {}; } diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -15,6 +15,38 @@ // ----- +// Folding of to_tensor(to_memref(t)) -> t is allowed when the buffer is not mutated before the to_tensor. +// CHECK-LABEL: func.func @tensor_load_of_unmutated_buffer( +func.func @tensor_load_of_unmutated_buffer(%arg0: tensor) -> (tensor, tensor) { + %0 = bufferization.to_memref %arg0 : memref + %1 = "tosa.cast"(%arg0) : (tensor) -> tensor + %2 = bufferization.to_tensor %0 : memref + return %1, %2 : tensor, tensor +} + +// CHECK-SAME: %[[arg0:.*]]: tensor) -> (tensor, tensor) { +// CHECK: %[[cast:.*]] = "tosa.cast"(%[[arg0]]) : (tensor) -> tensor +// CHECK: return %[[cast]], %[[arg0]] : tensor, tensor + +// ----- + +// Folding of to_tensor(to_memref(t)) -> t is not allowed since t may mutate before the to_tensor. +// CHECK-LABEL: func.func @tensor_load_of_mutated_buffer( +func.func @tensor_load_of_mutated_buffer(%arg0: tensor) -> (tensor) { + %0 = bufferization.to_memref %arg0 : memref + %1 = "use"(%0) : (memref) -> memref + %2 = bufferization.to_tensor %0 : memref + return %2 : tensor +} + +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = bufferization.to_memref %[[VAL_0]] : memref +// CHECK: %[[VAL_2:.*]] = "use"(%[[VAL_1]]) : (memref) -> memref +// CHECK: %[[VAL_3:.*]] = bufferization.to_tensor %[[VAL_1]] : memref +// CHECK: return %[[VAL_3]] : tensor + +// ----- + // Basic folding of to_memref(to_tensor(m)) -> m // CHECK-LABEL: func @buffer_cast_of_tensor_load( func.func @buffer_cast_of_tensor_load(%arg0: memref) -> memref {