diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -215,6 +215,29 @@ /*defaultImplementation=*/[{ return false; }] + >, + InterfaceMethod< + /*desc=*/[{ + Return `true` if the `uRead` and `uWrite` do not constitute a RaW + conflict. If they are conflicting or if it is unknown whether they are + conflicting, return `false`. This method will never be called with + OpOperands that do not have a tensor type. At least one of the two + given OpOperands belongs to this operation. + + This method can be implemented to specify custom RaW analysis rules. + If this method returns `true` the given OpOperands are not considered + to be conflicting and do not force out-of-place bufferization. (There + may still be other conflicts that do.) + }], + /*retType=*/"bool", + /*methodName=*/"isNotConflicting", + /*args=*/(ins "OpOperand *":$uRead, + "OpOperand *":$uWrite, + "const BufferizationAliasInfo &":$aliasInfo), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] > ]; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -281,24 +281,6 @@ // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// -/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. -/// equivalent operand / result and same offset/sizes/strides specification). -/// -/// This is one particular type of relationship between ops on tensors that -/// reduce to an equivalence on buffers. This should be generalized and -/// exposed as interfaces on the proper types. -static bool -areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo, - ExtractSliceOp st, InsertSliceOp sti) { - if (!st || !sti) - return false; - if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest())) - return false; - if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) - return false; - return true; -} - /// Return true if opOperand has been decided to bufferize in-place. static bool isInplaceMemoryWrite(OpOperand &opOperand, const BufferizationAliasInfo &aliasInfo) { @@ -368,21 +350,6 @@ return foundInplaceWrite; } -/// Return true if `value` is originating from an ExtractSliceOp that matches -/// the given InsertSliceOp. -static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo, - Value value, InsertSliceOp insertOp) { - auto condition = [&](Value val) { - if (auto extractOp = val.getDefiningOp()) - if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp)) - return true; - return false; - }; - - return llvm::all_of(findValueInReverseUseDefChain(value, condition), - condition); -} - /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors /// properly dominates `b` and `b` is not inside `a`. static bool happensBefore(Operation *a, Operation *b, @@ -450,6 +417,21 @@ if (uConflictingWrite == uRead) continue; + // No conflict if the op interface says so. + if (auto bufferizableOp = dyn_cast(readingOp)) + if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, + aliasInfo)) + continue; + + if (conflictingWritingOp != readingOp) + if (auto bufferizableOp = + dyn_cast(conflictingWritingOp)) + if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, + aliasInfo)) + continue; + + // Special rules for branches. + // TODO: Use an interface. if (scf::insideMutuallyExclusiveBranches(readingOp, conflictingWritingOp)) continue; @@ -478,73 +460,6 @@ if (getAliasingOpResult(*uConflictingWrite) == lastWrite) continue; - // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If - // uRead is an InsertSliceOp... - if (auto insertSliceOp = dyn_cast(readingOp)) { - // As an example, consider the following IR. - // - // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace= [true] } - // %1 = linalg.fill %cst, %0 {inplace= [true] } - // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] - // {inplace= [true] } - - // TODO: Use insertSliceOp.getDestOpOperand etc. when available. - if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(), - insertSliceOp)) - // Case 1: The main insight is that InsertSliceOp reads only part of - // the destination tensor. The overwritten area is not read. If - // uConflictingWrite writes into exactly the memory location that is - // being read by uRead, this is not a conflict. - // - // In the above example: - // uRead = OpOperand 1 (%t) of tensor.insert_slice - // uConflictingWrite = OpOperand 1 (%0) of linalg.fill - // - // The read of %t does not conflict with the write of the FillOp - // (same aliases!) because the area that the FillOp operates on is - // exactly the one that is *not* read via %t. - continue; - - if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && - uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp)) - // Case 2: The read of the source tensor and the write to the dest - // tensor via an InsertSliceOp is not a conflict if the read is - // reading exactly that part of an equivalent tensor that the - // InsertSliceOp is writing. - // - // In the above example: - // uRead = OpOperand 0 (%1) of tensor.insert_slice - // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice - continue; - } - - // If uConflictingWrite is an InsertSliceOp... - if (auto insertSliceOp = dyn_cast(conflictingWritingOp)) - // As an example, consider the following IR. - // - // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace= [true] } - // %1 = linalg.fill %cst, %0 {inplace= [true] } - // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] - // {inplace= [true] } - // %3 = vector.transfer_read %1, %cst - // - // In the above example: - // uRead = OpOperand 0 (%1) of vector.transfer_read - // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice - // lastWrite = %1 - // - // This is not a conflict because the InsertSliceOp overwrites the - // memory segment of %1 with the exact same data. (Effectively, there - // is no memory write here.) - if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - aliasInfo.areEquivalentBufferizedValues(uRead->get(), - insertSliceOp.source()) && - hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(), - insertSliceOp)) - continue; - // All requirements are met. Conflict found! LDBG("CONFLICT CONFIRMED!\n\n"); return true; @@ -2321,6 +2236,24 @@ } }; +/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. +/// equivalent operand / result and same offset/sizes/strides specification). +/// +/// This is one particular type of relationship between ops on tensors that +/// reduce to an equivalence on buffers. This should be generalized and +/// exposed as interfaces on the proper types. +static bool +areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo, + ExtractSliceOp st, InsertSliceOp sti) { + if (!st || !sti) + return false; + if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest())) + return false; + if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) + return false; + return true; +} + /// Return true if the source of a `insertSliceOp` bufferizes to an /// equivalent ExtractSliceOp that bufferizes inplace. static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp( @@ -2345,6 +2278,21 @@ return foundOp; } +/// Return true if `value` is originating from an ExtractSliceOp that matches +/// the given InsertSliceOp. +static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo, + Value value, InsertSliceOp insertOp) { + auto condition = [&](Value val) { + if (auto extractOp = val.getDefiningOp()) + if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp)) + return true; + return false; + }; + + return llvm::all_of(findValueInReverseUseDefChain(value, condition), + condition); +} + struct InsertSliceOpInterface : public BufferizableOpInterface::ExternalModel { @@ -2371,6 +2319,82 @@ return BufferRelation::Equivalent; } + bool isNotConflicting(Operation *op, OpOperand *uRead, + OpOperand *uConflictingWrite, + const BufferizationAliasInfo &aliasInfo) const { + Operation *readingOp = uRead->getOwner(); + Operation *conflictingWritingOp = uConflictingWrite->getOwner(); + + // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If + // uRead is an InsertSliceOp... + if (auto insertSliceOp = dyn_cast(readingOp)) { + // As an example, consider the following IR. + // + // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } + // %1 = linalg.fill %cst, %0 {inplace= [true] } + // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] + // {inplace= [true] } + + // TODO: Use insertSliceOp.getDestOpOperand etc. when available. + if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && + hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(), + insertSliceOp)) + // Case 1: The main insight is that InsertSliceOp reads only part of + // the destination tensor. The overwritten area is not read. If + // uConflictingWrite writes into exactly the memory location that is + // being read by uRead, this is not a conflict. + // + // In the above example: + // uRead = OpOperand 1 (%t) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%0) of linalg.fill + // + // The read of %t does not conflict with the write of the FillOp + // (same aliases!) because the area that the FillOp operates on is + // exactly the one that is *not* read via %t. + return true; + + if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && + uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && + hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp)) + // Case 2: The read of the source tensor and the write to the dest + // tensor via an InsertSliceOp is not a conflict if the read is + // reading exactly that part of an equivalent tensor that the + // InsertSliceOp is writing. + // + // In the above example: + // uRead = OpOperand 0 (%1) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice + return true; + } + + // If uConflictingWrite is an InsertSliceOp... + if (auto insertSliceOp = dyn_cast(conflictingWritingOp)) + // As an example, consider the following IR. + // + // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } + // %1 = linalg.fill %cst, %0 {inplace= [true] } + // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] + // {inplace= [true] } + // %3 = vector.transfer_read %1, %cst + // + // In the above example: + // uRead = OpOperand 0 (%1) of vector.transfer_read + // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice + // lastWrite = %1 + // + // This is not a conflict because the InsertSliceOp overwrites the + // memory segment of %1 with the exact same data. (Effectively, there + // is no memory write here.) + if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && + aliasInfo.areEquivalentBufferizedValues(uRead->get(), + insertSliceOp.source()) && + hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(), + insertSliceOp)) + return true; + + return false; + } + LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { // insert_slice ops arise from tiling and bufferizing them out-of-place is