diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3892,7 +3892,11 @@ OpFoldResult TensorLoadOp::fold(ArrayRef) { if (auto tensorToMemref = memref().getDefiningOp()) - return tensorToMemref.tensor(); + // Approximate alias analysis by conservatively folding only when no there + // is no interleaved operation. + if (tensorToMemref->getBlock() == this->getOperation()->getBlock() && + tensorToMemref->getNextNode() == this->getOperation()) + return tensorToMemref.tensor(); return {}; } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -368,11 +368,7 @@ } // CHECK-NEXT: %[[R0:.*]] = tensor_load %[[M0]] : memref<128x128xf32> - - // This folds away due to incorrect tensor_load(tensor_memref(x)) -> x folding - // that does not consider aliasing. As a consequence, cannot check fully atm. - // C-HECK-NEXT: %[[R1:.*]] = tensor_load %[[M1]] : memref<128x128xf32> - // C-HECK-NEXT: return %[[R0]], %[[R1]] : tensor<128x128xf32>, tensor<128x128xf32> - // CHECK-NEXT: return %[[R0]], %{{.*}} : tensor<128x128xf32>, tensor<128x128xf32> + // CHECK-NEXT: %[[R1:.*]] = tensor_load %[[M1]] : memref<128x128xf32> + // CHECK-NEXT: return %[[R0]], %[[R1]] : tensor<128x128xf32>, tensor<128x128xf32> return %0#0, %0#1 : tensor<128x128xf32>, tensor<128x128xf32> }