diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -387,6 +387,35 @@ // clang-format on } +/// Determine which OpOperand* will alias with `result` if the op is bufferized +/// in place. +/// Return None if the owner of `opOperand` does not have known +/// bufferization aliasing behavior, which indicates that the op must allocate +/// all of its tensor results. +/// TODO: in the future this may need to evolve towards a list of OpOperand*. +static Optional getAliasingOpOperand(OpResult result) { + if (!hasKnownBufferizationAliasingBehavior(result.getDefiningOp())) + return None; + return TypeSwitch(result.getDefiningOp()) + .Case([&](LinalgOp op) { + return op.getOutputTensorOperands()[result.getResultNumber()]; + }) + .Case([&](ExtractSliceOp op) { return &op->getOpOperand(0); }) + .Case([&](InsertSliceOp op) { return &op->getOpOperand(1); }) + .Case([&](vector::TransferWriteOp op) { return &op->getOpOperand(1); }) + // In the case of scf::ForOp, this currently assumes the iter_args / yield + // are 1-1. This may fail and is verified at the thend. + // TODO: update this. + .Case([&](scf::ForOp op) { + return &op.getIterOpOperands()[result.getResultNumber()]; + }) + .Default([&](Operation *op) { + op->dump(); + llvm_unreachable("unexpected defining op"); + return nullptr; + }); +} + /// Determine which OpResult will alias with `opOperand` if the op is bufferized /// in place. This is a superset of `getInplaceableOpResult`. /// Return None if the owner of `opOperand` does not have known @@ -508,7 +537,7 @@ void bufferizeOutOfPlace(OpResult result); /// Return true if it is possible to find an inplace write W among the uses of - /// aliasInfo[rootWrite], and a read R among the uses of aliasInfo[rootRead], + /// aliasInfo[result], and a read R among the uses of aliasInfo[result], /// such that W and R interfere. /// Such a (W, R) pair is an interference to the inplace bufferization of /// rootWrite when: @@ -518,15 +547,9 @@ /// C interleaved between W and R (i.e. W -> C -> R where -> denotes /// dominance). bool - wouldCreateReadAfterWriteInterference(Value rootWrite, Value rootRead, - Operation *opToBufferize, + wouldCreateReadAfterWriteInterference(OpResult result, const DominanceInfo &domInfo) const; - /// Return true if we find any read to opOperand.get() or any of its aliases, - /// that does not dominate opOperand.getOwner(). - bool existsNonDominatingRead(OpOperand &opOperand, - const DominanceInfo &domInfo) const; - /// Return true if `v1` and `v2` bufferize to equivalent buffers. bool areEquivalentBufferizedValues(Value v1, Value v2) const { return equivalentInfo.getLeaderValue(v1) == @@ -608,8 +631,8 @@ /// /// Capture possible cases where `aliasingWriteOp(alias(%rootWrite))` has no /// visible effect on `aliasingReadOp(alias(%rootRead))`. - bool isClobberedWriteBeforeRead(Operation *opToBufferize, Value rootRead, - Value rootWrite, OpOperand &aliasingRead, + bool isClobberedWriteBeforeRead(Operation *opToBufferize, + OpOperand &aliasingRead, OpOperand &aliasingWrite, const DominanceInfo &domInfo) const; @@ -732,59 +755,81 @@ setInPlaceOpResult(result, InPlaceSpec::False); } -/// Return true if merging the alias sets of `rootWrite` and `rootRead` would -/// result in a semantic change in the program (i.e. RAW violation). -/// -/// This is the case when one can find an inplace write W among the aliases -/// `rootWrite`, that may become an interference if W were to be bufferized -/// inplace. A potential interference would be with respect to a read R among -/// the aliases of `rootRead`. -/// +static bool isTheSSADef(OpOperand &write, OpOperand &read) { + return getInplaceableOpResult(write) == read.get(); +} + +/// Return true if it is possible to find an inplace write W among the uses of +/// aliasInfo[result], and a read R among the uses of aliasInfo[result], +/// such that W and R interfere. /// Such a (W, R) pair is an interference to the inplace bufferization of -/// rootWrite when R does not properly dominate W (i.e. W may come before R -/// along some control-flow path). +/// rootWrite when: +/// 1. R is not known properly dominate W (i.e. the effects of the write may +/// be visible from R). +/// 2. one cannot find an intermediate clobbering write `C` to W, such that +/// C interleaved between W and R (i.e. W -> C -> R where -> denotes +/// dominance). bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference( - Value rootWrite, Value rootRead, Operation *opToBufferize, - const DominanceInfo &domInfo) const { + OpResult result, const DominanceInfo &domInfo) const { + Optional maybeAliasingOperand = getAliasingOpOperand(result); + if (!maybeAliasingOperand) + return false; + + Operation *opToBufferize = result.getDefiningOp(); + Value root = (*maybeAliasingOperand)->get(); LDBG("----Start wouldCreateReadAfterWriteInterference\n"); + LDBG("--------rootValue: " << root << "\n"); - // Collect all the inplace write uses of some alias of `rootWrite`. + // Collect: + // 1. all the inplace write uses of some alias of `root`. + // 2. all the write uses that belong to `opToBufferize`. + // opToBufferize is not yet inplace, we want to determine if it can be inplace + // so we also consider all its write uses, not just the inplace ones. DenseSet usesWrite; - auto &aliasListWrite = getAliasInfoRef(rootWrite); - for (Value vWrite : aliasListWrite) { + for (Value vWrite : getAliasInfoRef(root)) { for (auto &uWrite : vWrite.getUses()) { - if (!bufferizesToMemoryWrite(uWrite, InPlaceSpec::True)) + if (!bufferizesToMemoryWrite(uWrite)) continue; - usesWrite.insert(&uWrite); + if (uWrite.getOwner() == opToBufferize || + bufferizesToMemoryWrite(uWrite, InPlaceSpec::True)) + usesWrite.insert(&uWrite); } } - - // Collect all the read uses of some alias of `rootRead`. + for (Value vWrite : getAliasInfoRef(result)) + for (auto &uWrite : vWrite.getUses()) + if (bufferizesToMemoryWrite(uWrite, InPlaceSpec::True)) + usesWrite.insert(&uWrite); + + // Collect all the reads of some alias of `root`. + // opToBufferize is not yet inplace, we want to determine if it can be inplace + // so we also consider all read uses of its result. DenseSet usesRead; - auto &aliasListRead = getAliasInfoRef(rootRead); - for (Value vRead : aliasListRead) { - for (auto &uRead : vRead.getUses()) { - if (!bufferizesToMemoryRead(uRead)) - continue; - usesRead.insert(&uRead); - } - } + auto &aliasListRead = getAliasInfoRef(root); + for (Value vRead : aliasListRead) + for (auto &uRead : vRead.getUses()) + if (bufferizesToMemoryRead(uRead)) + usesRead.insert(&uRead); + for (Value vRead : getAliasInfoRef(result)) + for (auto &uRead : vRead.getUses()) + if (bufferizesToMemoryRead(uRead)) + usesRead.insert(&uRead); for (OpOperand *uRead : usesRead) { Operation *aliasingReadOp = uRead->getOwner(); LDBG("----++++aliasRead #" << uRead->getOperandNumber() << " in: " << *aliasingReadOp << '\n'); for (OpOperand *uWrite : usesWrite) { - // Don't consider self-use of the same operand. - // Uses within the same op is fine though. + // Don't consider self-use of the same operand for interference. + // Multiple different uses within the same op is fair game though. if (uWrite == uRead) continue; + Operation *aliasingWriteOp = uWrite->getOwner(); LDBG("---- aliasWrite #" << uWrite->getOperandNumber() << " in: " << *aliasingWriteOp << '\n'); - // If read and written value already alias, no interference would be added - // by bufferizing inplace. - if (getAliasInfoRef(uRead->get()).contains(uWrite->get())) + // If the candidate write is the one that produces the read value (in the + // SSA def-use sense), this is not considered an interference. + if (getInplaceableOpResult(*uWrite) == uRead->get()) continue; // If aliasingReadOp properly dominates aliasingWriteOp, the read cannot // be affected by the write: there is no interference. @@ -797,12 +842,8 @@ << "): " << *aliasingReadOp << '\n'); LDBG(" Interfering write (op #" << uWrite->getOperandNumber() << "): " << *aliasingWriteOp << '\n'); - LDBG(" aliases rootRead: " << rootRead << '\n'); - LDBG(" aliases rootWrite: " << rootWrite << '\n'); LDBG("---->opportunity to clobber RaW interference\n"); - if (isClobberedWriteBeforeRead(opToBufferize, rootRead, rootWrite, *uRead, - *uWrite, domInfo)) { - + if (isClobberedWriteBeforeRead(opToBufferize, *uRead, *uWrite, domInfo)) { LDBG("---->clobbered! -> skip\n"); continue; } @@ -815,35 +856,6 @@ return false; } -/// Return true if we find any read to opOperand.get() or any of its aliases, -/// that does not dominate opOperand.getOwner(). -bool BufferizationAliasInfo::existsNonDominatingRead( - OpOperand &opOperand, const DominanceInfo &domInfo) const { - LDBG("----Start existsNonDominatingRead\n"); - Operation *op = opOperand.getOwner(); - for (Value alias : getAliasInfoRef(opOperand.get())) { - for (OpOperand &wantReadUse : alias.getUses()) { - LDBG("--------current operand #" << wantReadUse.getOperandNumber() << ": " - << *(wantReadUse.getOwner()) << '\n'); - if (!bufferizesToMemoryRead(wantReadUse)) { - LDBG("------------not a read -> skip\n"); - continue; - } - if (&wantReadUse == &opOperand) { - LDBG("------------self-read is not an interference -> skip\n"); - continue; - } - if (domInfo.properlyDominates(wantReadUse.getOwner(), op)) { - LDBG("------------read properly dominates -> skip\n"); - continue; - } - LDBG("----found interfering read of " << wantReadUse.get() << '\n'); - return true; - } - } - return false; -} - /// Return true if the source of a `insertSliceOp` bufferizes to an /// equivalent ExtractSliceOp. bool BufferizationAliasInfo::isSourceEquivalentToAMatchingExtractSliceOp( @@ -977,20 +989,13 @@ /// 3. Clobbers the write that would be interfering with the read. /// bool BufferizationAliasInfo::isClobberedWriteBeforeRead( - Operation *opToBufferize, Value rootRead, Value rootWrite, - OpOperand &aliasingRead, OpOperand &aliasingWrite, + Operation *opToBufferize, OpOperand &aliasingRead, OpOperand &aliasingWrite, const DominanceInfo &domInfo) const { Operation *aliasingReadOp = aliasingRead.getOwner(); Operation *aliasingWriteOp = aliasingWrite.getOwner(); assert(!domInfo.properlyDominates(aliasingReadOp, aliasingWriteOp) && "Unexpected aliasingReadOp properly dominates aliasingWriteOp"); - assert(((rootRead.isa() && - rootRead.getDefiningOp() == opToBufferize) || - (rootWrite.isa() && - rootWrite.getDefiningOp() == opToBufferize)) && - "Expected rootRead or rootWrite to be produced by opToBufferize"); - // Bail if the write does not dominate the read: it may clobber but only on // a strict subset of paths, which is not enough for safety. if (!domInfo.dominates(aliasingWriteOp, aliasingReadOp)) { @@ -1579,14 +1584,9 @@ // an interfering write? OpResult r = extractSliceOp->getResult(0); OpOperand &s = extractSliceOp->getOpOperand(0); - bool foundInterference = wouldCreateAliasingWriteToNonWriteableBuffer || - // Do not consider (s, s) and (r, r) as all the - // aliasings already exist by construction; we are - // interested in new interfering aliases only. - aliasInfo.wouldCreateReadAfterWriteInterference( - s.get(), r, extractSliceOp, domInfo) || - aliasInfo.wouldCreateReadAfterWriteInterference( - r, s.get(), extractSliceOp, domInfo); + bool foundInterference = + wouldCreateAliasingWriteToNonWriteableBuffer || + aliasInfo.wouldCreateReadAfterWriteInterference(r, domInfo); if (foundInterference) aliasInfo.bufferizeOutOfPlace(r); else @@ -1616,8 +1616,10 @@ << result << '\n'); // `result` must bufferize to a writeable buffer to be a candidate. - // This means the use->def chain not backpropagate to a function that is - // not inplaceable or to a constant op to be considered. + // This means the operand must not alias either: + // 1. a function bbArg that is not inplaceable or + // 2. a constant op. + // to be considered for inplace bufferization bool wouldCreateAliasingWriteToNonWriteableBuffer = aliasInfo.aliasesNonWriteableBuffer(operand); if (wouldCreateAliasingWriteToNonWriteableBuffer) @@ -1625,15 +1627,10 @@ else LDBG("->bufferizes to writeable inplace buffer\n"); - Value s = operand.get(), r = result; + assert(result == getInplaceableOpResult(operand)); bool foundInterference = wouldCreateAliasingWriteToNonWriteableBuffer || - aliasInfo.existsNonDominatingRead(operand, domInfo) || - // Do not consider (s, s) and (r, r) as all the aliasings already - // exist by construction; we are interested in new interfering aliases - // only. - aliasInfo.wouldCreateReadAfterWriteInterference(s, r, op, domInfo) || - aliasInfo.wouldCreateReadAfterWriteInterference(r, s, op, domInfo); + aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo); if (foundInterference) aliasInfo.bufferizeOutOfPlace(result);