diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -97,12 +97,10 @@ DenseMap> equivalentFuncArgs; /// A set of all read BlockArguments of FuncOps. - // Note: BlockArgument knows about its owner, so we do not need to store - // FuncOps here. - DenseSet readBbArgs; + DenseMap> readBbArgs; /// A set of all written-to BlockArguments of FuncOps. - DenseSet writtenBbArgs; + DenseMap> writtenBbArgs; /// Keep track of which FuncOps are fully analyzed or currently being /// analyzed. @@ -264,12 +262,16 @@ ModuleBufferizationState &moduleState = getModuleBufferizationState(state); auto funcOp = cast(op); + // Initialize data structure. + (void)moduleState.readBbArgs[funcOp]; + (void)moduleState.writtenBbArgs[funcOp]; + // If the function has no body, conservatively assume that all args are // read + written. if (funcOp.getBody().empty()) { for (BlockArgument bbArg : funcOp.getArguments()) { - moduleState.readBbArgs.insert(bbArg); - moduleState.writtenBbArgs.insert(bbArg); + moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); + moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); } return success(); @@ -283,9 +285,9 @@ if (state.getOptions().testAnalysisOnly) annotateFuncArg(funcOp, bbArg, isRead, isWritten); if (isRead) - moduleState.readBbArgs.insert(bbArg); + moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); if (isWritten) - moduleState.writtenBbArgs.insert(bbArg); + moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); } return success(); @@ -703,8 +705,10 @@ // FuncOp not analyzed yet. Assume that OpOperand is read. return true; - return moduleState.readBbArgs.contains( - funcOp.getArgument(opOperand.getOperandNumber())); + auto it = moduleState.readBbArgs.find(funcOp); + assert(it != moduleState.readBbArgs.end() && + "expected analysis info for analyzed FuncOps"); + return it->second.contains(opOperand.getOperandNumber()); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, @@ -719,8 +723,10 @@ // FuncOp not analyzed yet. Assume that OpOperand is written. return true; - return moduleState.writtenBbArgs.contains( - funcOp.getArgument(opOperand.getOperandNumber())); + auto it = moduleState.writtenBbArgs.find(funcOp); + assert(it != moduleState.writtenBbArgs.end() && + "expected analysis info for analyzed FuncOps"); + return it->second.contains(opOperand.getOperandNumber()); } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, @@ -834,8 +840,17 @@ continue; } - return callOp->emitError( - "call to FuncOp that returns non-equivalent tensors not supported"); + if (!state.getOptions().allowReturnMemref) + return callOp->emitError( + "call to FuncOp that returns non-equivalent tensors not supported"); + + // Returning a memref. This memref is not equivalent to any bbArg. It is + // likely a newly allocated buffer. We may want to hoist such allocations + // to the call site in the future. + retValMapping[returnValIdx] = resultTypes.size(); + // resultTypes.push_back(getMemRefType(returnType.cast(), + // state.getOptions())); + resultTypes.push_back(funcOp.getType().getResult(resultTypes.size())); } // 2. Compute bufferized FunctionType. @@ -1039,23 +1054,21 @@ if (preparedOptions.testAnalysisOnly) return success(); - // Bufferize function bodies. + // Bufferize functions. for (FuncOp funcOp : moduleState.orderedFuncOps) { // No body => no analysis. - if (funcOp.body().empty()) - continue; + if (!funcOp.body().empty()) + if (failed(bufferizeOp(funcOp, state))) + return failure(); - if (failed(bufferizeOp(funcOp, state))) - return failure(); - } - - // Bufferize function boundaries. - for (FuncOp funcOp : moduleState.orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, state))) return failure(); + } + // Check result. + for (FuncOp funcOp : moduleState.orderedFuncOps) { if (!preparedOptions.allowReturnMemref && llvm::any_of(funcOp.getType().getResults(), [](Type t) { return t.isa(); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -210,11 +210,10 @@ // ----- +// expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}} func private @foo(%t : tensor) -> (f32, tensor, f32) func @call_to_unknown_tensor_returning_func(%t : tensor) { - // expected-error @+2 {{call to FuncOp that returns non-equivalent tensors not supported}} - // expected-error @+1 {{op was not bufferized}} call @foo(%t) : (tensor) -> (f32, tensor, f32) return } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1355,3 +1355,23 @@ // CHECK: return %[[f]], %[[select]] return %f, %w : f32, tensor } + +// ----- + +// CHECK-LABEL: func @create_tensor() -> memref<10xf32> { +// CHECK: %[[alloc:.*]] = memref.alloc +// CHECK: return %[[alloc]] +func @create_tensor() -> tensor<10xf32> { + %0 = linalg.init_tensor [10] : tensor<10xf32> + return %0 : tensor<10xf32> +} + +// CHECK: func @caller( +// CHECK: %[[call:.*]] = call @create_tensor() : () -> memref<10xf32> +// CHECK: %[[extracted:.*]] = memref.load %[[call]] +// CHECK: return %[[extracted]] +func @caller(%idx: index) -> f32 { + %0 = call @create_tensor() : () -> (tensor<10xf32>) + %1 = tensor.extract %0[%idx] : tensor<10xf32> + return %1 : f32 +}