diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -34,6 +34,11 @@ same buffers. The analysis is performed on SSA use-def chains starting from function operands that are annotated with the 'inplaceable' attribute }]; + let options = [ + Option<"testAnalysisOnly", "test-analysis-only", "bool", + /*default=*/"false", + "Only runs inplaceability analysis (for testing purposes only)"> + ]; let constructor = "mlir::createLinalgComprehensiveFuncBufferizePass()"; } diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -32,6 +32,10 @@ namespace detail { LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); + +bool sameOffsetsSizesAndStrides( + OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, + llvm::function_ref cmp); } // namespace detail } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -419,6 +419,23 @@ return $_op.getOperand(getIndexOfDynamicStride(idx)); }] >, + InterfaceMethod< + /*desc=*/[{ + Return true if all `other`'s offsets, sizes and strides are the same. + Takes a custom `cmp` comparison function on OpFoldResult to avoid taking + a dialect dependence. + }], + /*retTy=*/"bool", + /*methodName=*/"isSameAs", + /*args=*/(ins "OffsetSizeAndStrideOpInterface":$other, + "llvm::function_ref":$cmp), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return detail::sameOffsetsSizesAndStrides( + ::mlir::cast<::mlir::OffsetSizeAndStrideOpInterface>( + $_op.getOperation()), other, cmp); + }] + >, ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -88,6 +88,7 @@ #include "mlir/Transforms/BufferUtils.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" #define DEBUG_TYPE "comprehensive-func-bufferize" @@ -152,12 +153,18 @@ /// analysis to determine which op results reuse the same buffer as some /// operand. OpResult getMatchingOpResult(OpOperand &opOperand) { - OpResult res = - llvm::TypeSwitch(opOperand.getOwner()) - .Case( - [&](auto op) { return getMatchingOpResult(op, opOperand); }) - .Default([&](Operation *op) { return OpResult(); }); - return res; + return llvm::TypeSwitch(opOperand.getOwner()) + // clang-format off + // Ops that perform destructive updates on operand(s) to produce + // result(s). + .Case( + [&](auto op) { return getMatchingOpResult(op, opOperand); }) + // Other ops. + .Case([&](auto op) { return OpResult(); }) + .Default([&](Operation *op) { return OpResult(); }); + // clang-format on } //===----------------------------------------------------------------------===// @@ -290,70 +297,6 @@ return bvm.lookup(key); } -//===----------------------------------------------------------------------===// -// Bufferization-specific support. -//===----------------------------------------------------------------------===// - -/// Determine whether any subsequent read of the tensor `opOperand` may occur. -/// For now, this assumes any use is a read. If any use of the tensor does not -/// properly dominate `opOperand.getOwner()`, then the tensor cannot be -/// bufferized inPlace. -// TODO: For now, this assumes any use is a read. Refine this. -bool hasInterferingTensorRead(OpOperand &opOperand, - const DominanceInfo &domInfo) { - if (!opOperand.get().getType().isa()) - return false; - for (auto &use : opOperand.get().getUses()) { - Operation *user = use.getOwner(); - - // If properly dominate, there is a clear sequence point and we can dismiss - // read. - if (domInfo.properlyDominates(user, opOperand.getOwner())) - continue; - // Otherwise, we need to analyze self-dependencies, for now just let it go. - // TODO: proper self-dependence analysis. - if (domInfo.dominates(user, opOperand.getOwner())) - continue; - if (user == opOperand.getOwner() && - use.getOperandNumber() == opOperand.getOperandNumber()) - continue; - LLVM_DEBUG(DBGS() << "found interfering read operand #" - << opOperand.getOperandNumber() - << " in op: " << *opOperand.getOwner() << "\n"); - return true; - } - LLVM_DEBUG(DBGS() << "no interfering read\n"); - return false; -} - -/// Return false if either: -/// 1. `opOperand` is produced by a constant op. For now this is assumed to be -/// bufferized to a GlobalMemrefOp that cannot be written. Generalize in the -/// future. -/// 2.`opOperand` is a BlockArgument of a FuncOp that is not known to be -/// bufferizable inplace. -/// 3.`opOperand` has an interfering tensor read. -/// Return true otherwise. -bool isBufferizableInPlace(OpOperand &opOperand, const DominanceInfo &domInfo) { - // Constant tensors are deemed not bufferizable for now. - if (auto constantOp = - dyn_cast_or_null(opOperand.get().getDefiningOp())) - return !constantOp.getResult().getType().isa(); - if (auto bbArg = opOperand.get().dyn_cast()) { - // Uses of function arguments that may not be written-to need to be copied. - // If the function argument itself is not inplaceable, early return false. - // If is is inplaceable, interfering tensor read need to be checked. - // - // TODO: better propagate the fact that we want a single clone inside the - // function. Atm every user that wants to write inplace will create its own - // alloc, irrespective of whether or not interfering reads occur. - if (isa(bbArg.getOwner()->getParentOp())) - if (getInPlace(bbArg) != InPlaceSpec::True) - return false; - } - return !hasInterferingTensorRead(opOperand, domInfo); -} - //===----------------------------------------------------------------------===// // Bufferization-specific MemRefType support. //===----------------------------------------------------------------------===// @@ -399,26 +342,6 @@ stridedLayout, addressSpace); } -//===----------------------------------------------------------------------===// -// Bufferization-specific inPlace pattern matching support. -//===----------------------------------------------------------------------===// - -/// First assign `op` if `slice.back()` isa `T`, then check condition. -/// If anything fails just return failure. Otherwise update `sliceRef` by -/// dropping `sliceRef.back()`, then return success(). -template -static LogicalResult -matchAndDropBack(ArrayRef &sliceRef, T &op, - llvm::function_ref condition = nullptr) { - if (sliceRef.empty()) - return failure(); - op = dyn_cast(sliceRef.back()); - if (!op || (condition && failed(condition(op)))) - return failure(); - sliceRef = sliceRef.drop_back(); - return success(); -} - //===----------------------------------------------------------------------===// // Bufferization-specific scoped alloc/dealloc insertion support. //===----------------------------------------------------------------------===// @@ -470,121 +393,6 @@ return casted; } -//===----------------------------------------------------------------------===// -// Bufferization-specific inPlace analysis support. -//===----------------------------------------------------------------------===// - -/// Detect the simple terminator pattern: -/// ``` -/// candidate -> ... -> inplaceable_op(candidate) -> term -/// ``` -template -static LogicalResult detectInplaceOpToTerminator(Operation *parentOp, - BlockArgument candidate, - ArrayRef slice) { - assert(parentOp && "Unexpected null parent op"); - if (!isa(parentOp)) - return failure(); - TerminatorOp terminatorOp; - // Match returnOp and update slice. - if (failed(matchAndDropBack(slice, terminatorOp))) { - LLVM_DEBUG(DBGS() << "FAIL: inplaceOpToTerm pattern -> slice must end with " - "a known terminator\n"); - return failure(); - } - return success(); -} - -/// The following uses internal knowledge of the position of tied operand / -/// results. -static void propagateInPlace(const SmallVector &initalWorklist, - const DominanceInfo &domInfo) { - LLVM_DEBUG(DBGS() << "\n\n"); - LLVM_DEBUG(DBGS() << "Start propagateInPlace from initial WL\n"); - LLVM_DEBUG(for (OpOperand *operand - : initalWorklist) DBGS() - << "WL item: " << operand->get() << " used by " - << *operand->getOwner() << "\n"); - SmallVector worklist(initalWorklist); - for (unsigned idx = 0; idx < worklist.size(); ++idx) { - // TODO: bail on subtensor/subtensor_insert and vector.transfer_read/write - // that should have been already captured in destructive update patterns? - OpOperand &operand = *worklist[idx]; - LLVM_DEBUG(DBGS() << "WL item: " << *operand.getOwner() << "\n"); - // If the owner turns out to be a CallOp without - // `kWriteableFuncBufferArgsAttrName` this will be a noop. - if (isBufferizableInPlace(operand, domInfo)) { - LLVM_DEBUG(DBGS() << "bufferizable inplace\n"); - setInPlaceOpResult(getMatchingOpResult(operand)); - } - LLVM_DEBUG(DBGS() << "propagatedInPlace: " << *operand.getOwner() << "\n"); - // use can have interfering reads that prevent it from being written inPlace - // but the values it produces are still themselves candidates for inPlace at - // their point of use. - for (Value v : operand.getOwner()->getResults()) { - LLVM_DEBUG(DBGS() << "propagate result: " << v << "\n"); - for (auto &use : v.getUses()) { - LLVM_DEBUG(DBGS() << "add use to WL: " << use.get() << "\n"); - worklist.push_back(&use); - } - } - } - LLVM_DEBUG(DBGS() << "\n\n"); -} - -static void propagateInPlace(BlockArgument &bbArg, - const DominanceInfo &domInfo) { - SmallVector worklist; - for (auto &use : bbArg.getUses()) - worklist.push_back(&use); - propagateInPlace(worklist, domInfo); -} - -/// Iterate over bbArgs of `parentOp` and determine if they are the root of a -/// known destructive update chain. Such a destructive update is related to -/// traditional loop nest + memory analysis but provides a simpler SSA use-def -/// chain-based abstraction. -static void destructiveUpdateAnalysis(Block *block, - const DominanceInfo &domInfo) { - Operation *parentOp = block->getParentOp(); - for (BlockArgument candidate : block->getArguments()) { - LLVM_DEBUG(llvm::dbgs() << "\n\n"); - LLVM_DEBUG(DBGS() << "Destructive update analysis on candidate: " - << candidate << "\nof:\n" - << *parentOp << "\n"); - - if (!candidate.getType().isa()) { - LLVM_DEBUG(DBGS() << "Not a tensor\n"); - continue; - } - - // FuncOp arguments must be inplaceable otherwise they cannot be the root of - // a destructive update chain. - if (isa(parentOp) && getInPlace(candidate) != InPlaceSpec::True) { - LLVM_DEBUG(DBGS() << "Not inplace\n"); - continue; - } - - llvm::SetVector slice; - getForwardSlice(candidate, &slice, - [&](Operation *op) { return op->getBlock() == block; }); - - LLVM_DEBUG(DBGS() << "Slice:\n"); - LLVM_DEBUG(for (auto *op : slice) DBGS() << *op << "\n"); - - bool failedDetectingDestructiveUpdate = - // func / return inplace patterns. - failed(detectInplaceOpToTerminator( - parentOp, candidate, slice.getArrayRef())); - if (failedDetectingDestructiveUpdate) { - LLVM_DEBUG(DBGS() << "Failed to detect a destructive update pattern\n"); - continue; - } - - propagateInPlace(candidate, domInfo); - } -} - //===----------------------------------------------------------------------===// // Bufferization as simple BlockAndValueMapping rewrites. //===----------------------------------------------------------------------===// @@ -748,6 +556,55 @@ return success(); } +/// Bufferize SubTensorOp to subview with optional alloc + copy depending on +/// whether or not it is marked inplaceable. +/// Note that `getMatchingOpResult` on a SubTensorOp always returns null. +/// As consequence a SubTensorOp always alloc + copy when taken in isolation. +static LogicalResult bufferize(OpBuilder &b, SubTensorOp subTensorOp, + BlockAndValueMapping &bvm) { + LLVM_DEBUG(DBGS() << "bufferize: " << *subTensorOp << "\n"); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(subTensorOp); + + Location loc = subTensorOp.getLoc(); + // Bail if source was not bufferized. + Value srcMemref = lookup(bvm, subTensorOp.source()); + if (!srcMemref) + return failure(); + auto srcMemrefType = srcMemref.getType().cast(); + auto dstTensorType = subTensorOp.result().getType().cast(); + + // If not inplaceable, alloc. + Value alloc; + auto inPlace = getInPlace(subTensorOp->getResult(0)); + if (inPlace != InPlaceSpec::True) { + alloc = + createNewAllocDeallocPairForShapedValue(b, loc, subTensorOp.result()); + b.setInsertionPointAfter(alloc.getDefiningOp()); + } + + // Bufferize to subview. + auto subviewMemRefType = + memref::SubViewOp::inferRankReducedResultType( + dstTensorType.getRank(), srcMemrefType, subTensorOp.getMixedOffsets(), + subTensorOp.getMixedSizes(), subTensorOp.getMixedStrides()) + .cast(); + Value subView = b.create( + loc, subviewMemRefType, srcMemref, subTensorOp.getMixedOffsets(), + subTensorOp.getMixedSizes(), subTensorOp.getMixedStrides()); + + /// If not inplaceable, copy. + if (alloc) { + b.create(subTensorOp.getLoc(), subView, alloc); + subView = alloc; + } + + map(bvm, subTensorOp.result(), subView); + return success(); +} + static LogicalResult bufferize(OpBuilder &b, SubTensorInsertOp subTensorInsertOp, BlockAndValueMapping &bvm) { @@ -765,7 +622,7 @@ if (inPlace != InPlaceSpec::True) { // Since subtensor_insert arise from tiling and introducing loops, this case // is generally a deal breaker. When used with loops, this ends up cloning - // the whole tensor on every single iteration and is a symtpom of a + // the whole tensor on every single iteration and is a symptom of a // catastrophically bad scheduling decision. // TODO: be very loud about it or even consider failing the pass. Value newDstMemref = createNewAllocDeallocPairForShapedValue( @@ -865,13 +722,171 @@ return success(); } +//===----------------------------------------------------------------------===// +// Functions and calls bufferization support. +//===----------------------------------------------------------------------===// + +/// Determine whether any subsequent read of the tensor `opOperand` may occur. +/// For now, this assumes any use is a read. If any use of the tensor does not +/// properly dominate `opOperand.getOwner()`, then the tensor cannot be +/// bufferized inPlace. +// TODO: For now, this assumes any use is a read. Refine this. +bool hasInterferingTensorRead(OpOperand &opOperand, + const DominanceInfo &domInfo) { + if (!opOperand.get().getType().isa()) + return false; + for (auto &use : opOperand.get().getUses()) { + Operation *user = use.getOwner(); + // If properly dominate, there is a clear sequence point and we can dismiss + // read. + if (domInfo.properlyDominates(user, opOperand.getOwner())) + continue; + // Otherwise, we need to analyze self-dependencies, for now just let it go. + // TODO: proper self-dependence analysis. + if (domInfo.dominates(user, opOperand.getOwner())) + continue; + if (user == opOperand.getOwner() && + use.getOperandNumber() == opOperand.getOperandNumber()) + continue; + LLVM_DEBUG(DBGS() << "found interfering read operand #" + << opOperand.getOperandNumber() + << " in op: " << *opOperand.getOwner() << "\n"); + return true; + } + LLVM_DEBUG(DBGS() << "no interfering read\n"); + return false; +} + +/// Return false if either: +/// 1. `opOperand` is produced by a constant op. For now this is assumed to be +/// bufferized to a GlobalMemrefOp that cannot be written. Generalize in the +/// future. +/// 2.`opOperand` is a BlockArgument of a FuncOp that is not known to be +/// bufferizable inplace. +/// Return true otherwise. +static bool bufferizeToWriteable(OpOperand &opOperand) { + // Constant tensors are deemed not bufferizable for now. + if (auto constantOp = + dyn_cast_or_null(opOperand.get().getDefiningOp())) + return !constantOp.getResult().getType().isa(); + if (auto bbArg = opOperand.get().dyn_cast()) { + // Uses of function arguments that may not be written-to need to be copied. + // If the function argument itself is not inplaceable, early return false. + // If is is inplaceable, interfering tensor read need to be checked. + // + // TODO: better propagate the fact that we want a single clone inside the + // function. Atm every user that wants to write inplace will create its own + // alloc, irrespective of whether or not interfering reads occur. + if (isa(bbArg.getOwner()->getParentOp())) { + if (getInPlace(bbArg) != InPlaceSpec::True) + return false; + } else { + // Conservatively dump any other block argument for now. + return false; + } + } + return true; +} + +/// Return false if either: +/// 1. `opOperand` is produced by a constant op. For now this is assumed to be +/// bufferized to a GlobalMemrefOp that cannot be written. Generalize in the +/// future. +/// 2.`opOperand` is a BlockArgument of a FuncOp that is not known to be +/// bufferizable inplace. +/// 3.`opOperand` has an interfering tensor read. +/// Return true otherwise. +static bool isBufferizableInPlace(OpOperand &opOperand, + const DominanceInfo &domInfo) { + return bufferizeToWriteable(opOperand) && + !hasInterferingTensorRead(opOperand, domInfo); +} + +static bool isLinalgInputOperand(OpOperand &operand) { + if (auto linalgOp = dyn_cast(operand.getOwner())) + return linalgOp.isInputTensor(&operand); + return false; +} + +/// Return true if `operand` bufferizes to a buffer that is known to never be +/// written. +static bool bufferizeToReadOnly(OpOperand &operand) { + return llvm::TypeSwitch(operand.getOwner()) + .Case([&](LinalgOp linalgOp) { return linalgOp.isInputTensor(&operand); }) + .Default([&](Operation *op) { return false; }); +} + +/// Assume operand is a use of a `subTensorOp`. +/// Return true if this use bufferizes to a buffer that is known to never be +/// written. +/// Note: This function takes into consideration uses of subTensorOp and whether +/// the owner of those uses is inplaceable. This needs to be run in postorder to +/// provide the most accurate analysis; otherwise it is conservative. +static bool subTensorUseBufferizesToReadOnly(OpOperand &operand) { + assert(operand.get().getDefiningOp() && "expected subtensor op"); + if (auto subTensorInsertOp = + dyn_cast(operand.getOwner())) { + return operand.getOperandNumber() == 0 /* source of the subTensorInsert*/ && + // If the subTensorInsertOp is not inplace, there is no possible + // internal aliasing with subTensorOp, which is inplaceable. + getInPlace(subTensorInsertOp->getResult(0)) != InPlaceSpec::True; + } + return bufferizeToReadOnly(operand); +} + +/// Return true if `dominator.getOwner()` dominates all other uses of +/// `dominator.get()`. +static bool dominatesAllOtherUses(OpOperand &dominator, + const DominanceInfo &domInfo) { + for (OpOperand &use : dominator.get().getUses()) { + // Same use. + if (use.getOwner() == dominator.getOwner() && + use.getOperandNumber() == dominator.getOperandNumber()) + continue; + if (!domInfo.properlyDominates(dominator.getOwner(), use.getOwner())) + return false; + } + return true; +} + +/// SubTensorOp introduces potential aliasing and a combination of things need +/// to occur to determine whether it is inplaceable. +static void analyzeInPlaceSubTensor(SubTensorOp subTensorOp, + const DominanceInfo &domInfo) { + // Case 1: + // a. All uses are known to bufferize to readonly buffers. + // b. The source has no use that is not dominated by subTensorOp. + // This can skip bufferizeToWriteable analysis / function boundary annotation. + if (llvm::all_of(subTensorOp.result().getUses(), + subTensorUseBufferizesToReadOnly) && + dominatesAllOtherUses(subTensorOp->getOpOperand(0), domInfo)) + return setInPlaceOpResult(subTensorOp->getResult(0), InPlaceSpec::True); + + // TODO: Implement more advanced use cases.There is a notion of transitivity + // and interference sets lurking. +} + +/// Analyze the internals of a FuncOp to determine inplaceable ops. static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp, const DominanceInfo &domInfo) { assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() && "expected a funcOp definition with a body"); - // Start propagating from FuncOp bbArgs. - destructiveUpdateAnalysis(&funcOp.body().front(), domInfo); + funcOp.walk([&](Operation *op) { + // Skip SubTensorOp in a first pass. + if (auto subTensorOp = dyn_cast(op)) + return analyzeInPlaceSubTensor(subTensorOp, domInfo); + + // All other ops are checked for `isBufferizableInPlace`. + for (OpOperand &opOperand : op->getOpOperands()) { + OpResult result = getMatchingOpResult(opOperand); + if (result && isBufferizableInPlace(opOperand, domInfo)) { + LLVM_DEBUG(DBGS() << "bufferizable inplace operand #" + << opOperand.getOperandNumber() << " in " << *op); + setInPlaceOpResult(result); + } + } + }); } static LogicalResult bufferizeFuncOpInternals( @@ -881,15 +896,22 @@ /// Start by bufferizing `funcOp` arguments. if (failed(bufferize(b, funcOp, bvm))) return failure(); - WalkResult result = funcOp.walk([&](Operation *op) { + WalkResult result = funcOp.walk([&](Operation *op) { LogicalResult status = llvm::TypeSwitch(op) // Skip BufferCast and TensorLoad ops. - .Case( + // clang-format off + .Case( [&](auto) { return success(); }) - .Case( [&](auto op) { return bufferize(b, op, bvm); }) + // clang-format on .Default([&](Operation *op) { auto isaTensor = [](Type t) { return t.isa(); }; if (llvm::any_of(op->getOperandTypes(), isaTensor) || @@ -925,8 +947,17 @@ DominanceInfo domInfo(funcOp); BlockAndValueMapping bvm; DenseMap> tiedResultsMap; + LLVM_DEBUG(llvm::dbgs() << "\n\n"); + LLVM_DEBUG(DBGS() << "Begin InPlaceAnalysisFuncOpInternals:\n" + << funcOp << "\n"); inPlaceAnalysisFuncOpInternals(funcOp, domInfo); + LLVM_DEBUG(DBGS() << "End InPlaceAnalysisFuncOpInternals:\n" + << funcOp << "\n"); + + if (testAnalysisOnly) + return; + LLVM_DEBUG(llvm::dbgs() << "\n\n"); LLVM_DEBUG(DBGS() << "Begin BufferizeFuncOpInternals:\n" << funcOp << "\n"); auto guard = llvm::make_scope_exit([&] { funcOp.walk( diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -155,3 +155,24 @@ return parseOperandsOrIntegersImpl(parser, values, integers); } + +bool mlir::detail::sameOffsetsSizesAndStrides( + OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, + llvm::function_ref cmp) { + if (a.static_offsets().size() != b.static_offsets().size()) + return false; + if (a.static_sizes().size() != b.static_sizes().size()) + return false; + if (a.static_strides().size() != b.static_strides().size()) + return false; + for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets())) + if (!cmp(std::get<0>(it), std::get<1>(it))) + return false; + for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes())) + if (!cmp(std::get<0>(it), std::get<1>(it))) + return false; + for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides())) + if (!cmp(std::get<0>(it), std::get<1>(it))) + return false; + return true; +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir @@ -1,4 +1,6 @@ // RUN: mlir-opt %s -linalg-comprehensive-func-bufferize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -linalg-comprehensive-func-bufferize=test-analysis-only -split-input-file |\ +// RUN: tee | FileCheck %s --check-prefix=ANALYSIS // CHECK-DAG: #[[$map_2d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> @@ -218,3 +220,140 @@ %r1 = linalg.fill(%A, %f0) : tensor, f32 -> tensor return %r0, %r1: tensor, tensor } + +// ----- + +// CHECK-LABEL: func @subtensor_fun +func @subtensor_fun(%A : tensor {linalg.inplaceable = true}) + -> tensor<4xf32> +{ + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref + // CHECK: %[[SV:.*]] = memref.subview %[[BUFFER_CAST_A]][0] [4] [1] + // CHECK: linalg.copy(%[[SV]], %[[ALLOC]]) + %r0 = subtensor %A[0][4][1] : tensor to tensor<4xf32> + return %r0: tensor<4xf32> +} + +// ----- + +// ANALYSIS-LABEL: func @subtensor_readonly_use +func @subtensor_readonly_use( + %A : tensor {linalg.inplaceable = true}, + %B : tensor<4x4xf32>, %C : tensor<4x4xf32>) -> tensor<4x4xf32> +{ + // subtensor is only used as a read. + // ANALYSIS: subtensor {{.*}} {__inplace_results_attr__ = ["true"]} + %sA = subtensor %A[0, 0][4, 4][1, 1] : tensor to tensor<4x4xf32> + // matmul output operand is not inplaceable at the function boundary. + // ANALYSIS: linalg.matmul {{.*}} + // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]} + %D = linalg.matmul ins(%sA, %B: tensor<4x4xf32>, tensor<4x4xf32>) + outs(%B: tensor<4x4xf32>) + -> tensor<4x4xf32> + return %D: tensor<4x4xf32> +} + +// ----- + +// ANALYSIS-LABEL: func @subtensor_nonmatching_subtensor_insert_inplace +func @subtensor_nonmatching_subtensor_insert_inplace( + %A : tensor {linalg.inplaceable = true}, %idx: index) + -> tensor +{ + // subtensor has no matching subtensor_insert and is not just used by known + // readonly ops. + // ANALYSIS: subtensor {{.*}} + // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]} + %r0 = subtensor %A[0][4][1] : tensor to tensor<4xf32> + // subtensor_insert can bufferize inplace fine. + // ANALYSIS: subtensor_insert {{.*}} {__inplace_results_attr__ = ["true"]} + %r1 = subtensor_insert %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor + return %r1: tensor +} + +// ----- + +// ANALYSIS-LABEL: func @subtensor_nonmatching_subtensor_insert_non_inplace +func @subtensor_nonmatching_subtensor_insert_non_inplace( + %A : tensor {linalg.inplaceable = false}, %idx: index) + -> tensor +{ + // subtensor has no matching subtensor_insert and is not just used by known + // readonly ops. + // ANALYSIS: subtensor {{.*}} {__inplace_results_attr__ = ["true"]} + %r0 = subtensor %A[0][4][1] : tensor to tensor<4xf32> + // subtensor_insert cannot bufferize inplace. + // ANALYSIS: subtensor_insert {{.*}} + // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]} + %r1 = subtensor_insert %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor + return %r1: tensor +} + +// ----- + +// ANALYSIS-LABEL: func @subtensor_matching_subtensor_insert +func @subtensor_matching_subtensor_insert(%A : tensor {linalg.inplaceable = true}) + -> tensor +{ + // subtensor has a matching subtensor_insert that bufferizes inplace. + // TODO: Atm subtensor is not inplaceable but can be. + // In the grander scheme, this will canonicalize away beforehand. + // ANALYSIS: subtensor {{.*}} + // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]} + %r0 = subtensor %A[0][4][1] : tensor to tensor<4xf32> + // subtensor_insert can bufferize inplace fine. + // ANALYSIS: subtensor_insert {{.*}} {__inplace_results_attr__ = ["true"]} + %r1 = subtensor_insert %r0 into %A[0][4][1] : tensor<4xf32> into tensor + return %r1: tensor +} + +// ----- + +// ANALYSIS-LABEL: func @subtensor_matching_and_nonmatching_1 +func @subtensor_matching_and_nonmatching_1(%A : tensor {linalg.inplaceable = true}, %idx: index) + -> (tensor, tensor) +{ + // %r1 is not inplaceable and %r2 is a matching subtensor_insert so %r0 could + // be inplaceable. + // In the grander scheme, %r2 will canonicalize away beforehand but %r0 will still + // not be inplaceable as the production of %r1 may involve a self-copy. + // ANALYSIS: subtensor {{.*}} + // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]} + %r0 = subtensor %A[0][4][1] : tensor to tensor<4xf32> + // ANALYSIS: subtensor_insert {{.*}} + // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]} + %r1 = subtensor_insert %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor + // ANALYSIS: subtensor_insert {{.*}} {__inplace_results_attr__ = ["true"]} + %r2 = subtensor_insert %r0 into %A[0][4][1] : tensor<4xf32> into tensor + return %r1, %r2: tensor, tensor +} + +// ----- + +// ANALYSIS-LABEL: func @subtensor_matching_and_nonmatching_2 +func @subtensor_matching_and_nonmatching_2(%A : tensor {linalg.inplaceable = true}, %idx: index) + -> (tensor, tensor) +{ + // %r1 is not inplaceable and %r2 is a matching subtensor_insert so %r0 should + // be inplaceable. + // In the grander scheme, %r2 will canonicalize away beforehand and %r0 will become + // inplaceable by reducing to the `subtensor_nonmatching_subtensor_insert_non_inplace` + // case, + // ANALYSIS: subtensor {{.*}} + // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]} + %r0 = subtensor %A[0][4][1] : tensor to tensor<4xf32> + // ANALYSIS: subtensor_insert {{.*}} + // ANALYSIS-NOT: {__inplace_results_attr__ = ["true"]} + %r2 = subtensor_insert %r0 into %A[0][4][1] : tensor<4xf32> into tensor + // ANALYSIS: subtensor_insert {{.*}} {__inplace_results_attr__ = ["true"]} + %r1 = subtensor_insert %r0 into %A[%idx][4][1] : tensor<4xf32> into tensor + + return %r1, %r2: tensor, tensor +} + +// ----- + +// TODO: unknown ops, linalg chain success, linalg chain failure. +