diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -261,6 +261,8 @@ return getOperation()->getNumOperands() - getNumControlOperands(); } /// Get the region iter arg that corresponds to an OpOperand. + /// This helper prevents internal op implementation detail leakage to + /// clients by hiding the operand / block argument mapping. BlockArgument getRegionIterArgForOpOperand(OpOperand &opOperand) { assert(opOperand.getOperandNumber() >= getNumControlOperands() && "expected an iter args operand"); @@ -270,6 +272,8 @@ opOperand.getOperandNumber() - getNumControlOperands()]; } /// Get the OpOperand& that corresponds to a region iter arg. + /// This helper prevents internal op implementation detail leakage to + /// clients by hiding the operand / block argument mapping. OpOperand &getOpOperandForRegionIterArg(BlockArgument bbArg) { assert(bbArg.getArgNumber() >= getNumInductionVars() && "expected a bbArg that is not an induction variable"); @@ -278,6 +282,27 @@ return getOperation()->getOpOperand( getNumControlOperands() + bbArg.getArgNumber() - getNumInductionVars()); } + /// Get the OpResult that corresponds to an OpOperand. + /// Assert that opOperand is an iterArg. + /// This helper prevents internal op implementation detail leakage to + /// clients by hiding the operand / block argument mapping. + OpResult getResultForOpOperand(OpOperand &opOperand) { + assert(opOperand.getOperandNumber() >= getNumControlOperands() && + "expected an iter args operand"); + assert(opOperand.getOwner() == getOperation() && + "opOperand does not belong to this scf::ForOp operation"); + return getOperation()->getResult( + opOperand.getOperandNumber() - getNumControlOperands()); + } + /// Get the OpOperand& that corresponds to an OpResultOpOperand. + /// This helper prevents internal op implementation detail leakage to + /// clients by hiding the operand / block argument mapping. + OpOperand &getOpOperandForResult(OpResult opResult) { + assert(opResult.getDefiningOp() == getOperation() && + "opResult does not belong to the scf::ForOp operation"); + return getOperation()->getOpOperand( + getNumControlOperands() + opResult.getResultNumber()); + } /// Return operands used when entering the region at 'index'. These operands /// correspond to the loop iterator operands, i.e., those exclusing the diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -109,6 +109,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" @@ -206,10 +207,10 @@ .Default(None); } -/// Mark whether OpResult can actually be bufferized inplace. If `inPlace` is -/// `InPlaceSpec::True`, the use-def chain analysis has guaranteed that no -/// subsequent write would occur to the bufferized tensor value (i.e. the result -/// can be bufferized inPlace). +/// Mark whether OpResult can actually be bufferized inplace. +/// If `inPlace` is `InPlaceSpec::True`, the use-def chain analysis has +/// guaranteed that no subsequent write would occur to the bufferized +/// tensor value (i.e. the result can be bufferized inPlace). static void setInPlaceOpResult(OpResult opResult, InPlaceSpec inPlace = InPlaceSpec::True) { if (!opResult) @@ -252,16 +253,26 @@ } /// Get inPlace information for `bbArg`. -/// If it does not come from a function, return InPlaceSpec::False. +/// FuncOp allow argument attributes, we use those to encode the information. +/// BlockArgument of other ops delegate to their owner's parent op. static InPlaceSpec getInPlace(BlockArgument bbArg) { - auto funcOp = dyn_cast(bbArg.getOwner()->getParentOp()); - if (!funcOp) - return InPlaceSpec::False; - auto attr = funcOp.getArgAttrOfType( - bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName); - if (!attr) - return InPlaceSpec::None; - return attr.getValue() ? InPlaceSpec::True : InPlaceSpec::False; + if (auto funcOp = dyn_cast(bbArg.getOwner()->getParentOp())) { + BoolAttr inplaceAttr = funcOp.getArgAttrOfType( + bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName); + if (!inplaceAttr) + return InPlaceSpec::None; + return inplaceAttr.getValue() ? InPlaceSpec::True : InPlaceSpec::False; + } + // Interestingly, scf::ForOp's bbArg can **always** be viewed inplace from the + // perspective of ops nested under it: + // 1. Either the matching iter operand is not bufferized inplace and an + // alloc + optional copy makes the bbArg itself inplaceable. + // 2. Or the matching iter operand is bufferized inplace and bbArg just + // bufferizes to that too. + if (auto forOp = dyn_cast(bbArg.getOwner()->getParentOp())) + return InPlaceSpec::True; + // Unknown cases. + return InPlaceSpec::None; } LLVM_ATTRIBUTE_UNUSED static InPlaceSpec getInPlace(Value v) { @@ -293,11 +304,13 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) { return // clang-format off - isa(op) + VectorTransferOpInterface, + scf::YieldOp>(op) // clang-format on || (none_of(op->getResultTypes(), [](Type t) { return t.isa(); }) && @@ -305,6 +318,15 @@ [](Type t) { return t.isa(); })); } +/// Return the OpResult that may bufferize into the same buffer as `opOperand` +/// when the op is bufferized inplace. +/// Return null if no such result exists. +static OpResult getInplaceableOpResult(scf::ForOp forOp, OpOperand &opOperand) { + if (!opOperand.get().getType().isa()) + return OpResult(); + return forOp.getResultForOpOperand(opOperand); +} + /// Return the OpResult that may bufferize into the same buffer as `opOperand` /// when the op is bufferized inplace. /// Return null if no such result exists. @@ -355,7 +377,8 @@ // clang-format off // Ops that perform destructive updates on operand(s) to produce // result(s). - .Case( [&](auto op) { return getInplaceableOpResult(op, opOperand); }) @@ -377,12 +400,15 @@ if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner())) return None; return TypeSwitch(opOperand.getOwner()) - // ReturnOp has no result. - .Case([&](ReturnOp op) { return OpResult(); }) + // These terminators legitimately have no result. + .Case( + [&](auto op) { return OpResult(); }) // ExtractSliceOp is different: its result is not inplaceable on op.source // but when bufferized inplace, the result is an aliasing subregion of // op.source. .Case([&](ExtractSliceOp op) { return op->getResult(0); }) + // All other ops, including scf::ForOp, return the result of + // `getInplaceableOpResult`. .Default( [&](Operation *op) { return getInplaceableOpResult(opOperand); }); } @@ -398,6 +424,10 @@ // may. if (isa(opOperand.getOwner())) return false; + // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its + // matching bbArg may. + if (isa(opOperand.getOwner())) + return false; if (auto linalgOp = dyn_cast(opOperand.getOwner())) return linalgOp.isInputTensor(&opOperand) || linalgOp.isInitTensor(&opOperand); @@ -422,8 +452,8 @@ // This does not bufferize to a write. if (!*maybeOpResult) return false; - // A ReturnOp is not a write. - if (isa(opOperand.getOwner())) + // These terminators are not writes. + if (isa(opOperand.getOwner())) return false; // ExtractSliceOp alone doesn't bufferize to a memory write, one of its uses // may. @@ -472,10 +502,14 @@ /// to some use that would bufferize to a write to a buffer. bool aliasesInPlaceWrite(ExtractSliceOp extractSliceOp) const; + /// Set the inPlace bufferization spec to true. /// Merge result's and operand's aliasing sets and iterate to a fixed point. void bufferizeInPlace(OpResult result, OpOperand &operand, BufferRelation bufferRelation = BufferRelation::None); + /// Set the inPlace bufferization spec to false. + void bufferizeOutOfPlace(OpResult result); + /// Return true if it is possible to find an inplace write W among the uses of /// aliasInfo[rootWrite], and a read R among the uses of aliasInfo[rootRead], /// such that W and R interfere. @@ -496,7 +530,13 @@ bool existsNonDominatingRead(OpOperand &opOperand, const DominanceInfo &domInfo) const; - /// Return true if the source of a `insertSliceOp` bufferizes to an + /// Return true if `v1` and `v2` bufferize to equivalent buffers. + bool areEquivalentBufferizedValues(Value v1, Value v2) const { + return equivalentInfo.getLeaderValue(v1) == + equivalentInfo.getLeaderValue(v2); + } + + /// Return true if the source of an `insertSliceOp` bufferizes to an /// equivalent ExtractSliceOp. bool isSourceEquivalentToAMatchingExtractSliceOp( InsertSliceOp insertSliceOp) const; @@ -601,14 +641,6 @@ } // namespace BufferizationAliasInfo::BufferizationAliasInfo(FuncOp funcOp) { - for (auto bbArg : funcOp.getArguments()) { - if (!bbArg.getType().isa()) - continue; - DenseSet selfSet; - selfSet.insert(bbArg); - aliasInfo.try_emplace(bbArg, selfSet); - equivalentInfo.insert(bbArg); - } funcOp.walk([&](Operation *op) { for (Value v : op->getResults()) { if (!v.getType().isa()) @@ -620,6 +652,18 @@ aliasInfo.try_emplace(v, selfSet); equivalentInfo.insert(v); } + for (Region &r : op->getRegions()) { + for (Block &b : r.getBlocks()) { + for (auto bbArg : b.getArguments()) { + if (!bbArg.getType().isa()) + continue; + DenseSet selfSet; + selfSet.insert(bbArg); + aliasInfo.try_emplace(bbArg, selfSet); + equivalentInfo.insert(bbArg); + } + } + } }); } @@ -634,13 +678,10 @@ for (Value v : getAliasInfoRef(operand.get())) { LDBG("-----------examine: " << v << '\n'); if (auto bbArg = v.dyn_cast()) { - // Uses of function arguments that may be written-to can be skipped. - if (isa(bbArg.getOwner()->getParentOp()) && - getInPlace(bbArg) == InPlaceSpec::True) { + if (getInPlace(bbArg) == InPlaceSpec::True) { LDBG("-----------bbArg is writeable -> skip: " << bbArg << '\n'); continue; } - // Conservatively dump any other block argument for now. LDBG("-----------notWriteable: " << v << '\n'); return true; } @@ -675,14 +716,23 @@ return false; } +/// Set the inPlace bufferization spec to true. /// Merge result's and operand's aliasing sets and iterates to a fixed point. void BufferizationAliasInfo::bufferizeInPlace(OpResult result, OpOperand &operand, BufferRelation bufferRelation) { + setInPlaceOpResult(result, InPlaceSpec::True); if (mergeAliases(result, operand.get())) mergeAliasesToFixedPoint(); if (bufferRelation == BufferRelation::Equivalent) equivalentInfo.unionSets(result, operand.get()); + // Dump the updated analysis. + LLVM_DEBUG(dump()); +} + +/// Set the inPlace bufferization spec to false. +void BufferizationAliasInfo::bufferizeOutOfPlace(OpResult result) { + setInPlaceOpResult(result, InPlaceSpec::False); } /// Return true if merging the alias sets of `rootWrite` and `rootRead` would @@ -1217,6 +1267,44 @@ return success(); } +static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp, + BlockAndValueMapping &bvm, + const BufferizationAliasInfo &aliasInfo) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + Location loc = forOp.getLoc(); + + LLVM_DEBUG(DBGS() << "bufferize: " << *forOp << "\n"); + + // If inPlace, just forward the buffer. + // Otherwise alloc and copy. + b.setInsertionPoint(forOp); + for (OpResult opResult : forOp->getResults()) { + // TODO: Atm we bail on unranked TensorType because we don't know how to + // alloc an UnrankedMemRefType + its underlying ranked MemRefType. + if (!opResult.getType().isa()) + return failure(); + OpOperand &opOperand = forOp.getOpOperandForResult(opResult); + Value operand = opOperand.get(); + Value operandBuffer = lookup(bvm, operand); + Value resultBuffer = operandBuffer; + if (getInPlace(opResult) != InPlaceSpec::True) { + resultBuffer = createNewAllocDeallocPairForShapedValue(b, loc, operand); + // If the tensor comes from `linalg::InitTensorOp`, the value is + // unitialized and we do not need to copy. + // TODO: if the matching bbArg does not bufferize to a read is more + // general. + if (!operand.getDefiningOp()) + b.create(forOp.getLoc(), operandBuffer, resultBuffer); + } + BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); + map(bvm, bbArg, resultBuffer); + map(bvm, opResult, resultBuffer); + } + + return success(); +} + /// FuncOp always creates TensorToMemRef ops. static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp, BlockAndValueMapping &bvm, @@ -1429,6 +1517,31 @@ return success(); } +static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp, + BlockAndValueMapping &bvm, + const BufferizationAliasInfo &aliasInfo) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(yieldOp); + + scf::ForOp forOp = dyn_cast(yieldOp->getParentOp()); + assert(forOp && "only support scf::ForOp parent for scf::YieldOp"); + for (OpOperand &operand : yieldOp->getOpOperands()) { + auto tensorType = operand.get().getType().dyn_cast(); + if (!tensorType) + continue; + OpOperand &forOperand = forOp.getOpOperandForResult( + forOp->getResult(operand.getOperandNumber())); + auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); + if (getInPlace(bbArg) == InPlaceSpec::True) + operand.set(bbArg); + else + operand.set( + b.create(yieldOp.getLoc(), lookup(bvm, bbArg))); + } + return success(); +} + //===----------------------------------------------------------------------===// // Bufferization analyses. //===----------------------------------------------------------------------===// @@ -1447,11 +1560,12 @@ /// /// An analysis is required to ensure inplace bufferization would not result in /// RaW dependence violations. -static void bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { +static LogicalResult +bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp, + BufferizationAliasInfo &aliasInfo, + const DominanceInfo &domInfo) { LDBG('\n'); - LDBG("Try to bufferize extract_slice inplace: " << *extractSliceOp << '\n'); + LDBG("Inplace analysis for extract_slice: " << *extractSliceOp << '\n'); // If `extractSliceOp` were to be bufferized inplace, it cannot end up // aliasing a write into a non-writeable buffer. @@ -1461,35 +1575,38 @@ if (wouldCreateAliasingWriteToNonWriteableBuffer) LDBG("->the corresponding buffer is not writeable\n"); - LDBG("->bufferizes to writeable inplace buffer\n"); + else + LDBG("->bufferizes to writeable inplace buffer\n"); // In any of extractSliceOp.result's aliases, can we find 2 such that we hit // an interfering write? - Value s = extractSliceOp.source(), r = extractSliceOp.result(); + OpResult r = extractSliceOp->getResult(0); + OpOperand &s = extractSliceOp->getOpOperand(0); bool foundInterference = wouldCreateAliasingWriteToNonWriteableBuffer || // Do not consider (s, s) and (r, r) as all the // aliasings already exist by construction; we are // interested in new interfering aliases only. aliasInfo.wouldCreateReadAfterWriteInterference( - s, r, extractSliceOp, domInfo) || + s.get(), r, extractSliceOp, domInfo) || aliasInfo.wouldCreateReadAfterWriteInterference( - r, s, extractSliceOp, domInfo); - if (foundInterference) { - setInPlaceOpResult(extractSliceOp->getResult(0), InPlaceSpec::False); - } else { - setInPlaceOpResult(extractSliceOp->getResult(0), InPlaceSpec::True); - aliasInfo.bufferizeInPlace(extractSliceOp->getResult(0), - extractSliceOp->getOpOperand(0)); - } - LDBG("Done bufferizing extract_slice\n"); + r, s.get(), extractSliceOp, domInfo); + if (foundInterference) + aliasInfo.bufferizeOutOfPlace(r); + else + aliasInfo.bufferizeInPlace(r, s); + + LDBG("Done inplace analysis for extract_slice\n"); + + return success(); } /// Analyze the (opOperand, result) pair to determine whether the result can /// be bufferized inPlace. If successful, InPlaceSpec::True is set for /// `result`. Otherwise, InPlaceSpec::False is set for `result`. -static void bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { +static LogicalResult +bufferizableInPlaceAnalysis(OpOperand &operand, OpResult result, + BufferizationAliasInfo &aliasInfo, + const DominanceInfo &domInfo) { Operation *op = result.getDefiningOp(); assert(result && !isa(op) && "expected OpResult not coming from a ExtractSliceOp"); @@ -1497,9 +1614,9 @@ int64_t resultNumber = result.getResultNumber(); (void)resultNumber; LDBG('\n'); - LDBG("Try to bufferize inplace result #" - << resultNumber << " (operand #" << operand.getOperandNumber() << ") in " - << result << '\n'); + LDBG("Inplace analysis for result #" << resultNumber << " (operand #" + << operand.getOperandNumber() << ") in " + << result << '\n'); // `result` must bufferize to a writeable buffer to be a candidate. // This means the use->def chain not backpropagate to a function that is @@ -1508,7 +1625,8 @@ aliasInfo.aliasesNonWriteableBuffer(operand); if (wouldCreateAliasingWriteToNonWriteableBuffer) LDBG("->the corresponding buffer is not writeable\n"); - LDBG("->bufferizes to writeable inplace buffer\n"); + else + LDBG("->bufferizes to writeable inplace buffer\n"); Value s = operand.get(), r = result; bool foundInterference = @@ -1520,22 +1638,56 @@ aliasInfo.wouldCreateReadAfterWriteInterference(s, r, op, domInfo) || aliasInfo.wouldCreateReadAfterWriteInterference(r, s, op, domInfo); - if (foundInterference) { - setInPlaceOpResult(result, InPlaceSpec::False); - } else { - setInPlaceOpResult(result, InPlaceSpec::True); + if (foundInterference) + aliasInfo.bufferizeOutOfPlace(result); + else // TODO: Atm, all inplace bufferizations yield equivalent tensors. Support // more cases on a per-need basis. aliasInfo.bufferizeInPlace( result, operand, BufferizationAliasInfo::BufferRelation::Equivalent); + + LDBG("Done inplace analysis for result #" << resultNumber << '\n'); + + return success(); +} + +/// Return `failure()` if either +/// scf::YieldOp are not explicitly bufferized and we need to perform a separate +/// sanity check for now. +static LogicalResult +bufferizationSanityCheck(scf::YieldOp yieldOp, + const BufferizationAliasInfo &aliasInfo) { + auto parentForOp = yieldOp->getParentOfType(); + if (!parentForOp) + return failure(); + + for (OpOperand &operand : yieldOp->getOpOperands()) { + OpResult matchingForOpResult = + parentForOp->getResult(operand.getOperandNumber()); + // Nothing to do if operand bufferizes out of place. + if (getInPlace(matchingForOpResult) != InPlaceSpec::True) + continue; + OpOperand &machingForOpOperand = + parentForOp.getOpOperandForResult(matchingForOpResult); + BlockArgument matchingForOpIterArg = + parentForOp.getRegionIterArgForOpOperand(machingForOpOperand); + if (!aliasInfo.areEquivalentBufferizedValues(matchingForOpIterArg, + operand.get())) { + yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to an equivalent buffer to the matching" + << " enclosing scf::for operand -> Fail the pass\n"; + return failure(); + } } - LDBG("Done bufferizing result #" << resultNumber << '\n'); + + return success(); } /// Analyze the `funcOp` body to determine which OpResults are inplaceable. -static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { +static LogicalResult +inPlaceAnalysisFuncOpInternals(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, + const DominanceInfo &domInfo) { LLVM_DEBUG(llvm::dbgs() << "\n\n"); LDBG("Begin InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() && @@ -1565,9 +1717,10 @@ // Walk InsertSliceOp in reverse for better interference behavior. for (InsertSliceOp insertSliceOp : reverse(insertSliceOps)) { OpOperand &destOpOperand = insertSliceOp->getOpOperand(1); - bufferizableInPlaceAnalysis(destOpOperand, - getInplaceableOpResult(destOpOperand), - aliasInfo, domInfo); + if (failed(bufferizableInPlaceAnalysis( + destOpOperand, getInplaceableOpResult(destOpOperand), aliasInfo, + domInfo))) + return failure(); } // Bufferize all ops except ExtractSliceOp and InsertSliceOp which are handled @@ -1576,15 +1729,25 @@ for (Operation *op : reverse(nonSliceOps)) for (OpOperand &opOperand : op->getOpOperands()) if (OpResult result = getInplaceableOpResult(opOperand)) - bufferizableInPlaceAnalysis(opOperand, result, aliasInfo, domInfo); + if (failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo, + domInfo))) + return failure(); // Finally, bufferize ExtractSliceOp. // Walk ExtractSliceOps in reverse for better clobbering behavior: it is // easier to detect clobbers of smaller slices before larger ones. for (ExtractSliceOp extractSliceOp : reverse(extractSliceOps)) - bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo); + if (failed(bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo))) + return failure(); + + // Sanity checks. + auto walkResult = funcOp.walk([&](scf::YieldOp yieldOp) -> WalkResult { + return bufferizationSanityCheck(yieldOp, aliasInfo); + }); LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); + + return success(!walkResult.wasInterrupted()); } //===----------------------------------------------------------------------===// @@ -1600,7 +1763,8 @@ /// Start by bufferizing `funcOp` arguments. if (failed(bufferize(b, funcOp, bvm, aliasInfo))) return failure(); - WalkResult result = funcOp.walk([&](Operation *op) { + // Walk in PreOrder to ensure ops with regions are handled before their body. + WalkResult result = funcOp.walk([&](Operation *op) { LogicalResult status = TypeSwitch(op) // Skip BufferCast and TensorLoad ops. @@ -1609,12 +1773,17 @@ memref::TensorLoadOp>( [&](auto) { return success(); }) .Case( - [&](auto op) { return bufferize(b, op, bvm, aliasInfo); }) + VectorTransferOpInterface, + scf::YieldOp>( + [&](auto op) { + LDBG("Begin buferize:\n" << op << '\n'); + return bufferize(b, op, bvm, aliasInfo); + }) // clang-format on .Default([&](Operation *op) { auto isaTensor = [](Type t) { return t.isa(); }; @@ -1652,7 +1821,12 @@ // Analysis phase. DominanceInfo domInfo(funcOp); BufferizationAliasInfo aliasInfo(funcOp); - inPlaceAnalysisFuncOpInternals(funcOp, aliasInfo, domInfo); + // If the analysis fails, just return. This is expected to reset the IR and no + // single OpResult should be marked inPlace. + if (failed(inPlaceAnalysisFuncOpInternals(funcOp, aliasInfo, domInfo))) { + signalPassFailure(); + return; + } if (testAnalysisOnly) return; diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis-invalid.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -linalg-comprehensive-func-bufferize=test-analysis-only -split-input-file -verify-diagnostics + +// ----- + +func @scf_for(%A : tensor, + %B : tensor {linalg.inplaceable = true}, + %C : tensor<4xf32>, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B) + -> (tensor, tensor) + { + %ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor + %ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor + + // Throw a wrench in the system by swapping yielded values: this result in a + // ping-pong of values at each iteration on which we currently want to fail. + + // expected-error @+1 {{Yield operand #1 does not bufferize to an equivalent buffer}} + scf.yield %ttB, %ttA : tensor, tensor + } + + return %r0#0, %r0#1: tensor, tensor +} + diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize-analysis.mlir @@ -412,3 +412,63 @@ return %rA, %rB, %rC: tensor, tensor, tensor } +//===----------------------------------------------------------------------===// +// Simple loop cases +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @scf_for_yield_only +func @scf_for_yield_only(%A : tensor, + %B : tensor {linalg.inplaceable = true}, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + // CHECK: scf.for + // CHECK-NEXT: scf.yield + // CHECK-NEXT: {__inplace_results_attr__ = ["false"]} + %r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor) { + scf.yield %t : tensor + } + + // CHECK: scf.for + // CHECK-NEXT: scf.yield + // CHECK-NEXT: {__inplace_results_attr__ = ["true"]} + %r1 = scf.for %i = %lb to %ub step %step iter_args(%t = %B) -> (tensor) { + scf.yield %t : tensor + } + + return %r0, %r1: tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @scf_for_with_tensor.insert_slice +func @scf_for_with_tensor.insert_slice(%A : tensor, + %B : tensor {linalg.inplaceable = true}, + %C : tensor<4xf32>, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + // CHECK: scf.for + // scf.for bbArgs are always inplaceable seen from ops inside the body: + // 1. Either the matching tensor is not inplaceable and an alloc occurs + // which makes bbArg inplaceable. + // 2. Or it is already inplaceable and so is bbArg. + // CHECK-NEXT: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + // CHECK-NEXT: tensor.insert_slice + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + // CHECK-NEXT: scf.yield + // CHECK-NEXT: {__inplace_results_attr__ = ["false", "true"]} + %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B) + -> (tensor, tensor) + { + %ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor + %ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor + scf.yield %ttA, %ttB : tensor, tensor + } + + return %r0#0, %r0#1: tensor, tensor +} + diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir @@ -273,3 +273,81 @@ // CHECK: return %[[RES]] return %r0: tensor<4xf32> } + +//===----------------------------------------------------------------------===// +// Simple loop cases +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @scf_for_yield_only +func @scf_for_yield_only(%A : tensor, + %B : tensor {linalg.inplaceable = true}, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + // CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast + // CHECK: %[[BUFFER_CAST_B:.*]] = memref.buffer_cast + // CHECK: linalg.copy(%[[BUFFER_CAST_A]], %[[ALLOC_FOR_A]]) + + // The first scf.for remains but just turns into dead code. + %r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor) { + scf.yield %t : tensor + } + + // The second scf.for remains but just turns into dead code. + %r1 = scf.for %i = %lb to %ub step %step iter_args(%t = %B) -> (tensor) { + scf.yield %t : tensor + } + + // Cross function call alloc/dealloc pattern must be hoist out. + // CHECK: memref.dealloc %[[ALLOC_FOR_A]] : memref + // CHECK: %[[rA:.*]] = memref.tensor_load %[[ALLOC_FOR_A]] + // Returning tensor_load of the buffer cast makes the %r1 loop dead. + // CHECK: %[[rB:.*]] = memref.tensor_load %[[BUFFER_CAST_B:.*]] + // CHECK: return %[[rA]], %[[rB]] : tensor, tensor + return %r0, %r1: tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @scf_for_with_tensor.insert_slice +func @scf_for_with_tensor.insert_slice( + %A : tensor, + %B : tensor {linalg.inplaceable = true}, + %C : tensor<4xf32>, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + // CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc + // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast + // CHECK: %[[BUFFER_CAST_B:.*]] = memref.buffer_cast + // CHECK: %[[BUFFER_CAST_C:.*]] = memref.buffer_cast + // CHECK: linalg.copy(%[[BUFFER_CAST_A]], %[[ALLOC_FOR_A]]) + + // CHECK: scf.for {{.*}} iter_args(%[[bbA:.*]] = %{{.*}}, %[[bbB:.*]] = %{{.*}}) + %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B) + -> (tensor, tensor) + { + // CHECK: %[[svA:.*]] = memref.subview %[[ALLOC_FOR_A]][0] [4] [1] + // %ttA bufferizes to direct copy of %BUFFER_CAST_C into %svA + // CHECK: linalg.copy(%[[BUFFER_CAST_C]], %[[svA]]) + %ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor + + // %ttB bufferizes to direct copy of %BUFFER_CAST_C into %BUFFER_CAST_B + // CHECK: %[[svB:.*]] = memref.subview %[[BUFFER_CAST_B]][0] [4] [1] + // CHECK: linalg.copy(%[[BUFFER_CAST_C]], %[[svB]]) + %ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor + + // Yielding bbA and bbB will canonicalize away into oblivion. + // CHECK: scf.yield %[[bbA]], %[[bbB]] : tensor, tensor + scf.yield %ttA, %ttB : tensor, tensor + } + + // CHECK: memref.dealloc %[[ALLOC_FOR_A]] : memref + // CHECK: %[[rA:.*]] = memref.tensor_load %[[ALLOC_FOR_A]] : memref + // CHECK: %[[rB:.*]] = memref.tensor_load %[[BUFFER_CAST_B]] : memref + // CHECK: return %[[rA]], %[[rB]] : tensor, tensor + return %r0#0, %r0#1: tensor, tensor +}