diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -944,16 +944,24 @@ static void equivalenceAnalysis(SmallVector &ops, BufferizationAliasInfo &aliasInfo, AnalysisState &state) { - for (Operation *op : ops) - if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) - for (OpResult opResult : op->getOpResults()) - if (opResult.getType().isa()) - for (OpOperand *opOperand : - bufferizableOp.getAliasingOpOperands(opResult, state)) - if (state.isInPlace(*opOperand)) - if (bufferizableOp.bufferRelation(opResult, state) == - BufferRelation::Equivalent) - aliasInfo.unionEquivalenceClasses(opResult, opOperand->get()); + for (Operation *op : ops) { + if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) { + for (OpOperand &opOperand : op->getOpOperands()) { + if (opOperand.get().getType().isa()) { + if (!state.isInPlace(opOperand)) + // Out-of-place OpOperands bufferize to new allocations and do not + // union equivalence sets. + continue; + AliasingOpResultList aliases = state.getAliasingOpResults(opOperand); + for (OpResult alias : aliases) { + if (cast(alias.getDefiningOp()) + .bufferRelation(alias, state) == BufferRelation::Equivalent) + aliasInfo.unionEquivalenceClasses(alias, opOperand.get()); + } + } + } + } + } } /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained @@ -965,7 +973,7 @@ SmallVector ops; op->walk([&](Operation *op) { // No tensors => no buffers. - if (none_of(op->getResultTypes(), isaTensor)) + if (none_of(op->getOperandTypes(), isaTensor)) return; ops.push_back(op); });