diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -322,6 +322,26 @@ SmallVector &newOps) = 0; }; +/// Return a contiguous MemRefType (i.e. with canonical/empty layout map) +/// with the same shape as `shapedType` and specified `layout` and +/// `addressSpace`. +MemRefType getContiguousMemRefType(ShapedType shapedType, + MemRefLayoutAttrInterface layout = {}, + Attribute memorySpace = {}); + +/// Return a contiguous MemRefType (i.e. with canonical/empty layout map) +/// with the same shape as `shapedType` and specified `layout` and +/// `addressSpace` or an UnrankedMemRefType otherwise. +Type getContiguousOrUnrankedMemRefType(Type type, + MemRefLayoutAttrInterface layout = {}, + Attribute memorySpace = {}); + +/// Return a MemRefType to which the `tensorType` can be bufferized in a +/// composable fashion. The layout must be the most dynamic possible and +/// canonicalize away once bufferization is finished. +MemRefType getDynamicMemRefType(RankedTensorType tensorType, + unsigned addressSpace = 0); + } // namespace comprehensive_bufferize } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h @@ -1,3 +1,11 @@ +//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALG_INTERFACE_IMPL_H diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h @@ -0,0 +1,27 @@ +//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H +#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H + +namespace mlir { + +class DialectRegistry; + +namespace linalg { +namespace comprehensive_bufferize { +namespace tensor_ext { + +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); + +} // namespace tensor_ext +} // namespace comprehensive_bufferize +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "llvm/Support/Debug.h" @@ -525,3 +526,31 @@ op->erase(); obsoleteOps.clear(); } + +MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType( + ShapedType shapedType, MemRefLayoutAttrInterface layout, + Attribute memorySpace) { + return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), + layout, memorySpace); +} + +Type mlir::linalg::comprehensive_bufferize::getContiguousOrUnrankedMemRefType( + Type type, MemRefLayoutAttrInterface layout, Attribute memorySpace) { + if (type.isa()) + return getContiguousMemRefType(type.cast(), layout, + memorySpace); + assert(!layout && "expected empty layout with UnrankedMemRefType"); + return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace); +} + +MemRefType mlir::linalg::comprehensive_bufferize::getDynamicMemRefType( + RankedTensorType tensorType, unsigned addressSpace) { + // TODO: address space decisions to connect with the actual alloc. + int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; + SmallVector dynamicStrides(tensorType.getRank(), + ShapedType::kDynamicStrideOrOffset); + AffineMap stridedLayout = makeStridedLinearLayoutMap( + dynamicStrides, dynamicOffset, tensorType.getContext()); + return MemRefType::get(tensorType.getShape(), tensorType.getElementType(), + stridedLayout, addressSpace); +} diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterface.cpp ComprehensiveBufferize.cpp LinalgInterfaceImpl.cpp + TensorInterfaceImpl.cpp ) add_mlir_dialect_library(MLIRBufferizableOpInterface @@ -25,6 +26,16 @@ MLIRTensor ) +add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl + TensorInterfaceImpl.cpp + + LINK_LIBS PUBLIC + MLIRBufferizableOpInterface + MLIRIR + MLIRMemRef + MLIRTensor +) + add_mlir_dialect_library(MLIRComprehensiveBufferize ComprehensiveBufferize.cpp @@ -37,6 +48,5 @@ MLIRSCF MLIRStandard MLIRStandardOpsTransforms - MLIRTensor MLIRVector ) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -587,45 +587,6 @@ // Bufferization-specific MemRefType support. //===----------------------------------------------------------------------===// -/// Return a contiguous MemRefType (i.e. with canonical/empty layout map) -/// with the same shape as `shapedType` and specified `layout` and -/// `addressSpace`. -static MemRefType getContiguousMemRefType(ShapedType shapedType, - MemRefLayoutAttrInterface layout = {}, - Attribute memorySpace = {}) { - return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), - layout, memorySpace); -} - -/// Return a contiguous MemRefType (i.e. with canonical/empty layout map) -/// with the same shape as `shapedType` and specified `layout` and -/// `addressSpace` or an UnrankedMemRefType otherwise. -static Type -getContiguousOrUnrankedMemRefType(Type type, - MemRefLayoutAttrInterface layout = {}, - Attribute memorySpace = {}) { - if (type.isa()) - return getContiguousMemRefType(type.cast(), layout, - memorySpace); - assert(!layout && "expected empty layout with UnrankedMemRefType"); - return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace); -} - -/// Return a MemRefType to which the `tensorType` can be bufferized in a -/// composable fashion. The layout must be the most dynamic possible and -/// canonicalize away once bufferization is finished. -static MemRefType getDynamicMemRefType(RankedTensorType tensorType, - unsigned addressSpace = 0) { - // TODO: address space decisions to connect with the actual alloc. - int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; - SmallVector dynamicStrides(tensorType.getRank(), - ShapedType::kDynamicStrideOrOffset); - AffineMap stridedLayout = makeStridedLinearLayoutMap( - dynamicStrides, dynamicOffset, tensorType.getContext()); - return MemRefType::get(tensorType.getShape(), tensorType.getElementType(), - stridedLayout, addressSpace); -} - /// Return the FunctionType with `argumentTypes` and `resultTypes` where each /// tensor is replaced by the corresponding buffer type. /// In order for all the callers to agree, this *must* bufferize to the most @@ -1963,420 +1924,6 @@ } // namespace std_ext -namespace tensor_ext { - -struct CastOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { - return false; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - return false; - } - - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {&op->getOpOperand(0)}; - } - - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { - return op->getResult(0); - } - - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { - return BufferRelation::Equivalent; - } - - LogicalResult bufferize(Operation *op, OpBuilder &b, - BufferizationState &state) const { - auto castOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(castOp); - - Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state); - if (!resultBuffer) - return failure(); - Type sourceType = resultBuffer.getType(); - auto rankedMemRefType = sourceType.dyn_cast(); - auto unrankedMemRefType = sourceType.dyn_cast(); - assert(rankedMemRefType || unrankedMemRefType); - Attribute memorySpace = rankedMemRefType - ? rankedMemRefType.getMemorySpace() - : unrankedMemRefType.getMemorySpace(); - TensorType tensorType = castOp.getResult().getType().cast(); - MemRefLayoutAttrInterface layout = - rankedMemRefType && tensorType.isa() - ? rankedMemRefType.getLayout() - : MemRefLayoutAttrInterface(); - Type memRefType = getContiguousOrUnrankedMemRefType( - castOp.getResult().getType(), layout, memorySpace); - Value res = - b.create(castOp.getLoc(), memRefType, resultBuffer); - state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult()); - state.mapBuffer(castOp.getResult(), res); - return success(); - } -}; - -struct DimOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - return false; - } - - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { - return OpResult(); - } - - LogicalResult bufferize(Operation *op, OpBuilder &b, - BufferizationState &state) const { - auto dimOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(dimOp); - - if (dimOp.source().getType().isa()) { - Value v = state.lookupBuffer(dimOp.source()); - dimOp.result().replaceAllUsesWith( - b.create(dimOp.getLoc(), v, dimOp.index())); - } - return success(); - } -}; - -struct ExtractSliceOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { - return false; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - return false; - } - - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {&op->getOpOperand(0) /*source*/}; - } - - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { - return &opOperand == &op->getOpOperand(0) /*source*/ - ? op->getResult(0) - : OpResult(); - } - - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { - return BufferRelation::None; - } - - LogicalResult bufferize(Operation *op, OpBuilder &b, - BufferizationState &state) const { - auto extractSliceOp = cast(op); - LDBG("bufferize: " << *extractSliceOp << '\n'); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(extractSliceOp); - - Location loc = extractSliceOp.getLoc(); - Value srcMemref = state.lookupBuffer(extractSliceOp.source()); - auto srcMemrefType = srcMemref.getType().cast(); - auto dstTensorType = - extractSliceOp.result().getType().cast(); - - // If not inplaceable, alloc. - bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0)); - Value alloc; - if (!inplace) - alloc = createNewAllocDeallocPairForShapedValue( - b, loc, extractSliceOp.result(), state); - - // Bufferize to subview. - auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( - dstTensorType.getRank(), srcMemrefType, - extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), - extractSliceOp.getMixedStrides()) - .cast(); - Value subView = b.create( - loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(), - extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); - // Insert new alias. - state.aliasInfo.insertNewBufferAlias(subView, srcMemref); - - /// If not inplaceable, copy. - if (!inplace) { - // Do not copy if the copied data is never read. - if (isValueRead(extractSliceOp.result())) - state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView, - alloc); - subView = alloc; - } - - state.mapBuffer(extractSliceOp.result(), subView); - return success(); - } -}; - -struct ExtractOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - return false; - } - - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { - return OpResult(); - } - - LogicalResult bufferize(Operation *op, OpBuilder &b, - BufferizationState &state) const { - auto extractOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(extractOp); - - Location loc = extractOp.getLoc(); - Value srcMemref = state.lookupBuffer(extractOp.tensor()); - Value l = b.create(loc, srcMemref, extractOp.indices()); - extractOp.replaceAllUsesWith(l); - return success(); - } -}; - -/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. -/// equivalent operand / result and same offset/sizes/strides specification). -/// -/// This is one particular type of relationship between ops on tensors that -/// reduce to an equivalence on buffers. This should be generalized and -/// exposed as interfaces on the proper types. -static bool -areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo, - ExtractSliceOp st, InsertSliceOp sti) { - if (!st || !sti) - return false; - if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest())) - return false; - if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) - return false; - return true; -} - -/// Return true if the source of a `insertSliceOp` bufferizes to an -/// equivalent ExtractSliceOp that bufferizes inplace. -static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp( - const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) { - LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp - << '\n'); - bool foundOp = false; - aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) { - auto extractSliceOp = value.getDefiningOp(); - if (extractSliceOp && - areEquivalentExtractSliceOps(aliasInfo, extractSliceOp, - insertSliceOp) && - aliasInfo.isInPlace(extractSliceOp->getResult(0))) { - LDBG("\tfound: " << extractSliceOp.getOperation() << '\n'); - foundOp = true; - } - }); - - if (!foundOp) - LDBG("\tnot equivalent\n"); - - return foundOp; -} - -/// Return true if `value` is originating from an ExtractSliceOp that matches -/// the given InsertSliceOp. -static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo, - Value value, InsertSliceOp insertOp) { - auto condition = [&](Value val) { - if (auto extractOp = val.getDefiningOp()) - if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp)) - return true; - return false; - }; - - return llvm::all_of(findValueInReverseUseDefChain(value, condition), - condition); -} - -struct InsertSliceOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - return &opOperand == &op->getOpOperand(1) /*dest*/; - } - - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {&op->getOpOperand(1) /*dest*/}; - } - - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { - return &opOperand == &op->getOpOperand(1) /*dest*/ - ? op->getResult(0) - : OpResult(); - } - - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { - return BufferRelation::Equivalent; - } - - bool isNotConflicting(Operation *op, OpOperand *uRead, - OpOperand *uConflictingWrite, - const BufferizationAliasInfo &aliasInfo) const { - Operation *readingOp = uRead->getOwner(); - Operation *conflictingWritingOp = uConflictingWrite->getOwner(); - - // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If - // uRead is an InsertSliceOp... - if (auto insertSliceOp = dyn_cast(readingOp)) { - // As an example, consider the following IR. - // - // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } - // %1 = linalg.fill %cst, %0 {inplace= [true] } - // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] - // {inplace= [true] } - - // TODO: Use insertSliceOp.getDestOpOperand etc. when available. - if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(), - insertSliceOp)) - // Case 1: The main insight is that InsertSliceOp reads only part of - // the destination tensor. The overwritten area is not read. If - // uConflictingWrite writes into exactly the memory location that is - // being read by uRead, this is not a conflict. - // - // In the above example: - // uRead = OpOperand 1 (%t) of tensor.insert_slice - // uConflictingWrite = OpOperand 1 (%0) of linalg.fill - // - // The read of %t does not conflict with the write of the FillOp - // (same aliases!) because the area that the FillOp operates on is - // exactly the one that is *not* read via %t. - return true; - - if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && - uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp)) - // Case 2: The read of the source tensor and the write to the dest - // tensor via an InsertSliceOp is not a conflict if the read is - // reading exactly that part of an equivalent tensor that the - // InsertSliceOp is writing. - // - // In the above example: - // uRead = OpOperand 0 (%1) of tensor.insert_slice - // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice - return true; - } - - // If uConflictingWrite is an InsertSliceOp... - if (auto insertSliceOp = dyn_cast(conflictingWritingOp)) - // As an example, consider the following IR. - // - // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } - // %1 = linalg.fill %cst, %0 {inplace= [true] } - // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] - // {inplace= [true] } - // %3 = vector.transfer_read %1, %cst - // - // In the above example: - // uRead = OpOperand 0 (%1) of vector.transfer_read - // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice - // lastWrite = %1 - // - // This is not a conflict because the InsertSliceOp overwrites the - // memory segment of %1 with the exact same data. (Effectively, there - // is no memory write here.) - if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && - aliasInfo.areEquivalentBufferizedValues(uRead->get(), - insertSliceOp.source()) && - hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(), - insertSliceOp)) - return true; - - return false; - } - - LogicalResult bufferize(Operation *op, OpBuilder &b, - BufferizationState &state) const { - // insert_slice ops arise from tiling and bufferizing them out-of-place is - // generally a deal breaker. When used with loops, this ends up cloning the - // whole tensor on every single iteration and is a symptom of a - // catastrophically bad scheduling decision. - // TODO: be very loud about it or even consider failing the pass. - auto insertSliceOp = cast(op); - LDBG("bufferize: " << *insertSliceOp << '\n'); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(insertSliceOp); - Location loc = insertSliceOp.getLoc(); - - // When bufferizing out-of-place, `getResultBuffer` allocates. - Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state); - if (!dstMemref) - return failure(); - - // A copy of the source buffer is needed if either: - // - The producer of `source` is not inplace. This is the case where a - // slice is computed out of place into the inplace full tensor. - // - The result is not inplace. This is the case where the whole tensor is - // cloned and the clone needs to be updated. - // TODO: Is this necessary? - bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp( - state.aliasInfo, insertSliceOp) || - !state.aliasInfo.isInPlace(insertSliceOp->getResult(0)); - if (needCopy) { - LDBG("insert_slice needs extra source copy: " << insertSliceOp.source() - << " -> copy\n"); - // Take a subview of the dst. - auto dstMemrefType = dstMemref.getType().cast(); - auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( - insertSliceOp.getSourceType().getRank(), dstMemrefType, - insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), - insertSliceOp.getMixedStrides()) - .cast(); - Value subView = b.create( - loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(), - insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); - // Insert new alias. - state.aliasInfo.insertNewBufferAlias(subView, dstMemref); - // Copy tensor. - Value srcMemref = state.lookupBuffer(insertSliceOp.source()); - state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, - subView); - } - - state.mapBuffer(insertSliceOp.result(), dstMemref); - return success(); - } -}; - -} // namespace tensor_ext - namespace vector_ext { struct TransferReadOpInterface @@ -2482,13 +2029,6 @@ registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return false; + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {&op->getOpOperand(0)}; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + return op->getResult(0); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + auto castOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(castOp); + + Value resultBuffer = getResultBuffer(b, castOp->getResult(0), state); + if (!resultBuffer) + return failure(); + Type sourceType = resultBuffer.getType(); + auto rankedMemRefType = sourceType.dyn_cast(); + auto unrankedMemRefType = sourceType.dyn_cast(); + assert(rankedMemRefType || unrankedMemRefType); + Attribute memorySpace = rankedMemRefType + ? rankedMemRefType.getMemorySpace() + : unrankedMemRefType.getMemorySpace(); + TensorType tensorType = castOp.getResult().getType().cast(); + MemRefLayoutAttrInterface layout = + rankedMemRefType && tensorType.isa() + ? rankedMemRefType.getLayout() + : MemRefLayoutAttrInterface(); + Type memRefType = getContiguousOrUnrankedMemRefType( + castOp.getResult().getType(), layout, memorySpace); + Value res = + b.create(castOp.getLoc(), memRefType, resultBuffer); + state.aliasInfo.insertNewBufferEquivalence(res, castOp.getResult()); + state.mapBuffer(castOp.getResult(), res); + return success(); + } +}; + +struct DimOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return false; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + auto dimOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(dimOp); + + if (dimOp.source().getType().isa()) { + Value v = state.lookupBuffer(dimOp.source()); + dimOp.result().replaceAllUsesWith( + b.create(dimOp.getLoc(), v, dimOp.index())); + } + return success(); + } +}; + +struct ExtractSliceOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return false; + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {&op->getOpOperand(0) /*source*/}; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + return &opOperand == &op->getOpOperand(0) /*source*/ + ? op->getResult(0) + : OpResult(); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::None; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + auto extractSliceOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(extractSliceOp); + + Location loc = extractSliceOp.getLoc(); + Value srcMemref = state.lookupBuffer(extractSliceOp.source()); + auto srcMemrefType = srcMemref.getType().cast(); + auto dstTensorType = + extractSliceOp.result().getType().cast(); + + // If not inplaceable, alloc. + bool inplace = state.aliasInfo.isInPlace(extractSliceOp->getResult(0)); + Value alloc; + if (!inplace) + alloc = state.allocationFns.createAllocDeallocFn( + b, loc, extractSliceOp.result(), state); + + // Bufferize to subview. + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + dstTensorType.getRank(), srcMemrefType, + extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), + extractSliceOp.getMixedStrides()) + .cast(); + Value subView = b.create( + loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(), + extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); + // Insert new alias. + state.aliasInfo.insertNewBufferAlias(subView, srcMemref); + + /// If not inplaceable, copy. + if (!inplace) { + // Do not copy if the copied data is never read. + if (isValueRead(extractSliceOp.result())) + state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView, + alloc); + subView = alloc; + } + + state.mapBuffer(extractSliceOp.result(), subView); + return success(); + } +}; + +struct ExtractOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return false; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + auto extractOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(extractOp); + + Location loc = extractOp.getLoc(); + Value srcMemref = state.lookupBuffer(extractOp.tensor()); + Value l = b.create(loc, srcMemref, extractOp.indices()); + extractOp.replaceAllUsesWith(l); + return success(); + } +}; + +/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. +/// equivalent operand / result and same offset/sizes/strides specification). +/// +/// This is one particular type of relationship between ops on tensors that +/// reduce to an equivalence on buffers. This should be generalized and +/// exposed as interfaces on the proper types. +static bool +areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo, + ExtractSliceOp st, InsertSliceOp sti) { + if (!st || !sti) + return false; + if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest())) + return false; + if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue)) + return false; + return true; +} + +/// Return true if the source of a `insertSliceOp` bufferizes to an +/// equivalent ExtractSliceOp that bufferizes inplace. +static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp( + const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) { + bool foundOp = false; + aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) { + auto extractSliceOp = value.getDefiningOp(); + if (extractSliceOp && + areEquivalentExtractSliceOps(aliasInfo, extractSliceOp, + insertSliceOp) && + aliasInfo.isInPlace(extractSliceOp->getResult(0))) { + foundOp = true; + } + }); + return foundOp; +} + +/// Return true if `value` is originating from an ExtractSliceOp that matches +/// the given InsertSliceOp. +static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo, + Value value, InsertSliceOp insertOp) { + auto condition = [&](Value val) { + if (auto extractOp = val.getDefiningOp()) + if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp)) + return true; + return false; + }; + + return llvm::all_of(findValueInReverseUseDefChain(value, condition), + condition); +} + +struct InsertSliceOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return &opOperand == &op->getOpOperand(1) /*dest*/; + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {&op->getOpOperand(1) /*dest*/}; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + return &opOperand == &op->getOpOperand(1) /*dest*/ + ? op->getResult(0) + : OpResult(); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + + bool isNotConflicting(Operation *op, OpOperand *uRead, + OpOperand *uConflictingWrite, + const BufferizationAliasInfo &aliasInfo) const { + Operation *readingOp = uRead->getOwner(); + Operation *conflictingWritingOp = uConflictingWrite->getOwner(); + + // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If + // uRead is an InsertSliceOp... + if (auto insertSliceOp = dyn_cast(readingOp)) { + // As an example, consider the following IR. + // + // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } + // %1 = linalg.fill %cst, %0 {inplace= [true] } + // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] + // {inplace= [true] } + + // TODO: Use insertSliceOp.getDestOpOperand etc. when available. + if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ && + hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(), + insertSliceOp)) + // Case 1: The main insight is that InsertSliceOp reads only part of + // the destination tensor. The overwritten area is not read. If + // uConflictingWrite writes into exactly the memory location that is + // being read by uRead, this is not a conflict. + // + // In the above example: + // uRead = OpOperand 1 (%t) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%0) of linalg.fill + // + // The read of %t does not conflict with the write of the FillOp + // (same aliases!) because the area that the FillOp operates on is + // exactly the one that is *not* read via %t. + return true; + + if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ && + uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && + hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp)) + // Case 2: The read of the source tensor and the write to the dest + // tensor via an InsertSliceOp is not a conflict if the read is + // reading exactly that part of an equivalent tensor that the + // InsertSliceOp is writing. + // + // In the above example: + // uRead = OpOperand 0 (%1) of tensor.insert_slice + // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice + return true; + } + + // If uConflictingWrite is an InsertSliceOp... + if (auto insertSliceOp = dyn_cast(conflictingWritingOp)) + // As an example, consider the following IR. + // + // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] } + // %1 = linalg.fill %cst, %0 {inplace= [true] } + // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1] + // {inplace= [true] } + // %3 = vector.transfer_read %1, %cst + // + // In the above example: + // uRead = OpOperand 0 (%1) of vector.transfer_read + // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice + // lastWrite = %1 + // + // This is not a conflict because the InsertSliceOp overwrites the + // memory segment of %1 with the exact same data. (Effectively, there + // is no memory write here.) + if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ && + aliasInfo.areEquivalentBufferizedValues(uRead->get(), + insertSliceOp.source()) && + hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(), + insertSliceOp)) + return true; + + return false; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + // insert_slice ops arise from tiling and bufferizing them out-of-place is + // generally a deal breaker. When used with loops, this ends up cloning the + // whole tensor on every single iteration and is a symptom of a + // catastrophically bad scheduling decision. + // TODO: be very loud about it or even consider failing the pass. + auto insertSliceOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(insertSliceOp); + Location loc = insertSliceOp.getLoc(); + + // When bufferizing out-of-place, `getResultBuffer` allocates. + Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), state); + if (!dstMemref) + return failure(); + + // A copy of the source buffer is needed if either: + // - The producer of `source` is not inplace. This is the case where a + // slice is computed out of place into the inplace full tensor. + // - The result is not inplace. This is the case where the whole tensor is + // cloned and the clone needs to be updated. + // TODO: Is this necessary? + bool needCopy = !isSourceEquivalentToAMatchingInplaceExtractSliceOp( + state.aliasInfo, insertSliceOp) || + !state.aliasInfo.isInPlace(insertSliceOp->getResult(0)); + if (needCopy) { + // Take a subview of the dst. + auto dstMemrefType = dstMemref.getType().cast(); + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + insertSliceOp.getSourceType().getRank(), dstMemrefType, + insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), + insertSliceOp.getMixedStrides()) + .cast(); + Value subView = b.create( + loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(), + insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); + // Insert new alias. + state.aliasInfo.insertNewBufferAlias(subView, dstMemref); + // Copy tensor. + Value srcMemref = state.lookupBuffer(insertSliceOp.source()); + state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, + subView); + } + + state.mapBuffer(insertSliceOp.result(), dstMemref); + return success(); + } +}; + +} // namespace tensor_ext +} // namespace comprehensive_bufferize +} // namespace linalg +} // namespace mlir + +void mlir::linalg::comprehensive_bufferize::tensor_ext:: + registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -49,6 +49,7 @@ MLIRStandardOpsTransforms MLIRStandardToLLVM MLIRTensor + MLIRTensorBufferizableOpInterfaceImpl MLIRTransforms MLIRTransformUtils MLIRVector diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -38,6 +39,7 @@ arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>(); registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); + tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); } }; } // end namespace diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6311,6 +6311,25 @@ ], ) +cc_library( + name = "TensorBufferizableOpInterfaceImpl", + srcs = [ + "lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp", + ], + hdrs = [ + "include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h", + ], + includes = ["include"], + deps = [ + ":BufferizableOpInterface", + ":IR", + ":MemRefDialect", + ":Support", + ":TensorDialect", + "//llvm:Support", + ], +) + td_library( name = "LinalgDocTdFiles", srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"], @@ -6530,6 +6549,7 @@ ":StandardOps", ":StandardOpsTransforms", ":Support", + ":TensorBufferizableOpInterfaceImpl", ":TensorDialect", ":TransformUtils", ":VectorOps", @@ -6560,7 +6580,6 @@ ":SCFDialect", ":StandardOps", ":Support", - ":TensorDialect", ":TransformUtils", ":VectorOps", "//llvm:Support",