diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -256,6 +256,7 @@ /// themselves (e.g., ExtractSliceOp). bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead( Value value) const { + assert(value.getType().isa() && "expected TensorType"); SmallVector workingSet; for (OpOperand &use : value.getUses()) workingSet.push_back(&use); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -6,87 +6,68 @@ // //===----------------------------------------------------------------------===// // -// Module bufferization is an extension of Comprehensive Bufferize that +// Module Bufferization is an extension of Comprehensive Bufferize that // bufferizes function boundaries. It provides `BufferizableOpInterface` -// implementations for FuncOp, CallOp and ReturnOp, along with a few helper -// functions that control the order in which functions are bufferized. +// implementations for FuncOp, CallOp and ReturnOp. // -// Three cases can occur during bufferization of FuncOps. +// Module Bufferization is run via `runComprehensiveBufferize(ModuleOp, ...)`. +// This function analyzed the given module and determines the order of +// analysis and bufferization: Functions that are called are processed before +// their respective callers. // -// i. inplaceable function arguments may be reused in place after the -// function itself has been bufferized. This is encoded by IR resembling: +// After analyzing a FuncOp, additional information about its bbArgs is +// gathered through PostAnalysisSteps and stored in `ModuleBufferizationState`. // -// ``` -// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// func @foo(%A: tensor {linalg.inplaceable = true}) -// -> tensor { -// %0 = bufferization.to_memref %A : memref -// // ... uses of %0 -// %res = bufferization.to_tensor %0 : memref -// return %res : tensor -// } -// ``` +// * `EquivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each +// tensor return value (if any). +// * `FuncOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is +// read/written. // -// this is the cue for the bufferization of the function foo (and calls -// to it) may bufferize to `func @foo(%A: memref)`. -// To fully achieve bufferization, an additional analysis is needed to -// determine whether function argument/operand pairs bufferize to a -// single inplace buffer argument (i.e. functions may return tensors in -// arbitrary order that may not match argument numbers). +// Only tensors that are equivalent to some FuncOp bbArg may be returned. +// Bufferization currently fails if other tensors (in particular tensors that +// bufferize out-of-place and result in a new buffer allocation) are returned. +// In the future, such allocations could be hoisted to the caller. // -// ii. results that don't map to an inplaceable function argument are -// generally allocated. Since memref semantics wrt ownership of the -// underlying memory region are not well-defined, comprehensive -// bufferization chooses to perform allocations in a scoped fashion: -// returning memrefs is always considered illegal. -// Such scenarios are encoded by IR resembling: +// Example: `foo` fails bufferization because %0 is not equivalent to any bbArg. +// ``` +// func @foo() -> tensor { +// %0 = linalg.init_tensor [...] : tensor +// return %0 : tensor +// } +// ``` // -// ``` -// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// func @foo(%A: tensor {linalg.inplaceable = true}) -// -> tensor { -// %0 = bufferization.to_memref %A : memref -// %1 = memref.dim %0, %c0 : memref -// %2 = memref.alloc(%1) : memref -// %3 = memref.cast %2 : memref to memref -// // ... uses of %3 -// memref.dealloc %2 : memref -// %res = bufferization.to_tensor %3 : memref -// return %res : tensor -// } -// ``` +// Module Bufferization implements the following calling convention. // -// this is the cue for the bufferization of the function foo (and calls -// to it) that it must bufferize to `func @foo(%A: memref, -// %B: memref)` (i.e. make a cloned -// allocation of the result tensor) -// To fully achieve bufferization, the alloc/dealloc pair must be lifted -// out of the function at each call site. +// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always +// be written to in-place. +// * If a tensor operand of a CallOp is read after the CallOp, the operand of +// the CallOp must bufferize out-of-place. // -// iii. as an optimization over ii., it may be possible to reuse an argument -// and only want to return a slice. -// This may forego allocation by letting *all* callers decide whether to -// pass a new *aliasing* memref function argument (i.e. a subview). -// Without loss of generality, callers may agree to allocate a new buffer -// to avoid this aliasing. Such scenarios are encoded by IR resembling: +// Example: The tensor.insert op bufferizes in-place because it is allowed to +// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize +// out-of-place because `%t0` is modified by the callee but read by the +// tensor.extract op. The analysis of CallOps decides whether an OpOperand must +// bufferize out-of-place based on results of `FuncOpBbArgReadWriteAnalysis`. +// ``` +// func @callee(%t1 : tensor) -> tensor { +// %f = ... : f32 +// %0 = tensor.insert %f into %t1[...] : tensor +// return %0 : tensor +// } // -// ``` -// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// func @foo(%arg0: tensor {linalg.inplaceable = true}) -// -> tensor<4xf32> { -// %0 = bufferization.to_memref %arg0 : memref -// %1 = memref.subview %0[0] [4] [1] : memref to -// memref<4xf32, #map> -// // ... inplace computes into %1 -// %3 = bufferization.to_tensor %1 : memref<4xf32, #map> -// return %3 : tensor<4xf32> -// } -// ``` +// func @caller() -> () { +// %t0 = ... : tensor +// %1 = call @callee(%t0) : (tensor) -> (tensor) +// %2 = tensor.extract %1[...] : tensor +// } +// ``` // -// Note: In the future, it may be worthwhile to design special bufferization -// ops to encode the desired semantics at function boundaries for i., ii. and -// iii. +// Note: If a function is external, `FuncOpBbArgReadWriteAnalysis` cannot +// analyze the function body. In such a case, the CallOp analysis conservatively +// assumes that each tensor OpOperand is both read and written. +// +// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked +// as "not reading" and/or "not writing". #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" @@ -110,8 +91,19 @@ /// indices. DenseMap> equivalentFuncArgs; + /// A set of all read BlockArguments of FuncOps. + DenseSet readBbArgs; + + /// A set of all written-to BlockArguments of FuncOps. + DenseSet writtenBbArgs; + + /// A set of all fully analyzed FuncOps. + DenseSet analyzedFuncOps; + + // A list of functions in the order in which they are analyzed + bufferized. SmallVector orderedFuncOps; + // A mapping of FuncOps to their callers. DenseMap> callerMap; }; } // namespace @@ -197,6 +189,56 @@ return success(); } }; + +/// Return true if the buffer of the given tensor value is written to. May be +/// called only for values inside already analyzed functions. +static bool isValueWritten(Value value, const BufferizationState &state, + const BufferizationAliasInfo &aliasInfo) { + assert(value.getType().isa() && "expected TensorType"); + bool isWritten = false; + aliasInfo.applyOnAliases(value, [&](Value val) { + for (OpOperand &use : val.getUses()) + if (state.isInPlace(use) && state.bufferizesToMemoryWrite(use)) + isWritten = true; + }); + return isWritten; +} + +/// Determine which FuncOp bbArgs are read and which are written. If this +/// PostAnalysisStep is run on a function with unknown ops, it will +/// conservatively assume that such ops bufferize to a read + write. +struct FuncOpBbArgReadWriteAnalysis : public PostAnalysisStep { + LogicalResult run(Operation *op, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, + SmallVector &newOps) override { + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + + // Support only single return-terminated block in the function. + auto funcOp = cast(op); + + // If the function has no body, conservatively assume that all args are + // read + written. + if (funcOp.getBody().empty()) { + for (BlockArgument bbArg : funcOp.getArguments()) { + moduleState.readBbArgs.insert(bbArg); + moduleState.writtenBbArgs.insert(bbArg); + } + + return success(); + } + + for (BlockArgument bbArg : funcOp.getArguments()) { + if (!bbArg.getType().isa()) + continue; + if (state.isValueRead(bbArg)) + moduleState.readBbArgs.insert(bbArg); + if (isValueWritten(bbArg, state, aliasInfo)) + moduleState.writtenBbArgs.insert(bbArg); + } + + return success(); + } +}; } // namespace static bool isaTensor(Type t) { return t.isa(); } @@ -566,43 +608,97 @@ static Optional getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleBufferizationState &state, int64_t returnValIdx) { - if (!state.equivalentFuncArgs.count(funcOp)) + auto funcOpIt = state.equivalentFuncArgs.find(funcOp); + if (funcOpIt == state.equivalentFuncArgs.end()) // No equivalence info stores for funcOp. return None; - const DenseMap &equivFuncArgs = - state.equivalentFuncArgs.lookup(funcOp); - if (!equivFuncArgs.count(returnValIdx)) + auto retValIt = funcOpIt->getSecond().find(returnValIdx); + if (retValIt == funcOpIt->getSecond().end()) // Return value has no equivalent bbArg. return None; - return equivFuncArgs.lookup(returnValIdx); + return retValIt->getSecond(); } struct CallOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const BufferizationState &state) const { - // CallOpInterface alone doesn't bufferize to a memory read, one of the uses - // of the matching bbArg may. It is the responsibility of the caller to - // inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be - // conservative. - return true; + CallOp callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + + const ModuleBufferizationState &moduleState = + getModuleBufferizationState(state); + if (!moduleState.analyzedFuncOps.contains(funcOp)) + // FuncOp not analyzed yet. Assume that OpOperand is read. + return true; + + return moduleState.readBbArgs.contains( + funcOp.getArgument(opOperand.getOperandNumber())); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const BufferizationState &state) const { - return false; + CallOp callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + + const ModuleBufferizationState &moduleState = + getModuleBufferizationState(state); + if (!moduleState.analyzedFuncOps.contains(funcOp)) + // FuncOp not analyzed yet. Assume that OpOperand is written. + return true; + + return moduleState.writtenBbArgs.contains( + funcOp.getArgument(opOperand.getOperandNumber())); } OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, const BufferizationState &state) const { - // CallOpInterface is special, it needs to wait for the callee to be - // bufferized and needs to inspect the BufferAliasInfo object. It can't - // make a proper determination by itself and needs to be conservative. + CallOp callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + const ModuleBufferizationState &moduleState = + getModuleBufferizationState(state); + + for (int64_t resultIdx = 0; resultIdx < callOp->getNumResults(); + ++resultIdx) + if (Optional maybeArgNumber = + getEquivalentFuncArgIdx(funcOp, moduleState, resultIdx)) + if (*maybeArgNumber == opOperand.getOperandNumber()) + return callOp->getOpResult(resultIdx); + + // Note: Returning a non-equivalent tensor from a FuncOp is currently not + // supported an will fail bufferization. return OpResult(); } + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + const BufferizationState &state) const { + CallOp callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + const ModuleBufferizationState &moduleState = + getModuleBufferizationState(state); + + if (Optional maybeArgNumber = getEquivalentFuncArgIdx( + funcOp, moduleState, opResult.getResultNumber())) + return {&op->getOpOperand(*maybeArgNumber)}; + + // Note: Returning a non-equivalent tensor from a FuncOp is currently not + // supported an will fail bufferization. + return {}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationAliasInfo &aliasInfo, + const BufferizationState &state) const { + return BufferRelation::Equivalent; + } + /// In a first approximation, all the function arguments of a FuncOp are /// marked inplaceable. For now, it is the responsibility of the `callOp` /// bufferization to allow FuncOp that are inplaceable to write inPlace. @@ -658,11 +754,12 @@ getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) { // Return operands that are equivalent to some bbArg, are not // returned. - Value buffer = - *state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx), - /*forceInPlace=*/true); - replacementValues[returnValIdx] = buffer; - newOperands[*bbArgIdx] = buffer; + FailureOr bufferOrFailure = + state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx)); + if (failed(bufferOrFailure)) + return failure(); + replacementValues[returnValIdx] = *bufferOrFailure; + newOperands[*bbArgIdx] = *bufferOrFailure; continue; } @@ -691,11 +788,15 @@ // Retrieve buffers for tensor operands. Tensor operand buffers, who's // corresponding FuncOp bbArgs are equivalent to a returned tensor, were // already stored in `newOperands` during Step 1. - Value buffer = newOperands[idx] ? newOperands[idx] - : *state.getBuffer(rewriter, opOperand, - /*forceInPlace=*/true); + Value buffer = newOperands[idx]; + if (!buffer) { + FailureOr bufferOrFailure = state.getBuffer(rewriter, opOperand); + if (failed(bufferOrFailure)) + return failure(); + buffer = *bufferOrFailure; + } - // Caller / callee type mistmatch is handled with a CastOp. + // Caller / callee type mismatch is handled with a CastOp. auto memRefType = bufferizedFuncType.getInput(idx); // Since we don't yet have a clear layout story, to_memref may // conservatively turn tensors into more dynamic memref than necessary. @@ -770,8 +871,6 @@ auto funcOp = cast(op); BlockArgument bbArg = value.dyn_cast(); assert(bbArg && "expected BlockArgument"); - const ModuleBufferizationState &moduleState = - getModuleBufferizationState(state); // "linalg.inplaceable" overrides other writability decisions. This is // currently used for testing only. @@ -780,16 +879,8 @@ BufferizableOpInterface::kInplaceableAttrName)) return inplaceAttr.getValue(); - // In a first approximation: - // ========================= - // If the function is called, we can allocate on the caller side which lets - // us force inplace arguments at function boundaries. - // TODO: do not rely on this behavior. - if (moduleState.callerMap.find(funcOp) != moduleState.callerMap.end()) - return true; - - // All other function arguments are not writable. - return false; + // All function arguments are writable by default. + return true; } bool isAllocationHoistingBarrier(Operation *op) const { return true; } @@ -852,6 +943,8 @@ // Collect bbArg/return value information after the analysis. options->postAnalysisSteps.emplace_back( std::make_unique()); + options->postAnalysisSteps.emplace_back( + std::make_unique()); // Gather equivalence info for CallOps. equivalenceAnalysis(funcOp, aliasInfo, moduleState); @@ -860,6 +953,9 @@ if (failed(analyzeOp(funcOp, state))) return failure(); + // Mark op as fully analyzed. + moduleState.analyzedFuncOps.insert(funcOp); + // Add annotations to function arguments. if (options->testAnalysisOnly) annotateOpsWithBufferizationMarkers(funcOp, state); diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -630,7 +630,7 @@ // of %r1 is read. // CHECK: scf.for // CHECK-NEXT: call - // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + // CHECK-SAME: {__inplace_operands_attr__ = ["false"]} // CHECK-NEXT: scf.yield // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} // CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "false"]} @@ -642,7 +642,7 @@ // %r1 bufferizes inplace fine. // CHECK: scf.for // CHECK-NEXT: call - // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + // CHECK-SAME: {__inplace_operands_attr__ = ["false"]} // CHECK-NEXT: scf.yield // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} // CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "true"]} @@ -655,7 +655,7 @@ // of %r3 is read. // CHECK: linalg.tiled_loop // CHECK-NEXT: call - // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + // CHECK-SAME: {__inplace_operands_attr__ = ["false"]} // CHECK-NEXT: linalg.yield // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} // CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "false"]} @@ -669,7 +669,7 @@ // %r3 bufferizes inplace fine. // CHECK: linalg.tiled_loop // CHECK-NEXT: call - // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} + // CHECK-SAME: {__inplace_operands_attr__ = ["false"]} // CHECK-NEXT: linalg.yield // CHECK-SAME: {__inplace_operands_attr__ = ["true"]} // CHECK: } {__inplace_operands_attr__ = ["none", "none", "none", "true"]} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -410,7 +410,9 @@ // CHECK: %[[A:.*]] = memref.get_global @__constant_4xi32 : memref<4xi32> %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> -// CHECK: %[[B:.*]] = memref.cast %[[A]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]> +// CHECK: %[[alloc:.*]] = memref.alloc +// CHECK: %[[B:.*]] = memref.cast %[[alloc]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]> +// CHECK: linalg.copy(%[[A]], %[[alloc]]) // CHECK: call @some_external_func(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> () call @some_external_func(%A) : (tensor<4xi32>) -> () @@ -430,7 +432,9 @@ // CHECK: %[[A:.*]] = memref.get_global @__constant_4xi32 : memref<4xi32> %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> -// CHECK: %[[B:.*]] = memref.cast %[[A]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]> +// CHECK: %[[alloc:.*]] = memref.alloc +// CHECK: %[[B:.*]] = memref.cast %[[alloc]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]> +// CHECK: linalg.copy(%[[A]], %[[alloc]]) // CHECK: call @some_external_func_within_scf_execute(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> () scf.execute_region { call @some_external_func_within_scf_execute(%A) : (tensor<4xi32>) -> () @@ -488,16 +492,19 @@ %lb : index, %ub : index, %step : index) -> (tensor, tensor) { -// CHECK-NEXT: call @scf_for_with_tensor_insert_slice(%[[A]], %[[B]], %[[C]] +// CHECK: call @scf_for_with_tensor_insert_slice(%[[A]], %[[B]], %[[C]] %r0:2 = call @scf_for_with_tensor_insert_slice(%A, %B, %C, %lb, %ub, %step) : (tensor, tensor, tensor<4xf32>, index, index, index) -> (tensor, tensor) - // %r0#0 is actually %B after inplaceable results are swapped in the callee. -// CHECK-NEXT: call @some_external_func(%[[B]]) : (memref) -> () + // %r0#0 requires a copy because we have no idea what the function is doing. +// CHECK: %[[alloc:.*]] = memref.alloc +// CHECK: %[[casted:.*]] = memref.cast %[[alloc]] +// CHECK: linalg.copy(%[[B]], %[[alloc]]) +// CHECK-NEXT: call @some_external_func(%[[casted]]) : (memref) -> () call @some_external_func(%r0#0) : (tensor) -> () -// CHECK-NEXT: return +// CHECK: return return %r0#0, %r0#1: tensor, tensor } @@ -745,8 +752,16 @@ func @entry(%A : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, linalg.inplaceable = false}, %B : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, linalg.inplaceable = false}, %C : tensor {linalg.inplaceable = false}) { -// CHECK-NEXT: %[[CASTED_B:.*]] = memref.cast %[[B]] : memref to memref -// CHECK-NEXT: call @callee(%[[A]], %[[CASTED_B]], %[[C]]) +// CHECK: %[[ALLOC_C:.*]] = memref.alloc +// CHECK: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]] +// CHECK: %[[ALLOC_B:.*]] = memref.alloc +// CHECK: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]] +// CHECK: %[[ALLOC_A:.*]] = memref.alloc +// CHECK: linalg.copy(%[[A]], %[[ALLOC_A]]) +// CHECK: linalg.copy(%[[B]], %[[ALLOC_B]]) +// CHECK: linalg.copy(%[[C]], %[[ALLOC_C]]) +// CHECK: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]] +// CHECK-NEXT: call @callee(%[[CASTED_A]], %[[CASTED_B]], %[[CASTED_C]]) call @callee(%A, %B, %C) : (tensor, tensor, tensor) -> () return } @@ -1057,7 +1072,10 @@ %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor) { // TODO: There should be a memory copy here. This is a bug in CallOp // bufferization. - // CHECK: call @inner_func_2(%[[arg0]]) + // CHECK: %[[alloc:.*]] = memref.alloc + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK: linalg.copy(%[[arg0]], %[[alloc]]) + // CHECK: call @inner_func_2(%[[casted]]) %3 = call @inner_func_2(%t1) : (tensor) -> tensor scf.yield %t1 : tensor }