diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -394,6 +394,13 @@ bool changed = true; while (changed) { changed = false; + // First move loop invariant ops outside of their loop. This needs to be + // done before as we cannot move ops without interputing the function walk. + func.walk([&](LoopLikeOpInterface loopLike) { + if (failed(moveLoopInvariantCode(loopLike))) + llvm_unreachable( + "Unexpected failure to move invariant code out of loop"); + }); func.walk([&](vector::TransferReadOp transferRead) { if (!transferRead.getShapedType().isa()) @@ -407,11 +414,6 @@ if (!loop) return WalkResult::advance(); - if (failed(moveLoopInvariantCode( - cast(loop.getOperation())))) - llvm_unreachable( - "Unexpected failure to move invariant code out of loop"); - LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() << "\n");