diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -85,17 +85,17 @@ if (failed(tryToFold(op, results, processGeneratedConstants))) return failure(); - // Constant folding succeeded. We will start replacing this op's uses and - // eventually erase this op. Invoke the callback provided by the caller to - // perform any pre-replacement action. - if (preReplaceAction) - preReplaceAction(op); - // Check to see if the operation was just updated in place. if (results.empty()) return success(); - // Otherwise, replace all of the result values and erase the operation. + // Constant folding succeeded. We will start replacing this op's uses and + // erase this op. Invoke the callback provided by the caller to perform any + // pre-replacement action. + if (preReplaceAction) + preReplaceAction(op); + + // Replace all of the result values and erase the operation. for (unsigned i = 0, e = results.size(); i != e; ++i) op->getResult(i).replaceAllUsesWith(results[i]); op->erase();