diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -2075,14 +2075,24 @@ auto tensorType = operand.get().getType().dyn_cast(); if (!tensorType) continue; + OpOperand &forOperand = forOp.getOpOperandForResult( forOp->getResult(operand.getOperandNumber())); auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - if (getInPlace(bbArg) == InPlaceSpec::True) - operand.set(bbArg); - else - operand.set( - b.create(yieldOp.getLoc(), lookup(bvm, bbArg))); + Value yieldedBuffer = lookup(bvm, operand.get()); + Value bbArgBuffer = lookup(bvm, bbArg); + if (!aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, bbArgBuffer)) { + // TODO: this could get resolved with copies but it can also turn into + // swaps so we need to be careful about order of copies. + return yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to an equivalent buffer to the matching" + << " enclosing scf::for operand"; + } + + // Buffers are equivalent so the work is already done and we just yield the + // bbArg so that it later canonicalizes away. + operand.set(bbArg); } return success(); } @@ -2205,38 +2215,6 @@ return success(); } -/// Return `failure()` if either -/// scf::YieldOp are not explicitly bufferized and we need to perform a separate -/// sanity check for now. -static LogicalResult -bufferizationSanityCheck(scf::YieldOp yieldOp, - const BufferizationAliasInfo &aliasInfo) { - auto parentForOp = yieldOp->getParentOfType(); - if (!parentForOp) - return yieldOp->emitError() << "not nested under ForOp"; - - for (OpOperand &operand : yieldOp->getOpOperands()) { - OpResult matchingForOpResult = - parentForOp->getResult(operand.getOperandNumber()); - // Nothing to do if operand bufferizes out of place. - if (getInPlace(matchingForOpResult) != InPlaceSpec::True) - continue; - OpOperand &machingForOpOperand = - parentForOp.getOpOperandForResult(matchingForOpResult); - BlockArgument matchingForOpIterArg = - parentForOp.getRegionIterArgForOpOperand(machingForOpOperand); - if (!aliasInfo.areEquivalentBufferizedValues(matchingForOpIterArg, - operand.get())) { - return yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to an equivalent buffer to the matching" - << " enclosing scf::for operand -> Fail the pass\n"; - } - } - - return success(); -} - /// Analyze the `funcOp` body to determine which OpResults are inplaceable. static LogicalResult inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, @@ -2275,13 +2253,14 @@ return failure(); } - // Bufferize all ops except ExtractSliceOp and InsertSliceOp which are handled - // separately. + // Analyze all ops that return a tensors, except ExtractSliceOp and + // InsertSliceOp which are handled separately. // Walk other ops in reverse for better interference behavior. for (Operation *op : reverse(nonSliceOps)) for (OpOperand &opOperand : op->getOpOperands()) if (OpResult result = getInplaceableOpResult(opOperand)) - if (failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo, + if (result.getType().isa() && + failed(bufferizableInPlaceAnalysis(opOperand, result, aliasInfo, domInfo))) return failure(); @@ -2292,14 +2271,9 @@ if (failed(bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo))) return failure(); - // Sanity checks. - auto walkResult = funcOp.walk([&](scf::YieldOp yieldOp) -> WalkResult { - return bufferizationSanityCheck(yieldOp, aliasInfo); - }); - LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); - return success(!walkResult.wasInterrupted()); + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -18,7 +18,7 @@ // expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}} func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor, %t2 : tensor) - -> (tensor, tensor) + -> (tensor, tensor) { cond_br %cond1, ^bb1, ^bb2 @@ -64,7 +64,7 @@ // Throw a wrench in the system by swapping yielded values: this result in a // ping-pong of values at each iteration on which we currently want to fail. - // expected-error @+1 {{Yield operand #1 does not bufferize to an equivalent buffer}} + // expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}} scf.yield %ttB, %ttA : tensor, tensor } @@ -73,6 +73,27 @@ // ----- +func private @fun_with_side_effects(%A: tensor {linalg.inplaceable = true}) + +func @foo(%A: tensor {linalg.inplaceable = true}) -> (tensor) { + call @fun_with_side_effects(%A) : (tensor) -> () + return %A: tensor +} + +func @scf_yield_needs_copy(%A : tensor {linalg.inplaceable = true}, %iters : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %res = scf.for %arg0 = %c0 to %iters step %c1 iter_args(%bbarg = %A) -> (tensor) { + %r = call @foo(%A) : (tensor) -> (tensor) + // expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}} + scf.yield %r : tensor + } + call @fun_with_side_effects(%res) : (tensor) -> () + return +} + +// ----- + func @extract_slice_fun(%A : tensor {linalg.inplaceable = true}) -> tensor<4xf32> { @@ -92,8 +113,8 @@ func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32> { - %r = scf.if %b -> (tensor<4xf32>) { - // expected-error @+1 {{not nested under ForOp}} + // expected-error @+1 {{unsupported op with tensors}} + %r = scf.if %b -> (tensor<4xf32>) { scf.yield %A : tensor<4xf32> } else { scf.yield %B : tensor<4xf32>