diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -330,6 +330,11 @@ /// the op is not bufferizable. bool bufferizesToMemoryWrite(OpOperand &opOperand) const; + /// Return true if the given `value` bufferizes to a memory write. Return + /// true if the value is a block argument. Return `true` if the defining op is + /// not bufferizable. Otherwise, consult the BufferizableOpInterface. + bool bufferizesToMemoryWrite(Value value) const; + /// Return true if `opOperand` does neither read nor write but bufferizes to /// an alias. Return false if the op is not bufferizable. bool bufferizesToAliasOnly(OpOperand &opOperand) const; @@ -349,7 +354,8 @@ /// traversed any further. /// /// When reaching the end of a chain (BlockArgument or Value without aliasing - /// OpOperands), also return the last Value of that chain. + /// OpOperands), also return the last Value of that chain if + /// `alwaysIncludeLeaves` is set. /// /// Example: /// @@ -368,10 +374,9 @@ /// { 2, 7, 8, 5 } /// /// If `followEquivalentOnly` is set, only equivalent OpOperands are selected. - SetVector - findValueInReverseUseDefChain(Value value, - llvm::function_ref condition, - bool followEquivalentOnly = false) const; + SetVector findValueInReverseUseDefChain( + Value value, llvm::function_ref condition, + bool followEquivalentOnly = false, bool alwaysIncludeLeaves = true) const; /// Find the Values of the last preceding write of a given Value. /// @@ -530,6 +535,12 @@ defaultGetBufferType(Value value, const BufferizationOptions &options, const DenseMap &fixedTypes); +/// This is the default implementation of +/// BufferizableOpInterface::resultBufferizesToMemoryWrite. Should not be called +/// from other places. +bool defaultResultBufferizesToMemoryWrite(OpResult opResult, + const AnalysisState &state); + /// This is the default implementation of /// BufferizableOpInterface::isRepetitiveRegion. Should not be called from other /// places. diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -92,35 +92,37 @@ }] >, InterfaceMethod< - /*desc=*/[{ - Return `true` if the given OpResult is a memory write. This is the - case if in the following cases: + /*desc=*/[{ + Return `true` if the given OpResult bufferizes to a memory write. + This is the same property as `bufferizesToMemoryWrite`, but from The + perspective of OpResults. + + This method will never be called on OpResults that do not have a + tensor type. - * The corresponding aliasing OpOperand bufferizes to a memory write. - * Or: There is no corresponding aliasing OpOperand. + This method has a default implementation. By default, it returns + `true` if: - If the OpResult has multiple aliasing OpOperands, this method - returns `true` if at least one of them bufferizes to a memory write. + * There is no corresponding aliasing OpOperand. + * Or: At least one aliasing OpOperand bufferizes to a memory write. + * Or: At least one aliasing OpOperand's value is defined inside the + defining op of the given OpResult and it is a memory write. + + Note: According to the third rule, an aliasing OpOperand value that is + defined of this op and is bufferizing to a memory write makes the + given OpResult bufferize to a memory write. }], - /*retType=*/"bool", - /*methodName=*/"isMemoryWrite", - /*args=*/(ins "::mlir::OpResult":$opResult, - "const ::mlir::bufferization::AnalysisState &":$state), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto bufferizableOp = - cast($_op.getOperation()); - SmallVector opOperands = - bufferizableOp.getAliasingOpOperand(opResult, state); - if (opOperands.empty()) - return true; - return llvm::any_of( - opOperands, - [&](OpOperand *operand) { - return bufferizableOp.bufferizesToMemoryWrite(*operand, - state); - }); - }] + /*retType=*/"bool", + /*methodName=*/"resultBufferizesToMemoryWrite", + /*args=*/(ins "::mlir::OpResult":$opResult, + "const ::mlir::bufferization::AnalysisState &":$state), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opResult.getDefiningOp() == $_op.getOperation() && + "invalid OpResult"); + return bufferization::detail::defaultResultBufferizesToMemoryWrite( + opResult, state); + }] >, InterfaceMethod< /*desc=*/[{ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -89,7 +89,8 @@ LogicalResult bufferize(RewriterBase &rewriter, const BufferizationOptions &options); - bool isMemoryWrite(OpResult opResult, const AnalysisState &state); + bool resultBufferizesToMemoryWrite(OpResult opResult, + const AnalysisState &state); bool bufferizesToAllocation(OpResult opResult) { return true; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -372,6 +372,16 @@ return false; } +bool AnalysisState::bufferizesToMemoryWrite(Value value) const { + auto opResult = value.dyn_cast(); + if (!opResult) + return true; + auto bufferizableOp = getOptions().dynCastBufferizableOp(value); + if (!bufferizableOp) + return true; + return bufferizableOp.resultBufferizesToMemoryWrite(opResult, *this); +} + /// Return true if the given value is read by an op that bufferizes to a memory /// read. Also takes into account ops that create an alias but do not read by /// themselves (e.g., ExtractSliceOp). @@ -401,7 +411,7 @@ // further. llvm::SetVector AnalysisState::findValueInReverseUseDefChain( Value value, llvm::function_ref condition, - bool followEquivalentOnly) const { + bool followEquivalentOnly, bool alwaysIncludeLeaves) const { llvm::SetVector result, workingSet; workingSet.insert(value); @@ -426,7 +436,8 @@ (followEquivalentOnly && bufferizableOp.bufferRelation(opResult, *this) != BufferRelation::Equivalent)) { - result.insert(value); + if (alwaysIncludeLeaves) + result.insert(value); continue; } @@ -440,15 +451,8 @@ // Find the Values of the last preceding write of a given Value. llvm::SetVector AnalysisState::findLastPrecedingWrite(Value value) const { - return findValueInReverseUseDefChain(value, [&](Value value) { - Operation *op = value.getDefiningOp(); - if (!op) - return true; - auto bufferizableOp = options.dynCastBufferizableOp(op); - if (!bufferizableOp) - return true; - return bufferizableOp.isMemoryWrite(value.cast(), *this); - }); + return findValueInReverseUseDefChain( + value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }); } AnalysisState::AnalysisState(const BufferizationOptions &options) @@ -585,6 +589,72 @@ .getResult(); } +bool bufferization::detail::defaultResultBufferizesToMemoryWrite( + OpResult opResult, const AnalysisState &state) { + auto bufferizableOp = cast(opResult.getDefiningOp()); + SmallVector opOperands = + bufferizableOp.getAliasingOpOperand(opResult, state); + + // OpResults that have no aliasing OpOperand usually bufferize memory writes. + // E.g.: tensor.generate ... : tensor<10xf32> fills a newly allocated buffer. + // Counter-example: bufferization.alloc_tensor just allocates and does not + // specifiy the data of the tensor, so resultBufferizesToMemoryWrite is + // overridden to return false. + if (opOperands.empty()) + return true; + + // If an aliasing OpOperand bufferizes to a memory write, the OpResult may + // bufferize to a memory write. + if (llvm::any_of(opOperands, [&](OpOperand *operand) { + return state.bufferizesToMemoryWrite(*operand); + })) + return true; + + // Check if a nested aliasing OpOperand value bufferizes to a memory write. + // In that case, the OpResult bufferizes to a memory write. E.g.: + // + // %0 = "some_writing_op" : tensor + // %r = scf.if ... -> tensor { + // scf.yield %0 : tensor + // } else { + // %1 = "another_writing_op"(%0) : tensor + // scf.yield %1 : tensor + // } + // "some_reading_op"(%r) + // + // %r bufferizes to a memory write because an aliasing OpOperand value (%1) + // bufferizes to a memory write and the defining op is inside the scf.if. + // + // Note: This treatment of surrouding ops is useful for ops that have a + // region but no OpOperand such as scf.if or scf.execute_region. It simplifies + // the analysis considerably. + // + // "another_writing_op" in the above example should be able to bufferize + // inplace in the absence of another read of %0. However, if the scf.if op + // would not be considered a "write", the analysis would detect the + // following conflict: + // + // * read = some_reading_op + // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) + // * conflictingWrite = %1 + // + auto isMemoryWriteInsideOp = [&](Value v) { + Operation *op = getOwnerOfValue(v); + if (!opResult.getDefiningOp()->isAncestor(op)) + return false; + return state.bufferizesToMemoryWrite(v); + }; + for (OpOperand *operand : opOperands) { + if (!state + .findValueInReverseUseDefChain( + operand->get(), isMemoryWriteInsideOp, + /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false) + .empty()) + return true; + } + return false; +} + FailureOr bufferization::detail::defaultGetBufferType( Value value, const BufferizationOptions &options, const DenseMap &fixedTypes) { diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -205,8 +205,8 @@ return success(); } -bool AllocTensorOp::isMemoryWrite(OpResult opResult, - const AnalysisState &state) { +bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult, + const AnalysisState &state) { // AllocTensorOps do not write unless they have a `copy` value. return static_cast(getCopy()); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -276,10 +276,7 @@ // memory write or not. SetVector lastWrites = findLastPrecedingWrite(opResult); bool isUndefined = llvm::none_of(lastWrites, [&](Value lastWrite) { - if (auto bufferizableOp = getOptions().dynCastBufferizableOp(lastWrite)) - return bufferizableOp.isMemoryWrite(lastWrite.cast(), - *this); - return true; + return this->bufferizesToMemoryWrite(lastWrite); }); if (isUndefined) for (OpOperand &use : opResult.getUses()) @@ -380,19 +377,6 @@ return nullptr; } -/// Return `true` if the given tensor value is a memory write. Most values are -/// tensor writes, but ops that define a tensor SSA value without specifying its -/// contents (e.g., alloc_tensor) are not. -static bool isMemoryWrite(Value value, const AnalysisState &state) { - auto opResult = value.dyn_cast(); - if (!opResult) - return true; - auto bufferizableOp = state.getOptions().dynCastBufferizableOp(value); - if (!bufferizableOp) - return true; - return bufferizableOp.isMemoryWrite(opResult, state); -} - /// Return `true` if op dominance can be used to rule out read-after-write /// conflicts wrt. the given reads and writes. /// @@ -495,7 +479,7 @@ // In case of a read, take the region which the read value is defined. for (OpOperand *uRead : usesRead) { // Optimization: Skip reads of values that have no defined contents. - if (!isMemoryWrite(uRead->get(), state)) + if (!state.bufferizesToMemoryWrite(uRead->get())) continue; Region *r = getEnclosingRepetitiveRegion(uRead->get(), options); if (!commonEnclosingRegion.has_value()) { diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -123,18 +123,6 @@ return {&yieldOp->getOpOperand(resultNum)}; } - // TODO: For better bufferization results, this could return `true` only if - // there is a memory write in the region. - bool isMemoryWrite(Operation *op, OpResult opResult, - const AnalysisState &state) const { - // Similar to scf.if, results of this op are always considered memory writes - // in the analysis. This is a useful pattern for all ops that have tensor - // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is - // implemented in terms of `bufferizesToMemoryWrite`, which does not work on - // ops without OpOperands. - return true; - } - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto executeRegionOp = cast(op); @@ -190,37 +178,6 @@ &ifOp.elseYield()->getOpOperand(resultNum)}; } - // TODO: For better bufferization results, this could return `true` only if - // there is a memory write in one (or both) of the branches. Since this is not - // allowed at the moment, we should never encounter scf.ifs that yield - // unmodified tensors. Such scf.yield ops could just fold away. - bool isMemoryWrite(Operation *op, OpResult opResult, - const AnalysisState &state) const { - // IfOp results are always considered memory writes in the analysis. This - // design decision simplifies the analysis considerably. E.g., consider the - // following test case: - // - // %0 = "some_writing_op" : tensor - // %r = scf.if %c -> (tensor) { - // scf.yield %0 - // } else { - // %1 = "another_writing_op"(%0) : tensor - // } - // "some_reading_op"(%r) - // - // "another_writing_op" in the above example should be able to bufferize - // inplace in the absence of another read of %0. However, if the scf.if op - // would not be considered a "write", the analysis would detect the - // following conflict: - // - // * read = some_reading_op - // * lastWrite = %0 (Note: The last write of %r would be a set: {%0, %1}.) - // * conflictingWrite = %1 - // - // For more details, check the "scf.IfOp" section of the design document. - return true; - } - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { OpBuilder::InsertionGuard g(rewriter); diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -46,18 +46,6 @@ return {&yieldOp->getOpOperand(resultNum)}; } - // TODO: For better bufferization results, this could return `true` only if - // there is a memory write in the region. - bool isMemoryWrite(Operation *op, OpResult opResult, - const AnalysisState &state) const { - // Similar to scf.if, results of this op are always considered memory writes - // in the analysis. This is a useful pattern for all ops that have tensor - // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is - // implemented in terms of `bufferizesToMemoryWrite`, which does not work on - // ops without OpOperands. - return true; - } - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto assumingOp = cast(op); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -83,8 +83,8 @@ struct NewOpInterface : public BufferizableOpInterface::ExternalModel { - bool isMemoryWrite(Operation *op, OpResult opResult, - const AnalysisState &state) const { + bool resultBufferizesToMemoryWrite(Operation *op, OpResult opResult, + const AnalysisState &state) const { // NewOps allocate but do not write. return false; }