diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -552,29 +552,30 @@ /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. /// equivalent operand / result and same offset/sizes/strides specification). -/// -/// This is one particular type of relationship between ops on tensors that -/// reduce to an equivalence on buffers. This should be generalized and -/// exposed as interfaces on the proper types. +template static bool areEquivalentExtractSliceOps(const AnalysisState &state, - ExtractSliceOp st, InsertSliceOp sti) { - if (!st || !sti) + ExtractSliceOp extractSliceOp, + OpTy insertSliceOp) { + if (!extractSliceOp || !insertSliceOp) return false; - if (sti != sti && - !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) + if (extractSliceOp != insertSliceOp && + !state.areEquivalentBufferizedValues(extractSliceOp.getSource(), + insertSliceOp.getDest())) return false; - if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) + if (!sameOffsetsSizesAndStrides(extractSliceOp, insertSliceOp, + isEqualConstantIntOrValue)) return false; return true; } /// Return true if `value` is originating from an ExtractSliceOp that matches /// the given InsertSliceOp. +template static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, - InsertSliceOp insertOp) { + OpTy insertSliceOp) { auto condition = [&](Value val) { - if (auto extractOp = val.getDefiningOp()) - if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) + if (auto extractSliceOp = val.getDefiningOp()) + if (areEquivalentExtractSliceOps(state, extractSliceOp, insertSliceOp)) return true; return false; }; @@ -583,6 +584,83 @@ condition); } +template +static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead, + OpOperand *uConflictingWrite, + const AnalysisState &state) { + Operation *readingOp = uRead->getOwner(); + Operation *conflictingWritingOp = uConflictingWrite->getOwner(); + + // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If + // uRead is an InsertSliceOp... + if (auto insertSliceOp = dyn_cast(readingOp)) { + // As an example, consider the following IR. + // + // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } + // %1 = linalg.fill %cst, %0 {inplace= [true] } + // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] + // {inplace= [true] } + + // TODO: Use insertSliceOp.getDestOpOperand etc. when available. + if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && + hasMatchingExtractSliceOp(state, uConflictingWrite->get(), + insertSliceOp)) + // Case 1: The main insight is that InsertSliceOp reads only part of + // the destination tensor. The overwritten area is not read. If + // uConflictingWrite writes into exactly the memory location that is + // being read by uRead, this is not a conflict. + // + // In the above example: + // uRead = OpOperand 1 (%t) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%0) of linalg.fill + // + // The read of %t does not conflict with the write of the FillOp + // (same aliases!) because the area that the FillOp operates on is + // exactly the one that is *not* read via %t. + return true; + + if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && + uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && + hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) + // Case 2: The read of the source tensor and the write to the dest + // tensor via an InsertSliceOp is not a conflict if the read is + // reading exactly that part of an equivalent tensor that the + // InsertSliceOp is writing. + // + // In the above example: + // uRead = OpOperand 0 (%1) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice + return true; + } + + // If uConflictingWrite is an InsertSliceOp... + if (auto insertSliceOp = dyn_cast(conflictingWritingOp)) + // As an example, consider the following IR. + // + // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } + // %1 = linalg.fill %cst, %0 {inplace= [true] } + // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] + // {inplace= [true] } + // %3 = vector.transfer_read %1, %cst + // + // In the above example: + // uRead = OpOperand 0 (%1) of vector.transfer_read + // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice + // lastWrite = %1 + // + // This is not a conflict because the InsertSliceOp overwrites the + // memory segment of %1 with the exact same data. (Effectively, there + // is no memory write here.) + if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && + state.areEquivalentBufferizedValues(uRead->get(), + insertSliceOp.getSource()) && + hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), + insertSliceOp)) + return true; + + return false; +} + /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under /// certain circumstances, this op can also be a no-op. struct InsertSliceOpInterface @@ -613,77 +691,8 @@ bool isNotConflicting(Operation *op, OpOperand *uRead, OpOperand *uConflictingWrite, const AnalysisState &state) const { - Operation *readingOp = uRead->getOwner(); - Operation *conflictingWritingOp = uConflictingWrite->getOwner(); - - // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If - // uRead is an InsertSliceOp... - if (auto insertSliceOp = dyn_cast(readingOp)) { - // As an example, consider the following IR. - // - // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } - // %1 = linalg.fill %cst, %0 {inplace= [true] } - // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] - // {inplace= [true] } - - // TODO: Use insertSliceOp.getDestOpOperand etc. when available. - if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(state, uConflictingWrite->get(), - insertSliceOp)) - // Case 1: The main insight is that InsertSliceOp reads only part of - // the destination tensor. The overwritten area is not read. If - // uConflictingWrite writes into exactly the memory location that is - // being read by uRead, this is not a conflict. - // - // In the above example: - // uRead = OpOperand 1 (%t) of tensor.insert_slice - // uConflictingWrite = OpOperand 1 (%0) of linalg.fill - // - // The read of %t does not conflict with the write of the FillOp - // (same aliases!) because the area that the FillOp operates on is - // exactly the one that is *not* read via %t. - return true; - - if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && - uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) - // Case 2: The read of the source tensor and the write to the dest - // tensor via an InsertSliceOp is not a conflict if the read is - // reading exactly that part of an equivalent tensor that the - // InsertSliceOp is writing. - // - // In the above example: - // uRead = OpOperand 0 (%1) of tensor.insert_slice - // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice - return true; - } - - // If uConflictingWrite is an InsertSliceOp... - if (auto insertSliceOp = dyn_cast(conflictingWritingOp)) - // As an example, consider the following IR. - // - // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } - // %1 = linalg.fill %cst, %0 {inplace= [true] } - // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] - // {inplace= [true] } - // %3 = vector.transfer_read %1, %cst - // - // In the above example: - // uRead = OpOperand 0 (%1) of vector.transfer_read - // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice - // lastWrite = %1 - // - // This is not a conflict because the InsertSliceOp overwrites the - // memory segment of %1 with the exact same data. (Effectively, there - // is no memory write here.) - if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - state.areEquivalentBufferizedValues(uRead->get(), - insertSliceOp.getSource()) && - hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), - insertSliceOp)) - return true; - - return false; + return isNotConflictingInsertSliceLikeOp( + op, uRead, uConflictingWrite, state); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -805,36 +814,6 @@ } }; -/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e. -/// equivalent operand / result and same offset/sizes/strides specification). -static bool areEquivalentExtractSliceOps(const AnalysisState &state, - ExtractSliceOp st, - ParallelInsertSliceOp sti) { - if (!st || !sti) - return false; - if (st != sti && - !state.areEquivalentBufferizedValues(st.getSource(), sti.getDest())) - return false; - if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) - return false; - return true; -} - -/// Return true if `value` is originating from an ExtractSliceOp that matches -/// the given InsertSliceOp. -static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value, - ParallelInsertSliceOp insertOp) { - auto condition = [&](Value val) { - if (auto extractOp = val.getDefiningOp()) - if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) - return true; - return false; - }; - - return llvm::all_of(state.findValueInReverseUseDefChain(value, condition), - condition); -} - /// Analysis of ParallelInsertSliceOp. struct ParallelInsertSliceOpInterface : public BufferizableOpInterface::ExternalModel< @@ -978,83 +957,11 @@ return success(); } - // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share - // the code. bool isNotConflicting(Operation *op, OpOperand *uRead, OpOperand *uConflictingWrite, const AnalysisState &state) const { - Operation *readingOp = uRead->getOwner(); - Operation *conflictingWritingOp = uConflictingWrite->getOwner(); - - // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If - // uRead is an InsertSliceOp... - if (auto insertSliceOp = dyn_cast(readingOp)) { - // As an example, consider the following IR. - // - // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } - // %1 = linalg.fill %cst, %0 {inplace= [true] } - // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] - // {inplace= [true] } - - // TODO: Use insertSliceOp.getDestOpOperand etc. when available. - if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(state, uConflictingWrite->get(), - insertSliceOp)) - // Case 1: The main insight is that InsertSliceOp reads only part of - // the destination tensor. The overwritten area is not read. If - // uConflictingWrite writes into exactly the memory location that is - // being read by uRead, this is not a conflict. - // - // In the above example: - // uRead = OpOperand 1 (%t) of tensor.insert_slice - // uConflictingWrite = OpOperand 1 (%0) of linalg.fill - // - // The read of %t does not conflict with the write of the FillOp - // (same aliases!) because the area that the FillOp operates on is - // exactly the one that is *not* read via %t. - return true; - - if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && - uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp)) - // Case 2: The read of the source tensor and the write to the dest - // tensor via an InsertSliceOp is not a conflict if the read is - // reading exactly that part of an equivalent tensor that the - // InsertSliceOp is writing. - // - // In the above example: - // uRead = OpOperand 0 (%1) of tensor.insert_slice - // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice - return true; - } - - // If uConflictingWrite is an InsertSliceOp... - if (auto insertSliceOp = - dyn_cast(conflictingWritingOp)) - // As an example, consider the following IR. - // - // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } - // %1 = linalg.fill %cst, %0 {inplace= [true] } - // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] - // {inplace= [true] } - // %3 = vector.transfer_read %1, %cst - // - // In the above example: - // uRead = OpOperand 0 (%1) of vector.transfer_read - // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice - // lastWrite = %1 - // - // This is not a conflict because the InsertSliceOp overwrites the - // memory segment of %1 with the exact same data. (Effectively, there - // is no memory write here.) - if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - state.areEquivalentBufferizedValues(uRead->get(), - insertSliceOp.getSource()) && - hasMatchingExtractSliceOp(state, insertSliceOp.getSource(), - insertSliceOp)) - return true; - - return false; + return isNotConflictingInsertSliceLikeOp( + op, uRead, uConflictingWrite, state); } };