diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -457,7 +457,7 @@ if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) return WalkResult::advance(); for (auto &use : transferRead.source().getUses()) { - if (!dom.properlyDominates(loop, use.getOwner())) + if (!loop->isAncestor(use.getOwner())) continue; if (use.getOwner() == transferRead.getOperation() || use.getOwner() == transferWrite.getOperation()) diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -39,9 +39,11 @@ // CHECK: scf.yield {{.*}} : vector<1xf32>, vector<2xf32> // CHECK: } // CHECK: vector.transfer_write %{{.*}} : vector<2xf32>, memref +// CHECK: "unrelated_use"(%[[MEMREF0]]) : (memref) -> () // CHECK: scf.yield {{.*}} : vector<1xf32> // CHECK: } // CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref +// CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref) -> () scf.for %i = %lb to %ub step %step { scf.for %j = %lb to %ub step %step { %r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref, vector<1xf32> @@ -66,7 +68,9 @@ vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref "some_crippling_use"(%memref3) : (memref) -> () } + "unrelated_use"(%memref0) : (memref) -> () } + "unrelated_use"(%memref1) : (memref) -> () return }