diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -179,8 +179,7 @@ /// equivalence classes to support bufferization. class BufferizationAliasInfo { public: - explicit BufferizationAliasInfo(Operation *rootOp, - const BufferizationOptions &options); + explicit BufferizationAliasInfo(Operation *rootOp); // BufferizationAliasInfo should be passed as a reference. BufferizationAliasInfo(const BufferizationAliasInfo &) = delete; @@ -271,68 +270,6 @@ /// Return `true` if the given value is a BlockArgument of a FuncOp. bool isFunctionArgument(Value value); -/// Determine which OpOperand* will alias with `result` if the op is bufferized -/// in place. Return an empty vector if the op is not bufferizable. -SmallVector getAliasingOpOperand(OpResult result); - -/// Determine which OpResult will alias with `opOperand` if the op is bufferized -/// in place. Return an empty OpResult if the op is not bufferizable. -OpResult getAliasingOpResult(OpOperand &opOperand); - -/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the -/// op is not bufferizable. -bool bufferizesToMemoryRead(OpOperand &opOperand); - -/// Return true if `opOperand` bufferizes to a memory write. Return -/// `true` if the op is not bufferizable. -bool bufferizesToMemoryWrite(OpOperand &opOperand); - -/// Return true if `opOperand` does neither read nor write but bufferizes to an -/// alias. Return false if the op is not bufferizable. -bool bufferizesToAliasOnly(OpOperand &opOperand); - -/// Return true if the given value is read by an op that bufferizes to a memory -/// read. Also takes into account ops that create an alias but do not read by -/// themselves (e.g., ExtractSliceOp). -bool isValueRead(Value value); - -/// Starting from `value`, follow the use-def chain in reverse, always selecting -/// the aliasing OpOperands. Find and return Values for which `condition` -/// evaluates to true. OpOperands of such matching Values are not traversed any -/// further. -/// -/// When reaching the end of a chain (BlockArgument or Value without aliasing -/// OpOperands), also return the last Value of that chain. -/// -/// Example: -/// -/// 8 -/// | -/// 6* 7* +-----+----+ -/// | | | | -/// 2* 3 4* 5 -/// | | | | -/// +----------+----------+----------+ -/// | -/// 1 -/// -/// In the above example, Values with a star satisfy the condition. When -/// starting the traversal from Value 1, the resulting SetVector is: -/// { 2, 7, 8, 5 } -llvm::SetVector -findValueInReverseUseDefChain(Value value, const BufferizationOptions &options, - std::function condition); - -/// Find the Value of the last preceding write of a given Value. -/// -/// Note: Unknown ops are handled conservatively and assumed to be writes. -/// Furthermore, BlockArguments are also assumed to be writes. There is no -/// analysis across block boundaries. -/// -/// Note: When reaching an end of the reverse SSA use-def chain, that value -/// is returned regardless of whether it is a memory write or not. -Value findLastPrecedingWrite(Value value, const BufferizationOptions &options); - /// Dialect-specific bufferization state. Analysis/bufferization information /// that is specific to ops from a certain dialect can be stored in derived /// variants of this struct. @@ -359,12 +296,74 @@ /// * `replaceOp` replaces an op with new values. class BufferizationState { public: - BufferizationState(Operation *op, const BufferizationOptions &options) - : aliasInfo(op, options), options(options), builder(op->getContext()) {} + BufferizationState(Operation *op, const BufferizationOptions &options); // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; + /// Determine which OpOperand* will alias with `result` if the op is + /// bufferized in place. Return an empty vector if the op is not bufferizable. + SmallVector getAliasingOpOperand(OpResult result); + + /// Determine which OpResult will alias with `opOperand` if the op is + /// bufferized in place. Return an empty OpResult if the op is not + /// bufferizable. + OpResult getAliasingOpResult(OpOperand &opOperand); + + /// Return true if `opOperand` bufferizes to a memory read. Return `true` if + /// the op is not bufferizable. + bool bufferizesToMemoryRead(OpOperand &opOperand); + + /// Return true if `opOperand` bufferizes to a memory write. Return true` if + /// the op is not bufferizable. + bool bufferizesToMemoryWrite(OpOperand &opOperand); + + /// Return true if `opOperand` does neither read nor write but bufferizes to + /// an alias. Return false if the op is not bufferizable. + bool bufferizesToAliasOnly(OpOperand &opOperand); + + /// Return true if the given value is read by an op that bufferizes to a + /// memory read. Also takes into account ops that create an alias but do not + /// read by themselves (e.g., ExtractSliceOp). + bool isValueRead(Value value); + + /// Starting from `value`, follow the use-def chain in reverse, always + /// selecting the aliasing OpOperands. Find and return Values for which + /// `condition` evaluates to true. OpOperands of such matching Values are not + /// traversed any further. + /// + /// When reaching the end of a chain (BlockArgument or Value without aliasing + /// OpOperands), also return the last Value of that chain. + /// + /// Example: + /// + /// 8 + /// | + /// 6* 7* +-----+----+ + /// | | | | + /// 2* 3 4* 5 + /// | | | | + /// +----------+----------+----------+ + /// | + /// 1 + /// + /// In the above example, Values with a star satisfy the condition. When + /// starting the traversal from Value 1, the resulting SetVector is: + /// { 2, 7, 8, 5 } + llvm::SetVector + findValueInReverseUseDefChain(Value value, + std::function condition); + + /// Find the Value of the last preceding write of a given Value. + /// + /// Note: Unknown ops are handled conservatively and assumed to be writes. + /// Furthermore, BlockArguments are also assumed to be writes. There is no + /// analysis across block boundaries. + /// + /// Note: When reaching an end of the reverse SSA use-def chain, that value + /// is returned regardless of whether it is a memory write or not. + Value findLastPrecedingWrite(Value value); + /// Creates a memref allocation. Optional createAlloc(OpBuilder &b, Location loc, MemRefType type, ArrayRef dynShape); @@ -494,25 +493,30 @@ struct AllocationHoistingBarrierOnly : public BufferizableOpInterface::ExternalModel< AllocationHoistingBarrierOnly, OpTy> { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return true; } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return false; } - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + BufferizationState &state) const { return {}; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return OpResult(); } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationAliasInfo &aliasInfo) const { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) const { return BufferRelation::None; } diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -32,7 +32,8 @@ }], /*retType=*/"bool", /*methodName=*/"bufferizesToMemoryRead", - /*args=*/(ins "OpOperand &":$opOperand), + /*args=*/(ins "OpOperand &":$opOperand, + "BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpOperands. @@ -60,7 +61,8 @@ }], /*retType=*/"bool", /*methodName=*/"bufferizesToMemoryWrite", - /*args=*/(ins "OpOperand &":$opOperand), + /*args=*/(ins "OpOperand &":$opOperand, + "BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpOperands. @@ -82,19 +84,21 @@ }], /*retType=*/"bool", /*methodName=*/"isMemoryWrite", - /*args=*/(ins "OpResult":$opResult), + /*args=*/(ins "OpResult":$opResult, + "BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ auto bufferizableOp = cast($_op.getOperation()); SmallVector opOperands = - bufferizableOp.getAliasingOpOperand(opResult); + bufferizableOp.getAliasingOpOperand(opResult, state); if (opOperands.empty()) return true; return llvm::any_of( opOperands, [&](OpOperand *operand) { - return bufferizableOp.bufferizesToMemoryWrite(*operand); + return bufferizableOp.bufferizesToMemoryWrite(*operand, + state); }); }] >, @@ -111,7 +115,8 @@ }], /*retType=*/"bool", /*methodName=*/"mustBufferizeInPlace", - /*args=*/(ins "OpResult":$opResult), + /*args=*/(ins "OpResult":$opResult, + "BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ return false; @@ -125,7 +130,8 @@ }], /*retType=*/"OpResult", /*methodName=*/"getAliasingOpResult", - /*args=*/(ins "OpOperand &":$opOperand), + /*args=*/(ins "OpOperand &":$opOperand, + "BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpOperands. @@ -148,7 +154,8 @@ }], /*retType=*/"SmallVector", /*methodName=*/"getAliasingOpOperand", - /*args=*/(ins "OpResult":$opResult), + /*args=*/(ins "OpResult":$opResult, + "BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ assert(opResult.getType().isa() && @@ -159,7 +166,8 @@ for (OpOperand &opOperand : $_op.getOperation()->getOpOperands()) { if (!opOperand.get().getType().isa()) continue; - if (bufferizableOp.getAliasingOpResult(opOperand) == opResult) + if (bufferizableOp.getAliasingOpResult(opOperand, state) == + opResult) result.push_back(&opOperand); } return result; @@ -179,7 +187,8 @@ /*retType=*/"BufferRelation", /*methodName=*/"bufferRelation", /*args=*/(ins "OpResult":$opResult, - "const BufferizationAliasInfo &":$aliasInfo), + "const BufferizationAliasInfo &":$aliasInfo, + "BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpResults @@ -282,13 +291,14 @@ /// be called on OpOperands that do not have a tensor type. /// /// Examples of such ops are `tensor.extract_slice` and `tensor.cast`. - bool bufferizesToAliasOnly(OpOperand &opOperand) { + bool bufferizesToAliasOnly(OpOperand &opOperand, + BufferizationState &state) { auto bufferizableOp = cast(getOperation()); - return !bufferizableOp.bufferizesToMemoryRead(opOperand) - && !bufferizableOp.bufferizesToMemoryWrite(opOperand) + return !bufferizableOp.bufferizesToMemoryRead(opOperand, state) + && !bufferizableOp.bufferizesToMemoryWrite(opOperand, state) && static_cast( - bufferizableOp.getAliasingOpResult(opOperand)); + bufferizableOp.getAliasingOpResult(opOperand, state)); } // TODO: The following two attributes should belong to the tensor dialect. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -78,8 +78,7 @@ // BufferizationAliasInfo //===----------------------------------------------------------------------===// -BufferizationAliasInfo::BufferizationAliasInfo( - Operation *rootOp, const BufferizationOptions &options) { +BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { rootOp->walk([&](Operation *op) { for (Value v : op->getResults()) if (v.getType().isa()) @@ -90,26 +89,6 @@ if (bbArg.getType().isa()) createAliasInfoEntry(bbArg); }); - - // Set up alias sets for OpResults that must bufferize in-place. This should - // be done before making any other bufferization decisions. - rootOp->walk([&](BufferizableOpInterface bufferizableOp) { - if (!options.isOpAllowed(bufferizableOp)) - return WalkResult::skip(); - for (OpResult opResult : bufferizableOp->getOpResults()) { - if (opResult.getType().isa()) - if (bufferizableOp.mustBufferizeInPlace(opResult)) { - SmallVector operands = - bufferizableOp.getAliasingOpOperand(opResult); - assert(!operands.empty() && - "expected that OpResult has aliasing OpOperand"); - for (OpOperand *operand : operands) - aliasInfo.unionSets(operand->get(), opResult); - markInPlace(opResult); - } - } - return WalkResult::advance(); - }); } /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the @@ -219,30 +198,32 @@ /// Determine which OpOperand* will alias with `result` if the op is bufferized /// in place. Return an empty vector if the op is not bufferizable. SmallVector -mlir::linalg::comprehensive_bufferize::getAliasingOpOperand(OpResult result) { +mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand( + OpResult result) { if (Operation *op = result.getDefiningOp()) if (auto bufferizableOp = dyn_cast(op)) - return bufferizableOp.getAliasingOpOperand(result); + return bufferizableOp.getAliasingOpOperand(result, *this); return {}; } /// Determine which OpResult will alias with `opOperand` if the op is bufferized /// in place. Return an empty OpResult if the op is not bufferizable. -OpResult mlir::linalg::comprehensive_bufferize::getAliasingOpResult( +OpResult +mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult( OpOperand &opOperand) { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) - return bufferizableOp.getAliasingOpResult(opOperand); + return bufferizableOp.getAliasingOpResult(opOperand, *this); return OpResult(); } /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the /// op is not bufferizable. -bool mlir::linalg::comprehensive_bufferize::bufferizesToMemoryRead( - OpOperand &opOperand) { +bool mlir::linalg::comprehensive_bufferize::BufferizationState:: + bufferizesToMemoryRead(OpOperand &opOperand) { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) - return bufferizableOp.bufferizesToMemoryRead(opOperand); + return bufferizableOp.bufferizesToMemoryRead(opOperand, *this); // Unknown op that returns a tensor. The inplace analysis does not support it. // Conservatively return true. @@ -251,11 +232,11 @@ /// Return true if `opOperand` bufferizes to a memory write. Return /// `true` if the op is not bufferizable. -bool mlir::linalg::comprehensive_bufferize::bufferizesToMemoryWrite( - OpOperand &opOperand) { +bool mlir::linalg::comprehensive_bufferize::BufferizationState:: + bufferizesToMemoryWrite(OpOperand &opOperand) { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) - return bufferizableOp.bufferizesToMemoryWrite(opOperand); + return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this); // Unknown op that returns a tensor. The inplace analysis does not support it. // Conservatively return true. @@ -264,11 +245,11 @@ /// Return true if `opOperand` does neither read nor write but bufferizes to an /// alias. Return false if the op is not bufferizable. -bool mlir::linalg::comprehensive_bufferize::bufferizesToAliasOnly( - OpOperand &opOperand) { +bool mlir::linalg::comprehensive_bufferize::BufferizationState:: + bufferizesToAliasOnly(OpOperand &opOperand) { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) - return bufferizableOp.bufferizesToAliasOnly(opOperand); + return bufferizableOp.bufferizesToAliasOnly(opOperand, *this); // Unknown op that returns a tensor. The inplace analysis does not support it. // Conservatively return false. @@ -278,7 +259,8 @@ /// Return true if the given value is read by an op that bufferizes to a memory /// read. Also takes into account ops that create an alias but do not read by /// themselves (e.g., ExtractSliceOp). -bool mlir::linalg::comprehensive_bufferize::isValueRead(Value value) { +bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead( + Value value) { SmallVector workingSet; for (OpOperand &use : value.getUses()) workingSet.push_back(&use); @@ -301,9 +283,9 @@ // evaluates to true. OpOperands of such matching Values are not traversed any // further. llvm::SetVector -mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain( - Value value, const BufferizationOptions &options, - std::function condition) { +mlir::linalg::comprehensive_bufferize::BufferizationState:: + findValueInReverseUseDefChain(Value value, + std::function condition) { llvm::SetVector result, workingSet; workingSet.insert(value); @@ -329,17 +311,17 @@ } // Find the Value of the last preceding write of a given Value. -Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite( - Value value, const BufferizationOptions &options) { +Value mlir::linalg::comprehensive_bufferize::BufferizationState:: + findLastPrecedingWrite(Value value) { SetVector result = - findValueInReverseUseDefChain(value, options, [&](Value value) { + findValueInReverseUseDefChain(value, [&](Value value) { Operation *op = value.getDefiningOp(); if (!op) return true; auto bufferizableOp = options.dynCastBufferizableOp(op); if (!bufferizableOp) return true; - return bufferizableOp.isMemoryWrite(value.cast()); + return bufferizableOp.isMemoryWrite(value.cast(), *this); }); // To simplify the analysis, `scf.if` ops are considered memory writes. There @@ -350,6 +332,30 @@ return result.front(); } +mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState( + Operation *op, const BufferizationOptions &options) + : aliasInfo(op), options(options), builder(op->getContext()) { + // Set up alias sets for OpResults that must bufferize in-place. This should + // be done before making any other bufferization decisions. + op->walk([&](BufferizableOpInterface bufferizableOp) { + if (!options.isOpAllowed(bufferizableOp)) + return WalkResult::skip(); + for (OpResult opResult : bufferizableOp->getOpResults()) { + if (opResult.getType().isa()) + if (bufferizableOp.mustBufferizeInPlace(opResult, *this)) { + SmallVector operands = + bufferizableOp.getAliasingOpOperand(opResult, *this); + assert(!operands.empty() && + "expected that OpResult has aliasing OpOperand"); + for (OpOperand *operand : operands) + aliasInfo.unionAliasSets(operand->get(), opResult); + aliasInfo.markInPlace(opResult); + } + } + return WalkResult::advance(); + }); +} + /// Return the result buffer (memref) for a given OpResult (tensor). Allocate /// a new buffer and copy over data from the existing buffer if out-of-place /// bufferization is necessary. @@ -394,9 +400,9 @@ // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA // use-def chain, it returns that value, regardless of whether it is a // memory write or not. - Value lastWrite = findLastPrecedingWrite(operand, options); + Value lastWrite = findLastPrecedingWrite(operand); if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) - if (!bufferizableOp.isMemoryWrite(lastWrite.cast())) + if (!bufferizableOp.isMemoryWrite(lastWrite.cast(), *this)) skipCopy = true; // Do not copy if the copied data is never read. if (!isValueRead(result)) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp @@ -39,12 +39,14 @@ struct ToMemrefOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { // It is unknown whether the resulting MemRef will be read or not. return true; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return OpResult(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -162,7 +162,8 @@ /// Return true if opOperand has been decided to bufferize in-place. static bool isInplaceMemoryWrite(OpOperand &opOperand, - const BufferizationAliasInfo &aliasInfo) { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) { // The analysis does not know what happens to the result of a ToMemrefOp, so // we assume that it is written to. // TODO: This is a conservative implementation. This rule will have to be @@ -170,11 +171,11 @@ if (isa(opOperand.getOwner())) return true; // OpOperands without an aliasing OpResult do not write. - OpResult opResult = getAliasingOpResult(opOperand); + OpResult opResult = state.getAliasingOpResult(opOperand); if (!opResult) return false; // OpOperands that do not bufferize to a memory write do not write in-place. - if (!bufferizesToMemoryWrite(opOperand)) + if (!state.bufferizesToMemoryWrite(opOperand)) return false; // Check current bufferization decisions. return aliasInfo.isInPlace(opResult); @@ -209,11 +210,12 @@ /// Return true if the buffer to which `operand` would bufferize is equivalent /// to some buffer write. static bool aliasesInPlaceWrite(Value value, - const BufferizationAliasInfo &aliasInfo) { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) { bool foundInplaceWrite = false; aliasInfo.applyOnAliases(value, [&](Value v) { for (auto &use : v.getUses()) { - if (isInplaceMemoryWrite(use, aliasInfo)) { + if (isInplaceMemoryWrite(use, aliasInfo, state)) { foundInplaceWrite = true; return; } @@ -295,7 +297,7 @@ // In the above example, if uRead is the OpOperand of reading_op, lastWrite // is %0. Note that operations that create an alias but do not write (such // as ExtractSliceOp) are skipped. - Value lastWrite = findLastPrecedingWrite(uRead->get(), options); + Value lastWrite = state.findLastPrecedingWrite(uRead->get()); // Look for conflicting memory writes. Potential conflicts are writes to an // alias that have been decided to bufferize inplace. @@ -352,7 +354,7 @@ // No conflict if the conflicting write and the last write are the same // use. - if (getAliasingOpResult(*uConflictingWrite) == lastWrite) + if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite) continue; // All requirements are met. Conflict found! @@ -402,7 +404,7 @@ bool checkConsistencyOnly = false) { #ifndef NDEBUG if (result) { - SmallVector opOperands = getAliasingOpOperand(result); + SmallVector opOperands = state.getAliasingOpOperand(result); assert(llvm::find(opOperands, &operand) != opOperands.end() && "operand and result do not match"); } else { @@ -416,7 +418,7 @@ aliasInfo.applyOnAliases(root, [&](Value alias) { for (auto &use : alias.getUses()) // Read to a value that aliases root. - if (bufferizesToMemoryRead(use)) + if (state.bufferizesToMemoryRead(use)) res.insert(&use); }); }; @@ -426,7 +428,7 @@ aliasInfo.applyOnAliases(root, [&](Value alias) { for (auto &use : alias.getUses()) // Inplace write to a value that aliases root. - if (isInplaceMemoryWrite(use, aliasInfo)) + if (isInplaceMemoryWrite(use, aliasInfo, state)) res.insert(&use); }); }; @@ -439,7 +441,7 @@ getAliasingInplaceWrites(usesWrite, operand.get()); if (result) getAliasingInplaceWrites(usesWrite, result); - if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand)) + if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state, @@ -453,7 +455,7 @@ const BufferizationAliasInfo &aliasInfo, BufferizationState &state) { #ifndef NDEBUG - SmallVector opOperands = getAliasingOpOperand(opResult); + SmallVector opOperands = state.getAliasingOpOperand(opResult); assert(llvm::find(opOperands, &opOperand) != opOperands.end() && "operand and result do not match"); #endif // NDEBUG @@ -467,9 +469,9 @@ return false; // This is a problem only if the buffer is written to via some alias. - bool hasWrite = aliasesInPlaceWrite(opResult, aliasInfo) || - aliasesInPlaceWrite(opOperand.get(), aliasInfo) || - bufferizesToMemoryWrite(opOperand); + bool hasWrite = aliasesInPlaceWrite(opResult, aliasInfo, state) || + aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) || + state.bufferizesToMemoryWrite(opOperand); if (!hasWrite) return false; @@ -485,7 +487,7 @@ OpOperand &operand, OpResult result, BufferizationAliasInfo &aliasInfo, BufferizationState &state, const DominanceInfo &domInfo) { #ifndef NDEBUG - SmallVector opOperands = getAliasingOpOperand(result); + SmallVector opOperands = state.getAliasingOpOperand(result); assert(llvm::find(opOperands, &operand) != opOperands.end() && "operand and result do not match"); #endif // NDEBUG @@ -539,7 +541,8 @@ for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) - if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand)) + if (OpResult opResult = + bufferizableOp.getAliasingOpResult(opOperand, state)) if (failed(bufferizableInPlaceAnalysisImpl( opOperand, opResult, aliasInfo, state, domInfo))) return failure(); @@ -569,16 +572,16 @@ /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. static void equivalenceAnalysis(SmallVector &ops, BufferizationAliasInfo &aliasInfo, - const BufferizationOptions &options) { + BufferizationState &state) { for (Operation *op : ops) - if (auto bufferizableOp = options.dynCastBufferizableOp(op)) + if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) for (OpResult opResult : op->getOpResults()) if (opResult.getType().isa()) if (aliasInfo.isInPlace(opResult)) { SmallVector opOperands = - bufferizableOp.getAliasingOpOperand(opResult); + bufferizableOp.getAliasingOpOperand(opResult, state); if (!opOperands.empty()) - if (bufferizableOp.bufferRelation(opResult, aliasInfo) == + if (bufferizableOp.bufferRelation(opResult, aliasInfo, state) == BufferRelation::Equivalent) for (OpOperand *opOperand : opOperands) aliasInfo.unionEquivalenceClasses(opResult, opOperand->get()); @@ -589,7 +592,7 @@ /// in `op`. static void equivalenceAnalysis(Operation *op, BufferizationAliasInfo &aliasInfo, - const BufferizationOptions &options) { + BufferizationState &state) { // Traverse ops in PostOrder: Nested ops first, then enclosing ops. SmallVector ops; op->walk([&](Operation *op) { @@ -599,7 +602,7 @@ ops.push_back(op); }); - equivalenceAnalysis(ops, aliasInfo, options); + equivalenceAnalysis(ops, aliasInfo, state); } /// Assert that the current bufferization decisions are consistent. @@ -613,7 +616,8 @@ if (auto bufferizableOp = options.dynCastBufferizableOp(op)) for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) { - OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand); + OpResult opResult = + bufferizableOp.getAliasingOpResult(opOperand, state); if (wouldCreateReadAfterWriteInterference( opOperand, opResult, domInfo, state, aliasInfo, /*checkConsistencyOnly=*/true)) { @@ -669,7 +673,7 @@ if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo, options.analysisFuzzerSeed))) return failure(); - equivalenceAnalysis(op, aliasInfo, options); + equivalenceAnalysis(op, aliasInfo, state); auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) { for (const std::unique_ptr &step : steps) { @@ -679,7 +683,7 @@ // Analyze ops that were created by the PostAnalysisStep. if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo))) return failure(); - equivalenceAnalysis(newOps, aliasInfo, options); + equivalenceAnalysis(newOps, aliasInfo, state); } return success(); }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -140,18 +140,22 @@ struct LinalgOpInterface : public BufferizableOpInterface::ExternalModel, OpTy> { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { auto genericOp = cast(op); return genericOp.payloadUsesValueFromOperand(&opOperand); } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { auto bufferizableOp = cast(op); - return static_cast(bufferizableOp.getAliasingOpResult(opOperand)); + return static_cast( + bufferizableOp.getAliasingOpResult(opOperand, state)); } - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + BufferizationState &state) const { auto genericOp = cast(op); DenseMap pairs = computeAliasingPairs(genericOp); for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) @@ -160,14 +164,16 @@ return {}; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { auto genericOp = cast(op); DenseMap pairs = computeAliasingPairs(genericOp); return pairs[&opOperand]; } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationAliasInfo &aliasInfo) const { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) const { return BufferRelation::Equivalent; } @@ -180,7 +186,8 @@ struct InitTensorOpInterface : public BufferizableOpInterface::ExternalModel { - bool isMemoryWrite(Operation *op, OpResult opResult) const { + bool isMemoryWrite(Operation *op, OpResult opResult, + BufferizationState &state) const { // InitTensorOps allocate but do not write. return false; } @@ -203,27 +210,32 @@ struct TiledLoopOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { // TiledLoop alone doesn't bufferize to a memory read, one of the uses of // its matching bbArg may. auto tiledLoopOp = cast(op); - return isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand)); + return state.isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand)); } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { // TiledLoop alone doesn't bufferize to a memory write, one of the uses of // its matching bbArg may. auto bufferizableOp = cast(op); - return static_cast(bufferizableOp.getAliasingOpResult(opOperand)); + return static_cast( + bufferizableOp.getAliasingOpResult(opOperand, state)); } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { auto tiledLoopOp = cast(op); return tiledLoopOp.getTiedOpResult(opOperand); } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationAliasInfo &aliasInfo) const { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) const { return BufferRelation::Equivalent; } @@ -331,15 +343,18 @@ struct YieldOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return true; } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return OpResult(); } @@ -391,7 +406,6 @@ std::function rewriteFunc, SmallVector &newOps) { OpBuilder b(op->getContext()); - const BufferizationOptions &options = state.getOptions(); WalkResult status = op->walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { @@ -400,7 +414,7 @@ continue; SetVector maybeInitTensor = - findValueInReverseUseDefChain(operand.get(), options, [&](Value val) { + state.findValueInReverseUseDefChain(operand.get(), [&](Value val) { // Continue traversal until this function returns true. OpResult opResult = val.dyn_cast(); if (!opResult) @@ -410,7 +424,7 @@ // Only equivalent tensors are supported at the moment. // TODO: Support cases such as extract_slice(init_tensor). SmallVector opOperands = - getAliasingOpOperand(opResult); + state.getAliasingOpOperand(opResult); if (!llvm::all_of(opOperands, [&](OpOperand *operand) { return aliasInfo.areEquivalentBufferizedValues(operand->get(), opResult); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -490,7 +490,8 @@ struct CallOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { // CallOpInterface alone doesn't bufferize to a memory read, one of the uses // of the matching bbArg may. It is the responsibility of the caller to // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be @@ -498,7 +499,8 @@ return true; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { // CallOpInterface is special, it needs to wait for the callee to be // bufferized and needs to inspect the BufferAliasInfo object. It can't // make a proper determination by itself and needs to be conservative. @@ -618,15 +620,18 @@ struct ReturnOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return true; } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return OpResult(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -22,8 +22,9 @@ struct ExecuteRegionOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + BufferizationState &state) const { // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be // any SSA value that is in scope. To allow for use-def chain traversal // through ExecuteRegionOps in the analysis, the corresponding yield value @@ -39,7 +40,8 @@ return {&yieldOp->getOpOperand(resultNum)}; } - bool mustBufferizeInPlace(Operation *op, OpResult opResult) const { + bool mustBufferizeInPlace(Operation *op, OpResult opResult, + BufferizationState &state) const { // ExecuteRegionOp results always bufferize in-place. Since they have no // OpOperands, they are mostly ignored by the analysis once alias sets are // set up. @@ -48,7 +50,8 @@ // TODO: For better bufferization results, this could return `true` only if // there is a memory write in the region. - bool isMemoryWrite(Operation *op, OpResult opResult) const { + bool isMemoryWrite(Operation *op, OpResult opResult, + BufferizationState &state) const { // Similar to scf.if, results of this op are always considered memory writes // in the analysis. This is a useful pattern for all ops that have tensor // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is @@ -71,15 +74,17 @@ } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationAliasInfo &aliasInfo) const { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) const { return BufferRelation::Equivalent; } }; struct IfOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + BufferizationState &state) const { // IfOps do not have tensor OpOperands. The yielded value can be any SSA // value that is in scope. To allow for use-def chain traversal through // IfOps in the analysis, both corresponding yield values from the then/else @@ -95,7 +100,8 @@ // there is a memory write in one (or both) of the branches. Since this is not // allowed at the moment, we should never encounter scf.ifs that yield // unmodified tensors. Such scf.yield ops could just fold away. - bool isMemoryWrite(Operation *op, OpResult opResult) const { + bool isMemoryWrite(Operation *op, OpResult opResult, + BufferizationState &state) const { // IfOp results are always considered memory writes in the analysis. This // design decision simplifies the analysis considerably. E.g., consider the // following test case: @@ -121,7 +127,8 @@ return true; } - bool mustBufferizeInPlace(Operation *op, OpResult opResult) const { + bool mustBufferizeInPlace(Operation *op, OpResult opResult, + BufferizationState &state) const { // IfOp results always bufferize in-place. Since they have no OpOperands, // they are mostly ignored by the analysis once alias sets are set up. return true; @@ -203,12 +210,13 @@ } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationAliasInfo &aliasInfo) const { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) const { // IfOp results are equivalent to their corresponding yield values if both // yield values are equivalent to each other. auto bufferizableOp = cast(op); SmallVector yieldValues = - bufferizableOp.getAliasingOpOperand(opResult); + bufferizableOp.getAliasingOpOperand(opResult, state); assert(yieldValues.size() == 2 && "expected 2 yield values"); bool equivalentYields = aliasInfo.areEquivalentBufferizedValues( yieldValues[0]->get(), yieldValues[1]->get()); @@ -219,21 +227,24 @@ struct ForOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of // its matching bbArg may. auto forOp = cast(op); - return isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); + return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand)); } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { // Tensor iter_args of scf::ForOps are always considered as a write. This is // to simplify the analysis. // TODO: Consider doing sth. like isValueWritten. return true; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { auto forOp = cast(op); if (!opOperand.get().getType().isa()) return OpResult(); @@ -241,7 +252,8 @@ } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationAliasInfo &aliasInfo) const { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) const { // ForOp results are equivalent to their corresponding init_args if the // corresponding iter_args and yield values are equivalent. auto forOp = cast(op); @@ -410,15 +422,18 @@ struct YieldOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return true; } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return OpResult(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -40,20 +40,24 @@ struct CastOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return false; } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return op->getResult(0); } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationAliasInfo &aliasInfo) const { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) const { return BufferRelation::Equivalent; } @@ -86,15 +90,18 @@ struct DimOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return true; } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return OpResult(); } @@ -112,22 +119,26 @@ struct ExtractSliceOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return false; } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return &opOperand == &op->getOpOperand(0) /*source*/ ? op->getResult(0) : OpResult(); } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationAliasInfo &aliasInfo) const { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) const { return BufferRelation::None; } @@ -160,7 +171,7 @@ /// If not inplaceable, copy. if (!inplace) { // Do not copy if the copied data is never read. - if (isValueRead(extractSliceOp.result())) + if (state.isValueRead(extractSliceOp.result())) state.createMemCpy(b, extractSliceOp.getLoc(), subView, alloc); subView = alloc; } @@ -173,15 +184,18 @@ struct ExtractOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return true; } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return OpResult(); } @@ -198,22 +212,26 @@ struct InsertOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return true; } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return true; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { assert(&opOperand == &op->getOpOperand(1) /*dest*/ && "expected dest OpOperand"); return op->getOpResult(0); } - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + BufferizationState &state) const { return {&op->getOpOperand(1) /*dest*/}; } @@ -229,7 +247,8 @@ } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationAliasInfo &aliasInfo) const { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) const { return BufferRelation::Equivalent; } }; @@ -272,8 +291,8 @@ /// Return true if `value` is originating from an ExtractSliceOp that matches /// the given InsertSliceOp. static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo, - const BufferizationOptions &options, - Value value, InsertSliceOp insertOp) { + BufferizationState &state, Value value, + InsertSliceOp insertOp) { auto condition = [&](Value val) { if (auto extractOp = val.getDefiningOp()) if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp)) @@ -281,29 +300,33 @@ return false; }; - return llvm::all_of(findValueInReverseUseDefChain(value, options, condition), + return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), condition); } struct InsertSliceOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return true; } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return &opOperand == &op->getOpOperand(1) /*dest*/; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return &opOperand == &op->getOpOperand(1) /*dest*/ ? op->getResult(0) : OpResult(); } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationAliasInfo &aliasInfo) const { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) const { return BufferRelation::Equivalent; } @@ -325,8 +348,8 @@ // TODO: Use insertSliceOp.getDestOpOperand etc. when available. if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), - uConflictingWrite->get(), insertSliceOp)) + hasMatchingExtractSliceOp(aliasInfo, state, uConflictingWrite->get(), + insertSliceOp)) // Case 1: The main insight is that InsertSliceOp reads only part of // the destination tensor. The overwritten area is not read. If // uConflictingWrite writes into exactly the memory location that is @@ -343,7 +366,7 @@ if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), uRead->get(), + hasMatchingExtractSliceOp(aliasInfo, state, uRead->get(), insertSliceOp)) // Case 2: The read of the source tensor and the write to the dest // tensor via an InsertSliceOp is not a conflict if the read is @@ -377,8 +400,8 @@ if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && aliasInfo.areEquivalentBufferizedValues(uRead->get(), insertSliceOp.source()) && - hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), - insertSliceOp.source(), insertSliceOp)) + hasMatchingExtractSliceOp(aliasInfo, state, insertSliceOp.source(), + insertSliceOp)) return true; return false; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -20,19 +20,22 @@ struct TransferReadOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return true; } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { return OpResult(); } @@ -56,26 +59,30 @@ struct TransferWriteOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return true; } - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return true; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + BufferizationState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return op->getOpResult(0); } BufferRelation bufferRelation(Operation *op, OpResult opResult, - const BufferizationAliasInfo &aliasInfo) const { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) const { return BufferRelation::Equivalent; }