diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -47,6 +47,10 @@ /// Specifies whether returning newly allocated memrefs should be allowed. /// Otherwise, a pass failure is triggered. bool allowReturnAllocs = false; + + /// Specifies whether buffer return values that are equivalent to a FuncOp + /// bbArg should be dropped. + bool dropEquivalentFuncResults = true; }; /// The BufferizationAliasInfo class maintains a list of buffer aliases and diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -230,6 +230,9 @@ `test-analysis-only`. }]; let options = [ + Option<"dropEquivalentFuncResults", "drop-equivalent-func-results", "bool", + /*default=*/"true", + "Drop buffer return values that are equivalent to a FuncOp arg.">, Option<"allowReturnAllocs", "allow-return-allocs", "bool", /*default=*/"false", "Allows returning/yielding new allocations from a block.">, diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -169,6 +169,7 @@ if (!options) { // Make new bufferization options if none were provided when creating the // pass. + opt.dropEquivalentFuncResults = dropEquivalentFuncResults; opt.allowReturnAllocs = allowReturnAllocs; opt.allowUnknownOps = allowUnknownOps; opt.alwaysAliasingWithDest = alwaysAliasingWithDest; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -269,17 +269,19 @@ continue; } - if (Optional bbArgIdx = - getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) { - // Return operands that are equivalent to some bbArg, are not - // returned. - FailureOr bufferOrFailure = - state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx)); - if (failed(bufferOrFailure)) - return failure(); - replacementValues[returnValIdx] = *bufferOrFailure; - newOperands[*bbArgIdx] = *bufferOrFailure; - continue; + if (options.dropEquivalentFuncResults) { + if (Optional bbArgIdx = + getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) { + // Return operands that are equivalent to some bbArg, are not + // returned. + FailureOr bufferOrFailure = + state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx)); + if (failed(bufferOrFailure)) + return failure(); + replacementValues[returnValIdx] = *bufferOrFailure; + newOperands[*bbArgIdx] = *bufferOrFailure; + continue; + } } if (!options.allowReturnAllocs) @@ -404,7 +406,8 @@ FunctionType funcType = funcOp.getFunctionType(); const FuncAnalysisState &funcState = getFuncAnalysisState(state.getAnalysisState()); - const BufferizationOptions &options = state.getOptions(); + const OneShotBufferizationOptions &options = + static_cast(state.getOptions()); // Construct the bufferized function type. SmallVector argTypes; @@ -479,20 +482,23 @@ } // If return operand is equivalent to some bbArg, no need to return it. - if (Optional equivBbArgIdx = getEquivalentFuncArgIdx( - funcOp, funcState, returnOperand.getOperandNumber())) { - rewriter.setInsertionPoint(returnOp); - Location loc = returnOp.getLoc(); - Value toMemrefOp = rewriter.create( - loc, getMemRefType(returnVal.getType().cast(), options), - returnVal); - BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx); - // Note: This copy will fold away. It must be inserted here to ensure - // that `returnVal` still has at least one use and does not fold away. - if (failed( - createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options))) - return funcOp->emitError("could not generate copy for bbArg"); - continue; + if (options.dropEquivalentFuncResults) { + if (Optional equivBbArgIdx = getEquivalentFuncArgIdx( + funcOp, funcState, returnOperand.getOperandNumber())) { + rewriter.setInsertionPoint(returnOp); + Location loc = returnOp.getLoc(); + Value toMemrefOp = rewriter.create( + loc, + getMemRefType(returnVal.getType().cast(), options), + returnVal); + BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx); + // Note: This copy will fold away. It must be inserted here to ensure + // that `returnVal` still has at least one use and does not fold away. + if (failed( + createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options))) + return funcOp->emitError("could not generate copy for bbArg"); + continue; + } } returnValues.push_back(*state.getBuffer(rewriter, returnOperand)); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs drop-equivalent-func-results=false" -split-input-file | FileCheck %s --check-prefix=EQUIV // Run fuzzer with different seeds. // RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null @@ -62,3 +63,16 @@ %r2 = tensor.extract %filled[%idx] : tensor return %r1, %r2 : f32, f32 } + +// ----- + +func.func @return_arg(%A: tensor) -> tensor { + func.return %A : tensor +} +// CHECK-LABEL: func @return_arg +// CHECK-SAME: %[[A:.*]]: memref