diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -71,6 +71,8 @@ } if (isMemoryEffectFree(user) || isa(user)) continue; + if (!loop->isAncestor(user)) + continue; return false; } return true; diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -696,6 +696,7 @@ // CHECK: "some_use"(%[[D0]], %[[D1]], %[[CAST]]) : (vector<32x64xf32>, vector<32x128xf32>, memref<32x128xf32, // CHECK-SAME: strided<[128, 1], offset: ?>>) -> () // CHECK: } +// CHECK: memref.dealloc %[[ALLOC]] : memref<32x64xf32> // CHECK: return func.func @hoist_vector_transfer_read() { %c0 = arith.constant 0 : index @@ -710,6 +711,7 @@ %3 = vector.transfer_read %memref0[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<32x64xf32>, vector<32x64xf32> "some_use"(%3, %2, %subview2) : (vector<32x64xf32>, vector<32x128xf32>, memref<32x128xf32, strided<[128, 1], offset: ?>>) -> () } + memref.dealloc %memref0 : memref<32x64xf32> return }