diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -57,9 +57,9 @@ void insertNewBufferEquivalence(Value newValue, Value alias); /// Return true if the buffer to which `operand` would bufferize aliases a - /// buffer that is known to not be writeable. This implies that the matching + /// buffer that is known to not be writable. This implies that the matching /// OpResult cannot be bufferized inplace. - bool aliasesNonWriteableBuffer(OpOperand &operand) const; + bool aliasesNonWritableBuffer(OpOperand &operand) const; /// Return true if the buffer to which `operand` would bufferize is equivalent /// to some buffer write. @@ -124,6 +124,12 @@ /// Apply `fun` to all the members of the equivalence class of `v`. void applyOnEquivalenceClass(Value v, function_ref fun) const; + /// Return true if the value is known to bufferize to writable memory. + bool bufferizesToWritableMemory(Value v) const; + + /// Specify that the value is known to bufferize to writable memory. + void setBufferizesToWritableMemory(Value v); + /// Print to `os`. void printAliases(raw_ostream &os) const; void printEquivalences(raw_ostream &os) const; @@ -210,6 +216,9 @@ OpOperand &aliasingWrite, const DominanceInfo &domInfo) const; + /// Set of tensors that are known to bufferize to writable memory. + llvm::DenseSet bufferizeToWritableMemory; + /// Auxiliary structure to store all the values a given value aliases with. /// These are the conservative cases that can further decompose into /// "equivalent" buffer relationships. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -726,15 +726,21 @@ } /// Return true if the buffer to which `operand` would bufferize aliases a -/// buffer that is known to not be writeable. This implies that the matching +/// buffer that is known to not be writable. This implies that the matching /// OpResult cannot be bufferized inplace. -bool BufferizationAliasInfo::aliasesNonWriteableBuffer( +bool BufferizationAliasInfo::aliasesNonWritableBuffer( OpOperand &operand) const { - LDBG("----Start aliasesNonWriteableBuffer\n"); + LDBG("----Start aliasesNonWritableBuffer\n"); LDBG("-------for -> #" << operand.getOperandNumber() << ": " << printOperationInfo(operand.getOwner()) << '\n'); for (Value v : getAliases(operand.get())) { LDBG("-----------examine: " << printValueInfo(v) << '\n'); + if (bufferizesToWritableMemory(v)) { + LDBG("-----------Value is known to be writeable -> skip: " + << printValueInfo(v) << '\n'); + continue; + } + if (auto bbArg = v.dyn_cast()) { if (getInPlace(bbArg) == InPlaceSpec::True) { LDBG("-----------bbArg is writeable -> skip: " << printValueInfo(bbArg) @@ -747,15 +753,24 @@ if (Operation *op = v.getDefiningOp()) { if (isa(op) || !hasKnownBufferizationAliasingBehavior(op)) { - LDBG("-----------notWriteable\n"); + LDBG("-----------notWritable\n"); return true; } } } - LDBG("---->operand is writeable\n"); + LDBG("---->operand is writable\n"); return false; } +bool BufferizationAliasInfo::bufferizesToWritableMemory(Value v) const { + return bufferizeToWritableMemory.count(v) > 0; +} + +/// Specify that the value is known to bufferize to writable memory. +void BufferizationAliasInfo::setBufferizesToWritableMemory(Value v) { + bufferizeToWritableMemory.insert(v); +} + /// Return true if the buffer to which `operand` would bufferize is equivalent /// to some buffer write. bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const { @@ -2184,22 +2199,22 @@ << printOperationInfo(extractSliceOp) << '\n'); // If `extractSliceOp` were to be bufferized inplace, it cannot end up - // aliasing a write into a non-writeable buffer. - bool wouldCreateAliasingWriteToNonWriteableBuffer = + // aliasing a write into a non-writable buffer. + bool wouldCreateAliasingWriteToNonWritableBuffer = aliasInfo.aliasesInPlaceWrite(extractSliceOp.result()) && - aliasInfo.aliasesNonWriteableBuffer(extractSliceOp->getOpOperand(0)); + aliasInfo.aliasesNonWritableBuffer(extractSliceOp->getOpOperand(0)); - if (wouldCreateAliasingWriteToNonWriteableBuffer) - LDBG("->the corresponding buffer is not writeable\n"); + if (wouldCreateAliasingWriteToNonWritableBuffer) + LDBG("->the corresponding buffer is not writable\n"); else - LDBG("->bufferizes to writeable inplace buffer\n"); + LDBG("->bufferizes to writable inplace buffer\n"); // In any of extractSliceOp.result's aliases, can we find 2 such that we hit // an interfering write? OpResult r = extractSliceOp->getResult(0); OpOperand &s = extractSliceOp->getOpOperand(0); bool foundInterference = - wouldCreateAliasingWriteToNonWriteableBuffer || + wouldCreateAliasingWriteToNonWritableBuffer || aliasInfo.wouldCreateReadAfterWriteInterference(r, domInfo); if (foundInterference) aliasInfo.bufferizeOutOfPlace(r); @@ -2230,21 +2245,21 @@ << operand.getOperandNumber() << " in " << printValueInfo(result) << '\n'); - // `result` must bufferize to a writeable buffer to be a candidate. + // `result` must bufferize to a writable buffer to be a candidate. // This means the operand must not alias either: // 1. a function bbArg that is not inplaceable or // 2. a constant op. // to be considered for inplace bufferization - bool wouldCreateAliasingWriteToNonWriteableBuffer = - aliasInfo.aliasesNonWriteableBuffer(operand); - if (wouldCreateAliasingWriteToNonWriteableBuffer) - LDBG("->the corresponding buffer is not writeable\n"); + bool wouldCreateAliasingWriteToNonWritableBuffer = + aliasInfo.aliasesNonWritableBuffer(operand); + if (wouldCreateAliasingWriteToNonWritableBuffer) + LDBG("->the corresponding buffer is not writable\n"); else - LDBG("->bufferizes to writeable inplace buffer\n"); + LDBG("->bufferizes to writable inplace buffer\n"); assert(result == getInplaceableOpResult(operand)); bool foundInterference = - wouldCreateAliasingWriteToNonWriteableBuffer || + wouldCreateAliasingWriteToNonWritableBuffer || aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo); if (foundInterference) @@ -2312,6 +2327,15 @@ ops.push_back(op); }); + // Set the function arguments marked with inplaceable to be known as + // bufferizing to a writeable memory. + for (BlockArgument bbArg : funcOp.getArguments()) { + BoolAttr inplaceAttr = funcOp.getArgAttrOfType( + bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName); + if (inplaceAttr && inplaceAttr.getValue()) + aliasInfo.setBufferizesToWritableMemory(bbArg); + } + LogicalResult res = inPlaceAnalysis(ops, aliasInfo, domInfo); LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n');