diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td @@ -60,64 +60,19 @@ llvm_unreachable("bufferizesToMemoryWrite not implemented"); }] >, - // TODO: Simplify this interface by removing `bufferizesToAliasOnly` and - // `getInplaceableOpResult`. Instead, always use `getAliasingOpResult`. If - // `getAliasingOpResult` returns a non-null value, we know that an alias - // is created. If `bufferizesToMemoryRead` and `bufferizesToMemoryWrite` - // return `false`, we know that the operands "bufferizes to alias only". - InterfaceMethod< - /*desc=*/[{ - Return `true` if the given OpOperand creates an alias but does neither - read nor write. This implies that `bufferizesToMemoryRead` and - `bufferizesToMemoryWrite` must return `false`. This method will never - be called on OpOperands that do not have a tensor type. - - Examples of such ops are `tensor.extract_slice` and `tensor.cast`. - }], - /*retType=*/"bool", - /*methodName=*/"bufferizesToAliasOnly", - /*args=*/(ins "OpOperand &":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - // For better debugging, run `bufferizesToMemoryWrite`, which fires an - // assertion when called on an op that should not have tensor - // OpOperands. - (void) cast($_op.getOperation()) - .bufferizesToMemoryWrite(opOperand); - // Return `false` by default, as most ops are not "alias only". - return false; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the OpResult that can bufferize in-place with a given - OpOperand. Return a null value if the OpOperand cannot bufferize - in-place. This method will never be called on OpOperands that do not - have a tensor type. - }], - /*retType=*/"OpResult", - /*methodName=*/"getInplaceableOpResult", - /*args=*/(ins "OpOperand &":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - // Does not have to be implemented for ops without tensor OpOperands. - llvm_unreachable("getInplaceableOpResult not implemented"); - }] - >, InterfaceMethod< /*desc=*/[{ Return the OpResult that aliases with a given OpOperand when - bufferized in-place. This is a superset of `getInplaceableOpResult`. - This method will never be called on OpOperands that do not have a - tensor type. + bufferized in-place. This method will never be called on OpOperands + that do not have a tensor type. }], /*retType=*/"OpResult", /*methodName=*/"getAliasingOpResult", /*args=*/(ins "OpOperand &":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return cast($_op.getOperation()) - .getInplaceableOpResult(opOperand); + // Does not have to be implemented for ops without tensor OpOperands. + llvm_unreachable("getAliasingOpResult not implemented"); }] >, InterfaceMethod< diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -427,8 +427,7 @@ } /// Determine which OpResult will alias with `opOperand` if the op is bufferized -/// in place. This is a superset of `getInplaceableOpResult`. Return an empty -/// OpResult if the op is not bufferizable. +/// in place. Return an empty OpResult if the op is not bufferizable. static OpResult getAliasingOpResult(OpOperand &opOperand) { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) @@ -436,21 +435,29 @@ return OpResult(); } -/// Return `true` if the given OpOperand does not bufferize to a memory read or -/// write, but creates an alias when bufferized inplace. Return `false` if the +/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the /// op is not bufferizable. -static bool bufferizesToAliasOnly(OpOperand &opOperand) { +static bool bufferizesToMemoryRead(OpOperand &opOperand) { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) - return bufferizableOp.bufferizesToAliasOnly(opOperand); + return bufferizableOp.bufferizesToMemoryRead(opOperand); // Unknown op that returns a tensor. The inplace analysis does not support it. - // Conservatively return false. - return false; + // Conservatively return true. + return true; } -// Predeclaration of function. -static bool bufferizesToMemoryRead(OpOperand &opOperand); +/// Return true if `opOperand` bufferizes to a memory write. Return +/// `true` if the op is not bufferizable. +static bool bufferizesToMemoryWrite(OpOperand &opOperand) { + if (auto bufferizableOp = + dyn_cast(opOperand.getOwner())) + return bufferizableOp.bufferizesToMemoryWrite(opOperand); + + // Unknown op that returns a tensor. The inplace analysis does not support it. + // Conservatively return true. + return true; +} /// Return true if the given value is read by an op that bufferizes to a memory /// read. Also takes into account ops that create an alias but do not read by @@ -462,8 +469,9 @@ while (!workingSet.empty()) { OpOperand *uMaybeReading = workingSet.pop_back_val(); - // Skip over all ops that create an alias but do not read. - if (bufferizesToAliasOnly(*uMaybeReading)) + // Skip over all ops that neither read nor write (but create an alias). + if (!bufferizesToMemoryRead(*uMaybeReading) && + !bufferizesToMemoryWrite(*uMaybeReading)) for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses()) workingSet.push_back(&use); if (bufferizesToMemoryRead(*uMaybeReading)) @@ -473,30 +481,6 @@ return false; } -/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the -/// op is not bufferizable. -static bool bufferizesToMemoryRead(OpOperand &opOperand) { - if (auto bufferizableOp = - dyn_cast(opOperand.getOwner())) - return bufferizableOp.bufferizesToMemoryRead(opOperand); - - // Unknown op that returns a tensor. The inplace analysis does not support it. - // Conservatively return true. - return true; -} - -/// Return true if `opOperand` bufferizes to a memory write. Return -/// `true` if the op is not bufferizable. -static bool bufferizesToMemoryWrite(OpOperand &opOperand) { - if (auto bufferizableOp = - dyn_cast(opOperand.getOwner())) - return bufferizableOp.bufferizesToMemoryWrite(opOperand); - - // Unknown op that returns a tensor. The inplace analysis does not support it. - // Conservatively return true. - return true; -} - /// Return the relationship between the operand and the its corresponding /// OpResult that it may alias with. Return None if the op is not bufferizable. static BufferRelation bufferRelation(OpOperand &opOperand) { @@ -1496,8 +1480,12 @@ return success(); } -/// This analysis function is used for OpOperands that alias with an OpResult -/// but are not inplaceable on it. E.g., ExtractSliceOp. +/// Determine if `operand` can be bufferized in-place with one of the op's +/// results. If so, set InPlaceSpec::True on the result. Otherwise, set +/// InPlaceSpec::False on the result. +/// +/// Even if an op does not read or write, it may still create an alias when +/// bufferized in-place. An example of such ops is tensor.extract_slice. /// /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace: /// @@ -1512,27 +1500,13 @@ /// An analysis is required to ensure inplace bufferization would not result in /// RaW dependence violations. static LogicalResult -bufferizableInPlaceAnalysisAliasOnlyOp(OpOperand &operand, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { - auto bufferizableOp = dyn_cast(operand.getOwner()); - assert(bufferizableOp && "expected op with known bufferization behavior"); - OpResult result = bufferizableOp.getAliasingOpResult(operand); - assert(result && "expected that the OpOperand has an aliasing OpResult"); - return bufferizableInPlaceAnalysisImpl(operand, result, aliasInfo, domInfo); -} - -/// Determine if `operand` can be bufferized in-place with one of the op's -/// results. If so, set InPlaceSpec::True on the result. Otherwise, set -/// InPlaceSpec::False on the result. -static LogicalResult bufferizableInPlaceAnalysis(OpOperand &operand, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { auto bufferizableOp = dyn_cast(operand.getOwner()); if (!bufferizableOp) return success(); - if (OpResult result = bufferizableOp.getInplaceableOpResult(operand)) + if (OpResult result = bufferizableOp.getAliasingOpResult(operand)) return bufferizableInPlaceAnalysisImpl(operand, result, aliasInfo, domInfo); return success(); } @@ -1553,21 +1527,12 @@ } // Walk ops in reverse for better interference analysis. - for (Operation *op : reverse(ops)) { + for (Operation *op : reverse(ops)) for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) { + if (opOperand.get().getType().isa()) if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) return failure(); - // Special logic to analyze OpOperands that are not inplaceable on an - // OpResult but may create an alias. - if (bufferizesToAliasOnly(opOperand)) - if (failed(bufferizableInPlaceAnalysisAliasOnlyOp( - opOperand, aliasInfo, domInfo))) - return failure(); - } - } - return success(); } @@ -2226,8 +2191,8 @@ aliasInfo.createAliasInfoEntry(extractOp.result()); // Run analysis on the ExtractSliceOp. - if (failed(bufferizableInPlaceAnalysisAliasOnlyOp( - extractOp->getOpOperand(0), aliasInfo, domInfo))) + if (failed(bufferizableInPlaceAnalysis(extractOp->getOpOperand(0), + aliasInfo, domInfo))) return WalkResult::interrupt(); // Advance to the next operation. @@ -2424,7 +2389,7 @@ // matching OpOperands. for (OpOperand *opOperand : op.getOutputOperands()) { OpResult opResult = cast(op.getOperation()) - .getInplaceableOpResult(*opOperand); + .getAliasingOpResult(*opOperand); assert(opResult && "could not find correspond OpResult"); bool skipCopy = !op.payloadUsesValueFromOperand(opOperand); Value resultBuffer = @@ -2508,7 +2473,7 @@ return {genericOp.getOutputTensorOperands()[opResult.getResultNumber()]}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { auto genericOp = cast(op); if (!opOperand.get().getType().isa()) return OpResult(); @@ -2594,7 +2559,7 @@ opResult.getResultNumber())}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { auto tiledLoopOp = cast(op); return tiledLoopOp.getTiedOpResult(opOperand); } @@ -2739,7 +2704,7 @@ return false; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } @@ -2836,7 +2801,7 @@ return {&forOp.getIterOpOperands()[opResult.getResultNumber()]}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { auto forOp = cast(op); if (!opOperand.get().getType().isa()) return OpResult(); @@ -2897,7 +2862,7 @@ return isa(op->getParentOp()); } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } @@ -2985,7 +2950,7 @@ return {}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { // CallOpInterface is special, it needs to wait for the callee to be // bufferized and needs to inspect the BufferAliasInfo object. It can't // make a proper determination by itself and needs to be conservative. @@ -3016,7 +2981,7 @@ return false; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } @@ -3063,19 +3028,11 @@ return false; } - bool bufferizesToAliasOnly(Operation *op, OpOperand &opOperand) const { - return true; - } - SmallVector getAliasingOpOperand(Operation *op, OpResult opResult) const { return {&op->getOpOperand(0)}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { - return OpResult(); - } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return op->getResult(0); } @@ -3131,7 +3088,7 @@ return false; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } @@ -3166,19 +3123,11 @@ return false; } - bool bufferizesToAliasOnly(Operation *op, OpOperand &opOperand) const { - return true; - } - SmallVector getAliasingOpOperand(Operation *op, OpResult opResult) const { return {&op->getOpOperand(0) /*source*/}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { - return OpResult(); - } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return &opOperand == &op->getOpOperand(0) /*source*/ ? op->getResult(0) @@ -3256,7 +3205,7 @@ return false; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } @@ -3294,7 +3243,7 @@ return {&op->getOpOperand(1) /*dest*/}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return &opOperand == &op->getOpOperand(1) /*dest*/ ? op->getResult(0) : OpResult(); @@ -3386,7 +3335,7 @@ return false; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } @@ -3431,7 +3380,7 @@ return {&op->getOpOperand(1)}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return op->getOpResult(0);