diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -245,10 +245,6 @@ /// themselves (e.g., ExtractSliceOp). bool isValueRead(Value value); -/// Return the relationship between the operand and the its corresponding -/// OpResult that it may alias with. Return None if the op is not bufferizable. -BufferRelation bufferRelation(OpOperand &opOperand); - /// Starting from `value`, follow the use-def chain in reverse, always selecting /// the aliasing OpOperands. Find and return Values for which `condition` /// evaluates to true. OpOperands of such matching Values are not traversed any @@ -430,7 +426,8 @@ return OpResult(); } - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo) const { return BufferRelation::None; } diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -153,19 +153,23 @@ >, InterfaceMethod< /*desc=*/[{ - Return the buffer relation between the given OpOperand and its - aliasing OpResult when bufferized in-place. Most OpOperands have an - "equivalence" relation. + Return the buffer relation between the given OpResult and its aliasing + OpOperands when bufferized in-place. Most OpOperands have an + "equivalence" relation. This method will never be called on OpResults + that do not have a tensor type. It will also never be called on + OpResults that do not have at least one aliasing OpOperand. TODO: Support other relations such as "OpOperand is included in OpResult". }], /*retType=*/"BufferRelation", /*methodName=*/"bufferRelation", - /*args=*/(ins "OpOperand &":$opOperand), + /*args=*/(ins "OpResult":$opResult, + "const BufferizationAliasInfo &":$aliasInfo), /*methodBody=*/"", /*defaultImplementation=*/[{ - // Does not have to be implemented for ops without tensor OpOperands. + // Does not have to be implemented for ops without tensor OpResults + // that have an aliasing OpOperand. llvm_unreachable("bufferRelation not implemented"); }] >, diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h @@ -9,6 +9,8 @@ #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCF_INTERFACE_IMPL_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCF_INTERFACE_IMPL_H +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" + namespace mlir { class DialectRegistry; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -167,8 +167,6 @@ markInPlace(result); aliasInfo.unionSets(result, operand.get()); - if (bufferRelation(operand) == BufferRelation::Equivalent) - equivalentInfo.unionSets(result, operand.get()); } /// Set the inPlace bufferization spec to false. @@ -303,19 +301,6 @@ return false; } -/// Return the relationship between the operand and the its corresponding -/// OpResult that it may alias with. Return None if the op is not bufferizable. -BufferRelation -mlir::linalg::comprehensive_bufferize::bufferRelation(OpOperand &opOperand) { - if (auto bufferizableOp = - dyn_cast(opOperand.getOwner())) - return bufferizableOp.bufferRelation(opOperand); - - // Unknown op that returns a tensor. The inplace analysis does not support it. - // Conservatively return None. - return BufferRelation::None; -} - // Starting from `value`, follow the use-def chain in reverse, always selecting // the aliasing OpOperands. Find and return Values for which `condition` // evaluates to true. OpOperands of such matching Values are not traversed any diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -634,6 +634,40 @@ return inPlaceAnalysis(ops, aliasInfo, domInfo, analysisFuzzerSeed); } +/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. +static void equivalenceAnalysis(SmallVector &ops, + BufferizationAliasInfo &aliasInfo) { + for (Operation *op : ops) + if (auto bufferizableOp = dyn_cast(op)) + for (OpResult opResult : op->getOpResults()) + if (opResult.getType().isa()) + if (aliasInfo.isInPlace(opResult)) { + SmallVector opOperands = + bufferizableOp.getAliasingOpOperand(opResult); + if (!opOperands.empty()) + if (bufferizableOp.bufferRelation(opResult, aliasInfo) == + BufferRelation::Equivalent) + for (OpOperand *opOperand : opOperands) + aliasInfo.unionEquivalenceClasses(opResult, opOperand->get()); + } +} + +/// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained +/// in `op`. +static void equivalenceAnalysis(Operation *op, + BufferizationAliasInfo &aliasInfo) { + // Traverse ops in PostOrder: Nested ops first, then enclosing ops. + SmallVector ops; + op->walk([&](Operation *op) { + // No tensors => no buffers. + if (none_of(op->getResultTypes(), isaTensor)) + return; + ops.push_back(op); + }); + + equivalenceAnalysis(ops, aliasInfo); +} + /// Assert that the current bufferization decisions are consistent. static LogicalResult checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo, @@ -700,6 +734,7 @@ if (failed( inPlaceAnalysis(op, aliasInfo, domInfo, options.analysisFuzzerSeed))) return failure(); + equivalenceAnalysis(op, aliasInfo); for (const std::unique_ptr &step : options.postAnalysisSteps) { @@ -709,6 +744,7 @@ // Analyze ops that were created by the PostAnalysisStep. if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo))) return failure(); + equivalenceAnalysis(newOps, aliasInfo); } // Annotate operations if we only want to report the analysis. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -132,7 +132,8 @@ return genericOp->getResult(outputOperandIndex - numOutputBuffers); } - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo) const { return BufferRelation::Equivalent; } @@ -205,7 +206,8 @@ return tiledLoopOp.getTiedOpResult(opOperand); } - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo) const { return BufferRelation::Equivalent; } @@ -432,8 +434,9 @@ // TODO: Support cases such as extract_slice(init_tensor). SmallVector opOperands = getAliasingOpOperand(opResult); - if (!llvm::all_of(opOperands, [](OpOperand *operand) { - return bufferRelation(*operand) == BufferRelation::Equivalent; + if (!llvm::all_of(opOperands, [&](OpOperand *operand) { + return aliasInfo.areEquivalentBufferizedValues(operand->get(), + opResult); })) return true; return false; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -460,10 +460,6 @@ return OpResult(); } - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { - return BufferRelation::Equivalent; - } - /// In a first approximation, all the function arguments of a FuncOp are /// marked inplaceable. For now, it is the responsibility of the `callOp` /// bufferization to allow FuncOp that are inplaceable to write inPlace. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -68,6 +68,11 @@ "scf.execute_region with tensor result not supported"); return comprehensive_bufferize::bufferize(&executeRegionOp.region(), state); } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo) const { + return BufferRelation::Equivalent; + } }; struct IfOpInterface @@ -149,6 +154,19 @@ return success(); } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo) const { + // IfOp results are equivalent to their corresponding yield values if both + // yield values are equivalent to each other. + auto bufferizableOp = cast(op); + SmallVector yieldValues = + bufferizableOp.getAliasingOpOperand(opResult); + assert(yieldValues.size() == 2 && "expected 2 yield values"); + bool equivalentYields = aliasInfo.areEquivalentBufferizedValues( + yieldValues[0]->get(), yieldValues[1]->get()); + return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; + } }; struct ForOpInterface @@ -181,8 +199,17 @@ return forOp.getResultForOpOperand(opOperand); } - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { - return BufferRelation::Equivalent; + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo) const { + // ForOp results are equivalent to their corresponding init_args if the + // corresponding iter_args and yield values are equivalent. + auto forOp = cast(op); + OpOperand &forOperand = forOp.getOpOperandForResult(opResult); + auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); + auto yieldOp = cast(&forOp.getLoopBody().front().back()); + bool equivalentYield = aliasInfo.areEquivalentBufferizedValues( + bbArg, yieldOp->getOperand(opResult.getResultNumber())); + return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; } bool isWritable(Operation *op, Value value) const { @@ -237,10 +264,8 @@ OpOperand &forOperand = forOp.getOpOperandForResult( forOp->getResult(operand.getOperandNumber())); auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - Value yieldedBuffer = state.lookupBuffer(operand.get()); - Value bbArgBuffer = state.lookupBuffer(bbArg); - if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, - bbArgBuffer)) { + if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(), + bbArg)) { // TODO: this could get resolved with copies but it can also turn into // swaps so we need to be careful about order of copies. return yieldOp->emitError() @@ -272,10 +297,6 @@ return OpResult(); } - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { - return BufferRelation::Equivalent; - } - LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { auto yieldOp = cast(op); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -57,7 +57,8 @@ return op->getResult(0); } - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo) const { return BufferRelation::Equivalent; } @@ -148,7 +149,8 @@ : OpResult(); } - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo) const { return BufferRelation::None; } @@ -268,7 +270,8 @@ return success(); } - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo) const { return BufferRelation::Equivalent; } }; @@ -345,7 +348,8 @@ : OpResult(); } - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo) const { return BufferRelation::Equivalent; } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -79,7 +79,8 @@ return op->getOpResult(0); } - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo) const { return BufferRelation::Equivalent; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -71,8 +71,7 @@ // Enable InitTensorOp elimination. options.addPostAnalysisStep< linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); - // TODO: Find a way to enable this step automatically when bufferizing - // tensor dialect ops. + // TODO: Find a way to enable these steps automatically. options.addPostAnalysisStep(); options.allowReturnMemref = this->allowReturnMemref;