diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -272,6 +272,7 @@ // After iterating `backwardSlice` we obtain: // indexEdges = [%i, %j, %ubi, %ubj] // backwardSlice = backwardSlice / [linalg.fill(%cst, %arg1), scf.for %k] + SetVector operationsToRemove; for (Operation *op : llvm::reverse(backwardSlice)) { // Add the index operands of `padTensorOp` and `sliceOp` to start the // exploration of the index computation. @@ -308,11 +309,12 @@ } continue; } - // Remove all other operation not used by the index computation except for - // constant operations that may be padding values used by `padTensorOp`. + // Remove all other operations not used by the index computation. An + // exception are constant operations that may be used by `padTensorOp`. if (!isa(op)) - backwardSlice.remove(op); + operationsToRemove.insert(op); } + backwardSlice.set_subtract(operationsToRemove); return success(); }