diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td @@ -60,11 +60,6 @@ llvm_unreachable("bufferizesToMemoryWrite not implemented"); }] >, - // TODO: Simplify this interface by removing `bufferizesToAliasOnly` and - // `getInplaceableOpResult`. Instead, always use `getAliasingOpResult`. If - // `getAliasingOpResult` returns a non-null value, we know that an alias - // is created. If `bufferizesToMemoryRead` and `bufferizesToMemoryWrite` - // return `false`, we know that the operands "bufferizes to alias only". InterfaceMethod< /*desc=*/[{ Return `true` if the given OpOperand creates an alias but does neither @@ -73,51 +68,37 @@ be called on OpOperands that do not have a tensor type. Examples of such ops are `tensor.extract_slice` and `tensor.cast`. + + Note: This method is not meant to be reimplemented. }], /*retType=*/"bool", /*methodName=*/"bufferizesToAliasOnly", /*args=*/(ins "OpOperand &":$opOperand), /*methodBody=*/"", + // TODO: This should be in methodBody instead of defaultImplementation. + // Due to a bug in TableGen codegen, this does not compile. /*defaultImplementation=*/[{ - // For better debugging, run `bufferizesToMemoryWrite`, which fires an - // assertion when called on an op that should not have tensor - // OpOperands. - (void) cast($_op.getOperation()) - .bufferizesToMemoryWrite(opOperand); - // Return `false` by default, as most ops are not "alias only". - return false; + auto bufferizableOp = + cast($_op.getOperation()); + return !bufferizableOp.bufferizesToMemoryRead(opOperand) + && !bufferizableOp.bufferizesToMemoryWrite(opOperand) + && static_cast( + bufferizableOp.getAliasingOpResult(opOperand)); }] >, - InterfaceMethod< - /*desc=*/[{ - Return the OpResult that can bufferize in-place with a given - OpOperand. Return a null value if the OpOperand cannot bufferize - in-place. This method will never be called on OpOperands that do not - have a tensor type. - }], - /*retType=*/"OpResult", - /*methodName=*/"getInplaceableOpResult", - /*args=*/(ins "OpOperand &":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - // Does not have to be implemented for ops without tensor OpOperands. - llvm_unreachable("getInplaceableOpResult not implemented"); - }] - >, InterfaceMethod< /*desc=*/[{ Return the OpResult that aliases with a given OpOperand when - bufferized in-place. This is a superset of `getInplaceableOpResult`. - This method will never be called on OpOperands that do not have a - tensor type. + bufferized in-place. This method will never be called on OpOperands + that do not have a tensor type. }], /*retType=*/"OpResult", /*methodName=*/"getAliasingOpResult", /*args=*/(ins "OpOperand &":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return cast($_op.getOperation()) - .getInplaceableOpResult(opOperand); + // Does not have to be implemented for ops without tensor OpOperands. + llvm_unreachable("getAliasingOpResult not implemented"); }] >, InterfaceMethod< diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -427,8 +427,7 @@ } /// Determine which OpResult will alias with `opOperand` if the op is bufferized -/// in place. This is a superset of `getInplaceableOpResult`. Return an empty -/// OpResult if the op is not bufferizable. +/// in place. Return an empty OpResult if the op is not bufferizable. static OpResult getAliasingOpResult(OpOperand &opOperand) { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) @@ -436,9 +435,32 @@ return OpResult(); } -/// Return `true` if the given OpOperand does not bufferize to a memory read or -/// write, but creates an alias when bufferized inplace. Return `false` if the +/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the /// op is not bufferizable. +static bool bufferizesToMemoryRead(OpOperand &opOperand) { + if (auto bufferizableOp = + dyn_cast(opOperand.getOwner())) + return bufferizableOp.bufferizesToMemoryRead(opOperand); + + // Unknown op that returns a tensor. The inplace analysis does not support it. + // Conservatively return true. + return true; +} + +/// Return true if `opOperand` bufferizes to a memory write. Return +/// `true` if the op is not bufferizable. +static bool bufferizesToMemoryWrite(OpOperand &opOperand) { + if (auto bufferizableOp = + dyn_cast(opOperand.getOwner())) + return bufferizableOp.bufferizesToMemoryWrite(opOperand); + + // Unknown op that returns a tensor. The inplace analysis does not support it. + // Conservatively return true. + return true; +} + +/// Return true if `opOperand` does neither read nor write but bufferizes to an +/// alias. Return false if the op is not bufferizable. static bool bufferizesToAliasOnly(OpOperand &opOperand) { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) @@ -449,9 +471,6 @@ return false; } -// Predeclaration of function. -static bool bufferizesToMemoryRead(OpOperand &opOperand); - /// Return true if the given value is read by an op that bufferizes to a memory /// read. Also takes into account ops that create an alias but do not read by /// themselves (e.g., ExtractSliceOp). @@ -462,7 +481,7 @@ while (!workingSet.empty()) { OpOperand *uMaybeReading = workingSet.pop_back_val(); - // Skip over all ops that create an alias but do not read. + // Skip over all ops that neither read nor write (but create an alias). if (bufferizesToAliasOnly(*uMaybeReading)) for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses()) workingSet.push_back(&use); @@ -473,30 +492,6 @@ return false; } -/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the -/// op is not bufferizable. -static bool bufferizesToMemoryRead(OpOperand &opOperand) { - if (auto bufferizableOp = - dyn_cast(opOperand.getOwner())) - return bufferizableOp.bufferizesToMemoryRead(opOperand); - - // Unknown op that returns a tensor. The inplace analysis does not support it. - // Conservatively return true. - return true; -} - -/// Return true if `opOperand` bufferizes to a memory write. Return -/// `true` if the op is not bufferizable. -static bool bufferizesToMemoryWrite(OpOperand &opOperand) { - if (auto bufferizableOp = - dyn_cast(opOperand.getOwner())) - return bufferizableOp.bufferizesToMemoryWrite(opOperand); - - // Unknown op that returns a tensor. The inplace analysis does not support it. - // Conservatively return true. - return true; -} - /// Return the relationship between the operand and the its corresponding /// OpResult that it may alias with. Return None if the op is not bufferizable. static BufferRelation bufferRelation(OpOperand &opOperand) { @@ -1493,8 +1488,12 @@ return success(); } -/// This analysis function is used for OpOperands that alias with an OpResult -/// but are not inplaceable on it. E.g., ExtractSliceOp. +/// Determine if `operand` can be bufferized in-place with one of the op's +/// results. If so, set InPlaceSpec::True on the result. Otherwise, set +/// InPlaceSpec::False on the result. +/// +/// Even if an op does not read or write, it may still create an alias when +/// bufferized in-place. An example of such ops is tensor.extract_slice. /// /// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace: /// @@ -1509,27 +1508,13 @@ /// An analysis is required to ensure inplace bufferization would not result in /// RaW dependence violations. static LogicalResult -bufferizableInPlaceAnalysisAliasOnlyOp(OpOperand &operand, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { - auto bufferizableOp = dyn_cast(operand.getOwner()); - assert(bufferizableOp && "expected op with known bufferization behavior"); - OpResult result = bufferizableOp.getAliasingOpResult(operand); - assert(result && "expected that the OpOperand has an aliasing OpResult"); - return bufferizableInPlaceAnalysisImpl(operand, result, aliasInfo, domInfo); -} - -/// Determine if `operand` can be bufferized in-place with one of the op's -/// results. If so, set InPlaceSpec::True on the result. Otherwise, set -/// InPlaceSpec::False on the result. -static LogicalResult bufferizableInPlaceAnalysis(OpOperand &operand, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { auto bufferizableOp = dyn_cast(operand.getOwner()); if (!bufferizableOp) return success(); - if (OpResult result = bufferizableOp.getInplaceableOpResult(operand)) + if (OpResult result = bufferizableOp.getAliasingOpResult(operand)) return bufferizableInPlaceAnalysisImpl(operand, result, aliasInfo, domInfo); return success(); } @@ -1550,21 +1535,12 @@ } // Walk ops in reverse for better interference analysis. - for (Operation *op : reverse(ops)) { + for (Operation *op : reverse(ops)) for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) { + if (opOperand.get().getType().isa()) if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo))) return failure(); - // Special logic to analyze OpOperands that are not inplaceable on an - // OpResult but may create an alias. - if (bufferizesToAliasOnly(opOperand)) - if (failed(bufferizableInPlaceAnalysisAliasOnlyOp( - opOperand, aliasInfo, domInfo))) - return failure(); - } - } - return success(); } @@ -2223,8 +2199,8 @@ aliasInfo.createAliasInfoEntry(extractOp.result()); // Run analysis on the ExtractSliceOp. - if (failed(bufferizableInPlaceAnalysisAliasOnlyOp( - extractOp->getOpOperand(0), aliasInfo, domInfo))) + if (failed(bufferizableInPlaceAnalysis(extractOp->getOpOperand(0), + aliasInfo, domInfo))) return WalkResult::interrupt(); // Advance to the next operation. @@ -2421,7 +2397,7 @@ // matching OpOperands. for (OpOperand *opOperand : op.getOutputOperands()) { OpResult opResult = cast(op.getOperation()) - .getInplaceableOpResult(*opOperand); + .getAliasingOpResult(*opOperand); assert(opResult && "could not find correspond OpResult"); bool skipCopy = !op.payloadUsesValueFromOperand(opOperand); Value resultBuffer = @@ -2505,7 +2481,7 @@ return {genericOp.getOutputTensorOperands()[opResult.getResultNumber()]}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { auto genericOp = cast(op); if (!opOperand.get().getType().isa()) return OpResult(); @@ -2591,7 +2567,7 @@ opResult.getResultNumber())}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { auto tiledLoopOp = cast(op); return tiledLoopOp.getTiedOpResult(opOperand); } @@ -2736,7 +2712,7 @@ return false; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } @@ -2831,7 +2807,7 @@ return {&forOp.getIterOpOperands()[opResult.getResultNumber()]}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { auto forOp = cast(op); if (!opOperand.get().getType().isa()) return OpResult(); @@ -2888,7 +2864,7 @@ return false; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } @@ -2972,7 +2948,7 @@ return {}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { // CallOpInterface is special, it needs to wait for the callee to be // bufferized and needs to inspect the BufferAliasInfo object. It can't // make a proper determination by itself and needs to be conservative. @@ -3003,7 +2979,7 @@ return false; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } @@ -3050,19 +3026,11 @@ return false; } - bool bufferizesToAliasOnly(Operation *op, OpOperand &opOperand) const { - return true; - } - SmallVector getAliasingOpOperand(Operation *op, OpResult opResult) const { return {&op->getOpOperand(0)}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { - return OpResult(); - } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return op->getResult(0); } @@ -3118,7 +3086,7 @@ return false; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } @@ -3153,19 +3121,11 @@ return false; } - bool bufferizesToAliasOnly(Operation *op, OpOperand &opOperand) const { - return true; - } - SmallVector getAliasingOpOperand(Operation *op, OpResult opResult) const { return {&op->getOpOperand(0) /*source*/}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { - return OpResult(); - } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return &opOperand == &op->getOpOperand(0) /*source*/ ? op->getResult(0) @@ -3243,7 +3203,7 @@ return false; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } @@ -3281,7 +3241,7 @@ return {&op->getOpOperand(1) /*dest*/}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return &opOperand == &op->getOpOperand(1) /*dest*/ ? op->getResult(0) : OpResult(); @@ -3373,7 +3333,7 @@ return false; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { return OpResult(); } @@ -3418,7 +3378,7 @@ return {&op->getOpOperand(1)}; } - OpResult getInplaceableOpResult(Operation *op, OpOperand &opOperand) const { + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); return op->getOpResult(0);