diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -90,6 +90,11 @@ /// Return true if `v1` and `v2` bufferize to equivalent buffers. bool areEquivalentBufferizedValues(Value v1, Value v2) const { + // Return `false` if we have no information about `v1` or `v2`. + if (equivalentInfo.findValue(v1) == equivalentInfo.end() || + equivalentInfo.findValue(v2) == equivalentInfo.end()) + return false; + return equivalentInfo.getLeaderValue(v1) == equivalentInfo.getLeaderValue(v2); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1748,7 +1748,7 @@ Operation *newCallOp = b.create(callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands); newCallOp->setAttrs(callOp->getAttrs()); - callOp->erase(); + // Delete the op at the end of bufferization. return success(); } @@ -2496,6 +2496,8 @@ if (failed(bufferize(b, funcOp, bvm, aliasInfo))) return failure(); + // Cannot erase ops during the traversal. Do that afterwards. + SmallVector toErase; // Bufferize the function body. `bufferizedOps` keeps track ops that were // already bufferized with pre-order traversal. DenseSet bufferizedOps; @@ -2522,6 +2524,13 @@ failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes, &globalCreator))) return failure(); + + // Register post-walk erasure, if necessary. + if (isa(op)) + if (llvm::any_of(op->getOperandTypes(), isaTensor) || + llvm::any_of(op->getResultTypes(), isaTensor)) + toErase.push_back(op); + return success(); }; if (funcOp.walk(walkFunc).wasInterrupted()) @@ -2529,6 +2538,9 @@ LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n'); + for (Operation *op : toErase) + op->erase(); + return success(); }