diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -98,12 +98,10 @@ DenseMap> equivalentFuncArgs; /// A set of all read BlockArguments of FuncOps. - // Note: BlockArgument knows about its owner, so we do not need to store - // FuncOps here. - DenseSet readBbArgs; + DenseMap> readBbArgs; /// A set of all written-to BlockArguments of FuncOps. - DenseSet writtenBbArgs; + DenseMap> writtenBbArgs; /// Keep track of which FuncOps are fully analyzed or currently being /// analyzed. @@ -263,12 +261,16 @@ ModuleAnalysisState &moduleState = getModuleAnalysisState(state); auto funcOp = cast(op); + // Initialize data structure. + (void)moduleState.readBbArgs[funcOp]; + (void)moduleState.writtenBbArgs[funcOp]; + // If the function has no body, conservatively assume that all args are // read + written. if (funcOp.getBody().empty()) { for (BlockArgument bbArg : funcOp.getArguments()) { - moduleState.readBbArgs.insert(bbArg); - moduleState.writtenBbArgs.insert(bbArg); + moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); + moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); } return success(); @@ -282,9 +284,9 @@ if (state.getOptions().testAnalysisOnly) annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); if (isRead) - moduleState.readBbArgs.insert(bbArg); + moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); if (isWritten) - moduleState.writtenBbArgs.insert(bbArg); + moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); } return success(); @@ -703,8 +705,10 @@ // FuncOp not analyzed yet. Assume that OpOperand is read. return true; - return moduleState.readBbArgs.contains( - funcOp.getArgument(opOperand.getOperandNumber())); + auto it = moduleState.readBbArgs.find(funcOp); + assert(it != moduleState.readBbArgs.end() && + "expected analysis info for analyzed FuncOps"); + return it->second.contains(opOperand.getOperandNumber()); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, @@ -718,8 +722,10 @@ // FuncOp not analyzed yet. Assume that OpOperand is written. return true; - return moduleState.writtenBbArgs.contains( - funcOp.getArgument(opOperand.getOperandNumber())); + auto it = moduleState.writtenBbArgs.find(funcOp); + assert(it != moduleState.writtenBbArgs.end() && + "expected analysis info for analyzed FuncOps"); + return it->second.contains(opOperand.getOperandNumber()); } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, @@ -777,6 +783,8 @@ assert(funcOp && "expected CallOp to a FuncOp"); const ModuleAnalysisState &moduleState = getModuleAnalysisState(state.getAnalysisState()); + const OneShotBufferizationOptions &options = + static_cast(state.getOptions()); // Result types of the bufferized CallOp. SmallVector resultTypes; @@ -828,8 +836,15 @@ continue; } - return callOp->emitError( - "call to FuncOp that returns non-equivalent tensors not supported"); + if (!options.allowReturnAllocs) + return callOp->emitError( + "call to FuncOp that returns non-equivalent tensors not supported"); + + // Returning a memref. This memref is not equivalent to any bbArg. It is + // likely a newly allocated buffer. We may want to hoist such allocations + // to the call site in the future. + retValMapping[returnValIdx] = resultTypes.size(); + resultTypes.push_back(funcOp.getType().getResult(resultTypes.size())); } // 2. Compute bufferized FunctionType. @@ -837,7 +852,7 @@ // Get the bufferized FunctionType for funcOp or construct it if not yet // available. FunctionType bufferizedFuncType = getBufferizedFunctionType( - funcOp.getContext(), argumentTypes, resultTypes, state.getOptions()); + funcOp.getContext(), argumentTypes, resultTypes, options); // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. for (OpOperand &opOperand : callOp->getOpOperands()) { @@ -1028,23 +1043,21 @@ if (options.testAnalysisOnly) return success(); - // Bufferize function bodies. + // Bufferize functions. for (FuncOp funcOp : moduleState.orderedFuncOps) { // No body => no analysis. - if (funcOp.getBody().empty()) - continue; - - if (failed(bufferizeOp(funcOp, bufferizationState))) - return failure(); - } + if (!funcOp.getBody().empty()) + if (failed(bufferizeOp(funcOp, bufferizationState))) + return failure(); - // Bufferize function boundaries. - for (FuncOp funcOp : moduleState.orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, bufferizationState))) return failure(); + } + // Check result. + for (FuncOp funcOp : moduleState.orderedFuncOps) { if (!options.allowReturnAllocs && llvm::any_of(funcOp.getType().getResults(), [](Type t) { return t.isa(); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -212,11 +212,10 @@ // ----- +// expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}} func private @foo(%t : tensor) -> (f32, tensor, f32) func @call_to_unknown_tensor_returning_func(%t : tensor) { - // expected-error @+2 {{call to FuncOp that returns non-equivalent tensors not supported}} - // expected-error @+1 {{op was not bufferized}} call @foo(%t) : (tensor) -> (f32, tensor, f32) return } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1323,3 +1323,23 @@ // CHECK: return %[[r0]], %[[r1]] return %f0, %f1: f32, f32 } + +// ----- + +// CHECK-LABEL: func @create_tensor() -> memref<10xf32> { +// CHECK: %[[alloc:.*]] = memref.alloc +// CHECK: return %[[alloc]] +func @create_tensor() -> tensor<10xf32> { + %0 = linalg.init_tensor [10] : tensor<10xf32> + return %0 : tensor<10xf32> +} + +// CHECK: func @caller( +// CHECK: %[[call:.*]] = call @create_tensor() : () -> memref<10xf32> +// CHECK: %[[extracted:.*]] = memref.load %[[call]] +// CHECK: return %[[extracted]] +func @caller(%idx: index) -> f32 { + %0 = call @create_tensor() : () -> (tensor<10xf32>) + %1 = tensor.extract %0[%idx] : tensor<10xf32> + return %1 : f32 +}